Checkpoint
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -16,6 +18,49 @@ import (
|
||||
"git.maze.io/maze/styx/logger"
|
||||
)
|
||||
|
||||
var netLookupIPAddrDecl = types.NewFunction(
|
||||
types.Args(
|
||||
types.Named("name", types.S).Description("Host name to lookup"),
|
||||
),
|
||||
types.Named("result", types.SetOfStr).Description("set(string) of IP address"),
|
||||
)
|
||||
|
||||
func netLookupIPAddrImpl(bc rego.BuiltinContext, nameTerm *ast.Term) (*ast.Term, error) {
|
||||
log := logger.StandardLog.Value("func", "styx.lookup_ip_addr")
|
||||
log.Trace("Call function")
|
||||
|
||||
name, err := parseStringTerm(nameTerm)
|
||||
if err != nil {
|
||||
log.Err(err).Debug("Call function failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(name); ip != nil {
|
||||
return ast.SetTerm(ast.StringTerm(ip.String())), nil
|
||||
}
|
||||
|
||||
ips, err := net.LookupIP(name)
|
||||
if err != nil {
|
||||
log.Err(err).Debug("IP resolution failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
terms = make([]*ast.Term, len(ips))
|
||||
strs = make([]string, len(ips))
|
||||
)
|
||||
slices.SortStableFunc(ips, func(a, b net.IP) int {
|
||||
return bytes.Compare(a, b)
|
||||
})
|
||||
for i, ip := range ips {
|
||||
terms[i] = ast.StringTerm(ip.String())
|
||||
strs[i] = ip.String()
|
||||
}
|
||||
|
||||
log.Tracef("Resolved %s to %s", name, strings.Join(strs, ", "))
|
||||
return ast.SetTerm(terms...), nil
|
||||
}
|
||||
|
||||
var domainContainsDecl = types.NewFunction(
|
||||
types.Args(
|
||||
types.Named("list", types.S).Description("Domain list to check against"),
|
||||
@@ -26,6 +71,7 @@ var domainContainsDecl = types.NewFunction(
|
||||
|
||||
func domainContainsImpl(bc rego.BuiltinContext, listTerm, nameTerm *ast.Term) (*ast.Term, error) {
|
||||
log := logger.StandardLog.Value("func", "styx.in_domains")
|
||||
log.Trace("Call function")
|
||||
|
||||
list, err := parseDomainListTerm(listTerm)
|
||||
if err != nil {
|
||||
|
@@ -1,8 +1,10 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
"git.maze.io/maze/styx/logger"
|
||||
proxy "git.maze.io/maze/styx/proxy"
|
||||
)
|
||||
@@ -26,17 +28,66 @@ func NewRequestHandler(p *Policy) proxy.RequestHandler {
|
||||
})
|
||||
}
|
||||
|
||||
func NewDialHandler(p *Policy) proxy.DialHandler {
|
||||
log := logger.StandardLog.Value("policy", p.name)
|
||||
return proxy.DialHandlerFunc(func(ctx proxy.Context, req *http.Request) (net.Conn, error) {
|
||||
input := NewInputFromRequest(ctx, req)
|
||||
input.logValues(log).Trace("Running dial handler")
|
||||
result, err := p.Query(input)
|
||||
if err != nil {
|
||||
log.Err(err).Error("Error evaulating policy")
|
||||
return nil, nil
|
||||
}
|
||||
r, err := result.Response(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).Error("Error generating response")
|
||||
return nil, nil
|
||||
}
|
||||
if r == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
c := netutil.NewLoopback()
|
||||
|
||||
go func(c net.Conn) {
|
||||
s := &http.Server{
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
r.Write(w)
|
||||
}),
|
||||
}
|
||||
_ = s.Serve(&netutil.AcceptOnce{Conn: c})
|
||||
}(c.Server)
|
||||
|
||||
return c.Client, nil
|
||||
})
|
||||
}
|
||||
|
||||
func NewForwardHandler(p *Policy) proxy.ForwardHandler {
|
||||
log := logger.StandardLog.Value("policy", p.name)
|
||||
return proxy.ForwardHandlerFunc(func(ctx proxy.Context, req *http.Request) (*http.Response, error) {
|
||||
input := NewInputFromRequest(ctx, req)
|
||||
input.logValues(log).Trace("Running forward handler")
|
||||
result, err := p.Query(input)
|
||||
if err != nil {
|
||||
log.Err(err).Error("Error evaulating policy")
|
||||
return nil, nil
|
||||
}
|
||||
return result.Response(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func NewResponseHandler(p *Policy) proxy.ResponseHandler {
|
||||
log := logger.StandardLog.Value("policy", p.name)
|
||||
return proxy.ResponseHandlerFunc(func(ctx proxy.Context) *http.Response {
|
||||
input := NewInputFromResponse(ctx, ctx.Response())
|
||||
result, err := p.Query(input)
|
||||
if err != nil {
|
||||
logger.StandardLog.Err(err).Error("Error evaulating policy")
|
||||
log.Err(err).Error("Error evaulating policy")
|
||||
return nil
|
||||
}
|
||||
r, err := result.Response(ctx)
|
||||
if err != nil {
|
||||
logger.StandardLog.Err(err).Error("Error generating response")
|
||||
log.Err(err).Error("Error generating response")
|
||||
return nil
|
||||
}
|
||||
return r
|
||||
|
@@ -9,8 +9,10 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-viper/mapstructure/v2"
|
||||
"github.com/open-policy-agent/opa/v1/ast"
|
||||
"github.com/open-policy-agent/opa/v1/rego"
|
||||
regoprint "github.com/open-policy-agent/opa/v1/topdown/print"
|
||||
|
||||
@@ -53,20 +55,36 @@ func newRego(option func(*rego.Rego), pkg string) []func(*rego.Rego) {
|
||||
if pkg == "" {
|
||||
pkg = DefaultPackageName
|
||||
}
|
||||
|
||||
capabilities := &ast.Capabilities{
|
||||
Builtins: ast.DefaultBuiltins[:], // all builtins
|
||||
Features: ast.Features, // all features
|
||||
AllowNet: nil, // allow all
|
||||
}
|
||||
|
||||
return []func(*rego.Rego){
|
||||
rego.Dump(os.Stderr),
|
||||
rego.Query("data." + pkg),
|
||||
rego.Strict(true),
|
||||
rego.Capabilities(capabilities),
|
||||
rego.Function2(®o.Function{
|
||||
Name: "styx.in_domains",
|
||||
Decl: domainContainsDecl,
|
||||
Memoize: true,
|
||||
Nondeterministic: true,
|
||||
}, domainContainsImpl),
|
||||
rego.Function2(®o.Function{
|
||||
Name: "styx.in_networks",
|
||||
Decl: networkContainsDecl,
|
||||
Memoize: true,
|
||||
Nondeterministic: true,
|
||||
}, networkContainsImpl),
|
||||
rego.Function1(®o.Function{
|
||||
Name: "styx.lookup_ip_addr", // override builtin
|
||||
Decl: netLookupIPAddrDecl,
|
||||
Memoize: true,
|
||||
Nondeterministic: true,
|
||||
}, netLookupIPAddrImpl),
|
||||
rego.PrintHook(printHook{}),
|
||||
option,
|
||||
}
|
||||
@@ -100,11 +118,12 @@ type Result struct {
|
||||
}
|
||||
|
||||
func (r *Result) Response(ctx proxy.Context) (*http.Response, error) {
|
||||
log := logger.StandardLog.Values(logger.Values{
|
||||
"id": ctx.ID(),
|
||||
"client": ctx.RemoteAddr().String(),
|
||||
})
|
||||
for _, text := range r.Errors {
|
||||
logger.StandardLog.Values(logger.Values{
|
||||
"id": ctx.ID(),
|
||||
"client": ctx.RemoteAddr().String(),
|
||||
}).Err(errors.New(text)).Warn("Error from policy")
|
||||
log.Err(errors.New(text)).Warn("Error from policy")
|
||||
}
|
||||
|
||||
switch {
|
||||
@@ -116,11 +135,21 @@ func (r *Result) Response(ctx proxy.Context) (*http.Response, error) {
|
||||
|
||||
case r.Template != "":
|
||||
b := new(bytes.Buffer)
|
||||
t, err := template.New("policy").ParseFiles(r.Template)
|
||||
t, err := template.New(filepath.Base(r.Template)).ParseFiles(r.Template)
|
||||
if err != nil {
|
||||
log.Value("template", r.Template).Err(err).Warn("Error loading template in response")
|
||||
return nil, err
|
||||
}
|
||||
if err = t.Execute(b, map[string]any{"context": ctx}); err != nil {
|
||||
t = t.Funcs(template.FuncMap{
|
||||
"tohex": func(v any) string { return fmt.Sprintf("%x", v) },
|
||||
})
|
||||
if err = t.Execute(b, map[string]any{
|
||||
"Context": ctx,
|
||||
"Request": ctx.Request(),
|
||||
"Response": ctx.Response(),
|
||||
"Errors": r.Errors,
|
||||
}); err != nil {
|
||||
log.Value("template", r.Template).Err(err).Warn("Error rendering template response")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user