diff --git a/protocol/detect_http.go b/protocol/detect_http.go index b7a4044..f656ce2 100644 --- a/protocol/detect_http.go +++ b/protocol/detect_http.go @@ -25,21 +25,22 @@ func detectHTTPRequest(dir Direction, data []byte, srcPort, dstPort int) (proto confidence = -.1 } + var ( + b = append(data, '\r', '\n') + r = bufio.NewReader(bytes.NewReader(b)) + ) + if request, err := http.ReadRequest(r); err == nil { + return &Protocol{ + Name: ProtocolHTTP, + Version: Version{ + Major: request.ProtoMajor, + Minor: request.ProtoMinor, + Patch: -1, + }, + }, confidence + .85 + } + if Strict { - var ( - b = append(data, '\r', '\n') - r = bufio.NewReader(bytes.NewReader(b)) - ) - if request, err := http.ReadRequest(r); err == nil { - return &Protocol{ - Name: ProtocolHTTP, - Version: Version{ - Major: request.ProtoMajor, - Minor: request.ProtoMinor, - Patch: -1, - }, - }, confidence + .85 - } return nil, 0 } diff --git a/protocol/detect_http_test.go b/protocol/detect_http_test.go index 0942d28..dcbf890 100644 --- a/protocol/detect_http_test.go +++ b/protocol/detect_http_test.go @@ -1,7 +1,6 @@ package protocol import ( - "errors" "testing" ) @@ -17,6 +16,32 @@ func TestDetectHTTPRequest(t *testing.T) { // An invalid HTTP request sshBanner := []byte("SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.4\r\n") + tests := []*testCase{ + { + Name: "HTTP/1.0 GET", + Direction: Client, + Data: http10Request, + DstPort: 80, + WantProto: ProtocolHTTP, + WantConfidence: .95, + }, + { + Name: "HTTP/1.1 GET", + Direction: Client, + Data: getRequest, + DstPort: 80, + WantProto: ProtocolHTTP, + WantConfidence: .95, + }, + { + Name: "Invalid SSH", + Direction: Client, + Data: sshBanner, + DstPort: 80, + WantError: ErrUnknown, + }, + } + defer func() { Strict = false }() for _, strict := range []bool{false, true} { Strict = strict @@ -27,40 +52,7 @@ func TestDetectHTTPRequest(t *testing.T) { } t.Run(name, func(t *testing.T) { - t.Run("HTTP/1.0 GET", func(t *testing.T) { - p, c, err := Detect(Client, http10Request, 1234, 80) - if err != nil { - t.Fatal(err) - return - } - t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) - if p.Name != ProtocolHTTP { - t.Fatalf("expected http protocol, got %s", p.Name) - return - } - }) - - t.Run("HTTP/1.1 GET", func(t *testing.T) { - p, c, err := Detect(Client, getRequest, 1234, 80) - if err != nil { - t.Fatal(err) - return - } - t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) - if p.Name != ProtocolHTTP { - t.Fatalf("expected http protocol, got %s", p.Name) - return - } - }) - - t.Run("Invalid SSH", func(t *testing.T) { - _, _, err := Detect(Server, sshBanner, 1234, 22) - if !errors.Is(err, ErrUnknown) { - t.Fatalf("expected unknown format, got error %T: %q", err, err) - } else { - t.Logf("error %q, as expected", err) - } - }) + testRunner(t, tests) }) } } @@ -83,72 +75,67 @@ func TestDetectHTTPResponse(t *testing.T) { // An invalid banner (SSH) sshBanner := []byte("SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.4\r\n") + tests := []*testCase{ + { + Name: "HTTP/1.0 403", + Direction: Server, + Data: http10Response, + SrcPort: 80, + WantProto: ProtocolHTTP, + }, + { + Name: "HTTP/1.1 200", + Direction: Server, + Data: responseOK, + SrcPort: 80, + WantProto: ProtocolHTTP, + }, + { + Name: "HTTP/1.1 404", + Direction: Server, + Data: responseNotFound, + SrcPort: 80, + WantProto: ProtocolHTTP, + }, + { + Name: "Invalid HTTP/1.1 GET", + Direction: Server, + Data: getRequest, + SrcPort: 80, + WantError: ErrUnknown, + }, + { + Name: "Invalid SSH", + Direction: Server, + Data: sshBanner, + SrcPort: 80, + WantError: ErrUnknown, + }, + } + defer func() { Strict = false }() for _, strict := range []bool{false, true} { Strict = strict - name := "loose" + var name string if strict { name = "strict" + for _, test := range tests { + if test.WantError == nil { + test.WantConfidence = .95 + } + } + } else { + name = "loose" + for _, test := range tests { + if test.WantError == nil { + test.WantConfidence = .85 + } + } } t.Run(name, func(t *testing.T) { - t.Run("HTTP/1.0 403", func(t *testing.T) { - p, c, err := Detect(Server, http10Response, 80, 1234) - if err != nil { - t.Fatal(err) - return - } - t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) - if p.Name != ProtocolHTTP { - t.Fatalf("expected http protocol, got %s", p.Name) - return - } - }) - - t.Run("HTTP/1.1 200", func(t *testing.T) { - p, c, err := Detect(Server, responseOK, 80, 1234) - if err != nil { - t.Fatal(err) - return - } - t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) - if p.Name != ProtocolHTTP { - t.Fatalf("expected http protocol, got %s", p.Name) - return - } - }) - - t.Run("HTTP/1.1 404", func(t *testing.T) { - p, c, err := Detect(Server, responseNotFound, 80, 1234) - if err != nil { - t.Fatal(err) - return - } - t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) - if p.Name != ProtocolHTTP { - t.Fatalf("expected http protocol, got %s", p.Name) - return - } - }) - - t.Run("Invalid HTTP/1.1 GET", func(t *testing.T) { - _, _, err := Detect(Server, getRequest, 1234, 80) - if !errors.Is(err, ErrUnknown) { - t.Fatalf("expected unknown format, got error %T: %q", err, err) - } else { - t.Logf("error %q, as expected", err) - } - }) - - t.Run("Invalid SSH", func(t *testing.T) { - _, _, err := Detect(Server, sshBanner, 22, 1234) - if !errors.Is(err, ErrUnknown) { - t.Fatalf("expected unknown format, got error %T: %q", err, err) - } else { - t.Logf("error %q, as expected", err) - } - }) + testRunner(t, tests) }) } } diff --git a/protocol/detect_mysql_test.go b/protocol/detect_mysql_test.go index 866a50a..277b407 100644 --- a/protocol/detect_mysql_test.go +++ b/protocol/detect_mysql_test.go @@ -1,7 +1,6 @@ package protocol import ( - "errors" "testing" ) @@ -14,7 +13,8 @@ func TestDetectMySQL(t *testing.T) { 0x04, 0x5a, 0x56, 0x5f, 0x3e, 0x6e, 0x76, 0x27, 0x00, 0xff, 0xff, 0xff, 0x02, 0x00, 0xff, 0xc7, 0x15, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x63, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x67, 0x5f, 0x73, - 0x68, 0x61, 0x32, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, + 0x68, 0x61, 0x32, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, + 0x00, } // 2. A valid MariaDB banner (protocol-compatible) @@ -37,46 +37,43 @@ func TestDetectMySQL(t *testing.T) { // 5. A slice that starts correctly but is malformed (no null terminator) malformedSlice := []byte{0x0a, 0x38, 0x2e, 0x30, 0x2e, 0x30, 0x01, 0x02, 0x03, 0x04, 0x05} - t.Run("MySQL 8", func(t *testing.T) { - p, c, _ := Detect(Server, mysql8Banner, 3306, 0) - if p == nil { - t.Fatal("expected MySQL protocol, got nil") - } - t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) - }) - - t.Run("MariaDB", func(t *testing.T) { - p, c, _ := Detect(Server, mariaDBBanner, 3306, 0) - if p == nil { - t.Fatal("expected MySQL protocol, got nil") - } - t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) - }) - - t.Run("Invalid HTTP", func(t *testing.T) { - _, _, err := Detect(Server, httpBanner, 1234, 80) - if !errors.Is(err, ErrUnknown) { - t.Fatalf("expected unknown format, got error %T: %q", err, err) - } else { - t.Logf("error %q, as expected", err) - } - }) - - t.Run("Too short", func(t *testing.T) { - _, _, err := Detect(Server, shortSlice, 3306, 1234) - if !errors.Is(err, ErrUnknown) { - t.Fatalf("expected unknown format, got error %T: %q", err, err) - } else { - t.Logf("error %q, as expected", err) - } - }) - - t.Run("Malformed", func(t *testing.T) { - _, _, err := Detect(Server, malformedSlice, 3306, 1234) - if !errors.Is(err, ErrUnknown) { - t.Fatalf("expected unknown format, got error %T: %q", err, err) - } else { - t.Logf("error %q, as expected", err) - } + testRunner(t, []*testCase{ + { + Name: "MySQL server", + Direction: Server, + Data: mysql8Banner, + SrcPort: 3306, + WantProto: ProtocolMySQL, + WantConfidence: .85, + }, + { + Name: "MariaDB server", + Direction: Server, + Data: mariaDBBanner, + SrcPort: 3306, + WantProto: ProtocolMySQL, + WantConfidence: .85, + }, + { + Name: "Invalid HTTP", + Direction: Server, + Data: httpBanner, + SrcPort: 80, + WantError: ErrUnknown, + }, + { + Name: "Invalid too short", + Direction: Server, + Data: shortSlice, + SrcPort: 3306, + WantError: ErrUnknown, + }, + { + Name: "Invalid malformed", + Direction: Server, + Data: malformedSlice, + SrcPort: 3306, + WantError: ErrUnknown, + }, }) } diff --git a/protocol/detect_postgres_test.go b/protocol/detect_postgres_test.go index 89d8035..f1f5cba 100644 --- a/protocol/detect_postgres_test.go +++ b/protocol/detect_postgres_test.go @@ -21,17 +21,31 @@ func TestDetectPostgreSQLClient(t *testing.T) { 'd', 'a', 't', 'a', 'b', 'a', 's', 'e', 0x00, 't', 'e', 's', 't', 0x00, 0x00, } - t.Run("Protocol 3.0", func(t *testing.T) { - p, c, err := Detect(Client, pgClientStartup, 0, 5432) - if err != nil { - t.Fatal(err) - return - } - t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, 100*c) - if p.Name != ProtocolPostgreSQL { - t.Fatalf("expected postgres protocol, got %s", p.Name) - return - } + mysqlBanner := []byte{ + 0x0a, 0x38, 0x2e, 0x30, 0x2e, 0x33, 0x32, 0x00, 0x0d, 0x00, 0x00, 0x00, + 0x04, 0x5a, 0x56, 0x5f, 0x3e, 0x6e, 0x76, 0x27, 0x00, 0xff, 0xff, 0xff, + 0x02, 0x00, 0xff, 0xc7, 0x15, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x63, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x67, 0x5f, 0x73, + 0x68, 0x61, 0x32, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, + 0x00, + } + + testRunner(t, []*testCase{ + { + Name: "PostgeSQL protocol 3.0", + Direction: Client, + Data: pgClientStartup, + DstPort: 5432, + WantProto: ProtocolPostgreSQL, + WantConfidence: .85, + }, + { + Name: "Invalid MySQL server", + Direction: Server, + Data: mysqlBanner, + SrcPort: 3306, + WantError: ErrUnknown, + }, }) } @@ -57,6 +71,24 @@ func TestDetectPostgreSQLServer(t *testing.T) { // Invalid data (HTTP GET request) httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + testRunner(t, []*testCase{ + { + Name: "PostgeSQL Sever AuthentcationOk", + Direction: Server, + Data: pgServerAuthOK, + DstPort: 5432, + WantProto: ProtocolPostgreSQL, + WantConfidence: .65, + }, + { + Name: "Invalid HTTP request", + Direction: Server, + Data: httpBanner, + SrcPort: 3306, + WantError: ErrUnknown, + }, + }) + t.Run("AuthenticationOk", func(t *testing.T) { p, c, err := Detect(Server, pgServerAuthOK, 5432, 0) if err != nil { diff --git a/protocol/detect_ssh.go b/protocol/detect_ssh.go index bb4e9a1..2265942 100644 --- a/protocol/detect_ssh.go +++ b/protocol/detect_ssh.go @@ -2,17 +2,37 @@ package protocol import ( "bytes" + "strings" ) +func init() { + // We can't match on SSH-?.? here, because the client or server may send a banner prior + // to sending the SSH handshake. + Register(Both, "", detectSSH) +} + // The required prefix for the SSH protocol identification line. const ( ssh199Prefix = "SSH-1.99-" ssh20Prefix = "SSH-2.0-" ) -func init() { - Register(Both, "", detectSSH) -} +var ( + commonPort = map[int]bool{ + 22: true, + 2200: true, + 2222: true, + } + commonImplementations = []string{ + "OpenSSH_", + "PuTTY", + "libssh", + "dropbear", + "Go", + "paramiko", + "Cyberduck", + } +) func detectSSH(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) { // The data must be at least as long as the prefix itself. @@ -20,34 +40,52 @@ func detectSSH(dir Direction, data []byte, srcPort, dstPort int) (proto *Protoco return nil, 0 } - if dstPort == 22 || dstPort == 2200 || dstPort == 2222 { - confidence = .1 + if commonPort[srcPort] || commonPort[dstPort] { + confidence += .1 } // The protocol allows for pre-banner text, so we have to check all lines. for _, line := range bytes.Split(data, []byte{'\n'}) { line = bytes.TrimSuffix(line, []byte{'\r'}) + + // Match the most common SSH 2.0 protocol. if bytes.HasPrefix(line, []byte(ssh20Prefix)) { + implementation := string(line[len(ssh20Prefix):]) + for _, prefix := range commonImplementations { + if strings.HasPrefix(implementation, prefix) { + confidence += .2 + break + } + } return &Protocol{ Name: ProtocolSSH, Version: Version{ Major: 2, Minor: 0, Patch: -1, - Extra: string(line[len(ssh20Prefix):]), + Extra: implementation, }, - }, confidence + 0.75 + }, confidence + 0.65 } + + // Match the (far) less common SSH 1.99 protocol. if bytes.HasPrefix(line, []byte(ssh199Prefix)) { + implementation := string(line[len(ssh20Prefix):]) + for _, prefix := range commonImplementations { + if strings.HasPrefix(implementation, prefix) { + confidence += .2 + break + } + } return &Protocol{ Name: ProtocolSSH, Version: Version{ Major: 1, Minor: 99, Patch: -1, - Extra: string(line[len(ssh20Prefix):]), + Extra: implementation, }, - }, confidence + 0.75 + }, confidence + 0.65 } } diff --git a/protocol/detect_ssh_test.go b/protocol/detect_ssh_test.go index 1009f95..9d10abf 100644 --- a/protocol/detect_ssh_test.go +++ b/protocol/detect_ssh_test.go @@ -1,7 +1,6 @@ package protocol import ( - "errors" "testing" ) @@ -31,73 +30,124 @@ func TestDetectSSH(t *testing.T) { // 5. A simple HTTP request httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") - t.Run("OpenSSH client", func(t *testing.T) { - p, _, err := Detect(Server, openSSHBanner, 0, 22) - if err != nil { - t.Fatal(err) - return - } - t.Logf("detected %s version %s", p.Name, p.Version) - if p.Name != ProtocolSSH { - t.Fatalf("expected ssh protocol, got %s", p.Name) - return - } + testRunner(t, []*testCase{ + { + Name: "OpenSSH client", + Direction: Client, + Data: openSSHBanner, + DstPort: 22, + WantProto: ProtocolSSH, + WantConfidence: .95, + }, + { + Name: "OpenSSH server", + Direction: Server, + Data: openSSHBanner, + SrcPort: 22, + WantProto: ProtocolSSH, + WantConfidence: .95, + }, + { + Name: "OpenSSH server with banner", + Direction: Server, + Data: preBannerSSH, + SrcPort: 22, + WantProto: ProtocolSSH, + WantConfidence: .95, + }, + { + Name: "Dropbear server", + Direction: Server, + Data: dropbearBanner, + SrcPort: 22, + WantProto: ProtocolSSH, + WantConfidence: .95, + }, + { + Name: "Invalid MySQL", + Direction: Server, + Data: mysqlBanner, + SrcPort: 3306, + WantError: ErrUnknown, + }, + { + Name: "Invalid HTTP", + Direction: Client, + Data: httpBanner, + DstPort: 80, + WantError: ErrUnknown, + }, }) - t.Run("OpenSSH server", func(t *testing.T) { - p, _, err := Detect(Server, openSSHBanner, 0, 22) - if err != nil { - t.Fatal(err) - return - } - t.Logf("detected %s version %s", p.Name, p.Version) - if p.Name != ProtocolSSH { - t.Fatalf("expected ssh protocol, got %s", p.Name) - return - } - }) + /* + t.Run("OpenSSH client", func(t *testing.T) { + p, c, err := Detect(Server, openSSHBanner, 0, 22) + if err != nil { + t.Fatal(err) + return + } + t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) + if p.Name != ProtocolSSH { + t.Fatalf("expected ssh protocol, got %s", p.Name) + return + } + }) - t.Run("OpenSSH server with banner", func(t *testing.T) { - p, _, err := Detect(Server, preBannerSSH, 0, 22) - if err != nil { - t.Fatal(err) - return - } - t.Logf("detected %s version %s", p.Name, p.Version) - if p.Name != ProtocolSSH { - t.Fatalf("expected ssh protocol, got %s", p.Name) - return - } - }) + t.Run("OpenSSH server", func(t *testing.T) { + p, c, err := Detect(Server, openSSHBanner, 0, 22) + if err != nil { + t.Fatal(err) + return + } + t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) + if p.Name != ProtocolSSH { + t.Fatalf("expected ssh protocol, got %s", p.Name) + return + } + }) - t.Run("Dropbear server", func(t *testing.T) { - p, _, err := Detect(Server, dropbearBanner, 0, 22) - if err != nil { - t.Fatal(err) - return - } - t.Logf("detected %s version %s", p.Name, p.Version) - if p.Name != ProtocolSSH { - t.Fatalf("expected ssh protocol, got %s", p.Name) - return - } - }) + t.Run("OpenSSH server with banner", func(t *testing.T) { + p, c, err := Detect(Server, preBannerSSH, 0, 22) + if err != nil { + t.Fatal(err) + return + } + t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) + if p.Name != ProtocolSSH { + t.Fatalf("expected ssh protocol, got %s", p.Name) + return + } + }) - t.Run("Invalid MySQL banner", func(t *testing.T) { - _, _, err := Detect(Server, mysqlBanner, 0, 3306) - if !errors.Is(err, ErrUnknown) { - t.Fatalf("expected unknown format, got error %T: %q", err, err) - } else { - t.Logf("error %q, as expected", err) - } - }) + t.Run("Dropbear server", func(t *testing.T) { + p, c, err := Detect(Server, dropbearBanner, 0, 22) + if err != nil { + t.Fatal(err) + return + } + t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) + if p.Name != ProtocolSSH { + t.Fatalf("expected ssh protocol, got %s", p.Name) + return + } + }) - t.Run("Invalid HTTP banner", func(t *testing.T) { - _, _, err := Detect(Server, httpBanner, 0, 80) - if !errors.Is(err, ErrUnknown) { - t.Fatalf("expected unknown format, got error %T: %q", err, err) - } else { - t.Logf("error %q, as expected", err) - } - }) + t.Run("Invalid MySQL banner", func(t *testing.T) { + _, _, err := Detect(Server, mysqlBanner, 0, 3306) + if !errors.Is(err, ErrUnknown) { + t.Fatalf("expected unknown format, got error %T: %q", err, err) + } else { + t.Logf("error %q, as expected", err) + } + }) + + t.Run("Invalid HTTP banner", func(t *testing.T) { + _, _, err := Detect(Server, httpBanner, 0, 80) + if !errors.Is(err, ErrUnknown) { + t.Fatalf("expected unknown format, got error %T: %q", err, err) + } else { + t.Logf("error %q, as expected", err) + } + }) + */ } diff --git a/protocol/detect_tls_test.go b/protocol/detect_tls_test.go index 4f0c8cf..b5dec4c 100644 --- a/protocol/detect_tls_test.go +++ b/protocol/detect_tls_test.go @@ -2,7 +2,6 @@ package protocol import ( "encoding/hex" - "errors" "strings" "testing" ) @@ -167,52 +166,91 @@ func TestDetectTLS(t *testing.T) { 0x00, 0x00, 0x00, 0x25, 0x00, 0x03, 0x00, 0x00, } + tests := []*testCase{ + { + Name: "SSLv3", + Direction: Client, + Data: sslV3ClientHello, + DstPort: 443, + WantProto: ProtocolSSL, + WantConfidence: .95, + }, + { + Name: "TLS 1.1", + Direction: Client, + Data: tls11ClientHello, + DstPort: 443, + WantProto: ProtocolTLS, + WantConfidence: .95, + }, + { + Name: "TLS 1.2", + Direction: Client, + Data: tls12ClientHello, + DstPort: 443, + WantProto: ProtocolTLS, + WantConfidence: .95, + }, + { + Name: "TLS 1.3", + Direction: Client, + Data: tls13ClientHello, + DstPort: 443, + WantProto: ProtocolTLS, + WantConfidence: .95, + }, + { + Name: "Invalid PostgreSQL", + Direction: Client, + Data: pgClientStartup, + DstPort: 5432, + WantError: ErrUnknown, + }, + } + defer func() { Strict = false }() for _, strict := range []bool{false, true} { Strict = strict - name := "loose" if strict { - name = "strict" + t.Run("strict", func(t *testing.T) { + testRunner(t, tests) + }) + } else { + // Strict runner doesn't allow for partial packet matching: + t.Run("loose", func(t *testing.T) { + testRunner(t, append([]*testCase{ + { + Name: "TLS 1.1 partial", + Direction: Client, + Data: tls11ClientHelloPartial, + DstPort: 443, + WantProto: ProtocolTLS, + WantConfidence: .50, + }, + }, tests...)) + }) } - t.Run(name, func(t *testing.T) { + /* - t.Run("SSLv3 Client Hello", func(t *testing.T) { - p, _, err := Detect(Client, sslV3ClientHello, 0, 0) - if err != nil { - t.Fatal(err) - return - } - t.Logf("detected %s version %s", p.Name, p.Version) - if p.Name != ProtocolSSL { - t.Fatalf("expected ssl protocol, got %s", p.Name) - return - } - }) + t.Run(name, func(t *testing.T) { - t.Run("TLS 1.1 Client Hello", func(t *testing.T) { - p, _, err := Detect(Client, tls11ClientHello, 0, 0) - if err != nil { - t.Fatal(err) - return - } - t.Logf("detected %s version %s", p.Name, p.Version) - if p.Name != ProtocolTLS { - t.Fatalf("expected tls protocol, got %s", p.Name) - return - } - }) - - t.Run("TLS 1.1 partial Client Hello", func(t *testing.T) { - p, _, err := Detect(Client, tls11ClientHelloPartial, 0, 0) - if strict { - if !errors.Is(err, ErrUnknown) { - t.Fatalf("expected unknown format, got error %T: %q", err, err) - } else { - t.Logf("error %q, as expected", err) + t.Run("SSLv3 Client Hello", func(t *testing.T) { + p, _, err := Detect(Client, sslV3ClientHello, 0, 0) + if err != nil { + t.Fatal(err) + return } - } else { + t.Logf("detected %s version %s", p.Name, p.Version) + if p.Name != ProtocolSSL { + t.Fatalf("expected ssl protocol, got %s", p.Name) + return + } + }) + + t.Run("TLS 1.1 Client Hello", func(t *testing.T) { + p, _, err := Detect(Client, tls11ClientHello, 0, 0) if err != nil { t.Fatal(err) return @@ -222,44 +260,65 @@ func TestDetectTLS(t *testing.T) { t.Fatalf("expected tls protocol, got %s", p.Name) return } - } - }) + }) - t.Run("TLS 1.2 Client Hello", func(t *testing.T) { - p, _, err := Detect(Client, tls12ClientHello, 0, 0) - if err != nil { - t.Fatal(err) - return - } - t.Logf("detected %s version %s", p.Name, p.Version) - if p.Name != ProtocolTLS { - t.Fatalf("expected tls protocol, got %s", p.Name) - return - } - }) + t.Run("TLS 1.1 partial Client Hello", func(t *testing.T) { + p, _, err := Detect(Client, tls11ClientHelloPartial, 0, 0) + if strict { + if !errors.Is(err, ErrUnknown) { + t.Fatalf("expected unknown format, got error %T: %q", err, err) + } else { + t.Logf("error %q, as expected", err) + } + } else { + if err != nil { + t.Fatal(err) + return + } + t.Logf("detected %s version %s", p.Name, p.Version) + if p.Name != ProtocolTLS { + t.Fatalf("expected tls protocol, got %s", p.Name) + return + } + } + }) - t.Run("TLS 1.3 Client Hello", func(t *testing.T) { - p, _, err := Detect(Client, tls13ClientHello, 0, 0) - if err != nil { - t.Fatal(err) - return - } - t.Logf("detected %s version %s", p.Name, p.Version) - if p.Name != ProtocolTLS { - t.Fatalf("expected tls protocol, got %s", p.Name) - return - } - }) + t.Run("TLS 1.2 Client Hello", func(t *testing.T) { + p, _, err := Detect(Client, tls12ClientHello, 0, 0) + if err != nil { + t.Fatal(err) + return + } + t.Logf("detected %s version %s", p.Name, p.Version) + if p.Name != ProtocolTLS { + t.Fatalf("expected tls protocol, got %s", p.Name) + return + } + }) - t.Run("Invalid PostgreSQL", func(t *testing.T) { - _, _, err := Detect(Server, pgClientStartup, 0, 0) - if !errors.Is(err, ErrUnknown) { - t.Fatalf("expected unknown format, got error %T: %q", err, err) - } else { - t.Logf("error %q, as expected", err) - } + t.Run("TLS 1.3 Client Hello", func(t *testing.T) { + p, _, err := Detect(Client, tls13ClientHello, 0, 0) + if err != nil { + t.Fatal(err) + return + } + t.Logf("detected %s version %s", p.Name, p.Version) + if p.Name != ProtocolTLS { + t.Fatalf("expected tls protocol, got %s", p.Name) + return + } + }) + + t.Run("Invalid PostgreSQL", func(t *testing.T) { + _, _, err := Detect(Server, pgClientStartup, 0, 0) + if !errors.Is(err, ErrUnknown) { + t.Fatalf("expected unknown format, got error %T: %q", err, err) + } else { + t.Logf("error %q, as expected", err) + } + }) }) - }) + */ } } diff --git a/protocol/detest_test.go b/protocol/detest_test.go index 8eb7ee6..a61dcb1 100644 --- a/protocol/detest_test.go +++ b/protocol/detest_test.go @@ -1,10 +1,109 @@ package protocol import ( + "errors" + "fmt" "math" + "math/rand" + "strconv" + "strings" "testing" ) +type testCase struct { + Name string + Direction Direction + Data []byte + SrcPort int + DstPort int + WantProto string + WantConfidence float64 + WantError error +} + +func testRunner(t *testing.T, tests []*testCase) { + t.Helper() + + for _, test := range tests { + t.Run(test.Name, func(t *testing.T) { + if test.SrcPort == 0 { + test.SrcPort = 1024 + rand.Intn(65535-1024) + } + if test.DstPort == 0 { + test.DstPort = 1024 + rand.Intn(65535-1024) + } + + proto, confidence, err := Detect(test.Direction, test.Data, test.SrcPort, test.DstPort) + + // Process error first + if err != nil { + if test.WantError == nil { + t.Fatalf("unexpected error: %v", err) + } else if !errors.Is(err, test.WantError) { + t.Fatalf("Detect(%s, %s, %d, %d) returned error %q, expected %q", + test.Direction, testBytesSample(test.Data, 8), test.SrcPort, test.DstPort, + err, test.WantError) + } else { + t.Logf("Detect(%s, %s, %d, %d) returned error %q as expected", + test.Direction, testBytesSample(test.Data, 8), test.SrcPort, test.DstPort, + err) + } + return + } else if test.WantError != nil { + t.Fatalf("Detect(%s, %s, %d, %d) returned protocol %q version %s, expected error %q", + test.Direction, testBytesSample(test.Data, 8), test.SrcPort, test.DstPort, + proto.Name, proto.Version, test.WantError) + return + } + + // Process protocol + if proto == nil { + t.Fatalf("Detect(%s, %s, %d, %d) returned nil, expected protocol %q", + test.Direction, testBytesSample(test.Data, 8), test.SrcPort, test.DstPort, + test.WantProto) + return + } + + t.Logf("Detect(%s, %s, %d, %d) returned protocol %q version %s with confidence %g%%", + test.Direction, testBytesSample(test.Data, 4), test.SrcPort, test.DstPort, + proto.Name, proto.Version, confidence*100) + + if proto.Name != test.WantProto { + t.Errorf("Expected protocol %q", test.WantProto) + } + if !testAlmostEqual(confidence, test.WantConfidence) { + t.Errorf("Expected confidence %g%%", test.WantConfidence*100) + } + }) + } +} + +func testBytesSample(b []byte, n int) string { + if b == nil { + return "" + } + var ( + hex []string + etc string + ) + for i, l := 0, len(b); i < l && i < n; i++ { + if strconv.IsPrint(rune(b[i])) { + hex = append(hex, fmt.Sprintf("%c", b[i])) + } else { + hex = append(hex, fmt.Sprintf("\\x%02X", b[i])) + } + if i == (n-1) && l > (n-1) { + etc = fmt.Sprintf(" … (%d more)", l-n) + } + } + return fmt.Sprintf(`"%s"%s`, strings.Join(hex, ""), etc) +} + +func testAlmostEqual(a, b float64) bool { + const e = 1e-9 + return math.Abs(a-b) < e +} + func TestCompareFloats(t *testing.T) { tests := []struct { name string