Better trie implementations

This commit is contained in:
2025-10-08 20:57:13 +02:00
parent 5f0f4aa96b
commit 582163d4be
26 changed files with 2482 additions and 122 deletions

62
dataset/dnstrie/name.go Normal file
View File

@@ -0,0 +1,62 @@
package dnstrie
import (
"strings"
"unicode"
)
// isValidDomainName validates if the given string is a valid DNS hostname.
// A valid hostname consists of a series of labels separated by dots.
// Each label must:
// - Be between 1 and 63 characters long.
// - Consist only of ASCII letters ('a'-'z', 'A'-'Z'), digits ('0'-'9'), and hyphens ('-').
// - Not start or end with a hyphen.
// The total length of the hostname (including dots) must not exceed 253 characters.
func isValidDomainName(host string) bool {
// 1. Check total length. The maximum length of a full hostname is 253 characters.
if len(host) > 253 {
return false
}
// An empty string is not a valid hostname.
if host == "" {
return false
}
// 2. Handle optional trailing dot for FQDNs.
// If the hostname ends with a dot, we remove it for validation purposes.
if strings.HasSuffix(host, ".") {
host = host[:len(host)-1]
}
// After removing a potential trailing dot, the string might be empty.
if host == "" {
return false
}
// 3. Split the hostname into labels.
labels := strings.Split(host, ".")
// 4. Validate each label.
for _, label := range labels {
// a. Check label length (1 to 63 characters).
if len(label) < 1 || len(label) > 63 {
return false
}
// b. Check if label starts or ends with a hyphen.
if strings.HasPrefix(label, "-") || strings.HasSuffix(label, "-") {
return false
}
// c. Check for allowed characters in the label.
for _, char := range label {
if !unicode.IsLetter(char) && !unicode.IsDigit(char) && char != '-' {
return false
}
}
}
// If all checks pass, the hostname is valid.
return true
}

View File

@@ -0,0 +1,42 @@
package dnstrie
import (
"strings"
"testing"
)
func TestIsValidDomainName(t *testing.T) {
tests := []struct {
name string
host string
expected bool
}{
{"Valid Hostname", "example.com", true},
{"Valid with Subdomain", "sub.domain.co.uk", true},
{"Valid with Hyphen", "my-host-name.org", true},
{"Valid with Digits", "app1.server2.net", true},
{"Valid FQDN", "example.com.", true},
{"Valid Single Label", "localhost", true},
{"Valid: Long but within limits", strings.Repeat("a", 63) + "." + strings.Repeat("b", 63) + ".com", true},
{"Invalid: Label Too Long", strings.Repeat("a", 64) + ".com", false},
{"Invalid: Total Length Too Long", strings.Repeat("a", 60) + "." + strings.Repeat("b", 60) + "." + strings.Repeat("c", 60) + "." + strings.Repeat("d", 60) + "." + strings.Repeat("e", 60) + ".com", false},
{"Invalid: Starts with Hyphen", "-invalid.com", false},
{"Invalid: Ends with Hyphen", "invalid-.com", false},
{"Invalid: Contains Underscore", "my_host.com", false},
{"Invalid: Contains Space", "my host.com", false},
{"Invalid: Double Dot", "example..com", false},
{"Invalid: Starts with Dot", ".example.com", false},
{"Invalid: Empty Label", "sub..domain.com", false},
{"Invalid: Just a Dot", ".", false},
{"Invalid: Empty String", "", false},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got := isValidDomainName(test.host)
if got != test.expected {
t.Errorf("isValidDomainName(%q) = %v; want %v", test.host, got, test.expected)
}
})
}
}

139
dataset/dnstrie/trie.go Normal file
View File

