Switch to new test harness

This commit is contained in:
2025-10-09 15:37:17 +02:00
parent 170a038612
commit fd55412020
8 changed files with 567 additions and 304 deletions

View File

@@ -25,7 +25,6 @@ func detectHTTPRequest(dir Direction, data []byte, srcPort, dstPort int) (proto
confidence = -.1 confidence = -.1
} }
if Strict {
var ( var (
b = append(data, '\r', '\n') b = append(data, '\r', '\n')
r = bufio.NewReader(bytes.NewReader(b)) r = bufio.NewReader(bytes.NewReader(b))
@@ -40,6 +39,8 @@ func detectHTTPRequest(dir Direction, data []byte, srcPort, dstPort int) (proto
}, },
}, confidence + .85 }, confidence + .85
} }
if Strict {
return nil, 0 return nil, 0
} }

View File

@@ -1,7 +1,6 @@
package protocol package protocol
import ( import (
"errors"
"testing" "testing"
) )
@@ -17,6 +16,32 @@ func TestDetectHTTPRequest(t *testing.T) {
// An invalid HTTP request // An invalid HTTP request
sshBanner := []byte("SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.4\r\n") sshBanner := []byte("SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.4\r\n")
tests := []*testCase{
{
Name: "HTTP/1.0 GET",
Direction: Client,
Data: http10Request,
DstPort: 80,
WantProto: ProtocolHTTP,
WantConfidence: .95,
},
{
Name: "HTTP/1.1 GET",
Direction: Client,
Data: getRequest,
DstPort: 80,
WantProto: ProtocolHTTP,
WantConfidence: .95,
},
{
Name: "Invalid SSH",
Direction: Client,
Data: sshBanner,
DstPort: 80,
WantError: ErrUnknown,
},
}
defer func() { Strict = false }() defer func() { Strict = false }()
for _, strict := range []bool{false, true} { for _, strict := range []bool{false, true} {
Strict = strict Strict = strict
@@ -27,40 +52,7 @@ func TestDetectHTTPRequest(t *testing.T) {
} }
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Run("HTTP/1.0 GET", func(t *testing.T) { testRunner(t, tests)
p, c, err := Detect(Client, http10Request, 1234, 80)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
if p.Name != ProtocolHTTP {
t.Fatalf("expected http protocol, got %s", p.Name)
return
}
})
t.Run("HTTP/1.1 GET", func(t *testing.T) {
p, c, err := Detect(Client, getRequest, 1234, 80)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
if p.Name != ProtocolHTTP {
t.Fatalf("expected http protocol, got %s", p.Name)
return
}
})
t.Run("Invalid SSH", func(t *testing.T) {
_, _, err := Detect(Server, sshBanner, 1234, 22)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {
t.Logf("error %q, as expected", err)
}
})
}) })
} }
} }
@@ -83,72 +75,67 @@ func TestDetectHTTPResponse(t *testing.T) {
// An invalid banner (SSH) // An invalid banner (SSH)
sshBanner := []byte("SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.4\r\n") sshBanner := []byte("SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.4\r\n")
tests := []*testCase{
{
Name: "HTTP/1.0 403",
Direction: Server,
Data: http10Response,
SrcPort: 80,
WantProto: ProtocolHTTP,
},
{
Name: "HTTP/1.1 200",
Direction: Server,
Data: responseOK,
SrcPort: 80,
WantProto: ProtocolHTTP,
},
{
Name: "HTTP/1.1 404",
Direction: Server,
Data: responseNotFound,
SrcPort: 80,
WantProto: ProtocolHTTP,
},
{
Name: "Invalid HTTP/1.1 GET",
Direction: Server,
Data: getRequest,
SrcPort: 80,
WantError: ErrUnknown,
},
{
Name: "Invalid SSH",
Direction: Server,
Data: sshBanner,
SrcPort: 80,
WantError: ErrUnknown,
},
}
defer func() { Strict = false }() defer func() { Strict = false }()
for _, strict := range []bool{false, true} { for _, strict := range []bool{false, true} {
Strict = strict Strict = strict
name := "loose" var name string
if strict { if strict {
name = "strict" name = "strict"
for _, test := range tests {
if test.WantError == nil {
test.WantConfidence = .95
}
}
} else {
name = "loose"
for _, test := range tests {
if test.WantError == nil {
test.WantConfidence = .85
}
}
} }
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Run("HTTP/1.0 403", func(t *testing.T) { testRunner(t, tests)
p, c, err := Detect(Server, http10Response, 80, 1234)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
if p.Name != ProtocolHTTP {
t.Fatalf("expected http protocol, got %s", p.Name)
return
}
})
t.Run("HTTP/1.1 200", func(t *testing.T) {
p, c, err := Detect(Server, responseOK, 80, 1234)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
if p.Name != ProtocolHTTP {
t.Fatalf("expected http protocol, got %s", p.Name)
return
}
})
t.Run("HTTP/1.1 404", func(t *testing.T) {
p, c, err := Detect(Server, responseNotFound, 80, 1234)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
if p.Name != ProtocolHTTP {
t.Fatalf("expected http protocol, got %s", p.Name)
return
}
})
t.Run("Invalid HTTP/1.1 GET", func(t *testing.T) {
_, _, err := Detect(Server, getRequest, 1234, 80)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {
t.Logf("error %q, as expected", err)
}
})
t.Run("Invalid SSH", func(t *testing.T) {
_, _, err := Detect(Server, sshBanner, 22, 1234)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {
t.Logf("error %q, as expected", err)
}
})
}) })
} }
} }

