Initial import

This commit is contained in:
2025-10-08 20:53:56 +02:00
commit 2081d684ed
25 changed files with 3486 additions and 0 deletions

113
protocol/detect.go Normal file
View File

@@ -0,0 +1,113 @@
package protocol
import (
"errors"
"fmt"
"sync"
"sync/atomic"
)
// Strict mode requires a full, compliant packet to be captured. This is only
// implemented by some detectors.
var Strict bool
// Common errors.
var (
ErrTimeout = errors.New("timeout")
ErrUnknown = errors.New("unknown protocol")
)
// Direction indicates the communcation direction.
type Direction int
// Directions supported by this package.
const (
Unknown Direction = iota
Client
Server
Both
)
func (dir Direction) Contains(other Direction) bool {
switch dir {
case Client:
return other == Client || other == Both
case Server:
return other == Server || other == Both
case Both:
return other == Client || other == Server
default:
return false
}
}
var directionName = map[Direction]string{
Client: "client",
Server: "server",
Both: "both",
}
func (dir Direction) String() string {
if s, ok := directionName[dir]; ok {
return s
}
return fmt.Sprintf("invalid (%d)", int(dir))
}
type format struct {
dir Direction
magic string
detect DetectFunc
}
// Formats is the list of registered formats.
var (
formatsMu sync.Mutex
atomicFormats atomic.Value
)
type DetectFunc func(Direction, []byte) *Protocol
func Register(dir Direction, magic string, detect DetectFunc) {
formatsMu.Lock()
formats, _ := atomicFormats.Load().([]format)
atomicFormats.Store(append(formats, format{dir, magic, detect}))
formatsMu.Unlock()
}
func matchMagic(magic string, data []byte) bool {
// Empty magic means the detector will always run.
if len(magic) == 0 {
return true
}
// The buffer should contain at least the same number of bytes
// as our magic.
if len(data) < len(magic) {
return false
}
// Match bytes in magic with bytes in data.
for i, b := range []byte(magic) {
if b != '?' && data[i] != b {
return false
}
}
return true
}
// Detect a protocol based on the provided data.
func Detect(dir Direction, data []byte) (*Protocol, error) {
formats, _ := atomicFormats.Load().([]format)
for _, f := range formats {
if f.dir.Contains(dir) {
// Check the buffer to see if we have sufficient bytes
if matchMagic(f.magic, data) {
if p := f.detect(dir, data); p != nil {
return p, nil
}
}
}
}
return nil, ErrUnknown
}

94
protocol/detect_http.go Normal file
View File

@@ -0,0 +1,94 @@
package protocol
import (
"bufio"
"bytes"
"fmt"
"net/http"
)
func init() {
Register(Client, "", detectHTTPRequest)
Register(Server, "HTTP/?.", detectHTTPResponse)
}
func detectHTTPRequest(dir Direction, data []byte) *Protocol {
// A minimal request "GET / HTTP/1.0\r\n" is > 8 bytes.
if len(data) < 8 {
return nil
}
if Strict {
var (
b = append(data, '\r', '\n')
r = bufio.NewReader(bytes.NewReader(b))
)
if request, err := http.ReadRequest(r); err == nil {
return &Protocol{
Name: ProtocolHTTP,
Version: Version{
Major: request.ProtoMajor,
Minor: request.ProtoMinor,
Patch: -1,
},
}
}
r.Reset(bytes.NewReader(b))
if response, err := http.ReadResponse(r, nil); err == nil {
return &Protocol{
Name: ProtocolHTTP,
Version: Version{
Major: response.ProtoMajor,
Minor: response.ProtoMinor,
Patch: -1,
},
}
}
return nil
}
crlfIndex := bytes.IndexFunc(data, func(r rune) bool {
return r == '\r' || r == '\n'
})
if crlfIndex == -1 {
return nil
}
// A request has three, space-separated parts.
part := bytes.Split(data[:crlfIndex], []byte(" "))
if len(part) != 3 {
return nil
}
// The last part starts with "HTTP/".
if !bytes.HasPrefix(part[2], []byte("HTTP/1")) {
return nil
}
var version = Version{Patch: -1}
fmt.Sscanf(string(part[2]), "HTTP/%d.%d ", &version.Major, &version.Minor)
return &Protocol{
Name: ProtocolHTTP,
Version: version,
}
}
func detectHTTPResponse(dir Direction, data []byte) *Protocol {
if !dir.Contains(Server) {
return nil
}
// A minimal response "HTTP/1.0 200 OK\r\n" is > 8 bytes.
if len(data) < 8 {
return nil
}
var version = Version{Patch: -1}
fmt.Sscanf(string(data), "HTTP/%d.%d ", &version.Major, &version.Minor)
return &Protocol{
Name: ProtocolHTTP,
Version: version,
}
}

View File

