Initial import

This commit is contained in:
2025-09-26 08:49:53 +02:00
commit a76650da35
35 changed files with 4660 additions and 0 deletions

35
internal/netutil/addr.go Normal file
View File

@@ -0,0 +1,35 @@
package netutil
import (
"net"
"strconv"
)
// EnsurePort makes sure the address in [host] contains a port.
func EnsurePort(host, port string) string {
if _, _, err := net.SplitHostPort(host); err == nil {
return host
}
return net.JoinHostPort(host, port)
}
// Host returns the bare host (without port).
func Host(name string) string {
host, _, err := net.SplitHostPort(name)
if err == nil {
return host
}
return name
}
// Port returns the port number.
func Port(name string) int {
_, port, err := net.SplitHostPort(name)
if err != nil {
return 0
}
// TODO: name resolution for ports?
i, _ := strconv.Atoi(port)
return i
}

View File

@@ -0,0 +1,99 @@
package netutil
import (
"strings"
"github.com/miekg/dns"
)
type DomainTree struct {
root *domainTreeNode
}
type domainTreeNode struct {
leaf map[string]*domainTreeNode
isEnd bool
}
func NewDomainList(domains ...string) *DomainTree {
tree := &DomainTree{
root: &domainTreeNode{leaf: make(map[string]*domainTreeNode)},
}
for _, domain := range domains {
tree.Add(domain)
}
return tree
}
func (tree *DomainTree) Add(domain string) {
domain = normalizeDomain(domain)
if domain == "" {
return
}
labels := dns.SplitDomainName(domain)
if len(labels) == 0 {
return
}
node := tree.root
for i := len(labels) - 1; i >= 0; i-- {
label := labels[i]
if label == "" {
continue
}
if node.leaf == nil {
node.leaf = make(map[string]*domainTreeNode)
}
if node.leaf[label] == nil {
node.leaf[label] = &domainTreeNode{}
}
node = node.leaf[label]
}
node.isEnd = true
}
func (tree *DomainTree) Contains(domain string) bool {
domain = normalizeDomain(domain)
if domain == "" {
return false
}
labels := dns.SplitDomainName(domain)
if len(labels) == 0 {
return false
}
node := tree.root
for i := len(labels) - 1; i >= 0; i-- {
if node.isEnd {
return true
}
if node.leaf == nil {
return false
}
label := labels[i]
if node = node.leaf[label]; node == nil {
return false
}
}
return node.isEnd
}
func normalizeDomain(domain string) string {
domain = strings.ToLower(strings.TrimSpace(domain))
if domain == "" {
return ""
}
// Remove trailing dot if present, dns.Fqdn will add it back properly
domain = strings.TrimSuffix(domain, ".")
if domain == "" {
return ""
}
return dns.Fqdn(domain)
}

View File

