package protocol import ( "errors" "fmt" "math" "slices" "sync" "sync/atomic" ) // Strict mode requires a full, compliant packet to be captured. This is only // implemented by some detectors. var Strict bool // Common errors. var ( ErrTimeout = errors.New("timeout") ErrUnknown = errors.New("unknown protocol") ) // Direction indicates the communcation direction. type Direction int // Directions supported by this package. const ( // Unknown direction is the default value and is not a valid Direction. Unknown Direction = iota // Client initiated. Client // Server initiated. Server // Both is either client or server initiated. Both ) // Contains checks if the provided other direction is included in this direction. func (dir Direction) Contains(other Direction) bool { switch dir { case Client: return other == Client || other == Both case Server: return other == Server || other == Both case Both: return other == Client || other == Server || other == Both default: return false } } var directionName = map[Direction]string{ Client: "client", Server: "server", Both: "both", } // IsValid checks if dir has a value recognized by this library. // // Also Unknown direction is not considered valid. func (dir Direction) IsValid() bool { return dir > Unknown && dir <= Both } func (dir Direction) String() string { if s, ok := directionName[dir]; ok { return s } return fmt.Sprintf("invalid (%d)", int(dir)) } type format struct { dir Direction magic string detect DetectFunc } // Formats is the list of registered formats. var ( formatsMu sync.Mutex atomicFormats atomic.Value ) type detectResult struct { // Protocol detected, nil if no detection. Protocol *Protocol // Confidence level [0..1]. Confidence float64 } // DetectFunc is a function which runs the in-depth protcol detection logic. // // The confidence score should be between 0 and 0.99. Score boundaries are not hard enforced by this library. type DetectFunc func(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) // Register a new protocol detector. // // The direction indicates in what direction we'll inspect the []byte slice passed to the [Detect] // function. Passing an invalid direction or None will be discarded. // // The magic string is used to quickly analyze if the []byte slice passed to [Detect] qualifies // for further inspection by the [DetectFunc]. See the [Match] function documentation for how // magic strings are matched against the input. func Register(dir Direction, magic string, detect DetectFunc) { if dir.IsValid() { formatsMu.Lock() formats, _ := atomicFormats.Load().([]format) atomicFormats.Store(append(formats, format{dir, magic, detect})) formatsMu.Unlock() } } // Detect a protocol based on the provided data. func Detect(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64, err error) { var ( formats, _ = atomicFormats.Load().([]format) results []detectResult ) for _, format := range formats { if format.dir.Contains(dir) { // Check the buffer to see if we have sufficient bytes if Match(format.magic, data) { if proto, confidence := format.detect(dir, data, srcPort, dstPort); proto != nil { results = append(results, detectResult{proto, confidence}) } } } } if len(results) > 0 { slices.SortStableFunc(results, func(a, b detectResult) int { return compareFloats(b.Confidence, a.Confidence) }) return results[0].Protocol, results[0].Confidence, nil } return nil, 0, ErrUnknown } // compareFloats compares two float64 numbers with tolerance for floating-point precision. // // Returns: // // -1 if a < b // 0 if a == b (within tolerance) // 1 if a > b func compareFloats(a, b float64) int { // Define the tolerance for floating-point comparison const tolerance = 1e-9 // Handle special cases: NaN and Inf if math.IsNaN(a) || math.IsNaN(b) { // NaN is considered equal to itself, otherwise not equal if math.IsNaN(a) && math.IsNaN(b) { return 0 } if math.IsNaN(a) { return -1 // NaN is considered less than any number } return 1 // Any number is greater than NaN } // Handle infinity cases if math.IsInf(a, 0) || math.IsInf(b, 0) { if a < b { return -1 } if a > b { return 1 } return 0 // Both are same infinity } // Compare with tolerance for regular numbers diff := a - b // If the absolute difference is within tolerance, consider them equal if math.Abs(diff) < tolerance { return 0 } // Otherwise return the comparison result if diff < 0 { return -1 } return 1 }