Files
dpi/protocol/detect/dns/detect.go
2025-10-09 17:58:40 +02:00

162 lines
3.7 KiB
Go

// Package dns implements DNS protocol detection.
//
// This package doesn't expose any public functions, but registers itself for use in protocol detection.
//
// # How to use this package
//
// Import this package into your project in order to enable DNS protocol detection:
//
// import _ "git.maze.io/go/dpi/protocol/detect/dns" // Register DNS protocol detection
package dns
import (
"encoding/hex"
"log"
"math"
"strings"
"github.com/miekg/dns"
"git.maze.io/go/dpi/protocol"
)
// protocolName is the DNS protocol name.
const protocolName = "dns"
var (
classTypeScoreUnknown = -.15
classTypeScore = map[uint16]float64{
dns.ClassINET: .1,
dns.ClassCHAOS: .05,
}
recordTypeScoreUnknown = -.15
recordTypeScore = map[uint16]float64{
// Common types
dns.TypeA: .2,
dns.TypeAAAA: .2,
dns.TypeCNAME: .2,
dns.TypeMX: .2,
dns.TypeNS: .2,
dns.TypePTR: .2,
dns.TypeSOA: .15,
dns.TypeTXT: .15,
dns.TypeSRV: .15,
dns.TypeCAA: .15,
// Common types related to DNSSEC
dns.TypeDNSKEY: .1,
dns.TypeDS: .1,
dns.TypeNSEC: .1,
dns.TypeNSEC3: .1,
// Rare/obsolete
dns.TypeHINFO: -.1,
dns.TypeNULL: -.1,
}
)
func init() {
// Every DNS packet (query or answer) has a 12-byte header.
log.Println("register DetectDNS")
protocol.Register(protocol.Both, "????????????", detectDNS)
}
// detectDNS can detect DNS queries and answersr from the provided data.
func detectDNS(dir protocol.Direction, data []byte, srcPort, dstPort int) (proto *protocol.Protocol, confidence float64) {
log.Printf("detect dns: %q", hex.EncodeToString(data))
// Parsing using miekg/dns
msg := new(dns.Msg)
if msg.Unpack(data) != nil {
return nil, 0
}
if srcPort == 53 || dstPort == 53 {
confidence = .1
}
// Base confidence for a DNS packet; a lot of things may look like DNS, so our initial
// confidence isn't very high.
confidence += .45
if msg.Opcode == dns.OpcodeQuery {
confidence += .1
} else {
confidence -= .2
}
switch msg.Rcode {
case dns.RcodeSuccess: // NOERROR
confidence += .1
case dns.RcodeNameError: // NXDOMAIN
confidence += .05
default:
confidence -= .05
}
if len(msg.Question) == 1 {
// Contains exactly one question, this is most common.
confidence += .2
if score, ok := classTypeScore[msg.Question[0].Qclass]; ok {
confidence += score
} else {
confidence += classTypeScoreUnknown
}
if score, ok := recordTypeScore[msg.Question[0].Qtype]; ok {
confidence += score
} else {
confidence += recordTypeScoreUnknown
}
} else {
// Contains zero or more than one question, this is very uncommon.
confidence -= .2
}
if msg.Response {
if msg.Authoritative {
confidence += .05
}
if len(msg.Question) == 1 && len(msg.Answer) > 0 {
var (
question = msg.Question[0].Name
answer = msg.Answer[0].Header().Name
)
if strings.EqualFold(question, answer) {
confidence += .1
} else {
var isCNAME bool
for _, rr := range msg.Answer {
if isCNAME = rr.Header().Rrtype == dns.TypeCNAME; isCNAME {
break
}
}
if isCNAME {
confidence += .05
} else {
confidence -= .1
}
}
}
} else {
if len(msg.Answer) > 0 || len(msg.Extra) > 0 {
// This wasn't a reply but we have an answer section anyway
confidence -= .2
} else {
confidence += .1
}
}
// Clip the confidence between [0 .. 0.99].
confidence = math.Max(confidence, 0)
confidence = math.Min(confidence, .99)
// We don't have a lower threshold for confidence capping in this function, because
// that is really up to the caller to determine if this wasn't some kind of attempt
// to exfiltrate data using malicious queries, etc.
return &protocol.Protocol{
Name: protocolName,
}, confidence
}