Better trie implementations
This commit is contained in:
@@ -45,8 +45,11 @@ type Context interface {
|
||||
// Response is the response that will be sent back to the client.
|
||||
Response() *http.Response
|
||||
|
||||
// Logger for this context.
|
||||
Logger() logger.Structured
|
||||
|
||||
// Client group.
|
||||
Client() (dataset.Client, error)
|
||||
Storage() dataset.Storage
|
||||
}
|
||||
|
||||
type WithCertificateAuthority interface {
|
||||
@@ -91,11 +94,10 @@ type proxyContext struct {
|
||||
idleTimeout time.Duration
|
||||
ca ca.CertificateAuthority
|
||||
storage dataset.Storage
|
||||
client dataset.Client
|
||||
}
|
||||
|
||||
// NewContext returns an initialized context for the provided [net.Conn].
|
||||
func NewContext(c net.Conn) Context {
|
||||
func NewContext(c net.Conn, storage dataset.Storage) Context {
|
||||
if c, ok := c.(*proxyContext); ok {
|
||||
return c
|
||||
}
|
||||
@@ -106,12 +108,13 @@ func NewContext(c net.Conn) Context {
|
||||
cr := &countingReader{reader: c}
|
||||
cw := &countingWriter{writer: c}
|
||||
return &proxyContext{
|
||||
Conn: c,
|
||||
id: binary.BigEndian.Uint64(b),
|
||||
cr: cr,
|
||||
br: bufio.NewReader(cr),
|
||||
cw: cw,
|
||||
res: &http.Response{StatusCode: 200},
|
||||
Conn: c,
|
||||
id: binary.BigEndian.Uint64(b),
|
||||
cr: cr,
|
||||
br: bufio.NewReader(cr),
|
||||
cw: cw,
|
||||
res: &http.Response{StatusCode: 200},
|
||||
storage: storage,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,7 +131,7 @@ func (c *proxyContext) AccessLogEntry() logger.Structured {
|
||||
return entry
|
||||
}
|
||||
|
||||
func (c *proxyContext) LogEntry() logger.Structured {
|
||||
func (c *proxyContext) Logger() logger.Structured {
|
||||
var id [8]byte
|
||||
binary.BigEndian.PutUint64(id[:], c.id)
|
||||
return ServerLog.Values(logger.Values{
|
||||
@@ -234,24 +237,8 @@ func (c *proxyContext) CertificateAuthority() ca.CertificateAuthority {
|
||||
return c.ca
|
||||
}
|
||||
|
||||
func (c *proxyContext) Client() (dataset.Client, error) {
|
||||
if c.storage == nil {
|
||||
return dataset.Client{}, dataset.ErrNotExist{Object: "client"}
|
||||
}
|
||||
if !c.client.CreatedAt.Equal(time.Time{}) {
|
||||
return c.client, nil
|
||||
}
|
||||
|
||||
var err error
|
||||
switch addr := c.Conn.RemoteAddr().(type) {
|
||||
case *net.TCPAddr:
|
||||
c.client, err = c.storage.ClientByIP(addr.IP)
|
||||
case *net.UDPAddr:
|
||||
c.client, err = c.storage.ClientByIP(addr.IP)
|
||||
default:
|
||||
err = dataset.ErrNotExist{Object: "client"}
|
||||
}
|
||||
return c.client, err
|
||||
func (c *proxyContext) Storage() dataset.Storage {
|
||||
return c.storage
|
||||
}
|
||||
|
||||
var _ Context = (*proxyContext)(nil)
|
||||
|
@@ -144,18 +144,21 @@ func Transparent(port int) ConnHandler {
|
||||
return nctx, nil
|
||||
}
|
||||
|
||||
b := new(bytes.Buffer)
|
||||
hello, err := cryptutil.ReadClientHello(io.TeeReader(netutil.ReadOnlyConn{Reader: ctx.br}, b))
|
||||
var (
|
||||
b = new(bytes.Buffer)
|
||||
hello, err = cryptutil.ReadClientHello(io.TeeReader(netutil.ReadOnlyConn{Reader: ctx.br}, b))
|
||||
log = ctx.Logger()
|
||||
)
|
||||
if err != nil {
|
||||
if _, ok := err.(tls.RecordHeaderError); !ok {
|
||||
ctx.LogEntry().Err(err).Value("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error")
|
||||
log.Err(err).Value("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error")
|
||||
return nil, err
|
||||
}
|
||||
// Not a TLS connection, moving on to regular HTTP request handling...
|
||||
ctx.LogEntry().Debug("HTTP connection on transparent port")
|
||||
log.Debug("HTTP connection on transparent port")
|
||||
ctx.transparent = port
|
||||
} else {
|
||||
ctx.LogEntry().Value("target", hello.ServerName).Debug("TLS connection on transparent port")
|
||||
log.Value("target", hello.ServerName).Debug("TLS connection on transparent port")
|
||||
ctx.transparent = port
|
||||
ctx.transparentTLS = true
|
||||
ctx.serverName = hello.ServerName
|
||||
|
178
proxy/policy.go
Normal file
178
proxy/policy.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"git.maze.io/maze/styx/dataset"
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
"github.com/open-policy-agent/opa/v1/ast"
|
||||
"github.com/open-policy-agent/opa/v1/rego"
|
||||
"github.com/open-policy-agent/opa/v1/types"
|
||||
)
|
||||
|
||||
// PolicyQueryOptions generates the Rego query functions for the provided [Context].
|
||||
func PolicyQueryOptions(ctx Context) (options []func(*rego.Rego)) {
|
||||
var (
|
||||
log = ctx.Logger()
|
||||
storage = ctx.Storage()
|
||||
)
|
||||
|
||||
addr, err := netip.ParseAddr(netutil.Host(ctx.RemoteAddr().String()))
|
||||
if err != nil {
|
||||
log.Err(err).Error("Error resolving remote address")
|
||||
return
|
||||
}
|
||||
|
||||
client, err := storage.ClientByAddr(addr)
|
||||
if err != nil {
|
||||
log.Err(err).Warn("Error resolving client")
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
permitDomains []*dataset.DomainTrie
|
||||
rejectDomains []*dataset.DomainTrie
|
||||
permitNetworks []*dataset.NetworkTrie
|
||||
rejectNetworks []*dataset.NetworkTrie
|
||||
)
|
||||
for _, group := range client.Groups {
|
||||
lists, err := storage.ListsByGroup(group)
|
||||
if err != nil {
|
||||
log.Err(err).Warn("Error resolving lists")
|
||||
return
|
||||
}
|
||||
for _, list := range lists {
|
||||
switch list.Type {
|
||||
case dataset.ListTypeDomain:
|
||||
trie, err := list.Domains()
|
||||
if err != nil {
|
||||
log.Err(err).Warn("Error resolving domain trie")
|
||||
}
|
||||
if list.Permit {
|
||||
permitDomains = append(permitDomains, trie)
|
||||
} else {
|
||||
rejectDomains = append(rejectDomains, trie)
|
||||
}
|
||||
case dataset.ListTypeNetwork:
|
||||
trie, err := list.Networks()
|
||||
if err != nil {
|
||||
log.Err(err).Warn("Error resolving domain trie")
|
||||
}
|
||||
if list.Permit {
|
||||
permitNetworks = append(permitNetworks, trie)
|
||||
} else {
|
||||
rejectNetworks = append(rejectNetworks, trie)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
options = append(options,
|
||||
rego.Function1(®o.Function{
|
||||
Name: "styx.reject_domain",
|
||||
Description: "Check if the domain is to be rejected",
|
||||
Decl: domainFunctionDecl,
|
||||
Nondeterministic: true,
|
||||
Memoize: true,
|
||||
}, domainFunctionImpl(rejectDomains)),
|
||||
rego.Function1(®o.Function{
|
||||
Name: "styx.permit_domain",
|
||||
Description: "Check if the domain is to be permitted",
|
||||
Decl: domainFunctionDecl,
|
||||
Nondeterministic: true,
|
||||
Memoize: true,
|
||||
}, domainFunctionImpl(permitDomains)),
|
||||
rego.Function1(®o.Function{
|
||||
Name: "styx.reject_network",
|
||||
Description: "Check if the IP, IP:port, host or host:port is to be rejected",
|
||||
Decl: networkFunctionDecl,
|
||||
Nondeterministic: true,
|
||||
Memoize: true,
|
||||
}, networkFunctionImpl(rejectNetworks)),
|
||||
rego.Function1(®o.Function{
|
||||
Name: "styx.permit_network",
|
||||
Description: "Check if the IP, IP:port, host or host:port is to be permitted",
|
||||
Decl: networkFunctionDecl,
|
||||
Nondeterministic: true,
|
||||
Memoize: true,
|
||||
}, networkFunctionImpl(permitNetworks)),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
var domainFunctionDecl = types.NewFunction(
|
||||
types.Args(types.Named("domain", types.S).Description("Domain to lookup")),
|
||||
types.Named("result", types.B).Description("`true` if domain matches"),
|
||||
)
|
||||
|
||||
func domainFunctionImpl(tries []*dataset.DomainTrie) rego.Builtin1 {
|
||||
return func(ctx rego.BuiltinContext, domainTerm *ast.Term) (*ast.Term, error) {
|
||||
domain, err := parseStringTerm(domainTerm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, trie := range tries {
|
||||
if trie.Contains(domain) {
|
||||
return ast.NewTerm(ast.Boolean(true)), nil
|
||||
}
|
||||
}
|
||||
return ast.NewTerm(ast.Boolean(false)), nil
|
||||
}
|
||||
}
|
||||
|
||||
var networkFunctionDecl = types.NewFunction(
|
||||
types.Args(types.Named("ip", types.S).Description("IP, IP:port, host or host:port to lookup")),
|
||||
types.Named("result", types.B).Description("`true` if IP matches"),
|
||||
)
|
||||
|
||||
func networkFunctionImpl(tries []*dataset.NetworkTrie) rego.Builtin1 {
|
||||
return func(ctx rego.BuiltinContext, ipTerm *ast.Term) (*ast.Term, error) {
|
||||
ips, err := parseAddrTerm(ipTerm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, trie := range tries {
|
||||
for _, ip := range ips {
|
||||
if trie.Contains(ip) {
|
||||
return ast.NewTerm(ast.Boolean(true)), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return ast.NewTerm(ast.Boolean(false)), nil
|
||||
}
|
||||
}
|
||||
|
||||
func parseAddrTerm(term *ast.Term) (addrs []netip.Addr, err error) {
|
||||
s, err := parseStringTerm(term)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if addr, err := netip.ParseAddr(netutil.Host(s)); err == nil {
|
||||
// Input was "ip" or "ip:port"
|
||||
return []netip.Addr{addr}, nil
|
||||
}
|
||||
|
||||
ips, err := net.LookupIP(netutil.Host(s))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if addr, ok := netip.AddrFromSlice(ip); ok {
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseStringTerm(term *ast.Term) (string, error) {
|
||||
value, ok := term.Value.(ast.String)
|
||||
if !ok {
|
||||
return "", errors.New("expected string argument")
|
||||
}
|
||||
return strings.Trim(value.String(), `"`), nil
|
||||
}
|
@@ -302,14 +302,15 @@ func (p *Proxy) Serve(l net.Listener) error {
|
||||
func (p *Proxy) handle(nc net.Conn) {
|
||||
var (
|
||||
start = time.Now()
|
||||
ctx = NewContext(nc).(*proxyContext)
|
||||
ctx = NewContext(nc, p.Storage).(*proxyContext)
|
||||
log = ctx.Logger()
|
||||
err error
|
||||
)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if err, ok := r.(error); ok {
|
||||
ctx.LogEntry().Err(err).Warn("Bug in code, recovered from panic!")
|
||||
log.Err(err).Warn("Bug in code, recovered from panic!")
|
||||
}
|
||||
_ = nc.Close()
|
||||
}
|
||||
@@ -360,16 +361,6 @@ func (p *Proxy) handle(nc net.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
log := ctx.LogEntry()
|
||||
if p.Storage != nil {
|
||||
if client, err := p.Storage.ClientByIP(nc.RemoteAddr().(*net.TCPAddr).IP); err == nil {
|
||||
log = log.Values(logger.Values{
|
||||
"client_id": client.ID,
|
||||
"client_network": client.String(),
|
||||
"client_description": client.Description,
|
||||
})
|
||||
}
|
||||
}
|
||||
for {
|
||||
if ctx.transparentTLS {
|
||||
ctx.req = &http.Request{
|
||||
@@ -448,7 +439,9 @@ func (p *Proxy) handleError(ctx *proxyContext, err error, sendResponse bool) {
|
||||
if res == nil && sendResponse {
|
||||
res = NewErrorResponse(err, ctx.Request())
|
||||
}
|
||||
ctx.LogEntry().Value("count", len(p.OnError)).Trace("Running error handlers")
|
||||
|
||||
log := ctx.Logger()
|
||||
log.Value("count", len(p.OnError)).Trace("Running error handlers")
|
||||
for _, f := range p.OnError {
|
||||
if newRes := f.HandleError(ctx, err); newRes != nil {
|
||||
res = newRes
|
||||
@@ -464,7 +457,7 @@ func (p *Proxy) handleError(ctx *proxyContext, err error, sendResponse bool) {
|
||||
func (p *Proxy) handleRequest(ctx *proxyContext) (err error) {
|
||||
switch {
|
||||
case ctx.req == nil:
|
||||
ctx.LogEntry().Warn("Request is nil in handleRequest!?")
|
||||
ctx.Logger().Warn("Request is nil in handleRequest!?")
|
||||
return errors.New("proxy: request is nil?")
|
||||
|
||||
case headerContains(ctx.req.Header, HeaderConnection, "upgrade"):
|
||||
@@ -527,7 +520,7 @@ func (p *Proxy) serve(ctx *proxyContext) (err error) {
|
||||
}
|
||||
|
||||
func (p *Proxy) serveConnect(ctx *proxyContext) (err error) {
|
||||
log := ctx.LogEntry()
|
||||
log := ctx.Logger()
|
||||
|
||||
// Most browsers expect to get a 200 OK after firing a HTTP CONNECT request; if the upstream
|
||||
// encounters any errors, we'll inform the client after reading the HTTP request that follows.
|
||||
@@ -571,13 +564,13 @@ func (p *Proxy) serveConnect(ctx *proxyContext) (err error) {
|
||||
}
|
||||
|
||||
ctx.res = NewResponse(http.StatusOK, nil, ctx.req)
|
||||
srv := NewContext(c).(*proxyContext)
|
||||
srv := NewContext(c, p.Storage).(*proxyContext)
|
||||
srv.SetIdleTimeout(p.IdleTimeout)
|
||||
return p.multiplex(ctx, srv)
|
||||
}
|
||||
|
||||
func (p *Proxy) serveForward(ctx *proxyContext) (err error) {
|
||||
log := ctx.LogEntry()
|
||||
log := ctx.Logger()
|
||||
log.Value("target", ctx.req.URL.String()).Debugf("%s forward request", ctx.req.Proto)
|
||||
|
||||
var res *http.Response
|
||||
@@ -609,7 +602,8 @@ func (p *Proxy) serveForward(ctx *proxyContext) (err error) {
|
||||
}
|
||||
|
||||
func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
|
||||
log := ctx.LogEntry().Value("target", ctx.req.URL.String())
|
||||
log := ctx.Logger()
|
||||
log.Value("target", ctx.req.URL.String())
|
||||
|
||||
switch ctx.req.URL.Scheme {
|
||||
case "http":
|
||||
@@ -632,7 +626,7 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
|
||||
}
|
||||
cancel()
|
||||
|
||||
srv := NewContext(c).(*proxyContext)
|
||||
srv := NewContext(c, p.Storage).(*proxyContext)
|
||||
srv.SetIdleTimeout(p.IdleTimeout)
|
||||
if err = ctx.req.Write(srv); err != nil {
|
||||
ctx.res = NewErrorResponse(err, ctx.req)
|
||||
@@ -662,7 +656,7 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
|
||||
|
||||
func (p *Proxy) multiplex(ctx, srv *proxyContext) (err error) {
|
||||
var (
|
||||
log = ctx.LogEntry().Value("server", srv.RemoteAddr().String())
|
||||
log = ctx.Logger().Value("server", srv.RemoteAddr().String())
|
||||
errs = make(chan error, 1)
|
||||
done = make(chan struct{}, 1)
|
||||
)
|
||||
|
@@ -1,19 +1,14 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"expvar"
|
||||
"strconv"
|
||||
|
||||
"git.maze.io/maze/styx/db/stats"
|
||||
)
|
||||
|
||||
func countStatus(code int) {
|
||||
k := "http:status:" + strconv.Itoa(code)
|
||||
v := expvar.Get(k)
|
||||
if v == nil {
|
||||
//v = stats.NewCounter("120s1s", "15m10s", "1h1m", "4w1d", "1y4w")
|
||||
v = stats.NewCounter(k, stats.Minutely, stats.Hourly, stats.Daily, stats.Yearly)
|
||||
expvar.Publish(k, v)
|
||||
}
|
||||
v.(stats.Metric).Add(1)
|
||||
/*
|
||||
k := "http:status:" + strconv.Itoa(code)
|
||||
v := expvar.Get(k)
|
||||
if v == nil {
|
||||
//v = stats.NewCounter("120s1s", "15m10s", "1h1m", "4w1d", "1y4w")
|
||||
v = stats.NewCounter(k, stats.Minutely, stats.Hourly, stats.Daily, stats.Yearly)
|
||||
expvar.Publish(k, v)
|
||||
}
|
||||
v.(stats.Metric).Add(1)
|
||||
*/
|
||||
}
|
||||
|
Reference in New Issue
Block a user