Better trie implementations

This commit is contained in:
2025-10-08 20:57:13 +02:00
parent 5f0f4aa96b
commit 582163d4be
26 changed files with 2482 additions and 122 deletions

253
dataset/nettrie/trie.go Normal file
View File

@@ -0,0 +1,253 @@
package nettrie
import "net/netip"
// Node represents a node in the path-compressed trie.
// Each node represents a prefix and can have up to two children.
type Node struct {
children [2]*Node
// prefix is the full prefix represented by the path to this node.
prefix netip.Prefix
// isValue marks if this node represents an explicitly inserted prefix.
isValue bool
}
// Trie is a path-compressed radix trie that stores network prefixes.
type Trie struct {
rootV4 *Node
rootV6 *Node
}
// New creates and initializes a new Trie.
func New() *Trie {
return &Trie{}
}
// Insert adds a prefix to the trie.
func (t *Trie) Insert(p netip.Prefix) {
p = p.Masked()
addr := p.Addr()
if addr.Is4() {
t.rootV4 = t.insert(t.rootV4, p)
} else {
t.rootV6 = t.insert(t.rootV6, p)
}
}
// insert is the recursive helper for inserting a prefix into the trie.
func (t *Trie) insert(node *Node, p netip.Prefix) *Node {
if node == nil {
return &Node{prefix: p, isValue: true}
}
addr := p.Addr()
commonLen := commonPrefixLen(addr, node.prefix.Addr())
pBits := p.Bits()
nodeBits := node.prefix.Bits()
if commonLen > pBits {
commonLen = pBits
}
if commonLen > nodeBits {
commonLen = nodeBits
}
if commonLen == nodeBits && commonLen == pBits {
// Exact match, mark the node as a value node.
node.isValue = true
return node
}
if commonLen < nodeBits {
// The new prefix diverges from the current node's prefix.
// We must split the current node.
commonP, _ := node.prefix.Addr().Prefix(commonLen)
splitNode := &Node{prefix: commonP}
// The existing node becomes a child of the new split node.
bit := getBit(node.prefix.Addr(), commonLen)
splitNode.children[bit] = node
if commonLen == pBits {
// The inserted prefix is a prefix of the node's original prefix.
// The new split node represents the inserted prefix.
splitNode.isValue = true
} else {
// The two prefixes diverge. Create a new child for the new prefix.
bit := getBit(addr, commonLen)
splitNode.children[bit] = &Node{prefix: p, isValue: true}
}
return splitNode
}
// commonLen == nodeBits, meaning the current node's prefix is a prefix of the new one.
// We need to descend to a child.
bit := getBit(addr, commonLen)
node.children[bit] = t.insert(node.children[bit], p)
return node
}
// Delete removes a prefix from the trie. It returns true if the prefix was found and removed.
func (t *Trie) Delete(p netip.Prefix) bool {
p = p.Masked()
addr := p.Addr()
var changed bool
if addr.Is4() {
t.rootV4, changed = t.delete(t.rootV4, p)
} else {
t.rootV6, changed = t.delete(t.rootV6, p)
}
return changed
}
// delete is the recursive helper for removing a prefix from the trie.
func (t *Trie) delete(node *Node, p netip.Prefix) (*Node, bool) {
if node == nil {
return nil, false
}
addr := p.Addr()
pBits := p.Bits()
nodeBits := node.prefix.Bits()
commonLen := commonPrefixLen(addr, node.prefix.Addr())
// The prefix is not on this path.
if commonLen < nodeBits || commonLen < pBits && pBits < nodeBits {
return node, false
}
var changed bool
if pBits > nodeBits {
// The prefix to delete is deeper in the trie. Recurse.
bit := getBit(addr, nodeBits)
node.children[bit], changed = t.delete(node.children[bit], p)
} else if pBits == nodeBits {
// This is the node to delete. Unset its value.
if !node.isValue {
return node, false // Prefix wasn't actually in the trie.
}
node.isValue = false
changed = true
} else { // pBits < nodeBits
return node, false // Prefix to delete is shorter, so can't be here.
}
if !changed {
return node, false
}
// Post-deletion cleanup:
// If the node has no value and can be merged with a single child, do so.
if !node.isValue {
if node.children[0] != nil && node.children[1] == nil {
return node.children[0], true
}
if node.children[0] == nil && node.children[1] != nil {
return node.children[1], true
}
}
// If the node is now a leaf without a value, it can be removed entirely.
if !node.isValue && node.children[0] == nil && node.children[1] == nil {
return nil, true
}
return node, true
}
// ContainsPrefix checks if the exact prefix exists in the trie.
func (t *Trie) ContainsPrefix(p netip.Prefix) bool {
p = p.Masked()
addr := p.Addr()
pBits := p.Bits()
node := t.rootV4
if addr.Is6() {
node = t.rootV6
}
for node != nil {
commonLen := commonPrefixLen(addr, node.prefix.Addr())
nodeBits := node.prefix.Bits()
if commonLen < nodeBits {
// Path has diverged. The prefix cannot be in this subtree.
return false
}
if pBits < nodeBits {
// The search prefix is shorter than the node's prefix,
// but they share a prefix. e.g. search /16, node is /24.
// The /16 is not explicitly in the trie.
return false
}
if pBits == nodeBits {
// Found a node with the exact same prefix length.
// Because we also know commonLen >= nodeBits, the prefixes are identical.
return node.isValue
}
// pBits > nodeBits, so we need to go deeper.
bit := getBit(addr, nodeBits)
node = node.children[bit]
}
return false
}
// Contains checks if the exact IP address exists in the trie as a full-length prefix.
func (t *Trie) Contains(addr netip.Addr) bool {
prefix := netip.PrefixFrom(addr, addr.BitLen())
return t.ContainsPrefix(prefix)
}
// WalkFunc is a function called for each prefix in the trie during a walk.
// Returning false from the function will stop the walk.
type WalkFunc func(p netip.Prefix) bool
// walk is the recursive helper for traversing the trie.
func walk(node *Node, f WalkFunc) bool {
if node == nil {
return true
}
if node.isValue {
if !f(node.prefix) {
return false
}
}
if node.children[0] != nil {
if !walk(node.children[0], f) {
return false
}
}
if node.children[1] != nil {
if !walk(node.children[1], f) {
return false
}
}
return true
}
// Walk traverses the trie and calls the given function for each prefix.
// If the function returns false, the walk is stopped. The order is not guaranteed.
func (t *Trie) Walk(f WalkFunc) {
if !walk(t.rootV4, f) {
return
}
walk(t.rootV6, f)
}
// Merge inserts all prefixes from another Trie into this one.
func (t *Trie) Merge(other *Trie) {
other.Walk(func(p netip.Prefix) bool {
t.Insert(p)
return true // continue walking
})
}

