package doh
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/labstack/echo/v4/middleware"
|
|
"github.com/miekg/dns"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
// DefaultQueryTimeout for doing DNS lookups.
|
|
DefaultQueryTimeout = 5 * time.Second
|
|
|
|
// DefaultRedirect is the redirect page for the root.
|
|
DefaultRedirect = "https://maze.io"
|
|
|
|
userAgent = "maze.io/doh 1.0"
|
|
)
|
|
|
|
type Server struct {
|
|
*echo.Echo
|
|
|
|
// Upstream DNS recursors
|
|
Upstream []string
|
|
|
|
// Insecure disables TLS verification
|
|
Insecure bool
|
|
|
|
// Verbose logging
|
|
Verbose bool
|
|
LogGuessedIP bool
|
|
|
|
// Redirect URL
|
|
Redirect string
|
|
|
|
udpClient *dns.Client
|
|
tcpClient *dns.Client
|
|
tcpClientTLS *dns.Client
|
|
}
|
|
|
|
func NewServer(queryTimeout time.Duration) *Server {
|
|
server := &Server{
|
|
Echo: echo.New(),
|
|
Redirect: DefaultRedirect,
|
|
udpClient: &dns.Client{
|
|
Net: "udp",
|
|
UDPSize: dns.DefaultMsgSize,
|
|
Timeout: queryTimeout,
|
|
},
|
|
tcpClient: &dns.Client{
|
|
Net: "tcp",
|
|
Timeout: queryTimeout,
|
|
},
|
|
tcpClientTLS: &dns.Client{
|
|
Net: "tcp-tls",
|
|
Timeout: queryTimeout,
|
|
TLSConfig: &tls.Config{},
|
|
},
|
|
}
|
|
|
|
// Tuning
|
|
server.HideBanner = true
|
|
|
|
// Middleware
|
|
server.Use(middleware.Recover())
|
|
server.Use(middleware.Logger())
|
|
|
|
// Routes
|
|
server.GET("/", func(c echo.Context) error {
|
|
return c.Redirect(http.StatusFound, server.Redirect)
|
|
})
|
|
server.GET("/dns-query", server.handleDNSQuery)
|
|
server.POST("/dns-query", server.handleDNSQuery)
|
|
|
|
return server
|
|
}
|
|
|
|
func (server *Server) handleDNSQuery(c echo.Context) error {
|
|
var (
|
|
r = c.Request()
|
|
ctx = r.Context()
|
|
response = c.Response()
|
|
w = response.Writer
|
|
header = response.Header()
|
|
)
|
|
|
|
header.Set("Access-Control-Allow-Headers", "Content-Type")
|
|
header.Set("Access-Control-Allow-Methods", "GET, HEAD, OPTIONS, POST")
|
|
header.Set("Access-Control-Allow-Origin", "*")
|
|
header.Set("Access-Control-Max-Age", "3600")
|
|
header.Set("Server", userAgent)
|
|
header.Set("X-Powered-By", userAgent)
|
|
|
|
if r.Method == http.MethodOptions {
|
|
header.Set("Content-Length", "0")
|
|
return nil
|
|
}
|
|
|
|
if r.Form == nil {
|
|
const maxMemory = 32 << 20 // 32 MB
|
|
r.ParseMultipartForm(maxMemory)
|
|
}
|
|
|
|
contentType := r.Header.Get("Content-Type")
|
|
if ct := r.FormValue("ct"); ct != "" {
|
|
contentType = ct
|
|
}
|
|
if contentType == "" {
|
|
// Guess request Content-Type based on other parameters
|
|
if r.FormValue("name") != "" {
|
|
contentType = "application/dns-json"
|
|
} else if r.FormValue("dns") != "" {
|
|
contentType = "application/dns-message"
|
|
}
|
|
}
|
|
|
|
var responseType string
|
|
for _, responseCandidate := range strings.Split(r.Header.Get("Accept"), ",") {
|
|
responseCandidate = strings.SplitN(responseCandidate, ";", 2)[0]
|
|
if responseCandidate == "application/json" {
|
|
responseType = "application/json"
|
|
break
|
|
} else if responseCandidate == "application/dns-udpwireformat" {
|
|
responseType = "application/dns-message"
|
|
break
|
|
} else if responseCandidate == "application/dns-message" {
|
|
responseType = "application/dns-message"
|
|
break
|
|
}
|
|
}
|
|
if responseType == "" {
|
|
// Guess response Content-Type based on request Content-Type
|
|
if contentType == "application/dns-json" {
|
|
responseType = "application/json"
|
|
} else if contentType == "application/dns-message" {
|
|
responseType = "application/dns-message"
|
|
} else if contentType == "application/dns-udpwireformat" {
|
|
responseType = "application/dns-message"
|
|
}
|
|
}
|
|
|
|
var req *DNSRequest
|
|
if contentType == "application/dns-json" {
|
|
req = server.parseRequestGoogle(ctx, w, r)
|
|
} else if contentType == "application/dns-message" {
|
|
req = server.parseRequestIETF(ctx, w, r)
|
|
} else if contentType == "application/dns-udpwireformat" {
|
|
req = server.parseRequestIETF(ctx, w, r)
|
|
} else {
|
|
formatError(w, fmt.Sprintf("Invalid argument value: \"ct\" = %q", contentType), 415)
|
|
return nil
|
|
}
|
|
|
|
if req.errcode == 444 {
|
|
return nil
|
|
}
|
|
|
|
req = server.patchRootRD(req)
|
|
var err error
|
|
req, err = server.query(ctx, req)
|
|
|
|
logger := logrus.WithFields(logrus.Fields{
|
|
"content_type": contentType,
|
|
"request": formatQuestions(req.request.Question),
|
|
"response": formatRRs(req.response.Answer),
|
|
"response_type": responseType,
|
|
})
|
|
if err != nil {
|
|
logger.WithError(err).Warn("query failure")
|
|
formatError(response, fmt.Sprintf("DNS query failure (%s)", err.Error()), 503)
|
|
return nil
|
|
}
|
|
logger.Info("query success")
|
|
|
|
if responseType == "application/json" {
|
|
server.generateResponseGoogle(ctx, w, r, req)
|
|
} else if responseType == "application/dns-message" {
|
|
server.generateResponseIETF(ctx, w, r, req)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (server *Server) query(ctx context.Context, req *DNSRequest) (res *DNSRequest, err error) {
|
|
for _, upstream := range server.Upstream {
|
|
var u *url.URL
|
|
if u, err = url.Parse(upstream); err != nil {
|
|
continue
|
|
}
|
|
|
|
switch u.Scheme {
|
|
case "tcp-tls":
|
|
server.tcpClientTLS.TLSConfig.InsecureSkipVerify = server.Insecure
|
|
req.response, _, err = server.tcpClientTLS.Exchange(req.request, u.Host)
|
|
case "tcp":
|
|
req.response, _, err = server.tcpClient.Exchange(req.request, u.Host)
|
|
case "udp":
|
|
if server.indexQuestionType(req.request, dns.TypeAXFR) > -1 {
|
|
req.response, _, err = server.tcpClient.Exchange(req.request, u.Host)
|
|
} else {
|
|
req.response, _, err = server.udpClient.Exchange(req.request, u.Host)
|
|
if err == nil && req.response != nil && req.response.Truncated {
|
|
log.Println(err)
|
|
req.response, _, err = server.tcpClient.Exchange(req.request, u.Host)
|
|
}
|
|
|
|
// Retry with TCP if this was an IXFR request and we only received an SOA
|
|
if (server.indexQuestionType(req.request, dns.TypeIXFR) > -1) &&
|
|
(len(req.response.Answer) == 1) &&
|
|
(req.response.Answer[0].Header().Rrtype == dns.TypeSOA) {
|
|
req.response, _, err = server.tcpClient.Exchange(req.request, u.Host)
|
|
}
|
|
}
|
|
}
|
|
|
|
if err == nil {
|
|
return req, nil
|
|
}
|
|
log.Printf("DNS error from upstream %s: %s\n", req.currentUpstream, err.Error())
|
|
}
|
|
return
|
|
}
|
|
|
|
func (server *Server) indexQuestionType(msg *dns.Msg, qtype uint16) int {
|
|
for i, question := range msg.Question {
|
|
if question.Qtype == qtype {
|
|
return i
|
|
}
|
|
}
|
|
return -1
|
|
}
|
|
|
|
func (server *Server) findClientIP(r *http.Request) net.IP {
|
|
noEcs := r.URL.Query().Get("no_ecs")
|
|
if strings.ToLower(noEcs) == "true" {
|
|
return nil
|
|
}
|
|
|
|
XForwardedFor := r.Header.Get("X-Forwarded-For")
|
|
if XForwardedFor != "" {
|
|
for _, addr := range strings.Split(XForwardedFor, ",") {
|
|
addr = strings.TrimSpace(addr)
|
|
ip := net.ParseIP(addr)
|
|
if isGlobalIP(ip) {
|
|
return ip
|
|
}
|
|
}
|
|
}
|
|
XRealIP := r.Header.Get("X-Real-IP")
|
|
if XRealIP != "" {
|
|
addr := strings.TrimSpace(XRealIP)
|
|
ip := net.ParseIP(addr)
|
|
if isGlobalIP(ip) {
|
|
return ip
|
|
}
|
|
}
|
|
remoteAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
if ip := remoteAddr.IP; isGlobalIP(ip) {
|
|
return ip
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (server *Server) patchRootRD(req *DNSRequest) *DNSRequest {
|
|
for _, question := range req.request.Question {
|
|
if question.Name == "." {
|
|
req.request.RecursionDesired = true
|
|
}
|
|
}
|
|
return req
|
|
}
|
|
|
|
type DNSRequest struct {
|
|
request *dns.Msg
|
|
response *dns.Msg
|
|
transactionID uint16
|
|
currentUpstream string
|
|
isTailored bool
|
|
errcode int
|
|
errtext string
|
|
}
|
|
|
|
type dnsError struct {
|
|
Status uint32 `json:"Status"`
|
|
Comment string `json:"Comment,omitempty"`
|
|
}
|
|
|
|
func formatError(w http.ResponseWriter, comment string, errcode int) {
|
|
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
|
errJson := dnsError{
|
|
Status: dns.RcodeServerFailure,
|
|
Comment: comment,
|
|
}
|
|
errStr, err := json.Marshal(errJson)
|
|
if err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
w.WriteHeader(errcode)
|
|
w.Write(errStr)
|
|
}
|
|
|
|
func formatQuestions(questions []dns.Question) []string {
|
|
s := make([]string, len(questions))
|
|
for i, q := range questions {
|
|
s[i] = q.String()
|
|
}
|
|
return s
|
|
}
|
|
|
|
func formatRRs(rrs []dns.RR) []string {
|
|
s := make([]string, len(rrs))
|
|
for i, rr := range rrs {
|
|
s[i] = rr.String()
|
|
}
|
|
return s
|
|
}
|