Checkpoint
This commit is contained in:
@@ -99,6 +99,20 @@ type Proxy struct {
|
||||
// Request filters are executed sequentially in the order they are added.
|
||||
OnRequest []RequestHandler
|
||||
|
||||
// OnDial is a list of dial filters that are applied in order when a new outbound
|
||||
// connection is about to be made.
|
||||
OnDial []DialHandler
|
||||
|
||||
// OnForward is a list of request forward filters that are applied in order when
|
||||
// a new HTTP proxy forward request is about to be made.
|
||||
//
|
||||
// Forward filters can be used to return a response directly without forwarding
|
||||
// the request to the upstream server.
|
||||
//
|
||||
// Forward filters should return a non-nil error if they want to terminate the
|
||||
// connection.
|
||||
OnForward []ForwardHandler
|
||||
|
||||
// OnResponse is a list of response filters that are applied in order when a
|
||||
// response is received from the upstream server.
|
||||
//
|
||||
@@ -267,6 +281,7 @@ func (p *Proxy) handle(nc net.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
log := ctx.LogEntry()
|
||||
for {
|
||||
if ctx.transparentTLS {
|
||||
ctx.req = &http.Request{
|
||||
@@ -297,10 +312,11 @@ func (p *Proxy) handle(nc net.Conn) {
|
||||
ctx.transparent = 0
|
||||
}
|
||||
|
||||
log.Value("count", len(p.OnRequest)).Trace("Running request handlers")
|
||||
for _, f := range p.OnRequest {
|
||||
newReq, newRes := f.HandleRequest(ctx)
|
||||
if newReq != nil {
|
||||
ServerLog.Values(logger.Values{
|
||||
log.Values(logger.Values{
|
||||
"filter": fmt.Sprintf("%T", f),
|
||||
"old_method": ctx.req.Method,
|
||||
"old_url": ctx.req.URL,
|
||||
@@ -310,7 +326,7 @@ func (p *Proxy) handle(nc net.Conn) {
|
||||
ctx.req = newReq
|
||||
}
|
||||
if newRes != nil {
|
||||
log := ServerLog.Values(logger.Values{
|
||||
log := log.Values(logger.Values{
|
||||
"filter": fmt.Sprintf("%T", f),
|
||||
"response": newRes.StatusCode,
|
||||
"status": newRes.Status,
|
||||
@@ -344,6 +360,7 @@ func (p *Proxy) handleError(ctx *proxyContext, err error, sendResponse bool) {
|
||||
if res == nil && sendResponse {
|
||||
res = NewErrorResponse(err, ctx.Request())
|
||||
}
|
||||
ctx.LogEntry().Value("count", len(p.OnError)).Trace("Running error handlers")
|
||||
for _, f := range p.OnError {
|
||||
if newRes := f.HandleError(ctx, err); newRes != nil {
|
||||
res = newRes
|
||||
@@ -438,18 +455,32 @@ func (p *Proxy) serveConnect(ctx *proxyContext) (err error) {
|
||||
}
|
||||
log.Value("target", ctx.req.URL.String()).Debugf("%s CONNECT request", ctx.req.Proto)
|
||||
|
||||
var (
|
||||
timeout, cancel = context.WithTimeout(context.Background(), p.DialTimeout)
|
||||
c net.Conn
|
||||
)
|
||||
if c, err = p.dial(timeout, ctx.req); err != nil {
|
||||
cancel()
|
||||
ctx.res = NewErrorResponse(err, ctx.req)
|
||||
_ = p.writeResponse(ctx)
|
||||
_ = ctx.Close()
|
||||
return fmt.Errorf("proxy: dial %s error: %w", ctx.req.URL, err)
|
||||
var c net.Conn
|
||||
log.Value("count", len(p.OnDial)).Trace("Running dial handlers")
|
||||
for _, f := range p.OnDial {
|
||||
if c, err = f.HandleDial(ctx, ctx.req); err != nil {
|
||||
return
|
||||
} else if c != nil {
|
||||
ServerLog.Values(logger.Values{
|
||||
"filter": fmt.Sprintf("%T", f),
|
||||
"target": ctx.req.URL.String(),
|
||||
"remote": c.RemoteAddr().String(),
|
||||
}).Debug("Replacing connection from filter")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if c == nil {
|
||||
timeout, cancel := context.WithTimeout(context.Background(), p.DialTimeout)
|
||||
if c, err = p.dial(timeout, ctx.req); err != nil {
|
||||
cancel()
|
||||
ctx.res = NewErrorResponse(err, ctx.req)
|
||||
_ = p.writeResponse(ctx)
|
||||
_ = ctx.Close()
|
||||
return fmt.Errorf("proxy: dial %s error: %w", ctx.req.URL, err)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
cancel()
|
||||
|
||||
ctx.res = NewResponse(http.StatusOK, nil, ctx.req)
|
||||
srv := NewContext(c).(*proxyContext)
|
||||
@@ -461,13 +492,29 @@ func (p *Proxy) serveForward(ctx *proxyContext) (err error) {
|
||||
log := ctx.LogEntry()
|
||||
log.Value("target", ctx.req.URL.String()).Debugf("%s forward request", ctx.req.Proto)
|
||||
|
||||
if ctx.res, err = p.RoundTripper.RoundTrip(ctx.req); err != nil {
|
||||
// log.Printf("%s forward request error: %v", ctx, err)
|
||||
ctx.res = NewErrorResponse(err, ctx.req)
|
||||
_ = p.writeResponse(ctx)
|
||||
_ = ctx.Close()
|
||||
return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err)
|
||||
var res *http.Response
|
||||
log.Value("count", len(p.OnForward)).Trace("Running forward handlers")
|
||||
for _, f := range p.OnForward {
|
||||
if res, err = f.HandleForward(ctx, ctx.req); err != nil {
|
||||
return
|
||||
} else if res != nil {
|
||||
log.Debug("Replacing response from forward filter")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if res == nil {
|
||||
if res, err = p.RoundTripper.RoundTrip(ctx.req); err != nil {
|
||||
// log.Printf("%s forward request error: %v", ctx, err)
|
||||
ctx.res = NewErrorResponse(err, ctx.req)
|
||||
_ = p.writeResponse(ctx)
|
||||
_ = ctx.Close()
|
||||
return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err)
|
||||
}
|
||||
} else {
|
||||
ctx.res = res
|
||||
}
|
||||
|
||||
p.applyResponseHandler(ctx)
|
||||
return p.writeResponse(ctx)
|
||||
}
|
||||
|
Reference in New Issue
Block a user