Checkpoint

This commit is contained in:
2025-10-01 15:37:55 +02:00
parent 4a60059ff2
commit 03352e3312
31 changed files with 2611 additions and 384 deletions

View File

@@ -29,6 +29,10 @@ func Port(name string) int {
return 0
}
if i, err := net.LookupPort("tcp", port); err == nil {
return i
}
// TODO: name resolution for ports?
i, _ := strconv.Atoi(port)
return i

View File

@@ -0,0 +1,63 @@
package arp
import (
"net"
"sync"
"time"
"github.com/sirupsen/logrus"
)
func init() {
go func() {
t := time.NewTicker(time.Second * 5)
for {
refresh()
<-t.C
}
}()
}
var table sync.Map
func refresh() {
t, err := lookup()
if err != nil {
logrus.StandardLogger().WithError(err).Warn("arp cache refresh failed")
} else {
for k, v := range t {
logrus.StandardLogger().WithFields(logrus.Fields{
"mac": v,
"ip": k,
}).Debug("Updating ARP cache")
table.Store(k, v)
}
}
}
func Get(addr net.Addr) net.HardwareAddr {
if addr == nil {
logrus.StandardLogger().Trace("No address found, can't lookup IP for MAC")
return nil
}
var ip net.IP
switch addr := addr.(type) {
case *net.TCPAddr:
ip = addr.IP
case *net.UDPAddr:
ip = addr.IP
}
if ip == nil {
logrus.StandardLogger().WithField("addr", addr.String()).Trace("No IP address found, can't lookup MAC")
return nil
}
if v, ok := table.Load(ip.String()); ok {
logrus.StandardLogger().WithField("ip", ip.String()).Tracef("%s is at %s", ip, v.(net.HardwareAddr).String())
return v.(net.HardwareAddr)
}
logrus.StandardLogger().WithField("ip", ip.String()).Trace("No MAC address found")
return nil
}

View File

@@ -0,0 +1,32 @@
package arp
import (
"bufio"
"net"
"os"
"strings"
)
func lookup() (map[string]net.HardwareAddr, error) {
f, err := os.Open("/proc/net/arp")
if err != nil {
return nil, err
}
defer func() { _ = f.Close() }()
t := make(map[string]net.HardwareAddr)
s := bufio.NewScanner(f)
for i := 0; s.Scan(); i++ {
if i == 0 {
continue
}
line := strings.Fields(s.Text())
if len(line) < 4 {
continue
}
if mac, err := net.ParseMAC(line[3]); err == nil {
t[line[0]] = mac
}
}
return t, nil
}

View File

@@ -0,0 +1,37 @@
//go:build !linux
// +build !linux
// ^ Linux isn't Unix anyway :P
package arp
import (
"net"
"os/exec"
"strings"
)
func lookup() (map[string]net.HardwareAddr, error) {
data, err := exec.Command("arp", "-an").Output()
if err != nil {
return nil, err
}
t := make(map[string]net.HardwareAddr)
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Fields(line)
if len(fields) < 3 {
continue
}
// strip brackets around IP
ip := strings.ReplaceAll(fields[1], "(", "")
ip = strings.ReplaceAll(ip, ")", "")
if mac, err := net.ParseMAC(fields[3]); err == nil {
t[ip] = mac
}
}
return t, nil
}

93
internal/netutil/conn.go Normal file
View File

@@ -0,0 +1,93 @@
package netutil
import (
"bufio"
"errors"
"io"
"net"
"syscall"
"time"
)
// BufferedConn uses byte buffers for Read and Write operations on a [net.Conn].
type BufferedConn struct {
net.Conn
Reader *bufio.Reader
Writer *bufio.Writer
}
func NewBufferedConn(c net.Conn) *BufferedConn {
if b, ok := c.(*BufferedConn); ok {
return b
}
return &BufferedConn{
Conn: c,
Reader: bufio.NewReader(c),
Writer: bufio.NewWriter(c),
}
}
func (conn BufferedConn) Read(p []byte) (int, error) { return conn.Reader.Read(p) }
func (conn BufferedConn) Write(p []byte) (int, error) { return conn.Writer.Write(p) }
func (conn BufferedConn) Flush() error { return conn.Writer.Flush() }
func (conn BufferedConn) NetConn() net.Conn { return conn.Conn }
// ReaderConn is a [net.Conn] with a separate [io.Reader] to read from.
type ReaderConn struct {
net.Conn
io.Reader
}
func (conn ReaderConn) Read(p []byte) (int, error) { return conn.Reader.Read(p) }
func (conn ReaderConn) NetConn() net.Conn { return conn.Conn }
// ReadOnlyConn only allows reading, all other operations will fail.
type ReadOnlyConn struct {
io.Reader
}
func (conn ReadOnlyConn) Read(p []byte) (int, error) { return conn.Reader.Read(p) }
func (conn ReadOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe }
func (conn ReadOnlyConn) Close() error { return nil }
func (conn ReadOnlyConn) LocalAddr() net.Addr { return nil }
func (conn ReadOnlyConn) RemoteAddr() net.Addr { return nil }
func (conn ReadOnlyConn) SetDeadline(t time.Time) error { return nil }
func (conn ReadOnlyConn) SetReadDeadline(t time.Time) error { return nil }
func (conn ReadOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
func (conn ReadOnlyConn) NetConn() net.Conn {
if c, ok := conn.Reader.(net.Conn); ok {
return c
}
return nil
}
func IsClosing(err error) bool {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, syscall.ECONNRESET) || err.Error() != "proxy: shutdown" {
return true
}
if err, ok := err.(net.Error); ok && err.Timeout() {
return true
}
// log.Debug().Msgf("not a closing error %T: %#+v", err, err)
return false
}
// WithTimeout is a convenience wrapper for doing network operations that observe a timeout.
func WithTimeout(c net.Conn, timeout time.Duration, do func() error) error {
if timeout <= 0 {
return do()
}
if err := c.SetDeadline(time.Now().Add(timeout)); err != nil {
return err
}
if err := do(); err != nil {
_ = c.SetDeadline(time.Time{})
return err
}
if err := c.SetDeadline(time.Time{}); err != nil {
return err
}
return nil
}

