Fixed up the documentation and refactored Limit
This commit is contained in:
@@ -15,7 +15,8 @@ type AcceptFunc func(Direction, *Protocol) error
|
||||
// If no protocol could be detected, the accept function is called with a nil
|
||||
// argument to check if we should proceed.
|
||||
//
|
||||
// If the accept function returns false, the connection will be closed.
|
||||
// If the accept function returns an error, all future reads and writes on the
|
||||
// returned connection will return that error.
|
||||
func Limit(conn net.Conn, accept AcceptFunc) net.Conn {
|
||||
if accept == nil {
|
||||
// Nothing to do here.
|
||||
@@ -39,23 +40,32 @@ func (limiter *connLimiter) init(readData, writeData []byte) {
|
||||
limiter.acceptOnce.Do(func() {
|
||||
var (
|
||||
dir Direction
|
||||
data []byte
|
||||
srcPort, dstPort int
|
||||
localPort = getPortFromAddr(limiter.LocalAddr())
|
||||
remotePort = getPortFromAddr(limiter.RemoteAddr())
|
||||
protocol *Protocol
|
||||
)
|
||||
if readData != nil {
|
||||
// init called by initial read
|
||||
dir, data, srcPort, dstPort = Server, readData, getPortFromAddr(limiter.LocalAddr()), getPortFromAddr(limiter.RemoteAddr())
|
||||
// init called by initial read (from the server)
|
||||
dir = Server
|
||||
protocol, _, _ = Detect(dir, readData, localPort, remotePort)
|
||||
} else {
|
||||
// init called by initial write
|
||||
dir, data, srcPort, dstPort = Client, writeData, getPortFromAddr(limiter.RemoteAddr()), getPortFromAddr(limiter.LocalAddr())
|
||||
// init called by initial write (from the client)
|
||||
dir = Client
|
||||
protocol, _, _ = Detect(dir, writeData, remotePort, localPort)
|
||||
}
|
||||
protocol, _, _ := Detect(dir, data, srcPort, dstPort)
|
||||
if err := limiter.accept(dir, protocol); err != nil {
|
||||
limiter.acceptError.Store(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// NetConn implements the same method as [net/tls.Conn] to obtain the underlying [net.Conn].
|
||||
func (limiter *connLimiter) NetConn() net.Conn {
|
||||
return limiter.Conn
|
||||
}
|
||||
|
||||
// Read from the connection, if this is the first read then the data returned by the underlying
|
||||
// connection is used for protocol detection.
|
||||
func (limiter *connLimiter) Read(p []byte) (n int, err error) {
|
||||
var ok bool
|
||||
if err, ok = limiter.acceptError.Load().(error); ok && err != nil {
|
||||
@@ -67,6 +77,8 @@ func (limiter *connLimiter) Read(p []byte) (n int, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// Write to the connection, if this is the first write then the data is used for protocol detection
|
||||
// before it gets sent to the underlying connection.
|
||||
func (limiter *connLimiter) Write(p []byte) (n int, err error) {
|
||||
limiter.init(nil, p)
|
||||
var ok bool
|
||||
|
Reference in New Issue
Block a user