586 lines
16 KiB
Go
586 lines
16 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"git.maze.io/maze/styx/internal/netutil"
|
|
"git.maze.io/maze/styx/logger"
|
|
"git.maze.io/maze/styx/stats"
|
|
)
|
|
|
|
// Common HTTP headers.
|
|
const (
|
|
HeaderConnection = "Connection"
|
|
HeaderContentType = "Content-Type"
|
|
HeaderDate = "Date"
|
|
HeaderForwarded = "Forwarded"
|
|
HeaderForwardedFor = "X-Forwarded-For"
|
|
HeaderForwardedHost = "X-Forwarded-Host"
|
|
HeaderForwardedPort = "X-Forwarded-Port"
|
|
HeaderForwardedProto = "X-Forwarded-Proto"
|
|
HeaderLocation = "Location"
|
|
HeaderRealIP = "X-Real-Ip"
|
|
HeaderUpgrade = "Upgrade"
|
|
HeaderVia = "Via"
|
|
)
|
|
|
|
// Safe defaults.
|
|
const (
|
|
DefaultDialTimeout = 15 * time.Second
|
|
DefaultIdleTimeout = 10 * time.Second
|
|
DefaultWebSocketIdleTimeout = 30 * time.Second
|
|
)
|
|
|
|
var (
|
|
// AccessLog is used for logging requests to the proxy.
|
|
AccessLog = logger.Get()
|
|
|
|
// ServerLog is used for logging server log messages.
|
|
ServerLog = logger.Get()
|
|
)
|
|
|
|
// Proxy implements a HTTP(S) proxy.
|
|
type Proxy struct {
|
|
// RoundTripper is used to make outbound HTTP requests. It defaults to a [http.Transport]
|
|
// with a custom DialContext that uses the configured [Dialer]s.
|
|
//
|
|
// Only override this if you know what you are doing.
|
|
RoundTripper http.RoundTripper
|
|
|
|
// Dialer is a map of protocol names to [Dialer] implementations. The default [Dialer]
|
|
// corresponds to an empty string key.
|
|
//
|
|
// Only override the default [Dialer] if you know what you are doing.
|
|
Dialer map[string]Dialer
|
|
|
|
// OnConnect is a list of connection filters that are applied in order when a new
|
|
// connection is established.
|
|
//
|
|
// Connection filters can be used to implement custom authentication, logging,
|
|
// rate limiting, etc.
|
|
//
|
|
// Connection filters are applied before any HTTP request is read from the connection.
|
|
//
|
|
// Connection filters should return a non-nil error if they want to terminate the
|
|
// connection. Returning a non-nil [net.Conn] will replace the existing connection
|
|
// with the returned one.
|
|
//
|
|
// Connection filters should not modify the connection in any way (e.g. wrapping it
|
|
// in a TLS connection) as this will interfere with the proxy's ability to read
|
|
// HTTP requests from the connection.
|
|
//
|
|
// Connection filters are executed sequentially in the order they are added.
|
|
OnConnect []ConnHandler
|
|
|
|
// OnRequest is a list of request filters that are applied in order when a new
|
|
// HTTP request is read from the connection.
|
|
//
|
|
// Request filters can be used to modify the request, or to return a response
|
|
// directly without forwarding the request to the upstream server.
|
|
//
|
|
// Request filters should return a non-nil error if they want to terminate the
|
|
// connection.
|
|
//
|
|
// Request filters are executed sequentially in the order they are added.
|
|
OnRequest []RequestHandler
|
|
|
|
// OnResponse is a list of response filters that are applied in order when a
|
|
// response is received from the upstream server.
|
|
//
|
|
// Response filters can be used to modify the response before it is sent to
|
|
// the client.
|
|
//
|
|
// Response filters should return a non-nil error if they want to terminate the
|
|
// connection.
|
|
//
|
|
// Response filters are executed sequentially in the order they are added.
|
|
OnResponse []ResponseHandler
|
|
|
|
// OnError is a list of error handlers that are applied in order when an
|
|
// error occurs during request processing.
|
|
//
|
|
// Error handlers can be used to log errors, or to return a custom response
|
|
// to the client.
|
|
//
|
|
// Error handlers should return a non-nil error if they want to terminate the
|
|
// connection.
|
|
//
|
|
// Error handlers are executed sequentially in the order they are added.
|
|
OnError []ErrorHandler
|
|
|
|
// DialTimeout is the timeout for establishing new connections to upstream servers.
|
|
DialTimeout time.Duration
|
|
|
|
// IdleTimeout is the timeout for idle connections.
|
|
IdleTimeout time.Duration
|
|
|
|
// WebSocketIdleTimeout is the timeout for idle WebSocket connections.
|
|
WebSocketIdleTimeout time.Duration
|
|
|
|
mux *http.ServeMux
|
|
}
|
|
|
|
// New [Proxy] with somewhat sane defaults.
|
|
func New() *Proxy {
|
|
p := &Proxy{
|
|
Dialer: map[string]Dialer{"": defaultDialer{}},
|
|
DialTimeout: DefaultDialTimeout,
|
|
IdleTimeout: DefaultIdleTimeout,
|
|
WebSocketIdleTimeout: DefaultWebSocketIdleTimeout,
|
|
mux: http.NewServeMux(),
|
|
}
|
|
|
|
// Make sure the roundtripper uses our dialers.
|
|
p.RoundTripper = &http.Transport{
|
|
TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper),
|
|
Proxy: http.ProxyFromEnvironment,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
ExpectContinueTimeout: time.Second,
|
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
return p.Dialer[""].DialContext(ctx, &http.Request{
|
|
URL: &url.URL{
|
|
Scheme: network,
|
|
Host: addr,
|
|
},
|
|
})
|
|
},
|
|
}
|
|
|
|
p.Handle("/stats", stats.Handler(stats.Exposed))
|
|
p.Handle("/stats.json", stats.JSONHandler(stats.Exposed))
|
|
|
|
return p
|
|
}
|
|
|
|
// Handle installs a [http.Handler] into the internal mux.
|
|
func (p *Proxy) Handle(pattern string, handler http.Handler) {
|
|
p.mux.Handle(pattern, handler)
|
|
}
|
|
|
|
// HandleFunc installs a [http.HandlerFunc] into the internal mux.
|
|
func (p *Proxy) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
|
|
p.mux.HandleFunc(pattern, handler)
|
|
}
|
|
|
|
// SetDialer specifies a [Dialer] for the specified protocol. The default [Dialer] corresponds
|
|
// to an empty string. Only override the default [Dialer] if you know what you are doing.
|
|
func (p *Proxy) SetDialer(proto string, dialer Dialer) {
|
|
if dialer == nil {
|
|
if proto != "" {
|
|
delete(p.Dialer, proto)
|
|
}
|
|
} else {
|
|
p.Dialer[proto] = dialer
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) dial(ctx context.Context, req *http.Request) (net.Conn, error) {
|
|
d, ok := p.Dialer[req.URL.Scheme]
|
|
if !ok {
|
|
d = p.Dialer[""]
|
|
}
|
|
|
|
return d.DialContext(ctx, req)
|
|
}
|
|
|
|
// Serve proxied connections on the specified listener.
|
|
func (p *Proxy) Serve(l net.Listener) error {
|
|
for {
|
|
c, err := l.Accept()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
go p.handle(c)
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) handle(nc net.Conn) {
|
|
var (
|
|
start = time.Now()
|
|
ctx = NewContext(nc).(*proxyContext)
|
|
err error
|
|
)
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
if err, ok := r.(error); ok {
|
|
ctx.LogEntry().Err(err).Warn("Bug in code, recovered from panic!")
|
|
}
|
|
_ = nc.Close()
|
|
}
|
|
}()
|
|
|
|
defer func() {
|
|
if cerr := ctx.Close(); cerr != nil && err == nil && !netutil.IsClosing(cerr) {
|
|
err = cerr
|
|
}
|
|
|
|
log := ctx.AccessLogEntry().Value("duration", time.Since(start))
|
|
if err != nil && !netutil.IsClosing(err) {
|
|
log = log.Err(err)
|
|
}
|
|
if req := ctx.Request(); req != nil {
|
|
log = log.Values(logger.Values{
|
|
"method": req.Method,
|
|
"request": req.URL.String(),
|
|
})
|
|
}
|
|
if res := ctx.Response(); res != nil {
|
|
//countStatus(res.StatusCode)
|
|
log.Values(logger.Values{
|
|
"response": res.StatusCode,
|
|
}).Info(res.Status)
|
|
} else {
|
|
//countStatus(0)
|
|
log.Info("No response")
|
|
}
|
|
}()
|
|
|
|
// Propagate timeouts
|
|
ctx.SetIdleTimeout(p.IdleTimeout)
|
|
|
|
for _, f := range p.OnConnect {
|
|
fc, err := f.HandleConn(ctx)
|
|
if err != nil {
|
|
ServerLog.Value("filter", fmt.Sprintf("%T", f)).Err(err).Warn("Error in conn filter")
|
|
p.handleError(ctx, err, true)
|
|
_ = nc.Close()
|
|
return
|
|
} else if fc != nil {
|
|
ServerLog.Value("filter", fmt.Sprintf("%T", f)).Debug("Replacing connection from filter")
|
|
ctx.Conn = fc
|
|
ctx.br = bufio.NewReader(fc)
|
|
}
|
|
}
|
|
|
|
for {
|
|
if ctx.transparentTLS {
|
|
ctx.req = &http.Request{
|
|
Method: http.MethodConnect,
|
|
URL: &url.URL{Scheme: "tcp", Host: net.JoinHostPort(ctx.serverName, strconv.Itoa(ctx.transparent))},
|
|
Host: net.JoinHostPort(ctx.serverName, strconv.Itoa(ctx.transparent)),
|
|
Proto: "HTTP/1.1",
|
|
ProtoMajor: 1,
|
|
ProtoMinor: 1,
|
|
Close: true,
|
|
}
|
|
} else if ctx.req, err = http.ReadRequest(ctx.Reader()); err != nil {
|
|
if !(errors.Is(err, io.EOF) || errors.Is(err, syscall.ECONNRESET)) {
|
|
ServerLog.Err(err).Debug("Error reading request")
|
|
}
|
|
p.handleError(ctx, err, true)
|
|
return
|
|
}
|
|
|
|
if ctx.transparent > 0 {
|
|
// Canonicallize to absolute URL
|
|
if ctx.req.URL.Host == "" {
|
|
ctx.req.URL.Host = ctx.req.Host
|
|
}
|
|
if ctx.req.URL.Scheme == "" {
|
|
ctx.req.URL.Scheme = "http"
|
|
}
|
|
ctx.transparent = 0
|
|
}
|
|
|
|
for _, f := range p.OnRequest {
|
|
newReq, newRes := f.HandleRequest(ctx)
|
|
if newReq != nil {
|
|
ServerLog.Values(logger.Values{
|
|
"filter": fmt.Sprintf("%T", f),
|
|
"old_method": ctx.req.Method,
|
|
"old_url": ctx.req.URL,
|
|
"new_method": newReq.Method,
|
|
"new_url": newReq.URL,
|
|
}).Debug("Replacing request from filter")
|
|
ctx.req = newReq
|
|
}
|
|
if newRes != nil {
|
|
log := ServerLog.Values(logger.Values{
|
|
"filter": fmt.Sprintf("%T", f),
|
|
"response": newRes.StatusCode,
|
|
"status": newRes.Status,
|
|
})
|
|
log.Debug("Replacing response from filter")
|
|
ctx.res = newRes
|
|
if err = p.writeResponse(ctx); err != nil {
|
|
if netutil.IsClosing(err) {
|
|
return
|
|
}
|
|
log.Err(err).Warn("Error overriding repsonse")
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
|
|
if err = p.handleRequest(ctx); err != nil {
|
|
p.handleError(ctx, err, true)
|
|
return
|
|
}
|
|
|
|
// Only once
|
|
if ctx.transparent > 0 || ctx.transparentTLS || ctx.req.Method == http.MethodConnect {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) handleError(ctx *proxyContext, err error, sendResponse bool) {
|
|
res := ctx.Response()
|
|
if res == nil && sendResponse {
|
|
res = NewErrorResponse(err, ctx.Request())
|
|
}
|
|
for _, f := range p.OnError {
|
|
if newRes := f.HandleError(ctx, err); newRes != nil {
|
|
res = newRes
|
|
}
|
|
}
|
|
if sendResponse && res != nil {
|
|
if werr := p.writeResponse(ctx); werr != nil && !netutil.IsClosing(err) {
|
|
ServerLog.Err(werr).Warn("Error writing error response")
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) handleRequest(ctx *proxyContext) (err error) {
|
|
switch {
|
|
case ctx.req == nil:
|
|
ctx.LogEntry().Warn("Request is nil in handleRequest!?")
|
|
return errors.New("proxy: request is nil?")
|
|
|
|
case headerContains(ctx.req.Header, HeaderConnection, "upgrade"):
|
|
if headerContains(ctx.req.Header, HeaderUpgrade, "websocket") {
|
|
return p.serveWebSocket(ctx)
|
|
}
|
|
ctx.res = NewResponse(http.StatusBadRequest, nil, ctx.req)
|
|
return p.writeResponse(ctx)
|
|
|
|
case ctx.req.Method == http.MethodConnect:
|
|
return p.serveConnect(ctx)
|
|
|
|
case ctx.req.URL.IsAbs():
|
|
return p.serveForward(ctx)
|
|
|
|
default:
|
|
return p.serve(ctx)
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) applyResponseHandler(ctx *proxyContext) {
|
|
for _, f := range p.OnResponse {
|
|
if newRes := f.HandleResponse(ctx); newRes != nil {
|
|
if ctx.res.Body != nil {
|
|
_ = ctx.res.Body.Close()
|
|
}
|
|
ctx.res = newRes
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) serve(ctx *proxyContext) (err error) {
|
|
var (
|
|
b = new(bytes.Buffer)
|
|
cw = ctx.cw
|
|
)
|
|
// This is where our response headers etc. are captured
|
|
ctx.res = NewResponse(http.StatusOK, nil, ctx.req)
|
|
|
|
// This is where our response body is captured
|
|
ctx.cw = &countingWriter{writer: b, bytes: ctx.cw.bytes}
|
|
|
|
// Pass ServeHTTP call to mux handler(s)
|
|
p.mux.ServeHTTP(ctx, ctx.req)
|
|
|
|
// Expose body
|
|
ctx.res.Body = io.NopCloser(b)
|
|
|
|
// Correct headers
|
|
if ctx.res.Header.Get(HeaderDate) == "" {
|
|
ctx.res.Header.Set(HeaderDate, time.Now().UTC().Format("Mon, 2 Jan 2006 15:04:05")+" GMT")
|
|
}
|
|
if ctx.res.Header.Get(HeaderContentType) == "" && b.Len() > 0 {
|
|
ctx.res.Header.Set(HeaderContentType, "text/html; charset=utf-8")
|
|
}
|
|
|
|
// Restore writer for the call to writeResponse
|
|
ctx.cw = cw
|
|
return p.writeResponse(ctx)
|
|
}
|
|
|
|
func (p *Proxy) serveConnect(ctx *proxyContext) (err error) {
|
|
log := ctx.LogEntry()
|
|
|
|
// Most browsers expect to get a 200 OK after firing a HTTP CONNECT request; if the upstream
|
|
// encounters any errors, we'll inform the client after reading the HTTP request that follows.
|
|
if !(ctx.transparent > 0 || ctx.transparentTLS) {
|
|
if _, err = io.WriteString(ctx, "HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
switch ctx.req.URL.Scheme {
|
|
case "":
|
|
ctx.req.URL.Scheme = "tcp"
|
|
}
|
|
log.Value("target", ctx.req.URL.String()).Debugf("%s CONNECT request", ctx.req.Proto)
|
|
|
|
var (
|
|
timeout, cancel = context.WithTimeout(context.Background(), p.DialTimeout)
|
|
c net.Conn
|
|
)
|
|
if c, err = p.dial(timeout, ctx.req); err != nil {
|
|
cancel()
|
|
ctx.res = NewErrorResponse(err, ctx.req)
|
|
_ = p.writeResponse(ctx)
|
|
_ = ctx.Close()
|
|
return fmt.Errorf("proxy: dial %s error: %w", ctx.req.URL, err)
|
|
}
|
|
cancel()
|
|
|
|
ctx.res = NewResponse(http.StatusOK, nil, ctx.req)
|
|
srv := NewContext(c).(*proxyContext)
|
|
srv.SetIdleTimeout(p.IdleTimeout)
|
|
return p.multiplex(ctx, srv)
|
|
}
|
|
|
|
func (p *Proxy) serveForward(ctx *proxyContext) (err error) {
|
|
log := ctx.LogEntry()
|
|
log.Value("target", ctx.req.URL.String()).Debugf("%s forward request", ctx.req.Proto)
|
|
|
|
if ctx.res, err = p.RoundTripper.RoundTrip(ctx.req); err != nil {
|
|
// log.Printf("%s forward request error: %v", ctx, err)
|
|
ctx.res = NewErrorResponse(err, ctx.req)
|
|
_ = p.writeResponse(ctx)
|
|
_ = ctx.Close()
|
|
return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err)
|
|
}
|
|
p.applyResponseHandler(ctx)
|
|
return p.writeResponse(ctx)
|
|
}
|
|
|
|
func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
|
|
log := ctx.LogEntry().Value("target", ctx.req.URL.String())
|
|
|
|
switch ctx.req.URL.Scheme {
|
|
case "http":
|
|
ctx.req.URL.Scheme = "ws"
|
|
case "https":
|
|
ctx.req.URL.Scheme = "wss"
|
|
}
|
|
|
|
log.Debugf("%s websocket request", ctx.req.Proto)
|
|
var (
|
|
timeout, cancel = context.WithTimeout(context.Background(), p.DialTimeout)
|
|
c net.Conn
|
|
)
|
|
if c, err = p.dial(timeout, ctx.req); err != nil {
|
|
cancel()
|
|
ctx.res = NewErrorResponse(err, ctx.req)
|
|
_ = p.writeResponse(ctx)
|
|
_ = ctx.Close()
|
|
return fmt.Errorf("proxy: dial %s error: %w", ctx.req.URL, err)
|
|
}
|
|
cancel()
|
|
|
|
srv := NewContext(c).(*proxyContext)
|
|
srv.SetIdleTimeout(p.IdleTimeout)
|
|
if err = ctx.req.Write(srv); err != nil {
|
|
ctx.res = NewErrorResponse(err, ctx.req)
|
|
_ = p.writeResponse(ctx)
|
|
_ = ctx.Close()
|
|
return fmt.Errorf("proxy: failed to write request to upstream: %w", err)
|
|
}
|
|
|
|
if ctx.res, err = http.ReadResponse(srv.Reader(), ctx.req); err != nil {
|
|
ctx.res = NewErrorResponse(err, ctx.req)
|
|
_ = p.writeResponse(ctx)
|
|
_ = ctx.Close()
|
|
return fmt.Errorf("proxy: failed to read response from upstream: %w", err)
|
|
}
|
|
|
|
log.Values(logger.Values{
|
|
"response": ctx.res.StatusCode,
|
|
"status": ctx.res.Status,
|
|
}).Debug("WebSocket response from upstream")
|
|
if err = p.writeResponse(ctx); err != nil {
|
|
_ = ctx.Close()
|
|
return
|
|
}
|
|
ctx.SetIdleTimeout(p.WebSocketIdleTimeout)
|
|
return p.multiplex(ctx, srv)
|
|
}
|
|
|
|
func (p *Proxy) multiplex(ctx, srv Context) (err error) {
|
|
var (
|
|
errs = make(chan error, 1)
|
|
done = make(chan struct{}, 1)
|
|
)
|
|
go func(errs chan<- error) {
|
|
defer close(done)
|
|
if _, err := io.Copy(srv, ctx); err != nil {
|
|
errs <- err
|
|
}
|
|
}(errs)
|
|
go func(errs chan<- error) {
|
|
if _, err := io.Copy(ctx, srv); err != nil {
|
|
errs <- err
|
|
}
|
|
}(errs)
|
|
|
|
select {
|
|
case err = <-errs:
|
|
return
|
|
case <-done:
|
|
return
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) writeResponse(ctx *proxyContext) (err error) {
|
|
res := ctx.Response()
|
|
for _, f := range p.OnResponse {
|
|
if newRes := f.HandleResponse(ctx); newRes != nil {
|
|
log.Printf("Filter returned response HTTP %s", newRes.Status)
|
|
if res.Body != nil {
|
|
_ = res.Body.Close()
|
|
}
|
|
res = newRes
|
|
}
|
|
}
|
|
ServerLog.Values(logger.Values{
|
|
"close": res.Close,
|
|
"header": res.Header,
|
|
}).Debug("Writing response")
|
|
if err = res.Write(ctx); err != nil {
|
|
return
|
|
}
|
|
if res.Close || ctx.res.Close || strings.ToLower(ctx.res.Header.Get(HeaderConnection)) != "keep-alive" {
|
|
// Force closing of connection.
|
|
if err = ctx.Close(); err != nil {
|
|
return
|
|
}
|
|
return io.EOF
|
|
}
|
|
return
|
|
}
|
|
|
|
func headerContains(h http.Header, k, v string) bool {
|
|
vs := h[http.CanonicalHeaderKey(k)]
|
|
return slices.ContainsFunc(vs, func(e string) bool {
|
|
return strings.EqualFold(e, v)
|
|
})
|
|
}
|