@@ -0,0 +1,139 @@
package dnstrie
import (
"fmt"
"github.com/miekg/dns"
)
// Node represents a single node in the prefix trie.
type Node struct {
// children maps the next label (e.g., "example" in "www.example.com")
// to the next node in the trie.
children map[string]*Node
// isEndOfDomain marks if this node represents the end of a valid domain.
isEndOfDomain bool
}
// Trie holds the root of the domain prefix trie.
type Trie struct {
root *Node
}
// New creates and initializes a new Trie.
func New() *Trie {
return &Trie{
root: &Node{
children: make(map[string]*Node),
},
}
}
// Insert adds a domain name to the trie.
// It canonicalizes the domain name before insertion.
func (t *Trie) Insert(domain string) error {
// Ensure the string is a valid domain name.
if !isValidDomainName(domain) {
return fmt.Errorf("'%s' is not a valid domain name", domain)
}
// Canonicalize the domain name (lowercase, ensure trailing dot).
canonicalDomain := dns.CanonicalName(domain)
// Split the domain into label offsets. This avoids allocating new strings for labels.
// For "www.example.com.", this returns []int{0, 4, 12}
offsets := dns.Split(canonicalDomain)
if len(offsets) == 0 {
return fmt.Errorf("could not split domain name: %s", domain)
}
currentNode := t.root
// Iterate through labels from TLD to the most specific label (right to left).
for i := len(offsets) - 1; i >= 0; i-- {
start := offsets[i]
var end int
if i == len(offsets)-1 {
// Last label, from its start to the end of the string, excluding the final dot.
end = len(canonicalDomain) - 1
} else {
// Intermediate label, from its start to the character before the next label's starting dot.
end = offsets[i+1] - 1
}
label := canonicalDomain[start:end]
if _, exists := currentNode.children[label]; !exists {
// If the label does not exist in the children map, create a new node.
currentNode.children[label] = &Node{children: make(map[string]*Node)}
}
// Move to the next node.
currentNode = currentNode.children[label]
}
// Mark the final node as the end of a domain.
currentNode.isEndOfDomain = true
return nil
}
// Contains checks if a domain name exists in the trie.
func (t *Trie) Contains(domain string) bool {
if !isValidDomainName(domain) {
return false
}
canonicalDomain := dns.CanonicalName(domain)
offsets := dns.Split(canonicalDomain)
if len(offsets) == 0 {
return false
}
currentNode := t.root
for i := len(offsets) - 1; i >= 0; i-- {
start := offsets[i]
var end int
if i == len(offsets)-1 {
end = len(canonicalDomain) - 1
} else {
end = offsets[i+1] - 1
}
label := canonicalDomain[start:end]
nextNode, exists := currentNode.children[label]
if !exists {
return false
}
currentNode = nextNode
}
return currentNode.isEndOfDomain
}
// Merge combines another trie into the current one.
//
// Domains from the 'other' trie will be added to the current trie.
func (t *Trie) Merge(other *Trie) {
if other == nil || other.root == nil {
return
}
mergeNodes(t.root, other.root)
}
// mergeNodes is a recursive helper function to merge nodes from another trie.
func mergeNodes(localNode, otherNode *Node) {
// If the other node marks the end of a domain, the local node should too.
if otherNode.isEndOfDomain {
localNode.isEndOfDomain = true
}
// Iterate over the children of the other node and merge them.
for label, otherChildNode := range otherNode.children {
localChildNode, exists := localNode.children[label]
if !exists {
// If the child doesn't exist in the local trie, just attach the other's child.
localNode.children[label] = otherChildNode
} else {
// If the child already exists, recurse to merge them.
mergeNodes(localChildNode, otherChildNode)
}
}
}

View File

