Checkpoint
This commit is contained in:
179
policy/func.go
Normal file
179
policy/func.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/v1/ast"
|
||||
"github.com/open-policy-agent/opa/v1/rego"
|
||||
"github.com/open-policy-agent/opa/v1/types"
|
||||
|
||||
"git.maze.io/maze/styx/dataset"
|
||||
"git.maze.io/maze/styx/logger"
|
||||
)
|
||||
|
||||
var domainContainsDecl = types.NewFunction(
|
||||
types.Args(
|
||||
types.Named("list", types.S).Description("Domain list to check against"),
|
||||
types.Named("name", types.S).Description("Host name to check"),
|
||||
),
|
||||
types.Named("result", types.B).Description("`true` if `name` is contained within `list`"),
|
||||
)
|
||||
|
||||
func domainContainsImpl(bc rego.BuiltinContext, listTerm, nameTerm *ast.Term) (*ast.Term, error) {
|
||||
log := logger.StandardLog.Value("func", "styx.in_domains")
|
||||
|
||||
list, err := parseDomainListTerm(listTerm)
|
||||
if err != nil {
|
||||
log.Err(err).Debug("Call function failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
name, err := parseStringTerm(nameTerm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Values(logger.Values{
|
||||
"list": listTerm.Value,
|
||||
"name": name,
|
||||
}).Trace("Calling function")
|
||||
return ast.BooleanTerm(list.Contains(name)), nil
|
||||
}
|
||||
|
||||
var networkContainsDecl = types.NewFunction(
|
||||
types.Args(
|
||||
types.Named("list", types.S).Description("Network list to check against"),
|
||||
types.Named("ip", types.S).Description("IP address to check"),
|
||||
),
|
||||
types.Named("result", types.B).Description("`true` if `ip` is contained within `list`"),
|
||||
)
|
||||
|
||||
func networkContainsImpl(bc rego.BuiltinContext, listTerm, ipTerm *ast.Term) (*ast.Term, error) {
|
||||
log := logger.StandardLog.Value("func", "styx.in_networks")
|
||||
|
||||
list, err := parseNetworkListTerm(listTerm)
|
||||
if err != nil {
|
||||
log.Err(err).Debug("Call function failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ip, err := parseIPTerm(ipTerm)
|
||||
if err != nil {
|
||||
log.Value("list", listTerm.Value).Err(err).Debug("Call function failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Values(logger.Values{
|
||||
"list": listTerm.Value,
|
||||
"ip": ip.String(),
|
||||
}).Trace("Calling function")
|
||||
return ast.BooleanTerm(list.Contains(ip)), nil
|
||||
}
|
||||
|
||||
func parseDomainListTerm(term *ast.Term) (*dataset.DomainTree, error) {
|
||||
nameArg, ok := term.Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, errors.New("expected string argument")
|
||||
}
|
||||
name := strings.Trim(nameArg.String(), `"`)
|
||||
|
||||
fn, ok := dataset.Domains[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no such domain list: %q", name)
|
||||
}
|
||||
|
||||
return fn, nil
|
||||
}
|
||||
|
||||
func parseNetworkListTerm(term *ast.Term) (*dataset.NetworkTree, error) {
|
||||
nameArg, ok := term.Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, errors.New("expected string argument")
|
||||
}
|
||||
name := strings.Trim(nameArg.String(), `"`)
|
||||
|
||||
fn, ok := dataset.Networks[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no such network list: %q", name)
|
||||
}
|
||||
|
||||
return fn, nil
|
||||
}
|
||||
|
||||
func parseStringTerm(term *ast.Term) (string, error) {
|
||||
ipArg, ok := term.Value.(ast.String)
|
||||
if !ok {
|
||||
return "", errors.New("expected string argument")
|
||||
}
|
||||
return strings.Trim(ipArg.String(), `"`), nil
|
||||
}
|
||||
|
||||
func parseIPTerm(term *ast.Term) (net.IP, error) {
|
||||
ipArg, ok := term.Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, errors.New("expected string argument")
|
||||
}
|
||||
ip := strings.Trim(ipArg.String(), `"`)
|
||||
if ip := net.ParseIP(ip); ip != nil {
|
||||
return ip, nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid IP address %q", ip)
|
||||
}
|
||||
|
||||
type ListReturner func() ([]string, error)
|
||||
|
||||
var (
|
||||
domains = map[string]ListReturner{}
|
||||
networks = map[string]ListReturner{}
|
||||
)
|
||||
|
||||
func AddDomainList(name string, fn ListReturner) {
|
||||
domains[name] = fn
|
||||
}
|
||||
|
||||
func AddNetworkList(name string, fn ListReturner) {
|
||||
networks[name] = fn
|
||||
}
|
||||
|
||||
func listLookupImpl(kind string, m map[string]ListReturner) func(rego.BuiltinContext, *ast.Term) (*ast.Term, error) {
|
||||
return func(bc rego.BuiltinContext, inf *ast.Term) (*ast.Term, error) {
|
||||
log := logger.StandardLog.Values(logger.V{
|
||||
"where": inf.Location.File + ":" + strconv.Itoa(inf.Location.Row) + "," + strconv.Itoa(inf.Location.Col),
|
||||
"func": kind,
|
||||
})
|
||||
|
||||
nameArg, ok := inf.Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, errors.New("expected string argument")
|
||||
}
|
||||
name := strings.Trim(nameArg.String(), `"`)
|
||||
|
||||
log = log.Value("type", name)
|
||||
log.Trace("Looking up list in policy")
|
||||
|
||||
fn, ok := m[name]
|
||||
if !ok {
|
||||
log.Error("No such list exists")
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
|
||||
list, err := fn()
|
||||
if err != nil {
|
||||
log.Err(err).Error("Error retrieving list")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
astList := make([]*ast.Term, 0, len(list))
|
||||
for _, item := range list {
|
||||
astList = append(astList, ast.StringTerm(item))
|
||||
}
|
||||
|
||||
log.Tracef("Returning list with %d items", len(astList))
|
||||
return ast.NewTerm(ast.NewArray(astList...)), nil
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user