Files
dpi/protocol/intercept.go
2025-10-08 20:53:56 +02:00

129 lines
3.0 KiB
Go

package protocol
import (
"net"
"sync/atomic"
"time"
)
// Intercepted is the result returned by [Interceptor.Detect].
type Intercepted struct {
Direction Direction
Protocol *Protocol
Error error
}
// Interceptor intercepts reads from client or server.
type Interceptor struct {
clientBytes chan []byte
clientReader *readInterceptor
serverBytes chan []byte
serverReader *readInterceptor
}
// NewInterceptor creates a new (transparent) protocol interceptor.
func NewInterceptor() *Interceptor {
return &Interceptor{
clientBytes: make(chan []byte, 1),
serverBytes: make(chan []byte, 1),
}
}
type readInterceptor struct {
net.Conn
bytes chan []byte
once atomic.Bool
}
func newReadInterceptor(c net.Conn, bytes chan []byte) *readInterceptor {
return &readInterceptor{
Conn: c,
bytes: bytes,
}
}
// Cancel any future Read interceptions and closes the channel.
func (r *readInterceptor) Cancel() {
if r == nil {
return
}
r.once.Store(true)
close(r.bytes)
}
func (r *readInterceptor) Read(p []byte) (n int, err error) {
if r.once.CompareAndSwap(false, true) {
if n, err = r.Conn.Read(p); n > 0 {
// We create a copy, since the Read caller may modify p
// immediately after reading.
data := make([]byte, n)
copy(data, p[:n])
// Buffer the bytes in the channel.
r.bytes <- data
}
return
}
return r.Conn.Read(p)
}
// Client binds the client connection to the interceptor.
func (i *Interceptor) Client(c net.Conn) net.Conn {
if ri, ok := c.(*readInterceptor); ok {
return ri
}
i.clientReader = newReadInterceptor(c, i.clientBytes)
return i.clientReader
}
// Server binds the server connection to the interceptor.
func (i *Interceptor) Server(c net.Conn) net.Conn {
if ri, ok := c.(*readInterceptor); ok {
return ri
}
i.serverReader = newReadInterceptor(c, i.serverBytes)
return i.serverReader
}
// Detect runs protocol detection on the previously bound Client and Server connection.
//
// It waits until either the client or the server performs a read operation,
// which is then used for running protocol detection. If the read operation
// takes longer than timeout, an error is returned.
//
// The returned channel always yields one result and is then closed.
func (i *Interceptor) Detect(timeout time.Duration) <-chan *Intercepted {
var interceptc = make(chan *Intercepted, 1)
go func() {
// Make sure all channels are closed once we finish processing.
defer close(interceptc)
defer i.clientReader.Cancel()
defer i.serverReader.Cancel()
select {
case <-time.After(timeout): // timeout
interceptc <- &Intercepted{
Error: ErrTimeout,
}
case data := <-i.clientBytes: // client sent banner
p, err := Detect(Client, data)
interceptc <- &Intercepted{
Direction: Client,
Protocol: p,
Error: err,
}
case data := <-i.serverBytes: // server sent banner
p, err := Detect(Server, data)
interceptc <- &Intercepted{
Direction: Server,
Protocol: p,
Error: err,
}
}
}()
return interceptc
}