@@ -0,0 +1,117 @@
package dnstrie
import "testing"
func TestTrie(t *testing.T) {
t.Run("InsertAndContains", func(t *testing.T) {
trie := New()
domain := "www.example.com"
err := trie.Insert(domain)
if err != nil {
t.Fatalf("Expected no error on insert, got %v", err)
}
if !trie.Contains(domain) {
t.Errorf("Expected Contains('%s') to be true, but it was false", domain)
}
})
t.Run("ContainsNotFound", func(t *testing.T) {
trie := New()
err := trie.Insert("example.com")
if err != nil {
t.Fatalf("Insert failed: %v", err)
}
if trie.Contains("nonexistent.com") {
t.Error("Expected not to find domain 'nonexistent.com', but did")
}
// Check for a path that exists but is not a terminal node
if trie.Contains("com") {
t.Error("Expected not to find non-terminal path 'com', but did")
}
})
t.Run("InsertInvalidDomain", func(t *testing.T) {
trie := New()
err := trie.Insert("not-a-valid-domain-")
if err == nil {
t.Error("Expected an error when inserting an invalid domain, but got nil")
}
})
t.Run("Canonicalization", func(t *testing.T) {
trie := New()
// Insert lowercase with trailing dot
err := trie.Insert("case.example.org.")
if err != nil {
t.Fatalf("Insert failed: %v", err)
}
// Check contains with uppercase without trailing dot
if !trie.Contains("CASE.EXAMPLE.ORG") {
t.Fatal("Failed to find domain with different case and no trailing dot")
}
})
t.Run("MultipleInsertions", func(t *testing.T) {
trie := New()
domains := []string{
"example.com",
"www.example.com",
"api.example.com",
"google.com",
}
for _, domain := range domains {
if err := trie.Insert(domain); err != nil {
t.Fatalf("Insert failed for %s: %v", domain, err)
}
}
for _, domain := range domains {
if !trie.Contains(domain) {
t.Errorf("Expected to find %s, but did not", domain)
}
}
if trie.Contains("ftp.example.com") {
t.Error("Found domain 'ftp.example.com' which was not inserted")
}
})
t.Run("MergeTries", func(t *testing.T) {
trie1 := New()
trie1.Insert("example.com")
trie1.Insert("sub.example.com")
trie2 := New()
trie2.Insert("google.com")
trie2.Insert("sub.example.com") // Overlapping domain
trie2.Insert("another.net")
trie1.Merge(trie2)
// Test domains from both tries
if !trie1.Contains("example.com") {
t.Error("Merge failed: trie1 should contain 'example.com'")
}
if !trie1.Contains("google.com") {
t.Error("Merge failed: trie1 should contain 'google.com'")
}
if !trie1.Contains("sub.example.com") {
t.Error("Merge failed: trie1 should contain overlapping 'sub.example.com'")
}
if !trie1.Contains("another.net") {
t.Error("Merge failed: trie1 should contain 'another.net'")
}
// Ensure trie2 is not modified
if !trie2.Contains("google.com") {
t.Error("Source trie (trie2) should not be modified after merge")
}
if trie2.Contains("example.com") {
t.Error("Source trie (trie2) was modified after merge")
}
})
}

View File

