Initial import
This commit is contained in:
128
protocol/intercept.go
Normal file
128
protocol/intercept.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Intercepted is the result returned by [Interceptor.Detect].
|
||||
type Intercepted struct {
|
||||
Direction Direction
|
||||
Protocol *Protocol
|
||||
Error error
|
||||
}
|
||||
|
||||
// Interceptor intercepts reads from client or server.
|
||||
type Interceptor struct {
|
||||
clientBytes chan []byte
|
||||
clientReader *readInterceptor
|
||||
serverBytes chan []byte
|
||||
serverReader *readInterceptor
|
||||
}
|
||||
|
||||
// NewInterceptor creates a new (transparent) protocol interceptor.
|
||||
func NewInterceptor() *Interceptor {
|
||||
return &Interceptor{
|
||||
clientBytes: make(chan []byte, 1),
|
||||
serverBytes: make(chan []byte, 1),
|
||||
}
|
||||
}
|
||||
|
||||
type readInterceptor struct {
|
||||
net.Conn
|
||||
bytes chan []byte
|
||||
once atomic.Bool
|
||||
}
|
||||
|
||||
func newReadInterceptor(c net.Conn, bytes chan []byte) *readInterceptor {
|
||||
return &readInterceptor{
|
||||
Conn: c,
|
||||
bytes: bytes,
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel any future Read interceptions and closes the channel.
|
||||
func (r *readInterceptor) Cancel() {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
r.once.Store(true)
|
||||
close(r.bytes)
|
||||
}
|
||||
|
||||
func (r *readInterceptor) Read(p []byte) (n int, err error) {
|
||||
if r.once.CompareAndSwap(false, true) {
|
||||
if n, err = r.Conn.Read(p); n > 0 {
|
||||
// We create a copy, since the Read caller may modify p
|
||||
// immediately after reading.
|
||||
data := make([]byte, n)
|
||||
copy(data, p[:n])
|
||||
// Buffer the bytes in the channel.
|
||||
r.bytes <- data
|
||||
}
|
||||
return
|
||||
}
|
||||
return r.Conn.Read(p)
|
||||
}
|
||||
|
||||
// Client binds the client connection to the interceptor.
|
||||
func (i *Interceptor) Client(c net.Conn) net.Conn {
|
||||
if ri, ok := c.(*readInterceptor); ok {
|
||||
return ri
|
||||
}
|
||||
i.clientReader = newReadInterceptor(c, i.clientBytes)
|
||||
return i.clientReader
|
||||
}
|
||||
|
||||
// Server binds the server connection to the interceptor.
|
||||
func (i *Interceptor) Server(c net.Conn) net.Conn {
|
||||
if ri, ok := c.(*readInterceptor); ok {
|
||||
return ri
|
||||
}
|
||||
i.serverReader = newReadInterceptor(c, i.serverBytes)
|
||||
return i.serverReader
|
||||
}
|
||||
|
||||
// Detect runs protocol detection on the previously bound Client and Server connection.
|
||||
//
|
||||
// It waits until either the client or the server performs a read operation,
|
||||
// which is then used for running protocol detection. If the read operation
|
||||
// takes longer than timeout, an error is returned.
|
||||
//
|
||||
// The returned channel always yields one result and is then closed.
|
||||
func (i *Interceptor) Detect(timeout time.Duration) <-chan *Intercepted {
|
||||
var interceptc = make(chan *Intercepted, 1)
|
||||
|
||||
go func() {
|
||||
// Make sure all channels are closed once we finish processing.
|
||||
defer close(interceptc)
|
||||
defer i.clientReader.Cancel()
|
||||
defer i.serverReader.Cancel()
|
||||
|
||||
select {
|
||||
case <-time.After(timeout): // timeout
|
||||
interceptc <- &Intercepted{
|
||||
Error: ErrTimeout,
|
||||
}
|
||||
|
||||
case data := <-i.clientBytes: // client sent banner
|
||||
p, err := Detect(Client, data)
|
||||
interceptc <- &Intercepted{
|
||||
Direction: Client,
|
||||
Protocol: p,
|
||||
Error: err,
|
||||
}
|
||||
|
||||
case data := <-i.serverBytes: // server sent banner
|
||||
p, err := Detect(Server, data)
|
||||
interceptc <- &Intercepted{
|
||||
Direction: Server,
|
||||
Protocol: p,
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return interceptc
|
||||
}
|
Reference in New Issue
Block a user