View File

@@ -0,0 +1,248 @@
package nettrie
import (
"net/netip"
"sort"
"testing"
)
// TestTrie_InsertAndContains_IPv4 tests insertion and verifies presence
// using ContainsPrefix and Contains for IPv4.
func TestTrie_InsertAndContains_IPv4(t *testing.T) {
trie := New()
prefixes := []string{
"0.0.0.0/0",
"10.0.0.0/8",
"192.168.0.0/16",
"192.168.1.0/24",
"192.168.1.128/25",
"192.168.1.200/32",
}
for _, s := range prefixes {
trie.Insert(netip.MustParsePrefix(s))
}
t.Run("Check existing prefixes", func(t *testing.T) {
for _, s := range prefixes {
p := netip.MustParsePrefix(s)
if !trie.ContainsPrefix(p) {
t.Errorf("expected trie to contain prefix %s, but it did not", p)
}
}
})
t.Run("Check non-existing prefixes", func(t *testing.T) {
nonExistentPrefixes := []string{
"10.0.0.0/9",
"192.168.1.0/23",
"1.2.3.4/32",
}
for _, s := range nonExistentPrefixes {
p := netip.MustParsePrefix(s)
if trie.ContainsPrefix(p) {
t.Errorf("expected trie to not contain prefix %s, but it did", p)
}
}
})
t.Run("Check host addresses with Contains", func(t *testing.T) {
if !trie.Contains(netip.MustParseAddr("192.168.1.200")) {
t.Error("expected trie to contain host address 192.168.1.200, but it did not")
}
if trie.Contains(netip.MustParseAddr("192.168.1.201")) {
t.Error("expected trie to not contain host address 192.168.1.201, but it did")
}
})
}
// TestTrie_InsertAndContains_IPv6 tests insertion and verifies presence
// using ContainsPrefix and Contains for IPv6.
func TestTrie_InsertAndContains_IPv6(t *testing.T) {
trie := New()
prefixes := []string{
"::/0",
"2001:db8::/32",
"2001:db8:acad::/48",
"2001:db8:acad:1::/64",
"2001:db8:acad:1::1/128",
}
for _, s := range prefixes {
trie.Insert(netip.MustParsePrefix(s))
}
t.Run("Check existing prefixes", func(t *testing.T) {
for _, s := range prefixes {
p := netip.MustParsePrefix(s)
if !trie.ContainsPrefix(p) {
t.Errorf("expected trie to contain prefix %s, but it did not", p)
}
}
})
t.Run("Check non-existing prefixes", func(t *testing.T) {
nonExistentPrefixes := []string{
"2001:db8::/33",
"2001:db8:acad::/47",
"2002::/16",
}
for _, s := range nonExistentPrefixes {
p := netip.MustParsePrefix(s)
if trie.ContainsPrefix(p) {
t.Errorf("expected trie to not contain prefix %s, but it did", p)
}
}
})
t.Run("Check host addresses with Contains", func(t *testing.T) {
if !trie.Contains(netip.MustParseAddr("2001:db8:acad:1::1")) {
t.Error("expected trie to contain host address 2001:db8:acad:1::1, but it did not")
}
if trie.Contains(netip.MustParseAddr("2001:db8:acad:1::2")) {
t.Error("expected trie to not contain host address 2001:db8:acad:1::2, but it did")
}
})
}
func TestTrie_Delete(t *testing.T) {
trie := New()
p8 := netip.MustParsePrefix("10.0.0.0/8")
p24 := netip.MustParsePrefix("10.0.0.0/24")
pOther := netip.MustParsePrefix("10.0.1.0/24")
trie.Insert(p8)
trie.Insert(p24)
trie.Insert(pOther)
t.Run("Delete Non-Existent", func(t *testing.T) {
if trie.Delete(netip.MustParsePrefix("1.2.3.4/32")) {
t.Error("Delete returned true for a non-existent prefix")
}
})
t.Run("Delete Leaf Node", func(t *testing.T) {
if !trie.Delete(p24) {
t.Fatal("Delete returned false for an existing prefix")
}
if trie.ContainsPrefix(p24) {
t.Error("ContainsPrefix should be false after delete")
}
if !trie.ContainsPrefix(p8) {
t.Error("parent prefix should not be affected by child delete")
}
})
t.Run("Delete Intermediate Node", func(t *testing.T) {
// Re-insert p24 so we can delete its parent p8
trie.Insert(p24)
if !trie.Delete(p8) {
t.Fatal("Delete returned false for an existing intermediate prefix")
}
if trie.ContainsPrefix(p8) {
t.Error("ContainsPrefix for deleted intermediate node should be false")
}
if !trie.ContainsPrefix(p24) {
t.Error("child prefix should still exist after parent was deleted")
}
if !trie.ContainsPrefix(pOther) {
t.Error("other prefix should still exist after parent was deleted")
}
})
}
func TestTrie_Walk(t *testing.T) {
trie := New()
prefixes := []string{
"10.0.0.0/8",
"192.168.1.0/24",
"2001:db8::/32",
"172.16.0.0/12",
"2001:db8:acad::/48",
}
for _, s := range prefixes {
trie.Insert(netip.MustParsePrefix(s))
}
t.Run("Walk all prefixes", func(t *testing.T) {
var walkedPrefixes []string
trie.Walk(func(p netip.Prefix) bool {
walkedPrefixes = append(walkedPrefixes, p.String())
return true
})
if len(walkedPrefixes) != len(prefixes) {
t.Fatalf("expected to walk %d prefixes, but got %d", len(prefixes), len(walkedPrefixes))
}
sort.Strings(prefixes)
sort.Strings(walkedPrefixes)
for i := range prefixes {
if prefixes[i] != walkedPrefixes[i] {
t.Errorf("walked prefixes mismatch: expected %v, got %v", prefixes, walkedPrefixes)
break
}
}
})
t.Run("Stop walk early", func(t *testing.T) {
count := 0
trie.Walk(func(p netip.Prefix) bool {
count++
return count < 3 // Stop after visiting 3 prefixes
})
if count != 3 {
t.Errorf("expected walk to stop after 3 prefixes, but it visited %d", count)
}
})
}
func TestTrie_Merge(t *testing.T) {
trieA := New()
trieA.Insert(netip.MustParsePrefix("10.0.0.0/8"))
trieA.Insert(netip.MustParsePrefix("192.168.1.0/24"))
trieA.Insert(netip.MustParsePrefix("2001:db8::/32")) // v6
trieB := New()
trieB.Insert(netip.MustParsePrefix("10.1.0.0/16"))
trieB.Insert(netip.MustParsePrefix("192.168.1.0/24")) // Overlap
trieB.Insert(netip.MustParsePrefix("172.16.0.0/12"))
trieB.Insert(netip.MustParsePrefix("2001:db8:acad::/48")) // v6
trieA.Merge(trieB)
expectedPrefixes := []string{
// From A
"10.0.0.0/8",
"192.168.1.0/24",
"2001:db8::/32",
// From B
"10.1.0.0/16",
"172.16.0.0/12",
"2001:db8:acad::/48",
}
for _, s := range expectedPrefixes {
p := netip.MustParsePrefix(s)
if !trieA.ContainsPrefix(p) {
t.Errorf("after merge, trieA is missing prefix %s", p)
}
}
// Check a prefix that should not be there
if trieA.ContainsPrefix(netip.MustParsePrefix("9.9.9.9/32")) {
t.Error("trieA contains a prefix it should not have")
}
// Verify trieB is unchanged
if !trieB.ContainsPrefix(netip.MustParsePrefix("172.16.0.0/12")) {
t.Error("trieB was modified during merge")
}
if trieB.ContainsPrefix(netip.MustParsePrefix("10.0.0.0/8")) {
t.Error("trieB contains a prefix from trieA")
}
}

