Toy DNS over HTTPS server. https://dns.maze.network/dns-query
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

277 lines
7.9 KiB

package doh
import (
"fmt"
"log"
"net"
"strconv"
"strings"
"time"
"github.com/miekg/dns"
)
type Response struct {
Status uint32 `json:"Status"` // Standard DNS response code (32 bit integer)
TC bool `json:"TC"` // Whether the response is truncated
RD bool `json:"RD"` // Recursion desired
RA bool `json:"RA"` // Recursion available
AD bool `json:"AD"` // Whether all response data was validated with DNSSEC
CD bool `json:"CD"` // Whether the client asked to disable DNSSEC
Question []Question `json:"Question"`
Answer []RR `json:"Answer,omitempty"`
Authority []RR `json:"Authority,omitempty"`
Additional []RR `json:"Additional,omitempty"`
Comment string `json:"Comment,omitempty"`
EdnsClientSubnet string `json:"edns_client_subnet,omitempty"`
HaveTTL bool `json:"-"` // Least time-to-live
LeastTTL uint32 `json:"-"`
EarliestExpires time.Time `json:"-"`
}
type Question struct {
Name string `json:"name"` // FQDN with trailing dot
Type uint16 `json:"type"` // Standard DNS RR type
}
type RR struct {
Question
TTL uint32 `json:"TTL"` // Record's time-to-live in seconds
Expires time.Time `json:"-"` // TTL in absolute time
ExpiresStr string `json:"Expires"`
Data string `json:"data"` // Data
}
func Marshal(msg *dns.Msg) *Response {
now := time.Now().UTC()
resp := new(Response)
resp.Status = uint32(msg.Rcode)
resp.TC = msg.Truncated
resp.RD = msg.RecursionDesired
resp.RA = msg.RecursionAvailable
resp.AD = msg.AuthenticatedData
resp.CD = msg.CheckingDisabled
resp.Question = make([]Question, 0, len(msg.Question))
for _, question := range msg.Question {
jsonQuestion := Question{
Name: question.Name,
Type: question.Qtype,
}
resp.Question = append(resp.Question, jsonQuestion)
}
resp.Answer = make([]RR, 0, len(msg.Answer))
for _, rr := range msg.Answer {
jsonAnswer := marshalRR(rr, now)
if !resp.HaveTTL || jsonAnswer.TTL < resp.LeastTTL {
resp.HaveTTL = true
resp.LeastTTL = jsonAnswer.TTL
resp.EarliestExpires = jsonAnswer.Expires
}
resp.Answer = append(resp.Answer, jsonAnswer)
}
resp.Authority = make([]RR, 0, len(msg.Ns))
for _, rr := range msg.Ns {
jsonAuthority := marshalRR(rr, now)
if !resp.HaveTTL || jsonAuthority.TTL < resp.LeastTTL {
resp.HaveTTL = true
resp.LeastTTL = jsonAuthority.TTL
resp.EarliestExpires = jsonAuthority.Expires
}
resp.Authority = append(resp.Authority, jsonAuthority)
}
resp.Additional = make([]RR, 0, len(msg.Extra))
for _, rr := range msg.Extra {
jsonAdditional := marshalRR(rr, now)
header := rr.Header()
if header.Rrtype == dns.TypeOPT {
opt := rr.(*dns.OPT)
resp.Status = ((opt.Hdr.Ttl & 0xff000000) >> 20) | (resp.Status & 0xff)
for _, option := range opt.Option {
if option.Option() == dns.EDNS0SUBNET {
edns0 := option.(*dns.EDNS0_SUBNET)
clientAddress := edns0.Address
if clientAddress == nil {
clientAddress = net.IP{0, 0, 0, 0}
} else if ipv4 := clientAddress.To4(); ipv4 != nil {
clientAddress = ipv4
}
resp.EdnsClientSubnet = clientAddress.String() + "/" + strconv.FormatUint(uint64(edns0.SourceScope), 10)
}
}
continue
}
if !resp.HaveTTL || jsonAdditional.TTL < resp.LeastTTL {
resp.HaveTTL = true
resp.LeastTTL = jsonAdditional.TTL
resp.EarliestExpires = jsonAdditional.Expires
}
resp.Additional = append(resp.Additional, jsonAdditional)
}
return resp
}
func marshalRR(rr dns.RR, now time.Time) (out RR) {
rrHeader := rr.Header()
out.Name = rrHeader.Name
out.Type = rrHeader.Rrtype
out.TTL = rrHeader.Ttl
out.Expires = now.Add(time.Duration(out.TTL) * time.Second)
out.ExpiresStr = out.Expires.Format(time.RFC1123)
data := strings.SplitN(rr.String(), "\t", 5)
if len(data) >= 5 {
out.Data = data[4]
}
return out
}
func Unmarshal(msg *dns.Msg, resp *Response, udpSize uint16, ednsClientNetmask uint8) *dns.Msg {
now := time.Now().UTC()
reply := msg.Copy()
reply.Truncated = resp.TC
reply.AuthenticatedData = resp.AD
reply.CheckingDisabled = resp.CD
reply.Rcode = dns.RcodeServerFailure
reply.Answer = make([]dns.RR, 0, len(resp.Answer))
for _, rr := range resp.Answer {
dnsRR, err := unmarshalRR(rr, now)
if err != nil {
log.Println(err)
} else {
reply.Answer = append(reply.Answer, dnsRR)
}
}
reply.Ns = make([]dns.RR, 0, len(resp.Authority))
for _, rr := range resp.Authority {
dnsRR, err := unmarshalRR(rr, now)
if err != nil {
log.Println(err)
} else {
reply.Ns = append(reply.Ns, dnsRR)
}
}
reply.Extra = make([]dns.RR, 0, len(resp.Additional)+1)
opt := new(dns.OPT)
opt.Hdr.Name = "."
opt.Hdr.Rrtype = dns.TypeOPT
if udpSize >= 512 {
opt.SetUDPSize(udpSize)
} else {
opt.SetUDPSize(512)
}
opt.SetDo(false)
ednsClientSubnet := resp.EdnsClientSubnet
ednsClientFamily := uint16(0)
ednsClientAddress := net.IP(nil)
ednsClientScope := uint8(255)
if ednsClientSubnet != "" {
slash := strings.IndexByte(ednsClientSubnet, '/')
if slash < 0 {
log.Println(UnmarshalError{"Invalid client subnet"})
} else {
ednsClientAddress = net.ParseIP(ednsClientSubnet[:slash])
if ednsClientAddress == nil {
log.Println(UnmarshalError{"Invalid client subnet address"})
} else if ipv4 := ednsClientAddress.To4(); ipv4 != nil {
ednsClientFamily = 1
ednsClientAddress = ipv4
} else {
ednsClientFamily = 2
}
scope, err := strconv.ParseUint(ednsClientSubnet[slash+1:], 10, 8)
if err != nil {
log.Println(UnmarshalError{"Invalid client subnet address"})
} else {
ednsClientScope = uint8(scope)
}
}
}
if ednsClientAddress != nil {
if ednsClientNetmask == 255 {
if ednsClientFamily == 1 {
ednsClientNetmask = 24
} else {
ednsClientNetmask = 56
}
}
edns0Subnet := new(dns.EDNS0_SUBNET)
edns0Subnet.Code = dns.EDNS0SUBNET
edns0Subnet.Family = ednsClientFamily
edns0Subnet.SourceNetmask = ednsClientNetmask
edns0Subnet.SourceScope = ednsClientScope
edns0Subnet.Address = ednsClientAddress
opt.Option = append(opt.Option, edns0Subnet)
}
reply.Extra = append(reply.Extra, opt)
for _, rr := range resp.Additional {
dnsRR, err := unmarshalRR(rr, now)
if err != nil {
log.Println(err)
} else {
reply.Extra = append(reply.Extra, dnsRR)
}
}
reply.Rcode = int(resp.Status & 0xf)
opt.Hdr.Ttl = (opt.Hdr.Ttl & 0x00ffffff) | ((resp.Status & 0xff0) << 20)
reply.Extra[0] = opt
return reply
}
func unmarshalRR(rr RR, now time.Time) (dnsRR dns.RR, err error) {
if strings.ContainsAny(rr.Name, "\t\r\n \"();\\") {
return nil, UnmarshalError{fmt.Sprintf("Record name contains space: %q", rr.Name)}
}
if rr.ExpiresStr != "" {
rr.Expires, err = time.Parse(time.RFC1123, rr.ExpiresStr)
if err != nil {
return nil, UnmarshalError{fmt.Sprintf("Invalid expire time: %q", rr.ExpiresStr)}
}
ttl := rr.Expires.Sub(now) / time.Second
if ttl >= 0 && ttl <= 0xffffffff {
rr.TTL = uint32(ttl)
}
}
rrType, ok := dns.TypeToString[rr.Type]
if !ok {
return nil, UnmarshalError{fmt.Sprintf("Unknown record type: %d", rr.Type)}
}
if strings.ContainsAny(rr.Data, "\r\n") {
return nil, UnmarshalError{fmt.Sprintf("Record data contains newline: %q", rr.Data)}
}
zone := fmt.Sprintf("%s %d IN %s %s", rr.Name, rr.TTL, rrType, rr.Data)
dnsRR, err = dns.NewRR(zone)
return
}
type UnmarshalError struct {
err string
}
func (e UnmarshalError) Error() string {
return "json-dns: " + e.err
}
func PrepareReply(req *dns.Msg) *dns.Msg {
reply := new(dns.Msg)
reply.Id = req.Id
reply.Response = true
reply.Opcode = req.Opcode
reply.RecursionDesired = req.RecursionDesired
reply.RecursionAvailable = req.RecursionDesired
reply.CheckingDisabled = req.CheckingDisabled
reply.Rcode = dns.RcodeServerFailure
reply.Compress = true
reply.Question = make([]dns.Question, len(req.Question))
copy(reply.Question, req.Question)
return reply
}