@@ -0,0 +1,186 @@
package dnstrie
import (
"fmt"
"github.com/miekg/dns"
)
// ValueNode represents a single node in the prefix trie, using generics for the value type.
type ValueNode[T any] struct {
// children maps the next label (e.g., "example" in "www.example.com")
// to the next node in the trie.
children map[string]*ValueNode[T]
// value is the data of generic type T associated with the domain name ending at this node.
value T
// isEndOfDomain marks if this node represents the end of a valid domain.
isEndOfDomain bool
}
// ValueTrie holds the root of the domain prefix trie, using generics.
type ValueTrie[T any] struct {
root *ValueNode[T]
}
// NewValue creates and initializes a new ValueTrie.
func NewValue[T any]() *ValueTrie[T] {
return &ValueTrie[T]{
root: &ValueNode[T]{
children: make(map[string]*ValueNode[T]),
},
}
}
// Insert adds a domain name and its associated generic value to the trie.
// It canonicalizes the domain name before insertion.
func (t *ValueTrie[T]) Insert(domain string, value T) error {
// Ensure the string is a valid domain name.
if !isValidDomainName(domain) {
return fmt.Errorf("'%s' is not a valid domain name", domain)
}
// Canonicalize the domain name (lowercase, ensure trailing dot).
canonicalDomain := dns.CanonicalName(domain)
// Split the domain into label offsets. This avoids allocating new strings for labels.
// For "www.example.com.", this returns []int{0, 4, 12}
offsets := dns.Split(canonicalDomain)
if len(offsets) == 0 {
return fmt.Errorf("could not split domain name: %s", domain)
}
currentNode := t.root
// Iterate through labels from TLD to the most specific label (right to left).
for i := len(offsets) - 1; i >= 0; i-- {
start := offsets[i]
var end int
if i == len(offsets)-1 {
// Last label, from its start to the end of the string, excluding the final dot.
end = len(canonicalDomain) - 1
} else {
// Intermediate label, from its start to the character before the next label's starting dot.
end = offsets[i+1] - 1
}
label := canonicalDomain[start:end]
if _, exists := currentNode.children[label]; !exists {
// If the label does not exist in the children map, create a new node.
currentNode.children[label] = &ValueNode[T]{children: make(map[string]*ValueNode[T])}
}
// Move to the next node.
currentNode = currentNode.children[label]
}
// Mark the final node as the end of a domain and store the value.
currentNode.isEndOfDomain = true
currentNode.value = value
return nil
}
// Contains checks if a domain name exists in the trie without returning its value.
func (t *ValueTrie[T]) Contains(domain string) bool {
if !isValidDomainName(domain) {
return false
}
canonicalDomain := dns.CanonicalName(domain)
offsets := dns.Split(canonicalDomain)
if len(offsets) == 0 {
return false
}
currentNode := t.root
for i := len(offsets) - 1; i >= 0; i-- {
start := offsets[i]
var end int
if i == len(offsets)-1 {
end = len(canonicalDomain) - 1
} else {
end = offsets[i+1] - 1
}
label := canonicalDomain[start:end]
nextNode, exists := currentNode.children[label]
if !exists {
return false
}
currentNode = nextNode
}
return currentNode.isEndOfDomain
}
// Search looks for a domain name in the trie.
// It returns the associated generic value and a boolean indicating if the domain was found.
func (t *ValueTrie[T]) Search(domain string) (T, bool) {
var zero T // The zero value for the generic type T.
if !isValidDomainName(domain) {
return zero, false
}
canonicalDomain := dns.CanonicalName(domain)
offsets := dns.Split(canonicalDomain)
if len(offsets) == 0 {
return zero, false
}
currentNode := t.root
for i := len(offsets) - 1; i >= 0; i-- {
start := offsets[i]
var end int
if i == len(offsets)-1 {
end = len(canonicalDomain) - 1
} else {
end = offsets[i+1] - 1
}
label := canonicalDomain[start:end]
nextNode, exists := currentNode.children[label]
if !exists {
// A label in the path was not found, so the domain doesn't exist.
return zero, false
}
currentNode = nextNode
}
// The full path was found, but we must also check if it's a terminal node.
// This prevents matching "example.com" if only "www.example.com" was inserted.
if currentNode.isEndOfDomain {
return currentNode.value, true
}
return zero, false
}
// Merge combines another trie into the current one.
//
// If a domain exists in both tries, the value from the 'other' trie is used.
func (t *ValueTrie[T]) Merge(other *ValueTrie[T]) {
if other == nil || other.root == nil {
return
}
mergeValueNodes(t.root, other.root)
}
// mergeNodes is a recursive helper function to merge nodes from another trie.
func mergeValueNodes[T any](localNode, otherNode *ValueNode[T]) {
// The other node value overwrites the local one.
if otherNode.isEndOfDomain {
localNode.value = otherNode.value
}
// Iterate over the children of the other node and merge them.
for label, otherChildNode := range otherNode.children {
localChildNode, exists := localNode.children[label]
if !exists {
// If the child doesn't exist in the local trie, attach the other's child.
localNode.children[label] = otherChildNode
} else {
// If the child already exists, recurse to merge them.
mergeValueNodes(localChildNode, otherChildNode)
}
}
}

View File

