From 4a60059ff2e79195c800ceb57181f9500cf21058 Mon Sep 17 00:00:00 2001 From: maze Date: Tue, 30 Sep 2025 08:08:22 +0200 Subject: [PATCH] Checkpoint --- proxy/admin.go | 145 ------ proxy/cache/config.go | 8 - proxy/config.go | 88 ---- proxy/context.go | 227 +++++++++ proxy/doc.go | 2 + proxy/handler.go | 309 +++++++++++ proxy/match/config.go | 324 ------------ proxy/match/match.go | 45 -- proxy/match/util.go | 11 - proxy/mitm/authority.go | 231 --------- proxy/mitm/cache.go | 233 --------- proxy/mitm/cache_test.go | 25 - proxy/mitm/config.go | 89 ---- proxy/policy/policy.go | 53 -- proxy/policy/policy_test.go | 139 ----- proxy/policy/rule.go | 368 -------------- proxy/policy/time.go | 53 -- proxy/proxy.go | 985 ++++++++++++++++-------------------- proxy/resolver/resolver.go | 148 ------ proxy/response.go | 99 ++-- proxy/session.go | 151 ------ proxy/stats.go | 19 + proxy/stats/stats.go | 225 -------- proxy/util.go | 16 - 24 files changed, 1034 insertions(+), 2959 deletions(-) delete mode 100644 proxy/admin.go delete mode 100644 proxy/cache/config.go delete mode 100644 proxy/config.go create mode 100644 proxy/context.go create mode 100644 proxy/doc.go create mode 100644 proxy/handler.go delete mode 100644 proxy/match/config.go delete mode 100644 proxy/match/match.go delete mode 100644 proxy/match/util.go delete mode 100644 proxy/mitm/authority.go delete mode 100644 proxy/mitm/cache.go delete mode 100644 proxy/mitm/cache_test.go delete mode 100644 proxy/mitm/config.go delete mode 100644 proxy/policy/policy.go delete mode 100644 proxy/policy/policy_test.go delete mode 100644 proxy/policy/rule.go delete mode 100644 proxy/policy/time.go delete mode 100644 proxy/resolver/resolver.go delete mode 100644 proxy/session.go create mode 100644 proxy/stats.go delete mode 100644 proxy/stats/stats.go delete mode 100644 proxy/util.go diff --git a/proxy/admin.go b/proxy/admin.go deleted file mode 100644 index b04f652..0000000 --- a/proxy/admin.go +++ /dev/null @@ -1,145 +0,0 @@ -package proxy - -import ( - "bytes" - "encoding/json" - "encoding/pem" - "errors" - "net/http" - "os" - "strconv" - "strings" - "time" - - "git.maze.io/maze/styx/internal/log" -) - -type Admin struct { - *Proxy -} - -func NewAdmin(proxy *Proxy) *Admin { - a := &Admin{ - Proxy: proxy, - } - return a -} - -func (a *Admin) handleRequest(ses *Session) error { - var ( - logger = ses.log() - err error - ) - switch ses.request.URL.Path { - case "/ca.crt": - err = a.handleCACert(ses) - case "/api/v1/policy": - err = a.apiPolicy(ses) - case "/api/v1/policy/matcher": - err = a.apiPolicyMatcher(ses) - case "/api/v1/stats/log": - err = a.apiStatsLog(ses) - case "/api/v1/stats/status": - err = a.apiStatsStatus(ses) - default: - if strings.HasPrefix(ses.request.URL.Path, "/api") { - err = errors.New("invalid endpoint") - } else { - err = os.ErrNotExist - } - } - if err != nil { - logger.Warn().Err(err).Msg("admin error") - ses.response = ErrorResponse(ses.request, err) - defer log.OnCloseError(logger.Debug(), ses.response.Body) - ses.response.Close = true - return a.writeResponse(ses) - } - return err -} - -func (a *Admin) handleCACert(ses *Session) error { - b := pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: a.authority.Certificate().Raw, - }) - - ses.response = NewResponse(http.StatusOK, bytes.NewReader(b), ses.request) - defer log.OnCloseError(log.Debug(), ses.response.Body) - - ses.response.Close = true - ses.response.Header.Set("Content-Type", "application/x-x509-ca-cert") - ses.response.ContentLength = int64(len(b)) - return a.writeResponse(ses) -} - -func (a *Admin) apiPolicy(ses *Session) error { - var ( - b = new(bytes.Buffer) - e = json.NewEncoder(b) - ) - e.SetIndent("", " ") - if err := e.Encode(a.config.Policy); err != nil { - return err - } - - ses.response = NewJSONResponse(http.StatusOK, b, ses.request) - defer log.OnCloseError(log.Debug(), ses.response.Body) - ses.response.Close = true - return a.writeResponse(ses) -} - -func (a *Admin) apiPolicyMatcher(ses *Session) error { - var ( - b = new(bytes.Buffer) - e = json.NewEncoder(b) - ) - e.SetIndent("", " ") - if err := e.Encode(a.config.Policy.Matchers); err != nil { - return err - } - - ses.response = NewJSONResponse(http.StatusOK, b, ses.request) - defer log.OnCloseError(log.Debug(), ses.response.Body) - ses.response.Close = true - return a.writeResponse(ses) -} - -func (a *Admin) apiResponse(ses *Session, v any, err error) error { - if err != nil { - return err - } - var ( - b = new(bytes.Buffer) - e = json.NewEncoder(b) - ) - e.SetIndent("", " ") - if err := e.Encode(v); err != nil { - return err - } - - ses.response = NewJSONResponse(http.StatusOK, b, ses.request) - defer log.OnCloseError(log.Debug(), ses.response.Body) - ses.response.Close = true - return a.writeResponse(ses) - -} - -func (a *Admin) apiStatsLog(ses *Session) error { - var ( - query = ses.request.URL.Query() - offset, _ = strconv.Atoi(query.Get("offset")) - limit, _ = strconv.Atoi(query.Get("limit")) - ) - if limit > 100 { - limit = 100 - } - - s, err := a.stats.QueryLog(offset, limit) - return a.apiResponse(ses, s, err) -} - -func (a *Admin) apiStatsStatus(ses *Session) error { - s, err := a.stats.QueryStatus(time.Time{}) - return a.apiResponse(ses, s, err) -} diff --git a/proxy/cache/config.go b/proxy/cache/config.go deleted file mode 100644 index cb46c72..0000000 --- a/proxy/cache/config.go +++ /dev/null @@ -1,8 +0,0 @@ -package cache - -import "github.com/hashicorp/hcl/v2" - -type Config struct { - Type string `hcl:"type"` - Body hcl.Body `hcl:",remain"` -} diff --git a/proxy/config.go b/proxy/config.go deleted file mode 100644 index 4e5c62c..0000000 --- a/proxy/config.go +++ /dev/null @@ -1,88 +0,0 @@ -package proxy - -import ( - "net" - "net/http" - "time" - - "git.maze.io/maze/styx/proxy/policy" - "git.maze.io/maze/styx/proxy/resolver" -) - -type ConnectHandler interface { - HandleConnect(session *Session, network, address string) net.Conn -} - -// ConnectHandlerFunc is called when the proxy receives a new HTTP CONNECT request. -type ConnectHandlerFunc func(session *Session, network, address string) net.Conn - -func (f ConnectHandlerFunc) HandleConnect(session *Session, network, address string) net.Conn { - return f(session, network, address) -} - -type RequestHandler interface { - HandleRequest(session *Session) (*http.Request, *http.Response) -} - -// RequestHandlerFunc is called when the proxy receives a new request. -type RequestHandlerFunc func(session *Session) (*http.Request, *http.Response) - -func (f RequestHandlerFunc) HandleRequest(session *Session) (*http.Request, *http.Response) { - return f(session) -} - -type ResponseHandler interface { - HandleResponse(session *Session) *http.Response -} - -// ResponseHandler is called when the proxy receives a response. -type ResponseHandlerFunc func(session *Session) *http.Response - -func (f ResponseHandlerFunc) HandleResponse(session *Session) *http.Response { - return f(session) -} - -type ErrorHandler interface { - HandleError(session *Session, err error) -} - -type ErrorHandlerFunc func(session *Session, err error) - -func (f ErrorHandlerFunc) HandleError(session *Session, err error) { - f(session, err) -} - -type Config struct { - // Listen address. - Listen string `hcl:"listen,optional"` - - // Bind address for outgoing connections. - Bind string `hcl:"bind,optional"` - - // Interface for outgoing connections. - Interface string `hcl:"interface,optional"` - - // Upstream proxy servers. - Upstream []string `hcl:"upstream,optional"` - - // DialTimeout for establishing new connections. - DialTimeout time.Duration `hcl:"dial_timeout,optional"` - - // Policy for the proxy. - Policy *policy.Policy `hcl:"policy,block"` - - // Resolver for the proxy. - Resolver resolver.Resolver - - ConnectHandler ConnectHandler - RequestHandler RequestHandler - ResponseHandler ResponseHandler - ErrorHandler ErrorHandler -} - -var ( - _ ConnectHandler = (ConnectHandlerFunc)(nil) - _ RequestHandler = (RequestHandlerFunc)(nil) - _ ResponseHandler = (ResponseHandlerFunc)(nil) - _ ErrorHandler = (ErrorHandlerFunc)(nil) -) diff --git a/proxy/context.go b/proxy/context.go new file mode 100644 index 0000000..943bb2b --- /dev/null +++ b/proxy/context.go @@ -0,0 +1,227 @@ +package proxy + +import ( + "bufio" + "crypto/rand" + "crypto/tls" + "encoding/binary" + "encoding/hex" + "fmt" + "io" + "net" + "net/http" + "strconv" + "sync/atomic" + "time" + + "git.maze.io/maze/styx/internal/netutil/arp" + "github.com/sirupsen/logrus" +) + +// Context provides convenience functions for the current ongoing HTTP proxy transaction (request). +type Context interface { + // Conn is the backing connection for this context. + net.Conn + + // ID is a unique connection identifier. + ID() uint64 + + // Reader returns a buffered reader on top of the [net.Conn]. + Reader() *bufio.Reader + + // BytesRead returns the number of bytes read. + BytesRead() int64 + + // BytesSent returns the number of bytes written. + BytesSent() int64 + + // TLSState returns the TLS connection state, it returns nil if the connection is not a TLS connection. + TLSState() *tls.ConnectionState + + // Request is the request made to the proxy. + Request() *http.Request + + // Response is the response that will be sent back to the client. + Response() *http.Response +} + +type countingReader struct { + reader io.Reader + bytes int64 +} + +func (r *countingReader) Read(p []byte) (n int, err error) { + if n, err = r.reader.Read(p); n > 0 { + atomic.AddInt64(&r.bytes, int64(n)) + } + return +} + +type countingWriter struct { + writer io.Writer + bytes int64 +} + +func (w *countingWriter) Write(p []byte) (n int, err error) { + if n, err = w.writer.Write(p); n > 0 { + atomic.AddInt64(&w.bytes, int64(n)) + } + return +} + +type proxyContext struct { + net.Conn + id uint64 + mac net.HardwareAddr + cr *countingReader + br *bufio.Reader + cw *countingWriter + isTransparent bool + isTransparentTLS bool + serverName string + req *http.Request + res *http.Response + idleTimeout time.Duration +} + +// NewContext returns an initialized context for the provided [net.Conn]. +func NewContext(c net.Conn) Context { + if c, ok := c.(*proxyContext); ok { + return c + } + + b := make([]byte, 8) + io.ReadFull(rand.Reader, b) + + cr := &countingReader{reader: c} + cw := &countingWriter{writer: c} + return &proxyContext{ + Conn: c, + id: binary.BigEndian.Uint64(b), + mac: arp.Get(c.RemoteAddr()), + cr: cr, + br: bufio.NewReader(cr), + cw: cw, + res: &http.Response{StatusCode: 200}, + } +} + +func (c *proxyContext) AccessLogEntry() *logrus.Entry { + var id [8]byte + binary.BigEndian.PutUint64(id[:], c.id) + entry := AccessLog.WithFields(logrus.Fields{ + "client": c.RemoteAddr().String(), + "server": c.LocalAddr().String(), + "id": hex.EncodeToString(id[:]), + "bytes_rx": c.BytesRead(), + "bytes_tx": c.BytesSent(), + }) + if c.mac != nil { + return entry.WithField("client_mac", c.mac.String()) + } + return entry +} + +func (c *proxyContext) LogEntry() *logrus.Entry { + var id [8]byte + binary.BigEndian.PutUint64(id[:], c.id) + return ServerLog.WithFields(logrus.Fields{ + "client": c.RemoteAddr().String(), + "server": c.LocalAddr().String(), + "id": hex.EncodeToString(id[:]), + }) +} + +func (c *proxyContext) String() string { + return fmt.Sprintf("client=%s server=%s id=%#08x", + c.RemoteAddr().String(), + c.LocalAddr().String(), + c.id) +} + +func (c *proxyContext) ID() uint64 { + return c.id +} + +func (c *proxyContext) BytesRead() int64 { + return atomic.LoadInt64(&c.cr.bytes) +} + +func (c *proxyContext) BytesSent() int64 { + return atomic.LoadInt64(&c.cw.bytes) +} + +func (c *proxyContext) Read(p []byte) (n int, err error) { + if c.idleTimeout > 0 { + if err = c.SetReadDeadline(time.Now().Add(c.idleTimeout)); err != nil { + return + } + } + return c.br.Read(p) +} + +func (c *proxyContext) Write(p []byte) (n int, err error) { + if c.idleTimeout > 0 { + if err = c.SetWriteDeadline(time.Now().Add(c.idleTimeout)); err != nil { + return + } + } + return c.cw.Write(p) +} + +func (c *proxyContext) Reader() *bufio.Reader { + return c.br +} + +func (c *proxyContext) Request() *http.Request { + return c.req +} + +func (c *proxyContext) SetRequest(req *http.Request) { + c.req = req +} + +func (c *proxyContext) Response() *http.Response { + return c.res +} + +func (c *proxyContext) SetIdleTimeout(t time.Duration) { + c.idleTimeout = t +} + +type connectionStater interface { + ConnectionState() tls.ConnectionState +} + +func (c *proxyContext) TLSState() *tls.ConnectionState { + if s, ok := c.Conn.(connectionStater); ok { + state := s.ConnectionState() + return &state + } + return nil +} + +// http.ResponseWriter interface: + +func (c *proxyContext) Header() http.Header { + if c.res == nil { + c.res = NewResponse(http.StatusOK, nil, c.req) + } + return c.res.Header +} + +func (c *proxyContext) WriteHeader(code int) { + if c.res == nil { + c.res = NewResponse(code, nil, c.req) + } else { + if text := http.StatusText(code); text != "" { + c.res.Status = strconv.Itoa(code) + " " + text + } else { + c.res.Status = strconv.Itoa(code) + } + c.res.StatusCode = code + } + //return c.res.Header.Write(c) +} + +var _ Context = (*proxyContext)(nil) diff --git a/proxy/doc.go b/proxy/doc.go new file mode 100644 index 0000000..889785c --- /dev/null +++ b/proxy/doc.go @@ -0,0 +1,2 @@ +// Package proxy contains a HTTP(s) (transparent) proxy server. +package proxy diff --git a/proxy/handler.go b/proxy/handler.go new file mode 100644 index 0000000..f064658 --- /dev/null +++ b/proxy/handler.go @@ -0,0 +1,309 @@ +package proxy + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + + "git.maze.io/maze/styx/ca" + "git.maze.io/maze/styx/internal/cryptutil" + "git.maze.io/maze/styx/internal/netutil" +) + +type Dialer interface { + DialContext(context.Context, *http.Request) (net.Conn, error) +} + +type defaultDialer struct{} + +func (defaultDialer) DialContext(ctx context.Context, req *http.Request) (net.Conn, error) { + if host := netutil.Host(req.URL.Host); host == "" { + return nil, errors.New("proxy: host missing in address") + } + + var d = net.Dialer{ + Resolver: net.DefaultResolver, + FallbackDelay: -1, + } + + // Ensure we have a port. + switch req.URL.Scheme { + case "http", "ws": + req.URL.Host = netutil.EnsurePort(req.URL.Host, "80") + case "https", "wss": + req.URL.Host = netutil.EnsurePort(req.URL.Host, "443") + } + + // Resolve the host. + if ips, err := d.Resolver.LookupIP(ctx, "ip", netutil.Host(req.URL.Host)); err != nil { + return nil, err + } else { + for _, ip := range ips { + switch { + case ip.IsUnspecified(): + return nil, fmt.Errorf("proxy: host %s resolves to unspecified address (blocked by DNS?)", netutil.Host(req.URL.Host)) + case ip.IsLoopback(): + return nil, fmt.Errorf("proxy: host %s resolves to loopback address (blocked by DNS?)", netutil.Host(req.URL.Host)) + } + } + } + + // Make the connection. + switch req.URL.Scheme { + case "tcp", "http", "ws": + // Plain TCP client connection. + return d.DialContext(ctx, "tcp", req.URL.Host) + case "https", "wss": + // Secure TLS client connection. + c, err := d.DialContext(ctx, "tcp", req.URL.Host) + if err != nil { + return nil, err + } + s := tls.Client(c, new(tls.Config)) + return s, s.Handshake() + default: + return nil, fmt.Errorf("proxy: can't dial %s protocol", req.URL.Scheme) + } +} + +// ConnFilter is called when a new connection has been accepted by the proxy. +type ConnFilter interface { + FilterConn(Context) (net.Conn, error) +} + +// ConnFilterFunc is a function that implements the [ConnFilter] interface. +type ConnFilterFunc func(Context) (net.Conn, error) + +func (f ConnFilterFunc) FilterConn(ctx Context) (net.Conn, error) { + return f(ctx) +} + +// TLS starts a TLS handshake on the accepted connection. +func TLS(certs []tls.Certificate) ConnFilter { + return ConnFilterFunc(func(ctx Context) (net.Conn, error) { + s := tls.Server(ctx, &tls.Config{ + Certificates: certs, + NextProtos: []string{"http/1.1"}, + }) + if err := s.Handshake(); err != nil { + return nil, err + } + return s, nil + }) +} + +// TLSInterceptor can generate certificates on-the-fly for clients that use a compatible TLS version. +func TLSInterceptor(ca ca.CertificateAuthority) ConnFilter { + return ConnFilterFunc(func(ctx Context) (net.Conn, error) { + s := tls.Server(ctx, &tls.Config{ + GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + ips := []net.IP{net.ParseIP(netutil.Host(ctx.RemoteAddr().String()))} + return ca.GetCertificate(hello.ServerName, []string{hello.ServerName}, ips) + }, + NextProtos: []string{"http/1.1"}, + }) + if err := s.Handshake(); err != nil { + return nil, err + } + return s, nil + }) +} + +// Transparent can handle transparent HTTP(S) requests on the port. +// +// When a new [net.Conn] is made, this function will inspect the initial request packet for a +// TLS handshake. If a TLS handshake is detected, the connection will make a feaux HTTP CONNECT +// request using TLS, if no handshake is detected, it will make a feaux plain HTTP CONNECT request. +func Transparent() ConnFilter { + return ConnFilterFunc(func(nctx Context) (net.Conn, error) { + ctx, ok := nctx.(*proxyContext) + if !ok { + return nctx, nil + } + + b := new(bytes.Buffer) + hello, err := cryptutil.ReadClientHello(io.TeeReader(ctx, b)) + if err != nil { + if _, ok := err.(tls.RecordHeaderError); !ok { + ctx.LogEntry().WithError(err).WithField("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error") + return nil, err + } + // Not a TLS connection, moving on to regular HTTP request handling... + ctx.LogEntry().Debug("HTTP connection on transparent port") + ctx.isTransparent = true + } else { + ctx.LogEntry().WithField("target", hello.ServerName).Debug("TLS connection on transparent port") + ctx.isTransparentTLS = true + ctx.serverName = hello.ServerName + } + + return netutil.ReaderConn{ + Conn: ctx.Conn, + Reader: io.MultiReader(bytes.NewReader(b.Bytes()), ctx.Conn), + }, nil + }) +} + +// RequestFilter can filter HTTP requests coming to the proxy. +type RequestFilter interface { + // FilterRequest filters a HTTP request made to the proxy. The current request may be obtained + // from [Context.Request]. If a previous RequestFilter provided a HTTP response, it is available + // from [Context.Response]. + // + // Modifications to the current request can be made to the Request returned by [Context.Request] + // and do not require returning a new [http.Request]. + // + // If the filter returns a non-nil [http.Response], then the [Request] will not be proxied to + // any upstream server(s). + FilterRequest(Context) (*http.Request, *http.Response) +} + +// RequestFilterFunc is a function that implements the [RequestFilter] interface. +type RequestFilterFunc func(Context) (*http.Request, *http.Response) + +func (f RequestFilterFunc) FilterRequest(ctx Context) (*http.Request, *http.Response) { + return f(ctx) +} + +// ResponseFilter can filter HTTP responses coming from the proxy. +type ResponseFilter interface { + // FilterResponse filters a HTTP response coming from the proxy. The current response may be + // obtained from [Context.Response]. + // + // Modifications to the current response can be made to the [Response] returned by [Context.Response]. + FilterResponse(Context) *http.Response +} + +// ResponseFilterFunc is a function that implements the [ResponseFilter] interface. +type ResponseFilterFunc func(Context) *http.Response + +func (f ResponseFilterFunc) FilterResponse(ctx Context) *http.Response { + return f(ctx) +} + +// CleanRequestProxyHeaders removes all headers added by downstream proxies from the [http.Request]. +func CleanRequestProxyHeaders() RequestFilter { + return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) { + if req := ctx.Request(); req != nil { + cleanProxyHeaders(req.Header) + } + return nil, nil + }) +} + +// CleanRequestProxyHeaders removes all headers for upstream proxies from the [http.Response]. +func CleanResponseProxyHeaders() ResponseFilter { + return ResponseFilterFunc(func(ctx Context) *http.Response { + if res := ctx.Response(); res != nil { + cleanProxyHeaders(res.Header) + } + return nil + }) +} + +// AddRequestHeaders adds headers to the [http.Request]. Any existing headers with the same +// key will remain intact. +func AddRequestHeaders(h http.Header) RequestFilter { + return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) { + if req := ctx.Request(); req != nil { + if req.Header == nil { + req.Header = make(http.Header) + } + addHeaders(req.Header, h) + } + return nil, nil + }) +} + +// SetRequestHeaders sets headers to the [http.Request]. Any existing headers with the same +// key will be removed. +func SetRequestHeaders(h http.Header) RequestFilter { + return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) { + if req := ctx.Request(); req != nil { + if req.Header == nil { + req.Header = make(http.Header) + } + setHeaders(req.Header, h) + } + return nil, nil + }) +} + +// AddResponseHeaders adds headers to the [http.Response]. Any existing headers with the same +// key will remain intact. +func AddResponseHeaders(h http.Header) ResponseFilter { + return ResponseFilterFunc(func(ctx Context) *http.Response { + if res := ctx.Response(); res != nil { + if res.Header == nil { + res.Header = make(http.Header) + } + addHeaders(res.Header, h) + } + return nil + }) +} + +// SetResponseHeaders sets headers to the [http.Response]. Any existing headers with the same +// key will be removed. +func SetResponseHeaders(h http.Header) ResponseFilter { + return ResponseFilterFunc(func(ctx Context) *http.Response { + if res := ctx.Response(); res != nil { + if res.Header == nil { + res.Header = make(http.Header) + } + setHeaders(res.Header, h) + } + return nil + }) +} + +// cleanProxyHeaders removes all headers commonly used by (reverse) HTTP proxies. +func cleanProxyHeaders(h http.Header) { + if h == nil { + return + } + + for _, k := range []string{ + HeaderForwarded, + HeaderForwardedFor, + HeaderForwardedHost, + HeaderForwardedPort, + HeaderForwardedProto, + HeaderRealIP, + HeaderVia, + } { + h.Del(k) + } +} + +// addHeaders adds to the current existing headers. +func addHeaders(dst, src http.Header) { + if src == nil { + return + } + + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +// setHeaders replaces all previous values. +func setHeaders(dst, src http.Header) { + if src == nil { + return + } + + for k, vv := range src { + dst.Del(k) + for _, v := range vv { + dst.Add(k, v) + } + } +} diff --git a/proxy/match/config.go b/proxy/match/config.go deleted file mode 100644 index 832a207..0000000 --- a/proxy/match/config.go +++ /dev/null @@ -1,324 +0,0 @@ -package match - -import ( - "fmt" - "net" - "net/http" - "os" - "regexp" - "slices" - "strconv" - "strings" - "time" - - "git.maze.io/maze/styx/internal/log" - "git.maze.io/maze/styx/internal/netutil" - "github.com/hashicorp/hcl/v2" - "github.com/hashicorp/hcl/v2/gohcl" -) - -type Config struct { - Path string `hcl:"path,optional"` - Refresh time.Duration `hcl:"refresh,optional"` - Domain []*Domain `hcl:"domain,block"` - Network []*Network `hcl:"network,block"` - Content []*Content `hcl:"content,block"` -} - -func (config Config) Matchers() (Matchers, error) { - all := make(Matchers) - if config.Domain != nil { - all["domain"] = make(map[string]Matcher) - for _, domain := range config.Domain { - m, err := domain.Matcher() - if err != nil { - return nil, fmt.Errorf("matcher domain %q invalid: %w", domain.Name, err) - } - all["domain"][domain.Name] = m - } - } - if config.Network != nil { - all["network"] = make(map[string]Matcher) - for _, network := range config.Network { - m, err := network.Matcher(true) - if err != nil { - return nil, fmt.Errorf("matcher network %q invalid: %w", network.Name, err) - } - all["network"][network.Name] = m - } - } - return all, nil -} - -type Content struct { - Name string `hcl:"name,label"` - Type string `hcl:"type"` - Body hcl.Body `hcl:",remain"` -} - -type contentHeader struct { - Key string `hcl:"name"` - Value string `hcl:"value,optional"` - List []string `hcl:"list,optional"` - name string - keyRe *regexp.Regexp - valueRe *regexp.Regexp -} - -func (m contentHeader) Name() string { return m.name } -func (m contentHeader) MatchesResponse(r *http.Response) bool { - for k, vv := range r.Header { - if m.keyRe.MatchString(k) { - for _, v := range vv { - if slices.Contains(m.List, v) { - return true - } - if m.valueRe != nil && m.valueRe.MatchString(v) { - return true - } - } - } - } - return false -} - -type contentType struct { - List []string `hcl:"list"` - name string -} - -func (m contentType) Name() string { return m.name } -func (m contentType) MatchesResponse(r *http.Response) bool { - return slices.Contains(m.List, r.Header.Get("Content-Type")) -} - -type contentSizeLargerThan struct { - Size int64 `hcl:"size"` - name string -} - -func (m contentSizeLargerThan) Name() string { return m.name } -func (m contentSizeLargerThan) MatchesResponse(r *http.Response) bool { - size, err := strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64) - if err != nil { - return false - } - return size >= m.Size -} - -type contentStatus struct { - Code []int `hcl:"code"` - name string -} - -func (m contentStatus) Name() string { return m.name } -func (m contentStatus) MatchesResponse(r *http.Response) bool { - return slices.Contains(m.Code, r.StatusCode) -} - -func (config Content) Matcher() (Response, error) { - switch strings.ToLower(config.Type) { - case "content", "contenttype", "content-type", "type": - var matcher = contentType{name: config.Name} - if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil { - return nil, err - } - return matcher, nil - - case "header": - var ( - matcher = contentHeader{name: config.Name} - err error - ) - if err = gohcl.DecodeBody(config.Body, nil, &matcher); err != nil { - return nil, err - } - if matcher.Value == "" && len(matcher.List) == 0 { - return nil, fmt.Errorf("invalid content %q: must contain either list or value", config.Name) - } - if matcher.keyRe, err = regexp.Compile(matcher.Key); err != nil { - return nil, fmt.Errorf("invalid regular expression on content %q key: %w", config.Name, err) - } - if matcher.Value != "" { - if matcher.valueRe, err = regexp.Compile(matcher.Value); err != nil { - return nil, fmt.Errorf("invalid regular expression on content %q value: %w", config.Name, err) - } - } - return matcher, nil - - case "size": - var matcher = contentSizeLargerThan{name: config.Name} - if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil { - return nil, err - } - return matcher, nil - - case "status": - var matcher = contentStatus{name: config.Name} - if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil { - return nil, err - } - return matcher, nil - - default: - return nil, fmt.Errorf("unknown content matcher type %q", config.Type) - } -} - -type Domain struct { - Name string `hcl:"name,label"` - Type string `hcl:"type"` - Body hcl.Body `hcl:",remain"` -} - -func (config Domain) Matcher() (Request, error) { - switch config.Type { - case "list": - var matcher = domainList{Title: config.Name} - if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil { - return nil, err - } - matcher.list = netutil.NewDomainList(matcher.List...) - return matcher, nil - - case "adblock", "dnsmasq", "hosts", "detect", "domains": - var matcher = DomainFile{ - Title: config.Name, - Type: config.Type, - } - if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil { - return nil, err - } - if matcher.Path == "" && matcher.From == "" { - return nil, fmt.Errorf("matcher: domain %q must have either file or from configured", config.Name) - } - if err := matcher.Update(); err != nil { - return nil, err - } - return matcher, nil - - default: - return nil, fmt.Errorf("unknown domain matcher type %q", config.Type) - } - -} - -type domainList struct { - Title string `json:"title"` - List []string `hcl:"list" json:"list"` - list *netutil.DomainTree -} - -func (m domainList) Name() string { - return m.Title -} - -func (m domainList) MatchesRequest(r *http.Request) bool { - host := netutil.Host(r.URL.Host) - log.Debug().Str("host", host).Msgf("match domain list (%d domains)", len(m.List)) - return m.list.Contains(host) -} - -type DomainFile struct { - Title string `json:"name"` - Type string `json:"type"` - Path string `hcl:"path,optional" json:"path,omitempty"` - From string `hcl:"from,optional" json:"from,omitempty"` - Refresh time.Duration `hcl:"refresh,optional" json:"refresh"` -} - -func (m DomainFile) Name() string { - return m.Title -} - -func (m DomainFile) MatchesRequest(_ *http.Request) bool { - return false -} - -func (m *DomainFile) Update() (err error) { - var data []byte - if m.Path != "" { - if data, err = os.ReadFile(m.Path); err != nil { - return - } - } else { - /* - var response *http.Response - if response, err = http.DefaultClient.Get(m.From); err != nil { - return - } - defer func() { _ = response.Body.Close() }() - if response.StatusCode != http.StatusOK { - return fmt.Errorf("match: domain %q update failed: %s", m.name, response.Status) - } - if data, err = io.ReadAll(response.Body); err != nil { - return - } - */ - } - - switch m.Type { - case "hosts": - } - - _ = data - return nil -} - -type Network struct { - Name string `hcl:"name,label"` - Type string `hcl:"type"` - Body hcl.Body `hcl:",remain"` -} - -func (config *Network) Matcher(target bool) (Matcher, error) { - switch config.Type { - case "list": - var ( - matcher = networkList{Title: config.Name} - err error - ) - if diag := gohcl.DecodeBody(config.Body, nil, &matcher); diag.HasErrors() { - return nil, diag - } - if matcher.tree, err = netutil.NewNetworkTree(matcher.List...); err != nil { - return nil, err - } - return &matcher, nil - - default: - return nil, fmt.Errorf("unknown network matcher type %q", config.Type) - } -} - -type networkList struct { - Title string `json:"name"` - List []string `hcl:"list" json:"list"` - tree *netutil.NetworkTree - target bool -} - -func (m *networkList) Name() string { - return m.Title -} - -func (m *networkList) MatchesIP(ip net.IP) bool { - return m.tree.Contains(ip) -} - -func (m *networkList) MatchesRequest(r *http.Request) bool { - var ( - host string - err error - ) - if m.target { - host, _, err = net.SplitHostPort(r.URL.Host) - } else { - host, _, err = net.SplitHostPort(r.RemoteAddr) - } - if err != nil { - return false - } - ip := net.ParseIP(host) - return m.MatchesIP(ip) -} diff --git a/proxy/match/match.go b/proxy/match/match.go deleted file mode 100644 index 98a5e17..0000000 --- a/proxy/match/match.go +++ /dev/null @@ -1,45 +0,0 @@ -package match - -import ( - "fmt" - "net" - "net/http" -) - -type Matchers map[string]map[string]Matcher - -func (all Matchers) Get(kind, name string) (m Matcher, err error) { - if typeMatchers, ok := all[kind]; ok { - if m, ok = typeMatchers[name]; ok { - return - } - return nil, fmt.Errorf("no %s matcher named %q found", kind, name) - } - return nil, fmt.Errorf("no %s matcher found", kind) -} - -type Matcher interface { - Name() string -} - -type Updater interface { - Update() error -} - -type IP interface { - Matcher - - MatchesIP(net.IP) bool -} - -type Request interface { - Matcher - - MatchesRequest(*http.Request) bool -} - -type Response interface { - Matcher - - MatchesResponse(*http.Response) bool -} diff --git a/proxy/match/util.go b/proxy/match/util.go deleted file mode 100644 index 210b1e5..0000000 --- a/proxy/match/util.go +++ /dev/null @@ -1,11 +0,0 @@ -package match - -import "net" - -func onlyHost(name string) string { - host, _, err := net.SplitHostPort(name) - if err != nil { - return name - } - return host -} diff --git a/proxy/mitm/authority.go b/proxy/mitm/authority.go deleted file mode 100644 index 37b55ca..0000000 --- a/proxy/mitm/authority.go +++ /dev/null @@ -1,231 +0,0 @@ -package mitm - -import ( - "crypto" - "crypto/rand" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "errors" - "fmt" - "math/big" - "os" - "strings" - "time" - - "git.maze.io/maze/styx/internal/cryptutil" - "git.maze.io/maze/styx/internal/log" - "github.com/miekg/dns" -) - -const DefaultValidity = 24 * time.Hour - -type Authority interface { - Certificate() *x509.Certificate - TLSConfig(name string) *tls.Config -} - -type authority struct { - pool *x509.CertPool - cert *x509.Certificate - key crypto.PrivateKey - keyID []byte - keyPool chan crypto.PrivateKey - cache Cache -} - -func New(config *Config) (Authority, error) { - cache, err := NewCache(config.Cache) - if err != nil { - return nil, err - } - - caConfig := config.CA - if caConfig == nil { - caConfig = new(CAConfig) - } - - cert, key, err := cryptutil.LoadKeyPair(caConfig.Cert, caConfig.Key) - if os.IsNotExist(err) { - days := caConfig.Days - if days == 0 { - days = DefaultDays - } - if cert, key, err = cryptutil.GenerateKeyPair(caConfig.DN(), days, caConfig.KeyType, caConfig.Bits); err != nil { - return nil, err - } - if strings.ContainsRune(caConfig.Cert, os.PathSeparator) { - if err = cryptutil.SaveKeyPair(cert, key, caConfig.Cert, caConfig.Key); err != nil { - return nil, err - } - } - } else if err != nil { - return nil, err - } - - pool := x509.NewCertPool() - pool.AddCert(cert) - - keyConfig := config.Key - if keyConfig == nil { - keyConfig = &defaultKeyConfig - } - - keyPoolSize := defaultKeyConfig.Pool - if keyConfig.Pool > 0 { - keyPoolSize = keyConfig.Pool - } - keyPool := make(chan crypto.PrivateKey, keyPoolSize) - if key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits); err != nil { - return nil, fmt.Errorf("mitm: invalid key configuration: %w", err) - } else { - keyPool <- key - } - - go func(pool chan<- crypto.PrivateKey) { - for { - key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits) - if err != nil { - log.Panic().Err(err).Msg("error generating private key") - } - pool <- key - } - }(keyPool) - - return &authority{ - pool: pool, - cert: cert, - key: key, - keyID: cryptutil.GenerateKeyID(cryptutil.PublicKey(key)), - keyPool: keyPool, - cache: cache, - }, nil -} - -func (ca *authority) log() log.Logger { - return log.Console.With(). - Str("ca", ca.cert.Subject.String()). - Logger() -} - -func (ca *authority) Certificate() *x509.Certificate { - return ca.cert -} - -func (ca *authority) TLSConfig(name string) *tls.Config { - return &tls.Config{ - GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { - log := ca.log() - if hello.ServerName != "" { - name = strings.ToLower(hello.ServerName) - log.Debug().Msg("requesting certificate for server name (SNI)") - } else { - log.Debug().Msg("requesting certificate for hostname") - } - if cert, ok := ca.getCached(name); ok { - log.Debug(). - Str("subject", cert.Leaf.Subject.String()). - Str("serial", cert.Leaf.SerialNumber.String()). - Time("valid", cert.Leaf.NotAfter). - Msg("using cached certificate") - return cert, nil - } - return ca.issueFor(name) - }, - NextProtos: []string{"http/1.1"}, - } -} - -func (ca *authority) getCached(name string) (cert *tls.Certificate, ok bool) { - log := ca.log() - - if cert = ca.cache.Certificate(name); cert == nil { - if baseDomain(name) != name { - cert = ca.cache.Certificate(baseDomain(name)) - } - } - if cert != nil { - if _, err := cert.Leaf.Verify(x509.VerifyOptions{ - DNSName: name, - Roots: ca.pool, - }); err != nil { - log.Debug().Err(err).Str("name", name).Msg("deleting invalid certificate from cache") - } else { - ok = true - } - } - return -} - -func (ca *authority) issueFor(name string) (*tls.Certificate, error) { - var ( - log = ca.log().With().Str("name", name).Logger() - key crypto.PrivateKey - ) - select { - case key = <-ca.keyPool: - case <-time.After(5 * time.Second): - return nil, errors.New("mitm: timeout waiting for private key generator to catch up") - } - if key == nil { - panic("key pool returned nil key") - } - - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - if err != nil { - return nil, fmt.Errorf("mtim: failed to generate serial number: %w", err) - } - - if part := dns.SplitDomainName(name); len(part) > 2 { - name = strings.Join(part[1:], ".") - log.Debug().Msgf("abbreviated name to %s (*.%s)", name, name) - } - - now := time.Now() - template := &x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{CommonName: name}, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - DNSNames: []string{name, "*." + name}, - BasicConstraintsValid: true, - NotBefore: now.Add(-DefaultValidity), - NotAfter: now.Add(+DefaultValidity), - } - der, err := x509.CreateCertificate(rand.Reader, template, ca.cert, cryptutil.PublicKey(key), ca.key) - if err != nil { - return nil, err - } - cert, err := x509.ParseCertificate(der) - if err != nil { - return nil, err - } - - log.Debug().Str("serial", serialNumber.String()).Msg("generated certificate") - out := &tls.Certificate{ - Certificate: [][]byte{der}, - Leaf: cert, - PrivateKey: key, - } - //ca.cache[name] = out - ca.cache.SaveCertificate(name, out) - return out, nil -} - -func containsValidCertificate(cert *tls.Certificate) bool { - if cert == nil || len(cert.Certificate) == 0 { - return false - } - - if cert.Leaf == nil { - var err error - if cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]); err != nil { - return false - } - } - - now := time.Now() - - return !(cert.Leaf.NotBefore.Before(now) || cert.Leaf.NotAfter.After(now)) -} diff --git a/proxy/mitm/cache.go b/proxy/mitm/cache.go deleted file mode 100644 index e0221ad..0000000 --- a/proxy/mitm/cache.go +++ /dev/null @@ -1,233 +0,0 @@ -package mitm - -import ( - "crypto/tls" - "fmt" - "io/fs" - "os" - "path/filepath" - "slices" - "strings" - "time" - - "github.com/hashicorp/golang-lru/v2/expirable" - "github.com/hashicorp/hcl/v2/gohcl" - "github.com/miekg/dns" - - "git.maze.io/maze/styx/internal/cryptutil" - "git.maze.io/maze/styx/internal/log" -) - -type Cache interface { - Certificate(name string) *tls.Certificate - SaveCertificate(name string, cert *tls.Certificate) error - RemoveCertificate(name string) -} - -func NewCache(config *CacheConfig) (Cache, error) { - if config == nil { - return NewCache(&CacheConfig{Type: "memory"}) - } - switch config.Type { - case "memory": - var cacheConfig = new(MemoryCacheConfig) - if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil { - return nil, err - } - return NewMemoryCache(cacheConfig.Size), nil - case "disk": - var cacheConfig = new(DiskCacheConfig) - if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil { - return nil, err - } - return NewDiskCache(cacheConfig.Path, time.Duration(cacheConfig.Expire*float64(time.Second))) - default: - return nil, fmt.Errorf("mitm: cache type %q is not supported", config.Type) - } -} - -type memoryCache struct { - cache *expirable.LRU[string, *tls.Certificate] -} - -func NewMemoryCache(size int) Cache { - return memoryCache{ - cache: expirable.NewLRU(size, func(key string, value *tls.Certificate) { - log.Debug().Str("name", key).Msg("certificate evicted from cache") - }, time.Hour*24), - } -} - -func (c memoryCache) Certificate(name string) (cert *tls.Certificate) { - var ok bool - if cert, ok = c.cache.Get(name); !ok { - cert, _ = c.cache.Get(baseDomain(name)) - } - return -} - -func (c memoryCache) SaveCertificate(name string, cert *tls.Certificate) error { - c.cache.Add(name, cert) - log.Debug().Str("name", name).Msg("certificate added to cache") - return nil -} - -func (c memoryCache) RemoveCertificate(name string) { - c.cache.Remove(name) -} - -type diskCache string - -func NewDiskCache(dir string, expire time.Duration) (Cache, error) { - if !filepath.IsAbs(dir) { - var err error - if dir, err = filepath.Abs(dir); err != nil { - return nil, err - } - } - if err := os.MkdirAll(dir, 0o750); err != nil { - return nil, err - } - info, err := os.Stat(dir) - if err != nil { - return nil, err - } - if info.Mode()&os.ModePerm|0o057 != 0 { - if err := os.Chmod(dir, 0o750); err != nil { - return nil, err - } - } - - if expire > 0 { - go expireDiskCache(dir, expire) - } - - return diskCache(dir), nil -} - -func expireDiskCache(root string, expire time.Duration) { - log.Debug().Str("path", root).Dur("expire", expire).Msg("disk cache expire loop starting") - ticker := time.NewTicker(expire) - defer ticker.Stop() - for { - now := <-ticker.C - log.Debug().Str("path", root).Dur("expire", expire).Msg("expire disk cache") - filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - // Remove the directory; this will fail if it's not empty, which is fine. - _ = os.Remove(path) - return nil - } - - cert, err := cryptutil.LoadCertificate(path) - if err != nil { - log.Debug().Str("path", path).Err(err).Msg("expire removing invalid certificate file") - _ = os.Remove(path) - return nil - } else if cert.NotAfter.Before(now) { - log.Debug().Str("path", path).Dur("expired", now.Sub(cert.NotAfter)).Msg("expire removing expired certificate") - _ = os.Remove(path) - return nil - } - return nil - }) - } -} - -func (c diskCache) path(name string) string { - part := dns.SplitDomainName(strings.ToLower(name)) - // x,com -> com,x - // www,maze,io -> io,maze,www - slices.Reverse(part) - // com,x -> com,x,x.com - // io,maze,www -> io,m,ma,maze,www.maze.io - if len(part) > 2 { - if len(part[1]) > 1 { - part = []string{ - part[0], - part[1][:1], - part[1][:2], - part[1], - name, - } - } else { - part = []string{ - part[0], - part[1][:1], - part[1], - name, - } - } - } else if len(part) > 1 { - if len(part[1]) > 1 { - part = []string{ - part[0], - part[1][:1], - part[1][:2], - name, - } - } else { - part = []string{ - part[0], - part[1][:1], - name, - } - } - } - part[len(part)-1] += ".crt" - return filepath.Join(append([]string{string(c)}, part...)...) -} - -func (c diskCache) Certificate(name string) (cert *tls.Certificate) { - if cert, key, err := cryptutil.LoadKeyPair(c.path(name), ""); err == nil { - return &tls.Certificate{ - Certificate: [][]byte{cert.Raw}, - Leaf: cert, - PrivateKey: key, - } - } - if cert, key, err := cryptutil.LoadKeyPair(c.path(baseDomain(name)), ""); err == nil { - return &tls.Certificate{ - Certificate: [][]byte{cert.Raw}, - Leaf: cert, - PrivateKey: key, - } - } - log.Debug().Str("path", string(c)).Str("name", name).Msg("cache miss") - return nil -} - -func (c diskCache) SaveCertificate(name string, cert *tls.Certificate) error { - dir, name := filepath.Split(c.path(name)) - if err := os.MkdirAll(dir, 0o750); err != nil { - return err - } - if err := cryptutil.SaveKeyPair(cert.Leaf, cert.PrivateKey, filepath.Join(dir, name), ""); err != nil { - return err - } - log.Debug().Str("name", name).Msg("certificate added to cache") - return nil -} - -func (c diskCache) RemoveCertificate(name string) { - path := c.path(name) - if err := os.Remove(path); err != nil { - if os.IsNotExist(err) { - return - } - log.Error().Err(err).Msg("certificate remove from cache failed") - } - _ = os.Remove(filepath.Dir(path)) - log.Debug().Str("name", name).Msg("certificate removed from cache") -} - -func baseDomain(name string) string { - name = strings.ToLower(name) - if part := dns.SplitDomainName(name); len(part) > 2 { - return strings.Join(part[1:], ".") - } - return name -} diff --git a/proxy/mitm/cache_test.go b/proxy/mitm/cache_test.go deleted file mode 100644 index 6f6d364..0000000 --- a/proxy/mitm/cache_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package mitm - -import "testing" - -func TestDiskCachePath(t *testing.T) { - cache := diskCache("testdata") - tests := []struct { - test string - want string - }{ - {"x.com", "testdata/com/x/x.com.crt"}, - {"feed.x.com", "testdata/com/x/x/feed.x.com.crt"}, - {"nu.nl", "testdata/nl/n/nu/nu.nl.crt"}, - {"maze.io", "testdata/io/m/ma/maze.io.crt"}, - {"lab.maze.io", "testdata/io/m/ma/maze/lab.maze.io.crt"}, - {"dev.lab.maze.io", "testdata/io/m/ma/maze/dev.lab.maze.io.crt"}, - } - for _, test := range tests { - t.Run(test.test, func(it *testing.T) { - if v := cache.path(test.test); v != test.want { - it.Errorf("expected %q to resolve to %q, got %q", test.test, test.want, v) - } - }) - } -} diff --git a/proxy/mitm/config.go b/proxy/mitm/config.go deleted file mode 100644 index 0c67650..0000000 --- a/proxy/mitm/config.go +++ /dev/null @@ -1,89 +0,0 @@ -package mitm - -import ( - "crypto/x509/pkix" - - "github.com/hashicorp/hcl/v2" -) - -const ( - DefaultCommonName = "Styx Certificate Authority" - DefaultDays = 3 -) - -type Config struct { - CA *CAConfig `hcl:"ca,block"` - Key *KeyConfig `hcl:"key,block"` - Cache *CacheConfig `hcl:"cache,block"` -} - -type CAConfig struct { - Cert string `hcl:"cert"` - Key string `hcl:"key,optional"` - Days int `hcl:"days,optional"` - KeyType string `hcl:"key_type,optional"` - Bits int `hcl:"bits,optional"` - Name string `hcl:"name,optional"` - Country string `hcl:"country,optional"` - Organization string `hcl:"organization,optional"` - Unit string `hcl:"unit,optional"` - Locality string `hcl:"locality,optional"` - Province string `hcl:"province,optional"` - Address []string `hcl:"address,optional"` - PostalCode string `hcl:"postal_code,optional"` -} - -func (config CAConfig) DN() pkix.Name { - var name = pkix.Name{ - CommonName: config.Name, - StreetAddress: config.Address, - } - if config.Name == "" { - name.CommonName = DefaultCommonName - } - if config.Country != "" { - name.Country = append(name.Country, config.Country) - } - if config.Organization != "" { - name.Organization = append(name.Organization, config.Organization) - } - if config.Unit != "" { - name.OrganizationalUnit = append(name.OrganizationalUnit, config.Unit) - } - if config.Locality != "" { - name.Locality = append(name.Locality, config.Locality) - } - if config.Province != "" { - name.Province = append(name.Province, config.Province) - } - if config.PostalCode != "" { - name.PostalCode = append(name.PostalCode, config.PostalCode) - } - return name -} - -type KeyConfig struct { - Type string `hcl:"type,optional"` - Bits int `hcl:"bits,optional"` - Pool int `hcl:"pool,optional"` -} - -var defaultKeyConfig = KeyConfig{ - Type: "rsa", - Bits: 2048, - Pool: 5, -} - -type CacheConfig struct { - Type string `hcl:"type"` - Body hcl.Body `hcl:",remain"` -} - -type MemoryCacheConfig struct { - Size int `hcl:"size,optional"` -} - -type DiskCacheConfig struct { - Path string `hcl:"path"` - Expire float64 `hcl:"expire,optional"` -} diff --git a/proxy/policy/policy.go b/proxy/policy/policy.go deleted file mode 100644 index 539b47e..0000000 --- a/proxy/policy/policy.go +++ /dev/null @@ -1,53 +0,0 @@ -package policy - -import ( - "net/http" - - "git.maze.io/maze/styx/proxy/match" -) - -// Policy contains rules that make up the policy. -// -// Some policy rules contain nested policies. -type Policy struct { - Rules []*rawRule `hcl:"on,block" json:"rules"` - Permit *bool `hcl:"permit" json:"permit"` - Matchers match.Matchers `json:"matchers"` // Matchers for the policy - -} - -func (p *Policy) Configure(matchers match.Matchers) (err error) { - for _, r := range p.Rules { - if err = r.Configure(matchers); err != nil { - return - } - } - p.Matchers = matchers - return -} - -func (p *Policy) PermitIntercept(r *http.Request) *bool { - if p != nil { - for _, rule := range p.Rules { - if rule, ok := rule.Rule.(InterceptRule); ok { - if permit := rule.PermitIntercept(r); permit != nil { - return permit - } - } - } - } - return p.Permit -} - -func (p *Policy) PermitRequest(r *http.Request) *bool { - if p != nil { - for _, rule := range p.Rules { - if rule, ok := rule.Rule.(RequestRule); ok { - if permit := rule.PermitRequest(r); permit != nil { - return permit - } - } - } - } - return p.Permit -} diff --git a/proxy/policy/policy_test.go b/proxy/policy/policy_test.go deleted file mode 100644 index 307bab0..0000000 --- a/proxy/policy/policy_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package policy - -import ( - "net" - "net/http" - "net/url" - "testing" - - "git.maze.io/maze/styx/internal/netutil" - "git.maze.io/maze/styx/proxy/match" - "github.com/miekg/dns" -) - -type testInDomainList struct { - t *testing.T - list []string -} - -func (testInDomainList) Name() string { return "testInDomainList" } -func (l testInDomainList) MatchesRequest(r *http.Request) bool { - for _, domain := range l.list { - if dns.IsSubDomain(domain, netutil.Host(r.URL.Host)) { - l.t.Logf("domain %s contains %s", domain, r.URL.Host) - return true - } - l.t.Logf("domain %s does not contain %s", domain, r.URL.Host) - } - return false -} - -func testInDomain(t *testing.T, domains ...string) match.Matcher { - return &testInDomainList{t: t, list: domains} -} - -type testInNetworkList struct { - t *testing.T - list []*net.IPNet -} - -func (testInNetworkList) Name() string { return "testInNetworkList" } -func (l testInNetworkList) MatchesIP(ip net.IP) bool { - for _, ipnet := range l.list { - if ipnet.Contains(ip) { - l.t.Logf("network %s contains %s", ipnet, ip) - return true - } - l.t.Logf("network %s does not contain %s", ipnet, ip) - } - return false -} - -func testInNetwork(t *testing.T, cidr string) match.Matcher { - t.Helper() - _, ipnet, err := net.ParseCIDR(cidr) - if err != nil { - panic(err) - } - return testInNetworkList{t: t, list: []*net.IPNet{ipnet}} -} - -func TestPolicy(t *testing.T) { - var ( - yes = true - nope = false - ) - p := &Policy{ - Rules: []*rawRule{ - { - Rule: &requestRule{ - domainOrNetworkRule: domainOrNetworkRule{ - matchers: []match.Matcher{testInNetwork(t, "127.0.0.0/8")}, - isSource: []bool{true}, - }, - }, - }, - { - Rule: &requestRule{ - domainOrNetworkRule: domainOrNetworkRule{ - matchers: []match.Matcher{testInNetwork(t, "127.0.0.0/8")}, - isSource: []bool{false}, - }, - Permit: &yes, - }, - }, - { - Rule: &requestRule{ - domainOrNetworkRule: domainOrNetworkRule{ - matchers: []match.Matcher{testInDomain(t, "maze.io", "maze.engineering")}, - }, - Permit: &yes, - }, - }, - { - Rule: &requestRule{ - domainOrNetworkRule: domainOrNetworkRule{ - matchers: []match.Matcher{testInDomain(t, "google.com")}, - }, - Permit: &nope, - }, - }, - }, - } - - r := &http.Request{ - URL: &url.URL{Scheme: "http", Host: "golang.org:80"}, - RemoteAddr: "127.0.0.1:1234", - } - if v := p.PermitRequest(r); v != nil { - t.Errorf("expected request to return no verdict, got %t", *v) - } - - p.Rules[0].Rule.(*requestRule).Permit = &yes - if v := p.PermitRequest(r); v == nil || *v != yes { - t.Errorf("expected request to return %t, %v", yes, v) - } - - r.RemoteAddr = "192.168.1.2:3456" - if v := p.PermitRequest(r); v != nil { - t.Errorf("expected request to return no verdict, got %t", *v) - } - if v := p.PermitIntercept(r); v != nil { - t.Errorf("expected request to return no verdict, got %t", *v) - } - - r.URL.Host = "maze.io" - if v := p.PermitRequest(r); v == nil || *v != yes { - t.Errorf("expected request to return %t, %v", yes, v) - } - - r.URL.Host = "google.com" - if v := p.PermitRequest(r); v == nil || *v != nope { - t.Errorf("expected request to return %t, %v", nope, v) - } - - r.URL.Host = "localhost:80" - if v := p.PermitRequest(r); v == nil || *v != yes { - t.Errorf("expected request to return %t, %v", yes, v) - } -} diff --git a/proxy/policy/rule.go b/proxy/policy/rule.go deleted file mode 100644 index 7cbe559..0000000 --- a/proxy/policy/rule.go +++ /dev/null @@ -1,368 +0,0 @@ -package policy - -import ( - "fmt" - "net" - "net/http" - "strings" - "time" - - "git.maze.io/maze/styx/internal/netutil" - "git.maze.io/maze/styx/proxy/match" - "github.com/google/uuid" - "github.com/hashicorp/hcl/v2" - "github.com/hashicorp/hcl/v2/gohcl" -) - -// Rule is a policy rule. -type Rule interface { - Configure(match.Matchers) error -} - -// InterceptRule can make policy rule decisions on intercept requests. -type InterceptRule interface { - PermitIntercept(r *http.Request) *bool -} - -// RequestRule can make policy rule decisions on HTTP CONNECT requests. -type RequestRule interface { - PermitRequest(r *http.Request) *bool -} - -type rawRule struct { - Type string `hcl:"type,label" json:"type"` - Body hcl.Body `hcl:",remain" json:"-"` - Rule `json:"rule"` -} - -func (r *rawRule) Configure(matchers match.Matchers) (err error) { - switch r.Type { - case "intercept": - r.Rule = new(interceptRule) - case "request": - r.Rule = new(requestRule) - case "days": - r.Rule = new(daysRule) - case "time": - r.Rule = new(timeRule) - case "all": - r.Rule = new(allRule) - default: - return fmt.Errorf("policy: invalid event type %q", r.Type) - } - - if diag := gohcl.DecodeBody(r.Body, nil, r.Rule); diag.HasErrors() { - return err - } - - return r.Rule.Configure(matchers) -} - -type allRule struct { - Rules []*rawRule `hcl:"on,block"` - Permit *bool `hcl:"permit"` -} - -func (r *allRule) Configure(matchers match.Matchers) (err error) { - return -} - -type domainOrNetworkRule struct { - matchers []match.Matcher - isSource []bool -} - -func (r *domainOrNetworkRule) configure(kind string, matchers match.Matchers, domains, sources, targets []string, v any, id *string) (err error) { - var m match.Matcher - for _, domain := range domains { - if m, err = matchers.Get("domain", domain); err != nil { - return fmt.Errorf("%s: unknown domain %q", kind, domain) - } - r.matchers = append(r.matchers, m) - r.isSource = append(r.isSource, false) - } - for _, network := range sources { - if m, err = matchers.Get("network", network); err != nil { - return fmt.Errorf("%s: unknown source network %q", kind, network) - } - r.matchers = append(r.matchers, m) - r.isSource = append(r.isSource, true) - } - for _, network := range targets { - if m, err = matchers.Get("network", network); err != nil { - return fmt.Errorf("%s: unknown target network %q", kind, network) - } - r.matchers = append(r.matchers, m) - r.isSource = append(r.isSource, false) - } - if len(r.matchers) == 0 { - return fmt.Errorf("%s: missing any of domain, source, target", kind) - } - if id != nil { - *id = uuid.NewString() - } - return -} - -func (r *domainOrNetworkRule) matchesRequest(q *http.Request) bool { - for i, m := range r.matchers { - if m, ok := m.(match.Request); ok { - if m.MatchesRequest(q) { - return true - } - } - if m, ok := m.(match.IP); ok { - if r.isSource[i] { - if m.MatchesIP(net.ParseIP(netutil.Host(q.RemoteAddr))) { - return true - } - } else { - var ( - host = netutil.Host(q.URL.Host) - ips []net.IP - ) - if ip := net.ParseIP(host); ip != nil { - ips = append(ips, ip) - } else { - ips, _ = net.LookupIP(host) - } - for _, ip := range ips { - if m.MatchesIP(ip) { - return true - } - } - } - } - } - return false -} - -type interceptRule struct { - ID string `json:"id,omitempty"` - Domain []string `hcl:"domain,optional" json:"domain,omitempty"` - Source []string `hcl:"source,optional" json:"source,omitempty"` - Target []string `hcl:"target,optional" json:"target,omitempty"` - Permit *bool `hcl:"permit" json:"permit"` - domainOrNetworkRule `json:"-"` -} - -func (r *interceptRule) Configure(matchers match.Matchers) (err error) { - return r.configure("intercept", matchers, r.Domain, r.Source, r.Target, r, &r.ID) -} - -func (r *interceptRule) PermitIntercept(q *http.Request) *bool { - if r.matchesRequest(q) { - return r.Permit - } - return nil -} - -type requestRule struct { - ID string `json:"id,omitempty"` - Domain []string `hcl:"domain,optional" json:"domain,omitempty"` - Source []string `hcl:"source,optional" json:"source,omitempty"` - Target []string `hcl:"target,optional" json:"target,omitempty"` - Permit *bool `hcl:"permit" json:"permit"` - domainOrNetworkRule `json:"-"` -} - -func (r *requestRule) Configure(matchers match.Matchers) (err error) { - return r.configure("request", matchers, r.Domain, r.Source, r.Target, r, &r.ID) -} - -func (r *requestRule) PermitRequest(q *http.Request) *bool { - if r.matchesRequest(q) { - return r.Permit - } - return nil -} - -type timeRule struct { - ID string `json:"id,omitempty"` - Time []string `hcl:"time" json:"time"` - Permit *bool `hcl:"permit" json:"permit"` - Body hcl.Body `hcl:",remain" json:"-"` - Rules *Policy `json:"rules"` - Start Time `json:"start"` - End Time `json:"end"` -} - -func (r *timeRule) isActive() bool { - if r == nil { - return true - } - - now := Now() - if r.Start.After(r.End) { // ie: 18:00-06:00 - return now.After(r.Start) || now.Before(r.End) - } - return now.After(r.Start) && now.Before(r.End) -} - -func (r *timeRule) Configure(matchers match.Matchers) (err error) { - if len(r.Time) != 2 { - return fmt.Errorf("invalid time %s, need [start, stop]", r.Time) - } - if r.Start, err = ParseTime(r.Time[0]); err != nil { - return fmt.Errorf("invalid start %q: %w", r.Time[0], err) - } - if r.End, err = ParseTime(r.Time[1]); err != nil { - return fmt.Errorf("invalid end %q: %w", r.Time[1], err) - } - - r.Rules = new(Policy) - if diag := gohcl.DecodeBody(r.Body, nil, r.Rules); diag.HasErrors() { - return diag - } - - if err = r.Rules.Configure(matchers); err != nil { - return - } - r.Rules.Matchers = nil - - if r.ID == "" { - r.ID = uuid.NewString() - } - - return -} - -func (r *timeRule) PermitIntercept(q *http.Request) *bool { - if !r.isActive() { - return nil - } - return r.Rules.PermitIntercept(q) -} - -func (r *timeRule) PermitRequest(q *http.Request) *bool { - if !r.isActive() { - return nil - } - return r.Rules.PermitRequest(q) -} - -type daysRule struct { - ID string `json:"id,omitempty"` - Days string `hcl:"days" json:"days"` - Permit *bool `hcl:"permit" json:"permit"` - Body hcl.Body `hcl:",remain" json:"-"` - Rules *Policy `json:"rules"` - cond []onCond -} - -func (r *daysRule) isActive() bool { - if r == nil || len(r.cond) == 0 { - return true - } - - now := time.Now() - for _, cond := range r.cond { - if cond(now) { - return true - } - } - return false -} - -func (r *daysRule) Configure(matchers match.Matchers) (err error) { - if r.cond, err = parseOnCond(r.Days); err != nil { - return - } - - r.Rules = new(Policy) - if diag := gohcl.DecodeBody(r.Body, nil, r.Rules); diag.HasErrors() { - return diag - } - if err = r.Rules.Configure(matchers); err != nil { - return - } - r.Rules.Matchers = nil - - if r.ID == "" { - r.ID = uuid.NewString() - } - - return -} - -func (r *daysRule) PermitIntercept(q *http.Request) *bool { - if !r.isActive() { - return nil - } - return r.Rules.PermitIntercept(q) -} - -func (r *daysRule) PermitRequest(q *http.Request) *bool { - if !r.isActive() { - return nil - } - return r.Rules.PermitRequest(q) -} - -type onCond func(time.Time) bool - -var weekdays = map[string]time.Weekday{ - "sun": time.Sunday, - "mon": time.Monday, - "tue": time.Tuesday, - "wed": time.Wednesday, - "thu": time.Thursday, - "fri": time.Friday, - "sat": time.Saturday, -} - -func parseOnCond(when string) (conds []onCond, err error) { - for _, spec := range strings.Split(when, ",") { - spec = strings.ToLower(strings.TrimSpace(spec)) - if d, ok := weekdays[spec]; ok { - conds = append(conds, onWeekday(d)) - } else if spec == "weekend" || spec == "weekends" { - conds = append(conds, onWeekend) - } else if spec == "workday" || spec == "workdays" { - conds = append(conds, onWorkday) - } else if strings.ContainsRune(spec, '-') { - var ( - part = strings.SplitN(spec, "-", 2) - from, upto time.Weekday - ok bool - ) - if from, ok = weekdays[part[0]]; !ok { - return nil, fmt.Errorf("on %q: invalid weekday %q", spec, part[0]) - } - if upto, ok = weekdays[part[1]]; !ok { - return nil, fmt.Errorf("on %q: invalid weekday %q", spec, part[1]) - } - if from < upto { - for d := from; d < upto; d++ { - conds = append(conds, onWeekday(d)) - } - } else { - for d := time.Sunday; d < from; d++ { - conds = append(conds, onWeekday(d)) - } - for d := upto; d <= time.Saturday; d++ { - conds = append(conds, onWeekday(d)) - } - } - } else { - return nil, fmt.Errorf("on %q: invalid condition", spec) - } - } - return -} - -func onWeekday(weekday time.Weekday) onCond { - return func(t time.Time) bool { - return t.Weekday() == weekday - } -} - -func onWeekend(t time.Time) bool { - d := t.Weekday() - return d == time.Saturday || d == time.Sunday -} - -func onWorkday(t time.Time) bool { - d := t.Weekday() - return !(d == time.Saturday || d == time.Sunday) -} diff --git a/proxy/policy/time.go b/proxy/policy/time.go deleted file mode 100644 index 957092b..0000000 --- a/proxy/policy/time.go +++ /dev/null @@ -1,53 +0,0 @@ -package policy - -import ( - "fmt" - "time" -) - -type Time struct { - Hour int - Minute int - Second int -} - -func (t Time) Eq(other Time) bool { - return t.Hour == other.Hour && t.Minute == other.Minute && t.Second == other.Second -} - -func (t Time) After(other Time) bool { - return t.Seconds() > other.Seconds() -} - -func (t Time) Before(other Time) bool { - return t.Seconds() < other.Seconds() -} - -func (t Time) Seconds() int { - return t.Hour*3600 + t.Minute*60 + t.Second -} - -func (t Time) MarshalJSON() ([]byte, error) { - return []byte(fmt.Sprintf(`"%02d:%02d:%02d"`, t.Hour, t.Minute, t.Second)), nil -} - -var timeFormats = []string{ - time.TimeOnly, - "15:04", - time.Kitchen, -} - -func Now() Time { - now := time.Now() - return Time{now.Hour(), now.Minute(), now.Second()} -} - -func ParseTime(s string) (t Time, err error) { - var tt time.Time - for _, layout := range timeFormats { - if tt, err = time.Parse(layout, s); err == nil { - return Time{tt.Hour(), tt.Minute(), tt.Second()}, nil - } - } - return Time{}, fmt.Errorf("time: invalid time %q", s) -} diff --git a/proxy/proxy.go b/proxy/proxy.go index 91cdc2d..fe7a14f 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -8,609 +8,498 @@ import ( "errors" "fmt" "io" + "log" "net" "net/http" + "net/url" + "slices" "strings" - "sync" "syscall" "time" - "git.maze.io/maze/styx/internal/log" "git.maze.io/maze/styx/internal/netutil" - "git.maze.io/maze/styx/proxy/mitm" - "git.maze.io/maze/styx/proxy/policy" - "git.maze.io/maze/styx/proxy/resolver" - "git.maze.io/maze/styx/proxy/stats" + "git.maze.io/maze/styx/stats" + "github.com/sirupsen/logrus" ) +// Common HTTP headers. const ( - DefaultListenAddr = ":3128" - DefaultBindAddr = "" - DefaultDialTimeout = 30 * time.Second - DefaultKeepAlivePeriod = 1 * time.Minute -) - -const ( - HeaderAcceptEncoding = "Accept-Encoding" HeaderConnection = "Connection" - HeaderContentLength = "Content-Length" HeaderContentType = "Content-Type" + HeaderDate = "Date" + HeaderForwarded = "Forwarded" + HeaderForwardedFor = "X-Forwarded-For" + HeaderForwardedHost = "X-Forwarded-Host" + HeaderForwardedPort = "X-Forwarded-Port" + HeaderForwardedProto = "X-Forwarded-Proto" + HeaderRealIP = "X-Real-Ip" HeaderUpgrade = "Upgrade" + HeaderVia = "Via" +) + +// Safe defaults. +const ( + DefaultDialTimeout = 15 * time.Second + DefaultIdleTimeout = 10 * time.Second + DefaultWebSocketIdleTimeout = 30 * time.Second ) var ( - ErrClosed = errors.New("proxy: shutdown") - ErrClientCert = errors.New("tls: client certificate requested") + // AccessLog is used for logging requests to the proxy. + AccessLog = logrus.StandardLogger() + + // ServerLog is used for logging server log messages. + ServerLog = logrus.StandardLogger() ) +// Proxy implements a HTTP(S) proxy. type Proxy struct { - addr *net.TCPAddr - bind *net.TCPAddr - resolver resolver.Resolver - transport *http.Transport - dial func(network, address string) (net.Conn, error) - config *Config - authority mitm.Authority - policy *policy.Policy - admin *Admin - stats *stats.Stats - closed chan struct{} - onConnect ConnectHandler - onRequest RequestHandler - onResponse ResponseHandler - onError ErrorHandler + rt http.RoundTripper + dialer map[string]Dialer + connFilter []ConnFilter + requestFilter []RequestFilter + responseFilter []ResponseFilter + dialTimeout time.Duration + idleTimeout time.Duration + webSocketIdleTimeout time.Duration + mux *http.ServeMux } -func New(config *Config, ca mitm.Authority) (*Proxy, error) { - if config == nil { - return nil, errors.New("proxy: config can't be nil") - } - +// New [Proxy] with somewhat sane defaults. +func New() *Proxy { p := &Proxy{ - transport: newTransport(), - config: config, - resolver: resolver.Default, - authority: ca, - policy: config.Policy, - closed: make(chan struct{}), - onConnect: config.ConnectHandler, - onRequest: config.RequestHandler, - onResponse: config.ResponseHandler, - onError: config.ErrorHandler, + dialer: map[string]Dialer{"": defaultDialer{}}, + dialTimeout: DefaultDialTimeout, + idleTimeout: DefaultIdleTimeout, + webSocketIdleTimeout: DefaultWebSocketIdleTimeout, + mux: http.NewServeMux(), } - var err error - if config.Listen == "" { - p.addr, err = net.ResolveTCPAddr("tcp", DefaultBindAddr) - } else { - p.addr, err = net.ResolveTCPAddr("tcp", config.Listen) - } - if err != nil { - return nil, fmt.Errorf("proxy: invalid listen addres: %w", err) - } - if config.Bind != "" { - if p.bind, err = net.ResolveTCPAddr("tcp", config.Bind+":0"); err != nil { - return nil, fmt.Errorf("proxy: invalid bind address: %w", err) - } - } else if config.Interface != "" { - if err = resolveInterfaceAddr(config.Interface); err != nil { - return nil, err - } - } - if p.bind != nil { - /* FIXME - var c *net.TCPConn - if c, err = net.DialTCP("tcp", p.bind, p.bind); err != nil && errors.Is(err, syscall.EADDRNOTAVAIL) { - return nil, fmt.Errorf("proxy: invalid bind address: %w", syscall.EADDRNOTAVAIL) - } else if c != nil { - _ = c.Close() - } - */ - } - if config.Resolver != nil { - p.resolver = config.Resolver - } - - dialTimeout := DefaultDialTimeout - if config.DialTimeout > 0 { - dialTimeout = config.DialTimeout - } - p.dial = (&net.Dialer{ - Timeout: dialTimeout, - KeepAlive: dialTimeout, - LocalAddr: p.bind, - }).Dial - - p.admin = NewAdmin(p) - - if p.stats, err = stats.New(); err != nil { - return nil, err - } - - return p, nil -} - -func newTransport() *http.Transport { - return &http.Transport{ - TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper), + // Make sure the roundtripper uses our dialers. + p.rt = &http.Transport{ + TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper), Proxy: http.ProxyFromEnvironment, - TLSHandshakeTimeout: 15 * time.Second, - ExpectContinueTimeout: 5 * time.Second, - } -} - -func (p *Proxy) Close() error { - select { - case <-p.closed: - return ErrClosed - default: - close(p.closed) - return nil - } -} - -func (p *Proxy) Start() error { - l, err := net.ListenTCP("tcp", p.addr) - if err != nil { - return err + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: time.Second, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return p.dialer[""].DialContext(ctx, &http.Request{ + URL: &url.URL{ + Scheme: network, + Host: addr, + }, + }) + }, } - go p.Serve(l) - return nil + p.Handle("/stats", stats.Handler(stats.Exposed)) + p.Handle("/stats.json", stats.JSONHandler(stats.Exposed)) + + return p } -func (p *Proxy) Serve(listener net.Listener) error { - defer func() { _ = listener.Close() }() +// Handle installs a [http.Handler] into the internal mux. +func (p *Proxy) Handle(pattern string, handler http.Handler) { + p.mux.Handle(pattern, handler) +} - log.Info().Str("addr", listener.Addr().String()).Msg("proxy server listening") - for { - select { - case <-p.closed: - return nil - default: +// HandleFunc installs a [http.HandlerFunc] into the internal mux. +func (p *Proxy) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) { + p.mux.HandleFunc(pattern, handler) +} + +// SetDialer specifies a [Dialer] for the specified protocol. The default [Dialer] corresponds +// to an empty string. Only override the default [Dialer] if you know what you are doing. +func (p *Proxy) SetDialer(proto string, dialer Dialer) { + if dialer == nil { + if proto != "" { + delete(p.dialer, proto) } + } else { + p.dialer[proto] = dialer + } +} - c, err := listener.Accept() +// AddConnFilter adds a connection filter to the stack. +func (p *Proxy) AddConnFilter(f ConnFilter) { + if f == nil { + return + } + p.connFilter = append(p.connFilter, f) +} + +// AddRequestFilter adds a request filter to the stack. +func (p *Proxy) AddRequestFilter(f RequestFilter) { + if f == nil { + return + } + p.requestFilter = append(p.requestFilter, f) +} + +// AddResponseFilter adds a response filter to the stack. +func (p *Proxy) AddResponseFilter(f ResponseFilter) { + if f == nil { + return + } + p.responseFilter = append(p.responseFilter, f) +} + +func (p *Proxy) dial(ctx context.Context, req *http.Request) (net.Conn, error) { + d, ok := p.dialer[req.URL.Scheme] + if !ok { + d = p.dialer[""] + } + + return d.DialContext(ctx, req) +} + +// Serve proxied connections on the specified listener. +func (p *Proxy) Serve(l net.Listener) error { + for { + c, err := l.Accept() if err != nil { return err } - - rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) - ctx := newContext(c, rw, nil) - - if c, ok := c.(*net.TCPConn); ok { - _ = c.SetKeepAlive(true) - _ = c.SetKeepAlivePeriod(DefaultKeepAlivePeriod) - } - - go p.handle(ctx) + go p.handle(c) } } -func (p *Proxy) handle(ctx *Context) { - logger := ctx.log() - defer log.OnCloseError(logger.Debug(), ctx.conn) - logger.Info().Str("client", ctx.RemoteAddr().String()).Msg("new client connection") - - last := int64(0) - for { - select { - case <-p.closed: - return - - default: - ses, err := p.handleRequest(ctx) - if ses != nil { - log := ses.log() - log.Info(). - Str("method", ses.request.Method). - Str("url", ses.request.URL.String()). - Str("status", ses.response.Status). - Int64("size", ctx.conn.bytes-last). - Msg("handled request") - - p.stats.AddLog(&stats.Log{ - ClientIP: netutil.Host(ses.request.RemoteAddr), - Request: stats.FromRequest(ses.request), - Response: stats.FromResponse(ses.response).SetSize(ctx.conn.bytes - last), - }) - - last = ctx.conn.bytes - } - if err != nil && !isClosing(err) || (ses != nil && ses.response != nil && ses.response.Close) { - event := logger.Debug() - if ctx.conn.bytes > 0 { - event = event.Int64("size", ctx.conn.bytes) - } - event.Msg("closing client connection") - return - } - } - } -} - -func (p *Proxy) handleRequest(ctx *Context) (ses *Session, err error) { - logger := ctx.log() - - var request *http.Request - if request, err = p.readRequest(ctx); err != nil { - return - } - - ses = newSession(ctx, request) - p.cleanRequest(ses, request) - - logger.Debug().Str("method", request.Method).Str("url", request.URL.String()).Msg("handle request") - - if p.onRequest != nil { - newRequest, newResponse := p.onRequest.HandleRequest(ses) - if newRequest != nil { - logger.Debug().Str("method", newRequest.Method).Str("url", newRequest.URL.String()).Msg("request override") - ses.request = newRequest - } - if newResponse != nil { - logger.Debug().Str("status", newResponse.Status).Msg("response override") - ses.response = newResponse - } - } - - if ses.response == nil { - // WebSocket request - if ses.request.Header.Get(HeaderUpgrade) == "websocket" { - return ses, p.handleTunnel(ses) - } - - cleanHopByHopHeaders(ses.request.Header) - - // Proxy CONNECT request - if ses.request.Method == http.MethodConnect { - return p.handleConnect(ses) - } - - if netutil.Port(ses.request.URL.Host) == p.addr.Port { - // Plain API request - ses.request.URL.Host = ses.request.Host - return ses, p.admin.handleRequest(ses) - - } else if ses.response, err = p.transport.RoundTrip(ses.request); err != nil { - // Plain HTTP request - if p.config.ErrorHandler != nil { - p.config.ErrorHandler.HandleError(ses, err) - } - ses.response = ErrorResponse(ses.request, err) - } - - logger.Debug().Str("status", ses.response.Status).Msg("received response") - cleanHopByHopHeaders(ses.response.Header) - } - - ses.response.Close = true - defer log.OnCloseError(logger.Debug(), ses.response.Body) - return ses, p.writeResponse(ses) -} - -func (p *Proxy) handleConnect(ses *Session) (next *Session, err error) { - next = ses - - logger := ses.log() - logger.Debug().Msgf("connecting to %s", ses.request.URL.Host) - - var c net.Conn - if c, err = p.connect(ses, "tcp", ses.request.URL.Host); err != nil { - logger.Error().Err(err).Msg("connect failed") - if p.onError != nil { - p.onError.HandleError(ses, err) - } - - ses.response = ErrorResponse(ses.request, err) - defer log.OnCloseError(logger.Debug(), ses.response.Body) - _ = p.writeResponse(ses) - - return - } - - defer func() { - if err := c.Close(); err != nil { - if p.onError != nil { - p.onError.HandleError(ses, err) - } - } - }() - - if p.canIntercept(ses.request) { - logger.Debug().Msg("intercepting connection") - ses.response = NewResponse(http.StatusOK, nil, ses.request) - err = p.writeResponse(ses) - log.OnCloseError(logger.Debug(), ses.response.Body) - if err != nil { - return - } - - // Peek first byte - b := make([]byte, 1) - if _, err = io.ReadFull(ses.ctx.rw, b); err != nil { - logger.Error().Err(err).Msg("error peeking CONNECT byte") - return - } - - // Drain buffered bytes - b = append(b, make([]byte, ses.ctx.rw.Reader.Buffered())...) - ses.ctx.rw.Reader.Read(b[1:]) - - r := &connReader{ - Conn: ses.ctx.conn, - Reader: io.MultiReader(bytes.NewBuffer(b), ses.ctx.conn), - } - if b[0] == 22 { // TLS handshake: https://tools.ietf.org/html/rfc5246#section-6.2.1 - secure := tls.Server(r, p.authority.TLSConfig(ses.request.URL.Host)) - if err = secure.Handshake(); err != nil { - logger.Error().Err(err).Msg("error intercepting TLS connection: client handshake failed") - return - } - - rw := bufio.NewReadWriter(bufio.NewReader(secure), bufio.NewWriter(secure)) - ctx := newContext(secure, rw, ses) - return p.handleRequest(ctx) - } - - rw := bufio.NewReadWriter(bufio.NewReader(r), bufio.NewWriter(r)) - ctx := newContext(r, rw, ses) - return p.handleRequest(ctx) - } - - ses.response = NewResponse(http.StatusOK, nil, ses.request) - defer log.OnCloseError(logger.Debug(), ses.response.Body) - ses.response.ContentLength = -1 - if err = p.writeResponse(ses); err != nil { - return - } - - logger.Debug().Msg("established CONNECT tunnel, proxying traffic") - var wait sync.WaitGroup - wait.Go(func() { copyStream(ses, c, ses.ctx.conn) }) - wait.Go(func() { copyStream(ses, ses.ctx.conn, c) }) - wait.Wait() - logger.Debug().Msg("closed CONNECT tunnel") - return -} - -func (p *Proxy) handleTunnel(ses *Session) (err error) { - logger := ses.log() - logger.Debug().Msgf("connecting to %s", ses.request.URL.Host) - - var c net.Conn - if c, err = p.connect(ses, "tcp", ses.request.URL.Host); err != nil { - logger.Error().Err(err).Msg("connect failed") - if p.onError != nil { - p.onError.HandleError(ses, err) - } - - ses.response = ErrorResponse(ses.request, err) - defer log.OnCloseError(logger.Debug(), ses.response.Body) - _ = p.writeResponse(ses) - - return - } - - defer log.OnCloseError(logger.Debug(), c) - - if ses.ctx.IsTLS() { - // Open a TLS client connection - secure := tls.Client(c, &tls.Config{ - ServerName: ses.request.URL.Host, - GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { - return nil, ErrClientCert - }, - }) - if err = secure.Handshake(); err != nil { - logger.Error().Err(err).Msg("TLS handshake failed") - return - } - c = secure - } - - if err = ses.request.Write(c); err != nil { - logger.Error().Err(err).Msg("failed to write request") - return - } - - logger.Debug().Msg("established tunnel, proxying traffic") - var wait sync.WaitGroup - wait.Go(func() { copyStream(ses, c, ses.ctx.conn) }) - wait.Go(func() { copyStream(ses, ses.ctx.conn, c) }) - wait.Wait() - logger.Debug().Msg("closed tunnel") - return -} - -func (p *Proxy) canIntercept(request *http.Request) bool { - if permit := p.policy.PermitIntercept(request); permit != nil { - return *permit - } - return true -} - -/* -func (p *Proxy) handleAPIRequest(ses *Session) error { - if ses.request.URL.Path == "/ca.crt" && p.authority != nil { - b := pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: p.authority.Certificate().Raw, - }) - - ses.response = NewResponse(http.StatusOK, bytes.NewReader(b), ses.request) - defer log.OnCloseError(logger.Debug(), ses.response.Body) - - ses.response.Close = true - ses.response.Header.Set("Content-Type", "application/x-x509-ca-cert") - ses.response.ContentLength = int64(len(b)) - return p.writeResponse(ses) - } - - ses.response = ErrorResponse(ses.request, errors.New("invalid API endpoint")) - defer log.OnCloseError(logger.Debug(), ses.response.Body) - ses.response.Close = true - return p.writeResponse(ses) -} -*/ - -func (p *Proxy) readRequest(ctx *Context) (request *http.Request, err error) { +func (p *Proxy) handle(nc net.Conn) { var ( - done = make(chan *http.Request, 1) - errs = make(chan error, 1) + start = time.Now() + ctx = NewContext(nc).(*proxyContext) + err error ) + defer func() { + if cerr := ctx.Close(); cerr != nil && err == nil { + err = cerr + } - go func() { - r, err := http.ReadRequest(ctx.rw.Reader) - if err != nil { - errs <- err + log := ctx.AccessLogEntry().WithField("duration", time.Since(start)) + if err != nil && !netutil.IsClosing(err) { + log = log.WithError(err) + } + if req := ctx.Request(); req != nil { + log = log.WithFields(logrus.Fields{ + "method": req.Method, + "request": req.URL.String(), + }) + } + if res := ctx.Response(); res != nil { + //countStatus(res.StatusCode) + log.WithFields(logrus.Fields{ + "response": res.StatusCode, + }).Info(res.Status) } else { - done <- r + //countStatus(0) + log.Info("No response") } }() + // Propagate timeouts + ctx.SetIdleTimeout(p.idleTimeout) + + for _, f := range p.connFilter { + fc, err := f.FilterConn(ctx) + if err != nil { + ServerLog.WithField("filter", fmt.Sprintf("%T", f)).WithError(err).Warn("error in conn filter") + _ = nc.Close() + return + } else if fc != nil { + ServerLog.WithField("filter", fmt.Sprintf("%T", f)).Debug("replacing connection from filter") + ctx.Conn = fc + ctx.br = bufio.NewReader(fc) + } + } + + for { + if ctx.isTransparentTLS { + ctx.req = &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{ + Scheme: "tcp", + Host: net.JoinHostPort(ctx.serverName, "443"), + }, + } + } else if ctx.req, err = http.ReadRequest(ctx.Reader()); err != nil { + if !(errors.Is(err, io.EOF) || errors.Is(err, syscall.ECONNRESET)) { + ServerLog.WithError(err).Debug("error reading request") + } + return + } + + if ctx.isTransparent { + // Canonicallize to absolute URL + if ctx.req.URL.Host == "" { + ctx.req.URL.Host = ctx.req.Host + } + if ctx.req.URL.Scheme == "" { + ctx.req.URL.Scheme = "http" + } + ctx.isTransparent = false + } + + for _, f := range p.requestFilter { + newReq, newRes := f.FilterRequest(ctx) + if newReq != nil { + ServerLog.WithFields(logrus.Fields{ + "filter": fmt.Sprintf("%T", f), + "old_method": ctx.req.Method, + "old_url": ctx.req.URL, + "new_method": newReq.Method, + "new_url": newReq.URL, + }).Debug("replacing request from filter") + ctx.req = newReq + } + if newRes != nil { + log := ServerLog.WithFields(logrus.Fields{ + "filter": fmt.Sprintf("%T", f), + "response": newRes.StatusCode, + "status": newRes.Status, + }) + log.Debug("replacing response from filter") + ctx.res = newRes + if err = p.writeResponse(ctx); err != nil { + log.WithError(err).Warn("error overriding repsonse") + } + continue + } + } + + if err = p.handleRequest(ctx); err != nil { + return + } + + // Only once + if ctx.isTransparent || ctx.isTransparentTLS { + return + } + } +} + +func (p *Proxy) handleRequest(ctx *proxyContext) (err error) { + switch { + case ctx.req == nil: + ctx.LogEntry().Warn("Request is nil in handleRequest!?") + return errors.New("proxy: request is nil?") + + case headerContains(ctx.req.Header, HeaderConnection, "upgrade"): + if headerContains(ctx.req.Header, HeaderUpgrade, "websocket") { + return p.serveWebSocket(ctx) + } + ctx.res = NewResponse(http.StatusBadRequest, nil, ctx.req) + return p.writeResponse(ctx) + + case ctx.req.Method == http.MethodConnect: + return p.serveConnect(ctx) + + case ctx.req.URL.IsAbs(): + return p.serveForward(ctx) + + default: + return p.serve(ctx) + } +} + +func (p *Proxy) applyResponseFilter(ctx *proxyContext) { + for _, f := range p.responseFilter { + if newRes := f.FilterResponse(ctx); newRes != nil { + if ctx.res.Body != nil { + _ = ctx.res.Body.Close() + } + ctx.res = newRes + } + } +} + +func (p *Proxy) serve(ctx *proxyContext) (err error) { + var ( + b = new(bytes.Buffer) + cw = ctx.cw + ) + // This is where our response headers etc. are captured + ctx.res = NewResponse(http.StatusOK, nil, ctx.req) + + // This is where our response body is captured + ctx.cw = &countingWriter{writer: b, bytes: ctx.cw.bytes} + + // Pass ServeHTTP call to mux handler(s) + p.mux.ServeHTTP(ctx, ctx.req) + + // Expose body + ctx.res.Body = io.NopCloser(b) + + // Correct headers + if ctx.res.Header.Get(HeaderDate) == "" { + ctx.res.Header.Set(HeaderDate, time.Now().UTC().Format("Mon, 2 Jan 2006 15:04:05")+" GMT") + } + if ctx.res.Header.Get(HeaderContentType) == "" && b.Len() > 0 { + ctx.res.Header.Set(HeaderContentType, "text/html; charset=utf-8") + } + + // Restore writer for the call to writeResponse + ctx.cw = cw + return p.writeResponse(ctx) +} + +func (p *Proxy) serveConnect(ctx *proxyContext) (err error) { + log := ctx.LogEntry() + + // Most browsers expect to get a 200 OK after firing a HTTP CONNECT request; if the upstream + // encounters any errors, we'll inform the client after reading the HTTP request that follows. + if !(ctx.isTransparent || ctx.isTransparentTLS) { + if _, err = io.WriteString(ctx, "HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil { + return + } + } + + switch ctx.req.URL.Scheme { + case "": + ctx.req.URL.Scheme = "tcp" + } + log.WithField("target", ctx.req.URL.String()).Debug("http CONNECT request") + + var ( + timeout, cancel = context.WithTimeout(context.Background(), p.dialTimeout) + c net.Conn + ) + if c, err = p.dial(timeout, ctx.req); err != nil { + cancel() + ctx.res = NewErrorResponse(err, ctx.req) + _ = p.writeResponse(ctx) + _ = ctx.Close() + return fmt.Errorf("proxy: dial %s error: %w", ctx.req.URL, err) + } + cancel() + + ctx.res = NewResponse(http.StatusOK, nil, ctx.req) + srv := NewContext(c).(*proxyContext) + srv.SetIdleTimeout(p.idleTimeout) + return p.multiplex(ctx, srv) +} + +func (p *Proxy) serveForward(ctx *proxyContext) (err error) { + log := ctx.LogEntry() + log.WithField("target", ctx.req.URL.String()).Debug("http forward request") + + if ctx.res, err = p.rt.RoundTrip(ctx.req); err != nil { + // log.Printf("%s forward request error: %v", ctx, err) + ctx.res = NewErrorResponse(err, ctx.req) + _ = p.writeResponse(ctx) + _ = ctx.Close() + return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err) + } + p.applyResponseFilter(ctx) + return p.writeResponse(ctx) +} + +func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) { + log := ctx.LogEntry().WithField("target", ctx.req.URL.String()) + + switch ctx.req.URL.Scheme { + case "http": + ctx.req.URL.Scheme = "ws" + case "https": + ctx.req.URL.Scheme = "wss" + } + + log.Debug("http websocket request") + var ( + timeout, cancel = context.WithTimeout(context.Background(), p.dialTimeout) + c net.Conn + ) + if c, err = p.dial(timeout, ctx.req); err != nil { + cancel() + ctx.res = NewErrorResponse(err, ctx.req) + _ = p.writeResponse(ctx) + _ = ctx.Close() + return fmt.Errorf("proxy: dial %s error: %w", ctx.req.URL, err) + } + cancel() + + srv := NewContext(c).(*proxyContext) + srv.SetIdleTimeout(p.idleTimeout) + if err = ctx.req.Write(srv); err != nil { + ctx.res = NewErrorResponse(err, ctx.req) + _ = p.writeResponse(ctx) + _ = ctx.Close() + return fmt.Errorf("proxy: failed to write request to upstream: %w", err) + } + + if ctx.res, err = http.ReadResponse(srv.Reader(), ctx.req); err != nil { + ctx.res = NewErrorResponse(err, ctx.req) + _ = p.writeResponse(ctx) + _ = ctx.Close() + return fmt.Errorf("proxy: failed to read response from upstream: %w", err) + } + + log.WithFields(logrus.Fields{ + "response": ctx.res.StatusCode, + "status": ctx.res.Status, + }).Debug("websocket response from upstream") + if err = p.writeResponse(ctx); err != nil { + _ = ctx.Close() + return + } + ctx.SetIdleTimeout(p.webSocketIdleTimeout) + return p.multiplex(ctx, srv) +} + +func (p *Proxy) multiplex(ctx, srv Context) (err error) { + var ( + errs = make(chan error, 1) + done = make(chan struct{}, 1) + ) + go func(errs chan<- error) { + defer close(done) + if _, err := io.Copy(srv, ctx); err != nil { + errs <- err + } + }(errs) + go func(errs chan<- error) { + if _, err := io.Copy(ctx, srv); err != nil { + errs <- err + } + }(errs) + select { - case <-p.closed: - return nil, ErrClosed - case request = <-done: - return case err = <-errs: return + case <-done: + return } } -func (p *Proxy) cleanRequest(ses *Session, request *http.Request) { - if request.URL.Host == "" { - request.URL.Host = request.Host - } - - // Ensure proper URL scheme - if !strings.HasPrefix(request.URL.Scheme, "http") { - request.URL.Scheme = "http" - } - if ses.ctx.IsTLS() { - state := ses.ctx.conn.Conn.(*tls.Conn).ConnectionState() - request.TLS = &state - request.URL.Scheme = "https" - } - - // Ensure proper RemoteAddr - request.RemoteAddr = ses.ctx.RemoteAddr().String() - - // Ensure proper encoding - if request.Header.Get(HeaderAcceptEncoding) != "" { - // We only support gzip - request.Header.Set(HeaderAcceptEncoding, "gzip") - } -} - -func (p *Proxy) writeResponse(ses *Session) (err error) { - log := ses.log() - - if p.onResponse != nil { - response := p.onResponse.HandleResponse(ses) - if response != nil { - log.Debug().Str("status", response.Status).Msg("response override") - ses.response = response +func (p *Proxy) writeResponse(ctx *proxyContext) (err error) { + res := ctx.Response() + for _, f := range p.responseFilter { + if newRes := f.FilterResponse(ctx); newRes != nil { + log.Printf("filter returned response HTTP %s", newRes.Status) + if res.Body != nil { + _ = res.Body.Close() + } + res = newRes } } - - if err = ses.response.Write(ses.ctx); err != nil { - log.Error().Err(err).Msg("error writing response back to client") - } else if err = ses.ctx.Flush(); err != nil { - log.Error().Err(err).Msg("error flushing response back to client") + ServerLog.WithFields(logrus.Fields{ + "close": res.Close, + "header": res.Header, + }).Debug("writing response") + if err = res.Write(ctx); err != nil { + return + } + if res.Close || ctx.res.Close || strings.ToLower(ctx.res.Header.Get(HeaderConnection)) != "keep-alive" { + // Force closing of connection. + if err = ctx.Close(); err != nil { + return + } + return io.EOF } - return } -func (p *Proxy) connect(ses *Session, network, address string) (c net.Conn, err error) { - log := ses.log() - log.Debug().Msgf("connect to %s://%s", network, address) - - if p.onConnect != nil { - if c = p.onConnect.HandleConnect(ses, network, address); c != nil { - log.Debug().Msg("connect override") - return - } - } - - var host, port string - if host, port, err = net.SplitHostPort(address); err != nil { - return - } - - var hosts []string - if hosts, err = p.resolver.Lookup(context.Background(), host); err != nil { - log.Warn().Err(err).Msg("connect failed: DNS lookup error") - return - } - - log.Debug().Str("address", hosts[0]).Msg("connect resolved address") - return p.dial(network, net.JoinHostPort(hosts[0], port)) -} - -var hopByHopHeaders = []string{ - HeaderConnection, - "Keep-Alive", - "Proxy-Authenticate", - "Proxy-Authorization", - "Proxy-Connection", // Non-standard, but required for HTTP/2. - "Te", - "Trailer", - "Transfer-Encoding", - HeaderUpgrade, -} - -func cleanHopByHopHeaders(header http.Header) { - // Additional hop-by-hop headers may be specified in `Connection` headers. - // http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-14#section-9.1 - for _, values := range header[HeaderConnection] { - for _, key := range strings.Split(values, ",") { - header.Del(key) - } - } - for _, key := range hopByHopHeaders { - header.Del(key) - } -} - -// copyStream copies data from reader to writer -func copyStream(ses *Session, w io.Writer, r io.Reader) { - log := ses.log() - if _, err := io.Copy(w, r); err != nil && !isClosing(err) { - log.Error().Err(err).Msg("failed CONNECT tunnel") - } else { - log.Debug().Msg("finished copying CONNECT tunnel") - } -} - -func isClosing(err error) bool { - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, syscall.ECONNRESET) || err == ErrClosed { - return true - } - if err, ok := err.(net.Error); ok && err.Timeout() { - return true - } - // log.Debug().Msgf("not a closing error %T: %#+v", err, err) - return false -} - -func resolveInterfaceAddr(name string) (err error) { - var iface *net.Interface - if iface, err = net.InterfaceByName(name); err != nil { - return - } - - var addrs []net.Addr - if addrs, err = iface.Addrs(); err != nil { - return - } - - for _, addr := range addrs { - if addr, ok := addr.(*net.IPNet); ok && !addr.IP.IsUnspecified() { - log.Warn().Msgf("addr %T: %s", addr, addr) - } - } - return errors.New("nope; TODO") +func headerContains(h http.Header, k, v string) bool { + vs := h[http.CanonicalHeaderKey(k)] + return slices.ContainsFunc(vs, func(e string) bool { + return strings.EqualFold(e, v) + }) } diff --git a/proxy/resolver/resolver.go b/proxy/resolver/resolver.go deleted file mode 100644 index 9d43c30..0000000 --- a/proxy/resolver/resolver.go +++ /dev/null @@ -1,148 +0,0 @@ -// Package resolver implements a caching DNS resolver -package resolver - -import ( - "context" - "math/rand/v2" - "net" - "strings" - "time" - - "git.maze.io/maze/styx/internal/netutil" - "github.com/hashicorp/golang-lru/v2/expirable" -) - -const ( - DefaultSize = 1024 - DefaultTTL = 5 * time.Minute - DefaultTimeout = 10 * time.Second -) - -var ( - // DefaultConfig are the defaults for the Default resolver. - DefaultConfig = Config{ - Size: DefaultSize, - TTL: DefaultTTL.Seconds(), - Timeout: DefaultTimeout.Seconds(), - } - - // Default resolver. - Default = New(DefaultConfig) -) - -type Resolver interface { - // Lookup returns resolved IPs for given hostname/ips. - Lookup(context.Context, string) ([]string, error) -} - -type netResolver struct { - resolver *net.Resolver - timeout time.Duration - noIPv6 bool - cache *expirable.LRU[string, []string] -} - -type Config struct { - // Size is our cache size in number of entries. - Size int `hcl:"size,optional"` - - // TTL is the cache time to live in seconds. - TTL float64 `hcl:"ttl,optional"` - - // Timeout is the cache timeout in seconds. - Timeout float64 `hcl:"timeout,optional"` - - // Server are alternative DNS servers. - Server []string `hcl:"server,optional"` - - // NoIPv6 disables IPv6 DNS resolution. - NoIPv6 bool `hcl:"noipv6,optional"` -} - -func New(config Config) Resolver { - var ( - size = config.Size - ttl = time.Duration(float64(time.Second) * config.TTL) - timeout = time.Duration(float64(time.Second) * config.Timeout) - ) - if size <= 0 { - size = DefaultSize - } - if ttl <= 0 { - ttl = DefaultTTL - } - if timeout <= 0 { - timeout = 0 - } - - var resolver = new(net.Resolver) - if len(config.Server) > 0 { - var dialer net.Dialer - resolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { - server := netutil.EnsurePort(config.Server[rand.IntN(len(config.Server))], "53") - return dialer.DialContext(ctx, network, server) - } - } - - return &netResolver{ - resolver: resolver, - timeout: timeout, - noIPv6: config.NoIPv6, - cache: expirable.NewLRU[string, []string](size, nil, ttl), - } -} - -func (r *netResolver) Lookup(ctx context.Context, host string) ([]string, error) { - host = strings.ToLower(strings.TrimSpace(host)) - if hosts, ok := r.cache.Get(host); ok { - rand.Shuffle(len(hosts), func(i, j int) { - hosts[i], hosts[j] = hosts[j], hosts[i] - }) - return hosts, nil - } - - hosts, err := r.lookup(ctx, host) - if err != nil { - return nil, err - } - r.cache.Add(host, hosts) - return hosts, nil -} - -func (r *netResolver) lookup(ctx context.Context, host string) ([]string, error) { - if r.timeout > 0 { - var cancel func() - ctx, cancel = context.WithTimeout(ctx, r.timeout) - defer cancel() - } - - if net.ParseIP(host) == nil { - addrs, err := r.resolver.LookupHost(ctx, host) - if err != nil { - return nil, err - } - if r.noIPv6 { - var addrs4 []string - for _, addr := range addrs { - if net.ParseIP(addr).To4() != nil { - addrs4 = append(addrs4, addr) - } - } - return addrs4, nil - } - return addrs, nil - } - - addrs, err := r.resolver.LookupIPAddr(ctx, host) - if err != nil { - return nil, err - } - - hosts := make([]string, len(addrs)) - for i, addr := range addrs { - if !r.noIPv6 || addr.IP.To4() != nil { - hosts[i] = addr.IP.String() - } - } - return hosts, nil -} diff --git a/proxy/response.go b/proxy/response.go index a4168d0..539d44b 100644 --- a/proxy/response.go +++ b/proxy/response.go @@ -2,77 +2,58 @@ package proxy import ( "bytes" - "fmt" "io" "net/http" "os" "strconv" - - "git.maze.io/maze/styx/internal/log" ) -func NewResponse(code int, body io.Reader, request *http.Request) *http.Response { - if body == nil { - body = new(bytes.Buffer) - } - - rc, ok := body.(io.ReadCloser) - if !ok { - rc = io.NopCloser(body) - } - - response := &http.Response{ - Status: strconv.Itoa(code) + " " + http.StatusText(code), - StatusCode: code, - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - Body: rc, - Request: request, - } - - if request != nil { - response.Close = request.Close - response.Proto = request.Proto - response.ProtoMajor = request.ProtoMajor - response.ProtoMinor = request.ProtoMinor - } - - return response -} - -type withLen interface { - Len() int -} - -type withSize interface { +type sizer interface { Size() int64 } -func NewJSONResponse(code int, body io.Reader, request *http.Request) *http.Response { - response := NewResponse(code, body, request) - response.Header.Set(HeaderContentType, "application/json") - if s, ok := body.(withLen); ok { - response.Header.Set(HeaderContentLength, strconv.Itoa(s.Len())) - } else if s, ok := body.(withSize); ok { - response.Header.Set(HeaderContentLength, strconv.FormatInt(s.Size(), 10)) - } else { - log.Trace().Str("type", fmt.Sprintf("%T", body)).Msg("can't detemine body size") +// NewResponse prepares a net [http.Response], based on the status code, optional body and +// optional [http.Request]. +func NewResponse(code int, body io.ReadCloser, req *http.Request) *http.Response { + res := &http.Response{ + StatusCode: code, + Header: make(http.Header), + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, } - response.Close = true - return response + + if text := http.StatusText(code); text != "" { + res.Status = strconv.Itoa(code) + " " + text + } else { + res.Status = strconv.Itoa(code) + } + + if body == nil && code >= 400 { + body = io.NopCloser(bytes.NewBufferString(http.StatusText(code))) + } + + res.Body = body + + if s, ok := body.(sizer); ok { + res.ContentLength = s.Size() + } + + if req != nil { + res.Close = req.Close + res.Proto = req.Proto + res.ProtoMajor = req.ProtoMajor + res.ProtoMinor = req.ProtoMinor + } + + return res } -func ErrorResponse(request *http.Request, err error) *http.Response { - response := NewResponse(http.StatusBadGateway, nil, request) +func NewErrorResponse(err error, req *http.Request) *http.Response { switch { - case os.IsNotExist(err): - response.StatusCode = http.StatusNotFound - case os.IsPermission(err): - response.StatusCode = http.StatusForbidden + case os.IsTimeout(err): + return NewResponse(http.StatusGatewayTimeout, nil, req) + default: + return NewResponse(http.StatusBadGateway, nil, req) } - response.Status = http.StatusText(response.StatusCode) - response.Close = true - return response } diff --git a/proxy/session.go b/proxy/session.go deleted file mode 100644 index de7ab6d..0000000 --- a/proxy/session.go +++ /dev/null @@ -1,151 +0,0 @@ -package proxy - -import ( - "bufio" - "crypto/tls" - "encoding/binary" - "encoding/hex" - "math/rand" - "net" - "net/http" - "sync/atomic" - "time" - - "git.maze.io/maze/styx/internal/log" -) - -var seed = rand.NewSource(time.Now().UnixNano()) - -type Context struct { - id int64 - conn *wrappedConn - rw *bufio.ReadWriter - parent *Session - data map[string]any -} - -func newContext(conn net.Conn, rw *bufio.ReadWriter, parent *Session) *Context { - if wrapped, ok := conn.(*wrappedConn); ok { - conn = wrapped.Conn - } - - ctx := &Context{ - id: seed.Int63(), - conn: &wrappedConn{Conn: conn}, - rw: rw, - parent: parent, - data: make(map[string]any), - } - - return ctx -} - -func (ctx *Context) log() log.Logger { - return log.Console.With(). - Str("context", ctx.ID()). - Str("addr", ctx.RemoteAddr().String()). - Logger() -} - -func (ctx *Context) ID() string { - var b [8]byte - binary.BigEndian.PutUint64(b[:], uint64(ctx.id)) - if ctx.parent != nil { - return ctx.parent.ID() + "-" + hex.EncodeToString(b[:]) - } - return hex.EncodeToString(b[:]) -} - -func (ctx *Context) IsTLS() bool { - _, ok := ctx.conn.Conn.(*tls.Conn) - return ok && ctx.parent != nil -} - -func (ctx *Context) RemoteAddr() net.Addr { - if ctx.parent != nil { - return ctx.parent.ctx.RemoteAddr() - } - return ctx.conn.RemoteAddr() -} - -func (ctx *Context) SetDeadline(t time.Time) error { - if ctx.parent != nil { - return ctx.parent.ctx.SetDeadline(t) - } - return ctx.conn.SetDeadline(t) -} - -func (ctx *Context) Set(key string, value any) { - ctx.data[key] = value -} - -func (ctx *Context) Get(key string) (value any, ok bool) { - value, ok = ctx.data[key] - return -} - -func (ctx *Context) Flush() error { - return ctx.rw.Flush() -} - -func (ctx *Context) Write(p []byte) (n int, err error) { - if n, err = ctx.rw.Write(p); n > 0 { - atomic.AddInt64(&ctx.conn.bytes, int64(n)) - } - return -} - -type Session struct { - id int64 - ctx *Context - request *http.Request - response *http.Response - data map[string]any -} - -func newSession(ctx *Context, request *http.Request) *Session { - return &Session{ - id: seed.Int63(), - ctx: ctx, - request: request, - data: make(map[string]any), - } -} - -func (ses *Session) log() log.Logger { - return log.Console.With(). - Str("context", ses.ctx.ID()). - Str("session", ses.ID()). - Str("addr", ses.ctx.RemoteAddr().String()). - Logger() -} - -func (ses *Session) ID() string { - var b [8]byte - binary.BigEndian.PutUint64(b[:], uint64(ses.id)) - return hex.EncodeToString(b[:]) -} - -func (ses *Session) Context() *Context { - return ses.ctx -} - -func (ses *Session) Request() *http.Request { - return ses.request -} - -func (ses *Session) Response() *http.Response { - return ses.response -} - -type wrappedConn struct { - net.Conn - bytes int64 -} - -func (c *wrappedConn) Write(p []byte) (n int, err error) { - if n, err = c.Conn.Write(p); n > 0 { - atomic.AddInt64(&c.bytes, int64(n)) - } - return -} diff --git a/proxy/stats.go b/proxy/stats.go new file mode 100644 index 0000000..2dd8052 --- /dev/null +++ b/proxy/stats.go @@ -0,0 +1,19 @@ +package proxy + +import ( + "expvar" + "strconv" + + "git.maze.io/maze/styx/db/stats" +) + +func countStatus(code int) { + k := "http:status:" + strconv.Itoa(code) + v := expvar.Get(k) + if v == nil { + //v = stats.NewCounter("120s1s", "15m10s", "1h1m", "4w1d", "1y4w") + v = stats.NewCounter(k, stats.Minutely, stats.Hourly, stats.Daily, stats.Yearly) + expvar.Publish(k, v) + } + v.(stats.Metric).Add(1) +} diff --git a/proxy/stats/stats.go b/proxy/stats/stats.go deleted file mode 100644 index 183a898..0000000 --- a/proxy/stats/stats.go +++ /dev/null @@ -1,225 +0,0 @@ -package stats - -import ( - "database/sql" - "database/sql/driver" - "encoding/json" - "fmt" - "net/http" - "os" - "os/user" - "path/filepath" - "time" - - "git.maze.io/maze/styx/internal/log" - _ "github.com/mattn/go-sqlite3" -) - -type Stats struct { - db *sql.DB -} - -func New() (*Stats, error) { - u, err := user.Current() - if err != nil { - return nil, err - } - - path := filepath.Join(u.HomeDir, ".styx", "stats.db") - if err = os.MkdirAll(filepath.Dir(path), 0o750); err != nil { - return nil, err - } - - db, err := sql.Open("sqlite3", path+"?_journal_mode=WAL") - if err != nil { - return nil, err - } - - for _, table := range []string{ - createLog, - createDomainStat, - createStatusStat, - } { - if _, err = db.Exec(table); err != nil { - return nil, err - } - } - - return &Stats{db: db}, nil -} - -func (s *Stats) AddLog(entry *Log) error { - var ( - request []byte - response []byte - err error - ) - if request, err = json.Marshal(entry.Request); err != nil { - return err - } - if response, err = json.Marshal(entry.Response); err != nil { - return err - } - - tx, err := s.db.Begin() - if err != nil { - return err - } - stmt, err := tx.Prepare("insert into styx_log(client_ip, request, response) values(?, ?, ?)") - if err != nil { - return err - } - defer stmt.Close() - if _, err = stmt.Exec(entry.ClientIP, request, response); err != nil { - return err - } - return tx.Commit() -} - -func (s *Stats) QueryLog(offset, limit int) ([]*Log, error) { - if limit == 0 { - limit = 50 - } - - rows, err := s.db.Query("select dt, client_ip, request, response from styx_log limit ?, ?", offset, limit) - if err != nil { - return nil, err - } - defer rows.Close() - - var logs []*Log - for rows.Next() { - var entry = new(Log) - if err = rows.Scan(&entry.Time, &entry.ClientIP, &entry.Request, &entry.Response); err != nil { - return nil, err - } - logs = append(logs, entry) - } - - return logs, nil -} - -type Status struct { - Code int `json:"code"` - Count int `json:"count"` -} - -var timeZero time.Time - -func (s *Stats) QueryStatus(since time.Time) ([]*Status, error) { - if since.Equal(timeZero) { - since = time.Now().Add(-24 * time.Hour) - } - - rows, err := s.db.Query("select response->'status', count(*) from styx_log where dt >= ? group by response->'status' order by response->'status'", since) - if err != nil { - return nil, err - } - - var stats []*Status - for rows.Next() { - var entry = new(Status) - if err = rows.Scan(&entry.Code, &entry.Count); err != nil { - return nil, err - } - stats = append(stats, entry) - } - return stats, nil -} - -const createLog = `CREATE TABLE IF NOT EXISTS styx_log ( - id INT PRIMARY KEY, - dt DATETIME DEFAULT CURRENT_TIMESTAMP, - client_ip TEXT NOT NULL, - request JSONB NOT NULL, - response JSONB NOT NULL -);` - -type Log struct { - Time time.Time `json:"time"` - ClientIP string `json:"client_ip"` - Request *Request `json:"request"` - Response *Response `json:"response"` -} - -type Request struct { - URL string `json:"url"` - Host string `json:"host"` - Method string `json:"method"` - Proto string `json:"proto"` - Header http.Header `json:"header"` -} - -func (r *Request) Scan(value any) error { - switch v := value.(type) { - case string: - return json.Unmarshal([]byte(v), r) - case []byte: - return json.Unmarshal(v, r) - default: - log.Error().Str("type", fmt.Sprintf("%T", value)).Msg("scan request unknown type") - return nil - } -} - -func (r *Request) Value() (driver.Value, error) { - b, err := json.Marshal(r) - return string(b), err -} - -func FromRequest(r *http.Request) *Request { - return &Request{ - URL: r.URL.String(), - Host: r.Host, - Method: r.Method, - Proto: r.Proto, - Header: r.Header, - } -} - -type Response struct { - Status int `json:"status"` - Size int64 `json:"size"` - Header http.Header `json:"header"` -} - -func (r *Response) Scan(value any) error { - switch v := value.(type) { - case string: - return json.Unmarshal([]byte(v), r) - case []byte: - return json.Unmarshal(v, r) - default: - log.Error().Str("type", fmt.Sprintf("%T", value)).Msg("scan response unknown type") - return nil - } -} - -func (r *Response) Value() (driver.Value, error) { - b, err := json.Marshal(r) - return string(b), err -} - -func (r *Response) SetSize(size int64) *Response { - r.Size = size - return r -} - -func FromResponse(r *http.Response) *Response { - return &Response{ - Status: r.StatusCode, - Header: r.Header, - } -} - -const createStatusStat = `CREATE TABLE IF NOT EXISTS styx_stat_status ( - id INT PRIMARY KEY, - dt DATETIME DEFAULT CURRENT_TIMESTAMP, - status INT NOT NULL -);` - -const createDomainStat = `CREATE TABLE IF NOT EXISTS styx_stat_domain ( - id INT PRIMARY KEY, - dt DATETIME DEFAULT CURRENT_TIMESTAMP, - domain TEXT NOT NULL -);` diff --git a/proxy/util.go b/proxy/util.go deleted file mode 100644 index 8f1a4ae..0000000 --- a/proxy/util.go +++ /dev/null @@ -1,16 +0,0 @@ -package proxy - -import ( - "io" - "net" -) - -// connReader is a net.Conn with a separate reader. -type connReader struct { - net.Conn - io.Reader -} - -func (c connReader) Read(p []byte) (int, error) { - return c.Reader.Read(p) -}