Toy DNS over HTTPS server. https://dns.maze.network/dns-query
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

329 lines
7.9 KiB

package doh
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"log"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
)
const (
// DefaultQueryTimeout for doing DNS lookups.
DefaultQueryTimeout = 5 * time.Second
// DefaultRedirect is the redirect page for the root.
DefaultRedirect = "https://maze.io"
userAgent = "maze.io/doh 1.0"
)
type Server struct {
*echo.Echo
// Upstream DNS recursors
Upstream []string
// Insecure disables TLS verification
Insecure bool
// Verbose logging
Verbose bool
LogGuessedIP bool
// Redirect URL
Redirect string
udpClient *dns.Client
tcpClient *dns.Client
tcpClientTLS *dns.Client
}
func NewServer(queryTimeout time.Duration) *Server {
server := &Server{
Echo: echo.New(),
Redirect: DefaultRedirect,
udpClient: &dns.Client{
Net: "udp",
UDPSize: dns.DefaultMsgSize,
Timeout: queryTimeout,
},
tcpClient: &dns.Client{
Net: "tcp",
Timeout: queryTimeout,
},
tcpClientTLS: &dns.Client{
Net: "tcp-tls",
Timeout: queryTimeout,
TLSConfig: &tls.Config{},
},
}
// Tuning
server.HideBanner = true
// Middleware
server.Use(middleware.Recover())
server.Use(middleware.Logger())
// Routes
server.GET("/", func(c echo.Context) error {
return c.Redirect(http.StatusFound, server.Redirect)
})
server.GET("/dns-query", server.handleDNSQuery)
server.POST("/dns-query", server.handleDNSQuery)
return server
}
func (server *Server) handleDNSQuery(c echo.Context) error {
var (
r = c.Request()
ctx = r.Context()
response = c.Response()
w = response.Writer
header = response.Header()
)
header.Set("Access-Control-Allow-Headers", "Content-Type")
header.Set("Access-Control-Allow-Methods", "GET, HEAD, OPTIONS, POST")
header.Set("Access-Control-Allow-Origin", "*")
header.Set("Access-Control-Max-Age", "3600")
header.Set("Server", userAgent)
header.Set("X-Powered-By", userAgent)
if r.Method == http.MethodOptions {
header.Set("Content-Length", "0")
return nil
}
if r.Form == nil {
const maxMemory = 32 << 20 // 32 MB
r.ParseMultipartForm(maxMemory)
}
contentType := r.Header.Get("Content-Type")
if ct := r.FormValue("ct"); ct != "" {
contentType = ct
}
if contentType == "" {
// Guess request Content-Type based on other parameters
if r.FormValue("name") != "" {
contentType = "application/dns-json"
} else if r.FormValue("dns") != "" {
contentType = "application/dns-message"
}
}
var responseType string
for _, responseCandidate := range strings.Split(r.Header.Get("Accept"), ",") {
responseCandidate = strings.SplitN(responseCandidate, ";", 2)[0]
if responseCandidate == "application/json" {
responseType = "application/json"
break
} else if responseCandidate == "application/dns-udpwireformat" {
responseType = "application/dns-message"
break
} else if responseCandidate == "application/dns-message" {
responseType = "application/dns-message"
break
}
}
if responseType == "" {
// Guess response Content-Type based on request Content-Type
if contentType == "application/dns-json" {
responseType = "application/json"
} else if contentType == "application/dns-message" {
responseType = "application/dns-message"
} else if contentType == "application/dns-udpwireformat" {
responseType = "application/dns-message"
}
}
var req *DNSRequest
if contentType == "application/dns-json" {
req = server.parseRequestGoogle(ctx, w, r)
} else if contentType == "application/dns-message" {
req = server.parseRequestIETF(ctx, w, r)
} else if contentType == "application/dns-udpwireformat" {
req = server.parseRequestIETF(ctx, w, r)
} else {
formatError(w, fmt.Sprintf("Invalid argument value: \"ct\" = %q", contentType), 415)
return nil
}
if req.errcode == 444 {
return nil
}
req = server.patchRootRD(req)
var err error
req, err = server.query(ctx, req)
logger := logrus.WithFields(logrus.Fields{
"content_type": contentType,
"request": formatQuestions(req.request.Question),
"response": formatRRs(req.response.Answer),
"response_type": responseType,
})
if err != nil {
logger.WithError(err).Warn("query failure")
formatError(response, fmt.Sprintf("DNS query failure (%s)", err.Error()), 503)
return nil
}
logger.Info("query success")
if responseType == "application/json" {
server.generateResponseGoogle(ctx, w, r, req)
} else if responseType == "application/dns-message" {
server.generateResponseIETF(ctx, w, r, req)
}
return nil
}
func (server *Server) query(ctx context.Context, req *DNSRequest) (res *DNSRequest, err error) {
for _, upstream := range server.Upstream {
var u *url.URL
if u, err = url.Parse(upstream); err != nil {
continue
}
switch u.Scheme {
case "tcp-tls":
server.tcpClientTLS.TLSConfig.InsecureSkipVerify = server.Insecure
req.response, _, err = server.tcpClientTLS.Exchange(req.request, u.Host)
case "tcp":
req.response, _, err = server.tcpClient.Exchange(req.request, u.Host)
case "udp":
if server.indexQuestionType(req.request, dns.TypeAXFR) > -1 {
req.response, _, err = server.tcpClient.Exchange(req.request, u.Host)
} else {
req.response, _, err = server.udpClient.Exchange(req.request, u.Host)
if err == nil && req.response != nil && req.response.Truncated {
log.Println(err)
req.response, _, err = server.tcpClient.Exchange(req.request, u.Host)
}
// Retry with TCP if this was an IXFR request and we only received an SOA
if (server.indexQuestionType(req.request, dns.TypeIXFR) > -1) &&
(len(req.response.Answer) == 1) &&
(req.response.Answer[0].Header().Rrtype == dns.TypeSOA) {
req.response, _, err = server.tcpClient.Exchange(req.request, u.Host)
}
}
}
if err == nil {
return req, nil
}
log.Printf("DNS error from upstream %s: %s\n", req.currentUpstream, err.Error())
}
return
}
func (server *Server) indexQuestionType(msg *dns.Msg, qtype uint16) int {
for i, question := range msg.Question {
if question.Qtype == qtype {
return i
}
}
return -1
}
func (server *Server) findClientIP(r *http.Request) net.IP {
noEcs := r.URL.Query().Get("no_ecs")
if strings.ToLower(noEcs) == "true" {
return nil
}
XForwardedFor := r.Header.Get("X-Forwarded-For")
if XForwardedFor != "" {
for _, addr := range strings.Split(XForwardedFor, ",") {
addr = strings.TrimSpace(addr)
ip := net.ParseIP(addr)
if isGlobalIP(ip) {
return ip
}
}
}
XRealIP := r.Header.Get("X-Real-IP")
if XRealIP != "" {
addr := strings.TrimSpace(XRealIP)
ip := net.ParseIP(addr)
if isGlobalIP(ip) {
return ip
}
}
remoteAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr)
if err != nil {
return nil
}
if ip := remoteAddr.IP; isGlobalIP(ip) {
return ip
}
return nil
}
func (server *Server) patchRootRD(req *DNSRequest) *DNSRequest {
for _, question := range req.request.Question {
if question.Name == "." {
req.request.RecursionDesired = true
}
}
return req
}
type DNSRequest struct {
request *dns.Msg
response *dns.Msg
transactionID uint16
currentUpstream string
isTailored bool
errcode int
errtext string
}
type dnsError struct {
Status uint32 `json:"Status"`
Comment string `json:"Comment,omitempty"`
}
func formatError(w http.ResponseWriter, comment string, errcode int) {
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
errJson := dnsError{
Status: dns.RcodeServerFailure,
Comment: comment,
}
errStr, err := json.Marshal(errJson)
if err != nil {
log.Fatalln(err)
}
w.WriteHeader(errcode)
w.Write(errStr)
}
func formatQuestions(questions []dns.Question) []string {
s := make([]string, len(questions))
for i, q := range questions {
s[i] = q.String()
}
return s
}
func formatRRs(rrs []dns.RR) []string {
s := make([]string, len(rrs))
for i, rr := range rrs {
s[i] = rr.String()
}
return s
}