329 lines
8.3 KiB
Go
329 lines
8.3 KiB
Go
package nettrie
|
|
|
|
import (
|
|
"math/bits"
|
|
"net/netip"
|
|
)
|
|
|
|
// getBit returns the n-th bit of an IP address (0-indexed).
|
|
func getBit(addr netip.Addr, n int) byte {
|
|
slice := addr.AsSlice()
|
|
byteIndex := n / 8
|
|
bitIndex := 7 - (n % 8)
|
|
return (slice[byteIndex] >> bitIndex) & 1
|
|
}
|
|
|
|
// commonPrefixLen computes the number of leading bits that are the same for two addresses.
|
|
func commonPrefixLen(a, b netip.Addr) int {
|
|
if a.Is4() != b.Is4() {
|
|
return 0
|
|
}
|
|
aSlice := a.AsSlice()
|
|
bSlice := b.AsSlice()
|
|
|
|
commonLen := 0
|
|
for i := 0; i < len(aSlice); i++ {
|
|
xor := aSlice[i] ^ bSlice[i]
|
|
if xor == 0 {
|
|
commonLen += 8
|
|
} else {
|
|
commonLen += bits.LeadingZeros8(xor)
|
|
return commonLen
|
|
}
|
|
}
|
|
return commonLen
|
|
}
|
|
|
|
// ValueNode represents a node in the path-compressed trie.
|
|
// Each node represents a prefix and can have up to two children.
|
|
type ValueNode[T any] struct {
|
|
children [2]*ValueNode[T]
|
|
|
|
// prefix is the full prefix represented by the path to this node.
|
|
prefix netip.Prefix
|
|
|
|
value T
|
|
isValue bool
|
|
}
|
|
|
|
// ValueTrie is a path-compressed radix trie that stores network prefixes and their values.
|
|
type ValueTrie[T any] struct {
|
|
rootV4 *ValueNode[T]
|
|
rootV6 *ValueNode[T]
|
|
}
|
|
|
|
// NewValue creates and initializes a new ValueTrie.
|
|
func NewValue[T any]() *ValueTrie[T] {
|
|
return &ValueTrie[T]{}
|
|
}
|
|
|
|
// Insert adds or updates a prefix in the trie with the given value.
|
|
func (t *ValueTrie[T]) Insert(p netip.Prefix, value T) {
|
|
p = p.Masked()
|
|
addr := p.Addr()
|
|
|
|
if addr.Is4() {
|
|
t.rootV4 = t.insert(t.rootV4, p, value)
|
|
} else {
|
|
t.rootV6 = t.insert(t.rootV6, p, value)
|
|
}
|
|
}
|
|
|
|
// insert is the recursive helper for inserting a prefix into the trie.
|
|
func (t *ValueTrie[T]) insert(node *ValueNode[T], p netip.Prefix, value T) *ValueNode[T] {
|
|
if node == nil {
|
|
return &ValueNode[T]{prefix: p, value: value, isValue: true}
|
|
}
|
|
|
|
addr := p.Addr()
|
|
commonLen := commonPrefixLen(addr, node.prefix.Addr())
|
|
pBits := p.Bits()
|
|
nodeBits := node.prefix.Bits()
|
|
|
|
if commonLen > pBits {
|
|
commonLen = pBits
|
|
}
|
|
if commonLen > nodeBits {
|
|
commonLen = nodeBits
|
|
}
|
|
|
|
if commonLen == nodeBits && commonLen == pBits {
|
|
// Exact match, update the value.
|
|
node.value = value
|
|
node.isValue = true
|
|
return node
|
|
}
|
|
|
|
if commonLen < nodeBits {
|
|
// The new prefix diverges from the current node's prefix.
|
|
// We must split the current node.
|
|
commonP, _ := node.prefix.Addr().Prefix(commonLen)
|
|
splitNode := &ValueNode[T]{prefix: commonP}
|
|
|
|
// The existing node becomes a child of the new split node.
|
|
bit := getBit(node.prefix.Addr(), commonLen)
|
|
splitNode.children[bit] = node
|
|
|
|
if commonLen == pBits {
|
|
// The inserted prefix is a prefix of the node's original prefix.
|
|
// The new split node represents the inserted prefix and gets the value.
|
|
splitNode.value = value
|
|
splitNode.isValue = true
|
|
} else {
|
|
// The two prefixes diverge. Create a new child for the new prefix.
|
|
bit = getBit(addr, commonLen)
|
|
splitNode.children[bit] = &ValueNode[T]{prefix: p, value: value, isValue: true}
|
|
}
|
|
return splitNode
|
|
}
|
|
|
|
// commonLen == nodeBits, meaning the current node's prefix is a prefix of the new one.
|
|
// We need to descend to a child.
|
|
bit := getBit(addr, commonLen)
|
|
node.children[bit] = t.insert(node.children[bit], p, value)
|
|
return node
|
|
}
|
|
|
|
// Lookup finds the value associated with the most specific prefix that contains the given IP address.
|
|
func (t *ValueTrie[T]) Lookup(addr netip.Addr) (value T, ok bool) {
|
|
node := t.rootV4
|
|
if addr.Is6() {
|
|
node = t.rootV6
|
|
}
|
|
|
|
var lastFoundValue T
|
|
var found bool
|
|
|
|
for node != nil {
|
|
commonLen := commonPrefixLen(addr, node.prefix.Addr())
|
|
nodeBits := node.prefix.Bits()
|
|
|
|
// If the address doesn't share a prefix with the node, we can't be in this subtree.
|
|
if commonLen < nodeBits {
|
|
break
|
|
}
|
|
|
|
// The address is within this node's prefix. If the node holds a value,
|
|
// it's our current best match.
|
|
if node.isValue {
|
|
lastFoundValue = node.value
|
|
found = true
|
|
}
|
|
|
|
// We've matched the whole address, can't go deeper.
|
|
if commonLen == addr.BitLen() {
|
|
break
|
|
}
|
|
|
|
// Descend to the next child based on the next bit after the node's prefix.
|
|
bit := getBit(addr, nodeBits)
|
|
node = node.children[bit]
|
|
}
|
|
|
|
return lastFoundValue, found
|
|
}
|
|
|
|
// Delete removes a prefix from the trie. It returns true if the prefix was found and removed.
|
|
func (t *ValueTrie[T]) Delete(p netip.Prefix) bool {
|
|
p = p.Masked()
|
|
addr := p.Addr()
|
|
|
|
var changed bool
|
|
if addr.Is4() {
|
|
t.rootV4, changed = t.delete(t.rootV4, p)
|
|
} else {
|
|
t.rootV6, changed = t.delete(t.rootV6, p)
|
|
}
|
|
return changed
|
|
}
|
|
|
|
// delete is the recursive helper for removing a prefix from the trie.
|
|
func (t *ValueTrie[T]) delete(node *ValueNode[T], p netip.Prefix) (*ValueNode[T], bool) {
|
|
if node == nil {
|
|
return nil, false
|
|
}
|
|
|
|
addr := p.Addr()
|
|
pBits := p.Bits()
|
|
nodeBits := node.prefix.Bits()
|
|
commonLen := commonPrefixLen(addr, node.prefix.Addr())
|
|
|
|
// The prefix is not on this path.
|
|
if commonLen < nodeBits || commonLen < pBits && pBits < nodeBits {
|
|
return node, false
|
|
}
|
|
|
|
var changed bool
|
|
if pBits > nodeBits {
|
|
// The prefix to delete is deeper in the trie. Recurse.
|
|
bit := getBit(addr, nodeBits)
|
|
node.children[bit], changed = t.delete(node.children[bit], p)
|
|
} else if pBits == nodeBits {
|
|
// This is the node to delete. Unset its value.
|
|
if !node.isValue {
|
|
return node, false // Prefix wasn't actually in the trie.
|
|
}
|
|
node.isValue = false
|
|
var zero T
|
|
node.value = zero
|
|
changed = true
|
|
} else { // pBits < nodeBits
|
|
return node, false // Prefix to delete is shorter, so can't be here.
|
|
}
|
|
|
|
if !changed {
|
|
return node, false
|
|
}
|
|
|
|
// Post-deletion cleanup:
|
|
// If the node has no value and can be merged with a single child, do so.
|
|
if !node.isValue {
|
|
if node.children[0] != nil && node.children[1] == nil {
|
|
return node.children[0], true
|
|
}
|
|
if node.children[0] == nil && node.children[1] != nil {
|
|
return node.children[1], true
|
|
}
|
|
}
|
|
|
|
// If the node is now a leaf without a value, it can be removed entirely.
|
|
if !node.isValue && node.children[0] == nil && node.children[1] == nil {
|
|
return nil, true
|
|
}
|
|
|
|
return node, true
|
|
}
|
|
|
|
// Contains checks if the exact IP address exists in the trie as a full-length prefix.
|
|
func (t *ValueTrie[T]) Contains(addr netip.Addr) bool {
|
|
prefix := netip.PrefixFrom(addr, addr.BitLen())
|
|
return t.ContainsPrefix(prefix)
|
|
}
|
|
|
|
// ContainsPrefix checks if the exact prefix exists in the trie.
|
|
func (t *ValueTrie[T]) ContainsPrefix(p netip.Prefix) bool {
|
|
p = p.Masked()
|
|
addr := p.Addr()
|
|
pBits := p.Bits()
|
|
|
|
node := t.rootV4
|
|
if addr.Is6() {
|
|
node = t.rootV6
|
|
}
|
|
|
|
for node != nil {
|
|
commonLen := commonPrefixLen(addr, node.prefix.Addr())
|
|
nodeBits := node.prefix.Bits()
|
|
|
|
if commonLen < nodeBits {
|
|
// Path has diverged. The prefix cannot be in this subtree.
|
|
return false
|
|
}
|
|
|
|
if pBits < nodeBits {
|
|
// The search prefix is shorter than the node's prefix,
|
|
// but they share a prefix. e.g. search /16, node is /24.
|
|
// The /16 is not explicitly in the trie.
|
|
return false
|
|
}
|
|
|
|
if pBits == nodeBits {
|
|
// Found a node with the exact same prefix length.
|
|
// Because we also know commonLen >= nodeBits, the prefixes are identical.
|
|
return node.isValue
|
|
}
|
|
|
|
// pBits > nodeBits, so we need to go deeper.
|
|
bit := getBit(addr, nodeBits)
|
|
node = node.children[bit]
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// WalkValueFunc is a function called for each prefix in the trie during a walk.
|
|
// Returning false from the function will stop the walk.
|
|
type WalkValueFunc[T any] func(p netip.Prefix, v T) bool
|
|
|
|
// walk is the recursive helper for traversing the trie.
|
|
func walkValue[T any](node *ValueNode[T], f WalkValueFunc[T]) bool {
|
|
if node == nil {
|
|
return true
|
|
}
|
|
|
|
if node.isValue {
|
|
if !f(node.prefix, node.value) {
|
|
return false
|
|
}
|
|
}
|
|
|
|
if node.children[0] != nil {
|
|
if !walkValue(node.children[0], f) {
|
|
return false
|
|
}
|
|
}
|
|
if node.children[1] != nil {
|
|
if !walkValue(node.children[1], f) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// Walk traverses the trie and calls the given function for each prefix and its value.
|
|
// If the function returns false, the walk is stopped. The order is not guaranteed.
|
|
func (t *ValueTrie[T]) Walk(f WalkValueFunc[T]) {
|
|
if !walkValue(t.rootV4, f) {
|
|
return
|
|
}
|
|
walkValue(t.rootV6, f)
|
|
}
|
|
|
|
// Merge inserts all prefixes from another Trie into this one.
|
|
func (t *ValueTrie[T]) Merge(other *ValueTrie[T]) {
|
|
other.Walk(func(p netip.Prefix, v T) bool {
|
|
t.Insert(p, v)
|
|
return true // continue walking
|
|
})
|
|
}
|