package dpi import ( "encoding/binary" "fmt" "io" "time" "golang.org/x/crypto/cryptobyte" ) // TLSExtension is a TLS extension. type TLSExtension struct { Type uint16 Data []byte } // TLSRecord is a TLS record. type TLSRecord struct { Raw []byte Type uint8 Version uint16 Length uint16 Data []byte } func DecodeTLSRecord(data []byte) (*TLSRecord, error) { var ( stream = cryptobyte.String(data) record = &TLSRecord{Raw: data} ) if !stream.ReadUint8(&record.Type) || !stream.ReadUint16(&record.Version) || !stream.ReadUint16(&record.Length) { return nil, DecodeError{ Reason: "invalid TLS record header", Err: io.ErrUnexpectedEOF, } } if !stream.ReadBytes(&record.Data, int(record.Length)) { return nil, DecodeError{ Reason: "invalid TLS record data", Err: io.ErrUnexpectedEOF, } } if !stream.Empty() { return nil, DecodeError{ Reason: "extraneous data after TLS record", Err: ErrInvalid, } } return record, nil } // TLSClientHello is a TLS ClientHello packet as part of the TLS handshake. type TLSClientHello struct { Raw []byte Version TLSVersion Random []byte SessionID []byte CipherSuites []uint16 CompressionMethods []uint8 Extensions []TLSExtension ServerName string SupportedCurves []uint16 // aka "Supported Groups" SupportedSignatureAlgorithms []TLSSignatureScheme // RFC 5246, Section 7.4.1.4.1 ALPNProtocols []string // RFC 7301, Section 3.1 SupportedVersions []TLSVersion // RFC 8446, Section 4.2.1z } func DecodeTLSClientHello(data []byte) (*TLSClientHello, error) { var ( stream = cryptobyte.String(data) hello = &TLSClientHello{Raw: data} ) // Read header (4 bytes) var handshakeType uint8 if !stream.ReadUint8(&handshakeType) || handshakeType != tlsTypeClientHello { return nil, DecodeError{ Reason: fmt.Sprintf("expected a TLS ClientHello (0x%02X), got 0x%02X", tlsTypeClientHello, handshakeType), Err: ErrInvalid, } } var record cryptobyte.String if !stream.ReadUint24LengthPrefixed(&record) { return nil, DecodeError{ Reason: "incomplete TLS record", Err: io.ErrUnexpectedEOF, } } if !stream.Empty() { return nil, DecodeError{ Reason: "invalid TLS record length", Err: ErrInvalid, } } // Parser client version. var version uint16 if !record.ReadUint16(&version) { return nil, DecodeError{ Reason: "incomplete TLS version", Err: io.ErrUnexpectedEOF, } } hello.Version = TLSVersion(version) // Parse random (32 bytes) if !record.ReadBytes(&hello.Random, 32) { return nil, DecodeError{ Reason: "incomplete TLS random bytes", Err: io.ErrUnexpectedEOF, } } // Parse session ID var sessionID cryptobyte.String if !record.ReadUint8LengthPrefixed(&sessionID) { return nil, DecodeError{ Reason: "incomplete TLS session ID", Err: io.ErrUnexpectedEOF, } } hello.SessionID = sessionID // Parse cipher suites var cipherSuites cryptobyte.String if !record.ReadUint16LengthPrefixed(&cipherSuites) { return nil, DecodeError{ Reason: "incomplete TLS cipher suites bytes", Err: io.ErrUnexpectedEOF, } } for !cipherSuites.Empty() { var cipherSuite uint16 if !cipherSuites.ReadUint16(&cipherSuite) { return nil, DecodeError{ Reason: "incomplete TLS cipher suite", Err: io.ErrUnexpectedEOF, } } hello.CipherSuites = append(hello.CipherSuites, cipherSuite) } // Parse compression methods var compressionMethods cryptobyte.String if !record.ReadUint8LengthPrefixed(&compressionMethods) { return nil, DecodeError{ Reason: "incomplete TLS compression methods bytes", Err: io.ErrUnexpectedEOF, } } hello.CompressionMethods = compressionMethods // Parse extensions (optional) if record.Empty() { return hello, nil } var extensions cryptobyte.String if !record.ReadUint16LengthPrefixed(&extensions) { return nil, DecodeError{ Reason: "incomplete TLS extensions", Err: io.ErrUnexpectedEOF, } } if !record.Empty() { return nil, DecodeError{ Reason: "extraneous TLS extension data", Err: io.ErrUnexpectedEOF, } } for !extensions.Empty() { var ( extension TLSExtension extensionData = cryptobyte.String(extension.Data) ) if !extensions.ReadUint16(&extension.Type) || !extensions.ReadUint16LengthPrefixed(&extensionData) { return nil, DecodeError{ Reason: "incomplete TLS extension record data", Err: io.ErrUnexpectedEOF, } } hello.Extensions = append(hello.Extensions, extension) switch extension.Type { case tlsExtensionServerName: // RFC 6066, Section 3 if !readTLSServerName(extensionData, &hello.ServerName) { return nil, DecodeError{ Reason: "invalid TLS server name extension data", Err: io.ErrUnexpectedEOF, } } case tlsExtensionSupportedGroups: // RFC 4492, Section 5.1.1 // RFC 8446, Section 4.2.7 if !readTLSSupportedGroups(extensionData, &hello.SupportedCurves) { return nil, DecodeError{ Reason: "invalid TLS supported groups extension data", Err: io.ErrUnexpectedEOF, } } case tlsExtensionSignatureAlgorithms: // RFC 5246, Section 7.4.1 if !readTLSSignatureAlgorithms(extensionData, &hello.SupportedSignatureAlgorithms) { return nil, DecodeError{ Reason: "invalid TLS supported signature algorithms extension data", Err: io.ErrUnexpectedEOF, } } case tlsExtensionALPN: if !readTLSALPN(extensionData, &hello.ALPNProtocols) { return nil, DecodeError{ Reason: "invalid TLS ALPN extension data", Err: io.ErrUnexpectedEOF, } } case tlsExtensionSupportedVersions: if !readTLSSupportedVersions(extensionData, &hello.SupportedVersions) { return nil, DecodeError{ Reason: "invalid TLS supported versions extension data", Err: io.ErrUnexpectedEOF, } } } } return hello, nil } func DecodeTLSClientHelloHandshake(data []byte) (*TLSClientHello, error) { record, err := DecodeTLSRecord(data) if err != nil { return nil, err } if record.Type != tlsRecordTypeHandshake { return nil, DecodeError{ Reason: fmt.Sprintf("expected TLS handshake record type (0x%02X), got 0x%02X", tlsRecordTypeHandshake, record.Type), Err: ErrInvalid, } } return DecodeTLSClientHello(record.Data) } // TLSServerHello is a TLS ServerHello packet as part of the TLS handshake. type TLSServerHello struct { Raw []byte Version TLSVersion RandomTimestamp time.Time Random []byte SessionID []byte CipherSuite uint16 CompressionMethod uint8 Extensions []TLSExtension } func DecodeTLSServerHello(data []byte) (*TLSServerHello, error) { var ( stream = cryptobyte.String(data) hello = &TLSServerHello{Raw: data} ) // Read header (4 bytes) var handshakeType uint8 if !stream.ReadUint8(&handshakeType) || handshakeType != tlsTypeServerHello { return nil, DecodeError{ Reason: fmt.Sprintf("expected a TLS ServerHello (0x%02X), got 0x%02X", tlsTypeServerHello, handshakeType), Err: ErrInvalid, } } var record cryptobyte.String if !stream.ReadUint24LengthPrefixed(&record) { return nil, DecodeError{ Reason: "incomplete TLS record", Err: io.ErrUnexpectedEOF, } } if !stream.Empty() { return nil, DecodeError{ Reason: "invalid TLS record length", Err: ErrInvalid, } } // Parser server version. var version uint16 if !record.ReadUint16(&version) { return nil, DecodeError{ Reason: "incomplete TLS version", Err: io.ErrUnexpectedEOF, } } hello.Version = TLSVersion(version) // Parse random (32 bytes) if !record.ReadBytes(&hello.Random, 32) { return nil, DecodeError{ Reason: "incomplete TLS random bytes", Err: io.ErrUnexpectedEOF, } } hello.RandomTimestamp = time.Unix(int64(binary.BigEndian.Uint32(hello.Random)), 0) // Parse session ID var sessionID cryptobyte.String if !record.ReadUint8LengthPrefixed(&sessionID) { return nil, DecodeError{ Reason: "incomplete TLS session ID", Err: io.ErrUnexpectedEOF, } } hello.SessionID = sessionID // Parse cipher suite if !record.ReadUint16(&hello.CipherSuite) { return nil, DecodeError{ Reason: "incomplete TLS cipher suite", Err: io.ErrUnexpectedEOF, } } // Parse compression method if !record.ReadUint8(&hello.CompressionMethod) { return nil, DecodeError{ Reason: "incomplete TLS compression method", Err: io.ErrUnexpectedEOF, } } // Parse extensions (optional) if record.Empty() { return hello, nil } var extensions cryptobyte.String if !record.ReadUint16LengthPrefixed(&extensions) { return nil, DecodeError{ Reason: "incomplete TLS extensions", Err: io.ErrUnexpectedEOF, } } if !record.Empty() { return nil, DecodeError{ Reason: "extraneous TLS extension data", Err: io.ErrUnexpectedEOF, } } for !extensions.Empty() { var ( extension TLSExtension extensionData = cryptobyte.String(extension.Data) ) if !extensions.ReadUint16(&extension.Type) || !extensions.ReadUint16LengthPrefixed(&extensionData) { return nil, DecodeError{ Reason: "incomplete TLS extension record data", Err: io.ErrUnexpectedEOF, } } hello.Extensions = append(hello.Extensions, extension) } return hello, nil } func readTLSServerName(data cryptobyte.String, serverName *string) bool { var list cryptobyte.String if !data.ReadUint16LengthPrefixed(&list) || list.Empty() { return false } for !list.Empty() { var ( nameType uint8 name cryptobyte.String ) if !list.ReadUint8(&nameType) || !list.ReadUint16LengthPrefixed(&name) || name.Empty() { return false } if nameType != 0 { continue } if *serverName == "" { *serverName = string(name) } } return true } func readTLSSupportedGroups(data cryptobyte.String, supported *[]uint16) bool { var groups cryptobyte.String if !data.ReadUint16LengthPrefixed(&groups) || groups.Empty() { return false } for !groups.Empty() { var group uint16 if !groups.ReadUint16(&group) { return false } *supported = append(*supported, group) } return true } func readTLSSignatureAlgorithms(data cryptobyte.String, supported *[]TLSSignatureScheme) bool { var algorithms cryptobyte.String if !data.ReadUint16LengthPrefixed(&algorithms) || algorithms.Empty() { return false } for !algorithms.Empty() { var algorithm uint16 if !algorithms.ReadUint16(&algorithm) { return false } *supported = append(*supported, TLSSignatureScheme(algorithm)) } return true } func readTLSALPN(data cryptobyte.String, alpnProtocols *[]string) bool { var list cryptobyte.String if !data.ReadUint16LengthPrefixed(&list) || list.Empty() { return false } for !list.Empty() { var proto cryptobyte.String if !list.ReadUint8LengthPrefixed(&proto) || proto.Empty() { return false } *alpnProtocols = append(*alpnProtocols, string(proto)) } return true } func readTLSSupportedVersions(data cryptobyte.String, versions *[]TLSVersion) bool { var list cryptobyte.String if !data.ReadUint8LengthPrefixed(&list) || list.Empty() { return false } for !list.Empty() { var version uint16 if !list.ReadUint16(&version) { return false } *versions = append(*versions, TLSVersion(version)) } return true }