Initial import
This commit is contained in:
76
protocol/limit.go
Normal file
76
protocol/limit.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// AcceptFunc receives a direction and a detected protocol.
|
||||
type AcceptFunc func(Direction, *Protocol) error
|
||||
|
||||
// Limit the connection protocol, by running a detection after either side sends
|
||||
// a banner within timeout.
|
||||
//
|
||||
// If no protocol could be detected, the accept function is called with a nil
|
||||
// argument to check if we should proceed.
|
||||
//
|
||||
// If the accept function returns false, the connection will be closed.
|
||||
func Limit(conn net.Conn, accept AcceptFunc) net.Conn {
|
||||
if accept == nil {
|
||||
// Nothing to do here.
|
||||
return conn
|
||||
}
|
||||
|
||||
return &connLimiter{
|
||||
Conn: conn,
|
||||
accept: accept,
|
||||
}
|
||||
}
|
||||
|
||||
type connLimiter struct {
|
||||
net.Conn
|
||||
accept AcceptFunc
|
||||
acceptOnce sync.Once
|
||||
acceptError atomic.Value
|
||||
}
|
||||
|
||||
func (l *connLimiter) init(readData, writeData []byte) {
|
||||
l.acceptOnce.Do(func() {
|
||||
var (
|
||||
dir Direction
|
||||
data []byte
|
||||
)
|
||||
if readData != nil {
|
||||
// init called by initial read
|
||||
dir, data = Server, readData
|
||||
} else {
|
||||
// init called by initial write
|
||||
dir, data = Client, writeData
|
||||
}
|
||||
protocol, _ := Detect(dir, data)
|
||||
if err := l.accept(dir, protocol); err != nil {
|
||||
l.acceptError.Store(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (l *connLimiter) Read(p []byte) (n int, err error) {
|
||||
var ok bool
|
||||
if err, ok = l.acceptError.Load().(error); ok && err != nil {
|
||||
return
|
||||
}
|
||||
if n, err = l.Conn.Read(p); n > 0 {
|
||||
l.init(p[:n], nil)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *connLimiter) Write(p []byte) (n int, err error) {
|
||||
l.init(nil, p)
|
||||
var ok bool
|
||||
if err, ok = l.acceptError.Load().(error); ok && err != nil {
|
||||
return
|
||||
}
|
||||
return l.Conn.Write(p)
|
||||
}
|
Reference in New Issue
Block a user