// Package resolver implements a caching DNS resolver package resolver import ( "context" "math/rand/v2" "net" "strings" "time" "git.maze.io/maze/styx/internal/netutil" "github.com/hashicorp/golang-lru/v2/expirable" ) const ( DefaultSize = 1024 DefaultTTL = 5 * time.Minute DefaultTimeout = 10 * time.Second ) var ( // DefaultConfig are the defaults for the Default resolver. DefaultConfig = Config{ Size: DefaultSize, TTL: DefaultTTL.Seconds(), Timeout: DefaultTimeout.Seconds(), } // Default resolver. Default = New(DefaultConfig) ) type Resolver interface { // Lookup returns resolved IPs for given hostname/ips. Lookup(context.Context, string) ([]string, error) } type netResolver struct { resolver *net.Resolver timeout time.Duration noIPv6 bool cache *expirable.LRU[string, []string] } type Config struct { // Size is our cache size in number of entries. Size int `hcl:"size,optional"` // TTL is the cache time to live in seconds. TTL float64 `hcl:"ttl,optional"` // Timeout is the cache timeout in seconds. Timeout float64 `hcl:"timeout,optional"` // Server are alternative DNS servers. Server []string `hcl:"server,optional"` // NoIPv6 disables IPv6 DNS resolution. NoIPv6 bool `hcl:"noipv6,optional"` } func New(config Config) Resolver { var ( size = config.Size ttl = time.Duration(float64(time.Second) * config.TTL) timeout = time.Duration(float64(time.Second) * config.Timeout) ) if size <= 0 { size = DefaultSize } if ttl <= 0 { ttl = DefaultTTL } if timeout <= 0 { timeout = 0 } var resolver = new(net.Resolver) if len(config.Server) > 0 { var dialer net.Dialer resolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { server := netutil.EnsurePort(config.Server[rand.IntN(len(config.Server))], "53") return dialer.DialContext(ctx, network, server) } } return &netResolver{ resolver: resolver, timeout: timeout, noIPv6: config.NoIPv6, cache: expirable.NewLRU[string, []string](size, nil, ttl), } } func (r *netResolver) Lookup(ctx context.Context, host string) ([]string, error) { host = strings.ToLower(strings.TrimSpace(host)) if hosts, ok := r.cache.Get(host); ok { rand.Shuffle(len(hosts), func(i, j int) { hosts[i], hosts[j] = hosts[j], hosts[i] }) return hosts, nil } hosts, err := r.lookup(ctx, host) if err != nil { return nil, err } r.cache.Add(host, hosts) return hosts, nil } func (r *netResolver) lookup(ctx context.Context, host string) ([]string, error) { if r.timeout > 0 { var cancel func() ctx, cancel = context.WithTimeout(ctx, r.timeout) defer cancel() } if net.ParseIP(host) == nil { addrs, err := r.resolver.LookupHost(ctx, host) if err != nil { return nil, err } if r.noIPv6 { var addrs4 []string for _, addr := range addrs { if net.ParseIP(addr).To4() != nil { addrs4 = append(addrs4, addr) } } return addrs4, nil } return addrs, nil } addrs, err := r.resolver.LookupIPAddr(ctx, host) if err != nil { return nil, err } hosts := make([]string, len(addrs)) for i, addr := range addrs { if !r.noIPv6 || addr.IP.To4() != nil { hosts[i] = addr.IP.String() } } return hosts, nil }