187 lines
5.2 KiB
Go
187 lines
5.2 KiB
Go
package dnstrie
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
// ValueNode represents a single node in the prefix trie, using generics for the value type.
|
|
type ValueNode[T any] struct {
|
|
// children maps the next label (e.g., "example" in "www.example.com")
|
|
// to the next node in the trie.
|
|
children map[string]*ValueNode[T]
|
|
|
|
// value is the data of generic type T associated with the domain name ending at this node.
|
|
value T
|
|
|
|
// isEndOfDomain marks if this node represents the end of a valid domain.
|
|
isEndOfDomain bool
|
|
}
|
|
|
|
// ValueTrie holds the root of the domain prefix trie, using generics.
|
|
type ValueTrie[T any] struct {
|
|
root *ValueNode[T]
|
|
}
|
|
|
|
// NewValue creates and initializes a new ValueTrie.
|
|
func NewValue[T any]() *ValueTrie[T] {
|
|
return &ValueTrie[T]{
|
|
root: &ValueNode[T]{
|
|
children: make(map[string]*ValueNode[T]),
|
|
},
|
|
}
|
|
}
|
|
|
|
// Insert adds a domain name and its associated generic value to the trie.
|
|
// It canonicalizes the domain name before insertion.
|
|
func (t *ValueTrie[T]) Insert(domain string, value T) error {
|
|
// Ensure the string is a valid domain name.
|
|
if !isValidDomainName(domain) {
|
|
return fmt.Errorf("'%s' is not a valid domain name", domain)
|
|
}
|
|
|
|
// Canonicalize the domain name (lowercase, ensure trailing dot).
|
|
canonicalDomain := dns.CanonicalName(domain)
|
|
|
|
// Split the domain into label offsets. This avoids allocating new strings for labels.
|
|
// For "www.example.com.", this returns []int{0, 4, 12}
|
|
offsets := dns.Split(canonicalDomain)
|
|
if len(offsets) == 0 {
|
|
return fmt.Errorf("could not split domain name: %s", domain)
|
|
}
|
|
|
|
currentNode := t.root
|
|
// Iterate through labels from TLD to the most specific label (right to left).
|
|
for i := len(offsets) - 1; i >= 0; i-- {
|
|
start := offsets[i]
|
|
var end int
|
|
if i == len(offsets)-1 {
|
|
// Last label, from its start to the end of the string, excluding the final dot.
|
|
end = len(canonicalDomain) - 1
|
|
} else {
|
|
// Intermediate label, from its start to the character before the next label's starting dot.
|
|
end = offsets[i+1] - 1
|
|
}
|
|
label := canonicalDomain[start:end]
|
|
|
|
if _, exists := currentNode.children[label]; !exists {
|
|
// If the label does not exist in the children map, create a new node.
|
|
currentNode.children[label] = &ValueNode[T]{children: make(map[string]*ValueNode[T])}
|
|
}
|
|
// Move to the next node.
|
|
currentNode = currentNode.children[label]
|
|
}
|
|
|
|
// Mark the final node as the end of a domain and store the value.
|
|
currentNode.isEndOfDomain = true
|
|
currentNode.value = value
|
|
|
|
return nil
|
|
}
|
|
|
|
// Contains checks if a domain name exists in the trie without returning its value.
|
|
func (t *ValueTrie[T]) Contains(domain string) bool {
|
|
if !isValidDomainName(domain) {
|
|
return false
|
|
}
|
|
|
|
canonicalDomain := dns.CanonicalName(domain)
|
|
offsets := dns.Split(canonicalDomain)
|
|
if len(offsets) == 0 {
|
|
return false
|
|
}
|
|
|
|
currentNode := t.root
|
|
for i := len(offsets) - 1; i >= 0; i-- {
|
|
start := offsets[i]
|
|
var end int
|
|
if i == len(offsets)-1 {
|
|
end = len(canonicalDomain) - 1
|
|
} else {
|
|
end = offsets[i+1] - 1
|
|
}
|
|
label := canonicalDomain[start:end]
|
|
|
|
nextNode, exists := currentNode.children[label]
|
|
if !exists {
|
|
return false
|
|
}
|
|
currentNode = nextNode
|
|
}
|
|
|
|
return currentNode.isEndOfDomain
|
|
}
|
|
|
|
// Search looks for a domain name in the trie.
|
|
// It returns the associated generic value and a boolean indicating if the domain was found.
|
|
func (t *ValueTrie[T]) Search(domain string) (T, bool) {
|
|
var zero T // The zero value for the generic type T.
|
|
if !isValidDomainName(domain) {
|
|
return zero, false
|
|
}
|
|
|
|
canonicalDomain := dns.CanonicalName(domain)
|
|
offsets := dns.Split(canonicalDomain)
|
|
if len(offsets) == 0 {
|
|
return zero, false
|
|
}
|
|
|
|
currentNode := t.root
|
|
for i := len(offsets) - 1; i >= 0; i-- {
|
|
start := offsets[i]
|
|
var end int
|
|
if i == len(offsets)-1 {
|
|
end = len(canonicalDomain) - 1
|
|
} else {
|
|
end = offsets[i+1] - 1
|
|
}
|
|
label := canonicalDomain[start:end]
|
|
|
|
nextNode, exists := currentNode.children[label]
|
|
if !exists {
|
|
// A label in the path was not found, so the domain doesn't exist.
|
|
return zero, false
|
|
}
|
|
currentNode = nextNode
|
|
}
|
|
|
|
// The full path was found, but we must also check if it's a terminal node.
|
|
// This prevents matching "example.com" if only "www.example.com" was inserted.
|
|
if currentNode.isEndOfDomain {
|
|
return currentNode.value, true
|
|
}
|
|
|
|
return zero, false
|
|
}
|
|
|
|
// Merge combines another trie into the current one.
|
|
//
|
|
// If a domain exists in both tries, the value from the 'other' trie is used.
|
|
func (t *ValueTrie[T]) Merge(other *ValueTrie[T]) {
|
|
if other == nil || other.root == nil {
|
|
return
|
|
}
|
|
mergeValueNodes(t.root, other.root)
|
|
}
|
|
|
|
// mergeNodes is a recursive helper function to merge nodes from another trie.
|
|
func mergeValueNodes[T any](localNode, otherNode *ValueNode[T]) {
|
|
// The other node value overwrites the local one.
|
|
if otherNode.isEndOfDomain {
|
|
localNode.value = otherNode.value
|
|
}
|
|
|
|
// Iterate over the children of the other node and merge them.
|
|
for label, otherChildNode := range otherNode.children {
|
|
localChildNode, exists := localNode.children[label]
|
|
if !exists {
|
|
// If the child doesn't exist in the local trie, attach the other's child.
|
|
localNode.children[label] = otherChildNode
|
|
} else {
|
|
// If the child already exists, recurse to merge them.
|
|
mergeValueNodes(localChildNode, otherChildNode)
|
|
}
|
|
}
|
|
}
|