package proxy import ( "errors" "net" "net/netip" "strings" "git.maze.io/maze/styx/dataset" "git.maze.io/maze/styx/internal/netutil" "github.com/open-policy-agent/opa/v1/ast" "github.com/open-policy-agent/opa/v1/rego" "github.com/open-policy-agent/opa/v1/types" ) // PolicyQueryOptions generates the Rego query functions for the provided [Context]. func PolicyQueryOptions(ctx Context) (options []func(*rego.Rego)) { var ( log = ctx.Logger() storage = ctx.Storage() ) addr, err := netip.ParseAddr(netutil.Host(ctx.RemoteAddr().String())) if err != nil { log.Err(err).Error("Error resolving remote address") return } client, err := storage.ClientByAddr(addr) if err != nil { log.Err(err).Warn("Error resolving client") return } var ( permitDomains []*dataset.DomainTrie rejectDomains []*dataset.DomainTrie permitNetworks []*dataset.NetworkTrie rejectNetworks []*dataset.NetworkTrie ) for _, group := range client.Groups { lists, err := storage.ListsByGroup(group) if err != nil { log.Err(err).Warn("Error resolving lists") return } for _, list := range lists { switch list.Type { case dataset.ListTypeDomain: trie, err := list.Domains() if err != nil { log.Err(err).Warn("Error resolving domain trie") } if list.Permit { permitDomains = append(permitDomains, trie) } else { rejectDomains = append(rejectDomains, trie) } case dataset.ListTypeNetwork: trie, err := list.Networks() if err != nil { log.Err(err).Warn("Error resolving domain trie") } if list.Permit { permitNetworks = append(permitNetworks, trie) } else { rejectNetworks = append(rejectNetworks, trie) } } } } options = append(options, rego.Function1(®o.Function{ Name: "styx.reject_domain", Description: "Check if the domain is to be rejected", Decl: domainFunctionDecl, Nondeterministic: true, Memoize: true, }, domainFunctionImpl(rejectDomains)), rego.Function1(®o.Function{ Name: "styx.permit_domain", Description: "Check if the domain is to be permitted", Decl: domainFunctionDecl, Nondeterministic: true, Memoize: true, }, domainFunctionImpl(permitDomains)), rego.Function1(®o.Function{ Name: "styx.reject_network", Description: "Check if the IP, IP:port, host or host:port is to be rejected", Decl: networkFunctionDecl, Nondeterministic: true, Memoize: true, }, networkFunctionImpl(rejectNetworks)), rego.Function1(®o.Function{ Name: "styx.permit_network", Description: "Check if the IP, IP:port, host or host:port is to be permitted", Decl: networkFunctionDecl, Nondeterministic: true, Memoize: true, }, networkFunctionImpl(permitNetworks)), ) return } var domainFunctionDecl = types.NewFunction( types.Args(types.Named("domain", types.S).Description("Domain to lookup")), types.Named("result", types.B).Description("`true` if domain matches"), ) func domainFunctionImpl(tries []*dataset.DomainTrie) rego.Builtin1 { return func(ctx rego.BuiltinContext, domainTerm *ast.Term) (*ast.Term, error) { domain, err := parseStringTerm(domainTerm) if err != nil { return nil, err } for _, trie := range tries { if trie.Contains(domain) { return ast.NewTerm(ast.Boolean(true)), nil } } return ast.NewTerm(ast.Boolean(false)), nil } } var networkFunctionDecl = types.NewFunction( types.Args(types.Named("ip", types.S).Description("IP, IP:port, host or host:port to lookup")), types.Named("result", types.B).Description("`true` if IP matches"), ) func networkFunctionImpl(tries []*dataset.NetworkTrie) rego.Builtin1 { return func(ctx rego.BuiltinContext, ipTerm *ast.Term) (*ast.Term, error) { ips, err := parseAddrTerm(ipTerm) if err != nil { return nil, err } for _, trie := range tries { for _, ip := range ips { if trie.Contains(ip) { return ast.NewTerm(ast.Boolean(true)), nil } } } return ast.NewTerm(ast.Boolean(false)), nil } } func parseAddrTerm(term *ast.Term) (addrs []netip.Addr, err error) { s, err := parseStringTerm(term) if err != nil { return nil, err } if addr, err := netip.ParseAddr(netutil.Host(s)); err == nil { // Input was "ip" or "ip:port" return []netip.Addr{addr}, nil } ips, err := net.LookupIP(netutil.Host(s)) if err != nil { return nil, err } for _, ip := range ips { if addr, ok := netip.AddrFromSlice(ip); ok { addrs = append(addrs, addr) } } return } func parseStringTerm(term *ast.Term) (string, error) { value, ok := term.Value.(ast.String) if !ok { return "", errors.New("expected string argument") } return strings.Trim(value.String(), `"`), nil }