View File

@@ -1,99 +0,0 @@
package netutil
import (
"strings"
"github.com/miekg/dns"
)
type DomainTree struct {
root *domainTreeNode
}
type domainTreeNode struct {
leaf map[string]*domainTreeNode
isEnd bool
}
func NewDomainList(domains ...string) *DomainTree {
tree := &DomainTree{
root: &domainTreeNode{leaf: make(map[string]*domainTreeNode)},
}
for _, domain := range domains {
tree.Add(domain)
}
return tree
}
func (tree *DomainTree) Add(domain string) {
domain = normalizeDomain(domain)
if domain == "" {
return
}
labels := dns.SplitDomainName(domain)
if len(labels) == 0 {
return
}
node := tree.root
for i := len(labels) - 1; i >= 0; i-- {
label := labels[i]
if label == "" {
continue
}
if node.leaf == nil {
node.leaf = make(map[string]*domainTreeNode)
}
if node.leaf[label] == nil {
node.leaf[label] = &domainTreeNode{}
}
node = node.leaf[label]
}
node.isEnd = true
}
func (tree *DomainTree) Contains(domain string) bool {
domain = normalizeDomain(domain)
if domain == "" {
return false
}
labels := dns.SplitDomainName(domain)
if len(labels) == 0 {
return false
}
node := tree.root
for i := len(labels) - 1; i >= 0; i-- {
if node.isEnd {
return true
}
if node.leaf == nil {
return false
}
label := labels[i]
if node = node.leaf[label]; node == nil {
return false
}
}
return node.isEnd
}
func normalizeDomain(domain string) string {
domain = strings.ToLower(strings.TrimSpace(domain))
if domain == "" {
return ""
}
// Remove trailing dot if present, dns.Fqdn will add it back properly
domain = strings.TrimSuffix(domain, ".")
if domain == "" {
return ""
}
return dns.Fqdn(domain)
}

View File

