package protocol import ( "errors" "testing" ) func TestDetectPostgreSQLClient(t *testing.T) { atomicFormats.Store([]format{}) registerPostgreSQL() // 1. A valid PostgreSQL client startup message // Format: len (4b), proto (4b), params (n-bytes) // Here, user=mazeio, database=test // The message is: "user\0mazeio\0database\0test\0\0" // Total length: 4 (len) + 4 (proto) + 27 (params) = 35 (0x23) pgClientStartup := []byte{ 0x00, 0x00, 0x00, 0x23, // Length: 35 0x00, 0x03, 0x00, 0x00, // Protocol Version 3.0 'u', 's', 'e', 'r', 0x00, 'm', 'a', 'z', 'e', 'i', 'o', 0x00, 'd', 'a', 't', 'a', 'b', 'a', 's', 'e', 0x00, 't', 'e', 's', 't', 0x00, 0x00, } mysqlBanner := []byte{ 0x0a, 0x38, 0x2e, 0x30, 0x2e, 0x33, 0x32, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x04, 0x5a, 0x56, 0x5f, 0x3e, 0x6e, 0x76, 0x27, 0x00, 0xff, 0xff, 0xff, 0x02, 0x00, 0xff, 0xc7, 0x15, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x63, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x67, 0x5f, 0x73, 0x68, 0x61, 0x32, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, } testRunner(t, []*testCase{ { Name: "PostgeSQL protocol 3.0", Direction: Client, Data: pgClientStartup, DstPort: 5432, WantProto: ProtocolPostgreSQL, WantConfidence: .85, }, { Name: "Invalid MySQL server", Direction: Server, Data: mysqlBanner, SrcPort: 3306, WantError: ErrUnknown, }, }) } func TestDetectPostgreSQLServer(t *testing.T) { atomicFormats.Store([]format{}) registerPostgreSQL() // A valid PostgreSQL server AuthenticationOk response // Format: type (1b), len (4b), content (4b) pgServerAuthOK := []byte{ 'R', // Type: Authentication 0x00, 0x00, 0x00, 0x08, // Length: 8 0x00, 0x00, 0x00, 0x00, // Auth OK (0) } // A valid PostgreSQL server ErrorResponse pgServerError := []byte{ 'E', // Type: ErrorResponse 0x00, 0x00, 0x00, 0x31, // Length 'S', 'E', 'R', 'R', 'O', 'R', 0x00, // ... and so on } // Invalid data (HTTP GET request) httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") testRunner(t, []*testCase{ { Name: "PostgeSQL Sever AuthentcationOk", Direction: Server, Data: pgServerAuthOK, DstPort: 5432, WantProto: ProtocolPostgreSQL, WantConfidence: .65, }, { Name: "Invalid HTTP request", Direction: Server, Data: httpBanner, SrcPort: 3306, WantError: ErrUnknown, }, }) t.Run("AuthenticationOk", func(t *testing.T) { p, c, err := Detect(Server, pgServerAuthOK, 5432, 0) if err != nil { t.Fatal(err) return } t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) if p.Name != ProtocolPostgreSQL { t.Fatalf("expected postgres protocol, got %s", p.Name) return } }) t.Run("ErrorResponse", func(t *testing.T) { p, c, err := Detect(Server, pgServerError, 5432, 0) if err != nil { t.Fatal(err) return } t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) if p.Name != ProtocolPostgreSQL { t.Fatalf("expected postgres protocol, got %s", p.Name) return } }) t.Run("Invalid HTTP", func(t *testing.T) { _, _, err := Detect(Server, httpBanner, 0, 80) if !errors.Is(err, ErrUnknown) { t.Fatalf("expected unknown format, got error %T: %q", err, err) } else { t.Logf("error %q, as expected", err) } }) }