Initial import
This commit is contained in:
94
protocol/detect_postgres_test.go
Normal file
94
protocol/detect_postgres_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDetectPostgreSQLClient(t *testing.T) {
|
||||
atomicFormats.Store([]format{})
|
||||
registerPostgreSQL()
|
||||
|
||||
// 1. A valid PostgreSQL client startup message
|
||||
// Format: len (4b), proto (4b), params (n-bytes)
|
||||
// Here, user=mazeio, database=test
|
||||
// The message is: "user\0mazeio\0database\0test\0\0"
|
||||
// Total length: 4 (len) + 4 (proto) + 27 (params) = 35 (0x23)
|
||||
pgClientStartup := []byte{
|
||||
0x00, 0x00, 0x00, 0x23, // Length: 35
|
||||
0x00, 0x03, 0x00, 0x00, // Protocol Version 3.0
|
||||
'u', 's', 'e', 'r', 0x00, 'm', 'a', 'z', 'e', 'i', 'o', 0x00,
|
||||
'd', 'a', 't', 'a', 'b', 'a', 's', 'e', 0x00, 't', 'e', 's', 't', 0x00, 0x00,
|
||||
}
|
||||
|
||||
t.Run("Protocol 3.0", func(t *testing.T) {
|
||||
p, err := Detect(Client, pgClientStartup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
if p.Name != ProtocolPostgreSQL {
|
||||
t.Fatalf("expected postgres protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDetectPostgreSQLServer(t *testing.T) {
|
||||
atomicFormats.Store([]format{})
|
||||
registerPostgreSQL()
|
||||
|
||||
// A valid PostgreSQL server AuthenticationOk response
|
||||
// Format: type (1b), len (4b), content (4b)
|
||||
pgServerAuthOK := []byte{
|
||||
'R', // Type: Authentication
|
||||
0x00, 0x00, 0x00, 0x08, // Length: 8
|
||||
0x00, 0x00, 0x00, 0x00, // Auth OK (0)
|
||||
}
|
||||
|
||||
// A valid PostgreSQL server ErrorResponse
|
||||
pgServerError := []byte{
|
||||
'E', // Type: ErrorResponse
|
||||
0x00, 0x00, 0x00, 0x31, // Length
|
||||
'S', 'E', 'R', 'R', 'O', 'R', 0x00, // ... and so on
|
||||
}
|
||||
|
||||
// Invalid data (HTTP GET request)
|
||||
httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
|
||||
t.Run("AuthenticationOk", func(t *testing.T) {
|
||||
p, err := Detect(Server, pgServerAuthOK)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
if p.Name != ProtocolPostgreSQL {
|
||||
t.Fatalf("expected postgres protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ErrorResponse", func(t *testing.T) {
|
||||
p, err := Detect(Server, pgServerError)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
if p.Name != ProtocolPostgreSQL {
|
||||
t.Fatalf("expected postgres protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid HTTP", func(t *testing.T) {
|
||||
_, err := Detect(Server, httpBanner)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
t.Logf("error %q, as expected", err)
|
||||
}
|
||||
})
|
||||
}
|
Reference in New Issue
Block a user