Checkpoint

This commit is contained in:
2025-10-01 15:37:55 +02:00
parent 4a60059ff2
commit 03352e3312
31 changed files with 2611 additions and 384 deletions

View File

@@ -13,13 +13,14 @@ import (
"net/http"
"net/url"
"slices"
"strconv"
"strings"
"syscall"
"time"
"git.maze.io/maze/styx/internal/netutil"
"git.maze.io/maze/styx/logger"
"git.maze.io/maze/styx/stats"
"github.com/sirupsen/logrus"
)
// Common HTTP headers.
@@ -32,6 +33,7 @@ const (
HeaderForwardedHost = "X-Forwarded-Host"
HeaderForwardedPort = "X-Forwarded-Port"
HeaderForwardedProto = "X-Forwarded-Proto"
HeaderLocation = "Location"
HeaderRealIP = "X-Real-Ip"
HeaderUpgrade = "Upgrade"
HeaderVia = "Via"
@@ -46,43 +48,111 @@ const (
var (
// AccessLog is used for logging requests to the proxy.
AccessLog = logrus.StandardLogger()
AccessLog = logger.Get()
// ServerLog is used for logging server log messages.
ServerLog = logrus.StandardLogger()
ServerLog = logger.Get()
)
// Proxy implements a HTTP(S) proxy.
type Proxy struct {
rt http.RoundTripper
dialer map[string]Dialer
connFilter []ConnFilter
requestFilter []RequestFilter
responseFilter []ResponseFilter
dialTimeout time.Duration
idleTimeout time.Duration
webSocketIdleTimeout time.Duration
mux *http.ServeMux
// RoundTripper is used to make outbound HTTP requests. It defaults to a [http.Transport]
// with a custom DialContext that uses the configured [Dialer]s.
//
// Only override this if you know what you are doing.
RoundTripper http.RoundTripper
// Dialer is a map of protocol names to [Dialer] implementations. The default [Dialer]
// corresponds to an empty string key.
//
// Only override the default [Dialer] if you know what you are doing.
Dialer map[string]Dialer
// OnConnect is a list of connection filters that are applied in order when a new
// connection is established.
//
// Connection filters can be used to implement custom authentication, logging,
// rate limiting, etc.
//
// Connection filters are applied before any HTTP request is read from the connection.
//
// Connection filters should return a non-nil error if they want to terminate the
// connection. Returning a non-nil [net.Conn] will replace the existing connection
// with the returned one.
//
// Connection filters should not modify the connection in any way (e.g. wrapping it
// in a TLS connection) as this will interfere with the proxy's ability to read
// HTTP requests from the connection.
//
// Connection filters are executed sequentially in the order they are added.
OnConnect []ConnHandler
// OnRequest is a list of request filters that are applied in order when a new
// HTTP request is read from the connection.
//
// Request filters can be used to modify the request, or to return a response
// directly without forwarding the request to the upstream server.
//
// Request filters should return a non-nil error if they want to terminate the
// connection.
//
// Request filters are executed sequentially in the order they are added.
OnRequest []RequestHandler
// OnResponse is a list of response filters that are applied in order when a
// response is received from the upstream server.
//
// Response filters can be used to modify the response before it is sent to
// the client.
//
// Response filters should return a non-nil error if they want to terminate the
// connection.
//
// Response filters are executed sequentially in the order they are added.
OnResponse []ResponseHandler
// OnError is a list of error handlers that are applied in order when an
// error occurs during request processing.
//
// Error handlers can be used to log errors, or to return a custom response
// to the client.
//
// Error handlers should return a non-nil error if they want to terminate the
// connection.
//
// Error handlers are executed sequentially in the order they are added.
OnError []ErrorHandler
// DialTimeout is the timeout for establishing new connections to upstream servers.
DialTimeout time.Duration
// IdleTimeout is the timeout for idle connections.
IdleTimeout time.Duration
// WebSocketIdleTimeout is the timeout for idle WebSocket connections.
WebSocketIdleTimeout time.Duration
mux *http.ServeMux
}
// New [Proxy] with somewhat sane defaults.
func New() *Proxy {
p := &Proxy{
dialer: map[string]Dialer{"": defaultDialer{}},
dialTimeout: DefaultDialTimeout,
idleTimeout: DefaultIdleTimeout,
webSocketIdleTimeout: DefaultWebSocketIdleTimeout,
Dialer: map[string]Dialer{"": defaultDialer{}},
DialTimeout: DefaultDialTimeout,
IdleTimeout: DefaultIdleTimeout,
WebSocketIdleTimeout: DefaultWebSocketIdleTimeout,
mux: http.NewServeMux(),
}
// Make sure the roundtripper uses our dialers.
p.rt = &http.Transport{
p.RoundTripper = &http.Transport{
TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper),
Proxy: http.ProxyFromEnvironment,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: time.Second,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return p.dialer[""].DialContext(ctx, &http.Request{
return p.Dialer[""].DialContext(ctx, &http.Request{
URL: &url.URL{
Scheme: network,
Host: addr,
@@ -112,41 +182,17 @@ func (p *Proxy) HandleFunc(pattern string, handler func(http.ResponseWriter, *ht
func (p *Proxy) SetDialer(proto string, dialer Dialer) {
if dialer == nil {
if proto != "" {
delete(p.dialer, proto)
delete(p.Dialer, proto)
}
} else {
p.dialer[proto] = dialer
p.Dialer[proto] = dialer
}
}
// AddConnFilter adds a connection filter to the stack.
func (p *Proxy) AddConnFilter(f ConnFilter) {
if f == nil {
return
}
p.connFilter = append(p.connFilter, f)
}
// AddRequestFilter adds a request filter to the stack.
func (p *Proxy) AddRequestFilter(f RequestFilter) {
if f == nil {
return
}
p.requestFilter = append(p.requestFilter, f)
}
// AddResponseFilter adds a response filter to the stack.
func (p *Proxy) AddResponseFilter(f ResponseFilter) {
if f == nil {
return
}
p.responseFilter = append(p.responseFilter, f)
}
func (p *Proxy) dial(ctx context.Context, req *http.Request) (net.Conn, error) {
d, ok := p.dialer[req.URL.Scheme]
d, ok := p.Dialer[req.URL.Scheme]
if !ok {
d = p.dialer[""]
d = p.Dialer[""]
}
return d.DialContext(ctx, req)
@@ -170,23 +216,32 @@ func (p *Proxy) handle(nc net.Conn) {
err error
)
defer func() {
if cerr := ctx.Close(); cerr != nil && err == nil {
if r := recover(); r != nil {
if err, ok := r.(error); ok {
ctx.LogEntry().Err(err).Warn("Bug in code, recovered from panic!")
}
_ = nc.Close()
}
}()
defer func() {
if cerr := ctx.Close(); cerr != nil && err == nil && !netutil.IsClosing(cerr) {
err = cerr
}
log := ctx.AccessLogEntry().WithField("duration", time.Since(start))
log := ctx.AccessLogEntry().Value("duration", time.Since(start))
if err != nil && !netutil.IsClosing(err) {
log = log.WithError(err)
log = log.Err(err)
}
if req := ctx.Request(); req != nil {
log = log.WithFields(logrus.Fields{
log = log.Values(logger.Values{
"method": req.Method,
"request": req.URL.String(),
})
}
if res := ctx.Response(); res != nil {
//countStatus(res.StatusCode)
log.WithFields(logrus.Fields{
log.Values(logger.Values{
"response": res.StatusCode,
}).Info(res.Status)
} else {
@@ -196,38 +251,42 @@ func (p *Proxy) handle(nc net.Conn) {
}()
// Propagate timeouts
ctx.SetIdleTimeout(p.idleTimeout)
ctx.SetIdleTimeout(p.IdleTimeout)
for _, f := range p.connFilter {
fc, err := f.FilterConn(ctx)
for _, f := range p.OnConnect {
fc, err := f.HandleConn(ctx)
if err != nil {
ServerLog.WithField("filter", fmt.Sprintf("%T", f)).WithError(err).Warn("error in conn filter")
ServerLog.Value("filter", fmt.Sprintf("%T", f)).Err(err).Warn("Error in conn filter")
p.handleError(ctx, err, true)
_ = nc.Close()
return
} else if fc != nil {
ServerLog.WithField("filter", fmt.Sprintf("%T", f)).Debug("replacing connection from filter")
ServerLog.Value("filter", fmt.Sprintf("%T", f)).Debug("Replacing connection from filter")
ctx.Conn = fc
ctx.br = bufio.NewReader(fc)
}
}
for {
if ctx.isTransparentTLS {
if ctx.transparentTLS {
ctx.req = &http.Request{
Method: http.MethodConnect,
URL: &url.URL{
Scheme: "tcp",
Host: net.JoinHostPort(ctx.serverName, "443"),
},
Method: http.MethodConnect,
URL: &url.URL{Scheme: "tcp", Host: net.JoinHostPort(ctx.serverName, strconv.Itoa(ctx.transparent))},
Host: net.JoinHostPort(ctx.serverName, strconv.Itoa(ctx.transparent)),
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Close: true,
}
} else if ctx.req, err = http.ReadRequest(ctx.Reader()); err != nil {
if !(errors.Is(err, io.EOF) || errors.Is(err, syscall.ECONNRESET)) {
ServerLog.WithError(err).Debug("error reading request")
ServerLog.Err(err).Debug("Error reading request")
}
p.handleError(ctx, err, true)
return
}
if ctx.isTransparent {
if ctx.transparent > 0 {
// Canonicallize to absolute URL
if ctx.req.URL.Host == "" {
ctx.req.URL.Host = ctx.req.Host
@@ -235,47 +294,68 @@ func (p *Proxy) handle(nc net.Conn) {
if ctx.req.URL.Scheme == "" {
ctx.req.URL.Scheme = "http"
}
ctx.isTransparent = false
ctx.transparent = 0
}
for _, f := range p.requestFilter {
newReq, newRes := f.FilterRequest(ctx)
for _, f := range p.OnRequest {
newReq, newRes := f.HandleRequest(ctx)
if newReq != nil {
ServerLog.WithFields(logrus.Fields{
ServerLog.Values(logger.Values{
"filter": fmt.Sprintf("%T", f),
"old_method": ctx.req.Method,
"old_url": ctx.req.URL,
"new_method": newReq.Method,
"new_url": newReq.URL,
}).Debug("replacing request from filter")
}).Debug("Replacing request from filter")
ctx.req = newReq
}
if newRes != nil {
log := ServerLog.WithFields(logrus.Fields{
log := ServerLog.Values(logger.Values{
"filter": fmt.Sprintf("%T", f),
"response": newRes.StatusCode,
"status": newRes.Status,
})
log.Debug("replacing response from filter")
log.Debug("Replacing response from filter")
ctx.res = newRes
if err = p.writeResponse(ctx); err != nil {
log.WithError(err).Warn("error overriding repsonse")
if netutil.IsClosing(err) {
return
}
log.Err(err).Warn("Error overriding repsonse")
}
continue
}
}
if err = p.handleRequest(ctx); err != nil {
p.handleError(ctx, err, true)
return
}
// Only once
if ctx.isTransparent || ctx.isTransparentTLS {
if ctx.transparent > 0 || ctx.transparentTLS || ctx.req.Method == http.MethodConnect {
return
}
}
}
func (p *Proxy) handleError(ctx *proxyContext, err error, sendResponse bool) {
res := ctx.Response()
if res == nil && sendResponse {
res = NewErrorResponse(err, ctx.Request())
}
for _, f := range p.OnError {
if newRes := f.HandleError(ctx, err); newRes != nil {
res = newRes
}
}
if sendResponse && res != nil {
if werr := p.writeResponse(ctx); werr != nil && !netutil.IsClosing(err) {
ServerLog.Err(werr).Warn("Error writing error response")
}
}
}
func (p *Proxy) handleRequest(ctx *proxyContext) (err error) {
switch {
case ctx.req == nil:
@@ -300,9 +380,9 @@ func (p *Proxy) handleRequest(ctx *proxyContext) (err error) {
}
}
func (p *Proxy) applyResponseFilter(ctx *proxyContext) {
for _, f := range p.responseFilter {
if newRes := f.FilterResponse(ctx); newRes != nil {
func (p *Proxy) applyResponseHandler(ctx *proxyContext) {
for _, f := range p.OnResponse {
if newRes := f.HandleResponse(ctx); newRes != nil {
if ctx.res.Body != nil {
_ = ctx.res.Body.Close()
}
@@ -346,7 +426,7 @@ func (p *Proxy) serveConnect(ctx *proxyContext) (err error) {
// Most browsers expect to get a 200 OK after firing a HTTP CONNECT request; if the upstream
// encounters any errors, we'll inform the client after reading the HTTP request that follows.
if !(ctx.isTransparent || ctx.isTransparentTLS) {
if !(ctx.transparent > 0 || ctx.transparentTLS) {
if _, err = io.WriteString(ctx, "HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil {
return
}
@@ -356,10 +436,10 @@ func (p *Proxy) serveConnect(ctx *proxyContext) (err error) {
case "":
ctx.req.URL.Scheme = "tcp"
}
log.WithField("target", ctx.req.URL.String()).Debug("http CONNECT request")
log.Value("target", ctx.req.URL.String()).Debugf("%s CONNECT request", ctx.req.Proto)
var (
timeout, cancel = context.WithTimeout(context.Background(), p.dialTimeout)
timeout, cancel = context.WithTimeout(context.Background(), p.DialTimeout)
c net.Conn
)
if c, err = p.dial(timeout, ctx.req); err != nil {
@@ -373,27 +453,27 @@ func (p *Proxy) serveConnect(ctx *proxyContext) (err error) {
ctx.res = NewResponse(http.StatusOK, nil, ctx.req)
srv := NewContext(c).(*proxyContext)
srv.SetIdleTimeout(p.idleTimeout)
srv.SetIdleTimeout(p.IdleTimeout)
return p.multiplex(ctx, srv)
}
func (p *Proxy) serveForward(ctx *proxyContext) (err error) {
log := ctx.LogEntry()
log.WithField("target", ctx.req.URL.String()).Debug("http forward request")
log.Value("target", ctx.req.URL.String()).Debugf("%s forward request", ctx.req.Proto)
if ctx.res, err = p.rt.RoundTrip(ctx.req); err != nil {
if ctx.res, err = p.RoundTripper.RoundTrip(ctx.req); err != nil {
// log.Printf("%s forward request error: %v", ctx, err)
ctx.res = NewErrorResponse(err, ctx.req)
_ = p.writeResponse(ctx)
_ = ctx.Close()
return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err)
}
p.applyResponseFilter(ctx)
p.applyResponseHandler(ctx)
return p.writeResponse(ctx)
}
func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
log := ctx.LogEntry().WithField("target", ctx.req.URL.String())
log := ctx.LogEntry().Value("target", ctx.req.URL.String())
switch ctx.req.URL.Scheme {
case "http":
@@ -402,9 +482,9 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
ctx.req.URL.Scheme = "wss"
}
log.Debug("http websocket request")
log.Debugf("%s websocket request", ctx.req.Proto)
var (
timeout, cancel = context.WithTimeout(context.Background(), p.dialTimeout)
timeout, cancel = context.WithTimeout(context.Background(), p.DialTimeout)
c net.Conn
)
if c, err = p.dial(timeout, ctx.req); err != nil {
@@ -417,7 +497,7 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
cancel()
srv := NewContext(c).(*proxyContext)
srv.SetIdleTimeout(p.idleTimeout)
srv.SetIdleTimeout(p.IdleTimeout)
if err = ctx.req.Write(srv); err != nil {
ctx.res = NewErrorResponse(err, ctx.req)
_ = p.writeResponse(ctx)
@@ -432,15 +512,15 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
return fmt.Errorf("proxy: failed to read response from upstream: %w", err)
}
log.WithFields(logrus.Fields{
log.Values(logger.Values{
"response": ctx.res.StatusCode,
"status": ctx.res.Status,
}).Debug("websocket response from upstream")
}).Debug("WebSocket response from upstream")
if err = p.writeResponse(ctx); err != nil {
_ = ctx.Close()
return
}
ctx.SetIdleTimeout(p.webSocketIdleTimeout)
ctx.SetIdleTimeout(p.WebSocketIdleTimeout)
return p.multiplex(ctx, srv)
}
@@ -471,19 +551,19 @@ func (p *Proxy) multiplex(ctx, srv Context) (err error) {
func (p *Proxy) writeResponse(ctx *proxyContext) (err error) {
res := ctx.Response()
for _, f := range p.responseFilter {
if newRes := f.FilterResponse(ctx); newRes != nil {
log.Printf("filter returned response HTTP %s", newRes.Status)
for _, f := range p.OnResponse {
if newRes := f.HandleResponse(ctx); newRes != nil {
log.Printf("Filter returned response HTTP %s", newRes.Status)
if res.Body != nil {
_ = res.Body.Close()
}
res = newRes
}
}
ServerLog.WithFields(logrus.Fields{
ServerLog.Values(logger.Values{
"close": res.Close,
"header": res.Header,
}).Debug("writing response")
}).Debug("Writing response")
if err = res.Write(ctx); err != nil {
return
}