View File

@@ -1,7 +1,6 @@
package protocol package protocol
import ( import (
"errors"
"testing" "testing"
) )
@@ -14,7 +13,8 @@ func TestDetectMySQL(t *testing.T) {
0x04, 0x5a, 0x56, 0x5f, 0x3e, 0x6e, 0x76, 0x27, 0x00, 0xff, 0xff, 0xff, 0x04, 0x5a, 0x56, 0x5f, 0x3e, 0x6e, 0x76, 0x27, 0x00, 0xff, 0xff, 0xff,
0x02, 0x00, 0xff, 0xc7, 0x15, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0xff, 0xc7, 0x15, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x63, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x67, 0x5f, 0x73, 0x00, 0x00, 0x00, 0x63, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x67, 0x5f, 0x73,
0x68, 0x61, 0x32, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, 0x68, 0x61, 0x32, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64,
0x00,
} }
// 2. A valid MariaDB banner (protocol-compatible) // 2. A valid MariaDB banner (protocol-compatible)
@@ -37,46 +37,43 @@ func TestDetectMySQL(t *testing.T) {
// 5. A slice that starts correctly but is malformed (no null terminator) // 5. A slice that starts correctly but is malformed (no null terminator)
malformedSlice := []byte{0x0a, 0x38, 0x2e, 0x30, 0x2e, 0x30, 0x01, 0x02, 0x03, 0x04, 0x05} malformedSlice := []byte{0x0a, 0x38, 0x2e, 0x30, 0x2e, 0x30, 0x01, 0x02, 0x03, 0x04, 0x05}
t.Run("MySQL 8", func(t *testing.T) { testRunner(t, []*testCase{
p, c, _ := Detect(Server, mysql8Banner, 3306, 0) {
if p == nil { Name: "MySQL server",
t.Fatal("expected MySQL protocol, got nil") Direction: Server,
} Data: mysql8Banner,
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) SrcPort: 3306,
}) WantProto: ProtocolMySQL,
WantConfidence: .85,
t.Run("MariaDB", func(t *testing.T) { },
p, c, _ := Detect(Server, mariaDBBanner, 3306, 0) {
if p == nil { Name: "MariaDB server",
t.Fatal("expected MySQL protocol, got nil") Direction: Server,
} Data: mariaDBBanner,
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) SrcPort: 3306,
}) WantProto: ProtocolMySQL,
WantConfidence: .85,
t.Run("Invalid HTTP", func(t *testing.T) { },
_, _, err := Detect(Server, httpBanner, 1234, 80) {
if !errors.Is(err, ErrUnknown) { Name: "Invalid HTTP",
t.Fatalf("expected unknown format, got error %T: %q", err, err) Direction: Server,
} else { Data: httpBanner,
t.Logf("error %q, as expected", err) SrcPort: 80,
} WantError: ErrUnknown,
}) },
{
t.Run("Too short", func(t *testing.T) { Name: "Invalid too short",
_, _, err := Detect(Server, shortSlice, 3306, 1234) Direction: Server,
if !errors.Is(err, ErrUnknown) { Data: shortSlice,
t.Fatalf("expected unknown format, got error %T: %q", err, err) SrcPort: 3306,
} else { WantError: ErrUnknown,
t.Logf("error %q, as expected", err) },
} {
}) Name: "Invalid malformed",
Direction: Server,
t.Run("Malformed", func(t *testing.T) { Data: malformedSlice,
_, _, err := Detect(Server, malformedSlice, 3306, 1234) SrcPort: 3306,
if !errors.Is(err, ErrUnknown) { WantError: ErrUnknown,
t.Fatalf("expected unknown format, got error %T: %q", err, err) },
} else {
t.Logf("error %q, as expected", err)
}
}) })
} }

