Checkpoint

This commit is contained in:
2025-09-30 08:08:22 +02:00
parent a76650da35
commit 4a60059ff2
24 changed files with 1034 additions and 2959 deletions

309
proxy/handler.go Normal file
View File

@@ -0,0 +1,309 @@
package proxy
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"git.maze.io/maze/styx/ca"
"git.maze.io/maze/styx/internal/cryptutil"
"git.maze.io/maze/styx/internal/netutil"
)
type Dialer interface {
DialContext(context.Context, *http.Request) (net.Conn, error)
}
type defaultDialer struct{}
func (defaultDialer) DialContext(ctx context.Context, req *http.Request) (net.Conn, error) {
if host := netutil.Host(req.URL.Host); host == "" {
return nil, errors.New("proxy: host missing in address")
}
var d = net.Dialer{
Resolver: net.DefaultResolver,
FallbackDelay: -1,
}
// Ensure we have a port.
switch req.URL.Scheme {
case "http", "ws":
req.URL.Host = netutil.EnsurePort(req.URL.Host, "80")
case "https", "wss":
req.URL.Host = netutil.EnsurePort(req.URL.Host, "443")
}
// Resolve the host.
if ips, err := d.Resolver.LookupIP(ctx, "ip", netutil.Host(req.URL.Host)); err != nil {
return nil, err
} else {
for _, ip := range ips {
switch {
case ip.IsUnspecified():
return nil, fmt.Errorf("proxy: host %s resolves to unspecified address (blocked by DNS?)", netutil.Host(req.URL.Host))
case ip.IsLoopback():
return nil, fmt.Errorf("proxy: host %s resolves to loopback address (blocked by DNS?)", netutil.Host(req.URL.Host))
}
}
}
// Make the connection.
switch req.URL.Scheme {
case "tcp", "http", "ws":
// Plain TCP client connection.
return d.DialContext(ctx, "tcp", req.URL.Host)
case "https", "wss":
// Secure TLS client connection.
c, err := d.DialContext(ctx, "tcp", req.URL.Host)
if err != nil {
return nil, err
}
s := tls.Client(c, new(tls.Config))
return s, s.Handshake()
default:
return nil, fmt.Errorf("proxy: can't dial %s protocol", req.URL.Scheme)
}
}
// ConnFilter is called when a new connection has been accepted by the proxy.
type ConnFilter interface {
FilterConn(Context) (net.Conn, error)
}
// ConnFilterFunc is a function that implements the [ConnFilter] interface.
type ConnFilterFunc func(Context) (net.Conn, error)
func (f ConnFilterFunc) FilterConn(ctx Context) (net.Conn, error) {
return f(ctx)
}
// TLS starts a TLS handshake on the accepted connection.
func TLS(certs []tls.Certificate) ConnFilter {
return ConnFilterFunc(func(ctx Context) (net.Conn, error) {
s := tls.Server(ctx, &tls.Config{
Certificates: certs,
NextProtos: []string{"http/1.1"},
})
if err := s.Handshake(); err != nil {
return nil, err
}
return s, nil
})
}
// TLSInterceptor can generate certificates on-the-fly for clients that use a compatible TLS version.
func TLSInterceptor(ca ca.CertificateAuthority) ConnFilter {
return ConnFilterFunc(func(ctx Context) (net.Conn, error) {
s := tls.Server(ctx, &tls.Config{
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
ips := []net.IP{net.ParseIP(netutil.Host(ctx.RemoteAddr().String()))}
return ca.GetCertificate(hello.ServerName, []string{hello.ServerName}, ips)
},
NextProtos: []string{"http/1.1"},
})
if err := s.Handshake(); err != nil {
return nil, err
}
return s, nil
})
}
// Transparent can handle transparent HTTP(S) requests on the port.
//
// When a new [net.Conn] is made, this function will inspect the initial request packet for a
// TLS handshake. If a TLS handshake is detected, the connection will make a feaux HTTP CONNECT
// request using TLS, if no handshake is detected, it will make a feaux plain HTTP CONNECT request.
func Transparent() ConnFilter {
return ConnFilterFunc(func(nctx Context) (net.Conn, error) {
ctx, ok := nctx.(*proxyContext)
if !ok {
return nctx, nil
}
b := new(bytes.Buffer)
hello, err := cryptutil.ReadClientHello(io.TeeReader(ctx, b))
if err != nil {
if _, ok := err.(tls.RecordHeaderError); !ok {
ctx.LogEntry().WithError(err).WithField("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error")
return nil, err
}
// Not a TLS connection, moving on to regular HTTP request handling...
ctx.LogEntry().Debug("HTTP connection on transparent port")
ctx.isTransparent = true
} else {
ctx.LogEntry().WithField("target", hello.ServerName).Debug("TLS connection on transparent port")
ctx.isTransparentTLS = true
ctx.serverName = hello.ServerName
}
return netutil.ReaderConn{
Conn: ctx.Conn,
Reader: io.MultiReader(bytes.NewReader(b.Bytes()), ctx.Conn),
}, nil
})
}
// RequestFilter can filter HTTP requests coming to the proxy.
type RequestFilter interface {
// FilterRequest filters a HTTP request made to the proxy. The current request may be obtained
// from [Context.Request]. If a previous RequestFilter provided a HTTP response, it is available
// from [Context.Response].
//
// Modifications to the current request can be made to the Request returned by [Context.Request]
// and do not require returning a new [http.Request].
//
// If the filter returns a non-nil [http.Response], then the [Request] will not be proxied to
// any upstream server(s).
FilterRequest(Context) (*http.Request, *http.Response)
}
// RequestFilterFunc is a function that implements the [RequestFilter] interface.
type RequestFilterFunc func(Context) (*http.Request, *http.Response)
func (f RequestFilterFunc) FilterRequest(ctx Context) (*http.Request, *http.Response) {
return f(ctx)
}
// ResponseFilter can filter HTTP responses coming from the proxy.
type ResponseFilter interface {
// FilterResponse filters a HTTP response coming from the proxy. The current response may be
// obtained from [Context.Response].
//
// Modifications to the current response can be made to the [Response] returned by [Context.Response].
FilterResponse(Context) *http.Response
}
// ResponseFilterFunc is a function that implements the [ResponseFilter] interface.
type ResponseFilterFunc func(Context) *http.Response
func (f ResponseFilterFunc) FilterResponse(ctx Context) *http.Response {
return f(ctx)
}
// CleanRequestProxyHeaders removes all headers added by downstream proxies from the [http.Request].
func CleanRequestProxyHeaders() RequestFilter {
return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
if req := ctx.Request(); req != nil {
cleanProxyHeaders(req.Header)
}
return nil, nil
})
}
// CleanRequestProxyHeaders removes all headers for upstream proxies from the [http.Response].
func CleanResponseProxyHeaders() ResponseFilter {
return ResponseFilterFunc(func(ctx Context) *http.Response {
if res := ctx.Response(); res != nil {
cleanProxyHeaders(res.Header)
}
return nil
})
}
// AddRequestHeaders adds headers to the [http.Request]. Any existing headers with the same
// key will remain intact.
func AddRequestHeaders(h http.Header) RequestFilter {
return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
if req := ctx.Request(); req != nil {
if req.Header == nil {
req.Header = make(http.Header)
}
addHeaders(req.Header, h)
}
return nil, nil
})
}
// SetRequestHeaders sets headers to the [http.Request]. Any existing headers with the same
// key will be removed.
func SetRequestHeaders(h http.Header) RequestFilter {
return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
if req := ctx.Request(); req != nil {
if req.Header == nil {
req.Header = make(http.Header)
}
setHeaders(req.Header, h)
}
return nil, nil
})
}
// AddResponseHeaders adds headers to the [http.Response]. Any existing headers with the same
// key will remain intact.
func AddResponseHeaders(h http.Header) ResponseFilter {
return ResponseFilterFunc(func(ctx Context) *http.Response {
if res := ctx.Response(); res != nil {
if res.Header == nil {
res.Header = make(http.Header)
}
addHeaders(res.Header, h)
}
return nil
})
}
// SetResponseHeaders sets headers to the [http.Response]. Any existing headers with the same
// key will be removed.
func SetResponseHeaders(h http.Header) ResponseFilter {
return ResponseFilterFunc(func(ctx Context) *http.Response {
if res := ctx.Response(); res != nil {
if res.Header == nil {
res.Header = make(http.Header)
}
setHeaders(res.Header, h)
}
return nil
})
}
// cleanProxyHeaders removes all headers commonly used by (reverse) HTTP proxies.
func cleanProxyHeaders(h http.Header) {
if h == nil {
return
}
for _, k := range []string{
HeaderForwarded,
HeaderForwardedFor,
HeaderForwardedHost,
HeaderForwardedPort,
HeaderForwardedProto,
HeaderRealIP,
HeaderVia,
} {
h.Del(k)
}
}
// addHeaders adds to the current existing headers.
func addHeaders(dst, src http.Header) {
if src == nil {
return
}
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
// setHeaders replaces all previous values.
func setHeaders(dst, src http.Header) {
if src == nil {
return
}
for k, vv := range src {
dst.Del(k)
for _, v := range vv {
dst.Add(k, v)
}
}
}