Files
dpi/protocol/detect.go

191 lines
4.5 KiB
Go

package protocol
import (
"errors"
"fmt"
"math"
"slices"
"sync"
"sync/atomic"
)
// Strict mode requires a full, compliant packet to be captured. This is only
// implemented by some detectors.
var Strict bool
// Common errors.
var (
ErrTimeout = errors.New("timeout")
ErrUnknown = errors.New("unknown protocol")
)
// Direction indicates the communcation direction.
type Direction int
// Directions supported by this package.
const (
// Unknown direction is the default value and is not a valid Direction.
Unknown Direction = iota
// Client initiated.
Client
// Server initiated.
Server
// Both is either client or server initiated.
Both
)
// Contains checks if the provided other direction is included in this direction.
func (dir Direction) Contains(other Direction) bool {
switch dir {
case Client:
return other == Client || other == Both
case Server:
return other == Server || other == Both
case Both:
return other == Client || other == Server || other == Both
default:
return false
}
}
var directionName = map[Direction]string{
Client: "client",
Server: "server",
Both: "both",
}
// IsValid checks if dir has a value recognized by this library.
//
// Also Unknown direction is not considered valid.
func (dir Direction) IsValid() bool {
return dir > Unknown && dir <= Both
}
func (dir Direction) String() string {
if s, ok := directionName[dir]; ok {
return s
}
return fmt.Sprintf("invalid (%d)", int(dir))
}
type format struct {
dir Direction
magic string
detect DetectFunc
}
// Formats is the list of registered formats.
var (
formatsMu sync.Mutex
atomicFormats atomic.Value
)
type detectResult struct {
// Protocol detected, nil if no detection.
Protocol *Protocol
// Confidence level [0..1].
Confidence float64
}
// DetectFunc is a function which runs the in-depth protcol detection logic.
//
// The confidence score should be between 0 and 0.99. Score boundaries are not hard enforced by this library.
type DetectFunc func(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64)
// Register a new protocol detector.
//
// The direction indicates in what direction we'll inspect the []byte slice passed to the [Detect]
// function. Passing an invalid direction or None will be discarded.
//
// The magic string is used to quickly analyze if the []byte slice passed to [Detect] qualifies
// for further inspection by the [DetectFunc]. See the [Match] function documentation for how
// magic strings are matched against the input.
func Register(dir Direction, magic string, detect DetectFunc) {
if dir.IsValid() {
formatsMu.Lock()
formats, _ := atomicFormats.Load().([]format)
atomicFormats.Store(append(formats, format{dir, magic, detect}))
formatsMu.Unlock()
}
}
// Detect a protocol based on the provided data.
func Detect(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64, err error) {
var (
formats, _ = atomicFormats.Load().([]format)
results []detectResult
)
for _, format := range formats {
if format.dir.Contains(dir) {
// Check the buffer to see if we have sufficient bytes
if Match(format.magic, data) {
if proto, confidence := format.detect(dir, data, srcPort, dstPort); proto != nil {
results = append(results, detectResult{proto, confidence})
}
}
}
}
if len(results) > 0 {
slices.SortStableFunc(results, func(a, b detectResult) int {
return compareFloats(b.Confidence, a.Confidence)
})
return results[0].Protocol, results[0].Confidence, nil
}
return nil, 0, ErrUnknown
}
// compareFloats compares two float64 numbers with tolerance for floating-point precision.
//
// Returns:
//
// -1 if a < b
// 0 if a == b (within tolerance)
// 1 if a > b
func compareFloats(a, b float64) int {
// Define the tolerance for floating-point comparison
const tolerance = 1e-9
// Handle special cases: NaN and Inf
if math.IsNaN(a) || math.IsNaN(b) {
// NaN is considered equal to itself, otherwise not equal
if math.IsNaN(a) && math.IsNaN(b) {
return 0
}
if math.IsNaN(a) {
return -1 // NaN is considered less than any number
}
return 1 // Any number is greater than NaN
}
// Handle infinity cases
if math.IsInf(a, 0) || math.IsInf(b, 0) {
if a < b {
return -1
}
if a > b {
return 1
}
return 0 // Both are same infinity
}
// Compare with tolerance for regular numbers
diff := a - b
// If the absolute difference is within tolerance, consider them equal
if math.Abs(diff) < tolerance {
return 0
}
// Otherwise return the comparison result
if diff < 0 {
return -1
}
return 1
}