Refactored detection logic to include ports and a confidence score
This commit is contained in:
@@ -3,6 +3,8 @@ package protocol
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
@@ -66,7 +68,15 @@ var (
|
||||
atomicFormats atomic.Value
|
||||
)
|
||||
|
||||
type DetectFunc func(Direction, []byte) *Protocol
|
||||
type detectResult struct {
|
||||
// Protocol detected, nil if no detection.
|
||||
Protocol *Protocol
|
||||
|
||||
// Confidence level [0..1].
|
||||
Confidence float64
|
||||
}
|
||||
|
||||
type DetectFunc func(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64)
|
||||
|
||||
func Register(dir Direction, magic string, detect DetectFunc) {
|
||||
formatsMu.Lock()
|
||||
@@ -97,17 +107,77 @@ func matchMagic(magic string, data []byte) bool {
|
||||
}
|
||||
|
||||
// Detect a protocol based on the provided data.
|
||||
func Detect(dir Direction, data []byte) (*Protocol, error) {
|
||||
formats, _ := atomicFormats.Load().([]format)
|
||||
for _, f := range formats {
|
||||
if f.dir.Contains(dir) {
|
||||
func Detect(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64, err error) {
|
||||
var (
|
||||
formats, _ = atomicFormats.Load().([]format)
|
||||
results []detectResult
|
||||
)
|
||||
for _, format := range formats {
|
||||
if format.dir.Contains(dir) {
|
||||
// Check the buffer to see if we have sufficient bytes
|
||||
if matchMagic(f.magic, data) {
|
||||
if p := f.detect(dir, data); p != nil {
|
||||
return p, nil
|
||||
if matchMagic(format.magic, data) {
|
||||
if proto, confidence := format.detect(dir, data, srcPort, dstPort); proto != nil {
|
||||
results = append(results, detectResult{proto, confidence})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, ErrUnknown
|
||||
|
||||
if len(results) > 0 {
|
||||
slices.SortStableFunc(results, func(a, b detectResult) int {
|
||||
return compareFloats(b.Confidence, a.Confidence)
|
||||
})
|
||||
return results[0].Protocol, results[0].Confidence, nil
|
||||
}
|
||||
|
||||
return nil, 0, ErrUnknown
|
||||
}
|
||||
|
||||
// compareFloats compares two float64 numbers with tolerance for floating-point precision.
|
||||
//
|
||||
// Returns:
|
||||
//
|
||||
// -1 if a < b
|
||||
// 0 if a == b (within tolerance)
|
||||
// 1 if a > b
|
||||
func compareFloats(a, b float64) int {
|
||||
// Define the tolerance for floating-point comparison
|
||||
const tolerance = 1e-9
|
||||
|
||||
// Handle special cases: NaN and Inf
|
||||
if math.IsNaN(a) || math.IsNaN(b) {
|
||||
// NaN is considered equal to itself, otherwise not equal
|
||||
if math.IsNaN(a) && math.IsNaN(b) {
|
||||
return 0
|
||||
}
|
||||
if math.IsNaN(a) {
|
||||
return -1 // NaN is considered less than any number
|
||||
}
|
||||
return 1 // Any number is greater than NaN
|
||||
}
|
||||
|
||||
// Handle infinity cases
|
||||
if math.IsInf(a, 0) || math.IsInf(b, 0) {
|
||||
if a < b {
|
||||
return -1
|
||||
}
|
||||
if a > b {
|
||||
return 1
|
||||
}
|
||||
return 0 // Both are same infinity
|
||||
}
|
||||
|
||||
// Compare with tolerance for regular numbers
|
||||
diff := a - b
|
||||
|
||||
// If the absolute difference is within tolerance, consider them equal
|
||||
if math.Abs(diff) < tolerance {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Otherwise return the comparison result
|
||||
if diff < 0 {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
@@ -12,10 +12,17 @@ func init() {
|
||||
Register(Server, "HTTP/?.", detectHTTPResponse)
|
||||
}
|
||||
|
||||
func detectHTTPRequest(dir Direction, data []byte) *Protocol {
|
||||
func detectHTTPRequest(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) {
|
||||
// A minimal request "GET / HTTP/1.0\r\n" is > 8 bytes.
|
||||
if len(data) < 8 {
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
switch dstPort {
|
||||
case 80, 8080: // Common HTTP ports
|
||||
confidence = +.1
|
||||
case 3128: // Common HTTP proxy port
|
||||
confidence = -.1
|
||||
}
|
||||
|
||||
if Strict {
|
||||
@@ -31,38 +38,27 @@ func detectHTTPRequest(dir Direction, data []byte) *Protocol {
|
||||
Minor: request.ProtoMinor,
|
||||
Patch: -1,
|
||||
},
|
||||
}
|
||||
}, confidence + .85
|
||||
}
|
||||
r.Reset(bytes.NewReader(b))
|
||||
if response, err := http.ReadResponse(r, nil); err == nil {
|
||||
return &Protocol{
|
||||
Name: ProtocolHTTP,
|
||||
Version: Version{
|
||||
Major: response.ProtoMajor,
|
||||
Minor: response.ProtoMinor,
|
||||
Patch: -1,
|
||||
},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
crlfIndex := bytes.IndexFunc(data, func(r rune) bool {
|
||||
return r == '\r' || r == '\n'
|
||||
})
|
||||
if crlfIndex == -1 {
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
// A request has three, space-separated parts.
|
||||
part := bytes.Split(data[:crlfIndex], []byte(" "))
|
||||
if len(part) != 3 {
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
// The last part starts with "HTTP/".
|
||||
if !bytes.HasPrefix(part[2], []byte("HTTP/1")) {
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
var version = Version{Patch: -1}
|
||||
@@ -71,17 +67,42 @@ func detectHTTPRequest(dir Direction, data []byte) *Protocol {
|
||||
return &Protocol{
|
||||
Name: ProtocolHTTP,
|
||||
Version: version,
|
||||
}
|
||||
}, confidence + .75
|
||||
}
|
||||
|
||||
func detectHTTPResponse(dir Direction, data []byte) *Protocol {
|
||||
func detectHTTPResponse(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) {
|
||||
if !dir.Contains(Server) {
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
// A minimal response "HTTP/1.0 200 OK\r\n" is > 8 bytes.
|
||||
if len(data) < 8 {
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
switch srcPort {
|
||||
case 80, 8080: // Common HTTP ports
|
||||
confidence = +.1
|
||||
case 3128: // Common HTTP proxy port
|
||||
confidence = -.1
|
||||
}
|
||||
|
||||
if Strict {
|
||||
var (
|
||||
b = append(data, '\r', '\n')
|
||||
r = bufio.NewReader(bytes.NewReader(b))
|
||||
)
|
||||
if response, err := http.ReadResponse(r, nil); err == nil {
|
||||
return &Protocol{
|
||||
Name: ProtocolHTTP,
|
||||
Version: Version{
|
||||
Major: response.ProtoMajor,
|
||||
Minor: response.ProtoMinor,
|
||||
Patch: -1,
|
||||
},
|
||||
}, confidence + .85
|
||||
}
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
var version = Version{Patch: -1}
|
||||
@@ -90,5 +111,5 @@ func detectHTTPResponse(dir Direction, data []byte) *Protocol {
|
||||
return &Protocol{
|
||||
Name: ProtocolHTTP,
|
||||
Version: version,
|
||||
}
|
||||
}, confidence + .75
|
||||
}
|
||||
|
@@ -28,12 +28,12 @@ func TestDetectHTTPRequest(t *testing.T) {
|
||||
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Run("HTTP/1.0 GET", func(t *testing.T) {
|
||||
p, err := Detect(Client, http10Request)
|
||||
p, c, err := Detect(Client, http10Request, 1234, 80)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
|
||||
if p.Name != ProtocolHTTP {
|
||||
t.Fatalf("expected http protocol, got %s", p.Name)
|
||||
return
|
||||
@@ -41,12 +41,12 @@ func TestDetectHTTPRequest(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("HTTP/1.1 GET", func(t *testing.T) {
|
||||
p, err := Detect(Client, getRequest)
|
||||
p, c, err := Detect(Client, getRequest, 1234, 80)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
|
||||
if p.Name != ProtocolHTTP {
|
||||
t.Fatalf("expected http protocol, got %s", p.Name)
|
||||
return
|
||||
@@ -54,7 +54,7 @@ func TestDetectHTTPRequest(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Invalid SSH", func(t *testing.T) {
|
||||
_, err := Detect(Server, sshBanner)
|
||||
_, _, err := Detect(Server, sshBanner, 1234, 22)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
@@ -94,12 +94,12 @@ func TestDetectHTTPResponse(t *testing.T) {
|
||||
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Run("HTTP/1.0 403", func(t *testing.T) {
|
||||
p, err := Detect(Server, http10Response)
|
||||
p, c, err := Detect(Server, http10Response, 80, 1234)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
|
||||
if p.Name != ProtocolHTTP {
|
||||
t.Fatalf("expected http protocol, got %s", p.Name)
|
||||
return
|
||||
@@ -107,12 +107,12 @@ func TestDetectHTTPResponse(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("HTTP/1.1 200", func(t *testing.T) {
|
||||
p, err := Detect(Server, responseOK)
|
||||
p, c, err := Detect(Server, responseOK, 80, 1234)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
|
||||
if p.Name != ProtocolHTTP {
|
||||
t.Fatalf("expected http protocol, got %s", p.Name)
|
||||
return
|
||||
@@ -120,12 +120,12 @@ func TestDetectHTTPResponse(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("HTTP/1.1 404", func(t *testing.T) {
|
||||
p, err := Detect(Server, responseNotFound)
|
||||
p, c, err := Detect(Server, responseNotFound, 80, 1234)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
|
||||
if p.Name != ProtocolHTTP {
|
||||
t.Fatalf("expected http protocol, got %s", p.Name)
|
||||
return
|
||||
@@ -133,7 +133,7 @@ func TestDetectHTTPResponse(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Invalid HTTP/1.1 GET", func(t *testing.T) {
|
||||
_, err := Detect(Server, getRequest)
|
||||
_, _, err := Detect(Server, getRequest, 1234, 80)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
@@ -142,7 +142,7 @@ func TestDetectHTTPResponse(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Invalid SSH", func(t *testing.T) {
|
||||
_, err := Detect(Server, sshBanner)
|
||||
_, _, err := Detect(Server, sshBanner, 22, 1234)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
|
@@ -9,15 +9,19 @@ func init() {
|
||||
Register(Server, "\x0a", detectMySQL)
|
||||
}
|
||||
|
||||
func detectMySQL(dir Direction, data []byte) *Protocol {
|
||||
func detectMySQL(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) {
|
||||
if len(data) < 7 {
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
// The first byte of the handshake packet is the protocol version.
|
||||
// For MySQL, this is 10 (0x0A).
|
||||
if data[0] != 0x0A {
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
if srcPort == 3306 {
|
||||
confidence = .1
|
||||
}
|
||||
|
||||
// After the protocol version, there is a null-terminated server version string.
|
||||
@@ -26,7 +30,7 @@ func detectMySQL(dir Direction, data []byte) *Protocol {
|
||||
|
||||
// If no null byte is found, it's not a valid banner.
|
||||
if nullIndex == -1 {
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
// The position of the null byte is relative to the start of the whole slice.
|
||||
@@ -38,7 +42,7 @@ func detectMySQL(dir Direction, data []byte) *Protocol {
|
||||
// We'll check for the 4-byte connection ID as a minimum requirement.
|
||||
const connectionIDLength = 4
|
||||
if len(data) < serverVersionEndPos+1+connectionIDLength {
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
var version Version
|
||||
@@ -47,5 +51,5 @@ func detectMySQL(dir Direction, data []byte) *Protocol {
|
||||
return &Protocol{
|
||||
Name: ProtocolMySQL,
|
||||
Version: version,
|
||||
}
|
||||
}, confidence + .75
|
||||
}
|
||||
|
@@ -38,23 +38,23 @@ func TestDetectMySQL(t *testing.T) {
|
||||
malformedSlice := []byte{0x0a, 0x38, 0x2e, 0x30, 0x2e, 0x30, 0x01, 0x02, 0x03, 0x04, 0x05}
|
||||
|
||||
t.Run("MySQL 8", func(t *testing.T) {
|
||||
p, _ := Detect(Server, mysql8Banner)
|
||||
p, c, _ := Detect(Server, mysql8Banner, 3306, 0)
|
||||
if p == nil {
|
||||
t.Fatal("expected MySQL protocol, got nil")
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
|
||||
})
|
||||
|
||||
t.Run("MariaDB", func(t *testing.T) {
|
||||
p, _ := Detect(Server, mariaDBBanner)
|
||||
p, c, _ := Detect(Server, mariaDBBanner, 3306, 0)
|
||||
if p == nil {
|
||||
t.Fatal("expected MySQL protocol, got nil")
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
|
||||
})
|
||||
|
||||
t.Run("Invalid HTTP", func(t *testing.T) {
|
||||
_, err := Detect(Server, httpBanner)
|
||||
_, _, err := Detect(Server, httpBanner, 1234, 80)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
@@ -63,7 +63,7 @@ func TestDetectMySQL(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Too short", func(t *testing.T) {
|
||||
_, err := Detect(Server, shortSlice)
|
||||
_, _, err := Detect(Server, shortSlice, 3306, 1234)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
@@ -72,7 +72,7 @@ func TestDetectMySQL(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Malformed", func(t *testing.T) {
|
||||
_, err := Detect(Server, malformedSlice)
|
||||
_, _, err := Detect(Server, malformedSlice, 3306, 1234)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
|
@@ -20,16 +20,20 @@ func registerPostgreSQL() {
|
||||
Register(Client, "????\x00\x03\x00\x00", detectPostgreSQLClient) // Startup packet, protocol 3.0
|
||||
}
|
||||
|
||||
func detectPostgreSQLClient(dir Direction, data []byte) *Protocol {
|
||||
func detectPostgreSQLClient(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) {
|
||||
// A client startup message needs at least 8 bytes (length + protocol version).
|
||||
if len(data) < 8 {
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
length := int(binary.BigEndian.Uint32(data[0:]))
|
||||
if len(data) != length {
|
||||
log.Printf("not postgres %q: %d != %d", data, len(data), length)
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
if dstPort == 5432 {
|
||||
confidence = .1
|
||||
}
|
||||
|
||||
major := int(binary.BigEndian.Uint16(data[4:]))
|
||||
@@ -42,15 +46,19 @@ func detectPostgreSQLClient(dir Direction, data []byte) *Protocol {
|
||||
Minor: minor,
|
||||
Patch: -1,
|
||||
},
|
||||
}
|
||||
}, confidence + .75
|
||||
}
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
func detectPostgreSQLServer(dir Direction, data []byte) *Protocol {
|
||||
func detectPostgreSQLServer(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) {
|
||||
// A server message needs at least 5 bytes (type + length).
|
||||
if len(data) < 5 {
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
if srcPort == 5432 {
|
||||
confidence = .1
|
||||
}
|
||||
|
||||
// All server messages (and subsequent client messages) are tagged with a single-byte type.
|
||||
@@ -62,9 +70,9 @@ func detectPostgreSQLServer(dir Direction, data []byte) *Protocol {
|
||||
'Z', // ReadyForQuery
|
||||
'E', // ErrorResponse
|
||||
'N': // NoticeResponse
|
||||
return &Protocol{Name: ProtocolPostgreSQL}
|
||||
return &Protocol{Name: ProtocolPostgreSQL}, confidence + .65
|
||||
|
||||
default:
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
}
|
||||
|
@@ -22,12 +22,12 @@ func TestDetectPostgreSQLClient(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("Protocol 3.0", func(t *testing.T) {
|
||||
p, err := Detect(Client, pgClientStartup)
|
||||
p, c, err := Detect(Client, pgClientStartup, 0, 5432)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, 100*c)
|
||||
if p.Name != ProtocolPostgreSQL {
|
||||
t.Fatalf("expected postgres protocol, got %s", p.Name)
|
||||
return
|
||||
@@ -58,12 +58,12 @@ func TestDetectPostgreSQLServer(t *testing.T) {
|
||||
httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
|
||||
t.Run("AuthenticationOk", func(t *testing.T) {
|
||||
p, err := Detect(Server, pgServerAuthOK)
|
||||
p, c, err := Detect(Server, pgServerAuthOK, 5432, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
|
||||
if p.Name != ProtocolPostgreSQL {
|
||||
t.Fatalf("expected postgres protocol, got %s", p.Name)
|
||||
return
|
||||
@@ -71,12 +71,12 @@ func TestDetectPostgreSQLServer(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("ErrorResponse", func(t *testing.T) {
|
||||
p, err := Detect(Server, pgServerError)
|
||||
p, c, err := Detect(Server, pgServerError, 5432, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
|
||||
if p.Name != ProtocolPostgreSQL {
|
||||
t.Fatalf("expected postgres protocol, got %s", p.Name)
|
||||
return
|
||||
@@ -84,7 +84,7 @@ func TestDetectPostgreSQLServer(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Invalid HTTP", func(t *testing.T) {
|
||||
_, err := Detect(Server, httpBanner)
|
||||
_, _, err := Detect(Server, httpBanner, 0, 80)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
|
@@ -14,10 +14,14 @@ func init() {
|
||||
Register(Both, "", detectSSH)
|
||||
}
|
||||
|
||||
func detectSSH(dir Direction, data []byte) *Protocol {
|
||||
func detectSSH(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) {
|
||||
// The data must be at least as long as the prefix itself.
|
||||
if len(data) < len(ssh20Prefix) {
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
if dstPort == 22 || dstPort == 2200 || dstPort == 2222 {
|
||||
confidence = .1
|
||||
}
|
||||
|
||||
// The protocol allows for pre-banner text, so we have to check all lines.
|
||||
@@ -32,7 +36,7 @@ func detectSSH(dir Direction, data []byte) *Protocol {
|
||||
Patch: -1,
|
||||
Extra: string(line[len(ssh20Prefix):]),
|
||||
},
|
||||
}
|
||||
}, confidence + 0.75
|
||||
}
|
||||
if bytes.HasPrefix(line, []byte(ssh199Prefix)) {
|
||||
return &Protocol{
|
||||
@@ -43,9 +47,9 @@ func detectSSH(dir Direction, data []byte) *Protocol {
|
||||
Patch: -1,
|
||||
Extra: string(line[len(ssh20Prefix):]),
|
||||
},
|
||||
}
|
||||
}, confidence + 0.75
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
|
@@ -32,7 +32,7 @@ func TestDetectSSH(t *testing.T) {
|
||||
httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
|
||||
t.Run("OpenSSH client", func(t *testing.T) {
|
||||
p, err := Detect(Server, openSSHBanner)
|
||||
p, _, err := Detect(Server, openSSHBanner, 0, 22)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -45,7 +45,7 @@ func TestDetectSSH(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("OpenSSH server", func(t *testing.T) {
|
||||
p, err := Detect(Server, openSSHBanner)
|
||||
p, _, err := Detect(Server, openSSHBanner, 0, 22)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -58,7 +58,7 @@ func TestDetectSSH(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("OpenSSH server with banner", func(t *testing.T) {
|
||||
p, err := Detect(Server, preBannerSSH)
|
||||
p, _, err := Detect(Server, preBannerSSH, 0, 22)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -71,7 +71,7 @@ func TestDetectSSH(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Dropbear server", func(t *testing.T) {
|
||||
p, err := Detect(Server, dropbearBanner)
|
||||
p, _, err := Detect(Server, dropbearBanner, 0, 22)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -84,7 +84,7 @@ func TestDetectSSH(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Invalid MySQL banner", func(t *testing.T) {
|
||||
_, err := Detect(Server, mysqlBanner)
|
||||
_, _, err := Detect(Server, mysqlBanner, 0, 3306)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
@@ -93,7 +93,7 @@ func TestDetectSSH(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Invalid HTTP banner", func(t *testing.T) {
|
||||
_, err := Detect(Server, httpBanner)
|
||||
_, _, err := Detect(Server, httpBanner, 0, 80)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
|
@@ -17,12 +17,12 @@ func registerTLS() {
|
||||
Register(Both, "\x16\x03\x03", detectTLS) // TLSv1.2
|
||||
}
|
||||
|
||||
func detectTLS(dir Direction, data []byte) *Protocol {
|
||||
func detectTLS(dir Direction, data []byte, _, _ int) (proto *Protocol, confidence float64) {
|
||||
stream := cryptobyte.String(data)
|
||||
|
||||
// A TLS packet always has a content type (1 byte), version (2 bytes) and length (2 bytes).
|
||||
if len(stream) < 5 {
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
// Check for TLS Handshake (type 22)
|
||||
@@ -32,15 +32,18 @@ func detectTLS(dir Direction, data []byte) *Protocol {
|
||||
Length uint32
|
||||
}
|
||||
if !stream.ReadUint8(&header.Type) || header.Type != 0x16 {
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
if !stream.ReadUint16(&header.Version) {
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
if !stream.ReadUint24(&header.Length) {
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
// Initial confidence
|
||||
confidence = 0.5
|
||||
|
||||
// Detected SSL/TLS version
|
||||
var version dpi.TLSVersion
|
||||
|
||||
@@ -48,6 +51,7 @@ func detectTLS(dir Direction, data []byte) *Protocol {
|
||||
if version == 0 {
|
||||
if hello, err := dpi.DecodeTLSClientHelloHandshake(data); err == nil {
|
||||
version = hello.Version
|
||||
confidence += .45
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,6 +59,7 @@ func detectTLS(dir Direction, data []byte) *Protocol {
|
||||
if version == 0 {
|
||||
if hello, err := dpi.DecodeTLSServerHello(data); err == nil {
|
||||
version = hello.Version
|
||||
confidence += .45
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,6 +73,7 @@ func detectTLS(dir Direction, data []byte) *Protocol {
|
||||
)
|
||||
if stream.ReadUint24(&length) && stream.ReadUint16(&versionWord) {
|
||||
version = dpi.TLSVersion(versionWord)
|
||||
confidence += .25
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -82,17 +88,17 @@ func detectTLS(dir Direction, data []byte) *Protocol {
|
||||
return &Protocol{
|
||||
Name: ProtocolSSL,
|
||||
Version: Version{Major: 3, Minor: 0, Patch: -1},
|
||||
}
|
||||
}, confidence
|
||||
} else if version >= dpi.VersionTLS10 && version <= dpi.VersionTLS13 {
|
||||
return &Protocol{
|
||||
Name: ProtocolTLS,
|
||||
Version: Version{Major: 1, Minor: int(uint8(version) - 1), Patch: -1},
|
||||
}
|
||||
}, confidence
|
||||
} else if version >= dpi.VersionTLS13Draft && version <= dpi.VersionTLS13Draft23 {
|
||||
return &Protocol{
|
||||
Name: ProtocolTLS,
|
||||
Version: Version{Major: 1, Minor: 3, Patch: -1},
|
||||
}
|
||||
}, confidence
|
||||
}
|
||||
return nil
|
||||
return nil, 0
|
||||
}
|
||||
|
@@ -179,7 +179,7 @@ func TestDetectTLS(t *testing.T) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
|
||||
t.Run("SSLv3 Client Hello", func(t *testing.T) {
|
||||
p, err := Detect(Client, sslV3ClientHello)
|
||||
p, _, err := Detect(Client, sslV3ClientHello, 0, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -192,7 +192,7 @@ func TestDetectTLS(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("TLS 1.1 Client Hello", func(t *testing.T) {
|
||||
p, err := Detect(Client, tls11ClientHello)
|
||||
p, _, err := Detect(Client, tls11ClientHello, 0, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -205,7 +205,7 @@ func TestDetectTLS(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("TLS 1.1 partial Client Hello", func(t *testing.T) {
|
||||
p, err := Detect(Client, tls11ClientHelloPartial)
|
||||
p, _, err := Detect(Client, tls11ClientHelloPartial, 0, 0)
|
||||
if strict {
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
@@ -226,7 +226,7 @@ func TestDetectTLS(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("TLS 1.2 Client Hello", func(t *testing.T) {
|
||||
p, err := Detect(Client, tls12ClientHello)
|
||||
p, _, err := Detect(Client, tls12ClientHello, 0, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -239,7 +239,7 @@ func TestDetectTLS(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("TLS 1.3 Client Hello", func(t *testing.T) {
|
||||
p, err := Detect(Client, tls13ClientHello)
|
||||
p, _, err := Detect(Client, tls13ClientHello, 0, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -252,7 +252,7 @@ func TestDetectTLS(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Invalid PostgreSQL", func(t *testing.T) {
|
||||
_, err := Detect(Server, pgClientStartup)
|
||||
_, _, err := Detect(Server, pgClientStartup, 0, 0)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
|
215
protocol/detest_test.go
Normal file
215
protocol/detest_test.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompareFloats(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a, b float64
|
||||
expected int
|
||||
}{
|
||||
// Basic comparisons
|
||||
{
|
||||
name: "a less than b",
|
||||
a: 1.0,
|
||||
b: 2.0,
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "a greater than b",
|
||||
a: 2.0,
|
||||
b: 1.0,
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "a equals b exact",
|
||||
a: 1.0,
|
||||
b: 1.0,
|
||||
expected: 0,
|
||||
},
|
||||
|
||||
// Floating-point precision cases
|
||||
{
|
||||
name: "famous 0.1 + 0.2 equals 0.3 within tolerance",
|
||||
a: 0.1 + 0.2,
|
||||
b: 0.3,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "very close numbers within tolerance",
|
||||
a: 1.0000000001,
|
||||
b: 1.0000000002,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "numbers outside tolerance a < b",
|
||||
a: 1.0,
|
||||
b: 1.0001,
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "numbers outside tolerance a > b",
|
||||
a: 1.0001,
|
||||
b: 1.0,
|
||||
expected: 1,
|
||||
},
|
||||
|
||||
// Edge cases with very small numbers
|
||||
{
|
||||
name: "very small numbers equal",
|
||||
a: 1e-20,
|
||||
b: 1e-20,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "very small numbers a < b",
|
||||
a: 1e-15,
|
||||
b: 2e-15,
|
||||
expected: -1,
|
||||
},
|
||||
|
||||
// Zero and negative zero
|
||||
{
|
||||
name: "zero equals zero",
|
||||
a: 0.0,
|
||||
b: 0.0,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "zero equals negative zero",
|
||||
a: 0.0,
|
||||
b: -0.0,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "zero less than small positive",
|
||||
a: 0.0,
|
||||
b: 1e-20,
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "zero greater than small negative",
|
||||
a: 0.0,
|
||||
b: -1e-20,
|
||||
expected: 1,
|
||||
},
|
||||
|
||||
// Negative numbers
|
||||
{
|
||||
name: "negative numbers a > b",
|
||||
a: -1.0,
|
||||
b: -2.0,
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "negative numbers a < b",
|
||||
a: -2.0,
|
||||
b: -1.0,
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "negative numbers equal",
|
||||
a: -1.0,
|
||||
b: -1.0,
|
||||
expected: 0,
|
||||
},
|
||||
|
||||
// Mixed signs
|
||||
{
|
||||
name: "negative less than positive",
|
||||
a: -1.0,
|
||||
b: 1.0,
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "positive greater than negative",
|
||||
a: 1.0,
|
||||
b: -1.0,
|
||||
expected: 1,
|
||||
},
|
||||
|
||||
// Special values: NaN
|
||||
{
|
||||
name: "NaN equals NaN",
|
||||
a: math.NaN(),
|
||||
b: math.NaN(),
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "NaN less than number",
|
||||
a: math.NaN(),
|
||||
b: 1.0,
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "number greater than NaN",
|
||||
a: 1.0,
|
||||
b: math.NaN(),
|
||||
expected: 1,
|
||||
},
|
||||
|
||||
// Special values: Infinity
|
||||
{
|
||||
name: "positive infinity equals positive infinity",
|
||||
a: math.Inf(1),
|
||||
b: math.Inf(1),
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "negative infinity equals negative infinity",
|
||||
a: math.Inf(-1),
|
||||
b: math.Inf(-1),
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "positive infinity greater than negative infinity",
|
||||
a: math.Inf(1),
|
||||
b: math.Inf(-1),
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "negative infinity less than positive infinity",
|
||||
a: math.Inf(-1),
|
||||
b: math.Inf(1),
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "positive infinity greater than large number",
|
||||
a: math.Inf(1),
|
||||
b: 1e308,
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "negative infinity less than small number",
|
||||
a: math.Inf(-1),
|
||||
b: -1e308,
|
||||
expected: -1,
|
||||
},
|
||||
|
||||
// Large numbers
|
||||
{
|
||||
name: "large numbers equal",
|
||||
a: 1e15,
|
||||
b: 1e15,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "large numbers a < b",
|
||||
a: 1e15,
|
||||
b: 2e15,
|
||||
expected: -1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := compareFloats(test.a, test.b)
|
||||
if result != test.expected {
|
||||
t.Errorf("compareFloats(%g, %g) = %d, want %d", test.a, test.b, result, test.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -2,21 +2,25 @@ package protocol
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Intercepted is the result returned by [Interceptor.Detect].
|
||||
type Intercepted struct {
|
||||
Direction Direction
|
||||
Protocol *Protocol
|
||||
Error error
|
||||
Direction Direction
|
||||
Protocol *Protocol
|
||||
Confidence float64
|
||||
Error error
|
||||
}
|
||||
|
||||
// Interceptor intercepts reads from client or server.
|
||||
type Interceptor struct {
|
||||
clientPort int
|
||||
clientBytes chan []byte
|
||||
clientReader *readInterceptor
|
||||
serverPort int
|
||||
serverBytes chan []byte
|
||||
serverReader *readInterceptor
|
||||
}
|
||||
@@ -71,6 +75,7 @@ func (i *Interceptor) Client(c net.Conn) net.Conn {
|
||||
if ri, ok := c.(*readInterceptor); ok {
|
||||
return ri
|
||||
}
|
||||
i.clientPort = getPortFromAddr(c.RemoteAddr())
|
||||
i.clientReader = newReadInterceptor(c, i.clientBytes)
|
||||
return i.clientReader
|
||||
}
|
||||
@@ -80,6 +85,7 @@ func (i *Interceptor) Server(c net.Conn) net.Conn {
|
||||
if ri, ok := c.(*readInterceptor); ok {
|
||||
return ri
|
||||
}
|
||||
i.serverPort = getPortFromAddr(c.RemoteAddr())
|
||||
i.serverReader = newReadInterceptor(c, i.serverBytes)
|
||||
return i.serverReader
|
||||
}
|
||||
@@ -107,22 +113,49 @@ func (i *Interceptor) Detect(timeout time.Duration) <-chan *Intercepted {
|
||||
}
|
||||
|
||||
case data := <-i.clientBytes: // client sent banner
|
||||
p, err := Detect(Client, data)
|
||||
p, c, err := Detect(Client, data, i.clientPort, i.serverPort)
|
||||
interceptc <- &Intercepted{
|
||||
Direction: Client,
|
||||
Protocol: p,
|
||||
Error: err,
|
||||
Direction: Client,
|
||||
Protocol: p,
|
||||
Confidence: c,
|
||||
Error: err,
|
||||
}
|
||||
|
||||
case data := <-i.serverBytes: // server sent banner
|
||||
p, err := Detect(Server, data)
|
||||
p, c, err := Detect(Server, data, i.serverPort, i.clientPort)
|
||||
interceptc <- &Intercepted{
|
||||
Direction: Server,
|
||||
Protocol: p,
|
||||
Error: err,
|
||||
Direction: Server,
|
||||
Protocol: p,
|
||||
Confidence: c,
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return interceptc
|
||||
}
|
||||
|
||||
func getPortFromAddr(addr net.Addr) int {
|
||||
switch a := addr.(type) {
|
||||
case *net.TCPAddr:
|
||||
return a.Port
|
||||
case *net.UDPAddr:
|
||||
return a.Port
|
||||
case *net.IPAddr:
|
||||
// IPAddr doesn't have a port
|
||||
return 0
|
||||
default:
|
||||
// Fallback to parsing
|
||||
_, service, err := net.SplitHostPort(addr.String())
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
if port, err := strconv.Atoi(service); err == nil {
|
||||
return port
|
||||
}
|
||||
if port, err := net.LookupPort(addr.Network(), service); err == nil {
|
||||
return port
|
||||
}
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
@@ -35,42 +35,43 @@ type connLimiter struct {
|
||||
acceptError atomic.Value
|
||||
}
|
||||
|
||||
func (l *connLimiter) init(readData, writeData []byte) {
|
||||
l.acceptOnce.Do(func() {
|
||||
func (limiter *connLimiter) init(readData, writeData []byte) {
|
||||
limiter.acceptOnce.Do(func() {
|
||||
var (
|
||||
dir Direction
|
||||
data []byte
|
||||
dir Direction
|
||||
data []byte
|
||||
srcPort, dstPort int
|
||||
)
|
||||
if readData != nil {
|
||||
// init called by initial read
|
||||
dir, data = Server, readData
|
||||
dir, data, srcPort, dstPort = Server, readData, getPortFromAddr(limiter.LocalAddr()), getPortFromAddr(limiter.RemoteAddr())
|
||||
} else {
|
||||
// init called by initial write
|
||||
dir, data = Client, writeData
|
||||
dir, data, srcPort, dstPort = Client, writeData, getPortFromAddr(limiter.RemoteAddr()), getPortFromAddr(limiter.LocalAddr())
|
||||
}
|
||||
protocol, _ := Detect(dir, data)
|
||||
if err := l.accept(dir, protocol); err != nil {
|
||||
l.acceptError.Store(err)
|
||||
protocol, _, _ := Detect(dir, data, srcPort, dstPort)
|
||||
if err := limiter.accept(dir, protocol); err != nil {
|
||||
limiter.acceptError.Store(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (l *connLimiter) Read(p []byte) (n int, err error) {
|
||||
func (limiter *connLimiter) Read(p []byte) (n int, err error) {
|
||||
var ok bool
|
||||
if err, ok = l.acceptError.Load().(error); ok && err != nil {
|
||||
if err, ok = limiter.acceptError.Load().(error); ok && err != nil {
|
||||
return
|
||||
}
|
||||
if n, err = l.Conn.Read(p); n > 0 {
|
||||
l.init(p[:n], nil)
|
||||
if n, err = limiter.Conn.Read(p); n > 0 {
|
||||
limiter.init(p[:n], nil)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *connLimiter) Write(p []byte) (n int, err error) {
|
||||
l.init(nil, p)
|
||||
func (limiter *connLimiter) Write(p []byte) (n int, err error) {
|
||||
limiter.init(nil, p)
|
||||
var ok bool
|
||||
if err, ok = l.acceptError.Load().(error); ok && err != nil {
|
||||
if err, ok = limiter.acceptError.Load().(error); ok && err != nil {
|
||||
return
|
||||
}
|
||||
return l.Conn.Write(p)
|
||||
return limiter.Conn.Write(p)
|
||||
}
|
||||
|
77
protocol/match.go
Normal file
77
protocol/match.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package protocol
|
||||
|
||||
// MatchPattern checks if the byte slice matches the magic string pattern.
|
||||
//
|
||||
// '?' matches any single character
|
||||
// '*' matches zero or more characters
|
||||
// '\' escapes special characters ('?', '*', '\')
|
||||
// All other characters must match exactly
|
||||
//
|
||||
// Returns true if all magic bytes are matched, even if input has extra bytes.
|
||||
func Match(magic string, input []byte) bool {
|
||||
return match(magic, input, 0, 0)
|
||||
}
|
||||
|
||||
// match is a recursive helper function that implements the matching logic
|
||||
func match(magic string, input []byte, magicIndex, inputIndex int) bool {
|
||||
// If we've reached the end of magic string, we've successfully matched all magic bytes
|
||||
// It doesn't matter if there are extra bytes in the input
|
||||
if magicIndex == len(magic) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle escape character
|
||||
if magic[magicIndex] == '\\' {
|
||||
// Check if there's a next character in magic
|
||||
if magicIndex+1 >= len(magic) {
|
||||
// Backslash at end of magic string - treat as literal backslash
|
||||
if inputIndex >= len(input) || input[inputIndex] != '\\' {
|
||||
return false
|
||||
}
|
||||
return match(magic, input, magicIndex+1, inputIndex+1)
|
||||
}
|
||||
|
||||
// Escape the next character - we need to match it literally
|
||||
escapedChar := magic[magicIndex+1]
|
||||
if inputIndex >= len(input) || input[inputIndex] != escapedChar {
|
||||
return false
|
||||
}
|
||||
// Skip both the backslash and the escaped character in magic, move one in input
|
||||
return match(magic, input, magicIndex+2, inputIndex+1)
|
||||
}
|
||||
|
||||
// If we've reached the end of input but not magic string
|
||||
if inputIndex == len(input) {
|
||||
// If we have '*' at the current position, it can match zero characters
|
||||
if magic[magicIndex] == '*' {
|
||||
return match(magic, input, magicIndex+1, inputIndex)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle '*' character - matches zero or more characters
|
||||
if magic[magicIndex] == '*' {
|
||||
// Try matching zero characters
|
||||
if match(magic, input, magicIndex+1, inputIndex) {
|
||||
return true
|
||||
}
|
||||
// Try matching one or more characters
|
||||
if inputIndex < len(input) && match(magic, input, magicIndex, inputIndex+1) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle '?' character - matches any single character
|
||||
if magic[magicIndex] == '?' {
|
||||
return match(magic, input, magicIndex+1, inputIndex+1)
|
||||
}
|
||||
|
||||
// Handle exact character match
|
||||
if magic[magicIndex] == input[inputIndex] {
|
||||
return match(magic, input, magicIndex+1, inputIndex+1)
|
||||
}
|
||||
|
||||
// No match found
|
||||
return false
|
||||
}
|
227
protocol/match_test.go
Normal file
227
protocol/match_test.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMatch(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
magic string
|
||||
input []byte
|
||||
expected bool
|
||||
}{
|
||||
// Basic escaping tests
|
||||
{
|
||||
name: "escape star",
|
||||
magic: "\\*",
|
||||
input: []byte("*"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "escape star no match",
|
||||
magic: "\\*",
|
||||
input: []byte("a"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "escape star with longer input",
|
||||
magic: "\\*",
|
||||
input: []byte("*extra"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "escape question mark",
|
||||
magic: "\\?",
|
||||
input: []byte("?"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "escape question mark no match",
|
||||
magic: "\\?",
|
||||
input: []byte("a"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "escape backslash",
|
||||
magic: "\\\\",
|
||||
input: []byte("\\"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "escape backslash no match",
|
||||
magic: "\\\\",
|
||||
input: []byte("a"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "escape backslash with longer input",
|
||||
magic: "\\\\",
|
||||
input: []byte("\\extra"),
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Multiple escaped characters
|
||||
{
|
||||
name: "multiple escaped characters",
|
||||
magic: "\\*\\?\\\\",
|
||||
input: []byte("*?\\"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "multiple escaped characters with longer input",
|
||||
magic: "\\*\\?\\\\",
|
||||
input: []byte("*?\\extra"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "mixed escaped characters",
|
||||
magic: "a\\*b\\?c\\\\d",
|
||||
input: []byte("a*b?c\\d"),
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Escaping combined with wildcards
|
||||
{
|
||||
name: "star then escaped star",
|
||||
magic: "*\\*",
|
||||
input: []byte("anything*"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "star then escaped star must end with star",
|
||||
magic: "*\\*",
|
||||
input: []byte("anything"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "star then escaped question",
|
||||
magic: "*\\?",
|
||||
input: []byte("hello?"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "question then escaped star",
|
||||
magic: "?\\*",
|
||||
input: []byte("a*"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "question then escaped star wrong second char",
|
||||
magic: "?\\*",
|
||||
input: []byte("aa"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "wildcards between escaped characters",
|
||||
magic: "*\\\\*",
|
||||
input: []byte("path\\to\\file"),
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Real-world escaping scenarios
|
||||
{
|
||||
name: "file pattern with literal star",
|
||||
magic: "file\\*.txt",
|
||||
input: []byte("file*.txt"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "file pattern with literal star no match",
|
||||
magic: "file\\*.txt",
|
||||
input: []byte("filex.txt"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "pattern with literal question",
|
||||
magic: "what\\?*",
|
||||
input: []byte("what? is this"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "pattern with literal question must have question",
|
||||
magic: "what\\?*",
|
||||
input: []byte("what is this"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "database like pattern",
|
||||
magic: "table_\\*_\\?",
|
||||
input: []byte("table_*_?"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "database like pattern with longer input",
|
||||
magic: "table_\\*_\\?",
|
||||
input: []byte("table_*_?backup"),
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Edge cases with escaping
|
||||
{
|
||||
name: "backslash at end of magic",
|
||||
magic: "test\\",
|
||||
input: []byte("test\\"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "backslash at end of magic no match",
|
||||
magic: "test\\",
|
||||
input: []byte("test"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "only backslash",
|
||||
magic: "\\",
|
||||
input: []byte("\\"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "consecutive backslashes",
|
||||
magic: "\\\\\\\\",
|
||||
input: []byte("\\\\"),
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Mixed scenarios with both escaping and wildcards
|
||||
{
|
||||
name: "escaped wildcards in middle",
|
||||
magic: "a*\\?b*\\*c",
|
||||
input: []byte("aanything?banything*c"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "escaped wildcards pattern",
|
||||
magic: "select * from \\*",
|
||||
input: []byte("select * from *"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "escaped wildcards pattern with longer input",
|
||||
magic: "select * from *\\*",
|
||||
input: []byte("select name from users*"),
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := Match(test.magic, test.input)
|
||||
if result != test.expected {
|
||||
t.Errorf("Match(%q, %q) = %v, want %v",
|
||||
test.magic, string(test.input), result, test.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark test with escaping
|
||||
func BenchmarkMatch(b *testing.B) {
|
||||
magic := "file\\*\\?*\\\\*.txt"
|
||||
input := []byte("file*?name\\backup.txt")
|
||||
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
Match(magic, input)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user