View File

@@ -21,17 +21,31 @@ func TestDetectPostgreSQLClient(t *testing.T) {
'd', 'a', 't', 'a', 'b', 'a', 's', 'e', 0x00, 't', 'e', 's', 't', 0x00, 0x00, 'd', 'a', 't', 'a', 'b', 'a', 's', 'e', 0x00, 't', 'e', 's', 't', 0x00, 0x00,
} }
t.Run("Protocol 3.0", func(t *testing.T) { mysqlBanner := []byte{
p, c, err := Detect(Client, pgClientStartup, 0, 5432) 0x0a, 0x38, 0x2e, 0x30, 0x2e, 0x33, 0x32, 0x00, 0x0d, 0x00, 0x00, 0x00,
if err != nil { 0x04, 0x5a, 0x56, 0x5f, 0x3e, 0x6e, 0x76, 0x27, 0x00, 0xff, 0xff, 0xff,
t.Fatal(err) 0x02, 0x00, 0xff, 0xc7, 0x15, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
return 0x00, 0x00, 0x00, 0x63, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x67, 0x5f, 0x73,
} 0x68, 0x61, 0x32, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64,
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, 100*c) 0x00,
if p.Name != ProtocolPostgreSQL {
t.Fatalf("expected postgres protocol, got %s", p.Name)
return
} }
testRunner(t, []*testCase{
{
Name: "PostgeSQL protocol 3.0",
Direction: Client,
Data: pgClientStartup,
DstPort: 5432,
WantProto: ProtocolPostgreSQL,
WantConfidence: .85,
},
{
Name: "Invalid MySQL server",
Direction: Server,
Data: mysqlBanner,
SrcPort: 3306,
WantError: ErrUnknown,
},
}) })
} }
@@ -57,6 +71,24 @@ func TestDetectPostgreSQLServer(t *testing.T) {
// Invalid data (HTTP GET request) // Invalid data (HTTP GET request)
httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
testRunner(t, []*testCase{
{
Name: "PostgeSQL Sever AuthentcationOk",
Direction: Server,
Data: pgServerAuthOK,
DstPort: 5432,
WantProto: ProtocolPostgreSQL,
WantConfidence: .65,
},
{
Name: "Invalid HTTP request",
Direction: Server,
Data: httpBanner,
SrcPort: 3306,
WantError: ErrUnknown,
},
})
t.Run("AuthenticationOk", func(t *testing.T) { t.Run("AuthenticationOk", func(t *testing.T) {
p, c, err := Detect(Server, pgServerAuthOK, 5432, 0) p, c, err := Detect(Server, pgServerAuthOK, 5432, 0)
if err != nil { if err != nil {

View File

@@ -2,17 +2,37 @@ package protocol
import ( import (
"bytes" "bytes"
"strings"
) )
func init() {
// We can't match on SSH-?.? here, because the client or server may send a banner prior
// to sending the SSH handshake.
Register(Both, "", detectSSH)
}
// The required prefix for the SSH protocol identification line. // The required prefix for the SSH protocol identification line.
const ( const (
ssh199Prefix = "SSH-1.99-" ssh199Prefix = "SSH-1.99-"
ssh20Prefix = "SSH-2.0-" ssh20Prefix = "SSH-2.0-"
) )
func init() { var (
Register(Both, "", detectSSH) commonPort = map[int]bool{
} 22: true,
2200: true,
2222: true,
}
commonImplementations = []string{
"OpenSSH_",
"PuTTY",
"libssh",
"dropbear",
"Go",
"paramiko",
"Cyberduck",
}
)
func detectSSH(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) { func detectSSH(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) {
// The data must be at least as long as the prefix itself. // The data must be at least as long as the prefix itself.
@@ -20,34 +40,52 @@ func detectSSH(dir Direction, data []byte, srcPort, dstPort int) (proto *Protoco
return nil, 0 return nil, 0
} }
if dstPort == 22 || dstPort == 2200 || dstPort == 2222 { if commonPort[srcPort] || commonPort[dstPort] {
confidence = .1 confidence += .1
} }
// The protocol allows for pre-banner text, so we have to check all lines. // The protocol allows for pre-banner text, so we have to check all lines.
for _, line := range bytes.Split(data, []byte{'\n'}) { for _, line := range bytes.Split(data, []byte{'\n'}) {
line = bytes.TrimSuffix(line, []byte{'\r'}) line = bytes.TrimSuffix(line, []byte{'\r'})
// Match the most common SSH 2.0 protocol.
if bytes.HasPrefix(line, []byte(ssh20Prefix)) { if bytes.HasPrefix(line, []byte(ssh20Prefix)) {
implementation := string(line[len(ssh20Prefix):])
for _, prefix := range commonImplementations {
if strings.HasPrefix(implementation, prefix) {
confidence += .2
break
}
}
return &Protocol{ return &Protocol{
Name: ProtocolSSH, Name: ProtocolSSH,
Version: Version{ Version: Version{
Major: 2, Major: 2,
Minor: 0, Minor: 0,
Patch: -1, Patch: -1,
Extra: string(line[len(ssh20Prefix):]), Extra: implementation,
}, },
}, confidence + 0.75 }, confidence + 0.65
} }
// Match the (far) less common SSH 1.99 protocol.
if bytes.HasPrefix(line, []byte(ssh199Prefix)) { if bytes.HasPrefix(line, []byte(ssh199Prefix)) {
implementation := string(line[len(ssh20Prefix):])
for _, prefix := range commonImplementations {
if strings.HasPrefix(implementation, prefix) {
confidence += .2
break
}
}
return &Protocol{ return &Protocol{
Name: ProtocolSSH, Name: ProtocolSSH,
Version: Version{ Version: Version{
Major: 1, Major: 1,
Minor: 99, Minor: 99,
Patch: -1, Patch: -1,
Extra: string(line[len(ssh20Prefix):]), Extra: implementation,
}, },
}, confidence + 0.75 }, confidence + 0.65
} }
} }

View File

@@ -1,7 +1,6 @@
package protocol package protocol
import ( import (
"errors"
"testing" "testing"
) )
@@ -31,13 +30,63 @@ func TestDetectSSH(t *testing.T) {
// 5. A simple HTTP request // 5. A simple HTTP request
httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
testRunner(t, []*testCase{
{
Name: "OpenSSH client",
Direction: Client,
Data: openSSHBanner,
DstPort: 22,
WantProto: ProtocolSSH,
WantConfidence: .95,
},
{
Name: "OpenSSH server",
Direction: Server,
Data: openSSHBanner,
SrcPort: 22,
WantProto: ProtocolSSH,
WantConfidence: .95,
},
{
Name: "OpenSSH server with banner",
Direction: Server,
Data: preBannerSSH,
SrcPort: 22,
WantProto: ProtocolSSH,
WantConfidence: .95,
},
{
Name: "Dropbear server",
Direction: Server,
Data: dropbearBanner,
SrcPort: 22,
WantProto: ProtocolSSH,
WantConfidence: .95,
},
{
Name: "Invalid MySQL",
Direction: Server,
Data: mysqlBanner,
SrcPort: 3306,
WantError: ErrUnknown,
},
{
Name: "Invalid HTTP",
Direction: Client,
Data: httpBanner,
DstPort: 80,
WantError: ErrUnknown,
},
})
/*
t.Run("OpenSSH client", func(t *testing.T) { t.Run("OpenSSH client", func(t *testing.T) {
p, _, err := Detect(Server, openSSHBanner, 0, 22) p, c, err := Detect(Server, openSSHBanner, 0, 22)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
} }
t.Logf("detected %s version %s", p.Name, p.Version) t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
if p.Name != ProtocolSSH { if p.Name != ProtocolSSH {
t.Fatalf("expected ssh protocol, got %s", p.Name) t.Fatalf("expected ssh protocol, got %s", p.Name)
return return
@@ -45,12 +94,12 @@ func TestDetectSSH(t *testing.T) {
}) })
t.Run("OpenSSH server", func(t *testing.T) { t.Run("OpenSSH server", func(t *testing.T) {
p, _, err := Detect(Server, openSSHBanner, 0, 22) p, c, err := Detect(Server, openSSHBanner, 0, 22)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
} }
t.Logf("detected %s version %s", p.Name, p.Version) t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
if p.Name != ProtocolSSH { if p.Name != ProtocolSSH {
t.Fatalf("expected ssh protocol, got %s", p.Name) t.Fatalf("expected ssh protocol, got %s", p.Name)
return return
@@ -58,12 +107,12 @@ func TestDetectSSH(t *testing.T) {
}) })
t.Run("OpenSSH server with banner", func(t *testing.T) { t.Run("OpenSSH server with banner", func(t *testing.T) {
p, _, err := Detect(Server, preBannerSSH, 0, 22) p, c, err := Detect(Server, preBannerSSH, 0, 22)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
} }
t.Logf("detected %s version %s", p.Name, p.Version) t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
if p.Name != ProtocolSSH { if p.Name != ProtocolSSH {
t.Fatalf("expected ssh protocol, got %s", p.Name) t.Fatalf("expected ssh protocol, got %s", p.Name)
return return
@@ -71,12 +120,12 @@ func TestDetectSSH(t *testing.T) {
}) })
t.Run("Dropbear server", func(t *testing.T) { t.Run("Dropbear server", func(t *testing.T) {
p, _, err := Detect(Server, dropbearBanner, 0, 22) p, c, err := Detect(Server, dropbearBanner, 0, 22)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
} }
t.Logf("detected %s version %s", p.Name, p.Version) t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
if p.Name != ProtocolSSH { if p.Name != ProtocolSSH {
t.Fatalf("expected ssh protocol, got %s", p.Name) t.Fatalf("expected ssh protocol, got %s", p.Name)
return return
@@ -100,4 +149,5 @@ func TestDetectSSH(t *testing.T) {
t.Logf("error %q, as expected", err) t.Logf("error %q, as expected", err)
} }
}) })
*/
} }

View File

@@ -2,7 +2,6 @@ package protocol
import ( import (
"encoding/hex" "encoding/hex"
"errors"
"strings" "strings"
"testing" "testing"
) )
@@ -167,15 +166,74 @@ func TestDetectTLS(t *testing.T) {
0x00, 0x00, 0x00, 0x25, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x25, 0x00, 0x03, 0x00, 0x00,
} }
tests := []*testCase{
{
Name: "SSLv3",
Direction: Client,
Data: sslV3ClientHello,
DstPort: 443,
WantProto: ProtocolSSL,
WantConfidence: .95,
},
{
Name: "TLS 1.1",
Direction: Client,
Data: tls11ClientHello,
DstPort: 443,
WantProto: ProtocolTLS,
WantConfidence: .95,
},
{
Name: "TLS 1.2",
Direction: Client,
Data: tls12ClientHello,
DstPort: 443,
WantProto: ProtocolTLS,
WantConfidence: .95,
},
{
Name: "TLS 1.3",
Direction: Client,
Data: tls13ClientHello,
DstPort: 443,
WantProto: ProtocolTLS,
WantConfidence: .95,
},
{
Name: "Invalid PostgreSQL",
Direction: Client,
Data: pgClientStartup,
DstPort: 5432,
WantError: ErrUnknown,
},
}
defer func() { Strict = false }() defer func() { Strict = false }()
for _, strict := range []bool{false, true} { for _, strict := range []bool{false, true} {
Strict = strict Strict = strict
name := "loose"
if strict { if strict {
name = "strict" t.Run("strict", func(t *testing.T) {
testRunner(t, tests)
})
} else {
// Strict runner doesn't allow for partial packet matching:
t.Run("loose", func(t *testing.T) {
testRunner(t, append([]*testCase{
{
Name: "TLS 1.1 partial",
Direction: Client,
Data: tls11ClientHelloPartial,
DstPort: 443,
WantProto: ProtocolTLS,
WantConfidence: .50,
},
}, tests...))
})
} }
/*
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Run("SSLv3 Client Hello", func(t *testing.T) { t.Run("SSLv3 Client Hello", func(t *testing.T) {
@@ -260,6 +318,7 @@ func TestDetectTLS(t *testing.T) {
} }
}) })
}) })
*/
} }
} }

View File

@@ -1,10 +1,109 @@
package protocol package protocol
import ( import (
"errors"
"fmt"
"math" "math"
"math/rand"
"strconv"
"strings"
"testing" "testing"
) )
type testCase struct {
Name string
Direction Direction
Data []byte
SrcPort int
DstPort int
WantProto string
WantConfidence float64
WantError error
}
func testRunner(t *testing.T, tests []*testCase) {
t.Helper()
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
if test.SrcPort == 0 {
test.SrcPort = 1024 + rand.Intn(65535-1024)
}
if test.DstPort == 0 {
test.DstPort = 1024 + rand.Intn(65535-1024)
}
proto, confidence, err := Detect(test.Direction, test.Data, test.SrcPort, test.DstPort)
// Process error first
if err != nil {
if test.WantError == nil {
t.Fatalf("unexpected error: %v", err)
} else if !errors.Is(err, test.WantError) {
t.Fatalf("Detect(%s, %s, %d, %d) returned error %q, expected %q",
test.Direction, testBytesSample(test.Data, 8), test.SrcPort, test.DstPort,
err, test.WantError)
} else {
t.Logf("Detect(%s, %s, %d, %d) returned error %q as expected",
test.Direction, testBytesSample(test.Data, 8), test.SrcPort, test.DstPort,
err)
}
return
} else if test.WantError != nil {
t.Fatalf("Detect(%s, %s, %d, %d) returned protocol %q version %s, expected error %q",
test.Direction, testBytesSample(test.Data, 8), test.SrcPort, test.DstPort,
proto.Name, proto.Version, test.WantError)
return
}
// Process protocol
if proto == nil {
t.Fatalf("Detect(%s, %s, %d, %d) returned nil, expected protocol %q",
test.Direction, testBytesSample(test.Data, 8), test.SrcPort, test.DstPort,
test.WantProto)
return
}
t.Logf("Detect(%s, %s, %d, %d) returned protocol %q version %s with confidence %g%%",
test.Direction, testBytesSample(test.Data, 4), test.SrcPort, test.DstPort,
proto.Name, proto.Version, confidence*100)
if proto.Name != test.WantProto {
t.Errorf("Expected protocol %q", test.WantProto)
}
if !testAlmostEqual(confidence, test.WantConfidence) {
t.Errorf("Expected confidence %g%%", test.WantConfidence*100)
}
})
}
}
func testBytesSample(b []byte, n int) string {
if b == nil {
return "<nil>"
}
var (
hex []string
etc string
)
for i, l := 0, len(b); i < l && i < n; i++ {
if strconv.IsPrint(rune(b[i])) {
hex = append(hex, fmt.Sprintf("%c", b[i]))
} else {
hex = append(hex, fmt.Sprintf("\\x%02X", b[i]))
}
if i == (n-1) && l > (n-1) {
etc = fmt.Sprintf(" … (%d more)", l-n)
}
}
return fmt.Sprintf(`"%s"%s`, strings.Join(hex, ""), etc)
}
func testAlmostEqual(a, b float64) bool {
const e = 1e-9
return math.Abs(a-b) < e
}
func TestCompareFloats(t *testing.T) { func TestCompareFloats(t *testing.T) {
tests := []struct { tests := []struct {
name string name string