Refactored detection logic to include ports and a confidence score
This commit is contained in:
		
							
								
								
									
										2
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.mod
									
									
									
									
									
								
							@@ -2,4 +2,4 @@ module git.maze.io/go/dpi
 | 
			
		||||
 | 
			
		||||
go 1.25
 | 
			
		||||
 | 
			
		||||
require golang.org/x/crypto v0.42.0 // indirect
 | 
			
		||||
require golang.org/x/crypto v0.42.0
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,8 @@ package protocol
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"math"
 | 
			
		||||
	"slices"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
)
 | 
			
		||||
@@ -66,7 +68,15 @@ var (
 | 
			
		||||
	atomicFormats atomic.Value
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type DetectFunc func(Direction, []byte) *Protocol
 | 
			
		||||
type detectResult struct {
 | 
			
		||||
	// Protocol detected, nil if no detection.
 | 
			
		||||
	Protocol *Protocol
 | 
			
		||||
 | 
			
		||||
	// Confidence level [0..1].
 | 
			
		||||
	Confidence float64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DetectFunc func(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64)
 | 
			
		||||
 | 
			
		||||
func Register(dir Direction, magic string, detect DetectFunc) {
 | 
			
		||||
	formatsMu.Lock()
 | 
			
		||||
@@ -97,17 +107,77 @@ func matchMagic(magic string, data []byte) bool {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Detect a protocol based on the provided data.
 | 
			
		||||
func Detect(dir Direction, data []byte) (*Protocol, error) {
 | 
			
		||||
	formats, _ := atomicFormats.Load().([]format)
 | 
			
		||||
	for _, f := range formats {
 | 
			
		||||
		if f.dir.Contains(dir) {
 | 
			
		||||
func Detect(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64, err error) {
 | 
			
		||||
	var (
 | 
			
		||||
		formats, _ = atomicFormats.Load().([]format)
 | 
			
		||||
		results    []detectResult
 | 
			
		||||
	)
 | 
			
		||||
	for _, format := range formats {
 | 
			
		||||
		if format.dir.Contains(dir) {
 | 
			
		||||
			// Check the buffer to see if we have sufficient bytes
 | 
			
		||||
			if matchMagic(f.magic, data) {
 | 
			
		||||
				if p := f.detect(dir, data); p != nil {
 | 
			
		||||
					return p, nil
 | 
			
		||||
			if matchMagic(format.magic, data) {
 | 
			
		||||
				if proto, confidence := format.detect(dir, data, srcPort, dstPort); proto != nil {
 | 
			
		||||
					results = append(results, detectResult{proto, confidence})
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil, ErrUnknown
 | 
			
		||||
 | 
			
		||||
	if len(results) > 0 {
 | 
			
		||||
		slices.SortStableFunc(results, func(a, b detectResult) int {
 | 
			
		||||
			return compareFloats(b.Confidence, a.Confidence)
 | 
			
		||||
		})
 | 
			
		||||
		return results[0].Protocol, results[0].Confidence, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, 0, ErrUnknown
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// compareFloats compares two float64 numbers with tolerance for floating-point precision.
 | 
			
		||||
//
 | 
			
		||||
// Returns:
 | 
			
		||||
//
 | 
			
		||||
//	-1 if a < b
 | 
			
		||||
//	 0 if a == b (within tolerance)
 | 
			
		||||
//	 1 if a > b
 | 
			
		||||
func compareFloats(a, b float64) int {
 | 
			
		||||
	// Define the tolerance for floating-point comparison
 | 
			
		||||
	const tolerance = 1e-9
 | 
			
		||||
 | 
			
		||||
	// Handle special cases: NaN and Inf
 | 
			
		||||
	if math.IsNaN(a) || math.IsNaN(b) {
 | 
			
		||||
		// NaN is considered equal to itself, otherwise not equal
 | 
			
		||||
		if math.IsNaN(a) && math.IsNaN(b) {
 | 
			
		||||
			return 0
 | 
			
		||||
		}
 | 
			
		||||
		if math.IsNaN(a) {
 | 
			
		||||
			return -1 // NaN is considered less than any number
 | 
			
		||||
		}
 | 
			
		||||
		return 1 // Any number is greater than NaN
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Handle infinity cases
 | 
			
		||||
	if math.IsInf(a, 0) || math.IsInf(b, 0) {
 | 
			
		||||
		if a < b {
 | 
			
		||||
			return -1
 | 
			
		||||
		}
 | 
			
		||||
		if a > b {
 | 
			
		||||
			return 1
 | 
			
		||||
		}
 | 
			
		||||
		return 0 // Both are same infinity
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Compare with tolerance for regular numbers
 | 
			
		||||
	diff := a - b
 | 
			
		||||
 | 
			
		||||
	// If the absolute difference is within tolerance, consider them equal
 | 
			
		||||
	if math.Abs(diff) < tolerance {
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Otherwise return the comparison result
 | 
			
		||||
	if diff < 0 {
 | 
			
		||||
		return -1
 | 
			
		||||
	}
 | 
			
		||||
	return 1
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -12,10 +12,17 @@ func init() {
 | 
			
		||||
	Register(Server, "HTTP/?.", detectHTTPResponse)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func detectHTTPRequest(dir Direction, data []byte) *Protocol {
 | 
			
		||||
func detectHTTPRequest(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) {
 | 
			
		||||
	// A minimal request "GET / HTTP/1.0\r\n" is > 8 bytes.
 | 
			
		||||
	if len(data) < 8 {
 | 
			
		||||
		return nil
 | 
			
		||||
		return nil, 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch dstPort {
 | 
			
		||||
	case 80, 8080: // Common HTTP ports
 | 
			
		||||
		confidence = +.1
 | 
			
		||||
	case 3128: // Common HTTP proxy port
 | 
			
		||||
		confidence = -.1
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if Strict {
 | 
			
		||||
@@ -31,38 +38,27 @@ func detectHTTPRequest(dir Direction, data []byte) *Protocol {
 | 
			
		||||
					Minor: request.ProtoMinor,
 | 
			
		||||
					Patch: -1,
 | 
			
		||||
				},
 | 
			
		||||
			}
 | 
			
		||||
			}, confidence + .85
 | 
			
		||||
		}
 | 
			
		||||
		r.Reset(bytes.NewReader(b))
 | 
			
		||||
		if response, err := http.ReadResponse(r, nil); err == nil {
 | 
			
		||||
			return &Protocol{
 | 
			
		||||
				Name: ProtocolHTTP,
 | 
			
		||||
				Version: Version{
 | 
			
		||||
					Major: response.ProtoMajor,
 | 
			
		||||
					Minor: response.ProtoMinor,
 | 
			
		||||
					Patch: -1,
 | 
			
		||||
				},
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
		return nil, 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	crlfIndex := bytes.IndexFunc(data, func(r rune) bool {
 | 
			
		||||
		return r == '\r' || r == '\n'
 | 
			
		||||
	})
 | 
			
		||||
	if crlfIndex == -1 {
 | 
			
		||||
		return nil
 | 
			
		||||
		return nil, 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// A request has three, space-separated parts.
 | 
			
		||||
	part := bytes.Split(data[:crlfIndex], []byte(" "))
 | 
			
		||||
	if len(part) != 3 {
 | 
			
		||||
		return nil
 | 
			
		||||
		return nil, 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// The last part starts with "HTTP/".
 | 
			
		||||
	if !bytes.HasPrefix(part[2], []byte("HTTP/1")) {
 | 
			
		||||
		return nil
 | 
			
		||||
		return nil, 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var version = Version{Patch: -1}
 | 
			
		||||
@@ -71,17 +67,42 @@ func detectHTTPRequest(dir Direction, data []byte) *Protocol {
 | 
			
		||||
	return &Protocol{
 | 
			
		||||
		Name:    ProtocolHTTP,
 | 
			
		||||
		Version: version,
 | 
			
		||||
	}
 | 
			
		||||
	}, confidence + .75
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func detectHTTPResponse(dir Direction, data []byte) *Protocol {
 | 
			
		||||
func detectHTTPResponse(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) {
 | 
			
		||||
	if !dir.Contains(Server) {
 | 
			
		||||
		return nil
 | 
			
		||||
		return nil, 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// A minimal response "HTTP/1.0 200 OK\r\n" is > 8 bytes.
 | 
			
		||||
	if len(data) < 8 {
 | 
			
		||||
		return nil
 | 
			
		||||
		return nil, 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch srcPort {
 | 
			
		||||
	case 80, 8080: // Common HTTP ports
 | 
			
		||||
		confidence = +.1
 | 
			
		||||
	case 3128: // Common HTTP proxy port
 | 
			
		||||
		confidence = -.1
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if Strict {
 | 
			
		||||
		var (
 | 
			
		||||
			b = append(data, '\r', '\n')
 | 
			
		||||
			r = bufio.NewReader(bytes.NewReader(b))
 | 
			
		||||
		)
 | 
			
		||||
		if response, err := http.ReadResponse(r, nil); err == nil {
 | 
			
		||||
			return &Protocol{
 | 
			
		||||
				Name: ProtocolHTTP,
 | 
			
		||||
				Version: Version{
 | 
			
		||||
					Major: response.ProtoMajor,
 | 
			
		||||
					Minor: response.ProtoMinor,
 | 
			
		||||
					Patch: -1,
 | 
			
		||||
				},
 | 
			
		||||
			}, confidence + .85
 | 
			
		||||
		}
 | 
			
		||||
		return nil, 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var version = Version{Patch: -1}
 | 
			
		||||
@@ -90,5 +111,5 @@ func detectHTTPResponse(dir Direction, data []byte) *Protocol {
 | 
			
		||||
	return &Protocol{
 | 
			
		||||
		Name:    ProtocolHTTP,
 | 
			
		||||
		Version: version,
 | 
			
		||||
	}
 | 
			
		||||
	}, confidence + .75
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -28,12 +28,12 @@ func TestDetectHTTPRequest(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
		t.Run(name, func(t *testing.T) {
 | 
			
		||||
			t.Run("HTTP/1.0 GET", func(t *testing.T) {
 | 
			
		||||
				p, err := Detect(Client, http10Request)
 | 
			
		||||
				p, c, err := Detect(Client, http10Request, 1234, 80)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					t.Fatal(err)
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
				t.Logf("detected %s version %s", p.Name, p.Version)
 | 
			
		||||
				t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
 | 
			
		||||
				if p.Name != ProtocolHTTP {
 | 
			
		||||
					t.Fatalf("expected http protocol, got %s", p.Name)
 | 
			
		||||
					return
 | 
			
		||||
@@ -41,12 +41,12 @@ func TestDetectHTTPRequest(t *testing.T) {
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			t.Run("HTTP/1.1 GET", func(t *testing.T) {
 | 
			
		||||
				p, err := Detect(Client, getRequest)
 | 
			
		||||
				p, c, err := Detect(Client, getRequest, 1234, 80)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					t.Fatal(err)
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
				t.Logf("detected %s version %s", p.Name, p.Version)
 | 
			
		||||
				t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
 | 
			
		||||
				if p.Name != ProtocolHTTP {
 | 
			
		||||
					t.Fatalf("expected http protocol, got %s", p.Name)
 | 
			
		||||
					return
 | 
			
		||||
@@ -54,7 +54,7 @@ func TestDetectHTTPRequest(t *testing.T) {
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			t.Run("Invalid SSH", func(t *testing.T) {
 | 
			
		||||
				_, err := Detect(Server, sshBanner)
 | 
			
		||||
				_, _, err := Detect(Server, sshBanner, 1234, 22)
 | 
			
		||||
				if !errors.Is(err, ErrUnknown) {
 | 
			
		||||
					t.Fatalf("expected unknown format, got error %T: %q", err, err)
 | 
			
		||||
				} else {
 | 
			
		||||
@@ -94,12 +94,12 @@ func TestDetectHTTPResponse(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
		t.Run(name, func(t *testing.T) {
 | 
			
		||||
			t.Run("HTTP/1.0 403", func(t *testing.T) {
 | 
			
		||||
				p, err := Detect(Server, http10Response)
 | 
			
		||||
				p, c, err := Detect(Server, http10Response, 80, 1234)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					t.Fatal(err)
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
				t.Logf("detected %s version %s", p.Name, p.Version)
 | 
			
		||||
				t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
 | 
			
		||||
				if p.Name != ProtocolHTTP {
 | 
			
		||||
					t.Fatalf("expected http protocol, got %s", p.Name)
 | 
			
		||||
					return
 | 
			
		||||
@@ -107,12 +107,12 @@ func TestDetectHTTPResponse(t *testing.T) {
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			t.Run("HTTP/1.1 200", func(t *testing.T) {
 | 
			
		||||
				p, err := Detect(Server, responseOK)
 | 
			
		||||
				p, c, err := Detect(Server, responseOK, 80, 1234)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					t.Fatal(err)
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
				t.Logf("detected %s version %s", p.Name, p.Version)
 | 
			
		||||
				t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
 | 
			
		||||
				if p.Name != ProtocolHTTP {
 | 
			
		||||
					t.Fatalf("expected http protocol, got %s", p.Name)
 | 
			
		||||
					return
 | 
			
		||||
@@ -120,12 +120,12 @@ func TestDetectHTTPResponse(t *testing.T) {
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			t.Run("HTTP/1.1 404", func(t *testing.T) {
 | 
			
		||||
				p, err := Detect(Server, responseNotFound)
 | 
			
		||||
				p, c, err := Detect(Server, responseNotFound, 80, 1234)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					t.Fatal(err)
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
				t.Logf("detected %s version %s", p.Name, p.Version)
 | 
			
		||||
				t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
 | 
			
		||||
				if p.Name != ProtocolHTTP {
 | 
			
		||||
					t.Fatalf("expected http protocol, got %s", p.Name)
 | 
			
		||||
					return
 | 
			
		||||
@@ -133,7 +133,7 @@ func TestDetectHTTPResponse(t *testing.T) {
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			t.Run("Invalid HTTP/1.1 GET", func(t *testing.T) {
 | 
			
		||||
				_, err := Detect(Server, getRequest)
 | 
			
		||||
				_, _, err := Detect(Server, getRequest, 1234, 80)
 | 
			
		||||
				if !errors.Is(err, ErrUnknown) {
 | 
			
		||||
					t.Fatalf("expected unknown format, got error %T: %q", err, err)
 | 
			
		||||
				} else {
 | 
			
		||||
@@ -142,7 +142,7 @@ func TestDetectHTTPResponse(t *testing.T) {
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			t.Run("Invalid SSH", func(t *testing.T) {
 | 
			
		||||
				_, err := Detect(Server, sshBanner)
 | 
			
		||||
				_, _, err := Detect(Server, sshBanner, 22, 1234)
 | 
			
		||||
				if !errors.Is(err, ErrUnknown) {
 | 
			
		||||
					t.Fatalf("expected unknown format, got error %T: %q", err, err)
 | 
			
		||||
				} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -9,15 +9,19 @@ func init() {
 | 
			
		||||
	Register(Server, "\x0a", detectMySQL)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func detectMySQL(dir Direction, data []byte) *Protocol {
 | 
			
		||||
func detectMySQL(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) {
 | 
			
		||||
	if len(data) < 7 {
 | 
			
		||||
		return nil
 | 
			
		||||
		return nil, 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// The first byte of the handshake packet is the protocol version.
 | 
			
		||||
	// For MySQL, this is 10 (0x0A).
 | 
			
		||||
	if data[0] != 0x0A {
 | 
			
		||||
		return nil
 | 
			
		||||
		return nil, 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if srcPort == 3306 {
 | 
			
		||||
		confidence = .1
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// After the protocol version, there is a null-terminated server version string.
 | 
			
		||||
@@ -26,7 +30,7 @@ func detectMySQL(dir Direction, data []byte) *Protocol {
 | 
			
		||||
 | 
			
		||||
	// If no null byte is found, it's not a valid banner.
 | 
			
		||||
	if nullIndex == -1 {
 | 
			
		||||
		return nil
 | 
			
		||||
		return nil, 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// The position of the null byte is relative to the start of the whole slice.
 | 
			
		||||
@@ -38,7 +42,7 @@ func detectMySQL(dir Direction, data []byte) *Protocol {
 | 
			
		||||
	// We'll check for the 4-byte connection ID as a minimum requirement.
 | 
			
		||||
	const connectionIDLength = 4
 | 
			
		||||
	if len(data) < serverVersionEndPos+1+connectionIDLength {
 | 
			
		||||
		return nil
 | 
			
		||||
		return nil, 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var version Version
 | 
			
		||||
@@ -47,5 +51,5 @@ func detectMySQL(dir Direction, data []byte) *Protocol {
 | 
			
		||||
	return &Protocol{
 | 
			
		||||
		Name:    ProtocolMySQL,
 | 
			
		||||
		Version: version,
 | 
			
		||||
	}
 | 
			
		||||
	}, confidence + .75
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -38,23 +38,23 @@ func TestDetectMySQL(t *testing.T) {
 | 
			
		||||
	malformedSlice := []byte{0x0a, 0x38, 0x2e, 0x30, 0x2e, 0x30, 0x01, 0x02, 0x03, 0x04, 0x05}
 | 
			
		||||
 | 
			
		||||
	t.Run("MySQL 8", func(t *testing.T) {
 | 
			
		||||
		p, _ := Detect(Server, mysql8Banner)
 | 
			
		||||
		p, c, _ := Detect(Server, mysql8Banner, 3306, 0)
 | 
			
		||||
		if p == nil {
 | 
			
		||||
			t.Fatal("expected MySQL protocol, got nil")
 | 
			
		||||
		}
 | 
			
		||||
		t.Logf("detected %s version %s", p.Name, p.Version)
 | 
			
		||||
		t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("MariaDB", func(t *testing.T) {
 | 
			
		||||
		p, _ := Detect(Server, mariaDBBanner)
 | 
			
		||||
		p, c, _ := Detect(Server, mariaDBBanner, 3306, 0)
 | 
			
		||||
		if p == nil {
 | 
			
		||||
			t.Fatal("expected MySQL protocol, got nil")
 | 
			
		||||
		}
 | 
			
		||||
		t.Logf("detected %s version %s", p.Name, p.Version)
 | 
			
		||||
		t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Invalid HTTP", func(t *testing.T) {
 | 
			
		||||
		_, err := Detect(Server, httpBanner)
 | 
			
		||||
		_, _, err := Detect(Server, httpBanner, 1234, 80)
 | 
			
		||||
		if !errors.Is(err, ErrUnknown) {
 | 
			
		||||
			t.Fatalf("expected unknown format, got error %T: %q", err, err)
 | 
			
		||||
		} else {
 | 
			
		||||
@@ -63,7 +63,7 @@ func TestDetectMySQL(t *testing.T) {
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Too short", func(t *testing.T) {
 | 
			
		||||
		_, err := Detect(Server, shortSlice)
 | 
			
		||||
		_, _, err := Detect(Server, shortSlice, 3306, 1234)
 | 
			
		||||
		if !errors.Is(err, ErrUnknown) {
 | 
			
		||||
			t.Fatalf("expected unknown format, got error %T: %q", err, err)
 | 
			
		||||
		} else {
 | 
			
		||||
@@ -72,7 +72,7 @@ func TestDetectMySQL(t *testing.T) {
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Malformed", func(t *testing.T) {
 | 
			
		||||
		_, err := Detect(Server, malformedSlice)
 | 
			
		||||
		_, _, err := Detect(Server, malformedSlice, 3306, 1234)
 | 
			
		||||
		if !errors.Is(err, ErrUnknown) {
 | 
			
		||||
			t.Fatalf("expected unknown format, got error %T: %q", err, err)
 | 
			
		||||
		} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -20,16 +20,20 @@ func registerPostgreSQL() {
 | 
			
		||||
	Register(Client, "????\x00\x03\x00\x00", detectPostgreSQLClient) // Startup packet, protocol 3.0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func detectPostgreSQLClient(dir Direction, data []byte) *Protocol {
 | 
			
		||||
func detectPostgreSQLClient(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) {
 | 
			
		||||
	// A client startup message needs at least 8 bytes (length + protocol version).
 | 
			
		||||
	if len(data) < 8 {
 | 
			
		||||
		return nil
 | 
			
		||||
		return nil, 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	length := int(binary.BigEndian.Uint32(data[0:]))
 | 
			
		||||
	if len(data) != length {
 | 
			
		||||
		log.Printf("not postgres %q: %d != %d", data, len(data), length)
 | 
			
		||||
		return nil
 | 
			
		||||
		return nil, 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if dstPort == 5432 {
 | 
			
		||||
		confidence = .1
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	major := int(binary.BigEndian.Uint16(data[4:]))
 | 
			
		||||
@@ -42,15 +46,19 @@ func detectPostgreSQLClient(dir Direction, data []byte) *Protocol {
 | 
			
		||||
				Minor: minor,
 | 
			
		||||
				Patch: -1,
 | 
			
		||||
			},
 | 
			
		||||
		}
 | 
			
		||||
		}, confidence + .75
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
	return nil, 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func detectPostgreSQLServer(dir Direction, data []byte) *Protocol {
 | 
			
		||||
func detectPostgreSQLServer(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) {
 | 
			
		||||
	// A server message needs at least 5 bytes (type + length).
 | 
			
		||||
	if len(data) < 5 {
 | 
			
		||||
		return nil
 | 
			
		||||
		return nil, 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if srcPort == 5432 {
 | 
			
		||||
		confidence = .1
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// All server messages (and subsequent client messages) are tagged with a single-byte type.
 | 
			
		||||
@@ -62,9 +70,9 @@ func detectPostgreSQLServer(dir Direction, data []byte) *Protocol {
 | 
			
		||||
		'Z', // ReadyForQuery
 | 
			
		||||
		'E', // ErrorResponse
 | 
			
		||||
		'N': // NoticeResponse
 | 
			
		||||
		return &Protocol{Name: ProtocolPostgreSQL}
 | 
			
		||||
		return &Protocol{Name: ProtocolPostgreSQL}, confidence + .65
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
		return nil
 | 
			
		||||
		return nil, 0
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -22,12 +22,12 @@ func TestDetectPostgreSQLClient(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.Run("Protocol 3.0", func(t *testing.T) {
 | 
			
		||||
		p, err := Detect(Client, pgClientStartup)
 | 
			
		||||
		p, c, err := Detect(Client, pgClientStartup, 0, 5432)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		t.Logf("detected %s version %s", p.Name, p.Version)
 | 
			
		||||
		t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, 100*c)
 | 
			
		||||
		if p.Name != ProtocolPostgreSQL {
 | 
			
		||||
			t.Fatalf("expected postgres protocol, got %s", p.Name)
 | 
			
		||||
			return
 | 
			
		||||
@@ -58,12 +58,12 @@ func TestDetectPostgreSQLServer(t *testing.T) {
 | 
			
		||||
	httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
 | 
			
		||||
 | 
			
		||||
	t.Run("AuthenticationOk", func(t *testing.T) {
 | 
			
		||||
		p, err := Detect(Server, pgServerAuthOK)
 | 
			
		||||
		p, c, err := Detect(Server, pgServerAuthOK, 5432, 0)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		t.Logf("detected %s version %s", p.Name, p.Version)
 | 
			
		||||
		t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
 | 
			
		||||
		if p.Name != ProtocolPostgreSQL {
 | 
			
		||||
			t.Fatalf("expected postgres protocol, got %s", p.Name)
 | 
			
		||||
			return
 | 
			
		||||
@@ -71,12 +71,12 @@ func TestDetectPostgreSQLServer(t *testing.T) {
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("ErrorResponse", func(t *testing.T) {
 | 
			
		||||
		p, err := Detect(Server, pgServerError)
 | 
			
		||||
		p, c, err := Detect(Server, pgServerError, 5432, 0)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		t.Logf("detected %s version %s", p.Name, p.Version)
 | 
			
		||||
		t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
 | 
			
		||||
		if p.Name != ProtocolPostgreSQL {
 | 
			
		||||
			t.Fatalf("expected postgres protocol, got %s", p.Name)
 | 
			
		||||
			return
 | 
			
		||||
@@ -84,7 +84,7 @@ func TestDetectPostgreSQLServer(t *testing.T) {
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Invalid HTTP", func(t *testing.T) {
 | 
			
		||||
		_, err := Detect(Server, httpBanner)
 | 
			
		||||
		_, _, err := Detect(Server, httpBanner, 0, 80)
 | 
			
		||||
		if !errors.Is(err, ErrUnknown) {
 | 
			
		||||
			t.Fatalf("expected unknown format, got error %T: %q", err, err)
 | 
			
		||||
		} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -14,10 +14,14 @@ func init() {
 | 
			
		||||
	Register(Both, "", detectSSH)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func detectSSH(dir Direction, data []byte) *Protocol {
 | 
			
		||||
func detectSSH(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) {
 | 
			
		||||
	// The data must be at least as long as the prefix itself.
 | 
			
		||||
	if len(data) < len(ssh20Prefix) {
 | 
			
		||||
		return nil
 | 
			
		||||
		return nil, 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if dstPort == 22 || dstPort == 2200 || dstPort == 2222 {
 | 
			
		||||
		confidence = .1
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// The protocol allows for pre-banner text, so we have to check all lines.
 | 
			
		||||
@@ -32,7 +36,7 @@ func detectSSH(dir Direction, data []byte) *Protocol {
 | 
			
		||||
					Patch: -1,
 | 
			
		||||
					Extra: string(line[len(ssh20Prefix):]),
 | 
			
		||||
				},
 | 
			
		||||
			}
 | 
			
		||||
			}, confidence + 0.75
 | 
			
		||||
		}
 | 
			
		||||
		if bytes.HasPrefix(line, []byte(ssh199Prefix)) {
 | 
			
		||||
			return &Protocol{
 | 
			
		||||
@@ -43,9 +47,9 @@ func detectSSH(dir Direction, data []byte) *Protocol {
 | 
			
		||||
					Patch: -1,
 | 
			
		||||
					Extra: string(line[len(ssh20Prefix):]),
 | 
			
		||||
				},
 | 
			
		||||
			}
 | 
			
		||||
			}, confidence + 0.75
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
	return nil, 0
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -32,7 +32,7 @@ func TestDetectSSH(t *testing.T) {
 | 
			
		||||
	httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
 | 
			
		||||
 | 
			
		||||
	t.Run("OpenSSH client", func(t *testing.T) {
 | 
			
		||||
		p, err := Detect(Server, openSSHBanner)
 | 
			
		||||
		p, _, err := Detect(Server, openSSHBanner, 0, 22)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
			return
 | 
			
		||||
@@ -45,7 +45,7 @@ func TestDetectSSH(t *testing.T) {
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("OpenSSH server", func(t *testing.T) {
 | 
			
		||||
		p, err := Detect(Server, openSSHBanner)
 | 
			
		||||
		p, _, err := Detect(Server, openSSHBanner, 0, 22)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
			return
 | 
			
		||||
@@ -58,7 +58,7 @@ func TestDetectSSH(t *testing.T) {
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("OpenSSH server with banner", func(t *testing.T) {
 | 
			
		||||
		p, err := Detect(Server, preBannerSSH)
 | 
			
		||||
		p, _, err := Detect(Server, preBannerSSH, 0, 22)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
			return
 | 
			
		||||
@@ -71,7 +71,7 @@ func TestDetectSSH(t *testing.T) {
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Dropbear server", func(t *testing.T) {
 | 
			
		||||
		p, err := Detect(Server, dropbearBanner)
 | 
			
		||||
		p, _, err := Detect(Server, dropbearBanner, 0, 22)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
			return
 | 
			
		||||
@@ -84,7 +84,7 @@ func TestDetectSSH(t *testing.T) {
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Invalid MySQL banner", func(t *testing.T) {
 | 
			
		||||
		_, err := Detect(Server, mysqlBanner)
 | 
			
		||||
		_, _, err := Detect(Server, mysqlBanner, 0, 3306)
 | 
			
		||||
		if !errors.Is(err, ErrUnknown) {
 | 
			
		||||
			t.Fatalf("expected unknown format, got error %T: %q", err, err)
 | 
			
		||||
		} else {
 | 
			
		||||
@@ -93,7 +93,7 @@ func TestDetectSSH(t *testing.T) {
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Invalid HTTP banner", func(t *testing.T) {
 | 
			
		||||
		_, err := Detect(Server, httpBanner)
 | 
			
		||||
		_, _, err := Detect(Server, httpBanner, 0, 80)
 | 
			
		||||
		if !errors.Is(err, ErrUnknown) {
 | 
			
		||||
			t.Fatalf("expected unknown format, got error %T: %q", err, err)
 | 
			
		||||
		} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -17,12 +17,12 @@ func registerTLS() {
 | 
			
		||||
	Register(Both, "\x16\x03\x03", detectTLS) // TLSv1.2
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func detectTLS(dir Direction, data []byte) *Protocol {
 | 
			
		||||
func detectTLS(dir Direction, data []byte, _, _ int) (proto *Protocol, confidence float64) {
 | 
			
		||||
	stream := cryptobyte.String(data)
 | 
			
		||||
 | 
			
		||||
	// A TLS packet always has a content type (1 byte), version (2 bytes) and length (2 bytes).
 | 
			
		||||
	if len(stream) < 5 {
 | 
			
		||||
		return nil
 | 
			
		||||
		return nil, 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Check for TLS Handshake (type 22)
 | 
			
		||||
@@ -32,15 +32,18 @@ func detectTLS(dir Direction, data []byte) *Protocol {
 | 
			
		||||
		Length  uint32
 | 
			
		||||
	}
 | 
			
		||||
	if !stream.ReadUint8(&header.Type) || header.Type != 0x16 {
 | 
			
		||||
		return nil
 | 
			
		||||
		return nil, 0
 | 
			
		||||
	}
 | 
			
		||||
	if !stream.ReadUint16(&header.Version) {
 | 
			
		||||
		return nil
 | 
			
		||||
		return nil, 0
 | 
			
		||||
	}
 | 
			
		||||
	if !stream.ReadUint24(&header.Length) {
 | 
			
		||||
		return nil
 | 
			
		||||
		return nil, 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Initial confidence
 | 
			
		||||
	confidence = 0.5
 | 
			
		||||
 | 
			
		||||
	// Detected SSL/TLS version
 | 
			
		||||
	var version dpi.TLSVersion
 | 
			
		||||
 | 
			
		||||
@@ -48,6 +51,7 @@ func detectTLS(dir Direction, data []byte) *Protocol {
 | 
			
		||||
	if version == 0 {
 | 
			
		||||
		if hello, err := dpi.DecodeTLSClientHelloHandshake(data); err == nil {
 | 
			
		||||
			version = hello.Version
 | 
			
		||||
			confidence += .45
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -55,6 +59,7 @@ func detectTLS(dir Direction, data []byte) *Protocol {
 | 
			
		||||
	if version == 0 {
 | 
			
		||||
		if hello, err := dpi.DecodeTLSServerHello(data); err == nil {
 | 
			
		||||
			version = hello.Version
 | 
			
		||||
			confidence += .45
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -68,6 +73,7 @@ func detectTLS(dir Direction, data []byte) *Protocol {
 | 
			
		||||
			)
 | 
			
		||||
			if stream.ReadUint24(&length) && stream.ReadUint16(&versionWord) {
 | 
			
		||||
				version = dpi.TLSVersion(versionWord)
 | 
			
		||||
				confidence += .25
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
@@ -82,17 +88,17 @@ func detectTLS(dir Direction, data []byte) *Protocol {
 | 
			
		||||
		return &Protocol{
 | 
			
		||||
			Name:    ProtocolSSL,
 | 
			
		||||
			Version: Version{Major: 3, Minor: 0, Patch: -1},
 | 
			
		||||
		}
 | 
			
		||||
		}, confidence
 | 
			
		||||
	} else if version >= dpi.VersionTLS10 && version <= dpi.VersionTLS13 {
 | 
			
		||||
		return &Protocol{
 | 
			
		||||
			Name:    ProtocolTLS,
 | 
			
		||||
			Version: Version{Major: 1, Minor: int(uint8(version) - 1), Patch: -1},
 | 
			
		||||
		}
 | 
			
		||||
		}, confidence
 | 
			
		||||
	} else if version >= dpi.VersionTLS13Draft && version <= dpi.VersionTLS13Draft23 {
 | 
			
		||||
		return &Protocol{
 | 
			
		||||
			Name:    ProtocolTLS,
 | 
			
		||||
			Version: Version{Major: 1, Minor: 3, Patch: -1},
 | 
			
		||||
		}
 | 
			
		||||
		}, confidence
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
	return nil, 0
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -179,7 +179,7 @@ func TestDetectTLS(t *testing.T) {
 | 
			
		||||
		t.Run(name, func(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
			t.Run("SSLv3 Client Hello", func(t *testing.T) {
 | 
			
		||||
				p, err := Detect(Client, sslV3ClientHello)
 | 
			
		||||
				p, _, err := Detect(Client, sslV3ClientHello, 0, 0)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					t.Fatal(err)
 | 
			
		||||
					return
 | 
			
		||||
@@ -192,7 +192,7 @@ func TestDetectTLS(t *testing.T) {
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			t.Run("TLS 1.1 Client Hello", func(t *testing.T) {
 | 
			
		||||
				p, err := Detect(Client, tls11ClientHello)
 | 
			
		||||
				p, _, err := Detect(Client, tls11ClientHello, 0, 0)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					t.Fatal(err)
 | 
			
		||||
					return
 | 
			
		||||
@@ -205,7 +205,7 @@ func TestDetectTLS(t *testing.T) {
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			t.Run("TLS 1.1 partial Client Hello", func(t *testing.T) {
 | 
			
		||||
				p, err := Detect(Client, tls11ClientHelloPartial)
 | 
			
		||||
				p, _, err := Detect(Client, tls11ClientHelloPartial, 0, 0)
 | 
			
		||||
				if strict {
 | 
			
		||||
					if !errors.Is(err, ErrUnknown) {
 | 
			
		||||
						t.Fatalf("expected unknown format, got error %T: %q", err, err)
 | 
			
		||||
@@ -226,7 +226,7 @@ func TestDetectTLS(t *testing.T) {
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			t.Run("TLS 1.2 Client Hello", func(t *testing.T) {
 | 
			
		||||
				p, err := Detect(Client, tls12ClientHello)
 | 
			
		||||
				p, _, err := Detect(Client, tls12ClientHello, 0, 0)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					t.Fatal(err)
 | 
			
		||||
					return
 | 
			
		||||
@@ -239,7 +239,7 @@ func TestDetectTLS(t *testing.T) {
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			t.Run("TLS 1.3 Client Hello", func(t *testing.T) {
 | 
			
		||||
				p, err := Detect(Client, tls13ClientHello)
 | 
			
		||||
				p, _, err := Detect(Client, tls13ClientHello, 0, 0)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					t.Fatal(err)
 | 
			
		||||
					return
 | 
			
		||||
@@ -252,7 +252,7 @@ func TestDetectTLS(t *testing.T) {
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			t.Run("Invalid PostgreSQL", func(t *testing.T) {
 | 
			
		||||
				_, err := Detect(Server, pgClientStartup)
 | 
			
		||||
				_, _, err := Detect(Server, pgClientStartup, 0, 0)
 | 
			
		||||
				if !errors.Is(err, ErrUnknown) {
 | 
			
		||||
					t.Fatalf("expected unknown format, got error %T: %q", err, err)
 | 
			
		||||
				} else {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										215
									
								
								protocol/detest_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										215
									
								
								protocol/detest_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,215 @@
 | 
			
		||||
package protocol
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"math"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestCompareFloats(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name     string
 | 
			
		||||
		a, b     float64
 | 
			
		||||
		expected int
 | 
			
		||||
	}{
 | 
			
		||||
		// Basic comparisons
 | 
			
		||||
		{
 | 
			
		||||
			name:     "a less than b",
 | 
			
		||||
			a:        1.0,
 | 
			
		||||
			b:        2.0,
 | 
			
		||||
			expected: -1,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "a greater than b",
 | 
			
		||||
			a:        2.0,
 | 
			
		||||
			b:        1.0,
 | 
			
		||||
			expected: 1,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "a equals b exact",
 | 
			
		||||
			a:        1.0,
 | 
			
		||||
			b:        1.0,
 | 
			
		||||
			expected: 0,
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		// Floating-point precision cases
 | 
			
		||||
		{
 | 
			
		||||
			name:     "famous 0.1 + 0.2 equals 0.3 within tolerance",
 | 
			
		||||
			a:        0.1 + 0.2,
 | 
			
		||||
			b:        0.3,
 | 
			
		||||
			expected: 0,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "very close numbers within tolerance",
 | 
			
		||||
			a:        1.0000000001,
 | 
			
		||||
			b:        1.0000000002,
 | 
			
		||||
			expected: 0,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "numbers outside tolerance a < b",
 | 
			
		||||
			a:        1.0,
 | 
			
		||||
			b:        1.0001,
 | 
			
		||||
			expected: -1,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "numbers outside tolerance a > b",
 | 
			
		||||
			a:        1.0001,
 | 
			
		||||
			b:        1.0,
 | 
			
		||||
			expected: 1,
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		// Edge cases with very small numbers
 | 
			
		||||
		{
 | 
			
		||||
			name:     "very small numbers equal",
 | 
			
		||||
			a:        1e-20,
 | 
			
		||||
			b:        1e-20,
 | 
			
		||||
			expected: 0,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "very small numbers a < b",
 | 
			
		||||
			a:        1e-15,
 | 
			
		||||
			b:        2e-15,
 | 
			
		||||
			expected: -1,
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		// Zero and negative zero
 | 
			
		||||
		{
 | 
			
		||||
			name:     "zero equals zero",
 | 
			
		||||
			a:        0.0,
 | 
			
		||||
			b:        0.0,
 | 
			
		||||
			expected: 0,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "zero equals negative zero",
 | 
			
		||||
			a:        0.0,
 | 
			
		||||
			b:        -0.0,
 | 
			
		||||
			expected: 0,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "zero less than small positive",
 | 
			
		||||
			a:        0.0,
 | 
			
		||||
			b:        1e-20,
 | 
			
		||||
			expected: -1,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "zero greater than small negative",
 | 
			
		||||
			a:        0.0,
 | 
			
		||||
			b:        -1e-20,
 | 
			
		||||
			expected: 1,
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		// Negative numbers
 | 
			
		||||
		{
 | 
			
		||||
			name:     "negative numbers a > b",
 | 
			
		||||
			a:        -1.0,
 | 
			
		||||
			b:        -2.0,
 | 
			
		||||
			expected: 1,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "negative numbers a < b",
 | 
			
		||||
			a:        -2.0,
 | 
			
		||||
			b:        -1.0,
 | 
			
		||||
			expected: -1,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "negative numbers equal",
 | 
			
		||||
			a:        -1.0,
 | 
			
		||||
			b:        -1.0,
 | 
			
		||||
			expected: 0,
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		// Mixed signs
 | 
			
		||||
		{
 | 
			
		||||
			name:     "negative less than positive",
 | 
			
		||||
			a:        -1.0,
 | 
			
		||||
			b:        1.0,
 | 
			
		||||
			expected: -1,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "positive greater than negative",
 | 
			
		||||
			a:        1.0,
 | 
			
		||||
			b:        -1.0,
 | 
			
		||||
			expected: 1,
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		// Special values: NaN
 | 
			
		||||
		{
 | 
			
		||||
			name:     "NaN equals NaN",
 | 
			
		||||
			a:        math.NaN(),
 | 
			
		||||
			b:        math.NaN(),
 | 
			
		||||
			expected: 0,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "NaN less than number",
 | 
			
		||||
			a:        math.NaN(),
 | 
			
		||||
			b:        1.0,
 | 
			
		||||
			expected: -1,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "number greater than NaN",
 | 
			
		||||
			a:        1.0,
 | 
			
		||||
			b:        math.NaN(),
 | 
			
		||||
			expected: 1,
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		// Special values: Infinity
 | 
			
		||||
		{
 | 
			
		||||
			name:     "positive infinity equals positive infinity",
 | 
			
		||||
			a:        math.Inf(1),
 | 
			
		||||
			b:        math.Inf(1),
 | 
			
		||||
			expected: 0,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "negative infinity equals negative infinity",
 | 
			
		||||
			a:        math.Inf(-1),
 | 
			
		||||
			b:        math.Inf(-1),
 | 
			
		||||
			expected: 0,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "positive infinity greater than negative infinity",
 | 
			
		||||
			a:        math.Inf(1),
 | 
			
		||||
			b:        math.Inf(-1),
 | 
			
		||||
			expected: 1,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "negative infinity less than positive infinity",
 | 
			
		||||
			a:        math.Inf(-1),
 | 
			
		||||
			b:        math.Inf(1),
 | 
			
		||||
			expected: -1,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "positive infinity greater than large number",
 | 
			
		||||
			a:        math.Inf(1),
 | 
			
		||||
			b:        1e308,
 | 
			
		||||
			expected: 1,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "negative infinity less than small number",
 | 
			
		||||
			a:        math.Inf(-1),
 | 
			
		||||
			b:        -1e308,
 | 
			
		||||
			expected: -1,
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		// Large numbers
 | 
			
		||||
		{
 | 
			
		||||
			name:     "large numbers equal",
 | 
			
		||||
			a:        1e15,
 | 
			
		||||
			b:        1e15,
 | 
			
		||||
			expected: 0,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "large numbers a < b",
 | 
			
		||||
			a:        1e15,
 | 
			
		||||
			b:        2e15,
 | 
			
		||||
			expected: -1,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, test := range tests {
 | 
			
		||||
		t.Run(test.name, func(t *testing.T) {
 | 
			
		||||
			result := compareFloats(test.a, test.b)
 | 
			
		||||
			if result != test.expected {
 | 
			
		||||
				t.Errorf("compareFloats(%g, %g) = %d, want %d", test.a, test.b, result, test.expected)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -2,21 +2,25 @@ package protocol
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Intercepted is the result returned by [Interceptor.Detect].
 | 
			
		||||
type Intercepted struct {
 | 
			
		||||
	Direction Direction
 | 
			
		||||
	Protocol  *Protocol
 | 
			
		||||
	Error     error
 | 
			
		||||
	Direction  Direction
 | 
			
		||||
	Protocol   *Protocol
 | 
			
		||||
	Confidence float64
 | 
			
		||||
	Error      error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Interceptor intercepts reads from client or server.
 | 
			
		||||
type Interceptor struct {
 | 
			
		||||
	clientPort   int
 | 
			
		||||
	clientBytes  chan []byte
 | 
			
		||||
	clientReader *readInterceptor
 | 
			
		||||
	serverPort   int
 | 
			
		||||
	serverBytes  chan []byte
 | 
			
		||||
	serverReader *readInterceptor
 | 
			
		||||
}
 | 
			
		||||
@@ -71,6 +75,7 @@ func (i *Interceptor) Client(c net.Conn) net.Conn {
 | 
			
		||||
	if ri, ok := c.(*readInterceptor); ok {
 | 
			
		||||
		return ri
 | 
			
		||||
	}
 | 
			
		||||
	i.clientPort = getPortFromAddr(c.RemoteAddr())
 | 
			
		||||
	i.clientReader = newReadInterceptor(c, i.clientBytes)
 | 
			
		||||
	return i.clientReader
 | 
			
		||||
}
 | 
			
		||||
@@ -80,6 +85,7 @@ func (i *Interceptor) Server(c net.Conn) net.Conn {
 | 
			
		||||
	if ri, ok := c.(*readInterceptor); ok {
 | 
			
		||||
		return ri
 | 
			
		||||
	}
 | 
			
		||||
	i.serverPort = getPortFromAddr(c.RemoteAddr())
 | 
			
		||||
	i.serverReader = newReadInterceptor(c, i.serverBytes)
 | 
			
		||||
	return i.serverReader
 | 
			
		||||
}
 | 
			
		||||
@@ -107,22 +113,49 @@ func (i *Interceptor) Detect(timeout time.Duration) <-chan *Intercepted {
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		case data := <-i.clientBytes: // client sent banner
 | 
			
		||||
			p, err := Detect(Client, data)
 | 
			
		||||
			p, c, err := Detect(Client, data, i.clientPort, i.serverPort)
 | 
			
		||||
			interceptc <- &Intercepted{
 | 
			
		||||
				Direction: Client,
 | 
			
		||||
				Protocol:  p,
 | 
			
		||||
				Error:     err,
 | 
			
		||||
				Direction:  Client,
 | 
			
		||||
				Protocol:   p,
 | 
			
		||||
				Confidence: c,
 | 
			
		||||
				Error:      err,
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		case data := <-i.serverBytes: // server sent banner
 | 
			
		||||
			p, err := Detect(Server, data)
 | 
			
		||||
			p, c, err := Detect(Server, data, i.serverPort, i.clientPort)
 | 
			
		||||
			interceptc <- &Intercepted{
 | 
			
		||||
				Direction: Server,
 | 
			
		||||
				Protocol:  p,
 | 
			
		||||
				Error:     err,
 | 
			
		||||
				Direction:  Server,
 | 
			
		||||
				Protocol:   p,
 | 
			
		||||
				Confidence: c,
 | 
			
		||||
				Error:      err,
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	return interceptc
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getPortFromAddr(addr net.Addr) int {
 | 
			
		||||
	switch a := addr.(type) {
 | 
			
		||||
	case *net.TCPAddr:
 | 
			
		||||
		return a.Port
 | 
			
		||||
	case *net.UDPAddr:
 | 
			
		||||
		return a.Port
 | 
			
		||||
	case *net.IPAddr:
 | 
			
		||||
		// IPAddr doesn't have a port
 | 
			
		||||
		return 0
 | 
			
		||||
	default:
 | 
			
		||||
		// Fallback to parsing
 | 
			
		||||
		_, service, err := net.SplitHostPort(addr.String())
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return 0
 | 
			
		||||
		}
 | 
			
		||||
		if port, err := strconv.Atoi(service); err == nil {
 | 
			
		||||
			return port
 | 
			
		||||
		}
 | 
			
		||||
		if port, err := net.LookupPort(addr.Network(), service); err == nil {
 | 
			
		||||
			return port
 | 
			
		||||
		}
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -35,42 +35,43 @@ type connLimiter struct {
 | 
			
		||||
	acceptError atomic.Value
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *connLimiter) init(readData, writeData []byte) {
 | 
			
		||||
	l.acceptOnce.Do(func() {
 | 
			
		||||
func (limiter *connLimiter) init(readData, writeData []byte) {
 | 
			
		||||
	limiter.acceptOnce.Do(func() {
 | 
			
		||||
		var (
 | 
			
		||||
			dir  Direction
 | 
			
		||||
			data []byte
 | 
			
		||||
			dir              Direction
 | 
			
		||||
			data             []byte
 | 
			
		||||
			srcPort, dstPort int
 | 
			
		||||
		)
 | 
			
		||||
		if readData != nil {
 | 
			
		||||
			// init called by initial read
 | 
			
		||||
			dir, data = Server, readData
 | 
			
		||||
			dir, data, srcPort, dstPort = Server, readData, getPortFromAddr(limiter.LocalAddr()), getPortFromAddr(limiter.RemoteAddr())
 | 
			
		||||
		} else {
 | 
			
		||||
			// init called by initial write
 | 
			
		||||
			dir, data = Client, writeData
 | 
			
		||||
			dir, data, srcPort, dstPort = Client, writeData, getPortFromAddr(limiter.RemoteAddr()), getPortFromAddr(limiter.LocalAddr())
 | 
			
		||||
		}
 | 
			
		||||
		protocol, _ := Detect(dir, data)
 | 
			
		||||
		if err := l.accept(dir, protocol); err != nil {
 | 
			
		||||
			l.acceptError.Store(err)
 | 
			
		||||
		protocol, _, _ := Detect(dir, data, srcPort, dstPort)
 | 
			
		||||
		if err := limiter.accept(dir, protocol); err != nil {
 | 
			
		||||
			limiter.acceptError.Store(err)
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *connLimiter) Read(p []byte) (n int, err error) {
 | 
			
		||||
func (limiter *connLimiter) Read(p []byte) (n int, err error) {
 | 
			
		||||
	var ok bool
 | 
			
		||||
	if err, ok = l.acceptError.Load().(error); ok && err != nil {
 | 
			
		||||
	if err, ok = limiter.acceptError.Load().(error); ok && err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if n, err = l.Conn.Read(p); n > 0 {
 | 
			
		||||
		l.init(p[:n], nil)
 | 
			
		||||
	if n, err = limiter.Conn.Read(p); n > 0 {
 | 
			
		||||
		limiter.init(p[:n], nil)
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *connLimiter) Write(p []byte) (n int, err error) {
 | 
			
		||||
	l.init(nil, p)
 | 
			
		||||
func (limiter *connLimiter) Write(p []byte) (n int, err error) {
 | 
			
		||||
	limiter.init(nil, p)
 | 
			
		||||
	var ok bool
 | 
			
		||||
	if err, ok = l.acceptError.Load().(error); ok && err != nil {
 | 
			
		||||
	if err, ok = limiter.acceptError.Load().(error); ok && err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return l.Conn.Write(p)
 | 
			
		||||
	return limiter.Conn.Write(p)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										77
									
								
								protocol/match.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								protocol/match.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,77 @@
 | 
			
		||||
package protocol
 | 
			
		||||
 | 
			
		||||
// MatchPattern checks if the byte slice matches the magic string pattern.
 | 
			
		||||
//
 | 
			
		||||
// '?' matches any single character
 | 
			
		||||
// '*' matches zero or more characters
 | 
			
		||||
// '\' escapes special characters ('?', '*', '\')
 | 
			
		||||
// All other characters must match exactly
 | 
			
		||||
//
 | 
			
		||||
// Returns true if all magic bytes are matched, even if input has extra bytes.
 | 
			
		||||
func Match(magic string, input []byte) bool {
 | 
			
		||||
	return match(magic, input, 0, 0)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// match is a recursive helper function that implements the matching logic
 | 
			
		||||
func match(magic string, input []byte, magicIndex, inputIndex int) bool {
 | 
			
		||||
	// If we've reached the end of magic string, we've successfully matched all magic bytes
 | 
			
		||||
	// It doesn't matter if there are extra bytes in the input
 | 
			
		||||
	if magicIndex == len(magic) {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Handle escape character
 | 
			
		||||
	if magic[magicIndex] == '\\' {
 | 
			
		||||
		// Check if there's a next character in magic
 | 
			
		||||
		if magicIndex+1 >= len(magic) {
 | 
			
		||||
			// Backslash at end of magic string - treat as literal backslash
 | 
			
		||||
			if inputIndex >= len(input) || input[inputIndex] != '\\' {
 | 
			
		||||
				return false
 | 
			
		||||
			}
 | 
			
		||||
			return match(magic, input, magicIndex+1, inputIndex+1)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Escape the next character - we need to match it literally
 | 
			
		||||
		escapedChar := magic[magicIndex+1]
 | 
			
		||||
		if inputIndex >= len(input) || input[inputIndex] != escapedChar {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
		// Skip both the backslash and the escaped character in magic, move one in input
 | 
			
		||||
		return match(magic, input, magicIndex+2, inputIndex+1)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// If we've reached the end of input but not magic string
 | 
			
		||||
	if inputIndex == len(input) {
 | 
			
		||||
		// If we have '*' at the current position, it can match zero characters
 | 
			
		||||
		if magic[magicIndex] == '*' {
 | 
			
		||||
			return match(magic, input, magicIndex+1, inputIndex)
 | 
			
		||||
		}
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Handle '*' character - matches zero or more characters
 | 
			
		||||
	if magic[magicIndex] == '*' {
 | 
			
		||||
		// Try matching zero characters
 | 
			
		||||
		if match(magic, input, magicIndex+1, inputIndex) {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
		// Try matching one or more characters
 | 
			
		||||
		if inputIndex < len(input) && match(magic, input, magicIndex, inputIndex+1) {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Handle '?' character - matches any single character
 | 
			
		||||
	if magic[magicIndex] == '?' {
 | 
			
		||||
		return match(magic, input, magicIndex+1, inputIndex+1)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Handle exact character match
 | 
			
		||||
	if magic[magicIndex] == input[inputIndex] {
 | 
			
		||||
		return match(magic, input, magicIndex+1, inputIndex+1)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// No match found
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										227
									
								
								protocol/match_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										227
									
								
								protocol/match_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,227 @@
 | 
			
		||||
package protocol
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestMatch(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name     string
 | 
			
		||||
		magic    string
 | 
			
		||||
		input    []byte
 | 
			
		||||
		expected bool
 | 
			
		||||
	}{
 | 
			
		||||
		// Basic escaping tests
 | 
			
		||||
		{
 | 
			
		||||
			name:     "escape star",
 | 
			
		||||
			magic:    "\\*",
 | 
			
		||||
			input:    []byte("*"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "escape star no match",
 | 
			
		||||
			magic:    "\\*",
 | 
			
		||||
			input:    []byte("a"),
 | 
			
		||||
			expected: false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "escape star with longer input",
 | 
			
		||||
			magic:    "\\*",
 | 
			
		||||
			input:    []byte("*extra"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "escape question mark",
 | 
			
		||||
			magic:    "\\?",
 | 
			
		||||
			input:    []byte("?"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "escape question mark no match",
 | 
			
		||||
			magic:    "\\?",
 | 
			
		||||
			input:    []byte("a"),
 | 
			
		||||
			expected: false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "escape backslash",
 | 
			
		||||
			magic:    "\\\\",
 | 
			
		||||
			input:    []byte("\\"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "escape backslash no match",
 | 
			
		||||
			magic:    "\\\\",
 | 
			
		||||
			input:    []byte("a"),
 | 
			
		||||
			expected: false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "escape backslash with longer input",
 | 
			
		||||
			magic:    "\\\\",
 | 
			
		||||
			input:    []byte("\\extra"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		// Multiple escaped characters
 | 
			
		||||
		{
 | 
			
		||||
			name:     "multiple escaped characters",
 | 
			
		||||
			magic:    "\\*\\?\\\\",
 | 
			
		||||
			input:    []byte("*?\\"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "multiple escaped characters with longer input",
 | 
			
		||||
			magic:    "\\*\\?\\\\",
 | 
			
		||||
			input:    []byte("*?\\extra"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "mixed escaped characters",
 | 
			
		||||
			magic:    "a\\*b\\?c\\\\d",
 | 
			
		||||
			input:    []byte("a*b?c\\d"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		// Escaping combined with wildcards
 | 
			
		||||
		{
 | 
			
		||||
			name:     "star then escaped star",
 | 
			
		||||
			magic:    "*\\*",
 | 
			
		||||
			input:    []byte("anything*"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "star then escaped star must end with star",
 | 
			
		||||
			magic:    "*\\*",
 | 
			
		||||
			input:    []byte("anything"),
 | 
			
		||||
			expected: false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "star then escaped question",
 | 
			
		||||
			magic:    "*\\?",
 | 
			
		||||
			input:    []byte("hello?"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "question then escaped star",
 | 
			
		||||
			magic:    "?\\*",
 | 
			
		||||
			input:    []byte("a*"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "question then escaped star wrong second char",
 | 
			
		||||
			magic:    "?\\*",
 | 
			
		||||
			input:    []byte("aa"),
 | 
			
		||||
			expected: false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "wildcards between escaped characters",
 | 
			
		||||
			magic:    "*\\\\*",
 | 
			
		||||
			input:    []byte("path\\to\\file"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		// Real-world escaping scenarios
 | 
			
		||||
		{
 | 
			
		||||
			name:     "file pattern with literal star",
 | 
			
		||||
			magic:    "file\\*.txt",
 | 
			
		||||
			input:    []byte("file*.txt"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "file pattern with literal star no match",
 | 
			
		||||
			magic:    "file\\*.txt",
 | 
			
		||||
			input:    []byte("filex.txt"),
 | 
			
		||||
			expected: false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "pattern with literal question",
 | 
			
		||||
			magic:    "what\\?*",
 | 
			
		||||
			input:    []byte("what? is this"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "pattern with literal question must have question",
 | 
			
		||||
			magic:    "what\\?*",
 | 
			
		||||
			input:    []byte("what is this"),
 | 
			
		||||
			expected: false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "database like pattern",
 | 
			
		||||
			magic:    "table_\\*_\\?",
 | 
			
		||||
			input:    []byte("table_*_?"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "database like pattern with longer input",
 | 
			
		||||
			magic:    "table_\\*_\\?",
 | 
			
		||||
			input:    []byte("table_*_?backup"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		// Edge cases with escaping
 | 
			
		||||
		{
 | 
			
		||||
			name:     "backslash at end of magic",
 | 
			
		||||
			magic:    "test\\",
 | 
			
		||||
			input:    []byte("test\\"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "backslash at end of magic no match",
 | 
			
		||||
			magic:    "test\\",
 | 
			
		||||
			input:    []byte("test"),
 | 
			
		||||
			expected: false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "only backslash",
 | 
			
		||||
			magic:    "\\",
 | 
			
		||||
			input:    []byte("\\"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "consecutive backslashes",
 | 
			
		||||
			magic:    "\\\\\\\\",
 | 
			
		||||
			input:    []byte("\\\\"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		// Mixed scenarios with both escaping and wildcards
 | 
			
		||||
		{
 | 
			
		||||
			name:     "escaped wildcards in middle",
 | 
			
		||||
			magic:    "a*\\?b*\\*c",
 | 
			
		||||
			input:    []byte("aanything?banything*c"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "escaped wildcards pattern",
 | 
			
		||||
			magic:    "select * from \\*",
 | 
			
		||||
			input:    []byte("select * from *"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "escaped wildcards pattern with longer input",
 | 
			
		||||
			magic:    "select * from *\\*",
 | 
			
		||||
			input:    []byte("select name from users*"),
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, test := range tests {
 | 
			
		||||
		t.Run(test.name, func(t *testing.T) {
 | 
			
		||||
			result := Match(test.magic, test.input)
 | 
			
		||||
			if result != test.expected {
 | 
			
		||||
				t.Errorf("Match(%q, %q) = %v, want %v",
 | 
			
		||||
					test.magic, string(test.input), result, test.expected)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Benchmark test with escaping
 | 
			
		||||
func BenchmarkMatch(b *testing.B) {
 | 
			
		||||
	magic := "file\\*\\?*\\\\*.txt"
 | 
			
		||||
	input := []byte("file*?name\\backup.txt")
 | 
			
		||||
 | 
			
		||||
	b.ResetTimer()
 | 
			
		||||
	for b.Loop() {
 | 
			
		||||
		Match(magic, input)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user