@@ -0,0 +1,182 @@
package dnstrie
import (
"testing"
)
func TestValueTrie(t *testing.T) {
t.Run("InsertAndSearchStrings", func(t *testing.T) {
trie := NewValue[string]()
domain := "www.example.com"
value := "192.0.2.1"
err := trie.Insert(domain, value)
if err != nil {
t.Fatalf("Expected no error on insert, got %v", err)
}
foundValue, found := trie.Search(domain)
if !found {
t.Fatalf("Expected to find domain '%s', but did not", domain)
}
if foundValue != value {
t.Errorf("Expected value '%s', got '%s'", value, foundValue)
}
})
t.Run("SearchNotFound", func(t *testing.T) {
trie := NewValue[string]()
err := trie.Insert("example.com", "value")
if err != nil {
t.Fatalf("Insert failed: %v", err)
}
_, found := trie.Search("nonexistent.com")
if found {
t.Error("Expected not to find domain 'nonexistent.com', but did")
}
// Search for a path that exists but is not a terminal node
_, found = trie.Search("com")
if found {
t.Error("Expected not to find non-terminal path 'com', but did")
}
})
t.Run("InsertInvalidDomain", func(t *testing.T) {
trie := NewValue[string]()
err := trie.Insert("not-a-valid-domain-", "value")
if err == nil {
t.Error("Expected an error when inserting an invalid domain, but got nil")
}
})
t.Run("OverwriteValue", func(t *testing.T) {
trie := NewValue[string]()
domain := "overwrite.com"
initialValue := "first"
overwriteValue := "second"
err := trie.Insert(domain, initialValue)
if err != nil {
t.Fatalf("Initial insert failed: %v", err)
}
err = trie.Insert(domain, overwriteValue)
if err != nil {
t.Fatalf("Overwrite insert failed: %v", err)
}
foundValue, found := trie.Search(domain)
if !found {
t.Fatalf("Expected to find domain '%s' after overwrite", domain)
}
if foundValue != overwriteValue {
t.Errorf("Expected overwritten value '%s', got '%s'", overwriteValue, foundValue)
}
})
t.Run("Canonicalization", func(t *testing.T) {
trie := NewValue[string]()
value := "canonical"
// Insert lowercase with trailing dot
err := trie.Insert("case.example.org.", value)
if err != nil {
t.Fatalf("Insert failed: %v", err)
}
// Search uppercase without trailing dot
foundValue, found := trie.Search("CASE.EXAMPLE.ORG")
if !found {
t.Fatal("Failed to find domain with different case and no trailing dot")
}
if foundValue != value {
t.Errorf("Expected value '%s' for canonical search, got '%s'", value, foundValue)
}
})
t.Run("InsertAndSearchIntegers", func(t *testing.T) {
trie := NewValue[int]()
domain := "int.example.com"
value := 12345
err := trie.Insert(domain, value)
if err != nil {
t.Fatalf("Expected no error on insert for int trie, got %v", err)
}
foundValue, found := trie.Search(domain)
if !found {
t.Fatalf("Expected to find domain '%s' in int trie, but did not", domain)
}
if foundValue != value {
t.Errorf("Expected int value %d, got %d", value, foundValue)
}
// Search for a non-existent domain in the int trie
_, found = trie.Search("nonexistent.int.example.com")
if found {
t.Error("Found a nonexistent domain in the int trie")
}
})
t.Run("Contains", func(t *testing.T) {
trie := NewValue[string]()
domain := "contains.example.com"
err := trie.Insert(domain, "some-value")
if err != nil {
t.Fatalf("Insert failed: %v", err)
}
if !trie.Contains(domain) {
t.Errorf("Expected Contains('%s') to be true, but it was false", domain)
}
if trie.Contains("nonexistent." + domain) {
t.Errorf("Expected Contains for nonexistent domain to be false, but it was true")
}
if trie.Contains("example.com") {
t.Error("Expected Contains for a non-terminal path to be false, but it was true")
}
})
t.Run("MergeTries", func(t *testing.T) {
trie1 := NewValue[int]()
trie1.Insert("example.com", 100)
trie1.Insert("sub.example.com", 200)
trie2 := NewValue[int]()
trie2.Insert("google.com", 300)
trie2.Insert("sub.example.com", 999) // Overlapping domain, new value
trie2.Insert("another.net", 400)
trie1.Merge(trie2)
// Test domains from both tries are present
if !trie1.Contains("example.com") {
t.Error("Merge failed: trie1 should contain 'example.com'")
}
if !trie1.Contains("google.com") {
t.Error("Merge failed: trie1 should contain 'google.com'")
}
if !trie1.Contains("another.net") {
t.Error("Merge failed: trie1 should contain 'another.net'")
}
// Test that overlapping value was updated from trie2
val, found := trie1.Search("sub.example.com")
if !found || val != 999 {
t.Errorf("Expected value for overlapping domain to be 999, but got %d", val)
}
// Test that non-overlapping value from trie1 is intact
val, found = trie1.Search("example.com")
if !found || val != 100 {
t.Errorf("Expected value for 'example.com' to be 100, but got %d", val)
}
// Ensure trie2 is not modified
if _, found := trie2.Search("example.com"); found {
t.Error("Source trie (trie2) was modified after merge")
}
})
}