Fixed up the documentation and refactored Limit
This commit is contained in:
@@ -15,7 +15,8 @@ type AcceptFunc func(Direction, *Protocol) error
|
|||||||
// If no protocol could be detected, the accept function is called with a nil
|
// If no protocol could be detected, the accept function is called with a nil
|
||||||
// argument to check if we should proceed.
|
// argument to check if we should proceed.
|
||||||
//
|
//
|
||||||
// If the accept function returns false, the connection will be closed.
|
// If the accept function returns an error, all future reads and writes on the
|
||||||
|
// returned connection will return that error.
|
||||||
func Limit(conn net.Conn, accept AcceptFunc) net.Conn {
|
func Limit(conn net.Conn, accept AcceptFunc) net.Conn {
|
||||||
if accept == nil {
|
if accept == nil {
|
||||||
// Nothing to do here.
|
// Nothing to do here.
|
||||||
@@ -39,23 +40,32 @@ func (limiter *connLimiter) init(readData, writeData []byte) {
|
|||||||
limiter.acceptOnce.Do(func() {
|
limiter.acceptOnce.Do(func() {
|
||||||
var (
|
var (
|
||||||
dir Direction
|
dir Direction
|
||||||
data []byte
|
localPort = getPortFromAddr(limiter.LocalAddr())
|
||||||
srcPort, dstPort int
|
remotePort = getPortFromAddr(limiter.RemoteAddr())
|
||||||
|
protocol *Protocol
|
||||||
)
|
)
|
||||||
if readData != nil {
|
if readData != nil {
|
||||||
// init called by initial read
|
// init called by initial read (from the server)
|
||||||
dir, data, srcPort, dstPort = Server, readData, getPortFromAddr(limiter.LocalAddr()), getPortFromAddr(limiter.RemoteAddr())
|
dir = Server
|
||||||
|
protocol, _, _ = Detect(dir, readData, localPort, remotePort)
|
||||||
} else {
|
} else {
|
||||||
// init called by initial write
|
// init called by initial write (from the client)
|
||||||
dir, data, srcPort, dstPort = Client, writeData, getPortFromAddr(limiter.RemoteAddr()), getPortFromAddr(limiter.LocalAddr())
|
dir = Client
|
||||||
|
protocol, _, _ = Detect(dir, writeData, remotePort, localPort)
|
||||||
}
|
}
|
||||||
protocol, _, _ := Detect(dir, data, srcPort, dstPort)
|
|
||||||
if err := limiter.accept(dir, protocol); err != nil {
|
if err := limiter.accept(dir, protocol); err != nil {
|
||||||
limiter.acceptError.Store(err)
|
limiter.acceptError.Store(err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NetConn implements the same method as [net/tls.Conn] to obtain the underlying [net.Conn].
|
||||||
|
func (limiter *connLimiter) NetConn() net.Conn {
|
||||||
|
return limiter.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read from the connection, if this is the first read then the data returned by the underlying
|
||||||
|
// connection is used for protocol detection.
|
||||||
func (limiter *connLimiter) Read(p []byte) (n int, err error) {
|
func (limiter *connLimiter) Read(p []byte) (n int, err error) {
|
||||||
var ok bool
|
var ok bool
|
||||||
if err, ok = limiter.acceptError.Load().(error); ok && err != nil {
|
if err, ok = limiter.acceptError.Load().(error); ok && err != nil {
|
||||||
@@ -67,6 +77,8 @@ func (limiter *connLimiter) Read(p []byte) (n int, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Write to the connection, if this is the first write then the data is used for protocol detection
|
||||||
|
// before it gets sent to the underlying connection.
|
||||||
func (limiter *connLimiter) Write(p []byte) (n int, err error) {
|
func (limiter *connLimiter) Write(p []byte) (n int, err error) {
|
||||||
limiter.init(nil, p)
|
limiter.init(nil, p)
|
||||||
var ok bool
|
var ok bool
|
||||||
|
Reference in New Issue
Block a user