Checkpoint

This commit is contained in:
2025-10-06 22:25:23 +02:00
parent a23259cfdc
commit a254b306f2
48 changed files with 3327 additions and 212 deletions

View File

@@ -15,9 +15,12 @@ import (
"slices"
"strconv"
"strings"
"sync"
"syscall"
"time"
"git.maze.io/maze/styx/ca"
"git.maze.io/maze/styx/dataset"
"git.maze.io/maze/styx/internal/netutil"
"git.maze.io/maze/styx/logger"
"git.maze.io/maze/styx/stats"
@@ -26,6 +29,7 @@ import (
// Common HTTP headers.
const (
HeaderConnection = "Connection"
HeaderContentLength = "Content-Length"
HeaderContentType = "Content-Type"
HeaderDate = "Date"
HeaderForwarded = "Forwarded"
@@ -146,7 +150,17 @@ type Proxy struct {
// WebSocketIdleTimeout is the timeout for idle WebSocket connections.
WebSocketIdleTimeout time.Duration
mux *http.ServeMux
// CertificateAuthority can issue certificates for man-in-the-middle connections.
CertificateAuthority ca.CertificateAuthority
// Storage for resolving clients/groups
Storage dataset.Storage
mux *http.ServeMux
closed chan struct{}
closeOnce sync.Once
mu sync.RWMutex
listeners []net.Listener
}
// New [Proxy] with somewhat sane defaults.
@@ -157,6 +171,7 @@ func New() *Proxy {
IdleTimeout: DefaultIdleTimeout,
WebSocketIdleTimeout: DefaultWebSocketIdleTimeout,
mux: http.NewServeMux(),
closed: make(chan struct{}, 1),
}
// Make sure the roundtripper uses our dialers.
@@ -181,6 +196,55 @@ func New() *Proxy {
return p
}
func (p *Proxy) Close() error {
var closeListeners bool
p.closeOnce.Do(func() {
close(p.closed)
closeListeners = true
})
if closeListeners {
p.mu.RLock()
for _, l := range p.listeners {
_ = l.Close()
}
p.mu.RUnlock()
}
return nil
}
func (p *Proxy) isClosed() bool {
select {
case <-p.closed:
return true
default:
return false
}
}
func (p *Proxy) addListener(l net.Listener) {
if l == nil {
return
}
p.mu.Lock()
p.listeners = append(p.listeners, l)
p.mu.Unlock()
}
func (p *Proxy) removeListener(l net.Listener) {
if l == nil {
return
}
p.mu.Lock()
listeners := make([]net.Listener, 0, len(p.listeners)-1)
for _, o := range p.listeners {
if o != l {
listeners = append(listeners, o)
}
}
p.listeners = listeners
p.mu.Unlock()
}
// Handle installs a [http.Handler] into the internal mux.
func (p *Proxy) Handle(pattern string, handler http.Handler) {
p.mux.Handle(pattern, handler)
@@ -214,11 +278,23 @@ func (p *Proxy) dial(ctx context.Context, req *http.Request) (net.Conn, error) {
// Serve proxied connections on the specified listener.
func (p *Proxy) Serve(l net.Listener) error {
p.addListener(l)
defer p.removeListener(l)
for {
if p.isClosed() {
return nil
}
c, err := l.Accept()
if err != nil {
return err
}
if p.isClosed() {
_ = c.Close()
return nil
}
go p.handle(c)
}
}
@@ -229,6 +305,7 @@ func (p *Proxy) handle(nc net.Conn) {
ctx = NewContext(nc).(*proxyContext)
err error
)
defer func() {
if r := recover(); r != nil {
if err, ok := r.(error); ok {
@@ -266,6 +343,8 @@ func (p *Proxy) handle(nc net.Conn) {
// Propagate timeouts
ctx.SetIdleTimeout(p.IdleTimeout)
ctx.ca = p.CertificateAuthority
ctx.storage = p.Storage
for _, f := range p.OnConnect {
fc, err := f.HandleConn(ctx)
@@ -282,6 +361,15 @@ func (p *Proxy) handle(nc net.Conn) {
}
log := ctx.LogEntry()
if p.Storage != nil {
if client, err := p.Storage.ClientByIP(nc.RemoteAddr().(*net.TCPAddr).IP); err == nil {
log = log.Values(logger.Values{
"client_id": client.ID,
"client_network": client.String(),
"client_description": client.Description,
})
}
}
for {
if ctx.transparentTLS {
ctx.req = &http.Request{
@@ -344,7 +432,7 @@ func (p *Proxy) handle(nc net.Conn) {
}
if err = p.handleRequest(ctx); err != nil {
p.handleError(ctx, err, true)
p.handleError(ctx, err, !netutil.IsClosing(err))
return
}
@@ -511,7 +599,8 @@ func (p *Proxy) serveForward(ctx *proxyContext) (err error) {
_ = ctx.Close()
return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err)
}
} else {
}
if res != nil {
ctx.res = res
}
@@ -571,28 +660,44 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
return p.multiplex(ctx, srv)
}
func (p *Proxy) multiplex(ctx, srv Context) (err error) {
func (p *Proxy) multiplex(ctx, srv *proxyContext) (err error) {
var (
log = ctx.LogEntry().Value("server", srv.RemoteAddr().String())
errs = make(chan error, 1)
done = make(chan struct{}, 1)
)
go func(errs chan<- error) {
defer close(done)
if _, err := io.Copy(srv, ctx); err != nil {
if _, err := io.Copy(ctx, srv); err != nil && !netutil.IsClosing(err) {
log.Err(err).Trace("Multiplexing closed in client->server")
errs <- err
} else {
log.Trace("Multiplexing closed in client->server")
}
}(errs)
go func(errs chan<- error) {
if _, err := io.Copy(ctx, srv); err != nil {
defer close(done)
if _, err := io.Copy(srv, ctx); err != nil && !netutil.IsClosing(err) {
log.Err(err).Trace("Multiplexing closed in server->client")
errs <- err
} else {
log.Trace("Multiplexing closed in server->client")
}
}(errs)
defer func() {
log.Trace("Multiplexing done, force-closing client and server connections")
_ = ctx.Close()
_ = srv.Close()
}()
select {
case err = <-errs:
return
case <-done:
return
return io.EOF // multiplexing never recycles connection
case <-p.closed:
return io.EOF // server closed
}
}