Files
styx/proxy/proxy.go
2025-09-30 08:08:22 +02:00

506 lines
13 KiB
Go

package proxy
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
"slices"
"strings"
"syscall"
"time"
"git.maze.io/maze/styx/internal/netutil"
"git.maze.io/maze/styx/stats"
"github.com/sirupsen/logrus"
)
// Common HTTP headers.
const (
HeaderConnection = "Connection"
HeaderContentType = "Content-Type"
HeaderDate = "Date"
HeaderForwarded = "Forwarded"
HeaderForwardedFor = "X-Forwarded-For"
HeaderForwardedHost = "X-Forwarded-Host"
HeaderForwardedPort = "X-Forwarded-Port"
HeaderForwardedProto = "X-Forwarded-Proto"
HeaderRealIP = "X-Real-Ip"
HeaderUpgrade = "Upgrade"
HeaderVia = "Via"
)
// Safe defaults.
const (
DefaultDialTimeout = 15 * time.Second
DefaultIdleTimeout = 10 * time.Second
DefaultWebSocketIdleTimeout = 30 * time.Second
)
var (
// AccessLog is used for logging requests to the proxy.
AccessLog = logrus.StandardLogger()
// ServerLog is used for logging server log messages.
ServerLog = logrus.StandardLogger()
)
// Proxy implements a HTTP(S) proxy.
type Proxy struct {
rt http.RoundTripper
dialer map[string]Dialer
connFilter []ConnFilter
requestFilter []RequestFilter
responseFilter []ResponseFilter
dialTimeout time.Duration
idleTimeout time.Duration
webSocketIdleTimeout time.Duration
mux *http.ServeMux
}
// New [Proxy] with somewhat sane defaults.
func New() *Proxy {
p := &Proxy{
dialer: map[string]Dialer{"": defaultDialer{}},
dialTimeout: DefaultDialTimeout,
idleTimeout: DefaultIdleTimeout,
webSocketIdleTimeout: DefaultWebSocketIdleTimeout,
mux: http.NewServeMux(),
}
// Make sure the roundtripper uses our dialers.
p.rt = &http.Transport{
TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper),
Proxy: http.ProxyFromEnvironment,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: time.Second,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return p.dialer[""].DialContext(ctx, &http.Request{
URL: &url.URL{
Scheme: network,
Host: addr,
},
})
},
}
p.Handle("/stats", stats.Handler(stats.Exposed))
p.Handle("/stats.json", stats.JSONHandler(stats.Exposed))
return p
}
// Handle installs a [http.Handler] into the internal mux.
func (p *Proxy) Handle(pattern string, handler http.Handler) {
p.mux.Handle(pattern, handler)
}
// HandleFunc installs a [http.HandlerFunc] into the internal mux.
func (p *Proxy) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
p.mux.HandleFunc(pattern, handler)
}
// SetDialer specifies a [Dialer] for the specified protocol. The default [Dialer] corresponds
// to an empty string. Only override the default [Dialer] if you know what you are doing.
func (p *Proxy) SetDialer(proto string, dialer Dialer) {
if dialer == nil {
if proto != "" {
delete(p.dialer, proto)
}
} else {
p.dialer[proto] = dialer
}
}
// AddConnFilter adds a connection filter to the stack.
func (p *Proxy) AddConnFilter(f ConnFilter) {
if f == nil {
return
}
p.connFilter = append(p.connFilter, f)
}
// AddRequestFilter adds a request filter to the stack.
func (p *Proxy) AddRequestFilter(f RequestFilter) {
if f == nil {
return
}
p.requestFilter = append(p.requestFilter, f)
}
// AddResponseFilter adds a response filter to the stack.
func (p *Proxy) AddResponseFilter(f ResponseFilter) {
if f == nil {
return
}
p.responseFilter = append(p.responseFilter, f)
}
func (p *Proxy) dial(ctx context.Context, req *http.Request) (net.Conn, error) {
d, ok := p.dialer[req.URL.Scheme]
if !ok {
d = p.dialer[""]
}
return d.DialContext(ctx, req)
}
// Serve proxied connections on the specified listener.
func (p *Proxy) Serve(l net.Listener) error {
for {
c, err := l.Accept()
if err != nil {
return err
}
go p.handle(c)
}
}
func (p *Proxy) handle(nc net.Conn) {
var (
start = time.Now()
ctx = NewContext(nc).(*proxyContext)
err error
)
defer func() {
if cerr := ctx.Close(); cerr != nil && err == nil {
err = cerr
}
log := ctx.AccessLogEntry().WithField("duration", time.Since(start))
if err != nil && !netutil.IsClosing(err) {
log = log.WithError(err)
}
if req := ctx.Request(); req != nil {
log = log.WithFields(logrus.Fields{
"method": req.Method,
"request": req.URL.String(),
})
}
if res := ctx.Response(); res != nil {
//countStatus(res.StatusCode)
log.WithFields(logrus.Fields{
"response": res.StatusCode,
}).Info(res.Status)
} else {
//countStatus(0)
log.Info("No response")
}
}()
// Propagate timeouts
ctx.SetIdleTimeout(p.idleTimeout)
for _, f := range p.connFilter {
fc, err := f.FilterConn(ctx)
if err != nil {
ServerLog.WithField("filter", fmt.Sprintf("%T", f)).WithError(err).Warn("error in conn filter")
_ = nc.Close()
return
} else if fc != nil {
ServerLog.WithField("filter", fmt.Sprintf("%T", f)).Debug("replacing connection from filter")
ctx.Conn = fc
ctx.br = bufio.NewReader(fc)
}
}
for {
if ctx.isTransparentTLS {
ctx.req = &http.Request{
Method: http.MethodConnect,
URL: &url.URL{
Scheme: "tcp",
Host: net.JoinHostPort(ctx.serverName, "443"),
},
}
} else if ctx.req, err = http.ReadRequest(ctx.Reader()); err != nil {
if !(errors.Is(err, io.EOF) || errors.Is(err, syscall.ECONNRESET)) {
ServerLog.WithError(err).Debug("error reading request")
}
return
}
if ctx.isTransparent {
// Canonicallize to absolute URL
if ctx.req.URL.Host == "" {
ctx.req.URL.Host = ctx.req.Host
}
if ctx.req.URL.Scheme == "" {
ctx.req.URL.Scheme = "http"
}
ctx.isTransparent = false
}
for _, f := range p.requestFilter {
newReq, newRes := f.FilterRequest(ctx)
if newReq != nil {
ServerLog.WithFields(logrus.Fields{
"filter": fmt.Sprintf("%T", f),
"old_method": ctx.req.Method,
"old_url": ctx.req.URL,
"new_method": newReq.Method,
"new_url": newReq.URL,
}).Debug("replacing request from filter")
ctx.req = newReq
}
if newRes != nil {
log := ServerLog.WithFields(logrus.Fields{
"filter": fmt.Sprintf("%T", f),
"response": newRes.StatusCode,
"status": newRes.Status,
})
log.Debug("replacing response from filter")
ctx.res = newRes
if err = p.writeResponse(ctx); err != nil {
log.WithError(err).Warn("error overriding repsonse")
}
continue
}
}
if err = p.handleRequest(ctx); err != nil {
return
}
// Only once
if ctx.isTransparent || ctx.isTransparentTLS {
return
}
}
}
func (p *Proxy) handleRequest(ctx *proxyContext) (err error) {
switch {
case ctx.req == nil:
ctx.LogEntry().Warn("Request is nil in handleRequest!?")
return errors.New("proxy: request is nil?")
case headerContains(ctx.req.Header, HeaderConnection, "upgrade"):
if headerContains(ctx.req.Header, HeaderUpgrade, "websocket") {
return p.serveWebSocket(ctx)
}
ctx.res = NewResponse(http.StatusBadRequest, nil, ctx.req)
return p.writeResponse(ctx)
case ctx.req.Method == http.MethodConnect:
return p.serveConnect(ctx)
case ctx.req.URL.IsAbs():
return p.serveForward(ctx)
default:
return p.serve(ctx)
}
}
func (p *Proxy) applyResponseFilter(ctx *proxyContext) {
for _, f := range p.responseFilter {
if newRes := f.FilterResponse(ctx); newRes != nil {
if ctx.res.Body != nil {
_ = ctx.res.Body.Close()
}
ctx.res = newRes
}
}
}
func (p *Proxy) serve(ctx *proxyContext) (err error) {
var (
b = new(bytes.Buffer)
cw = ctx.cw
)
// This is where our response headers etc. are captured
ctx.res = NewResponse(http.StatusOK, nil, ctx.req)
// This is where our response body is captured
ctx.cw = &countingWriter{writer: b, bytes: ctx.cw.bytes}
// Pass ServeHTTP call to mux handler(s)
p.mux.ServeHTTP(ctx, ctx.req)
// Expose body
ctx.res.Body = io.NopCloser(b)
// Correct headers
if ctx.res.Header.Get(HeaderDate) == "" {
ctx.res.Header.Set(HeaderDate, time.Now().UTC().Format("Mon, 2 Jan 2006 15:04:05")+" GMT")
}
if ctx.res.Header.Get(HeaderContentType) == "" && b.Len() > 0 {
ctx.res.Header.Set(HeaderContentType, "text/html; charset=utf-8")
}
// Restore writer for the call to writeResponse
ctx.cw = cw
return p.writeResponse(ctx)
}
func (p *Proxy) serveConnect(ctx *proxyContext) (err error) {
log := ctx.LogEntry()
// Most browsers expect to get a 200 OK after firing a HTTP CONNECT request; if the upstream
// encounters any errors, we'll inform the client after reading the HTTP request that follows.
if !(ctx.isTransparent || ctx.isTransparentTLS) {
if _, err = io.WriteString(ctx, "HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil {
return
}
}
switch ctx.req.URL.Scheme {
case "":
ctx.req.URL.Scheme = "tcp"
}
log.WithField("target", ctx.req.URL.String()).Debug("http CONNECT request")
var (
timeout, cancel = context.WithTimeout(context.Background(), p.dialTimeout)
c net.Conn
)
if c, err = p.dial(timeout, ctx.req); err != nil {
cancel()
ctx.res = NewErrorResponse(err, ctx.req)
_ = p.writeResponse(ctx)
_ = ctx.Close()
return fmt.Errorf("proxy: dial %s error: %w", ctx.req.URL, err)
}
cancel()
ctx.res = NewResponse(http.StatusOK, nil, ctx.req)
srv := NewContext(c).(*proxyContext)
srv.SetIdleTimeout(p.idleTimeout)
return p.multiplex(ctx, srv)
}
func (p *Proxy) serveForward(ctx *proxyContext) (err error) {
log := ctx.LogEntry()
log.WithField("target", ctx.req.URL.String()).Debug("http forward request")
if ctx.res, err = p.rt.RoundTrip(ctx.req); err != nil {
// log.Printf("%s forward request error: %v", ctx, err)
ctx.res = NewErrorResponse(err, ctx.req)
_ = p.writeResponse(ctx)
_ = ctx.Close()
return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err)
}
p.applyResponseFilter(ctx)
return p.writeResponse(ctx)
}
func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
log := ctx.LogEntry().WithField("target", ctx.req.URL.String())
switch ctx.req.URL.Scheme {
case "http":
ctx.req.URL.Scheme = "ws"
case "https":
ctx.req.URL.Scheme = "wss"
}
log.Debug("http websocket request")
var (
timeout, cancel = context.WithTimeout(context.Background(), p.dialTimeout)
c net.Conn
)
if c, err = p.dial(timeout, ctx.req); err != nil {
cancel()
ctx.res = NewErrorResponse(err, ctx.req)
_ = p.writeResponse(ctx)
_ = ctx.Close()
return fmt.Errorf("proxy: dial %s error: %w", ctx.req.URL, err)
}
cancel()
srv := NewContext(c).(*proxyContext)
srv.SetIdleTimeout(p.idleTimeout)
if err = ctx.req.Write(srv); err != nil {
ctx.res = NewErrorResponse(err, ctx.req)
_ = p.writeResponse(ctx)
_ = ctx.Close()
return fmt.Errorf("proxy: failed to write request to upstream: %w", err)
}
if ctx.res, err = http.ReadResponse(srv.Reader(), ctx.req); err != nil {
ctx.res = NewErrorResponse(err, ctx.req)
_ = p.writeResponse(ctx)
_ = ctx.Close()
return fmt.Errorf("proxy: failed to read response from upstream: %w", err)
}
log.WithFields(logrus.Fields{
"response": ctx.res.StatusCode,
"status": ctx.res.Status,
}).Debug("websocket response from upstream")
if err = p.writeResponse(ctx); err != nil {
_ = ctx.Close()
return
}
ctx.SetIdleTimeout(p.webSocketIdleTimeout)
return p.multiplex(ctx, srv)
}
func (p *Proxy) multiplex(ctx, srv Context) (err error) {
var (
errs = make(chan error, 1)
done = make(chan struct{}, 1)
)
go func(errs chan<- error) {
defer close(done)
if _, err := io.Copy(srv, ctx); err != nil {
errs <- err
}
}(errs)
go func(errs chan<- error) {
if _, err := io.Copy(ctx, srv); err != nil {
errs <- err
}
}(errs)
select {
case err = <-errs:
return
case <-done:
return
}
}
func (p *Proxy) writeResponse(ctx *proxyContext) (err error) {
res := ctx.Response()
for _, f := range p.responseFilter {
if newRes := f.FilterResponse(ctx); newRes != nil {
log.Printf("filter returned response HTTP %s", newRes.Status)
if res.Body != nil {
_ = res.Body.Close()
}
res = newRes
}
}
ServerLog.WithFields(logrus.Fields{
"close": res.Close,
"header": res.Header,
}).Debug("writing response")
if err = res.Write(ctx); err != nil {
return
}
if res.Close || ctx.res.Close || strings.ToLower(ctx.res.Header.Get(HeaderConnection)) != "keep-alive" {
// Force closing of connection.
if err = ctx.Close(); err != nil {
return
}
return io.EOF
}
return
}
func headerContains(h http.Header, k, v string) bool {
vs := h[http.CanonicalHeaderKey(k)]
return slices.ContainsFunc(vs, func(e string) bool {
return strings.EqualFold(e, v)
})
}