From 582163d4be7a9846a8169acd1d2d9d4a15cff1eb Mon Sep 17 00:00:00 2001 From: maze Date: Wed, 8 Oct 2025 20:57:13 +0200 Subject: [PATCH] Better trie implementations --- cmd/styx/config.go | 58 ++++- dataset/dnstrie/name.go | 62 ++++++ dataset/dnstrie/name_test.go | 42 ++++ dataset/dnstrie/trie.go | 139 ++++++++++++ dataset/dnstrie/trie_test.go | 117 ++++++++++ dataset/dnstrie/valuetrie.go | 186 ++++++++++++++++ dataset/dnstrie/valuetrie_test.go | 182 ++++++++++++++++ dataset/domain.go | 18 ++ dataset/nettrie/trie.go | 253 ++++++++++++++++++++++ dataset/nettrie/trie_test.go | 248 ++++++++++++++++++++++ dataset/nettrie/valuetrie.go | 328 ++++++++++++++++++++++++++++ dataset/nettrie/valuetrie_test.go | 340 ++++++++++++++++++++++++++++++ dataset/network.go | 27 +++ dataset/storage.go | 57 ++--- dataset/storage_bstore.go | 44 +++- dataset/storage_cache.go | 142 +++++++++++++ go.mod | 3 + policy/handler.go | 12 +- policy/input.go | 17 +- policy/policy.go | 35 ++- proxy/context.go | 43 ++-- proxy/handler.go | 13 +- proxy/policy.go | 178 ++++++++++++++++ proxy/proxy.go | 34 ++- proxy/stats.go | 25 +-- styx.hcl | 1 + 26 files changed, 2482 insertions(+), 122 deletions(-) create mode 100644 dataset/dnstrie/name.go create mode 100644 dataset/dnstrie/name_test.go create mode 100644 dataset/dnstrie/trie.go create mode 100644 dataset/dnstrie/trie_test.go create mode 100644 dataset/dnstrie/valuetrie.go create mode 100644 dataset/dnstrie/valuetrie_test.go create mode 100644 dataset/nettrie/trie.go create mode 100644 dataset/nettrie/trie_test.go create mode 100644 dataset/nettrie/valuetrie.go create mode 100644 dataset/nettrie/valuetrie_test.go create mode 100644 dataset/storage_cache.go create mode 100644 proxy/policy.go diff --git a/cmd/styx/config.go b/cmd/styx/config.go index 3bc970d..eb78d53 100644 --- a/cmd/styx/config.go +++ b/cmd/styx/config.go @@ -3,6 +3,10 @@ package main import ( "crypto/tls" "fmt" + "os" + "path/filepath" + "strings" + "time" "github.com/hashicorp/hcl/v2" "github.com/hashicorp/hcl/v2/gohcl" @@ -24,8 +28,25 @@ type Config struct { } func (c Config) Proxies(log logger.Structured) ([]*proxy.Proxy, error) { + log.Debug("Loading policies") policies := make(map[string]*policy.Policy) for _, pc := range c.Policy { + if !filepath.IsAbs(pc.Path) { + var err error + if pc.Path, err = filepath.Abs(pc.Path); err != nil { + return nil, fmt.Errorf("invalid policy path: %w", err) + } + } + if pc.Package == "" { + var err error + if pc.Package, err = policy.PackageFromFile(pc.Path); err != nil { + return nil, fmt.Errorf("can't determine package in %s: %w", pc.Path, err) + } + } + log.Values(logger.Values{ + "path": pc.Path, + "package": pc.Package, + }).Debug("Loading policy definition") p, err := policy.New(pc.Path, pc.Package) if err != nil { return nil, fmt.Errorf("policy %s: %w", pc.Name, err) @@ -39,6 +60,7 @@ func (c Config) Proxies(log logger.Structured) ([]*proxy.Proxy, error) { onForward []proxy.ForwardHandler onResponse []proxy.ResponseHandler ) + log.Debug("Resolving policy handlers") for _, name := range c.Proxy.On.Request { log.Value("policy", name).Debug("Resolving request policy") p, ok := policies[name] @@ -109,10 +131,18 @@ type PortTLSConfig struct { } func (c PortConfig) Proxy() (*proxy.Proxy, error) { - p := proxy.New() + log := logger.StandardLog.Value("port", c.Listen) + port := proxy.New() if c.Transparent > 0 { - p.OnConnect = append(p.OnConnect, proxy.Transparent(c.Transparent)) + log.Debug("Configuring transparent proxy handler") + port.OnConnect = append(port.OnConnect, proxy.Transparent(c.Transparent)) } else if c.TLS != nil { + if strings.ContainsRune(c.TLS.Cert, os.PathSeparator) { + log = log.Value("cert", c.TLS.Cert) + } else { + log = log.Value("cert", "") + } + log.Debug("Configuring TLS handler") cert, err := cryptutil.LoadTLSCertificate(c.TLS.Cert, c.TLS.Key) if err != nil { return nil, err @@ -121,6 +151,7 @@ func (c PortConfig) Proxy() (*proxy.Proxy, error) { config := new(tls.Config) config.Certificates = []tls.Certificate{cert} if c.TLS.CA != "" { + log.Value("ca", c.TLS.CA).Debug("Loading trusted roots") roots, err := cryptutil.LoadRoots(c.TLS.CA) if err != nil { return nil, err @@ -128,9 +159,9 @@ func (c PortConfig) Proxy() (*proxy.Proxy, error) { config.RootCAs = roots } - p.OnConnect = append(p.OnConnect, proxy.TLS(config)) + port.OnConnect = append(port.OnConnect, proxy.TLS(config)) } - return p, nil + return port, nil } type ProxyPolicyConfig struct { @@ -177,17 +208,23 @@ func (c DataConfig) Configure() error { return nil } -func (c DataConfig) OpenStorage() (dataset.Storage, error) { +func (c DataConfig) OpenStorage() (s dataset.Storage, err error) { + var cache time.Duration switch c.Storage.Type { case "", "bolt", "boltdb": var config struct { - Path string `hcl:"path"` + Path string `hcl:"path"` + Cache float64 `hcl:"cache,optional"` } if diag := gohcl.DecodeBody(c.Storage.Body, nil, &config); diag.HasErrors() { return nil, diag } - //return dataset.OpenBolt(config.Path) - return dataset.OpenBStore(config.Path) + if s, err = dataset.OpenBStore(config.Path); err != nil { + return + } + if config.Cache > 0 { + cache = time.Duration(config.Cache * float64(time.Second)) + } /* case "sqlite", "sqlite3": @@ -203,6 +240,11 @@ func (c DataConfig) OpenStorage() (dataset.Storage, error) { default: return nil, fmt.Errorf("storage: no %q driver", c.Storage.Type) } + + if s != nil && cache > 0 { + return dataset.Cache(s, cache), nil + } + return } type DataStorageConfig struct { diff --git a/dataset/dnstrie/name.go b/dataset/dnstrie/name.go new file mode 100644 index 0000000..2281ad4 --- /dev/null +++ b/dataset/dnstrie/name.go @@ -0,0 +1,62 @@ +package dnstrie + +import ( + "strings" + "unicode" +) + +// isValidDomainName validates if the given string is a valid DNS hostname. +// A valid hostname consists of a series of labels separated by dots. +// Each label must: +// - Be between 1 and 63 characters long. +// - Consist only of ASCII letters ('a'-'z', 'A'-'Z'), digits ('0'-'9'), and hyphens ('-'). +// - Not start or end with a hyphen. +// The total length of the hostname (including dots) must not exceed 253 characters. +func isValidDomainName(host string) bool { + // 1. Check total length. The maximum length of a full hostname is 253 characters. + if len(host) > 253 { + return false + } + + // An empty string is not a valid hostname. + if host == "" { + return false + } + + // 2. Handle optional trailing dot for FQDNs. + // If the hostname ends with a dot, we remove it for validation purposes. + if strings.HasSuffix(host, ".") { + host = host[:len(host)-1] + } + + // After removing a potential trailing dot, the string might be empty. + if host == "" { + return false + } + + // 3. Split the hostname into labels. + labels := strings.Split(host, ".") + + // 4. Validate each label. + for _, label := range labels { + // a. Check label length (1 to 63 characters). + if len(label) < 1 || len(label) > 63 { + return false + } + + // b. Check if label starts or ends with a hyphen. + if strings.HasPrefix(label, "-") || strings.HasSuffix(label, "-") { + return false + } + + // c. Check for allowed characters in the label. + for _, char := range label { + if !unicode.IsLetter(char) && !unicode.IsDigit(char) && char != '-' { + return false + } + } + } + + // If all checks pass, the hostname is valid. + return true +} diff --git a/dataset/dnstrie/name_test.go b/dataset/dnstrie/name_test.go new file mode 100644 index 0000000..fbceaaa --- /dev/null +++ b/dataset/dnstrie/name_test.go @@ -0,0 +1,42 @@ +package dnstrie + +import ( + "strings" + "testing" +) + +func TestIsValidDomainName(t *testing.T) { + tests := []struct { + name string + host string + expected bool + }{ + {"Valid Hostname", "example.com", true}, + {"Valid with Subdomain", "sub.domain.co.uk", true}, + {"Valid with Hyphen", "my-host-name.org", true}, + {"Valid with Digits", "app1.server2.net", true}, + {"Valid FQDN", "example.com.", true}, + {"Valid Single Label", "localhost", true}, + {"Valid: Long but within limits", strings.Repeat("a", 63) + "." + strings.Repeat("b", 63) + ".com", true}, + {"Invalid: Label Too Long", strings.Repeat("a", 64) + ".com", false}, + {"Invalid: Total Length Too Long", strings.Repeat("a", 60) + "." + strings.Repeat("b", 60) + "." + strings.Repeat("c", 60) + "." + strings.Repeat("d", 60) + "." + strings.Repeat("e", 60) + ".com", false}, + {"Invalid: Starts with Hyphen", "-invalid.com", false}, + {"Invalid: Ends with Hyphen", "invalid-.com", false}, + {"Invalid: Contains Underscore", "my_host.com", false}, + {"Invalid: Contains Space", "my host.com", false}, + {"Invalid: Double Dot", "example..com", false}, + {"Invalid: Starts with Dot", ".example.com", false}, + {"Invalid: Empty Label", "sub..domain.com", false}, + {"Invalid: Just a Dot", ".", false}, + {"Invalid: Empty String", "", false}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got := isValidDomainName(test.host) + if got != test.expected { + t.Errorf("isValidDomainName(%q) = %v; want %v", test.host, got, test.expected) + } + }) + } +} diff --git a/dataset/dnstrie/trie.go b/dataset/dnstrie/trie.go new file mode 100644 index 0000000..c8f5b45 --- /dev/null +++ b/dataset/dnstrie/trie.go @@ -0,0 +1,139 @@ +package dnstrie + +import ( + "fmt" + + "github.com/miekg/dns" +) + +// Node represents a single node in the prefix trie. +type Node struct { + // children maps the next label (e.g., "example" in "www.example.com") + // to the next node in the trie. + children map[string]*Node + // isEndOfDomain marks if this node represents the end of a valid domain. + isEndOfDomain bool +} + +// Trie holds the root of the domain prefix trie. +type Trie struct { + root *Node +} + +// New creates and initializes a new Trie. +func New() *Trie { + return &Trie{ + root: &Node{ + children: make(map[string]*Node), + }, + } +} + +// Insert adds a domain name to the trie. +// It canonicalizes the domain name before insertion. +func (t *Trie) Insert(domain string) error { + // Ensure the string is a valid domain name. + if !isValidDomainName(domain) { + return fmt.Errorf("'%s' is not a valid domain name", domain) + } + + // Canonicalize the domain name (lowercase, ensure trailing dot). + canonicalDomain := dns.CanonicalName(domain) + + // Split the domain into label offsets. This avoids allocating new strings for labels. + // For "www.example.com.", this returns []int{0, 4, 12} + offsets := dns.Split(canonicalDomain) + if len(offsets) == 0 { + return fmt.Errorf("could not split domain name: %s", domain) + } + + currentNode := t.root + // Iterate through labels from TLD to the most specific label (right to left). + for i := len(offsets) - 1; i >= 0; i-- { + start := offsets[i] + var end int + if i == len(offsets)-1 { + // Last label, from its start to the end of the string, excluding the final dot. + end = len(canonicalDomain) - 1 + } else { + // Intermediate label, from its start to the character before the next label's starting dot. + end = offsets[i+1] - 1 + } + label := canonicalDomain[start:end] + + if _, exists := currentNode.children[label]; !exists { + // If the label does not exist in the children map, create a new node. + currentNode.children[label] = &Node{children: make(map[string]*Node)} + } + // Move to the next node. + currentNode = currentNode.children[label] + } + + // Mark the final node as the end of a domain. + currentNode.isEndOfDomain = true + + return nil +} + +// Contains checks if a domain name exists in the trie. +func (t *Trie) Contains(domain string) bool { + if !isValidDomainName(domain) { + return false + } + + canonicalDomain := dns.CanonicalName(domain) + offsets := dns.Split(canonicalDomain) + if len(offsets) == 0 { + return false + } + + currentNode := t.root + for i := len(offsets) - 1; i >= 0; i-- { + start := offsets[i] + var end int + if i == len(offsets)-1 { + end = len(canonicalDomain) - 1 + } else { + end = offsets[i+1] - 1 + } + label := canonicalDomain[start:end] + + nextNode, exists := currentNode.children[label] + if !exists { + return false + } + currentNode = nextNode + } + + return currentNode.isEndOfDomain +} + +// Merge combines another trie into the current one. +// +// Domains from the 'other' trie will be added to the current trie. +func (t *Trie) Merge(other *Trie) { + if other == nil || other.root == nil { + return + } + mergeNodes(t.root, other.root) +} + +// mergeNodes is a recursive helper function to merge nodes from another trie. +func mergeNodes(localNode, otherNode *Node) { + // If the other node marks the end of a domain, the local node should too. + if otherNode.isEndOfDomain { + localNode.isEndOfDomain = true + } + + // Iterate over the children of the other node and merge them. + for label, otherChildNode := range otherNode.children { + localChildNode, exists := localNode.children[label] + if !exists { + // If the child doesn't exist in the local trie, just attach the other's child. + localNode.children[label] = otherChildNode + } else { + // If the child already exists, recurse to merge them. + mergeNodes(localChildNode, otherChildNode) + } + } +} diff --git a/dataset/dnstrie/trie_test.go b/dataset/dnstrie/trie_test.go new file mode 100644 index 0000000..2a9fae7 --- /dev/null +++ b/dataset/dnstrie/trie_test.go @@ -0,0 +1,117 @@ +package dnstrie + +import "testing" + +func TestTrie(t *testing.T) { + t.Run("InsertAndContains", func(t *testing.T) { + trie := New() + domain := "www.example.com" + err := trie.Insert(domain) + if err != nil { + t.Fatalf("Expected no error on insert, got %v", err) + } + + if !trie.Contains(domain) { + t.Errorf("Expected Contains('%s') to be true, but it was false", domain) + } + }) + + t.Run("ContainsNotFound", func(t *testing.T) { + trie := New() + err := trie.Insert("example.com") + if err != nil { + t.Fatalf("Insert failed: %v", err) + } + + if trie.Contains("nonexistent.com") { + t.Error("Expected not to find domain 'nonexistent.com', but did") + } + + // Check for a path that exists but is not a terminal node + if trie.Contains("com") { + t.Error("Expected not to find non-terminal path 'com', but did") + } + }) + + t.Run("InsertInvalidDomain", func(t *testing.T) { + trie := New() + err := trie.Insert("not-a-valid-domain-") + if err == nil { + t.Error("Expected an error when inserting an invalid domain, but got nil") + } + }) + + t.Run("Canonicalization", func(t *testing.T) { + trie := New() + // Insert lowercase with trailing dot + err := trie.Insert("case.example.org.") + if err != nil { + t.Fatalf("Insert failed: %v", err) + } + + // Check contains with uppercase without trailing dot + if !trie.Contains("CASE.EXAMPLE.ORG") { + t.Fatal("Failed to find domain with different case and no trailing dot") + } + }) + + t.Run("MultipleInsertions", func(t *testing.T) { + trie := New() + domains := []string{ + "example.com", + "www.example.com", + "api.example.com", + "google.com", + } + for _, domain := range domains { + if err := trie.Insert(domain); err != nil { + t.Fatalf("Insert failed for %s: %v", domain, err) + } + } + + for _, domain := range domains { + if !trie.Contains(domain) { + t.Errorf("Expected to find %s, but did not", domain) + } + } + + if trie.Contains("ftp.example.com") { + t.Error("Found domain 'ftp.example.com' which was not inserted") + } + }) + + t.Run("MergeTries", func(t *testing.T) { + trie1 := New() + trie1.Insert("example.com") + trie1.Insert("sub.example.com") + + trie2 := New() + trie2.Insert("google.com") + trie2.Insert("sub.example.com") // Overlapping domain + trie2.Insert("another.net") + + trie1.Merge(trie2) + + // Test domains from both tries + if !trie1.Contains("example.com") { + t.Error("Merge failed: trie1 should contain 'example.com'") + } + if !trie1.Contains("google.com") { + t.Error("Merge failed: trie1 should contain 'google.com'") + } + if !trie1.Contains("sub.example.com") { + t.Error("Merge failed: trie1 should contain overlapping 'sub.example.com'") + } + if !trie1.Contains("another.net") { + t.Error("Merge failed: trie1 should contain 'another.net'") + } + + // Ensure trie2 is not modified + if !trie2.Contains("google.com") { + t.Error("Source trie (trie2) should not be modified after merge") + } + if trie2.Contains("example.com") { + t.Error("Source trie (trie2) was modified after merge") + } + }) +} diff --git a/dataset/dnstrie/valuetrie.go b/dataset/dnstrie/valuetrie.go new file mode 100644 index 0000000..8efe460 --- /dev/null +++ b/dataset/dnstrie/valuetrie.go @@ -0,0 +1,186 @@ +package dnstrie + +import ( + "fmt" + + "github.com/miekg/dns" +) + +// ValueNode represents a single node in the prefix trie, using generics for the value type. +type ValueNode[T any] struct { + // children maps the next label (e.g., "example" in "www.example.com") + // to the next node in the trie. + children map[string]*ValueNode[T] + + // value is the data of generic type T associated with the domain name ending at this node. + value T + + // isEndOfDomain marks if this node represents the end of a valid domain. + isEndOfDomain bool +} + +// ValueTrie holds the root of the domain prefix trie, using generics. +type ValueTrie[T any] struct { + root *ValueNode[T] +} + +// NewValue creates and initializes a new ValueTrie. +func NewValue[T any]() *ValueTrie[T] { + return &ValueTrie[T]{ + root: &ValueNode[T]{ + children: make(map[string]*ValueNode[T]), + }, + } +} + +// Insert adds a domain name and its associated generic value to the trie. +// It canonicalizes the domain name before insertion. +func (t *ValueTrie[T]) Insert(domain string, value T) error { + // Ensure the string is a valid domain name. + if !isValidDomainName(domain) { + return fmt.Errorf("'%s' is not a valid domain name", domain) + } + + // Canonicalize the domain name (lowercase, ensure trailing dot). + canonicalDomain := dns.CanonicalName(domain) + + // Split the domain into label offsets. This avoids allocating new strings for labels. + // For "www.example.com.", this returns []int{0, 4, 12} + offsets := dns.Split(canonicalDomain) + if len(offsets) == 0 { + return fmt.Errorf("could not split domain name: %s", domain) + } + + currentNode := t.root + // Iterate through labels from TLD to the most specific label (right to left). + for i := len(offsets) - 1; i >= 0; i-- { + start := offsets[i] + var end int + if i == len(offsets)-1 { + // Last label, from its start to the end of the string, excluding the final dot. + end = len(canonicalDomain) - 1 + } else { + // Intermediate label, from its start to the character before the next label's starting dot. + end = offsets[i+1] - 1 + } + label := canonicalDomain[start:end] + + if _, exists := currentNode.children[label]; !exists { + // If the label does not exist in the children map, create a new node. + currentNode.children[label] = &ValueNode[T]{children: make(map[string]*ValueNode[T])} + } + // Move to the next node. + currentNode = currentNode.children[label] + } + + // Mark the final node as the end of a domain and store the value. + currentNode.isEndOfDomain = true + currentNode.value = value + + return nil +} + +// Contains checks if a domain name exists in the trie without returning its value. +func (t *ValueTrie[T]) Contains(domain string) bool { + if !isValidDomainName(domain) { + return false + } + + canonicalDomain := dns.CanonicalName(domain) + offsets := dns.Split(canonicalDomain) + if len(offsets) == 0 { + return false + } + + currentNode := t.root + for i := len(offsets) - 1; i >= 0; i-- { + start := offsets[i] + var end int + if i == len(offsets)-1 { + end = len(canonicalDomain) - 1 + } else { + end = offsets[i+1] - 1 + } + label := canonicalDomain[start:end] + + nextNode, exists := currentNode.children[label] + if !exists { + return false + } + currentNode = nextNode + } + + return currentNode.isEndOfDomain +} + +// Search looks for a domain name in the trie. +// It returns the associated generic value and a boolean indicating if the domain was found. +func (t *ValueTrie[T]) Search(domain string) (T, bool) { + var zero T // The zero value for the generic type T. + if !isValidDomainName(domain) { + return zero, false + } + + canonicalDomain := dns.CanonicalName(domain) + offsets := dns.Split(canonicalDomain) + if len(offsets) == 0 { + return zero, false + } + + currentNode := t.root + for i := len(offsets) - 1; i >= 0; i-- { + start := offsets[i] + var end int + if i == len(offsets)-1 { + end = len(canonicalDomain) - 1 + } else { + end = offsets[i+1] - 1 + } + label := canonicalDomain[start:end] + + nextNode, exists := currentNode.children[label] + if !exists { + // A label in the path was not found, so the domain doesn't exist. + return zero, false + } + currentNode = nextNode + } + + // The full path was found, but we must also check if it's a terminal node. + // This prevents matching "example.com" if only "www.example.com" was inserted. + if currentNode.isEndOfDomain { + return currentNode.value, true + } + + return zero, false +} + +// Merge combines another trie into the current one. +// +// If a domain exists in both tries, the value from the 'other' trie is used. +func (t *ValueTrie[T]) Merge(other *ValueTrie[T]) { + if other == nil || other.root == nil { + return + } + mergeValueNodes(t.root, other.root) +} + +// mergeNodes is a recursive helper function to merge nodes from another trie. +func mergeValueNodes[T any](localNode, otherNode *ValueNode[T]) { + // The other node value overwrites the local one. + if otherNode.isEndOfDomain { + localNode.value = otherNode.value + } + + // Iterate over the children of the other node and merge them. + for label, otherChildNode := range otherNode.children { + localChildNode, exists := localNode.children[label] + if !exists { + // If the child doesn't exist in the local trie, attach the other's child. + localNode.children[label] = otherChildNode + } else { + // If the child already exists, recurse to merge them. + mergeValueNodes(localChildNode, otherChildNode) + } + } +} diff --git a/dataset/dnstrie/valuetrie_test.go b/dataset/dnstrie/valuetrie_test.go new file mode 100644 index 0000000..63f226e --- /dev/null +++ b/dataset/dnstrie/valuetrie_test.go @@ -0,0 +1,182 @@ +package dnstrie + +import ( + "testing" +) + +func TestValueTrie(t *testing.T) { + t.Run("InsertAndSearchStrings", func(t *testing.T) { + trie := NewValue[string]() + domain := "www.example.com" + value := "192.0.2.1" + err := trie.Insert(domain, value) + if err != nil { + t.Fatalf("Expected no error on insert, got %v", err) + } + + foundValue, found := trie.Search(domain) + if !found { + t.Fatalf("Expected to find domain '%s', but did not", domain) + } + if foundValue != value { + t.Errorf("Expected value '%s', got '%s'", value, foundValue) + } + }) + + t.Run("SearchNotFound", func(t *testing.T) { + trie := NewValue[string]() + err := trie.Insert("example.com", "value") + if err != nil { + t.Fatalf("Insert failed: %v", err) + } + + _, found := trie.Search("nonexistent.com") + if found { + t.Error("Expected not to find domain 'nonexistent.com', but did") + } + + // Search for a path that exists but is not a terminal node + _, found = trie.Search("com") + if found { + t.Error("Expected not to find non-terminal path 'com', but did") + } + }) + + t.Run("InsertInvalidDomain", func(t *testing.T) { + trie := NewValue[string]() + err := trie.Insert("not-a-valid-domain-", "value") + if err == nil { + t.Error("Expected an error when inserting an invalid domain, but got nil") + } + }) + + t.Run("OverwriteValue", func(t *testing.T) { + trie := NewValue[string]() + domain := "overwrite.com" + initialValue := "first" + overwriteValue := "second" + + err := trie.Insert(domain, initialValue) + if err != nil { + t.Fatalf("Initial insert failed: %v", err) + } + err = trie.Insert(domain, overwriteValue) + if err != nil { + t.Fatalf("Overwrite insert failed: %v", err) + } + + foundValue, found := trie.Search(domain) + if !found { + t.Fatalf("Expected to find domain '%s' after overwrite", domain) + } + if foundValue != overwriteValue { + t.Errorf("Expected overwritten value '%s', got '%s'", overwriteValue, foundValue) + } + }) + + t.Run("Canonicalization", func(t *testing.T) { + trie := NewValue[string]() + value := "canonical" + // Insert lowercase with trailing dot + err := trie.Insert("case.example.org.", value) + if err != nil { + t.Fatalf("Insert failed: %v", err) + } + + // Search uppercase without trailing dot + foundValue, found := trie.Search("CASE.EXAMPLE.ORG") + if !found { + t.Fatal("Failed to find domain with different case and no trailing dot") + } + if foundValue != value { + t.Errorf("Expected value '%s' for canonical search, got '%s'", value, foundValue) + } + }) + + t.Run("InsertAndSearchIntegers", func(t *testing.T) { + trie := NewValue[int]() + domain := "int.example.com" + value := 12345 + + err := trie.Insert(domain, value) + if err != nil { + t.Fatalf("Expected no error on insert for int trie, got %v", err) + } + + foundValue, found := trie.Search(domain) + if !found { + t.Fatalf("Expected to find domain '%s' in int trie, but did not", domain) + } + if foundValue != value { + t.Errorf("Expected int value %d, got %d", value, foundValue) + } + + // Search for a non-existent domain in the int trie + _, found = trie.Search("nonexistent.int.example.com") + if found { + t.Error("Found a nonexistent domain in the int trie") + } + }) + + t.Run("Contains", func(t *testing.T) { + trie := NewValue[string]() + domain := "contains.example.com" + err := trie.Insert(domain, "some-value") + if err != nil { + t.Fatalf("Insert failed: %v", err) + } + + if !trie.Contains(domain) { + t.Errorf("Expected Contains('%s') to be true, but it was false", domain) + } + + if trie.Contains("nonexistent." + domain) { + t.Errorf("Expected Contains for nonexistent domain to be false, but it was true") + } + + if trie.Contains("example.com") { + t.Error("Expected Contains for a non-terminal path to be false, but it was true") + } + }) + + t.Run("MergeTries", func(t *testing.T) { + trie1 := NewValue[int]() + trie1.Insert("example.com", 100) + trie1.Insert("sub.example.com", 200) + + trie2 := NewValue[int]() + trie2.Insert("google.com", 300) + trie2.Insert("sub.example.com", 999) // Overlapping domain, new value + trie2.Insert("another.net", 400) + + trie1.Merge(trie2) + + // Test domains from both tries are present + if !trie1.Contains("example.com") { + t.Error("Merge failed: trie1 should contain 'example.com'") + } + if !trie1.Contains("google.com") { + t.Error("Merge failed: trie1 should contain 'google.com'") + } + if !trie1.Contains("another.net") { + t.Error("Merge failed: trie1 should contain 'another.net'") + } + + // Test that overlapping value was updated from trie2 + val, found := trie1.Search("sub.example.com") + if !found || val != 999 { + t.Errorf("Expected value for overlapping domain to be 999, but got %d", val) + } + + // Test that non-overlapping value from trie1 is intact + val, found = trie1.Search("example.com") + if !found || val != 100 { + t.Errorf("Expected value for 'example.com' to be 100, but got %d", val) + } + + // Ensure trie2 is not modified + if _, found := trie2.Search("example.com"); found { + t.Error("Source trie (trie2) was modified after merge") + } + }) +} diff --git a/dataset/domain.go b/dataset/domain.go index d5bfabd..2b043f5 100644 --- a/dataset/domain.go +++ b/dataset/domain.go @@ -1,11 +1,29 @@ package dataset import ( + "fmt" "strings" + "git.maze.io/maze/styx/dataset/dnstrie" "github.com/miekg/dns" ) +type DomainTrie struct { + *dnstrie.ValueTrie[bool] +} + +func NewDomainTrie(permit bool, domains ...string) (*DomainTrie, error) { + trie := &DomainTrie{ + ValueTrie: dnstrie.NewValue[bool](), + } + for _, domain := range domains { + if err := trie.Insert(domain, permit); err != nil { + return nil, fmt.Errorf("dataset: error inserting %s: %w", domain, err) + } + } + return trie, nil +} + type DomainTree struct { root *domainTreeNode } diff --git a/dataset/nettrie/trie.go b/dataset/nettrie/trie.go new file mode 100644 index 0000000..c33512e --- /dev/null +++ b/dataset/nettrie/trie.go @@ -0,0 +1,253 @@ +package nettrie + +import "net/netip" + +// Node represents a node in the path-compressed trie. +// Each node represents a prefix and can have up to two children. +type Node struct { + children [2]*Node + + // prefix is the full prefix represented by the path to this node. + prefix netip.Prefix + + // isValue marks if this node represents an explicitly inserted prefix. + isValue bool +} + +// Trie is a path-compressed radix trie that stores network prefixes. +type Trie struct { + rootV4 *Node + rootV6 *Node +} + +// New creates and initializes a new Trie. +func New() *Trie { + return &Trie{} +} + +// Insert adds a prefix to the trie. +func (t *Trie) Insert(p netip.Prefix) { + p = p.Masked() + addr := p.Addr() + + if addr.Is4() { + t.rootV4 = t.insert(t.rootV4, p) + } else { + t.rootV6 = t.insert(t.rootV6, p) + } +} + +// insert is the recursive helper for inserting a prefix into the trie. +func (t *Trie) insert(node *Node, p netip.Prefix) *Node { + if node == nil { + return &Node{prefix: p, isValue: true} + } + + addr := p.Addr() + commonLen := commonPrefixLen(addr, node.prefix.Addr()) + pBits := p.Bits() + nodeBits := node.prefix.Bits() + + if commonLen > pBits { + commonLen = pBits + } + if commonLen > nodeBits { + commonLen = nodeBits + } + + if commonLen == nodeBits && commonLen == pBits { + // Exact match, mark the node as a value node. + node.isValue = true + return node + } + + if commonLen < nodeBits { + // The new prefix diverges from the current node's prefix. + // We must split the current node. + commonP, _ := node.prefix.Addr().Prefix(commonLen) + splitNode := &Node{prefix: commonP} + + // The existing node becomes a child of the new split node. + bit := getBit(node.prefix.Addr(), commonLen) + splitNode.children[bit] = node + + if commonLen == pBits { + // The inserted prefix is a prefix of the node's original prefix. + // The new split node represents the inserted prefix. + splitNode.isValue = true + } else { + // The two prefixes diverge. Create a new child for the new prefix. + bit := getBit(addr, commonLen) + splitNode.children[bit] = &Node{prefix: p, isValue: true} + } + return splitNode + } + + // commonLen == nodeBits, meaning the current node's prefix is a prefix of the new one. + // We need to descend to a child. + bit := getBit(addr, commonLen) + node.children[bit] = t.insert(node.children[bit], p) + return node +} + +// Delete removes a prefix from the trie. It returns true if the prefix was found and removed. +func (t *Trie) Delete(p netip.Prefix) bool { + p = p.Masked() + addr := p.Addr() + + var changed bool + if addr.Is4() { + t.rootV4, changed = t.delete(t.rootV4, p) + } else { + t.rootV6, changed = t.delete(t.rootV6, p) + } + return changed +} + +// delete is the recursive helper for removing a prefix from the trie. +func (t *Trie) delete(node *Node, p netip.Prefix) (*Node, bool) { + if node == nil { + return nil, false + } + + addr := p.Addr() + pBits := p.Bits() + nodeBits := node.prefix.Bits() + commonLen := commonPrefixLen(addr, node.prefix.Addr()) + + // The prefix is not on this path. + if commonLen < nodeBits || commonLen < pBits && pBits < nodeBits { + return node, false + } + + var changed bool + if pBits > nodeBits { + // The prefix to delete is deeper in the trie. Recurse. + bit := getBit(addr, nodeBits) + node.children[bit], changed = t.delete(node.children[bit], p) + } else if pBits == nodeBits { + // This is the node to delete. Unset its value. + if !node.isValue { + return node, false // Prefix wasn't actually in the trie. + } + node.isValue = false + changed = true + } else { // pBits < nodeBits + return node, false // Prefix to delete is shorter, so can't be here. + } + + if !changed { + return node, false + } + + // Post-deletion cleanup: + // If the node has no value and can be merged with a single child, do so. + if !node.isValue { + if node.children[0] != nil && node.children[1] == nil { + return node.children[0], true + } + if node.children[0] == nil && node.children[1] != nil { + return node.children[1], true + } + } + + // If the node is now a leaf without a value, it can be removed entirely. + if !node.isValue && node.children[0] == nil && node.children[1] == nil { + return nil, true + } + + return node, true +} + +// ContainsPrefix checks if the exact prefix exists in the trie. +func (t *Trie) ContainsPrefix(p netip.Prefix) bool { + p = p.Masked() + addr := p.Addr() + pBits := p.Bits() + + node := t.rootV4 + if addr.Is6() { + node = t.rootV6 + } + + for node != nil { + commonLen := commonPrefixLen(addr, node.prefix.Addr()) + nodeBits := node.prefix.Bits() + + if commonLen < nodeBits { + // Path has diverged. The prefix cannot be in this subtree. + return false + } + + if pBits < nodeBits { + // The search prefix is shorter than the node's prefix, + // but they share a prefix. e.g. search /16, node is /24. + // The /16 is not explicitly in the trie. + return false + } + + if pBits == nodeBits { + // Found a node with the exact same prefix length. + // Because we also know commonLen >= nodeBits, the prefixes are identical. + return node.isValue + } + + // pBits > nodeBits, so we need to go deeper. + bit := getBit(addr, nodeBits) + node = node.children[bit] + } + + return false +} + +// Contains checks if the exact IP address exists in the trie as a full-length prefix. +func (t *Trie) Contains(addr netip.Addr) bool { + prefix := netip.PrefixFrom(addr, addr.BitLen()) + return t.ContainsPrefix(prefix) +} + +// WalkFunc is a function called for each prefix in the trie during a walk. +// Returning false from the function will stop the walk. +type WalkFunc func(p netip.Prefix) bool + +// walk is the recursive helper for traversing the trie. +func walk(node *Node, f WalkFunc) bool { + if node == nil { + return true + } + + if node.isValue { + if !f(node.prefix) { + return false + } + } + + if node.children[0] != nil { + if !walk(node.children[0], f) { + return false + } + } + if node.children[1] != nil { + if !walk(node.children[1], f) { + return false + } + } + return true +} + +// Walk traverses the trie and calls the given function for each prefix. +// If the function returns false, the walk is stopped. The order is not guaranteed. +func (t *Trie) Walk(f WalkFunc) { + if !walk(t.rootV4, f) { + return + } + walk(t.rootV6, f) +} + +// Merge inserts all prefixes from another Trie into this one. +func (t *Trie) Merge(other *Trie) { + other.Walk(func(p netip.Prefix) bool { + t.Insert(p) + return true // continue walking + }) +} diff --git a/dataset/nettrie/trie_test.go b/dataset/nettrie/trie_test.go new file mode 100644 index 0000000..bf979be --- /dev/null +++ b/dataset/nettrie/trie_test.go @@ -0,0 +1,248 @@ +package nettrie + +import ( + "net/netip" + "sort" + "testing" +) + +// TestTrie_InsertAndContains_IPv4 tests insertion and verifies presence +// using ContainsPrefix and Contains for IPv4. +func TestTrie_InsertAndContains_IPv4(t *testing.T) { + trie := New() + + prefixes := []string{ + "0.0.0.0/0", + "10.0.0.0/8", + "192.168.0.0/16", + "192.168.1.0/24", + "192.168.1.128/25", + "192.168.1.200/32", + } + for _, s := range prefixes { + trie.Insert(netip.MustParsePrefix(s)) + } + + t.Run("Check existing prefixes", func(t *testing.T) { + for _, s := range prefixes { + p := netip.MustParsePrefix(s) + if !trie.ContainsPrefix(p) { + t.Errorf("expected trie to contain prefix %s, but it did not", p) + } + } + }) + + t.Run("Check non-existing prefixes", func(t *testing.T) { + nonExistentPrefixes := []string{ + "10.0.0.0/9", + "192.168.1.0/23", + "1.2.3.4/32", + } + for _, s := range nonExistentPrefixes { + p := netip.MustParsePrefix(s) + if trie.ContainsPrefix(p) { + t.Errorf("expected trie to not contain prefix %s, but it did", p) + } + } + }) + + t.Run("Check host addresses with Contains", func(t *testing.T) { + if !trie.Contains(netip.MustParseAddr("192.168.1.200")) { + t.Error("expected trie to contain host address 192.168.1.200, but it did not") + } + if trie.Contains(netip.MustParseAddr("192.168.1.201")) { + t.Error("expected trie to not contain host address 192.168.1.201, but it did") + } + }) +} + +// TestTrie_InsertAndContains_IPv6 tests insertion and verifies presence +// using ContainsPrefix and Contains for IPv6. +func TestTrie_InsertAndContains_IPv6(t *testing.T) { + trie := New() + + prefixes := []string{ + "::/0", + "2001:db8::/32", + "2001:db8:acad::/48", + "2001:db8:acad:1::/64", + "2001:db8:acad:1::1/128", + } + for _, s := range prefixes { + trie.Insert(netip.MustParsePrefix(s)) + } + + t.Run("Check existing prefixes", func(t *testing.T) { + for _, s := range prefixes { + p := netip.MustParsePrefix(s) + if !trie.ContainsPrefix(p) { + t.Errorf("expected trie to contain prefix %s, but it did not", p) + } + } + }) + + t.Run("Check non-existing prefixes", func(t *testing.T) { + nonExistentPrefixes := []string{ + "2001:db8::/33", + "2001:db8:acad::/47", + "2002::/16", + } + for _, s := range nonExistentPrefixes { + p := netip.MustParsePrefix(s) + if trie.ContainsPrefix(p) { + t.Errorf("expected trie to not contain prefix %s, but it did", p) + } + } + }) + + t.Run("Check host addresses with Contains", func(t *testing.T) { + if !trie.Contains(netip.MustParseAddr("2001:db8:acad:1::1")) { + t.Error("expected trie to contain host address 2001:db8:acad:1::1, but it did not") + } + if trie.Contains(netip.MustParseAddr("2001:db8:acad:1::2")) { + t.Error("expected trie to not contain host address 2001:db8:acad:1::2, but it did") + } + }) +} + +func TestTrie_Delete(t *testing.T) { + trie := New() + + p8 := netip.MustParsePrefix("10.0.0.0/8") + p24 := netip.MustParsePrefix("10.0.0.0/24") + pOther := netip.MustParsePrefix("10.0.1.0/24") + + trie.Insert(p8) + trie.Insert(p24) + trie.Insert(pOther) + + t.Run("Delete Non-Existent", func(t *testing.T) { + if trie.Delete(netip.MustParsePrefix("1.2.3.4/32")) { + t.Error("Delete returned true for a non-existent prefix") + } + }) + + t.Run("Delete Leaf Node", func(t *testing.T) { + if !trie.Delete(p24) { + t.Fatal("Delete returned false for an existing prefix") + } + if trie.ContainsPrefix(p24) { + t.Error("ContainsPrefix should be false after delete") + } + if !trie.ContainsPrefix(p8) { + t.Error("parent prefix should not be affected by child delete") + } + }) + + t.Run("Delete Intermediate Node", func(t *testing.T) { + // Re-insert p24 so we can delete its parent p8 + trie.Insert(p24) + + if !trie.Delete(p8) { + t.Fatal("Delete returned false for an existing intermediate prefix") + } + if trie.ContainsPrefix(p8) { + t.Error("ContainsPrefix for deleted intermediate node should be false") + } + if !trie.ContainsPrefix(p24) { + t.Error("child prefix should still exist after parent was deleted") + } + if !trie.ContainsPrefix(pOther) { + t.Error("other prefix should still exist after parent was deleted") + } + }) +} + +func TestTrie_Walk(t *testing.T) { + trie := New() + prefixes := []string{ + "10.0.0.0/8", + "192.168.1.0/24", + "2001:db8::/32", + "172.16.0.0/12", + "2001:db8:acad::/48", + } + for _, s := range prefixes { + trie.Insert(netip.MustParsePrefix(s)) + } + + t.Run("Walk all prefixes", func(t *testing.T) { + var walkedPrefixes []string + trie.Walk(func(p netip.Prefix) bool { + walkedPrefixes = append(walkedPrefixes, p.String()) + return true + }) + + if len(walkedPrefixes) != len(prefixes) { + t.Fatalf("expected to walk %d prefixes, but got %d", len(prefixes), len(walkedPrefixes)) + } + + sort.Strings(prefixes) + sort.Strings(walkedPrefixes) + + for i := range prefixes { + if prefixes[i] != walkedPrefixes[i] { + t.Errorf("walked prefixes mismatch: expected %v, got %v", prefixes, walkedPrefixes) + break + } + } + }) + + t.Run("Stop walk early", func(t *testing.T) { + count := 0 + trie.Walk(func(p netip.Prefix) bool { + count++ + return count < 3 // Stop after visiting 3 prefixes + }) + + if count != 3 { + t.Errorf("expected walk to stop after 3 prefixes, but it visited %d", count) + } + }) +} + +func TestTrie_Merge(t *testing.T) { + trieA := New() + trieA.Insert(netip.MustParsePrefix("10.0.0.0/8")) + trieA.Insert(netip.MustParsePrefix("192.168.1.0/24")) + trieA.Insert(netip.MustParsePrefix("2001:db8::/32")) // v6 + + trieB := New() + trieB.Insert(netip.MustParsePrefix("10.1.0.0/16")) + trieB.Insert(netip.MustParsePrefix("192.168.1.0/24")) // Overlap + trieB.Insert(netip.MustParsePrefix("172.16.0.0/12")) + trieB.Insert(netip.MustParsePrefix("2001:db8:acad::/48")) // v6 + + trieA.Merge(trieB) + + expectedPrefixes := []string{ + // From A + "10.0.0.0/8", + "192.168.1.0/24", + "2001:db8::/32", + // From B + "10.1.0.0/16", + "172.16.0.0/12", + "2001:db8:acad::/48", + } + + for _, s := range expectedPrefixes { + p := netip.MustParsePrefix(s) + if !trieA.ContainsPrefix(p) { + t.Errorf("after merge, trieA is missing prefix %s", p) + } + } + + // Check a prefix that should not be there + if trieA.ContainsPrefix(netip.MustParsePrefix("9.9.9.9/32")) { + t.Error("trieA contains a prefix it should not have") + } + + // Verify trieB is unchanged + if !trieB.ContainsPrefix(netip.MustParsePrefix("172.16.0.0/12")) { + t.Error("trieB was modified during merge") + } + if trieB.ContainsPrefix(netip.MustParsePrefix("10.0.0.0/8")) { + t.Error("trieB contains a prefix from trieA") + } +} diff --git a/dataset/nettrie/valuetrie.go b/dataset/nettrie/valuetrie.go new file mode 100644 index 0000000..d5977d6 --- /dev/null +++ b/dataset/nettrie/valuetrie.go @@ -0,0 +1,328 @@ +package nettrie + +import ( + "math/bits" + "net/netip" +) + +// getBit returns the n-th bit of an IP address (0-indexed). +func getBit(addr netip.Addr, n int) byte { + slice := addr.AsSlice() + byteIndex := n / 8 + bitIndex := 7 - (n % 8) + return (slice[byteIndex] >> bitIndex) & 1 +} + +// commonPrefixLen computes the number of leading bits that are the same for two addresses. +func commonPrefixLen(a, b netip.Addr) int { + if a.Is4() != b.Is4() { + return 0 + } + aSlice := a.AsSlice() + bSlice := b.AsSlice() + + commonLen := 0 + for i := 0; i < len(aSlice); i++ { + xor := aSlice[i] ^ bSlice[i] + if xor == 0 { + commonLen += 8 + } else { + commonLen += bits.LeadingZeros8(xor) + return commonLen + } + } + return commonLen +} + +// ValueNode represents a node in the path-compressed trie. +// Each node represents a prefix and can have up to two children. +type ValueNode[T any] struct { + children [2]*ValueNode[T] + + // prefix is the full prefix represented by the path to this node. + prefix netip.Prefix + + value T + isValue bool +} + +// ValueTrie is a path-compressed radix trie that stores network prefixes and their values. +type ValueTrie[T any] struct { + rootV4 *ValueNode[T] + rootV6 *ValueNode[T] +} + +// NewValue creates and initializes a new ValueTrie. +func NewValue[T any]() *ValueTrie[T] { + return &ValueTrie[T]{} +} + +// Insert adds or updates a prefix in the trie with the given value. +func (t *ValueTrie[T]) Insert(p netip.Prefix, value T) { + p = p.Masked() + addr := p.Addr() + + if addr.Is4() { + t.rootV4 = t.insert(t.rootV4, p, value) + } else { + t.rootV6 = t.insert(t.rootV6, p, value) + } +} + +// insert is the recursive helper for inserting a prefix into the trie. +func (t *ValueTrie[T]) insert(node *ValueNode[T], p netip.Prefix, value T) *ValueNode[T] { + if node == nil { + return &ValueNode[T]{prefix: p, value: value, isValue: true} + } + + addr := p.Addr() + commonLen := commonPrefixLen(addr, node.prefix.Addr()) + pBits := p.Bits() + nodeBits := node.prefix.Bits() + + if commonLen > pBits { + commonLen = pBits + } + if commonLen > nodeBits { + commonLen = nodeBits + } + + if commonLen == nodeBits && commonLen == pBits { + // Exact match, update the value. + node.value = value + node.isValue = true + return node + } + + if commonLen < nodeBits { + // The new prefix diverges from the current node's prefix. + // We must split the current node. + commonP, _ := node.prefix.Addr().Prefix(commonLen) + splitNode := &ValueNode[T]{prefix: commonP} + + // The existing node becomes a child of the new split node. + bit := getBit(node.prefix.Addr(), commonLen) + splitNode.children[bit] = node + + if commonLen == pBits { + // The inserted prefix is a prefix of the node's original prefix. + // The new split node represents the inserted prefix and gets the value. + splitNode.value = value + splitNode.isValue = true + } else { + // The two prefixes diverge. Create a new child for the new prefix. + bit = getBit(addr, commonLen) + splitNode.children[bit] = &ValueNode[T]{prefix: p, value: value, isValue: true} + } + return splitNode + } + + // commonLen == nodeBits, meaning the current node's prefix is a prefix of the new one. + // We need to descend to a child. + bit := getBit(addr, commonLen) + node.children[bit] = t.insert(node.children[bit], p, value) + return node +} + +// Lookup finds the value associated with the most specific prefix that contains the given IP address. +func (t *ValueTrie[T]) Lookup(addr netip.Addr) (value T, ok bool) { + node := t.rootV4 + if addr.Is6() { + node = t.rootV6 + } + + var lastFoundValue T + var found bool + + for node != nil { + commonLen := commonPrefixLen(addr, node.prefix.Addr()) + nodeBits := node.prefix.Bits() + + // If the address doesn't share a prefix with the node, we can't be in this subtree. + if commonLen < nodeBits { + break + } + + // The address is within this node's prefix. If the node holds a value, + // it's our current best match. + if node.isValue { + lastFoundValue = node.value + found = true + } + + // We've matched the whole address, can't go deeper. + if commonLen == addr.BitLen() { + break + } + + // Descend to the next child based on the next bit after the node's prefix. + bit := getBit(addr, nodeBits) + node = node.children[bit] + } + + return lastFoundValue, found +} + +// Delete removes a prefix from the trie. It returns true if the prefix was found and removed. +func (t *ValueTrie[T]) Delete(p netip.Prefix) bool { + p = p.Masked() + addr := p.Addr() + + var changed bool + if addr.Is4() { + t.rootV4, changed = t.delete(t.rootV4, p) + } else { + t.rootV6, changed = t.delete(t.rootV6, p) + } + return changed +} + +// delete is the recursive helper for removing a prefix from the trie. +func (t *ValueTrie[T]) delete(node *ValueNode[T], p netip.Prefix) (*ValueNode[T], bool) { + if node == nil { + return nil, false + } + + addr := p.Addr() + pBits := p.Bits() + nodeBits := node.prefix.Bits() + commonLen := commonPrefixLen(addr, node.prefix.Addr()) + + // The prefix is not on this path. + if commonLen < nodeBits || commonLen < pBits && pBits < nodeBits { + return node, false + } + + var changed bool + if pBits > nodeBits { + // The prefix to delete is deeper in the trie. Recurse. + bit := getBit(addr, nodeBits) + node.children[bit], changed = t.delete(node.children[bit], p) + } else if pBits == nodeBits { + // This is the node to delete. Unset its value. + if !node.isValue { + return node, false // Prefix wasn't actually in the trie. + } + node.isValue = false + var zero T + node.value = zero + changed = true + } else { // pBits < nodeBits + return node, false // Prefix to delete is shorter, so can't be here. + } + + if !changed { + return node, false + } + + // Post-deletion cleanup: + // If the node has no value and can be merged with a single child, do so. + if !node.isValue { + if node.children[0] != nil && node.children[1] == nil { + return node.children[0], true + } + if node.children[0] == nil && node.children[1] != nil { + return node.children[1], true + } + } + + // If the node is now a leaf without a value, it can be removed entirely. + if !node.isValue && node.children[0] == nil && node.children[1] == nil { + return nil, true + } + + return node, true +} + +// Contains checks if the exact IP address exists in the trie as a full-length prefix. +func (t *ValueTrie[T]) Contains(addr netip.Addr) bool { + prefix := netip.PrefixFrom(addr, addr.BitLen()) + return t.ContainsPrefix(prefix) +} + +// ContainsPrefix checks if the exact prefix exists in the trie. +func (t *ValueTrie[T]) ContainsPrefix(p netip.Prefix) bool { + p = p.Masked() + addr := p.Addr() + pBits := p.Bits() + + node := t.rootV4 + if addr.Is6() { + node = t.rootV6 + } + + for node != nil { + commonLen := commonPrefixLen(addr, node.prefix.Addr()) + nodeBits := node.prefix.Bits() + + if commonLen < nodeBits { + // Path has diverged. The prefix cannot be in this subtree. + return false + } + + if pBits < nodeBits { + // The search prefix is shorter than the node's prefix, + // but they share a prefix. e.g. search /16, node is /24. + // The /16 is not explicitly in the trie. + return false + } + + if pBits == nodeBits { + // Found a node with the exact same prefix length. + // Because we also know commonLen >= nodeBits, the prefixes are identical. + return node.isValue + } + + // pBits > nodeBits, so we need to go deeper. + bit := getBit(addr, nodeBits) + node = node.children[bit] + } + + return false +} + +// WalkValueFunc is a function called for each prefix in the trie during a walk. +// Returning false from the function will stop the walk. +type WalkValueFunc[T any] func(p netip.Prefix, v T) bool + +// walk is the recursive helper for traversing the trie. +func walkValue[T any](node *ValueNode[T], f WalkValueFunc[T]) bool { + if node == nil { + return true + } + + if node.isValue { + if !f(node.prefix, node.value) { + return false + } + } + + if node.children[0] != nil { + if !walkValue(node.children[0], f) { + return false + } + } + if node.children[1] != nil { + if !walkValue(node.children[1], f) { + return false + } + } + return true +} + +// Walk traverses the trie and calls the given function for each prefix and its value. +// If the function returns false, the walk is stopped. The order is not guaranteed. +func (t *ValueTrie[T]) Walk(f WalkValueFunc[T]) { + if !walkValue(t.rootV4, f) { + return + } + walkValue(t.rootV6, f) +} + +// Merge inserts all prefixes from another Trie into this one. +func (t *ValueTrie[T]) Merge(other *ValueTrie[T]) { + other.Walk(func(p netip.Prefix, v T) bool { + t.Insert(p, v) + return true // continue walking + }) +} diff --git a/dataset/nettrie/valuetrie_test.go b/dataset/nettrie/valuetrie_test.go new file mode 100644 index 0000000..cac0825 --- /dev/null +++ b/dataset/nettrie/valuetrie_test.go @@ -0,0 +1,340 @@ +package nettrie + +import ( + "maps" + "net/netip" + "reflect" + "slices" + "testing" +) + +func TestValueTrie_IPv4(t *testing.T) { + trie := NewValue[string]() + + routes := map[string]string{ + "0.0.0.0/0": "Default", + "10.0.0.0/8": "Private A", + "192.168.0.0/16": "Private B", + "192.168.1.0/24": "LAN", + "192.168.1.128/25": "Subnet", + "192.168.1.200/32": "Host", + } + + for s, v := range routes { + trie.Insert(netip.MustParsePrefix(s), v) + } + + testCases := []struct { + name string + lookupAddr string + expectedVal string + expectedOK bool + }{ + {"Exact Host Match", "192.168.1.200", "Host", true}, + {"Subnet Match", "192.168.1.201", "Subnet", true}, + {"LAN Match", "192.168.1.50", "LAN", true}, + {"Private B Match", "192.168.255.255", "Private B", true}, + {"Private A Match", "10.255.255.255", "Private A", true}, + {"Default Match", "8.8.8.8", "Default", true}, + {"No Match", "203.0.113.1", "Default", true}, // Falls back to default + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + addr := netip.MustParseAddr(tc.lookupAddr) + val, ok := trie.Lookup(addr) + + if ok != tc.expectedOK { + t.Errorf("expected ok=%v, got ok=%v", tc.expectedOK, ok) + } + if val != tc.expectedVal { + t.Errorf("expected value=%q, got %q", tc.expectedVal, val) + } + }) + } +} + +func TestValueTrie_IPv6(t *testing.T) { + trie := NewValue[string]() + + routes := map[string]string{ + "::/0": "Default", + "2001:db8::/32": "Global Unicast", + "2001:db8:acad::/48": "Academic", + "2001:db8:acad:1::/64": "CS Department", + "2001:db8:acad:1::1/128": "Host Route", + } + + for s, v := range routes { + trie.Insert(netip.MustParsePrefix(s), v) + } + + testCases := []struct { + name string + lookupAddr string + expectedVal string + expectedOK bool + }{ + {"Exact Host Match", "2001:db8:acad:1::1", "Host Route", true}, + {"CS Dept Match", "2001:db8:acad:1::2", "CS Department", true}, + {"Academic Match", "2001:db8:acad:2::1", "Academic", true}, + {"Global Unicast Match", "2001:db8:cafe::1", "Global Unicast", true}, + {"Default Match", "2606:4700:4700::1111", "Default", true}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + addr := netip.MustParseAddr(tc.lookupAddr) + val, ok := trie.Lookup(addr) + + if ok != tc.expectedOK { + t.Errorf("expected ok=%v, got ok=%v", tc.expectedOK, ok) + } + if val != tc.expectedVal { + t.Errorf("expected value=%q, got %q", tc.expectedVal, val) + } + }) + } +} + +func TestValueTrie_UpdateValue(t *testing.T) { + trie := NewValue[string]() + prefix := netip.MustParsePrefix("192.168.1.0/24") + addr := netip.MustParseAddr("192.168.1.1") + + trie.Insert(prefix, "Initial Value") + val, _ := trie.Lookup(addr) + if val != "Initial Value" { + t.Fatalf("expected initial value, got %q", val) + } + + trie.Insert(prefix, "Updated Value") + val, _ = trie.Lookup(addr) + if val != "Updated Value" { + t.Fatalf("expected updated value, got %q", val) + } +} + +func TestValueTrie_Delete(t *testing.T) { + trie := NewValue[string]() + + trie.Insert(netip.MustParsePrefix("10.0.0.0/8"), "A") + trie.Insert(netip.MustParsePrefix("10.0.0.0/24"), "B") + trie.Insert(netip.MustParsePrefix("10.0.1.0/24"), "C") + + t.Run("Delete Non-Existent", func(t *testing.T) { + if trie.Delete(netip.MustParsePrefix("1.2.3.4/32")) { + t.Error("deleted a non-existent prefix") + } + }) + + t.Run("Delete Leaf Node", func(t *testing.T) { + // Delete 10.0.0.0/24, which is a leaf. + if !trie.Delete(netip.MustParsePrefix("10.0.0.0/24")) { + t.Fatal("failed to delete 10.0.0.0/24") + } + // Lookup should now resolve to the parent 10.0.0.0/8 + val, ok := trie.Lookup(netip.MustParseAddr("10.0.0.1")) + if !ok || val != "A" { + t.Errorf("lookup failed after delete, got val=%q ok=%v, want val=\"A\" ok=true", val, ok) + } + }) + + t.Run("Delete and Merge", func(t *testing.T) { + // Insert a new prefix that causes a split + trie.Insert(netip.MustParsePrefix("10.0.0.0/8"), "A-updated") + + // We now have 10.0.0.0/8 and 10.0.1.0/24. Deleting 10.0.1.0/24 should + // cause the split node to be removed and merged back into 10.0.0.0/8. + if !trie.Delete(netip.MustParsePrefix("10.0.1.0/24")) { + t.Fatal("failed to delete 10.0.1.0/24") + } + + // A lookup for 10.0.1.1 should now match 10.0.0.0/8 + val, ok := trie.Lookup(netip.MustParseAddr("10.0.1.1")) + if !ok || val != "A-updated" { + t.Errorf("lookup failed after merge, got val=%q ok=%v, want val=\"A-updated\" ok=true", val, ok) + } + }) + + t.Run("Delete Prefix Of Another", func(t *testing.T) { + // Delete 10.0.0.0/8 + trie.Insert(netip.MustParsePrefix("10.0.0.0/8"), "A") + trie.Insert(netip.MustParsePrefix("10.1.0.0/16"), "D") + + if !trie.Delete(netip.MustParsePrefix("10.0.0.0/8")) { + t.Fatal("failed to delete 10.0.0.0/8") + } + + // Lookup for 10.0.0.1 should now fail (no default route) + _, ok := trie.Lookup(netip.MustParseAddr("10.0.0.1")) + if ok { + t.Error("lookup for 10.0.0.1 should have failed") + } + + // Lookup for 10.1.0.1 should still succeed + val, ok := trie.Lookup(netip.MustParseAddr("10.1.0.1")) + if !ok || val != "D" { + t.Error("lookup for 10.1.0.1 failed unexpectedly") + } + }) +} + +func TestValueTrie_Contains(t *testing.T) { + trie := NewValue[string]() + + prefixes := []string{ + "10.0.0.0/8", + "192.168.1.0/24", + "192.168.1.200/32", + "2001:db8::/32", + "2001:db8:acad:1::1/128", + } + + for _, s := range prefixes { + trie.Insert(netip.MustParsePrefix(s), "present") + } + + t.Run("ContainsPrefix", func(t *testing.T) { + testCases := []struct { + prefix string + want bool + }{ + {"10.0.0.0/8", true}, + {"192.168.1.200/32", true}, + {"2001:db8::/32", true}, + {"2001:db8:acad:1::1/128", true}, + {"10.0.0.0/9", false}, // shorter parent, but not exact + {"192.168.1.0/25", false}, // non-existent child + {"172.16.0.0/12", false}, // completely different prefix + {"2001:db8::/48", false}, // non-existent child + } + + for _, tc := range testCases { + p := netip.MustParsePrefix(tc.prefix) + if got := trie.ContainsPrefix(p); got != tc.want { + t.Errorf("ContainsPrefix(%q) = %v, want %v", tc.prefix, got, tc.want) + } + } + }) + + t.Run("Contains", func(t *testing.T) { + testCases := []struct { + addr string + want bool + }{ + {"192.168.1.200", true}, + {"2001:db8:acad:1::1", true}, + {"192.168.1.201", false}, // In /24 range, but not a /32 host route + {"10.0.0.1", false}, // In /8 range, but not a /32 host route + } + + for _, tc := range testCases { + a := netip.MustParseAddr(tc.addr) + if got := trie.Contains(a); got != tc.want { + t.Errorf("Contains(%q) = %v, want %v", tc.addr, got, tc.want) + } + } + }) +} + +func TestValueTrie_Walk(t *testing.T) { + trie := NewValue[string]() + prefixes := map[string]string{ + "10.0.0.0/8": "A", + "192.168.1.0/24": "B", + "2001:db8::/32": "C", + "172.16.0.0/12": "D", + "2001:db8:acad::/48": "E", + } + for s, v := range prefixes { + trie.Insert(netip.MustParsePrefix(s), v) + } + + t.Run("Walk all prefixes", func(t *testing.T) { + walked := make(map[string]string) + trie.Walk(func(p netip.Prefix, v string) bool { + walked[p.String()] = v + return true + }) + + if !maps.Equal(walked, prefixes) { + t.Errorf("walked prefixes mismatch:\nexpected: %v\ngot: %v", prefixes, walked) + } + }) + + t.Run("Stop walk early", func(t *testing.T) { + count := 0 + trie.Walk(func(p netip.Prefix, v string) bool { + count++ + return count < 3 // Stop after visiting 3 prefixes + }) + + if count != 3 { + t.Errorf("expected walk to stop after 3 prefixes, but it visited %d", count) + } + }) + + t.Run("Stop walk between families", func(t *testing.T) { + stopTrie := NewValue[int]() + stopTrie.Insert(netip.MustParsePrefix("10.0.0.0/8"), 1) + stopTrie.Insert(netip.MustParsePrefix("2001:db8::/32"), 2) + + count := 0 + stopTrie.Walk(func(p netip.Prefix, v int) bool { + count++ + return false + }) + + if count != 1 { + t.Errorf("expected walk to stop after 1 prefix, but it visited %d", count) + } + }) +} + +func TestValueTrie_Merge(t *testing.T) { + trieA := NewValue[string]() + trieA.Insert(netip.MustParsePrefix("10.0.0.0/8"), "net10") + trieA.Insert(netip.MustParsePrefix("192.168.1.0/24"), "lan_A") + trieA.Insert(netip.MustParsePrefix("2001:db8::/32"), "v6_A") + + trieB := NewValue[string]() + trieB.Insert(netip.MustParsePrefix("10.1.0.0/16"), "net10_subnet") + trieB.Insert(netip.MustParsePrefix("192.168.1.0/24"), "lan_B_override") // Overlap + trieB.Insert(netip.MustParsePrefix("172.16.0.0/12"), "corp") + trieB.Insert(netip.MustParsePrefix("2001:db8:acad::/48"), "v6_B") + + trieA.Merge(trieB) + + expected := map[string]string{ + "10.0.0.0/8": "net10", + "192.168.1.0/24": "lan_B_override", // Value should be from trieB + "2001:db8::/32": "v6_A", + "10.1.0.0/16": "net10_subnet", + "172.16.0.0/12": "corp", + "2001:db8:acad::/48": "v6_B", + } + + actual := make(map[string]string) + trieA.Walk(func(p netip.Prefix, v string) bool { + actual[p.String()] = v + return true + }) + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("Merge result incorrect.\nExpected: %v\nGot: %v", expected, actual) + } + + // Verify trieB is unchanged by collecting its prefixes + bPrefixes := []string{} + trieB.Walk(func(p netip.Prefix, v string) bool { + bPrefixes = append(bPrefixes, p.String()) + return true + }) + expectedBPrefixes := []string{"10.1.0.0/16", "172.16.0.0/12", "192.168.1.0/24", "2001:db8:acad::/48"} + slices.Sort(bPrefixes) + slices.Sort(expectedBPrefixes) + if !reflect.DeepEqual(bPrefixes, expectedBPrefixes) { + t.Errorf("trieB was modified during merge.\nExpected: %v\nGot: %v", expectedBPrefixes, bPrefixes) + } +} diff --git a/dataset/network.go b/dataset/network.go index b377362..3511c33 100644 --- a/dataset/network.go +++ b/dataset/network.go @@ -2,10 +2,37 @@ package dataset import ( "net" + "net/netip" + "git.maze.io/maze/styx/dataset/nettrie" "github.com/yl2chen/cidranger" ) +type NetworkTrie struct { + *nettrie.Trie +} + +func NewNetworkTrie(prefixes ...netip.Prefix) *NetworkTrie { + trie := &NetworkTrie{ + Trie: nettrie.New(), + } + for _, prefix := range prefixes { + trie.Insert(prefix) + } + return trie +} + +func (trie *NetworkTrie) ContainsIP(ip net.IP) bool { + if ip == nil { + return false + } + addr, ok := netip.AddrFromSlice(ip) + if !ok { + return false + } + return trie.Contains(addr) +} + type NetworkTree struct { ranger cidranger.Ranger } diff --git a/dataset/storage.go b/dataset/storage.go index a5366a8..52fca27 100644 --- a/dataset/storage.go +++ b/dataset/storage.go @@ -1,21 +1,20 @@ package dataset import ( - "bufio" "bytes" "fmt" "io" "io/fs" "net" "net/http" + "net/netip" "net/url" "os" "slices" - "strings" "time" + "git.maze.io/maze/styx/dataset/parser" _ "github.com/mattn/go-sqlite3" // SQLite3 driver - "github.com/miekg/dns" ) type Storage interface { @@ -27,11 +26,13 @@ type Storage interface { Clients() (Clients, error) ClientByID(int64) (Client, error) - ClientByIP(net.IP) (Client, error) + ClientByAddr(netip.Addr) (Client, error) + // ClientByIP(net.IP) (Client, error) SaveClient(*Client) error DeleteClient(Client) error Lists() ([]List, error) + ListsByGroup(Group) ([]List, error) ListByID(int64) (List, error) SaveList(*List) error DeleteList(List) error @@ -44,7 +45,6 @@ type Group struct { Description string `json:"description"` CreatedAt time.Time `json:"created_at" bstore:"nonzero"` UpdatedAt time.Time `json:"updated_at" bstore:"nonzero"` - Storage Storage `json:"-" bstore:"-"` } type Client struct { @@ -56,11 +56,6 @@ type Client struct { Groups []Group `json:"groups,omitempty" bstore:"-"` CreatedAt time.Time `json:"created_at" bstore:"nonzero"` UpdatedAt time.Time `json:"updated_at" bstore:"nonzero"` - Storage Storage `json:"-" bstore:"-"` -} - -type WithClient interface { - Client() (Client, error) } type ClientGroup struct { @@ -80,6 +75,15 @@ func (c *Client) ContainsIP(ip net.IP) bool { return ipnet.Contains(ip) } +func (c *Client) ContainsAddr(ip netip.Addr) bool { + return c.Prefix().Contains(ip) +} + +func (c Client) Prefix() netip.Prefix { + ip, _ := netip.ParseAddr(c.IP) + return netip.PrefixFrom(ip, c.Mask) +} + func (c *Client) String() string { ipnet := &net.IPNet{ IP: net.ParseIP(c.IP), @@ -136,28 +140,29 @@ type List struct { UpdatedAt time.Time `json:"updated_at"` } -func (list *List) Domains() (*DomainTree, error) { +func (list *List) Networks() (*NetworkTrie, error) { + if list.Type != ListTypeNetwork { + return nil, nil + } + + prefixes, _, err := parser.ParseNetworks(bytes.NewReader(list.Cache)) + if err != nil { + return nil, err + } + return NewNetworkTrie(prefixes...), nil +} + +func (list *List) Domains() (*DomainTrie, error) { if list.Type != ListTypeDomain { return nil, nil } - var ( - tree = NewDomainList() - scan = bufio.NewScanner(bytes.NewReader(list.Cache)) - ) - for scan.Scan() { - line := strings.TrimSpace(scan.Text()) - if line == "" || line[0] == '#' { - continue - } - if labels, ok := dns.IsDomainName(line); ok && labels >= 2 { - tree.Add(line) - } - } - if err := scan.Err(); err != nil { + domains, _, err := parser.ParseDomains(bytes.NewReader(list.Cache)) + if err != nil { return nil, err } - return tree, nil + + return NewDomainTrie(list.Permit, domains...) } func (list *List) Update() (updated bool, err error) { diff --git a/dataset/storage_bstore.go b/dataset/storage_bstore.go index 89168a2..68748b3 100644 --- a/dataset/storage_bstore.go +++ b/dataset/storage_bstore.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "net/netip" "path/filepath" "slices" "strings" @@ -20,13 +21,18 @@ type bstoreStorage struct { } func OpenBStore(name string) (Storage, error) { + log := logger.StandardLog.Value("database", name) + if !filepath.IsAbs(name) { var err error if name, err = filepath.Abs(name); err != nil { + log.Err(err).Error("Opening BoltDB storage failed; invalid path") return nil, err } + log = log.Value("database", name) } + log.Debug("Opening BoltDB storage") ctx := context.Background() db, err := bstore.Open(ctx, name, nil, Group{}, @@ -36,6 +42,7 @@ func OpenBStore(name string) (Storage, error) { ListGroup{}, ) if err != nil { + log.Err(err).Error("Opening BoltDB storage failed") return nil, err } @@ -47,6 +54,7 @@ func OpenBStore(name string) (Storage, error) { ) if defaultGroup, err = s.GroupByName("Default"); errors.Is(err, bstore.ErrAbsent) { + log.Debug("Creating default group") defaultGroup = Group{ Name: "Default", IsEnabled: true, @@ -63,6 +71,7 @@ func OpenBStore(name string) (Storage, error) { FilterFn(func(client Client) bool { return net.ParseIP(client.IP).Equal(net.ParseIP("0.0.0.0")) && client.Mask == 0 }).Get(); errors.Is(err, bstore.ErrAbsent) { + log.Debug("Creating default IPv4 clients") defaultClient4 = Client{ Network: "ipv4", IP: "0.0.0.0", @@ -83,6 +92,7 @@ func OpenBStore(name string) (Storage, error) { FilterFn(func(client Client) bool { return net.ParseIP(client.IP).Equal(net.ParseIP("::")) && client.Mask == 0 }).Get(); errors.Is(err, bstore.ErrAbsent) { + log.Debug("Creating default IPv6 clients") defaultClient6 = Client{ Network: "ipv6", IP: "::", @@ -100,6 +110,7 @@ func OpenBStore(name string) (Storage, error) { } // Start updater + log.Trace("Starting list updater") NewUpdater(s) return s, nil @@ -197,7 +208,12 @@ func (s *bstoreStorage) ClientByID(id int64) (Client, error) { } func (s *bstoreStorage) ClientByIP(ip net.IP) (Client, error) { - if ip == nil { + addr, _ := netip.AddrFromSlice(ip) + return s.ClientByAddr(addr) +} + +func (s *bstoreStorage) ClientByAddr(addr netip.Addr) (Client, error) { + if !addr.IsValid() { return Client{}, ErrNotExist{Object: "client"} } var ( @@ -205,9 +221,9 @@ func (s *bstoreStorage) ClientByIP(ip net.IP) (Client, error) { clients Clients network string ) - if ip4 := ip.To4(); ip4 != nil { + if addr.Is4() { network = "ipv4" - } else if ip6 := ip.To16(); ip6 != nil { + } else { network = "ipv6" } if network == "" { @@ -216,7 +232,7 @@ func (s *bstoreStorage) ClientByIP(ip net.IP) (Client, error) { for client, err := range bstore.QueryDB[Client](ctx, s.db). FilterEqual("Network", network). FilterFn(func(client Client) bool { - return client.ContainsIP(ip) + return client.ContainsAddr(addr) }).All() { if err != nil { return Client{}, err @@ -320,6 +336,26 @@ func (s *bstoreStorage) Lists() ([]List, error) { return lists, nil } +func (s *bstoreStorage) ListsByGroup(group Group) ([]List, error) { + ctx := context.Background() + ids := make([]int64, 0) + for item, err := range bstore.QueryDB[ListGroup](ctx, s.db).FilterEqual("GroupID", group.ID).All() { + if err != nil { + return nil, err + } + ids = append(ids, item.ListID) + } + + var lists []List + for list, err := range bstore.QueryDB[List](ctx, s.db).FilterIDs(ids).All() { + if err != nil { + return nil, err + } + lists = append(lists, list) + } + return lists, nil +} + func (s *bstoreStorage) ListByID(id int64) (List, error) { ctx := context.Background() list, err := bstore.QueryDB[List](ctx, s.db).FilterID(id).Get() diff --git a/dataset/storage_cache.go b/dataset/storage_cache.go new file mode 100644 index 0000000..11b566f --- /dev/null +++ b/dataset/storage_cache.go @@ -0,0 +1,142 @@ +package dataset + +import ( + "fmt" + "net" + "net/netip" + "sync" + "time" + + "git.maze.io/maze/styx/logger" +) + +const MinCacheExpire = 10 * time.Second + +type cache struct { + Storage + expire time.Duration + groupByID sync.Map + clientByAddr sync.Map + listByGroup sync.Map + closed chan struct{} +} + +type cacheItem struct { + cachedAt time.Time + value any +} + +// Cache items returned from a Storage for the specified duration. +// +// Does not cache negative hits. +func Cache(storage Storage, expire time.Duration) Storage { + if expire < MinCacheExpire { + expire = MinCacheExpire + } + + logger.StandardLog.Value("expire", expire).Debug("Caching Storage responses") + s := &cache{ + Storage: storage, + expire: expire, + closed: make(chan struct{}, 1), + } + go s.cleanUpTimer() + return s +} + +func (s *cache) cleanUpTimer() { + ticker := time.NewTicker(s.expire) + defer ticker.Stop() + + for { + select { + case <-s.closed: + return + + case now := <-ticker.C: + logger.StandardLog.Trace("Cache cleanup running") + s.cleanUp(now, &s.groupByID) + s.cleanUp(now, &s.clientByAddr) + s.cleanUp(now, &s.listByGroup) + } + } +} + +func (s *cache) cleanUp(now time.Time, cacheMap *sync.Map) { + cacheMap.Range(func(key, item any) bool { + cached := item.(cacheItem) + if ago := now.Sub(cached.cachedAt); ago >= s.expire { + logger.StandardLog.Values(logger.Values{ + "ago": ago, + "type": fmt.Sprintf("%T", cached.value), + "item": fmt.Sprintf("%s", cached.value), + }).Debug("Cache removing expired item") + cacheMap.Delete(key) + } + return true + }) +} + +func (s *cache) load(now time.Time, cacheMap *sync.Map, key any) (value any, ok bool) { + var item any + if item, ok = cacheMap.Load(key); !ok { + return + } + + cached := item.(cacheItem) + if now.Sub(cached.cachedAt) < s.expire { + return cached.value, true + } + return nil, false +} + +func (s *cache) save(now time.Time, cacheMap *sync.Map, key, value any) { + cacheMap.Store(key, cacheItem{ + cachedAt: now, + value: value, + }) +} + +func (s *cache) GroupByID(id int64) (Group, error) { + now := time.Now() + if value, ok := s.load(now, &s.groupByID, id); ok { + return value.(Group), nil + } + + group, err := s.Storage.GroupByID(id) + if err == nil { + s.save(now, &s.groupByID, id, group) + } + return group, err +} + +func (s *cache) ClientByIP(ip net.IP) (Client, error) { + addr, _ := netip.AddrFromSlice(ip) + return s.ClientByAddr(addr) +} + +func (s *cache) ClientByAddr(ip netip.Addr) (Client, error) { + now := time.Now() + if value, ok := s.load(now, &s.clientByAddr, ip); ok { + return value.(Client), nil + } + + client, err := s.Storage.ClientByAddr(ip) + if err == nil { + s.save(now, &s.clientByAddr, ip, client) + } + return client, err +} + +func (s *cache) ListsByGroup(group Group) ([]List, error) { + now := time.Now() + if value, ok := s.load(now, &s.listByGroup, group.ID); ok { + return value.([]List), nil + } + + lists, err := s.Storage.ListsByGroup(group) + if err == nil { + s.save(now, &s.listByGroup, group.ID, lists) + } + return lists, err +} diff --git a/go.mod b/go.mod index c941b30..42d5ff9 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/open-policy-agent/opa v1.9.0 github.com/rs/zerolog v1.34.0 github.com/sirupsen/logrus v1.9.4-0.20230606125235-dd1b4c2e81af + github.com/stretchr/testify v1.11.1 github.com/yl2chen/cidranger v1.0.2 ) @@ -20,6 +21,7 @@ require ( github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect github.com/dgraph-io/ristretto/v2 v2.3.0 // indirect github.com/go-ini/ini v1.67.0 // indirect @@ -42,6 +44,7 @@ require ( github.com/mattn/go-isatty v0.0.20 // indirect github.com/mitchellh/go-wordwrap v1.0.1 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/client_golang v1.23.2 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.66.1 // indirect diff --git a/policy/handler.go b/policy/handler.go index f32c729..b7b4f4f 100644 --- a/policy/handler.go +++ b/policy/handler.go @@ -12,12 +12,12 @@ import ( proxy "git.maze.io/maze/styx/proxy" ) -func NewRequestHandler(p *Policy) proxy.RequestHandler { - log := logger.StandardLog.Value("policy", p.name) +func NewRequestHandler(policy *Policy) proxy.RequestHandler { + log := logger.StandardLog.Value("policy", policy.name) return proxy.RequestHandlerFunc(func(ctx proxy.Context) (*http.Request, *http.Response) { input := NewInputFromRequest(ctx, ctx.Request()) input.logValues(log).Trace("Running request handler") - result, err := p.Query(input) + result, err := policy.Query(input, proxy.PolicyQueryOptions(ctx)...) if err != nil { log.Err(err).Error("Error evaulating policy") return nil, nil @@ -32,12 +32,12 @@ func NewRequestHandler(p *Policy) proxy.RequestHandler { }) } -func NewDialHandler(p *Policy) proxy.DialHandler { - log := logger.StandardLog.Value("policy", p.name) +func NewDialHandler(policy *Policy) proxy.DialHandler { + log := logger.StandardLog.Value("policy", policy.name) return proxy.DialHandlerFunc(func(ctx proxy.Context, req *http.Request) (net.Conn, error) { input := NewInputFromRequest(ctx, req) input.logValues(log).Trace("Running dial handler") - result, err := p.Query(input) + result, err := policy.Query(input, proxy.PolicyQueryOptions(ctx)...) if err != nil { log.Err(err).Error("Error evaulating policy") return nil, nil diff --git a/policy/input.go b/policy/input.go index 84327c6..c2e3c17 100644 --- a/policy/input.go +++ b/policy/input.go @@ -10,7 +10,6 @@ import ( "net/url" "strconv" - "git.maze.io/maze/styx/dataset" "git.maze.io/maze/styx/internal/netutil" "git.maze.io/maze/styx/logger" proxy "git.maze.io/maze/styx/proxy" @@ -48,14 +47,16 @@ func NewInputFromConn(c net.Conn) *Input { TLS: NewTLSFromConn(c), } - if wcl, ok := c.(dataset.WithClient); ok { - client, err := wcl.Client() - if err == nil { - input.Context["client_id"] = client.ID - input.Context["client_description"] = client.Description - input.Context["groups"] = client.Groups + /* + if wcl, ok := c.(dataset.WithClient); ok { + client, err := wcl.Client() + if err == nil { + input.Context["client_id"] = client.ID + input.Context["client_description"] = client.Description + input.Context["groups"] = client.Groups + } } - } + */ if ctx, ok := c.(proxy.Context); ok { input.Context["local"] = NewClientFromAddr(ctx.LocalAddr()) diff --git a/policy/policy.go b/policy/policy.go index 6db76c5..da84ada 100644 --- a/policy/policy.go +++ b/policy/policy.go @@ -1,6 +1,7 @@ package policy import ( + "bufio" "bytes" "context" "errors" @@ -10,6 +11,7 @@ import ( "net/http" "os" "path/filepath" + "strings" "github.com/go-viper/mapstructure/v2" "github.com/open-policy-agent/opa/v1/ast" @@ -168,12 +170,17 @@ func (r *Result) Response(ctx proxy.Context) (*http.Response, error) { } } -func (p *Policy) Query(input *Input) (*Result, error) { +func (p *Policy) Query(input *Input, options ...func(*rego.Rego)) (*Result, error) { log := logger.StandardLog.Value("policy", p.name) log.Trace("Evaluating policy") + var regoOptions = append(p.options, rego.Input(input)) + for _, option := range options { + regoOptions = append(regoOptions, option) + } + var ( - rego = rego.New(append(p.options, rego.Input(input))...) + rego = rego.New(regoOptions...) ctx = context.Background() rs, err = rego.Eval(ctx) ) @@ -200,3 +207,27 @@ func (p *Policy) Query(input *Input) (*Result, error) { } return result, nil } + +// PackageFromFile reads the "package" stanza from the provided Rego policy file. +// +// If no stanza can be found, an error is returned. +func PackageFromFile(name string) (string, error) { + f, err := os.Open(name) + if err != nil { + return "", err + } + defer func() { _ = f.Close() }() + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + text := strings.TrimSpace(scanner.Text()) + part := strings.Fields(text) + if len(part) > 1 && part[0] == "package" { + return part[1], nil + } + } + if err := scanner.Err(); err != nil { + return "", err + } + return "", fmt.Errorf("policy: can't detemine package name of %s", name) +} diff --git a/proxy/context.go b/proxy/context.go index a984be8..1126c9f 100644 --- a/proxy/context.go +++ b/proxy/context.go @@ -45,8 +45,11 @@ type Context interface { // Response is the response that will be sent back to the client. Response() *http.Response + // Logger for this context. + Logger() logger.Structured + // Client group. - Client() (dataset.Client, error) + Storage() dataset.Storage } type WithCertificateAuthority interface { @@ -91,11 +94,10 @@ type proxyContext struct { idleTimeout time.Duration ca ca.CertificateAuthority storage dataset.Storage - client dataset.Client } // NewContext returns an initialized context for the provided [net.Conn]. -func NewContext(c net.Conn) Context { +func NewContext(c net.Conn, storage dataset.Storage) Context { if c, ok := c.(*proxyContext); ok { return c } @@ -106,12 +108,13 @@ func NewContext(c net.Conn) Context { cr := &countingReader{reader: c} cw := &countingWriter{writer: c} return &proxyContext{ - Conn: c, - id: binary.BigEndian.Uint64(b), - cr: cr, - br: bufio.NewReader(cr), - cw: cw, - res: &http.Response{StatusCode: 200}, + Conn: c, + id: binary.BigEndian.Uint64(b), + cr: cr, + br: bufio.NewReader(cr), + cw: cw, + res: &http.Response{StatusCode: 200}, + storage: storage, } } @@ -128,7 +131,7 @@ func (c *proxyContext) AccessLogEntry() logger.Structured { return entry } -func (c *proxyContext) LogEntry() logger.Structured { +func (c *proxyContext) Logger() logger.Structured { var id [8]byte binary.BigEndian.PutUint64(id[:], c.id) return ServerLog.Values(logger.Values{ @@ -234,24 +237,8 @@ func (c *proxyContext) CertificateAuthority() ca.CertificateAuthority { return c.ca } -func (c *proxyContext) Client() (dataset.Client, error) { - if c.storage == nil { - return dataset.Client{}, dataset.ErrNotExist{Object: "client"} - } - if !c.client.CreatedAt.Equal(time.Time{}) { - return c.client, nil - } - - var err error - switch addr := c.Conn.RemoteAddr().(type) { - case *net.TCPAddr: - c.client, err = c.storage.ClientByIP(addr.IP) - case *net.UDPAddr: - c.client, err = c.storage.ClientByIP(addr.IP) - default: - err = dataset.ErrNotExist{Object: "client"} - } - return c.client, err +func (c *proxyContext) Storage() dataset.Storage { + return c.storage } var _ Context = (*proxyContext)(nil) diff --git a/proxy/handler.go b/proxy/handler.go index 84db558..4f32791 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -144,18 +144,21 @@ func Transparent(port int) ConnHandler { return nctx, nil } - b := new(bytes.Buffer) - hello, err := cryptutil.ReadClientHello(io.TeeReader(netutil.ReadOnlyConn{Reader: ctx.br}, b)) + var ( + b = new(bytes.Buffer) + hello, err = cryptutil.ReadClientHello(io.TeeReader(netutil.ReadOnlyConn{Reader: ctx.br}, b)) + log = ctx.Logger() + ) if err != nil { if _, ok := err.(tls.RecordHeaderError); !ok { - ctx.LogEntry().Err(err).Value("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error") + log.Err(err).Value("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error") return nil, err } // Not a TLS connection, moving on to regular HTTP request handling... - ctx.LogEntry().Debug("HTTP connection on transparent port") + log.Debug("HTTP connection on transparent port") ctx.transparent = port } else { - ctx.LogEntry().Value("target", hello.ServerName).Debug("TLS connection on transparent port") + log.Value("target", hello.ServerName).Debug("TLS connection on transparent port") ctx.transparent = port ctx.transparentTLS = true ctx.serverName = hello.ServerName diff --git a/proxy/policy.go b/proxy/policy.go new file mode 100644 index 0000000..f2c7dcb --- /dev/null +++ b/proxy/policy.go @@ -0,0 +1,178 @@ +package proxy + +import ( + "errors" + "net" + "net/netip" + "strings" + + "git.maze.io/maze/styx/dataset" + "git.maze.io/maze/styx/internal/netutil" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/rego" + "github.com/open-policy-agent/opa/v1/types" +) + +// PolicyQueryOptions generates the Rego query functions for the provided [Context]. +func PolicyQueryOptions(ctx Context) (options []func(*rego.Rego)) { + var ( + log = ctx.Logger() + storage = ctx.Storage() + ) + + addr, err := netip.ParseAddr(netutil.Host(ctx.RemoteAddr().String())) + if err != nil { + log.Err(err).Error("Error resolving remote address") + return + } + + client, err := storage.ClientByAddr(addr) + if err != nil { + log.Err(err).Warn("Error resolving client") + return + } + + var ( + permitDomains []*dataset.DomainTrie + rejectDomains []*dataset.DomainTrie + permitNetworks []*dataset.NetworkTrie + rejectNetworks []*dataset.NetworkTrie + ) + for _, group := range client.Groups { + lists, err := storage.ListsByGroup(group) + if err != nil { + log.Err(err).Warn("Error resolving lists") + return + } + for _, list := range lists { + switch list.Type { + case dataset.ListTypeDomain: + trie, err := list.Domains() + if err != nil { + log.Err(err).Warn("Error resolving domain trie") + } + if list.Permit { + permitDomains = append(permitDomains, trie) + } else { + rejectDomains = append(rejectDomains, trie) + } + case dataset.ListTypeNetwork: + trie, err := list.Networks() + if err != nil { + log.Err(err).Warn("Error resolving domain trie") + } + if list.Permit { + permitNetworks = append(permitNetworks, trie) + } else { + rejectNetworks = append(rejectNetworks, trie) + } + } + } + } + + options = append(options, + rego.Function1(®o.Function{ + Name: "styx.reject_domain", + Description: "Check if the domain is to be rejected", + Decl: domainFunctionDecl, + Nondeterministic: true, + Memoize: true, + }, domainFunctionImpl(rejectDomains)), + rego.Function1(®o.Function{ + Name: "styx.permit_domain", + Description: "Check if the domain is to be permitted", + Decl: domainFunctionDecl, + Nondeterministic: true, + Memoize: true, + }, domainFunctionImpl(permitDomains)), + rego.Function1(®o.Function{ + Name: "styx.reject_network", + Description: "Check if the IP, IP:port, host or host:port is to be rejected", + Decl: networkFunctionDecl, + Nondeterministic: true, + Memoize: true, + }, networkFunctionImpl(rejectNetworks)), + rego.Function1(®o.Function{ + Name: "styx.permit_network", + Description: "Check if the IP, IP:port, host or host:port is to be permitted", + Decl: networkFunctionDecl, + Nondeterministic: true, + Memoize: true, + }, networkFunctionImpl(permitNetworks)), + ) + return +} + +var domainFunctionDecl = types.NewFunction( + types.Args(types.Named("domain", types.S).Description("Domain to lookup")), + types.Named("result", types.B).Description("`true` if domain matches"), +) + +func domainFunctionImpl(tries []*dataset.DomainTrie) rego.Builtin1 { + return func(ctx rego.BuiltinContext, domainTerm *ast.Term) (*ast.Term, error) { + domain, err := parseStringTerm(domainTerm) + if err != nil { + return nil, err + } + for _, trie := range tries { + if trie.Contains(domain) { + return ast.NewTerm(ast.Boolean(true)), nil + } + } + return ast.NewTerm(ast.Boolean(false)), nil + } +} + +var networkFunctionDecl = types.NewFunction( + types.Args(types.Named("ip", types.S).Description("IP, IP:port, host or host:port to lookup")), + types.Named("result", types.B).Description("`true` if IP matches"), +) + +func networkFunctionImpl(tries []*dataset.NetworkTrie) rego.Builtin1 { + return func(ctx rego.BuiltinContext, ipTerm *ast.Term) (*ast.Term, error) { + ips, err := parseAddrTerm(ipTerm) + if err != nil { + return nil, err + } + + for _, trie := range tries { + for _, ip := range ips { + if trie.Contains(ip) { + return ast.NewTerm(ast.Boolean(true)), nil + } + } + } + return ast.NewTerm(ast.Boolean(false)), nil + } +} + +func parseAddrTerm(term *ast.Term) (addrs []netip.Addr, err error) { + s, err := parseStringTerm(term) + if err != nil { + return nil, err + } + if addr, err := netip.ParseAddr(netutil.Host(s)); err == nil { + // Input was "ip" or "ip:port" + return []netip.Addr{addr}, nil + } + + ips, err := net.LookupIP(netutil.Host(s)) + if err != nil { + return nil, err + } + + for _, ip := range ips { + if addr, ok := netip.AddrFromSlice(ip); ok { + addrs = append(addrs, addr) + } + } + return +} + +func parseStringTerm(term *ast.Term) (string, error) { + value, ok := term.Value.(ast.String) + if !ok { + return "", errors.New("expected string argument") + } + return strings.Trim(value.String(), `"`), nil +} diff --git a/proxy/proxy.go b/proxy/proxy.go index 7c8e365..5b009d8 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -302,14 +302,15 @@ func (p *Proxy) Serve(l net.Listener) error { func (p *Proxy) handle(nc net.Conn) { var ( start = time.Now() - ctx = NewContext(nc).(*proxyContext) + ctx = NewContext(nc, p.Storage).(*proxyContext) + log = ctx.Logger() err error ) defer func() { if r := recover(); r != nil { if err, ok := r.(error); ok { - ctx.LogEntry().Err(err).Warn("Bug in code, recovered from panic!") + log.Err(err).Warn("Bug in code, recovered from panic!") } _ = nc.Close() } @@ -360,16 +361,6 @@ func (p *Proxy) handle(nc net.Conn) { } } - log := ctx.LogEntry() - if p.Storage != nil { - if client, err := p.Storage.ClientByIP(nc.RemoteAddr().(*net.TCPAddr).IP); err == nil { - log = log.Values(logger.Values{ - "client_id": client.ID, - "client_network": client.String(), - "client_description": client.Description, - }) - } - } for { if ctx.transparentTLS { ctx.req = &http.Request{ @@ -448,7 +439,9 @@ func (p *Proxy) handleError(ctx *proxyContext, err error, sendResponse bool) { if res == nil && sendResponse { res = NewErrorResponse(err, ctx.Request()) } - ctx.LogEntry().Value("count", len(p.OnError)).Trace("Running error handlers") + + log := ctx.Logger() + log.Value("count", len(p.OnError)).Trace("Running error handlers") for _, f := range p.OnError { if newRes := f.HandleError(ctx, err); newRes != nil { res = newRes @@ -464,7 +457,7 @@ func (p *Proxy) handleError(ctx *proxyContext, err error, sendResponse bool) { func (p *Proxy) handleRequest(ctx *proxyContext) (err error) { switch { case ctx.req == nil: - ctx.LogEntry().Warn("Request is nil in handleRequest!?") + ctx.Logger().Warn("Request is nil in handleRequest!?") return errors.New("proxy: request is nil?") case headerContains(ctx.req.Header, HeaderConnection, "upgrade"): @@ -527,7 +520,7 @@ func (p *Proxy) serve(ctx *proxyContext) (err error) { } func (p *Proxy) serveConnect(ctx *proxyContext) (err error) { - log := ctx.LogEntry() + log := ctx.Logger() // Most browsers expect to get a 200 OK after firing a HTTP CONNECT request; if the upstream // encounters any errors, we'll inform the client after reading the HTTP request that follows. @@ -571,13 +564,13 @@ func (p *Proxy) serveConnect(ctx *proxyContext) (err error) { } ctx.res = NewResponse(http.StatusOK, nil, ctx.req) - srv := NewContext(c).(*proxyContext) + srv := NewContext(c, p.Storage).(*proxyContext) srv.SetIdleTimeout(p.IdleTimeout) return p.multiplex(ctx, srv) } func (p *Proxy) serveForward(ctx *proxyContext) (err error) { - log := ctx.LogEntry() + log := ctx.Logger() log.Value("target", ctx.req.URL.String()).Debugf("%s forward request", ctx.req.Proto) var res *http.Response @@ -609,7 +602,8 @@ func (p *Proxy) serveForward(ctx *proxyContext) (err error) { } func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) { - log := ctx.LogEntry().Value("target", ctx.req.URL.String()) + log := ctx.Logger() + log.Value("target", ctx.req.URL.String()) switch ctx.req.URL.Scheme { case "http": @@ -632,7 +626,7 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) { } cancel() - srv := NewContext(c).(*proxyContext) + srv := NewContext(c, p.Storage).(*proxyContext) srv.SetIdleTimeout(p.IdleTimeout) if err = ctx.req.Write(srv); err != nil { ctx.res = NewErrorResponse(err, ctx.req) @@ -662,7 +656,7 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) { func (p *Proxy) multiplex(ctx, srv *proxyContext) (err error) { var ( - log = ctx.LogEntry().Value("server", srv.RemoteAddr().String()) + log = ctx.Logger().Value("server", srv.RemoteAddr().String()) errs = make(chan error, 1) done = make(chan struct{}, 1) ) diff --git a/proxy/stats.go b/proxy/stats.go index 2dd8052..78f2828 100644 --- a/proxy/stats.go +++ b/proxy/stats.go @@ -1,19 +1,14 @@ package proxy -import ( - "expvar" - "strconv" - - "git.maze.io/maze/styx/db/stats" -) - func countStatus(code int) { - k := "http:status:" + strconv.Itoa(code) - v := expvar.Get(k) - if v == nil { - //v = stats.NewCounter("120s1s", "15m10s", "1h1m", "4w1d", "1y4w") - v = stats.NewCounter(k, stats.Minutely, stats.Hourly, stats.Daily, stats.Yearly) - expvar.Publish(k, v) - } - v.(stats.Metric).Add(1) + /* + k := "http:status:" + strconv.Itoa(code) + v := expvar.Get(k) + if v == nil { + //v = stats.NewCounter("120s1s", "15m10s", "1h1m", "4w1d", "1y4w") + v = stats.NewCounter(k, stats.Minutely, stats.Hourly, stats.Daily, stats.Yearly) + expvar.Publish(k, v) + } + v.(stats.Metric).Add(1) + */ } diff --git a/styx.hcl b/styx.hcl index 89a9f29..cc8785b 100644 --- a/styx.hcl +++ b/styx.hcl @@ -62,6 +62,7 @@ data { storage { type = "bolt" path = "testdata/styx.bolt" + cache = 10 #type = "sqlite" #path = "testdata/styx.db" }