Files
styx/proxy/policy/rule.go
2025-09-26 08:49:53 +02:00

369 lines
8.7 KiB
Go

package policy
import (
"fmt"
"net"
"net/http"
"strings"
"time"
"git.maze.io/maze/styx/internal/netutil"
"git.maze.io/maze/styx/proxy/match"
"github.com/google/uuid"
"github.com/hashicorp/hcl/v2"
"github.com/hashicorp/hcl/v2/gohcl"
)
// Rule is a policy rule.
type Rule interface {
Configure(match.Matchers) error
}
// InterceptRule can make policy rule decisions on intercept requests.
type InterceptRule interface {
PermitIntercept(r *http.Request) *bool
}
// RequestRule can make policy rule decisions on HTTP CONNECT requests.
type RequestRule interface {
PermitRequest(r *http.Request) *bool
}
type rawRule struct {
Type string `hcl:"type,label" json:"type"`
Body hcl.Body `hcl:",remain" json:"-"`
Rule `json:"rule"`
}
func (r *rawRule) Configure(matchers match.Matchers) (err error) {
switch r.Type {
case "intercept":
r.Rule = new(interceptRule)
case "request":
r.Rule = new(requestRule)
case "days":
r.Rule = new(daysRule)
case "time":
r.Rule = new(timeRule)
case "all":
r.Rule = new(allRule)
default:
return fmt.Errorf("policy: invalid event type %q", r.Type)
}
if diag := gohcl.DecodeBody(r.Body, nil, r.Rule); diag.HasErrors() {
return err
}
return r.Rule.Configure(matchers)
}
type allRule struct {
Rules []*rawRule `hcl:"on,block"`
Permit *bool `hcl:"permit"`
}
func (r *allRule) Configure(matchers match.Matchers) (err error) {
return
}
type domainOrNetworkRule struct {
matchers []match.Matcher
isSource []bool
}
func (r *domainOrNetworkRule) configure(kind string, matchers match.Matchers, domains, sources, targets []string, v any, id *string) (err error) {
var m match.Matcher
for _, domain := range domains {
if m, err = matchers.Get("domain", domain); err != nil {
return fmt.Errorf("%s: unknown domain %q", kind, domain)
}
r.matchers = append(r.matchers, m)
r.isSource = append(r.isSource, false)
}
for _, network := range sources {
if m, err = matchers.Get("network", network); err != nil {
return fmt.Errorf("%s: unknown source network %q", kind, network)
}
r.matchers = append(r.matchers, m)
r.isSource = append(r.isSource, true)
}
for _, network := range targets {
if m, err = matchers.Get("network", network); err != nil {
return fmt.Errorf("%s: unknown target network %q", kind, network)
}
r.matchers = append(r.matchers, m)
r.isSource = append(r.isSource, false)
}
if len(r.matchers) == 0 {
return fmt.Errorf("%s: missing any of domain, source, target", kind)
}
if id != nil {
*id = uuid.NewString()
}
return
}
func (r *domainOrNetworkRule) matchesRequest(q *http.Request) bool {
for i, m := range r.matchers {
if m, ok := m.(match.Request); ok {
if m.MatchesRequest(q) {
return true
}
}
if m, ok := m.(match.IP); ok {
if r.isSource[i] {
if m.MatchesIP(net.ParseIP(netutil.Host(q.RemoteAddr))) {
return true
}
} else {
var (
host = netutil.Host(q.URL.Host)
ips []net.IP
)
if ip := net.ParseIP(host); ip != nil {
ips = append(ips, ip)
} else {
ips, _ = net.LookupIP(host)
}
for _, ip := range ips {
if m.MatchesIP(ip) {
return true
}
}
}
}
}
return false
}
type interceptRule struct {
ID string `json:"id,omitempty"`
Domain []string `hcl:"domain,optional" json:"domain,omitempty"`
Source []string `hcl:"source,optional" json:"source,omitempty"`
Target []string `hcl:"target,optional" json:"target,omitempty"`
Permit *bool `hcl:"permit" json:"permit"`
domainOrNetworkRule `json:"-"`
}
func (r *interceptRule) Configure(matchers match.Matchers) (err error) {
return r.configure("intercept", matchers, r.Domain, r.Source, r.Target, r, &r.ID)
}
func (r *interceptRule) PermitIntercept(q *http.Request) *bool {
if r.matchesRequest(q) {
return r.Permit
}
return nil
}
type requestRule struct {
ID string `json:"id,omitempty"`
Domain []string `hcl:"domain,optional" json:"domain,omitempty"`
Source []string `hcl:"source,optional" json:"source,omitempty"`
Target []string `hcl:"target,optional" json:"target,omitempty"`
Permit *bool `hcl:"permit" json:"permit"`
domainOrNetworkRule `json:"-"`
}
func (r *requestRule) Configure(matchers match.Matchers) (err error) {
return r.configure("request", matchers, r.Domain, r.Source, r.Target, r, &r.ID)
}
func (r *requestRule) PermitRequest(q *http.Request) *bool {
if r.matchesRequest(q) {
return r.Permit
}
return nil
}
type timeRule struct {
ID string `json:"id,omitempty"`
Time []string `hcl:"time" json:"time"`
Permit *bool `hcl:"permit" json:"permit"`
Body hcl.Body `hcl:",remain" json:"-"`
Rules *Policy `json:"rules"`
Start Time `json:"start"`
End Time `json:"end"`
}
func (r *timeRule) isActive() bool {
if r == nil {
return true
}
now := Now()
if r.Start.After(r.End) { // ie: 18:00-06:00
return now.After(r.Start) || now.Before(r.End)
}
return now.After(r.Start) && now.Before(r.End)
}
func (r *timeRule) Configure(matchers match.Matchers) (err error) {
if len(r.Time) != 2 {
return fmt.Errorf("invalid time %s, need [start, stop]", r.Time)
}
if r.Start, err = ParseTime(r.Time[0]); err != nil {
return fmt.Errorf("invalid start %q: %w", r.Time[0], err)
}
if r.End, err = ParseTime(r.Time[1]); err != nil {
return fmt.Errorf("invalid end %q: %w", r.Time[1], err)
}
r.Rules = new(Policy)
if diag := gohcl.DecodeBody(r.Body, nil, r.Rules); diag.HasErrors() {
return diag
}
if err = r.Rules.Configure(matchers); err != nil {
return
}
r.Rules.Matchers = nil
if r.ID == "" {
r.ID = uuid.NewString()
}
return
}
func (r *timeRule) PermitIntercept(q *http.Request) *bool {
if !r.isActive() {
return nil
}
return r.Rules.PermitIntercept(q)
}
func (r *timeRule) PermitRequest(q *http.Request) *bool {
if !r.isActive() {
return nil
}
return r.Rules.PermitRequest(q)
}
type daysRule struct {
ID string `json:"id,omitempty"`
Days string `hcl:"days" json:"days"`
Permit *bool `hcl:"permit" json:"permit"`
Body hcl.Body `hcl:",remain" json:"-"`
Rules *Policy `json:"rules"`
cond []onCond
}
func (r *daysRule) isActive() bool {
if r == nil || len(r.cond) == 0 {
return true
}
now := time.Now()
for _, cond := range r.cond {
if cond(now) {
return true
}
}
return false
}
func (r *daysRule) Configure(matchers match.Matchers) (err error) {
if r.cond, err = parseOnCond(r.Days); err != nil {
return
}
r.Rules = new(Policy)
if diag := gohcl.DecodeBody(r.Body, nil, r.Rules); diag.HasErrors() {
return diag
}
if err = r.Rules.Configure(matchers); err != nil {
return
}
r.Rules.Matchers = nil
if r.ID == "" {
r.ID = uuid.NewString()
}
return
}
func (r *daysRule) PermitIntercept(q *http.Request) *bool {
if !r.isActive() {
return nil
}
return r.Rules.PermitIntercept(q)
}
func (r *daysRule) PermitRequest(q *http.Request) *bool {
if !r.isActive() {
return nil
}
return r.Rules.PermitRequest(q)
}
type onCond func(time.Time) bool
var weekdays = map[string]time.Weekday{
"sun": time.Sunday,
"mon": time.Monday,
"tue": time.Tuesday,
"wed": time.Wednesday,
"thu": time.Thursday,
"fri": time.Friday,
"sat": time.Saturday,
}
func parseOnCond(when string) (conds []onCond, err error) {
for _, spec := range strings.Split(when, ",") {
spec = strings.ToLower(strings.TrimSpace(spec))
if d, ok := weekdays[spec]; ok {
conds = append(conds, onWeekday(d))
} else if spec == "weekend" || spec == "weekends" {
conds = append(conds, onWeekend)
} else if spec == "workday" || spec == "workdays" {
conds = append(conds, onWorkday)
} else if strings.ContainsRune(spec, '-') {
var (
part = strings.SplitN(spec, "-", 2)
from, upto time.Weekday
ok bool
)
if from, ok = weekdays[part[0]]; !ok {
return nil, fmt.Errorf("on %q: invalid weekday %q", spec, part[0])
}
if upto, ok = weekdays[part[1]]; !ok {
return nil, fmt.Errorf("on %q: invalid weekday %q", spec, part[1])
}
if from < upto {
for d := from; d < upto; d++ {
conds = append(conds, onWeekday(d))
}
} else {
for d := time.Sunday; d < from; d++ {
conds = append(conds, onWeekday(d))
}
for d := upto; d <= time.Saturday; d++ {
conds = append(conds, onWeekday(d))
}
}
} else {
return nil, fmt.Errorf("on %q: invalid condition", spec)
}
}
return
}
func onWeekday(weekday time.Weekday) onCond {
return func(t time.Time) bool {
return t.Weekday() == weekday
}
}
func onWeekend(t time.Time) bool {
d := t.Weekday()
return d == time.Saturday || d == time.Sunday
}
func onWorkday(t time.Time) bool {
d := t.Weekday()
return !(d == time.Saturday || d == time.Sunday)
}