Better trie implementations
This commit is contained in:
62
dataset/dnstrie/name.go
Normal file
62
dataset/dnstrie/name.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package dnstrie
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// isValidDomainName validates if the given string is a valid DNS hostname.
|
||||
// A valid hostname consists of a series of labels separated by dots.
|
||||
// Each label must:
|
||||
// - Be between 1 and 63 characters long.
|
||||
// - Consist only of ASCII letters ('a'-'z', 'A'-'Z'), digits ('0'-'9'), and hyphens ('-').
|
||||
// - Not start or end with a hyphen.
|
||||
// The total length of the hostname (including dots) must not exceed 253 characters.
|
||||
func isValidDomainName(host string) bool {
|
||||
// 1. Check total length. The maximum length of a full hostname is 253 characters.
|
||||
if len(host) > 253 {
|
||||
return false
|
||||
}
|
||||
|
||||
// An empty string is not a valid hostname.
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 2. Handle optional trailing dot for FQDNs.
|
||||
// If the hostname ends with a dot, we remove it for validation purposes.
|
||||
if strings.HasSuffix(host, ".") {
|
||||
host = host[:len(host)-1]
|
||||
}
|
||||
|
||||
// After removing a potential trailing dot, the string might be empty.
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 3. Split the hostname into labels.
|
||||
labels := strings.Split(host, ".")
|
||||
|
||||
// 4. Validate each label.
|
||||
for _, label := range labels {
|
||||
// a. Check label length (1 to 63 characters).
|
||||
if len(label) < 1 || len(label) > 63 {
|
||||
return false
|
||||
}
|
||||
|
||||
// b. Check if label starts or ends with a hyphen.
|
||||
if strings.HasPrefix(label, "-") || strings.HasSuffix(label, "-") {
|
||||
return false
|
||||
}
|
||||
|
||||
// c. Check for allowed characters in the label.
|
||||
for _, char := range label {
|
||||
if !unicode.IsLetter(char) && !unicode.IsDigit(char) && char != '-' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If all checks pass, the hostname is valid.
|
||||
return true
|
||||
}
|
42
dataset/dnstrie/name_test.go
Normal file
42
dataset/dnstrie/name_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package dnstrie
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsValidDomainName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
expected bool
|
||||
}{
|
||||
{"Valid Hostname", "example.com", true},
|
||||
{"Valid with Subdomain", "sub.domain.co.uk", true},
|
||||
{"Valid with Hyphen", "my-host-name.org", true},
|
||||
{"Valid with Digits", "app1.server2.net", true},
|
||||
{"Valid FQDN", "example.com.", true},
|
||||
{"Valid Single Label", "localhost", true},
|
||||
{"Valid: Long but within limits", strings.Repeat("a", 63) + "." + strings.Repeat("b", 63) + ".com", true},
|
||||
{"Invalid: Label Too Long", strings.Repeat("a", 64) + ".com", false},
|
||||
{"Invalid: Total Length Too Long", strings.Repeat("a", 60) + "." + strings.Repeat("b", 60) + "." + strings.Repeat("c", 60) + "." + strings.Repeat("d", 60) + "." + strings.Repeat("e", 60) + ".com", false},
|
||||
{"Invalid: Starts with Hyphen", "-invalid.com", false},
|
||||
{"Invalid: Ends with Hyphen", "invalid-.com", false},
|
||||
{"Invalid: Contains Underscore", "my_host.com", false},
|
||||
{"Invalid: Contains Space", "my host.com", false},
|
||||
{"Invalid: Double Dot", "example..com", false},
|
||||
{"Invalid: Starts with Dot", ".example.com", false},
|
||||
{"Invalid: Empty Label", "sub..domain.com", false},
|
||||
{"Invalid: Just a Dot", ".", false},
|
||||
{"Invalid: Empty String", "", false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
got := isValidDomainName(test.host)
|
||||
if got != test.expected {
|
||||
t.Errorf("isValidDomainName(%q) = %v; want %v", test.host, got, test.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
139
dataset/dnstrie/trie.go
Normal file
139
dataset/dnstrie/trie.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package dnstrie
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Node represents a single node in the prefix trie.
|
||||
type Node struct {
|
||||
// children maps the next label (e.g., "example" in "www.example.com")
|
||||
// to the next node in the trie.
|
||||
children map[string]*Node
|
||||
// isEndOfDomain marks if this node represents the end of a valid domain.
|
||||
isEndOfDomain bool
|
||||
}
|
||||
|
||||
// Trie holds the root of the domain prefix trie.
|
||||
type Trie struct {
|
||||
root *Node
|
||||
}
|
||||
|
||||
// New creates and initializes a new Trie.
|
||||
func New() *Trie {
|
||||
return &Trie{
|
||||
root: &Node{
|
||||
children: make(map[string]*Node),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Insert adds a domain name to the trie.
|
||||
// It canonicalizes the domain name before insertion.
|
||||
func (t *Trie) Insert(domain string) error {
|
||||
// Ensure the string is a valid domain name.
|
||||
if !isValidDomainName(domain) {
|
||||
return fmt.Errorf("'%s' is not a valid domain name", domain)
|
||||
}
|
||||
|
||||
// Canonicalize the domain name (lowercase, ensure trailing dot).
|
||||
canonicalDomain := dns.CanonicalName(domain)
|
||||
|
||||
// Split the domain into label offsets. This avoids allocating new strings for labels.
|
||||
// For "www.example.com.", this returns []int{0, 4, 12}
|
||||
offsets := dns.Split(canonicalDomain)
|
||||
if len(offsets) == 0 {
|
||||
return fmt.Errorf("could not split domain name: %s", domain)
|
||||
}
|
||||
|
||||
currentNode := t.root
|
||||
// Iterate through labels from TLD to the most specific label (right to left).
|
||||
for i := len(offsets) - 1; i >= 0; i-- {
|
||||
start := offsets[i]
|
||||
var end int
|
||||
if i == len(offsets)-1 {
|
||||
// Last label, from its start to the end of the string, excluding the final dot.
|
||||
end = len(canonicalDomain) - 1
|
||||
} else {
|
||||
// Intermediate label, from its start to the character before the next label's starting dot.
|
||||
end = offsets[i+1] - 1
|
||||
}
|
||||
label := canonicalDomain[start:end]
|
||||
|
||||
if _, exists := currentNode.children[label]; !exists {
|
||||
// If the label does not exist in the children map, create a new node.
|
||||
currentNode.children[label] = &Node{children: make(map[string]*Node)}
|
||||
}
|
||||
// Move to the next node.
|
||||
currentNode = currentNode.children[label]
|
||||
}
|
||||
|
||||
// Mark the final node as the end of a domain.
|
||||
currentNode.isEndOfDomain = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Contains checks if a domain name exists in the trie.
|
||||
func (t *Trie) Contains(domain string) bool {
|
||||
if !isValidDomainName(domain) {
|
||||
return false
|
||||
}
|
||||
|
||||
canonicalDomain := dns.CanonicalName(domain)
|
||||
offsets := dns.Split(canonicalDomain)
|
||||
if len(offsets) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
currentNode := t.root
|
||||
for i := len(offsets) - 1; i >= 0; i-- {
|
||||
start := offsets[i]
|
||||
var end int
|
||||
if i == len(offsets)-1 {
|
||||
end = len(canonicalDomain) - 1
|
||||
} else {
|
||||
end = offsets[i+1] - 1
|
||||
}
|
||||
label := canonicalDomain[start:end]
|
||||
|
||||
nextNode, exists := currentNode.children[label]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
currentNode = nextNode
|
||||
}
|
||||
|
||||
return currentNode.isEndOfDomain
|
||||
}
|
||||
|
||||
// Merge combines another trie into the current one.
|
||||
//
|
||||
// Domains from the 'other' trie will be added to the current trie.
|
||||
func (t *Trie) Merge(other *Trie) {
|
||||
if other == nil || other.root == nil {
|
||||
return
|
||||
}
|
||||
mergeNodes(t.root, other.root)
|
||||
}
|
||||
|
||||
// mergeNodes is a recursive helper function to merge nodes from another trie.
|
||||
func mergeNodes(localNode, otherNode *Node) {
|
||||
// If the other node marks the end of a domain, the local node should too.
|
||||
if otherNode.isEndOfDomain {
|
||||
localNode.isEndOfDomain = true
|
||||
}
|
||||
|
||||
// Iterate over the children of the other node and merge them.
|
||||
for label, otherChildNode := range otherNode.children {
|
||||
localChildNode, exists := localNode.children[label]
|
||||
if !exists {
|
||||
// If the child doesn't exist in the local trie, just attach the other's child.
|
||||
localNode.children[label] = otherChildNode
|
||||
} else {
|
||||
// If the child already exists, recurse to merge them.
|
||||
mergeNodes(localChildNode, otherChildNode)
|
||||
}
|
||||
}
|
||||
}
|
117
dataset/dnstrie/trie_test.go
Normal file
117
dataset/dnstrie/trie_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package dnstrie
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestTrie(t *testing.T) {
|
||||
t.Run("InsertAndContains", func(t *testing.T) {
|
||||
trie := New()
|
||||
domain := "www.example.com"
|
||||
err := trie.Insert(domain)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error on insert, got %v", err)
|
||||
}
|
||||
|
||||
if !trie.Contains(domain) {
|
||||
t.Errorf("Expected Contains('%s') to be true, but it was false", domain)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ContainsNotFound", func(t *testing.T) {
|
||||
trie := New()
|
||||
err := trie.Insert("example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Insert failed: %v", err)
|
||||
}
|
||||
|
||||
if trie.Contains("nonexistent.com") {
|
||||
t.Error("Expected not to find domain 'nonexistent.com', but did")
|
||||
}
|
||||
|
||||
// Check for a path that exists but is not a terminal node
|
||||
if trie.Contains("com") {
|
||||
t.Error("Expected not to find non-terminal path 'com', but did")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InsertInvalidDomain", func(t *testing.T) {
|
||||
trie := New()
|
||||
err := trie.Insert("not-a-valid-domain-")
|
||||
if err == nil {
|
||||
t.Error("Expected an error when inserting an invalid domain, but got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Canonicalization", func(t *testing.T) {
|
||||
trie := New()
|
||||
// Insert lowercase with trailing dot
|
||||
err := trie.Insert("case.example.org.")
|
||||
if err != nil {
|
||||
t.Fatalf("Insert failed: %v", err)
|
||||
}
|
||||
|
||||
// Check contains with uppercase without trailing dot
|
||||
if !trie.Contains("CASE.EXAMPLE.ORG") {
|
||||
t.Fatal("Failed to find domain with different case and no trailing dot")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MultipleInsertions", func(t *testing.T) {
|
||||
trie := New()
|
||||
domains := []string{
|
||||
"example.com",
|
||||
"www.example.com",
|
||||
"api.example.com",
|
||||
"google.com",
|
||||
}
|
||||
for _, domain := range domains {
|
||||
if err := trie.Insert(domain); err != nil {
|
||||
t.Fatalf("Insert failed for %s: %v", domain, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, domain := range domains {
|
||||
if !trie.Contains(domain) {
|
||||
t.Errorf("Expected to find %s, but did not", domain)
|
||||
}
|
||||
}
|
||||
|
||||
if trie.Contains("ftp.example.com") {
|
||||
t.Error("Found domain 'ftp.example.com' which was not inserted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MergeTries", func(t *testing.T) {
|
||||
trie1 := New()
|
||||
trie1.Insert("example.com")
|
||||
trie1.Insert("sub.example.com")
|
||||
|
||||
trie2 := New()
|
||||
trie2.Insert("google.com")
|
||||
trie2.Insert("sub.example.com") // Overlapping domain
|
||||
trie2.Insert("another.net")
|
||||
|
||||
trie1.Merge(trie2)
|
||||
|
||||
// Test domains from both tries
|
||||
if !trie1.Contains("example.com") {
|
||||
t.Error("Merge failed: trie1 should contain 'example.com'")
|
||||
}
|
||||
if !trie1.Contains("google.com") {
|
||||
t.Error("Merge failed: trie1 should contain 'google.com'")
|
||||
}
|
||||
if !trie1.Contains("sub.example.com") {
|
||||
t.Error("Merge failed: trie1 should contain overlapping 'sub.example.com'")
|
||||
}
|
||||
if !trie1.Contains("another.net") {
|
||||
t.Error("Merge failed: trie1 should contain 'another.net'")
|
||||
}
|
||||
|
||||
// Ensure trie2 is not modified
|
||||
if !trie2.Contains("google.com") {
|
||||
t.Error("Source trie (trie2) should not be modified after merge")
|
||||
}
|
||||
if trie2.Contains("example.com") {
|
||||
t.Error("Source trie (trie2) was modified after merge")
|
||||
}
|
||||
})
|
||||
}
|
186
dataset/dnstrie/valuetrie.go
Normal file
186
dataset/dnstrie/valuetrie.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package dnstrie
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// ValueNode represents a single node in the prefix trie, using generics for the value type.
|
||||
type ValueNode[T any] struct {
|
||||
// children maps the next label (e.g., "example" in "www.example.com")
|
||||
// to the next node in the trie.
|
||||
children map[string]*ValueNode[T]
|
||||
|
||||
// value is the data of generic type T associated with the domain name ending at this node.
|
||||
value T
|
||||
|
||||
// isEndOfDomain marks if this node represents the end of a valid domain.
|
||||
isEndOfDomain bool
|
||||
}
|
||||
|
||||
// ValueTrie holds the root of the domain prefix trie, using generics.
|
||||
type ValueTrie[T any] struct {
|
||||
root *ValueNode[T]
|
||||
}
|
||||
|
||||
// NewValue creates and initializes a new ValueTrie.
|
||||
func NewValue[T any]() *ValueTrie[T] {
|
||||
return &ValueTrie[T]{
|
||||
root: &ValueNode[T]{
|
||||
children: make(map[string]*ValueNode[T]),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Insert adds a domain name and its associated generic value to the trie.
|
||||
// It canonicalizes the domain name before insertion.
|
||||
func (t *ValueTrie[T]) Insert(domain string, value T) error {
|
||||
// Ensure the string is a valid domain name.
|
||||
if !isValidDomainName(domain) {
|
||||
return fmt.Errorf("'%s' is not a valid domain name", domain)
|
||||
}
|
||||
|
||||
// Canonicalize the domain name (lowercase, ensure trailing dot).
|
||||
canonicalDomain := dns.CanonicalName(domain)
|
||||
|
||||
// Split the domain into label offsets. This avoids allocating new strings for labels.
|
||||
// For "www.example.com.", this returns []int{0, 4, 12}
|
||||
offsets := dns.Split(canonicalDomain)
|
||||
if len(offsets) == 0 {
|
||||
return fmt.Errorf("could not split domain name: %s", domain)
|
||||
}
|
||||
|
||||
currentNode := t.root
|
||||
// Iterate through labels from TLD to the most specific label (right to left).
|
||||
for i := len(offsets) - 1; i >= 0; i-- {
|
||||
start := offsets[i]
|
||||
var end int
|
||||
if i == len(offsets)-1 {
|
||||
// Last label, from its start to the end of the string, excluding the final dot.
|
||||
end = len(canonicalDomain) - 1
|
||||
} else {
|
||||
// Intermediate label, from its start to the character before the next label's starting dot.
|
||||
end = offsets[i+1] - 1
|
||||
}
|
||||
label := canonicalDomain[start:end]
|
||||
|
||||
if _, exists := currentNode.children[label]; !exists {
|
||||
// If the label does not exist in the children map, create a new node.
|
||||
currentNode.children[label] = &ValueNode[T]{children: make(map[string]*ValueNode[T])}
|
||||
}
|
||||
// Move to the next node.
|
||||
currentNode = currentNode.children[label]
|
||||
}
|
||||
|
||||
// Mark the final node as the end of a domain and store the value.
|
||||
currentNode.isEndOfDomain = true
|
||||
currentNode.value = value
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Contains checks if a domain name exists in the trie without returning its value.
|
||||
func (t *ValueTrie[T]) Contains(domain string) bool {
|
||||
if !isValidDomainName(domain) {
|
||||
return false
|
||||
}
|
||||
|
||||
canonicalDomain := dns.CanonicalName(domain)
|
||||
offsets := dns.Split(canonicalDomain)
|
||||
if len(offsets) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
currentNode := t.root
|
||||
for i := len(offsets) - 1; i >= 0; i-- {
|
||||
start := offsets[i]
|
||||
var end int
|
||||
if i == len(offsets)-1 {
|
||||
end = len(canonicalDomain) - 1
|
||||
} else {
|
||||
end = offsets[i+1] - 1
|
||||
}
|
||||
label := canonicalDomain[start:end]
|
||||
|
||||
nextNode, exists := currentNode.children[label]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
currentNode = nextNode
|
||||
}
|
||||
|
||||
return currentNode.isEndOfDomain
|
||||
}
|
||||
|
||||
// Search looks for a domain name in the trie.
|
||||
// It returns the associated generic value and a boolean indicating if the domain was found.
|
||||
func (t *ValueTrie[T]) Search(domain string) (T, bool) {
|
||||
var zero T // The zero value for the generic type T.
|
||||
if !isValidDomainName(domain) {
|
||||
return zero, false
|
||||
}
|
||||
|
||||
canonicalDomain := dns.CanonicalName(domain)
|
||||
offsets := dns.Split(canonicalDomain)
|
||||
if len(offsets) == 0 {
|
||||
return zero, false
|
||||
}
|
||||
|
||||
currentNode := t.root
|
||||
for i := len(offsets) - 1; i >= 0; i-- {
|
||||
start := offsets[i]
|
||||
var end int
|
||||
if i == len(offsets)-1 {
|
||||
end = len(canonicalDomain) - 1
|
||||
} else {
|
||||
end = offsets[i+1] - 1
|
||||
}
|
||||
label := canonicalDomain[start:end]
|
||||
|
||||
nextNode, exists := currentNode.children[label]
|
||||
if !exists {
|
||||
// A label in the path was not found, so the domain doesn't exist.
|
||||
return zero, false
|
||||
}
|
||||
currentNode = nextNode
|
||||
}
|
||||
|
||||
// The full path was found, but we must also check if it's a terminal node.
|
||||
// This prevents matching "example.com" if only "www.example.com" was inserted.
|
||||
if currentNode.isEndOfDomain {
|
||||
return currentNode.value, true
|
||||
}
|
||||
|
||||
return zero, false
|
||||
}
|
||||
|
||||
// Merge combines another trie into the current one.
|
||||
//
|
||||
// If a domain exists in both tries, the value from the 'other' trie is used.
|
||||
func (t *ValueTrie[T]) Merge(other *ValueTrie[T]) {
|
||||
if other == nil || other.root == nil {
|
||||
return
|
||||
}
|
||||
mergeValueNodes(t.root, other.root)
|
||||
}
|
||||
|
||||
// mergeNodes is a recursive helper function to merge nodes from another trie.
|
||||
func mergeValueNodes[T any](localNode, otherNode *ValueNode[T]) {
|
||||
// The other node value overwrites the local one.
|
||||
if otherNode.isEndOfDomain {
|
||||
localNode.value = otherNode.value
|
||||
}
|
||||
|
||||
// Iterate over the children of the other node and merge them.
|
||||
for label, otherChildNode := range otherNode.children {
|
||||
localChildNode, exists := localNode.children[label]
|
||||
if !exists {
|
||||
// If the child doesn't exist in the local trie, attach the other's child.
|
||||
localNode.children[label] = otherChildNode
|
||||
} else {
|
||||
// If the child already exists, recurse to merge them.
|
||||
mergeValueNodes(localChildNode, otherChildNode)
|
||||
}
|
||||
}
|
||||
}
|
182
dataset/dnstrie/valuetrie_test.go
Normal file
182
dataset/dnstrie/valuetrie_test.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package dnstrie
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValueTrie(t *testing.T) {
|
||||
t.Run("InsertAndSearchStrings", func(t *testing.T) {
|
||||
trie := NewValue[string]()
|
||||
domain := "www.example.com"
|
||||
value := "192.0.2.1"
|
||||
err := trie.Insert(domain, value)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error on insert, got %v", err)
|
||||
}
|
||||
|
||||
foundValue, found := trie.Search(domain)
|
||||
if !found {
|
||||
t.Fatalf("Expected to find domain '%s', but did not", domain)
|
||||
}
|
||||
if foundValue != value {
|
||||
t.Errorf("Expected value '%s', got '%s'", value, foundValue)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SearchNotFound", func(t *testing.T) {
|
||||
trie := NewValue[string]()
|
||||
err := trie.Insert("example.com", "value")
|
||||
if err != nil {
|
||||
t.Fatalf("Insert failed: %v", err)
|
||||
}
|
||||
|
||||
_, found := trie.Search("nonexistent.com")
|
||||
if found {
|
||||
t.Error("Expected not to find domain 'nonexistent.com', but did")
|
||||
}
|
||||
|
||||
// Search for a path that exists but is not a terminal node
|
||||
_, found = trie.Search("com")
|
||||
if found {
|
||||
t.Error("Expected not to find non-terminal path 'com', but did")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InsertInvalidDomain", func(t *testing.T) {
|
||||
trie := NewValue[string]()
|
||||
err := trie.Insert("not-a-valid-domain-", "value")
|
||||
if err == nil {
|
||||
t.Error("Expected an error when inserting an invalid domain, but got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OverwriteValue", func(t *testing.T) {
|
||||
trie := NewValue[string]()
|
||||
domain := "overwrite.com"
|
||||
initialValue := "first"
|
||||
overwriteValue := "second"
|
||||
|
||||
err := trie.Insert(domain, initialValue)
|
||||
if err != nil {
|
||||
t.Fatalf("Initial insert failed: %v", err)
|
||||
}
|
||||
err = trie.Insert(domain, overwriteValue)
|
||||
if err != nil {
|
||||
t.Fatalf("Overwrite insert failed: %v", err)
|
||||
}
|
||||
|
||||
foundValue, found := trie.Search(domain)
|
||||
if !found {
|
||||
t.Fatalf("Expected to find domain '%s' after overwrite", domain)
|
||||
}
|
||||
if foundValue != overwriteValue {
|
||||
t.Errorf("Expected overwritten value '%s', got '%s'", overwriteValue, foundValue)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Canonicalization", func(t *testing.T) {
|
||||
trie := NewValue[string]()
|
||||
value := "canonical"
|
||||
// Insert lowercase with trailing dot
|
||||
err := trie.Insert("case.example.org.", value)
|
||||
if err != nil {
|
||||
t.Fatalf("Insert failed: %v", err)
|
||||
}
|
||||
|
||||
// Search uppercase without trailing dot
|
||||
foundValue, found := trie.Search("CASE.EXAMPLE.ORG")
|
||||
if !found {
|
||||
t.Fatal("Failed to find domain with different case and no trailing dot")
|
||||
}
|
||||
if foundValue != value {
|
||||
t.Errorf("Expected value '%s' for canonical search, got '%s'", value, foundValue)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InsertAndSearchIntegers", func(t *testing.T) {
|
||||
trie := NewValue[int]()
|
||||
domain := "int.example.com"
|
||||
value := 12345
|
||||
|
||||
err := trie.Insert(domain, value)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error on insert for int trie, got %v", err)
|
||||
}
|
||||
|
||||
foundValue, found := trie.Search(domain)
|
||||
if !found {
|
||||
t.Fatalf("Expected to find domain '%s' in int trie, but did not", domain)
|
||||
}
|
||||
if foundValue != value {
|
||||
t.Errorf("Expected int value %d, got %d", value, foundValue)
|
||||
}
|
||||
|
||||
// Search for a non-existent domain in the int trie
|
||||
_, found = trie.Search("nonexistent.int.example.com")
|
||||
if found {
|
||||
t.Error("Found a nonexistent domain in the int trie")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Contains", func(t *testing.T) {
|
||||
trie := NewValue[string]()
|
||||
domain := "contains.example.com"
|
||||
err := trie.Insert(domain, "some-value")
|
||||
if err != nil {
|
||||
t.Fatalf("Insert failed: %v", err)
|
||||
}
|
||||
|
||||
if !trie.Contains(domain) {
|
||||
t.Errorf("Expected Contains('%s') to be true, but it was false", domain)
|
||||
}
|
||||
|
||||
if trie.Contains("nonexistent." + domain) {
|
||||
t.Errorf("Expected Contains for nonexistent domain to be false, but it was true")
|
||||
}
|
||||
|
||||
if trie.Contains("example.com") {
|
||||
t.Error("Expected Contains for a non-terminal path to be false, but it was true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MergeTries", func(t *testing.T) {
|
||||
trie1 := NewValue[int]()
|
||||
trie1.Insert("example.com", 100)
|
||||
trie1.Insert("sub.example.com", 200)
|
||||
|
||||
trie2 := NewValue[int]()
|
||||
trie2.Insert("google.com", 300)
|
||||
trie2.Insert("sub.example.com", 999) // Overlapping domain, new value
|
||||
trie2.Insert("another.net", 400)
|
||||
|
||||
trie1.Merge(trie2)
|
||||
|
||||
// Test domains from both tries are present
|
||||
if !trie1.Contains("example.com") {
|
||||
t.Error("Merge failed: trie1 should contain 'example.com'")
|
||||
}
|
||||
if !trie1.Contains("google.com") {
|
||||
t.Error("Merge failed: trie1 should contain 'google.com'")
|
||||
}
|
||||
if !trie1.Contains("another.net") {
|
||||
t.Error("Merge failed: trie1 should contain 'another.net'")
|
||||
}
|
||||
|
||||
// Test that overlapping value was updated from trie2
|
||||
val, found := trie1.Search("sub.example.com")
|
||||
if !found || val != 999 {
|
||||
t.Errorf("Expected value for overlapping domain to be 999, but got %d", val)
|
||||
}
|
||||
|
||||
// Test that non-overlapping value from trie1 is intact
|
||||
val, found = trie1.Search("example.com")
|
||||
if !found || val != 100 {
|
||||
t.Errorf("Expected value for 'example.com' to be 100, but got %d", val)
|
||||
}
|
||||
|
||||
// Ensure trie2 is not modified
|
||||
if _, found := trie2.Search("example.com"); found {
|
||||
t.Error("Source trie (trie2) was modified after merge")
|
||||
}
|
||||
})
|
||||
}
|
Reference in New Issue
Block a user