Files
dpi/tls.go
maze 81a3829382 Refactoring
Refactored Protocol.Name -> Protocol.Type; added Encapsulation field
Refactored TLS parsing; added support for ALPN
2025-10-10 12:41:44 +02:00

448 lines
11 KiB
Go

package dpi
import (
"encoding/binary"
"fmt"
"io"
"time"
"golang.org/x/crypto/cryptobyte"
)
// TLSExtension is a TLS extension.
type TLSExtension struct {
Type uint16
Data []byte
}
// TLSRecord is a TLS record.
type TLSRecord struct {
Raw []byte
Type uint8
Version TLSVersion
Length uint16
Data []byte
}
func DecodeTLSRecord(data []byte) (*TLSRecord, error) {
var (
stream = cryptobyte.String(data)
record = &TLSRecord{Raw: data}
)
var version uint16
if !stream.ReadUint8(&record.Type) ||
!stream.ReadUint16(&version) ||
!stream.ReadUint16(&record.Length) {
return nil, DecodeError{
Reason: "invalid TLS record header",
Err: io.ErrUnexpectedEOF,
}
}
record.Version = TLSVersion(version)
if !stream.ReadBytes(&record.Data, int(record.Length)) {
return nil, DecodeError{
Reason: "invalid TLS record data",
Err: io.ErrUnexpectedEOF,
}
}
return record, nil
}
// TLSClientHello is a TLS ClientHello packet as part of the TLS handshake.
type TLSClientHello struct {
Raw []byte
Version TLSVersion
Random []byte
SessionID []byte
CipherSuites []uint16
CompressionMethods []uint8
Extensions []TLSExtension
ServerName string
SupportedCurves []uint16 // aka "Supported Groups"
SupportedSignatureAlgorithms []TLSSignatureScheme // RFC 5246, Section 7.4.1.4.1
ALPNProtocols []string // RFC 7301, Section 3.1
SupportedVersions []TLSVersion // RFC 8446, Section 4.2.1z
}
func DecodeTLSClientHello(data []byte) (*TLSClientHello, error) {
var (
stream = cryptobyte.String(data)
hello = &TLSClientHello{Raw: data}
)
// Read header (4 bytes)
var handshakeType uint8
if !stream.ReadUint8(&handshakeType) || handshakeType != tlsTypeClientHello {
return nil, DecodeError{
Reason: fmt.Sprintf("expected a TLS ClientHello (0x%02X), got 0x%02X", tlsTypeClientHello, handshakeType),
Err: ErrInvalid,
}
}
var record cryptobyte.String
if !stream.ReadUint24LengthPrefixed(&record) {
return nil, DecodeError{
Reason: "incomplete TLS record",
Err: io.ErrUnexpectedEOF,
}
}
if !stream.Empty() {
return nil, DecodeError{
Reason: "invalid TLS record length",
Err: ErrInvalid,
}
}
// Parser client version.
var version uint16
if !record.ReadUint16(&version) {
return nil, DecodeError{
Reason: "incomplete TLS version",
Err: io.ErrUnexpectedEOF,
}
}
hello.Version = TLSVersion(version)
// Parse random (32 bytes)
if !record.ReadBytes(&hello.Random, 32) {
return nil, DecodeError{
Reason: "incomplete TLS random bytes",
Err: io.ErrUnexpectedEOF,
}
}
// Parse session ID
var sessionID cryptobyte.String
if !record.ReadUint8LengthPrefixed(&sessionID) {
return nil, DecodeError{
Reason: "incomplete TLS session ID",
Err: io.ErrUnexpectedEOF,
}
}
hello.SessionID = sessionID
// Parse cipher suites
var cipherSuites cryptobyte.String
if !record.ReadUint16LengthPrefixed(&cipherSuites) {
return nil, DecodeError{
Reason: "incomplete TLS cipher suites bytes",
Err: io.ErrUnexpectedEOF,
}
}
for !cipherSuites.Empty() {
var cipherSuite uint16
if !cipherSuites.ReadUint16(&cipherSuite) {
return nil, DecodeError{
Reason: "incomplete TLS cipher suite",
Err: io.ErrUnexpectedEOF,
}
}
hello.CipherSuites = append(hello.CipherSuites, cipherSuite)
}
// Parse compression methods
var compressionMethods cryptobyte.String
if !record.ReadUint8LengthPrefixed(&compressionMethods) {
return nil, DecodeError{
Reason: "incomplete TLS compression methods bytes",
Err: io.ErrUnexpectedEOF,
}
}
hello.CompressionMethods = compressionMethods
// Parse extensions (optional)
if record.Empty() {
return hello, nil
}
var extensions cryptobyte.String
if !record.ReadUint16LengthPrefixed(&extensions) {
return nil, DecodeError{
Reason: "incomplete TLS extensions",
Err: io.ErrUnexpectedEOF,
}
}
for !extensions.Empty() {
var (
extension TLSExtension
extensionData = cryptobyte.String(extension.Data)
)
if !extensions.ReadUint16(&extension.Type) || !extensions.ReadUint16LengthPrefixed(&extensionData) {
return nil, DecodeError{
Reason: "incomplete TLS extension record data",
Err: io.ErrUnexpectedEOF,
}
}
hello.Extensions = append(hello.Extensions, extension)
switch extension.Type {
case tlsExtensionServerName:
// RFC 6066, Section 3
if !readTLSServerName(extensionData, &hello.ServerName) {
return nil, DecodeError{
Reason: "invalid TLS server name extension data",
Err: io.ErrUnexpectedEOF,
}
}
case tlsExtensionSupportedGroups:
// RFC 4492, Section 5.1.1
// RFC 8446, Section 4.2.7
if !readTLSSupportedGroups(extensionData, &hello.SupportedCurves) {
return nil, DecodeError{
Reason: "invalid TLS supported groups extension data",
Err: io.ErrUnexpectedEOF,
}
}
case tlsExtensionSignatureAlgorithms:
// RFC 5246, Section 7.4.1
if !readTLSSignatureAlgorithms(extensionData, &hello.SupportedSignatureAlgorithms) {
return nil, DecodeError{
Reason: "invalid TLS supported signature algorithms extension data",
Err: io.ErrUnexpectedEOF,
}
}
case tlsExtensionALPN:
if !readTLSALPN(extensionData, &hello.ALPNProtocols) {
return nil, DecodeError{
Reason: "invalid TLS ALPN extension data",
Err: io.ErrUnexpectedEOF,
}
}
case tlsExtensionSupportedVersions:
if !readTLSSupportedVersions(extensionData, &hello.SupportedVersions) {
return nil, DecodeError{
Reason: "invalid TLS supported versions extension data",
Err: io.ErrUnexpectedEOF,
}
}
}
}
return hello, nil
}
func DecodeTLSClientHelloHandshake(data []byte) (*TLSClientHello, error) {
record, err := DecodeTLSRecord(data)
if err != nil {
return nil, err
}
if record.Type != tlsRecordTypeHandshake {
return nil, DecodeError{
Reason: fmt.Sprintf("expected TLS handshake record type (0x%02X), got 0x%02X", tlsRecordTypeHandshake, record.Type),
Err: ErrInvalid,
}
}
return DecodeTLSClientHello(record.Data)
}
// TLSServerHello is a TLS ServerHello packet as part of the TLS handshake.
type TLSServerHello struct {
Raw []byte
Version TLSVersion
RandomTimestamp time.Time
Random []byte
SessionID []byte
CipherSuite uint16
CompressionMethod uint8
Extensions []TLSExtension
ALPNProtocols []string // RFC 7301, Section 3.1
}
func DecodeTLSServerHello(data []byte) (*TLSServerHello, error) {
var (
stream = cryptobyte.String(data)
hello = &TLSServerHello{Raw: data}
)
// Read header (4 bytes)
var handshakeType uint8
if !stream.ReadUint8(&handshakeType) || handshakeType != tlsTypeServerHello {
return nil, DecodeError{
Reason: fmt.Sprintf("expected a TLS ServerHello (0x%02X), got 0x%02X", tlsTypeServerHello, handshakeType),
Err: ErrInvalid,
}
}
var record cryptobyte.String
if !stream.ReadUint24LengthPrefixed(&record) {
return nil, DecodeError{
Reason: "incomplete TLS record",
Err: io.ErrUnexpectedEOF,
}
}
if !stream.Empty() {
return nil, DecodeError{
Reason: "invalid TLS record length",
Err: ErrInvalid,
}
}
// Parser server version.
var version uint16
if !record.ReadUint16(&version) {
return nil, DecodeError{
Reason: "incomplete TLS version",
Err: io.ErrUnexpectedEOF,
}
}
hello.Version = TLSVersion(version)
// Parse random (32 bytes)
if !record.ReadBytes(&hello.Random, 32) {
return nil, DecodeError{
Reason: "incomplete TLS random bytes",
Err: io.ErrUnexpectedEOF,
}
}
hello.RandomTimestamp = time.Unix(int64(binary.BigEndian.Uint32(hello.Random)), 0)
// Parse session ID
var sessionID cryptobyte.String
if !record.ReadUint8LengthPrefixed(&sessionID) {
return nil, DecodeError{
Reason: "incomplete TLS session ID",
Err: io.ErrUnexpectedEOF,
}
}
hello.SessionID = sessionID
// Parse cipher suite
if !record.ReadUint16(&hello.CipherSuite) {
return nil, DecodeError{
Reason: "incomplete TLS cipher suite",
Err: io.ErrUnexpectedEOF,
}
}
// Parse compression method
if !record.ReadUint8(&hello.CompressionMethod) {
return nil, DecodeError{
Reason: "incomplete TLS compression method",
Err: io.ErrUnexpectedEOF,
}
}
// Parse extensions (optional)
if record.Empty() {
return hello, nil
}
var extensions cryptobyte.String
if !record.ReadUint16LengthPrefixed(&extensions) {
return nil, DecodeError{
Reason: "incomplete TLS extensions",
Err: io.ErrUnexpectedEOF,
}
}
for !extensions.Empty() {
var (
extension TLSExtension
extensionData = cryptobyte.String(extension.Data)
)
if !extensions.ReadUint16(&extension.Type) || !extensions.ReadUint16LengthPrefixed(&extensionData) {
return nil, DecodeError{
Reason: "incomplete TLS extension record data",
Err: io.ErrUnexpectedEOF,
}
}
hello.Extensions = append(hello.Extensions, extension)
switch extension.Type {
case tlsExtensionALPN:
_ = readTLSALPN(extensionData, &hello.ALPNProtocols)
}
}
return hello, nil
}
func readTLSServerName(data cryptobyte.String, serverName *string) bool {
var list cryptobyte.String
if !data.ReadUint16LengthPrefixed(&list) || list.Empty() {
return false
}
for !list.Empty() {
var (
nameType uint8
name cryptobyte.String
)
if !list.ReadUint8(&nameType) || !list.ReadUint16LengthPrefixed(&name) || name.Empty() {
return false
}
if nameType != 0 {
continue
}
if *serverName == "" {
*serverName = string(name)
}
}
return true
}
func readTLSSupportedGroups(data cryptobyte.String, supported *[]uint16) bool {
var groups cryptobyte.String
if !data.ReadUint16LengthPrefixed(&groups) || groups.Empty() {
return false
}
for !groups.Empty() {
var group uint16
if !groups.ReadUint16(&group) {
return false
}
*supported = append(*supported, group)
}
return true
}
func readTLSSignatureAlgorithms(data cryptobyte.String, supported *[]TLSSignatureScheme) bool {
var algorithms cryptobyte.String
if !data.ReadUint16LengthPrefixed(&algorithms) || algorithms.Empty() {
return false
}
for !algorithms.Empty() {
var algorithm uint16
if !algorithms.ReadUint16(&algorithm) {
return false
}
*supported = append(*supported, TLSSignatureScheme(algorithm))
}
return true
}
func readTLSALPN(data cryptobyte.String, alpnProtocols *[]string) bool {
var list cryptobyte.String
if !data.ReadUint16LengthPrefixed(&list) || list.Empty() {
return false
}
for !list.Empty() {
var proto cryptobyte.String
if !list.ReadUint8LengthPrefixed(&proto) || proto.Empty() {
return false
}
*alpnProtocols = append(*alpnProtocols, string(proto))
}
return true
}
func readTLSSupportedVersions(data cryptobyte.String, versions *[]TLSVersion) bool {
var list cryptobyte.String
if !data.ReadUint8LengthPrefixed(&list) || list.Empty() {
return false
}
for !list.Empty() {
var version uint16
if !list.ReadUint16(&version) {
return false
}
*versions = append(*versions, TLSVersion(version))
}
return true
}