package netutil import ( "net" "testing" ) func TestNewNetworkTree(t *testing.T) { // Test empty creation nl, err := NewNetworkTree() if err != nil { t.Fatalf("NewNetworkTree() failed: %v", err) } if nl == nil { t.Fatal("NewNetworkTree() returned nil") } if nl.ranger == nil { t.Error("NetworkTree ranger should not be nil") } // Test creation with networks nl, err = NewNetworkTree("192.168.1.0/24", "10.0.0.0/8") if err != nil { t.Fatalf("NewNetworkTree() with networks failed: %v", err) } if nl == nil { t.Fatal("NewNetworkTree() with networks returned nil") } } func TestNewNetworkTree_InvalidNetworks(t *testing.T) { // Test with invalid network _, err := NewNetworkTree("invalid-cidr") if err == nil { t.Error("NewNetworkTree() with invalid CIDR should have failed") } // Test with mix of valid and invalid networks _, err = NewNetworkTree("192.168.1.0/24", "invalid-cidr", "10.0.0.0/8") if err == nil { t.Error("NewNetworkTree() with mixed valid/invalid CIDRs should have failed") } } func TestNetworkTree_AddCIDR_Valid(t *testing.T) { nl, err := NewNetworkTree() if err != nil { t.Fatalf("NewNetworkTree() failed: %v", err) } tests := []struct { cidr string desc string }{ {"192.168.1.0/24", "IPv4 CIDR"}, {"10.0.0.0/8", "IPv4 large range"}, {"2001:db8::/32", "IPv6 CIDR"}, {"::1/128", "IPv6 localhost"}, {"0.0.0.0/0", "IPv4 entire internet"}, {"::/0", "IPv6 entire internet"}, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { if err := nl.AddCIDR(tt.cidr); err != nil { t.Errorf("AddCIDR(%q) failed: %v", tt.cidr, err) } }) } } func TestNetworkTree_AddCIDR_Invalid(t *testing.T) { nl, err := NewNetworkTree() if err != nil { t.Fatalf("NewNetworkTree() failed: %v", err) } invalidCIDRs := []string{ "invalid-cidr", "192.168.1.1", // missing mask "192.168.1.0/33", // invalid mask for IPv4 "2001:db8::/129", // invalid mask for IPv6 "", "not-an-ip/24", } for _, cidr := range invalidCIDRs { t.Run(cidr, func(t *testing.T) { if err := nl.AddCIDR(cidr); err == nil { t.Errorf("AddCIDR(%q) should have failed but didn't", cidr) } }) } } func TestNetworkTree_Add(t *testing.T) { nl, err := NewNetworkTree() if err != nil { t.Fatalf("NewNetworkTree() failed: %v", err) } tests := []struct { cidr string desc string }{ {"192.168.1.0/24", "IPv4 network"}, {"2001:db8::/32", "IPv6 network"}, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { _, ipNet, err := net.ParseCIDR(tt.cidr) if err != nil { t.Fatalf("ParseCIDR failed: %v", err) } // Should not panic nl.Add(ipNet) }) } } func TestNetworkTree_Contains_IPv4(t *testing.T) { nl, err := NewNetworkTree("192.168.1.0/24", "10.0.0.0/8", "172.16.0.0/12") if err != nil { t.Fatalf("NewNetworkTree() failed: %v", err) } tests := []struct { ip string want bool desc string }{ // IPs that should match {"192.168.1.1", true, "in 192.168.1.0/24"}, {"192.168.1.255", true, "broadcast in 192.168.1.0/24"}, {"10.0.0.1", true, "in 10.0.0.0/8"}, {"10.255.255.255", true, "max in 10.0.0.0/8"}, {"172.16.0.1", true, "in 172.16.0.0/12"}, {"172.31.255.255", true, "max in 172.16.0.0/12"}, // IPs that should not match {"192.168.2.1", false, "outside 192.168.1.0/24"}, {"11.0.0.1", false, "outside 10.0.0.0/8"}, {"172.32.0.1", false, "outside 172.16.0.0/12"}, {"8.8.8.8", false, "public DNS"}, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { ip := net.ParseIP(tt.ip) if ip == nil { t.Fatalf("ParseIP(%q) returned nil", tt.ip) } got := nl.Contains(ip) if got != tt.want { t.Errorf("Contains(%q) = %v, want %v", tt.ip, got, tt.want) } }) } } func TestNetworkTree_Contains_IPv6(t *testing.T) { nl, err := NewNetworkTree("2001:db8::/32", "2001:db8:abcd::/48", "::1/128") if err != nil { t.Fatalf("NewNetworkTree() failed: %v", err) } tests := []struct { ip string want bool desc string }{ // IPs that should match {"2001:db8::1", true, "in 2001:db8::/32"}, {"2001:db8:ffff:ffff:ffff:ffff:ffff:ffff", true, "max in 2001:db8::/32"}, {"2001:db8:abcd::1", true, "in 2001:db8:abcd::/48"}, {"::1", true, "localhost"}, // IPs that should not match {"2001:db9::1", false, "outside 2001:db8::/32"}, {"2001:db9:abcd::1", false, "outside 2001:db8:abcd::/48"}, {"::2", false, "outside ::1/128"}, {"2001:4860:4860::8888", false, "public DNS"}, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { ip := net.ParseIP(tt.ip) if ip == nil { t.Fatalf("ParseIP(%q) returned nil", tt.ip) } got := nl.Contains(ip) if got != tt.want { t.Errorf("Contains(%q) = %v, want %v", tt.ip, got, tt.want) } }) } } func TestNetworkTree_Contains_EdgeCases(t *testing.T) { nl, err := NewNetworkTree() if err != nil { t.Fatalf("NewNetworkTree() failed: %v", err) } // Test with nil IP if nl.Contains(nil) != false { t.Error("Contains(nil) should return false") } // Test with empty list ip := net.ParseIP("192.168.1.1") if nl.Contains(ip) != false { t.Error("Contains() on empty list should return false") } } func TestNetworkTree_Contains_OverlappingRanges(t *testing.T) { nl, err := NewNetworkTree("192.168.0.0/16", "192.168.1.0/24", "192.168.1.128/25") if err != nil { t.Fatalf("NewNetworkTree() failed: %v", err) } // All these should match because we have overlapping ranges tests := []string{ "192.168.1.1", "192.168.1.129", "192.168.2.1", } for _, ipStr := range tests { t.Run(ipStr, func(t *testing.T) { ip := net.ParseIP(ipStr) if !nl.Contains(ip) { t.Errorf("Contains(%q) should return true for overlapping ranges", ipStr) } }) } } func TestNetworkTree_Contains_EntireInternet(t *testing.T) { nl, err := NewNetworkTree("0.0.0.0/0", "::/0") if err != nil { t.Fatalf("NewNetworkTree() failed: %v", err) } tests := []struct { ip string desc string }{ {"192.168.1.1", "IPv4 private"}, {"8.8.8.8", "IPv4 public"}, {"2001:db8::1", "IPv6"}, {"::1", "IPv6 localhost"}, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { ip := net.ParseIP(tt.ip) if !nl.Contains(ip) { t.Errorf("Contains(%q) should return true for entire internet range", tt.ip) } }) } } func TestNetworkTree_MixedIPv4AndIPv6(t *testing.T) { nl, err := NewNetworkTree("192.168.1.0/24", "2001:db8::/32") if err != nil { t.Fatalf("NewNetworkTree() failed: %v", err) } // Test IPv4 in IPv6 format (should still work due to normalization) ipv4InIPv6 := net.ParseIP("::ffff:192.168.1.1") // IPv4-mapped IPv6 if !nl.Contains(ipv4InIPv6) { t.Error("Contains() should handle IPv4-mapped IPv6 addresses") } // Regular IPv4 should work ipv4 := net.ParseIP("192.168.1.1") if !nl.Contains(ipv4) { t.Error("Contains() should handle regular IPv4 addresses") } // IPv6 should work ipv6 := net.ParseIP("2001:db8::1") if !nl.Contains(ipv6) { t.Error("Contains() should handle IPv6 addresses") } } func TestNetworkTree_Add_InvalidIPNet(t *testing.T) { nl, err := NewNetworkTree() if err != nil { t.Fatalf("NewNetworkTree() failed: %v", err) } // Create an invalid IPNet (nil IP) invalidIPNet := &net.IPNet{ IP: nil, Mask: net.CIDRMask(24, 32), } // This should not panic nl.Add(invalidIPNet) // Verify that it doesn't affect Contains results ip := net.ParseIP("192.168.1.1") if nl.Contains(ip) { t.Error("Contains() should return false after adding invalid IPNet") } } func TestNetworkTree_InitializationWithNetworks(t *testing.T) { networks := []string{ "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "2001:db8::/32", } nl, err := NewNetworkTree(networks...) if err != nil { t.Fatalf("NewNetworkTree() with multiple networks failed: %v", err) } // Test that all networks were added correctly testCases := []struct { ip string want bool }{ {"10.1.2.3", true}, {"172.16.1.1", true}, {"192.168.1.1", true}, {"2001:db8::1", true}, {"8.8.8.8", false}, } for _, tc := range testCases { ip := net.ParseIP(tc.ip) if got := nl.Contains(ip); got != tc.want { t.Errorf("Contains(%q) = %v, want %v", tc.ip, got, tc.want) } } } func BenchmarkNetworkTree_Contains(b *testing.B) { nl, err := NewNetworkTree( "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "2001:db8::/32", ) if err != nil { b.Fatalf("NewNetworkTree() failed: %v", err) } testIPs := []net.IP{ net.ParseIP("10.1.2.3"), net.ParseIP("192.168.1.1"), net.ParseIP("2001:db8::1"), net.ParseIP("8.8.8.8"), } b.ResetTimer() for i := 0; i < b.N; i++ { ip := testIPs[i%len(testIPs)] nl.Contains(ip) } } func BenchmarkNetworkTree_NewNetworkTree(b *testing.B) { cidrs := []string{ "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "2001:db8::/32", } b.ResetTimer() for b.Loop() { _, err := NewNetworkTree(cidrs...) if err != nil { b.Fatalf("NewNetworkTree() failed: %v", err) } } } func BenchmarkNetworkTree_AddCIDR(b *testing.B) { cidrs := []string{ "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "2001:db8::/32", } b.ResetTimer() for b.Loop() { nl, err := NewNetworkTree() if err != nil { b.Fatalf("NewNetworkTree() failed: %v", err) } for _, cidr := range cidrs { nl.AddCIDR(cidr) } } }