Refactored detection logic to include ports and a confidence score

This commit is contained in:
2025-10-09 11:54:43 +02:00
parent 2081d684ed
commit 2ab59437fa
17 changed files with 795 additions and 129 deletions

2
go.mod
View File

@@ -2,4 +2,4 @@ module git.maze.io/go/dpi
go 1.25
require golang.org/x/crypto v0.42.0 // indirect
require golang.org/x/crypto v0.42.0

View File

@@ -3,6 +3,8 @@ package protocol
import (
"errors"
"fmt"
"math"
"slices"
"sync"
"sync/atomic"
)
@@ -66,7 +68,15 @@ var (
atomicFormats atomic.Value
)
type DetectFunc func(Direction, []byte) *Protocol
type detectResult struct {
// Protocol detected, nil if no detection.
Protocol *Protocol
// Confidence level [0..1].
Confidence float64
}
type DetectFunc func(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64)
func Register(dir Direction, magic string, detect DetectFunc) {
formatsMu.Lock()
@@ -97,17 +107,77 @@ func matchMagic(magic string, data []byte) bool {
}
// Detect a protocol based on the provided data.
func Detect(dir Direction, data []byte) (*Protocol, error) {
formats, _ := atomicFormats.Load().([]format)
for _, f := range formats {
if f.dir.Contains(dir) {
func Detect(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64, err error) {
var (
formats, _ = atomicFormats.Load().([]format)
results []detectResult
)
for _, format := range formats {
if format.dir.Contains(dir) {
// Check the buffer to see if we have sufficient bytes
if matchMagic(f.magic, data) {
if p := f.detect(dir, data); p != nil {
return p, nil
if matchMagic(format.magic, data) {
if proto, confidence := format.detect(dir, data, srcPort, dstPort); proto != nil {
results = append(results, detectResult{proto, confidence})
}
}
}
}
return nil, ErrUnknown
if len(results) > 0 {
slices.SortStableFunc(results, func(a, b detectResult) int {
return compareFloats(b.Confidence, a.Confidence)
})
return results[0].Protocol, results[0].Confidence, nil
}
return nil, 0, ErrUnknown
}
// compareFloats compares two float64 numbers with tolerance for floating-point precision.
//
// Returns:
//
// -1 if a < b
// 0 if a == b (within tolerance)
// 1 if a > b
func compareFloats(a, b float64) int {
// Define the tolerance for floating-point comparison
const tolerance = 1e-9
// Handle special cases: NaN and Inf
if math.IsNaN(a) || math.IsNaN(b) {
// NaN is considered equal to itself, otherwise not equal
if math.IsNaN(a) && math.IsNaN(b) {
return 0
}
if math.IsNaN(a) {
return -1 // NaN is considered less than any number
}
return 1 // Any number is greater than NaN
}
// Handle infinity cases
if math.IsInf(a, 0) || math.IsInf(b, 0) {
if a < b {
return -1
}
if a > b {
return 1
}
return 0 // Both are same infinity
}
// Compare with tolerance for regular numbers
diff := a - b
// If the absolute difference is within tolerance, consider them equal
if math.Abs(diff) < tolerance {
return 0
}
// Otherwise return the comparison result
if diff < 0 {
return -1
}
return 1
}

View File

@@ -12,10 +12,17 @@ func init() {
Register(Server, "HTTP/?.", detectHTTPResponse)
}
func detectHTTPRequest(dir Direction, data []byte) *Protocol {
func detectHTTPRequest(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) {
// A minimal request "GET / HTTP/1.0\r\n" is > 8 bytes.
if len(data) < 8 {
return nil
return nil, 0
}
switch dstPort {
case 80, 8080: // Common HTTP ports
confidence = +.1
case 3128: // Common HTTP proxy port
confidence = -.1
}
if Strict {
@@ -31,38 +38,27 @@ func detectHTTPRequest(dir Direction, data []byte) *Protocol {
Minor: request.ProtoMinor,
Patch: -1,
},
}, confidence + .85
}
}
r.Reset(bytes.NewReader(b))
if response, err := http.ReadResponse(r, nil); err == nil {
return &Protocol{
Name: ProtocolHTTP,
Version: Version{
Major: response.ProtoMajor,
Minor: response.ProtoMinor,
Patch: -1,
},
}
}
return nil
return nil, 0
}
crlfIndex := bytes.IndexFunc(data, func(r rune) bool {
return r == '\r' || r == '\n'
})
if crlfIndex == -1 {
return nil
return nil, 0
}
// A request has three, space-separated parts.
part := bytes.Split(data[:crlfIndex], []byte(" "))
if len(part) != 3 {
return nil
return nil, 0
}
// The last part starts with "HTTP/".
if !bytes.HasPrefix(part[2], []byte("HTTP/1")) {
return nil
return nil, 0
}
var version = Version{Patch: -1}
@@ -71,17 +67,42 @@ func detectHTTPRequest(dir Direction, data []byte) *Protocol {
return &Protocol{
Name: ProtocolHTTP,
Version: version,
}
}, confidence + .75
}
func detectHTTPResponse(dir Direction, data []byte) *Protocol {
func detectHTTPResponse(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) {
if !dir.Contains(Server) {
return nil
return nil, 0
}
// A minimal response "HTTP/1.0 200 OK\r\n" is > 8 bytes.
if len(data) < 8 {
return nil
return nil, 0
}
switch srcPort {
case 80, 8080: // Common HTTP ports
confidence = +.1
case 3128: // Common HTTP proxy port
confidence = -.1
}
if Strict {
var (
b = append(data, '\r', '\n')
r = bufio.NewReader(bytes.NewReader(b))
)
if response, err := http.ReadResponse(r, nil); err == nil {
return &Protocol{
Name: ProtocolHTTP,
Version: Version{
Major: response.ProtoMajor,
Minor: response.ProtoMinor,
Patch: -1,
},
}, confidence + .85
}
return nil, 0
}
var version = Version{Patch: -1}
@@ -90,5 +111,5 @@ func detectHTTPResponse(dir Direction, data []byte) *Protocol {
return &Protocol{
Name: ProtocolHTTP,
Version: version,
}
}, confidence + .75
}

View File

@@ -28,12 +28,12 @@ func TestDetectHTTPRequest(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Run("HTTP/1.0 GET", func(t *testing.T) {
p, err := Detect(Client, http10Request)
p, c, err := Detect(Client, http10Request, 1234, 80)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
if p.Name != ProtocolHTTP {
t.Fatalf("expected http protocol, got %s", p.Name)
return
@@ -41,12 +41,12 @@ func TestDetectHTTPRequest(t *testing.T) {
})
t.Run("HTTP/1.1 GET", func(t *testing.T) {
p, err := Detect(Client, getRequest)
p, c, err := Detect(Client, getRequest, 1234, 80)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
if p.Name != ProtocolHTTP {
t.Fatalf("expected http protocol, got %s", p.Name)
return
@@ -54,7 +54,7 @@ func TestDetectHTTPRequest(t *testing.T) {
})
t.Run("Invalid SSH", func(t *testing.T) {
_, err := Detect(Server, sshBanner)
_, _, err := Detect(Server, sshBanner, 1234, 22)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {
@@ -94,12 +94,12 @@ func TestDetectHTTPResponse(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Run("HTTP/1.0 403", func(t *testing.T) {
p, err := Detect(Server, http10Response)
p, c, err := Detect(Server, http10Response, 80, 1234)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
if p.Name != ProtocolHTTP {
t.Fatalf("expected http protocol, got %s", p.Name)
return
@@ -107,12 +107,12 @@ func TestDetectHTTPResponse(t *testing.T) {
})
t.Run("HTTP/1.1 200", func(t *testing.T) {
p, err := Detect(Server, responseOK)
p, c, err := Detect(Server, responseOK, 80, 1234)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
if p.Name != ProtocolHTTP {
t.Fatalf("expected http protocol, got %s", p.Name)
return
@@ -120,12 +120,12 @@ func TestDetectHTTPResponse(t *testing.T) {
})
t.Run("HTTP/1.1 404", func(t *testing.T) {
p, err := Detect(Server, responseNotFound)
p, c, err := Detect(Server, responseNotFound, 80, 1234)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
if p.Name != ProtocolHTTP {
t.Fatalf("expected http protocol, got %s", p.Name)
return
@@ -133,7 +133,7 @@ func TestDetectHTTPResponse(t *testing.T) {
})
t.Run("Invalid HTTP/1.1 GET", func(t *testing.T) {
_, err := Detect(Server, getRequest)
_, _, err := Detect(Server, getRequest, 1234, 80)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {
@@ -142,7 +142,7 @@ func TestDetectHTTPResponse(t *testing.T) {
})
t.Run("Invalid SSH", func(t *testing.T) {
_, err := Detect(Server, sshBanner)
_, _, err := Detect(Server, sshBanner, 22, 1234)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {

View File

@@ -9,15 +9,19 @@ func init() {
Register(Server, "\x0a", detectMySQL)
}
func detectMySQL(dir Direction, data []byte) *Protocol {
func detectMySQL(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) {
if len(data) < 7 {
return nil
return nil, 0
}
// The first byte of the handshake packet is the protocol version.
// For MySQL, this is 10 (0x0A).
if data[0] != 0x0A {
return nil
return nil, 0
}
if srcPort == 3306 {
confidence = .1
}
// After the protocol version, there is a null-terminated server version string.
@@ -26,7 +30,7 @@ func detectMySQL(dir Direction, data []byte) *Protocol {
// If no null byte is found, it's not a valid banner.
if nullIndex == -1 {
return nil
return nil, 0
}
// The position of the null byte is relative to the start of the whole slice.
@@ -38,7 +42,7 @@ func detectMySQL(dir Direction, data []byte) *Protocol {
// We'll check for the 4-byte connection ID as a minimum requirement.
const connectionIDLength = 4
if len(data) < serverVersionEndPos+1+connectionIDLength {
return nil
return nil, 0
}
var version Version
@@ -47,5 +51,5 @@ func detectMySQL(dir Direction, data []byte) *Protocol {
return &Protocol{
Name: ProtocolMySQL,
Version: version,
}
}, confidence + .75
}

View File

@@ -38,23 +38,23 @@ func TestDetectMySQL(t *testing.T) {
malformedSlice := []byte{0x0a, 0x38, 0x2e, 0x30, 0x2e, 0x30, 0x01, 0x02, 0x03, 0x04, 0x05}
t.Run("MySQL 8", func(t *testing.T) {
p, _ := Detect(Server, mysql8Banner)
p, c, _ := Detect(Server, mysql8Banner, 3306, 0)
if p == nil {
t.Fatal("expected MySQL protocol, got nil")
}
t.Logf("detected %s version %s", p.Name, p.Version)
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
})
t.Run("MariaDB", func(t *testing.T) {
p, _ := Detect(Server, mariaDBBanner)
p, c, _ := Detect(Server, mariaDBBanner, 3306, 0)
if p == nil {
t.Fatal("expected MySQL protocol, got nil")
}
t.Logf("detected %s version %s", p.Name, p.Version)
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
})
t.Run("Invalid HTTP", func(t *testing.T) {
_, err := Detect(Server, httpBanner)
_, _, err := Detect(Server, httpBanner, 1234, 80)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {
@@ -63,7 +63,7 @@ func TestDetectMySQL(t *testing.T) {
})
t.Run("Too short", func(t *testing.T) {
_, err := Detect(Server, shortSlice)
_, _, err := Detect(Server, shortSlice, 3306, 1234)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {
@@ -72,7 +72,7 @@ func TestDetectMySQL(t *testing.T) {
})
t.Run("Malformed", func(t *testing.T) {
_, err := Detect(Server, malformedSlice)
_, _, err := Detect(Server, malformedSlice, 3306, 1234)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {

View File

@@ -20,16 +20,20 @@ func registerPostgreSQL() {
Register(Client, "????\x00\x03\x00\x00", detectPostgreSQLClient) // Startup packet, protocol 3.0
}
func detectPostgreSQLClient(dir Direction, data []byte) *Protocol {
func detectPostgreSQLClient(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) {
// A client startup message needs at least 8 bytes (length + protocol version).
if len(data) < 8 {
return nil
return nil, 0
}
length := int(binary.BigEndian.Uint32(data[0:]))
if len(data) != length {
log.Printf("not postgres %q: %d != %d", data, len(data), length)
return nil
return nil, 0
}
if dstPort == 5432 {
confidence = .1
}
major := int(binary.BigEndian.Uint16(data[4:]))
@@ -42,15 +46,19 @@ func detectPostgreSQLClient(dir Direction, data []byte) *Protocol {
Minor: minor,
Patch: -1,
},
}, confidence + .75
}
}
return nil
return nil, 0
}
func detectPostgreSQLServer(dir Direction, data []byte) *Protocol {
func detectPostgreSQLServer(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) {
// A server message needs at least 5 bytes (type + length).
if len(data) < 5 {
return nil
return nil, 0
}
if srcPort == 5432 {
confidence = .1
}
// All server messages (and subsequent client messages) are tagged with a single-byte type.
@@ -62,9 +70,9 @@ func detectPostgreSQLServer(dir Direction, data []byte) *Protocol {
'Z', // ReadyForQuery
'E', // ErrorResponse
'N': // NoticeResponse
return &Protocol{Name: ProtocolPostgreSQL}
return &Protocol{Name: ProtocolPostgreSQL}, confidence + .65
default:
return nil
return nil, 0
}
}

View File

@@ -22,12 +22,12 @@ func TestDetectPostgreSQLClient(t *testing.T) {
}
t.Run("Protocol 3.0", func(t *testing.T) {
p, err := Detect(Client, pgClientStartup)
p, c, err := Detect(Client, pgClientStartup, 0, 5432)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, 100*c)
if p.Name != ProtocolPostgreSQL {
t.Fatalf("expected postgres protocol, got %s", p.Name)
return
@@ -58,12 +58,12 @@ func TestDetectPostgreSQLServer(t *testing.T) {
httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
t.Run("AuthenticationOk", func(t *testing.T) {
p, err := Detect(Server, pgServerAuthOK)
p, c, err := Detect(Server, pgServerAuthOK, 5432, 0)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
if p.Name != ProtocolPostgreSQL {
t.Fatalf("expected postgres protocol, got %s", p.Name)
return
@@ -71,12 +71,12 @@ func TestDetectPostgreSQLServer(t *testing.T) {
})
t.Run("ErrorResponse", func(t *testing.T) {
p, err := Detect(Server, pgServerError)
p, c, err := Detect(Server, pgServerError, 5432, 0)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
if p.Name != ProtocolPostgreSQL {
t.Fatalf("expected postgres protocol, got %s", p.Name)
return
@@ -84,7 +84,7 @@ func TestDetectPostgreSQLServer(t *testing.T) {
})
t.Run("Invalid HTTP", func(t *testing.T) {
_, err := Detect(Server, httpBanner)
_, _, err := Detect(Server, httpBanner, 0, 80)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {

View File

@@ -14,10 +14,14 @@ func init() {
Register(Both, "", detectSSH)
}
func detectSSH(dir Direction, data []byte) *Protocol {
func detectSSH(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) {
// The data must be at least as long as the prefix itself.
if len(data) < len(ssh20Prefix) {
return nil
return nil, 0
}
if dstPort == 22 || dstPort == 2200 || dstPort == 2222 {
confidence = .1
}
// The protocol allows for pre-banner text, so we have to check all lines.
@@ -32,7 +36,7 @@ func detectSSH(dir Direction, data []byte) *Protocol {
Patch: -1,
Extra: string(line[len(ssh20Prefix):]),
},
}
}, confidence + 0.75
}
if bytes.HasPrefix(line, []byte(ssh199Prefix)) {
return &Protocol{
@@ -43,9 +47,9 @@ func detectSSH(dir Direction, data []byte) *Protocol {
Patch: -1,
Extra: string(line[len(ssh20Prefix):]),
},
}
}, confidence + 0.75
}
}
return nil
return nil, 0
}

View File

@@ -32,7 +32,7 @@ func TestDetectSSH(t *testing.T) {
httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
t.Run("OpenSSH client", func(t *testing.T) {
p, err := Detect(Server, openSSHBanner)
p, _, err := Detect(Server, openSSHBanner, 0, 22)
if err != nil {
t.Fatal(err)
return
@@ -45,7 +45,7 @@ func TestDetectSSH(t *testing.T) {
})
t.Run("OpenSSH server", func(t *testing.T) {
p, err := Detect(Server, openSSHBanner)
p, _, err := Detect(Server, openSSHBanner, 0, 22)
if err != nil {
t.Fatal(err)
return
@@ -58,7 +58,7 @@ func TestDetectSSH(t *testing.T) {
})
t.Run("OpenSSH server with banner", func(t *testing.T) {
p, err := Detect(Server, preBannerSSH)
p, _, err := Detect(Server, preBannerSSH, 0, 22)
if err != nil {
t.Fatal(err)
return
@@ -71,7 +71,7 @@ func TestDetectSSH(t *testing.T) {
})
t.Run("Dropbear server", func(t *testing.T) {
p, err := Detect(Server, dropbearBanner)
p, _, err := Detect(Server, dropbearBanner, 0, 22)
if err != nil {
t.Fatal(err)
return
@@ -84,7 +84,7 @@ func TestDetectSSH(t *testing.T) {
})
t.Run("Invalid MySQL banner", func(t *testing.T) {
_, err := Detect(Server, mysqlBanner)
_, _, err := Detect(Server, mysqlBanner, 0, 3306)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {
@@ -93,7 +93,7 @@ func TestDetectSSH(t *testing.T) {
})
t.Run("Invalid HTTP banner", func(t *testing.T) {
_, err := Detect(Server, httpBanner)
_, _, err := Detect(Server, httpBanner, 0, 80)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {

View File

@@ -17,12 +17,12 @@ func registerTLS() {
Register(Both, "\x16\x03\x03", detectTLS) // TLSv1.2
}
func detectTLS(dir Direction, data []byte) *Protocol {
func detectTLS(dir Direction, data []byte, _, _ int) (proto *Protocol, confidence float64) {
stream := cryptobyte.String(data)
// A TLS packet always has a content type (1 byte), version (2 bytes) and length (2 bytes).
if len(stream) < 5 {
return nil
return nil, 0
}
// Check for TLS Handshake (type 22)
@@ -32,15 +32,18 @@ func detectTLS(dir Direction, data []byte) *Protocol {
Length uint32
}
if !stream.ReadUint8(&header.Type) || header.Type != 0x16 {
return nil
return nil, 0
}
if !stream.ReadUint16(&header.Version) {
return nil
return nil, 0
}
if !stream.ReadUint24(&header.Length) {
return nil
return nil, 0
}
// Initial confidence
confidence = 0.5
// Detected SSL/TLS version
var version dpi.TLSVersion
@@ -48,6 +51,7 @@ func detectTLS(dir Direction, data []byte) *Protocol {
if version == 0 {
if hello, err := dpi.DecodeTLSClientHelloHandshake(data); err == nil {
version = hello.Version
confidence += .45
}
}
@@ -55,6 +59,7 @@ func detectTLS(dir Direction, data []byte) *Protocol {
if version == 0 {
if hello, err := dpi.DecodeTLSServerHello(data); err == nil {
version = hello.Version
confidence += .45
}
}
@@ -68,6 +73,7 @@ func detectTLS(dir Direction, data []byte) *Protocol {
)
if stream.ReadUint24(&length) && stream.ReadUint16(&versionWord) {
version = dpi.TLSVersion(versionWord)
confidence += .25
}
}
}
@@ -82,17 +88,17 @@ func detectTLS(dir Direction, data []byte) *Protocol {
return &Protocol{
Name: ProtocolSSL,
Version: Version{Major: 3, Minor: 0, Patch: -1},
}
}, confidence
} else if version >= dpi.VersionTLS10 && version <= dpi.VersionTLS13 {
return &Protocol{
Name: ProtocolTLS,
Version: Version{Major: 1, Minor: int(uint8(version) - 1), Patch: -1},
}
}, confidence
} else if version >= dpi.VersionTLS13Draft && version <= dpi.VersionTLS13Draft23 {
return &Protocol{
Name: ProtocolTLS,
Version: Version{Major: 1, Minor: 3, Patch: -1},
}, confidence
}
}
return nil
return nil, 0
}

View File

@@ -179,7 +179,7 @@ func TestDetectTLS(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Run("SSLv3 Client Hello", func(t *testing.T) {
p, err := Detect(Client, sslV3ClientHello)
p, _, err := Detect(Client, sslV3ClientHello, 0, 0)
if err != nil {
t.Fatal(err)
return
@@ -192,7 +192,7 @@ func TestDetectTLS(t *testing.T) {
})
t.Run("TLS 1.1 Client Hello", func(t *testing.T) {
p, err := Detect(Client, tls11ClientHello)
p, _, err := Detect(Client, tls11ClientHello, 0, 0)
if err != nil {
t.Fatal(err)
return
@@ -205,7 +205,7 @@ func TestDetectTLS(t *testing.T) {
})
t.Run("TLS 1.1 partial Client Hello", func(t *testing.T) {
p, err := Detect(Client, tls11ClientHelloPartial)
p, _, err := Detect(Client, tls11ClientHelloPartial, 0, 0)
if strict {
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
@@ -226,7 +226,7 @@ func TestDetectTLS(t *testing.T) {
})
t.Run("TLS 1.2 Client Hello", func(t *testing.T) {
p, err := Detect(Client, tls12ClientHello)
p, _, err := Detect(Client, tls12ClientHello, 0, 0)
if err != nil {
t.Fatal(err)
return
@@ -239,7 +239,7 @@ func TestDetectTLS(t *testing.T) {
})
t.Run("TLS 1.3 Client Hello", func(t *testing.T) {
p, err := Detect(Client, tls13ClientHello)
p, _, err := Detect(Client, tls13ClientHello, 0, 0)
if err != nil {
t.Fatal(err)
return
@@ -252,7 +252,7 @@ func TestDetectTLS(t *testing.T) {
})
t.Run("Invalid PostgreSQL", func(t *testing.T) {
_, err := Detect(Server, pgClientStartup)
_, _, err := Detect(Server, pgClientStartup, 0, 0)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {

215
protocol/detest_test.go Normal file
View File

@@ -0,0 +1,215 @@
package protocol
import (
"math"
"testing"
)
func TestCompareFloats(t *testing.T) {
tests := []struct {
name string
a, b float64
expected int
}{
// Basic comparisons
{
name: "a less than b",
a: 1.0,
b: 2.0,
expected: -1,
},
{
name: "a greater than b",
a: 2.0,
b: 1.0,
expected: 1,
},
{
name: "a equals b exact",
a: 1.0,
b: 1.0,
expected: 0,
},
// Floating-point precision cases
{
name: "famous 0.1 + 0.2 equals 0.3 within tolerance",
a: 0.1 + 0.2,
b: 0.3,
expected: 0,
},
{
name: "very close numbers within tolerance",
a: 1.0000000001,
b: 1.0000000002,
expected: 0,
},
{
name: "numbers outside tolerance a < b",
a: 1.0,
b: 1.0001,
expected: -1,
},
{
name: "numbers outside tolerance a > b",
a: 1.0001,
b: 1.0,
expected: 1,
},
// Edge cases with very small numbers
{
name: "very small numbers equal",
a: 1e-20,
b: 1e-20,
expected: 0,
},
{
name: "very small numbers a < b",
a: 1e-15,
b: 2e-15,
expected: -1,
},
// Zero and negative zero
{
name: "zero equals zero",
a: 0.0,
b: 0.0,
expected: 0,
},
{
name: "zero equals negative zero",
a: 0.0,
b: -0.0,
expected: 0,
},
{
name: "zero less than small positive",
a: 0.0,
b: 1e-20,
expected: -1,
},
{
name: "zero greater than small negative",
a: 0.0,
b: -1e-20,
expected: 1,
},
// Negative numbers
{
name: "negative numbers a > b",
a: -1.0,
b: -2.0,
expected: 1,
},
{
name: "negative numbers a < b",
a: -2.0,
b: -1.0,
expected: -1,
},
{
name: "negative numbers equal",
a: -1.0,
b: -1.0,
expected: 0,
},
// Mixed signs
{
name: "negative less than positive",
a: -1.0,
b: 1.0,
expected: -1,
},
{
name: "positive greater than negative",
a: 1.0,
b: -1.0,
expected: 1,
},
// Special values: NaN
{
name: "NaN equals NaN",
a: math.NaN(),
b: math.NaN(),
expected: 0,
},
{
name: "NaN less than number",
a: math.NaN(),
b: 1.0,
expected: -1,
},
{
name: "number greater than NaN",
a: 1.0,
b: math.NaN(),
expected: 1,
},
// Special values: Infinity
{
name: "positive infinity equals positive infinity",
a: math.Inf(1),
b: math.Inf(1),
expected: 0,
},
{
name: "negative infinity equals negative infinity",
a: math.Inf(-1),
b: math.Inf(-1),
expected: 0,
},
{
name: "positive infinity greater than negative infinity",
a: math.Inf(1),
b: math.Inf(-1),
expected: 1,
},
{
name: "negative infinity less than positive infinity",
a: math.Inf(-1),
b: math.Inf(1),
expected: -1,
},
{
name: "positive infinity greater than large number",
a: math.Inf(1),
b: 1e308,
expected: 1,
},
{
name: "negative infinity less than small number",
a: math.Inf(-1),
b: -1e308,
expected: -1,
},
// Large numbers
{
name: "large numbers equal",
a: 1e15,
b: 1e15,
expected: 0,
},
{
name: "large numbers a < b",
a: 1e15,
b: 2e15,
expected: -1,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := compareFloats(test.a, test.b)
if result != test.expected {
t.Errorf("compareFloats(%g, %g) = %d, want %d", test.a, test.b, result, test.expected)
}
})
}
}

View File

@@ -2,6 +2,7 @@ package protocol
import (
"net"
"strconv"
"sync/atomic"
"time"
)
@@ -10,13 +11,16 @@ import (
type Intercepted struct {
Direction Direction
Protocol *Protocol
Confidence float64
Error error
}
// Interceptor intercepts reads from client or server.
type Interceptor struct {
clientPort int
clientBytes chan []byte
clientReader *readInterceptor
serverPort int
serverBytes chan []byte
serverReader *readInterceptor
}
@@ -71,6 +75,7 @@ func (i *Interceptor) Client(c net.Conn) net.Conn {
if ri, ok := c.(*readInterceptor); ok {
return ri
}
i.clientPort = getPortFromAddr(c.RemoteAddr())
i.clientReader = newReadInterceptor(c, i.clientBytes)
return i.clientReader
}
@@ -80,6 +85,7 @@ func (i *Interceptor) Server(c net.Conn) net.Conn {
if ri, ok := c.(*readInterceptor); ok {
return ri
}
i.serverPort = getPortFromAddr(c.RemoteAddr())
i.serverReader = newReadInterceptor(c, i.serverBytes)
return i.serverReader
}
@@ -107,18 +113,20 @@ func (i *Interceptor) Detect(timeout time.Duration) <-chan *Intercepted {
}
case data := <-i.clientBytes: // client sent banner
p, err := Detect(Client, data)
p, c, err := Detect(Client, data, i.clientPort, i.serverPort)
interceptc <- &Intercepted{
Direction: Client,
Protocol: p,
Confidence: c,
Error: err,
}
case data := <-i.serverBytes: // server sent banner
p, err := Detect(Server, data)
p, c, err := Detect(Server, data, i.serverPort, i.clientPort)
interceptc <- &Intercepted{
Direction: Server,
Protocol: p,
Confidence: c,
Error: err,
}
}
@@ -126,3 +134,28 @@ func (i *Interceptor) Detect(timeout time.Duration) <-chan *Intercepted {
return interceptc
}
func getPortFromAddr(addr net.Addr) int {
switch a := addr.(type) {
case *net.TCPAddr:
return a.Port
case *net.UDPAddr:
return a.Port
case *net.IPAddr:
// IPAddr doesn't have a port
return 0
default:
// Fallback to parsing
_, service, err := net.SplitHostPort(addr.String())
if err != nil {
return 0
}
if port, err := strconv.Atoi(service); err == nil {
return port
}
if port, err := net.LookupPort(addr.Network(), service); err == nil {
return port
}
return 0
}
}

View File

@@ -35,42 +35,43 @@ type connLimiter struct {
acceptError atomic.Value
}
func (l *connLimiter) init(readData, writeData []byte) {
l.acceptOnce.Do(func() {
func (limiter *connLimiter) init(readData, writeData []byte) {
limiter.acceptOnce.Do(func() {
var (
dir Direction
data []byte
srcPort, dstPort int
)
if readData != nil {
// init called by initial read
dir, data = Server, readData
dir, data, srcPort, dstPort = Server, readData, getPortFromAddr(limiter.LocalAddr()), getPortFromAddr(limiter.RemoteAddr())
} else {
// init called by initial write
dir, data = Client, writeData
dir, data, srcPort, dstPort = Client, writeData, getPortFromAddr(limiter.RemoteAddr()), getPortFromAddr(limiter.LocalAddr())
}
protocol, _ := Detect(dir, data)
if err := l.accept(dir, protocol); err != nil {
l.acceptError.Store(err)
protocol, _, _ := Detect(dir, data, srcPort, dstPort)
if err := limiter.accept(dir, protocol); err != nil {
limiter.acceptError.Store(err)
}
})
}
func (l *connLimiter) Read(p []byte) (n int, err error) {
func (limiter *connLimiter) Read(p []byte) (n int, err error) {
var ok bool
if err, ok = l.acceptError.Load().(error); ok && err != nil {
if err, ok = limiter.acceptError.Load().(error); ok && err != nil {
return
}
if n, err = l.Conn.Read(p); n > 0 {
l.init(p[:n], nil)
if n, err = limiter.Conn.Read(p); n > 0 {
limiter.init(p[:n], nil)
}
return
}
func (l *connLimiter) Write(p []byte) (n int, err error) {
l.init(nil, p)
func (limiter *connLimiter) Write(p []byte) (n int, err error) {
limiter.init(nil, p)
var ok bool
if err, ok = l.acceptError.Load().(error); ok && err != nil {
if err, ok = limiter.acceptError.Load().(error); ok && err != nil {
return
}
return l.Conn.Write(p)
return limiter.Conn.Write(p)
}

77
protocol/match.go Normal file
View File

@@ -0,0 +1,77 @@
package protocol
// MatchPattern checks if the byte slice matches the magic string pattern.
//
// '?' matches any single character
// '*' matches zero or more characters
// '\' escapes special characters ('?', '*', '\')
// All other characters must match exactly
//
// Returns true if all magic bytes are matched, even if input has extra bytes.
func Match(magic string, input []byte) bool {
return match(magic, input, 0, 0)
}
// match is a recursive helper function that implements the matching logic
func match(magic string, input []byte, magicIndex, inputIndex int) bool {
// If we've reached the end of magic string, we've successfully matched all magic bytes
// It doesn't matter if there are extra bytes in the input
if magicIndex == len(magic) {
return true
}
// Handle escape character
if magic[magicIndex] == '\\' {
// Check if there's a next character in magic
if magicIndex+1 >= len(magic) {
// Backslash at end of magic string - treat as literal backslash
if inputIndex >= len(input) || input[inputIndex] != '\\' {
return false
}
return match(magic, input, magicIndex+1, inputIndex+1)
}
// Escape the next character - we need to match it literally
escapedChar := magic[magicIndex+1]
if inputIndex >= len(input) || input[inputIndex] != escapedChar {
return false
}
// Skip both the backslash and the escaped character in magic, move one in input
return match(magic, input, magicIndex+2, inputIndex+1)
}
// If we've reached the end of input but not magic string
if inputIndex == len(input) {
// If we have '*' at the current position, it can match zero characters
if magic[magicIndex] == '*' {
return match(magic, input, magicIndex+1, inputIndex)
}
return false
}
// Handle '*' character - matches zero or more characters
if magic[magicIndex] == '*' {
// Try matching zero characters
if match(magic, input, magicIndex+1, inputIndex) {
return true
}
// Try matching one or more characters
if inputIndex < len(input) && match(magic, input, magicIndex, inputIndex+1) {
return true
}
return false
}
// Handle '?' character - matches any single character
if magic[magicIndex] == '?' {
return match(magic, input, magicIndex+1, inputIndex+1)
}
// Handle exact character match
if magic[magicIndex] == input[inputIndex] {
return match(magic, input, magicIndex+1, inputIndex+1)
}
// No match found
return false
}

227
protocol/match_test.go Normal file
View File

@@ -0,0 +1,227 @@
package protocol
import (
"testing"
)
func TestMatch(t *testing.T) {
tests := []struct {
name string
magic string
input []byte
expected bool
}{
// Basic escaping tests
{
name: "escape star",
magic: "\\*",
input: []byte("*"),
expected: true,
},
{
name: "escape star no match",
magic: "\\*",
input: []byte("a"),
expected: false,
},
{
name: "escape star with longer input",
magic: "\\*",
input: []byte("*extra"),
expected: true,
},
{
name: "escape question mark",
magic: "\\?",
input: []byte("?"),
expected: true,
},
{
name: "escape question mark no match",
magic: "\\?",
input: []byte("a"),
expected: false,
},
{
name: "escape backslash",
magic: "\\\\",
input: []byte("\\"),
expected: true,
},
{
name: "escape backslash no match",
magic: "\\\\",
input: []byte("a"),
expected: false,
},
{
name: "escape backslash with longer input",
magic: "\\\\",
input: []byte("\\extra"),
expected: true,
},
// Multiple escaped characters
{
name: "multiple escaped characters",
magic: "\\*\\?\\\\",
input: []byte("*?\\"),
expected: true,
},
{
name: "multiple escaped characters with longer input",
magic: "\\*\\?\\\\",
input: []byte("*?\\extra"),
expected: true,
},
{
name: "mixed escaped characters",
magic: "a\\*b\\?c\\\\d",
input: []byte("a*b?c\\d"),
expected: true,
},
// Escaping combined with wildcards
{
name: "star then escaped star",
magic: "*\\*",
input: []byte("anything*"),
expected: true,
},
{
name: "star then escaped star must end with star",
magic: "*\\*",
input: []byte("anything"),
expected: false,
},
{
name: "star then escaped question",
magic: "*\\?",
input: []byte("hello?"),
expected: true,
},
{
name: "question then escaped star",
magic: "?\\*",
input: []byte("a*"),
expected: true,
},
{
name: "question then escaped star wrong second char",
magic: "?\\*",
input: []byte("aa"),
expected: false,
},
{
name: "wildcards between escaped characters",
magic: "*\\\\*",
input: []byte("path\\to\\file"),
expected: true,
},
// Real-world escaping scenarios
{
name: "file pattern with literal star",
magic: "file\\*.txt",
input: []byte("file*.txt"),
expected: true,
},
{
name: "file pattern with literal star no match",
magic: "file\\*.txt",
input: []byte("filex.txt"),
expected: false,
},
{
name: "pattern with literal question",
magic: "what\\?*",
input: []byte("what? is this"),
expected: true,
},
{
name: "pattern with literal question must have question",
magic: "what\\?*",
input: []byte("what is this"),
expected: false,
},
{
name: "database like pattern",
magic: "table_\\*_\\?",
input: []byte("table_*_?"),
expected: true,
},
{
name: "database like pattern with longer input",
magic: "table_\\*_\\?",
input: []byte("table_*_?backup"),
expected: true,
},
// Edge cases with escaping
{
name: "backslash at end of magic",
magic: "test\\",
input: []byte("test\\"),
expected: true,
},
{
name: "backslash at end of magic no match",
magic: "test\\",
input: []byte("test"),
expected: false,
},
{
name: "only backslash",
magic: "\\",
input: []byte("\\"),
expected: true,
},
{
name: "consecutive backslashes",
magic: "\\\\\\\\",
input: []byte("\\\\"),
expected: true,
},
// Mixed scenarios with both escaping and wildcards
{
name: "escaped wildcards in middle",
magic: "a*\\?b*\\*c",
input: []byte("aanything?banything*c"),
expected: true,
},
{
name: "escaped wildcards pattern",
magic: "select * from \\*",
input: []byte("select * from *"),
expected: true,
},
{
name: "escaped wildcards pattern with longer input",
magic: "select * from *\\*",
input: []byte("select name from users*"),
expected: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := Match(test.magic, test.input)
if result != test.expected {
t.Errorf("Match(%q, %q) = %v, want %v",
test.magic, string(test.input), result, test.expected)
}
})
}
}
// Benchmark test with escaping
func BenchmarkMatch(b *testing.B) {
magic := "file\\*\\?*\\\\*.txt"
input := []byte("file*?name\\backup.txt")
b.ResetTimer()
for b.Loop() {
Match(magic, input)
}
}