Checkpoint

This commit is contained in:
2025-10-01 21:10:48 +02:00
parent 03352e3312
commit a23259cfdc
52 changed files with 2214 additions and 39 deletions

View File

@@ -168,6 +168,34 @@ func Transparent(port int) ConnHandler {
})
}
// DialHandler can filter network dial requests coming from the proxy.
type DialHandler interface {
// HandleDial filters an outbound dial request made by the proxy.
//
// The handler may decide to intercept the dial request and return a new [net.Conn]
// that will be used instead of dialing the target. The handler can also return
// nil, in which case the normal dial will proceed.
HandleDial(Context, *http.Request) (net.Conn, error)
}
// DialHandlerFunc is a function that implements the [DialHandler] interface.
type DialHandlerFunc func(Context, *http.Request) (net.Conn, error)
func (f DialHandlerFunc) HandleDial(ctx Context, req *http.Request) (net.Conn, error) {
return f(ctx, req)
}
// ForwardHandler can filter forward HTTP proxy requests.
type ForwardHandler interface {
HandleForward(Context, *http.Request) (*http.Response, error)
}
type ForwardHandlerFunc func(Context, *http.Request) (*http.Response, error)
func (f ForwardHandlerFunc) HandleForward(ctx Context, req *http.Request) (*http.Response, error) {
return f(ctx, req)
}
// RequestHandler can filter HTTP requests coming to the proxy.
type RequestHandler interface {
// HandlerRequest filters a HTTP request made to the proxy. The current request may be obtained

View File

@@ -99,6 +99,20 @@ type Proxy struct {
// Request filters are executed sequentially in the order they are added.
OnRequest []RequestHandler
// OnDial is a list of dial filters that are applied in order when a new outbound
// connection is about to be made.
OnDial []DialHandler
// OnForward is a list of request forward filters that are applied in order when
// a new HTTP proxy forward request is about to be made.
//
// Forward filters can be used to return a response directly without forwarding
// the request to the upstream server.
//
// Forward filters should return a non-nil error if they want to terminate the
// connection.
OnForward []ForwardHandler
// OnResponse is a list of response filters that are applied in order when a
// response is received from the upstream server.
//
@@ -267,6 +281,7 @@ func (p *Proxy) handle(nc net.Conn) {
}
}
log := ctx.LogEntry()
for {
if ctx.transparentTLS {
ctx.req = &http.Request{
@@ -297,10 +312,11 @@ func (p *Proxy) handle(nc net.Conn) {
ctx.transparent = 0
}
log.Value("count", len(p.OnRequest)).Trace("Running request handlers")
for _, f := range p.OnRequest {
newReq, newRes := f.HandleRequest(ctx)
if newReq != nil {
ServerLog.Values(logger.Values{
log.Values(logger.Values{
"filter": fmt.Sprintf("%T", f),
"old_method": ctx.req.Method,
"old_url": ctx.req.URL,
@@ -310,7 +326,7 @@ func (p *Proxy) handle(nc net.Conn) {
ctx.req = newReq
}
if newRes != nil {
log := ServerLog.Values(logger.Values{
log := log.Values(logger.Values{
"filter": fmt.Sprintf("%T", f),
"response": newRes.StatusCode,
"status": newRes.Status,
@@ -344,6 +360,7 @@ func (p *Proxy) handleError(ctx *proxyContext, err error, sendResponse bool) {
if res == nil && sendResponse {
res = NewErrorResponse(err, ctx.Request())
}
ctx.LogEntry().Value("count", len(p.OnError)).Trace("Running error handlers")
for _, f := range p.OnError {
if newRes := f.HandleError(ctx, err); newRes != nil {
res = newRes
@@ -438,18 +455,32 @@ func (p *Proxy) serveConnect(ctx *proxyContext) (err error) {
}
log.Value("target", ctx.req.URL.String()).Debugf("%s CONNECT request", ctx.req.Proto)
var (
timeout, cancel = context.WithTimeout(context.Background(), p.DialTimeout)
c net.Conn
)
if c, err = p.dial(timeout, ctx.req); err != nil {
cancel()
ctx.res = NewErrorResponse(err, ctx.req)
_ = p.writeResponse(ctx)
_ = ctx.Close()
return fmt.Errorf("proxy: dial %s error: %w", ctx.req.URL, err)
var c net.Conn
log.Value("count", len(p.OnDial)).Trace("Running dial handlers")
for _, f := range p.OnDial {
if c, err = f.HandleDial(ctx, ctx.req); err != nil {
return
} else if c != nil {
ServerLog.Values(logger.Values{
"filter": fmt.Sprintf("%T", f),
"target": ctx.req.URL.String(),
"remote": c.RemoteAddr().String(),
}).Debug("Replacing connection from filter")
break
}
}
if c == nil {
timeout, cancel := context.WithTimeout(context.Background(), p.DialTimeout)
if c, err = p.dial(timeout, ctx.req); err != nil {
cancel()
ctx.res = NewErrorResponse(err, ctx.req)
_ = p.writeResponse(ctx)
_ = ctx.Close()
return fmt.Errorf("proxy: dial %s error: %w", ctx.req.URL, err)
}
cancel()
}
cancel()
ctx.res = NewResponse(http.StatusOK, nil, ctx.req)
srv := NewContext(c).(*proxyContext)
@@ -461,13 +492,29 @@ func (p *Proxy) serveForward(ctx *proxyContext) (err error) {
log := ctx.LogEntry()
log.Value("target", ctx.req.URL.String()).Debugf("%s forward request", ctx.req.Proto)
if ctx.res, err = p.RoundTripper.RoundTrip(ctx.req); err != nil {
// log.Printf("%s forward request error: %v", ctx, err)
ctx.res = NewErrorResponse(err, ctx.req)
_ = p.writeResponse(ctx)
_ = ctx.Close()
return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err)
var res *http.Response
log.Value("count", len(p.OnForward)).Trace("Running forward handlers")
for _, f := range p.OnForward {
if res, err = f.HandleForward(ctx, ctx.req); err != nil {
return
} else if res != nil {
log.Debug("Replacing response from forward filter")
break
}
}
if res == nil {
if res, err = p.RoundTripper.RoundTrip(ctx.req); err != nil {
// log.Printf("%s forward request error: %v", ctx, err)
ctx.res = NewErrorResponse(err, ctx.req)
_ = p.writeResponse(ctx)
_ = ctx.Close()
return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err)
}
} else {
ctx.res = res
}
p.applyResponseHandler(ctx)
return p.writeResponse(ctx)
}