Initial import
This commit is contained in:
616
proxy/proxy.go
Normal file
616
proxy/proxy.go
Normal file
@@ -0,0 +1,616 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
"git.maze.io/maze/styx/proxy/mitm"
|
||||
"git.maze.io/maze/styx/proxy/policy"
|
||||
"git.maze.io/maze/styx/proxy/resolver"
|
||||
"git.maze.io/maze/styx/proxy/stats"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultListenAddr = ":3128"
|
||||
DefaultBindAddr = ""
|
||||
DefaultDialTimeout = 30 * time.Second
|
||||
DefaultKeepAlivePeriod = 1 * time.Minute
|
||||
)
|
||||
|
||||
const (
|
||||
HeaderAcceptEncoding = "Accept-Encoding"
|
||||
HeaderConnection = "Connection"
|
||||
HeaderContentLength = "Content-Length"
|
||||
HeaderContentType = "Content-Type"
|
||||
HeaderUpgrade = "Upgrade"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrClosed = errors.New("proxy: shutdown")
|
||||
ErrClientCert = errors.New("tls: client certificate requested")
|
||||
)
|
||||
|
||||
type Proxy struct {
|
||||
addr *net.TCPAddr
|
||||
bind *net.TCPAddr
|
||||
resolver resolver.Resolver
|
||||
transport *http.Transport
|
||||
dial func(network, address string) (net.Conn, error)
|
||||
config *Config
|
||||
authority mitm.Authority
|
||||
policy *policy.Policy
|
||||
admin *Admin
|
||||
stats *stats.Stats
|
||||
closed chan struct{}
|
||||
onConnect ConnectHandler
|
||||
onRequest RequestHandler
|
||||
onResponse ResponseHandler
|
||||
onError ErrorHandler
|
||||
}
|
||||
|
||||
func New(config *Config, ca mitm.Authority) (*Proxy, error) {
|
||||
if config == nil {
|
||||
return nil, errors.New("proxy: config can't be nil")
|
||||
}
|
||||
|
||||
p := &Proxy{
|
||||
transport: newTransport(),
|
||||
config: config,
|
||||
resolver: resolver.Default,
|
||||
authority: ca,
|
||||
policy: config.Policy,
|
||||
closed: make(chan struct{}),
|
||||
onConnect: config.ConnectHandler,
|
||||
onRequest: config.RequestHandler,
|
||||
onResponse: config.ResponseHandler,
|
||||
onError: config.ErrorHandler,
|
||||
}
|
||||
|
||||
var err error
|
||||
if config.Listen == "" {
|
||||
p.addr, err = net.ResolveTCPAddr("tcp", DefaultBindAddr)
|
||||
} else {
|
||||
p.addr, err = net.ResolveTCPAddr("tcp", config.Listen)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("proxy: invalid listen addres: %w", err)
|
||||
}
|
||||
if config.Bind != "" {
|
||||
if p.bind, err = net.ResolveTCPAddr("tcp", config.Bind+":0"); err != nil {
|
||||
return nil, fmt.Errorf("proxy: invalid bind address: %w", err)
|
||||
}
|
||||
} else if config.Interface != "" {
|
||||
if err = resolveInterfaceAddr(config.Interface); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if p.bind != nil {
|
||||
/* FIXME
|
||||
var c *net.TCPConn
|
||||
if c, err = net.DialTCP("tcp", p.bind, p.bind); err != nil && errors.Is(err, syscall.EADDRNOTAVAIL) {
|
||||
return nil, fmt.Errorf("proxy: invalid bind address: %w", syscall.EADDRNOTAVAIL)
|
||||
} else if c != nil {
|
||||
_ = c.Close()
|
||||
}
|
||||
*/
|
||||
}
|
||||
if config.Resolver != nil {
|
||||
p.resolver = config.Resolver
|
||||
}
|
||||
|
||||
dialTimeout := DefaultDialTimeout
|
||||
if config.DialTimeout > 0 {
|
||||
dialTimeout = config.DialTimeout
|
||||
}
|
||||
p.dial = (&net.Dialer{
|
||||
Timeout: dialTimeout,
|
||||
KeepAlive: dialTimeout,
|
||||
LocalAddr: p.bind,
|
||||
}).Dial
|
||||
|
||||
p.admin = NewAdmin(p)
|
||||
|
||||
if p.stats, err = stats.New(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func newTransport() *http.Transport {
|
||||
return &http.Transport{
|
||||
TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper),
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
TLSHandshakeTimeout: 15 * time.Second,
|
||||
ExpectContinueTimeout: 5 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) Close() error {
|
||||
select {
|
||||
case <-p.closed:
|
||||
return ErrClosed
|
||||
default:
|
||||
close(p.closed)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) Start() error {
|
||||
l, err := net.ListenTCP("tcp", p.addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go p.Serve(l)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Proxy) Serve(listener net.Listener) error {
|
||||
defer func() { _ = listener.Close() }()
|
||||
|
||||
log.Info().Str("addr", listener.Addr().String()).Msg("proxy server listening")
|
||||
for {
|
||||
select {
|
||||
case <-p.closed:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
c, err := listener.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
|
||||
ctx := newContext(c, rw, nil)
|
||||
|
||||
if c, ok := c.(*net.TCPConn); ok {
|
||||
_ = c.SetKeepAlive(true)
|
||||
_ = c.SetKeepAlivePeriod(DefaultKeepAlivePeriod)
|
||||
}
|
||||
|
||||
go p.handle(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) handle(ctx *Context) {
|
||||
logger := ctx.log()
|
||||
defer log.OnCloseError(logger.Debug(), ctx.conn)
|
||||
logger.Info().Str("client", ctx.RemoteAddr().String()).Msg("new client connection")
|
||||
|
||||
last := int64(0)
|
||||
for {
|
||||
select {
|
||||
case <-p.closed:
|
||||
return
|
||||
|
||||
default:
|
||||
ses, err := p.handleRequest(ctx)
|
||||
if ses != nil {
|
||||
log := ses.log()
|
||||
log.Info().
|
||||
Str("method", ses.request.Method).
|
||||
Str("url", ses.request.URL.String()).
|
||||
Str("status", ses.response.Status).
|
||||
Int64("size", ctx.conn.bytes-last).
|
||||
Msg("handled request")
|
||||
|
||||
p.stats.AddLog(&stats.Log{
|
||||
ClientIP: netutil.Host(ses.request.RemoteAddr),
|
||||
Request: stats.FromRequest(ses.request),
|
||||
Response: stats.FromResponse(ses.response).SetSize(ctx.conn.bytes - last),
|
||||
})
|
||||
|
||||
last = ctx.conn.bytes
|
||||
}
|
||||
if err != nil && !isClosing(err) || (ses != nil && ses.response != nil && ses.response.Close) {
|
||||
event := logger.Debug()
|
||||
if ctx.conn.bytes > 0 {
|
||||
event = event.Int64("size", ctx.conn.bytes)
|
||||
}
|
||||
event.Msg("closing client connection")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) handleRequest(ctx *Context) (ses *Session, err error) {
|
||||
logger := ctx.log()
|
||||
|
||||
var request *http.Request
|
||||
if request, err = p.readRequest(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ses = newSession(ctx, request)
|
||||
p.cleanRequest(ses, request)
|
||||
|
||||
logger.Debug().Str("method", request.Method).Str("url", request.URL.String()).Msg("handle request")
|
||||
|
||||
if p.onRequest != nil {
|
||||
newRequest, newResponse := p.onRequest.HandleRequest(ses)
|
||||
if newRequest != nil {
|
||||
logger.Debug().Str("method", newRequest.Method).Str("url", newRequest.URL.String()).Msg("request override")
|
||||
ses.request = newRequest
|
||||
}
|
||||
if newResponse != nil {
|
||||
logger.Debug().Str("status", newResponse.Status).Msg("response override")
|
||||
ses.response = newResponse
|
||||
}
|
||||
}
|
||||
|
||||
if ses.response == nil {
|
||||
// WebSocket request
|
||||
if ses.request.Header.Get(HeaderUpgrade) == "websocket" {
|
||||
return ses, p.handleTunnel(ses)
|
||||
}
|
||||
|
||||
cleanHopByHopHeaders(ses.request.Header)
|
||||
|
||||
// Proxy CONNECT request
|
||||
if ses.request.Method == http.MethodConnect {
|
||||
return p.handleConnect(ses)
|
||||
}
|
||||
|
||||
if netutil.Port(ses.request.URL.Host) == p.addr.Port {
|
||||
// Plain API request
|
||||
ses.request.URL.Host = ses.request.Host
|
||||
return ses, p.admin.handleRequest(ses)
|
||||
|
||||
} else if ses.response, err = p.transport.RoundTrip(ses.request); err != nil {
|
||||
// Plain HTTP request
|
||||
if p.config.ErrorHandler != nil {
|
||||
p.config.ErrorHandler.HandleError(ses, err)
|
||||
}
|
||||
ses.response = ErrorResponse(ses.request, err)
|
||||
}
|
||||
|
||||
logger.Debug().Str("status", ses.response.Status).Msg("received response")
|
||||
cleanHopByHopHeaders(ses.response.Header)
|
||||
}
|
||||
|
||||
ses.response.Close = true
|
||||
defer log.OnCloseError(logger.Debug(), ses.response.Body)
|
||||
return ses, p.writeResponse(ses)
|
||||
}
|
||||
|
||||
func (p *Proxy) handleConnect(ses *Session) (next *Session, err error) {
|
||||
next = ses
|
||||
|
||||
logger := ses.log()
|
||||
logger.Debug().Msgf("connecting to %s", ses.request.URL.Host)
|
||||
|
||||
var c net.Conn
|
||||
if c, err = p.connect(ses, "tcp", ses.request.URL.Host); err != nil {
|
||||
logger.Error().Err(err).Msg("connect failed")
|
||||
if p.onError != nil {
|
||||
p.onError.HandleError(ses, err)
|
||||
}
|
||||
|
||||
ses.response = ErrorResponse(ses.request, err)
|
||||
defer log.OnCloseError(logger.Debug(), ses.response.Body)
|
||||
_ = p.writeResponse(ses)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := c.Close(); err != nil {
|
||||
if p.onError != nil {
|
||||
p.onError.HandleError(ses, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if p.canIntercept(ses.request) {
|
||||
logger.Debug().Msg("intercepting connection")
|
||||
ses.response = NewResponse(http.StatusOK, nil, ses.request)
|
||||
err = p.writeResponse(ses)
|
||||
log.OnCloseError(logger.Debug(), ses.response.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Peek first byte
|
||||
b := make([]byte, 1)
|
||||
if _, err = io.ReadFull(ses.ctx.rw, b); err != nil {
|
||||
logger.Error().Err(err).Msg("error peeking CONNECT byte")
|
||||
return
|
||||
}
|
||||
|
||||
// Drain buffered bytes
|
||||
b = append(b, make([]byte, ses.ctx.rw.Reader.Buffered())...)
|
||||
ses.ctx.rw.Reader.Read(b[1:])
|
||||
|
||||
r := &connReader{
|
||||
Conn: ses.ctx.conn,
|
||||
Reader: io.MultiReader(bytes.NewBuffer(b), ses.ctx.conn),
|
||||
}
|
||||
if b[0] == 22 { // TLS handshake: https://tools.ietf.org/html/rfc5246#section-6.2.1
|
||||
secure := tls.Server(r, p.authority.TLSConfig(ses.request.URL.Host))
|
||||
if err = secure.Handshake(); err != nil {
|
||||
logger.Error().Err(err).Msg("error intercepting TLS connection: client handshake failed")
|
||||
return
|
||||
}
|
||||
|
||||
rw := bufio.NewReadWriter(bufio.NewReader(secure), bufio.NewWriter(secure))
|
||||
ctx := newContext(secure, rw, ses)
|
||||
return p.handleRequest(ctx)
|
||||
}
|
||||
|
||||
rw := bufio.NewReadWriter(bufio.NewReader(r), bufio.NewWriter(r))
|
||||
ctx := newContext(r, rw, ses)
|
||||
return p.handleRequest(ctx)
|
||||
}
|
||||
|
||||
ses.response = NewResponse(http.StatusOK, nil, ses.request)
|
||||
defer log.OnCloseError(logger.Debug(), ses.response.Body)
|
||||
ses.response.ContentLength = -1
|
||||
if err = p.writeResponse(ses); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug().Msg("established CONNECT tunnel, proxying traffic")
|
||||
var wait sync.WaitGroup
|
||||
wait.Go(func() { copyStream(ses, c, ses.ctx.conn) })
|
||||
wait.Go(func() { copyStream(ses, ses.ctx.conn, c) })
|
||||
wait.Wait()
|
||||
logger.Debug().Msg("closed CONNECT tunnel")
|
||||
return
|
||||
}
|
||||
|
||||
func (p *Proxy) handleTunnel(ses *Session) (err error) {
|
||||
logger := ses.log()
|
||||
logger.Debug().Msgf("connecting to %s", ses.request.URL.Host)
|
||||
|
||||
var c net.Conn
|
||||
if c, err = p.connect(ses, "tcp", ses.request.URL.Host); err != nil {
|
||||
logger.Error().Err(err).Msg("connect failed")
|
||||
if p.onError != nil {
|
||||
p.onError.HandleError(ses, err)
|
||||
}
|
||||
|
||||
ses.response = ErrorResponse(ses.request, err)
|
||||
defer log.OnCloseError(logger.Debug(), ses.response.Body)
|
||||
_ = p.writeResponse(ses)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer log.OnCloseError(logger.Debug(), c)
|
||||
|
||||
if ses.ctx.IsTLS() {
|
||||
// Open a TLS client connection
|
||||
secure := tls.Client(c, &tls.Config{
|
||||
ServerName: ses.request.URL.Host,
|
||||
GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return nil, ErrClientCert
|
||||
},
|
||||
})
|
||||
if err = secure.Handshake(); err != nil {
|
||||
logger.Error().Err(err).Msg("TLS handshake failed")
|
||||
return
|
||||
}
|
||||
c = secure
|
||||
}
|
||||
|
||||
if err = ses.request.Write(c); err != nil {
|
||||
logger.Error().Err(err).Msg("failed to write request")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug().Msg("established tunnel, proxying traffic")
|
||||
var wait sync.WaitGroup
|
||||
wait.Go(func() { copyStream(ses, c, ses.ctx.conn) })
|
||||
wait.Go(func() { copyStream(ses, ses.ctx.conn, c) })
|
||||
wait.Wait()
|
||||
logger.Debug().Msg("closed tunnel")
|
||||
return
|
||||
}
|
||||
|
||||
func (p *Proxy) canIntercept(request *http.Request) bool {
|
||||
if permit := p.policy.PermitIntercept(request); permit != nil {
|
||||
return *permit
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
/*
|
||||
func (p *Proxy) handleAPIRequest(ses *Session) error {
|
||||
if ses.request.URL.Path == "/ca.crt" && p.authority != nil {
|
||||
b := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: p.authority.Certificate().Raw,
|
||||
})
|
||||
|
||||
ses.response = NewResponse(http.StatusOK, bytes.NewReader(b), ses.request)
|
||||
defer log.OnCloseError(logger.Debug(), ses.response.Body)
|
||||
|
||||
ses.response.Close = true
|
||||
ses.response.Header.Set("Content-Type", "application/x-x509-ca-cert")
|
||||
ses.response.ContentLength = int64(len(b))
|
||||
return p.writeResponse(ses)
|
||||
}
|
||||
|
||||
ses.response = ErrorResponse(ses.request, errors.New("invalid API endpoint"))
|
||||
defer log.OnCloseError(logger.Debug(), ses.response.Body)
|
||||
ses.response.Close = true
|
||||
return p.writeResponse(ses)
|
||||
}
|
||||
*/
|
||||
|
||||
func (p *Proxy) readRequest(ctx *Context) (request *http.Request, err error) {
|
||||
var (
|
||||
done = make(chan *http.Request, 1)
|
||||
errs = make(chan error, 1)
|
||||
)
|
||||
|
||||
go func() {
|
||||
r, err := http.ReadRequest(ctx.rw.Reader)
|
||||
if err != nil {
|
||||
errs <- err
|
||||
} else {
|
||||
done <- r
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-p.closed:
|
||||
return nil, ErrClosed
|
||||
case request = <-done:
|
||||
return
|
||||
case err = <-errs:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) cleanRequest(ses *Session, request *http.Request) {
|
||||
if request.URL.Host == "" {
|
||||
request.URL.Host = request.Host
|
||||
}
|
||||
|
||||
// Ensure proper URL scheme
|
||||
if !strings.HasPrefix(request.URL.Scheme, "http") {
|
||||
request.URL.Scheme = "http"
|
||||
}
|
||||
if ses.ctx.IsTLS() {
|
||||
state := ses.ctx.conn.Conn.(*tls.Conn).ConnectionState()
|
||||
request.TLS = &state
|
||||
request.URL.Scheme = "https"
|
||||
}
|
||||
|
||||
// Ensure proper RemoteAddr
|
||||
request.RemoteAddr = ses.ctx.RemoteAddr().String()
|
||||
|
||||
// Ensure proper encoding
|
||||
if request.Header.Get(HeaderAcceptEncoding) != "" {
|
||||
// We only support gzip
|
||||
request.Header.Set(HeaderAcceptEncoding, "gzip")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) writeResponse(ses *Session) (err error) {
|
||||
log := ses.log()
|
||||
|
||||
if p.onResponse != nil {
|
||||
response := p.onResponse.HandleResponse(ses)
|
||||
if response != nil {
|
||||
log.Debug().Str("status", response.Status).Msg("response override")
|
||||
ses.response = response
|
||||
}
|
||||
}
|
||||
|
||||
if err = ses.response.Write(ses.ctx); err != nil {
|
||||
log.Error().Err(err).Msg("error writing response back to client")
|
||||
} else if err = ses.ctx.Flush(); err != nil {
|
||||
log.Error().Err(err).Msg("error flushing response back to client")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (p *Proxy) connect(ses *Session, network, address string) (c net.Conn, err error) {
|
||||
log := ses.log()
|
||||
log.Debug().Msgf("connect to %s://%s", network, address)
|
||||
|
||||
if p.onConnect != nil {
|
||||
if c = p.onConnect.HandleConnect(ses, network, address); c != nil {
|
||||
log.Debug().Msg("connect override")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var host, port string
|
||||
if host, port, err = net.SplitHostPort(address); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var hosts []string
|
||||
if hosts, err = p.resolver.Lookup(context.Background(), host); err != nil {
|
||||
log.Warn().Err(err).Msg("connect failed: DNS lookup error")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().Str("address", hosts[0]).Msg("connect resolved address")
|
||||
return p.dial(network, net.JoinHostPort(hosts[0], port))
|
||||
}
|
||||
|
||||
var hopByHopHeaders = []string{
|
||||
HeaderConnection,
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Proxy-Connection", // Non-standard, but required for HTTP/2.
|
||||
"Te",
|
||||
"Trailer",
|
||||
"Transfer-Encoding",
|
||||
HeaderUpgrade,
|
||||
}
|
||||
|
||||
func cleanHopByHopHeaders(header http.Header) {
|
||||
// Additional hop-by-hop headers may be specified in `Connection` headers.
|
||||
// http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-14#section-9.1
|
||||
for _, values := range header[HeaderConnection] {
|
||||
for _, key := range strings.Split(values, ",") {
|
||||
header.Del(key)
|
||||
}
|
||||
}
|
||||
for _, key := range hopByHopHeaders {
|
||||
header.Del(key)
|
||||
}
|
||||
}
|
||||
|
||||
// copyStream copies data from reader to writer
|
||||
func copyStream(ses *Session, w io.Writer, r io.Reader) {
|
||||
log := ses.log()
|
||||
if _, err := io.Copy(w, r); err != nil && !isClosing(err) {
|
||||
log.Error().Err(err).Msg("failed CONNECT tunnel")
|
||||
} else {
|
||||
log.Debug().Msg("finished copying CONNECT tunnel")
|
||||
}
|
||||
}
|
||||
|
||||
func isClosing(err error) bool {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, syscall.ECONNRESET) || err == ErrClosed {
|
||||
return true
|
||||
}
|
||||
if err, ok := err.(net.Error); ok && err.Timeout() {
|
||||
return true
|
||||
}
|
||||
// log.Debug().Msgf("not a closing error %T: %#+v", err, err)
|
||||
return false
|
||||
}
|
||||
|
||||
func resolveInterfaceAddr(name string) (err error) {
|
||||
var iface *net.Interface
|
||||
if iface, err = net.InterfaceByName(name); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var addrs []net.Addr
|
||||
if addrs, err = iface.Addrs(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
if addr, ok := addr.(*net.IPNet); ok && !addr.IP.IsUnspecified() {
|
||||
log.Warn().Msgf("addr %T: %s", addr, addr)
|
||||
}
|
||||
}
|
||||
return errors.New("nope; TODO")
|
||||
}
|
Reference in New Issue
Block a user