Files
styx/dataset/dnstrie/trie.go
2025-10-08 20:57:13 +02:00

140 lines
3.7 KiB
Go

package dnstrie
import (
"fmt"
"github.com/miekg/dns"
)
// Node represents a single node in the prefix trie.
type Node struct {
// children maps the next label (e.g., "example" in "www.example.com")
// to the next node in the trie.
children map[string]*Node
// isEndOfDomain marks if this node represents the end of a valid domain.
isEndOfDomain bool
}
// Trie holds the root of the domain prefix trie.
type Trie struct {
root *Node
}
// New creates and initializes a new Trie.
func New() *Trie {
return &Trie{
root: &Node{
children: make(map[string]*Node),
},
}
}
// Insert adds a domain name to the trie.
// It canonicalizes the domain name before insertion.
func (t *Trie) Insert(domain string) error {
// Ensure the string is a valid domain name.
if !isValidDomainName(domain) {
return fmt.Errorf("'%s' is not a valid domain name", domain)
}
// Canonicalize the domain name (lowercase, ensure trailing dot).
canonicalDomain := dns.CanonicalName(domain)
// Split the domain into label offsets. This avoids allocating new strings for labels.
// For "www.example.com.", this returns []int{0, 4, 12}
offsets := dns.Split(canonicalDomain)
if len(offsets) == 0 {
return fmt.Errorf("could not split domain name: %s", domain)
}
currentNode := t.root
// Iterate through labels from TLD to the most specific label (right to left).
for i := len(offsets) - 1; i >= 0; i-- {
start := offsets[i]
var end int
if i == len(offsets)-1 {
// Last label, from its start to the end of the string, excluding the final dot.
end = len(canonicalDomain) - 1
} else {
// Intermediate label, from its start to the character before the next label's starting dot.
end = offsets[i+1] - 1
}
label := canonicalDomain[start:end]
if _, exists := currentNode.children[label]; !exists {
// If the label does not exist in the children map, create a new node.
currentNode.children[label] = &Node{children: make(map[string]*Node)}
}
// Move to the next node.
currentNode = currentNode.children[label]
}
// Mark the final node as the end of a domain.
currentNode.isEndOfDomain = true
return nil
}
// Contains checks if a domain name exists in the trie.
func (t *Trie) Contains(domain string) bool {
if !isValidDomainName(domain) {
return false
}
canonicalDomain := dns.CanonicalName(domain)
offsets := dns.Split(canonicalDomain)
if len(offsets) == 0 {
return false
}
currentNode := t.root
for i := len(offsets) - 1; i >= 0; i-- {
start := offsets[i]
var end int
if i == len(offsets)-1 {
end = len(canonicalDomain) - 1
} else {
end = offsets[i+1] - 1
}
label := canonicalDomain[start:end]
nextNode, exists := currentNode.children[label]
if !exists {
return false
}
currentNode = nextNode
}
return currentNode.isEndOfDomain
}
// Merge combines another trie into the current one.
//
// Domains from the 'other' trie will be added to the current trie.
func (t *Trie) Merge(other *Trie) {
if other == nil || other.root == nil {
return
}
mergeNodes(t.root, other.root)
}
// mergeNodes is a recursive helper function to merge nodes from another trie.
func mergeNodes(localNode, otherNode *Node) {
// If the other node marks the end of a domain, the local node should too.
if otherNode.isEndOfDomain {
localNode.isEndOfDomain = true
}
// Iterate over the children of the other node and merge them.
for label, otherChildNode := range otherNode.children {
localChildNode, exists := localNode.children[label]
if !exists {
// If the child doesn't exist in the local trie, just attach the other's child.
localNode.children[label] = otherChildNode
} else {
// If the child already exists, recurse to merge them.
mergeNodes(localChildNode, otherChildNode)
}
}
}