package policy import ( "errors" "fmt" "net" "os" "strconv" "strings" "github.com/open-policy-agent/opa/v1/ast" "github.com/open-policy-agent/opa/v1/rego" "github.com/open-policy-agent/opa/v1/types" "git.maze.io/maze/styx/dataset" "git.maze.io/maze/styx/logger" ) var domainContainsDecl = types.NewFunction( types.Args( types.Named("list", types.S).Description("Domain list to check against"), types.Named("name", types.S).Description("Host name to check"), ), types.Named("result", types.B).Description("`true` if `name` is contained within `list`"), ) func domainContainsImpl(bc rego.BuiltinContext, listTerm, nameTerm *ast.Term) (*ast.Term, error) { log := logger.StandardLog.Value("func", "styx.in_domains") list, err := parseDomainListTerm(listTerm) if err != nil { log.Err(err).Debug("Call function failed") return nil, err } name, err := parseStringTerm(nameTerm) if err != nil { return nil, err } log.Values(logger.Values{ "list": listTerm.Value, "name": name, }).Trace("Calling function") return ast.BooleanTerm(list.Contains(name)), nil } var networkContainsDecl = types.NewFunction( types.Args( types.Named("list", types.S).Description("Network list to check against"), types.Named("ip", types.S).Description("IP address to check"), ), types.Named("result", types.B).Description("`true` if `ip` is contained within `list`"), ) func networkContainsImpl(bc rego.BuiltinContext, listTerm, ipTerm *ast.Term) (*ast.Term, error) { log := logger.StandardLog.Value("func", "styx.in_networks") list, err := parseNetworkListTerm(listTerm) if err != nil { log.Err(err).Debug("Call function failed") return nil, err } ip, err := parseIPTerm(ipTerm) if err != nil { log.Value("list", listTerm.Value).Err(err).Debug("Call function failed") return nil, err } log.Values(logger.Values{ "list": listTerm.Value, "ip": ip.String(), }).Trace("Calling function") return ast.BooleanTerm(list.Contains(ip)), nil } func parseDomainListTerm(term *ast.Term) (*dataset.DomainTree, error) { nameArg, ok := term.Value.(ast.String) if !ok { return nil, errors.New("expected string argument") } name := strings.Trim(nameArg.String(), `"`) fn, ok := dataset.Domains[name] if !ok { return nil, fmt.Errorf("no such domain list: %q", name) } return fn, nil } func parseNetworkListTerm(term *ast.Term) (*dataset.NetworkTree, error) { nameArg, ok := term.Value.(ast.String) if !ok { return nil, errors.New("expected string argument") } name := strings.Trim(nameArg.String(), `"`) fn, ok := dataset.Networks[name] if !ok { return nil, fmt.Errorf("no such network list: %q", name) } return fn, nil } func parseStringTerm(term *ast.Term) (string, error) { ipArg, ok := term.Value.(ast.String) if !ok { return "", errors.New("expected string argument") } return strings.Trim(ipArg.String(), `"`), nil } func parseIPTerm(term *ast.Term) (net.IP, error) { ipArg, ok := term.Value.(ast.String) if !ok { return nil, errors.New("expected string argument") } ip := strings.Trim(ipArg.String(), `"`) if ip := net.ParseIP(ip); ip != nil { return ip, nil } return nil, fmt.Errorf("invalid IP address %q", ip) } type ListReturner func() ([]string, error) var ( domains = map[string]ListReturner{} networks = map[string]ListReturner{} ) func AddDomainList(name string, fn ListReturner) { domains[name] = fn } func AddNetworkList(name string, fn ListReturner) { networks[name] = fn } func listLookupImpl(kind string, m map[string]ListReturner) func(rego.BuiltinContext, *ast.Term) (*ast.Term, error) { return func(bc rego.BuiltinContext, inf *ast.Term) (*ast.Term, error) { log := logger.StandardLog.Values(logger.V{ "where": inf.Location.File + ":" + strconv.Itoa(inf.Location.Row) + "," + strconv.Itoa(inf.Location.Col), "func": kind, }) nameArg, ok := inf.Value.(ast.String) if !ok { return nil, errors.New("expected string argument") } name := strings.Trim(nameArg.String(), `"`) log = log.Value("type", name) log.Trace("Looking up list in policy") fn, ok := m[name] if !ok { log.Error("No such list exists") return nil, os.ErrNotExist } list, err := fn() if err != nil { log.Err(err).Error("Error retrieving list") return nil, err } astList := make([]*ast.Term, 0, len(list)) for _, item := range list { astList = append(astList, ast.StringTerm(item)) } log.Tracef("Returning list with %d items", len(astList)) return ast.NewTerm(ast.NewArray(astList...)), nil } }