@@ -0,0 +1,154 @@
package protocol
import (
"errors"
"testing"
)
func TestDetectHTTPRequest(t *testing.T) {
atomicFormats.Store([]format{{Client, "", detectHTTPRequest}})
// A valid HTTP/1.0 GET request
http10Request := []byte("GET /old-page.html HTTP/1.0\r\nUser-Agent: NCSA_Mosaic/1.0\r\n\r\n")
// A valid HTTP/1.1 GET request
getRequest := []byte("GET /resource/item?id=123 HTTP/1.1\r\nHost: example.com\r\n\r\n")
// An invalid HTTP request
sshBanner := []byte("SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.4\r\n")
defer func() { Strict = false }()
for _, strict := range []bool{false, true} {
Strict = strict
name := "loose"
if strict {
name = "strict"
}
t.Run(name, func(t *testing.T) {
t.Run("HTTP/1.0 GET", func(t *testing.T) {
p, err := Detect(Client, http10Request)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
if p.Name != ProtocolHTTP {
t.Fatalf("expected http protocol, got %s", p.Name)
return
}
})
t.Run("HTTP/1.1 GET", func(t *testing.T) {
p, err := Detect(Client, getRequest)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
if p.Name != ProtocolHTTP {
t.Fatalf("expected http protocol, got %s", p.Name)
return
}
})
t.Run("Invalid SSH", func(t *testing.T) {
_, err := Detect(Server, sshBanner)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {
t.Logf("error %q, as expected", err)
}
})
})
}
}
func TestDetectHTTPResponse(t *testing.T) {
atomicFormats.Store([]format{{Server, "HTTP/?.? ", detectHTTPResponse}})
// A valid HTTP/1.0 403 Forbidden response
http10Response := []byte("HTTP/1.0 403 Forbidden\r\nServer: CERN/3.0\r\n\r\n")
// A valid HTTP/1.1 200 OK response
responseOK := []byte("HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n<html>...</html>")
// A valid HTTP/1.1 404 Not Found response
responseNotFound := []byte("HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n")
// An invalid HTTP GET request
getRequest := []byte("GET /resource/item?id=123 HTTP/1.1\r\nHost: example.com\r\n\r\n")
// An invalid banner (SSH)
sshBanner := []byte("SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.4\r\n")
defer func() { Strict = false }()
for _, strict := range []bool{false, true} {
Strict = strict
name := "loose"
if strict {
name = "strict"
}
t.Run(name, func(t *testing.T) {
t.Run("HTTP/1.0 403", func(t *testing.T) {
p, err := Detect(Server, http10Response)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
if p.Name != ProtocolHTTP {
t.Fatalf("expected http protocol, got %s", p.Name)
return
}
})
t.Run("HTTP/1.1 200", func(t *testing.T) {
p, err := Detect(Server, responseOK)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
if p.Name != ProtocolHTTP {
t.Fatalf("expected http protocol, got %s", p.Name)
return
}
})
t.Run("HTTP/1.1 404", func(t *testing.T) {
p, err := Detect(Server, responseNotFound)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
if p.Name != ProtocolHTTP {
t.Fatalf("expected http protocol, got %s", p.Name)
return
}
})
t.Run("Invalid HTTP/1.1 GET", func(t *testing.T) {
_, err := Detect(Server, getRequest)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {
t.Logf("error %q, as expected", err)
}
})
t.Run("Invalid SSH", func(t *testing.T) {
_, err := Detect(Server, sshBanner)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {
t.Logf("error %q, as expected", err)
}
})
})
}
}

51
protocol/detect_mysql.go Normal file
View File

@@ -0,0 +1,51 @@
package protocol
import (
"bytes"
"fmt"
)
func init() {
Register(Server, "\x0a", detectMySQL)
}
func detectMySQL(dir Direction, data []byte) *Protocol {
if len(data) < 7 {
return nil
}
// The first byte of the handshake packet is the protocol version.
// For MySQL, this is 10 (0x0A).
if data[0] != 0x0A {
return nil
}
// After the protocol version, there is a null-terminated server version string.
// We search for the null byte starting from the second byte (index 1).
nullIndex := bytes.IndexByte(data[1:], 0x00)
// If no null byte is found, it's not a valid banner.
if nullIndex == -1 {
return nil
}
// The position of the null byte is relative to the start of the whole slice.
// It's 1 (for the protocol byte) + nullIndex.
serverVersionEndPos := 1 + nullIndex
// After the null-terminated version string, there must be at least 4 bytes
// for the connection ID, plus more data for capabilities, auth, etc.
// We'll check for the 4-byte connection ID as a minimum requirement.
const connectionIDLength = 4
if len(data) < serverVersionEndPos+1+connectionIDLength {
return nil
}
var version Version
fmt.Sscanf(string(data[1:serverVersionEndPos]), "%d.%d.%d-%s", &version.Major, &version.Minor, &version.Patch, &version.Extra)
return &Protocol{
Name: ProtocolMySQL,
Version: version,
}
}

View File

