Files
dpi/protocol/limit.go
2025-10-08 20:53:56 +02:00

77 lines
1.6 KiB
Go

package protocol
import (
"net"
"sync"
"sync/atomic"
)
// AcceptFunc receives a direction and a detected protocol.
type AcceptFunc func(Direction, *Protocol) error
// Limit the connection protocol, by running a detection after either side sends
// a banner within timeout.
//
// If no protocol could be detected, the accept function is called with a nil
// argument to check if we should proceed.
//
// If the accept function returns false, the connection will be closed.
func Limit(conn net.Conn, accept AcceptFunc) net.Conn {
if accept == nil {
// Nothing to do here.
return conn
}
return &connLimiter{
Conn: conn,
accept: accept,
}
}
type connLimiter struct {
net.Conn
accept AcceptFunc
acceptOnce sync.Once
acceptError atomic.Value
}
func (l *connLimiter) init(readData, writeData []byte) {
l.acceptOnce.Do(func() {
var (
dir Direction
data []byte
)
if readData != nil {
// init called by initial read
dir, data = Server, readData
} else {
// init called by initial write
dir, data = Client, writeData
}
protocol, _ := Detect(dir, data)
if err := l.accept(dir, protocol); err != nil {
l.acceptError.Store(err)
}
})
}
func (l *connLimiter) Read(p []byte) (n int, err error) {
var ok bool
if err, ok = l.acceptError.Load().(error); ok && err != nil {
return
}
if n, err = l.Conn.Read(p); n > 0 {
l.init(p[:n], nil)
}
return
}
func (l *connLimiter) Write(p []byte) (n int, err error) {
l.init(nil, p)
var ok bool
if err, ok = l.acceptError.Load().(error); ok && err != nil {
return
}
return l.Conn.Write(p)
}