Checkpoint

This commit is contained in:
2025-10-01 15:37:55 +02:00
parent 4a60059ff2
commit 03352e3312
31 changed files with 2611 additions and 384 deletions

179
policy/func.go Normal file
View File

@@ -0,0 +1,179 @@
package policy
import (
"errors"
"fmt"
"net"
"os"
"strconv"
"strings"
"github.com/open-policy-agent/opa/v1/ast"
"github.com/open-policy-agent/opa/v1/rego"
"github.com/open-policy-agent/opa/v1/types"
"git.maze.io/maze/styx/dataset"
"git.maze.io/maze/styx/logger"
)
var domainContainsDecl = types.NewFunction(
types.Args(
types.Named("list", types.S).Description("Domain list to check against"),
types.Named("name", types.S).Description("Host name to check"),
),
types.Named("result", types.B).Description("`true` if `name` is contained within `list`"),
)
func domainContainsImpl(bc rego.BuiltinContext, listTerm, nameTerm *ast.Term) (*ast.Term, error) {
log := logger.StandardLog.Value("func", "styx.in_domains")
list, err := parseDomainListTerm(listTerm)
if err != nil {
log.Err(err).Debug("Call function failed")
return nil, err
}
name, err := parseStringTerm(nameTerm)
if err != nil {
return nil, err
}
log.Values(logger.Values{
"list": listTerm.Value,
"name": name,
}).Trace("Calling function")
return ast.BooleanTerm(list.Contains(name)), nil
}
var networkContainsDecl = types.NewFunction(
types.Args(
types.Named("list", types.S).Description("Network list to check against"),
types.Named("ip", types.S).Description("IP address to check"),
),
types.Named("result", types.B).Description("`true` if `ip` is contained within `list`"),
)
func networkContainsImpl(bc rego.BuiltinContext, listTerm, ipTerm *ast.Term) (*ast.Term, error) {
log := logger.StandardLog.Value("func", "styx.in_networks")
list, err := parseNetworkListTerm(listTerm)
if err != nil {
log.Err(err).Debug("Call function failed")
return nil, err
}
ip, err := parseIPTerm(ipTerm)
if err != nil {
log.Value("list", listTerm.Value).Err(err).Debug("Call function failed")
return nil, err
}
log.Values(logger.Values{
"list": listTerm.Value,
"ip": ip.String(),
}).Trace("Calling function")
return ast.BooleanTerm(list.Contains(ip)), nil
}
func parseDomainListTerm(term *ast.Term) (*dataset.DomainTree, error) {
nameArg, ok := term.Value.(ast.String)
if !ok {
return nil, errors.New("expected string argument")
}
name := strings.Trim(nameArg.String(), `"`)
fn, ok := dataset.Domains[name]
if !ok {
return nil, fmt.Errorf("no such domain list: %q", name)
}
return fn, nil
}
func parseNetworkListTerm(term *ast.Term) (*dataset.NetworkTree, error) {
nameArg, ok := term.Value.(ast.String)
if !ok {
return nil, errors.New("expected string argument")
}
name := strings.Trim(nameArg.String(), `"`)
fn, ok := dataset.Networks[name]
if !ok {
return nil, fmt.Errorf("no such network list: %q", name)
}
return fn, nil
}
func parseStringTerm(term *ast.Term) (string, error) {
ipArg, ok := term.Value.(ast.String)
if !ok {
return "", errors.New("expected string argument")
}
return strings.Trim(ipArg.String(), `"`), nil
}
func parseIPTerm(term *ast.Term) (net.IP, error) {
ipArg, ok := term.Value.(ast.String)
if !ok {
return nil, errors.New("expected string argument")
}
ip := strings.Trim(ipArg.String(), `"`)
if ip := net.ParseIP(ip); ip != nil {
return ip, nil
}
return nil, fmt.Errorf("invalid IP address %q", ip)
}
type ListReturner func() ([]string, error)
var (
domains = map[string]ListReturner{}
networks = map[string]ListReturner{}
)
func AddDomainList(name string, fn ListReturner) {
domains[name] = fn
}
func AddNetworkList(name string, fn ListReturner) {
networks[name] = fn
}
func listLookupImpl(kind string, m map[string]ListReturner) func(rego.BuiltinContext, *ast.Term) (*ast.Term, error) {
return func(bc rego.BuiltinContext, inf *ast.Term) (*ast.Term, error) {
log := logger.StandardLog.Values(logger.V{
"where": inf.Location.File + ":" + strconv.Itoa(inf.Location.Row) + "," + strconv.Itoa(inf.Location.Col),
"func": kind,
})
nameArg, ok := inf.Value.(ast.String)
if !ok {
return nil, errors.New("expected string argument")
}
name := strings.Trim(nameArg.String(), `"`)
log = log.Value("type", name)
log.Trace("Looking up list in policy")
fn, ok := m[name]
if !ok {
log.Error("No such list exists")
return nil, os.ErrNotExist
}
list, err := fn()
if err != nil {
log.Err(err).Error("Error retrieving list")
return nil, err
}
astList := make([]*ast.Term, 0, len(list))
for _, item := range list {
astList = append(astList, ast.StringTerm(item))
}
log.Tracef("Returning list with %d items", len(astList))
return ast.NewTerm(ast.NewArray(astList...)), nil
}
}

