506 lines
13 KiB
Go
506 lines
13 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"slices"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"git.maze.io/maze/styx/internal/netutil"
|
|
"git.maze.io/maze/styx/stats"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// Common HTTP headers.
|
|
const (
|
|
HeaderConnection = "Connection"
|
|
HeaderContentType = "Content-Type"
|
|
HeaderDate = "Date"
|
|
HeaderForwarded = "Forwarded"
|
|
HeaderForwardedFor = "X-Forwarded-For"
|
|
HeaderForwardedHost = "X-Forwarded-Host"
|
|
HeaderForwardedPort = "X-Forwarded-Port"
|
|
HeaderForwardedProto = "X-Forwarded-Proto"
|
|
HeaderRealIP = "X-Real-Ip"
|
|
HeaderUpgrade = "Upgrade"
|
|
HeaderVia = "Via"
|
|
)
|
|
|
|
// Safe defaults.
|
|
const (
|
|
DefaultDialTimeout = 15 * time.Second
|
|
DefaultIdleTimeout = 10 * time.Second
|
|
DefaultWebSocketIdleTimeout = 30 * time.Second
|
|
)
|
|
|
|
var (
|
|
// AccessLog is used for logging requests to the proxy.
|
|
AccessLog = logrus.StandardLogger()
|
|
|
|
// ServerLog is used for logging server log messages.
|
|
ServerLog = logrus.StandardLogger()
|
|
)
|
|
|
|
// Proxy implements a HTTP(S) proxy.
|
|
type Proxy struct {
|
|
rt http.RoundTripper
|
|
dialer map[string]Dialer
|
|
connFilter []ConnFilter
|
|
requestFilter []RequestFilter
|
|
responseFilter []ResponseFilter
|
|
dialTimeout time.Duration
|
|
idleTimeout time.Duration
|
|
webSocketIdleTimeout time.Duration
|
|
mux *http.ServeMux
|
|
}
|
|
|
|
// New [Proxy] with somewhat sane defaults.
|
|
func New() *Proxy {
|
|
p := &Proxy{
|
|
dialer: map[string]Dialer{"": defaultDialer{}},
|
|
dialTimeout: DefaultDialTimeout,
|
|
idleTimeout: DefaultIdleTimeout,
|
|
webSocketIdleTimeout: DefaultWebSocketIdleTimeout,
|
|
mux: http.NewServeMux(),
|
|
}
|
|
|
|
// Make sure the roundtripper uses our dialers.
|
|
p.rt = &http.Transport{
|
|
TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper),
|
|
Proxy: http.ProxyFromEnvironment,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
ExpectContinueTimeout: time.Second,
|
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
return p.dialer[""].DialContext(ctx, &http.Request{
|
|
URL: &url.URL{
|
|
Scheme: network,
|
|
Host: addr,
|
|
},
|
|
})
|
|
},
|
|
}
|
|
|
|
p.Handle("/stats", stats.Handler(stats.Exposed))
|
|
p.Handle("/stats.json", stats.JSONHandler(stats.Exposed))
|
|
|
|
return p
|
|
}
|
|
|
|
// Handle installs a [http.Handler] into the internal mux.
|
|
func (p *Proxy) Handle(pattern string, handler http.Handler) {
|
|
p.mux.Handle(pattern, handler)
|
|
}
|
|
|
|
// HandleFunc installs a [http.HandlerFunc] into the internal mux.
|
|
func (p *Proxy) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
|
|
p.mux.HandleFunc(pattern, handler)
|
|
}
|
|
|
|
// SetDialer specifies a [Dialer] for the specified protocol. The default [Dialer] corresponds
|
|
// to an empty string. Only override the default [Dialer] if you know what you are doing.
|
|
func (p *Proxy) SetDialer(proto string, dialer Dialer) {
|
|
if dialer == nil {
|
|
if proto != "" {
|
|
delete(p.dialer, proto)
|
|
}
|
|
} else {
|
|
p.dialer[proto] = dialer
|
|
}
|
|
}
|
|
|
|
// AddConnFilter adds a connection filter to the stack.
|
|
func (p *Proxy) AddConnFilter(f ConnFilter) {
|
|
if f == nil {
|
|
return
|
|
}
|
|
p.connFilter = append(p.connFilter, f)
|
|
}
|
|
|
|
// AddRequestFilter adds a request filter to the stack.
|
|
func (p *Proxy) AddRequestFilter(f RequestFilter) {
|
|
if f == nil {
|
|
return
|
|
}
|
|
p.requestFilter = append(p.requestFilter, f)
|
|
}
|
|
|
|
// AddResponseFilter adds a response filter to the stack.
|
|
func (p *Proxy) AddResponseFilter(f ResponseFilter) {
|
|
if f == nil {
|
|
return
|
|
}
|
|
p.responseFilter = append(p.responseFilter, f)
|
|
}
|
|
|
|
func (p *Proxy) dial(ctx context.Context, req *http.Request) (net.Conn, error) {
|
|
d, ok := p.dialer[req.URL.Scheme]
|
|
if !ok {
|
|
d = p.dialer[""]
|
|
}
|
|
|
|
return d.DialContext(ctx, req)
|
|
}
|
|
|
|
// Serve proxied connections on the specified listener.
|
|
func (p *Proxy) Serve(l net.Listener) error {
|
|
for {
|
|
c, err := l.Accept()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
go p.handle(c)
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) handle(nc net.Conn) {
|
|
var (
|
|
start = time.Now()
|
|
ctx = NewContext(nc).(*proxyContext)
|
|
err error
|
|
)
|
|
defer func() {
|
|
if cerr := ctx.Close(); cerr != nil && err == nil {
|
|
err = cerr
|
|
}
|
|
|
|
log := ctx.AccessLogEntry().WithField("duration", time.Since(start))
|
|
if err != nil && !netutil.IsClosing(err) {
|
|
log = log.WithError(err)
|
|
}
|
|
if req := ctx.Request(); req != nil {
|
|
log = log.WithFields(logrus.Fields{
|
|
"method": req.Method,
|
|
"request": req.URL.String(),
|
|
})
|
|
}
|
|
if res := ctx.Response(); res != nil {
|
|
//countStatus(res.StatusCode)
|
|
log.WithFields(logrus.Fields{
|
|
"response": res.StatusCode,
|
|
}).Info(res.Status)
|
|
} else {
|
|
//countStatus(0)
|
|
log.Info("No response")
|
|
}
|
|
}()
|
|
|
|
// Propagate timeouts
|
|
ctx.SetIdleTimeout(p.idleTimeout)
|
|
|
|
for _, f := range p.connFilter {
|
|
fc, err := f.FilterConn(ctx)
|
|
if err != nil {
|
|
ServerLog.WithField("filter", fmt.Sprintf("%T", f)).WithError(err).Warn("error in conn filter")
|
|
_ = nc.Close()
|
|
return
|
|
} else if fc != nil {
|
|
ServerLog.WithField("filter", fmt.Sprintf("%T", f)).Debug("replacing connection from filter")
|
|
ctx.Conn = fc
|
|
ctx.br = bufio.NewReader(fc)
|
|
}
|
|
}
|
|
|
|
for {
|
|
if ctx.isTransparentTLS {
|
|
ctx.req = &http.Request{
|
|
Method: http.MethodConnect,
|
|
URL: &url.URL{
|
|
Scheme: "tcp",
|
|
Host: net.JoinHostPort(ctx.serverName, "443"),
|
|
},
|
|
}
|
|
} else if ctx.req, err = http.ReadRequest(ctx.Reader()); err != nil {
|
|
if !(errors.Is(err, io.EOF) || errors.Is(err, syscall.ECONNRESET)) {
|
|
ServerLog.WithError(err).Debug("error reading request")
|
|
}
|
|
return
|
|
}
|
|
|
|
if ctx.isTransparent {
|
|
// Canonicallize to absolute URL
|
|
if ctx.req.URL.Host == "" {
|
|
ctx.req.URL.Host = ctx.req.Host
|
|
}
|
|
if ctx.req.URL.Scheme == "" {
|
|
ctx.req.URL.Scheme = "http"
|
|
}
|
|
ctx.isTransparent = false
|
|
}
|
|
|
|
for _, f := range p.requestFilter {
|
|
newReq, newRes := f.FilterRequest(ctx)
|
|
if newReq != nil {
|
|
ServerLog.WithFields(logrus.Fields{
|
|
"filter": fmt.Sprintf("%T", f),
|
|
"old_method": ctx.req.Method,
|
|
"old_url": ctx.req.URL,
|
|
"new_method": newReq.Method,
|
|
"new_url": newReq.URL,
|
|
}).Debug("replacing request from filter")
|
|
ctx.req = newReq
|
|
}
|
|
if newRes != nil {
|
|
log := ServerLog.WithFields(logrus.Fields{
|
|
"filter": fmt.Sprintf("%T", f),
|
|
"response": newRes.StatusCode,
|
|
"status": newRes.Status,
|
|
})
|
|
log.Debug("replacing response from filter")
|
|
ctx.res = newRes
|
|
if err = p.writeResponse(ctx); err != nil {
|
|
log.WithError(err).Warn("error overriding repsonse")
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
|
|
if err = p.handleRequest(ctx); err != nil {
|
|
return
|
|
}
|
|
|
|
// Only once
|
|
if ctx.isTransparent || ctx.isTransparentTLS {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) handleRequest(ctx *proxyContext) (err error) {
|
|
switch {
|
|
case ctx.req == nil:
|
|
ctx.LogEntry().Warn("Request is nil in handleRequest!?")
|
|
return errors.New("proxy: request is nil?")
|
|
|
|
case headerContains(ctx.req.Header, HeaderConnection, "upgrade"):
|
|
if headerContains(ctx.req.Header, HeaderUpgrade, "websocket") {
|
|
return p.serveWebSocket(ctx)
|
|
}
|
|
ctx.res = NewResponse(http.StatusBadRequest, nil, ctx.req)
|
|
return p.writeResponse(ctx)
|
|
|
|
case ctx.req.Method == http.MethodConnect:
|
|
return p.serveConnect(ctx)
|
|
|
|
case ctx.req.URL.IsAbs():
|
|
return p.serveForward(ctx)
|
|
|
|
default:
|
|
return p.serve(ctx)
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) applyResponseFilter(ctx *proxyContext) {
|
|
for _, f := range p.responseFilter {
|
|
if newRes := f.FilterResponse(ctx); newRes != nil {
|
|
if ctx.res.Body != nil {
|
|
_ = ctx.res.Body.Close()
|
|
}
|
|
ctx.res = newRes
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) serve(ctx *proxyContext) (err error) {
|
|
var (
|
|
b = new(bytes.Buffer)
|
|
cw = ctx.cw
|
|
)
|
|
// This is where our response headers etc. are captured
|
|
ctx.res = NewResponse(http.StatusOK, nil, ctx.req)
|
|
|
|
// This is where our response body is captured
|
|
ctx.cw = &countingWriter{writer: b, bytes: ctx.cw.bytes}
|
|
|
|
// Pass ServeHTTP call to mux handler(s)
|
|
p.mux.ServeHTTP(ctx, ctx.req)
|
|
|
|
// Expose body
|
|
ctx.res.Body = io.NopCloser(b)
|
|
|
|
// Correct headers
|
|
if ctx.res.Header.Get(HeaderDate) == "" {
|
|
ctx.res.Header.Set(HeaderDate, time.Now().UTC().Format("Mon, 2 Jan 2006 15:04:05")+" GMT")
|
|
}
|
|
if ctx.res.Header.Get(HeaderContentType) == "" && b.Len() > 0 {
|
|
ctx.res.Header.Set(HeaderContentType, "text/html; charset=utf-8")
|
|
}
|
|
|
|
// Restore writer for the call to writeResponse
|
|
ctx.cw = cw
|
|
return p.writeResponse(ctx)
|
|
}
|
|
|
|
func (p *Proxy) serveConnect(ctx *proxyContext) (err error) {
|
|
log := ctx.LogEntry()
|
|
|
|
// Most browsers expect to get a 200 OK after firing a HTTP CONNECT request; if the upstream
|
|
// encounters any errors, we'll inform the client after reading the HTTP request that follows.
|
|
if !(ctx.isTransparent || ctx.isTransparentTLS) {
|
|
if _, err = io.WriteString(ctx, "HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
switch ctx.req.URL.Scheme {
|
|
case "":
|
|
ctx.req.URL.Scheme = "tcp"
|
|
}
|
|
log.WithField("target", ctx.req.URL.String()).Debug("http CONNECT request")
|
|
|
|
var (
|
|
timeout, cancel = context.WithTimeout(context.Background(), p.dialTimeout)
|
|
c net.Conn
|
|
)
|
|
if c, err = p.dial(timeout, ctx.req); err != nil {
|
|
cancel()
|
|
ctx.res = NewErrorResponse(err, ctx.req)
|
|
_ = p.writeResponse(ctx)
|
|
_ = ctx.Close()
|
|
return fmt.Errorf("proxy: dial %s error: %w", ctx.req.URL, err)
|
|
}
|
|
cancel()
|
|
|
|
ctx.res = NewResponse(http.StatusOK, nil, ctx.req)
|
|
srv := NewContext(c).(*proxyContext)
|
|
srv.SetIdleTimeout(p.idleTimeout)
|
|
return p.multiplex(ctx, srv)
|
|
}
|
|
|
|
func (p *Proxy) serveForward(ctx *proxyContext) (err error) {
|
|
log := ctx.LogEntry()
|
|
log.WithField("target", ctx.req.URL.String()).Debug("http forward request")
|
|
|
|
if ctx.res, err = p.rt.RoundTrip(ctx.req); err != nil {
|
|
// log.Printf("%s forward request error: %v", ctx, err)
|
|
ctx.res = NewErrorResponse(err, ctx.req)
|
|
_ = p.writeResponse(ctx)
|
|
_ = ctx.Close()
|
|
return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err)
|
|
}
|
|
p.applyResponseFilter(ctx)
|
|
return p.writeResponse(ctx)
|
|
}
|
|
|
|
func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
|
|
log := ctx.LogEntry().WithField("target", ctx.req.URL.String())
|
|
|
|
switch ctx.req.URL.Scheme {
|
|
case "http":
|
|
ctx.req.URL.Scheme = "ws"
|
|
case "https":
|
|
ctx.req.URL.Scheme = "wss"
|
|
}
|
|
|
|
log.Debug("http websocket request")
|
|
var (
|
|
timeout, cancel = context.WithTimeout(context.Background(), p.dialTimeout)
|
|
c net.Conn
|
|
)
|
|
if c, err = p.dial(timeout, ctx.req); err != nil {
|
|
cancel()
|
|
ctx.res = NewErrorResponse(err, ctx.req)
|
|
_ = p.writeResponse(ctx)
|
|
_ = ctx.Close()
|
|
return fmt.Errorf("proxy: dial %s error: %w", ctx.req.URL, err)
|
|
}
|
|
cancel()
|
|
|
|
srv := NewContext(c).(*proxyContext)
|
|
srv.SetIdleTimeout(p.idleTimeout)
|
|
if err = ctx.req.Write(srv); err != nil {
|
|
ctx.res = NewErrorResponse(err, ctx.req)
|
|
_ = p.writeResponse(ctx)
|
|
_ = ctx.Close()
|
|
return fmt.Errorf("proxy: failed to write request to upstream: %w", err)
|
|
}
|
|
|
|
if ctx.res, err = http.ReadResponse(srv.Reader(), ctx.req); err != nil {
|
|
ctx.res = NewErrorResponse(err, ctx.req)
|
|
_ = p.writeResponse(ctx)
|
|
_ = ctx.Close()
|
|
return fmt.Errorf("proxy: failed to read response from upstream: %w", err)
|
|
}
|
|
|
|
log.WithFields(logrus.Fields{
|
|
"response": ctx.res.StatusCode,
|
|
"status": ctx.res.Status,
|
|
}).Debug("websocket response from upstream")
|
|
if err = p.writeResponse(ctx); err != nil {
|
|
_ = ctx.Close()
|
|
return
|
|
}
|
|
ctx.SetIdleTimeout(p.webSocketIdleTimeout)
|
|
return p.multiplex(ctx, srv)
|
|
}
|
|
|
|
func (p *Proxy) multiplex(ctx, srv Context) (err error) {
|
|
var (
|
|
errs = make(chan error, 1)
|
|
done = make(chan struct{}, 1)
|
|
)
|
|
go func(errs chan<- error) {
|
|
defer close(done)
|
|
if _, err := io.Copy(srv, ctx); err != nil {
|
|
errs <- err
|
|
}
|
|
}(errs)
|
|
go func(errs chan<- error) {
|
|
if _, err := io.Copy(ctx, srv); err != nil {
|
|
errs <- err
|
|
}
|
|
}(errs)
|
|
|
|
select {
|
|
case err = <-errs:
|
|
return
|
|
case <-done:
|
|
return
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) writeResponse(ctx *proxyContext) (err error) {
|
|
res := ctx.Response()
|
|
for _, f := range p.responseFilter {
|
|
if newRes := f.FilterResponse(ctx); newRes != nil {
|
|
log.Printf("filter returned response HTTP %s", newRes.Status)
|
|
if res.Body != nil {
|
|
_ = res.Body.Close()
|
|
}
|
|
res = newRes
|
|
}
|
|
}
|
|
ServerLog.WithFields(logrus.Fields{
|
|
"close": res.Close,
|
|
"header": res.Header,
|
|
}).Debug("writing response")
|
|
if err = res.Write(ctx); err != nil {
|
|
return
|
|
}
|
|
if res.Close || ctx.res.Close || strings.ToLower(ctx.res.Header.Get(HeaderConnection)) != "keep-alive" {
|
|
// Force closing of connection.
|
|
if err = ctx.Close(); err != nil {
|
|
return
|
|
}
|
|
return io.EOF
|
|
}
|
|
return
|
|
}
|
|
|
|
func headerContains(h http.Header, k, v string) bool {
|
|
vs := h[http.CanonicalHeaderKey(k)]
|
|
return slices.ContainsFunc(vs, func(e string) bool {
|
|
return strings.EqualFold(e, v)
|
|
})
|
|
}
|