@@ -0,0 +1,276 @@
package netutil
import (
"testing"
)
func TestDomainList(t *testing.T) {
tests := []struct {
name string
domains []string
hostname string
expected bool
}{
// Basic exact matches
{
name: "exact match",
domains: []string{"example.com"},
hostname: "example.com",
expected: true,
},
{
name: "exact match with subdomain in list",
domains: []string{"api.example.com"},
hostname: "api.example.com",
expected: true,
},
// Suffix matching - if domain is in list, all subdomains should match
{
name: "subdomain matches parent domain",
domains: []string{"example.com"},
hostname: "sub.example.com",
expected: true,
},
{
name: "multiple subdomain levels match",
domains: []string{"example.com"},
hostname: "deep.nested.sub.example.com",
expected: true,
},
{
name: "subdomain matches intermediate domain",
domains: []string{"api.example.com", "example.com"},
hostname: "sub.api.example.com",
expected: true,
},
// Multi-level TLDs
{
name: "co.uk domain exact match",
domains: []string{"domain.co.uk"},
hostname: "domain.co.uk",
expected: true,
},
{
name: "subdomain of co.uk domain",
domains: []string{"domain.co.uk"},
hostname: "sub.domain.co.uk",
expected: true,
},
// Case sensitivity
{
name: "case insensitive match",
domains: []string{"Example.COM"},
hostname: "example.com",
expected: true,
},
{
name: "case insensitive hostname",
domains: []string{"example.com"},
hostname: "EXAMPLE.COM",
expected: true,
},
// Trailing dots
{
name: "domain with trailing dot",
domains: []string{"example.com."},
hostname: "example.com",
expected: true,
},
{
name: "hostname with trailing dot",
domains: []string{"example.com"},
hostname: "example.com.",
expected: true,
},
// Non-matches
{
name: "different TLD",
domains: []string{"example.com"},
hostname: "example.org",
expected: false,
},
{
name: "different domain",
domains: []string{"example.com"},
hostname: "test.com",
expected: false,
},
{
name: "partial match but not suffix",
domains: []string{"example.com"},
hostname: "com",
expected: false,
},
{
name: "empty hostname",
domains: []string{"example.com"},
hostname: "",
expected: false,
},
// Multiple domains in list
{
name: "matches first domain in list",
domains: []string{"test.org", "example.com"},
hostname: "example.com",
expected: true,
},
{
name: "matches second domain in list",
domains: []string{"test.org", "example.com"},
hostname: "test.org",
expected: true,
},
{
name: "subdomain matches any domain in list",
domains: []string{"test.org", "example.com"},
hostname: "sub.example.com",
expected: true,
},
// Edge cases
{
name: "empty domain list",
domains: []string{},
hostname: "example.com",
expected: false,
},
{
name: "invalid domain in list",
domains: []string{""},
hostname: "example.com",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
list := NewDomainList(tt.domains...)
result := list.Contains(tt.hostname)
if result != tt.expected {
t.Errorf("Contains(%q) = %v, expected %v (domains: %v)",
tt.hostname, result, tt.expected, tt.domains)
}
})
}
}
func TestDomainList_Performance(t *testing.T) {
// Test with a large number of domains to ensure performance
domains := make([]string, 1000)
for i := 0; i < 1000; i++ {
domains[i] = string(rune('a'+(i%26))) + ".com"
}
domains = append(domains, "example.com") // Add our test domain
list := NewDomainList(domains...)
// These should be fast even with many domains
if !list.Contains("example.com") {
t.Error("Should match exact domain")
}
if !list.Contains("sub.example.com") {
t.Error("Should match subdomain")
}
if list.Contains("notfound.com") {
t.Error("Should not match unrelated domain")
}
}
func TestDomainList_ComplexDomains(t *testing.T) {
domains := []string{
"very.long.domain.name.with.many.labels.com",
"example.co.uk",
"sub.domain.example.com",
"a.b.c.d.e.f.com",
}
list := NewDomainList(domains...)
tests := []struct {
hostname string
expected bool
}{
{"very.long.domain.name.with.many.labels.com", true},
{"sub.very.long.domain.name.with.many.labels.com", true},
{"example.co.uk", true},
{"www.example.co.uk", true},
{"sub.domain.example.com", true},
{"another.sub.domain.example.com", true},
{"a.b.c.d.e.f.com", true},
{"x.a.b.c.d.e.f.com", true},
{"not.matching.com", false},
{"com", false},
{"uk", false},
}
for _, tt := range tests {
t.Run(tt.hostname, func(t *testing.T) {
result := list.Contains(tt.hostname)
if result != tt.expected {
t.Errorf("Contains(%q) = %v, expected %v", tt.hostname, result, tt.expected)
}
})
}
}
func TestDomainList_SpecialCases(t *testing.T) {
t.Run("domain with asterisk treated literally", func(t *testing.T) {
list := NewDomainList("*.example.com")
// The asterisk should be treated as a literal label, not a wildcard
if !list.Contains("*.example.com") {
t.Error("Asterisk should be treated literally, not as wildcard")
}
if list.Contains("test.example.com") {
t.Error("Should not match subdomain with literal asterisk domain")
}
})
t.Run("domains with hyphens and numbers", func(t *testing.T) {
list := NewDomainList("test-123.example.com", "123abc.org")
if !list.Contains("test-123.example.com") {
t.Error("Should match domain with hyphens and numbers")
}
if !list.Contains("sub.test-123.example.com") {
t.Error("Should match subdomain of hyphenated domain")
}
if !list.Contains("123abc.org") {
t.Error("Should match domain starting with numbers")
}
if !list.Contains("www.123abc.org") {
t.Error("Should match subdomain of numeric domain")
}
})
}
func BenchmarkDomainList(b *testing.B) {
// Benchmark with realistic domain list
domains := []string{
"google.com",
"github.com",
"example.org",
"sub.domain.com",
"api.service.co.uk",
"very.long.domain.name.example.com",
}
list := NewDomainList(domains...)
b.ResetTimer()
for b.Loop() {
// Mix of matches and non-matches
list.Contains("sub.example.org")
list.Contains("api.github.com")
list.Contains("nonexistent.com")
list.Contains("deep.nested.sub.domain.com")
list.Contains("service.co.uk")
}
}