@@ -0,0 +1,82 @@
package protocol
import (
"errors"
"testing"
)
func TestDetectMySQL(t *testing.T) {
atomicFormats.Store([]format{{Server, "\x0a", detectMySQL}})
// 1. A valid MySQL 8.0 banner
mysql8Banner := []byte{
0x0a, 0x38, 0x2e, 0x30, 0x2e, 0x33, 0x32, 0x00, 0x0d, 0x00, 0x00, 0x00,
0x04, 0x5a, 0x56, 0x5f, 0x3e, 0x6e, 0x76, 0x27, 0x00, 0xff, 0xff, 0xff,
0x02, 0x00, 0xff, 0xc7, 0x15, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x63, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x67, 0x5f, 0x73,
0x68, 0x61, 0x32, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00,
}
// 2. A valid MariaDB banner (protocol-compatible)
mariaDBBanner := []byte{
0x0a, 0x35, 0x2e, 0x35, 0x2e, 0x35, 0x2d, 0x31, 0x30, 0x2e, 0x36, 0x2e,
0x35, 0x2d, 0x4d, 0x61, 0x72, 0x69, 0x61, 0x44, 0x42, 0x2d, 0x6c, 0x6f,
0x67, 0x00, 0x1a, 0x00, 0x00, 0x00, 0x4e, 0x5c, 0x32, 0x7b, 0x45, 0x3b,
0x40, 0x60, 0x00, 0xff, 0xf7, 0x08, 0x02, 0x00, 0xff, 0x81, 0x15, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x6d, 0x79, 0x73,
0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70, 0x61,
0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00,
}
// 3. An invalid banner (e.g., an HTTP request)
httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
// 4. A short, invalid slice
shortSlice := []byte{0x0a, 0x31, 0x32, 0x33}
// 5. A slice that starts correctly but is malformed (no null terminator)
malformedSlice := []byte{0x0a, 0x38, 0x2e, 0x30, 0x2e, 0x30, 0x01, 0x02, 0x03, 0x04, 0x05}
t.Run("MySQL 8", func(t *testing.T) {
p, _ := Detect(Server, mysql8Banner)
if p == nil {
t.Fatal("expected MySQL protocol, got nil")
}
t.Logf("detected %s version %s", p.Name, p.Version)
})
t.Run("MariaDB", func(t *testing.T) {
p, _ := Detect(Server, mariaDBBanner)
if p == nil {
t.Fatal("expected MySQL protocol, got nil")
}
t.Logf("detected %s version %s", p.Name, p.Version)
})
t.Run("Invalid HTTP", func(t *testing.T) {
_, err := Detect(Server, httpBanner)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {
t.Logf("error %q, as expected", err)
}
})
t.Run("Too short", func(t *testing.T) {
_, err := Detect(Server, shortSlice)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {
t.Logf("error %q, as expected", err)
}
})
t.Run("Malformed", func(t *testing.T) {
_, err := Detect(Server, malformedSlice)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {
t.Logf("error %q, as expected", err)
}
})
}

View File

@@ -0,0 +1,70 @@
package protocol
import (
"encoding/binary"
"log"
)
func init() {
registerPostgreSQL()
}
func registerPostgreSQL() {
Register(Server, "R\x00???", detectPostgreSQLServer) // Authentication request
Register(Server, "K\x00???", detectPostgreSQLServer) // BackendKeyData
Register(Server, "S\x00???", detectPostgreSQLServer) // ParameterStatus
Register(Server, "Z\x00???", detectPostgreSQLServer) // ReadyForQuery
Register(Server, "E\x00???", detectPostgreSQLServer) // ErrorResponse
Register(Server, "N\x00???", detectPostgreSQLServer) // NoticeResponse
Register(Client, "????\x00\x02\x00\x00", detectPostgreSQLClient) // Startup packet, protocol 2.0
Register(Client, "????\x00\x03\x00\x00", detectPostgreSQLClient) // Startup packet, protocol 3.0
}
func detectPostgreSQLClient(dir Direction, data []byte) *Protocol {
// A client startup message needs at least 8 bytes (length + protocol version).
if len(data) < 8 {
return nil
}
length := int(binary.BigEndian.Uint32(data[0:]))
if len(data) != length {
log.Printf("not postgres %q: %d != %d", data, len(data), length)
return nil
}
major := int(binary.BigEndian.Uint16(data[4:]))
minor := int(binary.BigEndian.Uint16(data[6:]))
if major == 2 || major == 3 {
return &Protocol{
Name: ProtocolPostgreSQL,
Version: Version{
Major: major,
Minor: minor,
Patch: -1,
},
}
}
return nil
}
func detectPostgreSQLServer(dir Direction, data []byte) *Protocol {
// A server message needs at least 5 bytes (type + length).
if len(data) < 5 {
return nil
}
// All server messages (and subsequent client messages) are tagged with a single-byte type.
firstByte := data[0]
switch firstByte {
case 'R', // Authentication request
'K', // BackendKeyData
'S', // ParameterStatus
'Z', // ReadyForQuery
'E', // ErrorResponse
'N': // NoticeResponse
return &Protocol{Name: ProtocolPostgreSQL}
default:
return nil
}
}

View File

