Checkpoint
This commit is contained in:
@@ -15,17 +15,25 @@ import (
|
||||
"github.com/open-policy-agent/opa/v1/types"
|
||||
|
||||
"git.maze.io/maze/styx/dataset"
|
||||
"git.maze.io/maze/styx/internal/timeutil"
|
||||
"git.maze.io/maze/styx/logger"
|
||||
)
|
||||
|
||||
var netLookupIPAddrDecl = types.NewFunction(
|
||||
var lookupIPAddrFunc = ®o.Function{
|
||||
Name: "styx.lookup_ip_addr",
|
||||
Decl: lookupIPAddrDecl,
|
||||
Memoize: true,
|
||||
Nondeterministic: true,
|
||||
}
|
||||
|
||||
var lookupIPAddrDecl = types.NewFunction(
|
||||
types.Args(
|
||||
types.Named("name", types.S).Description("Host name to lookup"),
|
||||
),
|
||||
types.Named("result", types.SetOfStr).Description("set(string) of IP address"),
|
||||
)
|
||||
|
||||
func netLookupIPAddrImpl(bc rego.BuiltinContext, nameTerm *ast.Term) (*ast.Term, error) {
|
||||
func lookupIPAddr(bc rego.BuiltinContext, nameTerm *ast.Term) (*ast.Term, error) {
|
||||
log := logger.StandardLog.Value("func", "styx.lookup_ip_addr")
|
||||
log.Trace("Call function")
|
||||
|
||||
@@ -61,6 +69,57 @@ func netLookupIPAddrImpl(bc rego.BuiltinContext, nameTerm *ast.Term) (*ast.Term,
|
||||
return ast.SetTerm(terms...), nil
|
||||
}
|
||||
|
||||
var timebetweenFunc = ®o.Function{
|
||||
Name: "styx.time_between",
|
||||
Decl: timeBetweenDecl,
|
||||
Nondeterministic: false,
|
||||
}
|
||||
|
||||
var timeBetweenDecl = types.NewFunction(
|
||||
types.Args(
|
||||
types.Named("start", types.S).Description("Start time"),
|
||||
types.Named("end", types.S).Description("End time"),
|
||||
),
|
||||
types.Named("result", types.B).Description("`true` if the current local time is between `start` and `end`"),
|
||||
)
|
||||
|
||||
func timeBetween(bc rego.BuiltinContext, startTerm, endTerm *ast.Term) (*ast.Term, error) {
|
||||
log := logger.StandardLog.Value("func", "styx.time_between")
|
||||
log.Trace("Call function")
|
||||
|
||||
start, err := parseTimeTerm(startTerm)
|
||||
if err != nil {
|
||||
log.Err(err).Debug("Invalid start time")
|
||||
return nil, err
|
||||
}
|
||||
end, err := parseTimeTerm(endTerm)
|
||||
if err != nil {
|
||||
log.Err(err).Debug("Invalid end time")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := timeutil.Now()
|
||||
if start.Before(end) {
|
||||
return ast.BooleanTerm((now.Eq(start) || now.After(start)) && now.Before(end)), nil
|
||||
}
|
||||
return ast.BooleanTerm(now.Eq(end) || now.After(end) || now.Before(start)), nil
|
||||
}
|
||||
|
||||
func parseTimeTerm(term *ast.Term) (timeutil.Time, error) {
|
||||
timeArg, ok := term.Value.(ast.String)
|
||||
if !ok {
|
||||
return timeutil.Time{}, errors.New("expected string argument")
|
||||
}
|
||||
return timeutil.ParseTime(strings.Trim(timeArg.String(), `"`))
|
||||
}
|
||||
|
||||
var domainContainsFunc = ®o.Function{
|
||||
Name: "styx.domains_contain",
|
||||
Decl: domainContainsDecl,
|
||||
Memoize: true,
|
||||
Nondeterministic: true,
|
||||
}
|
||||
|
||||
var domainContainsDecl = types.NewFunction(
|
||||
types.Args(
|
||||
types.Named("list", types.S).Description("Domain list to check against"),
|
||||
@@ -69,8 +128,8 @@ var domainContainsDecl = types.NewFunction(
|
||||
types.Named("result", types.B).Description("`true` if `name` is contained within `list`"),
|
||||
)
|
||||
|
||||
func domainContainsImpl(bc rego.BuiltinContext, listTerm, nameTerm *ast.Term) (*ast.Term, error) {
|
||||
log := logger.StandardLog.Value("func", "styx.in_domains")
|
||||
func domainContains(bc rego.BuiltinContext, listTerm, nameTerm *ast.Term) (*ast.Term, error) {
|
||||
log := logger.StandardLog.Value("func", "styx.domains_contain")
|
||||
log.Trace("Call function")
|
||||
|
||||
list, err := parseDomainListTerm(listTerm)
|
||||
@@ -91,6 +150,13 @@ func domainContainsImpl(bc rego.BuiltinContext, listTerm, nameTerm *ast.Term) (*
|
||||
return ast.BooleanTerm(list.Contains(name)), nil
|
||||
}
|
||||
|
||||
var networkContainsFunc = ®o.Function{
|
||||
Name: "styx.networks_contain",
|
||||
Decl: networkContainsDecl,
|
||||
Memoize: true,
|
||||
Nondeterministic: true,
|
||||
}
|
||||
|
||||
var networkContainsDecl = types.NewFunction(
|
||||
types.Args(
|
||||
types.Named("list", types.S).Description("Network list to check against"),
|
||||
@@ -99,8 +165,8 @@ var networkContainsDecl = types.NewFunction(
|
||||
types.Named("result", types.B).Description("`true` if `ip` is contained within `list`"),
|
||||
)
|
||||
|
||||
func networkContainsImpl(bc rego.BuiltinContext, listTerm, ipTerm *ast.Term) (*ast.Term, error) {
|
||||
log := logger.StandardLog.Value("func", "styx.in_networks")
|
||||
func networkContains(bc rego.BuiltinContext, listTerm, ipTerm *ast.Term) (*ast.Term, error) {
|
||||
log := logger.StandardLog.Value("func", "styx.networks_contain")
|
||||
|
||||
list, err := parseNetworkListTerm(listTerm)
|
||||
if err != nil {
|
||||
|
@@ -1,9 +1,12 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"git.maze.io/maze/styx/ca"
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
"git.maze.io/maze/styx/logger"
|
||||
proxy "git.maze.io/maze/styx/proxy"
|
||||
@@ -24,6 +27,7 @@ func NewRequestHandler(p *Policy) proxy.RequestHandler {
|
||||
log.Err(err).Error("Error generating response")
|
||||
return nil, nil
|
||||
}
|
||||
log.Debug("Replacing HTTP response from policy")
|
||||
return nil, r
|
||||
})
|
||||
}
|
||||
@@ -47,21 +51,52 @@ func NewDialHandler(p *Policy) proxy.DialHandler {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
c := netutil.NewLoopback()
|
||||
// Create a fake loopback connection
|
||||
pipe := netutil.NewLoopback()
|
||||
|
||||
go func(c net.Conn) {
|
||||
s := &http.Server{
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
r.Write(w)
|
||||
}),
|
||||
defer func() { _ = c.Close() }()
|
||||
if req.URL.Scheme == "https" || req.URL.Scheme == "wss" || netutil.Port(req.URL.Host) == 443 {
|
||||
c = maybeUpgradeToTLS(c, ctx, req, log)
|
||||
}
|
||||
_ = s.Serve(&netutil.AcceptOnce{Conn: c})
|
||||
}(c.Server)
|
||||
|
||||
return c.Client, nil
|
||||
br := bufio.NewReader(c)
|
||||
if _, err := http.ReadRequest(br); err != nil {
|
||||
log.Err(err).Warn("Malformed HTTP request in MITM connection")
|
||||
}
|
||||
_ = r.Write(c)
|
||||
}(pipe.Server)
|
||||
|
||||
return pipe.Client, nil
|
||||
})
|
||||
}
|
||||
|
||||
func maybeUpgradeToTLS(c net.Conn, ctx proxy.Context, req *http.Request, log logger.Structured) net.Conn {
|
||||
var ca ca.CertificateAuthority
|
||||
if caCtx, ok := ctx.(proxy.WithCertificateAuthority); ok {
|
||||
ca = caCtx.CertificateAuthority()
|
||||
}
|
||||
if ca == nil {
|
||||
return c
|
||||
}
|
||||
|
||||
secure := tls.Server(c, &tls.Config{
|
||||
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
log.Values(logger.Values{
|
||||
"cn": req.URL.Host,
|
||||
"names": hello.ServerName,
|
||||
}).Debug("Requesting certificate from CA")
|
||||
return ca.GetCertificate(netutil.Host(req.URL.Host), []string{hello.ServerName}, nil)
|
||||
},
|
||||
NextProtos: []string{"http/1.1"},
|
||||
})
|
||||
if err := secure.Handshake(); err != nil {
|
||||
log.Err(err).Warn("Failed to pretend secure HTTP")
|
||||
return c
|
||||
}
|
||||
return secure
|
||||
}
|
||||
|
||||
func NewForwardHandler(p *Policy) proxy.ForwardHandler {
|
||||
log := logger.StandardLog.Value("policy", p.name)
|
||||
return proxy.ForwardHandlerFunc(func(ctx proxy.Context, req *http.Request) (*http.Response, error) {
|
||||
@@ -72,7 +107,15 @@ func NewForwardHandler(p *Policy) proxy.ForwardHandler {
|
||||
log.Err(err).Error("Error evaulating policy")
|
||||
return nil, nil
|
||||
}
|
||||
return result.Response(ctx)
|
||||
r, err := result.Response(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).Error("Error generating response")
|
||||
return nil, err
|
||||
}
|
||||
if r != nil {
|
||||
log.Debug("Replacing HTTP response from policy")
|
||||
}
|
||||
return r, nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -80,6 +123,7 @@ func NewResponseHandler(p *Policy) proxy.ResponseHandler {
|
||||
log := logger.StandardLog.Value("policy", p.name)
|
||||
return proxy.ResponseHandlerFunc(func(ctx proxy.Context) *http.Response {
|
||||
input := NewInputFromResponse(ctx, ctx.Response())
|
||||
input.logValues(log).Trace("Running response handler")
|
||||
result, err := p.Query(input)
|
||||
if err != nil {
|
||||
log.Err(err).Error("Error evaulating policy")
|
||||
@@ -90,6 +134,9 @@ func NewResponseHandler(p *Policy) proxy.ResponseHandler {
|
||||
log.Err(err).Error("Error generating response")
|
||||
return nil
|
||||
}
|
||||
if r != nil {
|
||||
log.Debug("Replacing HTTP response from policy")
|
||||
}
|
||||
return r
|
||||
})
|
||||
}
|
||||
|
@@ -10,19 +10,26 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"git.maze.io/maze/styx/dataset"
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
"git.maze.io/maze/styx/logger"
|
||||
proxy "git.maze.io/maze/styx/proxy"
|
||||
)
|
||||
|
||||
// Input represents the input to the policy query.
|
||||
type Input struct {
|
||||
Client *Client `json:"client"`
|
||||
TLS *TLS `json:"tls"`
|
||||
Request *Request `json:"request"`
|
||||
Response *Response `json:"response"`
|
||||
Context map[string]any `json:"context"`
|
||||
Client *Client `json:"client"`
|
||||
Groups []*Group `json:"groups"`
|
||||
TLS *TLS `json:"tls"`
|
||||
Request *Request `json:"request"`
|
||||
Response *Response `json:"response"`
|
||||
}
|
||||
|
||||
func (i *Input) logValues(log logger.Structured) logger.Structured {
|
||||
if i.Context != nil {
|
||||
log = log.Values(i.Context)
|
||||
}
|
||||
log = i.Client.logValues(log)
|
||||
log = i.TLS.logValues(log)
|
||||
log = i.Request.logValues(log)
|
||||
@@ -34,10 +41,29 @@ func NewInputFromConn(c net.Conn) *Input {
|
||||
if c == nil {
|
||||
return new(Input)
|
||||
}
|
||||
return &Input{
|
||||
Client: NewClientFromConn(c),
|
||||
TLS: NewTLSFromConn(c),
|
||||
|
||||
input := &Input{
|
||||
Context: make(map[string]any),
|
||||
Client: NewClientFromConn(c),
|
||||
TLS: NewTLSFromConn(c),
|
||||
}
|
||||
|
||||
if wcl, ok := c.(dataset.WithClient); ok {
|
||||
client, err := wcl.Client()
|
||||
if err == nil {
|
||||
input.Context["client_id"] = client.ID
|
||||
input.Context["client_description"] = client.Description
|
||||
input.Context["groups"] = client.Groups
|
||||
}
|
||||
}
|
||||
|
||||
if ctx, ok := c.(proxy.Context); ok {
|
||||
input.Context["local"] = NewClientFromAddr(ctx.LocalAddr())
|
||||
input.Context["bytes_rx"] = ctx.BytesRead()
|
||||
input.Context["bytes_tx"] = ctx.BytesSent()
|
||||
}
|
||||
|
||||
return input
|
||||
}
|
||||
|
||||
func NewInputFromRequest(c net.Conn, r *http.Request) *Input {
|
||||
@@ -131,6 +157,10 @@ func NewClientFromAddr(addr net.Addr) *Client {
|
||||
}
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type TLS struct {
|
||||
Version string `json:"version"`
|
||||
CipherSuite string `json:"cipher_suite"`
|
||||
|
@@ -67,24 +67,10 @@ func newRego(option func(*rego.Rego), pkg string) []func(*rego.Rego) {
|
||||
rego.Query("data." + pkg),
|
||||
rego.Strict(true),
|
||||
rego.Capabilities(capabilities),
|
||||
rego.Function2(®o.Function{
|
||||
Name: "styx.in_domains",
|
||||
Decl: domainContainsDecl,
|
||||
Memoize: true,
|
||||
Nondeterministic: true,
|
||||
}, domainContainsImpl),
|
||||
rego.Function2(®o.Function{
|
||||
Name: "styx.in_networks",
|
||||
Decl: networkContainsDecl,
|
||||
Memoize: true,
|
||||
Nondeterministic: true,
|
||||
}, networkContainsImpl),
|
||||
rego.Function1(®o.Function{
|
||||
Name: "styx.lookup_ip_addr", // override builtin
|
||||
Decl: netLookupIPAddrDecl,
|
||||
Memoize: true,
|
||||
Nondeterministic: true,
|
||||
}, netLookupIPAddrImpl),
|
||||
rego.Function2(domainContainsFunc, domainContains),
|
||||
rego.Function2(networkContainsFunc, networkContains),
|
||||
rego.Function1(lookupIPAddrFunc, lookupIPAddr),
|
||||
rego.Function2(timebetweenFunc, timeBetween),
|
||||
rego.PrintHook(printHook{}),
|
||||
option,
|
||||
}
|
||||
@@ -128,16 +114,20 @@ func (r *Result) Response(ctx proxy.Context) (*http.Response, error) {
|
||||
|
||||
switch {
|
||||
case r.Redirect != "":
|
||||
log.Value("location", r.Redirect).Trace("Creating a HTTP redirect response")
|
||||
response := proxy.NewResponse(http.StatusFound, nil, ctx.Request())
|
||||
response.Header.Set("Server", "styx")
|
||||
response.Header.Set(proxy.HeaderLocation, r.Redirect)
|
||||
return response, nil
|
||||
|
||||
case r.Template != "":
|
||||
log = log.Value("template", r.Template)
|
||||
log.Trace("Creating a HTTP template response")
|
||||
|
||||
b := new(bytes.Buffer)
|
||||
t, err := template.New(filepath.Base(r.Template)).ParseFiles(r.Template)
|
||||
if err != nil {
|
||||
log.Value("template", r.Template).Err(err).Warn("Error loading template in response")
|
||||
log.Err(err).Warn("Error loading template in response")
|
||||
return nil, err
|
||||
}
|
||||
t = t.Funcs(template.FuncMap{
|
||||
@@ -149,7 +139,7 @@ func (r *Result) Response(ctx proxy.Context) (*http.Response, error) {
|
||||
"Response": ctx.Response(),
|
||||
"Errors": r.Errors,
|
||||
}); err != nil {
|
||||
log.Value("template", r.Template).Err(err).Warn("Error rendering template response")
|
||||
log.Err(err).Warn("Error rendering template response")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -159,46 +149,34 @@ func (r *Result) Response(ctx proxy.Context) (*http.Response, error) {
|
||||
return response, nil
|
||||
|
||||
case r.Reject > 0:
|
||||
log.Value("code", r.Reject).Trace("Creating a HTTP reject response")
|
||||
body := io.NopCloser(bytes.NewBufferString(http.StatusText(r.Reject)))
|
||||
response := proxy.NewResponse(r.Reject, body, ctx.Request())
|
||||
response.Header.Set(proxy.HeaderContentType, "text/plain")
|
||||
return response, nil
|
||||
|
||||
case r.Permit != nil && !*r.Permit:
|
||||
log.Trace("Creating a HTTP reject response due to explicit not permit")
|
||||
body := io.NopCloser(bytes.NewBufferString(http.StatusText(http.StatusForbidden)))
|
||||
response := proxy.NewResponse(http.StatusForbidden, body, ctx.Request())
|
||||
response.Header.Set(proxy.HeaderContentType, "text/plain")
|
||||
return response, nil
|
||||
|
||||
default:
|
||||
log.Trace("Not creating a HTTP response")
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Policy) Query(input *Input) (*Result, error) {
|
||||
/*
|
||||
e := json.NewEncoder(os.Stdout)
|
||||
e.SetIndent("", " ")
|
||||
e.Encode(doc)
|
||||
*/
|
||||
|
||||
log := logger.StandardLog.Value("policy", p.name)
|
||||
log.Trace("Evaluating policy")
|
||||
|
||||
r := rego.New(append(p.options, rego.Input(input))...)
|
||||
|
||||
ctx := context.Background()
|
||||
/*
|
||||
query, err := p.rego.PrepareForEval(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rs, err := query.Eval(ctx, rego.EvalInput(input))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
*/
|
||||
rs, err := r.Eval(ctx)
|
||||
var (
|
||||
rego = rego.New(append(p.options, rego.Input(input))...)
|
||||
ctx = context.Background()
|
||||
rs, err = rego.Eval(ctx)
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -208,6 +186,12 @@ func (p *Policy) Query(input *Input) (*Result, error) {
|
||||
result := &Result{}
|
||||
for _, expr := range rs[0].Expressions {
|
||||
if m, ok := expr.Value.(map[string]any); ok {
|
||||
// Remove private variables.
|
||||
for k := range m {
|
||||
if len(k) > 0 && k[0] == '_' {
|
||||
delete(m, k)
|
||||
}
|
||||
}
|
||||
log.Values(m).Trace("Policy result expression")
|
||||
if err = mapstructure.Decode(m, result); err != nil {
|
||||
return nil, err
|
||||
|
Reference in New Issue
Block a user