package protocol import ( "net" "sync/atomic" "time" ) // Intercepted is the result returned by [Interceptor.Detect]. type Intercepted struct { Direction Direction Protocol *Protocol Error error } // Interceptor intercepts reads from client or server. type Interceptor struct { clientBytes chan []byte clientReader *readInterceptor serverBytes chan []byte serverReader *readInterceptor } // NewInterceptor creates a new (transparent) protocol interceptor. func NewInterceptor() *Interceptor { return &Interceptor{ clientBytes: make(chan []byte, 1), serverBytes: make(chan []byte, 1), } } type readInterceptor struct { net.Conn bytes chan []byte once atomic.Bool } func newReadInterceptor(c net.Conn, bytes chan []byte) *readInterceptor { return &readInterceptor{ Conn: c, bytes: bytes, } } // Cancel any future Read interceptions and closes the channel. func (r *readInterceptor) Cancel() { if r == nil { return } r.once.Store(true) close(r.bytes) } func (r *readInterceptor) Read(p []byte) (n int, err error) { if r.once.CompareAndSwap(false, true) { if n, err = r.Conn.Read(p); n > 0 { // We create a copy, since the Read caller may modify p // immediately after reading. data := make([]byte, n) copy(data, p[:n]) // Buffer the bytes in the channel. r.bytes <- data } return } return r.Conn.Read(p) } // Client binds the client connection to the interceptor. func (i *Interceptor) Client(c net.Conn) net.Conn { if ri, ok := c.(*readInterceptor); ok { return ri } i.clientReader = newReadInterceptor(c, i.clientBytes) return i.clientReader } // Server binds the server connection to the interceptor. func (i *Interceptor) Server(c net.Conn) net.Conn { if ri, ok := c.(*readInterceptor); ok { return ri } i.serverReader = newReadInterceptor(c, i.serverBytes) return i.serverReader } // Detect runs protocol detection on the previously bound Client and Server connection. // // It waits until either the client or the server performs a read operation, // which is then used for running protocol detection. If the read operation // takes longer than timeout, an error is returned. // // The returned channel always yields one result and is then closed. func (i *Interceptor) Detect(timeout time.Duration) <-chan *Intercepted { var interceptc = make(chan *Intercepted, 1) go func() { // Make sure all channels are closed once we finish processing. defer close(interceptc) defer i.clientReader.Cancel() defer i.serverReader.Cancel() select { case <-time.After(timeout): // timeout interceptc <- &Intercepted{ Error: ErrTimeout, } case data := <-i.clientBytes: // client sent banner p, err := Detect(Client, data) interceptc <- &Intercepted{ Direction: Client, Protocol: p, Error: err, } case data := <-i.serverBytes: // server sent banner p, err := Detect(Server, data) interceptc <- &Intercepted{ Direction: Server, Protocol: p, Error: err, } } }() return interceptc }