package policy import ( "bytes" "errors" "fmt" "net" "os" "slices" "strconv" "strings" "github.com/open-policy-agent/opa/v1/ast" "github.com/open-policy-agent/opa/v1/rego" "github.com/open-policy-agent/opa/v1/types" "git.maze.io/maze/styx/dataset" "git.maze.io/maze/styx/internal/timeutil" "git.maze.io/maze/styx/logger" ) var lookupIPAddrFunc = ®o.Function{ Name: "styx.lookup_ip_addr", Decl: lookupIPAddrDecl, Memoize: true, Nondeterministic: true, } var lookupIPAddrDecl = types.NewFunction( types.Args( types.Named("name", types.S).Description("Host name to lookup"), ), types.Named("result", types.SetOfStr).Description("set(string) of IP address"), ) func lookupIPAddr(bc rego.BuiltinContext, nameTerm *ast.Term) (*ast.Term, error) { log := logger.StandardLog.Value("func", "styx.lookup_ip_addr") log.Trace("Call function") name, err := parseStringTerm(nameTerm) if err != nil { log.Err(err).Debug("Call function failed") return nil, err } if ip := net.ParseIP(name); ip != nil { return ast.SetTerm(ast.StringTerm(ip.String())), nil } ips, err := net.LookupIP(name) if err != nil { log.Err(err).Debug("IP resolution failed") return nil, err } var ( terms = make([]*ast.Term, len(ips)) strs = make([]string, len(ips)) ) slices.SortStableFunc(ips, func(a, b net.IP) int { return bytes.Compare(a, b) }) for i, ip := range ips { terms[i] = ast.StringTerm(ip.String()) strs[i] = ip.String() } log.Tracef("Resolved %s to %s", name, strings.Join(strs, ", ")) return ast.SetTerm(terms...), nil } var timebetweenFunc = ®o.Function{ Name: "styx.time_between", Decl: timeBetweenDecl, Nondeterministic: false, } var timeBetweenDecl = types.NewFunction( types.Args( types.Named("start", types.S).Description("Start time"), types.Named("end", types.S).Description("End time"), ), types.Named("result", types.B).Description("`true` if the current local time is between `start` and `end`"), ) func timeBetween(bc rego.BuiltinContext, startTerm, endTerm *ast.Term) (*ast.Term, error) { log := logger.StandardLog.Value("func", "styx.time_between") log.Trace("Call function") start, err := parseTimeTerm(startTerm) if err != nil { log.Err(err).Debug("Invalid start time") return nil, err } end, err := parseTimeTerm(endTerm) if err != nil { log.Err(err).Debug("Invalid end time") return nil, err } now := timeutil.Now() if start.Before(end) { return ast.BooleanTerm((now.Eq(start) || now.After(start)) && now.Before(end)), nil } return ast.BooleanTerm(now.Eq(end) || now.After(end) || now.Before(start)), nil } func parseTimeTerm(term *ast.Term) (timeutil.Time, error) { timeArg, ok := term.Value.(ast.String) if !ok { return timeutil.Time{}, errors.New("expected string argument") } return timeutil.ParseTime(strings.Trim(timeArg.String(), `"`)) } var domainContainsFunc = ®o.Function{ Name: "styx.domains_contain", Decl: domainContainsDecl, Memoize: true, Nondeterministic: true, } var domainContainsDecl = types.NewFunction( types.Args( types.Named("list", types.S).Description("Domain list to check against"), types.Named("name", types.S).Description("Host name to check"), ), types.Named("result", types.B).Description("`true` if `name` is contained within `list`"), ) func domainContains(bc rego.BuiltinContext, listTerm, nameTerm *ast.Term) (*ast.Term, error) { log := logger.StandardLog.Value("func", "styx.domains_contain") log.Trace("Call function") list, err := parseDomainListTerm(listTerm) if err != nil { log.Err(err).Debug("Call function failed") return nil, err } name, err := parseStringTerm(nameTerm) if err != nil { return nil, err } log.Values(logger.Values{ "list": listTerm.Value, "name": name, }).Trace("Calling function") return ast.BooleanTerm(list.Contains(name)), nil } var networkContainsFunc = ®o.Function{ Name: "styx.networks_contain", Decl: networkContainsDecl, Memoize: true, Nondeterministic: true, } var networkContainsDecl = types.NewFunction( types.Args( types.Named("list", types.S).Description("Network list to check against"), types.Named("ip", types.S).Description("IP address to check"), ), types.Named("result", types.B).Description("`true` if `ip` is contained within `list`"), ) func networkContains(bc rego.BuiltinContext, listTerm, ipTerm *ast.Term) (*ast.Term, error) { log := logger.StandardLog.Value("func", "styx.networks_contain") list, err := parseNetworkListTerm(listTerm) if err != nil { log.Err(err).Debug("Call function failed") return nil, err } ip, err := parseIPTerm(ipTerm) if err != nil { log.Value("list", listTerm.Value).Err(err).Debug("Call function failed") return nil, err } log.Values(logger.Values{ "list": listTerm.Value, "ip": ip.String(), }).Trace("Calling function") return ast.BooleanTerm(list.Contains(ip)), nil } func parseDomainListTerm(term *ast.Term) (*dataset.DomainTree, error) { nameArg, ok := term.Value.(ast.String) if !ok { return nil, errors.New("expected string argument") } name := strings.Trim(nameArg.String(), `"`) fn, ok := dataset.Domains[name] if !ok { return nil, fmt.Errorf("no such domain list: %q", name) } return fn, nil } func parseNetworkListTerm(term *ast.Term) (*dataset.NetworkTree, error) { nameArg, ok := term.Value.(ast.String) if !ok { return nil, errors.New("expected string argument") } name := strings.Trim(nameArg.String(), `"`) fn, ok := dataset.Networks[name] if !ok { return nil, fmt.Errorf("no such network list: %q", name) } return fn, nil } func parseStringTerm(term *ast.Term) (string, error) { ipArg, ok := term.Value.(ast.String) if !ok { return "", errors.New("expected string argument") } return strings.Trim(ipArg.String(), `"`), nil } func parseIPTerm(term *ast.Term) (net.IP, error) { ipArg, ok := term.Value.(ast.String) if !ok { return nil, errors.New("expected string argument") } ip := strings.Trim(ipArg.String(), `"`) if ip := net.ParseIP(ip); ip != nil { return ip, nil } return nil, fmt.Errorf("invalid IP address %q", ip) } type ListReturner func() ([]string, error) var ( domains = map[string]ListReturner{} networks = map[string]ListReturner{} ) func AddDomainList(name string, fn ListReturner) { domains[name] = fn } func AddNetworkList(name string, fn ListReturner) { networks[name] = fn } func listLookupImpl(kind string, m map[string]ListReturner) func(rego.BuiltinContext, *ast.Term) (*ast.Term, error) { return func(bc rego.BuiltinContext, inf *ast.Term) (*ast.Term, error) { log := logger.StandardLog.Values(logger.V{ "where": inf.Location.File + ":" + strconv.Itoa(inf.Location.Row) + "," + strconv.Itoa(inf.Location.Col), "func": kind, }) nameArg, ok := inf.Value.(ast.String) if !ok { return nil, errors.New("expected string argument") } name := strings.Trim(nameArg.String(), `"`) log = log.Value("type", name) log.Trace("Looking up list in policy") fn, ok := m[name] if !ok { log.Error("No such list exists") return nil, os.ErrNotExist } list, err := fn() if err != nil { log.Err(err).Error("Error retrieving list") return nil, err } astList := make([]*ast.Term, 0, len(list)) for _, item := range list { astList = append(astList, ast.StringTerm(item)) } log.Tracef("Returning list with %d items", len(astList)) return ast.NewTerm(ast.NewArray(astList...)), nil } }