From 2ab59437fac57c26529c086dbe92f33faf09bfce Mon Sep 17 00:00:00 2001 From: maze Date: Thu, 9 Oct 2025 11:54:43 +0200 Subject: [PATCH] Refactored detection logic to include ports and a confidence score --- go.mod | 2 +- protocol/detect.go | 88 ++++++++++-- protocol/detect_http.go | 67 +++++---- protocol/detect_http_test.go | 26 ++-- protocol/detect_mysql.go | 16 ++- protocol/detect_mysql_test.go | 14 +- protocol/detect_postgres.go | 26 ++-- protocol/detect_postgres_test.go | 14 +- protocol/detect_ssh.go | 14 +- protocol/detect_ssh_test.go | 12 +- protocol/detect_tls.go | 24 ++-- protocol/detect_tls_test.go | 12 +- protocol/detest_test.go | 215 +++++++++++++++++++++++++++++ protocol/intercept.go | 55 ++++++-- protocol/limit.go | 35 ++--- protocol/match.go | 77 +++++++++++ protocol/match_test.go | 227 +++++++++++++++++++++++++++++++ 17 files changed, 795 insertions(+), 129 deletions(-) create mode 100644 protocol/detest_test.go create mode 100644 protocol/match.go create mode 100644 protocol/match_test.go diff --git a/go.mod b/go.mod index ce95ba7..e5b104f 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module git.maze.io/go/dpi go 1.25 -require golang.org/x/crypto v0.42.0 // indirect +require golang.org/x/crypto v0.42.0 diff --git a/protocol/detect.go b/protocol/detect.go index bc1f641..8aafd9a 100644 --- a/protocol/detect.go +++ b/protocol/detect.go @@ -3,6 +3,8 @@ package protocol import ( "errors" "fmt" + "math" + "slices" "sync" "sync/atomic" ) @@ -66,7 +68,15 @@ var ( atomicFormats atomic.Value ) -type DetectFunc func(Direction, []byte) *Protocol +type detectResult struct { + // Protocol detected, nil if no detection. + Protocol *Protocol + + // Confidence level [0..1]. + Confidence float64 +} + +type DetectFunc func(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) func Register(dir Direction, magic string, detect DetectFunc) { formatsMu.Lock() @@ -97,17 +107,77 @@ func matchMagic(magic string, data []byte) bool { } // Detect a protocol based on the provided data. -func Detect(dir Direction, data []byte) (*Protocol, error) { - formats, _ := atomicFormats.Load().([]format) - for _, f := range formats { - if f.dir.Contains(dir) { +func Detect(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64, err error) { + var ( + formats, _ = atomicFormats.Load().([]format) + results []detectResult + ) + for _, format := range formats { + if format.dir.Contains(dir) { // Check the buffer to see if we have sufficient bytes - if matchMagic(f.magic, data) { - if p := f.detect(dir, data); p != nil { - return p, nil + if matchMagic(format.magic, data) { + if proto, confidence := format.detect(dir, data, srcPort, dstPort); proto != nil { + results = append(results, detectResult{proto, confidence}) } } } } - return nil, ErrUnknown + + if len(results) > 0 { + slices.SortStableFunc(results, func(a, b detectResult) int { + return compareFloats(b.Confidence, a.Confidence) + }) + return results[0].Protocol, results[0].Confidence, nil + } + + return nil, 0, ErrUnknown +} + +// compareFloats compares two float64 numbers with tolerance for floating-point precision. +// +// Returns: +// +// -1 if a < b +// 0 if a == b (within tolerance) +// 1 if a > b +func compareFloats(a, b float64) int { + // Define the tolerance for floating-point comparison + const tolerance = 1e-9 + + // Handle special cases: NaN and Inf + if math.IsNaN(a) || math.IsNaN(b) { + // NaN is considered equal to itself, otherwise not equal + if math.IsNaN(a) && math.IsNaN(b) { + return 0 + } + if math.IsNaN(a) { + return -1 // NaN is considered less than any number + } + return 1 // Any number is greater than NaN + } + + // Handle infinity cases + if math.IsInf(a, 0) || math.IsInf(b, 0) { + if a < b { + return -1 + } + if a > b { + return 1 + } + return 0 // Both are same infinity + } + + // Compare with tolerance for regular numbers + diff := a - b + + // If the absolute difference is within tolerance, consider them equal + if math.Abs(diff) < tolerance { + return 0 + } + + // Otherwise return the comparison result + if diff < 0 { + return -1 + } + return 1 } diff --git a/protocol/detect_http.go b/protocol/detect_http.go index f487924..a5863c1 100644 --- a/protocol/detect_http.go +++ b/protocol/detect_http.go @@ -12,10 +12,17 @@ func init() { Register(Server, "HTTP/?.", detectHTTPResponse) } -func detectHTTPRequest(dir Direction, data []byte) *Protocol { +func detectHTTPRequest(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) { // A minimal request "GET / HTTP/1.0\r\n" is > 8 bytes. if len(data) < 8 { - return nil + return nil, 0 + } + + switch dstPort { + case 80, 8080: // Common HTTP ports + confidence = +.1 + case 3128: // Common HTTP proxy port + confidence = -.1 } if Strict { @@ -31,38 +38,27 @@ func detectHTTPRequest(dir Direction, data []byte) *Protocol { Minor: request.ProtoMinor, Patch: -1, }, - } + }, confidence + .85 } - r.Reset(bytes.NewReader(b)) - if response, err := http.ReadResponse(r, nil); err == nil { - return &Protocol{ - Name: ProtocolHTTP, - Version: Version{ - Major: response.ProtoMajor, - Minor: response.ProtoMinor, - Patch: -1, - }, - } - } - return nil + return nil, 0 } crlfIndex := bytes.IndexFunc(data, func(r rune) bool { return r == '\r' || r == '\n' }) if crlfIndex == -1 { - return nil + return nil, 0 } // A request has three, space-separated parts. part := bytes.Split(data[:crlfIndex], []byte(" ")) if len(part) != 3 { - return nil + return nil, 0 } // The last part starts with "HTTP/". if !bytes.HasPrefix(part[2], []byte("HTTP/1")) { - return nil + return nil, 0 } var version = Version{Patch: -1} @@ -71,17 +67,42 @@ func detectHTTPRequest(dir Direction, data []byte) *Protocol { return &Protocol{ Name: ProtocolHTTP, Version: version, - } + }, confidence + .75 } -func detectHTTPResponse(dir Direction, data []byte) *Protocol { +func detectHTTPResponse(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) { if !dir.Contains(Server) { - return nil + return nil, 0 } // A minimal response "HTTP/1.0 200 OK\r\n" is > 8 bytes. if len(data) < 8 { - return nil + return nil, 0 + } + + switch srcPort { + case 80, 8080: // Common HTTP ports + confidence = +.1 + case 3128: // Common HTTP proxy port + confidence = -.1 + } + + if Strict { + var ( + b = append(data, '\r', '\n') + r = bufio.NewReader(bytes.NewReader(b)) + ) + if response, err := http.ReadResponse(r, nil); err == nil { + return &Protocol{ + Name: ProtocolHTTP, + Version: Version{ + Major: response.ProtoMajor, + Minor: response.ProtoMinor, + Patch: -1, + }, + }, confidence + .85 + } + return nil, 0 } var version = Version{Patch: -1} @@ -90,5 +111,5 @@ func detectHTTPResponse(dir Direction, data []byte) *Protocol { return &Protocol{ Name: ProtocolHTTP, Version: version, - } + }, confidence + .75 } diff --git a/protocol/detect_http_test.go b/protocol/detect_http_test.go index 4374871..0942d28 100644 --- a/protocol/detect_http_test.go +++ b/protocol/detect_http_test.go @@ -28,12 +28,12 @@ func TestDetectHTTPRequest(t *testing.T) { t.Run(name, func(t *testing.T) { t.Run("HTTP/1.0 GET", func(t *testing.T) { - p, err := Detect(Client, http10Request) + p, c, err := Detect(Client, http10Request, 1234, 80) if err != nil { t.Fatal(err) return } - t.Logf("detected %s version %s", p.Name, p.Version) + t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) if p.Name != ProtocolHTTP { t.Fatalf("expected http protocol, got %s", p.Name) return @@ -41,12 +41,12 @@ func TestDetectHTTPRequest(t *testing.T) { }) t.Run("HTTP/1.1 GET", func(t *testing.T) { - p, err := Detect(Client, getRequest) + p, c, err := Detect(Client, getRequest, 1234, 80) if err != nil { t.Fatal(err) return } - t.Logf("detected %s version %s", p.Name, p.Version) + t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) if p.Name != ProtocolHTTP { t.Fatalf("expected http protocol, got %s", p.Name) return @@ -54,7 +54,7 @@ func TestDetectHTTPRequest(t *testing.T) { }) t.Run("Invalid SSH", func(t *testing.T) { - _, err := Detect(Server, sshBanner) + _, _, err := Detect(Server, sshBanner, 1234, 22) if !errors.Is(err, ErrUnknown) { t.Fatalf("expected unknown format, got error %T: %q", err, err) } else { @@ -94,12 +94,12 @@ func TestDetectHTTPResponse(t *testing.T) { t.Run(name, func(t *testing.T) { t.Run("HTTP/1.0 403", func(t *testing.T) { - p, err := Detect(Server, http10Response) + p, c, err := Detect(Server, http10Response, 80, 1234) if err != nil { t.Fatal(err) return } - t.Logf("detected %s version %s", p.Name, p.Version) + t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) if p.Name != ProtocolHTTP { t.Fatalf("expected http protocol, got %s", p.Name) return @@ -107,12 +107,12 @@ func TestDetectHTTPResponse(t *testing.T) { }) t.Run("HTTP/1.1 200", func(t *testing.T) { - p, err := Detect(Server, responseOK) + p, c, err := Detect(Server, responseOK, 80, 1234) if err != nil { t.Fatal(err) return } - t.Logf("detected %s version %s", p.Name, p.Version) + t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) if p.Name != ProtocolHTTP { t.Fatalf("expected http protocol, got %s", p.Name) return @@ -120,12 +120,12 @@ func TestDetectHTTPResponse(t *testing.T) { }) t.Run("HTTP/1.1 404", func(t *testing.T) { - p, err := Detect(Server, responseNotFound) + p, c, err := Detect(Server, responseNotFound, 80, 1234) if err != nil { t.Fatal(err) return } - t.Logf("detected %s version %s", p.Name, p.Version) + t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) if p.Name != ProtocolHTTP { t.Fatalf("expected http protocol, got %s", p.Name) return @@ -133,7 +133,7 @@ func TestDetectHTTPResponse(t *testing.T) { }) t.Run("Invalid HTTP/1.1 GET", func(t *testing.T) { - _, err := Detect(Server, getRequest) + _, _, err := Detect(Server, getRequest, 1234, 80) if !errors.Is(err, ErrUnknown) { t.Fatalf("expected unknown format, got error %T: %q", err, err) } else { @@ -142,7 +142,7 @@ func TestDetectHTTPResponse(t *testing.T) { }) t.Run("Invalid SSH", func(t *testing.T) { - _, err := Detect(Server, sshBanner) + _, _, err := Detect(Server, sshBanner, 22, 1234) if !errors.Is(err, ErrUnknown) { t.Fatalf("expected unknown format, got error %T: %q", err, err) } else { diff --git a/protocol/detect_mysql.go b/protocol/detect_mysql.go index 03cb3cd..da65176 100644 --- a/protocol/detect_mysql.go +++ b/protocol/detect_mysql.go @@ -9,15 +9,19 @@ func init() { Register(Server, "\x0a", detectMySQL) } -func detectMySQL(dir Direction, data []byte) *Protocol { +func detectMySQL(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) { if len(data) < 7 { - return nil + return nil, 0 } // The first byte of the handshake packet is the protocol version. // For MySQL, this is 10 (0x0A). if data[0] != 0x0A { - return nil + return nil, 0 + } + + if srcPort == 3306 { + confidence = .1 } // After the protocol version, there is a null-terminated server version string. @@ -26,7 +30,7 @@ func detectMySQL(dir Direction, data []byte) *Protocol { // If no null byte is found, it's not a valid banner. if nullIndex == -1 { - return nil + return nil, 0 } // The position of the null byte is relative to the start of the whole slice. @@ -38,7 +42,7 @@ func detectMySQL(dir Direction, data []byte) *Protocol { // We'll check for the 4-byte connection ID as a minimum requirement. const connectionIDLength = 4 if len(data) < serverVersionEndPos+1+connectionIDLength { - return nil + return nil, 0 } var version Version @@ -47,5 +51,5 @@ func detectMySQL(dir Direction, data []byte) *Protocol { return &Protocol{ Name: ProtocolMySQL, Version: version, - } + }, confidence + .75 } diff --git a/protocol/detect_mysql_test.go b/protocol/detect_mysql_test.go index 3e019b2..866a50a 100644 --- a/protocol/detect_mysql_test.go +++ b/protocol/detect_mysql_test.go @@ -38,23 +38,23 @@ func TestDetectMySQL(t *testing.T) { malformedSlice := []byte{0x0a, 0x38, 0x2e, 0x30, 0x2e, 0x30, 0x01, 0x02, 0x03, 0x04, 0x05} t.Run("MySQL 8", func(t *testing.T) { - p, _ := Detect(Server, mysql8Banner) + p, c, _ := Detect(Server, mysql8Banner, 3306, 0) if p == nil { t.Fatal("expected MySQL protocol, got nil") } - t.Logf("detected %s version %s", p.Name, p.Version) + t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) }) t.Run("MariaDB", func(t *testing.T) { - p, _ := Detect(Server, mariaDBBanner) + p, c, _ := Detect(Server, mariaDBBanner, 3306, 0) if p == nil { t.Fatal("expected MySQL protocol, got nil") } - t.Logf("detected %s version %s", p.Name, p.Version) + t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) }) t.Run("Invalid HTTP", func(t *testing.T) { - _, err := Detect(Server, httpBanner) + _, _, err := Detect(Server, httpBanner, 1234, 80) if !errors.Is(err, ErrUnknown) { t.Fatalf("expected unknown format, got error %T: %q", err, err) } else { @@ -63,7 +63,7 @@ func TestDetectMySQL(t *testing.T) { }) t.Run("Too short", func(t *testing.T) { - _, err := Detect(Server, shortSlice) + _, _, err := Detect(Server, shortSlice, 3306, 1234) if !errors.Is(err, ErrUnknown) { t.Fatalf("expected unknown format, got error %T: %q", err, err) } else { @@ -72,7 +72,7 @@ func TestDetectMySQL(t *testing.T) { }) t.Run("Malformed", func(t *testing.T) { - _, err := Detect(Server, malformedSlice) + _, _, err := Detect(Server, malformedSlice, 3306, 1234) if !errors.Is(err, ErrUnknown) { t.Fatalf("expected unknown format, got error %T: %q", err, err) } else { diff --git a/protocol/detect_postgres.go b/protocol/detect_postgres.go index 4235534..d6405ce 100644 --- a/protocol/detect_postgres.go +++ b/protocol/detect_postgres.go @@ -20,16 +20,20 @@ func registerPostgreSQL() { Register(Client, "????\x00\x03\x00\x00", detectPostgreSQLClient) // Startup packet, protocol 3.0 } -func detectPostgreSQLClient(dir Direction, data []byte) *Protocol { +func detectPostgreSQLClient(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) { // A client startup message needs at least 8 bytes (length + protocol version). if len(data) < 8 { - return nil + return nil, 0 } length := int(binary.BigEndian.Uint32(data[0:])) if len(data) != length { log.Printf("not postgres %q: %d != %d", data, len(data), length) - return nil + return nil, 0 + } + + if dstPort == 5432 { + confidence = .1 } major := int(binary.BigEndian.Uint16(data[4:])) @@ -42,15 +46,19 @@ func detectPostgreSQLClient(dir Direction, data []byte) *Protocol { Minor: minor, Patch: -1, }, - } + }, confidence + .75 } - return nil + return nil, 0 } -func detectPostgreSQLServer(dir Direction, data []byte) *Protocol { +func detectPostgreSQLServer(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) { // A server message needs at least 5 bytes (type + length). if len(data) < 5 { - return nil + return nil, 0 + } + + if srcPort == 5432 { + confidence = .1 } // All server messages (and subsequent client messages) are tagged with a single-byte type. @@ -62,9 +70,9 @@ func detectPostgreSQLServer(dir Direction, data []byte) *Protocol { 'Z', // ReadyForQuery 'E', // ErrorResponse 'N': // NoticeResponse - return &Protocol{Name: ProtocolPostgreSQL} + return &Protocol{Name: ProtocolPostgreSQL}, confidence + .65 default: - return nil + return nil, 0 } } diff --git a/protocol/detect_postgres_test.go b/protocol/detect_postgres_test.go index c35f338..89d8035 100644 --- a/protocol/detect_postgres_test.go +++ b/protocol/detect_postgres_test.go @@ -22,12 +22,12 @@ func TestDetectPostgreSQLClient(t *testing.T) { } t.Run("Protocol 3.0", func(t *testing.T) { - p, err := Detect(Client, pgClientStartup) + p, c, err := Detect(Client, pgClientStartup, 0, 5432) if err != nil { t.Fatal(err) return } - t.Logf("detected %s version %s", p.Name, p.Version) + t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, 100*c) if p.Name != ProtocolPostgreSQL { t.Fatalf("expected postgres protocol, got %s", p.Name) return @@ -58,12 +58,12 @@ func TestDetectPostgreSQLServer(t *testing.T) { httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") t.Run("AuthenticationOk", func(t *testing.T) { - p, err := Detect(Server, pgServerAuthOK) + p, c, err := Detect(Server, pgServerAuthOK, 5432, 0) if err != nil { t.Fatal(err) return } - t.Logf("detected %s version %s", p.Name, p.Version) + t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) if p.Name != ProtocolPostgreSQL { t.Fatalf("expected postgres protocol, got %s", p.Name) return @@ -71,12 +71,12 @@ func TestDetectPostgreSQLServer(t *testing.T) { }) t.Run("ErrorResponse", func(t *testing.T) { - p, err := Detect(Server, pgServerError) + p, c, err := Detect(Server, pgServerError, 5432, 0) if err != nil { t.Fatal(err) return } - t.Logf("detected %s version %s", p.Name, p.Version) + t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) if p.Name != ProtocolPostgreSQL { t.Fatalf("expected postgres protocol, got %s", p.Name) return @@ -84,7 +84,7 @@ func TestDetectPostgreSQLServer(t *testing.T) { }) t.Run("Invalid HTTP", func(t *testing.T) { - _, err := Detect(Server, httpBanner) + _, _, err := Detect(Server, httpBanner, 0, 80) if !errors.Is(err, ErrUnknown) { t.Fatalf("expected unknown format, got error %T: %q", err, err) } else { diff --git a/protocol/detect_ssh.go b/protocol/detect_ssh.go index 8953900..bb4e9a1 100644 --- a/protocol/detect_ssh.go +++ b/protocol/detect_ssh.go @@ -14,10 +14,14 @@ func init() { Register(Both, "", detectSSH) } -func detectSSH(dir Direction, data []byte) *Protocol { +func detectSSH(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) { // The data must be at least as long as the prefix itself. if len(data) < len(ssh20Prefix) { - return nil + return nil, 0 + } + + if dstPort == 22 || dstPort == 2200 || dstPort == 2222 { + confidence = .1 } // The protocol allows for pre-banner text, so we have to check all lines. @@ -32,7 +36,7 @@ func detectSSH(dir Direction, data []byte) *Protocol { Patch: -1, Extra: string(line[len(ssh20Prefix):]), }, - } + }, confidence + 0.75 } if bytes.HasPrefix(line, []byte(ssh199Prefix)) { return &Protocol{ @@ -43,9 +47,9 @@ func detectSSH(dir Direction, data []byte) *Protocol { Patch: -1, Extra: string(line[len(ssh20Prefix):]), }, - } + }, confidence + 0.75 } } - return nil + return nil, 0 } diff --git a/protocol/detect_ssh_test.go b/protocol/detect_ssh_test.go index 82fd982..1009f95 100644 --- a/protocol/detect_ssh_test.go +++ b/protocol/detect_ssh_test.go @@ -32,7 +32,7 @@ func TestDetectSSH(t *testing.T) { httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") t.Run("OpenSSH client", func(t *testing.T) { - p, err := Detect(Server, openSSHBanner) + p, _, err := Detect(Server, openSSHBanner, 0, 22) if err != nil { t.Fatal(err) return @@ -45,7 +45,7 @@ func TestDetectSSH(t *testing.T) { }) t.Run("OpenSSH server", func(t *testing.T) { - p, err := Detect(Server, openSSHBanner) + p, _, err := Detect(Server, openSSHBanner, 0, 22) if err != nil { t.Fatal(err) return @@ -58,7 +58,7 @@ func TestDetectSSH(t *testing.T) { }) t.Run("OpenSSH server with banner", func(t *testing.T) { - p, err := Detect(Server, preBannerSSH) + p, _, err := Detect(Server, preBannerSSH, 0, 22) if err != nil { t.Fatal(err) return @@ -71,7 +71,7 @@ func TestDetectSSH(t *testing.T) { }) t.Run("Dropbear server", func(t *testing.T) { - p, err := Detect(Server, dropbearBanner) + p, _, err := Detect(Server, dropbearBanner, 0, 22) if err != nil { t.Fatal(err) return @@ -84,7 +84,7 @@ func TestDetectSSH(t *testing.T) { }) t.Run("Invalid MySQL banner", func(t *testing.T) { - _, err := Detect(Server, mysqlBanner) + _, _, err := Detect(Server, mysqlBanner, 0, 3306) if !errors.Is(err, ErrUnknown) { t.Fatalf("expected unknown format, got error %T: %q", err, err) } else { @@ -93,7 +93,7 @@ func TestDetectSSH(t *testing.T) { }) t.Run("Invalid HTTP banner", func(t *testing.T) { - _, err := Detect(Server, httpBanner) + _, _, err := Detect(Server, httpBanner, 0, 80) if !errors.Is(err, ErrUnknown) { t.Fatalf("expected unknown format, got error %T: %q", err, err) } else { diff --git a/protocol/detect_tls.go b/protocol/detect_tls.go index 5838dfe..2d64892 100644 --- a/protocol/detect_tls.go +++ b/protocol/detect_tls.go @@ -17,12 +17,12 @@ func registerTLS() { Register(Both, "\x16\x03\x03", detectTLS) // TLSv1.2 } -func detectTLS(dir Direction, data []byte) *Protocol { +func detectTLS(dir Direction, data []byte, _, _ int) (proto *Protocol, confidence float64) { stream := cryptobyte.String(data) // A TLS packet always has a content type (1 byte), version (2 bytes) and length (2 bytes). if len(stream) < 5 { - return nil + return nil, 0 } // Check for TLS Handshake (type 22) @@ -32,15 +32,18 @@ func detectTLS(dir Direction, data []byte) *Protocol { Length uint32 } if !stream.ReadUint8(&header.Type) || header.Type != 0x16 { - return nil + return nil, 0 } if !stream.ReadUint16(&header.Version) { - return nil + return nil, 0 } if !stream.ReadUint24(&header.Length) { - return nil + return nil, 0 } + // Initial confidence + confidence = 0.5 + // Detected SSL/TLS version var version dpi.TLSVersion @@ -48,6 +51,7 @@ func detectTLS(dir Direction, data []byte) *Protocol { if version == 0 { if hello, err := dpi.DecodeTLSClientHelloHandshake(data); err == nil { version = hello.Version + confidence += .45 } } @@ -55,6 +59,7 @@ func detectTLS(dir Direction, data []byte) *Protocol { if version == 0 { if hello, err := dpi.DecodeTLSServerHello(data); err == nil { version = hello.Version + confidence += .45 } } @@ -68,6 +73,7 @@ func detectTLS(dir Direction, data []byte) *Protocol { ) if stream.ReadUint24(&length) && stream.ReadUint16(&versionWord) { version = dpi.TLSVersion(versionWord) + confidence += .25 } } } @@ -82,17 +88,17 @@ func detectTLS(dir Direction, data []byte) *Protocol { return &Protocol{ Name: ProtocolSSL, Version: Version{Major: 3, Minor: 0, Patch: -1}, - } + }, confidence } else if version >= dpi.VersionTLS10 && version <= dpi.VersionTLS13 { return &Protocol{ Name: ProtocolTLS, Version: Version{Major: 1, Minor: int(uint8(version) - 1), Patch: -1}, - } + }, confidence } else if version >= dpi.VersionTLS13Draft && version <= dpi.VersionTLS13Draft23 { return &Protocol{ Name: ProtocolTLS, Version: Version{Major: 1, Minor: 3, Patch: -1}, - } + }, confidence } - return nil + return nil, 0 } diff --git a/protocol/detect_tls_test.go b/protocol/detect_tls_test.go index 452d4e8..4f0c8cf 100644 --- a/protocol/detect_tls_test.go +++ b/protocol/detect_tls_test.go @@ -179,7 +179,7 @@ func TestDetectTLS(t *testing.T) { t.Run(name, func(t *testing.T) { t.Run("SSLv3 Client Hello", func(t *testing.T) { - p, err := Detect(Client, sslV3ClientHello) + p, _, err := Detect(Client, sslV3ClientHello, 0, 0) if err != nil { t.Fatal(err) return @@ -192,7 +192,7 @@ func TestDetectTLS(t *testing.T) { }) t.Run("TLS 1.1 Client Hello", func(t *testing.T) { - p, err := Detect(Client, tls11ClientHello) + p, _, err := Detect(Client, tls11ClientHello, 0, 0) if err != nil { t.Fatal(err) return @@ -205,7 +205,7 @@ func TestDetectTLS(t *testing.T) { }) t.Run("TLS 1.1 partial Client Hello", func(t *testing.T) { - p, err := Detect(Client, tls11ClientHelloPartial) + p, _, err := Detect(Client, tls11ClientHelloPartial, 0, 0) if strict { if !errors.Is(err, ErrUnknown) { t.Fatalf("expected unknown format, got error %T: %q", err, err) @@ -226,7 +226,7 @@ func TestDetectTLS(t *testing.T) { }) t.Run("TLS 1.2 Client Hello", func(t *testing.T) { - p, err := Detect(Client, tls12ClientHello) + p, _, err := Detect(Client, tls12ClientHello, 0, 0) if err != nil { t.Fatal(err) return @@ -239,7 +239,7 @@ func TestDetectTLS(t *testing.T) { }) t.Run("TLS 1.3 Client Hello", func(t *testing.T) { - p, err := Detect(Client, tls13ClientHello) + p, _, err := Detect(Client, tls13ClientHello, 0, 0) if err != nil { t.Fatal(err) return @@ -252,7 +252,7 @@ func TestDetectTLS(t *testing.T) { }) t.Run("Invalid PostgreSQL", func(t *testing.T) { - _, err := Detect(Server, pgClientStartup) + _, _, err := Detect(Server, pgClientStartup, 0, 0) if !errors.Is(err, ErrUnknown) { t.Fatalf("expected unknown format, got error %T: %q", err, err) } else { diff --git a/protocol/detest_test.go b/protocol/detest_test.go new file mode 100644 index 0000000..0c0bf5d --- /dev/null +++ b/protocol/detest_test.go @@ -0,0 +1,215 @@ +package protocol + +import ( + "math" + "testing" +) + +func TestCompareFloats(t *testing.T) { + tests := []struct { + name string + a, b float64 + expected int + }{ + // Basic comparisons + { + name: "a less than b", + a: 1.0, + b: 2.0, + expected: -1, + }, + { + name: "a greater than b", + a: 2.0, + b: 1.0, + expected: 1, + }, + { + name: "a equals b exact", + a: 1.0, + b: 1.0, + expected: 0, + }, + + // Floating-point precision cases + { + name: "famous 0.1 + 0.2 equals 0.3 within tolerance", + a: 0.1 + 0.2, + b: 0.3, + expected: 0, + }, + { + name: "very close numbers within tolerance", + a: 1.0000000001, + b: 1.0000000002, + expected: 0, + }, + { + name: "numbers outside tolerance a < b", + a: 1.0, + b: 1.0001, + expected: -1, + }, + { + name: "numbers outside tolerance a > b", + a: 1.0001, + b: 1.0, + expected: 1, + }, + + // Edge cases with very small numbers + { + name: "very small numbers equal", + a: 1e-20, + b: 1e-20, + expected: 0, + }, + { + name: "very small numbers a < b", + a: 1e-15, + b: 2e-15, + expected: -1, + }, + + // Zero and negative zero + { + name: "zero equals zero", + a: 0.0, + b: 0.0, + expected: 0, + }, + { + name: "zero equals negative zero", + a: 0.0, + b: -0.0, + expected: 0, + }, + { + name: "zero less than small positive", + a: 0.0, + b: 1e-20, + expected: -1, + }, + { + name: "zero greater than small negative", + a: 0.0, + b: -1e-20, + expected: 1, + }, + + // Negative numbers + { + name: "negative numbers a > b", + a: -1.0, + b: -2.0, + expected: 1, + }, + { + name: "negative numbers a < b", + a: -2.0, + b: -1.0, + expected: -1, + }, + { + name: "negative numbers equal", + a: -1.0, + b: -1.0, + expected: 0, + }, + + // Mixed signs + { + name: "negative less than positive", + a: -1.0, + b: 1.0, + expected: -1, + }, + { + name: "positive greater than negative", + a: 1.0, + b: -1.0, + expected: 1, + }, + + // Special values: NaN + { + name: "NaN equals NaN", + a: math.NaN(), + b: math.NaN(), + expected: 0, + }, + { + name: "NaN less than number", + a: math.NaN(), + b: 1.0, + expected: -1, + }, + { + name: "number greater than NaN", + a: 1.0, + b: math.NaN(), + expected: 1, + }, + + // Special values: Infinity + { + name: "positive infinity equals positive infinity", + a: math.Inf(1), + b: math.Inf(1), + expected: 0, + }, + { + name: "negative infinity equals negative infinity", + a: math.Inf(-1), + b: math.Inf(-1), + expected: 0, + }, + { + name: "positive infinity greater than negative infinity", + a: math.Inf(1), + b: math.Inf(-1), + expected: 1, + }, + { + name: "negative infinity less than positive infinity", + a: math.Inf(-1), + b: math.Inf(1), + expected: -1, + }, + { + name: "positive infinity greater than large number", + a: math.Inf(1), + b: 1e308, + expected: 1, + }, + { + name: "negative infinity less than small number", + a: math.Inf(-1), + b: -1e308, + expected: -1, + }, + + // Large numbers + { + name: "large numbers equal", + a: 1e15, + b: 1e15, + expected: 0, + }, + { + name: "large numbers a < b", + a: 1e15, + b: 2e15, + expected: -1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := compareFloats(test.a, test.b) + if result != test.expected { + t.Errorf("compareFloats(%g, %g) = %d, want %d", test.a, test.b, result, test.expected) + } + }) + } +} diff --git a/protocol/intercept.go b/protocol/intercept.go index 50e155e..7ac0b34 100644 --- a/protocol/intercept.go +++ b/protocol/intercept.go @@ -2,21 +2,25 @@ package protocol import ( "net" + "strconv" "sync/atomic" "time" ) // Intercepted is the result returned by [Interceptor.Detect]. type Intercepted struct { - Direction Direction - Protocol *Protocol - Error error + Direction Direction + Protocol *Protocol + Confidence float64 + Error error } // Interceptor intercepts reads from client or server. type Interceptor struct { + clientPort int clientBytes chan []byte clientReader *readInterceptor + serverPort int serverBytes chan []byte serverReader *readInterceptor } @@ -71,6 +75,7 @@ func (i *Interceptor) Client(c net.Conn) net.Conn { if ri, ok := c.(*readInterceptor); ok { return ri } + i.clientPort = getPortFromAddr(c.RemoteAddr()) i.clientReader = newReadInterceptor(c, i.clientBytes) return i.clientReader } @@ -80,6 +85,7 @@ func (i *Interceptor) Server(c net.Conn) net.Conn { if ri, ok := c.(*readInterceptor); ok { return ri } + i.serverPort = getPortFromAddr(c.RemoteAddr()) i.serverReader = newReadInterceptor(c, i.serverBytes) return i.serverReader } @@ -107,22 +113,49 @@ func (i *Interceptor) Detect(timeout time.Duration) <-chan *Intercepted { } case data := <-i.clientBytes: // client sent banner - p, err := Detect(Client, data) + p, c, err := Detect(Client, data, i.clientPort, i.serverPort) interceptc <- &Intercepted{ - Direction: Client, - Protocol: p, - Error: err, + Direction: Client, + Protocol: p, + Confidence: c, + Error: err, } case data := <-i.serverBytes: // server sent banner - p, err := Detect(Server, data) + p, c, err := Detect(Server, data, i.serverPort, i.clientPort) interceptc <- &Intercepted{ - Direction: Server, - Protocol: p, - Error: err, + Direction: Server, + Protocol: p, + Confidence: c, + Error: err, } } }() return interceptc } + +func getPortFromAddr(addr net.Addr) int { + switch a := addr.(type) { + case *net.TCPAddr: + return a.Port + case *net.UDPAddr: + return a.Port + case *net.IPAddr: + // IPAddr doesn't have a port + return 0 + default: + // Fallback to parsing + _, service, err := net.SplitHostPort(addr.String()) + if err != nil { + return 0 + } + if port, err := strconv.Atoi(service); err == nil { + return port + } + if port, err := net.LookupPort(addr.Network(), service); err == nil { + return port + } + return 0 + } +} diff --git a/protocol/limit.go b/protocol/limit.go index 2ec10e9..251ede6 100644 --- a/protocol/limit.go +++ b/protocol/limit.go @@ -35,42 +35,43 @@ type connLimiter struct { acceptError atomic.Value } -func (l *connLimiter) init(readData, writeData []byte) { - l.acceptOnce.Do(func() { +func (limiter *connLimiter) init(readData, writeData []byte) { + limiter.acceptOnce.Do(func() { var ( - dir Direction - data []byte + dir Direction + data []byte + srcPort, dstPort int ) if readData != nil { // init called by initial read - dir, data = Server, readData + dir, data, srcPort, dstPort = Server, readData, getPortFromAddr(limiter.LocalAddr()), getPortFromAddr(limiter.RemoteAddr()) } else { // init called by initial write - dir, data = Client, writeData + dir, data, srcPort, dstPort = Client, writeData, getPortFromAddr(limiter.RemoteAddr()), getPortFromAddr(limiter.LocalAddr()) } - protocol, _ := Detect(dir, data) - if err := l.accept(dir, protocol); err != nil { - l.acceptError.Store(err) + protocol, _, _ := Detect(dir, data, srcPort, dstPort) + if err := limiter.accept(dir, protocol); err != nil { + limiter.acceptError.Store(err) } }) } -func (l *connLimiter) Read(p []byte) (n int, err error) { +func (limiter *connLimiter) Read(p []byte) (n int, err error) { var ok bool - if err, ok = l.acceptError.Load().(error); ok && err != nil { + if err, ok = limiter.acceptError.Load().(error); ok && err != nil { return } - if n, err = l.Conn.Read(p); n > 0 { - l.init(p[:n], nil) + if n, err = limiter.Conn.Read(p); n > 0 { + limiter.init(p[:n], nil) } return } -func (l *connLimiter) Write(p []byte) (n int, err error) { - l.init(nil, p) +func (limiter *connLimiter) Write(p []byte) (n int, err error) { + limiter.init(nil, p) var ok bool - if err, ok = l.acceptError.Load().(error); ok && err != nil { + if err, ok = limiter.acceptError.Load().(error); ok && err != nil { return } - return l.Conn.Write(p) + return limiter.Conn.Write(p) } diff --git a/protocol/match.go b/protocol/match.go new file mode 100644 index 0000000..9a00ef3 --- /dev/null +++ b/protocol/match.go @@ -0,0 +1,77 @@ +package protocol + +// MatchPattern checks if the byte slice matches the magic string pattern. +// +// '?' matches any single character +// '*' matches zero or more characters +// '\' escapes special characters ('?', '*', '\') +// All other characters must match exactly +// +// Returns true if all magic bytes are matched, even if input has extra bytes. +func Match(magic string, input []byte) bool { + return match(magic, input, 0, 0) +} + +// match is a recursive helper function that implements the matching logic +func match(magic string, input []byte, magicIndex, inputIndex int) bool { + // If we've reached the end of magic string, we've successfully matched all magic bytes + // It doesn't matter if there are extra bytes in the input + if magicIndex == len(magic) { + return true + } + + // Handle escape character + if magic[magicIndex] == '\\' { + // Check if there's a next character in magic + if magicIndex+1 >= len(magic) { + // Backslash at end of magic string - treat as literal backslash + if inputIndex >= len(input) || input[inputIndex] != '\\' { + return false + } + return match(magic, input, magicIndex+1, inputIndex+1) + } + + // Escape the next character - we need to match it literally + escapedChar := magic[magicIndex+1] + if inputIndex >= len(input) || input[inputIndex] != escapedChar { + return false + } + // Skip both the backslash and the escaped character in magic, move one in input + return match(magic, input, magicIndex+2, inputIndex+1) + } + + // If we've reached the end of input but not magic string + if inputIndex == len(input) { + // If we have '*' at the current position, it can match zero characters + if magic[magicIndex] == '*' { + return match(magic, input, magicIndex+1, inputIndex) + } + return false + } + + // Handle '*' character - matches zero or more characters + if magic[magicIndex] == '*' { + // Try matching zero characters + if match(magic, input, magicIndex+1, inputIndex) { + return true + } + // Try matching one or more characters + if inputIndex < len(input) && match(magic, input, magicIndex, inputIndex+1) { + return true + } + return false + } + + // Handle '?' character - matches any single character + if magic[magicIndex] == '?' { + return match(magic, input, magicIndex+1, inputIndex+1) + } + + // Handle exact character match + if magic[magicIndex] == input[inputIndex] { + return match(magic, input, magicIndex+1, inputIndex+1) + } + + // No match found + return false +} diff --git a/protocol/match_test.go b/protocol/match_test.go new file mode 100644 index 0000000..9111d9f --- /dev/null +++ b/protocol/match_test.go @@ -0,0 +1,227 @@ +package protocol + +import ( + "testing" +) + +func TestMatch(t *testing.T) { + tests := []struct { + name string + magic string + input []byte + expected bool + }{ + // Basic escaping tests + { + name: "escape star", + magic: "\\*", + input: []byte("*"), + expected: true, + }, + { + name: "escape star no match", + magic: "\\*", + input: []byte("a"), + expected: false, + }, + { + name: "escape star with longer input", + magic: "\\*", + input: []byte("*extra"), + expected: true, + }, + { + name: "escape question mark", + magic: "\\?", + input: []byte("?"), + expected: true, + }, + { + name: "escape question mark no match", + magic: "\\?", + input: []byte("a"), + expected: false, + }, + { + name: "escape backslash", + magic: "\\\\", + input: []byte("\\"), + expected: true, + }, + { + name: "escape backslash no match", + magic: "\\\\", + input: []byte("a"), + expected: false, + }, + { + name: "escape backslash with longer input", + magic: "\\\\", + input: []byte("\\extra"), + expected: true, + }, + + // Multiple escaped characters + { + name: "multiple escaped characters", + magic: "\\*\\?\\\\", + input: []byte("*?\\"), + expected: true, + }, + { + name: "multiple escaped characters with longer input", + magic: "\\*\\?\\\\", + input: []byte("*?\\extra"), + expected: true, + }, + { + name: "mixed escaped characters", + magic: "a\\*b\\?c\\\\d", + input: []byte("a*b?c\\d"), + expected: true, + }, + + // Escaping combined with wildcards + { + name: "star then escaped star", + magic: "*\\*", + input: []byte("anything*"), + expected: true, + }, + { + name: "star then escaped star must end with star", + magic: "*\\*", + input: []byte("anything"), + expected: false, + }, + { + name: "star then escaped question", + magic: "*\\?", + input: []byte("hello?"), + expected: true, + }, + { + name: "question then escaped star", + magic: "?\\*", + input: []byte("a*"), + expected: true, + }, + { + name: "question then escaped star wrong second char", + magic: "?\\*", + input: []byte("aa"), + expected: false, + }, + { + name: "wildcards between escaped characters", + magic: "*\\\\*", + input: []byte("path\\to\\file"), + expected: true, + }, + + // Real-world escaping scenarios + { + name: "file pattern with literal star", + magic: "file\\*.txt", + input: []byte("file*.txt"), + expected: true, + }, + { + name: "file pattern with literal star no match", + magic: "file\\*.txt", + input: []byte("filex.txt"), + expected: false, + }, + { + name: "pattern with literal question", + magic: "what\\?*", + input: []byte("what? is this"), + expected: true, + }, + { + name: "pattern with literal question must have question", + magic: "what\\?*", + input: []byte("what is this"), + expected: false, + }, + { + name: "database like pattern", + magic: "table_\\*_\\?", + input: []byte("table_*_?"), + expected: true, + }, + { + name: "database like pattern with longer input", + magic: "table_\\*_\\?", + input: []byte("table_*_?backup"), + expected: true, + }, + + // Edge cases with escaping + { + name: "backslash at end of magic", + magic: "test\\", + input: []byte("test\\"), + expected: true, + }, + { + name: "backslash at end of magic no match", + magic: "test\\", + input: []byte("test"), + expected: false, + }, + { + name: "only backslash", + magic: "\\", + input: []byte("\\"), + expected: true, + }, + { + name: "consecutive backslashes", + magic: "\\\\\\\\", + input: []byte("\\\\"), + expected: true, + }, + + // Mixed scenarios with both escaping and wildcards + { + name: "escaped wildcards in middle", + magic: "a*\\?b*\\*c", + input: []byte("aanything?banything*c"), + expected: true, + }, + { + name: "escaped wildcards pattern", + magic: "select * from \\*", + input: []byte("select * from *"), + expected: true, + }, + { + name: "escaped wildcards pattern with longer input", + magic: "select * from *\\*", + input: []byte("select name from users*"), + expected: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := Match(test.magic, test.input) + if result != test.expected { + t.Errorf("Match(%q, %q) = %v, want %v", + test.magic, string(test.input), result, test.expected) + } + }) + } +} + +// Benchmark test with escaping +func BenchmarkMatch(b *testing.B) { + magic := "file\\*\\?*\\\\*.txt" + input := []byte("file*?name\\backup.txt") + + b.ResetTimer() + for b.Loop() { + Match(magic, input) + } +}