Checkpoint
This commit is contained in:
179
policy/func.go
Normal file
179
policy/func.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/v1/ast"
|
||||
"github.com/open-policy-agent/opa/v1/rego"
|
||||
"github.com/open-policy-agent/opa/v1/types"
|
||||
|
||||
"git.maze.io/maze/styx/dataset"
|
||||
"git.maze.io/maze/styx/logger"
|
||||
)
|
||||
|
||||
var domainContainsDecl = types.NewFunction(
|
||||
types.Args(
|
||||
types.Named("list", types.S).Description("Domain list to check against"),
|
||||
types.Named("name", types.S).Description("Host name to check"),
|
||||
),
|
||||
types.Named("result", types.B).Description("`true` if `name` is contained within `list`"),
|
||||
)
|
||||
|
||||
func domainContainsImpl(bc rego.BuiltinContext, listTerm, nameTerm *ast.Term) (*ast.Term, error) {
|
||||
log := logger.StandardLog.Value("func", "styx.in_domains")
|
||||
|
||||
list, err := parseDomainListTerm(listTerm)
|
||||
if err != nil {
|
||||
log.Err(err).Debug("Call function failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
name, err := parseStringTerm(nameTerm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Values(logger.Values{
|
||||
"list": listTerm.Value,
|
||||
"name": name,
|
||||
}).Trace("Calling function")
|
||||
return ast.BooleanTerm(list.Contains(name)), nil
|
||||
}
|
||||
|
||||
var networkContainsDecl = types.NewFunction(
|
||||
types.Args(
|
||||
types.Named("list", types.S).Description("Network list to check against"),
|
||||
types.Named("ip", types.S).Description("IP address to check"),
|
||||
),
|
||||
types.Named("result", types.B).Description("`true` if `ip` is contained within `list`"),
|
||||
)
|
||||
|
||||
func networkContainsImpl(bc rego.BuiltinContext, listTerm, ipTerm *ast.Term) (*ast.Term, error) {
|
||||
log := logger.StandardLog.Value("func", "styx.in_networks")
|
||||
|
||||
list, err := parseNetworkListTerm(listTerm)
|
||||
if err != nil {
|
||||
log.Err(err).Debug("Call function failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ip, err := parseIPTerm(ipTerm)
|
||||
if err != nil {
|
||||
log.Value("list", listTerm.Value).Err(err).Debug("Call function failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Values(logger.Values{
|
||||
"list": listTerm.Value,
|
||||
"ip": ip.String(),
|
||||
}).Trace("Calling function")
|
||||
return ast.BooleanTerm(list.Contains(ip)), nil
|
||||
}
|
||||
|
||||
func parseDomainListTerm(term *ast.Term) (*dataset.DomainTree, error) {
|
||||
nameArg, ok := term.Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, errors.New("expected string argument")
|
||||
}
|
||||
name := strings.Trim(nameArg.String(), `"`)
|
||||
|
||||
fn, ok := dataset.Domains[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no such domain list: %q", name)
|
||||
}
|
||||
|
||||
return fn, nil
|
||||
}
|
||||
|
||||
func parseNetworkListTerm(term *ast.Term) (*dataset.NetworkTree, error) {
|
||||
nameArg, ok := term.Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, errors.New("expected string argument")
|
||||
}
|
||||
name := strings.Trim(nameArg.String(), `"`)
|
||||
|
||||
fn, ok := dataset.Networks[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no such network list: %q", name)
|
||||
}
|
||||
|
||||
return fn, nil
|
||||
}
|
||||
|
||||
func parseStringTerm(term *ast.Term) (string, error) {
|
||||
ipArg, ok := term.Value.(ast.String)
|
||||
if !ok {
|
||||
return "", errors.New("expected string argument")
|
||||
}
|
||||
return strings.Trim(ipArg.String(), `"`), nil
|
||||
}
|
||||
|
||||
func parseIPTerm(term *ast.Term) (net.IP, error) {
|
||||
ipArg, ok := term.Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, errors.New("expected string argument")
|
||||
}
|
||||
ip := strings.Trim(ipArg.String(), `"`)
|
||||
if ip := net.ParseIP(ip); ip != nil {
|
||||
return ip, nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid IP address %q", ip)
|
||||
}
|
||||
|
||||
type ListReturner func() ([]string, error)
|
||||
|
||||
var (
|
||||
domains = map[string]ListReturner{}
|
||||
networks = map[string]ListReturner{}
|
||||
)
|
||||
|
||||
func AddDomainList(name string, fn ListReturner) {
|
||||
domains[name] = fn
|
||||
}
|
||||
|
||||
func AddNetworkList(name string, fn ListReturner) {
|
||||
networks[name] = fn
|
||||
}
|
||||
|
||||
func listLookupImpl(kind string, m map[string]ListReturner) func(rego.BuiltinContext, *ast.Term) (*ast.Term, error) {
|
||||
return func(bc rego.BuiltinContext, inf *ast.Term) (*ast.Term, error) {
|
||||
log := logger.StandardLog.Values(logger.V{
|
||||
"where": inf.Location.File + ":" + strconv.Itoa(inf.Location.Row) + "," + strconv.Itoa(inf.Location.Col),
|
||||
"func": kind,
|
||||
})
|
||||
|
||||
nameArg, ok := inf.Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, errors.New("expected string argument")
|
||||
}
|
||||
name := strings.Trim(nameArg.String(), `"`)
|
||||
|
||||
log = log.Value("type", name)
|
||||
log.Trace("Looking up list in policy")
|
||||
|
||||
fn, ok := m[name]
|
||||
if !ok {
|
||||
log.Error("No such list exists")
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
|
||||
list, err := fn()
|
||||
if err != nil {
|
||||
log.Err(err).Error("Error retrieving list")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
astList := make([]*ast.Term, 0, len(list))
|
||||
for _, item := range list {
|
||||
astList = append(astList, ast.StringTerm(item))
|
||||
}
|
||||
|
||||
log.Tracef("Returning list with %d items", len(astList))
|
||||
return ast.NewTerm(ast.NewArray(astList...)), nil
|
||||
}
|
||||
}
|
44
policy/handler.go
Normal file
44
policy/handler.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.maze.io/maze/styx/logger"
|
||||
proxy "git.maze.io/maze/styx/proxy"
|
||||
)
|
||||
|
||||
func NewRequestHandler(p *Policy) proxy.RequestHandler {
|
||||
log := logger.StandardLog.Value("policy", p.name)
|
||||
return proxy.RequestHandlerFunc(func(ctx proxy.Context) (*http.Request, *http.Response) {
|
||||
input := NewInputFromRequest(ctx, ctx.Request())
|
||||
input.logValues(log).Trace("Running request handler")
|
||||
result, err := p.Query(input)
|
||||
if err != nil {
|
||||
log.Err(err).Error("Error evaulating policy")
|
||||
return nil, nil
|
||||
}
|
||||
r, err := result.Response(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).Error("Error generating response")
|
||||
return nil, nil
|
||||
}
|
||||
return nil, r
|
||||
})
|
||||
}
|
||||
|
||||
func NewResponseHandler(p *Policy) proxy.ResponseHandler {
|
||||
return proxy.ResponseHandlerFunc(func(ctx proxy.Context) *http.Response {
|
||||
input := NewInputFromResponse(ctx, ctx.Response())
|
||||
result, err := p.Query(input)
|
||||
if err != nil {
|
||||
logger.StandardLog.Err(err).Error("Error evaulating policy")
|
||||
return nil
|
||||
}
|
||||
r, err := result.Response(ctx)
|
||||
if err != nil {
|
||||
logger.StandardLog.Err(err).Error("Error generating response")
|
||||
return nil
|
||||
}
|
||||
return r
|
||||
})
|
||||
}
|
394
policy/input.go
Normal file
394
policy/input.go
Normal file
@@ -0,0 +1,394 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
"git.maze.io/maze/styx/logger"
|
||||
)
|
||||
|
||||
// Input represents the input to the policy query.
|
||||
type Input struct {
|
||||
Client *Client `json:"client"`
|
||||
TLS *TLS `json:"tls"`
|
||||
Request *Request `json:"request"`
|
||||
Response *Response `json:"response"`
|
||||
}
|
||||
|
||||
func (i *Input) logValues(log logger.Structured) logger.Structured {
|
||||
log = i.Client.logValues(log)
|
||||
log = i.TLS.logValues(log)
|
||||
log = i.Request.logValues(log)
|
||||
log = i.Response.logValues(log)
|
||||
return log
|
||||
}
|
||||
|
||||
func NewInputFromConn(c net.Conn) *Input {
|
||||
if c == nil {
|
||||
return new(Input)
|
||||
}
|
||||
return &Input{
|
||||
Client: NewClientFromConn(c),
|
||||
TLS: NewTLSFromConn(c),
|
||||
}
|
||||
}
|
||||
|
||||
func NewInputFromRequest(c net.Conn, r *http.Request) *Input {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
input := NewInputFromConn(c)
|
||||
input.Request = NewRequest(r)
|
||||
return input
|
||||
}
|
||||
|
||||
func NewInputFromResponse(c net.Conn, r *http.Response) *Input {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
input := NewInputFromConn(c)
|
||||
input.Response = NewResponse(r)
|
||||
return input
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
Network string `json:"network"`
|
||||
IP string `json:"ip"`
|
||||
Port int `json:"int"`
|
||||
}
|
||||
|
||||
func (i *Client) logValues(log logger.Structured) logger.Structured {
|
||||
if i != nil {
|
||||
log = log.Values(logger.Values{
|
||||
"client_network": i.Network,
|
||||
"client_ip": i.IP,
|
||||
"client_port": i.Port,
|
||||
})
|
||||
}
|
||||
return log
|
||||
}
|
||||
|
||||
func NewClient(network, address string) *Client {
|
||||
if host, port, err := net.SplitHostPort(address); err == nil {
|
||||
p, _ := net.LookupPort(network, port)
|
||||
return &Client{
|
||||
Network: network,
|
||||
IP: host,
|
||||
Port: p,
|
||||
}
|
||||
}
|
||||
return &Client{
|
||||
Network: network,
|
||||
IP: address,
|
||||
}
|
||||
}
|
||||
|
||||
func NewClientFromConn(c net.Conn) *Client {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return NewClientFromAddr(c.RemoteAddr())
|
||||
}
|
||||
|
||||
func NewClientFromAddr(addr net.Addr) *Client {
|
||||
switch addr := addr.(type) {
|
||||
case *net.TCPAddr:
|
||||
return &Client{
|
||||
Network: addr.Network(),
|
||||
IP: addr.IP.String(),
|
||||
Port: addr.Port,
|
||||
}
|
||||
case *net.UDPAddr:
|
||||
return &Client{
|
||||
Network: addr.Network(),
|
||||
IP: addr.IP.String(),
|
||||
Port: addr.Port,
|
||||
}
|
||||
case *net.IPAddr:
|
||||
return &Client{
|
||||
Network: addr.Network(),
|
||||
IP: addr.IP.String(),
|
||||
}
|
||||
default:
|
||||
if host, port, err := net.SplitHostPort(addr.String()); err == nil {
|
||||
return &Client{
|
||||
Network: addr.Network(),
|
||||
IP: host,
|
||||
Port: func() int { p, _ := net.LookupPort(addr.Network(), port); return p }(),
|
||||
}
|
||||
}
|
||||
return &Client{
|
||||
Network: addr.Network(),
|
||||
IP: addr.String(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type TLS struct {
|
||||
Version string `json:"version"`
|
||||
CipherSuite string `json:"cipher_suite"`
|
||||
ServerName string `json:"server_name"`
|
||||
Certificates []*Certificate `json:"certificates"`
|
||||
}
|
||||
|
||||
func (i *TLS) logValues(log logger.Structured) logger.Structured {
|
||||
if i != nil {
|
||||
cns := make([]string, len(i.Certificates))
|
||||
for j, cert := range i.Certificates {
|
||||
cns[j] = cert.Subject.CommonName
|
||||
}
|
||||
log = log.Values(logger.Values{
|
||||
"tls_version": i.Version,
|
||||
"tls_cipher": i.CipherSuite,
|
||||
"tls_server_name": i.ServerName,
|
||||
"tls_certificates": cns,
|
||||
})
|
||||
}
|
||||
return log
|
||||
}
|
||||
|
||||
func NewTLS(state *tls.ConnectionState) *TLS {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
tls := &TLS{
|
||||
Version: tls.VersionName(state.Version),
|
||||
CipherSuite: tls.CipherSuiteName(state.CipherSuite),
|
||||
ServerName: state.ServerName,
|
||||
}
|
||||
for _, cert := range state.PeerCertificates {
|
||||
if cert := NewCertificate(cert); cert != nil {
|
||||
tls.Certificates = append(tls.Certificates, cert)
|
||||
}
|
||||
}
|
||||
return tls
|
||||
}
|
||||
|
||||
type tlsConnectionStater interface {
|
||||
ConnectionState() tls.ConnectionState
|
||||
}
|
||||
|
||||
func NewTLSFromConn(c net.Conn) *TLS {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if s, ok := c.(tlsConnectionStater); ok {
|
||||
cs := s.ConnectionState()
|
||||
return NewTLS(&cs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Certificate struct {
|
||||
SerialNumber string `json:"serial_number"`
|
||||
Subject PKIXName `json:"subject"`
|
||||
Issuer PKIXName `json:"issuer"`
|
||||
NotBefore int64 `json:"not_before"`
|
||||
NotAfter int64 `json:"not_after"`
|
||||
}
|
||||
|
||||
func NewCertificate(cert *x509.Certificate) *Certificate {
|
||||
if cert == nil {
|
||||
return nil
|
||||
}
|
||||
return &Certificate{
|
||||
SerialNumber: cert.SerialNumber.String(),
|
||||
Subject: MakePKIXName(cert.Subject),
|
||||
Issuer: MakePKIXName(cert.Issuer),
|
||||
NotBefore: cert.NotBefore.UnixNano(),
|
||||
NotAfter: cert.NotAfter.UnixNano(),
|
||||
}
|
||||
}
|
||||
|
||||
type PKIXName struct {
|
||||
CommonName string `json:"cn,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
Organization string `json:"organization,omitempty"`
|
||||
OrganizationalUnit string `json:"ou,omitempty"`
|
||||
Locality string `json:"locality,omitempty"`
|
||||
Province string `json:"province,omitempty"`
|
||||
StreetAddress string `json:"address,omitempty"`
|
||||
PostalCode string `json:"postalcode,omitempty"`
|
||||
}
|
||||
|
||||
func MakePKIXName(name pkix.Name) PKIXName {
|
||||
return PKIXName{
|
||||
CommonName: name.CommonName,
|
||||
Country: pick(name.Country...),
|
||||
Organization: pick(name.Organization...),
|
||||
OrganizationalUnit: pick(name.OrganizationalUnit...),
|
||||
Locality: pick(name.Locality...),
|
||||
Province: pick(name.Province...),
|
||||
StreetAddress: pick(name.StreetAddress...),
|
||||
PostalCode: pick(name.PostalCode...),
|
||||
}
|
||||
}
|
||||
|
||||
// Request represents an HTTP request.
|
||||
type Request struct {
|
||||
Method string `json:"method"`
|
||||
URL *URL `json:"url"`
|
||||
Proto string `json:"proto"`
|
||||
Header map[string]string `json:"header"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
RequestURI string `json:"request_uri"`
|
||||
}
|
||||
|
||||
func (i *Request) logValues(log logger.Structured) logger.Structured {
|
||||
if i != nil {
|
||||
log = log.Values(logger.Values{
|
||||
"request_method": i.Method,
|
||||
"request_url": i.URL.String(),
|
||||
"request_proto": i.Proto,
|
||||
"request_header": i.Header,
|
||||
"request_host": i.Host,
|
||||
"request_port": i.Port,
|
||||
})
|
||||
}
|
||||
return log
|
||||
}
|
||||
|
||||
func NewRequest(r *http.Request) *Request {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
header := make(map[string]string)
|
||||
for key := range r.Header {
|
||||
header[key] = r.Header.Get(key)
|
||||
}
|
||||
|
||||
host, portName, err := net.SplitHostPort(r.URL.Host)
|
||||
if err != nil {
|
||||
host = netutil.Host(r.URL.Host)
|
||||
portName = "80"
|
||||
if r.URL.Scheme == "https" || r.URL.Scheme == "wss" || r.TLS != nil {
|
||||
portName = "443"
|
||||
}
|
||||
}
|
||||
|
||||
var port int
|
||||
if port, err = strconv.Atoi(portName); err != nil {
|
||||
port, _ = net.LookupPort("tcp", portName)
|
||||
}
|
||||
|
||||
return &Request{
|
||||
Method: r.Method,
|
||||
URL: NewURL(r.URL),
|
||||
Proto: r.Proto,
|
||||
Header: header,
|
||||
Host: host,
|
||||
Port: port,
|
||||
RequestURI: r.RequestURI,
|
||||
}
|
||||
}
|
||||
|
||||
// Response represents an HTTP response.
|
||||
type Response struct {
|
||||
Status string `json:"status"`
|
||||
StatusCode int `json:"status_code"`
|
||||
Proto string `json:"proto"`
|
||||
Header map[string]string `json:"header"`
|
||||
ContentLength int64 `json:"content_length"`
|
||||
Close bool `json:"close"`
|
||||
Request *Request `json:"request"`
|
||||
TLS *TLS `json:"tls"`
|
||||
}
|
||||
|
||||
func (i *Response) logValues(log logger.Structured) logger.Structured {
|
||||
if i != nil {
|
||||
log = log.Values(logger.Values{
|
||||
"response_status": i.StatusCode,
|
||||
"response_proto": i.Proto,
|
||||
"response_header": i.Header,
|
||||
"response_close": i.Close,
|
||||
"response_tls": i.TLS != nil,
|
||||
})
|
||||
}
|
||||
return log
|
||||
}
|
||||
|
||||
func NewResponse(r *http.Response) *Response {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
header := make(map[string]string)
|
||||
for key := range r.Header {
|
||||
header[key] = r.Header.Get(key)
|
||||
}
|
||||
return &Response{
|
||||
Status: r.Status,
|
||||
StatusCode: r.StatusCode,
|
||||
Proto: r.Proto,
|
||||
Header: header,
|
||||
ContentLength: r.ContentLength,
|
||||
Close: r.Close,
|
||||
Request: NewRequest(r.Request),
|
||||
TLS: NewTLS(r.TLS),
|
||||
}
|
||||
}
|
||||
|
||||
type URL struct {
|
||||
Scheme string `json:"scheme"`
|
||||
Host string `json:"host"`
|
||||
Path string `json:"path"`
|
||||
Query map[string]string `json:"query"`
|
||||
}
|
||||
|
||||
func (i *URL) String() string {
|
||||
if i == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
|
||||
s := fmt.Sprintf("%s://%s%s", i.Scheme, i.Host, i.Path)
|
||||
if len(i.Query) > 0 {
|
||||
s += "?"
|
||||
for k, v := range i.Query {
|
||||
s += k + "=" + url.QueryEscape(v)
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func ParseURL(rawurl string) (*URL, error) {
|
||||
parsed, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewURL(parsed), nil
|
||||
}
|
||||
|
||||
func NewURL(url *url.URL) *URL {
|
||||
if url == nil {
|
||||
return nil
|
||||
}
|
||||
query := make(map[string]string)
|
||||
for key, values := range url.Query() {
|
||||
if len(values) > 0 {
|
||||
query[key] = values[0]
|
||||
}
|
||||
}
|
||||
return &URL{
|
||||
Scheme: url.Scheme,
|
||||
Host: url.Host,
|
||||
Path: url.Path,
|
||||
Query: query,
|
||||
}
|
||||
}
|
||||
|
||||
func pick(values ...string) string {
|
||||
for _, v := range values {
|
||||
if v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
189
policy/policy.go
Normal file
189
policy/policy.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/go-viper/mapstructure/v2"
|
||||
"github.com/open-policy-agent/opa/v1/rego"
|
||||
regoprint "github.com/open-policy-agent/opa/v1/topdown/print"
|
||||
|
||||
"git.maze.io/maze/styx/logger"
|
||||
proxy "git.maze.io/maze/styx/proxy"
|
||||
)
|
||||
|
||||
const DefaultPackageName = "styx"
|
||||
|
||||
var ErrNoResult = errors.New("policy: no result")
|
||||
|
||||
type Policy struct {
|
||||
name string
|
||||
options []func(*rego.Rego)
|
||||
}
|
||||
|
||||
func New(name, pkg string) (*Policy, error) {
|
||||
p := &Policy{
|
||||
name: name,
|
||||
options: newRego(rego.Load([]string{name}, nil), pkg),
|
||||
}
|
||||
if _, err := p.Query(&Input{}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func NewFromString(module, pkg string) (*Policy, error) {
|
||||
p := &Policy{
|
||||
name: "<inline>",
|
||||
options: newRego(rego.Module("styx", module), pkg),
|
||||
}
|
||||
if _, err := p.Query(&Input{}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func newRego(option func(*rego.Rego), pkg string) []func(*rego.Rego) {
|
||||
if pkg == "" {
|
||||
pkg = DefaultPackageName
|
||||
}
|
||||
return []func(*rego.Rego){
|
||||
rego.Dump(os.Stderr),
|
||||
rego.Query("data." + pkg),
|
||||
rego.Strict(true),
|
||||
rego.Function2(®o.Function{
|
||||
Name: "styx.in_domains",
|
||||
Decl: domainContainsDecl,
|
||||
Nondeterministic: true,
|
||||
}, domainContainsImpl),
|
||||
rego.Function2(®o.Function{
|
||||
Name: "styx.in_networks",
|
||||
Decl: networkContainsDecl,
|
||||
Nondeterministic: true,
|
||||
}, networkContainsImpl),
|
||||
rego.PrintHook(printHook{}),
|
||||
option,
|
||||
}
|
||||
}
|
||||
|
||||
type printHook struct{}
|
||||
|
||||
func (printHook) Print(ctx regoprint.Context, message string) error {
|
||||
logger.StandardLog.Values(logger.Values{
|
||||
"where": fmt.Sprintf("%s:%d,%d", ctx.Location.File, ctx.Location.Row, ctx.Location.Col),
|
||||
"from": string(ctx.Location.Text),
|
||||
}).Debug(message)
|
||||
return nil
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
// Reject signals explicit rejection.
|
||||
Reject int `json:"reject" mapstructure:"reject"`
|
||||
|
||||
// Permit signals explicit permission.
|
||||
Permit *bool `json:"permit" mapstructure:"permit"`
|
||||
|
||||
// Redirect to this URL.
|
||||
Redirect string `json:"redirect" mapstructure:"redirect"`
|
||||
|
||||
// Template to render as response body.
|
||||
Template string `json:"template" mapstructure:"template"`
|
||||
|
||||
// Errors contains error messages.
|
||||
Errors []string `json:"errors" mapstructure:"errors,omitempty"`
|
||||
}
|
||||
|
||||
func (r *Result) Response(ctx proxy.Context) (*http.Response, error) {
|
||||
for _, text := range r.Errors {
|
||||
logger.StandardLog.Values(logger.Values{
|
||||
"id": ctx.ID(),
|
||||
"client": ctx.RemoteAddr().String(),
|
||||
}).Err(errors.New(text)).Warn("Error from policy")
|
||||
}
|
||||
|
||||
switch {
|
||||
case r.Redirect != "":
|
||||
response := proxy.NewResponse(http.StatusFound, nil, ctx.Request())
|
||||
response.Header.Set("Server", "styx")
|
||||
response.Header.Set(proxy.HeaderLocation, r.Redirect)
|
||||
return response, nil
|
||||
|
||||
case r.Template != "":
|
||||
b := new(bytes.Buffer)
|
||||
t, err := template.New("policy").ParseFiles(r.Template)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = t.Execute(b, map[string]any{"context": ctx}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := proxy.NewResponse(http.StatusFound, io.NopCloser(b), ctx.Request())
|
||||
response.Header.Set("Server", "styx")
|
||||
response.Header.Set(proxy.HeaderContentType, "text/html")
|
||||
return response, nil
|
||||
|
||||
case r.Reject > 0:
|
||||
body := io.NopCloser(bytes.NewBufferString(http.StatusText(r.Reject)))
|
||||
response := proxy.NewResponse(r.Reject, body, ctx.Request())
|
||||
response.Header.Set(proxy.HeaderContentType, "text/plain")
|
||||
return response, nil
|
||||
|
||||
case r.Permit != nil && !*r.Permit:
|
||||
body := io.NopCloser(bytes.NewBufferString(http.StatusText(http.StatusForbidden)))
|
||||
response := proxy.NewResponse(http.StatusForbidden, body, ctx.Request())
|
||||
response.Header.Set(proxy.HeaderContentType, "text/plain")
|
||||
return response, nil
|
||||
|
||||
default:
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Policy) Query(input *Input) (*Result, error) {
|
||||
/*
|
||||
e := json.NewEncoder(os.Stdout)
|
||||
e.SetIndent("", " ")
|
||||
e.Encode(doc)
|
||||
*/
|
||||
|
||||
log := logger.StandardLog.Value("policy", p.name)
|
||||
log.Trace("Evaluating policy")
|
||||
|
||||
r := rego.New(append(p.options, rego.Input(input))...)
|
||||
|
||||
ctx := context.Background()
|
||||
/*
|
||||
query, err := p.rego.PrepareForEval(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rs, err := query.Eval(ctx, rego.EvalInput(input))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
*/
|
||||
rs, err := r.Eval(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(rs) == 0 || len(rs[0].Expressions) == 0 {
|
||||
return nil, ErrNoResult
|
||||
}
|
||||
result := &Result{}
|
||||
for _, expr := range rs[0].Expressions {
|
||||
if m, ok := expr.Value.(map[string]any); ok {
|
||||
log.Values(m).Trace("Policy result expression")
|
||||
if err = mapstructure.Decode(m, result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
Reference in New Issue
Block a user