Checkpoint
This commit is contained in:
121
proxy/proxy.go
121
proxy/proxy.go
@@ -15,9 +15,12 @@ import (
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/ca"
|
||||
"git.maze.io/maze/styx/dataset"
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
"git.maze.io/maze/styx/logger"
|
||||
"git.maze.io/maze/styx/stats"
|
||||
@@ -26,6 +29,7 @@ import (
|
||||
// Common HTTP headers.
|
||||
const (
|
||||
HeaderConnection = "Connection"
|
||||
HeaderContentLength = "Content-Length"
|
||||
HeaderContentType = "Content-Type"
|
||||
HeaderDate = "Date"
|
||||
HeaderForwarded = "Forwarded"
|
||||
@@ -146,7 +150,17 @@ type Proxy struct {
|
||||
// WebSocketIdleTimeout is the timeout for idle WebSocket connections.
|
||||
WebSocketIdleTimeout time.Duration
|
||||
|
||||
mux *http.ServeMux
|
||||
// CertificateAuthority can issue certificates for man-in-the-middle connections.
|
||||
CertificateAuthority ca.CertificateAuthority
|
||||
|
||||
// Storage for resolving clients/groups
|
||||
Storage dataset.Storage
|
||||
|
||||
mux *http.ServeMux
|
||||
closed chan struct{}
|
||||
closeOnce sync.Once
|
||||
mu sync.RWMutex
|
||||
listeners []net.Listener
|
||||
}
|
||||
|
||||
// New [Proxy] with somewhat sane defaults.
|
||||
@@ -157,6 +171,7 @@ func New() *Proxy {
|
||||
IdleTimeout: DefaultIdleTimeout,
|
||||
WebSocketIdleTimeout: DefaultWebSocketIdleTimeout,
|
||||
mux: http.NewServeMux(),
|
||||
closed: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
// Make sure the roundtripper uses our dialers.
|
||||
@@ -181,6 +196,55 @@ func New() *Proxy {
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Proxy) Close() error {
|
||||
var closeListeners bool
|
||||
p.closeOnce.Do(func() {
|
||||
close(p.closed)
|
||||
closeListeners = true
|
||||
})
|
||||
if closeListeners {
|
||||
p.mu.RLock()
|
||||
for _, l := range p.listeners {
|
||||
_ = l.Close()
|
||||
}
|
||||
p.mu.RUnlock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Proxy) isClosed() bool {
|
||||
select {
|
||||
case <-p.closed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) addListener(l net.Listener) {
|
||||
if l == nil {
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
p.listeners = append(p.listeners, l)
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
func (p *Proxy) removeListener(l net.Listener) {
|
||||
if l == nil {
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
listeners := make([]net.Listener, 0, len(p.listeners)-1)
|
||||
for _, o := range p.listeners {
|
||||
if o != l {
|
||||
listeners = append(listeners, o)
|
||||
}
|
||||
}
|
||||
p.listeners = listeners
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
// Handle installs a [http.Handler] into the internal mux.
|
||||
func (p *Proxy) Handle(pattern string, handler http.Handler) {
|
||||
p.mux.Handle(pattern, handler)
|
||||
@@ -214,11 +278,23 @@ func (p *Proxy) dial(ctx context.Context, req *http.Request) (net.Conn, error) {
|
||||
|
||||
// Serve proxied connections on the specified listener.
|
||||
func (p *Proxy) Serve(l net.Listener) error {
|
||||
p.addListener(l)
|
||||
defer p.removeListener(l)
|
||||
for {
|
||||
if p.isClosed() {
|
||||
return nil
|
||||
}
|
||||
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if p.isClosed() {
|
||||
_ = c.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
go p.handle(c)
|
||||
}
|
||||
}
|
||||
@@ -229,6 +305,7 @@ func (p *Proxy) handle(nc net.Conn) {
|
||||
ctx = NewContext(nc).(*proxyContext)
|
||||
err error
|
||||
)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if err, ok := r.(error); ok {
|
||||
@@ -266,6 +343,8 @@ func (p *Proxy) handle(nc net.Conn) {
|
||||
|
||||
// Propagate timeouts
|
||||
ctx.SetIdleTimeout(p.IdleTimeout)
|
||||
ctx.ca = p.CertificateAuthority
|
||||
ctx.storage = p.Storage
|
||||
|
||||
for _, f := range p.OnConnect {
|
||||
fc, err := f.HandleConn(ctx)
|
||||
@@ -282,6 +361,15 @@ func (p *Proxy) handle(nc net.Conn) {
|
||||
}
|
||||
|
||||
log := ctx.LogEntry()
|
||||
if p.Storage != nil {
|
||||
if client, err := p.Storage.ClientByIP(nc.RemoteAddr().(*net.TCPAddr).IP); err == nil {
|
||||
log = log.Values(logger.Values{
|
||||
"client_id": client.ID,
|
||||
"client_network": client.String(),
|
||||
"client_description": client.Description,
|
||||
})
|
||||
}
|
||||
}
|
||||
for {
|
||||
if ctx.transparentTLS {
|
||||
ctx.req = &http.Request{
|
||||
@@ -344,7 +432,7 @@ func (p *Proxy) handle(nc net.Conn) {
|
||||
}
|
||||
|
||||
if err = p.handleRequest(ctx); err != nil {
|
||||
p.handleError(ctx, err, true)
|
||||
p.handleError(ctx, err, !netutil.IsClosing(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -511,7 +599,8 @@ func (p *Proxy) serveForward(ctx *proxyContext) (err error) {
|
||||
_ = ctx.Close()
|
||||
return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err)
|
||||
}
|
||||
} else {
|
||||
}
|
||||
if res != nil {
|
||||
ctx.res = res
|
||||
}
|
||||
|
||||
@@ -571,28 +660,44 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
|
||||
return p.multiplex(ctx, srv)
|
||||
}
|
||||
|
||||
func (p *Proxy) multiplex(ctx, srv Context) (err error) {
|
||||
func (p *Proxy) multiplex(ctx, srv *proxyContext) (err error) {
|
||||
var (
|
||||
log = ctx.LogEntry().Value("server", srv.RemoteAddr().String())
|
||||
errs = make(chan error, 1)
|
||||
done = make(chan struct{}, 1)
|
||||
)
|
||||
go func(errs chan<- error) {
|
||||
defer close(done)
|
||||
if _, err := io.Copy(srv, ctx); err != nil {
|
||||
if _, err := io.Copy(ctx, srv); err != nil && !netutil.IsClosing(err) {
|
||||
log.Err(err).Trace("Multiplexing closed in client->server")
|
||||
errs <- err
|
||||
} else {
|
||||
log.Trace("Multiplexing closed in client->server")
|
||||
}
|
||||
}(errs)
|
||||
|
||||
go func(errs chan<- error) {
|
||||
if _, err := io.Copy(ctx, srv); err != nil {
|
||||
defer close(done)
|
||||
if _, err := io.Copy(srv, ctx); err != nil && !netutil.IsClosing(err) {
|
||||
log.Err(err).Trace("Multiplexing closed in server->client")
|
||||
errs <- err
|
||||
} else {
|
||||
log.Trace("Multiplexing closed in server->client")
|
||||
}
|
||||
}(errs)
|
||||
|
||||
defer func() {
|
||||
log.Trace("Multiplexing done, force-closing client and server connections")
|
||||
_ = ctx.Close()
|
||||
_ = srv.Close()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err = <-errs:
|
||||
return
|
||||
case <-done:
|
||||
return
|
||||
return io.EOF // multiplexing never recycles connection
|
||||
case <-p.closed:
|
||||
return io.EOF // server closed
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user