149 lines
3.2 KiB
Go
149 lines
3.2 KiB
Go
// Package resolver implements a caching DNS resolver
|
|
package resolver
|
|
|
|
import (
|
|
"context"
|
|
"math/rand/v2"
|
|
"net"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.maze.io/maze/styx/internal/netutil"
|
|
"github.com/hashicorp/golang-lru/v2/expirable"
|
|
)
|
|
|
|
const (
|
|
DefaultSize = 1024
|
|
DefaultTTL = 5 * time.Minute
|
|
DefaultTimeout = 10 * time.Second
|
|
)
|
|
|
|
var (
|
|
// DefaultConfig are the defaults for the Default resolver.
|
|
DefaultConfig = Config{
|
|
Size: DefaultSize,
|
|
TTL: DefaultTTL.Seconds(),
|
|
Timeout: DefaultTimeout.Seconds(),
|
|
}
|
|
|
|
// Default resolver.
|
|
Default = New(DefaultConfig)
|
|
)
|
|
|
|
type Resolver interface {
|
|
// Lookup returns resolved IPs for given hostname/ips.
|
|
Lookup(context.Context, string) ([]string, error)
|
|
}
|
|
|
|
type netResolver struct {
|
|
resolver *net.Resolver
|
|
timeout time.Duration
|
|
noIPv6 bool
|
|
cache *expirable.LRU[string, []string]
|
|
}
|
|
|
|
type Config struct {
|
|
// Size is our cache size in number of entries.
|
|
Size int `hcl:"size,optional"`
|
|
|
|
// TTL is the cache time to live in seconds.
|
|
TTL float64 `hcl:"ttl,optional"`
|
|
|
|
// Timeout is the cache timeout in seconds.
|
|
Timeout float64 `hcl:"timeout,optional"`
|
|
|
|
// Server are alternative DNS servers.
|
|
Server []string `hcl:"server,optional"`
|
|
|
|
// NoIPv6 disables IPv6 DNS resolution.
|
|
NoIPv6 bool `hcl:"noipv6,optional"`
|
|
}
|
|
|
|
func New(config Config) Resolver {
|
|
var (
|
|
size = config.Size
|
|
ttl = time.Duration(float64(time.Second) * config.TTL)
|
|
timeout = time.Duration(float64(time.Second) * config.Timeout)
|
|
)
|
|
if size <= 0 {
|
|
size = DefaultSize
|
|
}
|
|
if ttl <= 0 {
|
|
ttl = DefaultTTL
|
|
}
|
|
if timeout <= 0 {
|
|
timeout = 0
|
|
}
|
|
|
|
var resolver = new(net.Resolver)
|
|
if len(config.Server) > 0 {
|
|
var dialer net.Dialer
|
|
resolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
server := netutil.EnsurePort(config.Server[rand.IntN(len(config.Server))], "53")
|
|
return dialer.DialContext(ctx, network, server)
|
|
}
|
|
}
|
|
|
|
return &netResolver{
|
|
resolver: resolver,
|
|
timeout: timeout,
|
|
noIPv6: config.NoIPv6,
|
|
cache: expirable.NewLRU[string, []string](size, nil, ttl),
|
|
}
|
|
}
|
|
|
|
func (r *netResolver) Lookup(ctx context.Context, host string) ([]string, error) {
|
|
host = strings.ToLower(strings.TrimSpace(host))
|
|
if hosts, ok := r.cache.Get(host); ok {
|
|
rand.Shuffle(len(hosts), func(i, j int) {
|
|
hosts[i], hosts[j] = hosts[j], hosts[i]
|
|
})
|
|
return hosts, nil
|
|
}
|
|
|
|
hosts, err := r.lookup(ctx, host)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
r.cache.Add(host, hosts)
|
|
return hosts, nil
|
|
}
|
|
|
|
func (r *netResolver) lookup(ctx context.Context, host string) ([]string, error) {
|
|
if r.timeout > 0 {
|
|
var cancel func()
|
|
ctx, cancel = context.WithTimeout(ctx, r.timeout)
|
|
defer cancel()
|
|
}
|
|
|
|
if net.ParseIP(host) == nil {
|
|
addrs, err := r.resolver.LookupHost(ctx, host)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if r.noIPv6 {
|
|
var addrs4 []string
|
|
for _, addr := range addrs {
|
|
if net.ParseIP(addr).To4() != nil {
|
|
addrs4 = append(addrs4, addr)
|
|
}
|
|
}
|
|
return addrs4, nil
|
|
}
|
|
return addrs, nil
|
|
}
|
|
|
|
addrs, err := r.resolver.LookupIPAddr(ctx, host)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
hosts := make([]string, len(addrs))
|
|
for i, addr := range addrs {
|
|
if !r.noIPv6 || addr.IP.To4() != nil {
|
|
hosts[i] = addr.IP.String()
|
|
}
|
|
}
|
|
return hosts, nil
|
|
}
|