package proxy import ( "bufio" "bytes" "context" "crypto/tls" "errors" "fmt" "io" "log" "net" "net/http" "net/url" "slices" "strconv" "strings" "syscall" "time" "git.maze.io/maze/styx/internal/netutil" "git.maze.io/maze/styx/logger" "git.maze.io/maze/styx/stats" ) // Common HTTP headers. const ( HeaderConnection = "Connection" HeaderContentType = "Content-Type" HeaderDate = "Date" HeaderForwarded = "Forwarded" HeaderForwardedFor = "X-Forwarded-For" HeaderForwardedHost = "X-Forwarded-Host" HeaderForwardedPort = "X-Forwarded-Port" HeaderForwardedProto = "X-Forwarded-Proto" HeaderLocation = "Location" HeaderRealIP = "X-Real-Ip" HeaderUpgrade = "Upgrade" HeaderVia = "Via" ) // Safe defaults. const ( DefaultDialTimeout = 15 * time.Second DefaultIdleTimeout = 10 * time.Second DefaultWebSocketIdleTimeout = 30 * time.Second ) var ( // AccessLog is used for logging requests to the proxy. AccessLog = logger.Get() // ServerLog is used for logging server log messages. ServerLog = logger.Get() ) // Proxy implements a HTTP(S) proxy. type Proxy struct { // RoundTripper is used to make outbound HTTP requests. It defaults to a [http.Transport] // with a custom DialContext that uses the configured [Dialer]s. // // Only override this if you know what you are doing. RoundTripper http.RoundTripper // Dialer is a map of protocol names to [Dialer] implementations. The default [Dialer] // corresponds to an empty string key. // // Only override the default [Dialer] if you know what you are doing. Dialer map[string]Dialer // OnConnect is a list of connection filters that are applied in order when a new // connection is established. // // Connection filters can be used to implement custom authentication, logging, // rate limiting, etc. // // Connection filters are applied before any HTTP request is read from the connection. // // Connection filters should return a non-nil error if they want to terminate the // connection. Returning a non-nil [net.Conn] will replace the existing connection // with the returned one. // // Connection filters should not modify the connection in any way (e.g. wrapping it // in a TLS connection) as this will interfere with the proxy's ability to read // HTTP requests from the connection. // // Connection filters are executed sequentially in the order they are added. OnConnect []ConnHandler // OnRequest is a list of request filters that are applied in order when a new // HTTP request is read from the connection. // // Request filters can be used to modify the request, or to return a response // directly without forwarding the request to the upstream server. // // Request filters should return a non-nil error if they want to terminate the // connection. // // Request filters are executed sequentially in the order they are added. OnRequest []RequestHandler // OnResponse is a list of response filters that are applied in order when a // response is received from the upstream server. // // Response filters can be used to modify the response before it is sent to // the client. // // Response filters should return a non-nil error if they want to terminate the // connection. // // Response filters are executed sequentially in the order they are added. OnResponse []ResponseHandler // OnError is a list of error handlers that are applied in order when an // error occurs during request processing. // // Error handlers can be used to log errors, or to return a custom response // to the client. // // Error handlers should return a non-nil error if they want to terminate the // connection. // // Error handlers are executed sequentially in the order they are added. OnError []ErrorHandler // DialTimeout is the timeout for establishing new connections to upstream servers. DialTimeout time.Duration // IdleTimeout is the timeout for idle connections. IdleTimeout time.Duration // WebSocketIdleTimeout is the timeout for idle WebSocket connections. WebSocketIdleTimeout time.Duration mux *http.ServeMux } // New [Proxy] with somewhat sane defaults. func New() *Proxy { p := &Proxy{ Dialer: map[string]Dialer{"": defaultDialer{}}, DialTimeout: DefaultDialTimeout, IdleTimeout: DefaultIdleTimeout, WebSocketIdleTimeout: DefaultWebSocketIdleTimeout, mux: http.NewServeMux(), } // Make sure the roundtripper uses our dialers. p.RoundTripper = &http.Transport{ TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper), Proxy: http.ProxyFromEnvironment, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: time.Second, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return p.Dialer[""].DialContext(ctx, &http.Request{ URL: &url.URL{ Scheme: network, Host: addr, }, }) }, } p.Handle("/stats", stats.Handler(stats.Exposed)) p.Handle("/stats.json", stats.JSONHandler(stats.Exposed)) return p } // Handle installs a [http.Handler] into the internal mux. func (p *Proxy) Handle(pattern string, handler http.Handler) { p.mux.Handle(pattern, handler) } // HandleFunc installs a [http.HandlerFunc] into the internal mux. func (p *Proxy) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) { p.mux.HandleFunc(pattern, handler) } // SetDialer specifies a [Dialer] for the specified protocol. The default [Dialer] corresponds // to an empty string. Only override the default [Dialer] if you know what you are doing. func (p *Proxy) SetDialer(proto string, dialer Dialer) { if dialer == nil { if proto != "" { delete(p.Dialer, proto) } } else { p.Dialer[proto] = dialer } } func (p *Proxy) dial(ctx context.Context, req *http.Request) (net.Conn, error) { d, ok := p.Dialer[req.URL.Scheme] if !ok { d = p.Dialer[""] } return d.DialContext(ctx, req) } // Serve proxied connections on the specified listener. func (p *Proxy) Serve(l net.Listener) error { for { c, err := l.Accept() if err != nil { return err } go p.handle(c) } } func (p *Proxy) handle(nc net.Conn) { var ( start = time.Now() ctx = NewContext(nc).(*proxyContext) err error ) defer func() { if r := recover(); r != nil { if err, ok := r.(error); ok { ctx.LogEntry().Err(err).Warn("Bug in code, recovered from panic!") } _ = nc.Close() } }() defer func() { if cerr := ctx.Close(); cerr != nil && err == nil && !netutil.IsClosing(cerr) { err = cerr } log := ctx.AccessLogEntry().Value("duration", time.Since(start)) if err != nil && !netutil.IsClosing(err) { log = log.Err(err) } if req := ctx.Request(); req != nil { log = log.Values(logger.Values{ "method": req.Method, "request": req.URL.String(), }) } if res := ctx.Response(); res != nil { //countStatus(res.StatusCode) log.Values(logger.Values{ "response": res.StatusCode, }).Info(res.Status) } else { //countStatus(0) log.Info("No response") } }() // Propagate timeouts ctx.SetIdleTimeout(p.IdleTimeout) for _, f := range p.OnConnect { fc, err := f.HandleConn(ctx) if err != nil { ServerLog.Value("filter", fmt.Sprintf("%T", f)).Err(err).Warn("Error in conn filter") p.handleError(ctx, err, true) _ = nc.Close() return } else if fc != nil { ServerLog.Value("filter", fmt.Sprintf("%T", f)).Debug("Replacing connection from filter") ctx.Conn = fc ctx.br = bufio.NewReader(fc) } } for { if ctx.transparentTLS { ctx.req = &http.Request{ Method: http.MethodConnect, URL: &url.URL{Scheme: "tcp", Host: net.JoinHostPort(ctx.serverName, strconv.Itoa(ctx.transparent))}, Host: net.JoinHostPort(ctx.serverName, strconv.Itoa(ctx.transparent)), Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, Close: true, } } else if ctx.req, err = http.ReadRequest(ctx.Reader()); err != nil { if !(errors.Is(err, io.EOF) || errors.Is(err, syscall.ECONNRESET)) { ServerLog.Err(err).Debug("Error reading request") } p.handleError(ctx, err, true) return } if ctx.transparent > 0 { // Canonicallize to absolute URL if ctx.req.URL.Host == "" { ctx.req.URL.Host = ctx.req.Host } if ctx.req.URL.Scheme == "" { ctx.req.URL.Scheme = "http" } ctx.transparent = 0 } for _, f := range p.OnRequest { newReq, newRes := f.HandleRequest(ctx) if newReq != nil { ServerLog.Values(logger.Values{ "filter": fmt.Sprintf("%T", f), "old_method": ctx.req.Method, "old_url": ctx.req.URL, "new_method": newReq.Method, "new_url": newReq.URL, }).Debug("Replacing request from filter") ctx.req = newReq } if newRes != nil { log := ServerLog.Values(logger.Values{ "filter": fmt.Sprintf("%T", f), "response": newRes.StatusCode, "status": newRes.Status, }) log.Debug("Replacing response from filter") ctx.res = newRes if err = p.writeResponse(ctx); err != nil { if netutil.IsClosing(err) { return } log.Err(err).Warn("Error overriding repsonse") } continue } } if err = p.handleRequest(ctx); err != nil { p.handleError(ctx, err, true) return } // Only once if ctx.transparent > 0 || ctx.transparentTLS || ctx.req.Method == http.MethodConnect { return } } } func (p *Proxy) handleError(ctx *proxyContext, err error, sendResponse bool) { res := ctx.Response() if res == nil && sendResponse { res = NewErrorResponse(err, ctx.Request()) } for _, f := range p.OnError { if newRes := f.HandleError(ctx, err); newRes != nil { res = newRes } } if sendResponse && res != nil { if werr := p.writeResponse(ctx); werr != nil && !netutil.IsClosing(err) { ServerLog.Err(werr).Warn("Error writing error response") } } } func (p *Proxy) handleRequest(ctx *proxyContext) (err error) { switch { case ctx.req == nil: ctx.LogEntry().Warn("Request is nil in handleRequest!?") return errors.New("proxy: request is nil?") case headerContains(ctx.req.Header, HeaderConnection, "upgrade"): if headerContains(ctx.req.Header, HeaderUpgrade, "websocket") { return p.serveWebSocket(ctx) } ctx.res = NewResponse(http.StatusBadRequest, nil, ctx.req) return p.writeResponse(ctx) case ctx.req.Method == http.MethodConnect: return p.serveConnect(ctx) case ctx.req.URL.IsAbs(): return p.serveForward(ctx) default: return p.serve(ctx) } } func (p *Proxy) applyResponseHandler(ctx *proxyContext) { for _, f := range p.OnResponse { if newRes := f.HandleResponse(ctx); newRes != nil { if ctx.res.Body != nil { _ = ctx.res.Body.Close() } ctx.res = newRes } } } func (p *Proxy) serve(ctx *proxyContext) (err error) { var ( b = new(bytes.Buffer) cw = ctx.cw ) // This is where our response headers etc. are captured ctx.res = NewResponse(http.StatusOK, nil, ctx.req) // This is where our response body is captured ctx.cw = &countingWriter{writer: b, bytes: ctx.cw.bytes} // Pass ServeHTTP call to mux handler(s) p.mux.ServeHTTP(ctx, ctx.req) // Expose body ctx.res.Body = io.NopCloser(b) // Correct headers if ctx.res.Header.Get(HeaderDate) == "" { ctx.res.Header.Set(HeaderDate, time.Now().UTC().Format("Mon, 2 Jan 2006 15:04:05")+" GMT") } if ctx.res.Header.Get(HeaderContentType) == "" && b.Len() > 0 { ctx.res.Header.Set(HeaderContentType, "text/html; charset=utf-8") } // Restore writer for the call to writeResponse ctx.cw = cw return p.writeResponse(ctx) } func (p *Proxy) serveConnect(ctx *proxyContext) (err error) { log := ctx.LogEntry() // Most browsers expect to get a 200 OK after firing a HTTP CONNECT request; if the upstream // encounters any errors, we'll inform the client after reading the HTTP request that follows. if !(ctx.transparent > 0 || ctx.transparentTLS) { if _, err = io.WriteString(ctx, "HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil { return } } switch ctx.req.URL.Scheme { case "": ctx.req.URL.Scheme = "tcp" } log.Value("target", ctx.req.URL.String()).Debugf("%s CONNECT request", ctx.req.Proto) var ( timeout, cancel = context.WithTimeout(context.Background(), p.DialTimeout) c net.Conn ) if c, err = p.dial(timeout, ctx.req); err != nil { cancel() ctx.res = NewErrorResponse(err, ctx.req) _ = p.writeResponse(ctx) _ = ctx.Close() return fmt.Errorf("proxy: dial %s error: %w", ctx.req.URL, err) } cancel() ctx.res = NewResponse(http.StatusOK, nil, ctx.req) srv := NewContext(c).(*proxyContext) srv.SetIdleTimeout(p.IdleTimeout) return p.multiplex(ctx, srv) } func (p *Proxy) serveForward(ctx *proxyContext) (err error) { log := ctx.LogEntry() log.Value("target", ctx.req.URL.String()).Debugf("%s forward request", ctx.req.Proto) if ctx.res, err = p.RoundTripper.RoundTrip(ctx.req); err != nil { // log.Printf("%s forward request error: %v", ctx, err) ctx.res = NewErrorResponse(err, ctx.req) _ = p.writeResponse(ctx) _ = ctx.Close() return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err) } p.applyResponseHandler(ctx) return p.writeResponse(ctx) } func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) { log := ctx.LogEntry().Value("target", ctx.req.URL.String()) switch ctx.req.URL.Scheme { case "http": ctx.req.URL.Scheme = "ws" case "https": ctx.req.URL.Scheme = "wss" } log.Debugf("%s websocket request", ctx.req.Proto) var ( timeout, cancel = context.WithTimeout(context.Background(), p.DialTimeout) c net.Conn ) if c, err = p.dial(timeout, ctx.req); err != nil { cancel() ctx.res = NewErrorResponse(err, ctx.req) _ = p.writeResponse(ctx) _ = ctx.Close() return fmt.Errorf("proxy: dial %s error: %w", ctx.req.URL, err) } cancel() srv := NewContext(c).(*proxyContext) srv.SetIdleTimeout(p.IdleTimeout) if err = ctx.req.Write(srv); err != nil { ctx.res = NewErrorResponse(err, ctx.req) _ = p.writeResponse(ctx) _ = ctx.Close() return fmt.Errorf("proxy: failed to write request to upstream: %w", err) } if ctx.res, err = http.ReadResponse(srv.Reader(), ctx.req); err != nil { ctx.res = NewErrorResponse(err, ctx.req) _ = p.writeResponse(ctx) _ = ctx.Close() return fmt.Errorf("proxy: failed to read response from upstream: %w", err) } log.Values(logger.Values{ "response": ctx.res.StatusCode, "status": ctx.res.Status, }).Debug("WebSocket response from upstream") if err = p.writeResponse(ctx); err != nil { _ = ctx.Close() return } ctx.SetIdleTimeout(p.WebSocketIdleTimeout) return p.multiplex(ctx, srv) } func (p *Proxy) multiplex(ctx, srv Context) (err error) { var ( errs = make(chan error, 1) done = make(chan struct{}, 1) ) go func(errs chan<- error) { defer close(done) if _, err := io.Copy(srv, ctx); err != nil { errs <- err } }(errs) go func(errs chan<- error) { if _, err := io.Copy(ctx, srv); err != nil { errs <- err } }(errs) select { case err = <-errs: return case <-done: return } } func (p *Proxy) writeResponse(ctx *proxyContext) (err error) { res := ctx.Response() for _, f := range p.OnResponse { if newRes := f.HandleResponse(ctx); newRes != nil { log.Printf("Filter returned response HTTP %s", newRes.Status) if res.Body != nil { _ = res.Body.Close() } res = newRes } } ServerLog.Values(logger.Values{ "close": res.Close, "header": res.Header, }).Debug("Writing response") if err = res.Write(ctx); err != nil { return } if res.Close || ctx.res.Close || strings.ToLower(ctx.res.Header.Get(HeaderConnection)) != "keep-alive" { // Force closing of connection. if err = ctx.Close(); err != nil { return } return io.EOF } return } func headerContains(h http.Header, k, v string) bool { vs := h[http.CanonicalHeaderKey(k)] return slices.ContainsFunc(vs, func(e string) bool { return strings.EqualFold(e, v) }) }