180 lines
4.4 KiB
Go
180 lines
4.4 KiB
Go
package policy
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/open-policy-agent/opa/v1/ast"
|
|
"github.com/open-policy-agent/opa/v1/rego"
|
|
"github.com/open-policy-agent/opa/v1/types"
|
|
|
|
"git.maze.io/maze/styx/dataset"
|
|
"git.maze.io/maze/styx/logger"
|
|
)
|
|
|
|
var domainContainsDecl = types.NewFunction(
|
|
types.Args(
|
|
types.Named("list", types.S).Description("Domain list to check against"),
|
|
types.Named("name", types.S).Description("Host name to check"),
|
|
),
|
|
types.Named("result", types.B).Description("`true` if `name` is contained within `list`"),
|
|
)
|
|
|
|
func domainContainsImpl(bc rego.BuiltinContext, listTerm, nameTerm *ast.Term) (*ast.Term, error) {
|
|
log := logger.StandardLog.Value("func", "styx.in_domains")
|
|
|
|
list, err := parseDomainListTerm(listTerm)
|
|
if err != nil {
|
|
log.Err(err).Debug("Call function failed")
|
|
return nil, err
|
|
}
|
|
|
|
name, err := parseStringTerm(nameTerm)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
log.Values(logger.Values{
|
|
"list": listTerm.Value,
|
|
"name": name,
|
|
}).Trace("Calling function")
|
|
return ast.BooleanTerm(list.Contains(name)), nil
|
|
}
|
|
|
|
var networkContainsDecl = types.NewFunction(
|
|
types.Args(
|
|
types.Named("list", types.S).Description("Network list to check against"),
|
|
types.Named("ip", types.S).Description("IP address to check"),
|
|
),
|
|
types.Named("result", types.B).Description("`true` if `ip` is contained within `list`"),
|
|
)
|
|
|
|
func networkContainsImpl(bc rego.BuiltinContext, listTerm, ipTerm *ast.Term) (*ast.Term, error) {
|
|
log := logger.StandardLog.Value("func", "styx.in_networks")
|
|
|
|
list, err := parseNetworkListTerm(listTerm)
|
|
if err != nil {
|
|
log.Err(err).Debug("Call function failed")
|
|
return nil, err
|
|
}
|
|
|
|
ip, err := parseIPTerm(ipTerm)
|
|
if err != nil {
|
|
log.Value("list", listTerm.Value).Err(err).Debug("Call function failed")
|
|
return nil, err
|
|
}
|
|
|
|
log.Values(logger.Values{
|
|
"list": listTerm.Value,
|
|
"ip": ip.String(),
|
|
}).Trace("Calling function")
|
|
return ast.BooleanTerm(list.Contains(ip)), nil
|
|
}
|
|
|
|
func parseDomainListTerm(term *ast.Term) (*dataset.DomainTree, error) {
|
|
nameArg, ok := term.Value.(ast.String)
|
|
if !ok {
|
|
return nil, errors.New("expected string argument")
|
|
}
|
|
name := strings.Trim(nameArg.String(), `"`)
|
|
|
|
fn, ok := dataset.Domains[name]
|
|
if !ok {
|
|
return nil, fmt.Errorf("no such domain list: %q", name)
|
|
}
|
|
|
|
return fn, nil
|
|
}
|
|
|
|
func parseNetworkListTerm(term *ast.Term) (*dataset.NetworkTree, error) {
|
|
nameArg, ok := term.Value.(ast.String)
|
|
if !ok {
|
|
return nil, errors.New("expected string argument")
|
|
}
|
|
name := strings.Trim(nameArg.String(), `"`)
|
|
|
|
fn, ok := dataset.Networks[name]
|
|
if !ok {
|
|
return nil, fmt.Errorf("no such network list: %q", name)
|
|
}
|
|
|
|
return fn, nil
|
|
}
|
|
|
|
func parseStringTerm(term *ast.Term) (string, error) {
|
|
ipArg, ok := term.Value.(ast.String)
|
|
if !ok {
|
|
return "", errors.New("expected string argument")
|
|
}
|
|
return strings.Trim(ipArg.String(), `"`), nil
|
|
}
|
|
|
|
func parseIPTerm(term *ast.Term) (net.IP, error) {
|
|
ipArg, ok := term.Value.(ast.String)
|
|
if !ok {
|
|
return nil, errors.New("expected string argument")
|
|
}
|
|
ip := strings.Trim(ipArg.String(), `"`)
|
|
if ip := net.ParseIP(ip); ip != nil {
|
|
return ip, nil
|
|
}
|
|
return nil, fmt.Errorf("invalid IP address %q", ip)
|
|
}
|
|
|
|
type ListReturner func() ([]string, error)
|
|
|
|
var (
|
|
domains = map[string]ListReturner{}
|
|
networks = map[string]ListReturner{}
|
|
)
|
|
|
|
func AddDomainList(name string, fn ListReturner) {
|
|
domains[name] = fn
|
|
}
|
|
|
|
func AddNetworkList(name string, fn ListReturner) {
|
|
networks[name] = fn
|
|
}
|
|
|
|
func listLookupImpl(kind string, m map[string]ListReturner) func(rego.BuiltinContext, *ast.Term) (*ast.Term, error) {
|
|
return func(bc rego.BuiltinContext, inf *ast.Term) (*ast.Term, error) {
|
|
log := logger.StandardLog.Values(logger.V{
|
|
"where": inf.Location.File + ":" + strconv.Itoa(inf.Location.Row) + "," + strconv.Itoa(inf.Location.Col),
|
|
"func": kind,
|
|
})
|
|
|
|
nameArg, ok := inf.Value.(ast.String)
|
|
if !ok {
|
|
return nil, errors.New("expected string argument")
|
|
}
|
|
name := strings.Trim(nameArg.String(), `"`)
|
|
|
|
log = log.Value("type", name)
|
|
log.Trace("Looking up list in policy")
|
|
|
|
fn, ok := m[name]
|
|
if !ok {
|
|
log.Error("No such list exists")
|
|
return nil, os.ErrNotExist
|
|
}
|
|
|
|
list, err := fn()
|
|
if err != nil {
|
|
log.Err(err).Error("Error retrieving list")
|
|
return nil, err
|
|
}
|
|
|
|
astList := make([]*ast.Term, 0, len(list))
|
|
for _, item := range list {
|
|
astList = append(astList, ast.StringTerm(item))
|
|
}
|
|
|
|
log.Tracef("Returning list with %d items", len(astList))
|
|
return ast.NewTerm(ast.NewArray(astList...)), nil
|
|
}
|
|
}
|