@@ -0,0 +1,94 @@
package protocol
import (
"errors"
"testing"
)
func TestDetectPostgreSQLClient(t *testing.T) {
atomicFormats.Store([]format{})
registerPostgreSQL()
// 1. A valid PostgreSQL client startup message
// Format: len (4b), proto (4b), params (n-bytes)
// Here, user=mazeio, database=test
// The message is: "user\0mazeio\0database\0test\0\0"
// Total length: 4 (len) + 4 (proto) + 27 (params) = 35 (0x23)
pgClientStartup := []byte{
0x00, 0x00, 0x00, 0x23, // Length: 35
0x00, 0x03, 0x00, 0x00, // Protocol Version 3.0
'u', 's', 'e', 'r', 0x00, 'm', 'a', 'z', 'e', 'i', 'o', 0x00,
'd', 'a', 't', 'a', 'b', 'a', 's', 'e', 0x00, 't', 'e', 's', 't', 0x00, 0x00,
}
t.Run("Protocol 3.0", func(t *testing.T) {
p, err := Detect(Client, pgClientStartup)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
if p.Name != ProtocolPostgreSQL {
t.Fatalf("expected postgres protocol, got %s", p.Name)
return
}
})
}
func TestDetectPostgreSQLServer(t *testing.T) {
atomicFormats.Store([]format{})
registerPostgreSQL()
// A valid PostgreSQL server AuthenticationOk response
// Format: type (1b), len (4b), content (4b)
pgServerAuthOK := []byte{
'R', // Type: Authentication
0x00, 0x00, 0x00, 0x08, // Length: 8
0x00, 0x00, 0x00, 0x00, // Auth OK (0)
}
// A valid PostgreSQL server ErrorResponse
pgServerError := []byte{
'E', // Type: ErrorResponse
0x00, 0x00, 0x00, 0x31, // Length
'S', 'E', 'R', 'R', 'O', 'R', 0x00, // ... and so on
}
// Invalid data (HTTP GET request)
httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
t.Run("AuthenticationOk", func(t *testing.T) {
p, err := Detect(Server, pgServerAuthOK)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
if p.Name != ProtocolPostgreSQL {
t.Fatalf("expected postgres protocol, got %s", p.Name)
return
}
})
t.Run("ErrorResponse", func(t *testing.T) {
p, err := Detect(Server, pgServerError)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
if p.Name != ProtocolPostgreSQL {
t.Fatalf("expected postgres protocol, got %s", p.Name)
return
}
})
t.Run("Invalid HTTP", func(t *testing.T) {
_, err := Detect(Server, httpBanner)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {
t.Logf("error %q, as expected", err)
}
})
}

51
protocol/detect_ssh.go Normal file
View File

@@ -0,0 +1,51 @@
package protocol
import (
"bytes"
)
// The required prefix for the SSH protocol identification line.
const (
ssh199Prefix = "SSH-1.99-"
ssh20Prefix = "SSH-2.0-"
)
func init() {
Register(Both, "", detectSSH)
}
func detectSSH(dir Direction, data []byte) *Protocol {
// The data must be at least as long as the prefix itself.
if len(data) < len(ssh20Prefix) {
return nil
}
// The protocol allows for pre-banner text, so we have to check all lines.
for _, line := range bytes.Split(data, []byte{'\n'}) {
line = bytes.TrimSuffix(line, []byte{'\r'})
if bytes.HasPrefix(line, []byte(ssh20Prefix)) {
return &Protocol{
Name: ProtocolSSH,
Version: Version{
Major: 2,
Minor: 0,
Patch: -1,
Extra: string(line[len(ssh20Prefix):]),
},
}
}
if bytes.HasPrefix(line, []byte(ssh199Prefix)) {
return &Protocol{
Name: ProtocolSSH,
Version: Version{
Major: 1,
Minor: 99,
Patch: -1,
Extra: string(line[len(ssh20Prefix):]),
},
}
}
}
return nil
}

103
protocol/detect_ssh_test.go Normal file
View File

@@ -0,0 +1,103 @@
package protocol
import (
"errors"
"testing"
)
func TestDetectSSH(t *testing.T) {
atomicFormats.Store([]format{{Both, "", detectSSH}})
// 1. A standard OpenSSH banner
openSSHBanner := []byte("SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.4\r\n")
// 2. An SSH banner with a pre-banner legal notice
preBannerSSH := []byte(
"*******************************************************************\r\n" +
"* W A R N I N G *\r\n" +
"* This system is for the use of authorized users only. *\r\n" +
"*******************************************************************\r\n" +
"SSH-2.0-OpenSSH_7.6p1\r\n",
)
// 3. A different SSH implementation (Dropbear)
dropbearBanner := []byte("SSH-2.0-dropbear_2020.81\r\n")
// 4. An invalid banner (e.g., the MySQL banner from the previous example)
mysqlBanner := []byte{
0x0a, 0x38, 0x2e, 0x30, 0x2e, 0x33, 0x32, 0x00, 0x0d, 0x00, 0x00, 0x00,
}
// 5. A simple HTTP request
httpBanner := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
t.Run("OpenSSH client", func(t *testing.T) {
p, err := Detect(Server, openSSHBanner)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
if p.Name != ProtocolSSH {
t.Fatalf("expected ssh protocol, got %s", p.Name)
return
}
})
t.Run("OpenSSH server", func(t *testing.T) {
p, err := Detect(Server, openSSHBanner)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
if p.Name != ProtocolSSH {
t.Fatalf("expected ssh protocol, got %s", p.Name)
return
}
})
t.Run("OpenSSH server with banner", func(t *testing.T) {
p, err := Detect(Server, preBannerSSH)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
if p.Name != ProtocolSSH {
t.Fatalf("expected ssh protocol, got %s", p.Name)
return
}
})
t.Run("Dropbear server", func(t *testing.T) {
p, err := Detect(Server, dropbearBanner)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
if p.Name != ProtocolSSH {
t.Fatalf("expected ssh protocol, got %s", p.Name)
return
}
})
t.Run("Invalid MySQL banner", func(t *testing.T) {
_, err := Detect(Server, mysqlBanner)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {
t.Logf("error %q, as expected", err)
}
})
t.Run("Invalid HTTP banner", func(t *testing.T) {
_, err := Detect(Server, httpBanner)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {
t.Logf("error %q, as expected", err)
}
})
}