44
policy/handler.go Normal file
View File

@@ -0,0 +1,44 @@
package policy
import (
"net/http"
"git.maze.io/maze/styx/logger"
proxy "git.maze.io/maze/styx/proxy"
)
func NewRequestHandler(p *Policy) proxy.RequestHandler {
log := logger.StandardLog.Value("policy", p.name)
return proxy.RequestHandlerFunc(func(ctx proxy.Context) (*http.Request, *http.Response) {
input := NewInputFromRequest(ctx, ctx.Request())
input.logValues(log).Trace("Running request handler")
result, err := p.Query(input)
if err != nil {
log.Err(err).Error("Error evaulating policy")
return nil, nil
}
r, err := result.Response(ctx)
if err != nil {
log.Err(err).Error("Error generating response")
return nil, nil
}
return nil, r
})
}
func NewResponseHandler(p *Policy) proxy.ResponseHandler {
return proxy.ResponseHandlerFunc(func(ctx proxy.Context) *http.Response {
input := NewInputFromResponse(ctx, ctx.Response())
result, err := p.Query(input)
if err != nil {
logger.StandardLog.Err(err).Error("Error evaulating policy")
return nil
}
r, err := result.Response(ctx)
if err != nil {
logger.StandardLog.Err(err).Error("Error generating response")
return nil
}
return r
})
}

394
policy/input.go Normal file
View File

