Fixed up the documentation and refactored Limit

This commit is contained in:
2025-10-10 14:39:21 +02:00
parent ddb0af36bb
commit 7b5578859e

View File

@@ -15,7 +15,8 @@ type AcceptFunc func(Direction, *Protocol) error
// If no protocol could be detected, the accept function is called with a nil // If no protocol could be detected, the accept function is called with a nil
// argument to check if we should proceed. // argument to check if we should proceed.
// //
// If the accept function returns false, the connection will be closed. // If the accept function returns an error, all future reads and writes on the
// returned connection will return that error.
func Limit(conn net.Conn, accept AcceptFunc) net.Conn { func Limit(conn net.Conn, accept AcceptFunc) net.Conn {
if accept == nil { if accept == nil {
// Nothing to do here. // Nothing to do here.
@@ -39,23 +40,32 @@ func (limiter *connLimiter) init(readData, writeData []byte) {
limiter.acceptOnce.Do(func() { limiter.acceptOnce.Do(func() {
var ( var (
dir Direction dir Direction
data []byte localPort = getPortFromAddr(limiter.LocalAddr())
srcPort, dstPort int remotePort = getPortFromAddr(limiter.RemoteAddr())
protocol *Protocol
) )
if readData != nil { if readData != nil {
// init called by initial read // init called by initial read (from the server)
dir, data, srcPort, dstPort = Server, readData, getPortFromAddr(limiter.LocalAddr()), getPortFromAddr(limiter.RemoteAddr()) dir = Server
protocol, _, _ = Detect(dir, readData, localPort, remotePort)
} else { } else {
// init called by initial write // init called by initial write (from the client)
dir, data, srcPort, dstPort = Client, writeData, getPortFromAddr(limiter.RemoteAddr()), getPortFromAddr(limiter.LocalAddr()) dir = Client
protocol, _, _ = Detect(dir, writeData, remotePort, localPort)
} }
protocol, _, _ := Detect(dir, data, srcPort, dstPort)
if err := limiter.accept(dir, protocol); err != nil { if err := limiter.accept(dir, protocol); err != nil {
limiter.acceptError.Store(err) limiter.acceptError.Store(err)
} }
}) })
} }
// NetConn implements the same method as [net/tls.Conn] to obtain the underlying [net.Conn].
func (limiter *connLimiter) NetConn() net.Conn {
return limiter.Conn
}
// Read from the connection, if this is the first read then the data returned by the underlying
// connection is used for protocol detection.
func (limiter *connLimiter) Read(p []byte) (n int, err error) { func (limiter *connLimiter) Read(p []byte) (n int, err error) {
var ok bool var ok bool
if err, ok = limiter.acceptError.Load().(error); ok && err != nil { if err, ok = limiter.acceptError.Load().(error); ok && err != nil {
@@ -67,6 +77,8 @@ func (limiter *connLimiter) Read(p []byte) (n int, err error) {
return return
} }
// Write to the connection, if this is the first write then the data is used for protocol detection
// before it gets sent to the underlying connection.
func (limiter *connLimiter) Write(p []byte) (n int, err error) { func (limiter *connLimiter) Write(p []byte) (n int, err error) {
limiter.init(nil, p) limiter.init(nil, p)
var ok bool var ok bool