package dnstrie import ( "fmt" "github.com/miekg/dns" ) // ValueNode represents a single node in the prefix trie, using generics for the value type. type ValueNode[T any] struct { // children maps the next label (e.g., "example" in "www.example.com") // to the next node in the trie. children map[string]*ValueNode[T] // value is the data of generic type T associated with the domain name ending at this node. value T // isEndOfDomain marks if this node represents the end of a valid domain. isEndOfDomain bool } // ValueTrie holds the root of the domain prefix trie, using generics. type ValueTrie[T any] struct { root *ValueNode[T] } // NewValue creates and initializes a new ValueTrie. func NewValue[T any]() *ValueTrie[T] { return &ValueTrie[T]{ root: &ValueNode[T]{ children: make(map[string]*ValueNode[T]), }, } } // Insert adds a domain name and its associated generic value to the trie. // It canonicalizes the domain name before insertion. func (t *ValueTrie[T]) Insert(domain string, value T) error { // Ensure the string is a valid domain name. if !isValidDomainName(domain) { return fmt.Errorf("'%s' is not a valid domain name", domain) } // Canonicalize the domain name (lowercase, ensure trailing dot). canonicalDomain := dns.CanonicalName(domain) // Split the domain into label offsets. This avoids allocating new strings for labels. // For "www.example.com.", this returns []int{0, 4, 12} offsets := dns.Split(canonicalDomain) if len(offsets) == 0 { return fmt.Errorf("could not split domain name: %s", domain) } currentNode := t.root // Iterate through labels from TLD to the most specific label (right to left). for i := len(offsets) - 1; i >= 0; i-- { start := offsets[i] var end int if i == len(offsets)-1 { // Last label, from its start to the end of the string, excluding the final dot. end = len(canonicalDomain) - 1 } else { // Intermediate label, from its start to the character before the next label's starting dot. end = offsets[i+1] - 1 } label := canonicalDomain[start:end] if _, exists := currentNode.children[label]; !exists { // If the label does not exist in the children map, create a new node. currentNode.children[label] = &ValueNode[T]{children: make(map[string]*ValueNode[T])} } // Move to the next node. currentNode = currentNode.children[label] } // Mark the final node as the end of a domain and store the value. currentNode.isEndOfDomain = true currentNode.value = value return nil } // Contains checks if a domain name exists in the trie without returning its value. func (t *ValueTrie[T]) Contains(domain string) bool { if !isValidDomainName(domain) { return false } canonicalDomain := dns.CanonicalName(domain) offsets := dns.Split(canonicalDomain) if len(offsets) == 0 { return false } currentNode := t.root for i := len(offsets) - 1; i >= 0; i-- { start := offsets[i] var end int if i == len(offsets)-1 { end = len(canonicalDomain) - 1 } else { end = offsets[i+1] - 1 } label := canonicalDomain[start:end] nextNode, exists := currentNode.children[label] if !exists { return false } currentNode = nextNode } return currentNode.isEndOfDomain } // Search looks for a domain name in the trie. // It returns the associated generic value and a boolean indicating if the domain was found. func (t *ValueTrie[T]) Search(domain string) (T, bool) { var zero T // The zero value for the generic type T. if !isValidDomainName(domain) { return zero, false } canonicalDomain := dns.CanonicalName(domain) offsets := dns.Split(canonicalDomain) if len(offsets) == 0 { return zero, false } currentNode := t.root for i := len(offsets) - 1; i >= 0; i-- { start := offsets[i] var end int if i == len(offsets)-1 { end = len(canonicalDomain) - 1 } else { end = offsets[i+1] - 1 } label := canonicalDomain[start:end] nextNode, exists := currentNode.children[label] if !exists { // A label in the path was not found, so the domain doesn't exist. return zero, false } currentNode = nextNode } // The full path was found, but we must also check if it's a terminal node. // This prevents matching "example.com" if only "www.example.com" was inserted. if currentNode.isEndOfDomain { return currentNode.value, true } return zero, false } // Merge combines another trie into the current one. // // If a domain exists in both tries, the value from the 'other' trie is used. func (t *ValueTrie[T]) Merge(other *ValueTrie[T]) { if other == nil || other.root == nil { return } mergeValueNodes(t.root, other.root) } // mergeNodes is a recursive helper function to merge nodes from another trie. func mergeValueNodes[T any](localNode, otherNode *ValueNode[T]) { // The other node value overwrites the local one. if otherNode.isEndOfDomain { localNode.value = otherNode.value } // Iterate over the children of the other node and merge them. for label, otherChildNode := range otherNode.children { localChildNode, exists := localNode.children[label] if !exists { // If the child doesn't exist in the local trie, attach the other's child. localNode.children[label] = otherChildNode } else { // If the child already exists, recurse to merge them. mergeValueNodes(localChildNode, otherChildNode) } } }