@@ -0,0 +1,394 @@
package policy
import (
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"net"
"net/http"
"net/url"
"strconv"
"git.maze.io/maze/styx/internal/netutil"
"git.maze.io/maze/styx/logger"
)
// Input represents the input to the policy query.
type Input struct {
Client *Client `json:"client"`
TLS *TLS `json:"tls"`
Request *Request `json:"request"`
Response *Response `json:"response"`
}
func (i *Input) logValues(log logger.Structured) logger.Structured {
log = i.Client.logValues(log)
log = i.TLS.logValues(log)
log = i.Request.logValues(log)
log = i.Response.logValues(log)
return log
}
func NewInputFromConn(c net.Conn) *Input {
if c == nil {
return new(Input)
}
return &Input{
Client: NewClientFromConn(c),
TLS: NewTLSFromConn(c),
}
}
func NewInputFromRequest(c net.Conn, r *http.Request) *Input {
if r == nil {
return nil
}
input := NewInputFromConn(c)
input.Request = NewRequest(r)
return input
}
func NewInputFromResponse(c net.Conn, r *http.Response) *Input {
if r == nil {
return nil
}
input := NewInputFromConn(c)
input.Response = NewResponse(r)
return input
}
type Client struct {
Network string `json:"network"`
IP string `json:"ip"`
Port int `json:"int"`
}
func (i *Client) logValues(log logger.Structured) logger.Structured {
if i != nil {
log = log.Values(logger.Values{
"client_network": i.Network,
"client_ip": i.IP,
"client_port": i.Port,
})
}
return log
}
func NewClient(network, address string) *Client {
if host, port, err := net.SplitHostPort(address); err == nil {
p, _ := net.LookupPort(network, port)
return &Client{
Network: network,
IP: host,
Port: p,
}
}
return &Client{
Network: network,
IP: address,
}
}
func NewClientFromConn(c net.Conn) *Client {
if c == nil {
return nil
}
return NewClientFromAddr(c.RemoteAddr())
}
func NewClientFromAddr(addr net.Addr) *Client {
switch addr := addr.(type) {
case *net.TCPAddr:
return &Client{
Network: addr.Network(),
IP: addr.IP.String(),
Port: addr.Port,
}
case *net.UDPAddr:
return &Client{
Network: addr.Network(),
IP: addr.IP.String(),
Port: addr.Port,
}
case *net.IPAddr:
return &Client{
Network: addr.Network(),
IP: addr.IP.String(),
}
default:
if host, port, err := net.SplitHostPort(addr.String()); err == nil {
return &Client{
Network: addr.Network(),
IP: host,
Port: func() int { p, _ := net.LookupPort(addr.Network(), port); return p }(),
}
}
return &Client{
Network: addr.Network(),
IP: addr.String(),
}
}
}
type TLS struct {
Version string `json:"version"`
CipherSuite string `json:"cipher_suite"`
ServerName string `json:"server_name"`
Certificates []*Certificate `json:"certificates"`
}
func (i *TLS) logValues(log logger.Structured) logger.Structured {
if i != nil {
cns := make([]string, len(i.Certificates))
for j, cert := range i.Certificates {
cns[j] = cert.Subject.CommonName
}
log = log.Values(logger.Values{
"tls_version": i.Version,
"tls_cipher": i.CipherSuite,
"tls_server_name": i.ServerName,
"tls_certificates": cns,
})
}
return log
}
func NewTLS(state *tls.ConnectionState) *TLS {
if state == nil {
return nil
}
tls := &TLS{
Version: tls.VersionName(state.Version),
CipherSuite: tls.CipherSuiteName(state.CipherSuite),
ServerName: state.ServerName,
}
for _, cert := range state.PeerCertificates {
if cert := NewCertificate(cert); cert != nil {
tls.Certificates = append(tls.Certificates, cert)
}
}
return tls
}
type tlsConnectionStater interface {
ConnectionState() tls.ConnectionState
}
func NewTLSFromConn(c net.Conn) *TLS {
if c == nil {
return nil
}
if s, ok := c.(tlsConnectionStater); ok {
cs := s.ConnectionState()
return NewTLS(&cs)
}
return nil
}
type Certificate struct {
SerialNumber string `json:"serial_number"`
Subject PKIXName `json:"subject"`
Issuer PKIXName `json:"issuer"`
NotBefore int64 `json:"not_before"`
NotAfter int64 `json:"not_after"`
}
func NewCertificate(cert *x509.Certificate) *Certificate {
if cert == nil {
return nil
}
return &Certificate{
SerialNumber: cert.SerialNumber.String(),
Subject: MakePKIXName(cert.Subject),
Issuer: MakePKIXName(cert.Issuer),
NotBefore: cert.NotBefore.UnixNano(),
NotAfter: cert.NotAfter.UnixNano(),
}
}
type PKIXName struct {
CommonName string `json:"cn,omitempty"`
Country string `json:"country,omitempty"`
Organization string `json:"organization,omitempty"`
OrganizationalUnit string `json:"ou,omitempty"`
Locality string `json:"locality,omitempty"`
Province string `json:"province,omitempty"`
StreetAddress string `json:"address,omitempty"`
PostalCode string `json:"postalcode,omitempty"`
}
func MakePKIXName(name pkix.Name) PKIXName {
return PKIXName{
CommonName: name.CommonName,
Country: pick(name.Country...),
Organization: pick(name.Organization...),
OrganizationalUnit: pick(name.OrganizationalUnit...),
Locality: pick(name.Locality...),
Province: pick(name.Province...),
StreetAddress: pick(name.StreetAddress...),
PostalCode: pick(name.PostalCode...),
}
}
// Request represents an HTTP request.
type Request struct {
Method string `json:"method"`
URL *URL `json:"url"`
Proto string `json:"proto"`
Header map[string]string `json:"header"`
Host string `json:"host"`
Port int `json:"port"`
RequestURI string `json:"request_uri"`
}
func (i *Request) logValues(log logger.Structured) logger.Structured {
if i != nil {
log = log.Values(logger.Values{
"request_method": i.Method,
"request_url": i.URL.String(),
"request_proto": i.Proto,
"request_header": i.Header,
"request_host": i.Host,
"request_port": i.Port,
})
}
return log
}
func NewRequest(r *http.Request) *Request {
if r == nil {
return nil
}
header := make(map[string]string)
for key := range r.Header {
header[key] = r.Header.Get(key)
}
host, portName, err := net.SplitHostPort(r.URL.Host)
if err != nil {
host = netutil.Host(r.URL.Host)
portName = "80"
if r.URL.Scheme == "https" || r.URL.Scheme == "wss" || r.TLS != nil {
portName = "443"
}
}
var port int
if port, err = strconv.Atoi(portName); err != nil {
port, _ = net.LookupPort("tcp", portName)
}
return &Request{
Method: r.Method,
URL: NewURL(r.URL),
Proto: r.Proto,
Header: header,
Host: host,
Port: port,
RequestURI: r.RequestURI,
}
}
// Response represents an HTTP response.
type Response struct {
Status string `json:"status"`
StatusCode int `json:"status_code"`
Proto string `json:"proto"`
Header map[string]string `json:"header"`
ContentLength int64 `json:"content_length"`
Close bool `json:"close"`
Request *Request `json:"request"`
TLS *TLS `json:"tls"`
}
func (i *Response) logValues(log logger.Structured) logger.Structured {
if i != nil {
log = log.Values(logger.Values{
"response_status": i.StatusCode,
"response_proto": i.Proto,
"response_header": i.Header,
"response_close": i.Close,
"response_tls": i.TLS != nil,
})
}
return log
}
func NewResponse(r *http.Response) *Response {
if r == nil {
return nil
}
header := make(map[string]string)
for key := range r.Header {
header[key] = r.Header.Get(key)
}
return &Response{
Status: r.Status,
StatusCode: r.StatusCode,
Proto: r.Proto,
Header: header,
ContentLength: r.ContentLength,
Close: r.Close,
Request: NewRequest(r.Request),
TLS: NewTLS(r.TLS),
}
}
type URL struct {
Scheme string `json:"scheme"`
Host string `json:"host"`
Path string `json:"path"`
Query map[string]string `json:"query"`
}
func (i *URL) String() string {
if i == nil {
return "<nil>"
}
s := fmt.Sprintf("%s://%s%s", i.Scheme, i.Host, i.Path)
if len(i.Query) > 0 {
s += "?"
for k, v := range i.Query {
s += k + "=" + url.QueryEscape(v)
}
}
return s
}
func ParseURL(rawurl string) (*URL, error) {
parsed, err := url.Parse(rawurl)
if err != nil {
return nil, err
}
return NewURL(parsed), nil
}
func NewURL(url *url.URL) *URL {
if url == nil {
return nil
}
query := make(map[string]string)
for key, values := range url.Query() {
if len(values) > 0 {
query[key] = values[0]
}
}
return &URL{
Scheme: url.Scheme,
Host: url.Host,
Path: url.Path,
Query: query,
}
}
func pick(values ...string) string {
for _, v := range values {
if v != "" {
return v
}
}
return ""
}

