package nettrie import ( "math/bits" "net/netip" ) // getBit returns the n-th bit of an IP address (0-indexed). func getBit(addr netip.Addr, n int) byte { slice := addr.AsSlice() byteIndex := n / 8 bitIndex := 7 - (n % 8) return (slice[byteIndex] >> bitIndex) & 1 } // commonPrefixLen computes the number of leading bits that are the same for two addresses. func commonPrefixLen(a, b netip.Addr) int { if a.Is4() != b.Is4() { return 0 } aSlice := a.AsSlice() bSlice := b.AsSlice() commonLen := 0 for i := 0; i < len(aSlice); i++ { xor := aSlice[i] ^ bSlice[i] if xor == 0 { commonLen += 8 } else { commonLen += bits.LeadingZeros8(xor) return commonLen } } return commonLen } // ValueNode represents a node in the path-compressed trie. // Each node represents a prefix and can have up to two children. type ValueNode[T any] struct { children [2]*ValueNode[T] // prefix is the full prefix represented by the path to this node. prefix netip.Prefix value T isValue bool } // ValueTrie is a path-compressed radix trie that stores network prefixes and their values. type ValueTrie[T any] struct { rootV4 *ValueNode[T] rootV6 *ValueNode[T] } // NewValue creates and initializes a new ValueTrie. func NewValue[T any]() *ValueTrie[T] { return &ValueTrie[T]{} } // Insert adds or updates a prefix in the trie with the given value. func (t *ValueTrie[T]) Insert(p netip.Prefix, value T) { p = p.Masked() addr := p.Addr() if addr.Is4() { t.rootV4 = t.insert(t.rootV4, p, value) } else { t.rootV6 = t.insert(t.rootV6, p, value) } } // insert is the recursive helper for inserting a prefix into the trie. func (t *ValueTrie[T]) insert(node *ValueNode[T], p netip.Prefix, value T) *ValueNode[T] { if node == nil { return &ValueNode[T]{prefix: p, value: value, isValue: true} } addr := p.Addr() commonLen := commonPrefixLen(addr, node.prefix.Addr()) pBits := p.Bits() nodeBits := node.prefix.Bits() if commonLen > pBits { commonLen = pBits } if commonLen > nodeBits { commonLen = nodeBits } if commonLen == nodeBits && commonLen == pBits { // Exact match, update the value. node.value = value node.isValue = true return node } if commonLen < nodeBits { // The new prefix diverges from the current node's prefix. // We must split the current node. commonP, _ := node.prefix.Addr().Prefix(commonLen) splitNode := &ValueNode[T]{prefix: commonP} // The existing node becomes a child of the new split node. bit := getBit(node.prefix.Addr(), commonLen) splitNode.children[bit] = node if commonLen == pBits { // The inserted prefix is a prefix of the node's original prefix. // The new split node represents the inserted prefix and gets the value. splitNode.value = value splitNode.isValue = true } else { // The two prefixes diverge. Create a new child for the new prefix. bit = getBit(addr, commonLen) splitNode.children[bit] = &ValueNode[T]{prefix: p, value: value, isValue: true} } return splitNode } // commonLen == nodeBits, meaning the current node's prefix is a prefix of the new one. // We need to descend to a child. bit := getBit(addr, commonLen) node.children[bit] = t.insert(node.children[bit], p, value) return node } // Lookup finds the value associated with the most specific prefix that contains the given IP address. func (t *ValueTrie[T]) Lookup(addr netip.Addr) (value T, ok bool) { node := t.rootV4 if addr.Is6() { node = t.rootV6 } var lastFoundValue T var found bool for node != nil { commonLen := commonPrefixLen(addr, node.prefix.Addr()) nodeBits := node.prefix.Bits() // If the address doesn't share a prefix with the node, we can't be in this subtree. if commonLen < nodeBits { break } // The address is within this node's prefix. If the node holds a value, // it's our current best match. if node.isValue { lastFoundValue = node.value found = true } // We've matched the whole address, can't go deeper. if commonLen == addr.BitLen() { break } // Descend to the next child based on the next bit after the node's prefix. bit := getBit(addr, nodeBits) node = node.children[bit] } return lastFoundValue, found } // Delete removes a prefix from the trie. It returns true if the prefix was found and removed. func (t *ValueTrie[T]) Delete(p netip.Prefix) bool { p = p.Masked() addr := p.Addr() var changed bool if addr.Is4() { t.rootV4, changed = t.delete(t.rootV4, p) } else { t.rootV6, changed = t.delete(t.rootV6, p) } return changed } // delete is the recursive helper for removing a prefix from the trie. func (t *ValueTrie[T]) delete(node *ValueNode[T], p netip.Prefix) (*ValueNode[T], bool) { if node == nil { return nil, false } addr := p.Addr() pBits := p.Bits() nodeBits := node.prefix.Bits() commonLen := commonPrefixLen(addr, node.prefix.Addr()) // The prefix is not on this path. if commonLen < nodeBits || commonLen < pBits && pBits < nodeBits { return node, false } var changed bool if pBits > nodeBits { // The prefix to delete is deeper in the trie. Recurse. bit := getBit(addr, nodeBits) node.children[bit], changed = t.delete(node.children[bit], p) } else if pBits == nodeBits { // This is the node to delete. Unset its value. if !node.isValue { return node, false // Prefix wasn't actually in the trie. } node.isValue = false var zero T node.value = zero changed = true } else { // pBits < nodeBits return node, false // Prefix to delete is shorter, so can't be here. } if !changed { return node, false } // Post-deletion cleanup: // If the node has no value and can be merged with a single child, do so. if !node.isValue { if node.children[0] != nil && node.children[1] == nil { return node.children[0], true } if node.children[0] == nil && node.children[1] != nil { return node.children[1], true } } // If the node is now a leaf without a value, it can be removed entirely. if !node.isValue && node.children[0] == nil && node.children[1] == nil { return nil, true } return node, true } // Contains checks if the exact IP address exists in the trie as a full-length prefix. func (t *ValueTrie[T]) Contains(addr netip.Addr) bool { prefix := netip.PrefixFrom(addr, addr.BitLen()) return t.ContainsPrefix(prefix) } // ContainsPrefix checks if the exact prefix exists in the trie. func (t *ValueTrie[T]) ContainsPrefix(p netip.Prefix) bool { p = p.Masked() addr := p.Addr() pBits := p.Bits() node := t.rootV4 if addr.Is6() { node = t.rootV6 } for node != nil { commonLen := commonPrefixLen(addr, node.prefix.Addr()) nodeBits := node.prefix.Bits() if commonLen < nodeBits { // Path has diverged. The prefix cannot be in this subtree. return false } if pBits < nodeBits { // The search prefix is shorter than the node's prefix, // but they share a prefix. e.g. search /16, node is /24. // The /16 is not explicitly in the trie. return false } if pBits == nodeBits { // Found a node with the exact same prefix length. // Because we also know commonLen >= nodeBits, the prefixes are identical. return node.isValue } // pBits > nodeBits, so we need to go deeper. bit := getBit(addr, nodeBits) node = node.children[bit] } return false } // WalkValueFunc is a function called for each prefix in the trie during a walk. // Returning false from the function will stop the walk. type WalkValueFunc[T any] func(p netip.Prefix, v T) bool // walk is the recursive helper for traversing the trie. func walkValue[T any](node *ValueNode[T], f WalkValueFunc[T]) bool { if node == nil { return true } if node.isValue { if !f(node.prefix, node.value) { return false } } if node.children[0] != nil { if !walkValue(node.children[0], f) { return false } } if node.children[1] != nil { if !walkValue(node.children[1], f) { return false } } return true } // Walk traverses the trie and calls the given function for each prefix and its value. // If the function returns false, the walk is stopped. The order is not guaranteed. func (t *ValueTrie[T]) Walk(f WalkValueFunc[T]) { if !walkValue(t.rootV4, f) { return } walkValue(t.rootV6, f) } // Merge inserts all prefixes from another Trie into this one. func (t *ValueTrie[T]) Merge(other *ValueTrie[T]) { other.Walk(func(p netip.Prefix, v T) bool { t.Insert(p, v) return true // continue walking }) }