Files
styx/policy/func.go
2025-10-01 15:37:55 +02:00

180 lines
4.4 KiB
Go

package policy
import (
"errors"
"fmt"
"net"
"os"
"strconv"
"strings"
"github.com/open-policy-agent/opa/v1/ast"
"github.com/open-policy-agent/opa/v1/rego"
"github.com/open-policy-agent/opa/v1/types"
"git.maze.io/maze/styx/dataset"
"git.maze.io/maze/styx/logger"
)
var domainContainsDecl = types.NewFunction(
types.Args(
types.Named("list", types.S).Description("Domain list to check against"),
types.Named("name", types.S).Description("Host name to check"),
),
types.Named("result", types.B).Description("`true` if `name` is contained within `list`"),
)
func domainContainsImpl(bc rego.BuiltinContext, listTerm, nameTerm *ast.Term) (*ast.Term, error) {
log := logger.StandardLog.Value("func", "styx.in_domains")
list, err := parseDomainListTerm(listTerm)
if err != nil {
log.Err(err).Debug("Call function failed")
return nil, err
}
name, err := parseStringTerm(nameTerm)
if err != nil {
return nil, err
}
log.Values(logger.Values{
"list": listTerm.Value,
"name": name,
}).Trace("Calling function")
return ast.BooleanTerm(list.Contains(name)), nil
}
var networkContainsDecl = types.NewFunction(
types.Args(
types.Named("list", types.S).Description("Network list to check against"),
types.Named("ip", types.S).Description("IP address to check"),
),
types.Named("result", types.B).Description("`true` if `ip` is contained within `list`"),
)
func networkContainsImpl(bc rego.BuiltinContext, listTerm, ipTerm *ast.Term) (*ast.Term, error) {
log := logger.StandardLog.Value("func", "styx.in_networks")
list, err := parseNetworkListTerm(listTerm)
if err != nil {
log.Err(err).Debug("Call function failed")
return nil, err
}
ip, err := parseIPTerm(ipTerm)
if err != nil {
log.Value("list", listTerm.Value).Err(err).Debug("Call function failed")
return nil, err
}
log.Values(logger.Values{
"list": listTerm.Value,
"ip": ip.String(),
}).Trace("Calling function")
return ast.BooleanTerm(list.Contains(ip)), nil
}
func parseDomainListTerm(term *ast.Term) (*dataset.DomainTree, error) {
nameArg, ok := term.Value.(ast.String)
if !ok {
return nil, errors.New("expected string argument")
}
name := strings.Trim(nameArg.String(), `"`)
fn, ok := dataset.Domains[name]
if !ok {
return nil, fmt.Errorf("no such domain list: %q", name)
}
return fn, nil
}
func parseNetworkListTerm(term *ast.Term) (*dataset.NetworkTree, error) {
nameArg, ok := term.Value.(ast.String)
if !ok {
return nil, errors.New("expected string argument")
}
name := strings.Trim(nameArg.String(), `"`)
fn, ok := dataset.Networks[name]
if !ok {
return nil, fmt.Errorf("no such network list: %q", name)
}
return fn, nil
}
func parseStringTerm(term *ast.Term) (string, error) {
ipArg, ok := term.Value.(ast.String)
if !ok {
return "", errors.New("expected string argument")
}
return strings.Trim(ipArg.String(), `"`), nil
}
func parseIPTerm(term *ast.Term) (net.IP, error) {
ipArg, ok := term.Value.(ast.String)
if !ok {
return nil, errors.New("expected string argument")
}
ip := strings.Trim(ipArg.String(), `"`)
if ip := net.ParseIP(ip); ip != nil {
return ip, nil
}
return nil, fmt.Errorf("invalid IP address %q", ip)
}
type ListReturner func() ([]string, error)
var (
domains = map[string]ListReturner{}
networks = map[string]ListReturner{}
)
func AddDomainList(name string, fn ListReturner) {
domains[name] = fn
}
func AddNetworkList(name string, fn ListReturner) {
networks[name] = fn
}
func listLookupImpl(kind string, m map[string]ListReturner) func(rego.BuiltinContext, *ast.Term) (*ast.Term, error) {
return func(bc rego.BuiltinContext, inf *ast.Term) (*ast.Term, error) {
log := logger.StandardLog.Values(logger.V{
"where": inf.Location.File + ":" + strconv.Itoa(inf.Location.Row) + "," + strconv.Itoa(inf.Location.Col),
"func": kind,
})
nameArg, ok := inf.Value.(ast.String)
if !ok {
return nil, errors.New("expected string argument")
}
name := strings.Trim(nameArg.String(), `"`)
log = log.Value("type", name)
log.Trace("Looking up list in policy")
fn, ok := m[name]
if !ok {
log.Error("No such list exists")
return nil, os.ErrNotExist
}
list, err := fn()
if err != nil {
log.Err(err).Error("Error retrieving list")
return nil, err
}
astList := make([]*ast.Term, 0, len(list))
for _, item := range list {
astList = append(astList, ast.StringTerm(item))
}
log.Tracef("Returning list with %d items", len(astList))
return ast.NewTerm(ast.NewArray(astList...)), nil
}
}