617 lines
15 KiB
Go
617 lines
15 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"git.maze.io/maze/styx/internal/log"
|
|
"git.maze.io/maze/styx/internal/netutil"
|
|
"git.maze.io/maze/styx/proxy/mitm"
|
|
"git.maze.io/maze/styx/proxy/policy"
|
|
"git.maze.io/maze/styx/proxy/resolver"
|
|
"git.maze.io/maze/styx/proxy/stats"
|
|
)
|
|
|
|
const (
|
|
DefaultListenAddr = ":3128"
|
|
DefaultBindAddr = ""
|
|
DefaultDialTimeout = 30 * time.Second
|
|
DefaultKeepAlivePeriod = 1 * time.Minute
|
|
)
|
|
|
|
const (
|
|
HeaderAcceptEncoding = "Accept-Encoding"
|
|
HeaderConnection = "Connection"
|
|
HeaderContentLength = "Content-Length"
|
|
HeaderContentType = "Content-Type"
|
|
HeaderUpgrade = "Upgrade"
|
|
)
|
|
|
|
var (
|
|
ErrClosed = errors.New("proxy: shutdown")
|
|
ErrClientCert = errors.New("tls: client certificate requested")
|
|
)
|
|
|
|
type Proxy struct {
|
|
addr *net.TCPAddr
|
|
bind *net.TCPAddr
|
|
resolver resolver.Resolver
|
|
transport *http.Transport
|
|
dial func(network, address string) (net.Conn, error)
|
|
config *Config
|
|
authority mitm.Authority
|
|
policy *policy.Policy
|
|
admin *Admin
|
|
stats *stats.Stats
|
|
closed chan struct{}
|
|
onConnect ConnectHandler
|
|
onRequest RequestHandler
|
|
onResponse ResponseHandler
|
|
onError ErrorHandler
|
|
}
|
|
|
|
func New(config *Config, ca mitm.Authority) (*Proxy, error) {
|
|
if config == nil {
|
|
return nil, errors.New("proxy: config can't be nil")
|
|
}
|
|
|
|
p := &Proxy{
|
|
transport: newTransport(),
|
|
config: config,
|
|
resolver: resolver.Default,
|
|
authority: ca,
|
|
policy: config.Policy,
|
|
closed: make(chan struct{}),
|
|
onConnect: config.ConnectHandler,
|
|
onRequest: config.RequestHandler,
|
|
onResponse: config.ResponseHandler,
|
|
onError: config.ErrorHandler,
|
|
}
|
|
|
|
var err error
|
|
if config.Listen == "" {
|
|
p.addr, err = net.ResolveTCPAddr("tcp", DefaultBindAddr)
|
|
} else {
|
|
p.addr, err = net.ResolveTCPAddr("tcp", config.Listen)
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("proxy: invalid listen addres: %w", err)
|
|
}
|
|
if config.Bind != "" {
|
|
if p.bind, err = net.ResolveTCPAddr("tcp", config.Bind+":0"); err != nil {
|
|
return nil, fmt.Errorf("proxy: invalid bind address: %w", err)
|
|
}
|
|
} else if config.Interface != "" {
|
|
if err = resolveInterfaceAddr(config.Interface); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
if p.bind != nil {
|
|
/* FIXME
|
|
var c *net.TCPConn
|
|
if c, err = net.DialTCP("tcp", p.bind, p.bind); err != nil && errors.Is(err, syscall.EADDRNOTAVAIL) {
|
|
return nil, fmt.Errorf("proxy: invalid bind address: %w", syscall.EADDRNOTAVAIL)
|
|
} else if c != nil {
|
|
_ = c.Close()
|
|
}
|
|
*/
|
|
}
|
|
if config.Resolver != nil {
|
|
p.resolver = config.Resolver
|
|
}
|
|
|
|
dialTimeout := DefaultDialTimeout
|
|
if config.DialTimeout > 0 {
|
|
dialTimeout = config.DialTimeout
|
|
}
|
|
p.dial = (&net.Dialer{
|
|
Timeout: dialTimeout,
|
|
KeepAlive: dialTimeout,
|
|
LocalAddr: p.bind,
|
|
}).Dial
|
|
|
|
p.admin = NewAdmin(p)
|
|
|
|
if p.stats, err = stats.New(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return p, nil
|
|
}
|
|
|
|
func newTransport() *http.Transport {
|
|
return &http.Transport{
|
|
TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper),
|
|
Proxy: http.ProxyFromEnvironment,
|
|
TLSHandshakeTimeout: 15 * time.Second,
|
|
ExpectContinueTimeout: 5 * time.Second,
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) Close() error {
|
|
select {
|
|
case <-p.closed:
|
|
return ErrClosed
|
|
default:
|
|
close(p.closed)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) Start() error {
|
|
l, err := net.ListenTCP("tcp", p.addr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
go p.Serve(l)
|
|
return nil
|
|
}
|
|
|
|
func (p *Proxy) Serve(listener net.Listener) error {
|
|
defer func() { _ = listener.Close() }()
|
|
|
|
log.Info().Str("addr", listener.Addr().String()).Msg("proxy server listening")
|
|
for {
|
|
select {
|
|
case <-p.closed:
|
|
return nil
|
|
default:
|
|
}
|
|
|
|
c, err := listener.Accept()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
|
|
ctx := newContext(c, rw, nil)
|
|
|
|
if c, ok := c.(*net.TCPConn); ok {
|
|
_ = c.SetKeepAlive(true)
|
|
_ = c.SetKeepAlivePeriod(DefaultKeepAlivePeriod)
|
|
}
|
|
|
|
go p.handle(ctx)
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) handle(ctx *Context) {
|
|
logger := ctx.log()
|
|
defer log.OnCloseError(logger.Debug(), ctx.conn)
|
|
logger.Info().Str("client", ctx.RemoteAddr().String()).Msg("new client connection")
|
|
|
|
last := int64(0)
|
|
for {
|
|
select {
|
|
case <-p.closed:
|
|
return
|
|
|
|
default:
|
|
ses, err := p.handleRequest(ctx)
|
|
if ses != nil {
|
|
log := ses.log()
|
|
log.Info().
|
|
Str("method", ses.request.Method).
|
|
Str("url", ses.request.URL.String()).
|
|
Str("status", ses.response.Status).
|
|
Int64("size", ctx.conn.bytes-last).
|
|
Msg("handled request")
|
|
|
|
p.stats.AddLog(&stats.Log{
|
|
ClientIP: netutil.Host(ses.request.RemoteAddr),
|
|
Request: stats.FromRequest(ses.request),
|
|
Response: stats.FromResponse(ses.response).SetSize(ctx.conn.bytes - last),
|
|
})
|
|
|
|
last = ctx.conn.bytes
|
|
}
|
|
if err != nil && !isClosing(err) || (ses != nil && ses.response != nil && ses.response.Close) {
|
|
event := logger.Debug()
|
|
if ctx.conn.bytes > 0 {
|
|
event = event.Int64("size", ctx.conn.bytes)
|
|
}
|
|
event.Msg("closing client connection")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) handleRequest(ctx *Context) (ses *Session, err error) {
|
|
logger := ctx.log()
|
|
|
|
var request *http.Request
|
|
if request, err = p.readRequest(ctx); err != nil {
|
|
return
|
|
}
|
|
|
|
ses = newSession(ctx, request)
|
|
p.cleanRequest(ses, request)
|
|
|
|
logger.Debug().Str("method", request.Method).Str("url", request.URL.String()).Msg("handle request")
|
|
|
|
if p.onRequest != nil {
|
|
newRequest, newResponse := p.onRequest.HandleRequest(ses)
|
|
if newRequest != nil {
|
|
logger.Debug().Str("method", newRequest.Method).Str("url", newRequest.URL.String()).Msg("request override")
|
|
ses.request = newRequest
|
|
}
|
|
if newResponse != nil {
|
|
logger.Debug().Str("status", newResponse.Status).Msg("response override")
|
|
ses.response = newResponse
|
|
}
|
|
}
|
|
|
|
if ses.response == nil {
|
|
// WebSocket request
|
|
if ses.request.Header.Get(HeaderUpgrade) == "websocket" {
|
|
return ses, p.handleTunnel(ses)
|
|
}
|
|
|
|
cleanHopByHopHeaders(ses.request.Header)
|
|
|
|
// Proxy CONNECT request
|
|
if ses.request.Method == http.MethodConnect {
|
|
return p.handleConnect(ses)
|
|
}
|
|
|
|
if netutil.Port(ses.request.URL.Host) == p.addr.Port {
|
|
// Plain API request
|
|
ses.request.URL.Host = ses.request.Host
|
|
return ses, p.admin.handleRequest(ses)
|
|
|
|
} else if ses.response, err = p.transport.RoundTrip(ses.request); err != nil {
|
|
// Plain HTTP request
|
|
if p.config.ErrorHandler != nil {
|
|
p.config.ErrorHandler.HandleError(ses, err)
|
|
}
|
|
ses.response = ErrorResponse(ses.request, err)
|
|
}
|
|
|
|
logger.Debug().Str("status", ses.response.Status).Msg("received response")
|
|
cleanHopByHopHeaders(ses.response.Header)
|
|
}
|
|
|
|
ses.response.Close = true
|
|
defer log.OnCloseError(logger.Debug(), ses.response.Body)
|
|
return ses, p.writeResponse(ses)
|
|
}
|
|
|
|
func (p *Proxy) handleConnect(ses *Session) (next *Session, err error) {
|
|
next = ses
|
|
|
|
logger := ses.log()
|
|
logger.Debug().Msgf("connecting to %s", ses.request.URL.Host)
|
|
|
|
var c net.Conn
|
|
if c, err = p.connect(ses, "tcp", ses.request.URL.Host); err != nil {
|
|
logger.Error().Err(err).Msg("connect failed")
|
|
if p.onError != nil {
|
|
p.onError.HandleError(ses, err)
|
|
}
|
|
|
|
ses.response = ErrorResponse(ses.request, err)
|
|
defer log.OnCloseError(logger.Debug(), ses.response.Body)
|
|
_ = p.writeResponse(ses)
|
|
|
|
return
|
|
}
|
|
|
|
defer func() {
|
|
if err := c.Close(); err != nil {
|
|
if p.onError != nil {
|
|
p.onError.HandleError(ses, err)
|
|
}
|
|
}
|
|
}()
|
|
|
|
if p.canIntercept(ses.request) {
|
|
logger.Debug().Msg("intercepting connection")
|
|
ses.response = NewResponse(http.StatusOK, nil, ses.request)
|
|
err = p.writeResponse(ses)
|
|
log.OnCloseError(logger.Debug(), ses.response.Body)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// Peek first byte
|
|
b := make([]byte, 1)
|
|
if _, err = io.ReadFull(ses.ctx.rw, b); err != nil {
|
|
logger.Error().Err(err).Msg("error peeking CONNECT byte")
|
|
return
|
|
}
|
|
|
|
// Drain buffered bytes
|
|
b = append(b, make([]byte, ses.ctx.rw.Reader.Buffered())...)
|
|
ses.ctx.rw.Reader.Read(b[1:])
|
|
|
|
r := &connReader{
|
|
Conn: ses.ctx.conn,
|
|
Reader: io.MultiReader(bytes.NewBuffer(b), ses.ctx.conn),
|
|
}
|
|
if b[0] == 22 { // TLS handshake: https://tools.ietf.org/html/rfc5246#section-6.2.1
|
|
secure := tls.Server(r, p.authority.TLSConfig(ses.request.URL.Host))
|
|
if err = secure.Handshake(); err != nil {
|
|
logger.Error().Err(err).Msg("error intercepting TLS connection: client handshake failed")
|
|
return
|
|
}
|
|
|
|
rw := bufio.NewReadWriter(bufio.NewReader(secure), bufio.NewWriter(secure))
|
|
ctx := newContext(secure, rw, ses)
|
|
return p.handleRequest(ctx)
|
|
}
|
|
|
|
rw := bufio.NewReadWriter(bufio.NewReader(r), bufio.NewWriter(r))
|
|
ctx := newContext(r, rw, ses)
|
|
return p.handleRequest(ctx)
|
|
}
|
|
|
|
ses.response = NewResponse(http.StatusOK, nil, ses.request)
|
|
defer log.OnCloseError(logger.Debug(), ses.response.Body)
|
|
ses.response.ContentLength = -1
|
|
if err = p.writeResponse(ses); err != nil {
|
|
return
|
|
}
|
|
|
|
logger.Debug().Msg("established CONNECT tunnel, proxying traffic")
|
|
var wait sync.WaitGroup
|
|
wait.Go(func() { copyStream(ses, c, ses.ctx.conn) })
|
|
wait.Go(func() { copyStream(ses, ses.ctx.conn, c) })
|
|
wait.Wait()
|
|
logger.Debug().Msg("closed CONNECT tunnel")
|
|
return
|
|
}
|
|
|
|
func (p *Proxy) handleTunnel(ses *Session) (err error) {
|
|
logger := ses.log()
|
|
logger.Debug().Msgf("connecting to %s", ses.request.URL.Host)
|
|
|
|
var c net.Conn
|
|
if c, err = p.connect(ses, "tcp", ses.request.URL.Host); err != nil {
|
|
logger.Error().Err(err).Msg("connect failed")
|
|
if p.onError != nil {
|
|
p.onError.HandleError(ses, err)
|
|
}
|
|
|
|
ses.response = ErrorResponse(ses.request, err)
|
|
defer log.OnCloseError(logger.Debug(), ses.response.Body)
|
|
_ = p.writeResponse(ses)
|
|
|
|
return
|
|
}
|
|
|
|
defer log.OnCloseError(logger.Debug(), c)
|
|
|
|
if ses.ctx.IsTLS() {
|
|
// Open a TLS client connection
|
|
secure := tls.Client(c, &tls.Config{
|
|
ServerName: ses.request.URL.Host,
|
|
GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
|
return nil, ErrClientCert
|
|
},
|
|
})
|
|
if err = secure.Handshake(); err != nil {
|
|
logger.Error().Err(err).Msg("TLS handshake failed")
|
|
return
|
|
}
|
|
c = secure
|
|
}
|
|
|
|
if err = ses.request.Write(c); err != nil {
|
|
logger.Error().Err(err).Msg("failed to write request")
|
|
return
|
|
}
|
|
|
|
logger.Debug().Msg("established tunnel, proxying traffic")
|
|
var wait sync.WaitGroup
|
|
wait.Go(func() { copyStream(ses, c, ses.ctx.conn) })
|
|
wait.Go(func() { copyStream(ses, ses.ctx.conn, c) })
|
|
wait.Wait()
|
|
logger.Debug().Msg("closed tunnel")
|
|
return
|
|
}
|
|
|
|
func (p *Proxy) canIntercept(request *http.Request) bool {
|
|
if permit := p.policy.PermitIntercept(request); permit != nil {
|
|
return *permit
|
|
}
|
|
return true
|
|
}
|
|
|
|
/*
|
|
func (p *Proxy) handleAPIRequest(ses *Session) error {
|
|
if ses.request.URL.Path == "/ca.crt" && p.authority != nil {
|
|
b := pem.EncodeToMemory(&pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: p.authority.Certificate().Raw,
|
|
})
|
|
|
|
ses.response = NewResponse(http.StatusOK, bytes.NewReader(b), ses.request)
|
|
defer log.OnCloseError(logger.Debug(), ses.response.Body)
|
|
|
|
ses.response.Close = true
|
|
ses.response.Header.Set("Content-Type", "application/x-x509-ca-cert")
|
|
ses.response.ContentLength = int64(len(b))
|
|
return p.writeResponse(ses)
|
|
}
|
|
|
|
ses.response = ErrorResponse(ses.request, errors.New("invalid API endpoint"))
|
|
defer log.OnCloseError(logger.Debug(), ses.response.Body)
|
|
ses.response.Close = true
|
|
return p.writeResponse(ses)
|
|
}
|
|
*/
|
|
|
|
func (p *Proxy) readRequest(ctx *Context) (request *http.Request, err error) {
|
|
var (
|
|
done = make(chan *http.Request, 1)
|
|
errs = make(chan error, 1)
|
|
)
|
|
|
|
go func() {
|
|
r, err := http.ReadRequest(ctx.rw.Reader)
|
|
if err != nil {
|
|
errs <- err
|
|
} else {
|
|
done <- r
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case <-p.closed:
|
|
return nil, ErrClosed
|
|
case request = <-done:
|
|
return
|
|
case err = <-errs:
|
|
return
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) cleanRequest(ses *Session, request *http.Request) {
|
|
if request.URL.Host == "" {
|
|
request.URL.Host = request.Host
|
|
}
|
|
|
|
// Ensure proper URL scheme
|
|
if !strings.HasPrefix(request.URL.Scheme, "http") {
|
|
request.URL.Scheme = "http"
|
|
}
|
|
if ses.ctx.IsTLS() {
|
|
state := ses.ctx.conn.Conn.(*tls.Conn).ConnectionState()
|
|
request.TLS = &state
|
|
request.URL.Scheme = "https"
|
|
}
|
|
|
|
// Ensure proper RemoteAddr
|
|
request.RemoteAddr = ses.ctx.RemoteAddr().String()
|
|
|
|
// Ensure proper encoding
|
|
if request.Header.Get(HeaderAcceptEncoding) != "" {
|
|
// We only support gzip
|
|
request.Header.Set(HeaderAcceptEncoding, "gzip")
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) writeResponse(ses *Session) (err error) {
|
|
log := ses.log()
|
|
|
|
if p.onResponse != nil {
|
|
response := p.onResponse.HandleResponse(ses)
|
|
if response != nil {
|
|
log.Debug().Str("status", response.Status).Msg("response override")
|
|
ses.response = response
|
|
}
|
|
}
|
|
|
|
if err = ses.response.Write(ses.ctx); err != nil {
|
|
log.Error().Err(err).Msg("error writing response back to client")
|
|
} else if err = ses.ctx.Flush(); err != nil {
|
|
log.Error().Err(err).Msg("error flushing response back to client")
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (p *Proxy) connect(ses *Session, network, address string) (c net.Conn, err error) {
|
|
log := ses.log()
|
|
log.Debug().Msgf("connect to %s://%s", network, address)
|
|
|
|
if p.onConnect != nil {
|
|
if c = p.onConnect.HandleConnect(ses, network, address); c != nil {
|
|
log.Debug().Msg("connect override")
|
|
return
|
|
}
|
|
}
|
|
|
|
var host, port string
|
|
if host, port, err = net.SplitHostPort(address); err != nil {
|
|
return
|
|
}
|
|
|
|
var hosts []string
|
|
if hosts, err = p.resolver.Lookup(context.Background(), host); err != nil {
|
|
log.Warn().Err(err).Msg("connect failed: DNS lookup error")
|
|
return
|
|
}
|
|
|
|
log.Debug().Str("address", hosts[0]).Msg("connect resolved address")
|
|
return p.dial(network, net.JoinHostPort(hosts[0], port))
|
|
}
|
|
|
|
var hopByHopHeaders = []string{
|
|
HeaderConnection,
|
|
"Keep-Alive",
|
|
"Proxy-Authenticate",
|
|
"Proxy-Authorization",
|
|
"Proxy-Connection", // Non-standard, but required for HTTP/2.
|
|
"Te",
|
|
"Trailer",
|
|
"Transfer-Encoding",
|
|
HeaderUpgrade,
|
|
}
|
|
|
|
func cleanHopByHopHeaders(header http.Header) {
|
|
// Additional hop-by-hop headers may be specified in `Connection` headers.
|
|
// http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-14#section-9.1
|
|
for _, values := range header[HeaderConnection] {
|
|
for _, key := range strings.Split(values, ",") {
|
|
header.Del(key)
|
|
}
|
|
}
|
|
for _, key := range hopByHopHeaders {
|
|
header.Del(key)
|
|
}
|
|
}
|
|
|
|
// copyStream copies data from reader to writer
|
|
func copyStream(ses *Session, w io.Writer, r io.Reader) {
|
|
log := ses.log()
|
|
if _, err := io.Copy(w, r); err != nil && !isClosing(err) {
|
|
log.Error().Err(err).Msg("failed CONNECT tunnel")
|
|
} else {
|
|
log.Debug().Msg("finished copying CONNECT tunnel")
|
|
}
|
|
}
|
|
|
|
func isClosing(err error) bool {
|
|
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, syscall.ECONNRESET) || err == ErrClosed {
|
|
return true
|
|
}
|
|
if err, ok := err.(net.Error); ok && err.Timeout() {
|
|
return true
|
|
}
|
|
// log.Debug().Msgf("not a closing error %T: %#+v", err, err)
|
|
return false
|
|
}
|
|
|
|
func resolveInterfaceAddr(name string) (err error) {
|
|
var iface *net.Interface
|
|
if iface, err = net.InterfaceByName(name); err != nil {
|
|
return
|
|
}
|
|
|
|
var addrs []net.Addr
|
|
if addrs, err = iface.Addrs(); err != nil {
|
|
return
|
|
}
|
|
|
|
for _, addr := range addrs {
|
|
if addr, ok := addr.(*net.IPNet); ok && !addr.IP.IsUnspecified() {
|
|
log.Warn().Msgf("addr %T: %s", addr, addr)
|
|
}
|
|
}
|
|
return errors.New("nope; TODO")
|
|
}
|