Refactored detection logic to include ports and a confidence score
This commit is contained in:
@@ -3,6 +3,8 @@ package protocol
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
@@ -66,7 +68,15 @@ var (
|
||||
atomicFormats atomic.Value
|
||||
)
|
||||
|
||||
type DetectFunc func(Direction, []byte) *Protocol
|
||||
type detectResult struct {
|
||||
// Protocol detected, nil if no detection.
|
||||
Protocol *Protocol
|
||||
|
||||
// Confidence level [0..1].
|
||||
Confidence float64
|
||||
}
|
||||
|
||||
type DetectFunc func(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64)
|
||||
|
||||
func Register(dir Direction, magic string, detect DetectFunc) {
|
||||
formatsMu.Lock()
|
||||
@@ -97,17 +107,77 @@ func matchMagic(magic string, data []byte) bool {
|
||||
}
|
||||
|
||||
// Detect a protocol based on the provided data.
|
||||
func Detect(dir Direction, data []byte) (*Protocol, error) {
|
||||
formats, _ := atomicFormats.Load().([]format)
|
||||
for _, f := range formats {
|
||||
if f.dir.Contains(dir) {
|
||||
func Detect(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64, err error) {
|
||||
var (
|
||||
formats, _ = atomicFormats.Load().([]format)
|
||||
results []detectResult
|
||||
)
|
||||
for _, format := range formats {
|
||||
if format.dir.Contains(dir) {
|
||||
// Check the buffer to see if we have sufficient bytes
|
||||
if matchMagic(f.magic, data) {
|
||||
if p := f.detect(dir, data); p != nil {
|
||||
return p, nil
|
||||
if matchMagic(format.magic, data) {
|
||||
if proto, confidence := format.detect(dir, data, srcPort, dstPort); proto != nil {
|
||||
results = append(results, detectResult{proto, confidence})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, ErrUnknown
|
||||
|
||||
if len(results) > 0 {
|
||||
slices.SortStableFunc(results, func(a, b detectResult) int {
|
||||
return compareFloats(b.Confidence, a.Confidence)
|
||||
})
|
||||
return results[0].Protocol, results[0].Confidence, nil
|
||||
}
|
||||
|
||||
return nil, 0, ErrUnknown
|
||||
}
|
||||
|
||||
// compareFloats compares two float64 numbers with tolerance for floating-point precision.
|
||||
//
|
||||
// Returns:
|
||||
//
|
||||
// -1 if a < b
|
||||
// 0 if a == b (within tolerance)
|
||||
// 1 if a > b
|
||||
func compareFloats(a, b float64) int {
|
||||
// Define the tolerance for floating-point comparison
|
||||
const tolerance = 1e-9
|
||||
|
||||
// Handle special cases: NaN and Inf
|
||||
if math.IsNaN(a) || math.IsNaN(b) {
|
||||
// NaN is considered equal to itself, otherwise not equal
|
||||
if math.IsNaN(a) && math.IsNaN(b) {
|
||||
return 0
|
||||
}
|
||||
if math.IsNaN(a) {
|
||||
return -1 // NaN is considered less than any number
|
||||
}
|
||||
return 1 // Any number is greater than NaN
|
||||
}
|
||||
|
||||
// Handle infinity cases
|
||||
if math.IsInf(a, 0) || math.IsInf(b, 0) {
|
||||
if a < b {
|
||||
return -1
|
||||
}
|
||||
if a > b {
|
||||
return 1
|
||||
}
|
||||
return 0 // Both are same infinity
|
||||
}
|
||||
|
||||
// Compare with tolerance for regular numbers
|
||||
diff := a - b
|
||||
|
||||
// If the absolute difference is within tolerance, consider them equal
|
||||
if math.Abs(diff) < tolerance {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Otherwise return the comparison result
|
||||
if diff < 0 {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
Reference in New Issue
Block a user