Files
styx/proxy/policy.go
2025-10-08 20:57:13 +02:00

179 lines
4.7 KiB
Go

package proxy
import (
"errors"
"net"
"net/netip"
"strings"
"git.maze.io/maze/styx/dataset"
"git.maze.io/maze/styx/internal/netutil"
"github.com/open-policy-agent/opa/v1/ast"
"github.com/open-policy-agent/opa/v1/rego"
"github.com/open-policy-agent/opa/v1/types"
)
// PolicyQueryOptions generates the Rego query functions for the provided [Context].
func PolicyQueryOptions(ctx Context) (options []func(*rego.Rego)) {
var (
log = ctx.Logger()
storage = ctx.Storage()
)
addr, err := netip.ParseAddr(netutil.Host(ctx.RemoteAddr().String()))
if err != nil {
log.Err(err).Error("Error resolving remote address")
return
}
client, err := storage.ClientByAddr(addr)
if err != nil {
log.Err(err).Warn("Error resolving client")
return
}
var (
permitDomains []*dataset.DomainTrie
rejectDomains []*dataset.DomainTrie
permitNetworks []*dataset.NetworkTrie
rejectNetworks []*dataset.NetworkTrie
)
for _, group := range client.Groups {
lists, err := storage.ListsByGroup(group)
if err != nil {
log.Err(err).Warn("Error resolving lists")
return
}
for _, list := range lists {
switch list.Type {
case dataset.ListTypeDomain:
trie, err := list.Domains()
if err != nil {
log.Err(err).Warn("Error resolving domain trie")
}
if list.Permit {
permitDomains = append(permitDomains, trie)
} else {
rejectDomains = append(rejectDomains, trie)
}
case dataset.ListTypeNetwork:
trie, err := list.Networks()
if err != nil {
log.Err(err).Warn("Error resolving domain trie")
}
if list.Permit {
permitNetworks = append(permitNetworks, trie)
} else {
rejectNetworks = append(rejectNetworks, trie)
}
}
}
}
options = append(options,
rego.Function1(&rego.Function{
Name: "styx.reject_domain",
Description: "Check if the domain is to be rejected",
Decl: domainFunctionDecl,
Nondeterministic: true,
Memoize: true,
}, domainFunctionImpl(rejectDomains)),
rego.Function1(&rego.Function{
Name: "styx.permit_domain",
Description: "Check if the domain is to be permitted",
Decl: domainFunctionDecl,
Nondeterministic: true,
Memoize: true,
}, domainFunctionImpl(permitDomains)),
rego.Function1(&rego.Function{
Name: "styx.reject_network",
Description: "Check if the IP, IP:port, host or host:port is to be rejected",
Decl: networkFunctionDecl,
Nondeterministic: true,
Memoize: true,
}, networkFunctionImpl(rejectNetworks)),
rego.Function1(&rego.Function{
Name: "styx.permit_network",
Description: "Check if the IP, IP:port, host or host:port is to be permitted",
Decl: networkFunctionDecl,
Nondeterministic: true,
Memoize: true,
}, networkFunctionImpl(permitNetworks)),
)
return
}
var domainFunctionDecl = types.NewFunction(
types.Args(types.Named("domain", types.S).Description("Domain to lookup")),
types.Named("result", types.B).Description("`true` if domain matches"),
)
func domainFunctionImpl(tries []*dataset.DomainTrie) rego.Builtin1 {
return func(ctx rego.BuiltinContext, domainTerm *ast.Term) (*ast.Term, error) {
domain, err := parseStringTerm(domainTerm)
if err != nil {
return nil, err
}
for _, trie := range tries {
if trie.Contains(domain) {
return ast.NewTerm(ast.Boolean(true)), nil
}
}
return ast.NewTerm(ast.Boolean(false)), nil
}
}
var networkFunctionDecl = types.NewFunction(
types.Args(types.Named("ip", types.S).Description("IP, IP:port, host or host:port to lookup")),
types.Named("result", types.B).Description("`true` if IP matches"),
)
func networkFunctionImpl(tries []*dataset.NetworkTrie) rego.Builtin1 {
return func(ctx rego.BuiltinContext, ipTerm *ast.Term) (*ast.Term, error) {
ips, err := parseAddrTerm(ipTerm)
if err != nil {
return nil, err
}
for _, trie := range tries {
for _, ip := range ips {
if trie.Contains(ip) {
return ast.NewTerm(ast.Boolean(true)), nil
}
}
}
return ast.NewTerm(ast.Boolean(false)), nil
}
}
func parseAddrTerm(term *ast.Term) (addrs []netip.Addr, err error) {
s, err := parseStringTerm(term)
if err != nil {
return nil, err
}
if addr, err := netip.ParseAddr(netutil.Host(s)); err == nil {
// Input was "ip" or "ip:port"
return []netip.Addr{addr}, nil
}
ips, err := net.LookupIP(netutil.Host(s))
if err != nil {
return nil, err
}
for _, ip := range ips {
if addr, ok := netip.AddrFromSlice(ip); ok {
addrs = append(addrs, addr)
}
}
return
}
func parseStringTerm(term *ast.Term) (string, error) {
value, ok := term.Value.(ast.String)
if !ok {
return "", errors.New("expected string argument")
}
return strings.Trim(value.String(), `"`), nil
}