From 05686075c4bbaec4714a8a7bb651f18d80885c24 Mon Sep 17 00:00:00 2001 From: maze Date: Thu, 9 Oct 2025 17:42:55 +0200 Subject: [PATCH] Added MQTT detection --- protocol/detect_mqtt.go | 150 +++++++++++++++++++++++++++++++++++ protocol/detect_mqtt_test.go | 86 ++++++++++++++++++++ protocol/protocol.go | 1 + 3 files changed, 237 insertions(+) create mode 100644 protocol/detect_mqtt.go create mode 100644 protocol/detect_mqtt_test.go diff --git a/protocol/detect_mqtt.go b/protocol/detect_mqtt.go new file mode 100644 index 0000000..24d3733 --- /dev/null +++ b/protocol/detect_mqtt.go @@ -0,0 +1,150 @@ +package protocol + +import ( + "log" + + "golang.org/x/crypto/cryptobyte" +) + +func init() { + Register(Both, "\x10?*MQTT", detectMQTT) +} + +func detectMQTT(dir Direction, data []byte, srcPort, dstPort int) (proto *Protocol, confidence float64) { + stream := cryptobyte.String(data) + + var packetType byte + if !stream.ReadUint8(&packetType) || packetType != 0x10 { + // Not an MQTT CONNECT control packet. + return + } + + // We read the value but only check in the end if we have bytes remaining. + var remainingLength uint32 + if !readMQTTVariableLengthUint32(&stream, &remainingLength) { + return + } + + var protocolName cryptobyte.String + if !stream.ReadUint16LengthPrefixed(&protocolName) || string(protocolName) != "MQTT" { + return + } + + // We are reasonabily sure this is MQTT now. + proto = &Protocol{ + Name: ProtocolMQTT, + } + confidence = 0.5 + + var protocolVersion uint8 + if !stream.ReadUint8(&protocolVersion) { + return + } + switch protocolVersion { + case 4: + proto.Version = Version{Major: 3, Minor: 1, Patch: 1} + confidence += .2 + case 5: + proto.Version = Version{Major: 5, Minor: 0, Patch: -1} + confidence += .2 + } + + var connectFlags uint8 + if !stream.ReadUint8(&connectFlags) { // read connect flags + return + } + confidence += .05 + var ( + hasUsername = (connectFlags&0b10000000)>>7 == 1 + hasPassword = (connectFlags&0b01000000)>>6 == 1 + willFlag = (connectFlags&0b00000100)>>2 == 1 + reserved = (connectFlags&0b00000001)>>0 == 1 + discard cryptobyte.String + ) + if reserved { + // Reserved bit is supposed to be not set. + confidence -= .2 + } + + var keepAlive uint16 + if !stream.ReadUint16(&keepAlive) { + return + } + + if !stream.ReadUint16LengthPrefixed(&discard) { // read client ID + return + } + confidence += .05 + discard = discard[:0] + + if willFlag { + if !stream.ReadUint16LengthPrefixed(&discard) { // read will topic + log.Println("will topic fail") + return + } + discard = discard[:0] + if !stream.ReadUint16LengthPrefixed(&discard) { // read will message + log.Println("will message fail") + return + } + discard = discard[:0] + confidence += .05 + } else { + confidence += .05 + } + + if hasUsername { + if !stream.ReadUint16LengthPrefixed(&discard) { // read username + log.Println("user fail") + return + } + discard = discard[:0] + confidence += .05 + } else { + confidence += .05 + } + + if hasPassword { + if !stream.ReadUint16LengthPrefixed(&discard) { // read password + log.Println("pass fail") + return + } + confidence += .05 + } else { + confidence += .05 + } + + if !stream.Empty() { + confidence -= .2 + } else { + confidence += .04 + } + + return +} + +func readMQTTVariableLengthUint32(stream *cryptobyte.String, value *uint32) bool { + var ( + multiplier uint32 = 1 + read int + ) + for { + var encodedByte byte + if !stream.ReadUint8(&encodedByte) { + return false + } + read++ + + *value += uint32(encodedByte&0x7F) * multiplier + if encodedByte&0x80 == 0 { + // Last byte + break + } + + if read >= 4 { + return false + } + multiplier *= 128 + } + return true +} diff --git a/protocol/detect_mqtt_test.go b/protocol/detect_mqtt_test.go new file mode 100644 index 0000000..29a174a --- /dev/null +++ b/protocol/detect_mqtt_test.go @@ -0,0 +1,86 @@ +package protocol + +import "testing" + +func TestDetectMQTT(t *testing.T) { + // A valid CONNECT packet + validSimplePacket := []byte{ + 0x10, 0x15, // Fixed Header: Type=CONNECT, Remaining Length=21 + 0x00, 0x04, 'M', 'Q', 'T', 'T', // Protocol Name: "MQTT" + 0x04, // Protocol Version: 4 (v3.1.1) + 0x02, // Connect Flags: Clean Session=true + 0x00, 0x3C, // Keep Alive: 60 seconds + 0x00, 0x07, 'm', 'a', 'z', 'e', '.', 'i', 'o', // Client ID: "maze.io" + } + + // A valid packet with Will, Username, and Password + validFullPacket := []byte{ + 0x10, 0x3E, // Type=CONNECT, Remaining Length=62 + 0x00, 0x04, 'M', 'Q', 'T', 'T', + 0x05, // Version + 0b11001110, // Flags: User, Pass, Will(QoS=1, Retain=false), Clean + 0x00, 0x1E, // Keep Alive: 30s + 0x00, 0x0B, 't', 'e', 's', 't', '-', 'c', 'l', 'i', 'e', 'n', 't', // Client ID + 0x00, 0x0D, 's', 't', 'a', 't', 'u', 's', '/', 'd', 'e', 'v', 'i', 'c', 'e', // Will Topic + 0x00, 0x07, 'o', 'f', 'f', 'l', 'i', 'n', 'e', // Will Message + 0x00, 0x04, 'u', 's', 'e', 'r', // Username + 0x00, 0x09, 's', 'e', 'c', 'r', 'e', 't', '1', '2', '3', // Password + } + + // Invalid packet (not a CONNECT packet) + //notConnectPacket := []byte{0x20, 0x02, 0x00, 0x00} // A CONNACK packet + + // Partial packet (bad length) + partialPacket := []byte{ + 0x10, 0x0B, // Remaining length is 11, but we provide more data + 0x00, 0x04, 'M', 'Q', 'T', 'T', // Protocol Name: "MQ" + } + + // Trailing garbage packet (bad remaining length) + trailingGarbagePacket := []byte{ + 0x10, 0x0B, // Remaining length is 11, but we provide more data + 0x00, 0x04, 'M', 'Q', 'T', 'T', // Protocol Name: "MQTT" + 0x04, // Protocol Version: 4 (v3.1.1) + 0b00000010, // Flags: Clean Session=true + 0x00, 0x3C, // Keep Alive: 60 seconds + 0x00, 0x07, 'm', 'a', 'z', 'e', '.', 'i', 'o', // Client ID: "maze.io" + 0x00, 0x01, 0x02, 0x03, // garbage + } + + tests := []*testCase{ + { + Name: "MQTT simple packet", + Direction: Client, + Data: validSimplePacket, + DstPort: 1883, + WantProto: ProtocolMQTT, + WantConfidence: .99, + }, + { + Name: "MQTT full packet", + Direction: Client, + Data: validFullPacket, + DstPort: 1883, + WantProto: ProtocolMQTT, + WantConfidence: .99, + }, + { + Name: "MQTT partial packet", + Direction: Client, + Data: partialPacket, + DstPort: 1883, + WantProto: ProtocolMQTT, + WantConfidence: .5, + }, + { + Name: "MQTT trailing garbage packet", + Direction: Client, + Data: trailingGarbagePacket, + DstPort: 1883, + WantProto: ProtocolMQTT, + WantConfidence: .75, + }, + } + + testRunner(t, tests) +} diff --git a/protocol/protocol.go b/protocol/protocol.go index 7c7d8db..845e9e7 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -9,6 +9,7 @@ import ( const ( ProtocolDNS = "dns" ProtocolHTTP = "http" + ProtocolMQTT = "mqtt" ProtocolMySQL = "mysql" ProtocolPostgreSQL = "postgresql" ProtocolSSH = "ssh"