Initial import
This commit is contained in:
148
proxy/resolver/resolver.go
Normal file
148
proxy/resolver/resolver.go
Normal file
@@ -0,0 +1,148 @@
|
||||
// Package resolver implements a caching DNS resolver
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultSize = 1024
|
||||
DefaultTTL = 5 * time.Minute
|
||||
DefaultTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultConfig are the defaults for the Default resolver.
|
||||
DefaultConfig = Config{
|
||||
Size: DefaultSize,
|
||||
TTL: DefaultTTL.Seconds(),
|
||||
Timeout: DefaultTimeout.Seconds(),
|
||||
}
|
||||
|
||||
// Default resolver.
|
||||
Default = New(DefaultConfig)
|
||||
)
|
||||
|
||||
type Resolver interface {
|
||||
// Lookup returns resolved IPs for given hostname/ips.
|
||||
Lookup(context.Context, string) ([]string, error)
|
||||
}
|
||||
|
||||
type netResolver struct {
|
||||
resolver *net.Resolver
|
||||
timeout time.Duration
|
||||
noIPv6 bool
|
||||
cache *expirable.LRU[string, []string]
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
// Size is our cache size in number of entries.
|
||||
Size int `hcl:"size,optional"`
|
||||
|
||||
// TTL is the cache time to live in seconds.
|
||||
TTL float64 `hcl:"ttl,optional"`
|
||||
|
||||
// Timeout is the cache timeout in seconds.
|
||||
Timeout float64 `hcl:"timeout,optional"`
|
||||
|
||||
// Server are alternative DNS servers.
|
||||
Server []string `hcl:"server,optional"`
|
||||
|
||||
// NoIPv6 disables IPv6 DNS resolution.
|
||||
NoIPv6 bool `hcl:"noipv6,optional"`
|
||||
}
|
||||
|
||||
func New(config Config) Resolver {
|
||||
var (
|
||||
size = config.Size
|
||||
ttl = time.Duration(float64(time.Second) * config.TTL)
|
||||
timeout = time.Duration(float64(time.Second) * config.Timeout)
|
||||
)
|
||||
if size <= 0 {
|
||||
size = DefaultSize
|
||||
}
|
||||
if ttl <= 0 {
|
||||
ttl = DefaultTTL
|
||||
}
|
||||
if timeout <= 0 {
|
||||
timeout = 0
|
||||
}
|
||||
|
||||
var resolver = new(net.Resolver)
|
||||
if len(config.Server) > 0 {
|
||||
var dialer net.Dialer
|
||||
resolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
server := netutil.EnsurePort(config.Server[rand.IntN(len(config.Server))], "53")
|
||||
return dialer.DialContext(ctx, network, server)
|
||||
}
|
||||
}
|
||||
|
||||
return &netResolver{
|
||||
resolver: resolver,
|
||||
timeout: timeout,
|
||||
noIPv6: config.NoIPv6,
|
||||
cache: expirable.NewLRU[string, []string](size, nil, ttl),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *netResolver) Lookup(ctx context.Context, host string) ([]string, error) {
|
||||
host = strings.ToLower(strings.TrimSpace(host))
|
||||
if hosts, ok := r.cache.Get(host); ok {
|
||||
rand.Shuffle(len(hosts), func(i, j int) {
|
||||
hosts[i], hosts[j] = hosts[j], hosts[i]
|
||||
})
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
hosts, err := r.lookup(ctx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.cache.Add(host, hosts)
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
func (r *netResolver) lookup(ctx context.Context, host string) ([]string, error) {
|
||||
if r.timeout > 0 {
|
||||
var cancel func()
|
||||
ctx, cancel = context.WithTimeout(ctx, r.timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
if net.ParseIP(host) == nil {
|
||||
addrs, err := r.resolver.LookupHost(ctx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.noIPv6 {
|
||||
var addrs4 []string
|
||||
for _, addr := range addrs {
|
||||
if net.ParseIP(addr).To4() != nil {
|
||||
addrs4 = append(addrs4, addr)
|
||||
}
|
||||
}
|
||||
return addrs4, nil
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
addrs, err := r.resolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hosts := make([]string, len(addrs))
|
||||
for i, addr := range addrs {
|
||||
if !r.noIPv6 || addr.IP.To4() != nil {
|
||||
hosts[i] = addr.IP.String()
|
||||
}
|
||||
}
|
||||
return hosts, nil
|
||||
}
|
Reference in New Issue
Block a user