package protocol import ( "net" "strconv" "sync/atomic" "time" ) // Intercepted is the result returned by [Interceptor.Detect]. type Intercepted struct { Direction Direction Protocol *Protocol Confidence float64 Error error } // Interceptor intercepts reads from client or server. type Interceptor struct { clientPort int clientBytes chan []byte clientReader *readInterceptor serverPort int serverBytes chan []byte serverReader *readInterceptor } // NewInterceptor creates a new (transparent) protocol interceptor. func NewInterceptor() *Interceptor { return &Interceptor{ clientBytes: make(chan []byte, 1), serverBytes: make(chan []byte, 1), } } type readInterceptor struct { net.Conn bytes chan []byte once atomic.Bool } func newReadInterceptor(c net.Conn, bytes chan []byte) *readInterceptor { return &readInterceptor{ Conn: c, bytes: bytes, } } // Cancel any future Read interceptions and closes the channel. func (r *readInterceptor) Cancel() { if r == nil { return } r.once.Store(true) close(r.bytes) } func (r *readInterceptor) Read(p []byte) (n int, err error) { if r.once.CompareAndSwap(false, true) { if n, err = r.Conn.Read(p); n > 0 { // We create a copy, since the Read caller may modify p // immediately after reading. data := make([]byte, n) copy(data, p[:n]) // Buffer the bytes in the channel. r.bytes <- data } return } return r.Conn.Read(p) } // Client binds the client connection to the interceptor. func (i *Interceptor) Client(c net.Conn) net.Conn { if ri, ok := c.(*readInterceptor); ok { return ri } i.clientPort = getPortFromAddr(c.RemoteAddr()) i.clientReader = newReadInterceptor(c, i.clientBytes) return i.clientReader } // Server binds the server connection to the interceptor. func (i *Interceptor) Server(c net.Conn) net.Conn { if ri, ok := c.(*readInterceptor); ok { return ri } i.serverPort = getPortFromAddr(c.RemoteAddr()) i.serverReader = newReadInterceptor(c, i.serverBytes) return i.serverReader } // Detect runs protocol detection on the previously bound Client and Server connection. // // It waits until either the client or the server performs a read operation, // which is then used for running protocol detection. If the read operation // takes longer than timeout, an error is returned. // // The returned channel always yields one result and is then closed. func (i *Interceptor) Detect(timeout time.Duration) <-chan *Intercepted { var interceptc = make(chan *Intercepted, 1) go func() { // Make sure all channels are closed once we finish processing. defer close(interceptc) defer i.clientReader.Cancel() defer i.serverReader.Cancel() select { case <-time.After(timeout): // timeout interceptc <- &Intercepted{ Error: ErrTimeout, } case data := <-i.clientBytes: // client sent banner p, c, err := Detect(Client, data, i.clientPort, i.serverPort) interceptc <- &Intercepted{ Direction: Client, Protocol: p, Confidence: c, Error: err, } case data := <-i.serverBytes: // server sent banner p, c, err := Detect(Server, data, i.serverPort, i.clientPort) interceptc <- &Intercepted{ Direction: Server, Protocol: p, Confidence: c, Error: err, } } }() return interceptc } func getPortFromAddr(addr net.Addr) int { switch a := addr.(type) { case *net.TCPAddr: return a.Port case *net.UDPAddr: return a.Port case *net.IPAddr: // IPAddr doesn't have a port return 0 default: // Fallback to parsing _, service, err := net.SplitHostPort(addr.String()) if err != nil { return 0 } if port, err := strconv.Atoi(service); err == nil { return port } if port, err := net.LookupPort(addr.Network(), service); err == nil { return port } return 0 } }