Checkpoint

This commit is contained in:
2025-10-01 15:37:55 +02:00
parent 4a60059ff2
commit 03352e3312
31 changed files with 2611 additions and 384 deletions

View File

@@ -14,8 +14,7 @@ import (
"sync/atomic"
"time"
"git.maze.io/maze/styx/internal/netutil/arp"
"github.com/sirupsen/logrus"
"git.maze.io/maze/styx/logger"
)
// Context provides convenience functions for the current ongoing HTTP proxy transaction (request).
@@ -71,17 +70,16 @@ func (w *countingWriter) Write(p []byte) (n int, err error) {
type proxyContext struct {
net.Conn
id uint64
mac net.HardwareAddr
cr *countingReader
br *bufio.Reader
cw *countingWriter
isTransparent bool
isTransparentTLS bool
serverName string
req *http.Request
res *http.Response
idleTimeout time.Duration
id uint64
cr *countingReader
br *bufio.Reader
cw *countingWriter
transparent int
transparentTLS bool
serverName string
req *http.Request
res *http.Response
idleTimeout time.Duration
}
// NewContext returns an initialized context for the provided [net.Conn].
@@ -98,7 +96,6 @@ func NewContext(c net.Conn) Context {
return &proxyContext{
Conn: c,
id: binary.BigEndian.Uint64(b),
mac: arp.Get(c.RemoteAddr()),
cr: cr,
br: bufio.NewReader(cr),
cw: cw,
@@ -106,26 +103,23 @@ func NewContext(c net.Conn) Context {
}
}
func (c *proxyContext) AccessLogEntry() *logrus.Entry {
func (c *proxyContext) AccessLogEntry() logger.Structured {
var id [8]byte
binary.BigEndian.PutUint64(id[:], c.id)
entry := AccessLog.WithFields(logrus.Fields{
entry := AccessLog.Values(logger.Values{
"client": c.RemoteAddr().String(),
"server": c.LocalAddr().String(),
"id": hex.EncodeToString(id[:]),
"bytes_rx": c.BytesRead(),
"bytes_tx": c.BytesSent(),
})
if c.mac != nil {
return entry.WithField("client_mac", c.mac.String())
}
return entry
}
func (c *proxyContext) LogEntry() *logrus.Entry {
func (c *proxyContext) LogEntry() logger.Structured {
var id [8]byte
binary.BigEndian.PutUint64(id[:], c.id)
return ServerLog.WithFields(logrus.Fields{
return ServerLog.Values(logger.Values{
"client": c.RemoteAddr().String(),
"server": c.LocalAddr().String(),
"id": hex.EncodeToString(id[:]),

View File

@@ -15,7 +15,12 @@ import (
"git.maze.io/maze/styx/internal/netutil"
)
// Dialer can make outbound connections to upstream servers.
type Dialer interface {
// DialContext makes a new connection to the address specified in the [http.Request].
//
// The [http.Request] contains the URL scheme (http, https, ws, wss) and host (with optional port)
// to connect to. The [context.Context] may be used for cancellation and timeouts.
DialContext(context.Context, *http.Request) (net.Conn, error)
}
@@ -71,25 +76,38 @@ func (defaultDialer) DialContext(ctx context.Context, req *http.Request) (net.Co
}
}
// ConnFilter is called when a new connection has been accepted by the proxy.
type ConnFilter interface {
FilterConn(Context) (net.Conn, error)
// ErrorHandler can handle errors that occur during proxying.
type ErrorHandler interface {
// HandleError handles an error that occurred during proxying. If the method returns a non-nil
// [http.Response], it will be sent to the client as-is. If it returns nil, a generic HTTP 502
// Bad Gateway response will be sent to the client.
//
// The [Context] may be inspected to obtain information about the request that caused the error.
// However, the [Context.Request] and [Context.Response] may be nil depending on when the error
// occurred.
HandleError(Context, error) *http.Response
}
// ConnFilterFunc is a function that implements the [ConnFilter] interface.
type ConnFilterFunc func(Context) (net.Conn, error)
// ConnHandler is called when a new connection has been accepted by the proxy.
type ConnHandler interface {
HandleConn(Context) (net.Conn, error)
}
func (f ConnFilterFunc) FilterConn(ctx Context) (net.Conn, error) {
// ConnHandlerFunc is a function that implements the [ConnHandler] interface.
type ConnHandlerFunc func(Context) (net.Conn, error)
func (f ConnHandlerFunc) HandleConn(ctx Context) (net.Conn, error) {
return f(ctx)
}
// TLS starts a TLS handshake on the accepted connection.
func TLS(certs []tls.Certificate) ConnFilter {
return ConnFilterFunc(func(ctx Context) (net.Conn, error) {
s := tls.Server(ctx, &tls.Config{
Certificates: certs,
NextProtos: []string{"http/1.1"},
})
func TLS(config *tls.Config) ConnHandler {
if config == nil {
config = new(tls.Config)
}
config.NextProtos = []string{"http/1.1"}
return ConnHandlerFunc(func(ctx Context) (net.Conn, error) {
s := tls.Server(ctx, config)
if err := s.Handshake(); err != nil {
return nil, err
}
@@ -98,8 +116,8 @@ func TLS(certs []tls.Certificate) ConnFilter {
}
// TLSInterceptor can generate certificates on-the-fly for clients that use a compatible TLS version.
func TLSInterceptor(ca ca.CertificateAuthority) ConnFilter {
return ConnFilterFunc(func(ctx Context) (net.Conn, error) {
func TLSInterceptor(ca ca.CertificateAuthority) ConnHandler {
return ConnHandlerFunc(func(ctx Context) (net.Conn, error) {
s := tls.Server(ctx, &tls.Config{
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
ips := []net.IP{net.ParseIP(netutil.Host(ctx.RemoteAddr().String()))}
@@ -119,26 +137,27 @@ func TLSInterceptor(ca ca.CertificateAuthority) ConnFilter {
// When a new [net.Conn] is made, this function will inspect the initial request packet for a
// TLS handshake. If a TLS handshake is detected, the connection will make a feaux HTTP CONNECT
// request using TLS, if no handshake is detected, it will make a feaux plain HTTP CONNECT request.
func Transparent() ConnFilter {
return ConnFilterFunc(func(nctx Context) (net.Conn, error) {
func Transparent(port int) ConnHandler {
return ConnHandlerFunc(func(nctx Context) (net.Conn, error) {
ctx, ok := nctx.(*proxyContext)
if !ok {
return nctx, nil
}
b := new(bytes.Buffer)
hello, err := cryptutil.ReadClientHello(io.TeeReader(ctx, b))
hello, err := cryptutil.ReadClientHello(io.TeeReader(netutil.ReadOnlyConn{Reader: ctx.br}, b))
if err != nil {
if _, ok := err.(tls.RecordHeaderError); !ok {
ctx.LogEntry().WithError(err).WithField("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error")
ctx.LogEntry().Err(err).Value("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error")
return nil, err
}
// Not a TLS connection, moving on to regular HTTP request handling...
ctx.LogEntry().Debug("HTTP connection on transparent port")
ctx.isTransparent = true
ctx.transparent = port
} else {
ctx.LogEntry().WithField("target", hello.ServerName).Debug("TLS connection on transparent port")
ctx.isTransparentTLS = true
ctx.LogEntry().Value("target", hello.ServerName).Debug("TLS connection on transparent port")
ctx.transparent = port
ctx.transparentTLS = true
ctx.serverName = hello.ServerName
}
@@ -149,10 +168,10 @@ func Transparent() ConnFilter {
})
}
// RequestFilter can filter HTTP requests coming to the proxy.
type RequestFilter interface {
// FilterRequest filters a HTTP request made to the proxy. The current request may be obtained
// from [Context.Request]. If a previous RequestFilter provided a HTTP response, it is available
// RequestHandler can filter HTTP requests coming to the proxy.
type RequestHandler interface {
// HandlerRequest filters a HTTP request made to the proxy. The current request may be obtained
// from [Context.Request]. If a previous RequestHandler provided a HTTP response, it is available
// from [Context.Response].
//
// Modifications to the current request can be made to the Request returned by [Context.Request]
@@ -160,35 +179,35 @@ type RequestFilter interface {
//
// If the filter returns a non-nil [http.Response], then the [Request] will not be proxied to
// any upstream server(s).
FilterRequest(Context) (*http.Request, *http.Response)
HandleRequest(Context) (*http.Request, *http.Response)
}
// RequestFilterFunc is a function that implements the [RequestFilter] interface.
type RequestFilterFunc func(Context) (*http.Request, *http.Response)
// RequestHandlerFunc is a function that implements the [RequestHandler] interface.
type RequestHandlerFunc func(Context) (*http.Request, *http.Response)
func (f RequestFilterFunc) FilterRequest(ctx Context) (*http.Request, *http.Response) {
func (f RequestHandlerFunc) HandleRequest(ctx Context) (*http.Request, *http.Response) {
return f(ctx)
}
// ResponseFilter can filter HTTP responses coming from the proxy.
type ResponseFilter interface {
// FilterResponse filters a HTTP response coming from the proxy. The current response may be
// ResponseHandler can filter HTTP responses coming from the proxy.
type ResponseHandler interface {
// HandlerResponse filters a HTTP response coming from the proxy. The current response may be
// obtained from [Context.Response].
//
// Modifications to the current response can be made to the [Response] returned by [Context.Response].
FilterResponse(Context) *http.Response
HandleResponse(Context) *http.Response
}
// ResponseFilterFunc is a function that implements the [ResponseFilter] interface.
type ResponseFilterFunc func(Context) *http.Response
// ResponseHandlerFunc is a function that implements the [ResponseHandler] interface.
type ResponseHandlerFunc func(Context) *http.Response
func (f ResponseFilterFunc) FilterResponse(ctx Context) *http.Response {
func (f ResponseHandlerFunc) HandleResponse(ctx Context) *http.Response {
return f(ctx)
}
// CleanRequestProxyHeaders removes all headers added by downstream proxies from the [http.Request].
func CleanRequestProxyHeaders() RequestFilter {
return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
func CleanRequestProxyHeaders() RequestHandler {
return RequestHandlerFunc(func(ctx Context) (*http.Request, *http.Response) {
if req := ctx.Request(); req != nil {
cleanProxyHeaders(req.Header)
}
@@ -197,8 +216,8 @@ func CleanRequestProxyHeaders() RequestFilter {
}
// CleanRequestProxyHeaders removes all headers for upstream proxies from the [http.Response].
func CleanResponseProxyHeaders() ResponseFilter {
return ResponseFilterFunc(func(ctx Context) *http.Response {
func CleanResponseProxyHeaders() ResponseHandler {
return ResponseHandlerFunc(func(ctx Context) *http.Response {
if res := ctx.Response(); res != nil {
cleanProxyHeaders(res.Header)
}
@@ -208,8 +227,8 @@ func CleanResponseProxyHeaders() ResponseFilter {
// AddRequestHeaders adds headers to the [http.Request]. Any existing headers with the same
// key will remain intact.
func AddRequestHeaders(h http.Header) RequestFilter {
return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
func AddRequestHeaders(h http.Header) RequestHandler {
return RequestHandlerFunc(func(ctx Context) (*http.Request, *http.Response) {
if req := ctx.Request(); req != nil {
if req.Header == nil {
req.Header = make(http.Header)
@@ -222,8 +241,8 @@ func AddRequestHeaders(h http.Header) RequestFilter {
// SetRequestHeaders sets headers to the [http.Request]. Any existing headers with the same
// key will be removed.
func SetRequestHeaders(h http.Header) RequestFilter {
return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
func SetRequestHeaders(h http.Header) RequestHandler {
return RequestHandlerFunc(func(ctx Context) (*http.Request, *http.Response) {
if req := ctx.Request(); req != nil {
if req.Header == nil {
req.Header = make(http.Header)
@@ -236,8 +255,8 @@ func SetRequestHeaders(h http.Header) RequestFilter {
// AddResponseHeaders adds headers to the [http.Response]. Any existing headers with the same
// key will remain intact.
func AddResponseHeaders(h http.Header) ResponseFilter {
return ResponseFilterFunc(func(ctx Context) *http.Response {
func AddResponseHeaders(h http.Header) ResponseHandler {
return ResponseHandlerFunc(func(ctx Context) *http.Response {
if res := ctx.Response(); res != nil {
if res.Header == nil {
res.Header = make(http.Header)
@@ -250,8 +269,8 @@ func AddResponseHeaders(h http.Header) ResponseFilter {
// SetResponseHeaders sets headers to the [http.Response]. Any existing headers with the same
// key will be removed.
func SetResponseHeaders(h http.Header) ResponseFilter {
return ResponseFilterFunc(func(ctx Context) *http.Response {
func SetResponseHeaders(h http.Header) ResponseHandler {
return ResponseHandlerFunc(func(ctx Context) *http.Response {
if res := ctx.Response(); res != nil {
if res.Header == nil {
res.Header = make(http.Header)

View File

@@ -13,13 +13,14 @@ import (
"net/http"
"net/url"
"slices"
"strconv"
"strings"
"syscall"
"time"
"git.maze.io/maze/styx/internal/netutil"
"git.maze.io/maze/styx/logger"
"git.maze.io/maze/styx/stats"
"github.com/sirupsen/logrus"
)
// Common HTTP headers.
@@ -32,6 +33,7 @@ const (
HeaderForwardedHost = "X-Forwarded-Host"
HeaderForwardedPort = "X-Forwarded-Port"
HeaderForwardedProto = "X-Forwarded-Proto"
HeaderLocation = "Location"
HeaderRealIP = "X-Real-Ip"
HeaderUpgrade = "Upgrade"
HeaderVia = "Via"
@@ -46,43 +48,111 @@ const (
var (
// AccessLog is used for logging requests to the proxy.
AccessLog = logrus.StandardLogger()
AccessLog = logger.Get()
// ServerLog is used for logging server log messages.
ServerLog = logrus.StandardLogger()
ServerLog = logger.Get()
)
// Proxy implements a HTTP(S) proxy.
type Proxy struct {
rt http.RoundTripper
dialer map[string]Dialer
connFilter []ConnFilter
requestFilter []RequestFilter
responseFilter []ResponseFilter
dialTimeout time.Duration
idleTimeout time.Duration
webSocketIdleTimeout time.Duration
mux *http.ServeMux
// RoundTripper is used to make outbound HTTP requests. It defaults to a [http.Transport]
// with a custom DialContext that uses the configured [Dialer]s.
//
// Only override this if you know what you are doing.
RoundTripper http.RoundTripper
// Dialer is a map of protocol names to [Dialer] implementations. The default [Dialer]
// corresponds to an empty string key.
//
// Only override the default [Dialer] if you know what you are doing.
Dialer map[string]Dialer
// OnConnect is a list of connection filters that are applied in order when a new
// connection is established.
//
// Connection filters can be used to implement custom authentication, logging,
// rate limiting, etc.
//
// Connection filters are applied before any HTTP request is read from the connection.
//
// Connection filters should return a non-nil error if they want to terminate the
// connection. Returning a non-nil [net.Conn] will replace the existing connection
// with the returned one.
//
// Connection filters should not modify the connection in any way (e.g. wrapping it
// in a TLS connection) as this will interfere with the proxy's ability to read
// HTTP requests from the connection.
//
// Connection filters are executed sequentially in the order they are added.
OnConnect []ConnHandler
// OnRequest is a list of request filters that are applied in order when a new
// HTTP request is read from the connection.
//
// Request filters can be used to modify the request, or to return a response
// directly without forwarding the request to the upstream server.
//
// Request filters should return a non-nil error if they want to terminate the
// connection.
//
// Request filters are executed sequentially in the order they are added.
OnRequest []RequestHandler
// OnResponse is a list of response filters that are applied in order when a
// response is received from the upstream server.
//
// Response filters can be used to modify the response before it is sent to
// the client.
//
// Response filters should return a non-nil error if they want to terminate the
// connection.
//
// Response filters are executed sequentially in the order they are added.
OnResponse []ResponseHandler
// OnError is a list of error handlers that are applied in order when an
// error occurs during request processing.
//
// Error handlers can be used to log errors, or to return a custom response
// to the client.
//
// Error handlers should return a non-nil error if they want to terminate the
// connection.
//
// Error handlers are executed sequentially in the order they are added.
OnError []ErrorHandler
// DialTimeout is the timeout for establishing new connections to upstream servers.
DialTimeout time.Duration
// IdleTimeout is the timeout for idle connections.
IdleTimeout time.Duration
// WebSocketIdleTimeout is the timeout for idle WebSocket connections.
WebSocketIdleTimeout time.Duration
mux *http.ServeMux
}
// New [Proxy] with somewhat sane defaults.
func New() *Proxy {
p := &Proxy{
dialer: map[string]Dialer{"": defaultDialer{}},
dialTimeout: DefaultDialTimeout,
idleTimeout: DefaultIdleTimeout,
webSocketIdleTimeout: DefaultWebSocketIdleTimeout,
Dialer: map[string]Dialer{"": defaultDialer{}},
DialTimeout: DefaultDialTimeout,
IdleTimeout: DefaultIdleTimeout,
WebSocketIdleTimeout: DefaultWebSocketIdleTimeout,
mux: http.NewServeMux(),
}
// Make sure the roundtripper uses our dialers.
p.rt = &http.Transport{
p.RoundTripper = &http.Transport{
TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper),
Proxy: http.ProxyFromEnvironment,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: time.Second,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return p.dialer[""].DialContext(ctx, &http.Request{
return p.Dialer[""].DialContext(ctx, &http.Request{
URL: &url.URL{
Scheme: network,
Host: addr,
@@ -112,41 +182,17 @@ func (p *Proxy) HandleFunc(pattern string, handler func(http.ResponseWriter, *ht
func (p *Proxy) SetDialer(proto string, dialer Dialer) {
if dialer == nil {
if proto != "" {
delete(p.dialer, proto)
delete(p.Dialer, proto)
}
} else {
p.dialer[proto] = dialer
p.Dialer[proto] = dialer
}
}
// AddConnFilter adds a connection filter to the stack.
func (p *Proxy) AddConnFilter(f ConnFilter) {
if f == nil {
return
}
p.connFilter = append(p.connFilter, f)
}
// AddRequestFilter adds a request filter to the stack.
func (p *Proxy) AddRequestFilter(f RequestFilter) {
if f == nil {
return
}
p.requestFilter = append(p.requestFilter, f)
}
// AddResponseFilter adds a response filter to the stack.
func (p *Proxy) AddResponseFilter(f ResponseFilter) {
if f == nil {
return
}
p.responseFilter = append(p.responseFilter, f)
}
func (p *Proxy) dial(ctx context.Context, req *http.Request) (net.Conn, error) {
d, ok := p.dialer[req.URL.Scheme]
d, ok := p.Dialer[req.URL.Scheme]
if !ok {
d = p.dialer[""]
d = p.Dialer[""]
}
return d.DialContext(ctx, req)
@@ -170,23 +216,32 @@ func (p *Proxy) handle(nc net.Conn) {
err error
)
defer func() {
if cerr := ctx.Close(); cerr != nil && err == nil {
if r := recover(); r != nil {
if err, ok := r.(error); ok {
ctx.LogEntry().Err(err).Warn("Bug in code, recovered from panic!")
}
_ = nc.Close()
}
}()
defer func() {
if cerr := ctx.Close(); cerr != nil && err == nil && !netutil.IsClosing(cerr) {
err = cerr
}
log := ctx.AccessLogEntry().WithField("duration", time.Since(start))
log := ctx.AccessLogEntry().Value("duration", time.Since(start))
if err != nil && !netutil.IsClosing(err) {
log = log.WithError(err)
log = log.Err(err)
}
if req := ctx.Request(); req != nil {
log = log.WithFields(logrus.Fields{
log = log.Values(logger.Values{
"method": req.Method,
"request": req.URL.String(),
})
}
if res := ctx.Response(); res != nil {
//countStatus(res.StatusCode)
log.WithFields(logrus.Fields{
log.Values(logger.Values{
"response": res.StatusCode,
}).Info(res.Status)
} else {
@@ -196,38 +251,42 @@ func (p *Proxy) handle(nc net.Conn) {
}()
// Propagate timeouts
ctx.SetIdleTimeout(p.idleTimeout)
ctx.SetIdleTimeout(p.IdleTimeout)
for _, f := range p.connFilter {
fc, err := f.FilterConn(ctx)
for _, f := range p.OnConnect {
fc, err := f.HandleConn(ctx)
if err != nil {
ServerLog.WithField("filter", fmt.Sprintf("%T", f)).WithError(err).Warn("error in conn filter")
ServerLog.Value("filter", fmt.Sprintf("%T", f)).Err(err).Warn("Error in conn filter")
p.handleError(ctx, err, true)
_ = nc.Close()
return
} else if fc != nil {
ServerLog.WithField("filter", fmt.Sprintf("%T", f)).Debug("replacing connection from filter")
ServerLog.Value("filter", fmt.Sprintf("%T", f)).Debug("Replacing connection from filter")
ctx.Conn = fc
ctx.br = bufio.NewReader(fc)
}
}
for {
if ctx.isTransparentTLS {
if ctx.transparentTLS {
ctx.req = &http.Request{
Method: http.MethodConnect,
URL: &url.URL{
Scheme: "tcp",
Host: net.JoinHostPort(ctx.serverName, "443"),
},
Method: http.MethodConnect,
URL: &url.URL{Scheme: "tcp", Host: net.JoinHostPort(ctx.serverName, strconv.Itoa(ctx.transparent))},
Host: net.JoinHostPort(ctx.serverName, strconv.Itoa(ctx.transparent)),
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Close: true,
}
} else if ctx.req, err = http.ReadRequest(ctx.Reader()); err != nil {
if !(errors.Is(err, io.EOF) || errors.Is(err, syscall.ECONNRESET)) {
ServerLog.WithError(err).Debug("error reading request")
ServerLog.Err(err).Debug("Error reading request")
}
p.handleError(ctx, err, true)
return
}
if ctx.isTransparent {
if ctx.transparent > 0 {
// Canonicallize to absolute URL
if ctx.req.URL.Host == "" {
ctx.req.URL.Host = ctx.req.Host
@@ -235,47 +294,68 @@ func (p *Proxy) handle(nc net.Conn) {
if ctx.req.URL.Scheme == "" {
ctx.req.URL.Scheme = "http"
}
ctx.isTransparent = false
ctx.transparent = 0
}
for _, f := range p.requestFilter {
newReq, newRes := f.FilterRequest(ctx)
for _, f := range p.OnRequest {
newReq, newRes := f.HandleRequest(ctx)
if newReq != nil {
ServerLog.WithFields(logrus.Fields{
ServerLog.Values(logger.Values{
"filter": fmt.Sprintf("%T", f),
"old_method": ctx.req.Method,
"old_url": ctx.req.URL,
"new_method": newReq.Method,
"new_url": newReq.URL,
}).Debug("replacing request from filter")
}).Debug("Replacing request from filter")
ctx.req = newReq
}
if newRes != nil {
log := ServerLog.WithFields(logrus.Fields{
log := ServerLog.Values(logger.Values{
"filter": fmt.Sprintf("%T", f),
"response": newRes.StatusCode,
"status": newRes.Status,
})
log.Debug("replacing response from filter")
log.Debug("Replacing response from filter")
ctx.res = newRes
if err = p.writeResponse(ctx); err != nil {
log.WithError(err).Warn("error overriding repsonse")
if netutil.IsClosing(err) {
return
}
log.Err(err).Warn("Error overriding repsonse")
}
continue
}
}
if err = p.handleRequest(ctx); err != nil {
p.handleError(ctx, err, true)
return
}
// Only once
if ctx.isTransparent || ctx.isTransparentTLS {
if ctx.transparent > 0 || ctx.transparentTLS || ctx.req.Method == http.MethodConnect {
return
}
}
}
func (p *Proxy) handleError(ctx *proxyContext, err error, sendResponse bool) {
res := ctx.Response()
if res == nil && sendResponse {
res = NewErrorResponse(err, ctx.Request())
}
for _, f := range p.OnError {
if newRes := f.HandleError(ctx, err); newRes != nil {
res = newRes
}
}
if sendResponse && res != nil {
if werr := p.writeResponse(ctx); werr != nil && !netutil.IsClosing(err) {
ServerLog.Err(werr).Warn("Error writing error response")
}
}
}
func (p *Proxy) handleRequest(ctx *proxyContext) (err error) {
switch {
case ctx.req == nil:
@@ -300,9 +380,9 @@ func (p *Proxy) handleRequest(ctx *proxyContext) (err error) {
}
}
func (p *Proxy) applyResponseFilter(ctx *proxyContext) {
for _, f := range p.responseFilter {
if newRes := f.FilterResponse(ctx); newRes != nil {
func (p *Proxy) applyResponseHandler(ctx *proxyContext) {
for _, f := range p.OnResponse {
if newRes := f.HandleResponse(ctx); newRes != nil {
if ctx.res.Body != nil {
_ = ctx.res.Body.Close()
}
@@ -346,7 +426,7 @@ func (p *Proxy) serveConnect(ctx *proxyContext) (err error) {
// Most browsers expect to get a 200 OK after firing a HTTP CONNECT request; if the upstream
// encounters any errors, we'll inform the client after reading the HTTP request that follows.
if !(ctx.isTransparent || ctx.isTransparentTLS) {
if !(ctx.transparent > 0 || ctx.transparentTLS) {
if _, err = io.WriteString(ctx, "HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil {
return
}
@@ -356,10 +436,10 @@ func (p *Proxy) serveConnect(ctx *proxyContext) (err error) {
case "":
ctx.req.URL.Scheme = "tcp"
}
log.WithField("target", ctx.req.URL.String()).Debug("http CONNECT request")
log.Value("target", ctx.req.URL.String()).Debugf("%s CONNECT request", ctx.req.Proto)
var (
timeout, cancel = context.WithTimeout(context.Background(), p.dialTimeout)
timeout, cancel = context.WithTimeout(context.Background(), p.DialTimeout)
c net.Conn
)
if c, err = p.dial(timeout, ctx.req); err != nil {
@@ -373,27 +453,27 @@ func (p *Proxy) serveConnect(ctx *proxyContext) (err error) {
ctx.res = NewResponse(http.StatusOK, nil, ctx.req)
srv := NewContext(c).(*proxyContext)
srv.SetIdleTimeout(p.idleTimeout)
srv.SetIdleTimeout(p.IdleTimeout)
return p.multiplex(ctx, srv)
}
func (p *Proxy) serveForward(ctx *proxyContext) (err error) {
log := ctx.LogEntry()
log.WithField("target", ctx.req.URL.String()).Debug("http forward request")
log.Value("target", ctx.req.URL.String()).Debugf("%s forward request", ctx.req.Proto)
if ctx.res, err = p.rt.RoundTrip(ctx.req); err != nil {
if ctx.res, err = p.RoundTripper.RoundTrip(ctx.req); err != nil {
// log.Printf("%s forward request error: %v", ctx, err)
ctx.res = NewErrorResponse(err, ctx.req)
_ = p.writeResponse(ctx)
_ = ctx.Close()
return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err)
}
p.applyResponseFilter(ctx)
p.applyResponseHandler(ctx)
return p.writeResponse(ctx)
}
func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
log := ctx.LogEntry().WithField("target", ctx.req.URL.String())
log := ctx.LogEntry().Value("target", ctx.req.URL.String())
switch ctx.req.URL.Scheme {
case "http":
@@ -402,9 +482,9 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
ctx.req.URL.Scheme = "wss"
}
log.Debug("http websocket request")
log.Debugf("%s websocket request", ctx.req.Proto)
var (
timeout, cancel = context.WithTimeout(context.Background(), p.dialTimeout)
timeout, cancel = context.WithTimeout(context.Background(), p.DialTimeout)
c net.Conn
)
if c, err = p.dial(timeout, ctx.req); err != nil {
@@ -417,7 +497,7 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
cancel()
srv := NewContext(c).(*proxyContext)
srv.SetIdleTimeout(p.idleTimeout)
srv.SetIdleTimeout(p.IdleTimeout)
if err = ctx.req.Write(srv); err != nil {
ctx.res = NewErrorResponse(err, ctx.req)
_ = p.writeResponse(ctx)
@@ -432,15 +512,15 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
return fmt.Errorf("proxy: failed to read response from upstream: %w", err)
}
log.WithFields(logrus.Fields{
log.Values(logger.Values{
"response": ctx.res.StatusCode,
"status": ctx.res.Status,
}).Debug("websocket response from upstream")
}).Debug("WebSocket response from upstream")
if err = p.writeResponse(ctx); err != nil {
_ = ctx.Close()
return
}
ctx.SetIdleTimeout(p.webSocketIdleTimeout)
ctx.SetIdleTimeout(p.WebSocketIdleTimeout)
return p.multiplex(ctx, srv)
}
@@ -471,19 +551,19 @@ func (p *Proxy) multiplex(ctx, srv Context) (err error) {
func (p *Proxy) writeResponse(ctx *proxyContext) (err error) {
res := ctx.Response()
for _, f := range p.responseFilter {
if newRes := f.FilterResponse(ctx); newRes != nil {
log.Printf("filter returned response HTTP %s", newRes.Status)
for _, f := range p.OnResponse {
if newRes := f.HandleResponse(ctx); newRes != nil {
log.Printf("Filter returned response HTTP %s", newRes.Status)
if res.Body != nil {
_ = res.Body.Close()
}
res = newRes
}
}
ServerLog.WithFields(logrus.Fields{
ServerLog.Values(logger.Values{
"close": res.Close,
"header": res.Header,
}).Debug("writing response")
}).Debug("Writing response")
if err = res.Write(ctx); err != nil {
return
}