@@ -1,276 +0,0 @@
package netutil
import (
"testing"
)
func TestDomainList(t *testing.T) {
tests := []struct {
name string
domains []string
hostname string
expected bool
}{
// Basic exact matches
{
name: "exact match",
domains: []string{"example.com"},
hostname: "example.com",
expected: true,
},
{
name: "exact match with subdomain in list",
domains: []string{"api.example.com"},
hostname: "api.example.com",
expected: true,
},
// Suffix matching - if domain is in list, all subdomains should match
{
name: "subdomain matches parent domain",
domains: []string{"example.com"},
hostname: "sub.example.com",
expected: true,
},
{
name: "multiple subdomain levels match",
domains: []string{"example.com"},
hostname: "deep.nested.sub.example.com",
expected: true,
},
{
name: "subdomain matches intermediate domain",
domains: []string{"api.example.com", "example.com"},
hostname: "sub.api.example.com",
expected: true,
},
// Multi-level TLDs
{
name: "co.uk domain exact match",
domains: []string{"domain.co.uk"},
hostname: "domain.co.uk",
expected: true,
},
{
name: "subdomain of co.uk domain",
domains: []string{"domain.co.uk"},
hostname: "sub.domain.co.uk",
expected: true,
},
// Case sensitivity
{
name: "case insensitive match",
domains: []string{"Example.COM"},
hostname: "example.com",
expected: true,
},
{
name: "case insensitive hostname",
domains: []string{"example.com"},
hostname: "EXAMPLE.COM",
expected: true,
},
// Trailing dots
{
name: "domain with trailing dot",
domains: []string{"example.com."},
hostname: "example.com",
expected: true,
},
{
name: "hostname with trailing dot",
domains: []string{"example.com"},
hostname: "example.com.",
expected: true,
},
// Non-matches
{
name: "different TLD",
domains: []string{"example.com"},
hostname: "example.org",
expected: false,
},
{
name: "different domain",
domains: []string{"example.com"},
hostname: "test.com",
expected: false,
},
{
name: "partial match but not suffix",
domains: []string{"example.com"},
hostname: "com",
expected: false,
},
{
name: "empty hostname",
domains: []string{"example.com"},
hostname: "",
expected: false,
},
// Multiple domains in list
{
name: "matches first domain in list",
domains: []string{"test.org", "example.com"},
hostname: "example.com",
expected: true,
},
{
name: "matches second domain in list",
domains: []string{"test.org", "example.com"},
hostname: "test.org",
expected: true,
},
{
name: "subdomain matches any domain in list",
domains: []string{"test.org", "example.com"},
hostname: "sub.example.com",
expected: true,
},
// Edge cases
{
name: "empty domain list",
domains: []string{},
hostname: "example.com",
expected: false,
},
{
name: "invalid domain in list",
domains: []string{""},
hostname: "example.com",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
list := NewDomainList(tt.domains...)
result := list.Contains(tt.hostname)
if result != tt.expected {
t.Errorf("Contains(%q) = %v, expected %v (domains: %v)",
tt.hostname, result, tt.expected, tt.domains)
}
})
}
}
func TestDomainList_Performance(t *testing.T) {
// Test with a large number of domains to ensure performance
domains := make([]string, 1000)
for i := 0; i < 1000; i++ {
domains[i] = string(rune('a'+(i%26))) + ".com"
}
domains = append(domains, "example.com") // Add our test domain
list := NewDomainList(domains...)
// These should be fast even with many domains
if !list.Contains("example.com") {
t.Error("Should match exact domain")
}
if !list.Contains("sub.example.com") {
t.Error("Should match subdomain")
}
if list.Contains("notfound.com") {
t.Error("Should not match unrelated domain")
}
}
func TestDomainList_ComplexDomains(t *testing.T) {
domains := []string{
"very.long.domain.name.with.many.labels.com",
"example.co.uk",
"sub.domain.example.com",
"a.b.c.d.e.f.com",
}
list := NewDomainList(domains...)
tests := []struct {
hostname string
expected bool
}{
{"very.long.domain.name.with.many.labels.com", true},
{"sub.very.long.domain.name.with.many.labels.com", true},
{"example.co.uk", true},
{"www.example.co.uk", true},
{"sub.domain.example.com", true},
{"another.sub.domain.example.com", true},
{"a.b.c.d.e.f.com", true},
{"x.a.b.c.d.e.f.com", true},
{"not.matching.com", false},
{"com", false},
{"uk", false},
}
for _, tt := range tests {
t.Run(tt.hostname, func(t *testing.T) {
result := list.Contains(tt.hostname)
if result != tt.expected {
t.Errorf("Contains(%q) = %v, expected %v", tt.hostname, result, tt.expected)
}
})
}
}
func TestDomainList_SpecialCases(t *testing.T) {
t.Run("domain with asterisk treated literally", func(t *testing.T) {
list := NewDomainList("*.example.com")
// The asterisk should be treated as a literal label, not a wildcard
if !list.Contains("*.example.com") {
t.Error("Asterisk should be treated literally, not as wildcard")
}
if list.Contains("test.example.com") {
t.Error("Should not match subdomain with literal asterisk domain")
}
})
t.Run("domains with hyphens and numbers", func(t *testing.T) {
list := NewDomainList("test-123.example.com", "123abc.org")
if !list.Contains("test-123.example.com") {
t.Error("Should match domain with hyphens and numbers")
}
if !list.Contains("sub.test-123.example.com") {
t.Error("Should match subdomain of hyphenated domain")
}
if !list.Contains("123abc.org") {
t.Error("Should match domain starting with numbers")
}
if !list.Contains("www.123abc.org") {
t.Error("Should match subdomain of numeric domain")
}
})
}
func BenchmarkDomainList(b *testing.B) {
// Benchmark with realistic domain list
domains := []string{
"google.com",
"github.com",
"example.org",
"sub.domain.com",
"api.service.co.uk",
"very.long.domain.name.example.com",
}
list := NewDomainList(domains...)
b.ResetTimer()
for b.Loop() {
// Mix of matches and non-matches
list.Contains("sub.example.org")
list.Contains("api.github.com")
list.Contains("nonexistent.com")
list.Contains("deep.nested.sub.domain.com")
list.Contains("service.co.uk")
}
}