package policy import ( "bytes" "context" "errors" "fmt" "html/template" "io" "net/http" "os" "path/filepath" "github.com/go-viper/mapstructure/v2" "github.com/open-policy-agent/opa/v1/ast" "github.com/open-policy-agent/opa/v1/rego" regoprint "github.com/open-policy-agent/opa/v1/topdown/print" "git.maze.io/maze/styx/logger" proxy "git.maze.io/maze/styx/proxy" ) const DefaultPackageName = "styx" var ErrNoResult = errors.New("policy: no result") type Policy struct { name string options []func(*rego.Rego) } func New(name, pkg string) (*Policy, error) { p := &Policy{ name: name, options: newRego(rego.Load([]string{name}, nil), pkg), } if _, err := p.Query(&Input{}); err != nil { return nil, err } return p, nil } func NewFromString(module, pkg string) (*Policy, error) { p := &Policy{ name: "", options: newRego(rego.Module("styx", module), pkg), } if _, err := p.Query(&Input{}); err != nil { return nil, err } return p, nil } func newRego(option func(*rego.Rego), pkg string) []func(*rego.Rego) { if pkg == "" { pkg = DefaultPackageName } capabilities := &ast.Capabilities{ Builtins: ast.DefaultBuiltins[:], // all builtins Features: ast.Features, // all features AllowNet: nil, // allow all } return []func(*rego.Rego){ rego.Dump(os.Stderr), rego.Query("data." + pkg), rego.Strict(true), rego.Capabilities(capabilities), rego.Function2(®o.Function{ Name: "styx.in_domains", Decl: domainContainsDecl, Memoize: true, Nondeterministic: true, }, domainContainsImpl), rego.Function2(®o.Function{ Name: "styx.in_networks", Decl: networkContainsDecl, Memoize: true, Nondeterministic: true, }, networkContainsImpl), rego.Function1(®o.Function{ Name: "styx.lookup_ip_addr", // override builtin Decl: netLookupIPAddrDecl, Memoize: true, Nondeterministic: true, }, netLookupIPAddrImpl), rego.PrintHook(printHook{}), option, } } type printHook struct{} func (printHook) Print(ctx regoprint.Context, message string) error { logger.StandardLog.Values(logger.Values{ "where": fmt.Sprintf("%s:%d,%d", ctx.Location.File, ctx.Location.Row, ctx.Location.Col), "from": string(ctx.Location.Text), }).Debug(message) return nil } type Result struct { // Reject signals explicit rejection. Reject int `json:"reject" mapstructure:"reject"` // Permit signals explicit permission. Permit *bool `json:"permit" mapstructure:"permit"` // Redirect to this URL. Redirect string `json:"redirect" mapstructure:"redirect"` // Template to render as response body. Template string `json:"template" mapstructure:"template"` // Errors contains error messages. Errors []string `json:"errors" mapstructure:"errors,omitempty"` } func (r *Result) Response(ctx proxy.Context) (*http.Response, error) { log := logger.StandardLog.Values(logger.Values{ "id": ctx.ID(), "client": ctx.RemoteAddr().String(), }) for _, text := range r.Errors { log.Err(errors.New(text)).Warn("Error from policy") } switch { case r.Redirect != "": response := proxy.NewResponse(http.StatusFound, nil, ctx.Request()) response.Header.Set("Server", "styx") response.Header.Set(proxy.HeaderLocation, r.Redirect) return response, nil case r.Template != "": b := new(bytes.Buffer) t, err := template.New(filepath.Base(r.Template)).ParseFiles(r.Template) if err != nil { log.Value("template", r.Template).Err(err).Warn("Error loading template in response") return nil, err } t = t.Funcs(template.FuncMap{ "tohex": func(v any) string { return fmt.Sprintf("%x", v) }, }) if err = t.Execute(b, map[string]any{ "Context": ctx, "Request": ctx.Request(), "Response": ctx.Response(), "Errors": r.Errors, }); err != nil { log.Value("template", r.Template).Err(err).Warn("Error rendering template response") return nil, err } response := proxy.NewResponse(http.StatusFound, io.NopCloser(b), ctx.Request()) response.Header.Set("Server", "styx") response.Header.Set(proxy.HeaderContentType, "text/html") return response, nil case r.Reject > 0: body := io.NopCloser(bytes.NewBufferString(http.StatusText(r.Reject))) response := proxy.NewResponse(r.Reject, body, ctx.Request()) response.Header.Set(proxy.HeaderContentType, "text/plain") return response, nil case r.Permit != nil && !*r.Permit: body := io.NopCloser(bytes.NewBufferString(http.StatusText(http.StatusForbidden))) response := proxy.NewResponse(http.StatusForbidden, body, ctx.Request()) response.Header.Set(proxy.HeaderContentType, "text/plain") return response, nil default: return nil, nil } } func (p *Policy) Query(input *Input) (*Result, error) { /* e := json.NewEncoder(os.Stdout) e.SetIndent("", " ") e.Encode(doc) */ log := logger.StandardLog.Value("policy", p.name) log.Trace("Evaluating policy") r := rego.New(append(p.options, rego.Input(input))...) ctx := context.Background() /* query, err := p.rego.PrepareForEval(ctx) if err != nil { return nil, err } rs, err := query.Eval(ctx, rego.EvalInput(input)) if err != nil { return nil, err } */ rs, err := r.Eval(ctx) if err != nil { return nil, err } if len(rs) == 0 || len(rs[0].Expressions) == 0 { return nil, ErrNoResult } result := &Result{} for _, expr := range rs[0].Expressions { if m, ok := expr.Value.(map[string]any); ok { log.Values(m).Trace("Policy result expression") if err = mapstructure.Decode(m, result); err != nil { return nil, err } } } return result, nil }