Initial import
This commit is contained in:
114
internal/cryptutil/key.go
Normal file
114
internal/cryptutil/key.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package cryptutil
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
"crypto/sha1"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
type publicKeyer interface {
|
||||
Public() any
|
||||
}
|
||||
|
||||
// PublicKey returns the public part of a crypto.PrivateKey.
|
||||
func PublicKey(key crypto.PrivateKey) any {
|
||||
switch key := key.(type) {
|
||||
case ed25519.PublicKey:
|
||||
return key
|
||||
case ed25519.PrivateKey:
|
||||
return key.Public()
|
||||
case ecdsa.PublicKey:
|
||||
return &key
|
||||
case *ecdsa.PublicKey:
|
||||
return key
|
||||
case *ecdsa.PrivateKey:
|
||||
return &key.PublicKey
|
||||
case rsa.PublicKey:
|
||||
return &key
|
||||
case *rsa.PublicKey:
|
||||
return key
|
||||
case *rsa.PrivateKey:
|
||||
return &key.PublicKey
|
||||
default:
|
||||
if p, ok := key.(publicKeyer); ok {
|
||||
return p.Public()
|
||||
}
|
||||
panic(fmt.Sprintf("don't know how to extract a public key from %T", key))
|
||||
}
|
||||
}
|
||||
|
||||
// LoadPrivateKey loads a private key from disk.
|
||||
func LoadPrivateKey(name string) (crypto.PrivateKey, error) {
|
||||
b, err := os.ReadFile(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decodePEMPrivateKey(b)
|
||||
}
|
||||
|
||||
func decodePEMPrivateKey(b []byte) (key crypto.PrivateKey, err error) {
|
||||
var (
|
||||
rest = b
|
||||
block *pem.Block
|
||||
)
|
||||
for {
|
||||
if block, rest = pem.Decode(rest); block == nil {
|
||||
return nil, errors.New("mitm: no private key PEM block could be decoded")
|
||||
}
|
||||
switch block.Type {
|
||||
case "EC PRIVATE KEY":
|
||||
return x509.ParseECPrivateKey(block.Bytes)
|
||||
|
||||
case "RSA PRIVATE KEY":
|
||||
return x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
|
||||
case "PRIVATE KEY":
|
||||
return x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateKeyID generates the PKIX public key ID.
|
||||
func GenerateKeyID(key crypto.PublicKey) []byte {
|
||||
b, err := x509.MarshalPKIXPublicKey(key)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return sha1.New().Sum(b)
|
||||
}
|
||||
|
||||
func keyType(key any) string {
|
||||
switch key := key.(type) {
|
||||
case ed25519.PrivateKey:
|
||||
return "ed25519"
|
||||
case *ecdsa.PrivateKey:
|
||||
return "ecdsa (" + curveType(key.Curve) + ")"
|
||||
case *rsa.PrivateKey:
|
||||
return "rsa"
|
||||
default:
|
||||
return fmt.Sprintf("%T", key)
|
||||
}
|
||||
}
|
||||
|
||||
func curveType(c elliptic.Curve) string {
|
||||
switch c {
|
||||
case elliptic.P224():
|
||||
return "p224"
|
||||
case elliptic.P256():
|
||||
return "p256"
|
||||
case elliptic.P384():
|
||||
return "p384"
|
||||
case elliptic.P521():
|
||||
return "p521"
|
||||
default:
|
||||
return fmt.Sprintf("%T", c)
|
||||
}
|
||||
}
|
242
internal/cryptutil/x509.go
Normal file
242
internal/cryptutil/x509.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package cryptutil
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
)
|
||||
|
||||
// Supported key types.
|
||||
const (
|
||||
TypeRSA = "rsa"
|
||||
TypeECDSA = "ecdsa"
|
||||
TypeED25519 = "ed25519"
|
||||
)
|
||||
|
||||
// Supported PEM block types.
|
||||
const (
|
||||
pemTypeCert = "CERTIFICATE"
|
||||
pemTypeRSA = "RSA PRIVATE KEY"
|
||||
pemTypeECDSA = "EC PRIVATE KEY"
|
||||
pemTypeAny = "PRIVATE KEY"
|
||||
)
|
||||
|
||||
// LoadKeyPair loads a certificate and private key, certdata and keydata can be a PEM encoded block or a file.
|
||||
//
|
||||
// If [keydata] is empty, then the private key is assumed to be contained in [certdata].
|
||||
func LoadKeyPair(certdata, keydata string) (cert *x509.Certificate, key crypto.PrivateKey, err error) {
|
||||
if keydata == "" {
|
||||
keydata = certdata
|
||||
}
|
||||
if strings.Contains(certdata, "-----BEGIN "+pemTypeCert) {
|
||||
log.Trace().Msg("parsing X.509 certificate")
|
||||
if cert, err = decodePEMBCertificate([]byte(certdata)); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
log.Trace().Str("name", certdata).Msg("loading X.509 certificate")
|
||||
if cert, err = LoadCertificate(certdata); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if strings.Contains(keydata, pemTypeAny+"-----") {
|
||||
log.Trace().Msg("parsing private key")
|
||||
if key, err = decodePEMPrivateKey([]byte(keydata)); err != nil {
|
||||
return
|
||||
}
|
||||
} else if key, err = LoadPrivateKey(keydata); err != nil {
|
||||
log.Trace().Str("name", keydata).Msg("loading private key")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// SaveKeyPair saves a certificate and private key in PEM encoding.
|
||||
//
|
||||
// If [keyFile] is empty, then the private key is stored in [certFile] alongside the certificate.
|
||||
//
|
||||
// Attempts are made to use secure file modes for files that contains private keys.
|
||||
func SaveKeyPair(cert *x509.Certificate, key crypto.PrivateKey, certFile, keyFile string) (err error) {
|
||||
var (
|
||||
keyDER []byte
|
||||
keyPEMType = pemTypeAny
|
||||
)
|
||||
switch key := key.(type) {
|
||||
case *ecdsa.PrivateKey:
|
||||
if keyDER, err = x509.MarshalECPrivateKey(key); err != nil {
|
||||
return
|
||||
}
|
||||
keyPEMType = pemTypeECDSA
|
||||
case ed25519.PrivateKey:
|
||||
if keyDER, err = x509.MarshalPKCS8PrivateKey(key); err != nil {
|
||||
return
|
||||
}
|
||||
case *rsa.PrivateKey:
|
||||
keyDER = x509.MarshalPKCS1PrivateKey(key)
|
||||
keyPEMType = pemTypeRSA
|
||||
default:
|
||||
return fmt.Errorf("mitm: don't know how to marshal %T", key)
|
||||
}
|
||||
|
||||
var certf, keyf *os.File
|
||||
if certf, err = os.OpenFile(certFile, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0o644); err != nil {
|
||||
return
|
||||
}
|
||||
defer func() { _ = certf.Close() }()
|
||||
|
||||
if filepath.Clean(certFile) == filepath.Clean(keyFile) || keyFile == "" {
|
||||
if err = certf.Chmod(0o600); err != nil {
|
||||
return
|
||||
}
|
||||
keyf, keyFile = certf, certFile
|
||||
} else {
|
||||
if keyf, err = os.OpenFile(keyFile, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0o600); err != nil {
|
||||
return
|
||||
}
|
||||
defer func() { _ = keyf.Close() }()
|
||||
}
|
||||
|
||||
log.Debug().Str("file", certFile).Msg("saving X.509 certificate")
|
||||
if err = pem.Encode(certf, &pem.Block{Type: pemTypeCert, Bytes: cert.Raw}); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().Str("fiile", keyFile).Msg("saving private key")
|
||||
if err = pem.Encode(keyf, &pem.Block{Type: keyPEMType, Bytes: keyDER}); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GenerateKeyPair generates a private key and self-signed certificate.
|
||||
func GenerateKeyPair(name pkix.Name, days int, keyType string, keyBits int) (cert *x509.Certificate, key crypto.PrivateKey, err error) {
|
||||
if key, err = GeneratePrivateKey(keyType, keyBits); err != nil {
|
||||
return
|
||||
}
|
||||
if cert, err = GenerateCertificateAuthority(name, days, key); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func GenerateCertificateAuthority(name pkix.Name, days int, key crypto.PrivateKey) (cert *x509.Certificate, err error) {
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mtim: failed to generate serial number: %w", err)
|
||||
}
|
||||
|
||||
keyUsage := x509.KeyUsageCertSign
|
||||
if _, ok := key.(*rsa.PrivateKey); ok {
|
||||
keyUsage |= x509.KeyUsageDigitalSignature
|
||||
}
|
||||
|
||||
notBefore := roundToDay(time.Now())
|
||||
notAfter := notBefore.Add(time.Duration(days) * 24 * time.Hour)
|
||||
|
||||
template := &x509.Certificate{
|
||||
Subject: name,
|
||||
SerialNumber: serialNumber,
|
||||
KeyUsage: keyUsage,
|
||||
SubjectKeyId: GenerateKeyID(key),
|
||||
IsCA: true,
|
||||
BasicConstraintsValid: true,
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("name", name.CommonName).
|
||||
Int("days", days).
|
||||
Str("key", keyType(key)).
|
||||
Str("serial", serialNumber.String()).
|
||||
Msg("generating self-signed CA certificate")
|
||||
|
||||
var der []byte
|
||||
if der, err = x509.CreateCertificate(rand.Reader, template, template, PublicKey(key), key); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return x509.ParseCertificate(der)
|
||||
}
|
||||
|
||||
func GeneratePrivateKey(kind string, bits int) (key crypto.PrivateKey, err error) {
|
||||
switch strings.ToLower(kind) {
|
||||
case TypeRSA, "":
|
||||
if bits == 0 {
|
||||
bits = 2048
|
||||
}
|
||||
log.Trace().Int("bits", bits).Str("type", TypeRSA).Msg("generating private key")
|
||||
return rsa.GenerateKey(rand.Reader, bits)
|
||||
|
||||
case TypeECDSA, "ec", "ecc":
|
||||
if bits == 0 {
|
||||
bits = 256
|
||||
}
|
||||
|
||||
var curve elliptic.Curve
|
||||
switch bits {
|
||||
case 224:
|
||||
curve = elliptic.P224()
|
||||
case 256:
|
||||
curve = elliptic.P256()
|
||||
case 384:
|
||||
curve = elliptic.P384()
|
||||
case 521:
|
||||
curve = elliptic.P521()
|
||||
default:
|
||||
return nil, fmt.Errorf("mitm: elliptic curve %d bits not supported", bits)
|
||||
}
|
||||
log.Trace().Int("bits", bits).Str("type", TypeECDSA).Msg("generating private key")
|
||||
return ecdsa.GenerateKey(curve, rand.Reader)
|
||||
|
||||
case TypeED25519:
|
||||
log.Trace().Str("type", TypeED25519).Msg("generating ED25519 private key")
|
||||
_, key, err = ed25519.GenerateKey(rand.Reader)
|
||||
return
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("mitm: don't know how to generate %s private key", kind)
|
||||
}
|
||||
}
|
||||
|
||||
func decodePEMBCertificate(b []byte) (cert *x509.Certificate, err error) {
|
||||
var (
|
||||
rest = b
|
||||
block *pem.Block
|
||||
)
|
||||
for {
|
||||
if block, rest = pem.Decode(rest); block == nil {
|
||||
return nil, errors.New("mitm: no CERTIFICATE PEM block could be decoded")
|
||||
} else if block.Type == "CERTIFICATE" {
|
||||
return x509.ParseCertificate(block.Bytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func LoadCertificate(name string) (*x509.Certificate, error) {
|
||||
b, err := os.ReadFile(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decodePEMBCertificate(b)
|
||||
}
|
||||
|
||||
func roundToDay(t time.Time) time.Time {
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||
}
|
44
internal/log/log.go
Normal file
44
internal/log/log.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package log
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// Aliases
|
||||
const (
|
||||
TraceLevel = zerolog.TraceLevel
|
||||
DebugLevel = zerolog.DebugLevel
|
||||
InfoLevel = zerolog.InfoLevel
|
||||
WarnLevel = zerolog.WarnLevel
|
||||
FatalLevel = zerolog.FatalLevel
|
||||
)
|
||||
|
||||
// Aliases
|
||||
type (
|
||||
Event = zerolog.Event
|
||||
Logger = zerolog.Logger
|
||||
)
|
||||
|
||||
// Console logger.
|
||||
var Console = zerolog.New(zerolog.NewConsoleWriter()).With().Timestamp().Logger()
|
||||
|
||||
func SetLevel(level zerolog.Level) {
|
||||
zerolog.SetGlobalLevel(level)
|
||||
//Console = Console.Level(level)
|
||||
}
|
||||
|
||||
func Trace() *Event { return Console.Trace() }
|
||||
func Debug() *Event { return Console.Debug() }
|
||||
func Info() *Event { return Console.Info() }
|
||||
func Warn() *Event { return Console.Warn() }
|
||||
func Error() *Event { return Console.Error() }
|
||||
func Fatal() *Event { return Console.Fatal() }
|
||||
func Panic() *Event { return Console.Panic() }
|
||||
|
||||
func OnCloseError(event *Event, closer io.Closer) {
|
||||
if err := closer.Close(); err != nil {
|
||||
event.Err(err).Msg("close failed")
|
||||
}
|
||||
}
|
35
internal/netutil/addr.go
Normal file
35
internal/netutil/addr.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// EnsurePort makes sure the address in [host] contains a port.
|
||||
func EnsurePort(host, port string) string {
|
||||
if _, _, err := net.SplitHostPort(host); err == nil {
|
||||
return host
|
||||
}
|
||||
return net.JoinHostPort(host, port)
|
||||
}
|
||||
|
||||
// Host returns the bare host (without port).
|
||||
func Host(name string) string {
|
||||
host, _, err := net.SplitHostPort(name)
|
||||
if err == nil {
|
||||
return host
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
// Port returns the port number.
|
||||
func Port(name string) int {
|
||||
_, port, err := net.SplitHostPort(name)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
// TODO: name resolution for ports?
|
||||
i, _ := strconv.Atoi(port)
|
||||
return i
|
||||
}
|
99
internal/netutil/domain.go
Normal file
99
internal/netutil/domain.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type DomainTree struct {
|
||||
root *domainTreeNode
|
||||
}
|
||||
|
||||
type domainTreeNode struct {
|
||||
leaf map[string]*domainTreeNode
|
||||
isEnd bool
|
||||
}
|
||||
|
||||
func NewDomainList(domains ...string) *DomainTree {
|
||||
tree := &DomainTree{
|
||||
root: &domainTreeNode{leaf: make(map[string]*domainTreeNode)},
|
||||
}
|
||||
for _, domain := range domains {
|
||||
tree.Add(domain)
|
||||
}
|
||||
return tree
|
||||
}
|
||||
|
||||
func (tree *DomainTree) Add(domain string) {
|
||||
domain = normalizeDomain(domain)
|
||||
if domain == "" {
|
||||
return
|
||||
}
|
||||
|
||||
labels := dns.SplitDomainName(domain)
|
||||
if len(labels) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
node := tree.root
|
||||
for i := len(labels) - 1; i >= 0; i-- {
|
||||
label := labels[i]
|
||||
if label == "" {
|
||||
continue
|
||||
}
|
||||
if node.leaf == nil {
|
||||
node.leaf = make(map[string]*domainTreeNode)
|
||||
}
|
||||
if node.leaf[label] == nil {
|
||||
node.leaf[label] = &domainTreeNode{}
|
||||
}
|
||||
node = node.leaf[label]
|
||||
}
|
||||
node.isEnd = true
|
||||
}
|
||||
|
||||
func (tree *DomainTree) Contains(domain string) bool {
|
||||
domain = normalizeDomain(domain)
|
||||
if domain == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
labels := dns.SplitDomainName(domain)
|
||||
if len(labels) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
node := tree.root
|
||||
for i := len(labels) - 1; i >= 0; i-- {
|
||||
if node.isEnd {
|
||||
return true
|
||||
}
|
||||
|
||||
if node.leaf == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
label := labels[i]
|
||||
if node = node.leaf[label]; node == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return node.isEnd
|
||||
}
|
||||
|
||||
func normalizeDomain(domain string) string {
|
||||
domain = strings.ToLower(strings.TrimSpace(domain))
|
||||
if domain == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Remove trailing dot if present, dns.Fqdn will add it back properly
|
||||
domain = strings.TrimSuffix(domain, ".")
|
||||
|
||||
if domain == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return dns.Fqdn(domain)
|
||||
}
|
276
internal/netutil/domain_test.go
Normal file
276
internal/netutil/domain_test.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDomainList(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
domains []string
|
||||
hostname string
|
||||
expected bool
|
||||
}{
|
||||
// Basic exact matches
|
||||
{
|
||||
name: "exact match",
|
||||
domains: []string{"example.com"},
|
||||
hostname: "example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "exact match with subdomain in list",
|
||||
domains: []string{"api.example.com"},
|
||||
hostname: "api.example.com",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Suffix matching - if domain is in list, all subdomains should match
|
||||
{
|
||||
name: "subdomain matches parent domain",
|
||||
domains: []string{"example.com"},
|
||||
hostname: "sub.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "multiple subdomain levels match",
|
||||
domains: []string{"example.com"},
|
||||
hostname: "deep.nested.sub.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "subdomain matches intermediate domain",
|
||||
domains: []string{"api.example.com", "example.com"},
|
||||
hostname: "sub.api.example.com",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Multi-level TLDs
|
||||
{
|
||||
name: "co.uk domain exact match",
|
||||
domains: []string{"domain.co.uk"},
|
||||
hostname: "domain.co.uk",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "subdomain of co.uk domain",
|
||||
domains: []string{"domain.co.uk"},
|
||||
hostname: "sub.domain.co.uk",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Case sensitivity
|
||||
{
|
||||
name: "case insensitive match",
|
||||
domains: []string{"Example.COM"},
|
||||
hostname: "example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "case insensitive hostname",
|
||||
domains: []string{"example.com"},
|
||||
hostname: "EXAMPLE.COM",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Trailing dots
|
||||
{
|
||||
name: "domain with trailing dot",
|
||||
domains: []string{"example.com."},
|
||||
hostname: "example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "hostname with trailing dot",
|
||||
domains: []string{"example.com"},
|
||||
hostname: "example.com.",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Non-matches
|
||||
{
|
||||
name: "different TLD",
|
||||
domains: []string{"example.com"},
|
||||
hostname: "example.org",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "different domain",
|
||||
domains: []string{"example.com"},
|
||||
hostname: "test.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "partial match but not suffix",
|
||||
domains: []string{"example.com"},
|
||||
hostname: "com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty hostname",
|
||||
domains: []string{"example.com"},
|
||||
hostname: "",
|
||||
expected: false,
|
||||
},
|
||||
|
||||
// Multiple domains in list
|
||||
{
|
||||
name: "matches first domain in list",
|
||||
domains: []string{"test.org", "example.com"},
|
||||
hostname: "example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "matches second domain in list",
|
||||
domains: []string{"test.org", "example.com"},
|
||||
hostname: "test.org",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "subdomain matches any domain in list",
|
||||
domains: []string{"test.org", "example.com"},
|
||||
hostname: "sub.example.com",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Edge cases
|
||||
{
|
||||
name: "empty domain list",
|
||||
domains: []string{},
|
||||
hostname: "example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid domain in list",
|
||||
domains: []string{""},
|
||||
hostname: "example.com",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
list := NewDomainList(tt.domains...)
|
||||
result := list.Contains(tt.hostname)
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Contains(%q) = %v, expected %v (domains: %v)",
|
||||
tt.hostname, result, tt.expected, tt.domains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDomainList_Performance(t *testing.T) {
|
||||
// Test with a large number of domains to ensure performance
|
||||
domains := make([]string, 1000)
|
||||
for i := 0; i < 1000; i++ {
|
||||
domains[i] = string(rune('a'+(i%26))) + ".com"
|
||||
}
|
||||
domains = append(domains, "example.com") // Add our test domain
|
||||
|
||||
list := NewDomainList(domains...)
|
||||
|
||||
// These should be fast even with many domains
|
||||
if !list.Contains("example.com") {
|
||||
t.Error("Should match exact domain")
|
||||
}
|
||||
if !list.Contains("sub.example.com") {
|
||||
t.Error("Should match subdomain")
|
||||
}
|
||||
if list.Contains("notfound.com") {
|
||||
t.Error("Should not match unrelated domain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDomainList_ComplexDomains(t *testing.T) {
|
||||
domains := []string{
|
||||
"very.long.domain.name.with.many.labels.com",
|
||||
"example.co.uk",
|
||||
"sub.domain.example.com",
|
||||
"a.b.c.d.e.f.com",
|
||||
}
|
||||
|
||||
list := NewDomainList(domains...)
|
||||
|
||||
tests := []struct {
|
||||
hostname string
|
||||
expected bool
|
||||
}{
|
||||
{"very.long.domain.name.with.many.labels.com", true},
|
||||
{"sub.very.long.domain.name.with.many.labels.com", true},
|
||||
{"example.co.uk", true},
|
||||
{"www.example.co.uk", true},
|
||||
{"sub.domain.example.com", true},
|
||||
{"another.sub.domain.example.com", true},
|
||||
{"a.b.c.d.e.f.com", true},
|
||||
{"x.a.b.c.d.e.f.com", true},
|
||||
{"not.matching.com", false},
|
||||
{"com", false},
|
||||
{"uk", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.hostname, func(t *testing.T) {
|
||||
result := list.Contains(tt.hostname)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Contains(%q) = %v, expected %v", tt.hostname, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDomainList_SpecialCases(t *testing.T) {
|
||||
t.Run("domain with asterisk treated literally", func(t *testing.T) {
|
||||
list := NewDomainList("*.example.com")
|
||||
|
||||
// The asterisk should be treated as a literal label, not a wildcard
|
||||
if !list.Contains("*.example.com") {
|
||||
t.Error("Asterisk should be treated literally, not as wildcard")
|
||||
}
|
||||
if list.Contains("test.example.com") {
|
||||
t.Error("Should not match subdomain with literal asterisk domain")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("domains with hyphens and numbers", func(t *testing.T) {
|
||||
list := NewDomainList("test-123.example.com", "123abc.org")
|
||||
|
||||
if !list.Contains("test-123.example.com") {
|
||||
t.Error("Should match domain with hyphens and numbers")
|
||||
}
|
||||
if !list.Contains("sub.test-123.example.com") {
|
||||
t.Error("Should match subdomain of hyphenated domain")
|
||||
}
|
||||
if !list.Contains("123abc.org") {
|
||||
t.Error("Should match domain starting with numbers")
|
||||
}
|
||||
if !list.Contains("www.123abc.org") {
|
||||
t.Error("Should match subdomain of numeric domain")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkDomainList(b *testing.B) {
|
||||
// Benchmark with realistic domain list
|
||||
domains := []string{
|
||||
"google.com",
|
||||
"github.com",
|
||||
"example.org",
|
||||
"sub.domain.com",
|
||||
"api.service.co.uk",
|
||||
"very.long.domain.name.example.com",
|
||||
}
|
||||
|
||||
list := NewDomainList(domains...)
|
||||
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
// Mix of matches and non-matches
|
||||
list.Contains("sub.example.org")
|
||||
list.Contains("api.github.com")
|
||||
list.Contains("nonexistent.com")
|
||||
list.Contains("deep.nested.sub.domain.com")
|
||||
list.Contains("service.co.uk")
|
||||
}
|
||||
}
|
44
internal/netutil/network.go
Normal file
44
internal/netutil/network.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/yl2chen/cidranger"
|
||||
)
|
||||
|
||||
type NetworkTree struct {
|
||||
ranger cidranger.Ranger
|
||||
}
|
||||
|
||||
func NewNetworkTree(networks ...string) (*NetworkTree, error) {
|
||||
tree := &NetworkTree{
|
||||
ranger: cidranger.NewPCTrieRanger(),
|
||||
}
|
||||
for _, cidr := range networks {
|
||||
if err := tree.AddCIDR(cidr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return tree, nil
|
||||
}
|
||||
|
||||
func (tree *NetworkTree) Add(ipnet *net.IPNet) {
|
||||
if ipnet == nil {
|
||||
return
|
||||
}
|
||||
tree.ranger.Insert(cidranger.NewBasicRangerEntry(*ipnet))
|
||||
}
|
||||
|
||||
func (tree *NetworkTree) AddCIDR(cidr string) error {
|
||||
_, ipnet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tree.ranger.Insert(cidranger.NewBasicRangerEntry(*ipnet))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tree *NetworkTree) Contains(ip net.IP) bool {
|
||||
contains, _ := tree.ranger.Contains(ip)
|
||||
return contains
|
||||
}
|
410
internal/netutil/network_test.go
Normal file
410
internal/netutil/network_test.go
Normal file
@@ -0,0 +1,410 @@
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewNetworkTree(t *testing.T) {
|
||||
// Test empty creation
|
||||
nl, err := NewNetworkTree()
|
||||
if err != nil {
|
||||
t.Fatalf("NewNetworkTree() failed: %v", err)
|
||||
}
|
||||
if nl == nil {
|
||||
t.Fatal("NewNetworkTree() returned nil")
|
||||
}
|
||||
if nl.ranger == nil {
|
||||
t.Error("NetworkTree ranger should not be nil")
|
||||
}
|
||||
|
||||
// Test creation with networks
|
||||
nl, err = NewNetworkTree("192.168.1.0/24", "10.0.0.0/8")
|
||||
if err != nil {
|
||||
t.Fatalf("NewNetworkTree() with networks failed: %v", err)
|
||||
}
|
||||
if nl == nil {
|
||||
t.Fatal("NewNetworkTree() with networks returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewNetworkTree_InvalidNetworks(t *testing.T) {
|
||||
// Test with invalid network
|
||||
_, err := NewNetworkTree("invalid-cidr")
|
||||
if err == nil {
|
||||
t.Error("NewNetworkTree() with invalid CIDR should have failed")
|
||||
}
|
||||
|
||||
// Test with mix of valid and invalid networks
|
||||
_, err = NewNetworkTree("192.168.1.0/24", "invalid-cidr", "10.0.0.0/8")
|
||||
if err == nil {
|
||||
t.Error("NewNetworkTree() with mixed valid/invalid CIDRs should have failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkTree_AddCIDR_Valid(t *testing.T) {
|
||||
nl, err := NewNetworkTree()
|
||||
if err != nil {
|
||||
t.Fatalf("NewNetworkTree() failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
cidr string
|
||||
desc string
|
||||
}{
|
||||
{"192.168.1.0/24", "IPv4 CIDR"},
|
||||
{"10.0.0.0/8", "IPv4 large range"},
|
||||
{"2001:db8::/32", "IPv6 CIDR"},
|
||||
{"::1/128", "IPv6 localhost"},
|
||||
{"0.0.0.0/0", "IPv4 entire internet"},
|
||||
{"::/0", "IPv6 entire internet"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
if err := nl.AddCIDR(tt.cidr); err != nil {
|
||||
t.Errorf("AddCIDR(%q) failed: %v", tt.cidr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkTree_AddCIDR_Invalid(t *testing.T) {
|
||||
nl, err := NewNetworkTree()
|
||||
if err != nil {
|
||||
t.Fatalf("NewNetworkTree() failed: %v", err)
|
||||
}
|
||||
|
||||
invalidCIDRs := []string{
|
||||
"invalid-cidr",
|
||||
"192.168.1.1", // missing mask
|
||||
"192.168.1.0/33", // invalid mask for IPv4
|
||||
"2001:db8::/129", // invalid mask for IPv6
|
||||
"",
|
||||
"not-an-ip/24",
|
||||
}
|
||||
|
||||
for _, cidr := range invalidCIDRs {
|
||||
t.Run(cidr, func(t *testing.T) {
|
||||
if err := nl.AddCIDR(cidr); err == nil {
|
||||
t.Errorf("AddCIDR(%q) should have failed but didn't", cidr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkTree_Add(t *testing.T) {
|
||||
nl, err := NewNetworkTree()
|
||||
if err != nil {
|
||||
t.Fatalf("NewNetworkTree() failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
cidr string
|
||||
desc string
|
||||
}{
|
||||
{"192.168.1.0/24", "IPv4 network"},
|
||||
{"2001:db8::/32", "IPv6 network"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
_, ipNet, err := net.ParseCIDR(tt.cidr)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseCIDR failed: %v", err)
|
||||
}
|
||||
|
||||
// Should not panic
|
||||
nl.Add(ipNet)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkTree_Contains_IPv4(t *testing.T) {
|
||||
nl, err := NewNetworkTree("192.168.1.0/24", "10.0.0.0/8", "172.16.0.0/12")
|
||||
if err != nil {
|
||||
t.Fatalf("NewNetworkTree() failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
ip string
|
||||
want bool
|
||||
desc string
|
||||
}{
|
||||
// IPs that should match
|
||||
{"192.168.1.1", true, "in 192.168.1.0/24"},
|
||||
{"192.168.1.255", true, "broadcast in 192.168.1.0/24"},
|
||||
{"10.0.0.1", true, "in 10.0.0.0/8"},
|
||||
{"10.255.255.255", true, "max in 10.0.0.0/8"},
|
||||
{"172.16.0.1", true, "in 172.16.0.0/12"},
|
||||
{"172.31.255.255", true, "max in 172.16.0.0/12"},
|
||||
|
||||
// IPs that should not match
|
||||
{"192.168.2.1", false, "outside 192.168.1.0/24"},
|
||||
{"11.0.0.1", false, "outside 10.0.0.0/8"},
|
||||
{"172.32.0.1", false, "outside 172.16.0.0/12"},
|
||||
{"8.8.8.8", false, "public DNS"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if ip == nil {
|
||||
t.Fatalf("ParseIP(%q) returned nil", tt.ip)
|
||||
}
|
||||
|
||||
got := nl.Contains(ip)
|
||||
if got != tt.want {
|
||||
t.Errorf("Contains(%q) = %v, want %v", tt.ip, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkTree_Contains_IPv6(t *testing.T) {
|
||||
nl, err := NewNetworkTree("2001:db8::/32", "2001:db8:abcd::/48", "::1/128")
|
||||
if err != nil {
|
||||
t.Fatalf("NewNetworkTree() failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
ip string
|
||||
want bool
|
||||
desc string
|
||||
}{
|
||||
// IPs that should match
|
||||
{"2001:db8::1", true, "in 2001:db8::/32"},
|
||||
{"2001:db8:ffff:ffff:ffff:ffff:ffff:ffff", true, "max in 2001:db8::/32"},
|
||||
{"2001:db8:abcd::1", true, "in 2001:db8:abcd::/48"},
|
||||
{"::1", true, "localhost"},
|
||||
|
||||
// IPs that should not match
|
||||
{"2001:db9::1", false, "outside 2001:db8::/32"},
|
||||
{"2001:db9:abcd::1", false, "outside 2001:db8:abcd::/48"},
|
||||
{"::2", false, "outside ::1/128"},
|
||||
{"2001:4860:4860::8888", false, "public DNS"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if ip == nil {
|
||||
t.Fatalf("ParseIP(%q) returned nil", tt.ip)
|
||||
}
|
||||
|
||||
got := nl.Contains(ip)
|
||||
if got != tt.want {
|
||||
t.Errorf("Contains(%q) = %v, want %v", tt.ip, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkTree_Contains_EdgeCases(t *testing.T) {
|
||||
nl, err := NewNetworkTree()
|
||||
if err != nil {
|
||||
t.Fatalf("NewNetworkTree() failed: %v", err)
|
||||
}
|
||||
|
||||
// Test with nil IP
|
||||
if nl.Contains(nil) != false {
|
||||
t.Error("Contains(nil) should return false")
|
||||
}
|
||||
|
||||
// Test with empty list
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
if nl.Contains(ip) != false {
|
||||
t.Error("Contains() on empty list should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkTree_Contains_OverlappingRanges(t *testing.T) {
|
||||
nl, err := NewNetworkTree("192.168.0.0/16", "192.168.1.0/24", "192.168.1.128/25")
|
||||
if err != nil {
|
||||
t.Fatalf("NewNetworkTree() failed: %v", err)
|
||||
}
|
||||
|
||||
// All these should match because we have overlapping ranges
|
||||
tests := []string{
|
||||
"192.168.1.1",
|
||||
"192.168.1.129",
|
||||
"192.168.2.1",
|
||||
}
|
||||
|
||||
for _, ipStr := range tests {
|
||||
t.Run(ipStr, func(t *testing.T) {
|
||||
ip := net.ParseIP(ipStr)
|
||||
if !nl.Contains(ip) {
|
||||
t.Errorf("Contains(%q) should return true for overlapping ranges", ipStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkTree_Contains_EntireInternet(t *testing.T) {
|
||||
nl, err := NewNetworkTree("0.0.0.0/0", "::/0")
|
||||
if err != nil {
|
||||
t.Fatalf("NewNetworkTree() failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
ip string
|
||||
desc string
|
||||
}{
|
||||
{"192.168.1.1", "IPv4 private"},
|
||||
{"8.8.8.8", "IPv4 public"},
|
||||
{"2001:db8::1", "IPv6"},
|
||||
{"::1", "IPv6 localhost"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if !nl.Contains(ip) {
|
||||
t.Errorf("Contains(%q) should return true for entire internet range", tt.ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkTree_MixedIPv4AndIPv6(t *testing.T) {
|
||||
nl, err := NewNetworkTree("192.168.1.0/24", "2001:db8::/32")
|
||||
if err != nil {
|
||||
t.Fatalf("NewNetworkTree() failed: %v", err)
|
||||
}
|
||||
|
||||
// Test IPv4 in IPv6 format (should still work due to normalization)
|
||||
ipv4InIPv6 := net.ParseIP("::ffff:192.168.1.1") // IPv4-mapped IPv6
|
||||
if !nl.Contains(ipv4InIPv6) {
|
||||
t.Error("Contains() should handle IPv4-mapped IPv6 addresses")
|
||||
}
|
||||
|
||||
// Regular IPv4 should work
|
||||
ipv4 := net.ParseIP("192.168.1.1")
|
||||
if !nl.Contains(ipv4) {
|
||||
t.Error("Contains() should handle regular IPv4 addresses")
|
||||
}
|
||||
|
||||
// IPv6 should work
|
||||
ipv6 := net.ParseIP("2001:db8::1")
|
||||
if !nl.Contains(ipv6) {
|
||||
t.Error("Contains() should handle IPv6 addresses")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkTree_Add_InvalidIPNet(t *testing.T) {
|
||||
nl, err := NewNetworkTree()
|
||||
if err != nil {
|
||||
t.Fatalf("NewNetworkTree() failed: %v", err)
|
||||
}
|
||||
|
||||
// Create an invalid IPNet (nil IP)
|
||||
invalidIPNet := &net.IPNet{
|
||||
IP: nil,
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
}
|
||||
|
||||
// This should not panic
|
||||
nl.Add(invalidIPNet)
|
||||
|
||||
// Verify that it doesn't affect Contains results
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
if nl.Contains(ip) {
|
||||
t.Error("Contains() should return false after adding invalid IPNet")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkTree_InitializationWithNetworks(t *testing.T) {
|
||||
networks := []string{
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"2001:db8::/32",
|
||||
}
|
||||
|
||||
nl, err := NewNetworkTree(networks...)
|
||||
if err != nil {
|
||||
t.Fatalf("NewNetworkTree() with multiple networks failed: %v", err)
|
||||
}
|
||||
|
||||
// Test that all networks were added correctly
|
||||
testCases := []struct {
|
||||
ip string
|
||||
want bool
|
||||
}{
|
||||
{"10.1.2.3", true},
|
||||
{"172.16.1.1", true},
|
||||
{"192.168.1.1", true},
|
||||
{"2001:db8::1", true},
|
||||
{"8.8.8.8", false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
ip := net.ParseIP(tc.ip)
|
||||
if got := nl.Contains(ip); got != tc.want {
|
||||
t.Errorf("Contains(%q) = %v, want %v", tc.ip, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNetworkTree_Contains(b *testing.B) {
|
||||
nl, err := NewNetworkTree(
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"2001:db8::/32",
|
||||
)
|
||||
if err != nil {
|
||||
b.Fatalf("NewNetworkTree() failed: %v", err)
|
||||
}
|
||||
|
||||
testIPs := []net.IP{
|
||||
net.ParseIP("10.1.2.3"),
|
||||
net.ParseIP("192.168.1.1"),
|
||||
net.ParseIP("2001:db8::1"),
|
||||
net.ParseIP("8.8.8.8"),
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
ip := testIPs[i%len(testIPs)]
|
||||
nl.Contains(ip)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNetworkTree_NewNetworkTree(b *testing.B) {
|
||||
cidrs := []string{
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"2001:db8::/32",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
_, err := NewNetworkTree(cidrs...)
|
||||
if err != nil {
|
||||
b.Fatalf("NewNetworkTree() failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNetworkTree_AddCIDR(b *testing.B) {
|
||||
cidrs := []string{
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"2001:db8::/32",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
nl, err := NewNetworkTree()
|
||||
if err != nil {
|
||||
b.Fatalf("NewNetworkTree() failed: %v", err)
|
||||
}
|
||||
for _, cidr := range cidrs {
|
||||
nl.AddCIDR(cidr)
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user