diff --git a/protocol/limit.go b/protocol/limit.go index 251ede6..9b72877 100644 --- a/protocol/limit.go +++ b/protocol/limit.go @@ -15,7 +15,8 @@ type AcceptFunc func(Direction, *Protocol) error // If no protocol could be detected, the accept function is called with a nil // argument to check if we should proceed. // -// If the accept function returns false, the connection will be closed. +// If the accept function returns an error, all future reads and writes on the +// returned connection will return that error. func Limit(conn net.Conn, accept AcceptFunc) net.Conn { if accept == nil { // Nothing to do here. @@ -38,24 +39,33 @@ type connLimiter struct { func (limiter *connLimiter) init(readData, writeData []byte) { limiter.acceptOnce.Do(func() { var ( - dir Direction - data []byte - srcPort, dstPort int + dir Direction + localPort = getPortFromAddr(limiter.LocalAddr()) + remotePort = getPortFromAddr(limiter.RemoteAddr()) + protocol *Protocol ) if readData != nil { - // init called by initial read - dir, data, srcPort, dstPort = Server, readData, getPortFromAddr(limiter.LocalAddr()), getPortFromAddr(limiter.RemoteAddr()) + // init called by initial read (from the server) + dir = Server + protocol, _, _ = Detect(dir, readData, localPort, remotePort) } else { - // init called by initial write - dir, data, srcPort, dstPort = Client, writeData, getPortFromAddr(limiter.RemoteAddr()), getPortFromAddr(limiter.LocalAddr()) + // init called by initial write (from the client) + dir = Client + protocol, _, _ = Detect(dir, writeData, remotePort, localPort) } - protocol, _, _ := Detect(dir, data, srcPort, dstPort) if err := limiter.accept(dir, protocol); err != nil { limiter.acceptError.Store(err) } }) } +// NetConn implements the same method as [net/tls.Conn] to obtain the underlying [net.Conn]. +func (limiter *connLimiter) NetConn() net.Conn { + return limiter.Conn +} + +// Read from the connection, if this is the first read then the data returned by the underlying +// connection is used for protocol detection. func (limiter *connLimiter) Read(p []byte) (n int, err error) { var ok bool if err, ok = limiter.acceptError.Load().(error); ok && err != nil { @@ -67,6 +77,8 @@ func (limiter *connLimiter) Read(p []byte) (n int, err error) { return } +// Write to the connection, if this is the first write then the data is used for protocol detection +// before it gets sent to the underlying connection. func (limiter *connLimiter) Write(p []byte) (n int, err error) { limiter.init(nil, p) var ok bool