package nettrie import "net/netip" // Node represents a node in the path-compressed trie. // Each node represents a prefix and can have up to two children. type Node struct { children [2]*Node // prefix is the full prefix represented by the path to this node. prefix netip.Prefix // isValue marks if this node represents an explicitly inserted prefix. isValue bool } // Trie is a path-compressed radix trie that stores network prefixes. type Trie struct { rootV4 *Node rootV6 *Node } // New creates and initializes a new Trie. func New() *Trie { return &Trie{} } // Insert adds a prefix to the trie. func (t *Trie) Insert(p netip.Prefix) { p = p.Masked() addr := p.Addr() if addr.Is4() { t.rootV4 = t.insert(t.rootV4, p) } else { t.rootV6 = t.insert(t.rootV6, p) } } // insert is the recursive helper for inserting a prefix into the trie. func (t *Trie) insert(node *Node, p netip.Prefix) *Node { if node == nil { return &Node{prefix: p, isValue: true} } addr := p.Addr() commonLen := commonPrefixLen(addr, node.prefix.Addr()) pBits := p.Bits() nodeBits := node.prefix.Bits() if commonLen > pBits { commonLen = pBits } if commonLen > nodeBits { commonLen = nodeBits } if commonLen == nodeBits && commonLen == pBits { // Exact match, mark the node as a value node. node.isValue = true return node } if commonLen < nodeBits { // The new prefix diverges from the current node's prefix. // We must split the current node. commonP, _ := node.prefix.Addr().Prefix(commonLen) splitNode := &Node{prefix: commonP} // The existing node becomes a child of the new split node. bit := getBit(node.prefix.Addr(), commonLen) splitNode.children[bit] = node if commonLen == pBits { // The inserted prefix is a prefix of the node's original prefix. // The new split node represents the inserted prefix. splitNode.isValue = true } else { // The two prefixes diverge. Create a new child for the new prefix. bit := getBit(addr, commonLen) splitNode.children[bit] = &Node{prefix: p, isValue: true} } return splitNode } // commonLen == nodeBits, meaning the current node's prefix is a prefix of the new one. // We need to descend to a child. bit := getBit(addr, commonLen) node.children[bit] = t.insert(node.children[bit], p) return node } // Delete removes a prefix from the trie. It returns true if the prefix was found and removed. func (t *Trie) Delete(p netip.Prefix) bool { p = p.Masked() addr := p.Addr() var changed bool if addr.Is4() { t.rootV4, changed = t.delete(t.rootV4, p) } else { t.rootV6, changed = t.delete(t.rootV6, p) } return changed } // delete is the recursive helper for removing a prefix from the trie. func (t *Trie) delete(node *Node, p netip.Prefix) (*Node, bool) { if node == nil { return nil, false } addr := p.Addr() pBits := p.Bits() nodeBits := node.prefix.Bits() commonLen := commonPrefixLen(addr, node.prefix.Addr()) // The prefix is not on this path. if commonLen < nodeBits || commonLen < pBits && pBits < nodeBits { return node, false } var changed bool if pBits > nodeBits { // The prefix to delete is deeper in the trie. Recurse. bit := getBit(addr, nodeBits) node.children[bit], changed = t.delete(node.children[bit], p) } else if pBits == nodeBits { // This is the node to delete. Unset its value. if !node.isValue { return node, false // Prefix wasn't actually in the trie. } node.isValue = false changed = true } else { // pBits < nodeBits return node, false // Prefix to delete is shorter, so can't be here. } if !changed { return node, false } // Post-deletion cleanup: // If the node has no value and can be merged with a single child, do so. if !node.isValue { if node.children[0] != nil && node.children[1] == nil { return node.children[0], true } if node.children[0] == nil && node.children[1] != nil { return node.children[1], true } } // If the node is now a leaf without a value, it can be removed entirely. if !node.isValue && node.children[0] == nil && node.children[1] == nil { return nil, true } return node, true } // ContainsPrefix checks if the exact prefix exists in the trie. func (t *Trie) ContainsPrefix(p netip.Prefix) bool { p = p.Masked() addr := p.Addr() pBits := p.Bits() node := t.rootV4 if addr.Is6() { node = t.rootV6 } for node != nil { commonLen := commonPrefixLen(addr, node.prefix.Addr()) nodeBits := node.prefix.Bits() if commonLen < nodeBits { // Path has diverged. The prefix cannot be in this subtree. return false } if pBits < nodeBits { // The search prefix is shorter than the node's prefix, // but they share a prefix. e.g. search /16, node is /24. // The /16 is not explicitly in the trie. return false } if pBits == nodeBits { // Found a node with the exact same prefix length. // Because we also know commonLen >= nodeBits, the prefixes are identical. return node.isValue } // pBits > nodeBits, so we need to go deeper. bit := getBit(addr, nodeBits) node = node.children[bit] } return false } // Contains checks if the exact IP address exists in the trie as a full-length prefix. func (t *Trie) Contains(addr netip.Addr) bool { prefix := netip.PrefixFrom(addr, addr.BitLen()) return t.ContainsPrefix(prefix) } // WalkFunc is a function called for each prefix in the trie during a walk. // Returning false from the function will stop the walk. type WalkFunc func(p netip.Prefix) bool // walk is the recursive helper for traversing the trie. func walk(node *Node, f WalkFunc) bool { if node == nil { return true } if node.isValue { if !f(node.prefix) { return false } } if node.children[0] != nil { if !walk(node.children[0], f) { return false } } if node.children[1] != nil { if !walk(node.children[1], f) { return false } } return true } // Walk traverses the trie and calls the given function for each prefix. // If the function returns false, the walk is stopped. The order is not guaranteed. func (t *Trie) Walk(f WalkFunc) { if !walk(t.rootV4, f) { return } walk(t.rootV6, f) } // Merge inserts all prefixes from another Trie into this one. func (t *Trie) Merge(other *Trie) { other.Walk(func(p netip.Prefix) bool { t.Insert(p) return true // continue walking }) }