Initial import
This commit is contained in:
113
protocol/detect.go
Normal file
113
protocol/detect.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Strict mode requires a full, compliant packet to be captured. This is only
|
||||
// implemented by some detectors.
|
||||
var Strict bool
|
||||
|
||||
// Common errors.
|
||||
var (
|
||||
ErrTimeout = errors.New("timeout")
|
||||
ErrUnknown = errors.New("unknown protocol")
|
||||
)
|
||||
|
||||
// Direction indicates the communcation direction.
|
||||
type Direction int
|
||||
|
||||
// Directions supported by this package.
|
||||
const (
|
||||
Unknown Direction = iota
|
||||
Client
|
||||
Server
|
||||
Both
|
||||
)
|
||||
|
||||
func (dir Direction) Contains(other Direction) bool {
|
||||
switch dir {
|
||||
case Client:
|
||||
return other == Client || other == Both
|
||||
case Server:
|
||||
return other == Server || other == Both
|
||||
case Both:
|
||||
return other == Client || other == Server
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
var directionName = map[Direction]string{
|
||||
Client: "client",
|
||||
Server: "server",
|
||||
Both: "both",
|
||||
}
|
||||
|
||||
func (dir Direction) String() string {
|
||||
if s, ok := directionName[dir]; ok {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("invalid (%d)", int(dir))
|
||||
}
|
||||
|
||||
type format struct {
|
||||
dir Direction
|
||||
magic string
|
||||
detect DetectFunc
|
||||
}
|
||||
|
||||
// Formats is the list of registered formats.
|
||||
var (
|
||||
formatsMu sync.Mutex
|
||||
atomicFormats atomic.Value
|
||||
)
|
||||
|
||||
type DetectFunc func(Direction, []byte) *Protocol
|
||||
|
||||
func Register(dir Direction, magic string, detect DetectFunc) {
|
||||
formatsMu.Lock()
|
||||
formats, _ := atomicFormats.Load().([]format)
|
||||
atomicFormats.Store(append(formats, format{dir, magic, detect}))
|
||||
formatsMu.Unlock()
|
||||
}
|
||||
|
||||
func matchMagic(magic string, data []byte) bool {
|
||||
// Empty magic means the detector will always run.
|
||||
if len(magic) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// The buffer should contain at least the same number of bytes
|
||||
// as our magic.
|
||||
if len(data) < len(magic) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Match bytes in magic with bytes in data.
|
||||
for i, b := range []byte(magic) {
|
||||
if b != '?' && data[i] != b {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Detect a protocol based on the provided data.
|
||||
func Detect(dir Direction, data []byte) (*Protocol, error) {
|
||||
formats, _ := atomicFormats.Load().([]format)
|
||||
for _, f := range formats {
|
||||
if f.dir.Contains(dir) {
|
||||
// Check the buffer to see if we have sufficient bytes
|
||||
if matchMagic(f.magic, data) {
|
||||
if p := f.detect(dir, data); p != nil {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, ErrUnknown
|
||||
}
|
94
protocol/detect_http.go
Normal file
94
protocol/detect_http.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Register(Client, "", detectHTTPRequest)
|
||||
Register(Server, "HTTP/?.", detectHTTPResponse)
|
||||
}
|
||||
|
||||
func detectHTTPRequest(dir Direction, data []byte) *Protocol {
|
||||
// A minimal request "GET / HTTP/1.0\r\n" is > 8 bytes.
|
||||
if len(data) < 8 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if Strict {
|
||||
var (
|
||||
b = append(data, '\r', '\n')
|
||||
r = bufio.NewReader(bytes.NewReader(b))
|
||||
)
|
||||
if request, err := http.ReadRequest(r); err == nil {
|
||||
return &Protocol{
|
||||
Name: ProtocolHTTP,
|
||||
Version: Version{
|
||||
Major: request.ProtoMajor,
|
||||
Minor: request.ProtoMinor,
|
||||
Patch: -1,
|
||||
},
|
||||
}
|
||||
}
|
||||
r.Reset(bytes.NewReader(b))
|
||||
if response, err := http.ReadResponse(r, nil); err == nil {
|
||||
return &Protocol{
|
||||
Name: ProtocolHTTP,
|
||||
Version: Version{
|
||||
Major: response.ProtoMajor,
|
||||
Minor: response.ProtoMinor,
|
||||
Patch: -1,
|
||||
},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
crlfIndex := bytes.IndexFunc(data, func(r rune) bool {
|
||||
return r == '\r' || r == '\n'
|
||||
})
|
||||
if crlfIndex == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// A request has three, space-separated parts.
|
||||
part := bytes.Split(data[:crlfIndex], []byte(" "))
|
||||
if len(part) != 3 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// The last part starts with "HTTP/".
|
||||
if !bytes.HasPrefix(part[2], []byte("HTTP/1")) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var version = Version{Patch: -1}
|
||||
fmt.Sscanf(string(part[2]), "HTTP/%d.%d ", &version.Major, &version.Minor)
|
||||
|
||||
return &Protocol{
|
||||
Name: ProtocolHTTP,
|
||||
Version: version,
|
||||
}
|
||||
}
|
||||
|
||||
func detectHTTPResponse(dir Direction, data []byte) *Protocol {
|
||||
if !dir.Contains(Server) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// A minimal response "HTTP/1.0 200 OK\r\n" is > 8 bytes.
|
||||
if len(data) < 8 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var version = Version{Patch: -1}
|
||||
fmt.Sscanf(string(data), "HTTP/%d.%d ", &version.Major, &version.Minor)
|
||||
|
||||
return &Protocol{
|
||||
Name: ProtocolHTTP,
|
||||
Version: version,
|
||||
}
|
||||
}
|
154
protocol/detect_http_test.go
Normal file
154
protocol/detect_http_test.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDetectHTTPRequest(t *testing.T) {
|
||||
atomicFormats.Store([]format{{Client, "", detectHTTPRequest}})
|
||||
|
||||
// A valid HTTP/1.0 GET request
|
||||
http10Request := []byte("GET /old-page.html HTTP/1.0\r\nUser-Agent: NCSA_Mosaic/1.0\r\n\r\n")
|
||||
|
||||
// A valid HTTP/1.1 GET request
|
||||
getRequest := []byte("GET /resource/item?id=123 HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
|
||||
// An invalid HTTP request
|
||||
sshBanner := []byte("SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.4\r\n")
|
||||
|
||||
defer func() { Strict = false }()
|
||||
for _, strict := range []bool{false, true} {
|
||||
Strict = strict
|
||||
|
||||
name := "loose"
|
||||
if strict {
|
||||
name = "strict"
|
||||
}
|
||||
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Run("HTTP/1.0 GET", func(t *testing.T) {
|
||||
p, err := Detect(Client, http10Request)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
if p.Name != ProtocolHTTP {
|
||||
t.Fatalf("expected http protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HTTP/1.1 GET", func(t *testing.T) {
|
||||
p, err := Detect(Client, getRequest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
if p.Name != ProtocolHTTP {
|
||||
t.Fatalf("expected http protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid SSH", func(t *testing.T) {
|
||||
_, err := Detect(Server, sshBanner)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
t.Logf("error %q, as expected", err)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectHTTPResponse(t *testing.T) {
|
||||
atomicFormats.Store([]format{{Server, "HTTP/?.? ", detectHTTPResponse}})
|
||||
|
||||
// A valid HTTP/1.0 403 Forbidden response
|
||||
http10Response := []byte("HTTP/1.0 403 Forbidden\r\nServer: CERN/3.0\r\n\r\n")
|
||||
|
||||
// A valid HTTP/1.1 200 OK response
|
||||
responseOK := []byte("HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n<html>...</html>")
|
||||
|
||||
// A valid HTTP/1.1 404 Not Found response
|
||||
responseNotFound := []byte("HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n")
|
||||
|
||||
// An invalid HTTP GET request
|
||||
getRequest := []byte("GET /resource/item?id=123 HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
|
||||
// An invalid banner (SSH)
|
||||
sshBanner := []byte("SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.4\r\n")
|
||||
|
||||
defer func() { Strict = false }()
|
||||
for _, strict := range []bool{false, true} {
|
||||
Strict = strict
|
||||
|
||||
name := "loose"
|
||||
if strict {
|
||||
name = "strict"
|
||||
}
|
||||
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Run("HTTP/1.0 403", func(t *testing.T) {
|
||||
p, err := Detect(Server, http10Response)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
if p.Name != ProtocolHTTP {
|
||||
t.Fatalf("expected http protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HTTP/1.1 200", func(t *testing.T) {
|
||||
p, err := Detect(Server, responseOK)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
if p.Name != ProtocolHTTP {
|
||||
t.Fatalf("expected http protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HTTP/1.1 404", func(t *testing.T) {
|
||||
p, err := Detect(Server, responseNotFound)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
if p.Name != ProtocolHTTP {
|
||||
t.Fatalf("expected http protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid HTTP/1.1 GET", func(t *testing.T) {
|
||||
_, err := Detect(Server, getRequest)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
t.Logf("error %q, as expected", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid SSH", func(t *testing.T) {
|
||||
_, err := Detect(Server, sshBanner)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
t.Logf("error %q, as expected", err)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
51
protocol/detect_mysql.go
Normal file
51
protocol/detect_mysql.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Register(Server, "\x0a", detectMySQL)
|
||||
}
|
||||
|
||||
func detectMySQL(dir Direction, data []byte) *Protocol {
|
||||
if len(data) < 7 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// The first byte of the handshake packet is the protocol version.
|
||||
// For MySQL, this is 10 (0x0A).
|
||||
if data[0] != 0x0A {
|
||||
return nil
|
||||
}
|
||||
|
||||
// After the protocol version, there is a null-terminated server version string.
|
||||
// We search for the null byte starting from the second byte (index 1).
|
||||
nullIndex := bytes.IndexByte(data[1:], 0x00)
|
||||
|
||||
// If no null byte is found, it's not a valid banner.
|
||||
if nullIndex == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// The position of the null byte is relative to the start of the whole slice.
|
||||
// It's 1 (for the protocol byte) + nullIndex.
|
||||
serverVersionEndPos := 1 + nullIndex
|
||||
|
||||
// After the null-terminated version string, there must be at least 4 bytes
|
||||
// for the connection ID, plus more data for capabilities, auth, etc.
|
||||
// We'll check for the 4-byte connection ID as a minimum requirement.
|
||||
const connectionIDLength = 4
|
||||
if len(data) < serverVersionEndPos+1+connectionIDLength {
|
||||
return nil
|
||||
}
|
||||
|
||||
var version Version
|
||||
fmt.Sscanf(string(data[1:serverVersionEndPos]), "%d.%d.%d-%s", &version.Major, &version.Minor, &version.Patch, &version.Extra)
|
||||
|
||||
return &Protocol{
|
||||
Name: ProtocolMySQL,
|
||||
Version: version,
|
||||
}
|
||||
}
|
82
protocol/detect_mysql_test.go
Normal file
82
protocol/detect_mysql_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDetectMySQL(t *testing.T) {
|
||||
atomicFormats.Store([]format{{Server, "\x0a", detectMySQL}})
|
||||
|
||||
// 1. A valid MySQL 8.0 banner
|
||||
mysql8Banner := []byte{
|
||||
0x0a, 0x38, 0x2e, 0x30, 0x2e, 0x33, 0x32, 0x00, 0x0d, 0x00, 0x00, 0x00,
|
||||
0x04, 0x5a, 0x56, 0x5f, 0x3e, 0x6e, 0x76, 0x27, 0x00, 0xff, 0xff, 0xff,
|
||||
0x02, 0x00, 0xff, 0xc7, 0x15, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x63, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x67, 0x5f, 0x73,
|
||||
0x68, 0x61, 0x32, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00,
|
||||
}
|
||||
|
||||
// 2. A valid MariaDB banner (protocol-compatible)
|
||||
mariaDBBanner := []byte{
|
||||
0x0a, 0x35, 0x2e, 0x35, 0x2e, 0x35, 0x2d, 0x31, 0x30, 0x2e, 0x36, 0x2e,
|
||||
0x35, 0x2d, 0x4d, 0x61, 0x72, 0x69, 0x61, 0x44, 0x42, 0x2d, 0x6c, 0x6f,
|
||||
0x67, 0x00, 0x1a, 0x00, 0x00, 0x00, 0x4e, 0x5c, 0x32, 0x7b, 0x45, 0x3b,
|
||||
0x40, 0x60, 0x00, 0xff, 0xf7, 0x08, 0x02, 0x00, 0xff, 0x81, 0x15, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x6d, 0x79, 0x73,
|
||||
0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70, 0x61,
|
||||
0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00,
|
||||
}
|
||||
|
||||
// 3. An invalid banner (e.g., an HTTP request)
|
||||
httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
|
||||
// 4. A short, invalid slice
|
||||
shortSlice := []byte{0x0a, 0x31, 0x32, 0x33}
|
||||
|
||||
// 5. A slice that starts correctly but is malformed (no null terminator)
|
||||
malformedSlice := []byte{0x0a, 0x38, 0x2e, 0x30, 0x2e, 0x30, 0x01, 0x02, 0x03, 0x04, 0x05}
|
||||
|
||||
t.Run("MySQL 8", func(t *testing.T) {
|
||||
p, _ := Detect(Server, mysql8Banner)
|
||||
if p == nil {
|
||||
t.Fatal("expected MySQL protocol, got nil")
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
})
|
||||
|
||||
t.Run("MariaDB", func(t *testing.T) {
|
||||
p, _ := Detect(Server, mariaDBBanner)
|
||||
if p == nil {
|
||||
t.Fatal("expected MySQL protocol, got nil")
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
})
|
||||
|
||||
t.Run("Invalid HTTP", func(t *testing.T) {
|
||||
_, err := Detect(Server, httpBanner)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
t.Logf("error %q, as expected", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Too short", func(t *testing.T) {
|
||||
_, err := Detect(Server, shortSlice)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
t.Logf("error %q, as expected", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Malformed", func(t *testing.T) {
|
||||
_, err := Detect(Server, malformedSlice)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
t.Logf("error %q, as expected", err)
|
||||
}
|
||||
})
|
||||
}
|
70
protocol/detect_postgres.go
Normal file
70
protocol/detect_postgres.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"log"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registerPostgreSQL()
|
||||
}
|
||||
|
||||
func registerPostgreSQL() {
|
||||
Register(Server, "R\x00???", detectPostgreSQLServer) // Authentication request
|
||||
Register(Server, "K\x00???", detectPostgreSQLServer) // BackendKeyData
|
||||
Register(Server, "S\x00???", detectPostgreSQLServer) // ParameterStatus
|
||||
Register(Server, "Z\x00???", detectPostgreSQLServer) // ReadyForQuery
|
||||
Register(Server, "E\x00???", detectPostgreSQLServer) // ErrorResponse
|
||||
Register(Server, "N\x00???", detectPostgreSQLServer) // NoticeResponse
|
||||
Register(Client, "????\x00\x02\x00\x00", detectPostgreSQLClient) // Startup packet, protocol 2.0
|
||||
Register(Client, "????\x00\x03\x00\x00", detectPostgreSQLClient) // Startup packet, protocol 3.0
|
||||
}
|
||||
|
||||
func detectPostgreSQLClient(dir Direction, data []byte) *Protocol {
|
||||
// A client startup message needs at least 8 bytes (length + protocol version).
|
||||
if len(data) < 8 {
|
||||
return nil
|
||||
}
|
||||
|
||||
length := int(binary.BigEndian.Uint32(data[0:]))
|
||||
if len(data) != length {
|
||||
log.Printf("not postgres %q: %d != %d", data, len(data), length)
|
||||
return nil
|
||||
}
|
||||
|
||||
major := int(binary.BigEndian.Uint16(data[4:]))
|
||||
minor := int(binary.BigEndian.Uint16(data[6:]))
|
||||
if major == 2 || major == 3 {
|
||||
return &Protocol{
|
||||
Name: ProtocolPostgreSQL,
|
||||
Version: Version{
|
||||
Major: major,
|
||||
Minor: minor,
|
||||
Patch: -1,
|
||||
},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func detectPostgreSQLServer(dir Direction, data []byte) *Protocol {
|
||||
// A server message needs at least 5 bytes (type + length).
|
||||
if len(data) < 5 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// All server messages (and subsequent client messages) are tagged with a single-byte type.
|
||||
firstByte := data[0]
|
||||
switch firstByte {
|
||||
case 'R', // Authentication request
|
||||
'K', // BackendKeyData
|
||||
'S', // ParameterStatus
|
||||
'Z', // ReadyForQuery
|
||||
'E', // ErrorResponse
|
||||
'N': // NoticeResponse
|
||||
return &Protocol{Name: ProtocolPostgreSQL}
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
94
protocol/detect_postgres_test.go
Normal file
94
protocol/detect_postgres_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDetectPostgreSQLClient(t *testing.T) {
|
||||
atomicFormats.Store([]format{})
|
||||
registerPostgreSQL()
|
||||
|
||||
// 1. A valid PostgreSQL client startup message
|
||||
// Format: len (4b), proto (4b), params (n-bytes)
|
||||
// Here, user=mazeio, database=test
|
||||
// The message is: "user\0mazeio\0database\0test\0\0"
|
||||
// Total length: 4 (len) + 4 (proto) + 27 (params) = 35 (0x23)
|
||||
pgClientStartup := []byte{
|
||||
0x00, 0x00, 0x00, 0x23, // Length: 35
|
||||
0x00, 0x03, 0x00, 0x00, // Protocol Version 3.0
|
||||
'u', 's', 'e', 'r', 0x00, 'm', 'a', 'z', 'e', 'i', 'o', 0x00,
|
||||
'd', 'a', 't', 'a', 'b', 'a', 's', 'e', 0x00, 't', 'e', 's', 't', 0x00, 0x00,
|
||||
}
|
||||
|
||||
t.Run("Protocol 3.0", func(t *testing.T) {
|
||||
p, err := Detect(Client, pgClientStartup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
if p.Name != ProtocolPostgreSQL {
|
||||
t.Fatalf("expected postgres protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDetectPostgreSQLServer(t *testing.T) {
|
||||
atomicFormats.Store([]format{})
|
||||
registerPostgreSQL()
|
||||
|
||||
// A valid PostgreSQL server AuthenticationOk response
|
||||
// Format: type (1b), len (4b), content (4b)
|
||||
pgServerAuthOK := []byte{
|
||||
'R', // Type: Authentication
|
||||
0x00, 0x00, 0x00, 0x08, // Length: 8
|
||||
0x00, 0x00, 0x00, 0x00, // Auth OK (0)
|
||||
}
|
||||
|
||||
// A valid PostgreSQL server ErrorResponse
|
||||
pgServerError := []byte{
|
||||
'E', // Type: ErrorResponse
|
||||
0x00, 0x00, 0x00, 0x31, // Length
|
||||
'S', 'E', 'R', 'R', 'O', 'R', 0x00, // ... and so on
|
||||
}
|
||||
|
||||
// Invalid data (HTTP GET request)
|
||||
httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
|
||||
t.Run("AuthenticationOk", func(t *testing.T) {
|
||||
p, err := Detect(Server, pgServerAuthOK)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
if p.Name != ProtocolPostgreSQL {
|
||||
t.Fatalf("expected postgres protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ErrorResponse", func(t *testing.T) {
|
||||
p, err := Detect(Server, pgServerError)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
if p.Name != ProtocolPostgreSQL {
|
||||
t.Fatalf("expected postgres protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid HTTP", func(t *testing.T) {
|
||||
_, err := Detect(Server, httpBanner)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
t.Logf("error %q, as expected", err)
|
||||
}
|
||||
})
|
||||
}
|
51
protocol/detect_ssh.go
Normal file
51
protocol/detect_ssh.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
)
|
||||
|
||||
// The required prefix for the SSH protocol identification line.
|
||||
const (
|
||||
ssh199Prefix = "SSH-1.99-"
|
||||
ssh20Prefix = "SSH-2.0-"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Register(Both, "", detectSSH)
|
||||
}
|
||||
|
||||
func detectSSH(dir Direction, data []byte) *Protocol {
|
||||
// The data must be at least as long as the prefix itself.
|
||||
if len(data) < len(ssh20Prefix) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// The protocol allows for pre-banner text, so we have to check all lines.
|
||||
for _, line := range bytes.Split(data, []byte{'\n'}) {
|
||||
line = bytes.TrimSuffix(line, []byte{'\r'})
|
||||
if bytes.HasPrefix(line, []byte(ssh20Prefix)) {
|
||||
return &Protocol{
|
||||
Name: ProtocolSSH,
|
||||
Version: Version{
|
||||
Major: 2,
|
||||
Minor: 0,
|
||||
Patch: -1,
|
||||
Extra: string(line[len(ssh20Prefix):]),
|
||||
},
|
||||
}
|
||||
}
|
||||
if bytes.HasPrefix(line, []byte(ssh199Prefix)) {
|
||||
return &Protocol{
|
||||
Name: ProtocolSSH,
|
||||
Version: Version{
|
||||
Major: 1,
|
||||
Minor: 99,
|
||||
Patch: -1,
|
||||
Extra: string(line[len(ssh20Prefix):]),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
103
protocol/detect_ssh_test.go
Normal file
103
protocol/detect_ssh_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDetectSSH(t *testing.T) {
|
||||
atomicFormats.Store([]format{{Both, "", detectSSH}})
|
||||
|
||||
// 1. A standard OpenSSH banner
|
||||
openSSHBanner := []byte("SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.4\r\n")
|
||||
|
||||
// 2. An SSH banner with a pre-banner legal notice
|
||||
preBannerSSH := []byte(
|
||||
"*******************************************************************\r\n" +
|
||||
"* W A R N I N G *\r\n" +
|
||||
"* This system is for the use of authorized users only. *\r\n" +
|
||||
"*******************************************************************\r\n" +
|
||||
"SSH-2.0-OpenSSH_7.6p1\r\n",
|
||||
)
|
||||
|
||||
// 3. A different SSH implementation (Dropbear)
|
||||
dropbearBanner := []byte("SSH-2.0-dropbear_2020.81\r\n")
|
||||
|
||||
// 4. An invalid banner (e.g., the MySQL banner from the previous example)
|
||||
mysqlBanner := []byte{
|
||||
0x0a, 0x38, 0x2e, 0x30, 0x2e, 0x33, 0x32, 0x00, 0x0d, 0x00, 0x00, 0x00,
|
||||
}
|
||||
|
||||
// 5. A simple HTTP request
|
||||
httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
|
||||
t.Run("OpenSSH client", func(t *testing.T) {
|
||||
p, err := Detect(Server, openSSHBanner)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
if p.Name != ProtocolSSH {
|
||||
t.Fatalf("expected ssh protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OpenSSH server", func(t *testing.T) {
|
||||
p, err := Detect(Server, openSSHBanner)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
if p.Name != ProtocolSSH {
|
||||
t.Fatalf("expected ssh protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OpenSSH server with banner", func(t *testing.T) {
|
||||
p, err := Detect(Server, preBannerSSH)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
if p.Name != ProtocolSSH {
|
||||
t.Fatalf("expected ssh protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Dropbear server", func(t *testing.T) {
|
||||
p, err := Detect(Server, dropbearBanner)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
if p.Name != ProtocolSSH {
|
||||
t.Fatalf("expected ssh protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid MySQL banner", func(t *testing.T) {
|
||||
_, err := Detect(Server, mysqlBanner)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
t.Logf("error %q, as expected", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid HTTP banner", func(t *testing.T) {
|
||||
_, err := Detect(Server, httpBanner)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
t.Logf("error %q, as expected", err)
|
||||
}
|
||||
})
|
||||
}
|
98
protocol/detect_tls.go
Normal file
98
protocol/detect_tls.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"golang.org/x/crypto/cryptobyte"
|
||||
|
||||
"git.maze.io/go/dpi"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registerTLS()
|
||||
}
|
||||
|
||||
func registerTLS() {
|
||||
Register(Both, "\x16\x03\x00", detectTLS) // SSLv3
|
||||
Register(Both, "\x16\x03\x01", detectTLS) // TLSv1.0
|
||||
Register(Both, "\x16\x03\x02", detectTLS) // TLSv1.1
|
||||
Register(Both, "\x16\x03\x03", detectTLS) // TLSv1.2
|
||||
}
|
||||
|
||||
func detectTLS(dir Direction, data []byte) *Protocol {
|
||||
stream := cryptobyte.String(data)
|
||||
|
||||
// A TLS packet always has a content type (1 byte), version (2 bytes) and length (2 bytes).
|
||||
if len(stream) < 5 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for TLS Handshake (type 22)
|
||||
var header struct {
|
||||
Type uint8
|
||||
Version uint16
|
||||
Length uint32
|
||||
}
|
||||
if !stream.ReadUint8(&header.Type) || header.Type != 0x16 {
|
||||
return nil
|
||||
}
|
||||
if !stream.ReadUint16(&header.Version) {
|
||||
return nil
|
||||
}
|
||||
if !stream.ReadUint24(&header.Length) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Detected SSL/TLS version
|
||||
var version dpi.TLSVersion
|
||||
|
||||
// Attempt to decode the full TLS Client Hello handshake
|
||||
if version == 0 {
|
||||
if hello, err := dpi.DecodeTLSClientHelloHandshake(data); err == nil {
|
||||
version = hello.Version
|
||||
}
|
||||
}
|
||||
|
||||
// Attempt to decode the full TLS Server Hello handshake
|
||||
if version == 0 {
|
||||
if hello, err := dpi.DecodeTLSServerHello(data); err == nil {
|
||||
version = hello.Version
|
||||
}
|
||||
}
|
||||
|
||||
// Attempt to decode at least the handshake protocol and version.
|
||||
if version == 0 && !Strict {
|
||||
var handshakeType uint8
|
||||
if stream.ReadUint8(&handshakeType) && (handshakeType == 1 || handshakeType == 2) {
|
||||
var (
|
||||
length uint32
|
||||
versionWord uint16
|
||||
)
|
||||
if stream.ReadUint24(&length) && stream.ReadUint16(&versionWord) {
|
||||
version = dpi.TLSVersion(versionWord)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to the version in the TLS record header, this is less accurate
|
||||
if version == 0 && !Strict {
|
||||
version = dpi.TLSVersion(header.Version)
|
||||
}
|
||||
|
||||
// We're "multi protocol", in that SSL is its own protocol
|
||||
if version == dpi.VersionSSL30 {
|
||||
return &Protocol{
|
||||
Name: ProtocolSSL,
|
||||
Version: Version{Major: 3, Minor: 0, Patch: -1},
|
||||
}
|
||||
} else if version >= dpi.VersionTLS10 && version <= dpi.VersionTLS13 {
|
||||
return &Protocol{
|
||||
Name: ProtocolTLS,
|
||||
Version: Version{Major: 1, Minor: int(uint8(version) - 1), Patch: -1},
|
||||
}
|
||||
} else if version >= dpi.VersionTLS13Draft && version <= dpi.VersionTLS13Draft23 {
|
||||
return &Protocol{
|
||||
Name: ProtocolTLS,
|
||||
Version: Version{Major: 1, Minor: 3, Patch: -1},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
280
protocol/detect_tls_test.go
Normal file
280
protocol/detect_tls_test.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDetectTLS(t *testing.T) {
|
||||
atomicFormats.Store([]format{})
|
||||
registerTLS()
|
||||
|
||||
// A SSLv3 Client Hello
|
||||
sslV3ClientHello := testMustDecodeHexString(`
|
||||
16 03 00 00 51 01 00 00 4d 03
|
||||
00 50 42 b2 29 1f cf 52 a0 94 87 05 e7 0b 63 08
|
||||
12 a2 6c 59 f7 f5 72 2b 57 14 a7 07 95 cb ce e5
|
||||
e4 00 00 26 00 04 00 05 00 2f 00 33 00 32 00 0a
|
||||
fe ff 00 16 00 13 00 66 00 09 fe fe 00 15 00 12
|
||||
00 03 00 08 00 06 00 14 00 11 01 00`)
|
||||
|
||||
// A synthesized TLS 1.1 ClientHello
|
||||
tls11ClientHello := []byte{
|
||||
// --- Record Layer (5 bytes) ---
|
||||
0x16, // Content Type: Handshake (22)
|
||||
0x03, 0x02, // Version: TLS 1.1 (Major=3, Minor=2)
|
||||
0x00, 0x45, // Length of the handshake message below (69 bytes)
|
||||
|
||||
// --- Handshake Protocol: ClientHello (69 bytes) ---
|
||||
0x01, // Handshake Type: ClientHello (1)
|
||||
0x00, 0x00, 0x41, // Length of the rest of the message (65 bytes)
|
||||
0x03, 0x02, // Client Version: TLS 1.1 (Major=3, Minor=2)
|
||||
|
||||
// Random (32 bytes):
|
||||
// In TLS 1.1+, the entire 32 bytes are fully random.
|
||||
// The timestamp structure from SSLv3 is removed.
|
||||
0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88,
|
||||
0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00,
|
||||
0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88,
|
||||
0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00,
|
||||
|
||||
// Session ID
|
||||
0x00, // Session ID Length: 0 (new session)
|
||||
|
||||
// Cipher Suites
|
||||
0x00, 0x04, // Cipher Suites Length: 4 bytes (2 suites)
|
||||
0x00, 0x2F, // Cipher Suite: TLS_RSA_WITH_AES_128_CBC_SHA
|
||||
0x00, 0x35, // Cipher Suite: TLS_RSA_WITH_AES_256_CBC_SHA
|
||||
|
||||
// Compression Methods
|
||||
0x01, // Compression Methods Length: 1 byte
|
||||
0x00, // Compression Method: NULL (0)
|
||||
|
||||
// --- Extensions (20 bytes) ---
|
||||
0x00, 0x14, // Extensions Length: 20 bytes
|
||||
// Extension: Server Name Indication (SNI)
|
||||
0x00, 0x00, // Extension Type: server_name (0)
|
||||
0x00, 0x10, // Extension Length: 16 bytes
|
||||
0x00, 0x0E, // Server Name List Length: 14 bytes
|
||||
0x00, // Server Name Type: host_name (0)
|
||||
0x00, 0x0B, // Server Name Length: 11 bytes
|
||||
// Server Name: "example.com"
|
||||
'e', 'x', 'a', 'm', 'p', 'l', 'e', '.', 'c', 'o', 'm',
|
||||
}
|
||||
|
||||
// A synthesized partial TLS 1.1 ClientHello
|
||||
tls11ClientHelloPartial := []byte{
|
||||
// --- Record Layer (5 bytes) ---
|
||||
0x16, // Content Type: Handshake (22)
|
||||
0x03, 0x02, // Version: TLS 1.1 (Major=3, Minor=2)
|
||||
0x00, 0x45, // Length of the handshake message below (69 bytes)
|
||||
|
||||
// --- Handshake Protocol: ClientHello (69 bytes) ---
|
||||
0x01, // Handshake Type: ClientHello (1)
|
||||
0x00, 0x00, 0x41, // Length of the rest of the message (65 bytes)
|
||||
0x03, 0x02, // Client Version: TLS 1.1 (Major=3, Minor=2)
|
||||
}
|
||||
|
||||
// A synthesized TLSv1.2 ClientHello
|
||||
tls12ClientHello := []byte{
|
||||
// --- Record Layer (5 bytes) ---
|
||||
0x16, // Content Type: Handshake (22)
|
||||
0x03, 0x03, // Version: TLS 1.2 (Major=3, Minor=3)
|
||||
0x00, 0x61, // Length of handshake message below (97 bytes)
|
||||
|
||||
// --- Handshake Protocol: ClientHello (97 bytes) ---
|
||||
0x01, // Handshake Type: ClientHello (1)
|
||||
0x00, 0x00, 0x5D, // Length of the rest of the message (93 bytes)
|
||||
0x03, 0x03, // Client Version: TLS 1.2 (Major=3, Minor=3)
|
||||
|
||||
// Random (32 bytes)
|
||||
0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x11,
|
||||
0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99,
|
||||
0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x11,
|
||||
0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99,
|
||||
|
||||
// Session ID
|
||||
0x00, // Session ID Length: 0 (new session)
|
||||
|
||||
// Cipher Suites (using modern GCM suites)
|
||||
0x00, 0x04, // Cipher Suites Length: 4 bytes (2 suites)
|
||||
0xC0, 0x2F, // Cipher Suite: TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
|
||||
0xC0, 0x30, // Cipher Suite: TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
|
||||
|
||||
// Compression Methods
|
||||
0x01, // Compression Methods Length: 1 byte
|
||||
0x00, // Compression Method: NULL (0)
|
||||
|
||||
// --- Extensions (48 bytes) ---
|
||||
0x00, 0x30, // Extensions Length: 48 bytes
|
||||
// Extension: Server Name Indication (SNI) (20 bytes)
|
||||
0x00, 0x00, 0x00, 0x10, 0x00, 0x0E, 0x00, 0x00, 0x0B,
|
||||
'e', 'x', 'a', 'm', 'p', 'l', 'e', '.', 'c', 'o', 'm',
|
||||
// Extension: Signature Algorithms (10 bytes) - CRITICAL for TLS 1.2
|
||||
0x00, 0x0D, // Extension Type: signature_algorithms (13)
|
||||
0x00, 0x06, // Extension Length: 6 bytes
|
||||
0x00, 0x04, // Hash/Signature Algorithm List Length: 4 bytes
|
||||
0x04, 0x01, // Algorithm: rsa_pkcs1_sha256
|
||||
0x05, 0x01, // Algorithm: rsa_pkcs1_sha384
|
||||
// Extension: Application-Layer Protocol Negotiation (ALPN) (18 bytes)
|
||||
0x00, 0x10, // Extension Type: application_layer_protocol_negotiation (16)
|
||||
0x00, 0x0E, // Extension Length: 14 bytes
|
||||
0x00, 0x0C, // ALPN Extension Length: 12 bytes
|
||||
// ALPN Protocol: "h2" (HTTP/2)
|
||||
0x02, 'h', '2',
|
||||
// ALPN Protocol: "http/1.1"
|
||||
0x08, 'h', 't', 't', 'p', '/', '1', '.', '1',
|
||||
}
|
||||
|
||||
// A valid TLS 1.3 Client Hello (captured from a real connection)
|
||||
tls13ClientHello := []byte{
|
||||
0x16, // Content Type: Handshake (22)
|
||||
0x03, 0x01, // Version: TLS 1.0 (for compatibility in a 1.3 hello)
|
||||
0x01, 0x3a, // Length
|
||||
0x01, 0x00, 0x01, 0x36, 0x03, 0x03, 0xb1,
|
||||
0x40, 0xd3, 0xf1, 0x7d, 0xa3, 0xb8, 0x33, 0xac, 0xad, 0x21, 0x79, 0x9c,
|
||||
0xbe, 0x39, 0x96, 0x08, 0x49, 0x3b, 0x53, 0x75, 0xa0, 0x1b, 0xee, 0x6e,
|
||||
0x6a, 0xbe, 0x6c, 0x41, 0xdf, 0x6c, 0xf4, 0x20, 0xa4, 0xaa, 0x0c, 0xca,
|
||||
0xd4, 0x37, 0x76, 0x5f, 0x49, 0xc6, 0x06, 0x9b, 0xac, 0x90, 0x89, 0x76,
|
||||
0x1c, 0xc7, 0xc4, 0x12, 0xb4, 0x4a, 0xe0, 0x27, 0x72, 0x89, 0x97, 0x85,
|
||||
0x76, 0xf8, 0xc8, 0x83, 0x00, 0x62, 0x13, 0x03, 0x13, 0x02, 0x13, 0x01,
|
||||
0xcc, 0xa9, 0xcc, 0xa8, 0xcc, 0xaa, 0xc0, 0x30, 0xc0, 0x2c, 0xc0, 0x28,
|
||||
0xc0, 0x24, 0xc0, 0x14, 0xc0, 0x0a, 0x00, 0x9f, 0x00, 0x6b, 0x00, 0x39,
|
||||
0xff, 0x85, 0x00, 0xc4, 0x00, 0x88, 0x00, 0x81, 0x00, 0x9d, 0x00, 0x3d,
|
||||
0x00, 0x35, 0x00, 0xc0, 0x00, 0x84, 0xc0, 0x2f, 0xc0, 0x2b, 0xc0, 0x27,
|
||||
0xc0, 0x23, 0xc0, 0x13, 0xc0, 0x09, 0x00, 0x9e, 0x00, 0x67, 0x00, 0x33,
|
||||
0x00, 0xbe, 0x00, 0x45, 0x00, 0x9c, 0x00, 0x3c, 0x00, 0x2f, 0x00, 0xba,
|
||||
0x00, 0x41, 0xc0, 0x11, 0xc0, 0x07, 0x00, 0x05, 0x00, 0x04, 0xc0, 0x12,
|
||||
0xc0, 0x08, 0x00, 0x16, 0x00, 0x0a, 0x00, 0xff, 0x01, 0x00, 0x00, 0x8b,
|
||||
0x00, 0x2b, 0x00, 0x09, 0x08, 0x03, 0x04, 0x03, 0x03, 0x03, 0x02, 0x03,
|
||||
0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x2c,
|
||||
0x4b, 0xaa, 0xb4, 0xb3, 0xc8, 0x93, 0xcd, 0x5c, 0x24, 0xb9, 0x9b, 0xd4,
|
||||
0x59, 0x04, 0xfe, 0x69, 0xaf, 0x68, 0xb9, 0xa6, 0x36, 0xbb, 0xab, 0x87,
|
||||
0xfa, 0x15, 0x59, 0xea, 0xdd, 0x38, 0x68, 0x00, 0x00, 0x00, 0x0e, 0x00,
|
||||
0x0c, 0x00, 0x00, 0x09, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73,
|
||||
0x74, 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00, 0x00, 0x0a, 0x00, 0x0a, 0x00,
|
||||
0x08, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, 0x00, 0x0d, 0x00,
|
||||
0x18, 0x00, 0x16, 0x08, 0x06, 0x06, 0x01, 0x06, 0x03, 0x08, 0x05, 0x05,
|
||||
0x01, 0x05, 0x03, 0x08, 0x04, 0x04, 0x01, 0x04, 0x03, 0x02, 0x01, 0x02,
|
||||
0x03, 0x00, 0x10, 0x00, 0x0e, 0x00, 0x0c, 0x02, 0x68, 0x32, 0x08, 0x68,
|
||||
0x74, 0x74, 0x70, 0x2f, 0x31, 0x2e, 0x31,
|
||||
}
|
||||
|
||||
// Invalid data (Postgres client startup)
|
||||
pgClientStartup := []byte{
|
||||
0x00, 0x00, 0x00, 0x25, 0x00, 0x03, 0x00, 0x00,
|
||||
}
|
||||
|
||||
defer func() { Strict = false }()
|
||||
for _, strict := range []bool{false, true} {
|
||||
Strict = strict
|
||||
|
||||
name := "loose"
|
||||
if strict {
|
||||
name = "strict"
|
||||
}
|
||||
|
||||
t.Run(name, func(t *testing.T) {
|
||||
|
||||
t.Run("SSLv3 Client Hello", func(t *testing.T) {
|
||||
p, err := Detect(Client, sslV3ClientHello)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
if p.Name != ProtocolSSL {
|
||||
t.Fatalf("expected ssl protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TLS 1.1 Client Hello", func(t *testing.T) {
|
||||
p, err := Detect(Client, tls11ClientHello)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
if p.Name != ProtocolTLS {
|
||||
t.Fatalf("expected tls protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TLS 1.1 partial Client Hello", func(t *testing.T) {
|
||||
p, err := Detect(Client, tls11ClientHelloPartial)
|
||||
if strict {
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
t.Logf("error %q, as expected", err)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
if p.Name != ProtocolTLS {
|
||||
t.Fatalf("expected tls protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TLS 1.2 Client Hello", func(t *testing.T) {
|
||||
p, err := Detect(Client, tls12ClientHello)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
if p.Name != ProtocolTLS {
|
||||
t.Fatalf("expected tls protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TLS 1.3 Client Hello", func(t *testing.T) {
|
||||
p, err := Detect(Client, tls13ClientHello)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
t.Logf("detected %s version %s", p.Name, p.Version)
|
||||
if p.Name != ProtocolTLS {
|
||||
t.Fatalf("expected tls protocol, got %s", p.Name)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid PostgreSQL", func(t *testing.T) {
|
||||
_, err := Detect(Server, pgClientStartup)
|
||||
if !errors.Is(err, ErrUnknown) {
|
||||
t.Fatalf("expected unknown format, got error %T: %q", err, err)
|
||||
} else {
|
||||
t.Logf("error %q, as expected", err)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testDecodeHexString(s string) ([]byte, error) {
|
||||
s = strings.TrimSpace(s)
|
||||
s = strings.ReplaceAll(s, " ", "")
|
||||
s = strings.ReplaceAll(s, "\n", "")
|
||||
s = strings.ReplaceAll(s, "\t", "")
|
||||
return hex.DecodeString(s)
|
||||
}
|
||||
|
||||
func testMustDecodeHexString(s string) []byte {
|
||||
b, err := testDecodeHexString(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b
|
||||
}
|
128
protocol/intercept.go
Normal file
128
protocol/intercept.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Intercepted is the result returned by [Interceptor.Detect].
|
||||
type Intercepted struct {
|
||||
Direction Direction
|
||||
Protocol *Protocol
|
||||
Error error
|
||||
}
|
||||
|
||||
// Interceptor intercepts reads from client or server.
|
||||
type Interceptor struct {
|
||||
clientBytes chan []byte
|
||||
clientReader *readInterceptor
|
||||
serverBytes chan []byte
|
||||
serverReader *readInterceptor
|
||||
}
|
||||
|
||||
// NewInterceptor creates a new (transparent) protocol interceptor.
|
||||
func NewInterceptor() *Interceptor {
|
||||
return &Interceptor{
|
||||
clientBytes: make(chan []byte, 1),
|
||||
serverBytes: make(chan []byte, 1),
|
||||
}
|
||||
}
|
||||
|
||||
type readInterceptor struct {
|
||||
net.Conn
|
||||
bytes chan []byte
|
||||
once atomic.Bool
|
||||
}
|
||||
|
||||
func newReadInterceptor(c net.Conn, bytes chan []byte) *readInterceptor {
|
||||
return &readInterceptor{
|
||||
Conn: c,
|
||||
bytes: bytes,
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel any future Read interceptions and closes the channel.
|
||||
func (r *readInterceptor) Cancel() {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
r.once.Store(true)
|
||||
close(r.bytes)
|
||||
}
|
||||
|
||||
func (r *readInterceptor) Read(p []byte) (n int, err error) {
|
||||
if r.once.CompareAndSwap(false, true) {
|
||||
if n, err = r.Conn.Read(p); n > 0 {
|
||||
// We create a copy, since the Read caller may modify p
|
||||
// immediately after reading.
|
||||
data := make([]byte, n)
|
||||
copy(data, p[:n])
|
||||
// Buffer the bytes in the channel.
|
||||
r.bytes <- data
|
||||
}
|
||||
return
|
||||
}
|
||||
return r.Conn.Read(p)
|
||||
}
|
||||
|
||||
// Client binds the client connection to the interceptor.
|
||||
func (i *Interceptor) Client(c net.Conn) net.Conn {
|
||||
if ri, ok := c.(*readInterceptor); ok {
|
||||
return ri
|
||||
}
|
||||
i.clientReader = newReadInterceptor(c, i.clientBytes)
|
||||
return i.clientReader
|
||||
}
|
||||
|
||||
// Server binds the server connection to the interceptor.
|
||||
func (i *Interceptor) Server(c net.Conn) net.Conn {
|
||||
if ri, ok := c.(*readInterceptor); ok {
|
||||
return ri
|
||||
}
|
||||
i.serverReader = newReadInterceptor(c, i.serverBytes)
|
||||
return i.serverReader
|
||||
}
|
||||
|
||||
// Detect runs protocol detection on the previously bound Client and Server connection.
|
||||
//
|
||||
// It waits until either the client or the server performs a read operation,
|
||||
// which is then used for running protocol detection. If the read operation
|
||||
// takes longer than timeout, an error is returned.
|
||||
//
|
||||
// The returned channel always yields one result and is then closed.
|
||||
func (i *Interceptor) Detect(timeout time.Duration) <-chan *Intercepted {
|
||||
var interceptc = make(chan *Intercepted, 1)
|
||||
|
||||
go func() {
|
||||
// Make sure all channels are closed once we finish processing.
|
||||
defer close(interceptc)
|
||||
defer i.clientReader.Cancel()
|
||||
defer i.serverReader.Cancel()
|
||||
|
||||
select {
|
||||
case <-time.After(timeout): // timeout
|
||||
interceptc <- &Intercepted{
|
||||
Error: ErrTimeout,
|
||||
}
|
||||
|
||||
case data := <-i.clientBytes: // client sent banner
|
||||
p, err := Detect(Client, data)
|
||||
interceptc <- &Intercepted{
|
||||
Direction: Client,
|
||||
Protocol: p,
|
||||
Error: err,
|
||||
}
|
||||
|
||||
case data := <-i.serverBytes: // server sent banner
|
||||
p, err := Detect(Server, data)
|
||||
interceptc <- &Intercepted{
|
||||
Direction: Server,
|
||||
Protocol: p,
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return interceptc
|
||||
}
|
76
protocol/limit.go
Normal file
76
protocol/limit.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// AcceptFunc receives a direction and a detected protocol.
|
||||
type AcceptFunc func(Direction, *Protocol) error
|
||||
|
||||
// Limit the connection protocol, by running a detection after either side sends
|
||||
// a banner within timeout.
|
||||
//
|
||||
// If no protocol could be detected, the accept function is called with a nil
|
||||
// argument to check if we should proceed.
|
||||
//
|
||||
// If the accept function returns false, the connection will be closed.
|
||||
func Limit(conn net.Conn, accept AcceptFunc) net.Conn {
|
||||
if accept == nil {
|
||||
// Nothing to do here.
|
||||
return conn
|
||||
}
|
||||
|
||||
return &connLimiter{
|
||||
Conn: conn,
|
||||
accept: accept,
|
||||
}
|
||||
}
|
||||
|
||||
type connLimiter struct {
|
||||
net.Conn
|
||||
accept AcceptFunc
|
||||
acceptOnce sync.Once
|
||||
acceptError atomic.Value
|
||||
}
|
||||
|
||||
func (l *connLimiter) init(readData, writeData []byte) {
|
||||
l.acceptOnce.Do(func() {
|
||||
var (
|
||||
dir Direction
|
||||
data []byte
|
||||
)
|
||||
if readData != nil {
|
||||
// init called by initial read
|
||||
dir, data = Server, readData
|
||||
} else {
|
||||
// init called by initial write
|
||||
dir, data = Client, writeData
|
||||
}
|
||||
protocol, _ := Detect(dir, data)
|
||||
if err := l.accept(dir, protocol); err != nil {
|
||||
l.acceptError.Store(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (l *connLimiter) Read(p []byte) (n int, err error) {
|
||||
var ok bool
|
||||
if err, ok = l.acceptError.Load().(error); ok && err != nil {
|
||||
return
|
||||
}
|
||||
if n, err = l.Conn.Read(p); n > 0 {
|
||||
l.init(p[:n], nil)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *connLimiter) Write(p []byte) (n int, err error) {
|
||||
l.init(nil, p)
|
||||
var ok bool
|
||||
if err, ok = l.acceptError.Load().(error); ok && err != nil {
|
||||
return
|
||||
}
|
||||
return l.Conn.Write(p)
|
||||
}
|
47
protocol/protocol.go
Normal file
47
protocol/protocol.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Protocols supported by this package.
|
||||
const (
|
||||
ProtocolDNS = "dns"
|
||||
ProtocolHTTP = "http"
|
||||
ProtocolMySQL = "mysql"
|
||||
ProtocolPostgreSQL = "postgresql"
|
||||
ProtocolSSH = "ssh"
|
||||
ProtocolSSL = "ssl"
|
||||
ProtocolTLS = "tls"
|
||||
)
|
||||
|
||||
type Protocol struct {
|
||||
Name string
|
||||
Version Version
|
||||
}
|
||||
|
||||
type Version struct {
|
||||
Major int
|
||||
Minor int
|
||||
Patch int
|
||||
Extra string
|
||||
}
|
||||
|
||||
func (v Version) String() string {
|
||||
p := make([]string, 0, 3)
|
||||
if v.Major >= 0 {
|
||||
p = append(p, strconv.Itoa(v.Major))
|
||||
if v.Minor >= 0 {
|
||||
p = append(p, strconv.Itoa(v.Minor))
|
||||
if v.Patch >= 0 {
|
||||
p = append(p, strconv.Itoa(v.Patch))
|
||||
}
|
||||
}
|
||||
}
|
||||
s := strings.Join(p, ".")
|
||||
if v.Extra != "" {
|
||||
return s + "-" + v.Extra
|
||||
}
|
||||
return s
|
||||
}
|
Reference in New Issue
Block a user