98
protocol/detect_tls.go Normal file
View File

@@ -0,0 +1,98 @@
package protocol
import (
"golang.org/x/crypto/cryptobyte"
"git.maze.io/go/dpi"
)
func init() {
registerTLS()
}
func registerTLS() {
Register(Both, "\x16\x03\x00", detectTLS) // SSLv3
Register(Both, "\x16\x03\x01", detectTLS) // TLSv1.0
Register(Both, "\x16\x03\x02", detectTLS) // TLSv1.1
Register(Both, "\x16\x03\x03", detectTLS) // TLSv1.2
}
func detectTLS(dir Direction, data []byte) *Protocol {
stream := cryptobyte.String(data)
// A TLS packet always has a content type (1 byte), version (2 bytes) and length (2 bytes).
if len(stream) < 5 {
return nil
}
// Check for TLS Handshake (type 22)
var header struct {
Type uint8
Version uint16
Length uint32
}
if !stream.ReadUint8(&header.Type) || header.Type != 0x16 {
return nil
}
if !stream.ReadUint16(&header.Version) {
return nil
}
if !stream.ReadUint24(&header.Length) {
return nil
}
// Detected SSL/TLS version
var version dpi.TLSVersion
// Attempt to decode the full TLS Client Hello handshake
if version == 0 {
if hello, err := dpi.DecodeTLSClientHelloHandshake(data); err == nil {
version = hello.Version
}
}
// Attempt to decode the full TLS Server Hello handshake
if version == 0 {
if hello, err := dpi.DecodeTLSServerHello(data); err == nil {
version = hello.Version
}
}
// Attempt to decode at least the handshake protocol and version.
if version == 0 && !Strict {
var handshakeType uint8
if stream.ReadUint8(&handshakeType) && (handshakeType == 1 || handshakeType == 2) {
var (
length uint32
versionWord uint16
)
if stream.ReadUint24(&length) && stream.ReadUint16(&versionWord) {
version = dpi.TLSVersion(versionWord)
}
}
}
// Fall back to the version in the TLS record header, this is less accurate
if version == 0 && !Strict {
version = dpi.TLSVersion(header.Version)
}
// We're "multi protocol", in that SSL is its own protocol
if version == dpi.VersionSSL30 {
return &Protocol{
Name: ProtocolSSL,
Version: Version{Major: 3, Minor: 0, Patch: -1},
}
} else if version >= dpi.VersionTLS10 && version <= dpi.VersionTLS13 {
return &Protocol{
Name: ProtocolTLS,
Version: Version{Major: 1, Minor: int(uint8(version) - 1), Patch: -1},
}
} else if version >= dpi.VersionTLS13Draft && version <= dpi.VersionTLS13Draft23 {
return &Protocol{
Name: ProtocolTLS,
Version: Version{Major: 1, Minor: 3, Patch: -1},
}
}
return nil
}

280
protocol/detect_tls_test.go Normal file
View File

