Files
styx/proxy/handler.go
2025-10-01 21:10:48 +02:00

357 lines
11 KiB
Go

package proxy
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"git.maze.io/maze/styx/ca"
"git.maze.io/maze/styx/internal/cryptutil"
"git.maze.io/maze/styx/internal/netutil"
)
// Dialer can make outbound connections to upstream servers.
type Dialer interface {
// DialContext makes a new connection to the address specified in the [http.Request].
//
// The [http.Request] contains the URL scheme (http, https, ws, wss) and host (with optional port)
// to connect to. The [context.Context] may be used for cancellation and timeouts.
DialContext(context.Context, *http.Request) (net.Conn, error)
}
type defaultDialer struct{}
func (defaultDialer) DialContext(ctx context.Context, req *http.Request) (net.Conn, error) {
if host := netutil.Host(req.URL.Host); host == "" {
return nil, errors.New("proxy: host missing in address")
}
var d = net.Dialer{
Resolver: net.DefaultResolver,
FallbackDelay: -1,
}
// Ensure we have a port.
switch req.URL.Scheme {
case "http", "ws":
req.URL.Host = netutil.EnsurePort(req.URL.Host, "80")
case "https", "wss":
req.URL.Host = netutil.EnsurePort(req.URL.Host, "443")
}
// Resolve the host.
if ips, err := d.Resolver.LookupIP(ctx, "ip", netutil.Host(req.URL.Host)); err != nil {
return nil, err
} else {
for _, ip := range ips {
switch {
case ip.IsUnspecified():
return nil, fmt.Errorf("proxy: host %s resolves to unspecified address (blocked by DNS?)", netutil.Host(req.URL.Host))
case ip.IsLoopback():
return nil, fmt.Errorf("proxy: host %s resolves to loopback address (blocked by DNS?)", netutil.Host(req.URL.Host))
}
}
}
// Make the connection.
switch req.URL.Scheme {
case "tcp", "http", "ws":
// Plain TCP client connection.
return d.DialContext(ctx, "tcp", req.URL.Host)
case "https", "wss":
// Secure TLS client connection.
c, err := d.DialContext(ctx, "tcp", req.URL.Host)
if err != nil {
return nil, err
}
s := tls.Client(c, new(tls.Config))
return s, s.Handshake()
default:
return nil, fmt.Errorf("proxy: can't dial %s protocol", req.URL.Scheme)
}
}
// ErrorHandler can handle errors that occur during proxying.
type ErrorHandler interface {
// HandleError handles an error that occurred during proxying. If the method returns a non-nil
// [http.Response], it will be sent to the client as-is. If it returns nil, a generic HTTP 502
// Bad Gateway response will be sent to the client.
//
// The [Context] may be inspected to obtain information about the request that caused the error.
// However, the [Context.Request] and [Context.Response] may be nil depending on when the error
// occurred.
HandleError(Context, error) *http.Response
}
// ConnHandler is called when a new connection has been accepted by the proxy.
type ConnHandler interface {
HandleConn(Context) (net.Conn, error)
}
// ConnHandlerFunc is a function that implements the [ConnHandler] interface.
type ConnHandlerFunc func(Context) (net.Conn, error)
func (f ConnHandlerFunc) HandleConn(ctx Context) (net.Conn, error) {
return f(ctx)
}
// TLS starts a TLS handshake on the accepted connection.
func TLS(config *tls.Config) ConnHandler {
if config == nil {
config = new(tls.Config)
}
config.NextProtos = []string{"http/1.1"}
return ConnHandlerFunc(func(ctx Context) (net.Conn, error) {
s := tls.Server(ctx, config)
if err := s.Handshake(); err != nil {
return nil, err
}
return s, nil
})
}
// TLSInterceptor can generate certificates on-the-fly for clients that use a compatible TLS version.
func TLSInterceptor(ca ca.CertificateAuthority) ConnHandler {
return ConnHandlerFunc(func(ctx Context) (net.Conn, error) {
s := tls.Server(ctx, &tls.Config{
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
ips := []net.IP{net.ParseIP(netutil.Host(ctx.RemoteAddr().String()))}
return ca.GetCertificate(hello.ServerName, []string{hello.ServerName}, ips)
},
NextProtos: []string{"http/1.1"},
})
if err := s.Handshake(); err != nil {
return nil, err
}
return s, nil
})
}
// Transparent can handle transparent HTTP(S) requests on the port.
//
// When a new [net.Conn] is made, this function will inspect the initial request packet for a
// TLS handshake. If a TLS handshake is detected, the connection will make a feaux HTTP CONNECT
// request using TLS, if no handshake is detected, it will make a feaux plain HTTP CONNECT request.
func Transparent(port int) ConnHandler {
return ConnHandlerFunc(func(nctx Context) (net.Conn, error) {
ctx, ok := nctx.(*proxyContext)
if !ok {
return nctx, nil
}
b := new(bytes.Buffer)
hello, err := cryptutil.ReadClientHello(io.TeeReader(netutil.ReadOnlyConn{Reader: ctx.br}, b))
if err != nil {
if _, ok := err.(tls.RecordHeaderError); !ok {
ctx.LogEntry().Err(err).Value("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error")
return nil, err
}
// Not a TLS connection, moving on to regular HTTP request handling...
ctx.LogEntry().Debug("HTTP connection on transparent port")
ctx.transparent = port
} else {
ctx.LogEntry().Value("target", hello.ServerName).Debug("TLS connection on transparent port")
ctx.transparent = port
ctx.transparentTLS = true
ctx.serverName = hello.ServerName
}
return netutil.ReaderConn{
Conn: ctx.Conn,
Reader: io.MultiReader(bytes.NewReader(b.Bytes()), ctx.Conn),
}, nil
})
}
// DialHandler can filter network dial requests coming from the proxy.
type DialHandler interface {
// HandleDial filters an outbound dial request made by the proxy.
//
// The handler may decide to intercept the dial request and return a new [net.Conn]
// that will be used instead of dialing the target. The handler can also return
// nil, in which case the normal dial will proceed.
HandleDial(Context, *http.Request) (net.Conn, error)
}
// DialHandlerFunc is a function that implements the [DialHandler] interface.
type DialHandlerFunc func(Context, *http.Request) (net.Conn, error)
func (f DialHandlerFunc) HandleDial(ctx Context, req *http.Request) (net.Conn, error) {
return f(ctx, req)
}
// ForwardHandler can filter forward HTTP proxy requests.
type ForwardHandler interface {
HandleForward(Context, *http.Request) (*http.Response, error)
}
type ForwardHandlerFunc func(Context, *http.Request) (*http.Response, error)
func (f ForwardHandlerFunc) HandleForward(ctx Context, req *http.Request) (*http.Response, error) {
return f(ctx, req)
}
// RequestHandler can filter HTTP requests coming to the proxy.
type RequestHandler interface {
// HandlerRequest filters a HTTP request made to the proxy. The current request may be obtained
// from [Context.Request]. If a previous RequestHandler provided a HTTP response, it is available
// from [Context.Response].
//
// Modifications to the current request can be made to the Request returned by [Context.Request]
// and do not require returning a new [http.Request].
//
// If the filter returns a non-nil [http.Response], then the [Request] will not be proxied to
// any upstream server(s).
HandleRequest(Context) (*http.Request, *http.Response)
}
// RequestHandlerFunc is a function that implements the [RequestHandler] interface.
type RequestHandlerFunc func(Context) (*http.Request, *http.Response)
func (f RequestHandlerFunc) HandleRequest(ctx Context) (*http.Request, *http.Response) {
return f(ctx)
}
// ResponseHandler can filter HTTP responses coming from the proxy.
type ResponseHandler interface {
// HandlerResponse filters a HTTP response coming from the proxy. The current response may be
// obtained from [Context.Response].
//
// Modifications to the current response can be made to the [Response] returned by [Context.Response].
HandleResponse(Context) *http.Response
}
// ResponseHandlerFunc is a function that implements the [ResponseHandler] interface.
type ResponseHandlerFunc func(Context) *http.Response
func (f ResponseHandlerFunc) HandleResponse(ctx Context) *http.Response {
return f(ctx)
}
// CleanRequestProxyHeaders removes all headers added by downstream proxies from the [http.Request].
func CleanRequestProxyHeaders() RequestHandler {
return RequestHandlerFunc(func(ctx Context) (*http.Request, *http.Response) {
if req := ctx.Request(); req != nil {
cleanProxyHeaders(req.Header)
}
return nil, nil
})
}
// CleanRequestProxyHeaders removes all headers for upstream proxies from the [http.Response].
func CleanResponseProxyHeaders() ResponseHandler {
return ResponseHandlerFunc(func(ctx Context) *http.Response {
if res := ctx.Response(); res != nil {
cleanProxyHeaders(res.Header)
}
return nil
})
}
// AddRequestHeaders adds headers to the [http.Request]. Any existing headers with the same
// key will remain intact.
func AddRequestHeaders(h http.Header) RequestHandler {
return RequestHandlerFunc(func(ctx Context) (*http.Request, *http.Response) {
if req := ctx.Request(); req != nil {
if req.Header == nil {
req.Header = make(http.Header)
}
addHeaders(req.Header, h)
}
return nil, nil
})
}
// SetRequestHeaders sets headers to the [http.Request]. Any existing headers with the same
// key will be removed.
func SetRequestHeaders(h http.Header) RequestHandler {
return RequestHandlerFunc(func(ctx Context) (*http.Request, *http.Response) {
if req := ctx.Request(); req != nil {
if req.Header == nil {
req.Header = make(http.Header)
}
setHeaders(req.Header, h)
}
return nil, nil
})
}
// AddResponseHeaders adds headers to the [http.Response]. Any existing headers with the same
// key will remain intact.
func AddResponseHeaders(h http.Header) ResponseHandler {
return ResponseHandlerFunc(func(ctx Context) *http.Response {
if res := ctx.Response(); res != nil {
if res.Header == nil {
res.Header = make(http.Header)
}
addHeaders(res.Header, h)
}
return nil
})
}
// SetResponseHeaders sets headers to the [http.Response]. Any existing headers with the same
// key will be removed.
func SetResponseHeaders(h http.Header) ResponseHandler {
return ResponseHandlerFunc(func(ctx Context) *http.Response {
if res := ctx.Response(); res != nil {
if res.Header == nil {
res.Header = make(http.Header)
}
setHeaders(res.Header, h)
}
return nil
})
}
// cleanProxyHeaders removes all headers commonly used by (reverse) HTTP proxies.
func cleanProxyHeaders(h http.Header) {
if h == nil {
return
}
for _, k := range []string{
HeaderForwarded,
HeaderForwardedFor,
HeaderForwardedHost,
HeaderForwardedPort,
HeaderForwardedProto,
HeaderRealIP,
HeaderVia,
} {
h.Del(k)
}
}
// addHeaders adds to the current existing headers.
func addHeaders(dst, src http.Header) {
if src == nil {
return
}
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
// setHeaders replaces all previous values.
func setHeaders(dst, src http.Header) {
if src == nil {
return
}
for k, vv := range src {
dst.Del(k)
for _, v := range vv {
dst.Add(k, v)
}
}
}