diff --git a/cmd/dpi-protocol-probe/main.go b/cmd/dpi-protocol-probe/main.go index 65c35a5..9af0960 100644 --- a/cmd/dpi-protocol-probe/main.go +++ b/cmd/dpi-protocol-probe/main.go @@ -41,5 +41,5 @@ func main() { os.Exit(3) } - fmt.Printf("Protocol at address %q is %s version %s (confidence %g%%)\n", address, protocol.Name, protocol.Version, confidence*100) + fmt.Printf("Protocol at address %q is %s version %s (confidence %g%%)\n", address, protocol.Type, protocol.Version, confidence*100) } diff --git a/cmd/protodial/main.go b/cmd/protodial/main.go index 923a1ec..433bb1b 100644 --- a/cmd/protodial/main.go +++ b/cmd/protodial/main.go @@ -42,11 +42,11 @@ func main() { if p == nil { return errors.New("no protocol detected") } - if !accept[p.Name] { - return fmt.Errorf("protocol %s is not accepted", p.Name) + if !accept[p.Type] { + return fmt.Errorf("protocol %s is not accepted", p.Type) } fmt.Fprintf(os.Stderr, "Accepting protocol %s version %s initiated by %s\n", - p.Name, p.Version, dir) + p.Type, p.Version, dir) return nil }) defer func() { _ = c.Close() }() diff --git a/cmd/protoproxy/main.go b/cmd/protoproxy/main.go index 7ab5063..69269e5 100644 --- a/cmd/protoproxy/main.go +++ b/cmd/protoproxy/main.go @@ -81,7 +81,7 @@ func proxy(client net.Conn, target string) { log.Printf("protocol detection failed: %v", result.Error) } else { log.Printf("detected protocol %s version %s initiated by %s", - result.Protocol.Name, result.Protocol.Version, result.Direction) + result.Protocol.Type, result.Protocol.Version, result.Direction) } // Wait for the multiplexing to finish. diff --git a/protocol/detect_http.go b/protocol/detect_http.go index f656ce2..92620d8 100644 --- a/protocol/detect_http.go +++ b/protocol/detect_http.go @@ -31,7 +31,7 @@ func detectHTTPRequest(dir Direction, data []byte, srcPort, dstPort int) (proto ) if request, err := http.ReadRequest(r); err == nil { return &Protocol{ - Name: ProtocolHTTP, + Type: TypeHTTP, Version: Version{ Major: request.ProtoMajor, Minor: request.ProtoMinor, @@ -66,7 +66,7 @@ func detectHTTPRequest(dir Direction, data []byte, srcPort, dstPort int) (proto _, _ = fmt.Sscanf(string(part[2]), "HTTP/%d.%d ", &version.Major, &version.Minor) return &Protocol{ - Name: ProtocolHTTP, + Type: TypeHTTP, Version: version, }, confidence + .75 } @@ -95,7 +95,7 @@ func detectHTTPResponse(dir Direction, data []byte, srcPort, dstPort int) (proto ) if response, err := http.ReadResponse(r, nil); err == nil { return &Protocol{ - Name: ProtocolHTTP, + Type: TypeHTTP, Version: Version{ Major: response.ProtoMajor, Minor: response.ProtoMinor, @@ -110,7 +110,7 @@ func detectHTTPResponse(dir Direction, data []byte, srcPort, dstPort int) (proto _, _ = fmt.Sscanf(string(data), "HTTP/%d.%d ", &version.Major, &version.Minor) return &Protocol{ - Name: ProtocolHTTP, + Type: TypeHTTP, Version: version, }, confidence + .75 } diff --git a/protocol/detect_http_test.go b/protocol/detect_http_test.go index dcbf890..7fa4dd2 100644 --- a/protocol/detect_http_test.go +++ b/protocol/detect_http_test.go @@ -22,7 +22,7 @@ func TestDetectHTTPRequest(t *testing.T) { Direction: Client, Data: http10Request, DstPort: 80, - WantProto: ProtocolHTTP, + WantType: TypeHTTP, WantConfidence: .95, }, { @@ -30,7 +30,7 @@ func TestDetectHTTPRequest(t *testing.T) { Direction: Client, Data: getRequest, DstPort: 80, - WantProto: ProtocolHTTP, + WantType: TypeHTTP, WantConfidence: .95, }, { @@ -81,21 +81,21 @@ func TestDetectHTTPResponse(t *testing.T) { Direction: Server, Data: http10Response, SrcPort: 80, - WantProto: ProtocolHTTP, + WantType: TypeHTTP, }, { Name: "HTTP/1.1 200", Direction: Server, Data: responseOK, SrcPort: 80, - WantProto: ProtocolHTTP, + WantType: TypeHTTP, }, { Name: "HTTP/1.1 404", Direction: Server, Data: responseNotFound, SrcPort: 80, - WantProto: ProtocolHTTP, + WantType: TypeHTTP, }, { Name: "Invalid HTTP/1.1 GET", diff --git a/protocol/detect_mqtt.go b/protocol/detect_mqtt.go index 24d3733..2e68d69 100644 --- a/protocol/detect_mqtt.go +++ b/protocol/detect_mqtt.go @@ -32,7 +32,7 @@ func detectMQTT(dir Direction, data []byte, srcPort, dstPort int) (proto *Protoc // We are reasonabily sure this is MQTT now. proto = &Protocol{ - Name: ProtocolMQTT, + Type: TypeMQTT, } confidence = 0.5 diff --git a/protocol/detect_mqtt_test.go b/protocol/detect_mqtt_test.go index 29a174a..f3dd964 100644 --- a/protocol/detect_mqtt_test.go +++ b/protocol/detect_mqtt_test.go @@ -53,7 +53,7 @@ func TestDetectMQTT(t *testing.T) { Direction: Client, Data: validSimplePacket, DstPort: 1883, - WantProto: ProtocolMQTT, + WantType: TypeMQTT, WantConfidence: .99, }, { @@ -61,7 +61,7 @@ func TestDetectMQTT(t *testing.T) { Direction: Client, Data: validFullPacket, DstPort: 1883, - WantProto: ProtocolMQTT, + WantType: TypeMQTT, WantConfidence: .99, }, { @@ -69,7 +69,7 @@ func TestDetectMQTT(t *testing.T) { Direction: Client, Data: partialPacket, DstPort: 1883, - WantProto: ProtocolMQTT, + WantType: TypeMQTT, WantConfidence: .5, }, { @@ -77,7 +77,7 @@ func TestDetectMQTT(t *testing.T) { Direction: Client, Data: trailingGarbagePacket, DstPort: 1883, - WantProto: ProtocolMQTT, + WantType: TypeMQTT, WantConfidence: .75, }, } diff --git a/protocol/detect_mysql.go b/protocol/detect_mysql.go index 638bba4..3a82bdd 100644 --- a/protocol/detect_mysql.go +++ b/protocol/detect_mysql.go @@ -49,7 +49,7 @@ func detectMySQL(dir Direction, data []byte, srcPort, dstPort int) (proto *Proto _, _ = fmt.Sscanf(string(data[1:serverVersionEndPos]), "%d.%d.%d-%s", &version.Major, &version.Minor, &version.Patch, &version.Extra) return &Protocol{ - Name: ProtocolMySQL, + Type: TypeMySQL, Version: version, }, confidence + .75 } diff --git a/protocol/detect_mysql_test.go b/protocol/detect_mysql_test.go index 277b407..9538785 100644 --- a/protocol/detect_mysql_test.go +++ b/protocol/detect_mysql_test.go @@ -43,7 +43,7 @@ func TestDetectMySQL(t *testing.T) { Direction: Server, Data: mysql8Banner, SrcPort: 3306, - WantProto: ProtocolMySQL, + WantType: TypeMySQL, WantConfidence: .85, }, { @@ -51,7 +51,7 @@ func TestDetectMySQL(t *testing.T) { Direction: Server, Data: mariaDBBanner, SrcPort: 3306, - WantProto: ProtocolMySQL, + WantType: TypeMySQL, WantConfidence: .85, }, { diff --git a/protocol/detect_postgres.go b/protocol/detect_postgres.go index d6405ce..e70a660 100644 --- a/protocol/detect_postgres.go +++ b/protocol/detect_postgres.go @@ -40,7 +40,7 @@ func detectPostgreSQLClient(dir Direction, data []byte, srcPort, dstPort int) (p minor := int(binary.BigEndian.Uint16(data[6:])) if major == 2 || major == 3 { return &Protocol{ - Name: ProtocolPostgreSQL, + Type: TypePostgreSQL, Version: Version{ Major: major, Minor: minor, @@ -70,7 +70,7 @@ func detectPostgreSQLServer(dir Direction, data []byte, srcPort, dstPort int) (p 'Z', // ReadyForQuery 'E', // ErrorResponse 'N': // NoticeResponse - return &Protocol{Name: ProtocolPostgreSQL}, confidence + .65 + return &Protocol{Type: TypePostgreSQL}, confidence + .65 default: return nil, 0 diff --git a/protocol/detect_postgres_test.go b/protocol/detect_postgres_test.go index f1f5cba..2fd3839 100644 --- a/protocol/detect_postgres_test.go +++ b/protocol/detect_postgres_test.go @@ -36,7 +36,7 @@ func TestDetectPostgreSQLClient(t *testing.T) { Direction: Client, Data: pgClientStartup, DstPort: 5432, - WantProto: ProtocolPostgreSQL, + WantType: TypePostgreSQL, WantConfidence: .85, }, { @@ -77,7 +77,7 @@ func TestDetectPostgreSQLServer(t *testing.T) { Direction: Server, Data: pgServerAuthOK, DstPort: 5432, - WantProto: ProtocolPostgreSQL, + WantType: TypePostgreSQL, WantConfidence: .65, }, { @@ -95,9 +95,9 @@ func TestDetectPostgreSQLServer(t *testing.T) { t.Fatal(err) return } - t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) - if p.Name != ProtocolPostgreSQL { - t.Fatalf("expected postgres protocol, got %s", p.Name) + t.Logf("detected %s version %s confidence %g%%", p.Type, p.Version, c*100) + if p.Type != TypePostgreSQL { + t.Fatalf("expected postgres protocol, got %s", p.Type) return } }) @@ -108,9 +108,9 @@ func TestDetectPostgreSQLServer(t *testing.T) { t.Fatal(err) return } - t.Logf("detected %s version %s confidence %g%%", p.Name, p.Version, c*100) - if p.Name != ProtocolPostgreSQL { - t.Fatalf("expected postgres protocol, got %s", p.Name) + t.Logf("detected %s version %s confidence %g%%", p.Type, p.Version, c*100) + if p.Type != TypePostgreSQL { + t.Fatalf("expected postgres protocol, got %s", p.Type) return } }) diff --git a/protocol/detect_ssh.go b/protocol/detect_ssh.go index 2265942..8fe309f 100644 --- a/protocol/detect_ssh.go +++ b/protocol/detect_ssh.go @@ -58,7 +58,7 @@ func detectSSH(dir Direction, data []byte, srcPort, dstPort int) (proto *Protoco } } return &Protocol{ - Name: ProtocolSSH, + Type: TypeSSH, Version: Version{ Major: 2, Minor: 0, @@ -78,7 +78,7 @@ func detectSSH(dir Direction, data []byte, srcPort, dstPort int) (proto *Protoco } } return &Protocol{ - Name: ProtocolSSH, + Type: TypeSSH, Version: Version{ Major: 1, Minor: 99, diff --git a/protocol/detect_ssh_test.go b/protocol/detect_ssh_test.go index 9d10abf..912637b 100644 --- a/protocol/detect_ssh_test.go +++ b/protocol/detect_ssh_test.go @@ -36,7 +36,7 @@ func TestDetectSSH(t *testing.T) { Direction: Client, Data: openSSHBanner, DstPort: 22, - WantProto: ProtocolSSH, + WantType: TypeSSH, WantConfidence: .95, }, { @@ -44,7 +44,7 @@ func TestDetectSSH(t *testing.T) { Direction: Server, Data: openSSHBanner, SrcPort: 22, - WantProto: ProtocolSSH, + WantType: TypeSSH, WantConfidence: .95, }, { @@ -52,7 +52,7 @@ func TestDetectSSH(t *testing.T) { Direction: Server, Data: preBannerSSH, SrcPort: 22, - WantProto: ProtocolSSH, + WantType: TypeSSH, WantConfidence: .95, }, { @@ -60,7 +60,7 @@ func TestDetectSSH(t *testing.T) { Direction: Server, Data: dropbearBanner, SrcPort: 22, - WantProto: ProtocolSSH, + WantType: TypeSSH, WantConfidence: .95, }, { diff --git a/protocol/detect_tls.go b/protocol/detect_tls.go index 2d64892..5623504 100644 --- a/protocol/detect_tls.go +++ b/protocol/detect_tls.go @@ -1,6 +1,9 @@ package protocol import ( + "slices" + "strings" + "golang.org/x/crypto/cryptobyte" "git.maze.io/go/dpi" @@ -11,12 +14,18 @@ func init() { } func registerTLS() { - Register(Both, "\x16\x03\x00", detectTLS) // SSLv3 - Register(Both, "\x16\x03\x01", detectTLS) // TLSv1.0 - Register(Both, "\x16\x03\x02", detectTLS) // TLSv1.1 - Register(Both, "\x16\x03\x03", detectTLS) // TLSv1.2 + Register(Both, "\x16\x03\x00", detectTLS) // SSL 3.0 + Register(Both, "\x16\x03\x01", detectTLS) // TLS 1.0 + Register(Both, "\x16\x03\x02", detectTLS) // TLS 1.1 + Register(Both, "\x16\x03\x03", detectTLS) // TLS 1.2 } +const ( + tlsRecordTypeHandshake uint8 = 22 + tlsTypeClientHello uint8 = 1 + tlsTypeServerHello uint8 = 2 +) + func detectTLS(dir Direction, data []byte, _, _ int) (proto *Protocol, confidence float64) { stream := cryptobyte.String(data) @@ -44,27 +53,53 @@ func detectTLS(dir Direction, data []byte, _, _ int) (proto *Protocol, confidenc // Initial confidence confidence = 0.5 - // Detected SSL/TLS version - var version dpi.TLSVersion + // Detected SSL/TLS tlsVersion + var tlsVersion dpi.TLSVersion // Attempt to decode the full TLS Client Hello handshake - if version == 0 { - if hello, err := dpi.DecodeTLSClientHelloHandshake(data); err == nil { - version = hello.Version - confidence += .45 - } - } + if tlsVersion == 0 { + if record, err := dpi.DecodeTLSRecord(data); err == nil && record.Type == tlsRecordTypeHandshake && len(record.Data) > 1 { + tlsVersion = record.Version + confidence += .15 - // Attempt to decode the full TLS Server Hello handshake - if version == 0 { - if hello, err := dpi.DecodeTLSServerHello(data); err == nil { - version = hello.Version - confidence += .45 + switch record.Data[0] { + case tlsTypeClientHello: // TLS ClientHello + if hello, err := dpi.DecodeTLSClientHello(record.Data); err == nil { + tlsVersion = hello.Version + confidence += .3 + + slices.SortStableFunc(hello.ALPNProtocols, func(a, b string) int { + return strings.Compare(b, a) + }) + + for _, id := range hello.ALPNProtocols { + if proto = ALPNProtocol[id]; proto != nil { + return + } + } + } + + case tlsTypeServerHello: // TLS ServerHello + if hello, err := dpi.DecodeTLSServerHello(record.Data); err == nil { + tlsVersion = hello.Version + confidence += .3 + + slices.SortStableFunc(hello.ALPNProtocols, func(a, b string) int { + return strings.Compare(b, a) + }) + + for _, id := range hello.ALPNProtocols { + if proto = ALPNProtocol[id]; proto != nil { + return + } + } + } + } } } // Attempt to decode at least the handshake protocol and version. - if version == 0 && !Strict { + if tlsVersion == 0 && !Strict { var handshakeType uint8 if stream.ReadUint8(&handshakeType) && (handshakeType == 1 || handshakeType == 2) { var ( @@ -72,33 +107,68 @@ func detectTLS(dir Direction, data []byte, _, _ int) (proto *Protocol, confidenc versionWord uint16 ) if stream.ReadUint24(&length) && stream.ReadUint16(&versionWord) { - version = dpi.TLSVersion(versionWord) + tlsVersion = dpi.TLSVersion(versionWord) confidence += .25 } } } // Fall back to the version in the TLS record header, this is less accurate - if version == 0 && !Strict { - version = dpi.TLSVersion(header.Version) + if tlsVersion == 0 && !Strict { + tlsVersion = dpi.TLSVersion(header.Version) } // We're "multi protocol", in that SSL is its own protocol - if version == dpi.VersionSSL30 { + if tlsVersion == dpi.VersionSSL30 { return &Protocol{ - Name: ProtocolSSL, + Type: TypeSSL, Version: Version{Major: 3, Minor: 0, Patch: -1}, }, confidence - } else if version >= dpi.VersionTLS10 && version <= dpi.VersionTLS13 { + } else if tlsVersion >= dpi.VersionTLS10 && tlsVersion <= dpi.VersionTLS13 { return &Protocol{ - Name: ProtocolTLS, - Version: Version{Major: 1, Minor: int(uint8(version) - 1), Patch: -1}, + Type: TypeTLS, + Version: Version{Major: 1, Minor: int(uint8(tlsVersion) - 1), Patch: -1}, }, confidence - } else if version >= dpi.VersionTLS13Draft && version <= dpi.VersionTLS13Draft23 { + } else if tlsVersion >= dpi.VersionTLS13Draft && tlsVersion <= dpi.VersionTLS13Draft23 { return &Protocol{ - Name: ProtocolTLS, + Type: TypeTLS, Version: Version{Major: 1, Minor: 3, Patch: -1}, }, confidence } return nil, 0 } + +// ALPNProtocol is a map of TLS Application-Layer Protocol Negotiation (ALPN) Protocol identifier to [Protocol]. +// +// See https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids\ +var ALPNProtocol = map[string]*Protocol{ + "acme-tls/1": {Type: TypeACME, Encapsulation: TypeTLS, Version: Version{Major: 1}}, // ACME-TLS/1 + "co": {Type: TypeCoAP, Encapsulation: TypeTLS}, // CoAP (over DTLS) + "coap": {Type: TypeCoAP, Encapsulation: TypeTLS}, // CoAP (over TLS) + "c-webrtc": {Type: TypeWebRTC, Encapsulation: TypeTLS}, // Confidential WebRTC Media and Data + "dot": {Type: TypeDNS, Encapsulation: TypeTLS}, // DNS-over-TLS + "ftp": {Type: TypeFTP, Encapsulation: TypeTLS}, // FTP + "http/0.9": {Type: TypeHTTP, Encapsulation: TypeTLS, Version: Version{Major: 0, Minor: 9, Patch: -1}}, // HTTP/0.9 + "http/1.0": {Type: TypeHTTP, Encapsulation: TypeTLS, Version: Version{Major: 1, Minor: 0, Patch: -1}}, // HTTP/1.0 + "http/1.1": {Type: TypeHTTP, Encapsulation: TypeTLS, Version: Version{Major: 1, Minor: 1, Patch: -1}}, // HTTP/1.1 + "h2": {Type: TypeHTTP, Encapsulation: TypeTLS, Version: Version{Major: 2, Minor: -1, Patch: -1}}, // HTTP/2 (over TLS) + "h2c": {Type: TypeHTTP, Encapsulation: TypeTLS, Version: Version{Major: 2, Minor: -1, Patch: -1}}, // HTTP/2 (over TCP) + "h3": {Type: TypeHTTP, Encapsulation: TypeTLS, Version: Version{Major: 3, Minor: -1, Patch: -1}}, // HTTP/3 + "irc": {Type: TypeIRC, Encapsulation: TypeTLS}, // IRC + "imap": {Type: TypeIMAP, Encapsulation: TypeTLS}, // IMAP + "managesieve": {Type: TypeManageSieve, Encapsulation: TypeTLS}, // ManageSieve + "mqtt": {Type: TypeMQTT, Encapsulation: TypeTLS}, // MQTT + "nntp": {Type: TypeNNTP, Encapsulation: TypeTLS}, // NNTP (reading) + "nnsp": {Type: TypeNNTP, Encapsulation: TypeTLS}, // NNTP (transit) + "postgresql": {Type: TypePostgreSQL, Encapsulation: TypeTLS}, // PostgreSQL + "pop3": {Type: TypePOP3, Encapsulation: TypeTLS}, // POP3 + "radius/1.0": {Type: TypeRADIUS, Encapsulation: TypeTLS, Version: Version{Major: 1, Minor: 0, Patch: -1}}, // RADIUS/1.0 + "radius/1.1": {Type: TypeRADIUS, Encapsulation: TypeTLS, Version: Version{Major: 1, Minor: 1, Patch: -1}}, // RADIUS/1.1 + "smb": {Type: TypeSMB, Encapsulation: TypeTLS, Version: Version{Major: 2, Minor: -1, Patch: -1}}, // SMB2 + "stun.nat-discovery": {Type: TypeSTUN, Encapsulation: TypeTLS}, // NAT discovery using Session Traversal Utilities for NAT (STUN) + "stun.turn": {Type: TypeSTUN, Encapsulation: TypeTLS}, // Traversal Using Relays around NAT (TURN) + "sunrpc": {Type: TypeSunRPC, Encapsulation: TypeTLS}, // SunRPC + "webrtc": {Type: TypeWebRTC, Encapsulation: TypeTLS}, // WebRTC Media and Data + "xmpp-client": {Type: TypeXMPP, Encapsulation: TypeTLS}, // XMPP jabber:client namespace + "xmpp-server": {Type: TypeXMPP, Encapsulation: TypeTLS}, // XMPP jabber:server namespace +} diff --git a/protocol/detect_tls_test.go b/protocol/detect_tls_test.go index b5dec4c..a3a5cac 100644 --- a/protocol/detect_tls_test.go +++ b/protocol/detect_tls_test.go @@ -76,6 +76,45 @@ func TestDetectTLS(t *testing.T) { 0x03, 0x02, // Client Version: TLS 1.1 (Major=3, Minor=2) } + // A valid TLS 1.1 ServerHello + tls11ServerHello := []byte{ + // --- Record Layer (5 bytes) --- + 0x16, // Content Type: Handshake (22) + 0x03, 0x02, // Version: TLS 1.1 (for compatibility in a 1.3 hello) + 0x00, 0x40, // Length of the handshake message below (64 bytes) + + // --- Handshake Protocol: ServerHello (64 bytes) --- + 0x02, // Handshake Type: ServerHello (2) + 0x00, 0x00, 0x3c, // Length of the rest of the message (60 bytes) + 0x03, 0x02, // Server Version: TLS 1.1 (Major=3, Minor=2) + + // Random (32 bytes) + 0xb7, 0xa8, 0xdf, 0xd5, 0x17, 0xb1, 0x50, 0xb4, + 0x28, 0xb7, 0xf6, 0xf3, 0xb9, 0x83, 0xcf, 0x9f, + 0x31, 0x55, 0x79, 0x1f, 0x3b, 0x07, 0x6d, 0x17, + 0x44, 0x4f, 0x57, 0x4e, 0x47, 0x52, 0x44, 0x00, + + // Session ID + 0x00, // Session ID Length: 0 (new session) + + 0xc0, 0x09, // Cipher Suite + 0x00, // Compression Method + + // --- Extensions (20 bytes) --- + 0x00, 0x14, // Extensions Length: 20 bytes + 0xff, 0x01, // Extension Type: renegotiation_info + 0x00, 0x01, // Extension Length: 1 byte + 0x00, // Renegotiation Info Length: 0 bytes + 0x00, 0x10, // Extension Type: alpn + 0x00, 0x05, // Extension Length: 5 bytes + 0x00, 0x03, // ALPN Extension Length: 3 bytes + 0x02, 'h', '2', + 0x0, 0x0b, // Extension Type: ec_points_format + 0x0, 0x02, // Extension Length: 2 bytes + 0x01, // EC Points Format Length: 1 byte + 0x00, // EC Point Format: uncompressed + } + // A synthesized TLSv1.2 ClientHello tls12ClientHello := []byte{ // --- Record Layer (5 bytes) --- @@ -129,35 +168,62 @@ func TestDetectTLS(t *testing.T) { // A valid TLS 1.3 Client Hello (captured from a real connection) tls13ClientHello := []byte{ + // --- Record Layer (5 bytes) --- 0x16, // Content Type: Handshake (22) 0x03, 0x01, // Version: TLS 1.0 (for compatibility in a 1.3 hello) 0x01, 0x3a, // Length - 0x01, 0x00, 0x01, 0x36, 0x03, 0x03, 0xb1, - 0x40, 0xd3, 0xf1, 0x7d, 0xa3, 0xb8, 0x33, 0xac, 0xad, 0x21, 0x79, 0x9c, - 0xbe, 0x39, 0x96, 0x08, 0x49, 0x3b, 0x53, 0x75, 0xa0, 0x1b, 0xee, 0x6e, - 0x6a, 0xbe, 0x6c, 0x41, 0xdf, 0x6c, 0xf4, 0x20, 0xa4, 0xaa, 0x0c, 0xca, - 0xd4, 0x37, 0x76, 0x5f, 0x49, 0xc6, 0x06, 0x9b, 0xac, 0x90, 0x89, 0x76, - 0x1c, 0xc7, 0xc4, 0x12, 0xb4, 0x4a, 0xe0, 0x27, 0x72, 0x89, 0x97, 0x85, - 0x76, 0xf8, 0xc8, 0x83, 0x00, 0x62, 0x13, 0x03, 0x13, 0x02, 0x13, 0x01, - 0xcc, 0xa9, 0xcc, 0xa8, 0xcc, 0xaa, 0xc0, 0x30, 0xc0, 0x2c, 0xc0, 0x28, - 0xc0, 0x24, 0xc0, 0x14, 0xc0, 0x0a, 0x00, 0x9f, 0x00, 0x6b, 0x00, 0x39, - 0xff, 0x85, 0x00, 0xc4, 0x00, 0x88, 0x00, 0x81, 0x00, 0x9d, 0x00, 0x3d, - 0x00, 0x35, 0x00, 0xc0, 0x00, 0x84, 0xc0, 0x2f, 0xc0, 0x2b, 0xc0, 0x27, - 0xc0, 0x23, 0xc0, 0x13, 0xc0, 0x09, 0x00, 0x9e, 0x00, 0x67, 0x00, 0x33, - 0x00, 0xbe, 0x00, 0x45, 0x00, 0x9c, 0x00, 0x3c, 0x00, 0x2f, 0x00, 0xba, - 0x00, 0x41, 0xc0, 0x11, 0xc0, 0x07, 0x00, 0x05, 0x00, 0x04, 0xc0, 0x12, - 0xc0, 0x08, 0x00, 0x16, 0x00, 0x0a, 0x00, 0xff, 0x01, 0x00, 0x00, 0x8b, - 0x00, 0x2b, 0x00, 0x09, 0x08, 0x03, 0x04, 0x03, 0x03, 0x03, 0x02, 0x03, - 0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x2c, - 0x4b, 0xaa, 0xb4, 0xb3, 0xc8, 0x93, 0xcd, 0x5c, 0x24, 0xb9, 0x9b, 0xd4, - 0x59, 0x04, 0xfe, 0x69, 0xaf, 0x68, 0xb9, 0xa6, 0x36, 0xbb, 0xab, 0x87, - 0xfa, 0x15, 0x59, 0xea, 0xdd, 0x38, 0x68, 0x00, 0x00, 0x00, 0x0e, 0x00, - 0x0c, 0x00, 0x00, 0x09, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73, - 0x74, 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00, 0x00, 0x0a, 0x00, 0x0a, 0x00, - 0x08, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, 0x00, 0x0d, 0x00, - 0x18, 0x00, 0x16, 0x08, 0x06, 0x06, 0x01, 0x06, 0x03, 0x08, 0x05, 0x05, - 0x01, 0x05, 0x03, 0x08, 0x04, 0x04, 0x01, 0x04, 0x03, 0x02, 0x01, 0x02, - 0x03, 0x00, 0x10, 0x00, 0x0e, 0x00, 0x0c, 0x02, 0x68, 0x32, 0x08, 0x68, + + // --- Handshake Protocol: ClientHello --- + 0x01, // Handshake Type: ClientHello (1) + 0x00, 0x01, 0x36, // Length of the rest of the message + 0x03, 0x03, // Client Version: TLS 1.2 (Major=3, Minor=3) + + // Random (32 bytes) + 0xb1, 0x40, 0xd3, 0xf1, 0x7d, 0xa3, 0xb8, 0x33, + 0xac, 0xad, 0x21, 0x79, 0x9c, 0xbe, 0x39, 0x96, + 0x08, 0x49, 0x3b, 0x53, 0x75, 0xa0, 0x1b, 0xee, + 0x6e, 0x6a, 0xbe, 0x6c, 0x41, 0xdf, 0x6c, 0xf4, + + // Session ID + 0x20, // Session ID Length: 32 + 0xa4, 0xaa, 0x0c, 0xca, 0xd4, 0x37, 0x76, 0x5f, // + 0x49, 0xc6, 0x06, 0x9b, 0xac, 0x90, 0x89, 0x76, // + 0x1c, 0xc7, 0xc4, 0x12, 0xb4, 0x4a, 0xe0, 0x27, // + 0x72, 0x89, 0x97, 0x85, 0x76, 0xf8, 0xc8, 0x83, // + 0x00, 0x62, 0x13, 0x03, 0x13, 0x02, 0x13, 0x01, // + + // Cipher Suites + 0xcc, 0xa9, 0xcc, 0xa8, 0xcc, 0xaa, 0xc0, 0x30, + 0xc0, 0x2c, 0xc0, 0x28, 0xc0, 0x24, 0xc0, 0x14, + 0xc0, 0x0a, 0x00, 0x9f, 0x00, 0x6b, 0x00, 0x39, + 0xff, 0x85, 0x00, 0xc4, 0x00, 0x88, 0x00, 0x81, + 0x00, 0x9d, 0x00, 0x3d, 0x00, 0x35, 0x00, 0xc0, + 0x00, 0x84, 0xc0, 0x2f, 0xc0, 0x2b, 0xc0, 0x27, + 0xc0, 0x23, 0xc0, 0x13, 0xc0, 0x09, 0x00, 0x9e, + 0x00, 0x67, 0x00, 0x33, 0x00, 0xbe, 0x00, 0x45, + 0x00, 0x9c, 0x00, 0x3c, 0x00, 0x2f, 0x00, 0xba, + 0x00, 0x41, 0xc0, 0x11, 0xc0, 0x07, 0x00, 0x05, + 0x00, 0x04, 0xc0, 0x12, 0xc0, 0x08, 0x00, 0x16, + 0x00, 0x0a, 0x00, 0xff, 0x01, 0x00, 0x00, 0x8b, + 0x00, 0x2b, 0x00, 0x09, 0x08, 0x03, 0x04, 0x03, + 0x03, 0x03, 0x02, 0x03, 0x01, 0x00, 0x33, 0x00, + 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x2c, + 0x4b, 0xaa, 0xb4, 0xb3, 0xc8, 0x93, 0xcd, 0x5c, + 0x24, 0xb9, 0x9b, 0xd4, 0x59, 0x04, 0xfe, 0x69, + 0xaf, 0x68, 0xb9, 0xa6, 0x36, 0xbb, 0xab, 0x87, + 0xfa, 0x15, 0x59, 0xea, 0xdd, 0x38, 0x68, 0x00, + 0x00, 0x00, 0x0e, 0x00, 0x0c, 0x00, 0x00, 0x09, + 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73, + 0x74, 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00, 0x00, + 0x0a, 0x00, 0x0a, 0x00, 0x08, 0x00, 0x1d, 0x00, + 0x17, 0x00, 0x18, 0x00, 0x19, 0x00, 0x0d, 0x00, + 0x18, 0x00, 0x16, 0x08, 0x06, 0x06, 0x01, 0x06, + 0x03, 0x08, 0x05, 0x05, + + // Compression Methods, etc. + 0x01, 0x05, 0x03, 0x08, 0x04, 0x04, 0x01, 0x04, + 0x03, 0x02, 0x01, 0x02, 0x03, 0x00, 0x10, 0x00, + 0x0e, 0x00, 0x0c, 0x02, 0x68, 0x32, 0x08, 0x68, 0x74, 0x74, 0x70, 0x2f, 0x31, 0x2e, 0x31, } @@ -172,15 +238,23 @@ func TestDetectTLS(t *testing.T) { Direction: Client, Data: sslV3ClientHello, DstPort: 443, - WantProto: ProtocolSSL, + WantType: TypeSSL, WantConfidence: .95, }, { - Name: "TLS 1.1", + Name: "TLS 1.1 ClientHello", Direction: Client, Data: tls11ClientHello, DstPort: 443, - WantProto: ProtocolTLS, + WantType: TypeTLS, + WantConfidence: .95, + }, + { + Name: "TLS 1.1 ServerHello", + Direction: Server, + Data: tls11ServerHello, + SrcPort: 443, + WantType: TypeHTTP, WantConfidence: .95, }, { @@ -188,7 +262,7 @@ func TestDetectTLS(t *testing.T) { Direction: Client, Data: tls12ClientHello, DstPort: 443, - WantProto: ProtocolTLS, + WantType: TypeHTTP, WantConfidence: .95, }, { @@ -196,7 +270,7 @@ func TestDetectTLS(t *testing.T) { Direction: Client, Data: tls13ClientHello, DstPort: 443, - WantProto: ProtocolTLS, + WantType: TypeHTTP, WantConfidence: .95, }, { @@ -225,7 +299,7 @@ func TestDetectTLS(t *testing.T) { Direction: Client, Data: tls11ClientHelloPartial, DstPort: 443, - WantProto: ProtocolTLS, + WantType: TypeTLS, WantConfidence: .50, }, }, tests...)) diff --git a/protocol/detest_test.go b/protocol/detest_test.go index a61dcb1..c92449c 100644 --- a/protocol/detest_test.go +++ b/protocol/detest_test.go @@ -16,7 +16,7 @@ type testCase struct { Data []byte SrcPort int DstPort int - WantProto string + WantType string WantConfidence float64 WantError error } @@ -52,7 +52,7 @@ func testRunner(t *testing.T, tests []*testCase) { } else if test.WantError != nil { t.Fatalf("Detect(%s, %s, %d, %d) returned protocol %q version %s, expected error %q", test.Direction, testBytesSample(test.Data, 8), test.SrcPort, test.DstPort, - proto.Name, proto.Version, test.WantError) + proto.Type, proto.Version, test.WantError) return } @@ -60,16 +60,16 @@ func testRunner(t *testing.T, tests []*testCase) { if proto == nil { t.Fatalf("Detect(%s, %s, %d, %d) returned nil, expected protocol %q", test.Direction, testBytesSample(test.Data, 8), test.SrcPort, test.DstPort, - test.WantProto) + test.WantType) return } - t.Logf("Detect(%s, %s, %d, %d) returned protocol %q version %s with confidence %g%%", + t.Logf("Detect(%s, %s, %d, %d) returned protocol %s with confidence %g%%", test.Direction, testBytesSample(test.Data, 4), test.SrcPort, test.DstPort, - proto.Name, proto.Version, confidence*100) + proto, confidence*100) - if proto.Name != test.WantProto { - t.Errorf("Expected protocol %q", test.WantProto) + if proto.Type != test.WantType { + t.Errorf("Expected protocol %q, got %q", test.WantType, proto.Type) } if !testAlmostEqual(confidence, test.WantConfidence) { t.Errorf("Expected confidence %g%%", test.WantConfidence*100) diff --git a/protocol/protocol.go b/protocol/protocol.go index 845e9e7..452ef35 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -7,21 +7,63 @@ import ( // Protocols supported by this package. const ( - ProtocolDNS = "dns" - ProtocolHTTP = "http" - ProtocolMQTT = "mqtt" - ProtocolMySQL = "mysql" - ProtocolPostgreSQL = "postgresql" - ProtocolSSH = "ssh" - ProtocolSSL = "ssl" - ProtocolTLS = "tls" + TypeACME = "ACME" + TypeCoAP = "CoAP" + TypeDNS = "DNS" + TypeFTP = "FTP" + TypeHTTP = "HTTP" + TypeIRC = "IRC" + TypeIMAP = "IMAP" + TypeJabber = TypeXMPP + TypeManageSieve = "ManageSieve" + TypeMosquitto = TypeMQTT + TypeMQTT = "MQTT" + TypeMySQL = "MySQL" + TypeNNTP = "NNTP" + TypePOP3 = "POP3" + TypePgSQL = TypePostgreSQL + TypePostgreSQL = "PostgreSQL" + TypeRADIUS = "RADIUS" + TypeSamba = TypeSMB + TypeSIP = "SIP" + TypeSMB = "SMB" + TypeSSH = "SSH" + TypeSSL = "SSL" + TypeSTUN = "STUN" + TypeSunRPC = "SunRPC" + TypeTLS = "TLS" + TypeWebRTC = "WebRTC" + TypeXMPP = "XMPP" ) +// Protocol description. type Protocol struct { - Name string + // Type of protocol, usually one of the constants defined in this package. + Type string + + // Encapsulation type, usually one of the constants defined in this package. + // + // Empty if there is no encapsulation. + Encapsulation string + + // Version of the protocol. Unknown versions are marked with [UnknownVersion]. Version Version } +func (proto Protocol) String() string { + var s string + if proto.Encapsulation != "" { + s = proto.Type + " (over " + proto.Encapsulation + ")" + } else { + s = proto.Type + } + if proto.Version == UnknownVersion { + return s + } + return s + " version " + proto.Version.String() +} + +// Version of a protocol. type Version struct { Major int Minor int @@ -29,15 +71,20 @@ type Version struct { Extra string } +// UnknownVersion +var UnknownVersion Version + func (v Version) String() string { + if v == UnknownVersion { + return "unknown" + } + p := make([]string, 0, 3) - if v.Major >= 0 { - p = append(p, strconv.Itoa(v.Major)) - if v.Minor >= 0 { - p = append(p, strconv.Itoa(v.Minor)) - if v.Patch >= 0 { - p = append(p, strconv.Itoa(v.Patch)) - } + p = append(p, strconv.Itoa(v.Major)) + if v.Minor >= 0 { + p = append(p, strconv.Itoa(v.Minor)) + if v.Patch >= 0 { + p = append(p, strconv.Itoa(v.Patch)) } } s := strings.Join(p, ".") diff --git a/testdata/tls13-clienthello.bin b/testdata/dump/tls13-clienthello.bin similarity index 100% rename from testdata/tls13-clienthello.bin rename to testdata/dump/tls13-clienthello.bin diff --git a/tls.go b/tls.go index f02e774..037feb4 100644 --- a/tls.go +++ b/tls.go @@ -19,7 +19,7 @@ type TLSExtension struct { type TLSRecord struct { Raw []byte Type uint8 - Version uint16 + Version TLSVersion Length uint16 Data []byte } @@ -30,26 +30,24 @@ func DecodeTLSRecord(data []byte) (*TLSRecord, error) { record = &TLSRecord{Raw: data} ) + var version uint16 if !stream.ReadUint8(&record.Type) || - !stream.ReadUint16(&record.Version) || + !stream.ReadUint16(&version) || !stream.ReadUint16(&record.Length) { return nil, DecodeError{ Reason: "invalid TLS record header", Err: io.ErrUnexpectedEOF, } } + record.Version = TLSVersion(version) + if !stream.ReadBytes(&record.Data, int(record.Length)) { return nil, DecodeError{ Reason: "invalid TLS record data", Err: io.ErrUnexpectedEOF, } } - if !stream.Empty() { - return nil, DecodeError{ - Reason: "extraneous data after TLS record", - Err: ErrInvalid, - } - } + return record, nil } @@ -166,12 +164,6 @@ func DecodeTLSClientHello(data []byte) (*TLSClientHello, error) { Err: io.ErrUnexpectedEOF, } } - if !record.Empty() { - return nil, DecodeError{ - Reason: "extraneous TLS extension data", - Err: io.ErrUnexpectedEOF, - } - } for !extensions.Empty() { var ( @@ -260,6 +252,7 @@ type TLSServerHello struct { CipherSuite uint16 CompressionMethod uint8 Extensions []TLSExtension + ALPNProtocols []string // RFC 7301, Section 3.1 } func DecodeTLSServerHello(data []byte) (*TLSServerHello, error) { @@ -347,12 +340,6 @@ func DecodeTLSServerHello(data []byte) (*TLSServerHello, error) { Err: io.ErrUnexpectedEOF, } } - if !record.Empty() { - return nil, DecodeError{ - Reason: "extraneous TLS extension data", - Err: io.ErrUnexpectedEOF, - } - } for !extensions.Empty() { var ( @@ -366,6 +353,11 @@ func DecodeTLSServerHello(data []byte) (*TLSServerHello, error) { } } hello.Extensions = append(hello.Extensions, extension) + + switch extension.Type { + case tlsExtensionALPN: + _ = readTLSALPN(extensionData, &hello.ALPNProtocols) + } } return hello, nil diff --git a/tls_test.go b/tls_test.go index 6224101..d8315f1 100644 --- a/tls_test.go +++ b/tls_test.go @@ -72,6 +72,39 @@ func TestDecodeTLSServerHello(t *testing.T) { t.Fatalf("failed to decode test ServerHello: %s", err) } + tls11ServerHello := []byte{ + // --- Handshake Protocol: ServerHello (64 bytes) --- + 0x02, // Handshake Type: ServerHello (2) + 0x00, 0x00, 0x3c, // Length of the rest of the message (60 bytes) + 0x03, 0x02, // Server Version: TLS 1.1 (Major=3, Minor=2) + + // Random (32 bytes) + 0xb7, 0xa8, 0xdf, 0xd5, 0x17, 0xb1, 0x50, 0xb4, + 0x28, 0xb7, 0xf6, 0xf3, 0xb9, 0x83, 0xcf, 0x9f, + 0x31, 0x55, 0x79, 0x1f, 0x3b, 0x07, 0x6d, 0x17, + 0x44, 0x4f, 0x57, 0x4e, 0x47, 0x52, 0x44, 0x00, + + // Session ID + 0x00, // Session ID Length: 0 (new session) + + 0xc0, 0x09, // Cipher Suite + 0x00, // Compression Method + + // --- Extensions (20 bytes) --- + 0x00, 0x14, // Extensions Length: 20 bytes + 0xff, 0x01, // Extension Type: renegotiation_info + 0x00, 0x01, // Extension Length: 1 byte + 0x00, // Renegotiation Info Length: 0 bytes + 0x00, 0x10, // Extension Type: alpn + 0x00, 0x05, // Extension Length: 5 bytes + 0x00, 0x03, // ALPN Extension Length: 3 bytes + 0x02, 'h', '2', + 0x0, 0x0b, // Extension Type: ec_points_format + 0x0, 0x02, // Extension Length: 2 bytes + 0x01, // EC Points Format Length: 1 byte + 0x00, // EC Point Format: uncompressed + } + t.Run("Server Hello", func(t *testing.T) { hello, err := DecodeTLSServerHello(serverHelloBytes) if err != nil { @@ -80,6 +113,15 @@ func TestDecodeTLSServerHello(t *testing.T) { } t.Logf("%#+v", hello) }) + + t.Run("TLS 1.1 Server Hello", func(t *testing.T) { + hello, err := DecodeTLSServerHello(tls11ServerHello) + if err != nil { + t.Fatal(err) + return + } + t.Logf("%#+v", hello) + }) } func testDecodeHexString(s string) ([]byte, error) {