package proxy import ( "bufio" "bytes" "context" "crypto/tls" "errors" "fmt" "io" "log" "net" "net/http" "net/url" "slices" "strings" "syscall" "time" "git.maze.io/maze/styx/internal/netutil" "git.maze.io/maze/styx/stats" "github.com/sirupsen/logrus" ) // Common HTTP headers. const ( HeaderConnection = "Connection" HeaderContentType = "Content-Type" HeaderDate = "Date" HeaderForwarded = "Forwarded" HeaderForwardedFor = "X-Forwarded-For" HeaderForwardedHost = "X-Forwarded-Host" HeaderForwardedPort = "X-Forwarded-Port" HeaderForwardedProto = "X-Forwarded-Proto" HeaderRealIP = "X-Real-Ip" HeaderUpgrade = "Upgrade" HeaderVia = "Via" ) // Safe defaults. const ( DefaultDialTimeout = 15 * time.Second DefaultIdleTimeout = 10 * time.Second DefaultWebSocketIdleTimeout = 30 * time.Second ) var ( // AccessLog is used for logging requests to the proxy. AccessLog = logrus.StandardLogger() // ServerLog is used for logging server log messages. ServerLog = logrus.StandardLogger() ) // Proxy implements a HTTP(S) proxy. type Proxy struct { rt http.RoundTripper dialer map[string]Dialer connFilter []ConnFilter requestFilter []RequestFilter responseFilter []ResponseFilter dialTimeout time.Duration idleTimeout time.Duration webSocketIdleTimeout time.Duration mux *http.ServeMux } // New [Proxy] with somewhat sane defaults. func New() *Proxy { p := &Proxy{ dialer: map[string]Dialer{"": defaultDialer{}}, dialTimeout: DefaultDialTimeout, idleTimeout: DefaultIdleTimeout, webSocketIdleTimeout: DefaultWebSocketIdleTimeout, mux: http.NewServeMux(), } // Make sure the roundtripper uses our dialers. p.rt = &http.Transport{ TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper), Proxy: http.ProxyFromEnvironment, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: time.Second, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return p.dialer[""].DialContext(ctx, &http.Request{ URL: &url.URL{ Scheme: network, Host: addr, }, }) }, } p.Handle("/stats", stats.Handler(stats.Exposed)) p.Handle("/stats.json", stats.JSONHandler(stats.Exposed)) return p } // Handle installs a [http.Handler] into the internal mux. func (p *Proxy) Handle(pattern string, handler http.Handler) { p.mux.Handle(pattern, handler) } // HandleFunc installs a [http.HandlerFunc] into the internal mux. func (p *Proxy) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) { p.mux.HandleFunc(pattern, handler) } // SetDialer specifies a [Dialer] for the specified protocol. The default [Dialer] corresponds // to an empty string. Only override the default [Dialer] if you know what you are doing. func (p *Proxy) SetDialer(proto string, dialer Dialer) { if dialer == nil { if proto != "" { delete(p.dialer, proto) } } else { p.dialer[proto] = dialer } } // AddConnFilter adds a connection filter to the stack. func (p *Proxy) AddConnFilter(f ConnFilter) { if f == nil { return } p.connFilter = append(p.connFilter, f) } // AddRequestFilter adds a request filter to the stack. func (p *Proxy) AddRequestFilter(f RequestFilter) { if f == nil { return } p.requestFilter = append(p.requestFilter, f) } // AddResponseFilter adds a response filter to the stack. func (p *Proxy) AddResponseFilter(f ResponseFilter) { if f == nil { return } p.responseFilter = append(p.responseFilter, f) } func (p *Proxy) dial(ctx context.Context, req *http.Request) (net.Conn, error) { d, ok := p.dialer[req.URL.Scheme] if !ok { d = p.dialer[""] } return d.DialContext(ctx, req) } // Serve proxied connections on the specified listener. func (p *Proxy) Serve(l net.Listener) error { for { c, err := l.Accept() if err != nil { return err } go p.handle(c) } } func (p *Proxy) handle(nc net.Conn) { var ( start = time.Now() ctx = NewContext(nc).(*proxyContext) err error ) defer func() { if cerr := ctx.Close(); cerr != nil && err == nil { err = cerr } log := ctx.AccessLogEntry().WithField("duration", time.Since(start)) if err != nil && !netutil.IsClosing(err) { log = log.WithError(err) } if req := ctx.Request(); req != nil { log = log.WithFields(logrus.Fields{ "method": req.Method, "request": req.URL.String(), }) } if res := ctx.Response(); res != nil { //countStatus(res.StatusCode) log.WithFields(logrus.Fields{ "response": res.StatusCode, }).Info(res.Status) } else { //countStatus(0) log.Info("No response") } }() // Propagate timeouts ctx.SetIdleTimeout(p.idleTimeout) for _, f := range p.connFilter { fc, err := f.FilterConn(ctx) if err != nil { ServerLog.WithField("filter", fmt.Sprintf("%T", f)).WithError(err).Warn("error in conn filter") _ = nc.Close() return } else if fc != nil { ServerLog.WithField("filter", fmt.Sprintf("%T", f)).Debug("replacing connection from filter") ctx.Conn = fc ctx.br = bufio.NewReader(fc) } } for { if ctx.isTransparentTLS { ctx.req = &http.Request{ Method: http.MethodConnect, URL: &url.URL{ Scheme: "tcp", Host: net.JoinHostPort(ctx.serverName, "443"), }, } } else if ctx.req, err = http.ReadRequest(ctx.Reader()); err != nil { if !(errors.Is(err, io.EOF) || errors.Is(err, syscall.ECONNRESET)) { ServerLog.WithError(err).Debug("error reading request") } return } if ctx.isTransparent { // Canonicallize to absolute URL if ctx.req.URL.Host == "" { ctx.req.URL.Host = ctx.req.Host } if ctx.req.URL.Scheme == "" { ctx.req.URL.Scheme = "http" } ctx.isTransparent = false } for _, f := range p.requestFilter { newReq, newRes := f.FilterRequest(ctx) if newReq != nil { ServerLog.WithFields(logrus.Fields{ "filter": fmt.Sprintf("%T", f), "old_method": ctx.req.Method, "old_url": ctx.req.URL, "new_method": newReq.Method, "new_url": newReq.URL, }).Debug("replacing request from filter") ctx.req = newReq } if newRes != nil { log := ServerLog.WithFields(logrus.Fields{ "filter": fmt.Sprintf("%T", f), "response": newRes.StatusCode, "status": newRes.Status, }) log.Debug("replacing response from filter") ctx.res = newRes if err = p.writeResponse(ctx); err != nil { log.WithError(err).Warn("error overriding repsonse") } continue } } if err = p.handleRequest(ctx); err != nil { return } // Only once if ctx.isTransparent || ctx.isTransparentTLS { return } } } func (p *Proxy) handleRequest(ctx *proxyContext) (err error) { switch { case ctx.req == nil: ctx.LogEntry().Warn("Request is nil in handleRequest!?") return errors.New("proxy: request is nil?") case headerContains(ctx.req.Header, HeaderConnection, "upgrade"): if headerContains(ctx.req.Header, HeaderUpgrade, "websocket") { return p.serveWebSocket(ctx) } ctx.res = NewResponse(http.StatusBadRequest, nil, ctx.req) return p.writeResponse(ctx) case ctx.req.Method == http.MethodConnect: return p.serveConnect(ctx) case ctx.req.URL.IsAbs(): return p.serveForward(ctx) default: return p.serve(ctx) } } func (p *Proxy) applyResponseFilter(ctx *proxyContext) { for _, f := range p.responseFilter { if newRes := f.FilterResponse(ctx); newRes != nil { if ctx.res.Body != nil { _ = ctx.res.Body.Close() } ctx.res = newRes } } } func (p *Proxy) serve(ctx *proxyContext) (err error) { var ( b = new(bytes.Buffer) cw = ctx.cw ) // This is where our response headers etc. are captured ctx.res = NewResponse(http.StatusOK, nil, ctx.req) // This is where our response body is captured ctx.cw = &countingWriter{writer: b, bytes: ctx.cw.bytes} // Pass ServeHTTP call to mux handler(s) p.mux.ServeHTTP(ctx, ctx.req) // Expose body ctx.res.Body = io.NopCloser(b) // Correct headers if ctx.res.Header.Get(HeaderDate) == "" { ctx.res.Header.Set(HeaderDate, time.Now().UTC().Format("Mon, 2 Jan 2006 15:04:05")+" GMT") } if ctx.res.Header.Get(HeaderContentType) == "" && b.Len() > 0 { ctx.res.Header.Set(HeaderContentType, "text/html; charset=utf-8") } // Restore writer for the call to writeResponse ctx.cw = cw return p.writeResponse(ctx) } func (p *Proxy) serveConnect(ctx *proxyContext) (err error) { log := ctx.LogEntry() // Most browsers expect to get a 200 OK after firing a HTTP CONNECT request; if the upstream // encounters any errors, we'll inform the client after reading the HTTP request that follows. if !(ctx.isTransparent || ctx.isTransparentTLS) { if _, err = io.WriteString(ctx, "HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil { return } } switch ctx.req.URL.Scheme { case "": ctx.req.URL.Scheme = "tcp" } log.WithField("target", ctx.req.URL.String()).Debug("http CONNECT request") var ( timeout, cancel = context.WithTimeout(context.Background(), p.dialTimeout) c net.Conn ) if c, err = p.dial(timeout, ctx.req); err != nil { cancel() ctx.res = NewErrorResponse(err, ctx.req) _ = p.writeResponse(ctx) _ = ctx.Close() return fmt.Errorf("proxy: dial %s error: %w", ctx.req.URL, err) } cancel() ctx.res = NewResponse(http.StatusOK, nil, ctx.req) srv := NewContext(c).(*proxyContext) srv.SetIdleTimeout(p.idleTimeout) return p.multiplex(ctx, srv) } func (p *Proxy) serveForward(ctx *proxyContext) (err error) { log := ctx.LogEntry() log.WithField("target", ctx.req.URL.String()).Debug("http forward request") if ctx.res, err = p.rt.RoundTrip(ctx.req); err != nil { // log.Printf("%s forward request error: %v", ctx, err) ctx.res = NewErrorResponse(err, ctx.req) _ = p.writeResponse(ctx) _ = ctx.Close() return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err) } p.applyResponseFilter(ctx) return p.writeResponse(ctx) } func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) { log := ctx.LogEntry().WithField("target", ctx.req.URL.String()) switch ctx.req.URL.Scheme { case "http": ctx.req.URL.Scheme = "ws" case "https": ctx.req.URL.Scheme = "wss" } log.Debug("http websocket request") var ( timeout, cancel = context.WithTimeout(context.Background(), p.dialTimeout) c net.Conn ) if c, err = p.dial(timeout, ctx.req); err != nil { cancel() ctx.res = NewErrorResponse(err, ctx.req) _ = p.writeResponse(ctx) _ = ctx.Close() return fmt.Errorf("proxy: dial %s error: %w", ctx.req.URL, err) } cancel() srv := NewContext(c).(*proxyContext) srv.SetIdleTimeout(p.idleTimeout) if err = ctx.req.Write(srv); err != nil { ctx.res = NewErrorResponse(err, ctx.req) _ = p.writeResponse(ctx) _ = ctx.Close() return fmt.Errorf("proxy: failed to write request to upstream: %w", err) } if ctx.res, err = http.ReadResponse(srv.Reader(), ctx.req); err != nil { ctx.res = NewErrorResponse(err, ctx.req) _ = p.writeResponse(ctx) _ = ctx.Close() return fmt.Errorf("proxy: failed to read response from upstream: %w", err) } log.WithFields(logrus.Fields{ "response": ctx.res.StatusCode, "status": ctx.res.Status, }).Debug("websocket response from upstream") if err = p.writeResponse(ctx); err != nil { _ = ctx.Close() return } ctx.SetIdleTimeout(p.webSocketIdleTimeout) return p.multiplex(ctx, srv) } func (p *Proxy) multiplex(ctx, srv Context) (err error) { var ( errs = make(chan error, 1) done = make(chan struct{}, 1) ) go func(errs chan<- error) { defer close(done) if _, err := io.Copy(srv, ctx); err != nil { errs <- err } }(errs) go func(errs chan<- error) { if _, err := io.Copy(ctx, srv); err != nil { errs <- err } }(errs) select { case err = <-errs: return case <-done: return } } func (p *Proxy) writeResponse(ctx *proxyContext) (err error) { res := ctx.Response() for _, f := range p.responseFilter { if newRes := f.FilterResponse(ctx); newRes != nil { log.Printf("filter returned response HTTP %s", newRes.Status) if res.Body != nil { _ = res.Body.Close() } res = newRes } } ServerLog.WithFields(logrus.Fields{ "close": res.Close, "header": res.Header, }).Debug("writing response") if err = res.Write(ctx); err != nil { return } if res.Close || ctx.res.Close || strings.ToLower(ctx.res.Header.Get(HeaderConnection)) != "keep-alive" { // Force closing of connection. if err = ctx.Close(); err != nil { return } return io.EOF } return } func headerContains(h http.Header, k, v string) bool { vs := h[http.CanonicalHeaderKey(k)] return slices.ContainsFunc(vs, func(e string) bool { return strings.EqualFold(e, v) }) }