184 lines
3.9 KiB
Go
184 lines
3.9 KiB
Go
package protocol
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"slices"
|
|
"sync"
|
|
"sync/atomic"
|
|
)
|
|
|
|
// Strict mode requires a full, compliant packet to be captured. This is only
|
|
// implemented by some detectors.
|
|
var Strict bool
|
|
|
|
// Common errors.
|
|
var (
|
|
ErrTimeout = errors.New("timeout")
|
|
ErrUnknown = errors.New("unknown protocol")
|
|
)
|
|
|
|
// Direction indicates the communcation direction.
|
|
type Direction int
|
|
|
|
// Directions supported by this package.
|
|
const (
|
|
Unknown Direction = iota
|
|
Client
|
|
Server
|
|
Both
|
|
)
|
|
|
|
func (dir Direction) Contains(other Direction) bool {
|
|
switch dir {
|
|
case Client:
|
|
return other == Client || other == Both
|
|
case Server:
|
|
return other == Server || other == Both
|
|
case Both:
|
|
return other == Client || other == Server
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
var directionName = map[Direction]string{
|
|
Client: "client",
|
|
Server: "server",
|
|
Both: "both",
|
|
}
|
|
|
|
func (dir Direction) String() string {
|
|
if s, ok := directionName[dir]; ok {
|
|
return s
|
|
}
|
|
return fmt.Sprintf("invalid (%d)", int(dir))
|
|
}
|
|
|
|
type format struct {
|
|
dir Direction
|
|
magic string
|
|
detect DetectFunc
|
|
}
|
|
|
|
// Formats is the list of registered formats.
|
|
var (
|
|
formatsMu sync.Mutex
|
|
atomicFormats atomic.Value
|
|
)
|
|
|
|
type detectResult struct {
|
|
// Protocol detected, nil if no detection.
|
|
Protocol *Protocol
|
|
|
|
// Confidence level [0..1].
|
|
Confidence float64
|
|
}
|
|
|
|
type DetectFunc func(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64)
|
|
|
|
func Register(dir Direction, magic string, detect DetectFunc) {
|
|
formatsMu.Lock()
|
|
formats, _ := atomicFormats.Load().([]format)
|
|
atomicFormats.Store(append(formats, format{dir, magic, detect}))
|
|
formatsMu.Unlock()
|
|
}
|
|
|
|
func matchMagic(magic string, data []byte) bool {
|
|
// Empty magic means the detector will always run.
|
|
if len(magic) == 0 {
|
|
return true
|
|
}
|
|
|
|
// The buffer should contain at least the same number of bytes
|
|
// as our magic.
|
|
if len(data) < len(magic) {
|
|
return false
|
|
}
|
|
|
|
// Match bytes in magic with bytes in data.
|
|
for i, b := range []byte(magic) {
|
|
if b != '?' && data[i] != b {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// Detect a protocol based on the provided data.
|
|
func Detect(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64, err error) {
|
|
var (
|
|
formats, _ = atomicFormats.Load().([]format)
|
|
results []detectResult
|
|
)
|
|
for _, format := range formats {
|
|
if format.dir.Contains(dir) {
|
|
// Check the buffer to see if we have sufficient bytes
|
|
if matchMagic(format.magic, data) {
|
|
if proto, confidence := format.detect(dir, data, srcPort, dstPort); proto != nil {
|
|
results = append(results, detectResult{proto, confidence})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(results) > 0 {
|
|
slices.SortStableFunc(results, func(a, b detectResult) int {
|
|
return compareFloats(b.Confidence, a.Confidence)
|
|
})
|
|
return results[0].Protocol, results[0].Confidence, nil
|
|
}
|
|
|
|
return nil, 0, ErrUnknown
|
|
}
|
|
|
|
// compareFloats compares two float64 numbers with tolerance for floating-point precision.
|
|
//
|
|
// Returns:
|
|
//
|
|
// -1 if a < b
|
|
// 0 if a == b (within tolerance)
|
|
// 1 if a > b
|
|
func compareFloats(a, b float64) int {
|
|
// Define the tolerance for floating-point comparison
|
|
const tolerance = 1e-9
|
|
|
|
// Handle special cases: NaN and Inf
|
|
if math.IsNaN(a) || math.IsNaN(b) {
|
|
// NaN is considered equal to itself, otherwise not equal
|
|
if math.IsNaN(a) && math.IsNaN(b) {
|
|
return 0
|
|
}
|
|
if math.IsNaN(a) {
|
|
return -1 // NaN is considered less than any number
|
|
}
|
|
return 1 // Any number is greater than NaN
|
|
}
|
|
|
|
// Handle infinity cases
|
|
if math.IsInf(a, 0) || math.IsInf(b, 0) {
|
|
if a < b {
|
|
return -1
|
|
}
|
|
if a > b {
|
|
return 1
|
|
}
|
|
return 0 // Both are same infinity
|
|
}
|
|
|
|
// Compare with tolerance for regular numbers
|
|
diff := a - b
|
|
|
|
// If the absolute difference is within tolerance, consider them equal
|
|
if math.Abs(diff) < tolerance {
|
|
return 0
|
|
}
|
|
|
|
// Otherwise return the comparison result
|
|
if diff < 0 {
|
|
return -1
|
|
}
|
|
return 1
|
|
}
|