Refactored detection logic to include ports and a confidence score

This commit is contained in:
2025-10-09 11:54:43 +02:00
parent 2081d684ed
commit 2ab59437fa
17 changed files with 795 additions and 129 deletions

View File

@@ -2,21 +2,25 @@ package protocol
import (
"net"
"strconv"
"sync/atomic"
"time"
)
// Intercepted is the result returned by [Interceptor.Detect].
type Intercepted struct {
Direction Direction
Protocol *Protocol
Error error
Direction Direction
Protocol *Protocol
Confidence float64
Error error
}
// Interceptor intercepts reads from client or server.
type Interceptor struct {
clientPort int
clientBytes chan []byte
clientReader *readInterceptor
serverPort int
serverBytes chan []byte
serverReader *readInterceptor
}
@@ -71,6 +75,7 @@ func (i *Interceptor) Client(c net.Conn) net.Conn {
if ri, ok := c.(*readInterceptor); ok {
return ri
}
i.clientPort = getPortFromAddr(c.RemoteAddr())
i.clientReader = newReadInterceptor(c, i.clientBytes)
return i.clientReader
}
@@ -80,6 +85,7 @@ func (i *Interceptor) Server(c net.Conn) net.Conn {
if ri, ok := c.(*readInterceptor); ok {
return ri
}
i.serverPort = getPortFromAddr(c.RemoteAddr())
i.serverReader = newReadInterceptor(c, i.serverBytes)
return i.serverReader
}
@@ -107,22 +113,49 @@ func (i *Interceptor) Detect(timeout time.Duration) <-chan *Intercepted {
}
case data := <-i.clientBytes: // client sent banner
p, err := Detect(Client, data)
p, c, err := Detect(Client, data, i.clientPort, i.serverPort)
interceptc <- &Intercepted{
Direction: Client,
Protocol: p,
Error: err,
Direction: Client,
Protocol: p,
Confidence: c,
Error: err,
}
case data := <-i.serverBytes: // server sent banner
p, err := Detect(Server, data)
p, c, err := Detect(Server, data, i.serverPort, i.clientPort)
interceptc <- &Intercepted{
Direction: Server,
Protocol: p,
Error: err,
Direction: Server,
Protocol: p,
Confidence: c,
Error: err,
}
}
}()
return interceptc
}
func getPortFromAddr(addr net.Addr) int {
switch a := addr.(type) {
case *net.TCPAddr:
return a.Port
case *net.UDPAddr:
return a.Port
case *net.IPAddr:
// IPAddr doesn't have a port
return 0
default:
// Fallback to parsing
_, service, err := net.SplitHostPort(addr.String())
if err != nil {
return 0
}
if port, err := strconv.Atoi(service); err == nil {
return port
}
if port, err := net.LookupPort(addr.Network(), service); err == nil {
return port
}
return 0
}
}