Checkpoint

This commit is contained in:
2025-10-06 22:25:23 +02:00
parent a23259cfdc
commit a254b306f2
48 changed files with 3327 additions and 212 deletions

View File

@@ -14,6 +14,8 @@ import (
"sync/atomic"
"time"
"git.maze.io/maze/styx/ca"
"git.maze.io/maze/styx/dataset"
"git.maze.io/maze/styx/logger"
)
@@ -42,6 +44,13 @@ type Context interface {
// Response is the response that will be sent back to the client.
Response() *http.Response
// Client group.
Client() (dataset.Client, error)
}
type WithCertificateAuthority interface {
CertificateAuthority() ca.CertificateAuthority
}
type countingReader struct {
@@ -80,6 +89,9 @@ type proxyContext struct {
req *http.Request
res *http.Response
idleTimeout time.Duration
ca ca.CertificateAuthority
storage dataset.Storage
client dataset.Client
}
// NewContext returns an initialized context for the provided [net.Conn].
@@ -218,4 +230,28 @@ func (c *proxyContext) WriteHeader(code int) {
//return c.res.Header.Write(c)
}
func (c *proxyContext) CertificateAuthority() ca.CertificateAuthority {
return c.ca
}
func (c *proxyContext) Client() (dataset.Client, error) {
if c.storage == nil {
return dataset.Client{}, dataset.ErrNotExist{Object: "client"}
}
if !c.client.CreatedAt.Equal(time.Time{}) {
return c.client, nil
}
var err error
switch addr := c.Conn.RemoteAddr().(type) {
case *net.TCPAddr:
c.client, err = c.storage.ClientByIP(addr.IP)
case *net.UDPAddr:
c.client, err = c.storage.ClientByIP(addr.IP)
default:
err = dataset.ErrNotExist{Object: "client"}
}
return c.client, err
}
var _ Context = (*proxyContext)(nil)

View File

@@ -15,9 +15,12 @@ import (
"slices"
"strconv"
"strings"
"sync"
"syscall"
"time"
"git.maze.io/maze/styx/ca"
"git.maze.io/maze/styx/dataset"
"git.maze.io/maze/styx/internal/netutil"
"git.maze.io/maze/styx/logger"
"git.maze.io/maze/styx/stats"
@@ -26,6 +29,7 @@ import (
// Common HTTP headers.
const (
HeaderConnection = "Connection"
HeaderContentLength = "Content-Length"
HeaderContentType = "Content-Type"
HeaderDate = "Date"
HeaderForwarded = "Forwarded"
@@ -146,7 +150,17 @@ type Proxy struct {
// WebSocketIdleTimeout is the timeout for idle WebSocket connections.
WebSocketIdleTimeout time.Duration
mux *http.ServeMux
// CertificateAuthority can issue certificates for man-in-the-middle connections.
CertificateAuthority ca.CertificateAuthority
// Storage for resolving clients/groups
Storage dataset.Storage
mux *http.ServeMux
closed chan struct{}
closeOnce sync.Once
mu sync.RWMutex
listeners []net.Listener
}
// New [Proxy] with somewhat sane defaults.
@@ -157,6 +171,7 @@ func New() *Proxy {
IdleTimeout: DefaultIdleTimeout,
WebSocketIdleTimeout: DefaultWebSocketIdleTimeout,
mux: http.NewServeMux(),
closed: make(chan struct{}, 1),
}
// Make sure the roundtripper uses our dialers.
@@ -181,6 +196,55 @@ func New() *Proxy {
return p
}
func (p *Proxy) Close() error {
var closeListeners bool
p.closeOnce.Do(func() {
close(p.closed)
closeListeners = true
})
if closeListeners {
p.mu.RLock()
for _, l := range p.listeners {
_ = l.Close()
}
p.mu.RUnlock()
}
return nil
}
func (p *Proxy) isClosed() bool {
select {
case <-p.closed:
return true
default:
return false
}
}
func (p *Proxy) addListener(l net.Listener) {
if l == nil {
return
}
p.mu.Lock()
p.listeners = append(p.listeners, l)
p.mu.Unlock()
}
func (p *Proxy) removeListener(l net.Listener) {
if l == nil {
return
}
p.mu.Lock()
listeners := make([]net.Listener, 0, len(p.listeners)-1)
for _, o := range p.listeners {
if o != l {
listeners = append(listeners, o)
}
}
p.listeners = listeners
p.mu.Unlock()
}
// Handle installs a [http.Handler] into the internal mux.
func (p *Proxy) Handle(pattern string, handler http.Handler) {
p.mux.Handle(pattern, handler)
@@ -214,11 +278,23 @@ func (p *Proxy) dial(ctx context.Context, req *http.Request) (net.Conn, error) {
// Serve proxied connections on the specified listener.
func (p *Proxy) Serve(l net.Listener) error {
p.addListener(l)
defer p.removeListener(l)
for {
if p.isClosed() {
return nil
}
c, err := l.Accept()
if err != nil {
return err
}
if p.isClosed() {
_ = c.Close()
return nil
}
go p.handle(c)
}
}
@@ -229,6 +305,7 @@ func (p *Proxy) handle(nc net.Conn) {
ctx = NewContext(nc).(*proxyContext)
err error
)
defer func() {
if r := recover(); r != nil {
if err, ok := r.(error); ok {
@@ -266,6 +343,8 @@ func (p *Proxy) handle(nc net.Conn) {
// Propagate timeouts
ctx.SetIdleTimeout(p.IdleTimeout)
ctx.ca = p.CertificateAuthority
ctx.storage = p.Storage
for _, f := range p.OnConnect {
fc, err := f.HandleConn(ctx)
@@ -282,6 +361,15 @@ func (p *Proxy) handle(nc net.Conn) {
}
log := ctx.LogEntry()
if p.Storage != nil {
if client, err := p.Storage.ClientByIP(nc.RemoteAddr().(*net.TCPAddr).IP); err == nil {
log = log.Values(logger.Values{
"client_id": client.ID,
"client_network": client.String(),
"client_description": client.Description,
})
}
}
for {
if ctx.transparentTLS {
ctx.req = &http.Request{
@@ -344,7 +432,7 @@ func (p *Proxy) handle(nc net.Conn) {
}
if err = p.handleRequest(ctx); err != nil {
p.handleError(ctx, err, true)
p.handleError(ctx, err, !netutil.IsClosing(err))
return
}
@@ -511,7 +599,8 @@ func (p *Proxy) serveForward(ctx *proxyContext) (err error) {
_ = ctx.Close()
return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err)
}
} else {
}
if res != nil {
ctx.res = res
}
@@ -571,28 +660,44 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
return p.multiplex(ctx, srv)
}
func (p *Proxy) multiplex(ctx, srv Context) (err error) {
func (p *Proxy) multiplex(ctx, srv *proxyContext) (err error) {
var (
log = ctx.LogEntry().Value("server", srv.RemoteAddr().String())
errs = make(chan error, 1)
done = make(chan struct{}, 1)
)
go func(errs chan<- error) {
defer close(done)
if _, err := io.Copy(srv, ctx); err != nil {
if _, err := io.Copy(ctx, srv); err != nil && !netutil.IsClosing(err) {
log.Err(err).Trace("Multiplexing closed in client->server")
errs <- err
} else {
log.Trace("Multiplexing closed in client->server")
}
}(errs)
go func(errs chan<- error) {
if _, err := io.Copy(ctx, srv); err != nil {
defer close(done)
if _, err := io.Copy(srv, ctx); err != nil && !netutil.IsClosing(err) {
log.Err(err).Trace("Multiplexing closed in server->client")
errs <- err
} else {
log.Trace("Multiplexing closed in server->client")
}
}(errs)
defer func() {
log.Trace("Multiplexing done, force-closing client and server connections")
_ = ctx.Close()
_ = srv.Close()
}()
select {
case err = <-errs:
return
case <-done:
return
return io.EOF // multiplexing never recycles connection
case <-p.closed:
return io.EOF // server closed
}
}