179 lines
4.7 KiB
Go
179 lines
4.7 KiB
Go
package proxy
|
|
|
|
import (
|
|
"errors"
|
|
"net"
|
|
"net/netip"
|
|
"strings"
|
|
|
|
"git.maze.io/maze/styx/dataset"
|
|
"git.maze.io/maze/styx/internal/netutil"
|
|
"github.com/open-policy-agent/opa/v1/ast"
|
|
"github.com/open-policy-agent/opa/v1/rego"
|
|
"github.com/open-policy-agent/opa/v1/types"
|
|
)
|
|
|
|
// PolicyQueryOptions generates the Rego query functions for the provided [Context].
|
|
func PolicyQueryOptions(ctx Context) (options []func(*rego.Rego)) {
|
|
var (
|
|
log = ctx.Logger()
|
|
storage = ctx.Storage()
|
|
)
|
|
|
|
addr, err := netip.ParseAddr(netutil.Host(ctx.RemoteAddr().String()))
|
|
if err != nil {
|
|
log.Err(err).Error("Error resolving remote address")
|
|
return
|
|
}
|
|
|
|
client, err := storage.ClientByAddr(addr)
|
|
if err != nil {
|
|
log.Err(err).Warn("Error resolving client")
|
|
return
|
|
}
|
|
|
|
var (
|
|
permitDomains []*dataset.DomainTrie
|
|
rejectDomains []*dataset.DomainTrie
|
|
permitNetworks []*dataset.NetworkTrie
|
|
rejectNetworks []*dataset.NetworkTrie
|
|
)
|
|
for _, group := range client.Groups {
|
|
lists, err := storage.ListsByGroup(group)
|
|
if err != nil {
|
|
log.Err(err).Warn("Error resolving lists")
|
|
return
|
|
}
|
|
for _, list := range lists {
|
|
switch list.Type {
|
|
case dataset.ListTypeDomain:
|
|
trie, err := list.Domains()
|
|
if err != nil {
|
|
log.Err(err).Warn("Error resolving domain trie")
|
|
}
|
|
if list.Permit {
|
|
permitDomains = append(permitDomains, trie)
|
|
} else {
|
|
rejectDomains = append(rejectDomains, trie)
|
|
}
|
|
case dataset.ListTypeNetwork:
|
|
trie, err := list.Networks()
|
|
if err != nil {
|
|
log.Err(err).Warn("Error resolving domain trie")
|
|
}
|
|
if list.Permit {
|
|
permitNetworks = append(permitNetworks, trie)
|
|
} else {
|
|
rejectNetworks = append(rejectNetworks, trie)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
options = append(options,
|
|
rego.Function1(®o.Function{
|
|
Name: "styx.reject_domain",
|
|
Description: "Check if the domain is to be rejected",
|
|
Decl: domainFunctionDecl,
|
|
Nondeterministic: true,
|
|
Memoize: true,
|
|
}, domainFunctionImpl(rejectDomains)),
|
|
rego.Function1(®o.Function{
|
|
Name: "styx.permit_domain",
|
|
Description: "Check if the domain is to be permitted",
|
|
Decl: domainFunctionDecl,
|
|
Nondeterministic: true,
|
|
Memoize: true,
|
|
}, domainFunctionImpl(permitDomains)),
|
|
rego.Function1(®o.Function{
|
|
Name: "styx.reject_network",
|
|
Description: "Check if the IP, IP:port, host or host:port is to be rejected",
|
|
Decl: networkFunctionDecl,
|
|
Nondeterministic: true,
|
|
Memoize: true,
|
|
}, networkFunctionImpl(rejectNetworks)),
|
|
rego.Function1(®o.Function{
|
|
Name: "styx.permit_network",
|
|
Description: "Check if the IP, IP:port, host or host:port is to be permitted",
|
|
Decl: networkFunctionDecl,
|
|
Nondeterministic: true,
|
|
Memoize: true,
|
|
}, networkFunctionImpl(permitNetworks)),
|
|
)
|
|
return
|
|
}
|
|
|
|
var domainFunctionDecl = types.NewFunction(
|
|
types.Args(types.Named("domain", types.S).Description("Domain to lookup")),
|
|
types.Named("result", types.B).Description("`true` if domain matches"),
|
|
)
|
|
|
|
func domainFunctionImpl(tries []*dataset.DomainTrie) rego.Builtin1 {
|
|
return func(ctx rego.BuiltinContext, domainTerm *ast.Term) (*ast.Term, error) {
|
|
domain, err := parseStringTerm(domainTerm)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, trie := range tries {
|
|
if trie.Contains(domain) {
|
|
return ast.NewTerm(ast.Boolean(true)), nil
|
|
}
|
|
}
|
|
return ast.NewTerm(ast.Boolean(false)), nil
|
|
}
|
|
}
|
|
|
|
var networkFunctionDecl = types.NewFunction(
|
|
types.Args(types.Named("ip", types.S).Description("IP, IP:port, host or host:port to lookup")),
|
|
types.Named("result", types.B).Description("`true` if IP matches"),
|
|
)
|
|
|
|
func networkFunctionImpl(tries []*dataset.NetworkTrie) rego.Builtin1 {
|
|
return func(ctx rego.BuiltinContext, ipTerm *ast.Term) (*ast.Term, error) {
|
|
ips, err := parseAddrTerm(ipTerm)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, trie := range tries {
|
|
for _, ip := range ips {
|
|
if trie.Contains(ip) {
|
|
return ast.NewTerm(ast.Boolean(true)), nil
|
|
}
|
|
}
|
|
}
|
|
return ast.NewTerm(ast.Boolean(false)), nil
|
|
}
|
|
}
|
|
|
|
func parseAddrTerm(term *ast.Term) (addrs []netip.Addr, err error) {
|
|
s, err := parseStringTerm(term)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if addr, err := netip.ParseAddr(netutil.Host(s)); err == nil {
|
|
// Input was "ip" or "ip:port"
|
|
return []netip.Addr{addr}, nil
|
|
}
|
|
|
|
ips, err := net.LookupIP(netutil.Host(s))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, ip := range ips {
|
|
if addr, ok := netip.AddrFromSlice(ip); ok {
|
|
addrs = append(addrs, addr)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func parseStringTerm(term *ast.Term) (string, error) {
|
|
value, ok := term.Value.(ast.String)
|
|
if !ok {
|
|
return "", errors.New("expected string argument")
|
|
}
|
|
return strings.Trim(value.String(), `"`), nil
|
|
}
|