77 lines
1.6 KiB
Go
77 lines
1.6 KiB
Go
package protocol
|
|
|
|
import (
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
)
|
|
|
|
// AcceptFunc receives a direction and a detected protocol.
|
|
type AcceptFunc func(Direction, *Protocol) error
|
|
|
|
// Limit the connection protocol, by running a detection after either side sends
|
|
// a banner within timeout.
|
|
//
|
|
// If no protocol could be detected, the accept function is called with a nil
|
|
// argument to check if we should proceed.
|
|
//
|
|
// If the accept function returns false, the connection will be closed.
|
|
func Limit(conn net.Conn, accept AcceptFunc) net.Conn {
|
|
if accept == nil {
|
|
// Nothing to do here.
|
|
return conn
|
|
}
|
|
|
|
return &connLimiter{
|
|
Conn: conn,
|
|
accept: accept,
|
|
}
|
|
}
|
|
|
|
type connLimiter struct {
|
|
net.Conn
|
|
accept AcceptFunc
|
|
acceptOnce sync.Once
|
|
acceptError atomic.Value
|
|
}
|
|
|
|
func (l *connLimiter) init(readData, writeData []byte) {
|
|
l.acceptOnce.Do(func() {
|
|
var (
|
|
dir Direction
|
|
data []byte
|
|
)
|
|
if readData != nil {
|
|
// init called by initial read
|
|
dir, data = Server, readData
|
|
} else {
|
|
// init called by initial write
|
|
dir, data = Client, writeData
|
|
}
|
|
protocol, _ := Detect(dir, data)
|
|
if err := l.accept(dir, protocol); err != nil {
|
|
l.acceptError.Store(err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func (l *connLimiter) Read(p []byte) (n int, err error) {
|
|
var ok bool
|
|
if err, ok = l.acceptError.Load().(error); ok && err != nil {
|
|
return
|
|
}
|
|
if n, err = l.Conn.Read(p); n > 0 {
|
|
l.init(p[:n], nil)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (l *connLimiter) Write(p []byte) (n int, err error) {
|
|
l.init(nil, p)
|
|
var ok bool
|
|
if err, ok = l.acceptError.Load().(error); ok && err != nil {
|
|
return
|
|
}
|
|
return l.Conn.Write(p)
|
|
}
|