Files
styx/policy/func.go
2025-10-06 22:25:23 +02:00

292 lines
7.4 KiB
Go

package policy
import (
"bytes"
"errors"
"fmt"
"net"
"os"
"slices"
"strconv"
"strings"
"github.com/open-policy-agent/opa/v1/ast"
"github.com/open-policy-agent/opa/v1/rego"
"github.com/open-policy-agent/opa/v1/types"
"git.maze.io/maze/styx/dataset"
"git.maze.io/maze/styx/internal/timeutil"
"git.maze.io/maze/styx/logger"
)
var lookupIPAddrFunc = &rego.Function{
Name: "styx.lookup_ip_addr",
Decl: lookupIPAddrDecl,
Memoize: true,
Nondeterministic: true,
}
var lookupIPAddrDecl = types.NewFunction(
types.Args(
types.Named("name", types.S).Description("Host name to lookup"),
),
types.Named("result", types.SetOfStr).Description("set(string) of IP address"),
)
func lookupIPAddr(bc rego.BuiltinContext, nameTerm *ast.Term) (*ast.Term, error) {
log := logger.StandardLog.Value("func", "styx.lookup_ip_addr")
log.Trace("Call function")
name, err := parseStringTerm(nameTerm)
if err != nil {
log.Err(err).Debug("Call function failed")
return nil, err
}
if ip := net.ParseIP(name); ip != nil {
return ast.SetTerm(ast.StringTerm(ip.String())), nil
}
ips, err := net.LookupIP(name)
if err != nil {
log.Err(err).Debug("IP resolution failed")
return nil, err
}
var (
terms = make([]*ast.Term, len(ips))
strs = make([]string, len(ips))
)
slices.SortStableFunc(ips, func(a, b net.IP) int {
return bytes.Compare(a, b)
})
for i, ip := range ips {
terms[i] = ast.StringTerm(ip.String())
strs[i] = ip.String()
}
log.Tracef("Resolved %s to %s", name, strings.Join(strs, ", "))
return ast.SetTerm(terms...), nil
}
var timebetweenFunc = &rego.Function{
Name: "styx.time_between",
Decl: timeBetweenDecl,
Nondeterministic: false,
}
var timeBetweenDecl = types.NewFunction(
types.Args(
types.Named("start", types.S).Description("Start time"),
types.Named("end", types.S).Description("End time"),
),
types.Named("result", types.B).Description("`true` if the current local time is between `start` and `end`"),
)
func timeBetween(bc rego.BuiltinContext, startTerm, endTerm *ast.Term) (*ast.Term, error) {
log := logger.StandardLog.Value("func", "styx.time_between")
log.Trace("Call function")
start, err := parseTimeTerm(startTerm)
if err != nil {
log.Err(err).Debug("Invalid start time")
return nil, err
}
end, err := parseTimeTerm(endTerm)
if err != nil {
log.Err(err).Debug("Invalid end time")
return nil, err
}
now := timeutil.Now()
if start.Before(end) {
return ast.BooleanTerm((now.Eq(start) || now.After(start)) && now.Before(end)), nil
}
return ast.BooleanTerm(now.Eq(end) || now.After(end) || now.Before(start)), nil
}
func parseTimeTerm(term *ast.Term) (timeutil.Time, error) {
timeArg, ok := term.Value.(ast.String)
if !ok {
return timeutil.Time{}, errors.New("expected string argument")
}
return timeutil.ParseTime(strings.Trim(timeArg.String(), `"`))
}
var domainContainsFunc = &rego.Function{
Name: "styx.domains_contain",
Decl: domainContainsDecl,
Memoize: true,
Nondeterministic: true,
}
var domainContainsDecl = types.NewFunction(
types.Args(
types.Named("list", types.S).Description("Domain list to check against"),
types.Named("name", types.S).Description("Host name to check"),
),
types.Named("result", types.B).Description("`true` if `name` is contained within `list`"),
)
func domainContains(bc rego.BuiltinContext, listTerm, nameTerm *ast.Term) (*ast.Term, error) {
log := logger.StandardLog.Value("func", "styx.domains_contain")
log.Trace("Call function")
list, err := parseDomainListTerm(listTerm)
if err != nil {
log.Err(err).Debug("Call function failed")
return nil, err
}
name, err := parseStringTerm(nameTerm)
if err != nil {
return nil, err
}
log.Values(logger.Values{
"list": listTerm.Value,
"name": name,
}).Trace("Calling function")
return ast.BooleanTerm(list.Contains(name)), nil
}
var networkContainsFunc = &rego.Function{
Name: "styx.networks_contain",
Decl: networkContainsDecl,
Memoize: true,
Nondeterministic: true,
}
var networkContainsDecl = types.NewFunction(
types.Args(
types.Named("list", types.S).Description("Network list to check against"),
types.Named("ip", types.S).Description("IP address to check"),
),
types.Named("result", types.B).Description("`true` if `ip` is contained within `list`"),
)
func networkContains(bc rego.BuiltinContext, listTerm, ipTerm *ast.Term) (*ast.Term, error) {
log := logger.StandardLog.Value("func", "styx.networks_contain")
list, err := parseNetworkListTerm(listTerm)
if err != nil {
log.Err(err).Debug("Call function failed")
return nil, err
}
ip, err := parseIPTerm(ipTerm)
if err != nil {
log.Value("list", listTerm.Value).Err(err).Debug("Call function failed")
return nil, err
}
log.Values(logger.Values{
"list": listTerm.Value,
"ip": ip.String(),
}).Trace("Calling function")
return ast.BooleanTerm(list.Contains(ip)), nil
}
func parseDomainListTerm(term *ast.Term) (*dataset.DomainTree, error) {
nameArg, ok := term.Value.(ast.String)
if !ok {
return nil, errors.New("expected string argument")
}
name := strings.Trim(nameArg.String(), `"`)
fn, ok := dataset.Domains[name]
if !ok {
return nil, fmt.Errorf("no such domain list: %q", name)
}
return fn, nil
}
func parseNetworkListTerm(term *ast.Term) (*dataset.NetworkTree, error) {
nameArg, ok := term.Value.(ast.String)
if !ok {
return nil, errors.New("expected string argument")
}
name := strings.Trim(nameArg.String(), `"`)
fn, ok := dataset.Networks[name]
if !ok {
return nil, fmt.Errorf("no such network list: %q", name)
}
return fn, nil
}
func parseStringTerm(term *ast.Term) (string, error) {
ipArg, ok := term.Value.(ast.String)
if !ok {
return "", errors.New("expected string argument")
}
return strings.Trim(ipArg.String(), `"`), nil
}
func parseIPTerm(term *ast.Term) (net.IP, error) {
ipArg, ok := term.Value.(ast.String)
if !ok {
return nil, errors.New("expected string argument")
}
ip := strings.Trim(ipArg.String(), `"`)
if ip := net.ParseIP(ip); ip != nil {
return ip, nil
}
return nil, fmt.Errorf("invalid IP address %q", ip)
}
type ListReturner func() ([]string, error)
var (
domains = map[string]ListReturner{}
networks = map[string]ListReturner{}
)
func AddDomainList(name string, fn ListReturner) {
domains[name] = fn
}
func AddNetworkList(name string, fn ListReturner) {
networks[name] = fn
}
func listLookupImpl(kind string, m map[string]ListReturner) func(rego.BuiltinContext, *ast.Term) (*ast.Term, error) {
return func(bc rego.BuiltinContext, inf *ast.Term) (*ast.Term, error) {
log := logger.StandardLog.Values(logger.V{
"where": inf.Location.File + ":" + strconv.Itoa(inf.Location.Row) + "," + strconv.Itoa(inf.Location.Col),
"func": kind,
})
nameArg, ok := inf.Value.(ast.String)
if !ok {
return nil, errors.New("expected string argument")
}
name := strings.Trim(nameArg.String(), `"`)
log = log.Value("type", name)
log.Trace("Looking up list in policy")
fn, ok := m[name]
if !ok {
log.Error("No such list exists")
return nil, os.ErrNotExist
}
list, err := fn()
if err != nil {
log.Err(err).Error("Error retrieving list")
return nil, err
}
astList := make([]*ast.Term, 0, len(list))
for _, item := range list {
astList = append(astList, ast.StringTerm(item))
}
log.Tracef("Returning list with %d items", len(astList))
return ast.NewTerm(ast.NewArray(astList...)), nil
}
}