package dnstrie import ( "fmt" "github.com/miekg/dns" ) // Node represents a single node in the prefix trie. type Node struct { // children maps the next label (e.g., "example" in "www.example.com") // to the next node in the trie. children map[string]*Node // isEndOfDomain marks if this node represents the end of a valid domain. isEndOfDomain bool } // Trie holds the root of the domain prefix trie. type Trie struct { root *Node } // New creates and initializes a new Trie. func New() *Trie { return &Trie{ root: &Node{ children: make(map[string]*Node), }, } } // Insert adds a domain name to the trie. // It canonicalizes the domain name before insertion. func (t *Trie) Insert(domain string) error { // Ensure the string is a valid domain name. if !isValidDomainName(domain) { return fmt.Errorf("'%s' is not a valid domain name", domain) } // Canonicalize the domain name (lowercase, ensure trailing dot). canonicalDomain := dns.CanonicalName(domain) // Split the domain into label offsets. This avoids allocating new strings for labels. // For "www.example.com.", this returns []int{0, 4, 12} offsets := dns.Split(canonicalDomain) if len(offsets) == 0 { return fmt.Errorf("could not split domain name: %s", domain) } currentNode := t.root // Iterate through labels from TLD to the most specific label (right to left). for i := len(offsets) - 1; i >= 0; i-- { start := offsets[i] var end int if i == len(offsets)-1 { // Last label, from its start to the end of the string, excluding the final dot. end = len(canonicalDomain) - 1 } else { // Intermediate label, from its start to the character before the next label's starting dot. end = offsets[i+1] - 1 } label := canonicalDomain[start:end] if _, exists := currentNode.children[label]; !exists { // If the label does not exist in the children map, create a new node. currentNode.children[label] = &Node{children: make(map[string]*Node)} } // Move to the next node. currentNode = currentNode.children[label] } // Mark the final node as the end of a domain. currentNode.isEndOfDomain = true return nil } // Contains checks if a domain name exists in the trie. func (t *Trie) Contains(domain string) bool { if !isValidDomainName(domain) { return false } canonicalDomain := dns.CanonicalName(domain) offsets := dns.Split(canonicalDomain) if len(offsets) == 0 { return false } currentNode := t.root for i := len(offsets) - 1; i >= 0; i-- { start := offsets[i] var end int if i == len(offsets)-1 { end = len(canonicalDomain) - 1 } else { end = offsets[i+1] - 1 } label := canonicalDomain[start:end] nextNode, exists := currentNode.children[label] if !exists { return false } currentNode = nextNode } return currentNode.isEndOfDomain } // Merge combines another trie into the current one. // // Domains from the 'other' trie will be added to the current trie. func (t *Trie) Merge(other *Trie) { if other == nil || other.root == nil { return } mergeNodes(t.root, other.root) } // mergeNodes is a recursive helper function to merge nodes from another trie. func mergeNodes(localNode, otherNode *Node) { // If the other node marks the end of a domain, the local node should too. if otherNode.isEndOfDomain { localNode.isEndOfDomain = true } // Iterate over the children of the other node and merge them. for label, otherChildNode := range otherNode.children { localChildNode, exists := localNode.children[label] if !exists { // If the child doesn't exist in the local trie, just attach the other's child. localNode.children[label] = otherChildNode } else { // If the child already exists, recurse to merge them. mergeNodes(localChildNode, otherChildNode) } } }