Better trie implementations
This commit is contained in:
62
dataset/dnstrie/name.go
Normal file
62
dataset/dnstrie/name.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package dnstrie
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// isValidDomainName validates if the given string is a valid DNS hostname.
|
||||
// A valid hostname consists of a series of labels separated by dots.
|
||||
// Each label must:
|
||||
// - Be between 1 and 63 characters long.
|
||||
// - Consist only of ASCII letters ('a'-'z', 'A'-'Z'), digits ('0'-'9'), and hyphens ('-').
|
||||
// - Not start or end with a hyphen.
|
||||
// The total length of the hostname (including dots) must not exceed 253 characters.
|
||||
func isValidDomainName(host string) bool {
|
||||
// 1. Check total length. The maximum length of a full hostname is 253 characters.
|
||||
if len(host) > 253 {
|
||||
return false
|
||||
}
|
||||
|
||||
// An empty string is not a valid hostname.
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 2. Handle optional trailing dot for FQDNs.
|
||||
// If the hostname ends with a dot, we remove it for validation purposes.
|
||||
if strings.HasSuffix(host, ".") {
|
||||
host = host[:len(host)-1]
|
||||
}
|
||||
|
||||
// After removing a potential trailing dot, the string might be empty.
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 3. Split the hostname into labels.
|
||||
labels := strings.Split(host, ".")
|
||||
|
||||
// 4. Validate each label.
|
||||
for _, label := range labels {
|
||||
// a. Check label length (1 to 63 characters).
|
||||
if len(label) < 1 || len(label) > 63 {
|
||||
return false
|
||||
}
|
||||
|
||||
// b. Check if label starts or ends with a hyphen.
|
||||
if strings.HasPrefix(label, "-") || strings.HasSuffix(label, "-") {
|
||||
return false
|
||||
}
|
||||
|
||||
// c. Check for allowed characters in the label.
|
||||
for _, char := range label {
|
||||
if !unicode.IsLetter(char) && !unicode.IsDigit(char) && char != '-' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If all checks pass, the hostname is valid.
|
||||
return true
|
||||
}
|
42
dataset/dnstrie/name_test.go
Normal file
42
dataset/dnstrie/name_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package dnstrie
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsValidDomainName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
expected bool
|
||||
}{
|
||||
{"Valid Hostname", "example.com", true},
|
||||
{"Valid with Subdomain", "sub.domain.co.uk", true},
|
||||
{"Valid with Hyphen", "my-host-name.org", true},
|
||||
{"Valid with Digits", "app1.server2.net", true},
|
||||
{"Valid FQDN", "example.com.", true},
|
||||
{"Valid Single Label", "localhost", true},
|
||||
{"Valid: Long but within limits", strings.Repeat("a", 63) + "." + strings.Repeat("b", 63) + ".com", true},
|
||||
{"Invalid: Label Too Long", strings.Repeat("a", 64) + ".com", false},
|
||||
{"Invalid: Total Length Too Long", strings.Repeat("a", 60) + "." + strings.Repeat("b", 60) + "." + strings.Repeat("c", 60) + "." + strings.Repeat("d", 60) + "." + strings.Repeat("e", 60) + ".com", false},
|
||||
{"Invalid: Starts with Hyphen", "-invalid.com", false},
|
||||
{"Invalid: Ends with Hyphen", "invalid-.com", false},
|
||||
{"Invalid: Contains Underscore", "my_host.com", false},
|
||||
{"Invalid: Contains Space", "my host.com", false},
|
||||
{"Invalid: Double Dot", "example..com", false},
|
||||
{"Invalid: Starts with Dot", ".example.com", false},
|
||||
{"Invalid: Empty Label", "sub..domain.com", false},
|
||||
{"Invalid: Just a Dot", ".", false},
|
||||
{"Invalid: Empty String", "", false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
got := isValidDomainName(test.host)
|
||||
if got != test.expected {
|
||||
t.Errorf("isValidDomainName(%q) = %v; want %v", test.host, got, test.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
139
dataset/dnstrie/trie.go
Normal file
139
dataset/dnstrie/trie.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package dnstrie
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Node represents a single node in the prefix trie.
|
||||
type Node struct {
|
||||
// children maps the next label (e.g., "example" in "www.example.com")
|
||||
// to the next node in the trie.
|
||||
children map[string]*Node
|
||||
// isEndOfDomain marks if this node represents the end of a valid domain.
|
||||
isEndOfDomain bool
|
||||
}
|
||||
|
||||
// Trie holds the root of the domain prefix trie.
|
||||
type Trie struct {
|
||||
root *Node
|
||||
}
|
||||
|
||||
// New creates and initializes a new Trie.
|
||||
func New() *Trie {
|
||||
return &Trie{
|
||||
root: &Node{
|
||||
children: make(map[string]*Node),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Insert adds a domain name to the trie.
|
||||
// It canonicalizes the domain name before insertion.
|
||||
func (t *Trie) Insert(domain string) error {
|
||||
// Ensure the string is a valid domain name.
|
||||
if !isValidDomainName(domain) {
|
||||
return fmt.Errorf("'%s' is not a valid domain name", domain)
|
||||
}
|
||||
|
||||
// Canonicalize the domain name (lowercase, ensure trailing dot).
|
||||
canonicalDomain := dns.CanonicalName(domain)
|
||||
|
||||
// Split the domain into label offsets. This avoids allocating new strings for labels.
|
||||
// For "www.example.com.", this returns []int{0, 4, 12}
|
||||
offsets := dns.Split(canonicalDomain)
|
||||
if len(offsets) == 0 {
|
||||
return fmt.Errorf("could not split domain name: %s", domain)
|
||||
}
|
||||
|
||||
currentNode := t.root
|
||||
// Iterate through labels from TLD to the most specific label (right to left).
|
||||
for i := len(offsets) - 1; i >= 0; i-- {
|
||||
start := offsets[i]
|
||||
var end int
|
||||
if i == len(offsets)-1 {
|
||||
// Last label, from its start to the end of the string, excluding the final dot.
|
||||
end = len(canonicalDomain) - 1
|
||||
} else {
|
||||
// Intermediate label, from its start to the character before the next label's starting dot.
|
||||
end = offsets[i+1] - 1
|
||||
}
|
||||
label := canonicalDomain[start:end]
|
||||
|
||||
if _, exists := currentNode.children[label]; !exists {
|
||||
// If the label does not exist in the children map, create a new node.
|
||||
currentNode.children[label] = &Node{children: make(map[string]*Node)}
|
||||
}
|
||||
// Move to the next node.
|
||||
currentNode = currentNode.children[label]
|
||||
}
|
||||
|
||||
// Mark the final node as the end of a domain.
|
||||
currentNode.isEndOfDomain = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Contains checks if a domain name exists in the trie.
|
||||
func (t *Trie) Contains(domain string) bool {
|
||||
if !isValidDomainName(domain) {
|
||||
return false
|
||||
}
|
||||
|
||||
canonicalDomain := dns.CanonicalName(domain)
|
||||
offsets := dns.Split(canonicalDomain)
|
||||
if len(offsets) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
currentNode := t.root
|
||||
for i := len(offsets) - 1; i >= 0; i-- {
|
||||
start := offsets[i]
|
||||
var end int
|
||||
if i == len(offsets)-1 {
|
||||
end = len(canonicalDomain) - 1
|
||||
} else {
|
||||
end = offsets[i+1] - 1
|
||||
}
|
||||
label := canonicalDomain[start:end]
|
||||
|
||||
nextNode, exists := currentNode.children[label]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
currentNode = nextNode
|
||||
}
|
||||
|
||||
return currentNode.isEndOfDomain
|
||||
}
|
||||
|
||||
// Merge combines another trie into the current one.
|
||||
//
|
||||
// Domains from the 'other' trie will be added to the current trie.
|
||||
func (t *Trie) Merge(other *Trie) {
|
||||
if other == nil || other.root == nil {
|
||||
return
|
||||
}
|
||||
mergeNodes(t.root, other.root)
|
||||
}
|
||||
|
||||
// mergeNodes is a recursive helper function to merge nodes from another trie.
|
||||
func mergeNodes(localNode, otherNode *Node) {
|
||||
// If the other node marks the end of a domain, the local node should too.
|
||||
if otherNode.isEndOfDomain {
|
||||
localNode.isEndOfDomain = true
|
||||
}
|
||||
|
||||
// Iterate over the children of the other node and merge them.
|
||||
for label, otherChildNode := range otherNode.children {
|
||||
localChildNode, exists := localNode.children[label]
|
||||
if !exists {
|
||||
// If the child doesn't exist in the local trie, just attach the other's child.
|
||||
localNode.children[label] = otherChildNode
|
||||
} else {
|
||||
// If the child already exists, recurse to merge them.
|
||||
mergeNodes(localChildNode, otherChildNode)
|
||||
}
|
||||
}
|
||||
}
|
117
dataset/dnstrie/trie_test.go
Normal file
117
dataset/dnstrie/trie_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package dnstrie
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestTrie(t *testing.T) {
|
||||
t.Run("InsertAndContains", func(t *testing.T) {
|
||||
trie := New()
|
||||
domain := "www.example.com"
|
||||
err := trie.Insert(domain)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error on insert, got %v", err)
|
||||
}
|
||||
|
||||
if !trie.Contains(domain) {
|
||||
t.Errorf("Expected Contains('%s') to be true, but it was false", domain)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ContainsNotFound", func(t *testing.T) {
|
||||
trie := New()
|
||||
err := trie.Insert("example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Insert failed: %v", err)
|
||||
}
|
||||
|
||||
if trie.Contains("nonexistent.com") {
|
||||
t.Error("Expected not to find domain 'nonexistent.com', but did")
|
||||
}
|
||||
|
||||
// Check for a path that exists but is not a terminal node
|
||||
if trie.Contains("com") {
|
||||
t.Error("Expected not to find non-terminal path 'com', but did")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InsertInvalidDomain", func(t *testing.T) {
|
||||
trie := New()
|
||||
err := trie.Insert("not-a-valid-domain-")
|
||||
if err == nil {
|
||||
t.Error("Expected an error when inserting an invalid domain, but got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Canonicalization", func(t *testing.T) {
|
||||
trie := New()
|
||||
// Insert lowercase with trailing dot
|
||||
err := trie.Insert("case.example.org.")
|
||||
if err != nil {
|
||||
t.Fatalf("Insert failed: %v", err)
|
||||
}
|
||||
|
||||
// Check contains with uppercase without trailing dot
|
||||
if !trie.Contains("CASE.EXAMPLE.ORG") {
|
||||
t.Fatal("Failed to find domain with different case and no trailing dot")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MultipleInsertions", func(t *testing.T) {
|
||||
trie := New()
|
||||
domains := []string{
|
||||
"example.com",
|
||||
"www.example.com",
|
||||
"api.example.com",
|
||||
"google.com",
|
||||
}
|
||||
for _, domain := range domains {
|
||||
if err := trie.Insert(domain); err != nil {
|
||||
t.Fatalf("Insert failed for %s: %v", domain, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, domain := range domains {
|
||||
if !trie.Contains(domain) {
|
||||
t.Errorf("Expected to find %s, but did not", domain)
|
||||
}
|
||||
}
|
||||
|
||||
if trie.Contains("ftp.example.com") {
|
||||
t.Error("Found domain 'ftp.example.com' which was not inserted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MergeTries", func(t *testing.T) {
|
||||
trie1 := New()
|
||||
trie1.Insert("example.com")
|
||||
trie1.Insert("sub.example.com")
|
||||
|
||||
trie2 := New()
|
||||
trie2.Insert("google.com")
|
||||
trie2.Insert("sub.example.com") // Overlapping domain
|
||||
trie2.Insert("another.net")
|
||||
|
||||
trie1.Merge(trie2)
|
||||
|
||||
// Test domains from both tries
|
||||
if !trie1.Contains("example.com") {
|
||||
t.Error("Merge failed: trie1 should contain 'example.com'")
|
||||
}
|
||||
if !trie1.Contains("google.com") {
|
||||
t.Error("Merge failed: trie1 should contain 'google.com'")
|
||||
}
|
||||
if !trie1.Contains("sub.example.com") {
|
||||
t.Error("Merge failed: trie1 should contain overlapping 'sub.example.com'")
|
||||
}
|
||||
if !trie1.Contains("another.net") {
|
||||
t.Error("Merge failed: trie1 should contain 'another.net'")
|
||||
}
|
||||
|
||||
// Ensure trie2 is not modified
|
||||
if !trie2.Contains("google.com") {
|
||||
t.Error("Source trie (trie2) should not be modified after merge")
|
||||
}
|
||||
if trie2.Contains("example.com") {
|
||||
t.Error("Source trie (trie2) was modified after merge")
|
||||
}
|
||||
})
|
||||
}
|
186
dataset/dnstrie/valuetrie.go
Normal file
186
dataset/dnstrie/valuetrie.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package dnstrie
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// ValueNode represents a single node in the prefix trie, using generics for the value type.
|
||||
type ValueNode[T any] struct {
|
||||
// children maps the next label (e.g., "example" in "www.example.com")
|
||||
// to the next node in the trie.
|
||||
children map[string]*ValueNode[T]
|
||||
|
||||
// value is the data of generic type T associated with the domain name ending at this node.
|
||||
value T
|
||||
|
||||
// isEndOfDomain marks if this node represents the end of a valid domain.
|
||||
isEndOfDomain bool
|
||||
}
|
||||
|
||||
// ValueTrie holds the root of the domain prefix trie, using generics.
|
||||
type ValueTrie[T any] struct {
|
||||
root *ValueNode[T]
|
||||
}
|
||||
|
||||
// NewValue creates and initializes a new ValueTrie.
|
||||
func NewValue[T any]() *ValueTrie[T] {
|
||||
return &ValueTrie[T]{
|
||||
root: &ValueNode[T]{
|
||||
children: make(map[string]*ValueNode[T]),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Insert adds a domain name and its associated generic value to the trie.
|
||||
// It canonicalizes the domain name before insertion.
|
||||
func (t *ValueTrie[T]) Insert(domain string, value T) error {
|
||||
// Ensure the string is a valid domain name.
|
||||
if !isValidDomainName(domain) {
|
||||
return fmt.Errorf("'%s' is not a valid domain name", domain)
|
||||
}
|
||||
|
||||
// Canonicalize the domain name (lowercase, ensure trailing dot).
|
||||
canonicalDomain := dns.CanonicalName(domain)
|
||||
|
||||
// Split the domain into label offsets. This avoids allocating new strings for labels.
|
||||
// For "www.example.com.", this returns []int{0, 4, 12}
|
||||
offsets := dns.Split(canonicalDomain)
|
||||
if len(offsets) == 0 {
|
||||
return fmt.Errorf("could not split domain name: %s", domain)
|
||||
}
|
||||
|
||||
currentNode := t.root
|
||||
// Iterate through labels from TLD to the most specific label (right to left).
|
||||
for i := len(offsets) - 1; i >= 0; i-- {
|
||||
start := offsets[i]
|
||||
var end int
|
||||
if i == len(offsets)-1 {
|
||||
// Last label, from its start to the end of the string, excluding the final dot.
|
||||
end = len(canonicalDomain) - 1
|
||||
} else {
|
||||
// Intermediate label, from its start to the character before the next label's starting dot.
|
||||
end = offsets[i+1] - 1
|
||||
}
|
||||
label := canonicalDomain[start:end]
|
||||
|
||||
if _, exists := currentNode.children[label]; !exists {
|
||||
// If the label does not exist in the children map, create a new node.
|
||||
currentNode.children[label] = &ValueNode[T]{children: make(map[string]*ValueNode[T])}
|
||||
}
|
||||
// Move to the next node.
|
||||
currentNode = currentNode.children[label]
|
||||
}
|
||||
|
||||
// Mark the final node as the end of a domain and store the value.
|
||||
currentNode.isEndOfDomain = true
|
||||
currentNode.value = value
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Contains checks if a domain name exists in the trie without returning its value.
|
||||
func (t *ValueTrie[T]) Contains(domain string) bool {
|
||||
if !isValidDomainName(domain) {
|
||||
return false
|
||||
}
|
||||
|
||||
canonicalDomain := dns.CanonicalName(domain)
|
||||
offsets := dns.Split(canonicalDomain)
|
||||
if len(offsets) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
currentNode := t.root
|
||||
for i := len(offsets) - 1; i >= 0; i-- {
|
||||
start := offsets[i]
|
||||
var end int
|
||||
if i == len(offsets)-1 {
|
||||
end = len(canonicalDomain) - 1
|
||||
} else {
|
||||
end = offsets[i+1] - 1
|
||||
}
|
||||
label := canonicalDomain[start:end]
|
||||
|
||||
nextNode, exists := currentNode.children[label]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
currentNode = nextNode
|
||||
}
|
||||
|
||||
return currentNode.isEndOfDomain
|
||||
}
|
||||
|
||||
// Search looks for a domain name in the trie.
|
||||
// It returns the associated generic value and a boolean indicating if the domain was found.
|
||||
func (t *ValueTrie[T]) Search(domain string) (T, bool) {
|
||||
var zero T // The zero value for the generic type T.
|
||||
if !isValidDomainName(domain) {
|
||||
return zero, false
|
||||
}
|
||||
|
||||
canonicalDomain := dns.CanonicalName(domain)
|
||||
offsets := dns.Split(canonicalDomain)
|
||||
if len(offsets) == 0 {
|
||||
return zero, false
|
||||
}
|
||||
|
||||
currentNode := t.root
|
||||
for i := len(offsets) - 1; i >= 0; i-- {
|
||||
start := offsets[i]
|
||||
var end int
|
||||
if i == len(offsets)-1 {
|
||||
end = len(canonicalDomain) - 1
|
||||
} else {
|
||||
end = offsets[i+1] - 1
|
||||
}
|
||||
label := canonicalDomain[start:end]
|
||||
|
||||
nextNode, exists := currentNode.children[label]
|
||||
if !exists {
|
||||
// A label in the path was not found, so the domain doesn't exist.
|
||||
return zero, false
|
||||
}
|
||||
currentNode = nextNode
|
||||
}
|
||||
|
||||
// The full path was found, but we must also check if it's a terminal node.
|
||||
// This prevents matching "example.com" if only "www.example.com" was inserted.
|
||||
if currentNode.isEndOfDomain {
|
||||
return currentNode.value, true
|
||||
}
|
||||
|
||||
return zero, false
|
||||
}
|
||||
|
||||
// Merge combines another trie into the current one.
|
||||
//
|
||||
// If a domain exists in both tries, the value from the 'other' trie is used.
|
||||
func (t *ValueTrie[T]) Merge(other *ValueTrie[T]) {
|
||||
if other == nil || other.root == nil {
|
||||
return
|
||||
}
|
||||
mergeValueNodes(t.root, other.root)
|
||||
}
|
||||
|
||||
// mergeNodes is a recursive helper function to merge nodes from another trie.
|
||||
func mergeValueNodes[T any](localNode, otherNode *ValueNode[T]) {
|
||||
// The other node value overwrites the local one.
|
||||
if otherNode.isEndOfDomain {
|
||||
localNode.value = otherNode.value
|
||||
}
|
||||
|
||||
// Iterate over the children of the other node and merge them.
|
||||
for label, otherChildNode := range otherNode.children {
|
||||
localChildNode, exists := localNode.children[label]
|
||||
if !exists {
|
||||
// If the child doesn't exist in the local trie, attach the other's child.
|
||||
localNode.children[label] = otherChildNode
|
||||
} else {
|
||||
// If the child already exists, recurse to merge them.
|
||||
mergeValueNodes(localChildNode, otherChildNode)
|
||||
}
|
||||
}
|
||||
}
|
182
dataset/dnstrie/valuetrie_test.go
Normal file
182
dataset/dnstrie/valuetrie_test.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package dnstrie
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValueTrie(t *testing.T) {
|
||||
t.Run("InsertAndSearchStrings", func(t *testing.T) {
|
||||
trie := NewValue[string]()
|
||||
domain := "www.example.com"
|
||||
value := "192.0.2.1"
|
||||
err := trie.Insert(domain, value)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error on insert, got %v", err)
|
||||
}
|
||||
|
||||
foundValue, found := trie.Search(domain)
|
||||
if !found {
|
||||
t.Fatalf("Expected to find domain '%s', but did not", domain)
|
||||
}
|
||||
if foundValue != value {
|
||||
t.Errorf("Expected value '%s', got '%s'", value, foundValue)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SearchNotFound", func(t *testing.T) {
|
||||
trie := NewValue[string]()
|
||||
err := trie.Insert("example.com", "value")
|
||||
if err != nil {
|
||||
t.Fatalf("Insert failed: %v", err)
|
||||
}
|
||||
|
||||
_, found := trie.Search("nonexistent.com")
|
||||
if found {
|
||||
t.Error("Expected not to find domain 'nonexistent.com', but did")
|
||||
}
|
||||
|
||||
// Search for a path that exists but is not a terminal node
|
||||
_, found = trie.Search("com")
|
||||
if found {
|
||||
t.Error("Expected not to find non-terminal path 'com', but did")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InsertInvalidDomain", func(t *testing.T) {
|
||||
trie := NewValue[string]()
|
||||
err := trie.Insert("not-a-valid-domain-", "value")
|
||||
if err == nil {
|
||||
t.Error("Expected an error when inserting an invalid domain, but got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OverwriteValue", func(t *testing.T) {
|
||||
trie := NewValue[string]()
|
||||
domain := "overwrite.com"
|
||||
initialValue := "first"
|
||||
overwriteValue := "second"
|
||||
|
||||
err := trie.Insert(domain, initialValue)
|
||||
if err != nil {
|
||||
t.Fatalf("Initial insert failed: %v", err)
|
||||
}
|
||||
err = trie.Insert(domain, overwriteValue)
|
||||
if err != nil {
|
||||
t.Fatalf("Overwrite insert failed: %v", err)
|
||||
}
|
||||
|
||||
foundValue, found := trie.Search(domain)
|
||||
if !found {
|
||||
t.Fatalf("Expected to find domain '%s' after overwrite", domain)
|
||||
}
|
||||
if foundValue != overwriteValue {
|
||||
t.Errorf("Expected overwritten value '%s', got '%s'", overwriteValue, foundValue)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Canonicalization", func(t *testing.T) {
|
||||
trie := NewValue[string]()
|
||||
value := "canonical"
|
||||
// Insert lowercase with trailing dot
|
||||
err := trie.Insert("case.example.org.", value)
|
||||
if err != nil {
|
||||
t.Fatalf("Insert failed: %v", err)
|
||||
}
|
||||
|
||||
// Search uppercase without trailing dot
|
||||
foundValue, found := trie.Search("CASE.EXAMPLE.ORG")
|
||||
if !found {
|
||||
t.Fatal("Failed to find domain with different case and no trailing dot")
|
||||
}
|
||||
if foundValue != value {
|
||||
t.Errorf("Expected value '%s' for canonical search, got '%s'", value, foundValue)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InsertAndSearchIntegers", func(t *testing.T) {
|
||||
trie := NewValue[int]()
|
||||
domain := "int.example.com"
|
||||
value := 12345
|
||||
|
||||
err := trie.Insert(domain, value)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error on insert for int trie, got %v", err)
|
||||
}
|
||||
|
||||
foundValue, found := trie.Search(domain)
|
||||
if !found {
|
||||
t.Fatalf("Expected to find domain '%s' in int trie, but did not", domain)
|
||||
}
|
||||
if foundValue != value {
|
||||
t.Errorf("Expected int value %d, got %d", value, foundValue)
|
||||
}
|
||||
|
||||
// Search for a non-existent domain in the int trie
|
||||
_, found = trie.Search("nonexistent.int.example.com")
|
||||
if found {
|
||||
t.Error("Found a nonexistent domain in the int trie")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Contains", func(t *testing.T) {
|
||||
trie := NewValue[string]()
|
||||
domain := "contains.example.com"
|
||||
err := trie.Insert(domain, "some-value")
|
||||
if err != nil {
|
||||
t.Fatalf("Insert failed: %v", err)
|
||||
}
|
||||
|
||||
if !trie.Contains(domain) {
|
||||
t.Errorf("Expected Contains('%s') to be true, but it was false", domain)
|
||||
}
|
||||
|
||||
if trie.Contains("nonexistent." + domain) {
|
||||
t.Errorf("Expected Contains for nonexistent domain to be false, but it was true")
|
||||
}
|
||||
|
||||
if trie.Contains("example.com") {
|
||||
t.Error("Expected Contains for a non-terminal path to be false, but it was true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MergeTries", func(t *testing.T) {
|
||||
trie1 := NewValue[int]()
|
||||
trie1.Insert("example.com", 100)
|
||||
trie1.Insert("sub.example.com", 200)
|
||||
|
||||
trie2 := NewValue[int]()
|
||||
trie2.Insert("google.com", 300)
|
||||
trie2.Insert("sub.example.com", 999) // Overlapping domain, new value
|
||||
trie2.Insert("another.net", 400)
|
||||
|
||||
trie1.Merge(trie2)
|
||||
|
||||
// Test domains from both tries are present
|
||||
if !trie1.Contains("example.com") {
|
||||
t.Error("Merge failed: trie1 should contain 'example.com'")
|
||||
}
|
||||
if !trie1.Contains("google.com") {
|
||||
t.Error("Merge failed: trie1 should contain 'google.com'")
|
||||
}
|
||||
if !trie1.Contains("another.net") {
|
||||
t.Error("Merge failed: trie1 should contain 'another.net'")
|
||||
}
|
||||
|
||||
// Test that overlapping value was updated from trie2
|
||||
val, found := trie1.Search("sub.example.com")
|
||||
if !found || val != 999 {
|
||||
t.Errorf("Expected value for overlapping domain to be 999, but got %d", val)
|
||||
}
|
||||
|
||||
// Test that non-overlapping value from trie1 is intact
|
||||
val, found = trie1.Search("example.com")
|
||||
if !found || val != 100 {
|
||||
t.Errorf("Expected value for 'example.com' to be 100, but got %d", val)
|
||||
}
|
||||
|
||||
// Ensure trie2 is not modified
|
||||
if _, found := trie2.Search("example.com"); found {
|
||||
t.Error("Source trie (trie2) was modified after merge")
|
||||
}
|
||||
})
|
||||
}
|
@@ -1,11 +1,29 @@
|
||||
package dataset
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.maze.io/maze/styx/dataset/dnstrie"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type DomainTrie struct {
|
||||
*dnstrie.ValueTrie[bool]
|
||||
}
|
||||
|
||||
func NewDomainTrie(permit bool, domains ...string) (*DomainTrie, error) {
|
||||
trie := &DomainTrie{
|
||||
ValueTrie: dnstrie.NewValue[bool](),
|
||||
}
|
||||
for _, domain := range domains {
|
||||
if err := trie.Insert(domain, permit); err != nil {
|
||||
return nil, fmt.Errorf("dataset: error inserting %s: %w", domain, err)
|
||||
}
|
||||
}
|
||||
return trie, nil
|
||||
}
|
||||
|
||||
type DomainTree struct {
|
||||
root *domainTreeNode
|
||||
}
|
||||
|
253
dataset/nettrie/trie.go
Normal file
253
dataset/nettrie/trie.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package nettrie
|
||||
|
||||
import "net/netip"
|
||||
|
||||
// Node represents a node in the path-compressed trie.
|
||||
// Each node represents a prefix and can have up to two children.
|
||||
type Node struct {
|
||||
children [2]*Node
|
||||
|
||||
// prefix is the full prefix represented by the path to this node.
|
||||
prefix netip.Prefix
|
||||
|
||||
// isValue marks if this node represents an explicitly inserted prefix.
|
||||
isValue bool
|
||||
}
|
||||
|
||||
// Trie is a path-compressed radix trie that stores network prefixes.
|
||||
type Trie struct {
|
||||
rootV4 *Node
|
||||
rootV6 *Node
|
||||
}
|
||||
|
||||
// New creates and initializes a new Trie.
|
||||
func New() *Trie {
|
||||
return &Trie{}
|
||||
}
|
||||
|
||||
// Insert adds a prefix to the trie.
|
||||
func (t *Trie) Insert(p netip.Prefix) {
|
||||
p = p.Masked()
|
||||
addr := p.Addr()
|
||||
|
||||
if addr.Is4() {
|
||||
t.rootV4 = t.insert(t.rootV4, p)
|
||||
} else {
|
||||
t.rootV6 = t.insert(t.rootV6, p)
|
||||
}
|
||||
}
|
||||
|
||||
// insert is the recursive helper for inserting a prefix into the trie.
|
||||
func (t *Trie) insert(node *Node, p netip.Prefix) *Node {
|
||||
if node == nil {
|
||||
return &Node{prefix: p, isValue: true}
|
||||
}
|
||||
|
||||
addr := p.Addr()
|
||||
commonLen := commonPrefixLen(addr, node.prefix.Addr())
|
||||
pBits := p.Bits()
|
||||
nodeBits := node.prefix.Bits()
|
||||
|
||||
if commonLen > pBits {
|
||||
commonLen = pBits
|
||||
}
|
||||
if commonLen > nodeBits {
|
||||
commonLen = nodeBits
|
||||
}
|
||||
|
||||
if commonLen == nodeBits && commonLen == pBits {
|
||||
// Exact match, mark the node as a value node.
|
||||
node.isValue = true
|
||||
return node
|
||||
}
|
||||
|
||||
if commonLen < nodeBits {
|
||||
// The new prefix diverges from the current node's prefix.
|
||||
// We must split the current node.
|
||||
commonP, _ := node.prefix.Addr().Prefix(commonLen)
|
||||
splitNode := &Node{prefix: commonP}
|
||||
|
||||
// The existing node becomes a child of the new split node.
|
||||
bit := getBit(node.prefix.Addr(), commonLen)
|
||||
splitNode.children[bit] = node
|
||||
|
||||
if commonLen == pBits {
|
||||
// The inserted prefix is a prefix of the node's original prefix.
|
||||
// The new split node represents the inserted prefix.
|
||||
splitNode.isValue = true
|
||||
} else {
|
||||
// The two prefixes diverge. Create a new child for the new prefix.
|
||||
bit := getBit(addr, commonLen)
|
||||
splitNode.children[bit] = &Node{prefix: p, isValue: true}
|
||||
}
|
||||
return splitNode
|
||||
}
|
||||
|
||||
// commonLen == nodeBits, meaning the current node's prefix is a prefix of the new one.
|
||||
// We need to descend to a child.
|
||||
bit := getBit(addr, commonLen)
|
||||
node.children[bit] = t.insert(node.children[bit], p)
|
||||
return node
|
||||
}
|
||||
|
||||
// Delete removes a prefix from the trie. It returns true if the prefix was found and removed.
|
||||
func (t *Trie) Delete(p netip.Prefix) bool {
|
||||
p = p.Masked()
|
||||
addr := p.Addr()
|
||||
|
||||
var changed bool
|
||||
if addr.Is4() {
|
||||
t.rootV4, changed = t.delete(t.rootV4, p)
|
||||
} else {
|
||||
t.rootV6, changed = t.delete(t.rootV6, p)
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
// delete is the recursive helper for removing a prefix from the trie.
|
||||
func (t *Trie) delete(node *Node, p netip.Prefix) (*Node, bool) {
|
||||
if node == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
addr := p.Addr()
|
||||
pBits := p.Bits()
|
||||
nodeBits := node.prefix.Bits()
|
||||
commonLen := commonPrefixLen(addr, node.prefix.Addr())
|
||||
|
||||
// The prefix is not on this path.
|
||||
if commonLen < nodeBits || commonLen < pBits && pBits < nodeBits {
|
||||
return node, false
|
||||
}
|
||||
|
||||
var changed bool
|
||||
if pBits > nodeBits {
|
||||
// The prefix to delete is deeper in the trie. Recurse.
|
||||
bit := getBit(addr, nodeBits)
|
||||
node.children[bit], changed = t.delete(node.children[bit], p)
|
||||
} else if pBits == nodeBits {
|
||||
// This is the node to delete. Unset its value.
|
||||
if !node.isValue {
|
||||
return node, false // Prefix wasn't actually in the trie.
|
||||
}
|
||||
node.isValue = false
|
||||
changed = true
|
||||
} else { // pBits < nodeBits
|
||||
return node, false // Prefix to delete is shorter, so can't be here.
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return node, false
|
||||
}
|
||||
|
||||
// Post-deletion cleanup:
|
||||
// If the node has no value and can be merged with a single child, do so.
|
||||
if !node.isValue {
|
||||
if node.children[0] != nil && node.children[1] == nil {
|
||||
return node.children[0], true
|
||||
}
|
||||
if node.children[0] == nil && node.children[1] != nil {
|
||||
return node.children[1], true
|
||||
}
|
||||
}
|
||||
|
||||
// If the node is now a leaf without a value, it can be removed entirely.
|
||||
if !node.isValue && node.children[0] == nil && node.children[1] == nil {
|
||||
return nil, true
|
||||
}
|
||||
|
||||
return node, true
|
||||
}
|
||||
|
||||
// ContainsPrefix checks if the exact prefix exists in the trie.
|
||||
func (t *Trie) ContainsPrefix(p netip.Prefix) bool {
|
||||
p = p.Masked()
|
||||
addr := p.Addr()
|
||||
pBits := p.Bits()
|
||||
|
||||
node := t.rootV4
|
||||
if addr.Is6() {
|
||||
node = t.rootV6
|
||||
}
|
||||
|
||||
for node != nil {
|
||||
commonLen := commonPrefixLen(addr, node.prefix.Addr())
|
||||
nodeBits := node.prefix.Bits()
|
||||
|
||||
if commonLen < nodeBits {
|
||||
// Path has diverged. The prefix cannot be in this subtree.
|
||||
return false
|
||||
}
|
||||
|
||||
if pBits < nodeBits {
|
||||
// The search prefix is shorter than the node's prefix,
|
||||
// but they share a prefix. e.g. search /16, node is /24.
|
||||
// The /16 is not explicitly in the trie.
|
||||
return false
|
||||
}
|
||||
|
||||
if pBits == nodeBits {
|
||||
// Found a node with the exact same prefix length.
|
||||
// Because we also know commonLen >= nodeBits, the prefixes are identical.
|
||||
return node.isValue
|
||||
}
|
||||
|
||||
// pBits > nodeBits, so we need to go deeper.
|
||||
bit := getBit(addr, nodeBits)
|
||||
node = node.children[bit]
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Contains checks if the exact IP address exists in the trie as a full-length prefix.
|
||||
func (t *Trie) Contains(addr netip.Addr) bool {
|
||||
prefix := netip.PrefixFrom(addr, addr.BitLen())
|
||||
return t.ContainsPrefix(prefix)
|
||||
}
|
||||
|
||||
// WalkFunc is a function called for each prefix in the trie during a walk.
|
||||
// Returning false from the function will stop the walk.
|
||||
type WalkFunc func(p netip.Prefix) bool
|
||||
|
||||
// walk is the recursive helper for traversing the trie.
|
||||
func walk(node *Node, f WalkFunc) bool {
|
||||
if node == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if node.isValue {
|
||||
if !f(node.prefix) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if node.children[0] != nil {
|
||||
if !walk(node.children[0], f) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if node.children[1] != nil {
|
||||
if !walk(node.children[1], f) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Walk traverses the trie and calls the given function for each prefix.
|
||||
// If the function returns false, the walk is stopped. The order is not guaranteed.
|
||||
func (t *Trie) Walk(f WalkFunc) {
|
||||
if !walk(t.rootV4, f) {
|
||||
return
|
||||
}
|
||||
walk(t.rootV6, f)
|
||||
}
|
||||
|
||||
// Merge inserts all prefixes from another Trie into this one.
|
||||
func (t *Trie) Merge(other *Trie) {
|
||||
other.Walk(func(p netip.Prefix) bool {
|
||||
t.Insert(p)
|
||||
return true // continue walking
|
||||
})
|
||||
}
|
248
dataset/nettrie/trie_test.go
Normal file
248
dataset/nettrie/trie_test.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package nettrie
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestTrie_InsertAndContains_IPv4 tests insertion and verifies presence
|
||||
// using ContainsPrefix and Contains for IPv4.
|
||||
func TestTrie_InsertAndContains_IPv4(t *testing.T) {
|
||||
trie := New()
|
||||
|
||||
prefixes := []string{
|
||||
"0.0.0.0/0",
|
||||
"10.0.0.0/8",
|
||||
"192.168.0.0/16",
|
||||
"192.168.1.0/24",
|
||||
"192.168.1.128/25",
|
||||
"192.168.1.200/32",
|
||||
}
|
||||
for _, s := range prefixes {
|
||||
trie.Insert(netip.MustParsePrefix(s))
|
||||
}
|
||||
|
||||
t.Run("Check existing prefixes", func(t *testing.T) {
|
||||
for _, s := range prefixes {
|
||||
p := netip.MustParsePrefix(s)
|
||||
if !trie.ContainsPrefix(p) {
|
||||
t.Errorf("expected trie to contain prefix %s, but it did not", p)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Check non-existing prefixes", func(t *testing.T) {
|
||||
nonExistentPrefixes := []string{
|
||||
"10.0.0.0/9",
|
||||
"192.168.1.0/23",
|
||||
"1.2.3.4/32",
|
||||
}
|
||||
for _, s := range nonExistentPrefixes {
|
||||
p := netip.MustParsePrefix(s)
|
||||
if trie.ContainsPrefix(p) {
|
||||
t.Errorf("expected trie to not contain prefix %s, but it did", p)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Check host addresses with Contains", func(t *testing.T) {
|
||||
if !trie.Contains(netip.MustParseAddr("192.168.1.200")) {
|
||||
t.Error("expected trie to contain host address 192.168.1.200, but it did not")
|
||||
}
|
||||
if trie.Contains(netip.MustParseAddr("192.168.1.201")) {
|
||||
t.Error("expected trie to not contain host address 192.168.1.201, but it did")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestTrie_InsertAndContains_IPv6 tests insertion and verifies presence
|
||||
// using ContainsPrefix and Contains for IPv6.
|
||||
func TestTrie_InsertAndContains_IPv6(t *testing.T) {
|
||||
trie := New()
|
||||
|
||||
prefixes := []string{
|
||||
"::/0",
|
||||
"2001:db8::/32",
|
||||
"2001:db8:acad::/48",
|
||||
"2001:db8:acad:1::/64",
|
||||
"2001:db8:acad:1::1/128",
|
||||
}
|
||||
for _, s := range prefixes {
|
||||
trie.Insert(netip.MustParsePrefix(s))
|
||||
}
|
||||
|
||||
t.Run("Check existing prefixes", func(t *testing.T) {
|
||||
for _, s := range prefixes {
|
||||
p := netip.MustParsePrefix(s)
|
||||
if !trie.ContainsPrefix(p) {
|
||||
t.Errorf("expected trie to contain prefix %s, but it did not", p)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Check non-existing prefixes", func(t *testing.T) {
|
||||
nonExistentPrefixes := []string{
|
||||
"2001:db8::/33",
|
||||
"2001:db8:acad::/47",
|
||||
"2002::/16",
|
||||
}
|
||||
for _, s := range nonExistentPrefixes {
|
||||
p := netip.MustParsePrefix(s)
|
||||
if trie.ContainsPrefix(p) {
|
||||
t.Errorf("expected trie to not contain prefix %s, but it did", p)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Check host addresses with Contains", func(t *testing.T) {
|
||||
if !trie.Contains(netip.MustParseAddr("2001:db8:acad:1::1")) {
|
||||
t.Error("expected trie to contain host address 2001:db8:acad:1::1, but it did not")
|
||||
}
|
||||
if trie.Contains(netip.MustParseAddr("2001:db8:acad:1::2")) {
|
||||
t.Error("expected trie to not contain host address 2001:db8:acad:1::2, but it did")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTrie_Delete(t *testing.T) {
|
||||
trie := New()
|
||||
|
||||
p8 := netip.MustParsePrefix("10.0.0.0/8")
|
||||
p24 := netip.MustParsePrefix("10.0.0.0/24")
|
||||
pOther := netip.MustParsePrefix("10.0.1.0/24")
|
||||
|
||||
trie.Insert(p8)
|
||||
trie.Insert(p24)
|
||||
trie.Insert(pOther)
|
||||
|
||||
t.Run("Delete Non-Existent", func(t *testing.T) {
|
||||
if trie.Delete(netip.MustParsePrefix("1.2.3.4/32")) {
|
||||
t.Error("Delete returned true for a non-existent prefix")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Delete Leaf Node", func(t *testing.T) {
|
||||
if !trie.Delete(p24) {
|
||||
t.Fatal("Delete returned false for an existing prefix")
|
||||
}
|
||||
if trie.ContainsPrefix(p24) {
|
||||
t.Error("ContainsPrefix should be false after delete")
|
||||
}
|
||||
if !trie.ContainsPrefix(p8) {
|
||||
t.Error("parent prefix should not be affected by child delete")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Delete Intermediate Node", func(t *testing.T) {
|
||||
// Re-insert p24 so we can delete its parent p8
|
||||
trie.Insert(p24)
|
||||
|
||||
if !trie.Delete(p8) {
|
||||
t.Fatal("Delete returned false for an existing intermediate prefix")
|
||||
}
|
||||
if trie.ContainsPrefix(p8) {
|
||||
t.Error("ContainsPrefix for deleted intermediate node should be false")
|
||||
}
|
||||
if !trie.ContainsPrefix(p24) {
|
||||
t.Error("child prefix should still exist after parent was deleted")
|
||||
}
|
||||
if !trie.ContainsPrefix(pOther) {
|
||||
t.Error("other prefix should still exist after parent was deleted")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTrie_Walk(t *testing.T) {
|
||||
trie := New()
|
||||
prefixes := []string{
|
||||
"10.0.0.0/8",
|
||||
"192.168.1.0/24",
|
||||
"2001:db8::/32",
|
||||
"172.16.0.0/12",
|
||||
"2001:db8:acad::/48",
|
||||
}
|
||||
for _, s := range prefixes {
|
||||
trie.Insert(netip.MustParsePrefix(s))
|
||||
}
|
||||
|
||||
t.Run("Walk all prefixes", func(t *testing.T) {
|
||||
var walkedPrefixes []string
|
||||
trie.Walk(func(p netip.Prefix) bool {
|
||||
walkedPrefixes = append(walkedPrefixes, p.String())
|
||||
return true
|
||||
})
|
||||
|
||||
if len(walkedPrefixes) != len(prefixes) {
|
||||
t.Fatalf("expected to walk %d prefixes, but got %d", len(prefixes), len(walkedPrefixes))
|
||||
}
|
||||
|
||||
sort.Strings(prefixes)
|
||||
sort.Strings(walkedPrefixes)
|
||||
|
||||
for i := range prefixes {
|
||||
if prefixes[i] != walkedPrefixes[i] {
|
||||
t.Errorf("walked prefixes mismatch: expected %v, got %v", prefixes, walkedPrefixes)
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Stop walk early", func(t *testing.T) {
|
||||
count := 0
|
||||
trie.Walk(func(p netip.Prefix) bool {
|
||||
count++
|
||||
return count < 3 // Stop after visiting 3 prefixes
|
||||
})
|
||||
|
||||
if count != 3 {
|
||||
t.Errorf("expected walk to stop after 3 prefixes, but it visited %d", count)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTrie_Merge(t *testing.T) {
|
||||
trieA := New()
|
||||
trieA.Insert(netip.MustParsePrefix("10.0.0.0/8"))
|
||||
trieA.Insert(netip.MustParsePrefix("192.168.1.0/24"))
|
||||
trieA.Insert(netip.MustParsePrefix("2001:db8::/32")) // v6
|
||||
|
||||
trieB := New()
|
||||
trieB.Insert(netip.MustParsePrefix("10.1.0.0/16"))
|
||||
trieB.Insert(netip.MustParsePrefix("192.168.1.0/24")) // Overlap
|
||||
trieB.Insert(netip.MustParsePrefix("172.16.0.0/12"))
|
||||
trieB.Insert(netip.MustParsePrefix("2001:db8:acad::/48")) // v6
|
||||
|
||||
trieA.Merge(trieB)
|
||||
|
||||
expectedPrefixes := []string{
|
||||
// From A
|
||||
"10.0.0.0/8",
|
||||
"192.168.1.0/24",
|
||||
"2001:db8::/32",
|
||||
// From B
|
||||
"10.1.0.0/16",
|
||||
"172.16.0.0/12",
|
||||
"2001:db8:acad::/48",
|
||||
}
|
||||
|
||||
for _, s := range expectedPrefixes {
|
||||
p := netip.MustParsePrefix(s)
|
||||
if !trieA.ContainsPrefix(p) {
|
||||
t.Errorf("after merge, trieA is missing prefix %s", p)
|
||||
}
|
||||
}
|
||||
|
||||
// Check a prefix that should not be there
|
||||
if trieA.ContainsPrefix(netip.MustParsePrefix("9.9.9.9/32")) {
|
||||
t.Error("trieA contains a prefix it should not have")
|
||||
}
|
||||
|
||||
// Verify trieB is unchanged
|
||||
if !trieB.ContainsPrefix(netip.MustParsePrefix("172.16.0.0/12")) {
|
||||
t.Error("trieB was modified during merge")
|
||||
}
|
||||
if trieB.ContainsPrefix(netip.MustParsePrefix("10.0.0.0/8")) {
|
||||
t.Error("trieB contains a prefix from trieA")
|
||||
}
|
||||
}
|
328
dataset/nettrie/valuetrie.go
Normal file
328
dataset/nettrie/valuetrie.go
Normal file
@@ -0,0 +1,328 @@
|
||||
package nettrie
|
||||
|
||||
import (
|
||||
"math/bits"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// getBit returns the n-th bit of an IP address (0-indexed).
|
||||
func getBit(addr netip.Addr, n int) byte {
|
||||
slice := addr.AsSlice()
|
||||
byteIndex := n / 8
|
||||
bitIndex := 7 - (n % 8)
|
||||
return (slice[byteIndex] >> bitIndex) & 1
|
||||
}
|
||||
|
||||
// commonPrefixLen computes the number of leading bits that are the same for two addresses.
|
||||
func commonPrefixLen(a, b netip.Addr) int {
|
||||
if a.Is4() != b.Is4() {
|
||||
return 0
|
||||
}
|
||||
aSlice := a.AsSlice()
|
||||
bSlice := b.AsSlice()
|
||||
|
||||
commonLen := 0
|
||||
for i := 0; i < len(aSlice); i++ {
|
||||
xor := aSlice[i] ^ bSlice[i]
|
||||
if xor == 0 {
|
||||
commonLen += 8
|
||||
} else {
|
||||
commonLen += bits.LeadingZeros8(xor)
|
||||
return commonLen
|
||||
}
|
||||
}
|
||||
return commonLen
|
||||
}
|
||||
|
||||
// ValueNode represents a node in the path-compressed trie.
|
||||
// Each node represents a prefix and can have up to two children.
|
||||
type ValueNode[T any] struct {
|
||||
children [2]*ValueNode[T]
|
||||
|
||||
// prefix is the full prefix represented by the path to this node.
|
||||
prefix netip.Prefix
|
||||
|
||||
value T
|
||||
isValue bool
|
||||
}
|
||||
|
||||
// ValueTrie is a path-compressed radix trie that stores network prefixes and their values.
|
||||
type ValueTrie[T any] struct {
|
||||
rootV4 *ValueNode[T]
|
||||
rootV6 *ValueNode[T]
|
||||
}
|
||||
|
||||
// NewValue creates and initializes a new ValueTrie.
|
||||
func NewValue[T any]() *ValueTrie[T] {
|
||||
return &ValueTrie[T]{}
|
||||
}
|
||||
|
||||
// Insert adds or updates a prefix in the trie with the given value.
|
||||
func (t *ValueTrie[T]) Insert(p netip.Prefix, value T) {
|
||||
p = p.Masked()
|
||||
addr := p.Addr()
|
||||
|
||||
if addr.Is4() {
|
||||
t.rootV4 = t.insert(t.rootV4, p, value)
|
||||
} else {
|
||||
t.rootV6 = t.insert(t.rootV6, p, value)
|
||||
}
|
||||
}
|
||||
|
||||
// insert is the recursive helper for inserting a prefix into the trie.
|
||||
func (t *ValueTrie[T]) insert(node *ValueNode[T], p netip.Prefix, value T) *ValueNode[T] {
|
||||
if node == nil {
|
||||
return &ValueNode[T]{prefix: p, value: value, isValue: true}
|
||||
}
|
||||
|
||||
addr := p.Addr()
|
||||
commonLen := commonPrefixLen(addr, node.prefix.Addr())
|
||||
pBits := p.Bits()
|
||||
nodeBits := node.prefix.Bits()
|
||||
|
||||
if commonLen > pBits {
|
||||
commonLen = pBits
|
||||
}
|
||||
if commonLen > nodeBits {
|
||||
commonLen = nodeBits
|
||||
}
|
||||
|
||||
if commonLen == nodeBits && commonLen == pBits {
|
||||
// Exact match, update the value.
|
||||
node.value = value
|
||||
node.isValue = true
|
||||
return node
|
||||
}
|
||||
|
||||
if commonLen < nodeBits {
|
||||
// The new prefix diverges from the current node's prefix.
|
||||
// We must split the current node.
|
||||
commonP, _ := node.prefix.Addr().Prefix(commonLen)
|
||||
splitNode := &ValueNode[T]{prefix: commonP}
|
||||
|
||||
// The existing node becomes a child of the new split node.
|
||||
bit := getBit(node.prefix.Addr(), commonLen)
|
||||
splitNode.children[bit] = node
|
||||
|
||||
if commonLen == pBits {
|
||||
// The inserted prefix is a prefix of the node's original prefix.
|
||||
// The new split node represents the inserted prefix and gets the value.
|
||||
splitNode.value = value
|
||||
splitNode.isValue = true
|
||||
} else {
|
||||
// The two prefixes diverge. Create a new child for the new prefix.
|
||||
bit = getBit(addr, commonLen)
|
||||
splitNode.children[bit] = &ValueNode[T]{prefix: p, value: value, isValue: true}
|
||||
}
|
||||
return splitNode
|
||||
}
|
||||
|
||||
// commonLen == nodeBits, meaning the current node's prefix is a prefix of the new one.
|
||||
// We need to descend to a child.
|
||||
bit := getBit(addr, commonLen)
|
||||
node.children[bit] = t.insert(node.children[bit], p, value)
|
||||
return node
|
||||
}
|
||||
|
||||
// Lookup finds the value associated with the most specific prefix that contains the given IP address.
|
||||
func (t *ValueTrie[T]) Lookup(addr netip.Addr) (value T, ok bool) {
|
||||
node := t.rootV4
|
||||
if addr.Is6() {
|
||||
node = t.rootV6
|
||||
}
|
||||
|
||||
var lastFoundValue T
|
||||
var found bool
|
||||
|
||||
for node != nil {
|
||||
commonLen := commonPrefixLen(addr, node.prefix.Addr())
|
||||
nodeBits := node.prefix.Bits()
|
||||
|
||||
// If the address doesn't share a prefix with the node, we can't be in this subtree.
|
||||
if commonLen < nodeBits {
|
||||
break
|
||||
}
|
||||
|
||||
// The address is within this node's prefix. If the node holds a value,
|
||||
// it's our current best match.
|
||||
if node.isValue {
|
||||
lastFoundValue = node.value
|
||||
found = true
|
||||
}
|
||||
|
||||
// We've matched the whole address, can't go deeper.
|
||||
if commonLen == addr.BitLen() {
|
||||
break
|
||||
}
|
||||
|
||||
// Descend to the next child based on the next bit after the node's prefix.
|
||||
bit := getBit(addr, nodeBits)
|
||||
node = node.children[bit]
|
||||
}
|
||||
|
||||
return lastFoundValue, found
|
||||
}
|
||||
|
||||
// Delete removes a prefix from the trie. It returns true if the prefix was found and removed.
|
||||
func (t *ValueTrie[T]) Delete(p netip.Prefix) bool {
|
||||
p = p.Masked()
|
||||
addr := p.Addr()
|
||||
|
||||
var changed bool
|
||||
if addr.Is4() {
|
||||
t.rootV4, changed = t.delete(t.rootV4, p)
|
||||
} else {
|
||||
t.rootV6, changed = t.delete(t.rootV6, p)
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
// delete is the recursive helper for removing a prefix from the trie.
|
||||
func (t *ValueTrie[T]) delete(node *ValueNode[T], p netip.Prefix) (*ValueNode[T], bool) {
|
||||
if node == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
addr := p.Addr()
|
||||
pBits := p.Bits()
|
||||
nodeBits := node.prefix.Bits()
|
||||
commonLen := commonPrefixLen(addr, node.prefix.Addr())
|
||||
|
||||
// The prefix is not on this path.
|
||||
if commonLen < nodeBits || commonLen < pBits && pBits < nodeBits {
|
||||
return node, false
|
||||
}
|
||||
|
||||
var changed bool
|
||||
if pBits > nodeBits {
|
||||
// The prefix to delete is deeper in the trie. Recurse.
|
||||
bit := getBit(addr, nodeBits)
|
||||
node.children[bit], changed = t.delete(node.children[bit], p)
|
||||
} else if pBits == nodeBits {
|
||||
// This is the node to delete. Unset its value.
|
||||
if !node.isValue {
|
||||
return node, false // Prefix wasn't actually in the trie.
|
||||
}
|
||||
node.isValue = false
|
||||
var zero T
|
||||
node.value = zero
|
||||
changed = true
|
||||
} else { // pBits < nodeBits
|
||||
return node, false // Prefix to delete is shorter, so can't be here.
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return node, false
|
||||
}
|
||||
|
||||
// Post-deletion cleanup:
|
||||
// If the node has no value and can be merged with a single child, do so.
|
||||
if !node.isValue {
|
||||
if node.children[0] != nil && node.children[1] == nil {
|
||||
return node.children[0], true
|
||||
}
|
||||
if node.children[0] == nil && node.children[1] != nil {
|
||||
return node.children[1], true
|
||||
}
|
||||
}
|
||||
|
||||
// If the node is now a leaf without a value, it can be removed entirely.
|
||||
if !node.isValue && node.children[0] == nil && node.children[1] == nil {
|
||||
return nil, true
|
||||
}
|
||||
|
||||
return node, true
|
||||
}
|
||||
|
||||
// Contains checks if the exact IP address exists in the trie as a full-length prefix.
|
||||
func (t *ValueTrie[T]) Contains(addr netip.Addr) bool {
|
||||
prefix := netip.PrefixFrom(addr, addr.BitLen())
|
||||
return t.ContainsPrefix(prefix)
|
||||
}
|
||||
|
||||
// ContainsPrefix checks if the exact prefix exists in the trie.
|
||||
func (t *ValueTrie[T]) ContainsPrefix(p netip.Prefix) bool {
|
||||
p = p.Masked()
|
||||
addr := p.Addr()
|
||||
pBits := p.Bits()
|
||||
|
||||
node := t.rootV4
|
||||
if addr.Is6() {
|
||||
node = t.rootV6
|
||||
}
|
||||
|
||||
for node != nil {
|
||||
commonLen := commonPrefixLen(addr, node.prefix.Addr())
|
||||
nodeBits := node.prefix.Bits()
|
||||
|
||||
if commonLen < nodeBits {
|
||||
// Path has diverged. The prefix cannot be in this subtree.
|
||||
return false
|
||||
}
|
||||
|
||||
if pBits < nodeBits {
|
||||
// The search prefix is shorter than the node's prefix,
|
||||
// but they share a prefix. e.g. search /16, node is /24.
|
||||
// The /16 is not explicitly in the trie.
|
||||
return false
|
||||
}
|
||||
|
||||
if pBits == nodeBits {
|
||||
// Found a node with the exact same prefix length.
|
||||
// Because we also know commonLen >= nodeBits, the prefixes are identical.
|
||||
return node.isValue
|
||||
}
|
||||
|
||||
// pBits > nodeBits, so we need to go deeper.
|
||||
bit := getBit(addr, nodeBits)
|
||||
node = node.children[bit]
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// WalkValueFunc is a function called for each prefix in the trie during a walk.
|
||||
// Returning false from the function will stop the walk.
|
||||
type WalkValueFunc[T any] func(p netip.Prefix, v T) bool
|
||||
|
||||
// walk is the recursive helper for traversing the trie.
|
||||
func walkValue[T any](node *ValueNode[T], f WalkValueFunc[T]) bool {
|
||||
if node == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if node.isValue {
|
||||
if !f(node.prefix, node.value) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if node.children[0] != nil {
|
||||
if !walkValue(node.children[0], f) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if node.children[1] != nil {
|
||||
if !walkValue(node.children[1], f) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Walk traverses the trie and calls the given function for each prefix and its value.
|
||||
// If the function returns false, the walk is stopped. The order is not guaranteed.
|
||||
func (t *ValueTrie[T]) Walk(f WalkValueFunc[T]) {
|
||||
if !walkValue(t.rootV4, f) {
|
||||
return
|
||||
}
|
||||
walkValue(t.rootV6, f)
|
||||
}
|
||||
|
||||
// Merge inserts all prefixes from another Trie into this one.
|
||||
func (t *ValueTrie[T]) Merge(other *ValueTrie[T]) {
|
||||
other.Walk(func(p netip.Prefix, v T) bool {
|
||||
t.Insert(p, v)
|
||||
return true // continue walking
|
||||
})
|
||||
}
|
340
dataset/nettrie/valuetrie_test.go
Normal file
340
dataset/nettrie/valuetrie_test.go
Normal file
@@ -0,0 +1,340 @@
|
||||
package nettrie
|
||||
|
||||
import (
|
||||
"maps"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValueTrie_IPv4(t *testing.T) {
|
||||
trie := NewValue[string]()
|
||||
|
||||
routes := map[string]string{
|
||||
"0.0.0.0/0": "Default",
|
||||
"10.0.0.0/8": "Private A",
|
||||
"192.168.0.0/16": "Private B",
|
||||
"192.168.1.0/24": "LAN",
|
||||
"192.168.1.128/25": "Subnet",
|
||||
"192.168.1.200/32": "Host",
|
||||
}
|
||||
|
||||
for s, v := range routes {
|
||||
trie.Insert(netip.MustParsePrefix(s), v)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
lookupAddr string
|
||||
expectedVal string
|
||||
expectedOK bool
|
||||
}{
|
||||
{"Exact Host Match", "192.168.1.200", "Host", true},
|
||||
{"Subnet Match", "192.168.1.201", "Subnet", true},
|
||||
{"LAN Match", "192.168.1.50", "LAN", true},
|
||||
{"Private B Match", "192.168.255.255", "Private B", true},
|
||||
{"Private A Match", "10.255.255.255", "Private A", true},
|
||||
{"Default Match", "8.8.8.8", "Default", true},
|
||||
{"No Match", "203.0.113.1", "Default", true}, // Falls back to default
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
addr := netip.MustParseAddr(tc.lookupAddr)
|
||||
val, ok := trie.Lookup(addr)
|
||||
|
||||
if ok != tc.expectedOK {
|
||||
t.Errorf("expected ok=%v, got ok=%v", tc.expectedOK, ok)
|
||||
}
|
||||
if val != tc.expectedVal {
|
||||
t.Errorf("expected value=%q, got %q", tc.expectedVal, val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValueTrie_IPv6(t *testing.T) {
|
||||
trie := NewValue[string]()
|
||||
|
||||
routes := map[string]string{
|
||||
"::/0": "Default",
|
||||
"2001:db8::/32": "Global Unicast",
|
||||
"2001:db8:acad::/48": "Academic",
|
||||
"2001:db8:acad:1::/64": "CS Department",
|
||||
"2001:db8:acad:1::1/128": "Host Route",
|
||||
}
|
||||
|
||||
for s, v := range routes {
|
||||
trie.Insert(netip.MustParsePrefix(s), v)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
lookupAddr string
|
||||
expectedVal string
|
||||
expectedOK bool
|
||||
}{
|
||||
{"Exact Host Match", "2001:db8:acad:1::1", "Host Route", true},
|
||||
{"CS Dept Match", "2001:db8:acad:1::2", "CS Department", true},
|
||||
{"Academic Match", "2001:db8:acad:2::1", "Academic", true},
|
||||
{"Global Unicast Match", "2001:db8:cafe::1", "Global Unicast", true},
|
||||
{"Default Match", "2606:4700:4700::1111", "Default", true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
addr := netip.MustParseAddr(tc.lookupAddr)
|
||||
val, ok := trie.Lookup(addr)
|
||||
|
||||
if ok != tc.expectedOK {
|
||||
t.Errorf("expected ok=%v, got ok=%v", tc.expectedOK, ok)
|
||||
}
|
||||
if val != tc.expectedVal {
|
||||
t.Errorf("expected value=%q, got %q", tc.expectedVal, val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValueTrie_UpdateValue(t *testing.T) {
|
||||
trie := NewValue[string]()
|
||||
prefix := netip.MustParsePrefix("192.168.1.0/24")
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
|
||||
trie.Insert(prefix, "Initial Value")
|
||||
val, _ := trie.Lookup(addr)
|
||||
if val != "Initial Value" {
|
||||
t.Fatalf("expected initial value, got %q", val)
|
||||
}
|
||||
|
||||
trie.Insert(prefix, "Updated Value")
|
||||
val, _ = trie.Lookup(addr)
|
||||
if val != "Updated Value" {
|
||||
t.Fatalf("expected updated value, got %q", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValueTrie_Delete(t *testing.T) {
|
||||
trie := NewValue[string]()
|
||||
|
||||
trie.Insert(netip.MustParsePrefix("10.0.0.0/8"), "A")
|
||||
trie.Insert(netip.MustParsePrefix("10.0.0.0/24"), "B")
|
||||
trie.Insert(netip.MustParsePrefix("10.0.1.0/24"), "C")
|
||||
|
||||
t.Run("Delete Non-Existent", func(t *testing.T) {
|
||||
if trie.Delete(netip.MustParsePrefix("1.2.3.4/32")) {
|
||||
t.Error("deleted a non-existent prefix")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Delete Leaf Node", func(t *testing.T) {
|
||||
// Delete 10.0.0.0/24, which is a leaf.
|
||||
if !trie.Delete(netip.MustParsePrefix("10.0.0.0/24")) {
|
||||
t.Fatal("failed to delete 10.0.0.0/24")
|
||||
}
|
||||
// Lookup should now resolve to the parent 10.0.0.0/8
|
||||
val, ok := trie.Lookup(netip.MustParseAddr("10.0.0.1"))
|
||||
if !ok || val != "A" {
|
||||
t.Errorf("lookup failed after delete, got val=%q ok=%v, want val=\"A\" ok=true", val, ok)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Delete and Merge", func(t *testing.T) {
|
||||
// Insert a new prefix that causes a split
|
||||
trie.Insert(netip.MustParsePrefix("10.0.0.0/8"), "A-updated")
|
||||
|
||||
// We now have 10.0.0.0/8 and 10.0.1.0/24. Deleting 10.0.1.0/24 should
|
||||
// cause the split node to be removed and merged back into 10.0.0.0/8.
|
||||
if !trie.Delete(netip.MustParsePrefix("10.0.1.0/24")) {
|
||||
t.Fatal("failed to delete 10.0.1.0/24")
|
||||
}
|
||||
|
||||
// A lookup for 10.0.1.1 should now match 10.0.0.0/8
|
||||
val, ok := trie.Lookup(netip.MustParseAddr("10.0.1.1"))
|
||||
if !ok || val != "A-updated" {
|
||||
t.Errorf("lookup failed after merge, got val=%q ok=%v, want val=\"A-updated\" ok=true", val, ok)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Delete Prefix Of Another", func(t *testing.T) {
|
||||
// Delete 10.0.0.0/8
|
||||
trie.Insert(netip.MustParsePrefix("10.0.0.0/8"), "A")
|
||||
trie.Insert(netip.MustParsePrefix("10.1.0.0/16"), "D")
|
||||
|
||||
if !trie.Delete(netip.MustParsePrefix("10.0.0.0/8")) {
|
||||
t.Fatal("failed to delete 10.0.0.0/8")
|
||||
}
|
||||
|
||||
// Lookup for 10.0.0.1 should now fail (no default route)
|
||||
_, ok := trie.Lookup(netip.MustParseAddr("10.0.0.1"))
|
||||
if ok {
|
||||
t.Error("lookup for 10.0.0.1 should have failed")
|
||||
}
|
||||
|
||||
// Lookup for 10.1.0.1 should still succeed
|
||||
val, ok := trie.Lookup(netip.MustParseAddr("10.1.0.1"))
|
||||
if !ok || val != "D" {
|
||||
t.Error("lookup for 10.1.0.1 failed unexpectedly")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestValueTrie_Contains(t *testing.T) {
|
||||
trie := NewValue[string]()
|
||||
|
||||
prefixes := []string{
|
||||
"10.0.0.0/8",
|
||||
"192.168.1.0/24",
|
||||
"192.168.1.200/32",
|
||||
"2001:db8::/32",
|
||||
"2001:db8:acad:1::1/128",
|
||||
}
|
||||
|
||||
for _, s := range prefixes {
|
||||
trie.Insert(netip.MustParsePrefix(s), "present")
|
||||
}
|
||||
|
||||
t.Run("ContainsPrefix", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
prefix string
|
||||
want bool
|
||||
}{
|
||||
{"10.0.0.0/8", true},
|
||||
{"192.168.1.200/32", true},
|
||||
{"2001:db8::/32", true},
|
||||
{"2001:db8:acad:1::1/128", true},
|
||||
{"10.0.0.0/9", false}, // shorter parent, but not exact
|
||||
{"192.168.1.0/25", false}, // non-existent child
|
||||
{"172.16.0.0/12", false}, // completely different prefix
|
||||
{"2001:db8::/48", false}, // non-existent child
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
p := netip.MustParsePrefix(tc.prefix)
|
||||
if got := trie.ContainsPrefix(p); got != tc.want {
|
||||
t.Errorf("ContainsPrefix(%q) = %v, want %v", tc.prefix, got, tc.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Contains", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
addr string
|
||||
want bool
|
||||
}{
|
||||
{"192.168.1.200", true},
|
||||
{"2001:db8:acad:1::1", true},
|
||||
{"192.168.1.201", false}, // In /24 range, but not a /32 host route
|
||||
{"10.0.0.1", false}, // In /8 range, but not a /32 host route
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
a := netip.MustParseAddr(tc.addr)
|
||||
if got := trie.Contains(a); got != tc.want {
|
||||
t.Errorf("Contains(%q) = %v, want %v", tc.addr, got, tc.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestValueTrie_Walk(t *testing.T) {
|
||||
trie := NewValue[string]()
|
||||
prefixes := map[string]string{
|
||||
"10.0.0.0/8": "A",
|
||||
"192.168.1.0/24": "B",
|
||||
"2001:db8::/32": "C",
|
||||
"172.16.0.0/12": "D",
|
||||
"2001:db8:acad::/48": "E",
|
||||
}
|
||||
for s, v := range prefixes {
|
||||
trie.Insert(netip.MustParsePrefix(s), v)
|
||||
}
|
||||
|
||||
t.Run("Walk all prefixes", func(t *testing.T) {
|
||||
walked := make(map[string]string)
|
||||
trie.Walk(func(p netip.Prefix, v string) bool {
|
||||
walked[p.String()] = v
|
||||
return true
|
||||
})
|
||||
|
||||
if !maps.Equal(walked, prefixes) {
|
||||
t.Errorf("walked prefixes mismatch:\nexpected: %v\ngot: %v", prefixes, walked)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Stop walk early", func(t *testing.T) {
|
||||
count := 0
|
||||
trie.Walk(func(p netip.Prefix, v string) bool {
|
||||
count++
|
||||
return count < 3 // Stop after visiting 3 prefixes
|
||||
})
|
||||
|
||||
if count != 3 {
|
||||
t.Errorf("expected walk to stop after 3 prefixes, but it visited %d", count)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Stop walk between families", func(t *testing.T) {
|
||||
stopTrie := NewValue[int]()
|
||||
stopTrie.Insert(netip.MustParsePrefix("10.0.0.0/8"), 1)
|
||||
stopTrie.Insert(netip.MustParsePrefix("2001:db8::/32"), 2)
|
||||
|
||||
count := 0
|
||||
stopTrie.Walk(func(p netip.Prefix, v int) bool {
|
||||
count++
|
||||
return false
|
||||
})
|
||||
|
||||
if count != 1 {
|
||||
t.Errorf("expected walk to stop after 1 prefix, but it visited %d", count)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestValueTrie_Merge(t *testing.T) {
|
||||
trieA := NewValue[string]()
|
||||
trieA.Insert(netip.MustParsePrefix("10.0.0.0/8"), "net10")
|
||||
trieA.Insert(netip.MustParsePrefix("192.168.1.0/24"), "lan_A")
|
||||
trieA.Insert(netip.MustParsePrefix("2001:db8::/32"), "v6_A")
|
||||
|
||||
trieB := NewValue[string]()
|
||||
trieB.Insert(netip.MustParsePrefix("10.1.0.0/16"), "net10_subnet")
|
||||
trieB.Insert(netip.MustParsePrefix("192.168.1.0/24"), "lan_B_override") // Overlap
|
||||
trieB.Insert(netip.MustParsePrefix("172.16.0.0/12"), "corp")
|
||||
trieB.Insert(netip.MustParsePrefix("2001:db8:acad::/48"), "v6_B")
|
||||
|
||||
trieA.Merge(trieB)
|
||||
|
||||
expected := map[string]string{
|
||||
"10.0.0.0/8": "net10",
|
||||
"192.168.1.0/24": "lan_B_override", // Value should be from trieB
|
||||
"2001:db8::/32": "v6_A",
|
||||
"10.1.0.0/16": "net10_subnet",
|
||||
"172.16.0.0/12": "corp",
|
||||
"2001:db8:acad::/48": "v6_B",
|
||||
}
|
||||
|
||||
actual := make(map[string]string)
|
||||
trieA.Walk(func(p netip.Prefix, v string) bool {
|
||||
actual[p.String()] = v
|
||||
return true
|
||||
})
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("Merge result incorrect.\nExpected: %v\nGot: %v", expected, actual)
|
||||
}
|
||||
|
||||
// Verify trieB is unchanged by collecting its prefixes
|
||||
bPrefixes := []string{}
|
||||
trieB.Walk(func(p netip.Prefix, v string) bool {
|
||||
bPrefixes = append(bPrefixes, p.String())
|
||||
return true
|
||||
})
|
||||
expectedBPrefixes := []string{"10.1.0.0/16", "172.16.0.0/12", "192.168.1.0/24", "2001:db8:acad::/48"}
|
||||
slices.Sort(bPrefixes)
|
||||
slices.Sort(expectedBPrefixes)
|
||||
if !reflect.DeepEqual(bPrefixes, expectedBPrefixes) {
|
||||
t.Errorf("trieB was modified during merge.\nExpected: %v\nGot: %v", expectedBPrefixes, bPrefixes)
|
||||
}
|
||||
}
|
@@ -2,10 +2,37 @@ package dataset
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"git.maze.io/maze/styx/dataset/nettrie"
|
||||
"github.com/yl2chen/cidranger"
|
||||
)
|
||||
|
||||
type NetworkTrie struct {
|
||||
*nettrie.Trie
|
||||
}
|
||||
|
||||
func NewNetworkTrie(prefixes ...netip.Prefix) *NetworkTrie {
|
||||
trie := &NetworkTrie{
|
||||
Trie: nettrie.New(),
|
||||
}
|
||||
for _, prefix := range prefixes {
|
||||
trie.Insert(prefix)
|
||||
}
|
||||
return trie
|
||||
}
|
||||
|
||||
func (trie *NetworkTrie) ContainsIP(ip net.IP) bool {
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return trie.Contains(addr)
|
||||
}
|
||||
|
||||
type NetworkTree struct {
|
||||
ranger cidranger.Ranger
|
||||
}
|
||||
|
@@ -1,21 +1,20 @@
|
||||
package dataset
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/dataset/parser"
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite3 driver
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type Storage interface {
|
||||
@@ -27,11 +26,13 @@ type Storage interface {
|
||||
|
||||
Clients() (Clients, error)
|
||||
ClientByID(int64) (Client, error)
|
||||
ClientByIP(net.IP) (Client, error)
|
||||
ClientByAddr(netip.Addr) (Client, error)
|
||||
// ClientByIP(net.IP) (Client, error)
|
||||
SaveClient(*Client) error
|
||||
DeleteClient(Client) error
|
||||
|
||||
Lists() ([]List, error)
|
||||
ListsByGroup(Group) ([]List, error)
|
||||
ListByID(int64) (List, error)
|
||||
SaveList(*List) error
|
||||
DeleteList(List) error
|
||||
@@ -44,7 +45,6 @@ type Group struct {
|
||||
Description string `json:"description"`
|
||||
CreatedAt time.Time `json:"created_at" bstore:"nonzero"`
|
||||
UpdatedAt time.Time `json:"updated_at" bstore:"nonzero"`
|
||||
Storage Storage `json:"-" bstore:"-"`
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
@@ -56,11 +56,6 @@ type Client struct {
|
||||
Groups []Group `json:"groups,omitempty" bstore:"-"`
|
||||
CreatedAt time.Time `json:"created_at" bstore:"nonzero"`
|
||||
UpdatedAt time.Time `json:"updated_at" bstore:"nonzero"`
|
||||
Storage Storage `json:"-" bstore:"-"`
|
||||
}
|
||||
|
||||
type WithClient interface {
|
||||
Client() (Client, error)
|
||||
}
|
||||
|
||||
type ClientGroup struct {
|
||||
@@ -80,6 +75,15 @@ func (c *Client) ContainsIP(ip net.IP) bool {
|
||||
return ipnet.Contains(ip)
|
||||
}
|
||||
|
||||
func (c *Client) ContainsAddr(ip netip.Addr) bool {
|
||||
return c.Prefix().Contains(ip)
|
||||
}
|
||||
|
||||
func (c Client) Prefix() netip.Prefix {
|
||||
ip, _ := netip.ParseAddr(c.IP)
|
||||
return netip.PrefixFrom(ip, c.Mask)
|
||||
}
|
||||
|
||||
func (c *Client) String() string {
|
||||
ipnet := &net.IPNet{
|
||||
IP: net.ParseIP(c.IP),
|
||||
@@ -136,28 +140,29 @@ type List struct {
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
func (list *List) Domains() (*DomainTree, error) {
|
||||
func (list *List) Networks() (*NetworkTrie, error) {
|
||||
if list.Type != ListTypeNetwork {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
prefixes, _, err := parser.ParseNetworks(bytes.NewReader(list.Cache))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewNetworkTrie(prefixes...), nil
|
||||
}
|
||||
|
||||
func (list *List) Domains() (*DomainTrie, error) {
|
||||
if list.Type != ListTypeDomain {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var (
|
||||
tree = NewDomainList()
|
||||
scan = bufio.NewScanner(bytes.NewReader(list.Cache))
|
||||
)
|
||||
for scan.Scan() {
|
||||
line := strings.TrimSpace(scan.Text())
|
||||
if line == "" || line[0] == '#' {
|
||||
continue
|
||||
}
|
||||
if labels, ok := dns.IsDomainName(line); ok && labels >= 2 {
|
||||
tree.Add(line)
|
||||
}
|
||||
}
|
||||
if err := scan.Err(); err != nil {
|
||||
domains, _, err := parser.ParseDomains(bytes.NewReader(list.Cache))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tree, nil
|
||||
|
||||
return NewDomainTrie(list.Permit, domains...)
|
||||
}
|
||||
|
||||
func (list *List) Update() (updated bool, err error) {
|
||||
|
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -20,13 +21,18 @@ type bstoreStorage struct {
|
||||
}
|
||||
|
||||
func OpenBStore(name string) (Storage, error) {
|
||||
log := logger.StandardLog.Value("database", name)
|
||||
|
||||
if !filepath.IsAbs(name) {
|
||||
var err error
|
||||
if name, err = filepath.Abs(name); err != nil {
|
||||
log.Err(err).Error("Opening BoltDB storage failed; invalid path")
|
||||
return nil, err
|
||||
}
|
||||
log = log.Value("database", name)
|
||||
}
|
||||
|
||||
log.Debug("Opening BoltDB storage")
|
||||
ctx := context.Background()
|
||||
db, err := bstore.Open(ctx, name, nil,
|
||||
Group{},
|
||||
@@ -36,6 +42,7 @@ func OpenBStore(name string) (Storage, error) {
|
||||
ListGroup{},
|
||||
)
|
||||
if err != nil {
|
||||
log.Err(err).Error("Opening BoltDB storage failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -47,6 +54,7 @@ func OpenBStore(name string) (Storage, error) {
|
||||
)
|
||||
|
||||
if defaultGroup, err = s.GroupByName("Default"); errors.Is(err, bstore.ErrAbsent) {
|
||||
log.Debug("Creating default group")
|
||||
defaultGroup = Group{
|
||||
Name: "Default",
|
||||
IsEnabled: true,
|
||||
@@ -63,6 +71,7 @@ func OpenBStore(name string) (Storage, error) {
|
||||
FilterFn(func(client Client) bool {
|
||||
return net.ParseIP(client.IP).Equal(net.ParseIP("0.0.0.0")) && client.Mask == 0
|
||||
}).Get(); errors.Is(err, bstore.ErrAbsent) {
|
||||
log.Debug("Creating default IPv4 clients")
|
||||
defaultClient4 = Client{
|
||||
Network: "ipv4",
|
||||
IP: "0.0.0.0",
|
||||
@@ -83,6 +92,7 @@ func OpenBStore(name string) (Storage, error) {
|
||||
FilterFn(func(client Client) bool {
|
||||
return net.ParseIP(client.IP).Equal(net.ParseIP("::")) && client.Mask == 0
|
||||
}).Get(); errors.Is(err, bstore.ErrAbsent) {
|
||||
log.Debug("Creating default IPv6 clients")
|
||||
defaultClient6 = Client{
|
||||
Network: "ipv6",
|
||||
IP: "::",
|
||||
@@ -100,6 +110,7 @@ func OpenBStore(name string) (Storage, error) {
|
||||
}
|
||||
|
||||
// Start updater
|
||||
log.Trace("Starting list updater")
|
||||
NewUpdater(s)
|
||||
|
||||
return s, nil
|
||||
@@ -197,7 +208,12 @@ func (s *bstoreStorage) ClientByID(id int64) (Client, error) {
|
||||
}
|
||||
|
||||
func (s *bstoreStorage) ClientByIP(ip net.IP) (Client, error) {
|
||||
if ip == nil {
|
||||
addr, _ := netip.AddrFromSlice(ip)
|
||||
return s.ClientByAddr(addr)
|
||||
}
|
||||
|
||||
func (s *bstoreStorage) ClientByAddr(addr netip.Addr) (Client, error) {
|
||||
if !addr.IsValid() {
|
||||
return Client{}, ErrNotExist{Object: "client"}
|
||||
}
|
||||
var (
|
||||
@@ -205,9 +221,9 @@ func (s *bstoreStorage) ClientByIP(ip net.IP) (Client, error) {
|
||||
clients Clients
|
||||
network string
|
||||
)
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
if addr.Is4() {
|
||||
network = "ipv4"
|
||||
} else if ip6 := ip.To16(); ip6 != nil {
|
||||
} else {
|
||||
network = "ipv6"
|
||||
}
|
||||
if network == "" {
|
||||
@@ -216,7 +232,7 @@ func (s *bstoreStorage) ClientByIP(ip net.IP) (Client, error) {
|
||||
for client, err := range bstore.QueryDB[Client](ctx, s.db).
|
||||
FilterEqual("Network", network).
|
||||
FilterFn(func(client Client) bool {
|
||||
return client.ContainsIP(ip)
|
||||
return client.ContainsAddr(addr)
|
||||
}).All() {
|
||||
if err != nil {
|
||||
return Client{}, err
|
||||
@@ -320,6 +336,26 @@ func (s *bstoreStorage) Lists() ([]List, error) {
|
||||
return lists, nil
|
||||
}
|
||||
|
||||
func (s *bstoreStorage) ListsByGroup(group Group) ([]List, error) {
|
||||
ctx := context.Background()
|
||||
ids := make([]int64, 0)
|
||||
for item, err := range bstore.QueryDB[ListGroup](ctx, s.db).FilterEqual("GroupID", group.ID).All() {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, item.ListID)
|
||||
}
|
||||
|
||||
var lists []List
|
||||
for list, err := range bstore.QueryDB[List](ctx, s.db).FilterIDs(ids).All() {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lists = append(lists, list)
|
||||
}
|
||||
return lists, nil
|
||||
}
|
||||
|
||||
func (s *bstoreStorage) ListByID(id int64) (List, error) {
|
||||
ctx := context.Background()
|
||||
list, err := bstore.QueryDB[List](ctx, s.db).FilterID(id).Get()
|
||||
|
142
dataset/storage_cache.go
Normal file
142
dataset/storage_cache.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package dataset
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/logger"
|
||||
)
|
||||
|
||||
const MinCacheExpire = 10 * time.Second
|
||||
|
||||
type cache struct {
|
||||
Storage
|
||||
expire time.Duration
|
||||
groupByID sync.Map
|
||||
clientByAddr sync.Map
|
||||
listByGroup sync.Map
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
type cacheItem struct {
|
||||
cachedAt time.Time
|
||||
value any
|
||||
}
|
||||
|
||||
// Cache items returned from a Storage for the specified duration.
|
||||
//
|
||||
// Does not cache negative hits.
|
||||
func Cache(storage Storage, expire time.Duration) Storage {
|
||||
if expire < MinCacheExpire {
|
||||
expire = MinCacheExpire
|
||||
}
|
||||
|
||||
logger.StandardLog.Value("expire", expire).Debug("Caching Storage responses")
|
||||
s := &cache{
|
||||
Storage: storage,
|
||||
expire: expire,
|
||||
closed: make(chan struct{}, 1),
|
||||
}
|
||||
go s.cleanUpTimer()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *cache) cleanUpTimer() {
|
||||
ticker := time.NewTicker(s.expire)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.closed:
|
||||
return
|
||||
|
||||
case now := <-ticker.C:
|
||||
logger.StandardLog.Trace("Cache cleanup running")
|
||||
s.cleanUp(now, &s.groupByID)
|
||||
s.cleanUp(now, &s.clientByAddr)
|
||||
s.cleanUp(now, &s.listByGroup)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *cache) cleanUp(now time.Time, cacheMap *sync.Map) {
|
||||
cacheMap.Range(func(key, item any) bool {
|
||||
cached := item.(cacheItem)
|
||||
if ago := now.Sub(cached.cachedAt); ago >= s.expire {
|
||||
logger.StandardLog.Values(logger.Values{
|
||||
"ago": ago,
|
||||
"type": fmt.Sprintf("%T", cached.value),
|
||||
"item": fmt.Sprintf("%s", cached.value),
|
||||
}).Debug("Cache removing expired item")
|
||||
cacheMap.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (s *cache) load(now time.Time, cacheMap *sync.Map, key any) (value any, ok bool) {
|
||||
var item any
|
||||
if item, ok = cacheMap.Load(key); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
cached := item.(cacheItem)
|
||||
if now.Sub(cached.cachedAt) < s.expire {
|
||||
return cached.value, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (s *cache) save(now time.Time, cacheMap *sync.Map, key, value any) {
|
||||
cacheMap.Store(key, cacheItem{
|
||||
cachedAt: now,
|
||||
value: value,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *cache) GroupByID(id int64) (Group, error) {
|
||||
now := time.Now()
|
||||
if value, ok := s.load(now, &s.groupByID, id); ok {
|
||||
return value.(Group), nil
|
||||
}
|
||||
|
||||
group, err := s.Storage.GroupByID(id)
|
||||
if err == nil {
|
||||
s.save(now, &s.groupByID, id, group)
|
||||
}
|
||||
return group, err
|
||||
}
|
||||
|
||||
func (s *cache) ClientByIP(ip net.IP) (Client, error) {
|
||||
addr, _ := netip.AddrFromSlice(ip)
|
||||
return s.ClientByAddr(addr)
|
||||
}
|
||||
|
||||
func (s *cache) ClientByAddr(ip netip.Addr) (Client, error) {
|
||||
now := time.Now()
|
||||
if value, ok := s.load(now, &s.clientByAddr, ip); ok {
|
||||
return value.(Client), nil
|
||||
}
|
||||
|
||||
client, err := s.Storage.ClientByAddr(ip)
|
||||
if err == nil {
|
||||
s.save(now, &s.clientByAddr, ip, client)
|
||||
}
|
||||
return client, err
|
||||
}
|
||||
|
||||
func (s *cache) ListsByGroup(group Group) ([]List, error) {
|
||||
now := time.Now()
|
||||
if value, ok := s.load(now, &s.listByGroup, group.ID); ok {
|
||||
return value.([]List), nil
|
||||
}
|
||||
|
||||
lists, err := s.Storage.ListsByGroup(group)
|
||||
if err == nil {
|
||||
s.save(now, &s.listByGroup, group.ID, lists)
|
||||
}
|
||||
return lists, err
|
||||
}
|
Reference in New Issue
Block a user