Better trie implementations

This commit is contained in:
2025-10-08 20:57:13 +02:00
parent 5f0f4aa96b
commit 582163d4be
26 changed files with 2482 additions and 122 deletions

View File

@@ -45,8 +45,11 @@ type Context interface {
// Response is the response that will be sent back to the client.
Response() *http.Response
// Logger for this context.
Logger() logger.Structured
// Client group.
Client() (dataset.Client, error)
Storage() dataset.Storage
}
type WithCertificateAuthority interface {
@@ -91,11 +94,10 @@ type proxyContext struct {
idleTimeout time.Duration
ca ca.CertificateAuthority
storage dataset.Storage
client dataset.Client
}
// NewContext returns an initialized context for the provided [net.Conn].
func NewContext(c net.Conn) Context {
func NewContext(c net.Conn, storage dataset.Storage) Context {
if c, ok := c.(*proxyContext); ok {
return c
}
@@ -106,12 +108,13 @@ func NewContext(c net.Conn) Context {
cr := &countingReader{reader: c}
cw := &countingWriter{writer: c}
return &proxyContext{
Conn: c,
id: binary.BigEndian.Uint64(b),
cr: cr,
br: bufio.NewReader(cr),
cw: cw,
res: &http.Response{StatusCode: 200},
Conn: c,
id: binary.BigEndian.Uint64(b),
cr: cr,
br: bufio.NewReader(cr),
cw: cw,
res: &http.Response{StatusCode: 200},
storage: storage,
}
}
@@ -128,7 +131,7 @@ func (c *proxyContext) AccessLogEntry() logger.Structured {
return entry
}
func (c *proxyContext) LogEntry() logger.Structured {
func (c *proxyContext) Logger() logger.Structured {
var id [8]byte
binary.BigEndian.PutUint64(id[:], c.id)
return ServerLog.Values(logger.Values{
@@ -234,24 +237,8 @@ func (c *proxyContext) CertificateAuthority() ca.CertificateAuthority {
return c.ca
}
func (c *proxyContext) Client() (dataset.Client, error) {
if c.storage == nil {
return dataset.Client{}, dataset.ErrNotExist{Object: "client"}
}
if !c.client.CreatedAt.Equal(time.Time{}) {
return c.client, nil
}
var err error
switch addr := c.Conn.RemoteAddr().(type) {
case *net.TCPAddr:
c.client, err = c.storage.ClientByIP(addr.IP)
case *net.UDPAddr:
c.client, err = c.storage.ClientByIP(addr.IP)
default:
err = dataset.ErrNotExist{Object: "client"}
}
return c.client, err
func (c *proxyContext) Storage() dataset.Storage {
return c.storage
}
var _ Context = (*proxyContext)(nil)

View File

@@ -144,18 +144,21 @@ func Transparent(port int) ConnHandler {
return nctx, nil
}
b := new(bytes.Buffer)
hello, err := cryptutil.ReadClientHello(io.TeeReader(netutil.ReadOnlyConn{Reader: ctx.br}, b))
var (
b = new(bytes.Buffer)
hello, err = cryptutil.ReadClientHello(io.TeeReader(netutil.ReadOnlyConn{Reader: ctx.br}, b))
log = ctx.Logger()
)
if err != nil {
if _, ok := err.(tls.RecordHeaderError); !ok {
ctx.LogEntry().Err(err).Value("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error")
log.Err(err).Value("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error")
return nil, err
}
// Not a TLS connection, moving on to regular HTTP request handling...
ctx.LogEntry().Debug("HTTP connection on transparent port")
log.Debug("HTTP connection on transparent port")
ctx.transparent = port
} else {
ctx.LogEntry().Value("target", hello.ServerName).Debug("TLS connection on transparent port")
log.Value("target", hello.ServerName).Debug("TLS connection on transparent port")
ctx.transparent = port
ctx.transparentTLS = true
ctx.serverName = hello.ServerName

178
proxy/policy.go Normal file
View File

@@ -0,0 +1,178 @@
package proxy
import (
"errors"
"net"
"net/netip"
"strings"
"git.maze.io/maze/styx/dataset"
"git.maze.io/maze/styx/internal/netutil"
"github.com/open-policy-agent/opa/v1/ast"
"github.com/open-policy-agent/opa/v1/rego"
"github.com/open-policy-agent/opa/v1/types"
)
// PolicyQueryOptions generates the Rego query functions for the provided [Context].
func PolicyQueryOptions(ctx Context) (options []func(*rego.Rego)) {
var (
log = ctx.Logger()
storage = ctx.Storage()
)
addr, err := netip.ParseAddr(netutil.Host(ctx.RemoteAddr().String()))
if err != nil {
log.Err(err).Error("Error resolving remote address")
return
}
client, err := storage.ClientByAddr(addr)
if err != nil {
log.Err(err).Warn("Error resolving client")
return
}
var (
permitDomains []*dataset.DomainTrie
rejectDomains []*dataset.DomainTrie
permitNetworks []*dataset.NetworkTrie
rejectNetworks []*dataset.NetworkTrie
)
for _, group := range client.Groups {
lists, err := storage.ListsByGroup(group)
if err != nil {
log.Err(err).Warn("Error resolving lists")
return
}
for _, list := range lists {
switch list.Type {
case dataset.ListTypeDomain:
trie, err := list.Domains()
if err != nil {
log.Err(err).Warn("Error resolving domain trie")
}
if list.Permit {
permitDomains = append(permitDomains, trie)
} else {
rejectDomains = append(rejectDomains, trie)
}
case dataset.ListTypeNetwork:
trie, err := list.Networks()
if err != nil {
log.Err(err).Warn("Error resolving domain trie")
}
if list.Permit {
permitNetworks = append(permitNetworks, trie)
} else {
rejectNetworks = append(rejectNetworks, trie)
}
}
}
}
options = append(options,
rego.Function1(&rego.Function{
Name: "styx.reject_domain",
Description: "Check if the domain is to be rejected",
Decl: domainFunctionDecl,
Nondeterministic: true,
Memoize: true,
}, domainFunctionImpl(rejectDomains)),
rego.Function1(&rego.Function{
Name: "styx.permit_domain",
Description: "Check if the domain is to be permitted",
Decl: domainFunctionDecl,
Nondeterministic: true,
Memoize: true,
}, domainFunctionImpl(permitDomains)),
rego.Function1(&rego.Function{
Name: "styx.reject_network",
Description: "Check if the IP, IP:port, host or host:port is to be rejected",
Decl: networkFunctionDecl,
Nondeterministic: true,
Memoize: true,
}, networkFunctionImpl(rejectNetworks)),
rego.Function1(&rego.Function{
Name: "styx.permit_network",
Description: "Check if the IP, IP:port, host or host:port is to be permitted",
Decl: networkFunctionDecl,
Nondeterministic: true,
Memoize: true,
}, networkFunctionImpl(permitNetworks)),
)
return
}
var domainFunctionDecl = types.NewFunction(
types.Args(types.Named("domain", types.S).Description("Domain to lookup")),
types.Named("result", types.B).Description("`true` if domain matches"),
)
func domainFunctionImpl(tries []*dataset.DomainTrie) rego.Builtin1 {
return func(ctx rego.BuiltinContext, domainTerm *ast.Term) (*ast.Term, error) {
domain, err := parseStringTerm(domainTerm)
if err != nil {
return nil, err
}
for _, trie := range tries {
if trie.Contains(domain) {
return ast.NewTerm(ast.Boolean(true)), nil
}
}
return ast.NewTerm(ast.Boolean(false)), nil
}
}
var networkFunctionDecl = types.NewFunction(
types.Args(types.Named("ip", types.S).Description("IP, IP:port, host or host:port to lookup")),
types.Named("result", types.B).Description("`true` if IP matches"),
)
func networkFunctionImpl(tries []*dataset.NetworkTrie) rego.Builtin1 {
return func(ctx rego.BuiltinContext, ipTerm *ast.Term) (*ast.Term, error) {
ips, err := parseAddrTerm(ipTerm)
if err != nil {
return nil, err
}
for _, trie := range tries {
for _, ip := range ips {
if trie.Contains(ip) {
return ast.NewTerm(ast.Boolean(true)), nil
}
}
}
return ast.NewTerm(ast.Boolean(false)), nil
}
}
func parseAddrTerm(term *ast.Term) (addrs []netip.Addr, err error) {
s, err := parseStringTerm(term)
if err != nil {
return nil, err
}
if addr, err := netip.ParseAddr(netutil.Host(s)); err == nil {
// Input was "ip" or "ip:port"
return []netip.Addr{addr}, nil
}
ips, err := net.LookupIP(netutil.Host(s))
if err != nil {
return nil, err
}
for _, ip := range ips {
if addr, ok := netip.AddrFromSlice(ip); ok {
addrs = append(addrs, addr)
}
}
return
}
func parseStringTerm(term *ast.Term) (string, error) {
value, ok := term.Value.(ast.String)
if !ok {
return "", errors.New("expected string argument")
}
return strings.Trim(value.String(), `"`), nil
}

View File

@@ -302,14 +302,15 @@ func (p *Proxy) Serve(l net.Listener) error {
func (p *Proxy) handle(nc net.Conn) {
var (
start = time.Now()
ctx = NewContext(nc).(*proxyContext)
ctx = NewContext(nc, p.Storage).(*proxyContext)
log = ctx.Logger()
err error
)
defer func() {
if r := recover(); r != nil {
if err, ok := r.(error); ok {
ctx.LogEntry().Err(err).Warn("Bug in code, recovered from panic!")
log.Err(err).Warn("Bug in code, recovered from panic!")
}
_ = nc.Close()
}
@@ -360,16 +361,6 @@ func (p *Proxy) handle(nc net.Conn) {
}
}
log := ctx.LogEntry()
if p.Storage != nil {
if client, err := p.Storage.ClientByIP(nc.RemoteAddr().(*net.TCPAddr).IP); err == nil {
log = log.Values(logger.Values{
"client_id": client.ID,
"client_network": client.String(),
"client_description": client.Description,
})
}
}
for {
if ctx.transparentTLS {
ctx.req = &http.Request{
@@ -448,7 +439,9 @@ func (p *Proxy) handleError(ctx *proxyContext, err error, sendResponse bool) {
if res == nil && sendResponse {
res = NewErrorResponse(err, ctx.Request())
}
ctx.LogEntry().Value("count", len(p.OnError)).Trace("Running error handlers")
log := ctx.Logger()
log.Value("count", len(p.OnError)).Trace("Running error handlers")
for _, f := range p.OnError {
if newRes := f.HandleError(ctx, err); newRes != nil {
res = newRes
@@ -464,7 +457,7 @@ func (p *Proxy) handleError(ctx *proxyContext, err error, sendResponse bool) {
func (p *Proxy) handleRequest(ctx *proxyContext) (err error) {
switch {
case ctx.req == nil:
ctx.LogEntry().Warn("Request is nil in handleRequest!?")
ctx.Logger().Warn("Request is nil in handleRequest!?")
return errors.New("proxy: request is nil?")
case headerContains(ctx.req.Header, HeaderConnection, "upgrade"):
@@ -527,7 +520,7 @@ func (p *Proxy) serve(ctx *proxyContext) (err error) {
}
func (p *Proxy) serveConnect(ctx *proxyContext) (err error) {
log := ctx.LogEntry()
log := ctx.Logger()
// Most browsers expect to get a 200 OK after firing a HTTP CONNECT request; if the upstream
// encounters any errors, we'll inform the client after reading the HTTP request that follows.
@@ -571,13 +564,13 @@ func (p *Proxy) serveConnect(ctx *proxyContext) (err error) {
}
ctx.res = NewResponse(http.StatusOK, nil, ctx.req)
srv := NewContext(c).(*proxyContext)
srv := NewContext(c, p.Storage).(*proxyContext)
srv.SetIdleTimeout(p.IdleTimeout)
return p.multiplex(ctx, srv)
}
func (p *Proxy) serveForward(ctx *proxyContext) (err error) {
log := ctx.LogEntry()
log := ctx.Logger()
log.Value("target", ctx.req.URL.String()).Debugf("%s forward request", ctx.req.Proto)
var res *http.Response
@@ -609,7 +602,8 @@ func (p *Proxy) serveForward(ctx *proxyContext) (err error) {
}
func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
log := ctx.LogEntry().Value("target", ctx.req.URL.String())
log := ctx.Logger()
log.Value("target", ctx.req.URL.String())
switch ctx.req.URL.Scheme {
case "http":
@@ -632,7 +626,7 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
}
cancel()
srv := NewContext(c).(*proxyContext)
srv := NewContext(c, p.Storage).(*proxyContext)
srv.SetIdleTimeout(p.IdleTimeout)
if err = ctx.req.Write(srv); err != nil {
ctx.res = NewErrorResponse(err, ctx.req)
@@ -662,7 +656,7 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
func (p *Proxy) multiplex(ctx, srv *proxyContext) (err error) {
var (
log = ctx.LogEntry().Value("server", srv.RemoteAddr().String())
log = ctx.Logger().Value("server", srv.RemoteAddr().String())
errs = make(chan error, 1)
done = make(chan struct{}, 1)
)

View File

@@ -1,19 +1,14 @@
package proxy
import (
"expvar"
"strconv"
"git.maze.io/maze/styx/db/stats"
)
func countStatus(code int) {
k := "http:status:" + strconv.Itoa(code)
v := expvar.Get(k)
if v == nil {
//v = stats.NewCounter("120s1s", "15m10s", "1h1m", "4w1d", "1y4w")
v = stats.NewCounter(k, stats.Minutely, stats.Hourly, stats.Daily, stats.Yearly)
expvar.Publish(k, v)
}
v.(stats.Metric).Add(1)
/*
k := "http:status:" + strconv.Itoa(code)
v := expvar.Get(k)
if v == nil {
//v = stats.NewCounter("120s1s", "15m10s", "1h1m", "4w1d", "1y4w")
v = stats.NewCounter(k, stats.Minutely, stats.Hourly, stats.Daily, stats.Yearly)
expvar.Publish(k, v)
}
v.(stats.Metric).Add(1)
*/
}