Checkpoint
This commit is contained in:
168
internal/cryptutil/tls.go
Normal file
168
internal/cryptutil/tls.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package cryptutil
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
"git.maze.io/maze/styx/internal/sliceutil"
|
||||
"git.maze.io/maze/styx/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
supportedCipherSuites = tls.CipherSuites()
|
||||
supportedCipherSuite = make(map[uint16]bool)
|
||||
supportedVersions = []uint16{
|
||||
tls.VersionTLS13,
|
||||
tls.VersionTLS12,
|
||||
tls.VersionTLS11,
|
||||
tls.VersionTLS10,
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
for _, suite := range supportedCipherSuites {
|
||||
supportedCipherSuite[suite.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
func DecodeTLSCertificate(b []byte) (tls.Certificate, error) {
|
||||
var (
|
||||
cert tls.Certificate
|
||||
chain []*x509.Certificate
|
||||
rest = b
|
||||
block *pem.Block
|
||||
err error
|
||||
)
|
||||
for {
|
||||
if block, rest = pem.Decode(rest); block == nil {
|
||||
break
|
||||
}
|
||||
switch block.Type {
|
||||
case "CERTIFICATE":
|
||||
cert.Certificate = append(cert.Certificate, block.Bytes)
|
||||
if x509Cert, err := x509.ParseCertificate(block.Bytes); err != nil {
|
||||
return tls.Certificate{}, err
|
||||
} else {
|
||||
chain = append(chain, x509Cert)
|
||||
cert.Leaf = x509Cert
|
||||
}
|
||||
case "PRIVATE KEY":
|
||||
if cert.PrivateKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
case "RSA PRIVATE KEY":
|
||||
if cert.PrivateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes); err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
case "EC PRIVATE KEY":
|
||||
if cert.PrivateKey, err = x509.ParseECPrivateKey(block.Bytes); err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
func LoadTLSCertificate(certFile, keyFile string) (tls.Certificate, error) {
|
||||
var (
|
||||
b []byte
|
||||
err error
|
||||
)
|
||||
if strings.Contains(certFile, "-----BEGIN") {
|
||||
logger.StandardLog.Trace("Loading X.509 certificate")
|
||||
b = []byte(certFile)
|
||||
} else {
|
||||
logger.StandardLog.Value("name", certFile).Trace("Loading X.509 certificate")
|
||||
if b, err = os.ReadFile(certFile); err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
}
|
||||
if strings.Contains(keyFile, "-----BEGIN") {
|
||||
logger.StandardLog.Trace("Loading private key")
|
||||
b = append(b, []byte(keyFile)...)
|
||||
} else if keyFile != "" {
|
||||
logger.StandardLog.Value("name", keyFile).Trace("Loading private key")
|
||||
var k []byte
|
||||
if k, err = os.ReadFile(keyFile); err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
b = append(b, k...)
|
||||
}
|
||||
return DecodeTLSCertificate(b)
|
||||
}
|
||||
|
||||
// CheckTLSBuffer is like [CheckTLSHandshake] but restores the original buffered reader.
|
||||
func CheckTLSBuffer(r *bufio.Reader) (bool, error) {
|
||||
b, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err = r.UnreadByte(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return b == 0x16, nil
|
||||
}
|
||||
|
||||
// CheckTLSHandshake checks if the next byte available in r looks like a TLS handshake.
|
||||
func CheckTLSHandshake(r io.Reader) (bool, error) {
|
||||
// Peek first byte received in tunneled connection, client initiates the TLS connection or plain HTTP request
|
||||
b := make([]byte, 1)
|
||||
if _, err := io.ReadFull(r, b); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// TLS handshake: https://tools.ietf.org/html/rfc5246#section-6.2.1
|
||||
return b[0] == 0x16, nil
|
||||
}
|
||||
|
||||
// SniffClientHello uses ReadClientHello to sniff the TLS handshake and returns a new [net.Conn] that
|
||||
// contains the original byte sequences.
|
||||
func SniffClientHello(c net.Conn) (net.Conn, *tls.ClientHelloInfo, error) {
|
||||
b := new(bytes.Buffer)
|
||||
h, err := ReadClientHello(io.TeeReader(c, b))
|
||||
return netutil.ReaderConn{
|
||||
Conn: c,
|
||||
Reader: io.MultiReader(b, c),
|
||||
}, h, err
|
||||
}
|
||||
|
||||
// ReadClientHello reads a TLS client hello message from the TLS handshake.
|
||||
func ReadClientHello(r io.Reader) (*tls.ClientHelloInfo, error) {
|
||||
var hello *tls.ClientHelloInfo
|
||||
err := tls.Server(netutil.ReadOnlyConn{Reader: r}, &tls.Config{
|
||||
GetConfigForClient: func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
hello = new(tls.ClientHelloInfo)
|
||||
*hello = *clientHello
|
||||
return nil, nil
|
||||
},
|
||||
}).Handshake()
|
||||
if hello == nil {
|
||||
return nil, err
|
||||
}
|
||||
return hello, nil
|
||||
}
|
||||
|
||||
// IsSupportedCipherSuite checks if Go can support the cipher suite.
|
||||
func IsSupportedCipherSuite(id uint16) bool {
|
||||
return supportedCipherSuite[id]
|
||||
}
|
||||
|
||||
// IsSupportedVersion checks if Go can support the TLS version.
|
||||
func IsSupportedVersion(version uint16) bool {
|
||||
return slices.Contains(supportedVersions, version)
|
||||
}
|
||||
|
||||
// OnlySecureCipherSuites removes any cipher suite that isn't supported by Go.
|
||||
func OnlySecureCipherSuites(ids []uint16) []uint16 {
|
||||
return sliceutil.Filter(ids, IsSupportedCipherSuite)
|
||||
}
|
@@ -18,7 +18,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
"git.maze.io/maze/styx/logger"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Supported key types.
|
||||
@@ -36,6 +37,62 @@ const (
|
||||
pemTypeAny = "PRIVATE KEY"
|
||||
)
|
||||
|
||||
// DecodeRoots loads all PEM encoded certificates from b.
|
||||
func DecodeRoots(b []byte) (*x509.CertPool, error) {
|
||||
var (
|
||||
pool = x509.NewCertPool()
|
||||
rest = b
|
||||
block *pem.Block
|
||||
)
|
||||
for {
|
||||
if block, rest = pem.Decode(rest); block == nil {
|
||||
break
|
||||
} else if block.Type == pemTypeCert {
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if cert.IsCA {
|
||||
pool.AddCert(cert)
|
||||
}
|
||||
}
|
||||
}
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// LoadRoots loads a certificate authority bundle.
|
||||
func LoadRoots(roots string) (*x509.CertPool, error) {
|
||||
if strings.Contains(roots, "-----BEGIN CERTIFICATE") {
|
||||
logger.StandardLog.Trace("Parsing X.509 certificates")
|
||||
return DecodeRoots([]byte(roots))
|
||||
}
|
||||
var b []byte
|
||||
i, err := os.Stat(roots)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if i.IsDir() {
|
||||
logger.StandardLog.Value("path", roots).Trace("Loading X.509 certificates from *.crt *.pem")
|
||||
for _, ext := range []string{"*.crt", "*.pem"} {
|
||||
var files []string
|
||||
if files, err = filepath.Glob(filepath.Join(roots, ext)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, file := range files {
|
||||
var v []byte
|
||||
if v, err = os.ReadFile(file); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b = append(b, v...)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.StandardLog.Value("path", roots).Trace("Loading X.509 certificates")
|
||||
if b, err = os.ReadFile(roots); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return DecodeRoots(b)
|
||||
}
|
||||
|
||||
// LoadKeyPair loads a certificate and private key, certdata and keydata can be a PEM encoded block or a file.
|
||||
//
|
||||
// If [keydata] is empty, then the private key is assumed to be contained in [certdata].
|
||||
@@ -44,23 +101,23 @@ func LoadKeyPair(certdata, keydata string) (cert *x509.Certificate, key crypto.P
|
||||
keydata = certdata
|
||||
}
|
||||
if strings.Contains(certdata, "-----BEGIN "+pemTypeCert) {
|
||||
log.Trace().Msg("parsing X.509 certificate")
|
||||
logger.StandardLog.Trace("Parsing X.509 certificate")
|
||||
if cert, err = decodePEMBCertificate([]byte(certdata)); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
log.Trace().Str("name", certdata).Msg("loading X.509 certificate")
|
||||
logger.StandardLog.Value("name", certdata).Trace("Loading X.509 certificate")
|
||||
if cert, err = LoadCertificate(certdata); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if strings.Contains(keydata, pemTypeAny+"-----") {
|
||||
log.Trace().Msg("parsing private key")
|
||||
logger.StandardLog.Trace("Parsing private key")
|
||||
if key, err = decodePEMPrivateKey([]byte(keydata)); err != nil {
|
||||
return
|
||||
}
|
||||
} else if key, err = LoadPrivateKey(keydata); err != nil {
|
||||
log.Trace().Str("name", keydata).Msg("loading private key")
|
||||
logger.StandardLog.Value("name", keydata).Trace("Loading private key")
|
||||
return
|
||||
}
|
||||
return
|
||||
|
@@ -1,44 +0,0 @@
|
||||
package log
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// Aliases
|
||||
const (
|
||||
TraceLevel = zerolog.TraceLevel
|
||||
DebugLevel = zerolog.DebugLevel
|
||||
InfoLevel = zerolog.InfoLevel
|
||||
WarnLevel = zerolog.WarnLevel
|
||||
FatalLevel = zerolog.FatalLevel
|
||||
)
|
||||
|
||||
// Aliases
|
||||
type (
|
||||
Event = zerolog.Event
|
||||
Logger = zerolog.Logger
|
||||
)
|
||||
|
||||
// Console logger.
|
||||
var Console = zerolog.New(zerolog.NewConsoleWriter()).With().Timestamp().Logger()
|
||||
|
||||
func SetLevel(level zerolog.Level) {
|
||||
zerolog.SetGlobalLevel(level)
|
||||
//Console = Console.Level(level)
|
||||
}
|
||||
|
||||
func Trace() *Event { return Console.Trace() }
|
||||
func Debug() *Event { return Console.Debug() }
|
||||
func Info() *Event { return Console.Info() }
|
||||
func Warn() *Event { return Console.Warn() }
|
||||
func Error() *Event { return Console.Error() }
|
||||
func Fatal() *Event { return Console.Fatal() }
|
||||
func Panic() *Event { return Console.Panic() }
|
||||
|
||||
func OnCloseError(event *Event, closer io.Closer) {
|
||||
if err := closer.Close(); err != nil {
|
||||
event.Err(err).Msg("close failed")
|
||||
}
|
||||
}
|
@@ -29,6 +29,10 @@ func Port(name string) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
if i, err := net.LookupPort("tcp", port); err == nil {
|
||||
return i
|
||||
}
|
||||
|
||||
// TODO: name resolution for ports?
|
||||
i, _ := strconv.Atoi(port)
|
||||
return i
|
||||
|
63
internal/netutil/arp/arp.go
Normal file
63
internal/netutil/arp/arp.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package arp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func init() {
|
||||
go func() {
|
||||
t := time.NewTicker(time.Second * 5)
|
||||
for {
|
||||
refresh()
|
||||
<-t.C
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
var table sync.Map
|
||||
|
||||
func refresh() {
|
||||
t, err := lookup()
|
||||
if err != nil {
|
||||
logrus.StandardLogger().WithError(err).Warn("arp cache refresh failed")
|
||||
} else {
|
||||
for k, v := range t {
|
||||
logrus.StandardLogger().WithFields(logrus.Fields{
|
||||
"mac": v,
|
||||
"ip": k,
|
||||
}).Debug("Updating ARP cache")
|
||||
table.Store(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Get(addr net.Addr) net.HardwareAddr {
|
||||
if addr == nil {
|
||||
logrus.StandardLogger().Trace("No address found, can't lookup IP for MAC")
|
||||
return nil
|
||||
}
|
||||
|
||||
var ip net.IP
|
||||
switch addr := addr.(type) {
|
||||
case *net.TCPAddr:
|
||||
ip = addr.IP
|
||||
case *net.UDPAddr:
|
||||
ip = addr.IP
|
||||
}
|
||||
if ip == nil {
|
||||
logrus.StandardLogger().WithField("addr", addr.String()).Trace("No IP address found, can't lookup MAC")
|
||||
return nil
|
||||
}
|
||||
|
||||
if v, ok := table.Load(ip.String()); ok {
|
||||
logrus.StandardLogger().WithField("ip", ip.String()).Tracef("%s is at %s", ip, v.(net.HardwareAddr).String())
|
||||
return v.(net.HardwareAddr)
|
||||
}
|
||||
|
||||
logrus.StandardLogger().WithField("ip", ip.String()).Trace("No MAC address found")
|
||||
return nil
|
||||
}
|
32
internal/netutil/arp/arp_linux.go
Normal file
32
internal/netutil/arp/arp_linux.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package arp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func lookup() (map[string]net.HardwareAddr, error) {
|
||||
f, err := os.Open("/proc/net/arp")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
t := make(map[string]net.HardwareAddr)
|
||||
s := bufio.NewScanner(f)
|
||||
for i := 0; s.Scan(); i++ {
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
line := strings.Fields(s.Text())
|
||||
if len(line) < 4 {
|
||||
continue
|
||||
}
|
||||
if mac, err := net.ParseMAC(line[3]); err == nil {
|
||||
t[line[0]] = mac
|
||||
}
|
||||
}
|
||||
return t, nil
|
||||
}
|
37
internal/netutil/arp/arp_unix.go
Normal file
37
internal/netutil/arp/arp_unix.go
Normal file
@@ -0,0 +1,37 @@
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
// ^ Linux isn't Unix anyway :P
|
||||
|
||||
package arp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func lookup() (map[string]net.HardwareAddr, error) {
|
||||
data, err := exec.Command("arp", "-an").Output()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := make(map[string]net.HardwareAddr)
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 3 {
|
||||
continue
|
||||
}
|
||||
|
||||
// strip brackets around IP
|
||||
ip := strings.ReplaceAll(fields[1], "(", "")
|
||||
ip = strings.ReplaceAll(ip, ")", "")
|
||||
|
||||
if mac, err := net.ParseMAC(fields[3]); err == nil {
|
||||
t[ip] = mac
|
||||
}
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
93
internal/netutil/conn.go
Normal file
93
internal/netutil/conn.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BufferedConn uses byte buffers for Read and Write operations on a [net.Conn].
|
||||
type BufferedConn struct {
|
||||
net.Conn
|
||||
Reader *bufio.Reader
|
||||
Writer *bufio.Writer
|
||||
}
|
||||
|
||||
func NewBufferedConn(c net.Conn) *BufferedConn {
|
||||
if b, ok := c.(*BufferedConn); ok {
|
||||
return b
|
||||
}
|
||||
return &BufferedConn{
|
||||
Conn: c,
|
||||
Reader: bufio.NewReader(c),
|
||||
Writer: bufio.NewWriter(c),
|
||||
}
|
||||
}
|
||||
|
||||
func (conn BufferedConn) Read(p []byte) (int, error) { return conn.Reader.Read(p) }
|
||||
func (conn BufferedConn) Write(p []byte) (int, error) { return conn.Writer.Write(p) }
|
||||
func (conn BufferedConn) Flush() error { return conn.Writer.Flush() }
|
||||
func (conn BufferedConn) NetConn() net.Conn { return conn.Conn }
|
||||
|
||||
// ReaderConn is a [net.Conn] with a separate [io.Reader] to read from.
|
||||
type ReaderConn struct {
|
||||
net.Conn
|
||||
io.Reader
|
||||
}
|
||||
|
||||
func (conn ReaderConn) Read(p []byte) (int, error) { return conn.Reader.Read(p) }
|
||||
func (conn ReaderConn) NetConn() net.Conn { return conn.Conn }
|
||||
|
||||
// ReadOnlyConn only allows reading, all other operations will fail.
|
||||
type ReadOnlyConn struct {
|
||||
io.Reader
|
||||
}
|
||||
|
||||
func (conn ReadOnlyConn) Read(p []byte) (int, error) { return conn.Reader.Read(p) }
|
||||
func (conn ReadOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe }
|
||||
func (conn ReadOnlyConn) Close() error { return nil }
|
||||
func (conn ReadOnlyConn) LocalAddr() net.Addr { return nil }
|
||||
func (conn ReadOnlyConn) RemoteAddr() net.Addr { return nil }
|
||||
func (conn ReadOnlyConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (conn ReadOnlyConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (conn ReadOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
func (conn ReadOnlyConn) NetConn() net.Conn {
|
||||
if c, ok := conn.Reader.(net.Conn); ok {
|
||||
return c
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsClosing(err error) bool {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, syscall.ECONNRESET) || err.Error() != "proxy: shutdown" {
|
||||
return true
|
||||
}
|
||||
if err, ok := err.(net.Error); ok && err.Timeout() {
|
||||
return true
|
||||
}
|
||||
// log.Debug().Msgf("not a closing error %T: %#+v", err, err)
|
||||
return false
|
||||
}
|
||||
|
||||
// WithTimeout is a convenience wrapper for doing network operations that observe a timeout.
|
||||
func WithTimeout(c net.Conn, timeout time.Duration, do func() error) error {
|
||||
if timeout <= 0 {
|
||||
return do()
|
||||
}
|
||||
|
||||
if err := c.SetDeadline(time.Now().Add(timeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := do(); err != nil {
|
||||
_ = c.SetDeadline(time.Time{})
|
||||
return err
|
||||
}
|
||||
if err := c.SetDeadline(time.Time{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -1,99 +0,0 @@
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type DomainTree struct {
|
||||
root *domainTreeNode
|
||||
}
|
||||
|
||||
type domainTreeNode struct {
|
||||
leaf map[string]*domainTreeNode
|
||||
isEnd bool
|
||||
}
|
||||
|
||||
func NewDomainList(domains ...string) *DomainTree {
|
||||
tree := &DomainTree{
|
||||
root: &domainTreeNode{leaf: make(map[string]*domainTreeNode)},
|
||||
}
|
||||
for _, domain := range domains {
|
||||
tree.Add(domain)
|
||||
}
|
||||
return tree
|
||||
}
|
||||
|
||||
func (tree *DomainTree) Add(domain string) {
|
||||
domain = normalizeDomain(domain)
|
||||
if domain == "" {
|
||||
return
|
||||
}
|
||||
|
||||
labels := dns.SplitDomainName(domain)
|
||||
if len(labels) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
node := tree.root
|
||||
for i := len(labels) - 1; i >= 0; i-- {
|
||||
label := labels[i]
|
||||
if label == "" {
|
||||
continue
|
||||
}
|
||||
if node.leaf == nil {
|
||||
node.leaf = make(map[string]*domainTreeNode)
|
||||
}
|
||||
if node.leaf[label] == nil {
|
||||
node.leaf[label] = &domainTreeNode{}
|
||||
}
|
||||
node = node.leaf[label]
|
||||
}
|
||||
node.isEnd = true
|
||||
}
|
||||
|
||||
func (tree *DomainTree) Contains(domain string) bool {
|
||||
domain = normalizeDomain(domain)
|
||||
if domain == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
labels := dns.SplitDomainName(domain)
|
||||
if len(labels) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
node := tree.root
|
||||
for i := len(labels) - 1; i >= 0; i-- {
|
||||
if node.isEnd {
|
||||
return true
|
||||
}
|
||||
|
||||
if node.leaf == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
label := labels[i]
|
||||
if node = node.leaf[label]; node == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return node.isEnd
|
||||
}
|
||||
|
||||
func normalizeDomain(domain string) string {
|
||||
domain = strings.ToLower(strings.TrimSpace(domain))
|
||||
if domain == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Remove trailing dot if present, dns.Fqdn will add it back properly
|
||||
domain = strings.TrimSuffix(domain, ".")
|
||||
|
||||
if domain == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return dns.Fqdn(domain)
|
||||
}
|
@@ -1,276 +0,0 @@
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDomainList(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
domains []string
|
||||
hostname string
|
||||
expected bool
|
||||
}{
|
||||
// Basic exact matches
|
||||
{
|
||||
name: "exact match",
|
||||
domains: []string{"example.com"},
|
||||
hostname: "example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "exact match with subdomain in list",
|
||||
domains: []string{"api.example.com"},
|
||||
hostname: "api.example.com",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Suffix matching - if domain is in list, all subdomains should match
|
||||
{
|
||||
name: "subdomain matches parent domain",
|
||||
domains: []string{"example.com"},
|
||||
hostname: "sub.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "multiple subdomain levels match",
|
||||
domains: []string{"example.com"},
|
||||
hostname: "deep.nested.sub.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "subdomain matches intermediate domain",
|
||||
domains: []string{"api.example.com", "example.com"},
|
||||
hostname: "sub.api.example.com",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Multi-level TLDs
|
||||
{
|
||||
name: "co.uk domain exact match",
|
||||
domains: []string{"domain.co.uk"},
|
||||
hostname: "domain.co.uk",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "subdomain of co.uk domain",
|
||||
domains: []string{"domain.co.uk"},
|
||||
hostname: "sub.domain.co.uk",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Case sensitivity
|
||||
{
|
||||
name: "case insensitive match",
|
||||
domains: []string{"Example.COM"},
|
||||
hostname: "example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "case insensitive hostname",
|
||||
domains: []string{"example.com"},
|
||||
hostname: "EXAMPLE.COM",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Trailing dots
|
||||
{
|
||||
name: "domain with trailing dot",
|
||||
domains: []string{"example.com."},
|
||||
hostname: "example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "hostname with trailing dot",
|
||||
domains: []string{"example.com"},
|
||||
hostname: "example.com.",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Non-matches
|
||||
{
|
||||
name: "different TLD",
|
||||
domains: []string{"example.com"},
|
||||
hostname: "example.org",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "different domain",
|
||||
domains: []string{"example.com"},
|
||||
hostname: "test.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "partial match but not suffix",
|
||||
domains: []string{"example.com"},
|
||||
hostname: "com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty hostname",
|
||||
domains: []string{"example.com"},
|
||||
hostname: "",
|
||||
expected: false,
|
||||
},
|
||||
|
||||
// Multiple domains in list
|
||||
{
|
||||
name: "matches first domain in list",
|
||||
domains: []string{"test.org", "example.com"},
|
||||
hostname: "example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "matches second domain in list",
|
||||
domains: []string{"test.org", "example.com"},
|
||||
hostname: "test.org",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "subdomain matches any domain in list",
|
||||
domains: []string{"test.org", "example.com"},
|
||||
hostname: "sub.example.com",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Edge cases
|
||||
{
|
||||
name: "empty domain list",
|
||||
domains: []string{},
|
||||
hostname: "example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid domain in list",
|
||||
domains: []string{""},
|
||||
hostname: "example.com",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
list := NewDomainList(tt.domains...)
|
||||
result := list.Contains(tt.hostname)
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Contains(%q) = %v, expected %v (domains: %v)",
|
||||
tt.hostname, result, tt.expected, tt.domains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDomainList_Performance(t *testing.T) {
|
||||
// Test with a large number of domains to ensure performance
|
||||
domains := make([]string, 1000)
|
||||
for i := 0; i < 1000; i++ {
|
||||
domains[i] = string(rune('a'+(i%26))) + ".com"
|
||||
}
|
||||
domains = append(domains, "example.com") // Add our test domain
|
||||
|
||||
list := NewDomainList(domains...)
|
||||
|
||||
// These should be fast even with many domains
|
||||
if !list.Contains("example.com") {
|
||||
t.Error("Should match exact domain")
|
||||
}
|
||||
if !list.Contains("sub.example.com") {
|
||||
t.Error("Should match subdomain")
|
||||
}
|
||||
if list.Contains("notfound.com") {
|
||||
t.Error("Should not match unrelated domain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDomainList_ComplexDomains(t *testing.T) {
|
||||
domains := []string{
|
||||
"very.long.domain.name.with.many.labels.com",
|
||||
"example.co.uk",
|
||||
"sub.domain.example.com",
|
||||
"a.b.c.d.e.f.com",
|
||||
}
|
||||
|
||||
list := NewDomainList(domains...)
|
||||
|
||||
tests := []struct {
|
||||
hostname string
|
||||
expected bool
|
||||
}{
|
||||
{"very.long.domain.name.with.many.labels.com", true},
|
||||
{"sub.very.long.domain.name.with.many.labels.com", true},
|
||||
{"example.co.uk", true},
|
||||
{"www.example.co.uk", true},
|
||||
{"sub.domain.example.com", true},
|
||||
{"another.sub.domain.example.com", true},
|
||||
{"a.b.c.d.e.f.com", true},
|
||||
{"x.a.b.c.d.e.f.com", true},
|
||||
{"not.matching.com", false},
|
||||
{"com", false},
|
||||
{"uk", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.hostname, func(t *testing.T) {
|
||||
result := list.Contains(tt.hostname)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Contains(%q) = %v, expected %v", tt.hostname, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDomainList_SpecialCases(t *testing.T) {
|
||||
t.Run("domain with asterisk treated literally", func(t *testing.T) {
|
||||
list := NewDomainList("*.example.com")
|
||||
|
||||
// The asterisk should be treated as a literal label, not a wildcard
|
||||
if !list.Contains("*.example.com") {
|
||||
t.Error("Asterisk should be treated literally, not as wildcard")
|
||||
}
|
||||
if list.Contains("test.example.com") {
|
||||
t.Error("Should not match subdomain with literal asterisk domain")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("domains with hyphens and numbers", func(t *testing.T) {
|
||||
list := NewDomainList("test-123.example.com", "123abc.org")
|
||||
|
||||
if !list.Contains("test-123.example.com") {
|
||||
t.Error("Should match domain with hyphens and numbers")
|
||||
}
|
||||
if !list.Contains("sub.test-123.example.com") {
|
||||
t.Error("Should match subdomain of hyphenated domain")
|
||||
}
|
||||
if !list.Contains("123abc.org") {
|
||||
t.Error("Should match domain starting with numbers")
|
||||
}
|
||||
if !list.Contains("www.123abc.org") {
|
||||
t.Error("Should match subdomain of numeric domain")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkDomainList(b *testing.B) {
|
||||
// Benchmark with realistic domain list
|
||||
domains := []string{
|
||||
"google.com",
|
||||
"github.com",
|
||||
"example.org",
|
||||
"sub.domain.com",
|
||||
"api.service.co.uk",
|
||||
"very.long.domain.name.example.com",
|
||||
}
|
||||
|
||||
list := NewDomainList(domains...)
|
||||
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
// Mix of matches and non-matches
|
||||
list.Contains("sub.example.org")
|
||||
list.Contains("api.github.com")
|
||||
list.Contains("nonexistent.com")
|
||||
list.Contains("deep.nested.sub.domain.com")
|
||||
list.Contains("service.co.uk")
|
||||
}
|
||||
}
|
70
internal/sliceutil/filter.go
Normal file
70
internal/sliceutil/filter.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package sliceutil
|
||||
|
||||
// AppendFilter takes two slice, 's' as source and 'd' as destination
|
||||
// and a predicate function, then applies it to each element of 's',
|
||||
// when 'p' returns true it appends the element into d, otherwise omit it.
|
||||
func AppendFilter[T any](s []T, d *[]T, p func(T) bool) {
|
||||
for _, e := range s {
|
||||
if p(e) {
|
||||
*d = append(*d, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AssignFilter takes two slice, 's' as source and 'd' as destination
|
||||
// and a predicate function, then applies it to each element of 's'.
|
||||
// 'd' slice will have the same capacity of 's' but starts with 0 length.
|
||||
// When 'p' returns true it appends the element into d, otherwise omit it.
|
||||
func AssignFilter[T any](s []T, d *[]T, p func(T) bool) {
|
||||
if cap(*d) == 0 {
|
||||
*d = make([]T, 0, len(s))
|
||||
}
|
||||
for _, e := range s {
|
||||
if p(e) {
|
||||
*d = append(*d, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DisposeFilter takes two slice, 's' as source and 'd' as destination
|
||||
// and a predicate function, then applies it to each element of 's'.
|
||||
// 'd' slice will share the exact same memory address of 's'.
|
||||
// When 'p' returns true it appends the element into d, otherwise omit it.
|
||||
// Then disposes the 's'. IMPORTANT: 's' cannot be used again.
|
||||
func DisposeFilter[T any](s []T, d *[]T, p func(T) bool) {
|
||||
*d = s[:0]
|
||||
for _, e := range s {
|
||||
if p(e) {
|
||||
*d = append(*d, e)
|
||||
}
|
||||
}
|
||||
var NIL T
|
||||
for i := len(*d); i < len(s); i++ {
|
||||
s[i] = NIL
|
||||
}
|
||||
}
|
||||
|
||||
// InPlaceFilter takes a slice and a predicate function, then applies it to each element of 's'.
|
||||
// When 'p' returns true it assign the value to the last index plus one where p was true, otherwise omit it.
|
||||
func InPlaceFilter[T any](s *[]T, p func(T) bool) {
|
||||
i := 0
|
||||
for _, e := range *s {
|
||||
if p(e) {
|
||||
(*s)[i] = e
|
||||
i++
|
||||
}
|
||||
}
|
||||
*s = (*s)[:i]
|
||||
}
|
||||
|
||||
// Filter takes a slice and a predcate function, then applies it to each element of 's'.
|
||||
// When 'p' returns true it appends the value to the output slice.
|
||||
func Filter[T any](s []T, p func(T) bool) []T {
|
||||
o := make([]T, 0, len(s))
|
||||
for _, e := range s {
|
||||
if p(e) {
|
||||
o = append(o, e)
|
||||
}
|
||||
}
|
||||
return o
|
||||
}
|
Reference in New Issue
Block a user