Switch to new test harness
This commit is contained in:
@@ -25,7 +25,6 @@ func detectHTTPRequest(dir Direction, data []byte, srcPort, dstPort int) (proto
|
||||
confidence = -.1
|
||||
}
|
||||
|
||||
if Strict {
|
||||
var (
|
||||
b = append(data, '\r', '\n')
|
||||
r = bufio.NewReader(bytes.NewReader(b))
|
||||
@@ -40,6 +39,8 @@ func detectHTTPRequest(dir Direction, data []byte, srcPort, dstPort int) (proto
|
||||
},
|
||||
}, confidence + .85
|
||||
}
|
||||
|
||||
if Strict {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
|
@@ -1,7 +1,6 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -17,6 +16,32 @@ func TestDetectHTTPRequest(t *testing.T) {
|
||||
// An invalid HTTP request
|
||||
sshBanner := []byte("SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.4\r\n")
|
||||
|
||||
tests := []*testCase{
|
||||
{
|
||||
Name: "HTTP/1.0 GET",
|
||||
Direction: Client,
|
||||
Data: http10Request,
|
||||
DstPort: 80,
|
||||
WantProto: ProtocolHTTP,
|
||||
WantConfidence: .95,
|
||||
},
|
||||
{
|
||||
Name: "HTTP/1.1 GET",
|
||||
Direction: Client,
|
||||
Data: getRequest,
|
||||
DstPort: 80,
|
||||
WantProto: ProtocolHTTP,
|
||||
WantConfidence: .95,
|
||||
},
|
||||
{
|
||||
Name: "Invalid SSH",
|
||||
Direction: Client,
|
||||
Data: sshBanner,
|
||||
DstPort: 80,
|
||||
WantError: ErrUnknown,
|
||||
},
|
||||
}
|
||||
|
||||
defer func() { Strict = false }()
|
||||
for _, strict := range []bool{false, true} {
|
||||
Strict = strict
|
||||
@@ -27,40 +52,7 @@ func TestDetectHTTPRequest(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Run("HTTP/1.0 GET", func(t *testing.T) {
|
||||
p, c, err := Detect(Client, http10Request, 1234, 80)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
|
||||
if p.Name != ProtocolHTTP {
|
||||
t.Fatalf("expected http protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HTTP/1.1 GET", func(t *testing.T) {
|
||||
p, c, err := Detect(Client, getRequest, 1234, 80)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
|
||||
if p.Name != ProtocolHTTP {
|
||||
t.Fatalf("expected http protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid SSH", func(t *testing.T) {
|
||||
_, _, err := Detect(Server, sshBanner, 1234, 22)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
t.Logf("error %q, as expected", err)
|
||||
}
|
||||
})
|
||||
testRunner(t, tests)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -83,72 +75,67 @@ func TestDetectHTTPResponse(t *testing.T) {
|
||||
// An invalid banner (SSH)
|
||||
sshBanner := []byte("SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.4\r\n")
|
||||
|
||||
tests := []*testCase{
|
||||
{
|
||||
Name: "HTTP/1.0 403",
|
||||
Direction: Server,
|
||||
Data: http10Response,
|
||||
SrcPort: 80,
|
||||
WantProto: ProtocolHTTP,
|
||||
},
|
||||
{
|
||||
Name: "HTTP/1.1 200",
|
||||
Direction: Server,
|
||||
Data: responseOK,
|
||||
SrcPort: 80,
|
||||
WantProto: ProtocolHTTP,
|
||||
},
|
||||
{
|
||||
Name: "HTTP/1.1 404",
|
||||
Direction: Server,
|
||||
Data: responseNotFound,
|
||||
SrcPort: 80,
|
||||
WantProto: ProtocolHTTP,
|
||||
},
|
||||
{
|
||||
Name: "Invalid HTTP/1.1 GET",
|
||||
Direction: Server,
|
||||
Data: getRequest,
|
||||
SrcPort: 80,
|
||||
WantError: ErrUnknown,
|
||||
},
|
||||
{
|
||||
Name: "Invalid SSH",
|
||||
Direction: Server,
|
||||
Data: sshBanner,
|
||||
SrcPort: 80,
|
||||
WantError: ErrUnknown,
|
||||
},
|
||||
}
|
||||
|
||||
defer func() { Strict = false }()
|
||||
for _, strict := range []bool{false, true} {
|
||||
Strict = strict
|
||||
|
||||
name := "loose"
|
||||
var name string
|
||||
if strict {
|
||||
name = "strict"
|
||||
for _, test := range tests {
|
||||
if test.WantError == nil {
|
||||
test.WantConfidence = .95
|
||||
}
|
||||
}
|
||||
} else {
|
||||
name = "loose"
|
||||
for _, test := range tests {
|
||||
if test.WantError == nil {
|
||||
test.WantConfidence = .85
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Run("HTTP/1.0 403", func(t *testing.T) {
|
||||
p, c, err := Detect(Server, http10Response, 80, 1234)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
|
||||
if p.Name != ProtocolHTTP {
|
||||
t.Fatalf("expected http protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HTTP/1.1 200", func(t *testing.T) {
|
||||
p, c, err := Detect(Server, responseOK, 80, 1234)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
|
||||
if p.Name != ProtocolHTTP {
|
||||
t.Fatalf("expected http protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HTTP/1.1 404", func(t *testing.T) {
|
||||
p, c, err := Detect(Server, responseNotFound, 80, 1234)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
|
||||
if p.Name != ProtocolHTTP {
|
||||
t.Fatalf("expected http protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid HTTP/1.1 GET", func(t *testing.T) {
|
||||
_, _, err := Detect(Server, getRequest, 1234, 80)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
t.Logf("error %q, as expected", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid SSH", func(t *testing.T) {
|
||||
_, _, err := Detect(Server, sshBanner, 22, 1234)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
t.Logf("error %q, as expected", err)
|
||||
}
|
||||
})
|
||||
testRunner(t, tests)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -1,7 +1,6 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -14,7 +13,8 @@ func TestDetectMySQL(t *testing.T) {
|
||||
0x04, 0x5a, 0x56, 0x5f, 0x3e, 0x6e, 0x76, 0x27, 0x00, 0xff, 0xff, 0xff,
|
||||
0x02, 0x00, 0xff, 0xc7, 0x15, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x63, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x67, 0x5f, 0x73,
|
||||
0x68, 0x61, 0x32, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00,
|
||||
0x68, 0x61, 0x32, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64,
|
||||
0x00,
|
||||
}
|
||||
|
||||
// 2. A valid MariaDB banner (protocol-compatible)
|
||||
@@ -37,46 +37,43 @@ func TestDetectMySQL(t *testing.T) {
|
||||
// 5. A slice that starts correctly but is malformed (no null terminator)
|
||||
malformedSlice := []byte{0x0a, 0x38, 0x2e, 0x30, 0x2e, 0x30, 0x01, 0x02, 0x03, 0x04, 0x05}
|
||||
|
||||
t.Run("MySQL 8", func(t *testing.T) {
|
||||
p, c, _ := Detect(Server, mysql8Banner, 3306, 0)
|
||||
if p == nil {
|
||||
t.Fatal("expected MySQL protocol, got nil")
|
||||
}
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
|
||||
})
|
||||
|
||||
t.Run("MariaDB", func(t *testing.T) {
|
||||
p, c, _ := Detect(Server, mariaDBBanner, 3306, 0)
|
||||
if p == nil {
|
||||
t.Fatal("expected MySQL protocol, got nil")
|
||||
}
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
|
||||
})
|
||||
|
||||
t.Run("Invalid HTTP", func(t *testing.T) {
|
||||
_, _, err := Detect(Server, httpBanner, 1234, 80)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
t.Logf("error %q, as expected", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Too short", func(t *testing.T) {
|
||||
_, _, err := Detect(Server, shortSlice, 3306, 1234)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
t.Logf("error %q, as expected", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Malformed", func(t *testing.T) {
|
||||
_, _, err := Detect(Server, malformedSlice, 3306, 1234)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
t.Logf("error %q, as expected", err)
|
||||
}
|
||||
testRunner(t, []*testCase{
|
||||
{
|
||||
Name: "MySQL server",
|
||||
Direction: Server,
|
||||
Data: mysql8Banner,
|
||||
SrcPort: 3306,
|
||||
WantProto: ProtocolMySQL,
|
||||
WantConfidence: .85,
|
||||
},
|
||||
{
|
||||
Name: "MariaDB server",
|
||||
Direction: Server,
|
||||
Data: mariaDBBanner,
|
||||
SrcPort: 3306,
|
||||
WantProto: ProtocolMySQL,
|
||||
WantConfidence: .85,
|
||||
},
|
||||
{
|
||||
Name: "Invalid HTTP",
|
||||
Direction: Server,
|
||||
Data: httpBanner,
|
||||
SrcPort: 80,
|
||||
WantError: ErrUnknown,
|
||||
},
|
||||
{
|
||||
Name: "Invalid too short",
|
||||
Direction: Server,
|
||||
Data: shortSlice,
|
||||
SrcPort: 3306,
|
||||
WantError: ErrUnknown,
|
||||
},
|
||||
{
|
||||
Name: "Invalid malformed",
|
||||
Direction: Server,
|
||||
Data: malformedSlice,
|
||||
SrcPort: 3306,
|
||||
WantError: ErrUnknown,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
@@ -21,17 +21,31 @@ func TestDetectPostgreSQLClient(t *testing.T) {
|
||||
'd', 'a', 't', 'a', 'b', 'a', 's', 'e', 0x00, 't', 'e', 's', 't', 0x00, 0x00,
|
||||
}
|
||||
|
||||
t.Run("Protocol 3.0", func(t *testing.T) {
|
||||
p, c, err := Detect(Client, pgClientStartup, 0, 5432)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, 100*c)
|
||||
if p.Name != ProtocolPostgreSQL {
|
||||
t.Fatalf("expected postgres protocol, got %s", p.Name)
|
||||
return
|
||||
mysqlBanner := []byte{
|
||||
0x0a, 0x38, 0x2e, 0x30, 0x2e, 0x33, 0x32, 0x00, 0x0d, 0x00, 0x00, 0x00,
|
||||
0x04, 0x5a, 0x56, 0x5f, 0x3e, 0x6e, 0x76, 0x27, 0x00, 0xff, 0xff, 0xff,
|
||||
0x02, 0x00, 0xff, 0xc7, 0x15, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x63, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x67, 0x5f, 0x73,
|
||||
0x68, 0x61, 0x32, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64,
|
||||
0x00,
|
||||
}
|
||||
|
||||
testRunner(t, []*testCase{
|
||||
{
|
||||
Name: "PostgeSQL protocol 3.0",
|
||||
Direction: Client,
|
||||
Data: pgClientStartup,
|
||||
DstPort: 5432,
|
||||
WantProto: ProtocolPostgreSQL,
|
||||
WantConfidence: .85,
|
||||
},
|
||||
{
|
||||
Name: "Invalid MySQL server",
|
||||
Direction: Server,
|
||||
Data: mysqlBanner,
|
||||
SrcPort: 3306,
|
||||
WantError: ErrUnknown,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -57,6 +71,24 @@ func TestDetectPostgreSQLServer(t *testing.T) {
|
||||
// Invalid data (HTTP GET request)
|
||||
httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
|
||||
testRunner(t, []*testCase{
|
||||
{
|
||||
Name: "PostgeSQL Sever AuthentcationOk",
|
||||
Direction: Server,
|
||||
Data: pgServerAuthOK,
|
||||
DstPort: 5432,
|
||||
WantProto: ProtocolPostgreSQL,
|
||||
WantConfidence: .65,
|
||||
},
|
||||
{
|
||||
Name: "Invalid HTTP request",
|
||||
Direction: Server,
|
||||
Data: httpBanner,
|
||||
SrcPort: 3306,
|
||||
WantError: ErrUnknown,
|
||||
},
|
||||
})
|
||||
|
||||
t.Run("AuthenticationOk", func(t *testing.T) {
|
||||
p, c, err := Detect(Server, pgServerAuthOK, 5432, 0)
|
||||
if err != nil {
|
||||
|
@@ -2,17 +2,37 @@ package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// We can't match on SSH-?.? here, because the client or server may send a banner prior
|
||||
// to sending the SSH handshake.
|
||||
Register(Both, "", detectSSH)
|
||||
}
|
||||
|
||||
// The required prefix for the SSH protocol identification line.
|
||||
const (
|
||||
ssh199Prefix = "SSH-1.99-"
|
||||
ssh20Prefix = "SSH-2.0-"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Register(Both, "", detectSSH)
|
||||
}
|
||||
var (
|
||||
commonPort = map[int]bool{
|
||||
22: true,
|
||||
2200: true,
|
||||
2222: true,
|
||||
}
|
||||
commonImplementations = []string{
|
||||
"OpenSSH_",
|
||||
"PuTTY",
|
||||
"libssh",
|
||||
"dropbear",
|
||||
"Go",
|
||||
"paramiko",
|
||||
"Cyberduck",
|
||||
}
|
||||
)
|
||||
|
||||
func detectSSH(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) {
|
||||
// The data must be at least as long as the prefix itself.
|
||||
@@ -20,34 +40,52 @@ func detectSSH(dir Direction, data []byte, srcPort, dstPort int) (proto *Protoco
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
if dstPort == 22 || dstPort == 2200 || dstPort == 2222 {
|
||||
confidence = .1
|
||||
if commonPort[srcPort] || commonPort[dstPort] {
|
||||
confidence += .1
|
||||
}
|
||||
|
||||
// The protocol allows for pre-banner text, so we have to check all lines.
|
||||
for _, line := range bytes.Split(data, []byte{'\n'}) {
|
||||
line = bytes.TrimSuffix(line, []byte{'\r'})
|
||||
|
||||
// Match the most common SSH 2.0 protocol.
|
||||
if bytes.HasPrefix(line, []byte(ssh20Prefix)) {
|
||||
implementation := string(line[len(ssh20Prefix):])
|
||||
for _, prefix := range commonImplementations {
|
||||
if strings.HasPrefix(implementation, prefix) {
|
||||
confidence += .2
|
||||
break
|
||||
}
|
||||
}
|
||||
return &Protocol{
|
||||
Name: ProtocolSSH,
|
||||
Version: Version{
|
||||
Major: 2,
|
||||
Minor: 0,
|
||||
Patch: -1,
|
||||
Extra: string(line[len(ssh20Prefix):]),
|
||||
Extra: implementation,
|
||||
},
|
||||
}, confidence + 0.75
|
||||
}, confidence + 0.65
|
||||
}
|
||||
|
||||
// Match the (far) less common SSH 1.99 protocol.
|
||||
if bytes.HasPrefix(line, []byte(ssh199Prefix)) {
|
||||
implementation := string(line[len(ssh20Prefix):])
|
||||
for _, prefix := range commonImplementations {
|
||||
if strings.HasPrefix(implementation, prefix) {
|
||||
confidence += .2
|
||||
break
|
||||
}
|
||||
}
|
||||
return &Protocol{
|
||||
Name: ProtocolSSH,
|
||||
Version: Version{
|
||||
Major: 1,
|
||||
Minor: 99,
|
||||
Patch: -1,
|
||||
Extra: string(line[len(ssh20Prefix):]),
|
||||
Extra: implementation,
|
||||
},
|
||||
}, confidence + 0.75
|
||||
}, confidence + 0.65
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,7 +1,6 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -31,13 +30,63 @@ func TestDetectSSH(t *testing.T) {
|
||||
// 5. A simple HTTP request
|
||||
httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
|
||||
testRunner(t, []*testCase{
|
||||
{
|
||||
Name: "OpenSSH client",
|
||||
Direction: Client,
|
||||
Data: openSSHBanner,
|
||||
DstPort: 22,
|
||||
WantProto: ProtocolSSH,
|
||||
WantConfidence: .95,
|
||||
},
|
||||
{
|
||||
Name: "OpenSSH server",
|
||||
Direction: Server,
|
||||
Data: openSSHBanner,
|
||||
SrcPort: 22,
|
||||
WantProto: ProtocolSSH,
|
||||
WantConfidence: .95,
|
||||
},
|
||||
{
|
||||
Name: "OpenSSH server with banner",
|
||||
Direction: Server,
|
||||
Data: preBannerSSH,
|
||||
SrcPort: 22,
|
||||
WantProto: ProtocolSSH,
|
||||
WantConfidence: .95,
|
||||
},
|
||||
{
|
||||
Name: "Dropbear server",
|
||||
Direction: Server,
|
||||
Data: dropbearBanner,
|
||||
SrcPort: 22,
|
||||
WantProto: ProtocolSSH,
|
||||
WantConfidence: .95,
|
||||
},
|
||||
{
|
||||
Name: "Invalid MySQL",
|
||||
Direction: Server,
|
||||
Data: mysqlBanner,
|
||||
SrcPort: 3306,
|
||||
WantError: ErrUnknown,
|
||||
},
|
||||
{
|
||||
Name: "Invalid HTTP",
|
||||
Direction: Client,
|
||||
Data: httpBanner,
|
||||
DstPort: 80,
|
||||
WantError: ErrUnknown,
|
||||
},
|
||||
})
|
||||
|
||||
/*
|
||||
t.Run("OpenSSH client", func(t *testing.T) {
|
||||
p, _, err := Detect(Server, openSSHBanner, 0, 22)
|
||||
p, c, err := Detect(Server, openSSHBanner, 0, 22)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
|
||||
if p.Name != ProtocolSSH {
|
||||
t.Fatalf("expected ssh protocol, got %s", p.Name)
|
||||
return
|
||||
@@ -45,12 +94,12 @@ func TestDetectSSH(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("OpenSSH server", func(t *testing.T) {
|
||||
p, _, err := Detect(Server, openSSHBanner, 0, 22)
|
||||
p, c, err := Detect(Server, openSSHBanner, 0, 22)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
|
||||
if p.Name != ProtocolSSH {
|
||||
t.Fatalf("expected ssh protocol, got %s", p.Name)
|
||||
return
|
||||
@@ -58,12 +107,12 @@ func TestDetectSSH(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("OpenSSH server with banner", func(t *testing.T) {
|
||||
p, _, err := Detect(Server, preBannerSSH, 0, 22)
|
||||
p, c, err := Detect(Server, preBannerSSH, 0, 22)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
|
||||
if p.Name != ProtocolSSH {
|
||||
t.Fatalf("expected ssh protocol, got %s", p.Name)
|
||||
return
|
||||
@@ -71,12 +120,12 @@ func TestDetectSSH(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Dropbear server", func(t *testing.T) {
|
||||
p, _, err := Detect(Server, dropbearBanner, 0, 22)
|
||||
p, c, err := Detect(Server, dropbearBanner, 0, 22)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100)
|
||||
if p.Name != ProtocolSSH {
|
||||
t.Fatalf("expected ssh protocol, got %s", p.Name)
|
||||
return
|
||||
@@ -100,4 +149,5 @@ func TestDetectSSH(t *testing.T) {
|
||||
t.Logf("error %q, as expected", err)
|
||||
}
|
||||
})
|
||||
*/
|
||||
}
|
||||
|
@@ -2,7 +2,6 @@ package protocol
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
@@ -167,15 +166,74 @@ func TestDetectTLS(t *testing.T) {
|
||||
0x00, 0x00, 0x00, 0x25, 0x00, 0x03, 0x00, 0x00,
|
||||
}
|
||||
|
||||
tests := []*testCase{
|
||||
{
|
||||
Name: "SSLv3",
|
||||
Direction: Client,
|
||||
Data: sslV3ClientHello,
|
||||
DstPort: 443,
|
||||
WantProto: ProtocolSSL,
|
||||
WantConfidence: .95,
|
||||
},
|
||||
{
|
||||
Name: "TLS 1.1",
|
||||
Direction: Client,
|
||||
Data: tls11ClientHello,
|
||||
DstPort: 443,
|
||||
WantProto: ProtocolTLS,
|
||||
WantConfidence: .95,
|
||||
},
|
||||
{
|
||||
Name: "TLS 1.2",
|
||||
Direction: Client,
|
||||
Data: tls12ClientHello,
|
||||
DstPort: 443,
|
||||
WantProto: ProtocolTLS,
|
||||
WantConfidence: .95,
|
||||
},
|
||||
{
|
||||
Name: "TLS 1.3",
|
||||
Direction: Client,
|
||||
Data: tls13ClientHello,
|
||||
DstPort: 443,
|
||||
WantProto: ProtocolTLS,
|
||||
WantConfidence: .95,
|
||||
},
|
||||
{
|
||||
Name: "Invalid PostgreSQL",
|
||||
Direction: Client,
|
||||
Data: pgClientStartup,
|
||||
DstPort: 5432,
|
||||
WantError: ErrUnknown,
|
||||
},
|
||||
}
|
||||
|
||||
defer func() { Strict = false }()
|
||||
for _, strict := range []bool{false, true} {
|
||||
Strict = strict
|
||||
|
||||
name := "loose"
|
||||
if strict {
|
||||
name = "strict"
|
||||
t.Run("strict", func(t *testing.T) {
|
||||
testRunner(t, tests)
|
||||
})
|
||||
} else {
|
||||
// Strict runner doesn't allow for partial packet matching:
|
||||
t.Run("loose", func(t *testing.T) {
|
||||
testRunner(t, append([]*testCase{
|
||||
{
|
||||
Name: "TLS 1.1 partial",
|
||||
Direction: Client,
|
||||
Data: tls11ClientHelloPartial,
|
||||
DstPort: 443,
|
||||
WantProto: ProtocolTLS,
|
||||
WantConfidence: .50,
|
||||
},
|
||||
}, tests...))
|
||||
})
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
t.Run(name, func(t *testing.T) {
|
||||
|
||||
t.Run("SSLv3 Client Hello", func(t *testing.T) {
|
||||
@@ -260,6 +318,7 @@ func TestDetectTLS(t *testing.T) {
|
||||
}
|
||||
})
|
||||
})
|
||||
*/
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,10 +1,109 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testCase struct {
|
||||
Name string
|
||||
Direction Direction
|
||||
Data []byte
|
||||
SrcPort int
|
||||
DstPort int
|
||||
WantProto string
|
||||
WantConfidence float64
|
||||
WantError error
|
||||
}
|
||||
|
||||
func testRunner(t *testing.T, tests []*testCase) {
|
||||
t.Helper()
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.Name, func(t *testing.T) {
|
||||
if test.SrcPort == 0 {
|
||||
test.SrcPort = 1024 + rand.Intn(65535-1024)
|
||||
}
|
||||
if test.DstPort == 0 {
|
||||
test.DstPort = 1024 + rand.Intn(65535-1024)
|
||||
}
|
||||
|
||||
proto, confidence, err := Detect(test.Direction, test.Data, test.SrcPort, test.DstPort)
|
||||
|
||||
// Process error first
|
||||
if err != nil {
|
||||
if test.WantError == nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
} else if !errors.Is(err, test.WantError) {
|
||||
t.Fatalf("Detect(%s, %s, %d, %d) returned error %q, expected %q",
|
||||
test.Direction, testBytesSample(test.Data, 8), test.SrcPort, test.DstPort,
|
||||
err, test.WantError)
|
||||
} else {
|
||||
t.Logf("Detect(%s, %s, %d, %d) returned error %q as expected",
|
||||
test.Direction, testBytesSample(test.Data, 8), test.SrcPort, test.DstPort,
|
||||
err)
|
||||
}
|
||||
return
|
||||
} else if test.WantError != nil {
|
||||
t.Fatalf("Detect(%s, %s, %d, %d) returned protocol %q version %s, expected error %q",
|
||||
test.Direction, testBytesSample(test.Data, 8), test.SrcPort, test.DstPort,
|
||||
proto.Name, proto.Version, test.WantError)
|
||||
return
|
||||
}
|
||||
|
||||
// Process protocol
|
||||
if proto == nil {
|
||||
t.Fatalf("Detect(%s, %s, %d, %d) returned nil, expected protocol %q",
|
||||
test.Direction, testBytesSample(test.Data, 8), test.SrcPort, test.DstPort,
|
||||
test.WantProto)
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("Detect(%s, %s, %d, %d) returned protocol %q version %s with confidence %g%%",
|
||||
test.Direction, testBytesSample(test.Data, 4), test.SrcPort, test.DstPort,
|
||||
proto.Name, proto.Version, confidence*100)
|
||||
|
||||
if proto.Name != test.WantProto {
|
||||
t.Errorf("Expected protocol %q", test.WantProto)
|
||||
}
|
||||
if !testAlmostEqual(confidence, test.WantConfidence) {
|
||||
t.Errorf("Expected confidence %g%%", test.WantConfidence*100)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testBytesSample(b []byte, n int) string {
|
||||
if b == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
var (
|
||||
hex []string
|
||||
etc string
|
||||
)
|
||||
for i, l := 0, len(b); i < l && i < n; i++ {
|
||||
if strconv.IsPrint(rune(b[i])) {
|
||||
hex = append(hex, fmt.Sprintf("%c", b[i]))
|
||||
} else {
|
||||
hex = append(hex, fmt.Sprintf("\\x%02X", b[i]))
|
||||
}
|
||||
if i == (n-1) && l > (n-1) {
|
||||
etc = fmt.Sprintf(" … (%d more)", l-n)
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf(`"%s"%s`, strings.Join(hex, ""), etc)
|
||||
}
|
||||
|
||||
func testAlmostEqual(a, b float64) bool {
|
||||
const e = 1e-9
|
||||
return math.Abs(a-b) < e
|
||||
}
|
||||
|
||||
func TestCompareFloats(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
Reference in New Issue
Block a user