78 lines
1.9 KiB
Go
78 lines
1.9 KiB
Go
package protocol
|
|
|
|
import (
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
)
|
|
|
|
// AcceptFunc receives a direction and a detected protocol.
|
|
type AcceptFunc func(Direction, *Protocol) error
|
|
|
|
// Limit the connection protocol, by running a detection after either side sends
|
|
// a banner within timeout.
|
|
//
|
|
// If no protocol could be detected, the accept function is called with a nil
|
|
// argument to check if we should proceed.
|
|
//
|
|
// If the accept function returns false, the connection will be closed.
|
|
func Limit(conn net.Conn, accept AcceptFunc) net.Conn {
|
|
if accept == nil {
|
|
// Nothing to do here.
|
|
return conn
|
|
}
|
|
|
|
return &connLimiter{
|
|
Conn: conn,
|
|
accept: accept,
|
|
}
|
|
}
|
|
|
|
type connLimiter struct {
|
|
net.Conn
|
|
accept AcceptFunc
|
|
acceptOnce sync.Once
|
|
acceptError atomic.Value
|
|
}
|
|
|
|
func (limiter *connLimiter) init(readData, writeData []byte) {
|
|
limiter.acceptOnce.Do(func() {
|
|
var (
|
|
dir Direction
|
|
data []byte
|
|
srcPort, dstPort int
|
|
)
|
|
if readData != nil {
|
|
// init called by initial read
|
|
dir, data, srcPort, dstPort = Server, readData, getPortFromAddr(limiter.LocalAddr()), getPortFromAddr(limiter.RemoteAddr())
|
|
} else {
|
|
// init called by initial write
|
|
dir, data, srcPort, dstPort = Client, writeData, getPortFromAddr(limiter.RemoteAddr()), getPortFromAddr(limiter.LocalAddr())
|
|
}
|
|
protocol, _, _ := Detect(dir, data, srcPort, dstPort)
|
|
if err := limiter.accept(dir, protocol); err != nil {
|
|
limiter.acceptError.Store(err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func (limiter *connLimiter) Read(p []byte) (n int, err error) {
|
|
var ok bool
|
|
if err, ok = limiter.acceptError.Load().(error); ok && err != nil {
|
|
return
|
|
}
|
|
if n, err = limiter.Conn.Read(p); n > 0 {
|
|
limiter.init(p[:n], nil)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (limiter *connLimiter) Write(p []byte) (n int, err error) {
|
|
limiter.init(nil, p)
|
|
var ok bool
|
|
if err, ok = limiter.acceptError.Load().(error); ok && err != nil {
|
|
return
|
|
}
|
|
return limiter.Conn.Write(p)
|
|
}
|