View File

@@ -0,0 +1,328 @@
package nettrie
import (
"math/bits"
"net/netip"
)
// getBit returns the n-th bit of an IP address (0-indexed).
func getBit(addr netip.Addr, n int) byte {
slice := addr.AsSlice()
byteIndex := n / 8
bitIndex := 7 - (n % 8)
return (slice[byteIndex] >> bitIndex) & 1
}
// commonPrefixLen computes the number of leading bits that are the same for two addresses.
func commonPrefixLen(a, b netip.Addr) int {
if a.Is4() != b.Is4() {
return 0
}
aSlice := a.AsSlice()
bSlice := b.AsSlice()
commonLen := 0
for i := 0; i < len(aSlice); i++ {
xor := aSlice[i] ^ bSlice[i]
if xor == 0 {
commonLen += 8
} else {
commonLen += bits.LeadingZeros8(xor)
return commonLen
}
}
return commonLen
}
// ValueNode represents a node in the path-compressed trie.
// Each node represents a prefix and can have up to two children.
type ValueNode[T any] struct {
children [2]*ValueNode[T]
// prefix is the full prefix represented by the path to this node.
prefix netip.Prefix
value T
isValue bool
}
// ValueTrie is a path-compressed radix trie that stores network prefixes and their values.
type ValueTrie[T any] struct {
rootV4 *ValueNode[T]
rootV6 *ValueNode[T]
}
// NewValue creates and initializes a new ValueTrie.
func NewValue[T any]() *ValueTrie[T] {
return &ValueTrie[T]{}
}
// Insert adds or updates a prefix in the trie with the given value.
func (t *ValueTrie[T]) Insert(p netip.Prefix, value T) {
p = p.Masked()
addr := p.Addr()
if addr.Is4() {
t.rootV4 = t.insert(t.rootV4, p, value)
} else {
t.rootV6 = t.insert(t.rootV6, p, value)
}
}
// insert is the recursive helper for inserting a prefix into the trie.
func (t *ValueTrie[T]) insert(node *ValueNode[T], p netip.Prefix, value T) *ValueNode[T] {
if node == nil {
return &ValueNode[T]{prefix: p, value: value, isValue: true}
}
addr := p.Addr()
commonLen := commonPrefixLen(addr, node.prefix.Addr())
pBits := p.Bits()
nodeBits := node.prefix.Bits()
if commonLen > pBits {
commonLen = pBits
}
if commonLen > nodeBits {
commonLen = nodeBits
}
if commonLen == nodeBits && commonLen == pBits {
// Exact match, update the value.
node.value = value
node.isValue = true
return node
}
if commonLen < nodeBits {
// The new prefix diverges from the current node's prefix.
// We must split the current node.
commonP, _ := node.prefix.Addr().Prefix(commonLen)
splitNode := &ValueNode[T]{prefix: commonP}
// The existing node becomes a child of the new split node.
bit := getBit(node.prefix.Addr(), commonLen)
splitNode.children[bit] = node
if commonLen == pBits {
// The inserted prefix is a prefix of the node's original prefix.
// The new split node represents the inserted prefix and gets the value.
splitNode.value = value
splitNode.isValue = true
} else {
// The two prefixes diverge. Create a new child for the new prefix.
bit = getBit(addr, commonLen)
splitNode.children[bit] = &ValueNode[T]{prefix: p, value: value, isValue: true}
}
return splitNode
}
// commonLen == nodeBits, meaning the current node's prefix is a prefix of the new one.
// We need to descend to a child.
bit := getBit(addr, commonLen)
node.children[bit] = t.insert(node.children[bit], p, value)
return node
}
// Lookup finds the value associated with the most specific prefix that contains the given IP address.
func (t *ValueTrie[T]) Lookup(addr netip.Addr) (value T, ok bool) {
node := t.rootV4
if addr.Is6() {
node = t.rootV6
}
var lastFoundValue T
var found bool
for node != nil {
commonLen := commonPrefixLen(addr, node.prefix.Addr())
nodeBits := node.prefix.Bits()
// If the address doesn't share a prefix with the node, we can't be in this subtree.
if commonLen < nodeBits {
break
}
// The address is within this node's prefix. If the node holds a value,
// it's our current best match.
if node.isValue {
lastFoundValue = node.value
found = true
}
// We've matched the whole address, can't go deeper.
if commonLen == addr.BitLen() {
break
}
// Descend to the next child based on the next bit after the node's prefix.
bit := getBit(addr, nodeBits)
node = node.children[bit]
}
return lastFoundValue, found
}
// Delete removes a prefix from the trie. It returns true if the prefix was found and removed.
func (t *ValueTrie[T]) Delete(p netip.Prefix) bool {
p = p.Masked()
addr := p.Addr()
var changed bool
if addr.Is4() {
t.rootV4, changed = t.delete(t.rootV4, p)
} else {
t.rootV6, changed = t.delete(t.rootV6, p)
}
return changed
}
// delete is the recursive helper for removing a prefix from the trie.
func (t *ValueTrie[T]) delete(node *ValueNode[T], p netip.Prefix) (*ValueNode[T], bool) {
if node == nil {
return nil, false
}
addr := p.Addr()
pBits := p.Bits()
nodeBits := node.prefix.Bits()
commonLen := commonPrefixLen(addr, node.prefix.Addr())
// The prefix is not on this path.
if commonLen < nodeBits || commonLen < pBits && pBits < nodeBits {
return node, false
}
var changed bool
if pBits > nodeBits {
// The prefix to delete is deeper in the trie. Recurse.
bit := getBit(addr, nodeBits)
node.children[bit], changed = t.delete(node.children[bit], p)
} else if pBits == nodeBits {
// This is the node to delete. Unset its value.
if !node.isValue {
return node, false // Prefix wasn't actually in the trie.
}
node.isValue = false
var zero T
node.value = zero
changed = true
} else { // pBits < nodeBits
return node, false // Prefix to delete is shorter, so can't be here.
}
if !changed {
return node, false
}
// Post-deletion cleanup:
// If the node has no value and can be merged with a single child, do so.
if !node.isValue {
if node.children[0] != nil && node.children[1] == nil {
return node.children[0], true
}
if node.children[0] == nil && node.children[1] != nil {
return node.children[1], true
}
}
// If the node is now a leaf without a value, it can be removed entirely.
if !node.isValue && node.children[0] == nil && node.children[1] == nil {
return nil, true
}
return node, true
}
// Contains checks if the exact IP address exists in the trie as a full-length prefix.
func (t *ValueTrie[T]) Contains(addr netip.Addr) bool {
prefix := netip.PrefixFrom(addr, addr.BitLen())
return t.ContainsPrefix(prefix)
}
// ContainsPrefix checks if the exact prefix exists in the trie.
func (t *ValueTrie[T]) ContainsPrefix(p netip.Prefix) bool {
p = p.Masked()
addr := p.Addr()
pBits := p.Bits()
node := t.rootV4
if addr.Is6() {
node = t.rootV6
}
for node != nil {
commonLen := commonPrefixLen(addr, node.prefix.Addr())
nodeBits := node.prefix.Bits()
if commonLen < nodeBits {
// Path has diverged. The prefix cannot be in this subtree.
return false
}
if pBits < nodeBits {
// The search prefix is shorter than the node's prefix,
// but they share a prefix. e.g. search /16, node is /24.
// The /16 is not explicitly in the trie.
return false
}
if pBits == nodeBits {
// Found a node with the exact same prefix length.
// Because we also know commonLen >= nodeBits, the prefixes are identical.
return node.isValue
}
// pBits > nodeBits, so we need to go deeper.
bit := getBit(addr, nodeBits)
node = node.children[bit]
}
return false
}
// WalkValueFunc is a function called for each prefix in the trie during a walk.
// Returning false from the function will stop the walk.
type WalkValueFunc[T any] func(p netip.Prefix, v T) bool
// walk is the recursive helper for traversing the trie.
func walkValue[T any](node *ValueNode[T], f WalkValueFunc[T]) bool {
if node == nil {
return true
}
if node.isValue {
if !f(node.prefix, node.value) {
return false
}
}
if node.children[0] != nil {
if !walkValue(node.children[0], f) {
return false
}
}
if node.children[1] != nil {
if !walkValue(node.children[1], f) {
return false
}
}
return true
}
// Walk traverses the trie and calls the given function for each prefix and its value.
// If the function returns false, the walk is stopped. The order is not guaranteed.
func (t *ValueTrie[T]) Walk(f WalkValueFunc[T]) {
if !walkValue(t.rootV4, f) {
return
}
walkValue(t.rootV6, f)
}
// Merge inserts all prefixes from another Trie into this one.
func (t *ValueTrie[T]) Merge(other *ValueTrie[T]) {
other.Walk(func(p netip.Prefix, v T) bool {
t.Insert(p, v)
return true // continue walking
})
}