@@ -0,0 +1,280 @@
package protocol
import (
"encoding/hex"
"errors"
"strings"
"testing"
)
func TestDetectTLS(t *testing.T) {
atomicFormats.Store([]format{})
registerTLS()
// A SSLv3 Client Hello
sslV3ClientHello := testMustDecodeHexString(`
16 03 00 00 51 01 00 00 4d 03
00 50 42 b2 29 1f cf 52 a0 94 87 05 e7 0b 63 08
12 a2 6c 59 f7 f5 72 2b 57 14 a7 07 95 cb ce e5
e4 00 00 26 00 04 00 05 00 2f 00 33 00 32 00 0a
fe ff 00 16 00 13 00 66 00 09 fe fe 00 15 00 12
00 03 00 08 00 06 00 14 00 11 01 00`)
// A synthesized TLS 1.1 ClientHello
tls11ClientHello := []byte{
// --- Record Layer (5 bytes) ---
0x16, // Content Type: Handshake (22)
0x03, 0x02, // Version: TLS 1.1 (Major=3, Minor=2)
0x00, 0x45, // Length of the handshake message below (69 bytes)
// --- Handshake Protocol: ClientHello (69 bytes) ---
0x01, // Handshake Type: ClientHello (1)
0x00, 0x00, 0x41, // Length of the rest of the message (65 bytes)
0x03, 0x02, // Client Version: TLS 1.1 (Major=3, Minor=2)
// Random (32 bytes):
// In TLS 1.1+, the entire 32 bytes are fully random.
// The timestamp structure from SSLv3 is removed.
0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88,
0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00,
0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88,
0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00,
// Session ID
0x00, // Session ID Length: 0 (new session)
// Cipher Suites
0x00, 0x04, // Cipher Suites Length: 4 bytes (2 suites)
0x00, 0x2F, // Cipher Suite: TLS_RSA_WITH_AES_128_CBC_SHA
0x00, 0x35, // Cipher Suite: TLS_RSA_WITH_AES_256_CBC_SHA
// Compression Methods
0x01, // Compression Methods Length: 1 byte
0x00, // Compression Method: NULL (0)
// --- Extensions (20 bytes) ---
0x00, 0x14, // Extensions Length: 20 bytes
// Extension: Server Name Indication (SNI)
0x00, 0x00, // Extension Type: server_name (0)
0x00, 0x10, // Extension Length: 16 bytes
0x00, 0x0E, // Server Name List Length: 14 bytes
0x00, // Server Name Type: host_name (0)
0x00, 0x0B, // Server Name Length: 11 bytes
// Server Name: "example.com"
'e', 'x', 'a', 'm', 'p', 'l', 'e', '.', 'c', 'o', 'm',
}
// A synthesized partial TLS 1.1 ClientHello
tls11ClientHelloPartial := []byte{
// --- Record Layer (5 bytes) ---
0x16, // Content Type: Handshake (22)
0x03, 0x02, // Version: TLS 1.1 (Major=3, Minor=2)
0x00, 0x45, // Length of the handshake message below (69 bytes)
// --- Handshake Protocol: ClientHello (69 bytes) ---
0x01, // Handshake Type: ClientHello (1)
0x00, 0x00, 0x41, // Length of the rest of the message (65 bytes)
0x03, 0x02, // Client Version: TLS 1.1 (Major=3, Minor=2)
}
// A synthesized TLSv1.2 ClientHello
tls12ClientHello := []byte{
// --- Record Layer (5 bytes) ---
0x16, // Content Type: Handshake (22)
0x03, 0x03, // Version: TLS 1.2 (Major=3, Minor=3)
0x00, 0x61, // Length of handshake message below (97 bytes)
// --- Handshake Protocol: ClientHello (97 bytes) ---
0x01, // Handshake Type: ClientHello (1)
0x00, 0x00, 0x5D, // Length of the rest of the message (93 bytes)
0x03, 0x03, // Client Version: TLS 1.2 (Major=3, Minor=3)
// Random (32 bytes)
0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x11,
0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99,
0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x11,
0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99,
// Session ID
0x00, // Session ID Length: 0 (new session)
// Cipher Suites (using modern GCM suites)
0x00, 0x04, // Cipher Suites Length: 4 bytes (2 suites)
0xC0, 0x2F, // Cipher Suite: TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
0xC0, 0x30, // Cipher Suite: TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
// Compression Methods
0x01, // Compression Methods Length: 1 byte
0x00, // Compression Method: NULL (0)
// --- Extensions (48 bytes) ---
0x00, 0x30, // Extensions Length: 48 bytes
// Extension: Server Name Indication (SNI) (20 bytes)
0x00, 0x00, 0x00, 0x10, 0x00, 0x0E, 0x00, 0x00, 0x0B,
'e', 'x', 'a', 'm', 'p', 'l', 'e', '.', 'c', 'o', 'm',
// Extension: Signature Algorithms (10 bytes) - CRITICAL for TLS 1.2
0x00, 0x0D, // Extension Type: signature_algorithms (13)
0x00, 0x06, // Extension Length: 6 bytes
0x00, 0x04, // Hash/Signature Algorithm List Length: 4 bytes
0x04, 0x01, // Algorithm: rsa_pkcs1_sha256
0x05, 0x01, // Algorithm: rsa_pkcs1_sha384
// Extension: Application-Layer Protocol Negotiation (ALPN) (18 bytes)
0x00, 0x10, // Extension Type: application_layer_protocol_negotiation (16)
0x00, 0x0E, // Extension Length: 14 bytes
0x00, 0x0C, // ALPN Extension Length: 12 bytes
// ALPN Protocol: "h2" (HTTP/2)
0x02, 'h', '2',
// ALPN Protocol: "http/1.1"
0x08, 'h', 't', 't', 'p', '/', '1', '.', '1',
}
// A valid TLS 1.3 Client Hello (captured from a real connection)
tls13ClientHello := []byte{
0x16, // Content Type: Handshake (22)
0x03, 0x01, // Version: TLS 1.0 (for compatibility in a 1.3 hello)
0x01, 0x3a, // Length
0x01, 0x00, 0x01, 0x36, 0x03, 0x03, 0xb1,
0x40, 0xd3, 0xf1, 0x7d, 0xa3, 0xb8, 0x33, 0xac, 0xad, 0x21, 0x79, 0x9c,
0xbe, 0x39, 0x96, 0x08, 0x49, 0x3b, 0x53, 0x75, 0xa0, 0x1b, 0xee, 0x6e,
0x6a, 0xbe, 0x6c, 0x41, 0xdf, 0x6c, 0xf4, 0x20, 0xa4, 0xaa, 0x0c, 0xca,
0xd4, 0x37, 0x76, 0x5f, 0x49, 0xc6, 0x06, 0x9b, 0xac, 0x90, 0x89, 0x76,
0x1c, 0xc7, 0xc4, 0x12, 0xb4, 0x4a, 0xe0, 0x27, 0x72, 0x89, 0x97, 0x85,
0x76, 0xf8, 0xc8, 0x83, 0x00, 0x62, 0x13, 0x03, 0x13, 0x02, 0x13, 0x01,
0xcc, 0xa9, 0xcc, 0xa8, 0xcc, 0xaa, 0xc0, 0x30, 0xc0, 0x2c, 0xc0, 0x28,
0xc0, 0x24, 0xc0, 0x14, 0xc0, 0x0a, 0x00, 0x9f, 0x00, 0x6b, 0x00, 0x39,
0xff, 0x85, 0x00, 0xc4, 0x00, 0x88, 0x00, 0x81, 0x00, 0x9d, 0x00, 0x3d,
0x00, 0x35, 0x00, 0xc0, 0x00, 0x84, 0xc0, 0x2f, 0xc0, 0x2b, 0xc0, 0x27,
0xc0, 0x23, 0xc0, 0x13, 0xc0, 0x09, 0x00, 0x9e, 0x00, 0x67, 0x00, 0x33,
0x00, 0xbe, 0x00, 0x45, 0x00, 0x9c, 0x00, 0x3c, 0x00, 0x2f, 0x00, 0xba,
0x00, 0x41, 0xc0, 0x11, 0xc0, 0x07, 0x00, 0x05, 0x00, 0x04, 0xc0, 0x12,
0xc0, 0x08, 0x00, 0x16, 0x00, 0x0a, 0x00, 0xff, 0x01, 0x00, 0x00, 0x8b,
0x00, 0x2b, 0x00, 0x09, 0x08, 0x03, 0x04, 0x03, 0x03, 0x03, 0x02, 0x03,
0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x2c,
0x4b, 0xaa, 0xb4, 0xb3, 0xc8, 0x93, 0xcd, 0x5c, 0x24, 0xb9, 0x9b, 0xd4,
0x59, 0x04, 0xfe, 0x69, 0xaf, 0x68, 0xb9, 0xa6, 0x36, 0xbb, 0xab, 0x87,
0xfa, 0x15, 0x59, 0xea, 0xdd, 0x38, 0x68, 0x00, 0x00, 0x00, 0x0e, 0x00,
0x0c, 0x00, 0x00, 0x09, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73,
0x74, 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00, 0x00, 0x0a, 0x00, 0x0a, 0x00,
0x08, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, 0x00, 0x0d, 0x00,
0x18, 0x00, 0x16, 0x08, 0x06, 0x06, 0x01, 0x06, 0x03, 0x08, 0x05, 0x05,
0x01, 0x05, 0x03, 0x08, 0x04, 0x04, 0x01, 0x04, 0x03, 0x02, 0x01, 0x02,
0x03, 0x00, 0x10, 0x00, 0x0e, 0x00, 0x0c, 0x02, 0x68, 0x32, 0x08, 0x68,
0x74, 0x74, 0x70, 0x2f, 0x31, 0x2e, 0x31,
}
// Invalid data (Postgres client startup)
pgClientStartup := []byte{
0x00, 0x00, 0x00, 0x25, 0x00, 0x03, 0x00, 0x00,
}
defer func() { Strict = false }()
for _, strict := range []bool{false, true} {
Strict = strict
name := "loose"
if strict {
name = "strict"
}
t.Run(name, func(t *testing.T) {
t.Run("SSLv3 Client Hello", func(t *testing.T) {
p, err := Detect(Client, sslV3ClientHello)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
if p.Name != ProtocolSSL {
t.Fatalf("expected ssl protocol, got %s", p.Name)
return
}
})
t.Run("TLS 1.1 Client Hello", func(t *testing.T) {
p, err := Detect(Client, tls11ClientHello)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
if p.Name != ProtocolTLS {
t.Fatalf("expected tls protocol, got %s", p.Name)
return
}
})
t.Run("TLS 1.1 partial Client Hello", func(t *testing.T) {
p, err := Detect(Client, tls11ClientHelloPartial)
if strict {
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {
t.Logf("error %q, as expected", err)
}
} else {
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
if p.Name != ProtocolTLS {
t.Fatalf("expected tls protocol, got %s", p.Name)
return
}
}
})
t.Run("TLS 1.2 Client Hello", func(t *testing.T) {
p, err := Detect(Client, tls12ClientHello)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
if p.Name != ProtocolTLS {
t.Fatalf("expected tls protocol, got %s", p.Name)
return
}
})
t.Run("TLS 1.3 Client Hello", func(t *testing.T) {
p, err := Detect(Client, tls13ClientHello)
if err != nil {
t.Fatal(err)
return
}
t.Logf("detected %s version %s", p.Name, p.Version)
if p.Name != ProtocolTLS {
t.Fatalf("expected tls protocol, got %s", p.Name)
return
}
})
t.Run("Invalid PostgreSQL", func(t *testing.T) {
_, err := Detect(Server, pgClientStartup)
if !errors.Is(err, ErrUnknown) {
t.Fatalf("expected unknown format, got error %T: %q", err, err)
} else {
t.Logf("error %q, as expected", err)
}
})
})
}
}
func testDecodeHexString(s string) ([]byte, error) {
s = strings.TrimSpace(s)
s = strings.ReplaceAll(s, " ", "")
s = strings.ReplaceAll(s, "\n", "")
s = strings.ReplaceAll(s, "\t", "")
return hex.DecodeString(s)
}
func testMustDecodeHexString(s string) []byte {
b, err := testDecodeHexString(s)
if err != nil {
panic(err)
}
return b
}

