package proxy import ( "bufio" "bytes" "context" "crypto/tls" "errors" "fmt" "io" "net" "net/http" "strings" "sync" "syscall" "time" "git.maze.io/maze/styx/internal/log" "git.maze.io/maze/styx/internal/netutil" "git.maze.io/maze/styx/proxy/mitm" "git.maze.io/maze/styx/proxy/policy" "git.maze.io/maze/styx/proxy/resolver" "git.maze.io/maze/styx/proxy/stats" ) const ( DefaultListenAddr = ":3128" DefaultBindAddr = "" DefaultDialTimeout = 30 * time.Second DefaultKeepAlivePeriod = 1 * time.Minute ) const ( HeaderAcceptEncoding = "Accept-Encoding" HeaderConnection = "Connection" HeaderContentLength = "Content-Length" HeaderContentType = "Content-Type" HeaderUpgrade = "Upgrade" ) var ( ErrClosed = errors.New("proxy: shutdown") ErrClientCert = errors.New("tls: client certificate requested") ) type Proxy struct { addr *net.TCPAddr bind *net.TCPAddr resolver resolver.Resolver transport *http.Transport dial func(network, address string) (net.Conn, error) config *Config authority mitm.Authority policy *policy.Policy admin *Admin stats *stats.Stats closed chan struct{} onConnect ConnectHandler onRequest RequestHandler onResponse ResponseHandler onError ErrorHandler } func New(config *Config, ca mitm.Authority) (*Proxy, error) { if config == nil { return nil, errors.New("proxy: config can't be nil") } p := &Proxy{ transport: newTransport(), config: config, resolver: resolver.Default, authority: ca, policy: config.Policy, closed: make(chan struct{}), onConnect: config.ConnectHandler, onRequest: config.RequestHandler, onResponse: config.ResponseHandler, onError: config.ErrorHandler, } var err error if config.Listen == "" { p.addr, err = net.ResolveTCPAddr("tcp", DefaultBindAddr) } else { p.addr, err = net.ResolveTCPAddr("tcp", config.Listen) } if err != nil { return nil, fmt.Errorf("proxy: invalid listen addres: %w", err) } if config.Bind != "" { if p.bind, err = net.ResolveTCPAddr("tcp", config.Bind+":0"); err != nil { return nil, fmt.Errorf("proxy: invalid bind address: %w", err) } } else if config.Interface != "" { if err = resolveInterfaceAddr(config.Interface); err != nil { return nil, err } } if p.bind != nil { /* FIXME var c *net.TCPConn if c, err = net.DialTCP("tcp", p.bind, p.bind); err != nil && errors.Is(err, syscall.EADDRNOTAVAIL) { return nil, fmt.Errorf("proxy: invalid bind address: %w", syscall.EADDRNOTAVAIL) } else if c != nil { _ = c.Close() } */ } if config.Resolver != nil { p.resolver = config.Resolver } dialTimeout := DefaultDialTimeout if config.DialTimeout > 0 { dialTimeout = config.DialTimeout } p.dial = (&net.Dialer{ Timeout: dialTimeout, KeepAlive: dialTimeout, LocalAddr: p.bind, }).Dial p.admin = NewAdmin(p) if p.stats, err = stats.New(); err != nil { return nil, err } return p, nil } func newTransport() *http.Transport { return &http.Transport{ TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper), Proxy: http.ProxyFromEnvironment, TLSHandshakeTimeout: 15 * time.Second, ExpectContinueTimeout: 5 * time.Second, } } func (p *Proxy) Close() error { select { case <-p.closed: return ErrClosed default: close(p.closed) return nil } } func (p *Proxy) Start() error { l, err := net.ListenTCP("tcp", p.addr) if err != nil { return err } go p.Serve(l) return nil } func (p *Proxy) Serve(listener net.Listener) error { defer func() { _ = listener.Close() }() log.Info().Str("addr", listener.Addr().String()).Msg("proxy server listening") for { select { case <-p.closed: return nil default: } c, err := listener.Accept() if err != nil { return err } rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) ctx := newContext(c, rw, nil) if c, ok := c.(*net.TCPConn); ok { _ = c.SetKeepAlive(true) _ = c.SetKeepAlivePeriod(DefaultKeepAlivePeriod) } go p.handle(ctx) } } func (p *Proxy) handle(ctx *Context) { logger := ctx.log() defer log.OnCloseError(logger.Debug(), ctx.conn) logger.Info().Str("client", ctx.RemoteAddr().String()).Msg("new client connection") last := int64(0) for { select { case <-p.closed: return default: ses, err := p.handleRequest(ctx) if ses != nil { log := ses.log() log.Info(). Str("method", ses.request.Method). Str("url", ses.request.URL.String()). Str("status", ses.response.Status). Int64("size", ctx.conn.bytes-last). Msg("handled request") p.stats.AddLog(&stats.Log{ ClientIP: netutil.Host(ses.request.RemoteAddr), Request: stats.FromRequest(ses.request), Response: stats.FromResponse(ses.response).SetSize(ctx.conn.bytes - last), }) last = ctx.conn.bytes } if err != nil && !isClosing(err) || (ses != nil && ses.response != nil && ses.response.Close) { event := logger.Debug() if ctx.conn.bytes > 0 { event = event.Int64("size", ctx.conn.bytes) } event.Msg("closing client connection") return } } } } func (p *Proxy) handleRequest(ctx *Context) (ses *Session, err error) { logger := ctx.log() var request *http.Request if request, err = p.readRequest(ctx); err != nil { return } ses = newSession(ctx, request) p.cleanRequest(ses, request) logger.Debug().Str("method", request.Method).Str("url", request.URL.String()).Msg("handle request") if p.onRequest != nil { newRequest, newResponse := p.onRequest.HandleRequest(ses) if newRequest != nil { logger.Debug().Str("method", newRequest.Method).Str("url", newRequest.URL.String()).Msg("request override") ses.request = newRequest } if newResponse != nil { logger.Debug().Str("status", newResponse.Status).Msg("response override") ses.response = newResponse } } if ses.response == nil { // WebSocket request if ses.request.Header.Get(HeaderUpgrade) == "websocket" { return ses, p.handleTunnel(ses) } cleanHopByHopHeaders(ses.request.Header) // Proxy CONNECT request if ses.request.Method == http.MethodConnect { return p.handleConnect(ses) } if netutil.Port(ses.request.URL.Host) == p.addr.Port { // Plain API request ses.request.URL.Host = ses.request.Host return ses, p.admin.handleRequest(ses) } else if ses.response, err = p.transport.RoundTrip(ses.request); err != nil { // Plain HTTP request if p.config.ErrorHandler != nil { p.config.ErrorHandler.HandleError(ses, err) } ses.response = ErrorResponse(ses.request, err) } logger.Debug().Str("status", ses.response.Status).Msg("received response") cleanHopByHopHeaders(ses.response.Header) } ses.response.Close = true defer log.OnCloseError(logger.Debug(), ses.response.Body) return ses, p.writeResponse(ses) } func (p *Proxy) handleConnect(ses *Session) (next *Session, err error) { next = ses logger := ses.log() logger.Debug().Msgf("connecting to %s", ses.request.URL.Host) var c net.Conn if c, err = p.connect(ses, "tcp", ses.request.URL.Host); err != nil { logger.Error().Err(err).Msg("connect failed") if p.onError != nil { p.onError.HandleError(ses, err) } ses.response = ErrorResponse(ses.request, err) defer log.OnCloseError(logger.Debug(), ses.response.Body) _ = p.writeResponse(ses) return } defer func() { if err := c.Close(); err != nil { if p.onError != nil { p.onError.HandleError(ses, err) } } }() if p.canIntercept(ses.request) { logger.Debug().Msg("intercepting connection") ses.response = NewResponse(http.StatusOK, nil, ses.request) err = p.writeResponse(ses) log.OnCloseError(logger.Debug(), ses.response.Body) if err != nil { return } // Peek first byte b := make([]byte, 1) if _, err = io.ReadFull(ses.ctx.rw, b); err != nil { logger.Error().Err(err).Msg("error peeking CONNECT byte") return } // Drain buffered bytes b = append(b, make([]byte, ses.ctx.rw.Reader.Buffered())...) ses.ctx.rw.Reader.Read(b[1:]) r := &connReader{ Conn: ses.ctx.conn, Reader: io.MultiReader(bytes.NewBuffer(b), ses.ctx.conn), } if b[0] == 22 { // TLS handshake: https://tools.ietf.org/html/rfc5246#section-6.2.1 secure := tls.Server(r, p.authority.TLSConfig(ses.request.URL.Host)) if err = secure.Handshake(); err != nil { logger.Error().Err(err).Msg("error intercepting TLS connection: client handshake failed") return } rw := bufio.NewReadWriter(bufio.NewReader(secure), bufio.NewWriter(secure)) ctx := newContext(secure, rw, ses) return p.handleRequest(ctx) } rw := bufio.NewReadWriter(bufio.NewReader(r), bufio.NewWriter(r)) ctx := newContext(r, rw, ses) return p.handleRequest(ctx) } ses.response = NewResponse(http.StatusOK, nil, ses.request) defer log.OnCloseError(logger.Debug(), ses.response.Body) ses.response.ContentLength = -1 if err = p.writeResponse(ses); err != nil { return } logger.Debug().Msg("established CONNECT tunnel, proxying traffic") var wait sync.WaitGroup wait.Go(func() { copyStream(ses, c, ses.ctx.conn) }) wait.Go(func() { copyStream(ses, ses.ctx.conn, c) }) wait.Wait() logger.Debug().Msg("closed CONNECT tunnel") return } func (p *Proxy) handleTunnel(ses *Session) (err error) { logger := ses.log() logger.Debug().Msgf("connecting to %s", ses.request.URL.Host) var c net.Conn if c, err = p.connect(ses, "tcp", ses.request.URL.Host); err != nil { logger.Error().Err(err).Msg("connect failed") if p.onError != nil { p.onError.HandleError(ses, err) } ses.response = ErrorResponse(ses.request, err) defer log.OnCloseError(logger.Debug(), ses.response.Body) _ = p.writeResponse(ses) return } defer log.OnCloseError(logger.Debug(), c) if ses.ctx.IsTLS() { // Open a TLS client connection secure := tls.Client(c, &tls.Config{ ServerName: ses.request.URL.Host, GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { return nil, ErrClientCert }, }) if err = secure.Handshake(); err != nil { logger.Error().Err(err).Msg("TLS handshake failed") return } c = secure } if err = ses.request.Write(c); err != nil { logger.Error().Err(err).Msg("failed to write request") return } logger.Debug().Msg("established tunnel, proxying traffic") var wait sync.WaitGroup wait.Go(func() { copyStream(ses, c, ses.ctx.conn) }) wait.Go(func() { copyStream(ses, ses.ctx.conn, c) }) wait.Wait() logger.Debug().Msg("closed tunnel") return } func (p *Proxy) canIntercept(request *http.Request) bool { if permit := p.policy.PermitIntercept(request); permit != nil { return *permit } return true } /* func (p *Proxy) handleAPIRequest(ses *Session) error { if ses.request.URL.Path == "/ca.crt" && p.authority != nil { b := pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: p.authority.Certificate().Raw, }) ses.response = NewResponse(http.StatusOK, bytes.NewReader(b), ses.request) defer log.OnCloseError(logger.Debug(), ses.response.Body) ses.response.Close = true ses.response.Header.Set("Content-Type", "application/x-x509-ca-cert") ses.response.ContentLength = int64(len(b)) return p.writeResponse(ses) } ses.response = ErrorResponse(ses.request, errors.New("invalid API endpoint")) defer log.OnCloseError(logger.Debug(), ses.response.Body) ses.response.Close = true return p.writeResponse(ses) } */ func (p *Proxy) readRequest(ctx *Context) (request *http.Request, err error) { var ( done = make(chan *http.Request, 1) errs = make(chan error, 1) ) go func() { r, err := http.ReadRequest(ctx.rw.Reader) if err != nil { errs <- err } else { done <- r } }() select { case <-p.closed: return nil, ErrClosed case request = <-done: return case err = <-errs: return } } func (p *Proxy) cleanRequest(ses *Session, request *http.Request) { if request.URL.Host == "" { request.URL.Host = request.Host } // Ensure proper URL scheme if !strings.HasPrefix(request.URL.Scheme, "http") { request.URL.Scheme = "http" } if ses.ctx.IsTLS() { state := ses.ctx.conn.Conn.(*tls.Conn).ConnectionState() request.TLS = &state request.URL.Scheme = "https" } // Ensure proper RemoteAddr request.RemoteAddr = ses.ctx.RemoteAddr().String() // Ensure proper encoding if request.Header.Get(HeaderAcceptEncoding) != "" { // We only support gzip request.Header.Set(HeaderAcceptEncoding, "gzip") } } func (p *Proxy) writeResponse(ses *Session) (err error) { log := ses.log() if p.onResponse != nil { response := p.onResponse.HandleResponse(ses) if response != nil { log.Debug().Str("status", response.Status).Msg("response override") ses.response = response } } if err = ses.response.Write(ses.ctx); err != nil { log.Error().Err(err).Msg("error writing response back to client") } else if err = ses.ctx.Flush(); err != nil { log.Error().Err(err).Msg("error flushing response back to client") } return } func (p *Proxy) connect(ses *Session, network, address string) (c net.Conn, err error) { log := ses.log() log.Debug().Msgf("connect to %s://%s", network, address) if p.onConnect != nil { if c = p.onConnect.HandleConnect(ses, network, address); c != nil { log.Debug().Msg("connect override") return } } var host, port string if host, port, err = net.SplitHostPort(address); err != nil { return } var hosts []string if hosts, err = p.resolver.Lookup(context.Background(), host); err != nil { log.Warn().Err(err).Msg("connect failed: DNS lookup error") return } log.Debug().Str("address", hosts[0]).Msg("connect resolved address") return p.dial(network, net.JoinHostPort(hosts[0], port)) } var hopByHopHeaders = []string{ HeaderConnection, "Keep-Alive", "Proxy-Authenticate", "Proxy-Authorization", "Proxy-Connection", // Non-standard, but required for HTTP/2. "Te", "Trailer", "Transfer-Encoding", HeaderUpgrade, } func cleanHopByHopHeaders(header http.Header) { // Additional hop-by-hop headers may be specified in `Connection` headers. // http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-14#section-9.1 for _, values := range header[HeaderConnection] { for _, key := range strings.Split(values, ",") { header.Del(key) } } for _, key := range hopByHopHeaders { header.Del(key) } } // copyStream copies data from reader to writer func copyStream(ses *Session, w io.Writer, r io.Reader) { log := ses.log() if _, err := io.Copy(w, r); err != nil && !isClosing(err) { log.Error().Err(err).Msg("failed CONNECT tunnel") } else { log.Debug().Msg("finished copying CONNECT tunnel") } } func isClosing(err error) bool { if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, syscall.ECONNRESET) || err == ErrClosed { return true } if err, ok := err.(net.Error); ok && err.Timeout() { return true } // log.Debug().Msgf("not a closing error %T: %#+v", err, err) return false } func resolveInterfaceAddr(name string) (err error) { var iface *net.Interface if iface, err = net.InterfaceByName(name); err != nil { return } var addrs []net.Addr if addrs, err = iface.Addrs(); err != nil { return } for _, addr := range addrs { if addr, ok := addr.(*net.IPNet); ok && !addr.IP.IsUnspecified() { log.Warn().Msgf("addr %T: %s", addr, addr) } } return errors.New("nope; TODO") }