View File

@@ -0,0 +1,44 @@
package netutil
import (
"net"
"github.com/yl2chen/cidranger"
)
type NetworkTree struct {
ranger cidranger.Ranger
}
func NewNetworkTree(networks ...string) (*NetworkTree, error) {
tree := &NetworkTree{
ranger: cidranger.NewPCTrieRanger(),
}
for _, cidr := range networks {
if err := tree.AddCIDR(cidr); err != nil {
return nil, err
}
}
return tree, nil
}
func (tree *NetworkTree) Add(ipnet *net.IPNet) {
if ipnet == nil {
return
}
tree.ranger.Insert(cidranger.NewBasicRangerEntry(*ipnet))
}
func (tree *NetworkTree) AddCIDR(cidr string) error {
_, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
return err
}
tree.ranger.Insert(cidranger.NewBasicRangerEntry(*ipnet))
return nil
}
func (tree *NetworkTree) Contains(ip net.IP) bool {
contains, _ := tree.ranger.Contains(ip)
return contains
}

View File

@@ -0,0 +1,410 @@
package netutil
import (
"net"
"testing"
)
func TestNewNetworkTree(t *testing.T) {
// Test empty creation
nl, err := NewNetworkTree()
if err != nil {
t.Fatalf("NewNetworkTree() failed: %v", err)
}
if nl == nil {
t.Fatal("NewNetworkTree() returned nil")
}
if nl.ranger == nil {
t.Error("NetworkTree ranger should not be nil")
}
// Test creation with networks
nl, err = NewNetworkTree("192.168.1.0/24", "10.0.0.0/8")
if err != nil {
t.Fatalf("NewNetworkTree() with networks failed: %v", err)
}
if nl == nil {
t.Fatal("NewNetworkTree() with networks returned nil")
}
}
func TestNewNetworkTree_InvalidNetworks(t *testing.T) {
// Test with invalid network
_, err := NewNetworkTree("invalid-cidr")
if err == nil {
t.Error("NewNetworkTree() with invalid CIDR should have failed")
}
// Test with mix of valid and invalid networks
_, err = NewNetworkTree("192.168.1.0/24", "invalid-cidr", "10.0.0.0/8")
if err == nil {
t.Error("NewNetworkTree() with mixed valid/invalid CIDRs should have failed")
}
}
func TestNetworkTree_AddCIDR_Valid(t *testing.T) {
nl, err := NewNetworkTree()
if err != nil {
t.Fatalf("NewNetworkTree() failed: %v", err)
}
tests := []struct {
cidr string
desc string
}{
{"192.168.1.0/24", "IPv4 CIDR"},
{"10.0.0.0/8", "IPv4 large range"},
{"2001:db8::/32", "IPv6 CIDR"},
{"::1/128", "IPv6 localhost"},
{"0.0.0.0/0", "IPv4 entire internet"},
{"::/0", "IPv6 entire internet"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
if err := nl.AddCIDR(tt.cidr); err != nil {
t.Errorf("AddCIDR(%q) failed: %v", tt.cidr, err)
}
})
}
}
func TestNetworkTree_AddCIDR_Invalid(t *testing.T) {
nl, err := NewNetworkTree()
if err != nil {
t.Fatalf("NewNetworkTree() failed: %v", err)
}
invalidCIDRs := []string{
"invalid-cidr",
"192.168.1.1", // missing mask
"192.168.1.0/33", // invalid mask for IPv4
"2001:db8::/129", // invalid mask for IPv6
"",
"not-an-ip/24",
}
for _, cidr := range invalidCIDRs {
t.Run(cidr, func(t *testing.T) {
if err := nl.AddCIDR(cidr); err == nil {
t.Errorf("AddCIDR(%q) should have failed but didn't", cidr)
}
})
}
}
func TestNetworkTree_Add(t *testing.T) {
nl, err := NewNetworkTree()
if err != nil {
t.Fatalf("NewNetworkTree() failed: %v", err)
}
tests := []struct {
cidr string
desc string
}{
{"192.168.1.0/24", "IPv4 network"},
{"2001:db8::/32", "IPv6 network"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
_, ipNet, err := net.ParseCIDR(tt.cidr)
if err != nil {
t.Fatalf("ParseCIDR failed: %v", err)
}
// Should not panic
nl.Add(ipNet)
})
}
}
func TestNetworkTree_Contains_IPv4(t *testing.T) {
nl, err := NewNetworkTree("192.168.1.0/24", "10.0.0.0/8", "172.16.0.0/12")
if err != nil {
t.Fatalf("NewNetworkTree() failed: %v", err)
}
tests := []struct {
ip string
want bool
desc string
}{
// IPs that should match
{"192.168.1.1", true, "in 192.168.1.0/24"},
{"192.168.1.255", true, "broadcast in 192.168.1.0/24"},
{"10.0.0.1", true, "in 10.0.0.0/8"},
{"10.255.255.255", true, "max in 10.0.0.0/8"},
{"172.16.0.1", true, "in 172.16.0.0/12"},
{"172.31.255.255", true, "max in 172.16.0.0/12"},
// IPs that should not match
{"192.168.2.1", false, "outside 192.168.1.0/24"},
{"11.0.0.1", false, "outside 10.0.0.0/8"},
{"172.32.0.1", false, "outside 172.16.0.0/12"},
{"8.8.8.8", false, "public DNS"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
ip := net.ParseIP(tt.ip)
if ip == nil {
t.Fatalf("ParseIP(%q) returned nil", tt.ip)
}
got := nl.Contains(ip)
if got != tt.want {
t.Errorf("Contains(%q) = %v, want %v", tt.ip, got, tt.want)
}
})
}
}
func TestNetworkTree_Contains_IPv6(t *testing.T) {
nl, err := NewNetworkTree("2001:db8::/32", "2001:db8:abcd::/48", "::1/128")
if err != nil {
t.Fatalf("NewNetworkTree() failed: %v", err)
}
tests := []struct {
ip string
want bool
desc string
}{
// IPs that should match
{"2001:db8::1", true, "in 2001:db8::/32"},
{"2001:db8:ffff:ffff:ffff:ffff:ffff:ffff", true, "max in 2001:db8::/32"},
{"2001:db8:abcd::1", true, "in 2001:db8:abcd::/48"},
{"::1", true, "localhost"},
// IPs that should not match
{"2001:db9::1", false, "outside 2001:db8::/32"},
{"2001:db9:abcd::1", false, "outside 2001:db8:abcd::/48"},
{"::2", false, "outside ::1/128"},
{"2001:4860:4860::8888", false, "public DNS"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
ip := net.ParseIP(tt.ip)
if ip == nil {
t.Fatalf("ParseIP(%q) returned nil", tt.ip)
}
got := nl.Contains(ip)
if got != tt.want {
t.Errorf("Contains(%q) = %v, want %v", tt.ip, got, tt.want)
}
})
}
}
func TestNetworkTree_Contains_EdgeCases(t *testing.T) {
nl, err := NewNetworkTree()
if err != nil {
t.Fatalf("NewNetworkTree() failed: %v", err)
}
// Test with nil IP
if nl.Contains(nil) != false {
t.Error("Contains(nil) should return false")
}
// Test with empty list
ip := net.ParseIP("192.168.1.1")
if nl.Contains(ip) != false {
t.Error("Contains() on empty list should return false")
}
}
func TestNetworkTree_Contains_OverlappingRanges(t *testing.T) {
nl, err := NewNetworkTree("192.168.0.0/16", "192.168.1.0/24", "192.168.1.128/25")
if err != nil {
t.Fatalf("NewNetworkTree() failed: %v", err)
}
// All these should match because we have overlapping ranges
tests := []string{
"192.168.1.1",
"192.168.1.129",
"192.168.2.1",
}
for _, ipStr := range tests {
t.Run(ipStr, func(t *testing.T) {
ip := net.ParseIP(ipStr)
if !nl.Contains(ip) {
t.Errorf("Contains(%q) should return true for overlapping ranges", ipStr)
}
})
}
}
func TestNetworkTree_Contains_EntireInternet(t *testing.T) {
nl, err := NewNetworkTree("0.0.0.0/0", "::/0")
if err != nil {
t.Fatalf("NewNetworkTree() failed: %v", err)
}
tests := []struct {
ip string
desc string
}{
{"192.168.1.1", "IPv4 private"},
{"8.8.8.8", "IPv4 public"},
{"2001:db8::1", "IPv6"},
{"::1", "IPv6 localhost"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
ip := net.ParseIP(tt.ip)
if !nl.Contains(ip) {
t.Errorf("Contains(%q) should return true for entire internet range", tt.ip)
}
})
}
}
func TestNetworkTree_MixedIPv4AndIPv6(t *testing.T) {
nl, err := NewNetworkTree("192.168.1.0/24", "2001:db8::/32")
if err != nil {
t.Fatalf("NewNetworkTree() failed: %v", err)
}
// Test IPv4 in IPv6 format (should still work due to normalization)
ipv4InIPv6 := net.ParseIP("::ffff:192.168.1.1") // IPv4-mapped IPv6
if !nl.Contains(ipv4InIPv6) {
t.Error("Contains() should handle IPv4-mapped IPv6 addresses")
}
// Regular IPv4 should work
ipv4 := net.ParseIP("192.168.1.1")
if !nl.Contains(ipv4) {
t.Error("Contains() should handle regular IPv4 addresses")
}
// IPv6 should work
ipv6 := net.ParseIP("2001:db8::1")
if !nl.Contains(ipv6) {
t.Error("Contains() should handle IPv6 addresses")
}
}
func TestNetworkTree_Add_InvalidIPNet(t *testing.T) {
nl, err := NewNetworkTree()
if err != nil {
t.Fatalf("NewNetworkTree() failed: %v", err)
}
// Create an invalid IPNet (nil IP)
invalidIPNet := &net.IPNet{
IP: nil,
Mask: net.CIDRMask(24, 32),
}
// This should not panic
nl.Add(invalidIPNet)
// Verify that it doesn't affect Contains results
ip := net.ParseIP("192.168.1.1")
if nl.Contains(ip) {
t.Error("Contains() should return false after adding invalid IPNet")
}
}
func TestNetworkTree_InitializationWithNetworks(t *testing.T) {
networks := []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"2001:db8::/32",
}
nl, err := NewNetworkTree(networks...)
if err != nil {
t.Fatalf("NewNetworkTree() with multiple networks failed: %v", err)
}
// Test that all networks were added correctly
testCases := []struct {
ip string
want bool
}{
{"10.1.2.3", true},
{"172.16.1.1", true},
{"192.168.1.1", true},
{"2001:db8::1", true},
{"8.8.8.8", false},
}
for _, tc := range testCases {
ip := net.ParseIP(tc.ip)
if got := nl.Contains(ip); got != tc.want {
t.Errorf("Contains(%q) = %v, want %v", tc.ip, got, tc.want)
}
}
}
func BenchmarkNetworkTree_Contains(b *testing.B) {
nl, err := NewNetworkTree(
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"2001:db8::/32",
)
if err != nil {
b.Fatalf("NewNetworkTree() failed: %v", err)
}
testIPs := []net.IP{
net.ParseIP("10.1.2.3"),
net.ParseIP("192.168.1.1"),
net.ParseIP("2001:db8::1"),
net.ParseIP("8.8.8.8"),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
ip := testIPs[i%len(testIPs)]
nl.Contains(ip)
}
}
func BenchmarkNetworkTree_NewNetworkTree(b *testing.B) {
cidrs := []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"2001:db8::/32",
}
b.ResetTimer()
for b.Loop() {
_, err := NewNetworkTree(cidrs...)
if err != nil {
b.Fatalf("NewNetworkTree() failed: %v", err)
}
}
}
func BenchmarkNetworkTree_AddCIDR(b *testing.B) {
cidrs := []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"2001:db8::/32",
}
b.ResetTimer()
for b.Loop() {
nl, err := NewNetworkTree()
if err != nil {
b.Fatalf("NewNetworkTree() failed: %v", err)
}
for _, cidr := range cidrs {
nl.AddCIDR(cidr)
}
}
}