128
protocol/intercept.go Normal file
View File

@@ -0,0 +1,128 @@
package protocol
import (
"net"
"sync/atomic"
"time"
)
// Intercepted is the result returned by [Interceptor.Detect].
type Intercepted struct {
Direction Direction
Protocol *Protocol
Error error
}
// Interceptor intercepts reads from client or server.
type Interceptor struct {
clientBytes chan []byte
clientReader *readInterceptor
serverBytes chan []byte
serverReader *readInterceptor
}
// NewInterceptor creates a new (transparent) protocol interceptor.
func NewInterceptor() *Interceptor {
return &Interceptor{
clientBytes: make(chan []byte, 1),
serverBytes: make(chan []byte, 1),
}
}
type readInterceptor struct {
net.Conn
bytes chan []byte
once atomic.Bool
}
func newReadInterceptor(c net.Conn, bytes chan []byte) *readInterceptor {
return &readInterceptor{
Conn: c,
bytes: bytes,
}
}
// Cancel any future Read interceptions and closes the channel.
func (r *readInterceptor) Cancel() {
if r == nil {
return
}
r.once.Store(true)
close(r.bytes)
}
func (r *readInterceptor) Read(p []byte) (n int, err error) {
if r.once.CompareAndSwap(false, true) {
if n, err = r.Conn.Read(p); n > 0 {
// We create a copy, since the Read caller may modify p
// immediately after reading.
data := make([]byte, n)
copy(data, p[:n])
// Buffer the bytes in the channel.
r.bytes <- data
}
return
}
return r.Conn.Read(p)
}
// Client binds the client connection to the interceptor.
func (i *Interceptor) Client(c net.Conn) net.Conn {
if ri, ok := c.(*readInterceptor); ok {
return ri
}
i.clientReader = newReadInterceptor(c, i.clientBytes)
return i.clientReader
}
// Server binds the server connection to the interceptor.
func (i *Interceptor) Server(c net.Conn) net.Conn {
if ri, ok := c.(*readInterceptor); ok {
return ri
}
i.serverReader = newReadInterceptor(c, i.serverBytes)
return i.serverReader
}
// Detect runs protocol detection on the previously bound Client and Server connection.
//
// It waits until either the client or the server performs a read operation,
// which is then used for running protocol detection. If the read operation
// takes longer than timeout, an error is returned.
//
// The returned channel always yields one result and is then closed.
func (i *Interceptor) Detect(timeout time.Duration) <-chan *Intercepted {
var interceptc = make(chan *Intercepted, 1)
go func() {
// Make sure all channels are closed once we finish processing.
defer close(interceptc)
defer i.clientReader.Cancel()
defer i.serverReader.Cancel()
select {
case <-time.After(timeout): // timeout
interceptc <- &Intercepted{
Error: ErrTimeout,
}
case data := <-i.clientBytes: // client sent banner
p, err := Detect(Client, data)
interceptc <- &Intercepted{
Direction: Client,
Protocol: p,
Error: err,
}
case data := <-i.serverBytes: // server sent banner
p, err := Detect(Server, data)
interceptc <- &Intercepted{
Direction: Server,
Protocol: p,
Error: err,
}
}
}()
return interceptc
}