189
policy/policy.go Normal file
View File

@@ -0,0 +1,189 @@
package policy
import (
"bytes"
"context"
"errors"
"fmt"
"html/template"
"io"
"net/http"
"os"
"github.com/go-viper/mapstructure/v2"
"github.com/open-policy-agent/opa/v1/rego"
regoprint "github.com/open-policy-agent/opa/v1/topdown/print"
"git.maze.io/maze/styx/logger"
proxy "git.maze.io/maze/styx/proxy"
)
const DefaultPackageName = "styx"
var ErrNoResult = errors.New("policy: no result")
type Policy struct {
name string
options []func(*rego.Rego)
}
func New(name, pkg string) (*Policy, error) {
p := &Policy{
name: name,
options: newRego(rego.Load([]string{name}, nil), pkg),
}
if _, err := p.Query(&Input{}); err != nil {
return nil, err
}
return p, nil
}
func NewFromString(module, pkg string) (*Policy, error) {
p := &Policy{
name: "<inline>",
options: newRego(rego.Module("styx", module), pkg),
}
if _, err := p.Query(&Input{}); err != nil {
return nil, err
}
return p, nil
}
func newRego(option func(*rego.Rego), pkg string) []func(*rego.Rego) {
if pkg == "" {
pkg = DefaultPackageName
}
return []func(*rego.Rego){
rego.Dump(os.Stderr),
rego.Query("data." + pkg),
rego.Strict(true),
rego.Function2(&rego.Function{
Name: "styx.in_domains",
Decl: domainContainsDecl,
Nondeterministic: true,
}, domainContainsImpl),
rego.Function2(&rego.Function{
Name: "styx.in_networks",
Decl: networkContainsDecl,
Nondeterministic: true,
}, networkContainsImpl),
rego.PrintHook(printHook{}),
option,
}
}
type printHook struct{}
func (printHook) Print(ctx regoprint.Context, message string) error {
logger.StandardLog.Values(logger.Values{
"where": fmt.Sprintf("%s:%d,%d", ctx.Location.File, ctx.Location.Row, ctx.Location.Col),
"from": string(ctx.Location.Text),
}).Debug(message)
return nil
}
type Result struct {
// Reject signals explicit rejection.
Reject int `json:"reject" mapstructure:"reject"`
// Permit signals explicit permission.
Permit *bool `json:"permit" mapstructure:"permit"`
// Redirect to this URL.
Redirect string `json:"redirect" mapstructure:"redirect"`
// Template to render as response body.
Template string `json:"template" mapstructure:"template"`
// Errors contains error messages.
Errors []string `json:"errors" mapstructure:"errors,omitempty"`
}
func (r *Result) Response(ctx proxy.Context) (*http.Response, error) {
for _, text := range r.Errors {
logger.StandardLog.Values(logger.Values{
"id": ctx.ID(),
"client": ctx.RemoteAddr().String(),
}).Err(errors.New(text)).Warn("Error from policy")
}
switch {
case r.Redirect != "":
response := proxy.NewResponse(http.StatusFound, nil, ctx.Request())
response.Header.Set("Server", "styx")
response.Header.Set(proxy.HeaderLocation, r.Redirect)
return response, nil
case r.Template != "":
b := new(bytes.Buffer)
t, err := template.New("policy").ParseFiles(r.Template)
if err != nil {
return nil, err
}
if err = t.Execute(b, map[string]any{"context": ctx}); err != nil {
return nil, err
}
response := proxy.NewResponse(http.StatusFound, io.NopCloser(b), ctx.Request())
response.Header.Set("Server", "styx")
response.Header.Set(proxy.HeaderContentType, "text/html")
return response, nil
case r.Reject > 0:
body := io.NopCloser(bytes.NewBufferString(http.StatusText(r.Reject)))
response := proxy.NewResponse(r.Reject, body, ctx.Request())
response.Header.Set(proxy.HeaderContentType, "text/plain")
return response, nil
case r.Permit != nil && !*r.Permit:
body := io.NopCloser(bytes.NewBufferString(http.StatusText(http.StatusForbidden)))
response := proxy.NewResponse(http.StatusForbidden, body, ctx.Request())
response.Header.Set(proxy.HeaderContentType, "text/plain")
return response, nil
default:
return nil, nil
}
}
func (p *Policy) Query(input *Input) (*Result, error) {
/*
e := json.NewEncoder(os.Stdout)
e.SetIndent("", " ")
e.Encode(doc)
*/
log := logger.StandardLog.Value("policy", p.name)
log.Trace("Evaluating policy")
r := rego.New(append(p.options, rego.Input(input))...)
ctx := context.Background()
/*
query, err := p.rego.PrepareForEval(ctx)
if err != nil {
return nil, err
}
rs, err := query.Eval(ctx, rego.EvalInput(input))
if err != nil {
return nil, err
}
*/
rs, err := r.Eval(ctx)
if err != nil {
return nil, err
}
if len(rs) == 0 || len(rs[0].Expressions) == 0 {
return nil, ErrNoResult
}
result := &Result{}
for _, expr := range rs[0].Expressions {
if m, ok := expr.Value.(map[string]any); ok {
log.Values(m).Trace("Policy result expression")
if err = mapstructure.Decode(m, result); err != nil {
return nil, err
}
}
}
return result, nil
}