360 lines
11 KiB
Go
360 lines
11 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
|
|
"git.maze.io/maze/styx/ca"
|
|
"git.maze.io/maze/styx/internal/cryptutil"
|
|
"git.maze.io/maze/styx/internal/netutil"
|
|
)
|
|
|
|
// Dialer can make outbound connections to upstream servers.
|
|
type Dialer interface {
|
|
// DialContext makes a new connection to the address specified in the [http.Request].
|
|
//
|
|
// The [http.Request] contains the URL scheme (http, https, ws, wss) and host (with optional port)
|
|
// to connect to. The [context.Context] may be used for cancellation and timeouts.
|
|
DialContext(context.Context, *http.Request) (net.Conn, error)
|
|
}
|
|
|
|
type defaultDialer struct{}
|
|
|
|
func (defaultDialer) DialContext(ctx context.Context, req *http.Request) (net.Conn, error) {
|
|
if host := netutil.Host(req.URL.Host); host == "" {
|
|
return nil, errors.New("proxy: host missing in address")
|
|
}
|
|
|
|
var d = net.Dialer{
|
|
Resolver: net.DefaultResolver,
|
|
FallbackDelay: -1,
|
|
}
|
|
|
|
// Ensure we have a port.
|
|
switch req.URL.Scheme {
|
|
case "http", "ws":
|
|
req.URL.Host = netutil.EnsurePort(req.URL.Host, "80")
|
|
case "https", "wss":
|
|
req.URL.Host = netutil.EnsurePort(req.URL.Host, "443")
|
|
}
|
|
|
|
// Resolve the host.
|
|
if ips, err := d.Resolver.LookupIP(ctx, "ip", netutil.Host(req.URL.Host)); err != nil {
|
|
return nil, err
|
|
} else {
|
|
for _, ip := range ips {
|
|
switch {
|
|
case ip.IsUnspecified():
|
|
return nil, fmt.Errorf("proxy: host %s resolves to unspecified address (blocked by DNS?)", netutil.Host(req.URL.Host))
|
|
case ip.IsLoopback():
|
|
return nil, fmt.Errorf("proxy: host %s resolves to loopback address (blocked by DNS?)", netutil.Host(req.URL.Host))
|
|
}
|
|
}
|
|
}
|
|
|
|
// Make the connection.
|
|
switch req.URL.Scheme {
|
|
case "tcp", "http", "ws":
|
|
// Plain TCP client connection.
|
|
return d.DialContext(ctx, "tcp", req.URL.Host)
|
|
case "https", "wss":
|
|
// Secure TLS client connection.
|
|
c, err := d.DialContext(ctx, "tcp", req.URL.Host)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
s := tls.Client(c, new(tls.Config))
|
|
return s, s.Handshake()
|
|
default:
|
|
return nil, fmt.Errorf("proxy: can't dial %s protocol", req.URL.Scheme)
|
|
}
|
|
}
|
|
|
|
// ErrorHandler can handle errors that occur during proxying.
|
|
type ErrorHandler interface {
|
|
// HandleError handles an error that occurred during proxying. If the method returns a non-nil
|
|
// [http.Response], it will be sent to the client as-is. If it returns nil, a generic HTTP 502
|
|
// Bad Gateway response will be sent to the client.
|
|
//
|
|
// The [Context] may be inspected to obtain information about the request that caused the error.
|
|
// However, the [Context.Request] and [Context.Response] may be nil depending on when the error
|
|
// occurred.
|
|
HandleError(Context, error) *http.Response
|
|
}
|
|
|
|
// ConnHandler is called when a new connection has been accepted by the proxy.
|
|
type ConnHandler interface {
|
|
HandleConn(Context) (net.Conn, error)
|
|
}
|
|
|
|
// ConnHandlerFunc is a function that implements the [ConnHandler] interface.
|
|
type ConnHandlerFunc func(Context) (net.Conn, error)
|
|
|
|
func (f ConnHandlerFunc) HandleConn(ctx Context) (net.Conn, error) {
|
|
return f(ctx)
|
|
}
|
|
|
|
// TLS starts a TLS handshake on the accepted connection.
|
|
func TLS(config *tls.Config) ConnHandler {
|
|
if config == nil {
|
|
config = new(tls.Config)
|
|
}
|
|
config.NextProtos = []string{"http/1.1"}
|
|
return ConnHandlerFunc(func(ctx Context) (net.Conn, error) {
|
|
s := tls.Server(ctx, config)
|
|
if err := s.Handshake(); err != nil {
|
|
return nil, err
|
|
}
|
|
return s, nil
|
|
})
|
|
}
|
|
|
|
// TLSInterceptor can generate certificates on-the-fly for clients that use a compatible TLS version.
|
|
func TLSInterceptor(ca ca.CertificateAuthority) ConnHandler {
|
|
return ConnHandlerFunc(func(ctx Context) (net.Conn, error) {
|
|
s := tls.Server(ctx, &tls.Config{
|
|
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
ips := []net.IP{net.ParseIP(netutil.Host(ctx.RemoteAddr().String()))}
|
|
return ca.GetCertificate(hello.ServerName, []string{hello.ServerName}, ips)
|
|
},
|
|
NextProtos: []string{"http/1.1"},
|
|
})
|
|
if err := s.Handshake(); err != nil {
|
|
return nil, err
|
|
}
|
|
return s, nil
|
|
})
|
|
}
|
|
|
|
// Transparent can handle transparent HTTP(S) requests on the port.
|
|
//
|
|
// When a new [net.Conn] is made, this function will inspect the initial request packet for a
|
|
// TLS handshake. If a TLS handshake is detected, the connection will make a feaux HTTP CONNECT
|
|
// request using TLS, if no handshake is detected, it will make a feaux plain HTTP CONNECT request.
|
|
func Transparent(port int) ConnHandler {
|
|
return ConnHandlerFunc(func(nctx Context) (net.Conn, error) {
|
|
ctx, ok := nctx.(*proxyContext)
|
|
if !ok {
|
|
return nctx, nil
|
|
}
|
|
|
|
var (
|
|
b = new(bytes.Buffer)
|
|
hello, err = cryptutil.ReadClientHello(io.TeeReader(netutil.ReadOnlyConn{Reader: ctx.br}, b))
|
|
log = ctx.Logger()
|
|
)
|
|
if err != nil {
|
|
if _, ok := err.(tls.RecordHeaderError); !ok {
|
|
log.Err(err).Value("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error")
|
|
return nil, err
|
|
}
|
|
// Not a TLS connection, moving on to regular HTTP request handling...
|
|
log.Debug("HTTP connection on transparent port")
|
|
ctx.transparent = port
|
|
} else {
|
|
log.Value("target", hello.ServerName).Debug("TLS connection on transparent port")
|
|
ctx.transparent = port
|
|
ctx.transparentTLS = true
|
|
ctx.serverName = hello.ServerName
|
|
}
|
|
|
|
return netutil.ReaderConn{
|
|
Conn: ctx.Conn,
|
|
Reader: io.MultiReader(bytes.NewReader(b.Bytes()), ctx.Conn),
|
|
}, nil
|
|
})
|
|
}
|
|
|
|
// DialHandler can filter network dial requests coming from the proxy.
|
|
type DialHandler interface {
|
|
// HandleDial filters an outbound dial request made by the proxy.
|
|
//
|
|
// The handler may decide to intercept the dial request and return a new [net.Conn]
|
|
// that will be used instead of dialing the target. The handler can also return
|
|
// nil, in which case the normal dial will proceed.
|
|
HandleDial(Context, *http.Request) (net.Conn, error)
|
|
}
|
|
|
|
// DialHandlerFunc is a function that implements the [DialHandler] interface.
|
|
type DialHandlerFunc func(Context, *http.Request) (net.Conn, error)
|
|
|
|
func (f DialHandlerFunc) HandleDial(ctx Context, req *http.Request) (net.Conn, error) {
|
|
return f(ctx, req)
|
|
}
|
|
|
|
// ForwardHandler can filter forward HTTP proxy requests.
|
|
type ForwardHandler interface {
|
|
HandleForward(Context, *http.Request) (*http.Response, error)
|
|
}
|
|
|
|
type ForwardHandlerFunc func(Context, *http.Request) (*http.Response, error)
|
|
|
|
func (f ForwardHandlerFunc) HandleForward(ctx Context, req *http.Request) (*http.Response, error) {
|
|
return f(ctx, req)
|
|
}
|
|
|
|
// RequestHandler can filter HTTP requests coming to the proxy.
|
|
type RequestHandler interface {
|
|
// HandlerRequest filters a HTTP request made to the proxy. The current request may be obtained
|
|
// from [Context.Request]. If a previous RequestHandler provided a HTTP response, it is available
|
|
// from [Context.Response].
|
|
//
|
|
// Modifications to the current request can be made to the Request returned by [Context.Request]
|
|
// and do not require returning a new [http.Request].
|
|
//
|
|
// If the filter returns a non-nil [http.Response], then the [Request] will not be proxied to
|
|
// any upstream server(s).
|
|
HandleRequest(Context) (*http.Request, *http.Response)
|
|
}
|
|
|
|
// RequestHandlerFunc is a function that implements the [RequestHandler] interface.
|
|
type RequestHandlerFunc func(Context) (*http.Request, *http.Response)
|
|
|
|
func (f RequestHandlerFunc) HandleRequest(ctx Context) (*http.Request, *http.Response) {
|
|
return f(ctx)
|
|
}
|
|
|
|
// ResponseHandler can filter HTTP responses coming from the proxy.
|
|
type ResponseHandler interface {
|
|
// HandlerResponse filters a HTTP response coming from the proxy. The current response may be
|
|
// obtained from [Context.Response].
|
|
//
|
|
// Modifications to the current response can be made to the [Response] returned by [Context.Response].
|
|
HandleResponse(Context) *http.Response
|
|
}
|
|
|
|
// ResponseHandlerFunc is a function that implements the [ResponseHandler] interface.
|
|
type ResponseHandlerFunc func(Context) *http.Response
|
|
|
|
func (f ResponseHandlerFunc) HandleResponse(ctx Context) *http.Response {
|
|
return f(ctx)
|
|
}
|
|
|
|
// CleanRequestProxyHeaders removes all headers added by downstream proxies from the [http.Request].
|
|
func CleanRequestProxyHeaders() RequestHandler {
|
|
return RequestHandlerFunc(func(ctx Context) (*http.Request, *http.Response) {
|
|
if req := ctx.Request(); req != nil {
|
|
cleanProxyHeaders(req.Header)
|
|
}
|
|
return nil, nil
|
|
})
|
|
}
|
|
|
|
// CleanRequestProxyHeaders removes all headers for upstream proxies from the [http.Response].
|
|
func CleanResponseProxyHeaders() ResponseHandler {
|
|
return ResponseHandlerFunc(func(ctx Context) *http.Response {
|
|
if res := ctx.Response(); res != nil {
|
|
cleanProxyHeaders(res.Header)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// AddRequestHeaders adds headers to the [http.Request]. Any existing headers with the same
|
|
// key will remain intact.
|
|
func AddRequestHeaders(h http.Header) RequestHandler {
|
|
return RequestHandlerFunc(func(ctx Context) (*http.Request, *http.Response) {
|
|
if req := ctx.Request(); req != nil {
|
|
if req.Header == nil {
|
|
req.Header = make(http.Header)
|
|
}
|
|
addHeaders(req.Header, h)
|
|
}
|
|
return nil, nil
|
|
})
|
|
}
|
|
|
|
// SetRequestHeaders sets headers to the [http.Request]. Any existing headers with the same
|
|
// key will be removed.
|
|
func SetRequestHeaders(h http.Header) RequestHandler {
|
|
return RequestHandlerFunc(func(ctx Context) (*http.Request, *http.Response) {
|
|
if req := ctx.Request(); req != nil {
|
|
if req.Header == nil {
|
|
req.Header = make(http.Header)
|
|
}
|
|
setHeaders(req.Header, h)
|
|
}
|
|
return nil, nil
|
|
})
|
|
}
|
|
|
|
// AddResponseHeaders adds headers to the [http.Response]. Any existing headers with the same
|
|
// key will remain intact.
|
|
func AddResponseHeaders(h http.Header) ResponseHandler {
|
|
return ResponseHandlerFunc(func(ctx Context) *http.Response {
|
|
if res := ctx.Response(); res != nil {
|
|
if res.Header == nil {
|
|
res.Header = make(http.Header)
|
|
}
|
|
addHeaders(res.Header, h)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// SetResponseHeaders sets headers to the [http.Response]. Any existing headers with the same
|
|
// key will be removed.
|
|
func SetResponseHeaders(h http.Header) ResponseHandler {
|
|
return ResponseHandlerFunc(func(ctx Context) *http.Response {
|
|
if res := ctx.Response(); res != nil {
|
|
if res.Header == nil {
|
|
res.Header = make(http.Header)
|
|
}
|
|
setHeaders(res.Header, h)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// cleanProxyHeaders removes all headers commonly used by (reverse) HTTP proxies.
|
|
func cleanProxyHeaders(h http.Header) {
|
|
if h == nil {
|
|
return
|
|
}
|
|
|
|
for _, k := range []string{
|
|
HeaderForwarded,
|
|
HeaderForwardedFor,
|
|
HeaderForwardedHost,
|
|
HeaderForwardedPort,
|
|
HeaderForwardedProto,
|
|
HeaderRealIP,
|
|
HeaderVia,
|
|
} {
|
|
h.Del(k)
|
|
}
|
|
}
|
|
|
|
// addHeaders adds to the current existing headers.
|
|
func addHeaders(dst, src http.Header) {
|
|
if src == nil {
|
|
return
|
|
}
|
|
|
|
for k, vv := range src {
|
|
for _, v := range vv {
|
|
dst.Add(k, v)
|
|
}
|
|
}
|
|
}
|
|
|
|
// setHeaders replaces all previous values.
|
|
func setHeaders(dst, src http.Header) {
|
|
if src == nil {
|
|
return
|
|
}
|
|
|
|
for k, vv := range src {
|
|
dst.Del(k)
|
|
for _, v := range vv {
|
|
dst.Add(k, v)
|
|
}
|
|
}
|
|
}
|