package proxy import ( "bytes" "context" "crypto/tls" "errors" "fmt" "io" "net" "net/http" "git.maze.io/maze/styx/ca" "git.maze.io/maze/styx/internal/cryptutil" "git.maze.io/maze/styx/internal/netutil" ) type Dialer interface { DialContext(context.Context, *http.Request) (net.Conn, error) } type defaultDialer struct{} func (defaultDialer) DialContext(ctx context.Context, req *http.Request) (net.Conn, error) { if host := netutil.Host(req.URL.Host); host == "" { return nil, errors.New("proxy: host missing in address") } var d = net.Dialer{ Resolver: net.DefaultResolver, FallbackDelay: -1, } // Ensure we have a port. switch req.URL.Scheme { case "http", "ws": req.URL.Host = netutil.EnsurePort(req.URL.Host, "80") case "https", "wss": req.URL.Host = netutil.EnsurePort(req.URL.Host, "443") } // Resolve the host. if ips, err := d.Resolver.LookupIP(ctx, "ip", netutil.Host(req.URL.Host)); err != nil { return nil, err } else { for _, ip := range ips { switch { case ip.IsUnspecified(): return nil, fmt.Errorf("proxy: host %s resolves to unspecified address (blocked by DNS?)", netutil.Host(req.URL.Host)) case ip.IsLoopback(): return nil, fmt.Errorf("proxy: host %s resolves to loopback address (blocked by DNS?)", netutil.Host(req.URL.Host)) } } } // Make the connection. switch req.URL.Scheme { case "tcp", "http", "ws": // Plain TCP client connection. return d.DialContext(ctx, "tcp", req.URL.Host) case "https", "wss": // Secure TLS client connection. c, err := d.DialContext(ctx, "tcp", req.URL.Host) if err != nil { return nil, err } s := tls.Client(c, new(tls.Config)) return s, s.Handshake() default: return nil, fmt.Errorf("proxy: can't dial %s protocol", req.URL.Scheme) } } // ConnFilter is called when a new connection has been accepted by the proxy. type ConnFilter interface { FilterConn(Context) (net.Conn, error) } // ConnFilterFunc is a function that implements the [ConnFilter] interface. type ConnFilterFunc func(Context) (net.Conn, error) func (f ConnFilterFunc) FilterConn(ctx Context) (net.Conn, error) { return f(ctx) } // TLS starts a TLS handshake on the accepted connection. func TLS(certs []tls.Certificate) ConnFilter { return ConnFilterFunc(func(ctx Context) (net.Conn, error) { s := tls.Server(ctx, &tls.Config{ Certificates: certs, NextProtos: []string{"http/1.1"}, }) if err := s.Handshake(); err != nil { return nil, err } return s, nil }) } // TLSInterceptor can generate certificates on-the-fly for clients that use a compatible TLS version. func TLSInterceptor(ca ca.CertificateAuthority) ConnFilter { return ConnFilterFunc(func(ctx Context) (net.Conn, error) { s := tls.Server(ctx, &tls.Config{ GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { ips := []net.IP{net.ParseIP(netutil.Host(ctx.RemoteAddr().String()))} return ca.GetCertificate(hello.ServerName, []string{hello.ServerName}, ips) }, NextProtos: []string{"http/1.1"}, }) if err := s.Handshake(); err != nil { return nil, err } return s, nil }) } // Transparent can handle transparent HTTP(S) requests on the port. // // When a new [net.Conn] is made, this function will inspect the initial request packet for a // TLS handshake. If a TLS handshake is detected, the connection will make a feaux HTTP CONNECT // request using TLS, if no handshake is detected, it will make a feaux plain HTTP CONNECT request. func Transparent() ConnFilter { return ConnFilterFunc(func(nctx Context) (net.Conn, error) { ctx, ok := nctx.(*proxyContext) if !ok { return nctx, nil } b := new(bytes.Buffer) hello, err := cryptutil.ReadClientHello(io.TeeReader(ctx, b)) if err != nil { if _, ok := err.(tls.RecordHeaderError); !ok { ctx.LogEntry().WithError(err).WithField("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error") return nil, err } // Not a TLS connection, moving on to regular HTTP request handling... ctx.LogEntry().Debug("HTTP connection on transparent port") ctx.isTransparent = true } else { ctx.LogEntry().WithField("target", hello.ServerName).Debug("TLS connection on transparent port") ctx.isTransparentTLS = true ctx.serverName = hello.ServerName } return netutil.ReaderConn{ Conn: ctx.Conn, Reader: io.MultiReader(bytes.NewReader(b.Bytes()), ctx.Conn), }, nil }) } // RequestFilter can filter HTTP requests coming to the proxy. type RequestFilter interface { // FilterRequest filters a HTTP request made to the proxy. The current request may be obtained // from [Context.Request]. If a previous RequestFilter provided a HTTP response, it is available // from [Context.Response]. // // Modifications to the current request can be made to the Request returned by [Context.Request] // and do not require returning a new [http.Request]. // // If the filter returns a non-nil [http.Response], then the [Request] will not be proxied to // any upstream server(s). FilterRequest(Context) (*http.Request, *http.Response) } // RequestFilterFunc is a function that implements the [RequestFilter] interface. type RequestFilterFunc func(Context) (*http.Request, *http.Response) func (f RequestFilterFunc) FilterRequest(ctx Context) (*http.Request, *http.Response) { return f(ctx) } // ResponseFilter can filter HTTP responses coming from the proxy. type ResponseFilter interface { // FilterResponse filters a HTTP response coming from the proxy. The current response may be // obtained from [Context.Response]. // // Modifications to the current response can be made to the [Response] returned by [Context.Response]. FilterResponse(Context) *http.Response } // ResponseFilterFunc is a function that implements the [ResponseFilter] interface. type ResponseFilterFunc func(Context) *http.Response func (f ResponseFilterFunc) FilterResponse(ctx Context) *http.Response { return f(ctx) } // CleanRequestProxyHeaders removes all headers added by downstream proxies from the [http.Request]. func CleanRequestProxyHeaders() RequestFilter { return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) { if req := ctx.Request(); req != nil { cleanProxyHeaders(req.Header) } return nil, nil }) } // CleanRequestProxyHeaders removes all headers for upstream proxies from the [http.Response]. func CleanResponseProxyHeaders() ResponseFilter { return ResponseFilterFunc(func(ctx Context) *http.Response { if res := ctx.Response(); res != nil { cleanProxyHeaders(res.Header) } return nil }) } // AddRequestHeaders adds headers to the [http.Request]. Any existing headers with the same // key will remain intact. func AddRequestHeaders(h http.Header) RequestFilter { return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) { if req := ctx.Request(); req != nil { if req.Header == nil { req.Header = make(http.Header) } addHeaders(req.Header, h) } return nil, nil }) } // SetRequestHeaders sets headers to the [http.Request]. Any existing headers with the same // key will be removed. func SetRequestHeaders(h http.Header) RequestFilter { return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) { if req := ctx.Request(); req != nil { if req.Header == nil { req.Header = make(http.Header) } setHeaders(req.Header, h) } return nil, nil }) } // AddResponseHeaders adds headers to the [http.Response]. Any existing headers with the same // key will remain intact. func AddResponseHeaders(h http.Header) ResponseFilter { return ResponseFilterFunc(func(ctx Context) *http.Response { if res := ctx.Response(); res != nil { if res.Header == nil { res.Header = make(http.Header) } addHeaders(res.Header, h) } return nil }) } // SetResponseHeaders sets headers to the [http.Response]. Any existing headers with the same // key will be removed. func SetResponseHeaders(h http.Header) ResponseFilter { return ResponseFilterFunc(func(ctx Context) *http.Response { if res := ctx.Response(); res != nil { if res.Header == nil { res.Header = make(http.Header) } setHeaders(res.Header, h) } return nil }) } // cleanProxyHeaders removes all headers commonly used by (reverse) HTTP proxies. func cleanProxyHeaders(h http.Header) { if h == nil { return } for _, k := range []string{ HeaderForwarded, HeaderForwardedFor, HeaderForwardedHost, HeaderForwardedPort, HeaderForwardedProto, HeaderRealIP, HeaderVia, } { h.Del(k) } } // addHeaders adds to the current existing headers. func addHeaders(dst, src http.Header) { if src == nil { return } for k, vv := range src { for _, v := range vv { dst.Add(k, v) } } } // setHeaders replaces all previous values. func setHeaders(dst, src http.Header) { if src == nil { return } for k, vv := range src { dst.Del(k) for _, v := range vv { dst.Add(k, v) } } }