package protocol import ( "errors" "testing" ) func TestDetectMySQL(t *testing.T) { atomicFormats.Store([]format{{Server, "\x0a", detectMySQL}}) // 1. A valid MySQL 8.0 banner mysql8Banner := []byte{ 0x0a, 0x38, 0x2e, 0x30, 0x2e, 0x33, 0x32, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x04, 0x5a, 0x56, 0x5f, 0x3e, 0x6e, 0x76, 0x27, 0x00, 0xff, 0xff, 0xff, 0x02, 0x00, 0xff, 0xc7, 0x15, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x63, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x67, 0x5f, 0x73, 0x68, 0x61, 0x32, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, } // 2. A valid MariaDB banner (protocol-compatible) mariaDBBanner := []byte{ 0x0a, 0x35, 0x2e, 0x35, 0x2e, 0x35, 0x2d, 0x31, 0x30, 0x2e, 0x36, 0x2e, 0x35, 0x2d, 0x4d, 0x61, 0x72, 0x69, 0x61, 0x44, 0x42, 0x2d, 0x6c, 0x6f, 0x67, 0x00, 0x1a, 0x00, 0x00, 0x00, 0x4e, 0x5c, 0x32, 0x7b, 0x45, 0x3b, 0x40, 0x60, 0x00, 0xff, 0xf7, 0x08, 0x02, 0x00, 0xff, 0x81, 0x15, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, } // 3. An invalid banner (e.g., an HTTP request) httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") // 4. A short, invalid slice shortSlice := []byte{0x0a, 0x31, 0x32, 0x33} // 5. A slice that starts correctly but is malformed (no null terminator) malformedSlice := []byte{0x0a, 0x38, 0x2e, 0x30, 0x2e, 0x30, 0x01, 0x02, 0x03, 0x04, 0x05} t.Run("MySQL 8", func(t *testing.T) { p, _ := Detect(Server, mysql8Banner) if p == nil { t.Fatal("expected MySQL protocol, got nil") } t.Logf("detected %s version %s", p.Name, p.Version) }) t.Run("MariaDB", func(t *testing.T) { p, _ := Detect(Server, mariaDBBanner) if p == nil { t.Fatal("expected MySQL protocol, got nil") } t.Logf("detected %s version %s", p.Name, p.Version) }) t.Run("Invalid HTTP", func(t *testing.T) { _, err := Detect(Server, httpBanner) if !errors.Is(err, ErrUnknown) { t.Fatalf("expected unknown format, got error %T: %q", err, err) } else { t.Logf("error %q, as expected", err) } }) t.Run("Too short", func(t *testing.T) { _, err := Detect(Server, shortSlice) if !errors.Is(err, ErrUnknown) { t.Fatalf("expected unknown format, got error %T: %q", err, err) } else { t.Logf("error %q, as expected", err) } }) t.Run("Malformed", func(t *testing.T) { _, err := Detect(Server, malformedSlice) if !errors.Is(err, ErrUnknown) { t.Fatalf("expected unknown format, got error %T: %q", err, err) } else { t.Logf("error %q, as expected", err) } }) }