Checkpoint

This commit is contained in:
2025-10-01 15:37:55 +02:00
parent 4a60059ff2
commit 03352e3312
31 changed files with 2611 additions and 384 deletions

View File

@@ -15,7 +15,12 @@ import (
"git.maze.io/maze/styx/internal/netutil"
)
// Dialer can make outbound connections to upstream servers.
type Dialer interface {
// DialContext makes a new connection to the address specified in the [http.Request].
//
// The [http.Request] contains the URL scheme (http, https, ws, wss) and host (with optional port)
// to connect to. The [context.Context] may be used for cancellation and timeouts.
DialContext(context.Context, *http.Request) (net.Conn, error)
}
@@ -71,25 +76,38 @@ func (defaultDialer) DialContext(ctx context.Context, req *http.Request) (net.Co
}
}
// ConnFilter is called when a new connection has been accepted by the proxy.
type ConnFilter interface {
FilterConn(Context) (net.Conn, error)
// ErrorHandler can handle errors that occur during proxying.
type ErrorHandler interface {
// HandleError handles an error that occurred during proxying. If the method returns a non-nil
// [http.Response], it will be sent to the client as-is. If it returns nil, a generic HTTP 502
// Bad Gateway response will be sent to the client.
//
// The [Context] may be inspected to obtain information about the request that caused the error.
// However, the [Context.Request] and [Context.Response] may be nil depending on when the error
// occurred.
HandleError(Context, error) *http.Response
}
// ConnFilterFunc is a function that implements the [ConnFilter] interface.
type ConnFilterFunc func(Context) (net.Conn, error)
// ConnHandler is called when a new connection has been accepted by the proxy.
type ConnHandler interface {
HandleConn(Context) (net.Conn, error)
}
func (f ConnFilterFunc) FilterConn(ctx Context) (net.Conn, error) {
// ConnHandlerFunc is a function that implements the [ConnHandler] interface.
type ConnHandlerFunc func(Context) (net.Conn, error)
func (f ConnHandlerFunc) HandleConn(ctx Context) (net.Conn, error) {
return f(ctx)
}
// TLS starts a TLS handshake on the accepted connection.
func TLS(certs []tls.Certificate) ConnFilter {
return ConnFilterFunc(func(ctx Context) (net.Conn, error) {
s := tls.Server(ctx, &tls.Config{
Certificates: certs,
NextProtos: []string{"http/1.1"},
})
func TLS(config *tls.Config) ConnHandler {
if config == nil {
config = new(tls.Config)
}
config.NextProtos = []string{"http/1.1"}
return ConnHandlerFunc(func(ctx Context) (net.Conn, error) {
s := tls.Server(ctx, config)
if err := s.Handshake(); err != nil {
return nil, err
}
@@ -98,8 +116,8 @@ func TLS(certs []tls.Certificate) ConnFilter {
}
// TLSInterceptor can generate certificates on-the-fly for clients that use a compatible TLS version.
func TLSInterceptor(ca ca.CertificateAuthority) ConnFilter {
return ConnFilterFunc(func(ctx Context) (net.Conn, error) {
func TLSInterceptor(ca ca.CertificateAuthority) ConnHandler {
return ConnHandlerFunc(func(ctx Context) (net.Conn, error) {
s := tls.Server(ctx, &tls.Config{
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
ips := []net.IP{net.ParseIP(netutil.Host(ctx.RemoteAddr().String()))}
@@ -119,26 +137,27 @@ func TLSInterceptor(ca ca.CertificateAuthority) ConnFilter {
// When a new [net.Conn] is made, this function will inspect the initial request packet for a
// TLS handshake. If a TLS handshake is detected, the connection will make a feaux HTTP CONNECT
// request using TLS, if no handshake is detected, it will make a feaux plain HTTP CONNECT request.
func Transparent() ConnFilter {
return ConnFilterFunc(func(nctx Context) (net.Conn, error) {
func Transparent(port int) ConnHandler {
return ConnHandlerFunc(func(nctx Context) (net.Conn, error) {
ctx, ok := nctx.(*proxyContext)
if !ok {
return nctx, nil
}
b := new(bytes.Buffer)
hello, err := cryptutil.ReadClientHello(io.TeeReader(ctx, b))
hello, err := cryptutil.ReadClientHello(io.TeeReader(netutil.ReadOnlyConn{Reader: ctx.br}, b))
if err != nil {
if _, ok := err.(tls.RecordHeaderError); !ok {
ctx.LogEntry().WithError(err).WithField("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error")
ctx.LogEntry().Err(err).Value("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error")
return nil, err
}
// Not a TLS connection, moving on to regular HTTP request handling...
ctx.LogEntry().Debug("HTTP connection on transparent port")
ctx.isTransparent = true
ctx.transparent = port
} else {
ctx.LogEntry().WithField("target", hello.ServerName).Debug("TLS connection on transparent port")
ctx.isTransparentTLS = true
ctx.LogEntry().Value("target", hello.ServerName).Debug("TLS connection on transparent port")
ctx.transparent = port
ctx.transparentTLS = true
ctx.serverName = hello.ServerName
}
@@ -149,10 +168,10 @@ func Transparent() ConnFilter {
})
}
// RequestFilter can filter HTTP requests coming to the proxy.
type RequestFilter interface {
// FilterRequest filters a HTTP request made to the proxy. The current request may be obtained
// from [Context.Request]. If a previous RequestFilter provided a HTTP response, it is available
// RequestHandler can filter HTTP requests coming to the proxy.
type RequestHandler interface {
// HandlerRequest filters a HTTP request made to the proxy. The current request may be obtained
// from [Context.Request]. If a previous RequestHandler provided a HTTP response, it is available
// from [Context.Response].
//
// Modifications to the current request can be made to the Request returned by [Context.Request]
@@ -160,35 +179,35 @@ type RequestFilter interface {
//
// If the filter returns a non-nil [http.Response], then the [Request] will not be proxied to
// any upstream server(s).
FilterRequest(Context) (*http.Request, *http.Response)
HandleRequest(Context) (*http.Request, *http.Response)
}
// RequestFilterFunc is a function that implements the [RequestFilter] interface.
type RequestFilterFunc func(Context) (*http.Request, *http.Response)
// RequestHandlerFunc is a function that implements the [RequestHandler] interface.
type RequestHandlerFunc func(Context) (*http.Request, *http.Response)
func (f RequestFilterFunc) FilterRequest(ctx Context) (*http.Request, *http.Response) {
func (f RequestHandlerFunc) HandleRequest(ctx Context) (*http.Request, *http.Response) {
return f(ctx)
}
// ResponseFilter can filter HTTP responses coming from the proxy.
type ResponseFilter interface {
// FilterResponse filters a HTTP response coming from the proxy. The current response may be
// ResponseHandler can filter HTTP responses coming from the proxy.
type ResponseHandler interface {
// HandlerResponse filters a HTTP response coming from the proxy. The current response may be
// obtained from [Context.Response].
//
// Modifications to the current response can be made to the [Response] returned by [Context.Response].
FilterResponse(Context) *http.Response
HandleResponse(Context) *http.Response
}
// ResponseFilterFunc is a function that implements the [ResponseFilter] interface.
type ResponseFilterFunc func(Context) *http.Response
// ResponseHandlerFunc is a function that implements the [ResponseHandler] interface.
type ResponseHandlerFunc func(Context) *http.Response
func (f ResponseFilterFunc) FilterResponse(ctx Context) *http.Response {
func (f ResponseHandlerFunc) HandleResponse(ctx Context) *http.Response {
return f(ctx)
}
// CleanRequestProxyHeaders removes all headers added by downstream proxies from the [http.Request].
func CleanRequestProxyHeaders() RequestFilter {
return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
func CleanRequestProxyHeaders() RequestHandler {
return RequestHandlerFunc(func(ctx Context) (*http.Request, *http.Response) {
if req := ctx.Request(); req != nil {
cleanProxyHeaders(req.Header)
}
@@ -197,8 +216,8 @@ func CleanRequestProxyHeaders() RequestFilter {
}
// CleanRequestProxyHeaders removes all headers for upstream proxies from the [http.Response].
func CleanResponseProxyHeaders() ResponseFilter {
return ResponseFilterFunc(func(ctx Context) *http.Response {
func CleanResponseProxyHeaders() ResponseHandler {
return ResponseHandlerFunc(func(ctx Context) *http.Response {
if res := ctx.Response(); res != nil {
cleanProxyHeaders(res.Header)
}
@@ -208,8 +227,8 @@ func CleanResponseProxyHeaders() ResponseFilter {
// AddRequestHeaders adds headers to the [http.Request]. Any existing headers with the same
// key will remain intact.
func AddRequestHeaders(h http.Header) RequestFilter {
return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
func AddRequestHeaders(h http.Header) RequestHandler {
return RequestHandlerFunc(func(ctx Context) (*http.Request, *http.Response) {
if req := ctx.Request(); req != nil {
if req.Header == nil {
req.Header = make(http.Header)
@@ -222,8 +241,8 @@ func AddRequestHeaders(h http.Header) RequestFilter {
// SetRequestHeaders sets headers to the [http.Request]. Any existing headers with the same
// key will be removed.
func SetRequestHeaders(h http.Header) RequestFilter {
return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
func SetRequestHeaders(h http.Header) RequestHandler {
return RequestHandlerFunc(func(ctx Context) (*http.Request, *http.Response) {
if req := ctx.Request(); req != nil {
if req.Header == nil {
req.Header = make(http.Header)
@@ -236,8 +255,8 @@ func SetRequestHeaders(h http.Header) RequestFilter {
// AddResponseHeaders adds headers to the [http.Response]. Any existing headers with the same
// key will remain intact.
func AddResponseHeaders(h http.Header) ResponseFilter {
return ResponseFilterFunc(func(ctx Context) *http.Response {
func AddResponseHeaders(h http.Header) ResponseHandler {
return ResponseHandlerFunc(func(ctx Context) *http.Response {
if res := ctx.Response(); res != nil {
if res.Header == nil {
res.Header = make(http.Header)
@@ -250,8 +269,8 @@ func AddResponseHeaders(h http.Header) ResponseFilter {
// SetResponseHeaders sets headers to the [http.Response]. Any existing headers with the same
// key will be removed.
func SetResponseHeaders(h http.Header) ResponseFilter {
return ResponseFilterFunc(func(ctx Context) *http.Response {
func SetResponseHeaders(h http.Header) ResponseHandler {
return ResponseHandlerFunc(func(ctx Context) *http.Response {
if res := ctx.Response(); res != nil {
if res.Header == nil {
res.Header = make(http.Header)