Initial import

This commit is contained in:
2025-09-26 08:49:53 +02:00
commit a76650da35
35 changed files with 4660 additions and 0 deletions

616
proxy/proxy.go Normal file
View File

@@ -0,0 +1,616 @@
package proxy
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"syscall"
"time"
"git.maze.io/maze/styx/internal/log"
"git.maze.io/maze/styx/internal/netutil"
"git.maze.io/maze/styx/proxy/mitm"
"git.maze.io/maze/styx/proxy/policy"
"git.maze.io/maze/styx/proxy/resolver"
"git.maze.io/maze/styx/proxy/stats"
)
const (
DefaultListenAddr = ":3128"
DefaultBindAddr = ""
DefaultDialTimeout = 30 * time.Second
DefaultKeepAlivePeriod = 1 * time.Minute
)
const (
HeaderAcceptEncoding = "Accept-Encoding"
HeaderConnection = "Connection"
HeaderContentLength = "Content-Length"
HeaderContentType = "Content-Type"
HeaderUpgrade = "Upgrade"
)
var (
ErrClosed = errors.New("proxy: shutdown")
ErrClientCert = errors.New("tls: client certificate requested")
)
type Proxy struct {
addr *net.TCPAddr
bind *net.TCPAddr
resolver resolver.Resolver
transport *http.Transport
dial func(network, address string) (net.Conn, error)
config *Config
authority mitm.Authority
policy *policy.Policy
admin *Admin
stats *stats.Stats
closed chan struct{}
onConnect ConnectHandler
onRequest RequestHandler
onResponse ResponseHandler
onError ErrorHandler
}
func New(config *Config, ca mitm.Authority) (*Proxy, error) {
if config == nil {
return nil, errors.New("proxy: config can't be nil")
}
p := &Proxy{
transport: newTransport(),
config: config,
resolver: resolver.Default,
authority: ca,
policy: config.Policy,
closed: make(chan struct{}),
onConnect: config.ConnectHandler,
onRequest: config.RequestHandler,
onResponse: config.ResponseHandler,
onError: config.ErrorHandler,
}
var err error
if config.Listen == "" {
p.addr, err = net.ResolveTCPAddr("tcp", DefaultBindAddr)
} else {
p.addr, err = net.ResolveTCPAddr("tcp", config.Listen)
}
if err != nil {
return nil, fmt.Errorf("proxy: invalid listen addres: %w", err)
}
if config.Bind != "" {
if p.bind, err = net.ResolveTCPAddr("tcp", config.Bind+":0"); err != nil {
return nil, fmt.Errorf("proxy: invalid bind address: %w", err)
}
} else if config.Interface != "" {
if err = resolveInterfaceAddr(config.Interface); err != nil {
return nil, err
}
}
if p.bind != nil {
/* FIXME
var c *net.TCPConn
if c, err = net.DialTCP("tcp", p.bind, p.bind); err != nil && errors.Is(err, syscall.EADDRNOTAVAIL) {
return nil, fmt.Errorf("proxy: invalid bind address: %w", syscall.EADDRNOTAVAIL)
} else if c != nil {
_ = c.Close()
}
*/
}
if config.Resolver != nil {
p.resolver = config.Resolver
}
dialTimeout := DefaultDialTimeout
if config.DialTimeout > 0 {
dialTimeout = config.DialTimeout
}
p.dial = (&net.Dialer{
Timeout: dialTimeout,
KeepAlive: dialTimeout,
LocalAddr: p.bind,
}).Dial
p.admin = NewAdmin(p)
if p.stats, err = stats.New(); err != nil {
return nil, err
}
return p, nil
}
func newTransport() *http.Transport {
return &http.Transport{
TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper),
Proxy: http.ProxyFromEnvironment,
TLSHandshakeTimeout: 15 * time.Second,
ExpectContinueTimeout: 5 * time.Second,
}
}
func (p *Proxy) Close() error {
select {
case <-p.closed:
return ErrClosed
default:
close(p.closed)
return nil
}
}
func (p *Proxy) Start() error {
l, err := net.ListenTCP("tcp", p.addr)
if err != nil {
return err
}
go p.Serve(l)
return nil
}
func (p *Proxy) Serve(listener net.Listener) error {
defer func() { _ = listener.Close() }()
log.Info().Str("addr", listener.Addr().String()).Msg("proxy server listening")
for {
select {
case <-p.closed:
return nil
default:
}
c, err := listener.Accept()
if err != nil {
return err
}
rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
ctx := newContext(c, rw, nil)
if c, ok := c.(*net.TCPConn); ok {
_ = c.SetKeepAlive(true)
_ = c.SetKeepAlivePeriod(DefaultKeepAlivePeriod)
}
go p.handle(ctx)
}
}
func (p *Proxy) handle(ctx *Context) {
logger := ctx.log()
defer log.OnCloseError(logger.Debug(), ctx.conn)
logger.Info().Str("client", ctx.RemoteAddr().String()).Msg("new client connection")
last := int64(0)
for {
select {
case <-p.closed:
return
default:
ses, err := p.handleRequest(ctx)
if ses != nil {
log := ses.log()
log.Info().
Str("method", ses.request.Method).
Str("url", ses.request.URL.String()).
Str("status", ses.response.Status).
Int64("size", ctx.conn.bytes-last).
Msg("handled request")
p.stats.AddLog(&stats.Log{
ClientIP: netutil.Host(ses.request.RemoteAddr),
Request: stats.FromRequest(ses.request),
Response: stats.FromResponse(ses.response).SetSize(ctx.conn.bytes - last),
})
last = ctx.conn.bytes
}
if err != nil && !isClosing(err) || (ses != nil && ses.response != nil && ses.response.Close) {
event := logger.Debug()
if ctx.conn.bytes > 0 {
event = event.Int64("size", ctx.conn.bytes)
}
event.Msg("closing client connection")
return
}
}
}
}
func (p *Proxy) handleRequest(ctx *Context) (ses *Session, err error) {
logger := ctx.log()
var request *http.Request
if request, err = p.readRequest(ctx); err != nil {
return
}
ses = newSession(ctx, request)
p.cleanRequest(ses, request)
logger.Debug().Str("method", request.Method).Str("url", request.URL.String()).Msg("handle request")
if p.onRequest != nil {
newRequest, newResponse := p.onRequest.HandleRequest(ses)
if newRequest != nil {
logger.Debug().Str("method", newRequest.Method).Str("url", newRequest.URL.String()).Msg("request override")
ses.request = newRequest
}
if newResponse != nil {
logger.Debug().Str("status", newResponse.Status).Msg("response override")
ses.response = newResponse
}
}
if ses.response == nil {
// WebSocket request
if ses.request.Header.Get(HeaderUpgrade) == "websocket" {
return ses, p.handleTunnel(ses)
}
cleanHopByHopHeaders(ses.request.Header)
// Proxy CONNECT request
if ses.request.Method == http.MethodConnect {
return p.handleConnect(ses)
}
if netutil.Port(ses.request.URL.Host) == p.addr.Port {
// Plain API request
ses.request.URL.Host = ses.request.Host
return ses, p.admin.handleRequest(ses)
} else if ses.response, err = p.transport.RoundTrip(ses.request); err != nil {
// Plain HTTP request
if p.config.ErrorHandler != nil {
p.config.ErrorHandler.HandleError(ses, err)
}
ses.response = ErrorResponse(ses.request, err)
}
logger.Debug().Str("status", ses.response.Status).Msg("received response")
cleanHopByHopHeaders(ses.response.Header)
}
ses.response.Close = true
defer log.OnCloseError(logger.Debug(), ses.response.Body)
return ses, p.writeResponse(ses)
}
func (p *Proxy) handleConnect(ses *Session) (next *Session, err error) {
next = ses
logger := ses.log()
logger.Debug().Msgf("connecting to %s", ses.request.URL.Host)
var c net.Conn
if c, err = p.connect(ses, "tcp", ses.request.URL.Host); err != nil {
logger.Error().Err(err).Msg("connect failed")
if p.onError != nil {
p.onError.HandleError(ses, err)
}
ses.response = ErrorResponse(ses.request, err)
defer log.OnCloseError(logger.Debug(), ses.response.Body)
_ = p.writeResponse(ses)
return
}
defer func() {
if err := c.Close(); err != nil {
if p.onError != nil {
p.onError.HandleError(ses, err)
}
}
}()
if p.canIntercept(ses.request) {
logger.Debug().Msg("intercepting connection")
ses.response = NewResponse(http.StatusOK, nil, ses.request)
err = p.writeResponse(ses)
log.OnCloseError(logger.Debug(), ses.response.Body)
if err != nil {
return
}
// Peek first byte
b := make([]byte, 1)
if _, err = io.ReadFull(ses.ctx.rw, b); err != nil {
logger.Error().Err(err).Msg("error peeking CONNECT byte")
return
}
// Drain buffered bytes
b = append(b, make([]byte, ses.ctx.rw.Reader.Buffered())...)
ses.ctx.rw.Reader.Read(b[1:])
r := &connReader{
Conn: ses.ctx.conn,
Reader: io.MultiReader(bytes.NewBuffer(b), ses.ctx.conn),
}
if b[0] == 22 { // TLS handshake: https://tools.ietf.org/html/rfc5246#section-6.2.1
secure := tls.Server(r, p.authority.TLSConfig(ses.request.URL.Host))
if err = secure.Handshake(); err != nil {
logger.Error().Err(err).Msg("error intercepting TLS connection: client handshake failed")
return
}
rw := bufio.NewReadWriter(bufio.NewReader(secure), bufio.NewWriter(secure))
ctx := newContext(secure, rw, ses)
return p.handleRequest(ctx)
}
rw := bufio.NewReadWriter(bufio.NewReader(r), bufio.NewWriter(r))
ctx := newContext(r, rw, ses)
return p.handleRequest(ctx)
}
ses.response = NewResponse(http.StatusOK, nil, ses.request)
defer log.OnCloseError(logger.Debug(), ses.response.Body)
ses.response.ContentLength = -1
if err = p.writeResponse(ses); err != nil {
return
}
logger.Debug().Msg("established CONNECT tunnel, proxying traffic")
var wait sync.WaitGroup
wait.Go(func() { copyStream(ses, c, ses.ctx.conn) })
wait.Go(func() { copyStream(ses, ses.ctx.conn, c) })
wait.Wait()
logger.Debug().Msg("closed CONNECT tunnel")
return
}
func (p *Proxy) handleTunnel(ses *Session) (err error) {
logger := ses.log()
logger.Debug().Msgf("connecting to %s", ses.request.URL.Host)
var c net.Conn
if c, err = p.connect(ses, "tcp", ses.request.URL.Host); err != nil {
logger.Error().Err(err).Msg("connect failed")
if p.onError != nil {
p.onError.HandleError(ses, err)
}
ses.response = ErrorResponse(ses.request, err)
defer log.OnCloseError(logger.Debug(), ses.response.Body)
_ = p.writeResponse(ses)
return
}
defer log.OnCloseError(logger.Debug(), c)
if ses.ctx.IsTLS() {
// Open a TLS client connection
secure := tls.Client(c, &tls.Config{
ServerName: ses.request.URL.Host,
GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return nil, ErrClientCert
},
})
if err = secure.Handshake(); err != nil {
logger.Error().Err(err).Msg("TLS handshake failed")
return
}
c = secure
}
if err = ses.request.Write(c); err != nil {
logger.Error().Err(err).Msg("failed to write request")
return
}
logger.Debug().Msg("established tunnel, proxying traffic")
var wait sync.WaitGroup
wait.Go(func() { copyStream(ses, c, ses.ctx.conn) })
wait.Go(func() { copyStream(ses, ses.ctx.conn, c) })
wait.Wait()
logger.Debug().Msg("closed tunnel")
return
}
func (p *Proxy) canIntercept(request *http.Request) bool {
if permit := p.policy.PermitIntercept(request); permit != nil {
return *permit
}
return true
}
/*
func (p *Proxy) handleAPIRequest(ses *Session) error {
if ses.request.URL.Path == "/ca.crt" && p.authority != nil {
b := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: p.authority.Certificate().Raw,
})
ses.response = NewResponse(http.StatusOK, bytes.NewReader(b), ses.request)
defer log.OnCloseError(logger.Debug(), ses.response.Body)
ses.response.Close = true
ses.response.Header.Set("Content-Type", "application/x-x509-ca-cert")
ses.response.ContentLength = int64(len(b))
return p.writeResponse(ses)
}
ses.response = ErrorResponse(ses.request, errors.New("invalid API endpoint"))
defer log.OnCloseError(logger.Debug(), ses.response.Body)
ses.response.Close = true
return p.writeResponse(ses)
}
*/
func (p *Proxy) readRequest(ctx *Context) (request *http.Request, err error) {
var (
done = make(chan *http.Request, 1)
errs = make(chan error, 1)
)
go func() {
r, err := http.ReadRequest(ctx.rw.Reader)
if err != nil {
errs <- err
} else {
done <- r
}
}()
select {
case <-p.closed:
return nil, ErrClosed
case request = <-done:
return
case err = <-errs:
return
}
}
func (p *Proxy) cleanRequest(ses *Session, request *http.Request) {
if request.URL.Host == "" {
request.URL.Host = request.Host
}
// Ensure proper URL scheme
if !strings.HasPrefix(request.URL.Scheme, "http") {
request.URL.Scheme = "http"
}
if ses.ctx.IsTLS() {
state := ses.ctx.conn.Conn.(*tls.Conn).ConnectionState()
request.TLS = &state
request.URL.Scheme = "https"
}
// Ensure proper RemoteAddr
request.RemoteAddr = ses.ctx.RemoteAddr().String()
// Ensure proper encoding
if request.Header.Get(HeaderAcceptEncoding) != "" {
// We only support gzip
request.Header.Set(HeaderAcceptEncoding, "gzip")
}
}
func (p *Proxy) writeResponse(ses *Session) (err error) {
log := ses.log()
if p.onResponse != nil {
response := p.onResponse.HandleResponse(ses)
if response != nil {
log.Debug().Str("status", response.Status).Msg("response override")
ses.response = response
}
}
if err = ses.response.Write(ses.ctx); err != nil {
log.Error().Err(err).Msg("error writing response back to client")
} else if err = ses.ctx.Flush(); err != nil {
log.Error().Err(err).Msg("error flushing response back to client")
}
return
}
func (p *Proxy) connect(ses *Session, network, address string) (c net.Conn, err error) {
log := ses.log()
log.Debug().Msgf("connect to %s://%s", network, address)
if p.onConnect != nil {
if c = p.onConnect.HandleConnect(ses, network, address); c != nil {
log.Debug().Msg("connect override")
return
}
}
var host, port string
if host, port, err = net.SplitHostPort(address); err != nil {
return
}
var hosts []string
if hosts, err = p.resolver.Lookup(context.Background(), host); err != nil {
log.Warn().Err(err).Msg("connect failed: DNS lookup error")
return
}
log.Debug().Str("address", hosts[0]).Msg("connect resolved address")
return p.dial(network, net.JoinHostPort(hosts[0], port))
}
var hopByHopHeaders = []string{
HeaderConnection,
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Proxy-Connection", // Non-standard, but required for HTTP/2.
"Te",
"Trailer",
"Transfer-Encoding",
HeaderUpgrade,
}
func cleanHopByHopHeaders(header http.Header) {
// Additional hop-by-hop headers may be specified in `Connection` headers.
// http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-14#section-9.1
for _, values := range header[HeaderConnection] {
for _, key := range strings.Split(values, ",") {
header.Del(key)
}
}
for _, key := range hopByHopHeaders {
header.Del(key)
}
}
// copyStream copies data from reader to writer
func copyStream(ses *Session, w io.Writer, r io.Reader) {
log := ses.log()
if _, err := io.Copy(w, r); err != nil && !isClosing(err) {
log.Error().Err(err).Msg("failed CONNECT tunnel")
} else {
log.Debug().Msg("finished copying CONNECT tunnel")
}
}
func isClosing(err error) bool {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, syscall.ECONNRESET) || err == ErrClosed {
return true
}
if err, ok := err.(net.Error); ok && err.Timeout() {
return true
}
// log.Debug().Msgf("not a closing error %T: %#+v", err, err)
return false
}
func resolveInterfaceAddr(name string) (err error) {
var iface *net.Interface
if iface, err = net.InterfaceByName(name); err != nil {
return
}
var addrs []net.Addr
if addrs, err = iface.Addrs(); err != nil {
return
}
for _, addr := range addrs {
if addr, ok := addr.(*net.IPNet); ok && !addr.IP.IsUnspecified() {
log.Warn().Msgf("addr %T: %s", addr, addr)
}
}
return errors.New("nope; TODO")
}