Files
styx/proxy/proxy.go
2025-10-01 15:37:55 +02:00

586 lines
16 KiB
Go

package proxy
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
"slices"
"strconv"
"strings"
"syscall"
"time"
"git.maze.io/maze/styx/internal/netutil"
"git.maze.io/maze/styx/logger"
"git.maze.io/maze/styx/stats"
)
// Common HTTP headers.
const (
HeaderConnection = "Connection"
HeaderContentType = "Content-Type"
HeaderDate = "Date"
HeaderForwarded = "Forwarded"
HeaderForwardedFor = "X-Forwarded-For"
HeaderForwardedHost = "X-Forwarded-Host"
HeaderForwardedPort = "X-Forwarded-Port"
HeaderForwardedProto = "X-Forwarded-Proto"
HeaderLocation = "Location"
HeaderRealIP = "X-Real-Ip"
HeaderUpgrade = "Upgrade"
HeaderVia = "Via"
)
// Safe defaults.
const (
DefaultDialTimeout = 15 * time.Second
DefaultIdleTimeout = 10 * time.Second
DefaultWebSocketIdleTimeout = 30 * time.Second
)
var (
// AccessLog is used for logging requests to the proxy.
AccessLog = logger.Get()
// ServerLog is used for logging server log messages.
ServerLog = logger.Get()
)
// Proxy implements a HTTP(S) proxy.
type Proxy struct {
// RoundTripper is used to make outbound HTTP requests. It defaults to a [http.Transport]
// with a custom DialContext that uses the configured [Dialer]s.
//
// Only override this if you know what you are doing.
RoundTripper http.RoundTripper
// Dialer is a map of protocol names to [Dialer] implementations. The default [Dialer]
// corresponds to an empty string key.
//
// Only override the default [Dialer] if you know what you are doing.
Dialer map[string]Dialer
// OnConnect is a list of connection filters that are applied in order when a new
// connection is established.
//
// Connection filters can be used to implement custom authentication, logging,
// rate limiting, etc.
//
// Connection filters are applied before any HTTP request is read from the connection.
//
// Connection filters should return a non-nil error if they want to terminate the
// connection. Returning a non-nil [net.Conn] will replace the existing connection
// with the returned one.
//
// Connection filters should not modify the connection in any way (e.g. wrapping it
// in a TLS connection) as this will interfere with the proxy's ability to read
// HTTP requests from the connection.
//
// Connection filters are executed sequentially in the order they are added.
OnConnect []ConnHandler
// OnRequest is a list of request filters that are applied in order when a new
// HTTP request is read from the connection.
//
// Request filters can be used to modify the request, or to return a response
// directly without forwarding the request to the upstream server.
//
// Request filters should return a non-nil error if they want to terminate the
// connection.
//
// Request filters are executed sequentially in the order they are added.
OnRequest []RequestHandler
// OnResponse is a list of response filters that are applied in order when a
// response is received from the upstream server.
//
// Response filters can be used to modify the response before it is sent to
// the client.
//
// Response filters should return a non-nil error if they want to terminate the
// connection.
//
// Response filters are executed sequentially in the order they are added.
OnResponse []ResponseHandler
// OnError is a list of error handlers that are applied in order when an
// error occurs during request processing.
//
// Error handlers can be used to log errors, or to return a custom response
// to the client.
//
// Error handlers should return a non-nil error if they want to terminate the
// connection.
//
// Error handlers are executed sequentially in the order they are added.
OnError []ErrorHandler
// DialTimeout is the timeout for establishing new connections to upstream servers.
DialTimeout time.Duration
// IdleTimeout is the timeout for idle connections.
IdleTimeout time.Duration
// WebSocketIdleTimeout is the timeout for idle WebSocket connections.
WebSocketIdleTimeout time.Duration
mux *http.ServeMux
}
// New [Proxy] with somewhat sane defaults.
func New() *Proxy {
p := &Proxy{
Dialer: map[string]Dialer{"": defaultDialer{}},
DialTimeout: DefaultDialTimeout,
IdleTimeout: DefaultIdleTimeout,
WebSocketIdleTimeout: DefaultWebSocketIdleTimeout,
mux: http.NewServeMux(),
}
// Make sure the roundtripper uses our dialers.
p.RoundTripper = &http.Transport{
TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper),
Proxy: http.ProxyFromEnvironment,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: time.Second,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return p.Dialer[""].DialContext(ctx, &http.Request{
URL: &url.URL{
Scheme: network,
Host: addr,
},
})
},
}
p.Handle("/stats", stats.Handler(stats.Exposed))
p.Handle("/stats.json", stats.JSONHandler(stats.Exposed))
return p
}
// Handle installs a [http.Handler] into the internal mux.
func (p *Proxy) Handle(pattern string, handler http.Handler) {
p.mux.Handle(pattern, handler)
}
// HandleFunc installs a [http.HandlerFunc] into the internal mux.
func (p *Proxy) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
p.mux.HandleFunc(pattern, handler)
}
// SetDialer specifies a [Dialer] for the specified protocol. The default [Dialer] corresponds
// to an empty string. Only override the default [Dialer] if you know what you are doing.
func (p *Proxy) SetDialer(proto string, dialer Dialer) {
if dialer == nil {
if proto != "" {
delete(p.Dialer, proto)
}
} else {
p.Dialer[proto] = dialer
}
}
func (p *Proxy) dial(ctx context.Context, req *http.Request) (net.Conn, error) {
d, ok := p.Dialer[req.URL.Scheme]
if !ok {
d = p.Dialer[""]
}
return d.DialContext(ctx, req)
}
// Serve proxied connections on the specified listener.
func (p *Proxy) Serve(l net.Listener) error {
for {
c, err := l.Accept()
if err != nil {
return err
}
go p.handle(c)
}
}
func (p *Proxy) handle(nc net.Conn) {
var (
start = time.Now()
ctx = NewContext(nc).(*proxyContext)
err error
)
defer func() {
if r := recover(); r != nil {
if err, ok := r.(error); ok {
ctx.LogEntry().Err(err).Warn("Bug in code, recovered from panic!")
}
_ = nc.Close()
}
}()
defer func() {
if cerr := ctx.Close(); cerr != nil && err == nil && !netutil.IsClosing(cerr) {
err = cerr
}
log := ctx.AccessLogEntry().Value("duration", time.Since(start))
if err != nil && !netutil.IsClosing(err) {
log = log.Err(err)
}
if req := ctx.Request(); req != nil {
log = log.Values(logger.Values{
"method": req.Method,
"request": req.URL.String(),
})
}
if res := ctx.Response(); res != nil {
//countStatus(res.StatusCode)
log.Values(logger.Values{
"response": res.StatusCode,
}).Info(res.Status)
} else {
//countStatus(0)
log.Info("No response")
}
}()
// Propagate timeouts
ctx.SetIdleTimeout(p.IdleTimeout)
for _, f := range p.OnConnect {
fc, err := f.HandleConn(ctx)
if err != nil {
ServerLog.Value("filter", fmt.Sprintf("%T", f)).Err(err).Warn("Error in conn filter")
p.handleError(ctx, err, true)
_ = nc.Close()
return
} else if fc != nil {
ServerLog.Value("filter", fmt.Sprintf("%T", f)).Debug("Replacing connection from filter")
ctx.Conn = fc
ctx.br = bufio.NewReader(fc)
}
}
for {
if ctx.transparentTLS {
ctx.req = &http.Request{
Method: http.MethodConnect,
URL: &url.URL{Scheme: "tcp", Host: net.JoinHostPort(ctx.serverName, strconv.Itoa(ctx.transparent))},
Host: net.JoinHostPort(ctx.serverName, strconv.Itoa(ctx.transparent)),
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Close: true,
}
} else if ctx.req, err = http.ReadRequest(ctx.Reader()); err != nil {
if !(errors.Is(err, io.EOF) || errors.Is(err, syscall.ECONNRESET)) {
ServerLog.Err(err).Debug("Error reading request")
}
p.handleError(ctx, err, true)
return
}
if ctx.transparent > 0 {
// Canonicallize to absolute URL
if ctx.req.URL.Host == "" {
ctx.req.URL.Host = ctx.req.Host
}
if ctx.req.URL.Scheme == "" {
ctx.req.URL.Scheme = "http"
}
ctx.transparent = 0
}
for _, f := range p.OnRequest {
newReq, newRes := f.HandleRequest(ctx)
if newReq != nil {
ServerLog.Values(logger.Values{
"filter": fmt.Sprintf("%T", f),
"old_method": ctx.req.Method,
"old_url": ctx.req.URL,
"new_method": newReq.Method,
"new_url": newReq.URL,
}).Debug("Replacing request from filter")
ctx.req = newReq
}
if newRes != nil {
log := ServerLog.Values(logger.Values{
"filter": fmt.Sprintf("%T", f),
"response": newRes.StatusCode,
"status": newRes.Status,
})
log.Debug("Replacing response from filter")
ctx.res = newRes
if err = p.writeResponse(ctx); err != nil {
if netutil.IsClosing(err) {
return
}
log.Err(err).Warn("Error overriding repsonse")
}
continue
}
}
if err = p.handleRequest(ctx); err != nil {
p.handleError(ctx, err, true)
return
}
// Only once
if ctx.transparent > 0 || ctx.transparentTLS || ctx.req.Method == http.MethodConnect {
return
}
}
}
func (p *Proxy) handleError(ctx *proxyContext, err error, sendResponse bool) {
res := ctx.Response()
if res == nil && sendResponse {
res = NewErrorResponse(err, ctx.Request())
}
for _, f := range p.OnError {
if newRes := f.HandleError(ctx, err); newRes != nil {
res = newRes
}
}
if sendResponse && res != nil {
if werr := p.writeResponse(ctx); werr != nil && !netutil.IsClosing(err) {
ServerLog.Err(werr).Warn("Error writing error response")
}
}
}
func (p *Proxy) handleRequest(ctx *proxyContext) (err error) {
switch {
case ctx.req == nil:
ctx.LogEntry().Warn("Request is nil in handleRequest!?")
return errors.New("proxy: request is nil?")
case headerContains(ctx.req.Header, HeaderConnection, "upgrade"):
if headerContains(ctx.req.Header, HeaderUpgrade, "websocket") {
return p.serveWebSocket(ctx)
}
ctx.res = NewResponse(http.StatusBadRequest, nil, ctx.req)
return p.writeResponse(ctx)
case ctx.req.Method == http.MethodConnect:
return p.serveConnect(ctx)
case ctx.req.URL.IsAbs():
return p.serveForward(ctx)
default:
return p.serve(ctx)
}
}
func (p *Proxy) applyResponseHandler(ctx *proxyContext) {
for _, f := range p.OnResponse {
if newRes := f.HandleResponse(ctx); newRes != nil {
if ctx.res.Body != nil {
_ = ctx.res.Body.Close()
}
ctx.res = newRes
}
}
}
func (p *Proxy) serve(ctx *proxyContext) (err error) {
var (
b = new(bytes.Buffer)
cw = ctx.cw
)
// This is where our response headers etc. are captured
ctx.res = NewResponse(http.StatusOK, nil, ctx.req)
// This is where our response body is captured
ctx.cw = &countingWriter{writer: b, bytes: ctx.cw.bytes}
// Pass ServeHTTP call to mux handler(s)
p.mux.ServeHTTP(ctx, ctx.req)
// Expose body
ctx.res.Body = io.NopCloser(b)
// Correct headers
if ctx.res.Header.Get(HeaderDate) == "" {
ctx.res.Header.Set(HeaderDate, time.Now().UTC().Format("Mon, 2 Jan 2006 15:04:05")+" GMT")
}
if ctx.res.Header.Get(HeaderContentType) == "" && b.Len() > 0 {
ctx.res.Header.Set(HeaderContentType, "text/html; charset=utf-8")
}
// Restore writer for the call to writeResponse
ctx.cw = cw
return p.writeResponse(ctx)
}
func (p *Proxy) serveConnect(ctx *proxyContext) (err error) {
log := ctx.LogEntry()
// Most browsers expect to get a 200 OK after firing a HTTP CONNECT request; if the upstream
// encounters any errors, we'll inform the client after reading the HTTP request that follows.
if !(ctx.transparent > 0 || ctx.transparentTLS) {
if _, err = io.WriteString(ctx, "HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil {
return
}
}
switch ctx.req.URL.Scheme {
case "":
ctx.req.URL.Scheme = "tcp"
}
log.Value("target", ctx.req.URL.String()).Debugf("%s CONNECT request", ctx.req.Proto)
var (
timeout, cancel = context.WithTimeout(context.Background(), p.DialTimeout)
c net.Conn
)
if c, err = p.dial(timeout, ctx.req); err != nil {
cancel()
ctx.res = NewErrorResponse(err, ctx.req)
_ = p.writeResponse(ctx)
_ = ctx.Close()
return fmt.Errorf("proxy: dial %s error: %w", ctx.req.URL, err)
}
cancel()
ctx.res = NewResponse(http.StatusOK, nil, ctx.req)
srv := NewContext(c).(*proxyContext)
srv.SetIdleTimeout(p.IdleTimeout)
return p.multiplex(ctx, srv)
}
func (p *Proxy) serveForward(ctx *proxyContext) (err error) {
log := ctx.LogEntry()
log.Value("target", ctx.req.URL.String()).Debugf("%s forward request", ctx.req.Proto)
if ctx.res, err = p.RoundTripper.RoundTrip(ctx.req); err != nil {
// log.Printf("%s forward request error: %v", ctx, err)
ctx.res = NewErrorResponse(err, ctx.req)
_ = p.writeResponse(ctx)
_ = ctx.Close()
return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err)
}
p.applyResponseHandler(ctx)
return p.writeResponse(ctx)
}
func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
log := ctx.LogEntry().Value("target", ctx.req.URL.String())
switch ctx.req.URL.Scheme {
case "http":
ctx.req.URL.Scheme = "ws"
case "https":
ctx.req.URL.Scheme = "wss"
}
log.Debugf("%s websocket request", ctx.req.Proto)
var (
timeout, cancel = context.WithTimeout(context.Background(), p.DialTimeout)
c net.Conn
)
if c, err = p.dial(timeout, ctx.req); err != nil {
cancel()
ctx.res = NewErrorResponse(err, ctx.req)
_ = p.writeResponse(ctx)
_ = ctx.Close()
return fmt.Errorf("proxy: dial %s error: %w", ctx.req.URL, err)
}
cancel()
srv := NewContext(c).(*proxyContext)
srv.SetIdleTimeout(p.IdleTimeout)
if err = ctx.req.Write(srv); err != nil {
ctx.res = NewErrorResponse(err, ctx.req)
_ = p.writeResponse(ctx)
_ = ctx.Close()
return fmt.Errorf("proxy: failed to write request to upstream: %w", err)
}
if ctx.res, err = http.ReadResponse(srv.Reader(), ctx.req); err != nil {
ctx.res = NewErrorResponse(err, ctx.req)
_ = p.writeResponse(ctx)
_ = ctx.Close()
return fmt.Errorf("proxy: failed to read response from upstream: %w", err)
}
log.Values(logger.Values{
"response": ctx.res.StatusCode,
"status": ctx.res.Status,
}).Debug("WebSocket response from upstream")
if err = p.writeResponse(ctx); err != nil {
_ = ctx.Close()
return
}
ctx.SetIdleTimeout(p.WebSocketIdleTimeout)
return p.multiplex(ctx, srv)
}
func (p *Proxy) multiplex(ctx, srv Context) (err error) {
var (
errs = make(chan error, 1)
done = make(chan struct{}, 1)
)
go func(errs chan<- error) {
defer close(done)
if _, err := io.Copy(srv, ctx); err != nil {
errs <- err
}
}(errs)
go func(errs chan<- error) {
if _, err := io.Copy(ctx, srv); err != nil {
errs <- err
}
}(errs)
select {
case err = <-errs:
return
case <-done:
return
}
}
func (p *Proxy) writeResponse(ctx *proxyContext) (err error) {
res := ctx.Response()
for _, f := range p.OnResponse {
if newRes := f.HandleResponse(ctx); newRes != nil {
log.Printf("Filter returned response HTTP %s", newRes.Status)
if res.Body != nil {
_ = res.Body.Close()
}
res = newRes
}
}
ServerLog.Values(logger.Values{
"close": res.Close,
"header": res.Header,
}).Debug("Writing response")
if err = res.Write(ctx); err != nil {
return
}
if res.Close || ctx.res.Close || strings.ToLower(ctx.res.Header.Get(HeaderConnection)) != "keep-alive" {
// Force closing of connection.
if err = ctx.Close(); err != nil {
return
}
return io.EOF
}
return
}
func headerContains(h http.Header, k, v string) bool {
vs := h[http.CanonicalHeaderKey(k)]
return slices.ContainsFunc(vs, func(e string) bool {
return strings.EqualFold(e, v)
})
}