Better trie implementations

This commit is contained in:
2025-10-08 20:57:13 +02:00
parent 5f0f4aa96b
commit 582163d4be
26 changed files with 2482 additions and 122 deletions

View File

@@ -0,0 +1,186 @@
package dnstrie
import (
"fmt"
"github.com/miekg/dns"
)
// ValueNode represents a single node in the prefix trie, using generics for the value type.
type ValueNode[T any] struct {
// children maps the next label (e.g., "example" in "www.example.com")
// to the next node in the trie.
children map[string]*ValueNode[T]
// value is the data of generic type T associated with the domain name ending at this node.
value T
// isEndOfDomain marks if this node represents the end of a valid domain.
isEndOfDomain bool
}
// ValueTrie holds the root of the domain prefix trie, using generics.
type ValueTrie[T any] struct {
root *ValueNode[T]
}
// NewValue creates and initializes a new ValueTrie.
func NewValue[T any]() *ValueTrie[T] {
return &ValueTrie[T]{
root: &ValueNode[T]{
children: make(map[string]*ValueNode[T]),
},
}
}
// Insert adds a domain name and its associated generic value to the trie.
// It canonicalizes the domain name before insertion.
func (t *ValueTrie[T]) Insert(domain string, value T) error {
// Ensure the string is a valid domain name.
if !isValidDomainName(domain) {
return fmt.Errorf("'%s' is not a valid domain name", domain)
}
// Canonicalize the domain name (lowercase, ensure trailing dot).
canonicalDomain := dns.CanonicalName(domain)
// Split the domain into label offsets. This avoids allocating new strings for labels.
// For "www.example.com.", this returns []int{0, 4, 12}
offsets := dns.Split(canonicalDomain)
if len(offsets) == 0 {
return fmt.Errorf("could not split domain name: %s", domain)
}
currentNode := t.root
// Iterate through labels from TLD to the most specific label (right to left).
for i := len(offsets) - 1; i >= 0; i-- {
start := offsets[i]
var end int
if i == len(offsets)-1 {
// Last label, from its start to the end of the string, excluding the final dot.
end = len(canonicalDomain) - 1
} else {
// Intermediate label, from its start to the character before the next label's starting dot.
end = offsets[i+1] - 1
}
label := canonicalDomain[start:end]
if _, exists := currentNode.children[label]; !exists {
// If the label does not exist in the children map, create a new node.
currentNode.children[label] = &ValueNode[T]{children: make(map[string]*ValueNode[T])}
}
// Move to the next node.
currentNode = currentNode.children[label]
}
// Mark the final node as the end of a domain and store the value.
currentNode.isEndOfDomain = true
currentNode.value = value
return nil
}
// Contains checks if a domain name exists in the trie without returning its value.
func (t *ValueTrie[T]) Contains(domain string) bool {
if !isValidDomainName(domain) {
return false
}
canonicalDomain := dns.CanonicalName(domain)
offsets := dns.Split(canonicalDomain)
if len(offsets) == 0 {
return false
}
currentNode := t.root
for i := len(offsets) - 1; i >= 0; i-- {
start := offsets[i]
var end int
if i == len(offsets)-1 {
end = len(canonicalDomain) - 1
} else {
end = offsets[i+1] - 1
}
label := canonicalDomain[start:end]
nextNode, exists := currentNode.children[label]
if !exists {
return false
}
currentNode = nextNode
}
return currentNode.isEndOfDomain
}
// Search looks for a domain name in the trie.
// It returns the associated generic value and a boolean indicating if the domain was found.
func (t *ValueTrie[T]) Search(domain string) (T, bool) {
var zero T // The zero value for the generic type T.
if !isValidDomainName(domain) {
return zero, false
}
canonicalDomain := dns.CanonicalName(domain)
offsets := dns.Split(canonicalDomain)
if len(offsets) == 0 {
return zero, false
}
currentNode := t.root
for i := len(offsets) - 1; i >= 0; i-- {
start := offsets[i]
var end int
if i == len(offsets)-1 {
end = len(canonicalDomain) - 1
} else {
end = offsets[i+1] - 1
}
label := canonicalDomain[start:end]
nextNode, exists := currentNode.children[label]
if !exists {
// A label in the path was not found, so the domain doesn't exist.
return zero, false
}
currentNode = nextNode
}
// The full path was found, but we must also check if it's a terminal node.
// This prevents matching "example.com" if only "www.example.com" was inserted.
if currentNode.isEndOfDomain {
return currentNode.value, true
}
return zero, false
}
// Merge combines another trie into the current one.
//
// If a domain exists in both tries, the value from the 'other' trie is used.
func (t *ValueTrie[T]) Merge(other *ValueTrie[T]) {
if other == nil || other.root == nil {
return
}
mergeValueNodes(t.root, other.root)
}
// mergeNodes is a recursive helper function to merge nodes from another trie.
func mergeValueNodes[T any](localNode, otherNode *ValueNode[T]) {
// The other node value overwrites the local one.
if otherNode.isEndOfDomain {
localNode.value = otherNode.value
}
// Iterate over the children of the other node and merge them.
for label, otherChildNode := range otherNode.children {
localChildNode, exists := localNode.children[label]
if !exists {
// If the child doesn't exist in the local trie, attach the other's child.
localNode.children[label] = otherChildNode
} else {
// If the child already exists, recurse to merge them.
mergeValueNodes(localChildNode, otherChildNode)
}
}
}