76
protocol/limit.go Normal file
View File

@@ -0,0 +1,76 @@
package protocol
import (
"net"
"sync"
"sync/atomic"
)
// AcceptFunc receives a direction and a detected protocol.
type AcceptFunc func(Direction, *Protocol) error
// Limit the connection protocol, by running a detection after either side sends
// a banner within timeout.
//
// If no protocol could be detected, the accept function is called with a nil
// argument to check if we should proceed.
//
// If the accept function returns false, the connection will be closed.
func Limit(conn net.Conn, accept AcceptFunc) net.Conn {
if accept == nil {
// Nothing to do here.
return conn
}
return &connLimiter{
Conn: conn,
accept: accept,
}
}
type connLimiter struct {
net.Conn
accept AcceptFunc
acceptOnce sync.Once
acceptError atomic.Value
}
func (l *connLimiter) init(readData, writeData []byte) {
l.acceptOnce.Do(func() {
var (
dir Direction
data []byte
)
if readData != nil {
// init called by initial read
dir, data = Server, readData
} else {
// init called by initial write
dir, data = Client, writeData
}
protocol, _ := Detect(dir, data)
if err := l.accept(dir, protocol); err != nil {
l.acceptError.Store(err)
}
})
}
func (l *connLimiter) Read(p []byte) (n int, err error) {
var ok bool
if err, ok = l.acceptError.Load().(error); ok && err != nil {
return
}
if n, err = l.Conn.Read(p); n > 0 {
l.init(p[:n], nil)
}
return
}
func (l *connLimiter) Write(p []byte) (n int, err error) {
l.init(nil, p)
var ok bool
if err, ok = l.acceptError.Load().(error); ok && err != nil {
return
}
return l.Conn.Write(p)
}

47
protocol/protocol.go Normal file
View File

@@ -0,0 +1,47 @@
package protocol
import (
"strconv"
"strings"
)
// Protocols supported by this package.
const (
ProtocolDNS = "dns"
ProtocolHTTP = "http"
ProtocolMySQL = "mysql"
ProtocolPostgreSQL = "postgresql"
ProtocolSSH = "ssh"
ProtocolSSL = "ssl"
ProtocolTLS = "tls"
)
type Protocol struct {
Name string
Version Version
}
type Version struct {
Major int
Minor int
Patch int
Extra string
}
func (v Version) String() string {
p := make([]string, 0, 3)
if v.Major >= 0 {
p = append(p, strconv.Itoa(v.Major))
if v.Minor >= 0 {
p = append(p, strconv.Itoa(v.Minor))
if v.Patch >= 0 {
p = append(p, strconv.Itoa(v.Patch))
}
}
}
s := strings.Join(p, ".")
if v.Extra != "" {
return s + "-" + v.Extra
}
return s
}