162 lines
3.7 KiB
Go
162 lines
3.7 KiB
Go
package protocol
|
|
|
|
import (
|
|
"net"
|
|
"strconv"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
// Intercepted is the result returned by [Interceptor.Detect].
|
|
type Intercepted struct {
|
|
Direction Direction
|
|
Protocol *Protocol
|
|
Confidence float64
|
|
Error error
|
|
}
|
|
|
|
// Interceptor intercepts reads from client or server.
|
|
type Interceptor struct {
|
|
clientPort int
|
|
clientBytes chan []byte
|
|
clientReader *readInterceptor
|
|
serverPort int
|
|
serverBytes chan []byte
|
|
serverReader *readInterceptor
|
|
}
|
|
|
|
// NewInterceptor creates a new (transparent) protocol interceptor.
|
|
func NewInterceptor() *Interceptor {
|
|
return &Interceptor{
|
|
clientBytes: make(chan []byte, 1),
|
|
serverBytes: make(chan []byte, 1),
|
|
}
|
|
}
|
|
|
|
type readInterceptor struct {
|
|
net.Conn
|
|
bytes chan []byte
|
|
once atomic.Bool
|
|
}
|
|
|
|
func newReadInterceptor(c net.Conn, bytes chan []byte) *readInterceptor {
|
|
return &readInterceptor{
|
|
Conn: c,
|
|
bytes: bytes,
|
|
}
|
|
}
|
|
|
|
// Cancel any future Read interceptions and closes the channel.
|
|
func (r *readInterceptor) Cancel() {
|
|
if r == nil {
|
|
return
|
|
}
|
|
r.once.Store(true)
|
|
close(r.bytes)
|
|
}
|
|
|
|
func (r *readInterceptor) Read(p []byte) (n int, err error) {
|
|
if r.once.CompareAndSwap(false, true) {
|
|
if n, err = r.Conn.Read(p); n > 0 {
|
|
// We create a copy, since the Read caller may modify p
|
|
// immediately after reading.
|
|
data := make([]byte, n)
|
|
copy(data, p[:n])
|
|
// Buffer the bytes in the channel.
|
|
r.bytes <- data
|
|
}
|
|
return
|
|
}
|
|
return r.Conn.Read(p)
|
|
}
|
|
|
|
// Client binds the client connection to the interceptor.
|
|
func (i *Interceptor) Client(c net.Conn) net.Conn {
|
|
if ri, ok := c.(*readInterceptor); ok {
|
|
return ri
|
|
}
|
|
i.clientPort = getPortFromAddr(c.RemoteAddr())
|
|
i.clientReader = newReadInterceptor(c, i.clientBytes)
|
|
return i.clientReader
|
|
}
|
|
|
|
// Server binds the server connection to the interceptor.
|
|
func (i *Interceptor) Server(c net.Conn) net.Conn {
|
|
if ri, ok := c.(*readInterceptor); ok {
|
|
return ri
|
|
}
|
|
i.serverPort = getPortFromAddr(c.RemoteAddr())
|
|
i.serverReader = newReadInterceptor(c, i.serverBytes)
|
|
return i.serverReader
|
|
}
|
|
|
|
// Detect runs protocol detection on the previously bound Client and Server connection.
|
|
//
|
|
// It waits until either the client or the server performs a read operation,
|
|
// which is then used for running protocol detection. If the read operation
|
|
// takes longer than timeout, an error is returned.
|
|
//
|
|
// The returned channel always yields one result and is then closed.
|
|
func (i *Interceptor) Detect(timeout time.Duration) <-chan *Intercepted {
|
|
var interceptc = make(chan *Intercepted, 1)
|
|
|
|
go func() {
|
|
// Make sure all channels are closed once we finish processing.
|
|
defer close(interceptc)
|
|
defer i.clientReader.Cancel()
|
|
defer i.serverReader.Cancel()
|
|
|
|
select {
|
|
case <-time.After(timeout): // timeout
|
|
interceptc <- &Intercepted{
|
|
Error: ErrTimeout,
|
|
}
|
|
|
|
case data := <-i.clientBytes: // client sent banner
|
|
p, c, err := Detect(Client, data, i.clientPort, i.serverPort)
|
|
interceptc <- &Intercepted{
|
|
Direction: Client,
|
|
Protocol: p,
|
|
Confidence: c,
|
|
Error: err,
|
|
}
|
|
|
|
case data := <-i.serverBytes: // server sent banner
|
|
p, c, err := Detect(Server, data, i.serverPort, i.clientPort)
|
|
interceptc <- &Intercepted{
|
|
Direction: Server,
|
|
Protocol: p,
|
|
Confidence: c,
|
|
Error: err,
|
|
}
|
|
}
|
|
}()
|
|
|
|
return interceptc
|
|
}
|
|
|
|
func getPortFromAddr(addr net.Addr) int {
|
|
switch a := addr.(type) {
|
|
case *net.TCPAddr:
|
|
return a.Port
|
|
case *net.UDPAddr:
|
|
return a.Port
|
|
case *net.IPAddr:
|
|
// IPAddr doesn't have a port
|
|
return 0
|
|
default:
|
|
// Fallback to parsing
|
|
_, service, err := net.SplitHostPort(addr.String())
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
if port, err := strconv.Atoi(service); err == nil {
|
|
return port
|
|
}
|
|
if port, err := net.LookupPort(addr.Network(), service); err == nil {
|
|
return port
|
|
}
|
|
return 0
|
|
}
|
|
}
|