Checkpoint
This commit is contained in:
@@ -14,6 +14,8 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/ca"
|
||||
"git.maze.io/maze/styx/dataset"
|
||||
"git.maze.io/maze/styx/logger"
|
||||
)
|
||||
|
||||
@@ -42,6 +44,13 @@ type Context interface {
|
||||
|
||||
// Response is the response that will be sent back to the client.
|
||||
Response() *http.Response
|
||||
|
||||
// Client group.
|
||||
Client() (dataset.Client, error)
|
||||
}
|
||||
|
||||
type WithCertificateAuthority interface {
|
||||
CertificateAuthority() ca.CertificateAuthority
|
||||
}
|
||||
|
||||
type countingReader struct {
|
||||
@@ -80,6 +89,9 @@ type proxyContext struct {
|
||||
req *http.Request
|
||||
res *http.Response
|
||||
idleTimeout time.Duration
|
||||
ca ca.CertificateAuthority
|
||||
storage dataset.Storage
|
||||
client dataset.Client
|
||||
}
|
||||
|
||||
// NewContext returns an initialized context for the provided [net.Conn].
|
||||
@@ -218,4 +230,28 @@ func (c *proxyContext) WriteHeader(code int) {
|
||||
//return c.res.Header.Write(c)
|
||||
}
|
||||
|
||||
func (c *proxyContext) CertificateAuthority() ca.CertificateAuthority {
|
||||
return c.ca
|
||||
}
|
||||
|
||||
func (c *proxyContext) Client() (dataset.Client, error) {
|
||||
if c.storage == nil {
|
||||
return dataset.Client{}, dataset.ErrNotExist{Object: "client"}
|
||||
}
|
||||
if !c.client.CreatedAt.Equal(time.Time{}) {
|
||||
return c.client, nil
|
||||
}
|
||||
|
||||
var err error
|
||||
switch addr := c.Conn.RemoteAddr().(type) {
|
||||
case *net.TCPAddr:
|
||||
c.client, err = c.storage.ClientByIP(addr.IP)
|
||||
case *net.UDPAddr:
|
||||
c.client, err = c.storage.ClientByIP(addr.IP)
|
||||
default:
|
||||
err = dataset.ErrNotExist{Object: "client"}
|
||||
}
|
||||
return c.client, err
|
||||
}
|
||||
|
||||
var _ Context = (*proxyContext)(nil)
|
||||
|
121
proxy/proxy.go
121
proxy/proxy.go
@@ -15,9 +15,12 @@ import (
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/ca"
|
||||
"git.maze.io/maze/styx/dataset"
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
"git.maze.io/maze/styx/logger"
|
||||
"git.maze.io/maze/styx/stats"
|
||||
@@ -26,6 +29,7 @@ import (
|
||||
// Common HTTP headers.
|
||||
const (
|
||||
HeaderConnection = "Connection"
|
||||
HeaderContentLength = "Content-Length"
|
||||
HeaderContentType = "Content-Type"
|
||||
HeaderDate = "Date"
|
||||
HeaderForwarded = "Forwarded"
|
||||
@@ -146,7 +150,17 @@ type Proxy struct {
|
||||
// WebSocketIdleTimeout is the timeout for idle WebSocket connections.
|
||||
WebSocketIdleTimeout time.Duration
|
||||
|
||||
mux *http.ServeMux
|
||||
// CertificateAuthority can issue certificates for man-in-the-middle connections.
|
||||
CertificateAuthority ca.CertificateAuthority
|
||||
|
||||
// Storage for resolving clients/groups
|
||||
Storage dataset.Storage
|
||||
|
||||
mux *http.ServeMux
|
||||
closed chan struct{}
|
||||
closeOnce sync.Once
|
||||
mu sync.RWMutex
|
||||
listeners []net.Listener
|
||||
}
|
||||
|
||||
// New [Proxy] with somewhat sane defaults.
|
||||
@@ -157,6 +171,7 @@ func New() *Proxy {
|
||||
IdleTimeout: DefaultIdleTimeout,
|
||||
WebSocketIdleTimeout: DefaultWebSocketIdleTimeout,
|
||||
mux: http.NewServeMux(),
|
||||
closed: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
// Make sure the roundtripper uses our dialers.
|
||||
@@ -181,6 +196,55 @@ func New() *Proxy {
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Proxy) Close() error {
|
||||
var closeListeners bool
|
||||
p.closeOnce.Do(func() {
|
||||
close(p.closed)
|
||||
closeListeners = true
|
||||
})
|
||||
if closeListeners {
|
||||
p.mu.RLock()
|
||||
for _, l := range p.listeners {
|
||||
_ = l.Close()
|
||||
}
|
||||
p.mu.RUnlock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Proxy) isClosed() bool {
|
||||
select {
|
||||
case <-p.closed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) addListener(l net.Listener) {
|
||||
if l == nil {
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
p.listeners = append(p.listeners, l)
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
func (p *Proxy) removeListener(l net.Listener) {
|
||||
if l == nil {
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
listeners := make([]net.Listener, 0, len(p.listeners)-1)
|
||||
for _, o := range p.listeners {
|
||||
if o != l {
|
||||
listeners = append(listeners, o)
|
||||
}
|
||||
}
|
||||
p.listeners = listeners
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
// Handle installs a [http.Handler] into the internal mux.
|
||||
func (p *Proxy) Handle(pattern string, handler http.Handler) {
|
||||
p.mux.Handle(pattern, handler)
|
||||
@@ -214,11 +278,23 @@ func (p *Proxy) dial(ctx context.Context, req *http.Request) (net.Conn, error) {
|
||||
|
||||
// Serve proxied connections on the specified listener.
|
||||
func (p *Proxy) Serve(l net.Listener) error {
|
||||
p.addListener(l)
|
||||
defer p.removeListener(l)
|
||||
for {
|
||||
if p.isClosed() {
|
||||
return nil
|
||||
}
|
||||
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if p.isClosed() {
|
||||
_ = c.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
go p.handle(c)
|
||||
}
|
||||
}
|
||||
@@ -229,6 +305,7 @@ func (p *Proxy) handle(nc net.Conn) {
|
||||
ctx = NewContext(nc).(*proxyContext)
|
||||
err error
|
||||
)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if err, ok := r.(error); ok {
|
||||
@@ -266,6 +343,8 @@ func (p *Proxy) handle(nc net.Conn) {
|
||||
|
||||
// Propagate timeouts
|
||||
ctx.SetIdleTimeout(p.IdleTimeout)
|
||||
ctx.ca = p.CertificateAuthority
|
||||
ctx.storage = p.Storage
|
||||
|
||||
for _, f := range p.OnConnect {
|
||||
fc, err := f.HandleConn(ctx)
|
||||
@@ -282,6 +361,15 @@ func (p *Proxy) handle(nc net.Conn) {
|
||||
}
|
||||
|
||||
log := ctx.LogEntry()
|
||||
if p.Storage != nil {
|
||||
if client, err := p.Storage.ClientByIP(nc.RemoteAddr().(*net.TCPAddr).IP); err == nil {
|
||||
log = log.Values(logger.Values{
|
||||
"client_id": client.ID,
|
||||
"client_network": client.String(),
|
||||
"client_description": client.Description,
|
||||
})
|
||||
}
|
||||
}
|
||||
for {
|
||||
if ctx.transparentTLS {
|
||||
ctx.req = &http.Request{
|
||||
@@ -344,7 +432,7 @@ func (p *Proxy) handle(nc net.Conn) {
|
||||
}
|
||||
|
||||
if err = p.handleRequest(ctx); err != nil {
|
||||
p.handleError(ctx, err, true)
|
||||
p.handleError(ctx, err, !netutil.IsClosing(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -511,7 +599,8 @@ func (p *Proxy) serveForward(ctx *proxyContext) (err error) {
|
||||
_ = ctx.Close()
|
||||
return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err)
|
||||
}
|
||||
} else {
|
||||
}
|
||||
if res != nil {
|
||||
ctx.res = res
|
||||
}
|
||||
|
||||
@@ -571,28 +660,44 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
|
||||
return p.multiplex(ctx, srv)
|
||||
}
|
||||
|
||||
func (p *Proxy) multiplex(ctx, srv Context) (err error) {
|
||||
func (p *Proxy) multiplex(ctx, srv *proxyContext) (err error) {
|
||||
var (
|
||||
log = ctx.LogEntry().Value("server", srv.RemoteAddr().String())
|
||||
errs = make(chan error, 1)
|
||||
done = make(chan struct{}, 1)
|
||||
)
|
||||
go func(errs chan<- error) {
|
||||
defer close(done)
|
||||
if _, err := io.Copy(srv, ctx); err != nil {
|
||||
if _, err := io.Copy(ctx, srv); err != nil && !netutil.IsClosing(err) {
|
||||
log.Err(err).Trace("Multiplexing closed in client->server")
|
||||
errs <- err
|
||||
} else {
|
||||
log.Trace("Multiplexing closed in client->server")
|
||||
}
|
||||
}(errs)
|
||||
|
||||
go func(errs chan<- error) {
|
||||
if _, err := io.Copy(ctx, srv); err != nil {
|
||||
defer close(done)
|
||||
if _, err := io.Copy(srv, ctx); err != nil && !netutil.IsClosing(err) {
|
||||
log.Err(err).Trace("Multiplexing closed in server->client")
|
||||
errs <- err
|
||||
} else {
|
||||
log.Trace("Multiplexing closed in server->client")
|
||||
}
|
||||
}(errs)
|
||||
|
||||
defer func() {
|
||||
log.Trace("Multiplexing done, force-closing client and server connections")
|
||||
_ = ctx.Close()
|
||||
_ = srv.Close()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err = <-errs:
|
||||
return
|
||||
case <-done:
|
||||
return
|
||||
return io.EOF // multiplexing never recycles connection
|
||||
case <-p.closed:
|
||||
return io.EOF // server closed
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user