package protocol import ( "errors" "fmt" "math" "math/rand" "strconv" "strings" "testing" ) type testCase struct { Name string Direction Direction Data []byte SrcPort int DstPort int WantType string WantConfidence float64 WantError error } func testRunner(t *testing.T, tests []*testCase) { t.Helper() for _, test := range tests { t.Run(test.Name, func(t *testing.T) { if test.SrcPort == 0 { test.SrcPort = 1024 + rand.Intn(65535-1024) } if test.DstPort == 0 { test.DstPort = 1024 + rand.Intn(65535-1024) } proto, confidence, err := Detect(test.Direction, test.Data, test.SrcPort, test.DstPort) // Process error first if err != nil { if test.WantError == nil { t.Fatalf("unexpected error: %v", err) } else if !errors.Is(err, test.WantError) { t.Fatalf("Detect(%s, %s, %d, %d) returned error %q, expected %q", test.Direction, testBytesSample(test.Data, 8), test.SrcPort, test.DstPort, err, test.WantError) } else { t.Logf("Detect(%s, %s, %d, %d) returned error %q as expected", test.Direction, testBytesSample(test.Data, 8), test.SrcPort, test.DstPort, err) } return } else if test.WantError != nil { t.Fatalf("Detect(%s, %s, %d, %d) returned protocol %q version %s, expected error %q", test.Direction, testBytesSample(test.Data, 8), test.SrcPort, test.DstPort, proto.Type, proto.Version, test.WantError) return } // Process protocol if proto == nil { t.Fatalf("Detect(%s, %s, %d, %d) returned nil, expected protocol %q", test.Direction, testBytesSample(test.Data, 8), test.SrcPort, test.DstPort, test.WantType) return } t.Logf("Detect(%s, %s, %d, %d) returned protocol %s with confidence %g%%", test.Direction, testBytesSample(test.Data, 4), test.SrcPort, test.DstPort, proto, confidence*100) if proto.Type != test.WantType { t.Errorf("Expected protocol %q, got %q", test.WantType, proto.Type) } if !testAlmostEqual(confidence, test.WantConfidence) { t.Errorf("Expected confidence %g%%", test.WantConfidence*100) } }) } } func testBytesSample(b []byte, n int) string { if b == nil { return "" } var ( hex []string etc string ) for i, l := 0, len(b); i < l && i < n; i++ { if strconv.IsPrint(rune(b[i])) { hex = append(hex, fmt.Sprintf("%c", b[i])) } else { hex = append(hex, fmt.Sprintf("\\x%02X", b[i])) } if i == (n-1) && l > (n-1) { etc = fmt.Sprintf(" … (%d more)", l-n) } } return fmt.Sprintf(`"%s"%s`, strings.Join(hex, ""), etc) } func testAlmostEqual(a, b float64) bool { const e = 1e-9 return math.Abs(a-b) < e } func TestCompareFloats(t *testing.T) { tests := []struct { name string a, b float64 expected int }{ // Basic comparisons { name: "a less than b", a: 1.0, b: 2.0, expected: -1, }, { name: "a greater than b", a: 2.0, b: 1.0, expected: 1, }, { name: "a equals b exact", a: 1.0, b: 1.0, expected: 0, }, // Floating-point precision cases { name: "famous 0.1 + 0.2 equals 0.3 within tolerance", a: 0.1 + 0.2, b: 0.3, expected: 0, }, { name: "very close numbers within tolerance", a: 1.0000000001, b: 1.0000000002, expected: 0, }, { name: "numbers outside tolerance a < b", a: 1.0, b: 1.0001, expected: -1, }, { name: "numbers outside tolerance a > b", a: 1.0001, b: 1.0, expected: 1, }, // Edge cases with very small numbers { name: "very small numbers equal", a: 1e-20, b: 1e-20, expected: 0, }, { name: "very small numbers a < b", a: 1e-9, b: 2e-9, expected: -1, }, // Zero and negative zero { name: "zero equals zero", a: 0.0, b: 0.0, expected: 0, }, { name: "zero equals negative zero", a: 0.0, b: -0.000000000000000000001, expected: 0, }, { name: "zero less than small positive", a: 0.0, b: 1e-9, expected: -1, }, { name: "zero greater than small negative", a: 0.0, b: -1e-9, expected: 1, }, // Negative numbers { name: "negative numbers a > b", a: -1.0, b: -2.0, expected: 1, }, { name: "negative numbers a < b", a: -2.0, b: -1.0, expected: -1, }, { name: "negative numbers equal", a: -1.0, b: -1.0, expected: 0, }, // Mixed signs { name: "negative less than positive", a: -1.0, b: 1.0, expected: -1, }, { name: "positive greater than negative", a: 1.0, b: -1.0, expected: 1, }, // Special values: NaN { name: "NaN equals NaN", a: math.NaN(), b: math.NaN(), expected: 0, }, { name: "NaN less than number", a: math.NaN(), b: 1.0, expected: -1, }, { name: "number greater than NaN", a: 1.0, b: math.NaN(), expected: 1, }, // Special values: Infinity { name: "positive infinity equals positive infinity", a: math.Inf(1), b: math.Inf(1), expected: 0, }, { name: "negative infinity equals negative infinity", a: math.Inf(-1), b: math.Inf(-1), expected: 0, }, { name: "positive infinity greater than negative infinity", a: math.Inf(1), b: math.Inf(-1), expected: 1, }, { name: "negative infinity less than positive infinity", a: math.Inf(-1), b: math.Inf(1), expected: -1, }, { name: "positive infinity greater than large number", a: math.Inf(1), b: 1e308, expected: 1, }, { name: "negative infinity less than small number", a: math.Inf(-1), b: -1e308, expected: -1, }, // Large numbers { name: "large numbers equal", a: 1e15, b: 1e15, expected: 0, }, { name: "large numbers a < b", a: 1e15, b: 2e15, expected: -1, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { result := compareFloats(test.a, test.b) if result != test.expected { t.Errorf("compareFloats(%g, %g) = %d, want %d", test.a, test.b, result, test.expected) } }) } }