Files
dpi/protocol/limit.go

90 lines
2.4 KiB
Go

package protocol
import (
"net"
"sync"
"sync/atomic"
)
// AcceptFunc receives a direction and a detected protocol.
type AcceptFunc func(Direction, *Protocol) error
// Limit the connection protocol, by running a detection after either side sends
// a banner within timeout.
//
// If no protocol could be detected, the accept function is called with a nil
// argument to check if we should proceed.
//
// If the accept function returns an error, all future reads and writes on the
// returned connection will return that error.
func Limit(conn net.Conn, accept AcceptFunc) net.Conn {
if accept == nil {
// Nothing to do here.
return conn
}
return &connLimiter{
Conn: conn,
accept: accept,
}
}
type connLimiter struct {
net.Conn
accept AcceptFunc
acceptOnce sync.Once
acceptError atomic.Value
}
func (limiter *connLimiter) init(readData, writeData []byte) {
limiter.acceptOnce.Do(func() {
var (
dir Direction
localPort = getPortFromAddr(limiter.LocalAddr())
remotePort = getPortFromAddr(limiter.RemoteAddr())
protocol *Protocol
)
if readData != nil {
// init called by initial read (from the server)
dir = Server
protocol, _, _ = Detect(dir, readData, localPort, remotePort)
} else {
// init called by initial write (from the client)
dir = Client
protocol, _, _ = Detect(dir, writeData, remotePort, localPort)
}
if err := limiter.accept(dir, protocol); err != nil {
limiter.acceptError.Store(err)
}
})
}
// NetConn implements the same method as [net/tls.Conn] to obtain the underlying [net.Conn].
func (limiter *connLimiter) NetConn() net.Conn {
return limiter.Conn
}
// Read from the connection, if this is the first read then the data returned by the underlying
// connection is used for protocol detection.
func (limiter *connLimiter) Read(p []byte) (n int, err error) {
var ok bool
if err, ok = limiter.acceptError.Load().(error); ok && err != nil {
return
}
if n, err = limiter.Conn.Read(p); n > 0 {
limiter.init(p[:n], nil)
}
return
}
// Write to the connection, if this is the first write then the data is used for protocol detection
// before it gets sent to the underlying connection.
func (limiter *connLimiter) Write(p []byte) (n int, err error) {
limiter.init(nil, p)
var ok bool
if err, ok = limiter.acceptError.Load().(error); ok && err != nil {
return
}
return limiter.Conn.Write(p)
}