Checkpoint
This commit is contained in:
309
proxy/handler.go
Normal file
309
proxy/handler.go
Normal file
@@ -0,0 +1,309 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"git.maze.io/maze/styx/ca"
|
||||
"git.maze.io/maze/styx/internal/cryptutil"
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
)
|
||||
|
||||
type Dialer interface {
|
||||
DialContext(context.Context, *http.Request) (net.Conn, error)
|
||||
}
|
||||
|
||||
type defaultDialer struct{}
|
||||
|
||||
func (defaultDialer) DialContext(ctx context.Context, req *http.Request) (net.Conn, error) {
|
||||
if host := netutil.Host(req.URL.Host); host == "" {
|
||||
return nil, errors.New("proxy: host missing in address")
|
||||
}
|
||||
|
||||
var d = net.Dialer{
|
||||
Resolver: net.DefaultResolver,
|
||||
FallbackDelay: -1,
|
||||
}
|
||||
|
||||
// Ensure we have a port.
|
||||
switch req.URL.Scheme {
|
||||
case "http", "ws":
|
||||
req.URL.Host = netutil.EnsurePort(req.URL.Host, "80")
|
||||
case "https", "wss":
|
||||
req.URL.Host = netutil.EnsurePort(req.URL.Host, "443")
|
||||
}
|
||||
|
||||
// Resolve the host.
|
||||
if ips, err := d.Resolver.LookupIP(ctx, "ip", netutil.Host(req.URL.Host)); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
for _, ip := range ips {
|
||||
switch {
|
||||
case ip.IsUnspecified():
|
||||
return nil, fmt.Errorf("proxy: host %s resolves to unspecified address (blocked by DNS?)", netutil.Host(req.URL.Host))
|
||||
case ip.IsLoopback():
|
||||
return nil, fmt.Errorf("proxy: host %s resolves to loopback address (blocked by DNS?)", netutil.Host(req.URL.Host))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Make the connection.
|
||||
switch req.URL.Scheme {
|
||||
case "tcp", "http", "ws":
|
||||
// Plain TCP client connection.
|
||||
return d.DialContext(ctx, "tcp", req.URL.Host)
|
||||
case "https", "wss":
|
||||
// Secure TLS client connection.
|
||||
c, err := d.DialContext(ctx, "tcp", req.URL.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := tls.Client(c, new(tls.Config))
|
||||
return s, s.Handshake()
|
||||
default:
|
||||
return nil, fmt.Errorf("proxy: can't dial %s protocol", req.URL.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
// ConnFilter is called when a new connection has been accepted by the proxy.
|
||||
type ConnFilter interface {
|
||||
FilterConn(Context) (net.Conn, error)
|
||||
}
|
||||
|
||||
// ConnFilterFunc is a function that implements the [ConnFilter] interface.
|
||||
type ConnFilterFunc func(Context) (net.Conn, error)
|
||||
|
||||
func (f ConnFilterFunc) FilterConn(ctx Context) (net.Conn, error) {
|
||||
return f(ctx)
|
||||
}
|
||||
|
||||
// TLS starts a TLS handshake on the accepted connection.
|
||||
func TLS(certs []tls.Certificate) ConnFilter {
|
||||
return ConnFilterFunc(func(ctx Context) (net.Conn, error) {
|
||||
s := tls.Server(ctx, &tls.Config{
|
||||
Certificates: certs,
|
||||
NextProtos: []string{"http/1.1"},
|
||||
})
|
||||
if err := s.Handshake(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
})
|
||||
}
|
||||
|
||||
// TLSInterceptor can generate certificates on-the-fly for clients that use a compatible TLS version.
|
||||
func TLSInterceptor(ca ca.CertificateAuthority) ConnFilter {
|
||||
return ConnFilterFunc(func(ctx Context) (net.Conn, error) {
|
||||
s := tls.Server(ctx, &tls.Config{
|
||||
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
ips := []net.IP{net.ParseIP(netutil.Host(ctx.RemoteAddr().String()))}
|
||||
return ca.GetCertificate(hello.ServerName, []string{hello.ServerName}, ips)
|
||||
},
|
||||
NextProtos: []string{"http/1.1"},
|
||||
})
|
||||
if err := s.Handshake(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
})
|
||||
}
|
||||
|
||||
// Transparent can handle transparent HTTP(S) requests on the port.
|
||||
//
|
||||
// When a new [net.Conn] is made, this function will inspect the initial request packet for a
|
||||
// TLS handshake. If a TLS handshake is detected, the connection will make a feaux HTTP CONNECT
|
||||
// request using TLS, if no handshake is detected, it will make a feaux plain HTTP CONNECT request.
|
||||
func Transparent() ConnFilter {
|
||||
return ConnFilterFunc(func(nctx Context) (net.Conn, error) {
|
||||
ctx, ok := nctx.(*proxyContext)
|
||||
if !ok {
|
||||
return nctx, nil
|
||||
}
|
||||
|
||||
b := new(bytes.Buffer)
|
||||
hello, err := cryptutil.ReadClientHello(io.TeeReader(ctx, b))
|
||||
if err != nil {
|
||||
if _, ok := err.(tls.RecordHeaderError); !ok {
|
||||
ctx.LogEntry().WithError(err).WithField("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error")
|
||||
return nil, err
|
||||
}
|
||||
// Not a TLS connection, moving on to regular HTTP request handling...
|
||||
ctx.LogEntry().Debug("HTTP connection on transparent port")
|
||||
ctx.isTransparent = true
|
||||
} else {
|
||||
ctx.LogEntry().WithField("target", hello.ServerName).Debug("TLS connection on transparent port")
|
||||
ctx.isTransparentTLS = true
|
||||
ctx.serverName = hello.ServerName
|
||||
}
|
||||
|
||||
return netutil.ReaderConn{
|
||||
Conn: ctx.Conn,
|
||||
Reader: io.MultiReader(bytes.NewReader(b.Bytes()), ctx.Conn),
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// RequestFilter can filter HTTP requests coming to the proxy.
|
||||
type RequestFilter interface {
|
||||
// FilterRequest filters a HTTP request made to the proxy. The current request may be obtained
|
||||
// from [Context.Request]. If a previous RequestFilter provided a HTTP response, it is available
|
||||
// from [Context.Response].
|
||||
//
|
||||
// Modifications to the current request can be made to the Request returned by [Context.Request]
|
||||
// and do not require returning a new [http.Request].
|
||||
//
|
||||
// If the filter returns a non-nil [http.Response], then the [Request] will not be proxied to
|
||||
// any upstream server(s).
|
||||
FilterRequest(Context) (*http.Request, *http.Response)
|
||||
}
|
||||
|
||||
// RequestFilterFunc is a function that implements the [RequestFilter] interface.
|
||||
type RequestFilterFunc func(Context) (*http.Request, *http.Response)
|
||||
|
||||
func (f RequestFilterFunc) FilterRequest(ctx Context) (*http.Request, *http.Response) {
|
||||
return f(ctx)
|
||||
}
|
||||
|
||||
// ResponseFilter can filter HTTP responses coming from the proxy.
|
||||
type ResponseFilter interface {
|
||||
// FilterResponse filters a HTTP response coming from the proxy. The current response may be
|
||||
// obtained from [Context.Response].
|
||||
//
|
||||
// Modifications to the current response can be made to the [Response] returned by [Context.Response].
|
||||
FilterResponse(Context) *http.Response
|
||||
}
|
||||
|
||||
// ResponseFilterFunc is a function that implements the [ResponseFilter] interface.
|
||||
type ResponseFilterFunc func(Context) *http.Response
|
||||
|
||||
func (f ResponseFilterFunc) FilterResponse(ctx Context) *http.Response {
|
||||
return f(ctx)
|
||||
}
|
||||
|
||||
// CleanRequestProxyHeaders removes all headers added by downstream proxies from the [http.Request].
|
||||
func CleanRequestProxyHeaders() RequestFilter {
|
||||
return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
|
||||
if req := ctx.Request(); req != nil {
|
||||
cleanProxyHeaders(req.Header)
|
||||
}
|
||||
return nil, nil
|
||||
})
|
||||
}
|
||||
|
||||
// CleanRequestProxyHeaders removes all headers for upstream proxies from the [http.Response].
|
||||
func CleanResponseProxyHeaders() ResponseFilter {
|
||||
return ResponseFilterFunc(func(ctx Context) *http.Response {
|
||||
if res := ctx.Response(); res != nil {
|
||||
cleanProxyHeaders(res.Header)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// AddRequestHeaders adds headers to the [http.Request]. Any existing headers with the same
|
||||
// key will remain intact.
|
||||
func AddRequestHeaders(h http.Header) RequestFilter {
|
||||
return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
|
||||
if req := ctx.Request(); req != nil {
|
||||
if req.Header == nil {
|
||||
req.Header = make(http.Header)
|
||||
}
|
||||
addHeaders(req.Header, h)
|
||||
}
|
||||
return nil, nil
|
||||
})
|
||||
}
|
||||
|
||||
// SetRequestHeaders sets headers to the [http.Request]. Any existing headers with the same
|
||||
// key will be removed.
|
||||
func SetRequestHeaders(h http.Header) RequestFilter {
|
||||
return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
|
||||
if req := ctx.Request(); req != nil {
|
||||
if req.Header == nil {
|
||||
req.Header = make(http.Header)
|
||||
}
|
||||
setHeaders(req.Header, h)
|
||||
}
|
||||
return nil, nil
|
||||
})
|
||||
}
|
||||
|
||||
// AddResponseHeaders adds headers to the [http.Response]. Any existing headers with the same
|
||||
// key will remain intact.
|
||||
func AddResponseHeaders(h http.Header) ResponseFilter {
|
||||
return ResponseFilterFunc(func(ctx Context) *http.Response {
|
||||
if res := ctx.Response(); res != nil {
|
||||
if res.Header == nil {
|
||||
res.Header = make(http.Header)
|
||||
}
|
||||
addHeaders(res.Header, h)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// SetResponseHeaders sets headers to the [http.Response]. Any existing headers with the same
|
||||
// key will be removed.
|
||||
func SetResponseHeaders(h http.Header) ResponseFilter {
|
||||
return ResponseFilterFunc(func(ctx Context) *http.Response {
|
||||
if res := ctx.Response(); res != nil {
|
||||
if res.Header == nil {
|
||||
res.Header = make(http.Header)
|
||||
}
|
||||
setHeaders(res.Header, h)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// cleanProxyHeaders removes all headers commonly used by (reverse) HTTP proxies.
|
||||
func cleanProxyHeaders(h http.Header) {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, k := range []string{
|
||||
HeaderForwarded,
|
||||
HeaderForwardedFor,
|
||||
HeaderForwardedHost,
|
||||
HeaderForwardedPort,
|
||||
HeaderForwardedProto,
|
||||
HeaderRealIP,
|
||||
HeaderVia,
|
||||
} {
|
||||
h.Del(k)
|
||||
}
|
||||
}
|
||||
|
||||
// addHeaders adds to the current existing headers.
|
||||
func addHeaders(dst, src http.Header) {
|
||||
if src == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setHeaders replaces all previous values.
|
||||
func setHeaders(dst, src http.Header) {
|
||||
if src == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for k, vv := range src {
|
||||
dst.Del(k)
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user