Better trie implementations
This commit is contained in:
253
dataset/nettrie/trie.go
Normal file
253
dataset/nettrie/trie.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package nettrie
|
||||
|
||||
import "net/netip"
|
||||
|
||||
// Node represents a node in the path-compressed trie.
|
||||
// Each node represents a prefix and can have up to two children.
|
||||
type Node struct {
|
||||
children [2]*Node
|
||||
|
||||
// prefix is the full prefix represented by the path to this node.
|
||||
prefix netip.Prefix
|
||||
|
||||
// isValue marks if this node represents an explicitly inserted prefix.
|
||||
isValue bool
|
||||
}
|
||||
|
||||
// Trie is a path-compressed radix trie that stores network prefixes.
|
||||
type Trie struct {
|
||||
rootV4 *Node
|
||||
rootV6 *Node
|
||||
}
|
||||
|
||||
// New creates and initializes a new Trie.
|
||||
func New() *Trie {
|
||||
return &Trie{}
|
||||
}
|
||||
|
||||
// Insert adds a prefix to the trie.
|
||||
func (t *Trie) Insert(p netip.Prefix) {
|
||||
p = p.Masked()
|
||||
addr := p.Addr()
|
||||
|
||||
if addr.Is4() {
|
||||
t.rootV4 = t.insert(t.rootV4, p)
|
||||
} else {
|
||||
t.rootV6 = t.insert(t.rootV6, p)
|
||||
}
|
||||
}
|
||||
|
||||
// insert is the recursive helper for inserting a prefix into the trie.
|
||||
func (t *Trie) insert(node *Node, p netip.Prefix) *Node {
|
||||
if node == nil {
|
||||
return &Node{prefix: p, isValue: true}
|
||||
}
|
||||
|
||||
addr := p.Addr()
|
||||
commonLen := commonPrefixLen(addr, node.prefix.Addr())
|
||||
pBits := p.Bits()
|
||||
nodeBits := node.prefix.Bits()
|
||||
|
||||
if commonLen > pBits {
|
||||
commonLen = pBits
|
||||
}
|
||||
if commonLen > nodeBits {
|
||||
commonLen = nodeBits
|
||||
}
|
||||
|
||||
if commonLen == nodeBits && commonLen == pBits {
|
||||
// Exact match, mark the node as a value node.
|
||||
node.isValue = true
|
||||
return node
|
||||
}
|
||||
|
||||
if commonLen < nodeBits {
|
||||
// The new prefix diverges from the current node's prefix.
|
||||
// We must split the current node.
|
||||
commonP, _ := node.prefix.Addr().Prefix(commonLen)
|
||||
splitNode := &Node{prefix: commonP}
|
||||
|
||||
// The existing node becomes a child of the new split node.
|
||||
bit := getBit(node.prefix.Addr(), commonLen)
|
||||
splitNode.children[bit] = node
|
||||
|
||||
if commonLen == pBits {
|
||||
// The inserted prefix is a prefix of the node's original prefix.
|
||||
// The new split node represents the inserted prefix.
|
||||
splitNode.isValue = true
|
||||
} else {
|
||||
// The two prefixes diverge. Create a new child for the new prefix.
|
||||
bit := getBit(addr, commonLen)
|
||||
splitNode.children[bit] = &Node{prefix: p, isValue: true}
|
||||
}
|
||||
return splitNode
|
||||
}
|
||||
|
||||
// commonLen == nodeBits, meaning the current node's prefix is a prefix of the new one.
|
||||
// We need to descend to a child.
|
||||
bit := getBit(addr, commonLen)
|
||||
node.children[bit] = t.insert(node.children[bit], p)
|
||||
return node
|
||||
}
|
||||
|
||||
// Delete removes a prefix from the trie. It returns true if the prefix was found and removed.
|
||||
func (t *Trie) Delete(p netip.Prefix) bool {
|
||||
p = p.Masked()
|
||||
addr := p.Addr()
|
||||
|
||||
var changed bool
|
||||
if addr.Is4() {
|
||||
t.rootV4, changed = t.delete(t.rootV4, p)
|
||||
} else {
|
||||
t.rootV6, changed = t.delete(t.rootV6, p)
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
// delete is the recursive helper for removing a prefix from the trie.
|
||||
func (t *Trie) delete(node *Node, p netip.Prefix) (*Node, bool) {
|
||||
if node == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
addr := p.Addr()
|
||||
pBits := p.Bits()
|
||||
nodeBits := node.prefix.Bits()
|
||||
commonLen := commonPrefixLen(addr, node.prefix.Addr())
|
||||
|
||||
// The prefix is not on this path.
|
||||
if commonLen < nodeBits || commonLen < pBits && pBits < nodeBits {
|
||||
return node, false
|
||||
}
|
||||
|
||||
var changed bool
|
||||
if pBits > nodeBits {
|
||||
// The prefix to delete is deeper in the trie. Recurse.
|
||||
bit := getBit(addr, nodeBits)
|
||||
node.children[bit], changed = t.delete(node.children[bit], p)
|
||||
} else if pBits == nodeBits {
|
||||
// This is the node to delete. Unset its value.
|
||||
if !node.isValue {
|
||||
return node, false // Prefix wasn't actually in the trie.
|
||||
}
|
||||
node.isValue = false
|
||||
changed = true
|
||||
} else { // pBits < nodeBits
|
||||
return node, false // Prefix to delete is shorter, so can't be here.
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return node, false
|
||||
}
|
||||
|
||||
// Post-deletion cleanup:
|
||||
// If the node has no value and can be merged with a single child, do so.
|
||||
if !node.isValue {
|
||||
if node.children[0] != nil && node.children[1] == nil {
|
||||
return node.children[0], true
|
||||
}
|
||||
if node.children[0] == nil && node.children[1] != nil {
|
||||
return node.children[1], true
|
||||
}
|
||||
}
|
||||
|
||||
// If the node is now a leaf without a value, it can be removed entirely.
|
||||
if !node.isValue && node.children[0] == nil && node.children[1] == nil {
|
||||
return nil, true
|
||||
}
|
||||
|
||||
return node, true
|
||||
}
|
||||
|
||||
// ContainsPrefix checks if the exact prefix exists in the trie.
|
||||
func (t *Trie) ContainsPrefix(p netip.Prefix) bool {
|
||||
p = p.Masked()
|
||||
addr := p.Addr()
|
||||
pBits := p.Bits()
|
||||
|
||||
node := t.rootV4
|
||||
if addr.Is6() {
|
||||
node = t.rootV6
|
||||
}
|
||||
|
||||
for node != nil {
|
||||
commonLen := commonPrefixLen(addr, node.prefix.Addr())
|
||||
nodeBits := node.prefix.Bits()
|
||||
|
||||
if commonLen < nodeBits {
|
||||
// Path has diverged. The prefix cannot be in this subtree.
|
||||
return false
|
||||
}
|
||||
|
||||
if pBits < nodeBits {
|
||||
// The search prefix is shorter than the node's prefix,
|
||||
// but they share a prefix. e.g. search /16, node is /24.
|
||||
// The /16 is not explicitly in the trie.
|
||||
return false
|
||||
}
|
||||
|
||||
if pBits == nodeBits {
|
||||
// Found a node with the exact same prefix length.
|
||||
// Because we also know commonLen >= nodeBits, the prefixes are identical.
|
||||
return node.isValue
|
||||
}
|
||||
|
||||
// pBits > nodeBits, so we need to go deeper.
|
||||
bit := getBit(addr, nodeBits)
|
||||
node = node.children[bit]
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Contains checks if the exact IP address exists in the trie as a full-length prefix.
|
||||
func (t *Trie) Contains(addr netip.Addr) bool {
|
||||
prefix := netip.PrefixFrom(addr, addr.BitLen())
|
||||
return t.ContainsPrefix(prefix)
|
||||
}
|
||||
|
||||
// WalkFunc is a function called for each prefix in the trie during a walk.
|
||||
// Returning false from the function will stop the walk.
|
||||
type WalkFunc func(p netip.Prefix) bool
|
||||
|
||||
// walk is the recursive helper for traversing the trie.
|
||||
func walk(node *Node, f WalkFunc) bool {
|
||||
if node == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if node.isValue {
|
||||
if !f(node.prefix) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if node.children[0] != nil {
|
||||
if !walk(node.children[0], f) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if node.children[1] != nil {
|
||||
if !walk(node.children[1], f) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Walk traverses the trie and calls the given function for each prefix.
|
||||
// If the function returns false, the walk is stopped. The order is not guaranteed.
|
||||
func (t *Trie) Walk(f WalkFunc) {
|
||||
if !walk(t.rootV4, f) {
|
||||
return
|
||||
}
|
||||
walk(t.rootV6, f)
|
||||
}
|
||||
|
||||
// Merge inserts all prefixes from another Trie into this one.
|
||||
func (t *Trie) Merge(other *Trie) {
|
||||
other.Walk(func(p netip.Prefix) bool {
|
||||
t.Insert(p)
|
||||
return true // continue walking
|
||||
})
|
||||
}
|
Reference in New Issue
Block a user