Checkpoint
This commit is contained in:
@@ -14,8 +14,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/netutil/arp"
|
||||
"github.com/sirupsen/logrus"
|
||||
"git.maze.io/maze/styx/logger"
|
||||
)
|
||||
|
||||
// Context provides convenience functions for the current ongoing HTTP proxy transaction (request).
|
||||
@@ -71,17 +70,16 @@ func (w *countingWriter) Write(p []byte) (n int, err error) {
|
||||
|
||||
type proxyContext struct {
|
||||
net.Conn
|
||||
id uint64
|
||||
mac net.HardwareAddr
|
||||
cr *countingReader
|
||||
br *bufio.Reader
|
||||
cw *countingWriter
|
||||
isTransparent bool
|
||||
isTransparentTLS bool
|
||||
serverName string
|
||||
req *http.Request
|
||||
res *http.Response
|
||||
idleTimeout time.Duration
|
||||
id uint64
|
||||
cr *countingReader
|
||||
br *bufio.Reader
|
||||
cw *countingWriter
|
||||
transparent int
|
||||
transparentTLS bool
|
||||
serverName string
|
||||
req *http.Request
|
||||
res *http.Response
|
||||
idleTimeout time.Duration
|
||||
}
|
||||
|
||||
// NewContext returns an initialized context for the provided [net.Conn].
|
||||
@@ -98,7 +96,6 @@ func NewContext(c net.Conn) Context {
|
||||
return &proxyContext{
|
||||
Conn: c,
|
||||
id: binary.BigEndian.Uint64(b),
|
||||
mac: arp.Get(c.RemoteAddr()),
|
||||
cr: cr,
|
||||
br: bufio.NewReader(cr),
|
||||
cw: cw,
|
||||
@@ -106,26 +103,23 @@ func NewContext(c net.Conn) Context {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *proxyContext) AccessLogEntry() *logrus.Entry {
|
||||
func (c *proxyContext) AccessLogEntry() logger.Structured {
|
||||
var id [8]byte
|
||||
binary.BigEndian.PutUint64(id[:], c.id)
|
||||
entry := AccessLog.WithFields(logrus.Fields{
|
||||
entry := AccessLog.Values(logger.Values{
|
||||
"client": c.RemoteAddr().String(),
|
||||
"server": c.LocalAddr().String(),
|
||||
"id": hex.EncodeToString(id[:]),
|
||||
"bytes_rx": c.BytesRead(),
|
||||
"bytes_tx": c.BytesSent(),
|
||||
})
|
||||
if c.mac != nil {
|
||||
return entry.WithField("client_mac", c.mac.String())
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
func (c *proxyContext) LogEntry() *logrus.Entry {
|
||||
func (c *proxyContext) LogEntry() logger.Structured {
|
||||
var id [8]byte
|
||||
binary.BigEndian.PutUint64(id[:], c.id)
|
||||
return ServerLog.WithFields(logrus.Fields{
|
||||
return ServerLog.Values(logger.Values{
|
||||
"client": c.RemoteAddr().String(),
|
||||
"server": c.LocalAddr().String(),
|
||||
"id": hex.EncodeToString(id[:]),
|
||||
|
115
proxy/handler.go
115
proxy/handler.go
@@ -15,7 +15,12 @@ import (
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
)
|
||||
|
||||
// Dialer can make outbound connections to upstream servers.
|
||||
type Dialer interface {
|
||||
// DialContext makes a new connection to the address specified in the [http.Request].
|
||||
//
|
||||
// The [http.Request] contains the URL scheme (http, https, ws, wss) and host (with optional port)
|
||||
// to connect to. The [context.Context] may be used for cancellation and timeouts.
|
||||
DialContext(context.Context, *http.Request) (net.Conn, error)
|
||||
}
|
||||
|
||||
@@ -71,25 +76,38 @@ func (defaultDialer) DialContext(ctx context.Context, req *http.Request) (net.Co
|
||||
}
|
||||
}
|
||||
|
||||
// ConnFilter is called when a new connection has been accepted by the proxy.
|
||||
type ConnFilter interface {
|
||||
FilterConn(Context) (net.Conn, error)
|
||||
// ErrorHandler can handle errors that occur during proxying.
|
||||
type ErrorHandler interface {
|
||||
// HandleError handles an error that occurred during proxying. If the method returns a non-nil
|
||||
// [http.Response], it will be sent to the client as-is. If it returns nil, a generic HTTP 502
|
||||
// Bad Gateway response will be sent to the client.
|
||||
//
|
||||
// The [Context] may be inspected to obtain information about the request that caused the error.
|
||||
// However, the [Context.Request] and [Context.Response] may be nil depending on when the error
|
||||
// occurred.
|
||||
HandleError(Context, error) *http.Response
|
||||
}
|
||||
|
||||
// ConnFilterFunc is a function that implements the [ConnFilter] interface.
|
||||
type ConnFilterFunc func(Context) (net.Conn, error)
|
||||
// ConnHandler is called when a new connection has been accepted by the proxy.
|
||||
type ConnHandler interface {
|
||||
HandleConn(Context) (net.Conn, error)
|
||||
}
|
||||
|
||||
func (f ConnFilterFunc) FilterConn(ctx Context) (net.Conn, error) {
|
||||
// ConnHandlerFunc is a function that implements the [ConnHandler] interface.
|
||||
type ConnHandlerFunc func(Context) (net.Conn, error)
|
||||
|
||||
func (f ConnHandlerFunc) HandleConn(ctx Context) (net.Conn, error) {
|
||||
return f(ctx)
|
||||
}
|
||||
|
||||
// TLS starts a TLS handshake on the accepted connection.
|
||||
func TLS(certs []tls.Certificate) ConnFilter {
|
||||
return ConnFilterFunc(func(ctx Context) (net.Conn, error) {
|
||||
s := tls.Server(ctx, &tls.Config{
|
||||
Certificates: certs,
|
||||
NextProtos: []string{"http/1.1"},
|
||||
})
|
||||
func TLS(config *tls.Config) ConnHandler {
|
||||
if config == nil {
|
||||
config = new(tls.Config)
|
||||
}
|
||||
config.NextProtos = []string{"http/1.1"}
|
||||
return ConnHandlerFunc(func(ctx Context) (net.Conn, error) {
|
||||
s := tls.Server(ctx, config)
|
||||
if err := s.Handshake(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -98,8 +116,8 @@ func TLS(certs []tls.Certificate) ConnFilter {
|
||||
}
|
||||
|
||||
// TLSInterceptor can generate certificates on-the-fly for clients that use a compatible TLS version.
|
||||
func TLSInterceptor(ca ca.CertificateAuthority) ConnFilter {
|
||||
return ConnFilterFunc(func(ctx Context) (net.Conn, error) {
|
||||
func TLSInterceptor(ca ca.CertificateAuthority) ConnHandler {
|
||||
return ConnHandlerFunc(func(ctx Context) (net.Conn, error) {
|
||||
s := tls.Server(ctx, &tls.Config{
|
||||
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
ips := []net.IP{net.ParseIP(netutil.Host(ctx.RemoteAddr().String()))}
|
||||
@@ -119,26 +137,27 @@ func TLSInterceptor(ca ca.CertificateAuthority) ConnFilter {
|
||||
// When a new [net.Conn] is made, this function will inspect the initial request packet for a
|
||||
// TLS handshake. If a TLS handshake is detected, the connection will make a feaux HTTP CONNECT
|
||||
// request using TLS, if no handshake is detected, it will make a feaux plain HTTP CONNECT request.
|
||||
func Transparent() ConnFilter {
|
||||
return ConnFilterFunc(func(nctx Context) (net.Conn, error) {
|
||||
func Transparent(port int) ConnHandler {
|
||||
return ConnHandlerFunc(func(nctx Context) (net.Conn, error) {
|
||||
ctx, ok := nctx.(*proxyContext)
|
||||
if !ok {
|
||||
return nctx, nil
|
||||
}
|
||||
|
||||
b := new(bytes.Buffer)
|
||||
hello, err := cryptutil.ReadClientHello(io.TeeReader(ctx, b))
|
||||
hello, err := cryptutil.ReadClientHello(io.TeeReader(netutil.ReadOnlyConn{Reader: ctx.br}, b))
|
||||
if err != nil {
|
||||
if _, ok := err.(tls.RecordHeaderError); !ok {
|
||||
ctx.LogEntry().WithError(err).WithField("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error")
|
||||
ctx.LogEntry().Err(err).Value("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error")
|
||||
return nil, err
|
||||
}
|
||||
// Not a TLS connection, moving on to regular HTTP request handling...
|
||||
ctx.LogEntry().Debug("HTTP connection on transparent port")
|
||||
ctx.isTransparent = true
|
||||
ctx.transparent = port
|
||||
} else {
|
||||
ctx.LogEntry().WithField("target", hello.ServerName).Debug("TLS connection on transparent port")
|
||||
ctx.isTransparentTLS = true
|
||||
ctx.LogEntry().Value("target", hello.ServerName).Debug("TLS connection on transparent port")
|
||||
ctx.transparent = port
|
||||
ctx.transparentTLS = true
|
||||
ctx.serverName = hello.ServerName
|
||||
}
|
||||
|
||||
@@ -149,10 +168,10 @@ func Transparent() ConnFilter {
|
||||
})
|
||||
}
|
||||
|
||||
// RequestFilter can filter HTTP requests coming to the proxy.
|
||||
type RequestFilter interface {
|
||||
// FilterRequest filters a HTTP request made to the proxy. The current request may be obtained
|
||||
// from [Context.Request]. If a previous RequestFilter provided a HTTP response, it is available
|
||||
// RequestHandler can filter HTTP requests coming to the proxy.
|
||||
type RequestHandler interface {
|
||||
// HandlerRequest filters a HTTP request made to the proxy. The current request may be obtained
|
||||
// from [Context.Request]. If a previous RequestHandler provided a HTTP response, it is available
|
||||
// from [Context.Response].
|
||||
//
|
||||
// Modifications to the current request can be made to the Request returned by [Context.Request]
|
||||
@@ -160,35 +179,35 @@ type RequestFilter interface {
|
||||
//
|
||||
// If the filter returns a non-nil [http.Response], then the [Request] will not be proxied to
|
||||
// any upstream server(s).
|
||||
FilterRequest(Context) (*http.Request, *http.Response)
|
||||
HandleRequest(Context) (*http.Request, *http.Response)
|
||||
}
|
||||
|
||||
// RequestFilterFunc is a function that implements the [RequestFilter] interface.
|
||||
type RequestFilterFunc func(Context) (*http.Request, *http.Response)
|
||||
// RequestHandlerFunc is a function that implements the [RequestHandler] interface.
|
||||
type RequestHandlerFunc func(Context) (*http.Request, *http.Response)
|
||||
|
||||
func (f RequestFilterFunc) FilterRequest(ctx Context) (*http.Request, *http.Response) {
|
||||
func (f RequestHandlerFunc) HandleRequest(ctx Context) (*http.Request, *http.Response) {
|
||||
return f(ctx)
|
||||
}
|
||||
|
||||
// ResponseFilter can filter HTTP responses coming from the proxy.
|
||||
type ResponseFilter interface {
|
||||
// FilterResponse filters a HTTP response coming from the proxy. The current response may be
|
||||
// ResponseHandler can filter HTTP responses coming from the proxy.
|
||||
type ResponseHandler interface {
|
||||
// HandlerResponse filters a HTTP response coming from the proxy. The current response may be
|
||||
// obtained from [Context.Response].
|
||||
//
|
||||
// Modifications to the current response can be made to the [Response] returned by [Context.Response].
|
||||
FilterResponse(Context) *http.Response
|
||||
HandleResponse(Context) *http.Response
|
||||
}
|
||||
|
||||
// ResponseFilterFunc is a function that implements the [ResponseFilter] interface.
|
||||
type ResponseFilterFunc func(Context) *http.Response
|
||||
// ResponseHandlerFunc is a function that implements the [ResponseHandler] interface.
|
||||
type ResponseHandlerFunc func(Context) *http.Response
|
||||
|
||||
func (f ResponseFilterFunc) FilterResponse(ctx Context) *http.Response {
|
||||
func (f ResponseHandlerFunc) HandleResponse(ctx Context) *http.Response {
|
||||
return f(ctx)
|
||||
}
|
||||
|
||||
// CleanRequestProxyHeaders removes all headers added by downstream proxies from the [http.Request].
|
||||
func CleanRequestProxyHeaders() RequestFilter {
|
||||
return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
|
||||
func CleanRequestProxyHeaders() RequestHandler {
|
||||
return RequestHandlerFunc(func(ctx Context) (*http.Request, *http.Response) {
|
||||
if req := ctx.Request(); req != nil {
|
||||
cleanProxyHeaders(req.Header)
|
||||
}
|
||||
@@ -197,8 +216,8 @@ func CleanRequestProxyHeaders() RequestFilter {
|
||||
}
|
||||
|
||||
// CleanRequestProxyHeaders removes all headers for upstream proxies from the [http.Response].
|
||||
func CleanResponseProxyHeaders() ResponseFilter {
|
||||
return ResponseFilterFunc(func(ctx Context) *http.Response {
|
||||
func CleanResponseProxyHeaders() ResponseHandler {
|
||||
return ResponseHandlerFunc(func(ctx Context) *http.Response {
|
||||
if res := ctx.Response(); res != nil {
|
||||
cleanProxyHeaders(res.Header)
|
||||
}
|
||||
@@ -208,8 +227,8 @@ func CleanResponseProxyHeaders() ResponseFilter {
|
||||
|
||||
// AddRequestHeaders adds headers to the [http.Request]. Any existing headers with the same
|
||||
// key will remain intact.
|
||||
func AddRequestHeaders(h http.Header) RequestFilter {
|
||||
return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
|
||||
func AddRequestHeaders(h http.Header) RequestHandler {
|
||||
return RequestHandlerFunc(func(ctx Context) (*http.Request, *http.Response) {
|
||||
if req := ctx.Request(); req != nil {
|
||||
if req.Header == nil {
|
||||
req.Header = make(http.Header)
|
||||
@@ -222,8 +241,8 @@ func AddRequestHeaders(h http.Header) RequestFilter {
|
||||
|
||||
// SetRequestHeaders sets headers to the [http.Request]. Any existing headers with the same
|
||||
// key will be removed.
|
||||
func SetRequestHeaders(h http.Header) RequestFilter {
|
||||
return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
|
||||
func SetRequestHeaders(h http.Header) RequestHandler {
|
||||
return RequestHandlerFunc(func(ctx Context) (*http.Request, *http.Response) {
|
||||
if req := ctx.Request(); req != nil {
|
||||
if req.Header == nil {
|
||||
req.Header = make(http.Header)
|
||||
@@ -236,8 +255,8 @@ func SetRequestHeaders(h http.Header) RequestFilter {
|
||||
|
||||
// AddResponseHeaders adds headers to the [http.Response]. Any existing headers with the same
|
||||
// key will remain intact.
|
||||
func AddResponseHeaders(h http.Header) ResponseFilter {
|
||||
return ResponseFilterFunc(func(ctx Context) *http.Response {
|
||||
func AddResponseHeaders(h http.Header) ResponseHandler {
|
||||
return ResponseHandlerFunc(func(ctx Context) *http.Response {
|
||||
if res := ctx.Response(); res != nil {
|
||||
if res.Header == nil {
|
||||
res.Header = make(http.Header)
|
||||
@@ -250,8 +269,8 @@ func AddResponseHeaders(h http.Header) ResponseFilter {
|
||||
|
||||
// SetResponseHeaders sets headers to the [http.Response]. Any existing headers with the same
|
||||
// key will be removed.
|
||||
func SetResponseHeaders(h http.Header) ResponseFilter {
|
||||
return ResponseFilterFunc(func(ctx Context) *http.Response {
|
||||
func SetResponseHeaders(h http.Header) ResponseHandler {
|
||||
return ResponseHandlerFunc(func(ctx Context) *http.Response {
|
||||
if res := ctx.Response(); res != nil {
|
||||
if res.Header == nil {
|
||||
res.Header = make(http.Header)
|
||||
|
270
proxy/proxy.go
270
proxy/proxy.go
@@ -13,13 +13,14 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
"git.maze.io/maze/styx/logger"
|
||||
"git.maze.io/maze/styx/stats"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Common HTTP headers.
|
||||
@@ -32,6 +33,7 @@ const (
|
||||
HeaderForwardedHost = "X-Forwarded-Host"
|
||||
HeaderForwardedPort = "X-Forwarded-Port"
|
||||
HeaderForwardedProto = "X-Forwarded-Proto"
|
||||
HeaderLocation = "Location"
|
||||
HeaderRealIP = "X-Real-Ip"
|
||||
HeaderUpgrade = "Upgrade"
|
||||
HeaderVia = "Via"
|
||||
@@ -46,43 +48,111 @@ const (
|
||||
|
||||
var (
|
||||
// AccessLog is used for logging requests to the proxy.
|
||||
AccessLog = logrus.StandardLogger()
|
||||
AccessLog = logger.Get()
|
||||
|
||||
// ServerLog is used for logging server log messages.
|
||||
ServerLog = logrus.StandardLogger()
|
||||
ServerLog = logger.Get()
|
||||
)
|
||||
|
||||
// Proxy implements a HTTP(S) proxy.
|
||||
type Proxy struct {
|
||||
rt http.RoundTripper
|
||||
dialer map[string]Dialer
|
||||
connFilter []ConnFilter
|
||||
requestFilter []RequestFilter
|
||||
responseFilter []ResponseFilter
|
||||
dialTimeout time.Duration
|
||||
idleTimeout time.Duration
|
||||
webSocketIdleTimeout time.Duration
|
||||
mux *http.ServeMux
|
||||
// RoundTripper is used to make outbound HTTP requests. It defaults to a [http.Transport]
|
||||
// with a custom DialContext that uses the configured [Dialer]s.
|
||||
//
|
||||
// Only override this if you know what you are doing.
|
||||
RoundTripper http.RoundTripper
|
||||
|
||||
// Dialer is a map of protocol names to [Dialer] implementations. The default [Dialer]
|
||||
// corresponds to an empty string key.
|
||||
//
|
||||
// Only override the default [Dialer] if you know what you are doing.
|
||||
Dialer map[string]Dialer
|
||||
|
||||
// OnConnect is a list of connection filters that are applied in order when a new
|
||||
// connection is established.
|
||||
//
|
||||
// Connection filters can be used to implement custom authentication, logging,
|
||||
// rate limiting, etc.
|
||||
//
|
||||
// Connection filters are applied before any HTTP request is read from the connection.
|
||||
//
|
||||
// Connection filters should return a non-nil error if they want to terminate the
|
||||
// connection. Returning a non-nil [net.Conn] will replace the existing connection
|
||||
// with the returned one.
|
||||
//
|
||||
// Connection filters should not modify the connection in any way (e.g. wrapping it
|
||||
// in a TLS connection) as this will interfere with the proxy's ability to read
|
||||
// HTTP requests from the connection.
|
||||
//
|
||||
// Connection filters are executed sequentially in the order they are added.
|
||||
OnConnect []ConnHandler
|
||||
|
||||
// OnRequest is a list of request filters that are applied in order when a new
|
||||
// HTTP request is read from the connection.
|
||||
//
|
||||
// Request filters can be used to modify the request, or to return a response
|
||||
// directly without forwarding the request to the upstream server.
|
||||
//
|
||||
// Request filters should return a non-nil error if they want to terminate the
|
||||
// connection.
|
||||
//
|
||||
// Request filters are executed sequentially in the order they are added.
|
||||
OnRequest []RequestHandler
|
||||
|
||||
// OnResponse is a list of response filters that are applied in order when a
|
||||
// response is received from the upstream server.
|
||||
//
|
||||
// Response filters can be used to modify the response before it is sent to
|
||||
// the client.
|
||||
//
|
||||
// Response filters should return a non-nil error if they want to terminate the
|
||||
// connection.
|
||||
//
|
||||
// Response filters are executed sequentially in the order they are added.
|
||||
OnResponse []ResponseHandler
|
||||
|
||||
// OnError is a list of error handlers that are applied in order when an
|
||||
// error occurs during request processing.
|
||||
//
|
||||
// Error handlers can be used to log errors, or to return a custom response
|
||||
// to the client.
|
||||
//
|
||||
// Error handlers should return a non-nil error if they want to terminate the
|
||||
// connection.
|
||||
//
|
||||
// Error handlers are executed sequentially in the order they are added.
|
||||
OnError []ErrorHandler
|
||||
|
||||
// DialTimeout is the timeout for establishing new connections to upstream servers.
|
||||
DialTimeout time.Duration
|
||||
|
||||
// IdleTimeout is the timeout for idle connections.
|
||||
IdleTimeout time.Duration
|
||||
|
||||
// WebSocketIdleTimeout is the timeout for idle WebSocket connections.
|
||||
WebSocketIdleTimeout time.Duration
|
||||
|
||||
mux *http.ServeMux
|
||||
}
|
||||
|
||||
// New [Proxy] with somewhat sane defaults.
|
||||
func New() *Proxy {
|
||||
p := &Proxy{
|
||||
dialer: map[string]Dialer{"": defaultDialer{}},
|
||||
dialTimeout: DefaultDialTimeout,
|
||||
idleTimeout: DefaultIdleTimeout,
|
||||
webSocketIdleTimeout: DefaultWebSocketIdleTimeout,
|
||||
Dialer: map[string]Dialer{"": defaultDialer{}},
|
||||
DialTimeout: DefaultDialTimeout,
|
||||
IdleTimeout: DefaultIdleTimeout,
|
||||
WebSocketIdleTimeout: DefaultWebSocketIdleTimeout,
|
||||
mux: http.NewServeMux(),
|
||||
}
|
||||
|
||||
// Make sure the roundtripper uses our dialers.
|
||||
p.rt = &http.Transport{
|
||||
p.RoundTripper = &http.Transport{
|
||||
TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper),
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: time.Second,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return p.dialer[""].DialContext(ctx, &http.Request{
|
||||
return p.Dialer[""].DialContext(ctx, &http.Request{
|
||||
URL: &url.URL{
|
||||
Scheme: network,
|
||||
Host: addr,
|
||||
@@ -112,41 +182,17 @@ func (p *Proxy) HandleFunc(pattern string, handler func(http.ResponseWriter, *ht
|
||||
func (p *Proxy) SetDialer(proto string, dialer Dialer) {
|
||||
if dialer == nil {
|
||||
if proto != "" {
|
||||
delete(p.dialer, proto)
|
||||
delete(p.Dialer, proto)
|
||||
}
|
||||
} else {
|
||||
p.dialer[proto] = dialer
|
||||
p.Dialer[proto] = dialer
|
||||
}
|
||||
}
|
||||
|
||||
// AddConnFilter adds a connection filter to the stack.
|
||||
func (p *Proxy) AddConnFilter(f ConnFilter) {
|
||||
if f == nil {
|
||||
return
|
||||
}
|
||||
p.connFilter = append(p.connFilter, f)
|
||||
}
|
||||
|
||||
// AddRequestFilter adds a request filter to the stack.
|
||||
func (p *Proxy) AddRequestFilter(f RequestFilter) {
|
||||
if f == nil {
|
||||
return
|
||||
}
|
||||
p.requestFilter = append(p.requestFilter, f)
|
||||
}
|
||||
|
||||
// AddResponseFilter adds a response filter to the stack.
|
||||
func (p *Proxy) AddResponseFilter(f ResponseFilter) {
|
||||
if f == nil {
|
||||
return
|
||||
}
|
||||
p.responseFilter = append(p.responseFilter, f)
|
||||
}
|
||||
|
||||
func (p *Proxy) dial(ctx context.Context, req *http.Request) (net.Conn, error) {
|
||||
d, ok := p.dialer[req.URL.Scheme]
|
||||
d, ok := p.Dialer[req.URL.Scheme]
|
||||
if !ok {
|
||||
d = p.dialer[""]
|
||||
d = p.Dialer[""]
|
||||
}
|
||||
|
||||
return d.DialContext(ctx, req)
|
||||
@@ -170,23 +216,32 @@ func (p *Proxy) handle(nc net.Conn) {
|
||||
err error
|
||||
)
|
||||
defer func() {
|
||||
if cerr := ctx.Close(); cerr != nil && err == nil {
|
||||
if r := recover(); r != nil {
|
||||
if err, ok := r.(error); ok {
|
||||
ctx.LogEntry().Err(err).Warn("Bug in code, recovered from panic!")
|
||||
}
|
||||
_ = nc.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
if cerr := ctx.Close(); cerr != nil && err == nil && !netutil.IsClosing(cerr) {
|
||||
err = cerr
|
||||
}
|
||||
|
||||
log := ctx.AccessLogEntry().WithField("duration", time.Since(start))
|
||||
log := ctx.AccessLogEntry().Value("duration", time.Since(start))
|
||||
if err != nil && !netutil.IsClosing(err) {
|
||||
log = log.WithError(err)
|
||||
log = log.Err(err)
|
||||
}
|
||||
if req := ctx.Request(); req != nil {
|
||||
log = log.WithFields(logrus.Fields{
|
||||
log = log.Values(logger.Values{
|
||||
"method": req.Method,
|
||||
"request": req.URL.String(),
|
||||
})
|
||||
}
|
||||
if res := ctx.Response(); res != nil {
|
||||
//countStatus(res.StatusCode)
|
||||
log.WithFields(logrus.Fields{
|
||||
log.Values(logger.Values{
|
||||
"response": res.StatusCode,
|
||||
}).Info(res.Status)
|
||||
} else {
|
||||
@@ -196,38 +251,42 @@ func (p *Proxy) handle(nc net.Conn) {
|
||||
}()
|
||||
|
||||
// Propagate timeouts
|
||||
ctx.SetIdleTimeout(p.idleTimeout)
|
||||
ctx.SetIdleTimeout(p.IdleTimeout)
|
||||
|
||||
for _, f := range p.connFilter {
|
||||
fc, err := f.FilterConn(ctx)
|
||||
for _, f := range p.OnConnect {
|
||||
fc, err := f.HandleConn(ctx)
|
||||
if err != nil {
|
||||
ServerLog.WithField("filter", fmt.Sprintf("%T", f)).WithError(err).Warn("error in conn filter")
|
||||
ServerLog.Value("filter", fmt.Sprintf("%T", f)).Err(err).Warn("Error in conn filter")
|
||||
p.handleError(ctx, err, true)
|
||||
_ = nc.Close()
|
||||
return
|
||||
} else if fc != nil {
|
||||
ServerLog.WithField("filter", fmt.Sprintf("%T", f)).Debug("replacing connection from filter")
|
||||
ServerLog.Value("filter", fmt.Sprintf("%T", f)).Debug("Replacing connection from filter")
|
||||
ctx.Conn = fc
|
||||
ctx.br = bufio.NewReader(fc)
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
if ctx.isTransparentTLS {
|
||||
if ctx.transparentTLS {
|
||||
ctx.req = &http.Request{
|
||||
Method: http.MethodConnect,
|
||||
URL: &url.URL{
|
||||
Scheme: "tcp",
|
||||
Host: net.JoinHostPort(ctx.serverName, "443"),
|
||||
},
|
||||
Method: http.MethodConnect,
|
||||
URL: &url.URL{Scheme: "tcp", Host: net.JoinHostPort(ctx.serverName, strconv.Itoa(ctx.transparent))},
|
||||
Host: net.JoinHostPort(ctx.serverName, strconv.Itoa(ctx.transparent)),
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Close: true,
|
||||
}
|
||||
} else if ctx.req, err = http.ReadRequest(ctx.Reader()); err != nil {
|
||||
if !(errors.Is(err, io.EOF) || errors.Is(err, syscall.ECONNRESET)) {
|
||||
ServerLog.WithError(err).Debug("error reading request")
|
||||
ServerLog.Err(err).Debug("Error reading request")
|
||||
}
|
||||
p.handleError(ctx, err, true)
|
||||
return
|
||||
}
|
||||
|
||||
if ctx.isTransparent {
|
||||
if ctx.transparent > 0 {
|
||||
// Canonicallize to absolute URL
|
||||
if ctx.req.URL.Host == "" {
|
||||
ctx.req.URL.Host = ctx.req.Host
|
||||
@@ -235,47 +294,68 @@ func (p *Proxy) handle(nc net.Conn) {
|
||||
if ctx.req.URL.Scheme == "" {
|
||||
ctx.req.URL.Scheme = "http"
|
||||
}
|
||||
ctx.isTransparent = false
|
||||
ctx.transparent = 0
|
||||
}
|
||||
|
||||
for _, f := range p.requestFilter {
|
||||
newReq, newRes := f.FilterRequest(ctx)
|
||||
for _, f := range p.OnRequest {
|
||||
newReq, newRes := f.HandleRequest(ctx)
|
||||
if newReq != nil {
|
||||
ServerLog.WithFields(logrus.Fields{
|
||||
ServerLog.Values(logger.Values{
|
||||
"filter": fmt.Sprintf("%T", f),
|
||||
"old_method": ctx.req.Method,
|
||||
"old_url": ctx.req.URL,
|
||||
"new_method": newReq.Method,
|
||||
"new_url": newReq.URL,
|
||||
}).Debug("replacing request from filter")
|
||||
}).Debug("Replacing request from filter")
|
||||
ctx.req = newReq
|
||||
}
|
||||
if newRes != nil {
|
||||
log := ServerLog.WithFields(logrus.Fields{
|
||||
log := ServerLog.Values(logger.Values{
|
||||
"filter": fmt.Sprintf("%T", f),
|
||||
"response": newRes.StatusCode,
|
||||
"status": newRes.Status,
|
||||
})
|
||||
log.Debug("replacing response from filter")
|
||||
log.Debug("Replacing response from filter")
|
||||
ctx.res = newRes
|
||||
if err = p.writeResponse(ctx); err != nil {
|
||||
log.WithError(err).Warn("error overriding repsonse")
|
||||
if netutil.IsClosing(err) {
|
||||
return
|
||||
}
|
||||
log.Err(err).Warn("Error overriding repsonse")
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err = p.handleRequest(ctx); err != nil {
|
||||
p.handleError(ctx, err, true)
|
||||
return
|
||||
}
|
||||
|
||||
// Only once
|
||||
if ctx.isTransparent || ctx.isTransparentTLS {
|
||||
if ctx.transparent > 0 || ctx.transparentTLS || ctx.req.Method == http.MethodConnect {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) handleError(ctx *proxyContext, err error, sendResponse bool) {
|
||||
res := ctx.Response()
|
||||
if res == nil && sendResponse {
|
||||
res = NewErrorResponse(err, ctx.Request())
|
||||
}
|
||||
for _, f := range p.OnError {
|
||||
if newRes := f.HandleError(ctx, err); newRes != nil {
|
||||
res = newRes
|
||||
}
|
||||
}
|
||||
if sendResponse && res != nil {
|
||||
if werr := p.writeResponse(ctx); werr != nil && !netutil.IsClosing(err) {
|
||||
ServerLog.Err(werr).Warn("Error writing error response")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) handleRequest(ctx *proxyContext) (err error) {
|
||||
switch {
|
||||
case ctx.req == nil:
|
||||
@@ -300,9 +380,9 @@ func (p *Proxy) handleRequest(ctx *proxyContext) (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) applyResponseFilter(ctx *proxyContext) {
|
||||
for _, f := range p.responseFilter {
|
||||
if newRes := f.FilterResponse(ctx); newRes != nil {
|
||||
func (p *Proxy) applyResponseHandler(ctx *proxyContext) {
|
||||
for _, f := range p.OnResponse {
|
||||
if newRes := f.HandleResponse(ctx); newRes != nil {
|
||||
if ctx.res.Body != nil {
|
||||
_ = ctx.res.Body.Close()
|
||||
}
|
||||
@@ -346,7 +426,7 @@ func (p *Proxy) serveConnect(ctx *proxyContext) (err error) {
|
||||
|
||||
// Most browsers expect to get a 200 OK after firing a HTTP CONNECT request; if the upstream
|
||||
// encounters any errors, we'll inform the client after reading the HTTP request that follows.
|
||||
if !(ctx.isTransparent || ctx.isTransparentTLS) {
|
||||
if !(ctx.transparent > 0 || ctx.transparentTLS) {
|
||||
if _, err = io.WriteString(ctx, "HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil {
|
||||
return
|
||||
}
|
||||
@@ -356,10 +436,10 @@ func (p *Proxy) serveConnect(ctx *proxyContext) (err error) {
|
||||
case "":
|
||||
ctx.req.URL.Scheme = "tcp"
|
||||
}
|
||||
log.WithField("target", ctx.req.URL.String()).Debug("http CONNECT request")
|
||||
log.Value("target", ctx.req.URL.String()).Debugf("%s CONNECT request", ctx.req.Proto)
|
||||
|
||||
var (
|
||||
timeout, cancel = context.WithTimeout(context.Background(), p.dialTimeout)
|
||||
timeout, cancel = context.WithTimeout(context.Background(), p.DialTimeout)
|
||||
c net.Conn
|
||||
)
|
||||
if c, err = p.dial(timeout, ctx.req); err != nil {
|
||||
@@ -373,27 +453,27 @@ func (p *Proxy) serveConnect(ctx *proxyContext) (err error) {
|
||||
|
||||
ctx.res = NewResponse(http.StatusOK, nil, ctx.req)
|
||||
srv := NewContext(c).(*proxyContext)
|
||||
srv.SetIdleTimeout(p.idleTimeout)
|
||||
srv.SetIdleTimeout(p.IdleTimeout)
|
||||
return p.multiplex(ctx, srv)
|
||||
}
|
||||
|
||||
func (p *Proxy) serveForward(ctx *proxyContext) (err error) {
|
||||
log := ctx.LogEntry()
|
||||
log.WithField("target", ctx.req.URL.String()).Debug("http forward request")
|
||||
log.Value("target", ctx.req.URL.String()).Debugf("%s forward request", ctx.req.Proto)
|
||||
|
||||
if ctx.res, err = p.rt.RoundTrip(ctx.req); err != nil {
|
||||
if ctx.res, err = p.RoundTripper.RoundTrip(ctx.req); err != nil {
|
||||
// log.Printf("%s forward request error: %v", ctx, err)
|
||||
ctx.res = NewErrorResponse(err, ctx.req)
|
||||
_ = p.writeResponse(ctx)
|
||||
_ = ctx.Close()
|
||||
return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err)
|
||||
}
|
||||
p.applyResponseFilter(ctx)
|
||||
p.applyResponseHandler(ctx)
|
||||
return p.writeResponse(ctx)
|
||||
}
|
||||
|
||||
func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
|
||||
log := ctx.LogEntry().WithField("target", ctx.req.URL.String())
|
||||
log := ctx.LogEntry().Value("target", ctx.req.URL.String())
|
||||
|
||||
switch ctx.req.URL.Scheme {
|
||||
case "http":
|
||||
@@ -402,9 +482,9 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
|
||||
ctx.req.URL.Scheme = "wss"
|
||||
}
|
||||
|
||||
log.Debug("http websocket request")
|
||||
log.Debugf("%s websocket request", ctx.req.Proto)
|
||||
var (
|
||||
timeout, cancel = context.WithTimeout(context.Background(), p.dialTimeout)
|
||||
timeout, cancel = context.WithTimeout(context.Background(), p.DialTimeout)
|
||||
c net.Conn
|
||||
)
|
||||
if c, err = p.dial(timeout, ctx.req); err != nil {
|
||||
@@ -417,7 +497,7 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
|
||||
cancel()
|
||||
|
||||
srv := NewContext(c).(*proxyContext)
|
||||
srv.SetIdleTimeout(p.idleTimeout)
|
||||
srv.SetIdleTimeout(p.IdleTimeout)
|
||||
if err = ctx.req.Write(srv); err != nil {
|
||||
ctx.res = NewErrorResponse(err, ctx.req)
|
||||
_ = p.writeResponse(ctx)
|
||||
@@ -432,15 +512,15 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
|
||||
return fmt.Errorf("proxy: failed to read response from upstream: %w", err)
|
||||
}
|
||||
|
||||
log.WithFields(logrus.Fields{
|
||||
log.Values(logger.Values{
|
||||
"response": ctx.res.StatusCode,
|
||||
"status": ctx.res.Status,
|
||||
}).Debug("websocket response from upstream")
|
||||
}).Debug("WebSocket response from upstream")
|
||||
if err = p.writeResponse(ctx); err != nil {
|
||||
_ = ctx.Close()
|
||||
return
|
||||
}
|
||||
ctx.SetIdleTimeout(p.webSocketIdleTimeout)
|
||||
ctx.SetIdleTimeout(p.WebSocketIdleTimeout)
|
||||
return p.multiplex(ctx, srv)
|
||||
}
|
||||
|
||||
@@ -471,19 +551,19 @@ func (p *Proxy) multiplex(ctx, srv Context) (err error) {
|
||||
|
||||
func (p *Proxy) writeResponse(ctx *proxyContext) (err error) {
|
||||
res := ctx.Response()
|
||||
for _, f := range p.responseFilter {
|
||||
if newRes := f.FilterResponse(ctx); newRes != nil {
|
||||
log.Printf("filter returned response HTTP %s", newRes.Status)
|
||||
for _, f := range p.OnResponse {
|
||||
if newRes := f.HandleResponse(ctx); newRes != nil {
|
||||
log.Printf("Filter returned response HTTP %s", newRes.Status)
|
||||
if res.Body != nil {
|
||||
_ = res.Body.Close()
|
||||
}
|
||||
res = newRes
|
||||
}
|
||||
}
|
||||
ServerLog.WithFields(logrus.Fields{
|
||||
ServerLog.Values(logger.Values{
|
||||
"close": res.Close,
|
||||
"header": res.Header,
|
||||
}).Debug("writing response")
|
||||
}).Debug("Writing response")
|
||||
if err = res.Write(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user