Better trie implementations
This commit is contained in:
		@@ -3,6 +3,10 @@ package main
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/hashicorp/hcl/v2"
 | 
			
		||||
	"github.com/hashicorp/hcl/v2/gohcl"
 | 
			
		||||
@@ -24,8 +28,25 @@ type Config struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c Config) Proxies(log logger.Structured) ([]*proxy.Proxy, error) {
 | 
			
		||||
	log.Debug("Loading policies")
 | 
			
		||||
	policies := make(map[string]*policy.Policy)
 | 
			
		||||
	for _, pc := range c.Policy {
 | 
			
		||||
		if !filepath.IsAbs(pc.Path) {
 | 
			
		||||
			var err error
 | 
			
		||||
			if pc.Path, err = filepath.Abs(pc.Path); err != nil {
 | 
			
		||||
				return nil, fmt.Errorf("invalid policy path: %w", err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if pc.Package == "" {
 | 
			
		||||
			var err error
 | 
			
		||||
			if pc.Package, err = policy.PackageFromFile(pc.Path); err != nil {
 | 
			
		||||
				return nil, fmt.Errorf("can't determine package in %s: %w", pc.Path, err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		log.Values(logger.Values{
 | 
			
		||||
			"path":    pc.Path,
 | 
			
		||||
			"package": pc.Package,
 | 
			
		||||
		}).Debug("Loading policy definition")
 | 
			
		||||
		p, err := policy.New(pc.Path, pc.Package)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("policy %s: %w", pc.Name, err)
 | 
			
		||||
@@ -39,6 +60,7 @@ func (c Config) Proxies(log logger.Structured) ([]*proxy.Proxy, error) {
 | 
			
		||||
		onForward  []proxy.ForwardHandler
 | 
			
		||||
		onResponse []proxy.ResponseHandler
 | 
			
		||||
	)
 | 
			
		||||
	log.Debug("Resolving policy handlers")
 | 
			
		||||
	for _, name := range c.Proxy.On.Request {
 | 
			
		||||
		log.Value("policy", name).Debug("Resolving request policy")
 | 
			
		||||
		p, ok := policies[name]
 | 
			
		||||
@@ -109,10 +131,18 @@ type PortTLSConfig struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c PortConfig) Proxy() (*proxy.Proxy, error) {
 | 
			
		||||
	p := proxy.New()
 | 
			
		||||
	log := logger.StandardLog.Value("port", c.Listen)
 | 
			
		||||
	port := proxy.New()
 | 
			
		||||
	if c.Transparent > 0 {
 | 
			
		||||
		p.OnConnect = append(p.OnConnect, proxy.Transparent(c.Transparent))
 | 
			
		||||
		log.Debug("Configuring transparent proxy handler")
 | 
			
		||||
		port.OnConnect = append(port.OnConnect, proxy.Transparent(c.Transparent))
 | 
			
		||||
	} else if c.TLS != nil {
 | 
			
		||||
		if strings.ContainsRune(c.TLS.Cert, os.PathSeparator) {
 | 
			
		||||
			log = log.Value("cert", c.TLS.Cert)
 | 
			
		||||
		} else {
 | 
			
		||||
			log = log.Value("cert", "<data>")
 | 
			
		||||
		}
 | 
			
		||||
		log.Debug("Configuring TLS handler")
 | 
			
		||||
		cert, err := cryptutil.LoadTLSCertificate(c.TLS.Cert, c.TLS.Key)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
@@ -121,6 +151,7 @@ func (c PortConfig) Proxy() (*proxy.Proxy, error) {
 | 
			
		||||
		config := new(tls.Config)
 | 
			
		||||
		config.Certificates = []tls.Certificate{cert}
 | 
			
		||||
		if c.TLS.CA != "" {
 | 
			
		||||
			log.Value("ca", c.TLS.CA).Debug("Loading trusted roots")
 | 
			
		||||
			roots, err := cryptutil.LoadRoots(c.TLS.CA)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
@@ -128,9 +159,9 @@ func (c PortConfig) Proxy() (*proxy.Proxy, error) {
 | 
			
		||||
			config.RootCAs = roots
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		p.OnConnect = append(p.OnConnect, proxy.TLS(config))
 | 
			
		||||
		port.OnConnect = append(port.OnConnect, proxy.TLS(config))
 | 
			
		||||
	}
 | 
			
		||||
	return p, nil
 | 
			
		||||
	return port, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ProxyPolicyConfig struct {
 | 
			
		||||
@@ -177,17 +208,23 @@ func (c DataConfig) Configure() error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c DataConfig) OpenStorage() (dataset.Storage, error) {
 | 
			
		||||
func (c DataConfig) OpenStorage() (s dataset.Storage, err error) {
 | 
			
		||||
	var cache time.Duration
 | 
			
		||||
	switch c.Storage.Type {
 | 
			
		||||
	case "", "bolt", "boltdb":
 | 
			
		||||
		var config struct {
 | 
			
		||||
			Path string `hcl:"path"`
 | 
			
		||||
			Path  string  `hcl:"path"`
 | 
			
		||||
			Cache float64 `hcl:"cache,optional"`
 | 
			
		||||
		}
 | 
			
		||||
		if diag := gohcl.DecodeBody(c.Storage.Body, nil, &config); diag.HasErrors() {
 | 
			
		||||
			return nil, diag
 | 
			
		||||
		}
 | 
			
		||||
		//return dataset.OpenBolt(config.Path)
 | 
			
		||||
		return dataset.OpenBStore(config.Path)
 | 
			
		||||
		if s, err = dataset.OpenBStore(config.Path); err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if config.Cache > 0 {
 | 
			
		||||
			cache = time.Duration(config.Cache * float64(time.Second))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	/*
 | 
			
		||||
		case "sqlite", "sqlite3":
 | 
			
		||||
@@ -203,6 +240,11 @@ func (c DataConfig) OpenStorage() (dataset.Storage, error) {
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, fmt.Errorf("storage: no %q driver", c.Storage.Type)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if s != nil && cache > 0 {
 | 
			
		||||
		return dataset.Cache(s, cache), nil
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DataStorageConfig struct {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										62
									
								
								dataset/dnstrie/name.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								dataset/dnstrie/name.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,62 @@
 | 
			
		||||
package dnstrie
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"strings"
 | 
			
		||||
	"unicode"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// isValidDomainName validates if the given string is a valid DNS hostname.
 | 
			
		||||
// A valid hostname consists of a series of labels separated by dots.
 | 
			
		||||
// Each label must:
 | 
			
		||||
// - Be between 1 and 63 characters long.
 | 
			
		||||
// - Consist only of ASCII letters ('a'-'z', 'A'-'Z'), digits ('0'-'9'), and hyphens ('-').
 | 
			
		||||
// - Not start or end with a hyphen.
 | 
			
		||||
// The total length of the hostname (including dots) must not exceed 253 characters.
 | 
			
		||||
func isValidDomainName(host string) bool {
 | 
			
		||||
	// 1. Check total length. The maximum length of a full hostname is 253 characters.
 | 
			
		||||
	if len(host) > 253 {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// An empty string is not a valid hostname.
 | 
			
		||||
	if host == "" {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 2. Handle optional trailing dot for FQDNs.
 | 
			
		||||
	// If the hostname ends with a dot, we remove it for validation purposes.
 | 
			
		||||
	if strings.HasSuffix(host, ".") {
 | 
			
		||||
		host = host[:len(host)-1]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// After removing a potential trailing dot, the string might be empty.
 | 
			
		||||
	if host == "" {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 3. Split the hostname into labels.
 | 
			
		||||
	labels := strings.Split(host, ".")
 | 
			
		||||
 | 
			
		||||
	// 4. Validate each label.
 | 
			
		||||
	for _, label := range labels {
 | 
			
		||||
		// a. Check label length (1 to 63 characters).
 | 
			
		||||
		if len(label) < 1 || len(label) > 63 {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// b. Check if label starts or ends with a hyphen.
 | 
			
		||||
		if strings.HasPrefix(label, "-") || strings.HasSuffix(label, "-") {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// c. Check for allowed characters in the label.
 | 
			
		||||
		for _, char := range label {
 | 
			
		||||
			if !unicode.IsLetter(char) && !unicode.IsDigit(char) && char != '-' {
 | 
			
		||||
				return false
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// If all checks pass, the hostname is valid.
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										42
									
								
								dataset/dnstrie/name_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								dataset/dnstrie/name_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,42 @@
 | 
			
		||||
package dnstrie
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestIsValidDomainName(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name     string
 | 
			
		||||
		host     string
 | 
			
		||||
		expected bool
 | 
			
		||||
	}{
 | 
			
		||||
		{"Valid Hostname", "example.com", true},
 | 
			
		||||
		{"Valid with Subdomain", "sub.domain.co.uk", true},
 | 
			
		||||
		{"Valid with Hyphen", "my-host-name.org", true},
 | 
			
		||||
		{"Valid with Digits", "app1.server2.net", true},
 | 
			
		||||
		{"Valid FQDN", "example.com.", true},
 | 
			
		||||
		{"Valid Single Label", "localhost", true},
 | 
			
		||||
		{"Valid: Long but within limits", strings.Repeat("a", 63) + "." + strings.Repeat("b", 63) + ".com", true},
 | 
			
		||||
		{"Invalid: Label Too Long", strings.Repeat("a", 64) + ".com", false},
 | 
			
		||||
		{"Invalid: Total Length Too Long", strings.Repeat("a", 60) + "." + strings.Repeat("b", 60) + "." + strings.Repeat("c", 60) + "." + strings.Repeat("d", 60) + "." + strings.Repeat("e", 60) + ".com", false},
 | 
			
		||||
		{"Invalid: Starts with Hyphen", "-invalid.com", false},
 | 
			
		||||
		{"Invalid: Ends with Hyphen", "invalid-.com", false},
 | 
			
		||||
		{"Invalid: Contains Underscore", "my_host.com", false},
 | 
			
		||||
		{"Invalid: Contains Space", "my host.com", false},
 | 
			
		||||
		{"Invalid: Double Dot", "example..com", false},
 | 
			
		||||
		{"Invalid: Starts with Dot", ".example.com", false},
 | 
			
		||||
		{"Invalid: Empty Label", "sub..domain.com", false},
 | 
			
		||||
		{"Invalid: Just a Dot", ".", false},
 | 
			
		||||
		{"Invalid: Empty String", "", false},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, test := range tests {
 | 
			
		||||
		t.Run(test.name, func(t *testing.T) {
 | 
			
		||||
			got := isValidDomainName(test.host)
 | 
			
		||||
			if got != test.expected {
 | 
			
		||||
				t.Errorf("isValidDomainName(%q) = %v; want %v", test.host, got, test.expected)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										139
									
								
								dataset/dnstrie/trie.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								dataset/dnstrie/trie.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,139 @@
 | 
			
		||||
package dnstrie
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Node represents a single node in the prefix trie.
 | 
			
		||||
type Node struct {
 | 
			
		||||
	// children maps the next label (e.g., "example" in "www.example.com")
 | 
			
		||||
	// to the next node in the trie.
 | 
			
		||||
	children map[string]*Node
 | 
			
		||||
	// isEndOfDomain marks if this node represents the end of a valid domain.
 | 
			
		||||
	isEndOfDomain bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Trie holds the root of the domain prefix trie.
 | 
			
		||||
type Trie struct {
 | 
			
		||||
	root *Node
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// New creates and initializes a new Trie.
 | 
			
		||||
func New() *Trie {
 | 
			
		||||
	return &Trie{
 | 
			
		||||
		root: &Node{
 | 
			
		||||
			children: make(map[string]*Node),
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Insert adds a domain name to the trie.
 | 
			
		||||
// It canonicalizes the domain name before insertion.
 | 
			
		||||
func (t *Trie) Insert(domain string) error {
 | 
			
		||||
	// Ensure the string is a valid domain name.
 | 
			
		||||
	if !isValidDomainName(domain) {
 | 
			
		||||
		return fmt.Errorf("'%s' is not a valid domain name", domain)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Canonicalize the domain name (lowercase, ensure trailing dot).
 | 
			
		||||
	canonicalDomain := dns.CanonicalName(domain)
 | 
			
		||||
 | 
			
		||||
	// Split the domain into label offsets. This avoids allocating new strings for labels.
 | 
			
		||||
	// For "www.example.com.", this returns []int{0, 4, 12}
 | 
			
		||||
	offsets := dns.Split(canonicalDomain)
 | 
			
		||||
	if len(offsets) == 0 {
 | 
			
		||||
		return fmt.Errorf("could not split domain name: %s", domain)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	currentNode := t.root
 | 
			
		||||
	// Iterate through labels from TLD to the most specific label (right to left).
 | 
			
		||||
	for i := len(offsets) - 1; i >= 0; i-- {
 | 
			
		||||
		start := offsets[i]
 | 
			
		||||
		var end int
 | 
			
		||||
		if i == len(offsets)-1 {
 | 
			
		||||
			// Last label, from its start to the end of the string, excluding the final dot.
 | 
			
		||||
			end = len(canonicalDomain) - 1
 | 
			
		||||
		} else {
 | 
			
		||||
			// Intermediate label, from its start to the character before the next label's starting dot.
 | 
			
		||||
			end = offsets[i+1] - 1
 | 
			
		||||
		}
 | 
			
		||||
		label := canonicalDomain[start:end]
 | 
			
		||||
 | 
			
		||||
		if _, exists := currentNode.children[label]; !exists {
 | 
			
		||||
			// If the label does not exist in the children map, create a new node.
 | 
			
		||||
			currentNode.children[label] = &Node{children: make(map[string]*Node)}
 | 
			
		||||
		}
 | 
			
		||||
		// Move to the next node.
 | 
			
		||||
		currentNode = currentNode.children[label]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Mark the final node as the end of a domain.
 | 
			
		||||
	currentNode.isEndOfDomain = true
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Contains checks if a domain name exists in the trie.
 | 
			
		||||
func (t *Trie) Contains(domain string) bool {
 | 
			
		||||
	if !isValidDomainName(domain) {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	canonicalDomain := dns.CanonicalName(domain)
 | 
			
		||||
	offsets := dns.Split(canonicalDomain)
 | 
			
		||||
	if len(offsets) == 0 {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	currentNode := t.root
 | 
			
		||||
	for i := len(offsets) - 1; i >= 0; i-- {
 | 
			
		||||
		start := offsets[i]
 | 
			
		||||
		var end int
 | 
			
		||||
		if i == len(offsets)-1 {
 | 
			
		||||
			end = len(canonicalDomain) - 1
 | 
			
		||||
		} else {
 | 
			
		||||
			end = offsets[i+1] - 1
 | 
			
		||||
		}
 | 
			
		||||
		label := canonicalDomain[start:end]
 | 
			
		||||
 | 
			
		||||
		nextNode, exists := currentNode.children[label]
 | 
			
		||||
		if !exists {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
		currentNode = nextNode
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return currentNode.isEndOfDomain
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Merge combines another trie into the current one.
 | 
			
		||||
//
 | 
			
		||||
// Domains from the 'other' trie will be added to the current trie.
 | 
			
		||||
func (t *Trie) Merge(other *Trie) {
 | 
			
		||||
	if other == nil || other.root == nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	mergeNodes(t.root, other.root)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// mergeNodes is a recursive helper function to merge nodes from another trie.
 | 
			
		||||
func mergeNodes(localNode, otherNode *Node) {
 | 
			
		||||
	// If the other node marks the end of a domain, the local node should too.
 | 
			
		||||
	if otherNode.isEndOfDomain {
 | 
			
		||||
		localNode.isEndOfDomain = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Iterate over the children of the other node and merge them.
 | 
			
		||||
	for label, otherChildNode := range otherNode.children {
 | 
			
		||||
		localChildNode, exists := localNode.children[label]
 | 
			
		||||
		if !exists {
 | 
			
		||||
			// If the child doesn't exist in the local trie, just attach the other's child.
 | 
			
		||||
			localNode.children[label] = otherChildNode
 | 
			
		||||
		} else {
 | 
			
		||||
			// If the child already exists, recurse to merge them.
 | 
			
		||||
			mergeNodes(localChildNode, otherChildNode)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										117
									
								
								dataset/dnstrie/trie_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								dataset/dnstrie/trie_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,117 @@
 | 
			
		||||
package dnstrie
 | 
			
		||||
 | 
			
		||||
import "testing"
 | 
			
		||||
 | 
			
		||||
func TestTrie(t *testing.T) {
 | 
			
		||||
	t.Run("InsertAndContains", func(t *testing.T) {
 | 
			
		||||
		trie := New()
 | 
			
		||||
		domain := "www.example.com"
 | 
			
		||||
		err := trie.Insert(domain)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("Expected no error on insert, got %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if !trie.Contains(domain) {
 | 
			
		||||
			t.Errorf("Expected Contains('%s') to be true, but it was false", domain)
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("ContainsNotFound", func(t *testing.T) {
 | 
			
		||||
		trie := New()
 | 
			
		||||
		err := trie.Insert("example.com")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("Insert failed: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if trie.Contains("nonexistent.com") {
 | 
			
		||||
			t.Error("Expected not to find domain 'nonexistent.com', but did")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Check for a path that exists but is not a terminal node
 | 
			
		||||
		if trie.Contains("com") {
 | 
			
		||||
			t.Error("Expected not to find non-terminal path 'com', but did")
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("InsertInvalidDomain", func(t *testing.T) {
 | 
			
		||||
		trie := New()
 | 
			
		||||
		err := trie.Insert("not-a-valid-domain-")
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			t.Error("Expected an error when inserting an invalid domain, but got nil")
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Canonicalization", func(t *testing.T) {
 | 
			
		||||
		trie := New()
 | 
			
		||||
		// Insert lowercase with trailing dot
 | 
			
		||||
		err := trie.Insert("case.example.org.")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("Insert failed: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Check contains with uppercase without trailing dot
 | 
			
		||||
		if !trie.Contains("CASE.EXAMPLE.ORG") {
 | 
			
		||||
			t.Fatal("Failed to find domain with different case and no trailing dot")
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("MultipleInsertions", func(t *testing.T) {
 | 
			
		||||
		trie := New()
 | 
			
		||||
		domains := []string{
 | 
			
		||||
			"example.com",
 | 
			
		||||
			"www.example.com",
 | 
			
		||||
			"api.example.com",
 | 
			
		||||
			"google.com",
 | 
			
		||||
		}
 | 
			
		||||
		for _, domain := range domains {
 | 
			
		||||
			if err := trie.Insert(domain); err != nil {
 | 
			
		||||
				t.Fatalf("Insert failed for %s: %v", domain, err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for _, domain := range domains {
 | 
			
		||||
			if !trie.Contains(domain) {
 | 
			
		||||
				t.Errorf("Expected to find %s, but did not", domain)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if trie.Contains("ftp.example.com") {
 | 
			
		||||
			t.Error("Found domain 'ftp.example.com' which was not inserted")
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("MergeTries", func(t *testing.T) {
 | 
			
		||||
		trie1 := New()
 | 
			
		||||
		trie1.Insert("example.com")
 | 
			
		||||
		trie1.Insert("sub.example.com")
 | 
			
		||||
 | 
			
		||||
		trie2 := New()
 | 
			
		||||
		trie2.Insert("google.com")
 | 
			
		||||
		trie2.Insert("sub.example.com") // Overlapping domain
 | 
			
		||||
		trie2.Insert("another.net")
 | 
			
		||||
 | 
			
		||||
		trie1.Merge(trie2)
 | 
			
		||||
 | 
			
		||||
		// Test domains from both tries
 | 
			
		||||
		if !trie1.Contains("example.com") {
 | 
			
		||||
			t.Error("Merge failed: trie1 should contain 'example.com'")
 | 
			
		||||
		}
 | 
			
		||||
		if !trie1.Contains("google.com") {
 | 
			
		||||
			t.Error("Merge failed: trie1 should contain 'google.com'")
 | 
			
		||||
		}
 | 
			
		||||
		if !trie1.Contains("sub.example.com") {
 | 
			
		||||
			t.Error("Merge failed: trie1 should contain overlapping 'sub.example.com'")
 | 
			
		||||
		}
 | 
			
		||||
		if !trie1.Contains("another.net") {
 | 
			
		||||
			t.Error("Merge failed: trie1 should contain 'another.net'")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Ensure trie2 is not modified
 | 
			
		||||
		if !trie2.Contains("google.com") {
 | 
			
		||||
			t.Error("Source trie (trie2) should not be modified after merge")
 | 
			
		||||
		}
 | 
			
		||||
		if trie2.Contains("example.com") {
 | 
			
		||||
			t.Error("Source trie (trie2) was modified after merge")
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										186
									
								
								dataset/dnstrie/valuetrie.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										186
									
								
								dataset/dnstrie/valuetrie.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,186 @@
 | 
			
		||||
package dnstrie
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ValueNode represents a single node in the prefix trie, using generics for the value type.
 | 
			
		||||
type ValueNode[T any] struct {
 | 
			
		||||
	// children maps the next label (e.g., "example" in "www.example.com")
 | 
			
		||||
	// to the next node in the trie.
 | 
			
		||||
	children map[string]*ValueNode[T]
 | 
			
		||||
 | 
			
		||||
	// value is the data of generic type T associated with the domain name ending at this node.
 | 
			
		||||
	value T
 | 
			
		||||
 | 
			
		||||
	// isEndOfDomain marks if this node represents the end of a valid domain.
 | 
			
		||||
	isEndOfDomain bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ValueTrie holds the root of the domain prefix trie, using generics.
 | 
			
		||||
type ValueTrie[T any] struct {
 | 
			
		||||
	root *ValueNode[T]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewValue creates and initializes a new ValueTrie.
 | 
			
		||||
func NewValue[T any]() *ValueTrie[T] {
 | 
			
		||||
	return &ValueTrie[T]{
 | 
			
		||||
		root: &ValueNode[T]{
 | 
			
		||||
			children: make(map[string]*ValueNode[T]),
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Insert adds a domain name and its associated generic value to the trie.
 | 
			
		||||
// It canonicalizes the domain name before insertion.
 | 
			
		||||
func (t *ValueTrie[T]) Insert(domain string, value T) error {
 | 
			
		||||
	// Ensure the string is a valid domain name.
 | 
			
		||||
	if !isValidDomainName(domain) {
 | 
			
		||||
		return fmt.Errorf("'%s' is not a valid domain name", domain)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Canonicalize the domain name (lowercase, ensure trailing dot).
 | 
			
		||||
	canonicalDomain := dns.CanonicalName(domain)
 | 
			
		||||
 | 
			
		||||
	// Split the domain into label offsets. This avoids allocating new strings for labels.
 | 
			
		||||
	// For "www.example.com.", this returns []int{0, 4, 12}
 | 
			
		||||
	offsets := dns.Split(canonicalDomain)
 | 
			
		||||
	if len(offsets) == 0 {
 | 
			
		||||
		return fmt.Errorf("could not split domain name: %s", domain)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	currentNode := t.root
 | 
			
		||||
	// Iterate through labels from TLD to the most specific label (right to left).
 | 
			
		||||
	for i := len(offsets) - 1; i >= 0; i-- {
 | 
			
		||||
		start := offsets[i]
 | 
			
		||||
		var end int
 | 
			
		||||
		if i == len(offsets)-1 {
 | 
			
		||||
			// Last label, from its start to the end of the string, excluding the final dot.
 | 
			
		||||
			end = len(canonicalDomain) - 1
 | 
			
		||||
		} else {
 | 
			
		||||
			// Intermediate label, from its start to the character before the next label's starting dot.
 | 
			
		||||
			end = offsets[i+1] - 1
 | 
			
		||||
		}
 | 
			
		||||
		label := canonicalDomain[start:end]
 | 
			
		||||
 | 
			
		||||
		if _, exists := currentNode.children[label]; !exists {
 | 
			
		||||
			// If the label does not exist in the children map, create a new node.
 | 
			
		||||
			currentNode.children[label] = &ValueNode[T]{children: make(map[string]*ValueNode[T])}
 | 
			
		||||
		}
 | 
			
		||||
		// Move to the next node.
 | 
			
		||||
		currentNode = currentNode.children[label]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Mark the final node as the end of a domain and store the value.
 | 
			
		||||
	currentNode.isEndOfDomain = true
 | 
			
		||||
	currentNode.value = value
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Contains checks if a domain name exists in the trie without returning its value.
 | 
			
		||||
func (t *ValueTrie[T]) Contains(domain string) bool {
 | 
			
		||||
	if !isValidDomainName(domain) {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	canonicalDomain := dns.CanonicalName(domain)
 | 
			
		||||
	offsets := dns.Split(canonicalDomain)
 | 
			
		||||
	if len(offsets) == 0 {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	currentNode := t.root
 | 
			
		||||
	for i := len(offsets) - 1; i >= 0; i-- {
 | 
			
		||||
		start := offsets[i]
 | 
			
		||||
		var end int
 | 
			
		||||
		if i == len(offsets)-1 {
 | 
			
		||||
			end = len(canonicalDomain) - 1
 | 
			
		||||
		} else {
 | 
			
		||||
			end = offsets[i+1] - 1
 | 
			
		||||
		}
 | 
			
		||||
		label := canonicalDomain[start:end]
 | 
			
		||||
 | 
			
		||||
		nextNode, exists := currentNode.children[label]
 | 
			
		||||
		if !exists {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
		currentNode = nextNode
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return currentNode.isEndOfDomain
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Search looks for a domain name in the trie.
 | 
			
		||||
// It returns the associated generic value and a boolean indicating if the domain was found.
 | 
			
		||||
func (t *ValueTrie[T]) Search(domain string) (T, bool) {
 | 
			
		||||
	var zero T // The zero value for the generic type T.
 | 
			
		||||
	if !isValidDomainName(domain) {
 | 
			
		||||
		return zero, false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	canonicalDomain := dns.CanonicalName(domain)
 | 
			
		||||
	offsets := dns.Split(canonicalDomain)
 | 
			
		||||
	if len(offsets) == 0 {
 | 
			
		||||
		return zero, false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	currentNode := t.root
 | 
			
		||||
	for i := len(offsets) - 1; i >= 0; i-- {
 | 
			
		||||
		start := offsets[i]
 | 
			
		||||
		var end int
 | 
			
		||||
		if i == len(offsets)-1 {
 | 
			
		||||
			end = len(canonicalDomain) - 1
 | 
			
		||||
		} else {
 | 
			
		||||
			end = offsets[i+1] - 1
 | 
			
		||||
		}
 | 
			
		||||
		label := canonicalDomain[start:end]
 | 
			
		||||
 | 
			
		||||
		nextNode, exists := currentNode.children[label]
 | 
			
		||||
		if !exists {
 | 
			
		||||
			// A label in the path was not found, so the domain doesn't exist.
 | 
			
		||||
			return zero, false
 | 
			
		||||
		}
 | 
			
		||||
		currentNode = nextNode
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// The full path was found, but we must also check if it's a terminal node.
 | 
			
		||||
	// This prevents matching "example.com" if only "www.example.com" was inserted.
 | 
			
		||||
	if currentNode.isEndOfDomain {
 | 
			
		||||
		return currentNode.value, true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return zero, false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Merge combines another trie into the current one.
 | 
			
		||||
//
 | 
			
		||||
// If a domain exists in both tries, the value from the 'other' trie is used.
 | 
			
		||||
func (t *ValueTrie[T]) Merge(other *ValueTrie[T]) {
 | 
			
		||||
	if other == nil || other.root == nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	mergeValueNodes(t.root, other.root)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// mergeNodes is a recursive helper function to merge nodes from another trie.
 | 
			
		||||
func mergeValueNodes[T any](localNode, otherNode *ValueNode[T]) {
 | 
			
		||||
	// The other node value overwrites the local one.
 | 
			
		||||
	if otherNode.isEndOfDomain {
 | 
			
		||||
		localNode.value = otherNode.value
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Iterate over the children of the other node and merge them.
 | 
			
		||||
	for label, otherChildNode := range otherNode.children {
 | 
			
		||||
		localChildNode, exists := localNode.children[label]
 | 
			
		||||
		if !exists {
 | 
			
		||||
			// If the child doesn't exist in the local trie, attach the other's child.
 | 
			
		||||
			localNode.children[label] = otherChildNode
 | 
			
		||||
		} else {
 | 
			
		||||
			// If the child already exists, recurse to merge them.
 | 
			
		||||
			mergeValueNodes(localChildNode, otherChildNode)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										182
									
								
								dataset/dnstrie/valuetrie_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										182
									
								
								dataset/dnstrie/valuetrie_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,182 @@
 | 
			
		||||
package dnstrie
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestValueTrie(t *testing.T) {
 | 
			
		||||
	t.Run("InsertAndSearchStrings", func(t *testing.T) {
 | 
			
		||||
		trie := NewValue[string]()
 | 
			
		||||
		domain := "www.example.com"
 | 
			
		||||
		value := "192.0.2.1"
 | 
			
		||||
		err := trie.Insert(domain, value)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("Expected no error on insert, got %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		foundValue, found := trie.Search(domain)
 | 
			
		||||
		if !found {
 | 
			
		||||
			t.Fatalf("Expected to find domain '%s', but did not", domain)
 | 
			
		||||
		}
 | 
			
		||||
		if foundValue != value {
 | 
			
		||||
			t.Errorf("Expected value '%s', got '%s'", value, foundValue)
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("SearchNotFound", func(t *testing.T) {
 | 
			
		||||
		trie := NewValue[string]()
 | 
			
		||||
		err := trie.Insert("example.com", "value")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("Insert failed: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		_, found := trie.Search("nonexistent.com")
 | 
			
		||||
		if found {
 | 
			
		||||
			t.Error("Expected not to find domain 'nonexistent.com', but did")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Search for a path that exists but is not a terminal node
 | 
			
		||||
		_, found = trie.Search("com")
 | 
			
		||||
		if found {
 | 
			
		||||
			t.Error("Expected not to find non-terminal path 'com', but did")
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("InsertInvalidDomain", func(t *testing.T) {
 | 
			
		||||
		trie := NewValue[string]()
 | 
			
		||||
		err := trie.Insert("not-a-valid-domain-", "value")
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			t.Error("Expected an error when inserting an invalid domain, but got nil")
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("OverwriteValue", func(t *testing.T) {
 | 
			
		||||
		trie := NewValue[string]()
 | 
			
		||||
		domain := "overwrite.com"
 | 
			
		||||
		initialValue := "first"
 | 
			
		||||
		overwriteValue := "second"
 | 
			
		||||
 | 
			
		||||
		err := trie.Insert(domain, initialValue)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("Initial insert failed: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		err = trie.Insert(domain, overwriteValue)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("Overwrite insert failed: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		foundValue, found := trie.Search(domain)
 | 
			
		||||
		if !found {
 | 
			
		||||
			t.Fatalf("Expected to find domain '%s' after overwrite", domain)
 | 
			
		||||
		}
 | 
			
		||||
		if foundValue != overwriteValue {
 | 
			
		||||
			t.Errorf("Expected overwritten value '%s', got '%s'", overwriteValue, foundValue)
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Canonicalization", func(t *testing.T) {
 | 
			
		||||
		trie := NewValue[string]()
 | 
			
		||||
		value := "canonical"
 | 
			
		||||
		// Insert lowercase with trailing dot
 | 
			
		||||
		err := trie.Insert("case.example.org.", value)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("Insert failed: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Search uppercase without trailing dot
 | 
			
		||||
		foundValue, found := trie.Search("CASE.EXAMPLE.ORG")
 | 
			
		||||
		if !found {
 | 
			
		||||
			t.Fatal("Failed to find domain with different case and no trailing dot")
 | 
			
		||||
		}
 | 
			
		||||
		if foundValue != value {
 | 
			
		||||
			t.Errorf("Expected value '%s' for canonical search, got '%s'", value, foundValue)
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("InsertAndSearchIntegers", func(t *testing.T) {
 | 
			
		||||
		trie := NewValue[int]()
 | 
			
		||||
		domain := "int.example.com"
 | 
			
		||||
		value := 12345
 | 
			
		||||
 | 
			
		||||
		err := trie.Insert(domain, value)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("Expected no error on insert for int trie, got %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		foundValue, found := trie.Search(domain)
 | 
			
		||||
		if !found {
 | 
			
		||||
			t.Fatalf("Expected to find domain '%s' in int trie, but did not", domain)
 | 
			
		||||
		}
 | 
			
		||||
		if foundValue != value {
 | 
			
		||||
			t.Errorf("Expected int value %d, got %d", value, foundValue)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Search for a non-existent domain in the int trie
 | 
			
		||||
		_, found = trie.Search("nonexistent.int.example.com")
 | 
			
		||||
		if found {
 | 
			
		||||
			t.Error("Found a nonexistent domain in the int trie")
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Contains", func(t *testing.T) {
 | 
			
		||||
		trie := NewValue[string]()
 | 
			
		||||
		domain := "contains.example.com"
 | 
			
		||||
		err := trie.Insert(domain, "some-value")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("Insert failed: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if !trie.Contains(domain) {
 | 
			
		||||
			t.Errorf("Expected Contains('%s') to be true, but it was false", domain)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if trie.Contains("nonexistent." + domain) {
 | 
			
		||||
			t.Errorf("Expected Contains for nonexistent domain to be false, but it was true")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if trie.Contains("example.com") {
 | 
			
		||||
			t.Error("Expected Contains for a non-terminal path to be false, but it was true")
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("MergeTries", func(t *testing.T) {
 | 
			
		||||
		trie1 := NewValue[int]()
 | 
			
		||||
		trie1.Insert("example.com", 100)
 | 
			
		||||
		trie1.Insert("sub.example.com", 200)
 | 
			
		||||
 | 
			
		||||
		trie2 := NewValue[int]()
 | 
			
		||||
		trie2.Insert("google.com", 300)
 | 
			
		||||
		trie2.Insert("sub.example.com", 999) // Overlapping domain, new value
 | 
			
		||||
		trie2.Insert("another.net", 400)
 | 
			
		||||
 | 
			
		||||
		trie1.Merge(trie2)
 | 
			
		||||
 | 
			
		||||
		// Test domains from both tries are present
 | 
			
		||||
		if !trie1.Contains("example.com") {
 | 
			
		||||
			t.Error("Merge failed: trie1 should contain 'example.com'")
 | 
			
		||||
		}
 | 
			
		||||
		if !trie1.Contains("google.com") {
 | 
			
		||||
			t.Error("Merge failed: trie1 should contain 'google.com'")
 | 
			
		||||
		}
 | 
			
		||||
		if !trie1.Contains("another.net") {
 | 
			
		||||
			t.Error("Merge failed: trie1 should contain 'another.net'")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Test that overlapping value was updated from trie2
 | 
			
		||||
		val, found := trie1.Search("sub.example.com")
 | 
			
		||||
		if !found || val != 999 {
 | 
			
		||||
			t.Errorf("Expected value for overlapping domain to be 999, but got %d", val)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Test that non-overlapping value from trie1 is intact
 | 
			
		||||
		val, found = trie1.Search("example.com")
 | 
			
		||||
		if !found || val != 100 {
 | 
			
		||||
			t.Errorf("Expected value for 'example.com' to be 100, but got %d", val)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Ensure trie2 is not modified
 | 
			
		||||
		if _, found := trie2.Search("example.com"); found {
 | 
			
		||||
			t.Error("Source trie (trie2) was modified after merge")
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
@@ -1,11 +1,29 @@
 | 
			
		||||
package dataset
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/dataset/dnstrie"
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type DomainTrie struct {
 | 
			
		||||
	*dnstrie.ValueTrie[bool]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewDomainTrie(permit bool, domains ...string) (*DomainTrie, error) {
 | 
			
		||||
	trie := &DomainTrie{
 | 
			
		||||
		ValueTrie: dnstrie.NewValue[bool](),
 | 
			
		||||
	}
 | 
			
		||||
	for _, domain := range domains {
 | 
			
		||||
		if err := trie.Insert(domain, permit); err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("dataset: error inserting %s: %w", domain, err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return trie, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DomainTree struct {
 | 
			
		||||
	root *domainTreeNode
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										253
									
								
								dataset/nettrie/trie.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										253
									
								
								dataset/nettrie/trie.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,253 @@
 | 
			
		||||
package nettrie
 | 
			
		||||
 | 
			
		||||
import "net/netip"
 | 
			
		||||
 | 
			
		||||
// Node represents a node in the path-compressed trie.
 | 
			
		||||
// Each node represents a prefix and can have up to two children.
 | 
			
		||||
type Node struct {
 | 
			
		||||
	children [2]*Node
 | 
			
		||||
 | 
			
		||||
	// prefix is the full prefix represented by the path to this node.
 | 
			
		||||
	prefix netip.Prefix
 | 
			
		||||
 | 
			
		||||
	// isValue marks if this node represents an explicitly inserted prefix.
 | 
			
		||||
	isValue bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Trie is a path-compressed radix trie that stores network prefixes.
 | 
			
		||||
type Trie struct {
 | 
			
		||||
	rootV4 *Node
 | 
			
		||||
	rootV6 *Node
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// New creates and initializes a new Trie.
 | 
			
		||||
func New() *Trie {
 | 
			
		||||
	return &Trie{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Insert adds a prefix to the trie.
 | 
			
		||||
func (t *Trie) Insert(p netip.Prefix) {
 | 
			
		||||
	p = p.Masked()
 | 
			
		||||
	addr := p.Addr()
 | 
			
		||||
 | 
			
		||||
	if addr.Is4() {
 | 
			
		||||
		t.rootV4 = t.insert(t.rootV4, p)
 | 
			
		||||
	} else {
 | 
			
		||||
		t.rootV6 = t.insert(t.rootV6, p)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// insert is the recursive helper for inserting a prefix into the trie.
 | 
			
		||||
func (t *Trie) insert(node *Node, p netip.Prefix) *Node {
 | 
			
		||||
	if node == nil {
 | 
			
		||||
		return &Node{prefix: p, isValue: true}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	addr := p.Addr()
 | 
			
		||||
	commonLen := commonPrefixLen(addr, node.prefix.Addr())
 | 
			
		||||
	pBits := p.Bits()
 | 
			
		||||
	nodeBits := node.prefix.Bits()
 | 
			
		||||
 | 
			
		||||
	if commonLen > pBits {
 | 
			
		||||
		commonLen = pBits
 | 
			
		||||
	}
 | 
			
		||||
	if commonLen > nodeBits {
 | 
			
		||||
		commonLen = nodeBits
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if commonLen == nodeBits && commonLen == pBits {
 | 
			
		||||
		// Exact match, mark the node as a value node.
 | 
			
		||||
		node.isValue = true
 | 
			
		||||
		return node
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if commonLen < nodeBits {
 | 
			
		||||
		// The new prefix diverges from the current node's prefix.
 | 
			
		||||
		// We must split the current node.
 | 
			
		||||
		commonP, _ := node.prefix.Addr().Prefix(commonLen)
 | 
			
		||||
		splitNode := &Node{prefix: commonP}
 | 
			
		||||
 | 
			
		||||
		// The existing node becomes a child of the new split node.
 | 
			
		||||
		bit := getBit(node.prefix.Addr(), commonLen)
 | 
			
		||||
		splitNode.children[bit] = node
 | 
			
		||||
 | 
			
		||||
		if commonLen == pBits {
 | 
			
		||||
			// The inserted prefix is a prefix of the node's original prefix.
 | 
			
		||||
			// The new split node represents the inserted prefix.
 | 
			
		||||
			splitNode.isValue = true
 | 
			
		||||
		} else {
 | 
			
		||||
			// The two prefixes diverge. Create a new child for the new prefix.
 | 
			
		||||
			bit := getBit(addr, commonLen)
 | 
			
		||||
			splitNode.children[bit] = &Node{prefix: p, isValue: true}
 | 
			
		||||
		}
 | 
			
		||||
		return splitNode
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// commonLen == nodeBits, meaning the current node's prefix is a prefix of the new one.
 | 
			
		||||
	// We need to descend to a child.
 | 
			
		||||
	bit := getBit(addr, commonLen)
 | 
			
		||||
	node.children[bit] = t.insert(node.children[bit], p)
 | 
			
		||||
	return node
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Delete removes a prefix from the trie. It returns true if the prefix was found and removed.
 | 
			
		||||
func (t *Trie) Delete(p netip.Prefix) bool {
 | 
			
		||||
	p = p.Masked()
 | 
			
		||||
	addr := p.Addr()
 | 
			
		||||
 | 
			
		||||
	var changed bool
 | 
			
		||||
	if addr.Is4() {
 | 
			
		||||
		t.rootV4, changed = t.delete(t.rootV4, p)
 | 
			
		||||
	} else {
 | 
			
		||||
		t.rootV6, changed = t.delete(t.rootV6, p)
 | 
			
		||||
	}
 | 
			
		||||
	return changed
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// delete is the recursive helper for removing a prefix from the trie.
 | 
			
		||||
func (t *Trie) delete(node *Node, p netip.Prefix) (*Node, bool) {
 | 
			
		||||
	if node == nil {
 | 
			
		||||
		return nil, false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	addr := p.Addr()
 | 
			
		||||
	pBits := p.Bits()
 | 
			
		||||
	nodeBits := node.prefix.Bits()
 | 
			
		||||
	commonLen := commonPrefixLen(addr, node.prefix.Addr())
 | 
			
		||||
 | 
			
		||||
	// The prefix is not on this path.
 | 
			
		||||
	if commonLen < nodeBits || commonLen < pBits && pBits < nodeBits {
 | 
			
		||||
		return node, false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var changed bool
 | 
			
		||||
	if pBits > nodeBits {
 | 
			
		||||
		// The prefix to delete is deeper in the trie. Recurse.
 | 
			
		||||
		bit := getBit(addr, nodeBits)
 | 
			
		||||
		node.children[bit], changed = t.delete(node.children[bit], p)
 | 
			
		||||
	} else if pBits == nodeBits {
 | 
			
		||||
		// This is the node to delete. Unset its value.
 | 
			
		||||
		if !node.isValue {
 | 
			
		||||
			return node, false // Prefix wasn't actually in the trie.
 | 
			
		||||
		}
 | 
			
		||||
		node.isValue = false
 | 
			
		||||
		changed = true
 | 
			
		||||
	} else { // pBits < nodeBits
 | 
			
		||||
		return node, false // Prefix to delete is shorter, so can't be here.
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !changed {
 | 
			
		||||
		return node, false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Post-deletion cleanup:
 | 
			
		||||
	// If the node has no value and can be merged with a single child, do so.
 | 
			
		||||
	if !node.isValue {
 | 
			
		||||
		if node.children[0] != nil && node.children[1] == nil {
 | 
			
		||||
			return node.children[0], true
 | 
			
		||||
		}
 | 
			
		||||
		if node.children[0] == nil && node.children[1] != nil {
 | 
			
		||||
			return node.children[1], true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// If the node is now a leaf without a value, it can be removed entirely.
 | 
			
		||||
	if !node.isValue && node.children[0] == nil && node.children[1] == nil {
 | 
			
		||||
		return nil, true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return node, true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ContainsPrefix checks if the exact prefix exists in the trie.
 | 
			
		||||
func (t *Trie) ContainsPrefix(p netip.Prefix) bool {
 | 
			
		||||
	p = p.Masked()
 | 
			
		||||
	addr := p.Addr()
 | 
			
		||||
	pBits := p.Bits()
 | 
			
		||||
 | 
			
		||||
	node := t.rootV4
 | 
			
		||||
	if addr.Is6() {
 | 
			
		||||
		node = t.rootV6
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for node != nil {
 | 
			
		||||
		commonLen := commonPrefixLen(addr, node.prefix.Addr())
 | 
			
		||||
		nodeBits := node.prefix.Bits()
 | 
			
		||||
 | 
			
		||||
		if commonLen < nodeBits {
 | 
			
		||||
			// Path has diverged. The prefix cannot be in this subtree.
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if pBits < nodeBits {
 | 
			
		||||
			// The search prefix is shorter than the node's prefix,
 | 
			
		||||
			// but they share a prefix. e.g. search /16, node is /24.
 | 
			
		||||
			// The /16 is not explicitly in the trie.
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if pBits == nodeBits {
 | 
			
		||||
			// Found a node with the exact same prefix length.
 | 
			
		||||
			// Because we also know commonLen >= nodeBits, the prefixes are identical.
 | 
			
		||||
			return node.isValue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// pBits > nodeBits, so we need to go deeper.
 | 
			
		||||
		bit := getBit(addr, nodeBits)
 | 
			
		||||
		node = node.children[bit]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Contains checks if the exact IP address exists in the trie as a full-length prefix.
 | 
			
		||||
func (t *Trie) Contains(addr netip.Addr) bool {
 | 
			
		||||
	prefix := netip.PrefixFrom(addr, addr.BitLen())
 | 
			
		||||
	return t.ContainsPrefix(prefix)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WalkFunc is a function called for each prefix in the trie during a walk.
 | 
			
		||||
// Returning false from the function will stop the walk.
 | 
			
		||||
type WalkFunc func(p netip.Prefix) bool
 | 
			
		||||
 | 
			
		||||
// walk is the recursive helper for traversing the trie.
 | 
			
		||||
func walk(node *Node, f WalkFunc) bool {
 | 
			
		||||
	if node == nil {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if node.isValue {
 | 
			
		||||
		if !f(node.prefix) {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if node.children[0] != nil {
 | 
			
		||||
		if !walk(node.children[0], f) {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if node.children[1] != nil {
 | 
			
		||||
		if !walk(node.children[1], f) {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Walk traverses the trie and calls the given function for each prefix.
 | 
			
		||||
// If the function returns false, the walk is stopped. The order is not guaranteed.
 | 
			
		||||
func (t *Trie) Walk(f WalkFunc) {
 | 
			
		||||
	if !walk(t.rootV4, f) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	walk(t.rootV6, f)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Merge inserts all prefixes from another Trie into this one.
 | 
			
		||||
func (t *Trie) Merge(other *Trie) {
 | 
			
		||||
	other.Walk(func(p netip.Prefix) bool {
 | 
			
		||||
		t.Insert(p)
 | 
			
		||||
		return true // continue walking
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										248
									
								
								dataset/nettrie/trie_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										248
									
								
								dataset/nettrie/trie_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,248 @@
 | 
			
		||||
package nettrie
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net/netip"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// TestTrie_InsertAndContains_IPv4 tests insertion and verifies presence
 | 
			
		||||
// using ContainsPrefix and Contains for IPv4.
 | 
			
		||||
func TestTrie_InsertAndContains_IPv4(t *testing.T) {
 | 
			
		||||
	trie := New()
 | 
			
		||||
 | 
			
		||||
	prefixes := []string{
 | 
			
		||||
		"0.0.0.0/0",
 | 
			
		||||
		"10.0.0.0/8",
 | 
			
		||||
		"192.168.0.0/16",
 | 
			
		||||
		"192.168.1.0/24",
 | 
			
		||||
		"192.168.1.128/25",
 | 
			
		||||
		"192.168.1.200/32",
 | 
			
		||||
	}
 | 
			
		||||
	for _, s := range prefixes {
 | 
			
		||||
		trie.Insert(netip.MustParsePrefix(s))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.Run("Check existing prefixes", func(t *testing.T) {
 | 
			
		||||
		for _, s := range prefixes {
 | 
			
		||||
			p := netip.MustParsePrefix(s)
 | 
			
		||||
			if !trie.ContainsPrefix(p) {
 | 
			
		||||
				t.Errorf("expected trie to contain prefix %s, but it did not", p)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Check non-existing prefixes", func(t *testing.T) {
 | 
			
		||||
		nonExistentPrefixes := []string{
 | 
			
		||||
			"10.0.0.0/9",
 | 
			
		||||
			"192.168.1.0/23",
 | 
			
		||||
			"1.2.3.4/32",
 | 
			
		||||
		}
 | 
			
		||||
		for _, s := range nonExistentPrefixes {
 | 
			
		||||
			p := netip.MustParsePrefix(s)
 | 
			
		||||
			if trie.ContainsPrefix(p) {
 | 
			
		||||
				t.Errorf("expected trie to not contain prefix %s, but it did", p)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Check host addresses with Contains", func(t *testing.T) {
 | 
			
		||||
		if !trie.Contains(netip.MustParseAddr("192.168.1.200")) {
 | 
			
		||||
			t.Error("expected trie to contain host address 192.168.1.200, but it did not")
 | 
			
		||||
		}
 | 
			
		||||
		if trie.Contains(netip.MustParseAddr("192.168.1.201")) {
 | 
			
		||||
			t.Error("expected trie to not contain host address 192.168.1.201, but it did")
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TestTrie_InsertAndContains_IPv6 tests insertion and verifies presence
 | 
			
		||||
// using ContainsPrefix and Contains for IPv6.
 | 
			
		||||
func TestTrie_InsertAndContains_IPv6(t *testing.T) {
 | 
			
		||||
	trie := New()
 | 
			
		||||
 | 
			
		||||
	prefixes := []string{
 | 
			
		||||
		"::/0",
 | 
			
		||||
		"2001:db8::/32",
 | 
			
		||||
		"2001:db8:acad::/48",
 | 
			
		||||
		"2001:db8:acad:1::/64",
 | 
			
		||||
		"2001:db8:acad:1::1/128",
 | 
			
		||||
	}
 | 
			
		||||
	for _, s := range prefixes {
 | 
			
		||||
		trie.Insert(netip.MustParsePrefix(s))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.Run("Check existing prefixes", func(t *testing.T) {
 | 
			
		||||
		for _, s := range prefixes {
 | 
			
		||||
			p := netip.MustParsePrefix(s)
 | 
			
		||||
			if !trie.ContainsPrefix(p) {
 | 
			
		||||
				t.Errorf("expected trie to contain prefix %s, but it did not", p)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Check non-existing prefixes", func(t *testing.T) {
 | 
			
		||||
		nonExistentPrefixes := []string{
 | 
			
		||||
			"2001:db8::/33",
 | 
			
		||||
			"2001:db8:acad::/47",
 | 
			
		||||
			"2002::/16",
 | 
			
		||||
		}
 | 
			
		||||
		for _, s := range nonExistentPrefixes {
 | 
			
		||||
			p := netip.MustParsePrefix(s)
 | 
			
		||||
			if trie.ContainsPrefix(p) {
 | 
			
		||||
				t.Errorf("expected trie to not contain prefix %s, but it did", p)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Check host addresses with Contains", func(t *testing.T) {
 | 
			
		||||
		if !trie.Contains(netip.MustParseAddr("2001:db8:acad:1::1")) {
 | 
			
		||||
			t.Error("expected trie to contain host address 2001:db8:acad:1::1, but it did not")
 | 
			
		||||
		}
 | 
			
		||||
		if trie.Contains(netip.MustParseAddr("2001:db8:acad:1::2")) {
 | 
			
		||||
			t.Error("expected trie to not contain host address 2001:db8:acad:1::2, but it did")
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestTrie_Delete(t *testing.T) {
 | 
			
		||||
	trie := New()
 | 
			
		||||
 | 
			
		||||
	p8 := netip.MustParsePrefix("10.0.0.0/8")
 | 
			
		||||
	p24 := netip.MustParsePrefix("10.0.0.0/24")
 | 
			
		||||
	pOther := netip.MustParsePrefix("10.0.1.0/24")
 | 
			
		||||
 | 
			
		||||
	trie.Insert(p8)
 | 
			
		||||
	trie.Insert(p24)
 | 
			
		||||
	trie.Insert(pOther)
 | 
			
		||||
 | 
			
		||||
	t.Run("Delete Non-Existent", func(t *testing.T) {
 | 
			
		||||
		if trie.Delete(netip.MustParsePrefix("1.2.3.4/32")) {
 | 
			
		||||
			t.Error("Delete returned true for a non-existent prefix")
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Delete Leaf Node", func(t *testing.T) {
 | 
			
		||||
		if !trie.Delete(p24) {
 | 
			
		||||
			t.Fatal("Delete returned false for an existing prefix")
 | 
			
		||||
		}
 | 
			
		||||
		if trie.ContainsPrefix(p24) {
 | 
			
		||||
			t.Error("ContainsPrefix should be false after delete")
 | 
			
		||||
		}
 | 
			
		||||
		if !trie.ContainsPrefix(p8) {
 | 
			
		||||
			t.Error("parent prefix should not be affected by child delete")
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Delete Intermediate Node", func(t *testing.T) {
 | 
			
		||||
		// Re-insert p24 so we can delete its parent p8
 | 
			
		||||
		trie.Insert(p24)
 | 
			
		||||
 | 
			
		||||
		if !trie.Delete(p8) {
 | 
			
		||||
			t.Fatal("Delete returned false for an existing intermediate prefix")
 | 
			
		||||
		}
 | 
			
		||||
		if trie.ContainsPrefix(p8) {
 | 
			
		||||
			t.Error("ContainsPrefix for deleted intermediate node should be false")
 | 
			
		||||
		}
 | 
			
		||||
		if !trie.ContainsPrefix(p24) {
 | 
			
		||||
			t.Error("child prefix should still exist after parent was deleted")
 | 
			
		||||
		}
 | 
			
		||||
		if !trie.ContainsPrefix(pOther) {
 | 
			
		||||
			t.Error("other prefix should still exist after parent was deleted")
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestTrie_Walk(t *testing.T) {
 | 
			
		||||
	trie := New()
 | 
			
		||||
	prefixes := []string{
 | 
			
		||||
		"10.0.0.0/8",
 | 
			
		||||
		"192.168.1.0/24",
 | 
			
		||||
		"2001:db8::/32",
 | 
			
		||||
		"172.16.0.0/12",
 | 
			
		||||
		"2001:db8:acad::/48",
 | 
			
		||||
	}
 | 
			
		||||
	for _, s := range prefixes {
 | 
			
		||||
		trie.Insert(netip.MustParsePrefix(s))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.Run("Walk all prefixes", func(t *testing.T) {
 | 
			
		||||
		var walkedPrefixes []string
 | 
			
		||||
		trie.Walk(func(p netip.Prefix) bool {
 | 
			
		||||
			walkedPrefixes = append(walkedPrefixes, p.String())
 | 
			
		||||
			return true
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		if len(walkedPrefixes) != len(prefixes) {
 | 
			
		||||
			t.Fatalf("expected to walk %d prefixes, but got %d", len(prefixes), len(walkedPrefixes))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		sort.Strings(prefixes)
 | 
			
		||||
		sort.Strings(walkedPrefixes)
 | 
			
		||||
 | 
			
		||||
		for i := range prefixes {
 | 
			
		||||
			if prefixes[i] != walkedPrefixes[i] {
 | 
			
		||||
				t.Errorf("walked prefixes mismatch: expected %v, got %v", prefixes, walkedPrefixes)
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Stop walk early", func(t *testing.T) {
 | 
			
		||||
		count := 0
 | 
			
		||||
		trie.Walk(func(p netip.Prefix) bool {
 | 
			
		||||
			count++
 | 
			
		||||
			return count < 3 // Stop after visiting 3 prefixes
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		if count != 3 {
 | 
			
		||||
			t.Errorf("expected walk to stop after 3 prefixes, but it visited %d", count)
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestTrie_Merge(t *testing.T) {
 | 
			
		||||
	trieA := New()
 | 
			
		||||
	trieA.Insert(netip.MustParsePrefix("10.0.0.0/8"))
 | 
			
		||||
	trieA.Insert(netip.MustParsePrefix("192.168.1.0/24"))
 | 
			
		||||
	trieA.Insert(netip.MustParsePrefix("2001:db8::/32")) // v6
 | 
			
		||||
 | 
			
		||||
	trieB := New()
 | 
			
		||||
	trieB.Insert(netip.MustParsePrefix("10.1.0.0/16"))
 | 
			
		||||
	trieB.Insert(netip.MustParsePrefix("192.168.1.0/24")) // Overlap
 | 
			
		||||
	trieB.Insert(netip.MustParsePrefix("172.16.0.0/12"))
 | 
			
		||||
	trieB.Insert(netip.MustParsePrefix("2001:db8:acad::/48")) // v6
 | 
			
		||||
 | 
			
		||||
	trieA.Merge(trieB)
 | 
			
		||||
 | 
			
		||||
	expectedPrefixes := []string{
 | 
			
		||||
		// From A
 | 
			
		||||
		"10.0.0.0/8",
 | 
			
		||||
		"192.168.1.0/24",
 | 
			
		||||
		"2001:db8::/32",
 | 
			
		||||
		// From B
 | 
			
		||||
		"10.1.0.0/16",
 | 
			
		||||
		"172.16.0.0/12",
 | 
			
		||||
		"2001:db8:acad::/48",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, s := range expectedPrefixes {
 | 
			
		||||
		p := netip.MustParsePrefix(s)
 | 
			
		||||
		if !trieA.ContainsPrefix(p) {
 | 
			
		||||
			t.Errorf("after merge, trieA is missing prefix %s", p)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Check a prefix that should not be there
 | 
			
		||||
	if trieA.ContainsPrefix(netip.MustParsePrefix("9.9.9.9/32")) {
 | 
			
		||||
		t.Error("trieA contains a prefix it should not have")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Verify trieB is unchanged
 | 
			
		||||
	if !trieB.ContainsPrefix(netip.MustParsePrefix("172.16.0.0/12")) {
 | 
			
		||||
		t.Error("trieB was modified during merge")
 | 
			
		||||
	}
 | 
			
		||||
	if trieB.ContainsPrefix(netip.MustParsePrefix("10.0.0.0/8")) {
 | 
			
		||||
		t.Error("trieB contains a prefix from trieA")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										328
									
								
								dataset/nettrie/valuetrie.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										328
									
								
								dataset/nettrie/valuetrie.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,328 @@
 | 
			
		||||
package nettrie
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"math/bits"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// getBit returns the n-th bit of an IP address (0-indexed).
 | 
			
		||||
func getBit(addr netip.Addr, n int) byte {
 | 
			
		||||
	slice := addr.AsSlice()
 | 
			
		||||
	byteIndex := n / 8
 | 
			
		||||
	bitIndex := 7 - (n % 8)
 | 
			
		||||
	return (slice[byteIndex] >> bitIndex) & 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// commonPrefixLen computes the number of leading bits that are the same for two addresses.
 | 
			
		||||
func commonPrefixLen(a, b netip.Addr) int {
 | 
			
		||||
	if a.Is4() != b.Is4() {
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
	aSlice := a.AsSlice()
 | 
			
		||||
	bSlice := b.AsSlice()
 | 
			
		||||
 | 
			
		||||
	commonLen := 0
 | 
			
		||||
	for i := 0; i < len(aSlice); i++ {
 | 
			
		||||
		xor := aSlice[i] ^ bSlice[i]
 | 
			
		||||
		if xor == 0 {
 | 
			
		||||
			commonLen += 8
 | 
			
		||||
		} else {
 | 
			
		||||
			commonLen += bits.LeadingZeros8(xor)
 | 
			
		||||
			return commonLen
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return commonLen
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ValueNode represents a node in the path-compressed trie.
 | 
			
		||||
// Each node represents a prefix and can have up to two children.
 | 
			
		||||
type ValueNode[T any] struct {
 | 
			
		||||
	children [2]*ValueNode[T]
 | 
			
		||||
 | 
			
		||||
	// prefix is the full prefix represented by the path to this node.
 | 
			
		||||
	prefix netip.Prefix
 | 
			
		||||
 | 
			
		||||
	value   T
 | 
			
		||||
	isValue bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ValueTrie is a path-compressed radix trie that stores network prefixes and their values.
 | 
			
		||||
type ValueTrie[T any] struct {
 | 
			
		||||
	rootV4 *ValueNode[T]
 | 
			
		||||
	rootV6 *ValueNode[T]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewValue creates and initializes a new ValueTrie.
 | 
			
		||||
func NewValue[T any]() *ValueTrie[T] {
 | 
			
		||||
	return &ValueTrie[T]{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Insert adds or updates a prefix in the trie with the given value.
 | 
			
		||||
func (t *ValueTrie[T]) Insert(p netip.Prefix, value T) {
 | 
			
		||||
	p = p.Masked()
 | 
			
		||||
	addr := p.Addr()
 | 
			
		||||
 | 
			
		||||
	if addr.Is4() {
 | 
			
		||||
		t.rootV4 = t.insert(t.rootV4, p, value)
 | 
			
		||||
	} else {
 | 
			
		||||
		t.rootV6 = t.insert(t.rootV6, p, value)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// insert is the recursive helper for inserting a prefix into the trie.
 | 
			
		||||
func (t *ValueTrie[T]) insert(node *ValueNode[T], p netip.Prefix, value T) *ValueNode[T] {
 | 
			
		||||
	if node == nil {
 | 
			
		||||
		return &ValueNode[T]{prefix: p, value: value, isValue: true}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	addr := p.Addr()
 | 
			
		||||
	commonLen := commonPrefixLen(addr, node.prefix.Addr())
 | 
			
		||||
	pBits := p.Bits()
 | 
			
		||||
	nodeBits := node.prefix.Bits()
 | 
			
		||||
 | 
			
		||||
	if commonLen > pBits {
 | 
			
		||||
		commonLen = pBits
 | 
			
		||||
	}
 | 
			
		||||
	if commonLen > nodeBits {
 | 
			
		||||
		commonLen = nodeBits
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if commonLen == nodeBits && commonLen == pBits {
 | 
			
		||||
		// Exact match, update the value.
 | 
			
		||||
		node.value = value
 | 
			
		||||
		node.isValue = true
 | 
			
		||||
		return node
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if commonLen < nodeBits {
 | 
			
		||||
		// The new prefix diverges from the current node's prefix.
 | 
			
		||||
		// We must split the current node.
 | 
			
		||||
		commonP, _ := node.prefix.Addr().Prefix(commonLen)
 | 
			
		||||
		splitNode := &ValueNode[T]{prefix: commonP}
 | 
			
		||||
 | 
			
		||||
		// The existing node becomes a child of the new split node.
 | 
			
		||||
		bit := getBit(node.prefix.Addr(), commonLen)
 | 
			
		||||
		splitNode.children[bit] = node
 | 
			
		||||
 | 
			
		||||
		if commonLen == pBits {
 | 
			
		||||
			// The inserted prefix is a prefix of the node's original prefix.
 | 
			
		||||
			// The new split node represents the inserted prefix and gets the value.
 | 
			
		||||
			splitNode.value = value
 | 
			
		||||
			splitNode.isValue = true
 | 
			
		||||
		} else {
 | 
			
		||||
			// The two prefixes diverge. Create a new child for the new prefix.
 | 
			
		||||
			bit = getBit(addr, commonLen)
 | 
			
		||||
			splitNode.children[bit] = &ValueNode[T]{prefix: p, value: value, isValue: true}
 | 
			
		||||
		}
 | 
			
		||||
		return splitNode
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// commonLen == nodeBits, meaning the current node's prefix is a prefix of the new one.
 | 
			
		||||
	// We need to descend to a child.
 | 
			
		||||
	bit := getBit(addr, commonLen)
 | 
			
		||||
	node.children[bit] = t.insert(node.children[bit], p, value)
 | 
			
		||||
	return node
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Lookup finds the value associated with the most specific prefix that contains the given IP address.
 | 
			
		||||
func (t *ValueTrie[T]) Lookup(addr netip.Addr) (value T, ok bool) {
 | 
			
		||||
	node := t.rootV4
 | 
			
		||||
	if addr.Is6() {
 | 
			
		||||
		node = t.rootV6
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var lastFoundValue T
 | 
			
		||||
	var found bool
 | 
			
		||||
 | 
			
		||||
	for node != nil {
 | 
			
		||||
		commonLen := commonPrefixLen(addr, node.prefix.Addr())
 | 
			
		||||
		nodeBits := node.prefix.Bits()
 | 
			
		||||
 | 
			
		||||
		// If the address doesn't share a prefix with the node, we can't be in this subtree.
 | 
			
		||||
		if commonLen < nodeBits {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// The address is within this node's prefix. If the node holds a value,
 | 
			
		||||
		// it's our current best match.
 | 
			
		||||
		if node.isValue {
 | 
			
		||||
			lastFoundValue = node.value
 | 
			
		||||
			found = true
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// We've matched the whole address, can't go deeper.
 | 
			
		||||
		if commonLen == addr.BitLen() {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Descend to the next child based on the next bit after the node's prefix.
 | 
			
		||||
		bit := getBit(addr, nodeBits)
 | 
			
		||||
		node = node.children[bit]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return lastFoundValue, found
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Delete removes a prefix from the trie. It returns true if the prefix was found and removed.
 | 
			
		||||
func (t *ValueTrie[T]) Delete(p netip.Prefix) bool {
 | 
			
		||||
	p = p.Masked()
 | 
			
		||||
	addr := p.Addr()
 | 
			
		||||
 | 
			
		||||
	var changed bool
 | 
			
		||||
	if addr.Is4() {
 | 
			
		||||
		t.rootV4, changed = t.delete(t.rootV4, p)
 | 
			
		||||
	} else {
 | 
			
		||||
		t.rootV6, changed = t.delete(t.rootV6, p)
 | 
			
		||||
	}
 | 
			
		||||
	return changed
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// delete is the recursive helper for removing a prefix from the trie.
 | 
			
		||||
func (t *ValueTrie[T]) delete(node *ValueNode[T], p netip.Prefix) (*ValueNode[T], bool) {
 | 
			
		||||
	if node == nil {
 | 
			
		||||
		return nil, false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	addr := p.Addr()
 | 
			
		||||
	pBits := p.Bits()
 | 
			
		||||
	nodeBits := node.prefix.Bits()
 | 
			
		||||
	commonLen := commonPrefixLen(addr, node.prefix.Addr())
 | 
			
		||||
 | 
			
		||||
	// The prefix is not on this path.
 | 
			
		||||
	if commonLen < nodeBits || commonLen < pBits && pBits < nodeBits {
 | 
			
		||||
		return node, false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var changed bool
 | 
			
		||||
	if pBits > nodeBits {
 | 
			
		||||
		// The prefix to delete is deeper in the trie. Recurse.
 | 
			
		||||
		bit := getBit(addr, nodeBits)
 | 
			
		||||
		node.children[bit], changed = t.delete(node.children[bit], p)
 | 
			
		||||
	} else if pBits == nodeBits {
 | 
			
		||||
		// This is the node to delete. Unset its value.
 | 
			
		||||
		if !node.isValue {
 | 
			
		||||
			return node, false // Prefix wasn't actually in the trie.
 | 
			
		||||
		}
 | 
			
		||||
		node.isValue = false
 | 
			
		||||
		var zero T
 | 
			
		||||
		node.value = zero
 | 
			
		||||
		changed = true
 | 
			
		||||
	} else { // pBits < nodeBits
 | 
			
		||||
		return node, false // Prefix to delete is shorter, so can't be here.
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !changed {
 | 
			
		||||
		return node, false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Post-deletion cleanup:
 | 
			
		||||
	// If the node has no value and can be merged with a single child, do so.
 | 
			
		||||
	if !node.isValue {
 | 
			
		||||
		if node.children[0] != nil && node.children[1] == nil {
 | 
			
		||||
			return node.children[0], true
 | 
			
		||||
		}
 | 
			
		||||
		if node.children[0] == nil && node.children[1] != nil {
 | 
			
		||||
			return node.children[1], true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// If the node is now a leaf without a value, it can be removed entirely.
 | 
			
		||||
	if !node.isValue && node.children[0] == nil && node.children[1] == nil {
 | 
			
		||||
		return nil, true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return node, true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Contains checks if the exact IP address exists in the trie as a full-length prefix.
 | 
			
		||||
func (t *ValueTrie[T]) Contains(addr netip.Addr) bool {
 | 
			
		||||
	prefix := netip.PrefixFrom(addr, addr.BitLen())
 | 
			
		||||
	return t.ContainsPrefix(prefix)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ContainsPrefix checks if the exact prefix exists in the trie.
 | 
			
		||||
func (t *ValueTrie[T]) ContainsPrefix(p netip.Prefix) bool {
 | 
			
		||||
	p = p.Masked()
 | 
			
		||||
	addr := p.Addr()
 | 
			
		||||
	pBits := p.Bits()
 | 
			
		||||
 | 
			
		||||
	node := t.rootV4
 | 
			
		||||
	if addr.Is6() {
 | 
			
		||||
		node = t.rootV6
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for node != nil {
 | 
			
		||||
		commonLen := commonPrefixLen(addr, node.prefix.Addr())
 | 
			
		||||
		nodeBits := node.prefix.Bits()
 | 
			
		||||
 | 
			
		||||
		if commonLen < nodeBits {
 | 
			
		||||
			// Path has diverged. The prefix cannot be in this subtree.
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if pBits < nodeBits {
 | 
			
		||||
			// The search prefix is shorter than the node's prefix,
 | 
			
		||||
			// but they share a prefix. e.g. search /16, node is /24.
 | 
			
		||||
			// The /16 is not explicitly in the trie.
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if pBits == nodeBits {
 | 
			
		||||
			// Found a node with the exact same prefix length.
 | 
			
		||||
			// Because we also know commonLen >= nodeBits, the prefixes are identical.
 | 
			
		||||
			return node.isValue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// pBits > nodeBits, so we need to go deeper.
 | 
			
		||||
		bit := getBit(addr, nodeBits)
 | 
			
		||||
		node = node.children[bit]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WalkValueFunc is a function called for each prefix in the trie during a walk.
 | 
			
		||||
// Returning false from the function will stop the walk.
 | 
			
		||||
type WalkValueFunc[T any] func(p netip.Prefix, v T) bool
 | 
			
		||||
 | 
			
		||||
// walk is the recursive helper for traversing the trie.
 | 
			
		||||
func walkValue[T any](node *ValueNode[T], f WalkValueFunc[T]) bool {
 | 
			
		||||
	if node == nil {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if node.isValue {
 | 
			
		||||
		if !f(node.prefix, node.value) {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if node.children[0] != nil {
 | 
			
		||||
		if !walkValue(node.children[0], f) {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if node.children[1] != nil {
 | 
			
		||||
		if !walkValue(node.children[1], f) {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Walk traverses the trie and calls the given function for each prefix and its value.
 | 
			
		||||
// If the function returns false, the walk is stopped. The order is not guaranteed.
 | 
			
		||||
func (t *ValueTrie[T]) Walk(f WalkValueFunc[T]) {
 | 
			
		||||
	if !walkValue(t.rootV4, f) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	walkValue(t.rootV6, f)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Merge inserts all prefixes from another Trie into this one.
 | 
			
		||||
func (t *ValueTrie[T]) Merge(other *ValueTrie[T]) {
 | 
			
		||||
	other.Walk(func(p netip.Prefix, v T) bool {
 | 
			
		||||
		t.Insert(p, v)
 | 
			
		||||
		return true // continue walking
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										340
									
								
								dataset/nettrie/valuetrie_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										340
									
								
								dataset/nettrie/valuetrie_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,340 @@
 | 
			
		||||
package nettrie
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"maps"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"slices"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestValueTrie_IPv4(t *testing.T) {
 | 
			
		||||
	trie := NewValue[string]()
 | 
			
		||||
 | 
			
		||||
	routes := map[string]string{
 | 
			
		||||
		"0.0.0.0/0":        "Default",
 | 
			
		||||
		"10.0.0.0/8":       "Private A",
 | 
			
		||||
		"192.168.0.0/16":   "Private B",
 | 
			
		||||
		"192.168.1.0/24":   "LAN",
 | 
			
		||||
		"192.168.1.128/25": "Subnet",
 | 
			
		||||
		"192.168.1.200/32": "Host",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for s, v := range routes {
 | 
			
		||||
		trie.Insert(netip.MustParsePrefix(s), v)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	testCases := []struct {
 | 
			
		||||
		name        string
 | 
			
		||||
		lookupAddr  string
 | 
			
		||||
		expectedVal string
 | 
			
		||||
		expectedOK  bool
 | 
			
		||||
	}{
 | 
			
		||||
		{"Exact Host Match", "192.168.1.200", "Host", true},
 | 
			
		||||
		{"Subnet Match", "192.168.1.201", "Subnet", true},
 | 
			
		||||
		{"LAN Match", "192.168.1.50", "LAN", true},
 | 
			
		||||
		{"Private B Match", "192.168.255.255", "Private B", true},
 | 
			
		||||
		{"Private A Match", "10.255.255.255", "Private A", true},
 | 
			
		||||
		{"Default Match", "8.8.8.8", "Default", true},
 | 
			
		||||
		{"No Match", "203.0.113.1", "Default", true}, // Falls back to default
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tc := range testCases {
 | 
			
		||||
		t.Run(tc.name, func(t *testing.T) {
 | 
			
		||||
			addr := netip.MustParseAddr(tc.lookupAddr)
 | 
			
		||||
			val, ok := trie.Lookup(addr)
 | 
			
		||||
 | 
			
		||||
			if ok != tc.expectedOK {
 | 
			
		||||
				t.Errorf("expected ok=%v, got ok=%v", tc.expectedOK, ok)
 | 
			
		||||
			}
 | 
			
		||||
			if val != tc.expectedVal {
 | 
			
		||||
				t.Errorf("expected value=%q, got %q", tc.expectedVal, val)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestValueTrie_IPv6(t *testing.T) {
 | 
			
		||||
	trie := NewValue[string]()
 | 
			
		||||
 | 
			
		||||
	routes := map[string]string{
 | 
			
		||||
		"::/0":                   "Default",
 | 
			
		||||
		"2001:db8::/32":          "Global Unicast",
 | 
			
		||||
		"2001:db8:acad::/48":     "Academic",
 | 
			
		||||
		"2001:db8:acad:1::/64":   "CS Department",
 | 
			
		||||
		"2001:db8:acad:1::1/128": "Host Route",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for s, v := range routes {
 | 
			
		||||
		trie.Insert(netip.MustParsePrefix(s), v)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	testCases := []struct {
 | 
			
		||||
		name        string
 | 
			
		||||
		lookupAddr  string
 | 
			
		||||
		expectedVal string
 | 
			
		||||
		expectedOK  bool
 | 
			
		||||
	}{
 | 
			
		||||
		{"Exact Host Match", "2001:db8:acad:1::1", "Host Route", true},
 | 
			
		||||
		{"CS Dept Match", "2001:db8:acad:1::2", "CS Department", true},
 | 
			
		||||
		{"Academic Match", "2001:db8:acad:2::1", "Academic", true},
 | 
			
		||||
		{"Global Unicast Match", "2001:db8:cafe::1", "Global Unicast", true},
 | 
			
		||||
		{"Default Match", "2606:4700:4700::1111", "Default", true},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tc := range testCases {
 | 
			
		||||
		t.Run(tc.name, func(t *testing.T) {
 | 
			
		||||
			addr := netip.MustParseAddr(tc.lookupAddr)
 | 
			
		||||
			val, ok := trie.Lookup(addr)
 | 
			
		||||
 | 
			
		||||
			if ok != tc.expectedOK {
 | 
			
		||||
				t.Errorf("expected ok=%v, got ok=%v", tc.expectedOK, ok)
 | 
			
		||||
			}
 | 
			
		||||
			if val != tc.expectedVal {
 | 
			
		||||
				t.Errorf("expected value=%q, got %q", tc.expectedVal, val)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestValueTrie_UpdateValue(t *testing.T) {
 | 
			
		||||
	trie := NewValue[string]()
 | 
			
		||||
	prefix := netip.MustParsePrefix("192.168.1.0/24")
 | 
			
		||||
	addr := netip.MustParseAddr("192.168.1.1")
 | 
			
		||||
 | 
			
		||||
	trie.Insert(prefix, "Initial Value")
 | 
			
		||||
	val, _ := trie.Lookup(addr)
 | 
			
		||||
	if val != "Initial Value" {
 | 
			
		||||
		t.Fatalf("expected initial value, got %q", val)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	trie.Insert(prefix, "Updated Value")
 | 
			
		||||
	val, _ = trie.Lookup(addr)
 | 
			
		||||
	if val != "Updated Value" {
 | 
			
		||||
		t.Fatalf("expected updated value, got %q", val)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestValueTrie_Delete(t *testing.T) {
 | 
			
		||||
	trie := NewValue[string]()
 | 
			
		||||
 | 
			
		||||
	trie.Insert(netip.MustParsePrefix("10.0.0.0/8"), "A")
 | 
			
		||||
	trie.Insert(netip.MustParsePrefix("10.0.0.0/24"), "B")
 | 
			
		||||
	trie.Insert(netip.MustParsePrefix("10.0.1.0/24"), "C")
 | 
			
		||||
 | 
			
		||||
	t.Run("Delete Non-Existent", func(t *testing.T) {
 | 
			
		||||
		if trie.Delete(netip.MustParsePrefix("1.2.3.4/32")) {
 | 
			
		||||
			t.Error("deleted a non-existent prefix")
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Delete Leaf Node", func(t *testing.T) {
 | 
			
		||||
		// Delete 10.0.0.0/24, which is a leaf.
 | 
			
		||||
		if !trie.Delete(netip.MustParsePrefix("10.0.0.0/24")) {
 | 
			
		||||
			t.Fatal("failed to delete 10.0.0.0/24")
 | 
			
		||||
		}
 | 
			
		||||
		// Lookup should now resolve to the parent 10.0.0.0/8
 | 
			
		||||
		val, ok := trie.Lookup(netip.MustParseAddr("10.0.0.1"))
 | 
			
		||||
		if !ok || val != "A" {
 | 
			
		||||
			t.Errorf("lookup failed after delete, got val=%q ok=%v, want val=\"A\" ok=true", val, ok)
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Delete and Merge", func(t *testing.T) {
 | 
			
		||||
		// Insert a new prefix that causes a split
 | 
			
		||||
		trie.Insert(netip.MustParsePrefix("10.0.0.0/8"), "A-updated")
 | 
			
		||||
 | 
			
		||||
		// We now have 10.0.0.0/8 and 10.0.1.0/24. Deleting 10.0.1.0/24 should
 | 
			
		||||
		// cause the split node to be removed and merged back into 10.0.0.0/8.
 | 
			
		||||
		if !trie.Delete(netip.MustParsePrefix("10.0.1.0/24")) {
 | 
			
		||||
			t.Fatal("failed to delete 10.0.1.0/24")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// A lookup for 10.0.1.1 should now match 10.0.0.0/8
 | 
			
		||||
		val, ok := trie.Lookup(netip.MustParseAddr("10.0.1.1"))
 | 
			
		||||
		if !ok || val != "A-updated" {
 | 
			
		||||
			t.Errorf("lookup failed after merge, got val=%q ok=%v, want val=\"A-updated\" ok=true", val, ok)
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Delete Prefix Of Another", func(t *testing.T) {
 | 
			
		||||
		// Delete 10.0.0.0/8
 | 
			
		||||
		trie.Insert(netip.MustParsePrefix("10.0.0.0/8"), "A")
 | 
			
		||||
		trie.Insert(netip.MustParsePrefix("10.1.0.0/16"), "D")
 | 
			
		||||
 | 
			
		||||
		if !trie.Delete(netip.MustParsePrefix("10.0.0.0/8")) {
 | 
			
		||||
			t.Fatal("failed to delete 10.0.0.0/8")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Lookup for 10.0.0.1 should now fail (no default route)
 | 
			
		||||
		_, ok := trie.Lookup(netip.MustParseAddr("10.0.0.1"))
 | 
			
		||||
		if ok {
 | 
			
		||||
			t.Error("lookup for 10.0.0.1 should have failed")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Lookup for 10.1.0.1 should still succeed
 | 
			
		||||
		val, ok := trie.Lookup(netip.MustParseAddr("10.1.0.1"))
 | 
			
		||||
		if !ok || val != "D" {
 | 
			
		||||
			t.Error("lookup for 10.1.0.1 failed unexpectedly")
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestValueTrie_Contains(t *testing.T) {
 | 
			
		||||
	trie := NewValue[string]()
 | 
			
		||||
 | 
			
		||||
	prefixes := []string{
 | 
			
		||||
		"10.0.0.0/8",
 | 
			
		||||
		"192.168.1.0/24",
 | 
			
		||||
		"192.168.1.200/32",
 | 
			
		||||
		"2001:db8::/32",
 | 
			
		||||
		"2001:db8:acad:1::1/128",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, s := range prefixes {
 | 
			
		||||
		trie.Insert(netip.MustParsePrefix(s), "present")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.Run("ContainsPrefix", func(t *testing.T) {
 | 
			
		||||
		testCases := []struct {
 | 
			
		||||
			prefix string
 | 
			
		||||
			want   bool
 | 
			
		||||
		}{
 | 
			
		||||
			{"10.0.0.0/8", true},
 | 
			
		||||
			{"192.168.1.200/32", true},
 | 
			
		||||
			{"2001:db8::/32", true},
 | 
			
		||||
			{"2001:db8:acad:1::1/128", true},
 | 
			
		||||
			{"10.0.0.0/9", false},     // shorter parent, but not exact
 | 
			
		||||
			{"192.168.1.0/25", false}, // non-existent child
 | 
			
		||||
			{"172.16.0.0/12", false},  // completely different prefix
 | 
			
		||||
			{"2001:db8::/48", false},  // non-existent child
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for _, tc := range testCases {
 | 
			
		||||
			p := netip.MustParsePrefix(tc.prefix)
 | 
			
		||||
			if got := trie.ContainsPrefix(p); got != tc.want {
 | 
			
		||||
				t.Errorf("ContainsPrefix(%q) = %v, want %v", tc.prefix, got, tc.want)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Contains", func(t *testing.T) {
 | 
			
		||||
		testCases := []struct {
 | 
			
		||||
			addr string
 | 
			
		||||
			want bool
 | 
			
		||||
		}{
 | 
			
		||||
			{"192.168.1.200", true},
 | 
			
		||||
			{"2001:db8:acad:1::1", true},
 | 
			
		||||
			{"192.168.1.201", false}, // In /24 range, but not a /32 host route
 | 
			
		||||
			{"10.0.0.1", false},      // In /8 range, but not a /32 host route
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for _, tc := range testCases {
 | 
			
		||||
			a := netip.MustParseAddr(tc.addr)
 | 
			
		||||
			if got := trie.Contains(a); got != tc.want {
 | 
			
		||||
				t.Errorf("Contains(%q) = %v, want %v", tc.addr, got, tc.want)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestValueTrie_Walk(t *testing.T) {
 | 
			
		||||
	trie := NewValue[string]()
 | 
			
		||||
	prefixes := map[string]string{
 | 
			
		||||
		"10.0.0.0/8":         "A",
 | 
			
		||||
		"192.168.1.0/24":     "B",
 | 
			
		||||
		"2001:db8::/32":      "C",
 | 
			
		||||
		"172.16.0.0/12":      "D",
 | 
			
		||||
		"2001:db8:acad::/48": "E",
 | 
			
		||||
	}
 | 
			
		||||
	for s, v := range prefixes {
 | 
			
		||||
		trie.Insert(netip.MustParsePrefix(s), v)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.Run("Walk all prefixes", func(t *testing.T) {
 | 
			
		||||
		walked := make(map[string]string)
 | 
			
		||||
		trie.Walk(func(p netip.Prefix, v string) bool {
 | 
			
		||||
			walked[p.String()] = v
 | 
			
		||||
			return true
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		if !maps.Equal(walked, prefixes) {
 | 
			
		||||
			t.Errorf("walked prefixes mismatch:\nexpected: %v\ngot:      %v", prefixes, walked)
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Stop walk early", func(t *testing.T) {
 | 
			
		||||
		count := 0
 | 
			
		||||
		trie.Walk(func(p netip.Prefix, v string) bool {
 | 
			
		||||
			count++
 | 
			
		||||
			return count < 3 // Stop after visiting 3 prefixes
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		if count != 3 {
 | 
			
		||||
			t.Errorf("expected walk to stop after 3 prefixes, but it visited %d", count)
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Stop walk between families", func(t *testing.T) {
 | 
			
		||||
		stopTrie := NewValue[int]()
 | 
			
		||||
		stopTrie.Insert(netip.MustParsePrefix("10.0.0.0/8"), 1)
 | 
			
		||||
		stopTrie.Insert(netip.MustParsePrefix("2001:db8::/32"), 2)
 | 
			
		||||
 | 
			
		||||
		count := 0
 | 
			
		||||
		stopTrie.Walk(func(p netip.Prefix, v int) bool {
 | 
			
		||||
			count++
 | 
			
		||||
			return false
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		if count != 1 {
 | 
			
		||||
			t.Errorf("expected walk to stop after 1 prefix, but it visited %d", count)
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestValueTrie_Merge(t *testing.T) {
 | 
			
		||||
	trieA := NewValue[string]()
 | 
			
		||||
	trieA.Insert(netip.MustParsePrefix("10.0.0.0/8"), "net10")
 | 
			
		||||
	trieA.Insert(netip.MustParsePrefix("192.168.1.0/24"), "lan_A")
 | 
			
		||||
	trieA.Insert(netip.MustParsePrefix("2001:db8::/32"), "v6_A")
 | 
			
		||||
 | 
			
		||||
	trieB := NewValue[string]()
 | 
			
		||||
	trieB.Insert(netip.MustParsePrefix("10.1.0.0/16"), "net10_subnet")
 | 
			
		||||
	trieB.Insert(netip.MustParsePrefix("192.168.1.0/24"), "lan_B_override") // Overlap
 | 
			
		||||
	trieB.Insert(netip.MustParsePrefix("172.16.0.0/12"), "corp")
 | 
			
		||||
	trieB.Insert(netip.MustParsePrefix("2001:db8:acad::/48"), "v6_B")
 | 
			
		||||
 | 
			
		||||
	trieA.Merge(trieB)
 | 
			
		||||
 | 
			
		||||
	expected := map[string]string{
 | 
			
		||||
		"10.0.0.0/8":         "net10",
 | 
			
		||||
		"192.168.1.0/24":     "lan_B_override", // Value should be from trieB
 | 
			
		||||
		"2001:db8::/32":      "v6_A",
 | 
			
		||||
		"10.1.0.0/16":        "net10_subnet",
 | 
			
		||||
		"172.16.0.0/12":      "corp",
 | 
			
		||||
		"2001:db8:acad::/48": "v6_B",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	actual := make(map[string]string)
 | 
			
		||||
	trieA.Walk(func(p netip.Prefix, v string) bool {
 | 
			
		||||
		actual[p.String()] = v
 | 
			
		||||
		return true
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if !reflect.DeepEqual(expected, actual) {
 | 
			
		||||
		t.Errorf("Merge result incorrect.\nExpected: %v\nGot: %v", expected, actual)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Verify trieB is unchanged by collecting its prefixes
 | 
			
		||||
	bPrefixes := []string{}
 | 
			
		||||
	trieB.Walk(func(p netip.Prefix, v string) bool {
 | 
			
		||||
		bPrefixes = append(bPrefixes, p.String())
 | 
			
		||||
		return true
 | 
			
		||||
	})
 | 
			
		||||
	expectedBPrefixes := []string{"10.1.0.0/16", "172.16.0.0/12", "192.168.1.0/24", "2001:db8:acad::/48"}
 | 
			
		||||
	slices.Sort(bPrefixes)
 | 
			
		||||
	slices.Sort(expectedBPrefixes)
 | 
			
		||||
	if !reflect.DeepEqual(bPrefixes, expectedBPrefixes) {
 | 
			
		||||
		t.Errorf("trieB was modified during merge.\nExpected: %v\nGot: %v", expectedBPrefixes, bPrefixes)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -2,10 +2,37 @@ package dataset
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/dataset/nettrie"
 | 
			
		||||
	"github.com/yl2chen/cidranger"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type NetworkTrie struct {
 | 
			
		||||
	*nettrie.Trie
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewNetworkTrie(prefixes ...netip.Prefix) *NetworkTrie {
 | 
			
		||||
	trie := &NetworkTrie{
 | 
			
		||||
		Trie: nettrie.New(),
 | 
			
		||||
	}
 | 
			
		||||
	for _, prefix := range prefixes {
 | 
			
		||||
		trie.Insert(prefix)
 | 
			
		||||
	}
 | 
			
		||||
	return trie
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (trie *NetworkTrie) ContainsIP(ip net.IP) bool {
 | 
			
		||||
	if ip == nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	addr, ok := netip.AddrFromSlice(ip)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return trie.Contains(addr)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type NetworkTree struct {
 | 
			
		||||
	ranger cidranger.Ranger
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,21 +1,20 @@
 | 
			
		||||
package dataset
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"io/fs"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"os"
 | 
			
		||||
	"slices"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/dataset/parser"
 | 
			
		||||
	_ "github.com/mattn/go-sqlite3" // SQLite3 driver
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Storage interface {
 | 
			
		||||
@@ -27,11 +26,13 @@ type Storage interface {
 | 
			
		||||
 | 
			
		||||
	Clients() (Clients, error)
 | 
			
		||||
	ClientByID(int64) (Client, error)
 | 
			
		||||
	ClientByIP(net.IP) (Client, error)
 | 
			
		||||
	ClientByAddr(netip.Addr) (Client, error)
 | 
			
		||||
	// ClientByIP(net.IP) (Client, error)
 | 
			
		||||
	SaveClient(*Client) error
 | 
			
		||||
	DeleteClient(Client) error
 | 
			
		||||
 | 
			
		||||
	Lists() ([]List, error)
 | 
			
		||||
	ListsByGroup(Group) ([]List, error)
 | 
			
		||||
	ListByID(int64) (List, error)
 | 
			
		||||
	SaveList(*List) error
 | 
			
		||||
	DeleteList(List) error
 | 
			
		||||
@@ -44,7 +45,6 @@ type Group struct {
 | 
			
		||||
	Description string    `json:"description"`
 | 
			
		||||
	CreatedAt   time.Time `json:"created_at" bstore:"nonzero"`
 | 
			
		||||
	UpdatedAt   time.Time `json:"updated_at" bstore:"nonzero"`
 | 
			
		||||
	Storage     Storage   `json:"-" bstore:"-"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Client struct {
 | 
			
		||||
@@ -56,11 +56,6 @@ type Client struct {
 | 
			
		||||
	Groups      []Group   `json:"groups,omitempty" bstore:"-"`
 | 
			
		||||
	CreatedAt   time.Time `json:"created_at" bstore:"nonzero"`
 | 
			
		||||
	UpdatedAt   time.Time `json:"updated_at" bstore:"nonzero"`
 | 
			
		||||
	Storage     Storage   `json:"-" bstore:"-"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type WithClient interface {
 | 
			
		||||
	Client() (Client, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ClientGroup struct {
 | 
			
		||||
@@ -80,6 +75,15 @@ func (c *Client) ContainsIP(ip net.IP) bool {
 | 
			
		||||
	return ipnet.Contains(ip)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) ContainsAddr(ip netip.Addr) bool {
 | 
			
		||||
	return c.Prefix().Contains(ip)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c Client) Prefix() netip.Prefix {
 | 
			
		||||
	ip, _ := netip.ParseAddr(c.IP)
 | 
			
		||||
	return netip.PrefixFrom(ip, c.Mask)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) String() string {
 | 
			
		||||
	ipnet := &net.IPNet{
 | 
			
		||||
		IP:   net.ParseIP(c.IP),
 | 
			
		||||
@@ -136,28 +140,29 @@ type List struct {
 | 
			
		||||
	UpdatedAt    time.Time     `json:"updated_at"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (list *List) Domains() (*DomainTree, error) {
 | 
			
		||||
func (list *List) Networks() (*NetworkTrie, error) {
 | 
			
		||||
	if list.Type != ListTypeNetwork {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	prefixes, _, err := parser.ParseNetworks(bytes.NewReader(list.Cache))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return NewNetworkTrie(prefixes...), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (list *List) Domains() (*DomainTrie, error) {
 | 
			
		||||
	if list.Type != ListTypeDomain {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var (
 | 
			
		||||
		tree = NewDomainList()
 | 
			
		||||
		scan = bufio.NewScanner(bytes.NewReader(list.Cache))
 | 
			
		||||
	)
 | 
			
		||||
	for scan.Scan() {
 | 
			
		||||
		line := strings.TrimSpace(scan.Text())
 | 
			
		||||
		if line == "" || line[0] == '#' {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if labels, ok := dns.IsDomainName(line); ok && labels >= 2 {
 | 
			
		||||
			tree.Add(line)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if err := scan.Err(); err != nil {
 | 
			
		||||
	domains, _, err := parser.ParseDomains(bytes.NewReader(list.Cache))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return tree, nil
 | 
			
		||||
 | 
			
		||||
	return NewDomainTrie(list.Permit, domains...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (list *List) Update() (updated bool, err error) {
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,7 @@ import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"slices"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -20,13 +21,18 @@ type bstoreStorage struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func OpenBStore(name string) (Storage, error) {
 | 
			
		||||
	log := logger.StandardLog.Value("database", name)
 | 
			
		||||
 | 
			
		||||
	if !filepath.IsAbs(name) {
 | 
			
		||||
		var err error
 | 
			
		||||
		if name, err = filepath.Abs(name); err != nil {
 | 
			
		||||
			log.Err(err).Error("Opening BoltDB storage failed; invalid path")
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		log = log.Value("database", name)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Debug("Opening BoltDB storage")
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	db, err := bstore.Open(ctx, name, nil,
 | 
			
		||||
		Group{},
 | 
			
		||||
@@ -36,6 +42,7 @@ func OpenBStore(name string) (Storage, error) {
 | 
			
		||||
		ListGroup{},
 | 
			
		||||
	)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Err(err).Error("Opening BoltDB storage failed")
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -47,6 +54,7 @@ func OpenBStore(name string) (Storage, error) {
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if defaultGroup, err = s.GroupByName("Default"); errors.Is(err, bstore.ErrAbsent) {
 | 
			
		||||
		log.Debug("Creating default group")
 | 
			
		||||
		defaultGroup = Group{
 | 
			
		||||
			Name:        "Default",
 | 
			
		||||
			IsEnabled:   true,
 | 
			
		||||
@@ -63,6 +71,7 @@ func OpenBStore(name string) (Storage, error) {
 | 
			
		||||
		FilterFn(func(client Client) bool {
 | 
			
		||||
			return net.ParseIP(client.IP).Equal(net.ParseIP("0.0.0.0")) && client.Mask == 0
 | 
			
		||||
		}).Get(); errors.Is(err, bstore.ErrAbsent) {
 | 
			
		||||
		log.Debug("Creating default IPv4 clients")
 | 
			
		||||
		defaultClient4 = Client{
 | 
			
		||||
			Network:     "ipv4",
 | 
			
		||||
			IP:          "0.0.0.0",
 | 
			
		||||
@@ -83,6 +92,7 @@ func OpenBStore(name string) (Storage, error) {
 | 
			
		||||
		FilterFn(func(client Client) bool {
 | 
			
		||||
			return net.ParseIP(client.IP).Equal(net.ParseIP("::")) && client.Mask == 0
 | 
			
		||||
		}).Get(); errors.Is(err, bstore.ErrAbsent) {
 | 
			
		||||
		log.Debug("Creating default IPv6 clients")
 | 
			
		||||
		defaultClient6 = Client{
 | 
			
		||||
			Network:     "ipv6",
 | 
			
		||||
			IP:          "::",
 | 
			
		||||
@@ -100,6 +110,7 @@ func OpenBStore(name string) (Storage, error) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Start updater
 | 
			
		||||
	log.Trace("Starting list updater")
 | 
			
		||||
	NewUpdater(s)
 | 
			
		||||
 | 
			
		||||
	return s, nil
 | 
			
		||||
@@ -197,7 +208,12 @@ func (s *bstoreStorage) ClientByID(id int64) (Client, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *bstoreStorage) ClientByIP(ip net.IP) (Client, error) {
 | 
			
		||||
	if ip == nil {
 | 
			
		||||
	addr, _ := netip.AddrFromSlice(ip)
 | 
			
		||||
	return s.ClientByAddr(addr)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *bstoreStorage) ClientByAddr(addr netip.Addr) (Client, error) {
 | 
			
		||||
	if !addr.IsValid() {
 | 
			
		||||
		return Client{}, ErrNotExist{Object: "client"}
 | 
			
		||||
	}
 | 
			
		||||
	var (
 | 
			
		||||
@@ -205,9 +221,9 @@ func (s *bstoreStorage) ClientByIP(ip net.IP) (Client, error) {
 | 
			
		||||
		clients Clients
 | 
			
		||||
		network string
 | 
			
		||||
	)
 | 
			
		||||
	if ip4 := ip.To4(); ip4 != nil {
 | 
			
		||||
	if addr.Is4() {
 | 
			
		||||
		network = "ipv4"
 | 
			
		||||
	} else if ip6 := ip.To16(); ip6 != nil {
 | 
			
		||||
	} else {
 | 
			
		||||
		network = "ipv6"
 | 
			
		||||
	}
 | 
			
		||||
	if network == "" {
 | 
			
		||||
@@ -216,7 +232,7 @@ func (s *bstoreStorage) ClientByIP(ip net.IP) (Client, error) {
 | 
			
		||||
	for client, err := range bstore.QueryDB[Client](ctx, s.db).
 | 
			
		||||
		FilterEqual("Network", network).
 | 
			
		||||
		FilterFn(func(client Client) bool {
 | 
			
		||||
			return client.ContainsIP(ip)
 | 
			
		||||
			return client.ContainsAddr(addr)
 | 
			
		||||
		}).All() {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return Client{}, err
 | 
			
		||||
@@ -320,6 +336,26 @@ func (s *bstoreStorage) Lists() ([]List, error) {
 | 
			
		||||
	return lists, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *bstoreStorage) ListsByGroup(group Group) ([]List, error) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	ids := make([]int64, 0)
 | 
			
		||||
	for item, err := range bstore.QueryDB[ListGroup](ctx, s.db).FilterEqual("GroupID", group.ID).All() {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		ids = append(ids, item.ListID)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var lists []List
 | 
			
		||||
	for list, err := range bstore.QueryDB[List](ctx, s.db).FilterIDs(ids).All() {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		lists = append(lists, list)
 | 
			
		||||
	}
 | 
			
		||||
	return lists, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *bstoreStorage) ListByID(id int64) (List, error) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	list, err := bstore.QueryDB[List](ctx, s.db).FilterID(id).Get()
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										142
									
								
								dataset/storage_cache.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										142
									
								
								dataset/storage_cache.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,142 @@
 | 
			
		||||
package dataset
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/logger"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const MinCacheExpire = 10 * time.Second
 | 
			
		||||
 | 
			
		||||
type cache struct {
 | 
			
		||||
	Storage
 | 
			
		||||
	expire       time.Duration
 | 
			
		||||
	groupByID    sync.Map
 | 
			
		||||
	clientByAddr sync.Map
 | 
			
		||||
	listByGroup  sync.Map
 | 
			
		||||
	closed       chan struct{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type cacheItem struct {
 | 
			
		||||
	cachedAt time.Time
 | 
			
		||||
	value    any
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Cache items returned from a Storage for the specified duration.
 | 
			
		||||
//
 | 
			
		||||
// Does not cache negative hits.
 | 
			
		||||
func Cache(storage Storage, expire time.Duration) Storage {
 | 
			
		||||
	if expire < MinCacheExpire {
 | 
			
		||||
		expire = MinCacheExpire
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.StandardLog.Value("expire", expire).Debug("Caching Storage responses")
 | 
			
		||||
	s := &cache{
 | 
			
		||||
		Storage: storage,
 | 
			
		||||
		expire:  expire,
 | 
			
		||||
		closed:  make(chan struct{}, 1),
 | 
			
		||||
	}
 | 
			
		||||
	go s.cleanUpTimer()
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *cache) cleanUpTimer() {
 | 
			
		||||
	ticker := time.NewTicker(s.expire)
 | 
			
		||||
	defer ticker.Stop()
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-s.closed:
 | 
			
		||||
			return
 | 
			
		||||
 | 
			
		||||
		case now := <-ticker.C:
 | 
			
		||||
			logger.StandardLog.Trace("Cache cleanup running")
 | 
			
		||||
			s.cleanUp(now, &s.groupByID)
 | 
			
		||||
			s.cleanUp(now, &s.clientByAddr)
 | 
			
		||||
			s.cleanUp(now, &s.listByGroup)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *cache) cleanUp(now time.Time, cacheMap *sync.Map) {
 | 
			
		||||
	cacheMap.Range(func(key, item any) bool {
 | 
			
		||||
		cached := item.(cacheItem)
 | 
			
		||||
		if ago := now.Sub(cached.cachedAt); ago >= s.expire {
 | 
			
		||||
			logger.StandardLog.Values(logger.Values{
 | 
			
		||||
				"ago":  ago,
 | 
			
		||||
				"type": fmt.Sprintf("%T", cached.value),
 | 
			
		||||
				"item": fmt.Sprintf("%s", cached.value),
 | 
			
		||||
			}).Debug("Cache removing expired item")
 | 
			
		||||
			cacheMap.Delete(key)
 | 
			
		||||
		}
 | 
			
		||||
		return true
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *cache) load(now time.Time, cacheMap *sync.Map, key any) (value any, ok bool) {
 | 
			
		||||
	var item any
 | 
			
		||||
	if item, ok = cacheMap.Load(key); !ok {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	cached := item.(cacheItem)
 | 
			
		||||
	if now.Sub(cached.cachedAt) < s.expire {
 | 
			
		||||
		return cached.value, true
 | 
			
		||||
	}
 | 
			
		||||
	return nil, false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *cache) save(now time.Time, cacheMap *sync.Map, key, value any) {
 | 
			
		||||
	cacheMap.Store(key, cacheItem{
 | 
			
		||||
		cachedAt: now,
 | 
			
		||||
		value:    value,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *cache) GroupByID(id int64) (Group, error) {
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	if value, ok := s.load(now, &s.groupByID, id); ok {
 | 
			
		||||
		return value.(Group), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	group, err := s.Storage.GroupByID(id)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		s.save(now, &s.groupByID, id, group)
 | 
			
		||||
	}
 | 
			
		||||
	return group, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *cache) ClientByIP(ip net.IP) (Client, error) {
 | 
			
		||||
	addr, _ := netip.AddrFromSlice(ip)
 | 
			
		||||
	return s.ClientByAddr(addr)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *cache) ClientByAddr(ip netip.Addr) (Client, error) {
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	if value, ok := s.load(now, &s.clientByAddr, ip); ok {
 | 
			
		||||
		return value.(Client), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client, err := s.Storage.ClientByAddr(ip)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		s.save(now, &s.clientByAddr, ip, client)
 | 
			
		||||
	}
 | 
			
		||||
	return client, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *cache) ListsByGroup(group Group) ([]List, error) {
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	if value, ok := s.load(now, &s.listByGroup, group.ID); ok {
 | 
			
		||||
		return value.([]List), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	lists, err := s.Storage.ListsByGroup(group)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		s.save(now, &s.listByGroup, group.ID, lists)
 | 
			
		||||
	}
 | 
			
		||||
	return lists, err
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										3
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								go.mod
									
									
									
									
									
								
							@@ -11,6 +11,7 @@ require (
 | 
			
		||||
	github.com/open-policy-agent/opa v1.9.0
 | 
			
		||||
	github.com/rs/zerolog v1.34.0
 | 
			
		||||
	github.com/sirupsen/logrus v1.9.4-0.20230606125235-dd1b4c2e81af
 | 
			
		||||
	github.com/stretchr/testify v1.11.1
 | 
			
		||||
	github.com/yl2chen/cidranger v1.0.2
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -20,6 +21,7 @@ require (
 | 
			
		||||
	github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
 | 
			
		||||
	github.com/beorn7/perks v1.0.1 // indirect
 | 
			
		||||
	github.com/cespare/xxhash/v2 v2.3.0 // indirect
 | 
			
		||||
	github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
 | 
			
		||||
	github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect
 | 
			
		||||
	github.com/dgraph-io/ristretto/v2 v2.3.0 // indirect
 | 
			
		||||
	github.com/go-ini/ini v1.67.0 // indirect
 | 
			
		||||
@@ -42,6 +44,7 @@ require (
 | 
			
		||||
	github.com/mattn/go-isatty v0.0.20 // indirect
 | 
			
		||||
	github.com/mitchellh/go-wordwrap v1.0.1 // indirect
 | 
			
		||||
	github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
 | 
			
		||||
	github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
 | 
			
		||||
	github.com/prometheus/client_golang v1.23.2 // indirect
 | 
			
		||||
	github.com/prometheus/client_model v0.6.2 // indirect
 | 
			
		||||
	github.com/prometheus/common v0.66.1 // indirect
 | 
			
		||||
 
 | 
			
		||||
@@ -12,12 +12,12 @@ import (
 | 
			
		||||
	proxy "git.maze.io/maze/styx/proxy"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func NewRequestHandler(p *Policy) proxy.RequestHandler {
 | 
			
		||||
	log := logger.StandardLog.Value("policy", p.name)
 | 
			
		||||
func NewRequestHandler(policy *Policy) proxy.RequestHandler {
 | 
			
		||||
	log := logger.StandardLog.Value("policy", policy.name)
 | 
			
		||||
	return proxy.RequestHandlerFunc(func(ctx proxy.Context) (*http.Request, *http.Response) {
 | 
			
		||||
		input := NewInputFromRequest(ctx, ctx.Request())
 | 
			
		||||
		input.logValues(log).Trace("Running request handler")
 | 
			
		||||
		result, err := p.Query(input)
 | 
			
		||||
		result, err := policy.Query(input, proxy.PolicyQueryOptions(ctx)...)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Err(err).Error("Error evaulating policy")
 | 
			
		||||
			return nil, nil
 | 
			
		||||
@@ -32,12 +32,12 @@ func NewRequestHandler(p *Policy) proxy.RequestHandler {
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewDialHandler(p *Policy) proxy.DialHandler {
 | 
			
		||||
	log := logger.StandardLog.Value("policy", p.name)
 | 
			
		||||
func NewDialHandler(policy *Policy) proxy.DialHandler {
 | 
			
		||||
	log := logger.StandardLog.Value("policy", policy.name)
 | 
			
		||||
	return proxy.DialHandlerFunc(func(ctx proxy.Context, req *http.Request) (net.Conn, error) {
 | 
			
		||||
		input := NewInputFromRequest(ctx, req)
 | 
			
		||||
		input.logValues(log).Trace("Running dial handler")
 | 
			
		||||
		result, err := p.Query(input)
 | 
			
		||||
		result, err := policy.Query(input, proxy.PolicyQueryOptions(ctx)...)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Err(err).Error("Error evaulating policy")
 | 
			
		||||
			return nil, nil
 | 
			
		||||
 
 | 
			
		||||
@@ -10,7 +10,6 @@ import (
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strconv"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/dataset"
 | 
			
		||||
	"git.maze.io/maze/styx/internal/netutil"
 | 
			
		||||
	"git.maze.io/maze/styx/logger"
 | 
			
		||||
	proxy "git.maze.io/maze/styx/proxy"
 | 
			
		||||
@@ -48,14 +47,16 @@ func NewInputFromConn(c net.Conn) *Input {
 | 
			
		||||
		TLS:     NewTLSFromConn(c),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if wcl, ok := c.(dataset.WithClient); ok {
 | 
			
		||||
		client, err := wcl.Client()
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			input.Context["client_id"] = client.ID
 | 
			
		||||
			input.Context["client_description"] = client.Description
 | 
			
		||||
			input.Context["groups"] = client.Groups
 | 
			
		||||
	/*
 | 
			
		||||
		if wcl, ok := c.(dataset.WithClient); ok {
 | 
			
		||||
			client, err := wcl.Client()
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				input.Context["client_id"] = client.ID
 | 
			
		||||
				input.Context["client_description"] = client.Description
 | 
			
		||||
				input.Context["groups"] = client.Groups
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	*/
 | 
			
		||||
 | 
			
		||||
	if ctx, ok := c.(proxy.Context); ok {
 | 
			
		||||
		input.Context["local"] = NewClientFromAddr(ctx.LocalAddr())
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package policy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
@@ -10,6 +11,7 @@ import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/go-viper/mapstructure/v2"
 | 
			
		||||
	"github.com/open-policy-agent/opa/v1/ast"
 | 
			
		||||
@@ -168,12 +170,17 @@ func (r *Result) Response(ctx proxy.Context) (*http.Response, error) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Policy) Query(input *Input) (*Result, error) {
 | 
			
		||||
func (p *Policy) Query(input *Input, options ...func(*rego.Rego)) (*Result, error) {
 | 
			
		||||
	log := logger.StandardLog.Value("policy", p.name)
 | 
			
		||||
	log.Trace("Evaluating policy")
 | 
			
		||||
 | 
			
		||||
	var regoOptions = append(p.options, rego.Input(input))
 | 
			
		||||
	for _, option := range options {
 | 
			
		||||
		regoOptions = append(regoOptions, option)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var (
 | 
			
		||||
		rego    = rego.New(append(p.options, rego.Input(input))...)
 | 
			
		||||
		rego    = rego.New(regoOptions...)
 | 
			
		||||
		ctx     = context.Background()
 | 
			
		||||
		rs, err = rego.Eval(ctx)
 | 
			
		||||
	)
 | 
			
		||||
@@ -200,3 +207,27 @@ func (p *Policy) Query(input *Input) (*Result, error) {
 | 
			
		||||
	}
 | 
			
		||||
	return result, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PackageFromFile reads the "package" stanza from the provided Rego policy file.
 | 
			
		||||
//
 | 
			
		||||
// If no stanza can be found, an error is returned.
 | 
			
		||||
func PackageFromFile(name string) (string, error) {
 | 
			
		||||
	f, err := os.Open(name)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	defer func() { _ = f.Close() }()
 | 
			
		||||
 | 
			
		||||
	scanner := bufio.NewScanner(f)
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		text := strings.TrimSpace(scanner.Text())
 | 
			
		||||
		part := strings.Fields(text)
 | 
			
		||||
		if len(part) > 1 && part[0] == "package" {
 | 
			
		||||
			return part[1], nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	return "", fmt.Errorf("policy: can't detemine package name of %s", name)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -45,8 +45,11 @@ type Context interface {
 | 
			
		||||
	// Response is the response that will be sent back to the client.
 | 
			
		||||
	Response() *http.Response
 | 
			
		||||
 | 
			
		||||
	// Logger for this context.
 | 
			
		||||
	Logger() logger.Structured
 | 
			
		||||
 | 
			
		||||
	// Client group.
 | 
			
		||||
	Client() (dataset.Client, error)
 | 
			
		||||
	Storage() dataset.Storage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type WithCertificateAuthority interface {
 | 
			
		||||
@@ -91,11 +94,10 @@ type proxyContext struct {
 | 
			
		||||
	idleTimeout    time.Duration
 | 
			
		||||
	ca             ca.CertificateAuthority
 | 
			
		||||
	storage        dataset.Storage
 | 
			
		||||
	client         dataset.Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewContext returns an initialized context for the provided [net.Conn].
 | 
			
		||||
func NewContext(c net.Conn) Context {
 | 
			
		||||
func NewContext(c net.Conn, storage dataset.Storage) Context {
 | 
			
		||||
	if c, ok := c.(*proxyContext); ok {
 | 
			
		||||
		return c
 | 
			
		||||
	}
 | 
			
		||||
@@ -106,12 +108,13 @@ func NewContext(c net.Conn) Context {
 | 
			
		||||
	cr := &countingReader{reader: c}
 | 
			
		||||
	cw := &countingWriter{writer: c}
 | 
			
		||||
	return &proxyContext{
 | 
			
		||||
		Conn: c,
 | 
			
		||||
		id:   binary.BigEndian.Uint64(b),
 | 
			
		||||
		cr:   cr,
 | 
			
		||||
		br:   bufio.NewReader(cr),
 | 
			
		||||
		cw:   cw,
 | 
			
		||||
		res:  &http.Response{StatusCode: 200},
 | 
			
		||||
		Conn:    c,
 | 
			
		||||
		id:      binary.BigEndian.Uint64(b),
 | 
			
		||||
		cr:      cr,
 | 
			
		||||
		br:      bufio.NewReader(cr),
 | 
			
		||||
		cw:      cw,
 | 
			
		||||
		res:     &http.Response{StatusCode: 200},
 | 
			
		||||
		storage: storage,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -128,7 +131,7 @@ func (c *proxyContext) AccessLogEntry() logger.Structured {
 | 
			
		||||
	return entry
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *proxyContext) LogEntry() logger.Structured {
 | 
			
		||||
func (c *proxyContext) Logger() logger.Structured {
 | 
			
		||||
	var id [8]byte
 | 
			
		||||
	binary.BigEndian.PutUint64(id[:], c.id)
 | 
			
		||||
	return ServerLog.Values(logger.Values{
 | 
			
		||||
@@ -234,24 +237,8 @@ func (c *proxyContext) CertificateAuthority() ca.CertificateAuthority {
 | 
			
		||||
	return c.ca
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *proxyContext) Client() (dataset.Client, error) {
 | 
			
		||||
	if c.storage == nil {
 | 
			
		||||
		return dataset.Client{}, dataset.ErrNotExist{Object: "client"}
 | 
			
		||||
	}
 | 
			
		||||
	if !c.client.CreatedAt.Equal(time.Time{}) {
 | 
			
		||||
		return c.client, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var err error
 | 
			
		||||
	switch addr := c.Conn.RemoteAddr().(type) {
 | 
			
		||||
	case *net.TCPAddr:
 | 
			
		||||
		c.client, err = c.storage.ClientByIP(addr.IP)
 | 
			
		||||
	case *net.UDPAddr:
 | 
			
		||||
		c.client, err = c.storage.ClientByIP(addr.IP)
 | 
			
		||||
	default:
 | 
			
		||||
		err = dataset.ErrNotExist{Object: "client"}
 | 
			
		||||
	}
 | 
			
		||||
	return c.client, err
 | 
			
		||||
func (c *proxyContext) Storage() dataset.Storage {
 | 
			
		||||
	return c.storage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Context = (*proxyContext)(nil)
 | 
			
		||||
 
 | 
			
		||||
@@ -144,18 +144,21 @@ func Transparent(port int) ConnHandler {
 | 
			
		||||
			return nctx, nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		b := new(bytes.Buffer)
 | 
			
		||||
		hello, err := cryptutil.ReadClientHello(io.TeeReader(netutil.ReadOnlyConn{Reader: ctx.br}, b))
 | 
			
		||||
		var (
 | 
			
		||||
			b          = new(bytes.Buffer)
 | 
			
		||||
			hello, err = cryptutil.ReadClientHello(io.TeeReader(netutil.ReadOnlyConn{Reader: ctx.br}, b))
 | 
			
		||||
			log        = ctx.Logger()
 | 
			
		||||
		)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if _, ok := err.(tls.RecordHeaderError); !ok {
 | 
			
		||||
				ctx.LogEntry().Err(err).Value("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error")
 | 
			
		||||
				log.Err(err).Value("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error")
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
			// Not a TLS connection, moving on to regular HTTP request handling...
 | 
			
		||||
			ctx.LogEntry().Debug("HTTP connection on transparent port")
 | 
			
		||||
			log.Debug("HTTP connection on transparent port")
 | 
			
		||||
			ctx.transparent = port
 | 
			
		||||
		} else {
 | 
			
		||||
			ctx.LogEntry().Value("target", hello.ServerName).Debug("TLS connection on transparent port")
 | 
			
		||||
			log.Value("target", hello.ServerName).Debug("TLS connection on transparent port")
 | 
			
		||||
			ctx.transparent = port
 | 
			
		||||
			ctx.transparentTLS = true
 | 
			
		||||
			ctx.serverName = hello.ServerName
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										178
									
								
								proxy/policy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										178
									
								
								proxy/policy.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,178 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/dataset"
 | 
			
		||||
	"git.maze.io/maze/styx/internal/netutil"
 | 
			
		||||
	"github.com/open-policy-agent/opa/v1/ast"
 | 
			
		||||
	"github.com/open-policy-agent/opa/v1/rego"
 | 
			
		||||
	"github.com/open-policy-agent/opa/v1/types"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// PolicyQueryOptions generates the Rego query functions for the provided [Context].
 | 
			
		||||
func PolicyQueryOptions(ctx Context) (options []func(*rego.Rego)) {
 | 
			
		||||
	var (
 | 
			
		||||
		log     = ctx.Logger()
 | 
			
		||||
		storage = ctx.Storage()
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	addr, err := netip.ParseAddr(netutil.Host(ctx.RemoteAddr().String()))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Err(err).Error("Error resolving remote address")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client, err := storage.ClientByAddr(addr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Err(err).Warn("Error resolving client")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var (
 | 
			
		||||
		permitDomains  []*dataset.DomainTrie
 | 
			
		||||
		rejectDomains  []*dataset.DomainTrie
 | 
			
		||||
		permitNetworks []*dataset.NetworkTrie
 | 
			
		||||
		rejectNetworks []*dataset.NetworkTrie
 | 
			
		||||
	)
 | 
			
		||||
	for _, group := range client.Groups {
 | 
			
		||||
		lists, err := storage.ListsByGroup(group)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Err(err).Warn("Error resolving lists")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		for _, list := range lists {
 | 
			
		||||
			switch list.Type {
 | 
			
		||||
			case dataset.ListTypeDomain:
 | 
			
		||||
				trie, err := list.Domains()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					log.Err(err).Warn("Error resolving domain trie")
 | 
			
		||||
				}
 | 
			
		||||
				if list.Permit {
 | 
			
		||||
					permitDomains = append(permitDomains, trie)
 | 
			
		||||
				} else {
 | 
			
		||||
					rejectDomains = append(rejectDomains, trie)
 | 
			
		||||
				}
 | 
			
		||||
			case dataset.ListTypeNetwork:
 | 
			
		||||
				trie, err := list.Networks()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					log.Err(err).Warn("Error resolving domain trie")
 | 
			
		||||
				}
 | 
			
		||||
				if list.Permit {
 | 
			
		||||
					permitNetworks = append(permitNetworks, trie)
 | 
			
		||||
				} else {
 | 
			
		||||
					rejectNetworks = append(rejectNetworks, trie)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	options = append(options,
 | 
			
		||||
		rego.Function1(®o.Function{
 | 
			
		||||
			Name:             "styx.reject_domain",
 | 
			
		||||
			Description:      "Check if the domain is to be rejected",
 | 
			
		||||
			Decl:             domainFunctionDecl,
 | 
			
		||||
			Nondeterministic: true,
 | 
			
		||||
			Memoize:          true,
 | 
			
		||||
		}, domainFunctionImpl(rejectDomains)),
 | 
			
		||||
		rego.Function1(®o.Function{
 | 
			
		||||
			Name:             "styx.permit_domain",
 | 
			
		||||
			Description:      "Check if the domain is to be permitted",
 | 
			
		||||
			Decl:             domainFunctionDecl,
 | 
			
		||||
			Nondeterministic: true,
 | 
			
		||||
			Memoize:          true,
 | 
			
		||||
		}, domainFunctionImpl(permitDomains)),
 | 
			
		||||
		rego.Function1(®o.Function{
 | 
			
		||||
			Name:             "styx.reject_network",
 | 
			
		||||
			Description:      "Check if the IP, IP:port, host or host:port is to be rejected",
 | 
			
		||||
			Decl:             networkFunctionDecl,
 | 
			
		||||
			Nondeterministic: true,
 | 
			
		||||
			Memoize:          true,
 | 
			
		||||
		}, networkFunctionImpl(rejectNetworks)),
 | 
			
		||||
		rego.Function1(®o.Function{
 | 
			
		||||
			Name:             "styx.permit_network",
 | 
			
		||||
			Description:      "Check if the IP, IP:port, host or host:port is to be permitted",
 | 
			
		||||
			Decl:             networkFunctionDecl,
 | 
			
		||||
			Nondeterministic: true,
 | 
			
		||||
			Memoize:          true,
 | 
			
		||||
		}, networkFunctionImpl(permitNetworks)),
 | 
			
		||||
	)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var domainFunctionDecl = types.NewFunction(
 | 
			
		||||
	types.Args(types.Named("domain", types.S).Description("Domain to lookup")),
 | 
			
		||||
	types.Named("result", types.B).Description("`true` if domain matches"),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func domainFunctionImpl(tries []*dataset.DomainTrie) rego.Builtin1 {
 | 
			
		||||
	return func(ctx rego.BuiltinContext, domainTerm *ast.Term) (*ast.Term, error) {
 | 
			
		||||
		domain, err := parseStringTerm(domainTerm)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		for _, trie := range tries {
 | 
			
		||||
			if trie.Contains(domain) {
 | 
			
		||||
				return ast.NewTerm(ast.Boolean(true)), nil
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		return ast.NewTerm(ast.Boolean(false)), nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var networkFunctionDecl = types.NewFunction(
 | 
			
		||||
	types.Args(types.Named("ip", types.S).Description("IP, IP:port, host or host:port to lookup")),
 | 
			
		||||
	types.Named("result", types.B).Description("`true` if IP matches"),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func networkFunctionImpl(tries []*dataset.NetworkTrie) rego.Builtin1 {
 | 
			
		||||
	return func(ctx rego.BuiltinContext, ipTerm *ast.Term) (*ast.Term, error) {
 | 
			
		||||
		ips, err := parseAddrTerm(ipTerm)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for _, trie := range tries {
 | 
			
		||||
			for _, ip := range ips {
 | 
			
		||||
				if trie.Contains(ip) {
 | 
			
		||||
					return ast.NewTerm(ast.Boolean(true)), nil
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		return ast.NewTerm(ast.Boolean(false)), nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func parseAddrTerm(term *ast.Term) (addrs []netip.Addr, err error) {
 | 
			
		||||
	s, err := parseStringTerm(term)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if addr, err := netip.ParseAddr(netutil.Host(s)); err == nil {
 | 
			
		||||
		// Input was "ip" or "ip:port"
 | 
			
		||||
		return []netip.Addr{addr}, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ips, err := net.LookupIP(netutil.Host(s))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, ip := range ips {
 | 
			
		||||
		if addr, ok := netip.AddrFromSlice(ip); ok {
 | 
			
		||||
			addrs = append(addrs, addr)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func parseStringTerm(term *ast.Term) (string, error) {
 | 
			
		||||
	value, ok := term.Value.(ast.String)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return "", errors.New("expected string argument")
 | 
			
		||||
	}
 | 
			
		||||
	return strings.Trim(value.String(), `"`), nil
 | 
			
		||||
}
 | 
			
		||||
@@ -302,14 +302,15 @@ func (p *Proxy) Serve(l net.Listener) error {
 | 
			
		||||
func (p *Proxy) handle(nc net.Conn) {
 | 
			
		||||
	var (
 | 
			
		||||
		start = time.Now()
 | 
			
		||||
		ctx   = NewContext(nc).(*proxyContext)
 | 
			
		||||
		ctx   = NewContext(nc, p.Storage).(*proxyContext)
 | 
			
		||||
		log   = ctx.Logger()
 | 
			
		||||
		err   error
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if r := recover(); r != nil {
 | 
			
		||||
			if err, ok := r.(error); ok {
 | 
			
		||||
				ctx.LogEntry().Err(err).Warn("Bug in code, recovered from panic!")
 | 
			
		||||
				log.Err(err).Warn("Bug in code, recovered from panic!")
 | 
			
		||||
			}
 | 
			
		||||
			_ = nc.Close()
 | 
			
		||||
		}
 | 
			
		||||
@@ -360,16 +361,6 @@ func (p *Proxy) handle(nc net.Conn) {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log := ctx.LogEntry()
 | 
			
		||||
	if p.Storage != nil {
 | 
			
		||||
		if client, err := p.Storage.ClientByIP(nc.RemoteAddr().(*net.TCPAddr).IP); err == nil {
 | 
			
		||||
			log = log.Values(logger.Values{
 | 
			
		||||
				"client_id":          client.ID,
 | 
			
		||||
				"client_network":     client.String(),
 | 
			
		||||
				"client_description": client.Description,
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	for {
 | 
			
		||||
		if ctx.transparentTLS {
 | 
			
		||||
			ctx.req = &http.Request{
 | 
			
		||||
@@ -448,7 +439,9 @@ func (p *Proxy) handleError(ctx *proxyContext, err error, sendResponse bool) {
 | 
			
		||||
	if res == nil && sendResponse {
 | 
			
		||||
		res = NewErrorResponse(err, ctx.Request())
 | 
			
		||||
	}
 | 
			
		||||
	ctx.LogEntry().Value("count", len(p.OnError)).Trace("Running error handlers")
 | 
			
		||||
 | 
			
		||||
	log := ctx.Logger()
 | 
			
		||||
	log.Value("count", len(p.OnError)).Trace("Running error handlers")
 | 
			
		||||
	for _, f := range p.OnError {
 | 
			
		||||
		if newRes := f.HandleError(ctx, err); newRes != nil {
 | 
			
		||||
			res = newRes
 | 
			
		||||
@@ -464,7 +457,7 @@ func (p *Proxy) handleError(ctx *proxyContext, err error, sendResponse bool) {
 | 
			
		||||
func (p *Proxy) handleRequest(ctx *proxyContext) (err error) {
 | 
			
		||||
	switch {
 | 
			
		||||
	case ctx.req == nil:
 | 
			
		||||
		ctx.LogEntry().Warn("Request is nil in handleRequest!?")
 | 
			
		||||
		ctx.Logger().Warn("Request is nil in handleRequest!?")
 | 
			
		||||
		return errors.New("proxy: request is nil?")
 | 
			
		||||
 | 
			
		||||
	case headerContains(ctx.req.Header, HeaderConnection, "upgrade"):
 | 
			
		||||
@@ -527,7 +520,7 @@ func (p *Proxy) serve(ctx *proxyContext) (err error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Proxy) serveConnect(ctx *proxyContext) (err error) {
 | 
			
		||||
	log := ctx.LogEntry()
 | 
			
		||||
	log := ctx.Logger()
 | 
			
		||||
 | 
			
		||||
	// Most browsers expect to get a 200 OK after firing a HTTP CONNECT request; if the upstream
 | 
			
		||||
	// encounters any errors, we'll inform the client after reading the HTTP request that follows.
 | 
			
		||||
@@ -571,13 +564,13 @@ func (p *Proxy) serveConnect(ctx *proxyContext) (err error) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ctx.res = NewResponse(http.StatusOK, nil, ctx.req)
 | 
			
		||||
	srv := NewContext(c).(*proxyContext)
 | 
			
		||||
	srv := NewContext(c, p.Storage).(*proxyContext)
 | 
			
		||||
	srv.SetIdleTimeout(p.IdleTimeout)
 | 
			
		||||
	return p.multiplex(ctx, srv)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Proxy) serveForward(ctx *proxyContext) (err error) {
 | 
			
		||||
	log := ctx.LogEntry()
 | 
			
		||||
	log := ctx.Logger()
 | 
			
		||||
	log.Value("target", ctx.req.URL.String()).Debugf("%s forward request", ctx.req.Proto)
 | 
			
		||||
 | 
			
		||||
	var res *http.Response
 | 
			
		||||
@@ -609,7 +602,8 @@ func (p *Proxy) serveForward(ctx *proxyContext) (err error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
 | 
			
		||||
	log := ctx.LogEntry().Value("target", ctx.req.URL.String())
 | 
			
		||||
	log := ctx.Logger()
 | 
			
		||||
	log.Value("target", ctx.req.URL.String())
 | 
			
		||||
 | 
			
		||||
	switch ctx.req.URL.Scheme {
 | 
			
		||||
	case "http":
 | 
			
		||||
@@ -632,7 +626,7 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
 | 
			
		||||
	}
 | 
			
		||||
	cancel()
 | 
			
		||||
 | 
			
		||||
	srv := NewContext(c).(*proxyContext)
 | 
			
		||||
	srv := NewContext(c, p.Storage).(*proxyContext)
 | 
			
		||||
	srv.SetIdleTimeout(p.IdleTimeout)
 | 
			
		||||
	if err = ctx.req.Write(srv); err != nil {
 | 
			
		||||
		ctx.res = NewErrorResponse(err, ctx.req)
 | 
			
		||||
@@ -662,7 +656,7 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
 | 
			
		||||
 | 
			
		||||
func (p *Proxy) multiplex(ctx, srv *proxyContext) (err error) {
 | 
			
		||||
	var (
 | 
			
		||||
		log  = ctx.LogEntry().Value("server", srv.RemoteAddr().String())
 | 
			
		||||
		log  = ctx.Logger().Value("server", srv.RemoteAddr().String())
 | 
			
		||||
		errs = make(chan error, 1)
 | 
			
		||||
		done = make(chan struct{}, 1)
 | 
			
		||||
	)
 | 
			
		||||
 
 | 
			
		||||
@@ -1,19 +1,14 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"expvar"
 | 
			
		||||
	"strconv"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/db/stats"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func countStatus(code int) {
 | 
			
		||||
	k := "http:status:" + strconv.Itoa(code)
 | 
			
		||||
	v := expvar.Get(k)
 | 
			
		||||
	if v == nil {
 | 
			
		||||
		//v = stats.NewCounter("120s1s", "15m10s", "1h1m", "4w1d", "1y4w")
 | 
			
		||||
		v = stats.NewCounter(k, stats.Minutely, stats.Hourly, stats.Daily, stats.Yearly)
 | 
			
		||||
		expvar.Publish(k, v)
 | 
			
		||||
	}
 | 
			
		||||
	v.(stats.Metric).Add(1)
 | 
			
		||||
	/*
 | 
			
		||||
		k := "http:status:" + strconv.Itoa(code)
 | 
			
		||||
		v := expvar.Get(k)
 | 
			
		||||
		if v == nil {
 | 
			
		||||
			//v = stats.NewCounter("120s1s", "15m10s", "1h1m", "4w1d", "1y4w")
 | 
			
		||||
			v = stats.NewCounter(k, stats.Minutely, stats.Hourly, stats.Daily, stats.Yearly)
 | 
			
		||||
			expvar.Publish(k, v)
 | 
			
		||||
		}
 | 
			
		||||
		v.(stats.Metric).Add(1)
 | 
			
		||||
	*/
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user