Better trie implementations
This commit is contained in:
178
proxy/policy.go
Normal file
178
proxy/policy.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"git.maze.io/maze/styx/dataset"
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
"github.com/open-policy-agent/opa/v1/ast"
|
||||
"github.com/open-policy-agent/opa/v1/rego"
|
||||
"github.com/open-policy-agent/opa/v1/types"
|
||||
)
|
||||
|
||||
// PolicyQueryOptions generates the Rego query functions for the provided [Context].
|
||||
func PolicyQueryOptions(ctx Context) (options []func(*rego.Rego)) {
|
||||
var (
|
||||
log = ctx.Logger()
|
||||
storage = ctx.Storage()
|
||||
)
|
||||
|
||||
addr, err := netip.ParseAddr(netutil.Host(ctx.RemoteAddr().String()))
|
||||
if err != nil {
|
||||
log.Err(err).Error("Error resolving remote address")
|
||||
return
|
||||
}
|
||||
|
||||
client, err := storage.ClientByAddr(addr)
|
||||
if err != nil {
|
||||
log.Err(err).Warn("Error resolving client")
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
permitDomains []*dataset.DomainTrie
|
||||
rejectDomains []*dataset.DomainTrie
|
||||
permitNetworks []*dataset.NetworkTrie
|
||||
rejectNetworks []*dataset.NetworkTrie
|
||||
)
|
||||
for _, group := range client.Groups {
|
||||
lists, err := storage.ListsByGroup(group)
|
||||
if err != nil {
|
||||
log.Err(err).Warn("Error resolving lists")
|
||||
return
|
||||
}
|
||||
for _, list := range lists {
|
||||
switch list.Type {
|
||||
case dataset.ListTypeDomain:
|
||||
trie, err := list.Domains()
|
||||
if err != nil {
|
||||
log.Err(err).Warn("Error resolving domain trie")
|
||||
}
|
||||
if list.Permit {
|
||||
permitDomains = append(permitDomains, trie)
|
||||
} else {
|
||||
rejectDomains = append(rejectDomains, trie)
|
||||
}
|
||||
case dataset.ListTypeNetwork:
|
||||
trie, err := list.Networks()
|
||||
if err != nil {
|
||||
log.Err(err).Warn("Error resolving domain trie")
|
||||
}
|
||||
if list.Permit {
|
||||
permitNetworks = append(permitNetworks, trie)
|
||||
} else {
|
||||
rejectNetworks = append(rejectNetworks, trie)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
options = append(options,
|
||||
rego.Function1(®o.Function{
|
||||
Name: "styx.reject_domain",
|
||||
Description: "Check if the domain is to be rejected",
|
||||
Decl: domainFunctionDecl,
|
||||
Nondeterministic: true,
|
||||
Memoize: true,
|
||||
}, domainFunctionImpl(rejectDomains)),
|
||||
rego.Function1(®o.Function{
|
||||
Name: "styx.permit_domain",
|
||||
Description: "Check if the domain is to be permitted",
|
||||
Decl: domainFunctionDecl,
|
||||
Nondeterministic: true,
|
||||
Memoize: true,
|
||||
}, domainFunctionImpl(permitDomains)),
|
||||
rego.Function1(®o.Function{
|
||||
Name: "styx.reject_network",
|
||||
Description: "Check if the IP, IP:port, host or host:port is to be rejected",
|
||||
Decl: networkFunctionDecl,
|
||||
Nondeterministic: true,
|
||||
Memoize: true,
|
||||
}, networkFunctionImpl(rejectNetworks)),
|
||||
rego.Function1(®o.Function{
|
||||
Name: "styx.permit_network",
|
||||
Description: "Check if the IP, IP:port, host or host:port is to be permitted",
|
||||
Decl: networkFunctionDecl,
|
||||
Nondeterministic: true,
|
||||
Memoize: true,
|
||||
}, networkFunctionImpl(permitNetworks)),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
var domainFunctionDecl = types.NewFunction(
|
||||
types.Args(types.Named("domain", types.S).Description("Domain to lookup")),
|
||||
types.Named("result", types.B).Description("`true` if domain matches"),
|
||||
)
|
||||
|
||||
func domainFunctionImpl(tries []*dataset.DomainTrie) rego.Builtin1 {
|
||||
return func(ctx rego.BuiltinContext, domainTerm *ast.Term) (*ast.Term, error) {
|
||||
domain, err := parseStringTerm(domainTerm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, trie := range tries {
|
||||
if trie.Contains(domain) {
|
||||
return ast.NewTerm(ast.Boolean(true)), nil
|
||||
}
|
||||
}
|
||||
return ast.NewTerm(ast.Boolean(false)), nil
|
||||
}
|
||||
}
|
||||
|
||||
var networkFunctionDecl = types.NewFunction(
|
||||
types.Args(types.Named("ip", types.S).Description("IP, IP:port, host or host:port to lookup")),
|
||||
types.Named("result", types.B).Description("`true` if IP matches"),
|
||||
)
|
||||
|
||||
func networkFunctionImpl(tries []*dataset.NetworkTrie) rego.Builtin1 {
|
||||
return func(ctx rego.BuiltinContext, ipTerm *ast.Term) (*ast.Term, error) {
|
||||
ips, err := parseAddrTerm(ipTerm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, trie := range tries {
|
||||
for _, ip := range ips {
|
||||
if trie.Contains(ip) {
|
||||
return ast.NewTerm(ast.Boolean(true)), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return ast.NewTerm(ast.Boolean(false)), nil
|
||||
}
|
||||
}
|
||||
|
||||
func parseAddrTerm(term *ast.Term) (addrs []netip.Addr, err error) {
|
||||
s, err := parseStringTerm(term)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if addr, err := netip.ParseAddr(netutil.Host(s)); err == nil {
|
||||
// Input was "ip" or "ip:port"
|
||||
return []netip.Addr{addr}, nil
|
||||
}
|
||||
|
||||
ips, err := net.LookupIP(netutil.Host(s))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if addr, ok := netip.AddrFromSlice(ip); ok {
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseStringTerm(term *ast.Term) (string, error) {
|
||||
value, ok := term.Value.(ast.String)
|
||||
if !ok {
|
||||
return "", errors.New("expected string argument")
|
||||
}
|
||||
return strings.Trim(value.String(), `"`), nil
|
||||
}
|
Reference in New Issue
Block a user