package protocol import ( "net" "sync" "sync/atomic" ) // AcceptFunc receives a direction and a detected protocol. type AcceptFunc func(Direction, *Protocol) error // Limit the connection protocol, by running a detection after either side sends // a banner within timeout. // // If no protocol could be detected, the accept function is called with a nil // argument to check if we should proceed. // // If the accept function returns an error, all future reads and writes on the // returned connection will return that error. func Limit(conn net.Conn, accept AcceptFunc) net.Conn { if accept == nil { // Nothing to do here. return conn } return &connLimiter{ Conn: conn, accept: accept, } } type connLimiter struct { net.Conn accept AcceptFunc acceptOnce sync.Once acceptError atomic.Value } func (limiter *connLimiter) init(readData, writeData []byte) { limiter.acceptOnce.Do(func() { var ( dir Direction localPort = getPortFromAddr(limiter.LocalAddr()) remotePort = getPortFromAddr(limiter.RemoteAddr()) protocol *Protocol ) if readData != nil { // init called by initial read (from the server) dir = Server protocol, _, _ = Detect(dir, readData, localPort, remotePort) } else { // init called by initial write (from the client) dir = Client protocol, _, _ = Detect(dir, writeData, remotePort, localPort) } if err := limiter.accept(dir, protocol); err != nil { limiter.acceptError.Store(err) } }) } // NetConn implements the same method as [net/tls.Conn] to obtain the underlying [net.Conn]. func (limiter *connLimiter) NetConn() net.Conn { return limiter.Conn } // Read from the connection, if this is the first read then the data returned by the underlying // connection is used for protocol detection. func (limiter *connLimiter) Read(p []byte) (n int, err error) { var ok bool if err, ok = limiter.acceptError.Load().(error); ok && err != nil { return } if n, err = limiter.Conn.Read(p); n > 0 { limiter.init(p[:n], nil) } return } // Write to the connection, if this is the first write then the data is used for protocol detection // before it gets sent to the underlying connection. func (limiter *connLimiter) Write(p []byte) (n int, err error) { limiter.init(nil, p) var ok bool if err, ok = limiter.acceptError.Load().(error); ok && err != nil { return } return limiter.Conn.Write(p) }