package proxy import ( "bytes" "context" "crypto/tls" "errors" "fmt" "io" "net" "net/http" "git.maze.io/maze/styx/ca" "git.maze.io/maze/styx/internal/cryptutil" "git.maze.io/maze/styx/internal/netutil" ) // Dialer can make outbound connections to upstream servers. type Dialer interface { // DialContext makes a new connection to the address specified in the [http.Request]. // // The [http.Request] contains the URL scheme (http, https, ws, wss) and host (with optional port) // to connect to. The [context.Context] may be used for cancellation and timeouts. DialContext(context.Context, *http.Request) (net.Conn, error) } type defaultDialer struct{} func (defaultDialer) DialContext(ctx context.Context, req *http.Request) (net.Conn, error) { if host := netutil.Host(req.URL.Host); host == "" { return nil, errors.New("proxy: host missing in address") } var d = net.Dialer{ Resolver: net.DefaultResolver, FallbackDelay: -1, } // Ensure we have a port. switch req.URL.Scheme { case "http", "ws": req.URL.Host = netutil.EnsurePort(req.URL.Host, "80") case "https", "wss": req.URL.Host = netutil.EnsurePort(req.URL.Host, "443") } // Resolve the host. if ips, err := d.Resolver.LookupIP(ctx, "ip", netutil.Host(req.URL.Host)); err != nil { return nil, err } else { for _, ip := range ips { switch { case ip.IsUnspecified(): return nil, fmt.Errorf("proxy: host %s resolves to unspecified address (blocked by DNS?)", netutil.Host(req.URL.Host)) case ip.IsLoopback(): return nil, fmt.Errorf("proxy: host %s resolves to loopback address (blocked by DNS?)", netutil.Host(req.URL.Host)) } } } // Make the connection. switch req.URL.Scheme { case "tcp", "http", "ws": // Plain TCP client connection. return d.DialContext(ctx, "tcp", req.URL.Host) case "https", "wss": // Secure TLS client connection. c, err := d.DialContext(ctx, "tcp", req.URL.Host) if err != nil { return nil, err } s := tls.Client(c, new(tls.Config)) return s, s.Handshake() default: return nil, fmt.Errorf("proxy: can't dial %s protocol", req.URL.Scheme) } } // ErrorHandler can handle errors that occur during proxying. type ErrorHandler interface { // HandleError handles an error that occurred during proxying. If the method returns a non-nil // [http.Response], it will be sent to the client as-is. If it returns nil, a generic HTTP 502 // Bad Gateway response will be sent to the client. // // The [Context] may be inspected to obtain information about the request that caused the error. // However, the [Context.Request] and [Context.Response] may be nil depending on when the error // occurred. HandleError(Context, error) *http.Response } // ConnHandler is called when a new connection has been accepted by the proxy. type ConnHandler interface { HandleConn(Context) (net.Conn, error) } // ConnHandlerFunc is a function that implements the [ConnHandler] interface. type ConnHandlerFunc func(Context) (net.Conn, error) func (f ConnHandlerFunc) HandleConn(ctx Context) (net.Conn, error) { return f(ctx) } // TLS starts a TLS handshake on the accepted connection. func TLS(config *tls.Config) ConnHandler { if config == nil { config = new(tls.Config) } config.NextProtos = []string{"http/1.1"} return ConnHandlerFunc(func(ctx Context) (net.Conn, error) { s := tls.Server(ctx, config) if err := s.Handshake(); err != nil { return nil, err } return s, nil }) } // TLSInterceptor can generate certificates on-the-fly for clients that use a compatible TLS version. func TLSInterceptor(ca ca.CertificateAuthority) ConnHandler { return ConnHandlerFunc(func(ctx Context) (net.Conn, error) { s := tls.Server(ctx, &tls.Config{ GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { ips := []net.IP{net.ParseIP(netutil.Host(ctx.RemoteAddr().String()))} return ca.GetCertificate(hello.ServerName, []string{hello.ServerName}, ips) }, NextProtos: []string{"http/1.1"}, }) if err := s.Handshake(); err != nil { return nil, err } return s, nil }) } // Transparent can handle transparent HTTP(S) requests on the port. // // When a new [net.Conn] is made, this function will inspect the initial request packet for a // TLS handshake. If a TLS handshake is detected, the connection will make a feaux HTTP CONNECT // request using TLS, if no handshake is detected, it will make a feaux plain HTTP CONNECT request. func Transparent(port int) ConnHandler { return ConnHandlerFunc(func(nctx Context) (net.Conn, error) { ctx, ok := nctx.(*proxyContext) if !ok { return nctx, nil } var ( b = new(bytes.Buffer) hello, err = cryptutil.ReadClientHello(io.TeeReader(netutil.ReadOnlyConn{Reader: ctx.br}, b)) log = ctx.Logger() ) if err != nil { if _, ok := err.(tls.RecordHeaderError); !ok { log.Err(err).Value("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error") return nil, err } // Not a TLS connection, moving on to regular HTTP request handling... log.Debug("HTTP connection on transparent port") ctx.transparent = port } else { log.Value("target", hello.ServerName).Debug("TLS connection on transparent port") ctx.transparent = port ctx.transparentTLS = true ctx.serverName = hello.ServerName } return netutil.ReaderConn{ Conn: ctx.Conn, Reader: io.MultiReader(bytes.NewReader(b.Bytes()), ctx.Conn), }, nil }) } // DialHandler can filter network dial requests coming from the proxy. type DialHandler interface { // HandleDial filters an outbound dial request made by the proxy. // // The handler may decide to intercept the dial request and return a new [net.Conn] // that will be used instead of dialing the target. The handler can also return // nil, in which case the normal dial will proceed. HandleDial(Context, *http.Request) (net.Conn, error) } // DialHandlerFunc is a function that implements the [DialHandler] interface. type DialHandlerFunc func(Context, *http.Request) (net.Conn, error) func (f DialHandlerFunc) HandleDial(ctx Context, req *http.Request) (net.Conn, error) { return f(ctx, req) } // ForwardHandler can filter forward HTTP proxy requests. type ForwardHandler interface { HandleForward(Context, *http.Request) (*http.Response, error) } type ForwardHandlerFunc func(Context, *http.Request) (*http.Response, error) func (f ForwardHandlerFunc) HandleForward(ctx Context, req *http.Request) (*http.Response, error) { return f(ctx, req) } // RequestHandler can filter HTTP requests coming to the proxy. type RequestHandler interface { // HandlerRequest filters a HTTP request made to the proxy. The current request may be obtained // from [Context.Request]. If a previous RequestHandler provided a HTTP response, it is available // from [Context.Response]. // // Modifications to the current request can be made to the Request returned by [Context.Request] // and do not require returning a new [http.Request]. // // If the filter returns a non-nil [http.Response], then the [Request] will not be proxied to // any upstream server(s). HandleRequest(Context) (*http.Request, *http.Response) } // RequestHandlerFunc is a function that implements the [RequestHandler] interface. type RequestHandlerFunc func(Context) (*http.Request, *http.Response) func (f RequestHandlerFunc) HandleRequest(ctx Context) (*http.Request, *http.Response) { return f(ctx) } // ResponseHandler can filter HTTP responses coming from the proxy. type ResponseHandler interface { // HandlerResponse filters a HTTP response coming from the proxy. The current response may be // obtained from [Context.Response]. // // Modifications to the current response can be made to the [Response] returned by [Context.Response]. HandleResponse(Context) *http.Response } // ResponseHandlerFunc is a function that implements the [ResponseHandler] interface. type ResponseHandlerFunc func(Context) *http.Response func (f ResponseHandlerFunc) HandleResponse(ctx Context) *http.Response { return f(ctx) } // CleanRequestProxyHeaders removes all headers added by downstream proxies from the [http.Request]. func CleanRequestProxyHeaders() RequestHandler { return RequestHandlerFunc(func(ctx Context) (*http.Request, *http.Response) { if req := ctx.Request(); req != nil { cleanProxyHeaders(req.Header) } return nil, nil }) } // CleanRequestProxyHeaders removes all headers for upstream proxies from the [http.Response]. func CleanResponseProxyHeaders() ResponseHandler { return ResponseHandlerFunc(func(ctx Context) *http.Response { if res := ctx.Response(); res != nil { cleanProxyHeaders(res.Header) } return nil }) } // AddRequestHeaders adds headers to the [http.Request]. Any existing headers with the same // key will remain intact. func AddRequestHeaders(h http.Header) RequestHandler { return RequestHandlerFunc(func(ctx Context) (*http.Request, *http.Response) { if req := ctx.Request(); req != nil { if req.Header == nil { req.Header = make(http.Header) } addHeaders(req.Header, h) } return nil, nil }) } // SetRequestHeaders sets headers to the [http.Request]. Any existing headers with the same // key will be removed. func SetRequestHeaders(h http.Header) RequestHandler { return RequestHandlerFunc(func(ctx Context) (*http.Request, *http.Response) { if req := ctx.Request(); req != nil { if req.Header == nil { req.Header = make(http.Header) } setHeaders(req.Header, h) } return nil, nil }) } // AddResponseHeaders adds headers to the [http.Response]. Any existing headers with the same // key will remain intact. func AddResponseHeaders(h http.Header) ResponseHandler { return ResponseHandlerFunc(func(ctx Context) *http.Response { if res := ctx.Response(); res != nil { if res.Header == nil { res.Header = make(http.Header) } addHeaders(res.Header, h) } return nil }) } // SetResponseHeaders sets headers to the [http.Response]. Any existing headers with the same // key will be removed. func SetResponseHeaders(h http.Header) ResponseHandler { return ResponseHandlerFunc(func(ctx Context) *http.Response { if res := ctx.Response(); res != nil { if res.Header == nil { res.Header = make(http.Header) } setHeaders(res.Header, h) } return nil }) } // cleanProxyHeaders removes all headers commonly used by (reverse) HTTP proxies. func cleanProxyHeaders(h http.Header) { if h == nil { return } for _, k := range []string{ HeaderForwarded, HeaderForwardedFor, HeaderForwardedHost, HeaderForwardedPort, HeaderForwardedProto, HeaderRealIP, HeaderVia, } { h.Del(k) } } // addHeaders adds to the current existing headers. func addHeaders(dst, src http.Header) { if src == nil { return } for k, vv := range src { for _, v := range vv { dst.Add(k, v) } } } // setHeaders replaces all previous values. func setHeaders(dst, src http.Header) { if src == nil { return } for k, vv := range src { dst.Del(k) for _, v := range vv { dst.Add(k, v) } } }