View File

@@ -0,0 +1,340 @@
package nettrie
import (
"maps"
"net/netip"
"reflect"
"slices"
"testing"
)
func TestValueTrie_IPv4(t *testing.T) {
trie := NewValue[string]()
routes := map[string]string{
"0.0.0.0/0": "Default",
"10.0.0.0/8": "Private A",
"192.168.0.0/16": "Private B",
"192.168.1.0/24": "LAN",
"192.168.1.128/25": "Subnet",
"192.168.1.200/32": "Host",
}
for s, v := range routes {
trie.Insert(netip.MustParsePrefix(s), v)
}
testCases := []struct {
name string
lookupAddr string
expectedVal string
expectedOK bool
}{
{"Exact Host Match", "192.168.1.200", "Host", true},
{"Subnet Match", "192.168.1.201", "Subnet", true},
{"LAN Match", "192.168.1.50", "LAN", true},
{"Private B Match", "192.168.255.255", "Private B", true},
{"Private A Match", "10.255.255.255", "Private A", true},
{"Default Match", "8.8.8.8", "Default", true},
{"No Match", "203.0.113.1", "Default", true}, // Falls back to default
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
addr := netip.MustParseAddr(tc.lookupAddr)
val, ok := trie.Lookup(addr)
if ok != tc.expectedOK {
t.Errorf("expected ok=%v, got ok=%v", tc.expectedOK, ok)
}
if val != tc.expectedVal {
t.Errorf("expected value=%q, got %q", tc.expectedVal, val)
}
})
}
}
func TestValueTrie_IPv6(t *testing.T) {
trie := NewValue[string]()
routes := map[string]string{
"::/0": "Default",
"2001:db8::/32": "Global Unicast",
"2001:db8:acad::/48": "Academic",
"2001:db8:acad:1::/64": "CS Department",
"2001:db8:acad:1::1/128": "Host Route",
}
for s, v := range routes {
trie.Insert(netip.MustParsePrefix(s), v)
}
testCases := []struct {
name string
lookupAddr string
expectedVal string
expectedOK bool
}{
{"Exact Host Match", "2001:db8:acad:1::1", "Host Route", true},
{"CS Dept Match", "2001:db8:acad:1::2", "CS Department", true},
{"Academic Match", "2001:db8:acad:2::1", "Academic", true},
{"Global Unicast Match", "2001:db8:cafe::1", "Global Unicast", true},
{"Default Match", "2606:4700:4700::1111", "Default", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
addr := netip.MustParseAddr(tc.lookupAddr)
val, ok := trie.Lookup(addr)
if ok != tc.expectedOK {
t.Errorf("expected ok=%v, got ok=%v", tc.expectedOK, ok)
}
if val != tc.expectedVal {
t.Errorf("expected value=%q, got %q", tc.expectedVal, val)
}
})
}
}
func TestValueTrie_UpdateValue(t *testing.T) {
trie := NewValue[string]()
prefix := netip.MustParsePrefix("192.168.1.0/24")
addr := netip.MustParseAddr("192.168.1.1")
trie.Insert(prefix, "Initial Value")
val, _ := trie.Lookup(addr)
if val != "Initial Value" {
t.Fatalf("expected initial value, got %q", val)
}
trie.Insert(prefix, "Updated Value")
val, _ = trie.Lookup(addr)
if val != "Updated Value" {
t.Fatalf("expected updated value, got %q", val)
}
}
func TestValueTrie_Delete(t *testing.T) {
trie := NewValue[string]()
trie.Insert(netip.MustParsePrefix("10.0.0.0/8"), "A")
trie.Insert(netip.MustParsePrefix("10.0.0.0/24"), "B")
trie.Insert(netip.MustParsePrefix("10.0.1.0/24"), "C")
t.Run("Delete Non-Existent", func(t *testing.T) {
if trie.Delete(netip.MustParsePrefix("1.2.3.4/32")) {
t.Error("deleted a non-existent prefix")
}
})
t.Run("Delete Leaf Node", func(t *testing.T) {
// Delete 10.0.0.0/24, which is a leaf.
if !trie.Delete(netip.MustParsePrefix("10.0.0.0/24")) {
t.Fatal("failed to delete 10.0.0.0/24")
}
// Lookup should now resolve to the parent 10.0.0.0/8
val, ok := trie.Lookup(netip.MustParseAddr("10.0.0.1"))
if !ok || val != "A" {
t.Errorf("lookup failed after delete, got val=%q ok=%v, want val=\"A\" ok=true", val, ok)
}
})
t.Run("Delete and Merge", func(t *testing.T) {
// Insert a new prefix that causes a split
trie.Insert(netip.MustParsePrefix("10.0.0.0/8"), "A-updated")
// We now have 10.0.0.0/8 and 10.0.1.0/24. Deleting 10.0.1.0/24 should
// cause the split node to be removed and merged back into 10.0.0.0/8.
if !trie.Delete(netip.MustParsePrefix("10.0.1.0/24")) {
t.Fatal("failed to delete 10.0.1.0/24")
}
// A lookup for 10.0.1.1 should now match 10.0.0.0/8
val, ok := trie.Lookup(netip.MustParseAddr("10.0.1.1"))
if !ok || val != "A-updated" {
t.Errorf("lookup failed after merge, got val=%q ok=%v, want val=\"A-updated\" ok=true", val, ok)
}
})
t.Run("Delete Prefix Of Another", func(t *testing.T) {
// Delete 10.0.0.0/8
trie.Insert(netip.MustParsePrefix("10.0.0.0/8"), "A")
trie.Insert(netip.MustParsePrefix("10.1.0.0/16"), "D")
if !trie.Delete(netip.MustParsePrefix("10.0.0.0/8")) {
t.Fatal("failed to delete 10.0.0.0/8")
}
// Lookup for 10.0.0.1 should now fail (no default route)
_, ok := trie.Lookup(netip.MustParseAddr("10.0.0.1"))
if ok {
t.Error("lookup for 10.0.0.1 should have failed")
}
// Lookup for 10.1.0.1 should still succeed
val, ok := trie.Lookup(netip.MustParseAddr("10.1.0.1"))
if !ok || val != "D" {
t.Error("lookup for 10.1.0.1 failed unexpectedly")
}
})
}
func TestValueTrie_Contains(t *testing.T) {
trie := NewValue[string]()
prefixes := []string{
"10.0.0.0/8",
"192.168.1.0/24",
"192.168.1.200/32",
"2001:db8::/32",
"2001:db8:acad:1::1/128",
}
for _, s := range prefixes {
trie.Insert(netip.MustParsePrefix(s), "present")
}
t.Run("ContainsPrefix", func(t *testing.T) {
testCases := []struct {
prefix string
want bool
}{
{"10.0.0.0/8", true},
{"192.168.1.200/32", true},
{"2001:db8::/32", true},
{"2001:db8:acad:1::1/128", true},
{"10.0.0.0/9", false}, // shorter parent, but not exact
{"192.168.1.0/25", false}, // non-existent child
{"172.16.0.0/12", false}, // completely different prefix
{"2001:db8::/48", false}, // non-existent child
}
for _, tc := range testCases {
p := netip.MustParsePrefix(tc.prefix)
if got := trie.ContainsPrefix(p); got != tc.want {
t.Errorf("ContainsPrefix(%q) = %v, want %v", tc.prefix, got, tc.want)
}
}
})
t.Run("Contains", func(t *testing.T) {
testCases := []struct {
addr string
want bool
}{
{"192.168.1.200", true},
{"2001:db8:acad:1::1", true},
{"192.168.1.201", false}, // In /24 range, but not a /32 host route
{"10.0.0.1", false}, // In /8 range, but not a /32 host route
}
for _, tc := range testCases {
a := netip.MustParseAddr(tc.addr)
if got := trie.Contains(a); got != tc.want {
t.Errorf("Contains(%q) = %v, want %v", tc.addr, got, tc.want)
}
}
})
}
func TestValueTrie_Walk(t *testing.T) {
trie := NewValue[string]()
prefixes := map[string]string{
"10.0.0.0/8": "A",
"192.168.1.0/24": "B",
"2001:db8::/32": "C",
"172.16.0.0/12": "D",
"2001:db8:acad::/48": "E",
}
for s, v := range prefixes {
trie.Insert(netip.MustParsePrefix(s), v)
}
t.Run("Walk all prefixes", func(t *testing.T) {
walked := make(map[string]string)
trie.Walk(func(p netip.Prefix, v string) bool {
walked[p.String()] = v
return true
})
if !maps.Equal(walked, prefixes) {
t.Errorf("walked prefixes mismatch:\nexpected: %v\ngot: %v", prefixes, walked)
}
})
t.Run("Stop walk early", func(t *testing.T) {
count := 0
trie.Walk(func(p netip.Prefix, v string) bool {
count++
return count < 3 // Stop after visiting 3 prefixes
})
if count != 3 {
t.Errorf("expected walk to stop after 3 prefixes, but it visited %d", count)
}
})
t.Run("Stop walk between families", func(t *testing.T) {
stopTrie := NewValue[int]()
stopTrie.Insert(netip.MustParsePrefix("10.0.0.0/8"), 1)
stopTrie.Insert(netip.MustParsePrefix("2001:db8::/32"), 2)
count := 0
stopTrie.Walk(func(p netip.Prefix, v int) bool {
count++
return false
})
if count != 1 {
t.Errorf("expected walk to stop after 1 prefix, but it visited %d", count)
}
})
}
func TestValueTrie_Merge(t *testing.T) {
trieA := NewValue[string]()
trieA.Insert(netip.MustParsePrefix("10.0.0.0/8"), "net10")
trieA.Insert(netip.MustParsePrefix("192.168.1.0/24"), "lan_A")
trieA.Insert(netip.MustParsePrefix("2001:db8::/32"), "v6_A")
trieB := NewValue[string]()
trieB.Insert(netip.MustParsePrefix("10.1.0.0/16"), "net10_subnet")
trieB.Insert(netip.MustParsePrefix("192.168.1.0/24"), "lan_B_override") // Overlap
trieB.Insert(netip.MustParsePrefix("172.16.0.0/12"), "corp")
trieB.Insert(netip.MustParsePrefix("2001:db8:acad::/48"), "v6_B")
trieA.Merge(trieB)
expected := map[string]string{
"10.0.0.0/8": "net10",
"192.168.1.0/24": "lan_B_override", // Value should be from trieB
"2001:db8::/32": "v6_A",
"10.1.0.0/16": "net10_subnet",
"172.16.0.0/12": "corp",
"2001:db8:acad::/48": "v6_B",
}
actual := make(map[string]string)
trieA.Walk(func(p netip.Prefix, v string) bool {
actual[p.String()] = v
return true
})
if !reflect.DeepEqual(expected, actual) {
t.Errorf("Merge result incorrect.\nExpected: %v\nGot: %v", expected, actual)
}
// Verify trieB is unchanged by collecting its prefixes
bPrefixes := []string{}
trieB.Walk(func(p netip.Prefix, v string) bool {
bPrefixes = append(bPrefixes, p.String())
return true
})
expectedBPrefixes := []string{"10.1.0.0/16", "172.16.0.0/12", "192.168.1.0/24", "2001:db8:acad::/48"}
slices.Sort(bPrefixes)
slices.Sort(expectedBPrefixes)
if !reflect.DeepEqual(bPrefixes, expectedBPrefixes) {
t.Errorf("trieB was modified during merge.\nExpected: %v\nGot: %v", expectedBPrefixes, bPrefixes)
}
}