Checkpoint
This commit is contained in:
		
							
								
								
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -1,4 +1,6 @@
 | 
			
		||||
# SQLite3 database file
 | 
			
		||||
# Database file
 | 
			
		||||
*.bolt
 | 
			
		||||
*.boltdb
 | 
			
		||||
*.db
 | 
			
		||||
 | 
			
		||||
# Log files
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										14
									
								
								.regal.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								.regal.yaml
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,14 @@
 | 
			
		||||
rules:
 | 
			
		||||
  idiomatic:
 | 
			
		||||
    directory-package-mismatch:
 | 
			
		||||
      level: ignore
 | 
			
		||||
 | 
			
		||||
  style:
 | 
			
		||||
    function-arg-return:
 | 
			
		||||
      level: error
 | 
			
		||||
      except-functions:
 | 
			
		||||
      - sprintf
 | 
			
		||||
 | 
			
		||||
project:
 | 
			
		||||
  roots:
 | 
			
		||||
  - testdata/policy
 | 
			
		||||
							
								
								
									
										146
									
								
								admin/admin.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										146
									
								
								admin/admin.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,146 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"sync"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/dataset"
 | 
			
		||||
	"git.maze.io/maze/styx/logger"
 | 
			
		||||
	"git.maze.io/maze/styx/proxy"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Admin struct {
 | 
			
		||||
	Storage   dataset.Storage
 | 
			
		||||
	setupOnce sync.Once
 | 
			
		||||
	mux       *http.ServeMux
 | 
			
		||||
	api       *http.ServeMux
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type apiError struct {
 | 
			
		||||
	Code int
 | 
			
		||||
	Err  error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (err apiError) Error() string {
 | 
			
		||||
	return err.Err.Error()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) setup() {
 | 
			
		||||
	a.mux = http.NewServeMux()
 | 
			
		||||
 | 
			
		||||
	a.api = http.NewServeMux()
 | 
			
		||||
	a.api.HandleFunc("GET /groups", a.apiGroups)
 | 
			
		||||
	a.api.HandleFunc("POST /group", a.apiGroupCreate)
 | 
			
		||||
	a.api.HandleFunc("GET /group/{id}", a.apiGroup)
 | 
			
		||||
	a.api.HandleFunc("PATCH /group/{id}", a.apiGroupUpdate)
 | 
			
		||||
	a.api.HandleFunc("DELETE /group/{id}", a.apiGroupDelete)
 | 
			
		||||
	a.api.HandleFunc("GET /clients", a.apiClients)
 | 
			
		||||
	a.api.HandleFunc("GET /client/{id}", a.apiClient)
 | 
			
		||||
	a.api.HandleFunc("POST /client", a.apiClientCreate)
 | 
			
		||||
	a.api.HandleFunc("PATCH /client/{id}", a.apiClientUpdate)
 | 
			
		||||
	a.api.HandleFunc("DELETE /client/{id}", a.apiClientDelete)
 | 
			
		||||
	a.api.HandleFunc("GET /lists", a.apiLists)
 | 
			
		||||
	a.api.HandleFunc("POST /list", a.apiListCreate)
 | 
			
		||||
	a.api.HandleFunc("GET /list/{id}", a.apiList)
 | 
			
		||||
	a.api.HandleFunc("DELETE /list/{id}", a.apiListDelete)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Handler interface {
 | 
			
		||||
	Handle(pattern string, handler http.Handler)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) Install(handler Handler) {
 | 
			
		||||
	a.setupOnce.Do(a.setup)
 | 
			
		||||
	handler.Handle("/api/v1/", http.StripPrefix("/api/v1", a.api))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) handleAPIError(w http.ResponseWriter, r *http.Request, err error) {
 | 
			
		||||
	code := http.StatusBadRequest
 | 
			
		||||
	switch {
 | 
			
		||||
	case dataset.IsNotExist(err):
 | 
			
		||||
		code = http.StatusNotFound
 | 
			
		||||
	case os.IsPermission(err):
 | 
			
		||||
		code = http.StatusForbidden
 | 
			
		||||
	case errors.Is(err, apiError{}):
 | 
			
		||||
		if c := err.(apiError).Code; c > 0 {
 | 
			
		||||
			code = c
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.StandardLog.Err(err).Values(logger.Values{
 | 
			
		||||
		"code":   code,
 | 
			
		||||
		"client": r.RemoteAddr,
 | 
			
		||||
		"method": r.Method,
 | 
			
		||||
		"path":   r.URL.Path,
 | 
			
		||||
	}).Warn("Unexpected API error encountered")
 | 
			
		||||
 | 
			
		||||
	var data []byte
 | 
			
		||||
	if err, ok := err.(apiError); ok {
 | 
			
		||||
		data, _ = json.Marshal(struct {
 | 
			
		||||
			Code  int    `json:"code"`
 | 
			
		||||
			Error string `json:"error"`
 | 
			
		||||
		}{code, err.Error()})
 | 
			
		||||
	} else {
 | 
			
		||||
		data, _ = json.Marshal(struct {
 | 
			
		||||
			Code  int    `json:"code"`
 | 
			
		||||
			Error string `json:"error"`
 | 
			
		||||
		}{code, http.StatusText(code)})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := proxy.NewResponse(code, io.NopCloser(bytes.NewReader(data)), r)
 | 
			
		||||
	res.Header.Set(proxy.HeaderContentType, "application/json")
 | 
			
		||||
 | 
			
		||||
	for k, vv := range res.Header {
 | 
			
		||||
		if len(vv) >= 1 {
 | 
			
		||||
			w.Header().Set(k, vv[0])
 | 
			
		||||
			for _, v := range vv[1:] {
 | 
			
		||||
				w.Header().Add(k, v)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	w.WriteHeader(code)
 | 
			
		||||
	io.Copy(w, res.Body)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) jsonResponse(w http.ResponseWriter, r *http.Request, value any, codes ...int) {
 | 
			
		||||
	var (
 | 
			
		||||
		code = http.StatusNoContent
 | 
			
		||||
		body io.ReadCloser
 | 
			
		||||
		size int64
 | 
			
		||||
	)
 | 
			
		||||
	if value != nil {
 | 
			
		||||
		data, err := json.Marshal(value)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			a.handleAPIError(w, r, err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		code = http.StatusOK
 | 
			
		||||
		body = io.NopCloser(bytes.NewReader(data))
 | 
			
		||||
		size = int64(len(data))
 | 
			
		||||
	}
 | 
			
		||||
	if len(codes) > 0 {
 | 
			
		||||
		code = codes[0]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := proxy.NewResponse(code, body, r)
 | 
			
		||||
	res.Close = true
 | 
			
		||||
	res.Header.Set(proxy.HeaderContentLength, strconv.FormatInt(size, 10))
 | 
			
		||||
	res.Header.Set(proxy.HeaderContentType, "application/json")
 | 
			
		||||
 | 
			
		||||
	for k, vv := range res.Header {
 | 
			
		||||
		if len(vv) >= 1 {
 | 
			
		||||
			w.Header().Set(k, vv[0])
 | 
			
		||||
			for _, v := range vv[1:] {
 | 
			
		||||
				w.Header().Add(k, v)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	w.WriteHeader(code)
 | 
			
		||||
	io.Copy(w, res.Body)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										183
									
								
								admin/api_client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										183
									
								
								admin/api_client.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,183 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/dataset"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiClients(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	clients, err := a.Storage.Clients()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	a.jsonResponse(w, r, clients)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiClient(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	id, err := strconv.ParseInt(r.PathValue("id"), 10, 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	client, err := a.Storage.ClientByID(id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	a.jsonResponse(w, r, client)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiClientCreate(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	var request struct {
 | 
			
		||||
		dataset.Client
 | 
			
		||||
		Groups    []int64   `json:"groups"`
 | 
			
		||||
		ID        int64     `json:"id"`         // mask, not used
 | 
			
		||||
		CreatedAt time.Time `json:"created_at"` // mask, not used
 | 
			
		||||
		UpdatedAt time.Time `json:"updated_at"` // mask, not used
 | 
			
		||||
	}
 | 
			
		||||
	if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := a.verifyClient(&request.Client); err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var groups []dataset.Group
 | 
			
		||||
	for _, id := range request.Groups {
 | 
			
		||||
		group, err := a.Storage.GroupByID(id)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			a.handleAPIError(w, r, err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		groups = append(groups, group)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	request.Client.Groups = groups
 | 
			
		||||
	if err := a.Storage.SaveClient(&request.Client); err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	a.jsonResponse(w, r, request.Client)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiClientUpdate(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	id, err := strconv.ParseInt(r.PathValue("id"), 10, 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client, err := a.Storage.ClientByID(id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	log.Printf("updating: %#+v", client)
 | 
			
		||||
 | 
			
		||||
	var request struct {
 | 
			
		||||
		dataset.Client
 | 
			
		||||
		Groups []int64 `json:"groups"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := a.verifyClient(&request.Client); err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client.IP = request.Client.IP
 | 
			
		||||
	client.Mask = request.Client.Mask
 | 
			
		||||
	client.Description = request.Client.Description
 | 
			
		||||
	client.Groups = client.Groups[:0]
 | 
			
		||||
	for _, id := range request.Groups {
 | 
			
		||||
		group, err := a.Storage.GroupByID(id)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			a.handleAPIError(w, r, err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		client.Groups = append(client.Groups, group)
 | 
			
		||||
	}
 | 
			
		||||
	if err := a.Storage.SaveClient(&client); err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	a.jsonResponse(w, r, client)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiClientDelete(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	id, err := strconv.ParseInt(r.PathValue("id"), 10, 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	client, err := a.Storage.ClientByID(id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if err = a.Storage.DeleteClient(client); err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	a.jsonResponse(w, r, nil)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) verifyClient(c *dataset.Client) (err error) {
 | 
			
		||||
	ip := net.ParseIP(c.IP)
 | 
			
		||||
	switch c.Network {
 | 
			
		||||
	case "ipv4":
 | 
			
		||||
		if ip.To4() == nil {
 | 
			
		||||
			return apiError{Err: errors.New("invalid IPv4 address")}
 | 
			
		||||
		}
 | 
			
		||||
		if c.Mask == 0 {
 | 
			
		||||
			c.Mask = 32 // one IP
 | 
			
		||||
		}
 | 
			
		||||
		if c.Mask <= 0 || c.Mask > 32 {
 | 
			
		||||
			return apiError{Err: errors.New("mask can't be zero")}
 | 
			
		||||
		}
 | 
			
		||||
		c.IP = ip.Mask(net.CIDRMask(int(c.Mask), 32)).String()
 | 
			
		||||
 | 
			
		||||
	case "ipv6":
 | 
			
		||||
		if ip.To16() == nil {
 | 
			
		||||
			return apiError{Err: errors.New("invalid IPv6 address")}
 | 
			
		||||
		}
 | 
			
		||||
		if c.Mask == 0 {
 | 
			
		||||
			c.Mask = 128 // one IP
 | 
			
		||||
		}
 | 
			
		||||
		if c.Mask <= 0 || c.Mask > 128 {
 | 
			
		||||
			return apiError{Err: errors.New("mask can't be zero")}
 | 
			
		||||
		}
 | 
			
		||||
		c.IP = ip.Mask(net.CIDRMask(int(c.Mask), 128)).String()
 | 
			
		||||
 | 
			
		||||
	case "":
 | 
			
		||||
		if ip.To4() != nil {
 | 
			
		||||
			c.Network = "ipv4"
 | 
			
		||||
		} else if ip.To16() != nil {
 | 
			
		||||
			c.Network = "ipv6"
 | 
			
		||||
		} else {
 | 
			
		||||
			return apiError{Err: errors.New("invalid IP address")}
 | 
			
		||||
		}
 | 
			
		||||
		return a.verifyClient(c)
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
		return apiError{Err: fmt.Errorf("invalid network %q", c.Network)}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										72
									
								
								admin/api_group.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										72
									
								
								admin/api_group.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,72 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/dataset"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiGroups(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	groups, err := a.Storage.Groups()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	a.jsonResponse(w, r, groups)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiGroup(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	id, err := strconv.ParseInt(r.PathValue("id"), 10, 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	group, err := a.Storage.GroupByID(id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	a.jsonResponse(w, r, group)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiGroupCreate(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	var request struct {
 | 
			
		||||
		dataset.Group
 | 
			
		||||
		ID        int64     `json:"id"`         // mask, not used
 | 
			
		||||
		CreatedAt time.Time `json:"created_at"` // mask, not used
 | 
			
		||||
		UpdatedAt time.Time `json:"updated_at"` // mask, not used
 | 
			
		||||
	}
 | 
			
		||||
	if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if err := a.Storage.SaveGroup(&request.Group); err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	a.jsonResponse(w, r, request.Group, http.StatusCreated)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiGroupUpdate(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiGroupDelete(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	id, err := strconv.ParseInt(r.PathValue("id"), 10, 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	group, err := a.Storage.GroupByID(id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if err = a.Storage.DeleteGroup(group); err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	a.jsonResponse(w, r, nil)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										98
									
								
								admin/api_list.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								admin/api_list.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,98 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/dataset"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiLists(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	lists, err := a.Storage.Lists()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	a.jsonResponse(w, r, lists)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiList(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	id, err := strconv.ParseInt(r.PathValue("id"), 10, 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	list, err := a.Storage.ListByID(id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	a.jsonResponse(w, r, list)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiListCreate(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	var request struct {
 | 
			
		||||
		dataset.List
 | 
			
		||||
		Groups    []int64   `json:"groups"`
 | 
			
		||||
		ID        int64     `json:"id"`         // mask, not used
 | 
			
		||||
		CreatedAt time.Time `json:"created_at"` // mask, not used
 | 
			
		||||
		UpdatedAt time.Time `json:"updated_at"` // mask, not used
 | 
			
		||||
	}
 | 
			
		||||
	if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := a.verifyList(&request.List); err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	request.List.Groups = request.List.Groups[:0]
 | 
			
		||||
	for _, id := range request.Groups {
 | 
			
		||||
		group, err := a.Storage.GroupByID(id)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			a.handleAPIError(w, r, err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		request.List.Groups = append(request.List.Groups, group)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := a.Storage.SaveList(&request.List); err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	a.jsonResponse(w, r, request.List)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiListDelete(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	id, err := strconv.ParseInt(r.PathValue("id"), 10, 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	list, err := a.Storage.ListByID(id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if err = a.Storage.DeleteList(list); err != nil {
 | 
			
		||||
		a.handleAPIError(w, r, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	a.jsonResponse(w, r, nil)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) verifyList(list *dataset.List) error {
 | 
			
		||||
	switch list.Type {
 | 
			
		||||
	case dataset.ListTypeDomain, dataset.ListTypeNetwork:
 | 
			
		||||
	default:
 | 
			
		||||
		return apiError{Err: fmt.Errorf("unknown list type %q", list.Type)}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										119
									
								
								ca/authority.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								ca/authority.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,119 @@
 | 
			
		||||
package ca
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto"
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"crypto/x509"
 | 
			
		||||
	"crypto/x509/pkix"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"math/big"
 | 
			
		||||
	"net"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/cryptutil"
 | 
			
		||||
	"git.maze.io/maze/styx/logger"
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type CertificateAuthority interface {
 | 
			
		||||
	GetCertificate(commonName string, dnsNames []string, ips []net.IP) (*tls.Certificate, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ca struct {
 | 
			
		||||
	cert  *x509.Certificate
 | 
			
		||||
	key   crypto.PrivateKey
 | 
			
		||||
	cache sync.Map
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Open(certData, keyData string) (CertificateAuthority, error) {
 | 
			
		||||
	cert, key, err := cryptutil.LoadKeyPair(certData, keyData)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	} else if !cert.IsCA {
 | 
			
		||||
		return nil, fmt.Errorf("ca: certificate for %s is not a certificate authority", cert.Subject.String())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &ca{
 | 
			
		||||
		cert: cert,
 | 
			
		||||
		key:  key,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ca *ca) GetCertificate(cn string, names []string, ips []net.IP) (*tls.Certificate, error) {
 | 
			
		||||
	var (
 | 
			
		||||
		log = logger.StandardLog.Values(logger.Values{
 | 
			
		||||
			"cn":    cn,
 | 
			
		||||
			"names": names,
 | 
			
		||||
			"ips":   ips,
 | 
			
		||||
		})
 | 
			
		||||
		now    = time.Now().UTC()
 | 
			
		||||
		parent = parentDomain(cn)
 | 
			
		||||
	)
 | 
			
		||||
	if cn == parent {
 | 
			
		||||
		names = append(names, "*."+cn)
 | 
			
		||||
	} else {
 | 
			
		||||
		names = append(names, "*."+parent, cn)
 | 
			
		||||
		cn = parent
 | 
			
		||||
		log = log.Value("cn", cn)
 | 
			
		||||
	}
 | 
			
		||||
	if v, ok := ca.cache.Load(parent); ok {
 | 
			
		||||
		if cert, ok := v.(*tls.Certificate); ok && now.After(cert.Leaf.NotBefore) && now.Before(cert.Leaf.NotAfter.Add(-time.Hour)) {
 | 
			
		||||
			log.Value("valid", cert.Leaf.NotAfter.Sub(now)).Debug("Using cached certificate")
 | 
			
		||||
			return cert, nil
 | 
			
		||||
		}
 | 
			
		||||
		log.Debug("Cached certificate invalid")
 | 
			
		||||
		ca.cache.Delete(parent)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
 | 
			
		||||
	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("ca: failed to generate serial number: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	notBefore := now.Round(24 * time.Hour)
 | 
			
		||||
	notAfter := notBefore.Add(48 * time.Hour)
 | 
			
		||||
 | 
			
		||||
	log.Values(logger.Values{
 | 
			
		||||
		"serial":  serialNumber.String(),
 | 
			
		||||
		"subject": pkix.Name{CommonName: cn}.String(),
 | 
			
		||||
	}).Debug("Generating certificate")
 | 
			
		||||
	template := &x509.Certificate{
 | 
			
		||||
		SerialNumber: serialNumber,
 | 
			
		||||
		KeyUsage:     x509.KeyUsageDataEncipherment | x509.KeyUsageDigitalSignature,
 | 
			
		||||
		ExtKeyUsage:  []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
 | 
			
		||||
		Subject:      pkix.Name{CommonName: cn},
 | 
			
		||||
		DNSNames:     names,
 | 
			
		||||
		IPAddresses:  ips,
 | 
			
		||||
		PublicKey:    cryptutil.PublicKey(ca.key),
 | 
			
		||||
		NotBefore:    notBefore,
 | 
			
		||||
		NotAfter:     notAfter,
 | 
			
		||||
	}
 | 
			
		||||
	der, err := x509.CreateCertificate(rand.Reader, template, ca.cert, template.PublicKey, ca.key)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	cert, err := x509.ParseCertificate(der)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	output := &tls.Certificate{
 | 
			
		||||
		Certificate: [][]byte{der},
 | 
			
		||||
		Leaf:        cert,
 | 
			
		||||
		PrivateKey:  ca.key,
 | 
			
		||||
	}
 | 
			
		||||
	ca.cache.Store(parent, output)
 | 
			
		||||
	return output, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func parentDomain(name string) string {
 | 
			
		||||
	part := dns.SplitDomainName(name)
 | 
			
		||||
	if len(part) <= 2 {
 | 
			
		||||
		return name
 | 
			
		||||
	}
 | 
			
		||||
	return strings.Join(part[1:], ".")
 | 
			
		||||
}
 | 
			
		||||
@@ -8,6 +8,7 @@ import (
 | 
			
		||||
	"github.com/hashicorp/hcl/v2/gohcl"
 | 
			
		||||
	"github.com/hashicorp/hcl/v2/hclsimple"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/ca"
 | 
			
		||||
	"git.maze.io/maze/styx/dataset"
 | 
			
		||||
	"git.maze.io/maze/styx/internal/cryptutil"
 | 
			
		||||
	"git.maze.io/maze/styx/logger"
 | 
			
		||||
@@ -18,6 +19,7 @@ import (
 | 
			
		||||
type Config struct {
 | 
			
		||||
	Proxy  ProxyConfig    `hcl:"proxy,block"`
 | 
			
		||||
	Policy []PolicyConfig `hcl:"policy,block"`
 | 
			
		||||
	CA     *CAConfig      `hcl:"ca,block"`
 | 
			
		||||
	Data   DataConfig     `hcl:"data,block"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -145,8 +147,18 @@ type PolicyConfig struct {
 | 
			
		||||
	Package string `hcl:"package,optional"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CAConfig struct {
 | 
			
		||||
	Cert string `hcl:"cert"`
 | 
			
		||||
	Key  string `hcl:"key,optional"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c CAConfig) CertificateAuthority() (ca.CertificateAuthority, error) {
 | 
			
		||||
	return ca.Open(c.Cert, c.Key)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DataConfig struct {
 | 
			
		||||
	Path     string              `hcl:"path,optional"`
 | 
			
		||||
	Storage  DataStorageConfig   `hcl:"storage,block"`
 | 
			
		||||
	Domains  []DomainDataConfig  `hcl:"domain,block"`
 | 
			
		||||
	Networks []NetworkDataConfig `hcl:"network,block"`
 | 
			
		||||
}
 | 
			
		||||
@@ -165,6 +177,39 @@ func (c DataConfig) Configure() error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c DataConfig) OpenStorage() (dataset.Storage, error) {
 | 
			
		||||
	switch c.Storage.Type {
 | 
			
		||||
	case "", "bolt", "boltdb":
 | 
			
		||||
		var config struct {
 | 
			
		||||
			Path string `hcl:"path"`
 | 
			
		||||
		}
 | 
			
		||||
		if diag := gohcl.DecodeBody(c.Storage.Body, nil, &config); diag.HasErrors() {
 | 
			
		||||
			return nil, diag
 | 
			
		||||
		}
 | 
			
		||||
		//return dataset.OpenBolt(config.Path)
 | 
			
		||||
		return dataset.OpenBStore(config.Path)
 | 
			
		||||
 | 
			
		||||
	/*
 | 
			
		||||
		case "sqlite", "sqlite3":
 | 
			
		||||
			var config struct {
 | 
			
		||||
				Path string `hcl:"path"`
 | 
			
		||||
			}
 | 
			
		||||
			if diag := gohcl.DecodeBody(c.Storage.Body, nil, &config); diag.HasErrors() {
 | 
			
		||||
				return nil, diag
 | 
			
		||||
			}
 | 
			
		||||
			return dataset.OpenSQLite(config.Path)
 | 
			
		||||
	*/
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, fmt.Errorf("storage: no %q driver", c.Storage.Type)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DataStorageConfig struct {
 | 
			
		||||
	Type string   `hcl:"type"`
 | 
			
		||||
	Body hcl.Body `hcl:",remain"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DomainDataConfig struct {
 | 
			
		||||
	Name string   `hcl:"name,label"`
 | 
			
		||||
	Type string   `hcl:"type"`
 | 
			
		||||
 
 | 
			
		||||
@@ -7,6 +7,9 @@ import (
 | 
			
		||||
	"os/signal"
 | 
			
		||||
	"syscall"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/admin"
 | 
			
		||||
	"git.maze.io/maze/styx/ca"
 | 
			
		||||
	"git.maze.io/maze/styx/dataset"
 | 
			
		||||
	"git.maze.io/maze/styx/logger"
 | 
			
		||||
	"git.maze.io/maze/styx/proxy"
 | 
			
		||||
)
 | 
			
		||||
@@ -40,6 +43,22 @@ func main() {
 | 
			
		||||
		log.Err(err).Fatal("Invalid data configuration")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var ca ca.CertificateAuthority
 | 
			
		||||
	if config.CA != nil {
 | 
			
		||||
		if ca, err = config.CA.CertificateAuthority(); err != nil {
 | 
			
		||||
			log.Err(err).Fatal("Invalid ca configuration")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var storage dataset.Storage
 | 
			
		||||
	if storage, err = config.Data.OpenStorage(); err != nil {
 | 
			
		||||
		log.Err(err).Fatal("Invalid data.storage configuration")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	admin := &admin.Admin{
 | 
			
		||||
		Storage: storage,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	proxies, err := config.Proxies(log)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Err(err).Fatal("Error configuring proxy ports")
 | 
			
		||||
@@ -52,6 +71,9 @@ func main() {
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	for i, p := range proxies {
 | 
			
		||||
		p.CertificateAuthority = ca
 | 
			
		||||
		p.Storage = storage
 | 
			
		||||
		admin.Install(p)
 | 
			
		||||
		go run(config.Proxy.Port[i].Listen, p, errs)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -64,12 +86,18 @@ func main() {
 | 
			
		||||
			case syscall.SIGHUP:
 | 
			
		||||
				log.Value("signal", sig.String()).Warn("Ignored reload signal ¯\\_(ツ)_/¯")
 | 
			
		||||
			default:
 | 
			
		||||
				log.Value("signal", sig.String()).Info("Shutting down on signal")
 | 
			
		||||
				return
 | 
			
		||||
				log.Value("signal", sig.String()).Warn("Shutting down on signal")
 | 
			
		||||
				close(done)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		case <-done:
 | 
			
		||||
			log.Info("Shutting down gracefully")
 | 
			
		||||
			log.Warn("Shutting down gracefully")
 | 
			
		||||
			for i, p := range proxies {
 | 
			
		||||
				log.Value("port", config.Proxy.Port[i].Listen).Info("Proxy port closing")
 | 
			
		||||
				if err := p.Close(); err != nil {
 | 
			
		||||
					log.Err(err).Error("Error closing proxy")
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			return
 | 
			
		||||
 | 
			
		||||
		case err = <-errs:
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										1
									
								
								dataset/base.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								dataset/base.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
			
		||||
package dataset
 | 
			
		||||
							
								
								
									
										25
									
								
								dataset/error.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								dataset/error.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,25 @@
 | 
			
		||||
package dataset
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"os"
 | 
			
		||||
 | 
			
		||||
	"github.com/mjl-/bstore"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ErrNotExist struct {
 | 
			
		||||
	Object string
 | 
			
		||||
	ID     int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (err ErrNotExist) Error() string {
 | 
			
		||||
	return fmt.Sprintf("storage: %s not found", err.Object)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func IsNotExist(err error) bool {
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return os.IsNotExist(err) || errors.Is(err, ErrNotExist{}) || errors.Is(err, bstore.ErrAbsent)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										53
									
								
								dataset/parser/adblock.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								dataset/parser/adblock.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,53 @@
 | 
			
		||||
package parser
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"io"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	RegisterDomainsParser(adblockDomainsParser{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type adblockDomainsParser struct{}
 | 
			
		||||
 | 
			
		||||
func (adblockDomainsParser) CanHandle(line string) bool {
 | 
			
		||||
	return strings.HasPrefix(strings.ToLower(line), `[adblock`) ||
 | 
			
		||||
		strings.HasPrefix(line, "@@") || // exception
 | 
			
		||||
		strings.HasPrefix(line, "||") || // blah
 | 
			
		||||
		line[0] == '*'
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (adblockDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
 | 
			
		||||
	scanner := bufio.NewScanner(r)
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		line := strings.TrimSpace(scanner.Text())
 | 
			
		||||
		if isComment(line) {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Common AdBlock patterns:
 | 
			
		||||
		// ||domain.com^
 | 
			
		||||
		// |http://domain.com|
 | 
			
		||||
		// domain.com/path
 | 
			
		||||
		// *domain.com*
 | 
			
		||||
		switch {
 | 
			
		||||
		case strings.HasPrefix(line, `||`): // domain anchor
 | 
			
		||||
			if i := strings.IndexByte(line, '^'); i != -1 {
 | 
			
		||||
				domains = append(domains, line[2:i])
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		case strings.HasPrefix(line, `|`) && strings.HasSuffix(line, `|`):
 | 
			
		||||
			domains = append(domains, line[1:len(line)-2])
 | 
			
		||||
			continue
 | 
			
		||||
		case strings.HasPrefix(line, `[`):
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		ignored++
 | 
			
		||||
	}
 | 
			
		||||
	if err = scanner.Err(); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return unique(domains), ignored, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										41
									
								
								dataset/parser/adblock_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								dataset/parser/adblock_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,41 @@
 | 
			
		||||
package parser
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestAdBlockParser(t *testing.T) {
 | 
			
		||||
	test := `[Adblock Plus 2.0]
 | 
			
		||||
! Title: AdRules DNS List
 | 
			
		||||
! Homepage: https://github.com/Cats-Team/AdRules
 | 
			
		||||
! Powerd by Cats-Team
 | 
			
		||||
! Expires: 1 (update frequency)
 | 
			
		||||
! Description: The DNS Filters
 | 
			
		||||
! Total count: 145270
 | 
			
		||||
! Update: 2025-10-07 02:05:08(GMT+8)
 | 
			
		||||
/^.+stat\.kugou\.com/
 | 
			
		||||
/^admarvel\./
 | 
			
		||||
||*-ad-sign.byteimg.com^
 | 
			
		||||
||*-ad.a.yximgs.com^
 | 
			
		||||
||*-applog.fqnovel.com^
 | 
			
		||||
||*-datareceiver.aki-game.net^
 | 
			
		||||
||*.exaapi.com^`
 | 
			
		||||
	want := []string{"*-ad-sign.byteimg.com", "*-ad.a.yximgs.com", "*-applog.fqnovel.com", "*-datareceiver.aki-game.net", "*.exaapi.com"}
 | 
			
		||||
 | 
			
		||||
	parsed, ignored, err := ParseDomains(strings.NewReader(test))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sort.Strings(parsed)
 | 
			
		||||
	if !reflect.DeepEqual(parsed, want) {
 | 
			
		||||
		t.Errorf("expected ParseDomains(domains) to return %v, got %v", want, parsed)
 | 
			
		||||
	}
 | 
			
		||||
	if ignored != 2 {
 | 
			
		||||
		t.Errorf("expected 2 ignored, got %d", ignored)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										139
									
								
								dataset/parser/dns.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								dataset/parser/dns.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,139 @@
 | 
			
		||||
package parser
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"io"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	RegisterDomainsParser(dnsmasqDomainsParser{})
 | 
			
		||||
	RegisterDomainsParser(mosDNSDomainsParser{})
 | 
			
		||||
	RegisterDomainsParser(smartDNSDomainsParser{})
 | 
			
		||||
	RegisterDomainsParser(unboundDomainsParser{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type dnsmasqDomainsParser struct{}
 | 
			
		||||
 | 
			
		||||
func (dnsmasqDomainsParser) CanHandle(line string) bool {
 | 
			
		||||
	return strings.HasPrefix(line, "address=/")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (dnsmasqDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
 | 
			
		||||
	scanner := bufio.NewScanner(r)
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		line := strings.TrimSpace(scanner.Text())
 | 
			
		||||
		if isComment(line) {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		switch {
 | 
			
		||||
		case strings.HasPrefix(line, "address=/"):
 | 
			
		||||
			part := strings.FieldsFunc(line, func(r rune) bool { return r == '/' })
 | 
			
		||||
			if len(part) >= 3 && isDomainName(part[1]) {
 | 
			
		||||
				domains = append(domains, part[1])
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		ignored++
 | 
			
		||||
	}
 | 
			
		||||
	if err = scanner.Err(); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return unique(domains), ignored, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type mosDNSDomainsParser struct{}
 | 
			
		||||
 | 
			
		||||
func (mosDNSDomainsParser) CanHandle(line string) bool {
 | 
			
		||||
	if strings.HasPrefix(line, "domain:") {
 | 
			
		||||
		return isDomainName(line[7:])
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mosDNSDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
 | 
			
		||||
	scanner := bufio.NewScanner(r)
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		line := strings.TrimSpace(scanner.Text())
 | 
			
		||||
		if isComment(line) {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if strings.HasPrefix(line, "domain:") {
 | 
			
		||||
			domains = append(domains, line[7:])
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		ignored++
 | 
			
		||||
	}
 | 
			
		||||
	if err = scanner.Err(); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return unique(domains), ignored, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type smartDNSDomainsParser struct{}
 | 
			
		||||
 | 
			
		||||
func (smartDNSDomainsParser) CanHandle(line string) bool {
 | 
			
		||||
	return strings.HasPrefix(line, "address /")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (smartDNSDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
 | 
			
		||||
	scanner := bufio.NewScanner(r)
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		line := strings.TrimSpace(scanner.Text())
 | 
			
		||||
		if isComment(line) {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if strings.HasPrefix(line, "address /") {
 | 
			
		||||
			if i := strings.IndexByte(line[9:], '/'); i > -1 {
 | 
			
		||||
				domains = append(domains, line[9:i+9])
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		ignored++
 | 
			
		||||
	}
 | 
			
		||||
	if err = scanner.Err(); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return unique(domains), ignored, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type unboundDomainsParser struct{}
 | 
			
		||||
 | 
			
		||||
func (unboundDomainsParser) CanHandle(line string) bool {
 | 
			
		||||
	return strings.HasPrefix(line, "local-data:") ||
 | 
			
		||||
		strings.HasPrefix(line, "local-zone:")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (unboundDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
 | 
			
		||||
	scanner := bufio.NewScanner(r)
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		line := strings.TrimSpace(scanner.Text())
 | 
			
		||||
		if isComment(line) {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		switch {
 | 
			
		||||
		case strings.HasPrefix(line, "local-data:"):
 | 
			
		||||
			record := strings.Trim(strings.TrimSpace(line[11:]), `"`)
 | 
			
		||||
			if rr, err := dns.NewRR(record); err == nil {
 | 
			
		||||
				switch rr.Header().Rrtype {
 | 
			
		||||
				case dns.TypeA, dns.TypeAAAA, dns.TypeCNAME:
 | 
			
		||||
					domains = append(domains, strings.Trim(rr.Header().Name, `.`))
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		case strings.HasPrefix(line, "local-zone:") && strings.HasSuffix(line, " reject"):
 | 
			
		||||
			line = strings.Trim(strings.TrimSpace(line[11:]), `"`)
 | 
			
		||||
			if i := strings.IndexByte(line, '"'); i > -1 {
 | 
			
		||||
				domains = append(domains, line[:i])
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		ignored++
 | 
			
		||||
	}
 | 
			
		||||
	if err = scanner.Err(); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return unique(domains), ignored, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										106
									
								
								dataset/parser/dns_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								dataset/parser/dns_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,106 @@
 | 
			
		||||
package parser
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestDNSMasqParser(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		Name        string
 | 
			
		||||
		Test        string
 | 
			
		||||
		Want        []string
 | 
			
		||||
		WantIgnored int
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			"data",
 | 
			
		||||
			`
 | 
			
		||||
local-data: "junk1.doubleclick.net A 127.0.0.1"
 | 
			
		||||
local-data: "junk2.doubleclick.net A 127.0.0.1"
 | 
			
		||||
local-data: "junk2.doubleclick.net CNAME doubleclick.net."
 | 
			
		||||
local-data: "junk6.doubleclick.net AAAA ::1"
 | 
			
		||||
local-data: "doubleclick.net A 127.0.0.1"
 | 
			
		||||
local-data: "ad.junk1.doubleclick.net A 127.0.0.1"
 | 
			
		||||
local-data: "adjunk.google.com A 127.0.0.1"`,
 | 
			
		||||
			[]string{"ad.junk1.doubleclick.net", "adjunk.google.com", "doubleclick.net", "junk1.doubleclick.net", "junk2.doubleclick.net", "junk6.doubleclick.net"},
 | 
			
		||||
			0,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"zone",
 | 
			
		||||
			`
 | 
			
		||||
local-zone: "doubleclick.net" reject
 | 
			
		||||
local-zone: "adjunk.google.com" reject`,
 | 
			
		||||
			[]string{"adjunk.google.com", "doubleclick.net"},
 | 
			
		||||
			0,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"address",
 | 
			
		||||
			`
 | 
			
		||||
address=/ziyu.net/0.0.0.0
 | 
			
		||||
address=/zlp6s.pw/0.0.0.0
 | 
			
		||||
address=/zm232.com/0.0.0.0
 | 
			
		||||
			`,
 | 
			
		||||
			[]string{"ziyu.net", "zlp6s.pw", "zm232.com"},
 | 
			
		||||
			0,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	for _, test := range tests {
 | 
			
		||||
		t.Run(test.Name, func(it *testing.T) {
 | 
			
		||||
			parsed, ignored, err := ParseDomains(strings.NewReader(test.Test))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Fatal(err)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			sort.Strings(parsed)
 | 
			
		||||
			if !reflect.DeepEqual(parsed, test.Want) {
 | 
			
		||||
				t.Errorf("expected ParseDomains(dnsmasq) to return\n\t%v, got\n\t%v", test.Want, parsed)
 | 
			
		||||
			}
 | 
			
		||||
			if ignored != test.WantIgnored {
 | 
			
		||||
				t.Errorf("expected %d ignored, got %d", test.WantIgnored, ignored)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMOSDNSParser(t *testing.T) {
 | 
			
		||||
	test := `domain:0019x.com
 | 
			
		||||
domain:002777.xyz
 | 
			
		||||
domain:003store.com
 | 
			
		||||
domain:00404850.xyz`
 | 
			
		||||
	want := []string{"0019x.com", "002777.xyz", "003store.com", "00404850.xyz"}
 | 
			
		||||
 | 
			
		||||
	parsed, _, err := ParseDomains(strings.NewReader(test))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sort.Strings(parsed)
 | 
			
		||||
	if !reflect.DeepEqual(parsed, want) {
 | 
			
		||||
		t.Errorf("expected ParseDomains(domains) to return %v, got %v", want, parsed)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSmartDNSParser(t *testing.T) {
 | 
			
		||||
	test := `# Title:AdRules SmartDNS List
 | 
			
		||||
# Update: 2025-10-07 02:05:08(GMT+8)
 | 
			
		||||
address /0.myikas.com/#
 | 
			
		||||
address /0.net.easyjet.com/#
 | 
			
		||||
address /0.nextyourcontent.com/#
 | 
			
		||||
address /0019x.com/#`
 | 
			
		||||
	want := []string{"0.myikas.com", "0.net.easyjet.com", "0.nextyourcontent.com", "0019x.com"}
 | 
			
		||||
 | 
			
		||||
	parsed, _, err := ParseDomains(strings.NewReader(test))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sort.Strings(parsed)
 | 
			
		||||
	if !reflect.DeepEqual(parsed, want) {
 | 
			
		||||
		t.Errorf("expected ParseDomains(domains) to return %v, got %v", want, parsed)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										40
									
								
								dataset/parser/domains.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								dataset/parser/domains.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,40 @@
 | 
			
		||||
package parser
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	domainsParsers = append(domainsParsers, domainsParser{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type domainsParser struct{}
 | 
			
		||||
 | 
			
		||||
func (domainsParser) CanHandle(line string) bool {
 | 
			
		||||
	return isDomainName(line) &&
 | 
			
		||||
		!strings.ContainsRune(line, ' ') &&
 | 
			
		||||
		!strings.ContainsRune(line, ':') &&
 | 
			
		||||
		net.ParseIP(line) == nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (domainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
 | 
			
		||||
	scanner := bufio.NewScanner(r)
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		line := strings.TrimSpace(scanner.Text())
 | 
			
		||||
		if isComment(line) {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if isDomainName(line) {
 | 
			
		||||
			domains = append(domains, line)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		ignored++
 | 
			
		||||
	}
 | 
			
		||||
	if err = scanner.Err(); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return unique(domains), ignored, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										31
									
								
								dataset/parser/domains_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								dataset/parser/domains_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,31 @@
 | 
			
		||||
package parser
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestParseDomains(t *testing.T) {
 | 
			
		||||
	test := `# This is a comment
 | 
			
		||||
facebook.com
 | 
			
		||||
tiktok.com
 | 
			
		||||
bogus ignored
 | 
			
		||||
youtube.com`
 | 
			
		||||
	want := []string{"facebook.com", "tiktok.com", "youtube.com"}
 | 
			
		||||
 | 
			
		||||
	parsed, ignored, err := ParseDomains(strings.NewReader(test))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sort.Strings(parsed)
 | 
			
		||||
	if !reflect.DeepEqual(parsed, want) {
 | 
			
		||||
		t.Errorf("expected ParseDomains(domains) to return %v, got %v", want, parsed)
 | 
			
		||||
	}
 | 
			
		||||
	if ignored != 1 {
 | 
			
		||||
		t.Errorf("expected 1 ignored, got %d", ignored)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										41
									
								
								dataset/parser/hosts.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								dataset/parser/hosts.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,41 @@
 | 
			
		||||
package parser
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	RegisterDomainsParser(hostsParser{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type hostsParser struct{}
 | 
			
		||||
 | 
			
		||||
func (hostsParser) CanHandle(line string) bool {
 | 
			
		||||
	part := strings.Fields(line)
 | 
			
		||||
	return len(part) >= 2 && net.ParseIP(part[0]) != nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (hostsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
 | 
			
		||||
	scanner := bufio.NewScanner(r)
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		line := strings.TrimSpace(scanner.Text())
 | 
			
		||||
		if isComment(line) {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		part := strings.Fields(line)
 | 
			
		||||
		if len(part) >= 2 && net.ParseIP(part[0]) != nil {
 | 
			
		||||
			domains = append(domains, part[1:]...)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		ignored++
 | 
			
		||||
	}
 | 
			
		||||
	if err = scanner.Err(); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return unique(domains), ignored, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										38
									
								
								dataset/parser/hosts_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								dataset/parser/hosts_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,38 @@
 | 
			
		||||
package parser
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestParseHosts(t *testing.T) {
 | 
			
		||||
	test := `##
 | 
			
		||||
# Host Database
 | 
			
		||||
#
 | 
			
		||||
# localhost is used to configure the loopback interface
 | 
			
		||||
# when the system is booting.  Do not change this entry.
 | 
			
		||||
##
 | 
			
		||||
127.0.0.1       localhost dragon dragon.local dragon.maze.network
 | 
			
		||||
255.255.255.255 broadcasthost
 | 
			
		||||
::1             localhost
 | 
			
		||||
ff00::1         multicast
 | 
			
		||||
1.2.3.4
 | 
			
		||||
`
 | 
			
		||||
	want := []string{"broadcasthost", "dragon", "dragon.local", "dragon.maze.network", "localhost", "multicast"}
 | 
			
		||||
 | 
			
		||||
	parsed, ignored, err := ParseDomains(strings.NewReader(test))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sort.Strings(parsed)
 | 
			
		||||
	if !reflect.DeepEqual(parsed, want) {
 | 
			
		||||
		t.Errorf("expected ParseDomains(hosts) to return %v, got %v", want, parsed)
 | 
			
		||||
	}
 | 
			
		||||
	if ignored != 1 {
 | 
			
		||||
		t.Errorf("expected 1 ignored, got %d", ignored)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										76
									
								
								dataset/parser/parser.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								dataset/parser/parser.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,76 @@
 | 
			
		||||
package parser
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"io"
 | 
			
		||||
	"log"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ErrNoParser = errors.New("no suitable parser could be found")
 | 
			
		||||
 | 
			
		||||
type Parser interface {
 | 
			
		||||
	CanHandle(line string) bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DomainsParser interface {
 | 
			
		||||
	Parser
 | 
			
		||||
	ParseDomains(io.Reader) (domains []string, ignored int, err error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var domainsParsers []DomainsParser
 | 
			
		||||
 | 
			
		||||
func RegisterDomainsParser(parser DomainsParser) {
 | 
			
		||||
	domainsParsers = append(domainsParsers, parser)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
 | 
			
		||||
	var (
 | 
			
		||||
		buffer  = new(bytes.Buffer)
 | 
			
		||||
		scanner = bufio.NewScanner(io.TeeReader(r, buffer))
 | 
			
		||||
		line    string
 | 
			
		||||
		parser  DomainsParser
 | 
			
		||||
	)
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		line = strings.TrimSpace(scanner.Text())
 | 
			
		||||
		if isComment(line) {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		for _, parser = range domainsParsers {
 | 
			
		||||
			if parser.CanHandle(line) {
 | 
			
		||||
				log.Printf("using parser %T", parser)
 | 
			
		||||
				return parser.ParseDomains(io.MultiReader(buffer, r))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		break
 | 
			
		||||
	}
 | 
			
		||||
	return nil, 0, ErrNoParser
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isComment(line string) bool {
 | 
			
		||||
	return line == "" || line[0] == '#' || line[0] == '!'
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isDomainName(name string) bool {
 | 
			
		||||
	n, ok := dns.IsDomainName(name)
 | 
			
		||||
	return n >= 2 && ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func unique(strings []string) []string {
 | 
			
		||||
	if strings == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	v := make(map[string]struct{})
 | 
			
		||||
	for _, s := range strings {
 | 
			
		||||
		v[s] = struct{}{}
 | 
			
		||||
	}
 | 
			
		||||
	o := make([]string, 0, len(v))
 | 
			
		||||
	for k := range v {
 | 
			
		||||
		o = append(o, k)
 | 
			
		||||
	}
 | 
			
		||||
	return o
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										31
									
								
								dataset/parser/parser_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								dataset/parser/parser_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,31 @@
 | 
			
		||||
package parser
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestUnique(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		Name string
 | 
			
		||||
		Test []string
 | 
			
		||||
		Want []string
 | 
			
		||||
	}{
 | 
			
		||||
		{"nil", nil, nil},
 | 
			
		||||
		{"single", []string{"test"}, []string{"test"}},
 | 
			
		||||
		{"duplicate", []string{"test", "test"}, []string{"test"}},
 | 
			
		||||
		{"multiple", []string{"a", "a", "b", "b", "b", "c"}, []string{"a", "b", "c"}},
 | 
			
		||||
	}
 | 
			
		||||
	for _, test := range tests {
 | 
			
		||||
		t.Run(test.Name, func(it *testing.T) {
 | 
			
		||||
			v := unique(test.Test)
 | 
			
		||||
			if v != nil {
 | 
			
		||||
				sort.Strings(v)
 | 
			
		||||
			}
 | 
			
		||||
			if !reflect.DeepEqual(v, test.Want) {
 | 
			
		||||
				it.Errorf("expected unique(%v) to return %v, got %v", test.Test, test.Want, v)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										231
									
								
								dataset/storage.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										231
									
								
								dataset/storage.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,231 @@
 | 
			
		||||
package dataset
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"io/fs"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"os"
 | 
			
		||||
	"slices"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	_ "github.com/mattn/go-sqlite3" // SQLite3 driver
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Storage interface {
 | 
			
		||||
	Groups() ([]Group, error)
 | 
			
		||||
	GroupByID(int64) (Group, error)
 | 
			
		||||
	GroupByName(name string) (Group, error)
 | 
			
		||||
	SaveGroup(*Group) error
 | 
			
		||||
	DeleteGroup(Group) error
 | 
			
		||||
 | 
			
		||||
	Clients() (Clients, error)
 | 
			
		||||
	ClientByID(int64) (Client, error)
 | 
			
		||||
	ClientByIP(net.IP) (Client, error)
 | 
			
		||||
	SaveClient(*Client) error
 | 
			
		||||
	DeleteClient(Client) error
 | 
			
		||||
 | 
			
		||||
	Lists() ([]List, error)
 | 
			
		||||
	ListByID(int64) (List, error)
 | 
			
		||||
	SaveList(*List) error
 | 
			
		||||
	DeleteList(List) error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Group struct {
 | 
			
		||||
	ID          int64     `json:"id"`
 | 
			
		||||
	Name        string    `json:"name" bstore:"nonzero,unique"`
 | 
			
		||||
	IsEnabled   bool      `json:"is_enabled" bstore:"nonzero"`
 | 
			
		||||
	Description string    `json:"description"`
 | 
			
		||||
	CreatedAt   time.Time `json:"created_at" bstore:"nonzero"`
 | 
			
		||||
	UpdatedAt   time.Time `json:"updated_at" bstore:"nonzero"`
 | 
			
		||||
	Storage     Storage   `json:"-" bstore:"-"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Client struct {
 | 
			
		||||
	ID          int64     `json:"id"`
 | 
			
		||||
	Network     string    `json:"network" bstore:"nonzero,index"`
 | 
			
		||||
	IP          string    `json:"ip" bstore:"nonzero,unique IP+Mask"`
 | 
			
		||||
	Mask        int       `json:"mask"`
 | 
			
		||||
	Description string    `json:"description"`
 | 
			
		||||
	Groups      []Group   `json:"groups,omitempty" bstore:"-"`
 | 
			
		||||
	CreatedAt   time.Time `json:"created_at" bstore:"nonzero"`
 | 
			
		||||
	UpdatedAt   time.Time `json:"updated_at" bstore:"nonzero"`
 | 
			
		||||
	Storage     Storage   `json:"-" bstore:"-"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type WithClient interface {
 | 
			
		||||
	Client() (Client, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ClientGroup struct {
 | 
			
		||||
	ID       int64 `json:"id"`
 | 
			
		||||
	ClientID int64 `json:"client_id" bstore:"ref Client,index"`
 | 
			
		||||
	GroupID  int64 `json:"group_id" bstore:"ref Group,index"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) ContainsIP(ip net.IP) bool {
 | 
			
		||||
	ipnet := &net.IPNet{
 | 
			
		||||
		IP:   net.ParseIP(c.IP),
 | 
			
		||||
		Mask: net.CIDRMask(int(c.Mask), 32),
 | 
			
		||||
	}
 | 
			
		||||
	if ipnet.IP == nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return ipnet.Contains(ip)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) String() string {
 | 
			
		||||
	ipnet := &net.IPNet{
 | 
			
		||||
		IP:   net.ParseIP(c.IP),
 | 
			
		||||
		Mask: net.CIDRMask(int(c.Mask), 32),
 | 
			
		||||
	}
 | 
			
		||||
	return ipnet.String()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Clients []Client
 | 
			
		||||
 | 
			
		||||
func (cs Clients) ByIP(ip net.IP) *Client {
 | 
			
		||||
	var candidates []*Client
 | 
			
		||||
	for _, c := range cs {
 | 
			
		||||
		if c.ContainsIP(ip) {
 | 
			
		||||
			candidates = append(candidates, &c)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	switch len(candidates) {
 | 
			
		||||
	case 0:
 | 
			
		||||
		return nil
 | 
			
		||||
	case 1:
 | 
			
		||||
		return candidates[0]
 | 
			
		||||
	default:
 | 
			
		||||
		slices.SortStableFunc(candidates, func(a, b *Client) int {
 | 
			
		||||
			return int(b.Mask) - int(a.Mask)
 | 
			
		||||
		})
 | 
			
		||||
		return candidates[0]
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	ListTypeDomain  = "domain"
 | 
			
		||||
	ListTypeNetwork = "network"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	MinListRefresh     = 1 * time.Minute
 | 
			
		||||
	DefaultListRefresh = 30 * time.Minute
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type List struct {
 | 
			
		||||
	ID           int64         `json:"id"`
 | 
			
		||||
	Type         string        `json:"type"`
 | 
			
		||||
	Source       string        `json:"source"`
 | 
			
		||||
	IsEnabled    bool          `json:"is_enabled"`
 | 
			
		||||
	Permit       bool          `json:"permit"`
 | 
			
		||||
	Groups       []Group       `json:"groups,omitempty" bstore:"-"`
 | 
			
		||||
	Status       int           `json:"status"`
 | 
			
		||||
	Comment      string        `json:"comment"`
 | 
			
		||||
	Cache        []byte        `json:"cache"`
 | 
			
		||||
	Refresh      time.Duration `json:"refresh"`
 | 
			
		||||
	LastModified time.Time     `json:"last_modified"`
 | 
			
		||||
	CreatedAt    time.Time     `json:"created_at"`
 | 
			
		||||
	UpdatedAt    time.Time     `json:"updated_at"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (list *List) Domains() (*DomainTree, error) {
 | 
			
		||||
	if list.Type != ListTypeDomain {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var (
 | 
			
		||||
		tree = NewDomainList()
 | 
			
		||||
		scan = bufio.NewScanner(bytes.NewReader(list.Cache))
 | 
			
		||||
	)
 | 
			
		||||
	for scan.Scan() {
 | 
			
		||||
		line := strings.TrimSpace(scan.Text())
 | 
			
		||||
		if line == "" || line[0] == '#' {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if labels, ok := dns.IsDomainName(line); ok && labels >= 2 {
 | 
			
		||||
			tree.Add(line)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if err := scan.Err(); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return tree, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (list *List) Update() (updated bool, err error) {
 | 
			
		||||
	u, err := url.Parse(list.Source)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch u.Scheme {
 | 
			
		||||
	case "", "file":
 | 
			
		||||
		return list.updateFile(u.Path)
 | 
			
		||||
	case "http", "https":
 | 
			
		||||
		return list.updateHTTP(u.String())
 | 
			
		||||
	default:
 | 
			
		||||
		return false, fmt.Errorf("dataset: don't know how to update %s sources", u.Scheme)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (list *List) updateFile(name string) (updated bool, err error) {
 | 
			
		||||
	var info fs.FileInfo
 | 
			
		||||
	if info, err = os.Stat(name); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	} else if info.IsDir() {
 | 
			
		||||
		return false, fmt.Errorf("dataset: list %d: %q is a directory", list.ID, name)
 | 
			
		||||
	}
 | 
			
		||||
	if updated = info.ModTime().After(list.UpdatedAt); !updated {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	list.Cache, _ = os.ReadFile(name)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (list *List) updateHTTP(url string) (updated bool, err error) {
 | 
			
		||||
	if updated, err = list.shouldUpdateHTTP(url); err != nil || !updated {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var response *http.Response
 | 
			
		||||
	if response, err = http.DefaultClient.Get(url); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	defer response.Body.Close()
 | 
			
		||||
	if list.Cache, err = io.ReadAll(response.Body); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return true, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (list *List) shouldUpdateHTTP(url string) (updated bool, err error) {
 | 
			
		||||
	var response *http.Response
 | 
			
		||||
	if response, err = http.DefaultClient.Head(url); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	defer response.Body.Close()
 | 
			
		||||
 | 
			
		||||
	if value := response.Header.Get("Last-Modified"); value != "" {
 | 
			
		||||
		var lastModified time.Time
 | 
			
		||||
		if lastModified, err = time.Parse(http.TimeFormat, value); err == nil {
 | 
			
		||||
			return lastModified.After(list.LastModified), nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// There are no headers that would indicate last-modified time, so assume we have to update:
 | 
			
		||||
	return true, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ListGroup struct {
 | 
			
		||||
	ID      int64 `json:"id"`
 | 
			
		||||
	ListID  int64 `json:"list_id" bstore:"ref List,index"`
 | 
			
		||||
	GroupID int64 `json:"group_id" bstore:"ref Group,index"`
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										412
									
								
								dataset/storage_bstore.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										412
									
								
								dataset/storage_bstore.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,412 @@
 | 
			
		||||
package dataset
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"slices"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/logger"
 | 
			
		||||
	"github.com/mjl-/bstore"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type bstoreStorage struct {
 | 
			
		||||
	db   *bstore.DB
 | 
			
		||||
	path string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func OpenBStore(name string) (Storage, error) {
 | 
			
		||||
	if !filepath.IsAbs(name) {
 | 
			
		||||
		var err error
 | 
			
		||||
		if name, err = filepath.Abs(name); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	db, err := bstore.Open(ctx, name, nil,
 | 
			
		||||
		Group{},
 | 
			
		||||
		Client{},
 | 
			
		||||
		ClientGroup{},
 | 
			
		||||
		List{},
 | 
			
		||||
		ListGroup{},
 | 
			
		||||
	)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var (
 | 
			
		||||
		s              = &bstoreStorage{db: db, path: name}
 | 
			
		||||
		defaultGroup   Group
 | 
			
		||||
		defaultClient4 Client
 | 
			
		||||
		defaultClient6 Client
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if defaultGroup, err = s.GroupByName("Default"); errors.Is(err, bstore.ErrAbsent) {
 | 
			
		||||
		defaultGroup = Group{
 | 
			
		||||
			Name:        "Default",
 | 
			
		||||
			IsEnabled:   true,
 | 
			
		||||
			Description: "Default group",
 | 
			
		||||
		}
 | 
			
		||||
		if err = s.SaveGroup(&defaultGroup); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	} else if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if defaultClient4, err = bstore.QueryDB[Client](ctx, db).
 | 
			
		||||
		FilterEqual("Network", "ipv4").
 | 
			
		||||
		FilterFn(func(client Client) bool {
 | 
			
		||||
			return net.ParseIP(client.IP).Equal(net.ParseIP("0.0.0.0")) && client.Mask == 0
 | 
			
		||||
		}).Get(); errors.Is(err, bstore.ErrAbsent) {
 | 
			
		||||
		defaultClient4 = Client{
 | 
			
		||||
			Network:     "ipv4",
 | 
			
		||||
			IP:          "0.0.0.0",
 | 
			
		||||
			Mask:        0,
 | 
			
		||||
			Description: "All IPv4 clients",
 | 
			
		||||
		}
 | 
			
		||||
		if err = s.SaveClient(&defaultClient4); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		if err = s.db.Insert(ctx, &ClientGroup{ClientID: defaultClient4.ID, GroupID: defaultGroup.ID}); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	} else if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if defaultClient6, err = bstore.QueryDB[Client](ctx, db).
 | 
			
		||||
		FilterEqual("Network", "ipv6").
 | 
			
		||||
		FilterFn(func(client Client) bool {
 | 
			
		||||
			return net.ParseIP(client.IP).Equal(net.ParseIP("::")) && client.Mask == 0
 | 
			
		||||
		}).Get(); errors.Is(err, bstore.ErrAbsent) {
 | 
			
		||||
		defaultClient6 = Client{
 | 
			
		||||
			Network:     "ipv6",
 | 
			
		||||
			IP:          "::",
 | 
			
		||||
			Mask:        0,
 | 
			
		||||
			Description: "All IPv6 clients",
 | 
			
		||||
		}
 | 
			
		||||
		if err = s.SaveClient(&defaultClient6); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		if err = s.db.Insert(ctx, &ClientGroup{ClientID: defaultClient6.ID, GroupID: defaultGroup.ID}); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	} else if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Start updater
 | 
			
		||||
	NewUpdater(s)
 | 
			
		||||
 | 
			
		||||
	return s, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *bstoreStorage) log() logger.Structured {
 | 
			
		||||
	return logger.StandardLog.Values(logger.Values{
 | 
			
		||||
		"storage":      "bstore",
 | 
			
		||||
		"storage_path": s.path,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *bstoreStorage) Groups() ([]Group, error) {
 | 
			
		||||
	var (
 | 
			
		||||
		ctx    = context.Background()
 | 
			
		||||
		query  = bstore.QueryDB[Group](ctx, s.db)
 | 
			
		||||
		groups = make([]Group, 0)
 | 
			
		||||
	)
 | 
			
		||||
	for group := range query.All() {
 | 
			
		||||
		groups = append(groups, group)
 | 
			
		||||
	}
 | 
			
		||||
	if err := query.Err(); err != nil && !errors.Is(err, bstore.ErrFinished) {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return groups, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *bstoreStorage) GroupByID(id int64) (Group, error) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	return bstore.QueryDB[Group](ctx, s.db).FilterID(id).Get()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *bstoreStorage) GroupByName(name string) (Group, error) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	return bstore.QueryDB[Group](ctx, s.db).FilterFn(func(group Group) bool {
 | 
			
		||||
		return strings.EqualFold(group.Name, name)
 | 
			
		||||
	}).Get()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *bstoreStorage) SaveGroup(group *Group) (err error) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	group.UpdatedAt = time.Now().UTC()
 | 
			
		||||
	if group.CreatedAt.Equal(time.Time{}) {
 | 
			
		||||
		group.CreatedAt = group.UpdatedAt
 | 
			
		||||
		err = s.db.Insert(ctx, group)
 | 
			
		||||
	} else {
 | 
			
		||||
		err = s.db.Update(ctx, group)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("dataset: save group %s failed: %w", group.Name, err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *bstoreStorage) DeleteGroup(group Group) (err error) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	tx, err := s.db.Begin(ctx, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if _, err = bstore.QueryTx[ClientGroup](tx).FilterEqual("GroupID", group.ID).Delete(); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if _, err = bstore.QueryTx[ListGroup](tx).FilterEqual("GroupID", group.ID).Delete(); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if err = tx.Delete(group); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return tx.Commit()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *bstoreStorage) Clients() (Clients, error) {
 | 
			
		||||
	var (
 | 
			
		||||
		ctx     = context.Background()
 | 
			
		||||
		query   = bstore.QueryDB[Client](ctx, s.db)
 | 
			
		||||
		clients = make(Clients, 0)
 | 
			
		||||
	)
 | 
			
		||||
	for client := range query.All() {
 | 
			
		||||
		clients = append(clients, client)
 | 
			
		||||
	}
 | 
			
		||||
	if err := query.Err(); err != nil && !errors.Is(err, bstore.ErrFinished) {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return clients, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *bstoreStorage) ClientByID(id int64) (Client, error) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	client, err := bstore.QueryDB[Client](ctx, s.db).FilterID(id).Get()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return client, err
 | 
			
		||||
	}
 | 
			
		||||
	return s.clientResolveGroups(ctx, client)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *bstoreStorage) ClientByIP(ip net.IP) (Client, error) {
 | 
			
		||||
	if ip == nil {
 | 
			
		||||
		return Client{}, ErrNotExist{Object: "client"}
 | 
			
		||||
	}
 | 
			
		||||
	var (
 | 
			
		||||
		ctx     = context.Background()
 | 
			
		||||
		clients Clients
 | 
			
		||||
		network string
 | 
			
		||||
	)
 | 
			
		||||
	if ip4 := ip.To4(); ip4 != nil {
 | 
			
		||||
		network = "ipv4"
 | 
			
		||||
	} else if ip6 := ip.To16(); ip6 != nil {
 | 
			
		||||
		network = "ipv6"
 | 
			
		||||
	}
 | 
			
		||||
	if network == "" {
 | 
			
		||||
		return Client{}, ErrNotExist{Object: "client"}
 | 
			
		||||
	}
 | 
			
		||||
	for client, err := range bstore.QueryDB[Client](ctx, s.db).
 | 
			
		||||
		FilterEqual("Network", network).
 | 
			
		||||
		FilterFn(func(client Client) bool {
 | 
			
		||||
			return client.ContainsIP(ip)
 | 
			
		||||
		}).All() {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return Client{}, err
 | 
			
		||||
		}
 | 
			
		||||
		clients = append(clients, client)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var client Client
 | 
			
		||||
	switch len(clients) {
 | 
			
		||||
	case 0:
 | 
			
		||||
		return Client{}, ErrNotExist{Object: "client"}
 | 
			
		||||
	case 1:
 | 
			
		||||
		client = clients[0]
 | 
			
		||||
	default:
 | 
			
		||||
		slices.SortStableFunc(clients, func(a, b Client) int {
 | 
			
		||||
			return int(b.Mask) - int(a.Mask)
 | 
			
		||||
		})
 | 
			
		||||
		client = clients[0]
 | 
			
		||||
	}
 | 
			
		||||
	return s.clientResolveGroups(ctx, client)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *bstoreStorage) clientResolveGroups(ctx context.Context, client Client) (Client, error) {
 | 
			
		||||
	for clientGroup, err := range bstore.QueryDB[ClientGroup](ctx, s.db).FilterEqual("ClientID", client.ID).All() {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return Client{}, err
 | 
			
		||||
		}
 | 
			
		||||
		if group, err := s.GroupByID(clientGroup.GroupID); err == nil {
 | 
			
		||||
			client.Groups = append(client.Groups, group)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return client, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *bstoreStorage) SaveClient(client *Client) (err error) {
 | 
			
		||||
	log := s.log()
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	client.UpdatedAt = time.Now().UTC()
 | 
			
		||||
 | 
			
		||||
	tx, err := s.db.Begin(ctx, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log = log.Values(logger.Values{"ip": client.IP, "mask": client.Mask, "description": client.Description})
 | 
			
		||||
	if client.CreatedAt.Equal(time.Time{}) {
 | 
			
		||||
		log.Debug("Create client")
 | 
			
		||||
		client.CreatedAt = client.UpdatedAt
 | 
			
		||||
		if err = tx.Insert(client); err != nil {
 | 
			
		||||
			return fmt.Errorf("dataset: client insert failed: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		log.Debug("Update client")
 | 
			
		||||
		if err = tx.Update(client); err != nil {
 | 
			
		||||
			return fmt.Errorf("dataset: client update failed: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var deleted int
 | 
			
		||||
	if deleted, err = bstore.QueryTx[ClientGroup](tx).FilterEqual("ClientID", client.ID).Delete(); err != nil {
 | 
			
		||||
		return fmt.Errorf("dataset: client groups delete failed: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	log.Debugf("Deleted %d groups", deleted)
 | 
			
		||||
	log.Debugf("Linking %d groups", len(client.Groups))
 | 
			
		||||
	for _, group := range client.Groups {
 | 
			
		||||
		if err = tx.Insert(&ClientGroup{ClientID: client.ID, GroupID: group.ID}); err != nil {
 | 
			
		||||
			return fmt.Errorf("dataset: client groups insert failed: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return tx.Commit()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *bstoreStorage) DeleteClient(client Client) (err error) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	tx, err := s.db.Begin(ctx, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if _, err = bstore.QueryTx[ClientGroup](tx).FilterEqual("ClientID", client.ID).Delete(); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if err = tx.Delete(client); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return tx.Commit()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *bstoreStorage) Lists() ([]List, error) {
 | 
			
		||||
	var (
 | 
			
		||||
		ctx   = context.Background()
 | 
			
		||||
		query = bstore.QueryDB[List](ctx, s.db)
 | 
			
		||||
		lists = make([]List, 0)
 | 
			
		||||
	)
 | 
			
		||||
	for list := range query.All() {
 | 
			
		||||
		lists = append(lists, list)
 | 
			
		||||
	}
 | 
			
		||||
	if err := query.Err(); err != nil && !errors.Is(err, bstore.ErrFinished) {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return lists, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *bstoreStorage) ListByID(id int64) (List, error) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	list, err := bstore.QueryDB[List](ctx, s.db).FilterID(id).Get()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return list, err
 | 
			
		||||
	}
 | 
			
		||||
	return s.listResolveGroups(ctx, list)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *bstoreStorage) listResolveGroups(ctx context.Context, list List) (List, error) {
 | 
			
		||||
	for listGroup, err := range bstore.QueryDB[ListGroup](ctx, s.db).FilterEqual("ListID", list.ID).All() {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return List{}, err
 | 
			
		||||
		}
 | 
			
		||||
		if group, err := s.GroupByID(listGroup.GroupID); err == nil {
 | 
			
		||||
			list.Groups = append(list.Groups, group)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return list, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *bstoreStorage) SaveList(list *List) (err error) {
 | 
			
		||||
	if list.Type != ListTypeDomain && list.Type != ListTypeNetwork {
 | 
			
		||||
		return fmt.Errorf("storage: unknown list type %q", list.Type)
 | 
			
		||||
	}
 | 
			
		||||
	if list.Refresh == 0 {
 | 
			
		||||
		list.Refresh = DefaultListRefresh
 | 
			
		||||
	} else if list.Refresh < MinListRefresh {
 | 
			
		||||
		list.Refresh = MinListRefresh
 | 
			
		||||
	}
 | 
			
		||||
	list.UpdatedAt = time.Now().UTC()
 | 
			
		||||
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	tx, err := s.db.Begin(ctx, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log := s.log()
 | 
			
		||||
	log = log.Values(logger.Values{
 | 
			
		||||
		"type":       list.Type,
 | 
			
		||||
		"source":     list.Source,
 | 
			
		||||
		"is_enabled": list.IsEnabled,
 | 
			
		||||
		"status":     list.Status,
 | 
			
		||||
		"cache":      len(list.Cache),
 | 
			
		||||
		"refresh":    list.Refresh,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if list.CreatedAt.Equal(time.Time{}) {
 | 
			
		||||
		log.Debug("Creating list")
 | 
			
		||||
		list.CreatedAt = list.UpdatedAt
 | 
			
		||||
		if err = tx.Insert(list); err != nil {
 | 
			
		||||
			return fmt.Errorf("dataset: list insert failed: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		log.Debug("Updating list")
 | 
			
		||||
		if err = tx.Update(list); err != nil {
 | 
			
		||||
			return fmt.Errorf("dataset: list update failed: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var deleted int
 | 
			
		||||
	if deleted, err = bstore.QueryTx[ListGroup](tx).FilterEqual("ListID", list.ID).Delete(); err != nil {
 | 
			
		||||
		return fmt.Errorf("dataset: list groups delete failed: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	log.Debugf("Deleted %d groups", deleted)
 | 
			
		||||
	log.Debugf("Linking %d groups", len(list.Groups))
 | 
			
		||||
	for _, group := range list.Groups {
 | 
			
		||||
		if err = tx.Insert(&ListGroup{ListID: list.ID, GroupID: group.ID}); err != nil {
 | 
			
		||||
			return fmt.Errorf("dataset: list groups insert failed: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return tx.Commit()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *bstoreStorage) DeleteList(list List) (err error) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	tx, err := s.db.Begin(ctx, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if _, err = bstore.QueryTx[ListGroup](tx).FilterEqual("ListID", list.ID).Delete(); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if err = tx.Delete(list); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return tx.Commit()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										226
									
								
								dataset/updater.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										226
									
								
								dataset/updater.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,226 @@
 | 
			
		||||
package dataset
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"os"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/logger"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Updater struct {
 | 
			
		||||
	storage  Storage
 | 
			
		||||
	lists    sync.Map // map[int64]List
 | 
			
		||||
	updaters sync.Map // map[int64]*updaterJob
 | 
			
		||||
	done     chan struct{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewUpdater(storage Storage) *Updater {
 | 
			
		||||
	u := &Updater{
 | 
			
		||||
		storage: storage,
 | 
			
		||||
		done:    make(chan struct{}, 1),
 | 
			
		||||
	}
 | 
			
		||||
	go u.refresh()
 | 
			
		||||
	return u
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (u *Updater) Close() error {
 | 
			
		||||
	select {
 | 
			
		||||
	case <-u.done:
 | 
			
		||||
		return nil
 | 
			
		||||
	default:
 | 
			
		||||
		close(u.done)
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (u *Updater) refresh() {
 | 
			
		||||
	check := time.NewTicker(time.Second)
 | 
			
		||||
	defer check.Stop()
 | 
			
		||||
 | 
			
		||||
	var (
 | 
			
		||||
		log = logger.StandardLog
 | 
			
		||||
	)
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-u.done:
 | 
			
		||||
			log.Debug("Updater closing, stopping updaters...")
 | 
			
		||||
			u.updaters.Range(func(key, value any) bool {
 | 
			
		||||
				if value != nil {
 | 
			
		||||
					close(value.(*updaterJob).done)
 | 
			
		||||
				}
 | 
			
		||||
				return true
 | 
			
		||||
			})
 | 
			
		||||
			return
 | 
			
		||||
 | 
			
		||||
		case now := <-check.C:
 | 
			
		||||
			u.check(now, log)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (u *Updater) check(now time.Time, log logger.Structured) (wait time.Duration) {
 | 
			
		||||
	log.Trace("Checking lists")
 | 
			
		||||
	lists, err := u.storage.Lists()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Err(err).Error("Updater can't retrieve lists")
 | 
			
		||||
		return -1
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var missing = make(map[int64]bool)
 | 
			
		||||
	u.lists.Range(func(key, _ any) bool {
 | 
			
		||||
		log.Tracef("List %d has updater running", key)
 | 
			
		||||
		missing[key.(int64)] = true
 | 
			
		||||
		return true
 | 
			
		||||
	})
 | 
			
		||||
	for _, list := range lists {
 | 
			
		||||
		log.Tracef("List %d is active: %t", list.ID, list.IsEnabled)
 | 
			
		||||
		if !list.IsEnabled {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		delete(missing, list.ID)
 | 
			
		||||
		if _, exists := u.lists.Load(list.ID); !exists {
 | 
			
		||||
			u.lists.Store(list.ID, list)
 | 
			
		||||
			updater := newUpdaterJob(u.storage, &list)
 | 
			
		||||
			u.updaters.Store(list.ID, updater)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for id := range missing {
 | 
			
		||||
		log.Tracef("List %d has updater running, but is no longer active, reaping...", id)
 | 
			
		||||
		if updater, ok := u.updaters.Load(id); ok {
 | 
			
		||||
			close(updater.(*updaterJob).done)
 | 
			
		||||
			u.updaters.Delete(id)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type updaterJob struct {
 | 
			
		||||
	storage Storage
 | 
			
		||||
	list    *List
 | 
			
		||||
	done    chan struct{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newUpdaterJob(storage Storage, list *List) *updaterJob {
 | 
			
		||||
	job := &updaterJob{
 | 
			
		||||
		storage: storage,
 | 
			
		||||
		list:    list,
 | 
			
		||||
		done:    make(chan struct{}, 1),
 | 
			
		||||
	}
 | 
			
		||||
	go job.loop()
 | 
			
		||||
	return job
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (job *updaterJob) loop() {
 | 
			
		||||
	var (
 | 
			
		||||
		ticker = time.NewTicker(job.list.Refresh)
 | 
			
		||||
		first  = time.After(0)
 | 
			
		||||
		now    time.Time
 | 
			
		||||
		log    = logger.StandardLog.Values(logger.Values{
 | 
			
		||||
			"list": job.list.ID,
 | 
			
		||||
			"type": job.list.Type,
 | 
			
		||||
		})
 | 
			
		||||
	)
 | 
			
		||||
	defer ticker.Stop()
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-job.done:
 | 
			
		||||
			log.Debug("List updater stopping")
 | 
			
		||||
			return
 | 
			
		||||
 | 
			
		||||
		case now = <-ticker.C:
 | 
			
		||||
		case now = <-first:
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		log.Debug("List updater running")
 | 
			
		||||
		if update, err := job.run(now); err != nil {
 | 
			
		||||
			log.Err(err).Error("List updater failed")
 | 
			
		||||
		} else if update {
 | 
			
		||||
			if err = job.storage.SaveList(job.list); err != nil {
 | 
			
		||||
				log.Err(err).Error("List updater save failed")
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// run this updater
 | 
			
		||||
func (job *updaterJob) run(now time.Time) (update bool, err error) {
 | 
			
		||||
	u, err := url.Parse(job.list.Source)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log := logger.StandardLog.Values(logger.Values{
 | 
			
		||||
		"list":   job.list.ID,
 | 
			
		||||
		"source": job.list.Source,
 | 
			
		||||
	})
 | 
			
		||||
	if u.Scheme == "" || u.Scheme == "file" {
 | 
			
		||||
		log.Debug("Updating list from file")
 | 
			
		||||
		return job.updateFile(u.Path)
 | 
			
		||||
	}
 | 
			
		||||
	log.Debug("Updating list from URL")
 | 
			
		||||
	return job.updateHTTP(u)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (job *updaterJob) updateFile(name string) (update bool, err error) {
 | 
			
		||||
	var b []byte
 | 
			
		||||
	if b, err = os.ReadFile(name); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if update = !bytes.Equal(b, job.list.Cache); update {
 | 
			
		||||
		job.list.Cache = b
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (job *updaterJob) updateHTTP(location *url.URL) (update bool, err error) {
 | 
			
		||||
	if update, err = job.shouldUpdateHTTP(location); err != nil || !update {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var (
 | 
			
		||||
		req *http.Request
 | 
			
		||||
		res *http.Response
 | 
			
		||||
	)
 | 
			
		||||
	if req, err = http.NewRequest(http.MethodGet, location.String(), nil); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if res, err = http.DefaultClient.Do(req); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	defer res.Body.Close()
 | 
			
		||||
 | 
			
		||||
	if job.list.Cache, err = io.ReadAll(res.Body); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return true, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (job *updaterJob) shouldUpdateHTTP(location *url.URL) (update bool, err error) {
 | 
			
		||||
	if len(job.list.Cache) == 0 {
 | 
			
		||||
		// Nothing cached, please update.
 | 
			
		||||
		return true, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var (
 | 
			
		||||
		req *http.Request
 | 
			
		||||
		res *http.Response
 | 
			
		||||
	)
 | 
			
		||||
	if req, err = http.NewRequest(http.MethodHead, location.String(), nil); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if res, err = http.DefaultClient.Do(req); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	defer res.Body.Close()
 | 
			
		||||
 | 
			
		||||
	if lastModified, err := time.Parse(http.TimeFormat, res.Header.Get("Last-Modified")); err == nil {
 | 
			
		||||
		return lastModified.After(job.list.UpdatedAt), nil
 | 
			
		||||
	}
 | 
			
		||||
	return true, nil // not sure, no Last-Modified, so let's update?
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										2
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.mod
									
									
									
									
									
								
							@@ -7,6 +7,7 @@ require (
 | 
			
		||||
	github.com/hashicorp/hcl/v2 v2.24.0
 | 
			
		||||
	github.com/mattn/go-sqlite3 v1.14.32
 | 
			
		||||
	github.com/miekg/dns v1.1.68
 | 
			
		||||
	github.com/mjl-/bstore v0.0.10
 | 
			
		||||
	github.com/open-policy-agent/opa v1.9.0
 | 
			
		||||
	github.com/rs/zerolog v1.34.0
 | 
			
		||||
	github.com/sirupsen/logrus v1.9.4-0.20230606125235-dd1b4c2e81af
 | 
			
		||||
@@ -54,6 +55,7 @@ require (
 | 
			
		||||
	github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
 | 
			
		||||
	github.com/yashtewari/glob-intersection v0.2.0 // indirect
 | 
			
		||||
	github.com/zclconf/go-cty v1.16.3 // indirect
 | 
			
		||||
	go.etcd.io/bbolt v1.4.3 // indirect
 | 
			
		||||
	go.opentelemetry.io/auto/sdk v1.2.1 // indirect
 | 
			
		||||
	go.opentelemetry.io/otel v1.38.0 // indirect
 | 
			
		||||
	go.opentelemetry.io/otel/metric v1.38.0 // indirect
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										4
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								go.sum
									
									
									
									
									
								
							@@ -98,6 +98,8 @@ github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
 | 
			
		||||
github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
 | 
			
		||||
github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0=
 | 
			
		||||
github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0=
 | 
			
		||||
github.com/mjl-/bstore v0.0.10 h1:fYLQy3EdgXvRHoa8Q3sXMAjZf+uQLRbsh9rYjGep/t4=
 | 
			
		||||
github.com/mjl-/bstore v0.0.10/go.mod h1:QzqlAZAVRKwyojCRd9v25viFsMxK5UmIbdxgEyHdK6c=
 | 
			
		||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
 | 
			
		||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
 | 
			
		||||
github.com/open-policy-agent/opa v1.9.0 h1:QWFNwbcc29IRy0xwD3hRrMc/RtSersLY1Z6TaID3vgI=
 | 
			
		||||
@@ -152,6 +154,8 @@ github.com/zclconf/go-cty v1.16.3 h1:osr++gw2T61A8KVYHoQiFbFd1Lh3JOCXc/jFLJXKTxk
 | 
			
		||||
github.com/zclconf/go-cty v1.16.3/go.mod h1:VvMs5i0vgZdhYawQNq5kePSpLAoz8u1xvZgrPIxfnZE=
 | 
			
		||||
github.com/zclconf/go-cty-debug v0.0.0-20240509010212-0d6042c53940 h1:4r45xpDWB6ZMSMNJFMOjqrGHynW3DIBuR2H9j0ug+Mo=
 | 
			
		||||
github.com/zclconf/go-cty-debug v0.0.0-20240509010212-0d6042c53940/go.mod h1:CmBdvvj3nqzfzJ6nTCIwDTPZ56aVGvDrmztiO5g3qrM=
 | 
			
		||||
go.etcd.io/bbolt v1.4.3 h1:dEadXpI6G79deX5prL3QRNP6JB8UxVkqo4UPnHaNXJo=
 | 
			
		||||
go.etcd.io/bbolt v1.4.3/go.mod h1:tKQlpPaYCVFctUIgFKFnAlvbmB3tpy1vkTnDWohtc0E=
 | 
			
		||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
 | 
			
		||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
 | 
			
		||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18=
 | 
			
		||||
 
 | 
			
		||||
@@ -8,6 +8,8 @@ import (
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"syscall"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/logger"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// BufferedConn uses byte buffers for Read and Write operations on a [net.Conn].
 | 
			
		||||
@@ -123,10 +125,13 @@ type AcceptOnce struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (listener *AcceptOnce) Accept() (net.Conn, error) {
 | 
			
		||||
	log := logger.StandardLog.Value("client", listener.Conn.RemoteAddr().String())
 | 
			
		||||
	if listener.once.Load() {
 | 
			
		||||
		log.Trace("Accept already happened, responding EOF")
 | 
			
		||||
		return nil, io.EOF
 | 
			
		||||
	}
 | 
			
		||||
	listener.once.Store(true)
 | 
			
		||||
	log.Trace("Accept client")
 | 
			
		||||
	return listener.Conn, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										74
									
								
								internal/timeutil/time.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								internal/timeutil/time.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,74 @@
 | 
			
		||||
package timeutil
 | 
			
		||||
 | 
			
		||||
import "time"
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	validTimeLayouts = []string{
 | 
			
		||||
		"15:04:05.999999999",
 | 
			
		||||
		"15:04:05",
 | 
			
		||||
		"15:04",
 | 
			
		||||
		"3:04:05PM",
 | 
			
		||||
		"3:04PM",
 | 
			
		||||
		"3PM",
 | 
			
		||||
	}
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Time struct {
 | 
			
		||||
	Hour       int
 | 
			
		||||
	Minute     int
 | 
			
		||||
	Second     int
 | 
			
		||||
	Nanosecond int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ParseTime(value string) (Time, error) {
 | 
			
		||||
	var t time.Time
 | 
			
		||||
	for _, layout := range validTimeLayouts {
 | 
			
		||||
		var err error
 | 
			
		||||
		if t, err = time.Parse(layout, value); err == nil {
 | 
			
		||||
			return Time{
 | 
			
		||||
				Hour:       t.Hour(),
 | 
			
		||||
				Minute:     t.Minute(),
 | 
			
		||||
				Second:     t.Second(),
 | 
			
		||||
				Nanosecond: t.Nanosecond(),
 | 
			
		||||
			}, nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return Time{}, &time.ParseError{
 | 
			
		||||
		Value:   value,
 | 
			
		||||
		Message: "invalid time",
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Now() Time {
 | 
			
		||||
	t := time.Now()
 | 
			
		||||
	return Time{
 | 
			
		||||
		Hour:       t.Hour(),
 | 
			
		||||
		Minute:     t.Minute(),
 | 
			
		||||
		Second:     t.Second(),
 | 
			
		||||
		Nanosecond: t.Nanosecond(),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t Time) After(other Time) bool {
 | 
			
		||||
	return other.Before(t)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t Time) Before(other Time) bool {
 | 
			
		||||
	if t.Hour == other.Hour {
 | 
			
		||||
		if t.Minute == other.Minute {
 | 
			
		||||
			if t.Second == other.Second {
 | 
			
		||||
				return t.Nanosecond < other.Nanosecond
 | 
			
		||||
			}
 | 
			
		||||
			return t.Second < other.Second
 | 
			
		||||
		}
 | 
			
		||||
		return t.Minute < other.Minute
 | 
			
		||||
	}
 | 
			
		||||
	return t.Hour < other.Hour
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t Time) Eq(other Time) bool {
 | 
			
		||||
	return t.Hour == other.Hour &&
 | 
			
		||||
		t.Minute == other.Minute &&
 | 
			
		||||
		t.Second == other.Second &&
 | 
			
		||||
		t.Nanosecond == other.Nanosecond
 | 
			
		||||
}
 | 
			
		||||
@@ -15,17 +15,25 @@ import (
 | 
			
		||||
	"github.com/open-policy-agent/opa/v1/types"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/dataset"
 | 
			
		||||
	"git.maze.io/maze/styx/internal/timeutil"
 | 
			
		||||
	"git.maze.io/maze/styx/logger"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var netLookupIPAddrDecl = types.NewFunction(
 | 
			
		||||
var lookupIPAddrFunc = ®o.Function{
 | 
			
		||||
	Name:             "styx.lookup_ip_addr",
 | 
			
		||||
	Decl:             lookupIPAddrDecl,
 | 
			
		||||
	Memoize:          true,
 | 
			
		||||
	Nondeterministic: true,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var lookupIPAddrDecl = types.NewFunction(
 | 
			
		||||
	types.Args(
 | 
			
		||||
		types.Named("name", types.S).Description("Host name to lookup"),
 | 
			
		||||
	),
 | 
			
		||||
	types.Named("result", types.SetOfStr).Description("set(string) of IP address"),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func netLookupIPAddrImpl(bc rego.BuiltinContext, nameTerm *ast.Term) (*ast.Term, error) {
 | 
			
		||||
func lookupIPAddr(bc rego.BuiltinContext, nameTerm *ast.Term) (*ast.Term, error) {
 | 
			
		||||
	log := logger.StandardLog.Value("func", "styx.lookup_ip_addr")
 | 
			
		||||
	log.Trace("Call function")
 | 
			
		||||
 | 
			
		||||
@@ -61,6 +69,57 @@ func netLookupIPAddrImpl(bc rego.BuiltinContext, nameTerm *ast.Term) (*ast.Term,
 | 
			
		||||
	return ast.SetTerm(terms...), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var timebetweenFunc = ®o.Function{
 | 
			
		||||
	Name:             "styx.time_between",
 | 
			
		||||
	Decl:             timeBetweenDecl,
 | 
			
		||||
	Nondeterministic: false,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var timeBetweenDecl = types.NewFunction(
 | 
			
		||||
	types.Args(
 | 
			
		||||
		types.Named("start", types.S).Description("Start time"),
 | 
			
		||||
		types.Named("end", types.S).Description("End time"),
 | 
			
		||||
	),
 | 
			
		||||
	types.Named("result", types.B).Description("`true` if the current local time is between `start` and `end`"),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func timeBetween(bc rego.BuiltinContext, startTerm, endTerm *ast.Term) (*ast.Term, error) {
 | 
			
		||||
	log := logger.StandardLog.Value("func", "styx.time_between")
 | 
			
		||||
	log.Trace("Call function")
 | 
			
		||||
 | 
			
		||||
	start, err := parseTimeTerm(startTerm)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Err(err).Debug("Invalid start time")
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	end, err := parseTimeTerm(endTerm)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Err(err).Debug("Invalid end time")
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	now := timeutil.Now()
 | 
			
		||||
	if start.Before(end) {
 | 
			
		||||
		return ast.BooleanTerm((now.Eq(start) || now.After(start)) && now.Before(end)), nil
 | 
			
		||||
	}
 | 
			
		||||
	return ast.BooleanTerm(now.Eq(end) || now.After(end) || now.Before(start)), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func parseTimeTerm(term *ast.Term) (timeutil.Time, error) {
 | 
			
		||||
	timeArg, ok := term.Value.(ast.String)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return timeutil.Time{}, errors.New("expected string argument")
 | 
			
		||||
	}
 | 
			
		||||
	return timeutil.ParseTime(strings.Trim(timeArg.String(), `"`))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var domainContainsFunc = ®o.Function{
 | 
			
		||||
	Name:             "styx.domains_contain",
 | 
			
		||||
	Decl:             domainContainsDecl,
 | 
			
		||||
	Memoize:          true,
 | 
			
		||||
	Nondeterministic: true,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var domainContainsDecl = types.NewFunction(
 | 
			
		||||
	types.Args(
 | 
			
		||||
		types.Named("list", types.S).Description("Domain list to check against"),
 | 
			
		||||
@@ -69,8 +128,8 @@ var domainContainsDecl = types.NewFunction(
 | 
			
		||||
	types.Named("result", types.B).Description("`true` if `name` is contained within `list`"),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func domainContainsImpl(bc rego.BuiltinContext, listTerm, nameTerm *ast.Term) (*ast.Term, error) {
 | 
			
		||||
	log := logger.StandardLog.Value("func", "styx.in_domains")
 | 
			
		||||
func domainContains(bc rego.BuiltinContext, listTerm, nameTerm *ast.Term) (*ast.Term, error) {
 | 
			
		||||
	log := logger.StandardLog.Value("func", "styx.domains_contain")
 | 
			
		||||
	log.Trace("Call function")
 | 
			
		||||
 | 
			
		||||
	list, err := parseDomainListTerm(listTerm)
 | 
			
		||||
@@ -91,6 +150,13 @@ func domainContainsImpl(bc rego.BuiltinContext, listTerm, nameTerm *ast.Term) (*
 | 
			
		||||
	return ast.BooleanTerm(list.Contains(name)), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var networkContainsFunc = ®o.Function{
 | 
			
		||||
	Name:             "styx.networks_contain",
 | 
			
		||||
	Decl:             networkContainsDecl,
 | 
			
		||||
	Memoize:          true,
 | 
			
		||||
	Nondeterministic: true,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var networkContainsDecl = types.NewFunction(
 | 
			
		||||
	types.Args(
 | 
			
		||||
		types.Named("list", types.S).Description("Network list to check against"),
 | 
			
		||||
@@ -99,8 +165,8 @@ var networkContainsDecl = types.NewFunction(
 | 
			
		||||
	types.Named("result", types.B).Description("`true` if `ip` is contained within `list`"),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func networkContainsImpl(bc rego.BuiltinContext, listTerm, ipTerm *ast.Term) (*ast.Term, error) {
 | 
			
		||||
	log := logger.StandardLog.Value("func", "styx.in_networks")
 | 
			
		||||
func networkContains(bc rego.BuiltinContext, listTerm, ipTerm *ast.Term) (*ast.Term, error) {
 | 
			
		||||
	log := logger.StandardLog.Value("func", "styx.networks_contain")
 | 
			
		||||
 | 
			
		||||
	list, err := parseNetworkListTerm(listTerm)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,9 +1,12 @@
 | 
			
		||||
package policy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/ca"
 | 
			
		||||
	"git.maze.io/maze/styx/internal/netutil"
 | 
			
		||||
	"git.maze.io/maze/styx/logger"
 | 
			
		||||
	proxy "git.maze.io/maze/styx/proxy"
 | 
			
		||||
@@ -24,6 +27,7 @@ func NewRequestHandler(p *Policy) proxy.RequestHandler {
 | 
			
		||||
			log.Err(err).Error("Error generating response")
 | 
			
		||||
			return nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		log.Debug("Replacing HTTP response from policy")
 | 
			
		||||
		return nil, r
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
@@ -47,21 +51,52 @@ func NewDialHandler(p *Policy) proxy.DialHandler {
 | 
			
		||||
			return nil, nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		c := netutil.NewLoopback()
 | 
			
		||||
		// Create a fake loopback connection
 | 
			
		||||
		pipe := netutil.NewLoopback()
 | 
			
		||||
 | 
			
		||||
		go func(c net.Conn) {
 | 
			
		||||
			s := &http.Server{
 | 
			
		||||
				Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
 | 
			
		||||
					r.Write(w)
 | 
			
		||||
				}),
 | 
			
		||||
			defer func() { _ = c.Close() }()
 | 
			
		||||
			if req.URL.Scheme == "https" || req.URL.Scheme == "wss" || netutil.Port(req.URL.Host) == 443 {
 | 
			
		||||
				c = maybeUpgradeToTLS(c, ctx, req, log)
 | 
			
		||||
			}
 | 
			
		||||
			_ = s.Serve(&netutil.AcceptOnce{Conn: c})
 | 
			
		||||
		}(c.Server)
 | 
			
		||||
 | 
			
		||||
		return c.Client, nil
 | 
			
		||||
			br := bufio.NewReader(c)
 | 
			
		||||
			if _, err := http.ReadRequest(br); err != nil {
 | 
			
		||||
				log.Err(err).Warn("Malformed HTTP request in MITM connection")
 | 
			
		||||
			}
 | 
			
		||||
			_ = r.Write(c)
 | 
			
		||||
		}(pipe.Server)
 | 
			
		||||
 | 
			
		||||
		return pipe.Client, nil
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func maybeUpgradeToTLS(c net.Conn, ctx proxy.Context, req *http.Request, log logger.Structured) net.Conn {
 | 
			
		||||
	var ca ca.CertificateAuthority
 | 
			
		||||
	if caCtx, ok := ctx.(proxy.WithCertificateAuthority); ok {
 | 
			
		||||
		ca = caCtx.CertificateAuthority()
 | 
			
		||||
	}
 | 
			
		||||
	if ca == nil {
 | 
			
		||||
		return c
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	secure := tls.Server(c, &tls.Config{
 | 
			
		||||
		GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
 | 
			
		||||
			log.Values(logger.Values{
 | 
			
		||||
				"cn":    req.URL.Host,
 | 
			
		||||
				"names": hello.ServerName,
 | 
			
		||||
			}).Debug("Requesting certificate from CA")
 | 
			
		||||
			return ca.GetCertificate(netutil.Host(req.URL.Host), []string{hello.ServerName}, nil)
 | 
			
		||||
		},
 | 
			
		||||
		NextProtos: []string{"http/1.1"},
 | 
			
		||||
	})
 | 
			
		||||
	if err := secure.Handshake(); err != nil {
 | 
			
		||||
		log.Err(err).Warn("Failed to pretend secure HTTP")
 | 
			
		||||
		return c
 | 
			
		||||
	}
 | 
			
		||||
	return secure
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewForwardHandler(p *Policy) proxy.ForwardHandler {
 | 
			
		||||
	log := logger.StandardLog.Value("policy", p.name)
 | 
			
		||||
	return proxy.ForwardHandlerFunc(func(ctx proxy.Context, req *http.Request) (*http.Response, error) {
 | 
			
		||||
@@ -72,7 +107,15 @@ func NewForwardHandler(p *Policy) proxy.ForwardHandler {
 | 
			
		||||
			log.Err(err).Error("Error evaulating policy")
 | 
			
		||||
			return nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		return result.Response(ctx)
 | 
			
		||||
		r, err := result.Response(ctx)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Err(err).Error("Error generating response")
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		if r != nil {
 | 
			
		||||
			log.Debug("Replacing HTTP response from policy")
 | 
			
		||||
		}
 | 
			
		||||
		return r, nil
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -80,6 +123,7 @@ func NewResponseHandler(p *Policy) proxy.ResponseHandler {
 | 
			
		||||
	log := logger.StandardLog.Value("policy", p.name)
 | 
			
		||||
	return proxy.ResponseHandlerFunc(func(ctx proxy.Context) *http.Response {
 | 
			
		||||
		input := NewInputFromResponse(ctx, ctx.Response())
 | 
			
		||||
		input.logValues(log).Trace("Running response handler")
 | 
			
		||||
		result, err := p.Query(input)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Err(err).Error("Error evaulating policy")
 | 
			
		||||
@@ -90,6 +134,9 @@ func NewResponseHandler(p *Policy) proxy.ResponseHandler {
 | 
			
		||||
			log.Err(err).Error("Error generating response")
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
		if r != nil {
 | 
			
		||||
			log.Debug("Replacing HTTP response from policy")
 | 
			
		||||
		}
 | 
			
		||||
		return r
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -10,19 +10,26 @@ import (
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strconv"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/dataset"
 | 
			
		||||
	"git.maze.io/maze/styx/internal/netutil"
 | 
			
		||||
	"git.maze.io/maze/styx/logger"
 | 
			
		||||
	proxy "git.maze.io/maze/styx/proxy"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Input represents the input to the policy query.
 | 
			
		||||
type Input struct {
 | 
			
		||||
	Client   *Client   `json:"client"`
 | 
			
		||||
	TLS      *TLS      `json:"tls"`
 | 
			
		||||
	Request  *Request  `json:"request"`
 | 
			
		||||
	Response *Response `json:"response"`
 | 
			
		||||
	Context  map[string]any `json:"context"`
 | 
			
		||||
	Client   *Client        `json:"client"`
 | 
			
		||||
	Groups   []*Group       `json:"groups"`
 | 
			
		||||
	TLS      *TLS           `json:"tls"`
 | 
			
		||||
	Request  *Request       `json:"request"`
 | 
			
		||||
	Response *Response      `json:"response"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (i *Input) logValues(log logger.Structured) logger.Structured {
 | 
			
		||||
	if i.Context != nil {
 | 
			
		||||
		log = log.Values(i.Context)
 | 
			
		||||
	}
 | 
			
		||||
	log = i.Client.logValues(log)
 | 
			
		||||
	log = i.TLS.logValues(log)
 | 
			
		||||
	log = i.Request.logValues(log)
 | 
			
		||||
@@ -34,10 +41,29 @@ func NewInputFromConn(c net.Conn) *Input {
 | 
			
		||||
	if c == nil {
 | 
			
		||||
		return new(Input)
 | 
			
		||||
	}
 | 
			
		||||
	return &Input{
 | 
			
		||||
		Client: NewClientFromConn(c),
 | 
			
		||||
		TLS:    NewTLSFromConn(c),
 | 
			
		||||
 | 
			
		||||
	input := &Input{
 | 
			
		||||
		Context: make(map[string]any),
 | 
			
		||||
		Client:  NewClientFromConn(c),
 | 
			
		||||
		TLS:     NewTLSFromConn(c),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if wcl, ok := c.(dataset.WithClient); ok {
 | 
			
		||||
		client, err := wcl.Client()
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			input.Context["client_id"] = client.ID
 | 
			
		||||
			input.Context["client_description"] = client.Description
 | 
			
		||||
			input.Context["groups"] = client.Groups
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ctx, ok := c.(proxy.Context); ok {
 | 
			
		||||
		input.Context["local"] = NewClientFromAddr(ctx.LocalAddr())
 | 
			
		||||
		input.Context["bytes_rx"] = ctx.BytesRead()
 | 
			
		||||
		input.Context["bytes_tx"] = ctx.BytesSent()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return input
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewInputFromRequest(c net.Conn, r *http.Request) *Input {
 | 
			
		||||
@@ -131,6 +157,10 @@ func NewClientFromAddr(addr net.Addr) *Client {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Group struct {
 | 
			
		||||
	Name string `json:"name"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TLS struct {
 | 
			
		||||
	Version      string         `json:"version"`
 | 
			
		||||
	CipherSuite  string         `json:"cipher_suite"`
 | 
			
		||||
 
 | 
			
		||||
@@ -67,24 +67,10 @@ func newRego(option func(*rego.Rego), pkg string) []func(*rego.Rego) {
 | 
			
		||||
		rego.Query("data." + pkg),
 | 
			
		||||
		rego.Strict(true),
 | 
			
		||||
		rego.Capabilities(capabilities),
 | 
			
		||||
		rego.Function2(®o.Function{
 | 
			
		||||
			Name:             "styx.in_domains",
 | 
			
		||||
			Decl:             domainContainsDecl,
 | 
			
		||||
			Memoize:          true,
 | 
			
		||||
			Nondeterministic: true,
 | 
			
		||||
		}, domainContainsImpl),
 | 
			
		||||
		rego.Function2(®o.Function{
 | 
			
		||||
			Name:             "styx.in_networks",
 | 
			
		||||
			Decl:             networkContainsDecl,
 | 
			
		||||
			Memoize:          true,
 | 
			
		||||
			Nondeterministic: true,
 | 
			
		||||
		}, networkContainsImpl),
 | 
			
		||||
		rego.Function1(®o.Function{
 | 
			
		||||
			Name:             "styx.lookup_ip_addr", // override builtin
 | 
			
		||||
			Decl:             netLookupIPAddrDecl,
 | 
			
		||||
			Memoize:          true,
 | 
			
		||||
			Nondeterministic: true,
 | 
			
		||||
		}, netLookupIPAddrImpl),
 | 
			
		||||
		rego.Function2(domainContainsFunc, domainContains),
 | 
			
		||||
		rego.Function2(networkContainsFunc, networkContains),
 | 
			
		||||
		rego.Function1(lookupIPAddrFunc, lookupIPAddr),
 | 
			
		||||
		rego.Function2(timebetweenFunc, timeBetween),
 | 
			
		||||
		rego.PrintHook(printHook{}),
 | 
			
		||||
		option,
 | 
			
		||||
	}
 | 
			
		||||
@@ -128,16 +114,20 @@ func (r *Result) Response(ctx proxy.Context) (*http.Response, error) {
 | 
			
		||||
 | 
			
		||||
	switch {
 | 
			
		||||
	case r.Redirect != "":
 | 
			
		||||
		log.Value("location", r.Redirect).Trace("Creating a HTTP redirect response")
 | 
			
		||||
		response := proxy.NewResponse(http.StatusFound, nil, ctx.Request())
 | 
			
		||||
		response.Header.Set("Server", "styx")
 | 
			
		||||
		response.Header.Set(proxy.HeaderLocation, r.Redirect)
 | 
			
		||||
		return response, nil
 | 
			
		||||
 | 
			
		||||
	case r.Template != "":
 | 
			
		||||
		log = log.Value("template", r.Template)
 | 
			
		||||
		log.Trace("Creating a HTTP template response")
 | 
			
		||||
 | 
			
		||||
		b := new(bytes.Buffer)
 | 
			
		||||
		t, err := template.New(filepath.Base(r.Template)).ParseFiles(r.Template)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Value("template", r.Template).Err(err).Warn("Error loading template in response")
 | 
			
		||||
			log.Err(err).Warn("Error loading template in response")
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		t = t.Funcs(template.FuncMap{
 | 
			
		||||
@@ -149,7 +139,7 @@ func (r *Result) Response(ctx proxy.Context) (*http.Response, error) {
 | 
			
		||||
			"Response": ctx.Response(),
 | 
			
		||||
			"Errors":   r.Errors,
 | 
			
		||||
		}); err != nil {
 | 
			
		||||
			log.Value("template", r.Template).Err(err).Warn("Error rendering template response")
 | 
			
		||||
			log.Err(err).Warn("Error rendering template response")
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@@ -159,46 +149,34 @@ func (r *Result) Response(ctx proxy.Context) (*http.Response, error) {
 | 
			
		||||
		return response, nil
 | 
			
		||||
 | 
			
		||||
	case r.Reject > 0:
 | 
			
		||||
		log.Value("code", r.Reject).Trace("Creating a HTTP reject response")
 | 
			
		||||
		body := io.NopCloser(bytes.NewBufferString(http.StatusText(r.Reject)))
 | 
			
		||||
		response := proxy.NewResponse(r.Reject, body, ctx.Request())
 | 
			
		||||
		response.Header.Set(proxy.HeaderContentType, "text/plain")
 | 
			
		||||
		return response, nil
 | 
			
		||||
 | 
			
		||||
	case r.Permit != nil && !*r.Permit:
 | 
			
		||||
		log.Trace("Creating a HTTP reject response due to explicit not permit")
 | 
			
		||||
		body := io.NopCloser(bytes.NewBufferString(http.StatusText(http.StatusForbidden)))
 | 
			
		||||
		response := proxy.NewResponse(http.StatusForbidden, body, ctx.Request())
 | 
			
		||||
		response.Header.Set(proxy.HeaderContentType, "text/plain")
 | 
			
		||||
		return response, nil
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
		log.Trace("Not creating a HTTP response")
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Policy) Query(input *Input) (*Result, error) {
 | 
			
		||||
	/*
 | 
			
		||||
		e := json.NewEncoder(os.Stdout)
 | 
			
		||||
		e.SetIndent("", "  ")
 | 
			
		||||
		e.Encode(doc)
 | 
			
		||||
	*/
 | 
			
		||||
 | 
			
		||||
	log := logger.StandardLog.Value("policy", p.name)
 | 
			
		||||
	log.Trace("Evaluating policy")
 | 
			
		||||
 | 
			
		||||
	r := rego.New(append(p.options, rego.Input(input))...)
 | 
			
		||||
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	/*
 | 
			
		||||
		query, err := p.rego.PrepareForEval(ctx)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		rs, err := query.Eval(ctx, rego.EvalInput(input))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	*/
 | 
			
		||||
	rs, err := r.Eval(ctx)
 | 
			
		||||
	var (
 | 
			
		||||
		rego    = rego.New(append(p.options, rego.Input(input))...)
 | 
			
		||||
		ctx     = context.Background()
 | 
			
		||||
		rs, err = rego.Eval(ctx)
 | 
			
		||||
	)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -208,6 +186,12 @@ func (p *Policy) Query(input *Input) (*Result, error) {
 | 
			
		||||
	result := &Result{}
 | 
			
		||||
	for _, expr := range rs[0].Expressions {
 | 
			
		||||
		if m, ok := expr.Value.(map[string]any); ok {
 | 
			
		||||
			// Remove private variables.
 | 
			
		||||
			for k := range m {
 | 
			
		||||
				if len(k) > 0 && k[0] == '_' {
 | 
			
		||||
					delete(m, k)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			log.Values(m).Trace("Policy result expression")
 | 
			
		||||
			if err = mapstructure.Decode(m, result); err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
 
 | 
			
		||||
@@ -14,6 +14,8 @@ import (
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/ca"
 | 
			
		||||
	"git.maze.io/maze/styx/dataset"
 | 
			
		||||
	"git.maze.io/maze/styx/logger"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -42,6 +44,13 @@ type Context interface {
 | 
			
		||||
 | 
			
		||||
	// Response is the response that will be sent back to the client.
 | 
			
		||||
	Response() *http.Response
 | 
			
		||||
 | 
			
		||||
	// Client group.
 | 
			
		||||
	Client() (dataset.Client, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type WithCertificateAuthority interface {
 | 
			
		||||
	CertificateAuthority() ca.CertificateAuthority
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type countingReader struct {
 | 
			
		||||
@@ -80,6 +89,9 @@ type proxyContext struct {
 | 
			
		||||
	req            *http.Request
 | 
			
		||||
	res            *http.Response
 | 
			
		||||
	idleTimeout    time.Duration
 | 
			
		||||
	ca             ca.CertificateAuthority
 | 
			
		||||
	storage        dataset.Storage
 | 
			
		||||
	client         dataset.Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewContext returns an initialized context for the provided [net.Conn].
 | 
			
		||||
@@ -218,4 +230,28 @@ func (c *proxyContext) WriteHeader(code int) {
 | 
			
		||||
	//return c.res.Header.Write(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *proxyContext) CertificateAuthority() ca.CertificateAuthority {
 | 
			
		||||
	return c.ca
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *proxyContext) Client() (dataset.Client, error) {
 | 
			
		||||
	if c.storage == nil {
 | 
			
		||||
		return dataset.Client{}, dataset.ErrNotExist{Object: "client"}
 | 
			
		||||
	}
 | 
			
		||||
	if !c.client.CreatedAt.Equal(time.Time{}) {
 | 
			
		||||
		return c.client, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var err error
 | 
			
		||||
	switch addr := c.Conn.RemoteAddr().(type) {
 | 
			
		||||
	case *net.TCPAddr:
 | 
			
		||||
		c.client, err = c.storage.ClientByIP(addr.IP)
 | 
			
		||||
	case *net.UDPAddr:
 | 
			
		||||
		c.client, err = c.storage.ClientByIP(addr.IP)
 | 
			
		||||
	default:
 | 
			
		||||
		err = dataset.ErrNotExist{Object: "client"}
 | 
			
		||||
	}
 | 
			
		||||
	return c.client, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Context = (*proxyContext)(nil)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										121
									
								
								proxy/proxy.go
									
									
									
									
									
								
							
							
						
						
									
										121
									
								
								proxy/proxy.go
									
									
									
									
									
								
							@@ -15,9 +15,12 @@ import (
 | 
			
		||||
	"slices"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"syscall"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/ca"
 | 
			
		||||
	"git.maze.io/maze/styx/dataset"
 | 
			
		||||
	"git.maze.io/maze/styx/internal/netutil"
 | 
			
		||||
	"git.maze.io/maze/styx/logger"
 | 
			
		||||
	"git.maze.io/maze/styx/stats"
 | 
			
		||||
@@ -26,6 +29,7 @@ import (
 | 
			
		||||
// Common HTTP headers.
 | 
			
		||||
const (
 | 
			
		||||
	HeaderConnection     = "Connection"
 | 
			
		||||
	HeaderContentLength  = "Content-Length"
 | 
			
		||||
	HeaderContentType    = "Content-Type"
 | 
			
		||||
	HeaderDate           = "Date"
 | 
			
		||||
	HeaderForwarded      = "Forwarded"
 | 
			
		||||
@@ -146,7 +150,17 @@ type Proxy struct {
 | 
			
		||||
	// WebSocketIdleTimeout is the timeout for idle WebSocket connections.
 | 
			
		||||
	WebSocketIdleTimeout time.Duration
 | 
			
		||||
 | 
			
		||||
	mux *http.ServeMux
 | 
			
		||||
	// CertificateAuthority can issue certificates for man-in-the-middle connections.
 | 
			
		||||
	CertificateAuthority ca.CertificateAuthority
 | 
			
		||||
 | 
			
		||||
	// Storage for resolving clients/groups
 | 
			
		||||
	Storage dataset.Storage
 | 
			
		||||
 | 
			
		||||
	mux       *http.ServeMux
 | 
			
		||||
	closed    chan struct{}
 | 
			
		||||
	closeOnce sync.Once
 | 
			
		||||
	mu        sync.RWMutex
 | 
			
		||||
	listeners []net.Listener
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// New [Proxy] with somewhat sane defaults.
 | 
			
		||||
@@ -157,6 +171,7 @@ func New() *Proxy {
 | 
			
		||||
		IdleTimeout:          DefaultIdleTimeout,
 | 
			
		||||
		WebSocketIdleTimeout: DefaultWebSocketIdleTimeout,
 | 
			
		||||
		mux:                  http.NewServeMux(),
 | 
			
		||||
		closed:               make(chan struct{}, 1),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Make sure the roundtripper uses our dialers.
 | 
			
		||||
@@ -181,6 +196,55 @@ func New() *Proxy {
 | 
			
		||||
	return p
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Proxy) Close() error {
 | 
			
		||||
	var closeListeners bool
 | 
			
		||||
	p.closeOnce.Do(func() {
 | 
			
		||||
		close(p.closed)
 | 
			
		||||
		closeListeners = true
 | 
			
		||||
	})
 | 
			
		||||
	if closeListeners {
 | 
			
		||||
		p.mu.RLock()
 | 
			
		||||
		for _, l := range p.listeners {
 | 
			
		||||
			_ = l.Close()
 | 
			
		||||
		}
 | 
			
		||||
		p.mu.RUnlock()
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Proxy) isClosed() bool {
 | 
			
		||||
	select {
 | 
			
		||||
	case <-p.closed:
 | 
			
		||||
		return true
 | 
			
		||||
	default:
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Proxy) addListener(l net.Listener) {
 | 
			
		||||
	if l == nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	p.mu.Lock()
 | 
			
		||||
	p.listeners = append(p.listeners, l)
 | 
			
		||||
	p.mu.Unlock()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Proxy) removeListener(l net.Listener) {
 | 
			
		||||
	if l == nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	p.mu.Lock()
 | 
			
		||||
	listeners := make([]net.Listener, 0, len(p.listeners)-1)
 | 
			
		||||
	for _, o := range p.listeners {
 | 
			
		||||
		if o != l {
 | 
			
		||||
			listeners = append(listeners, o)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	p.listeners = listeners
 | 
			
		||||
	p.mu.Unlock()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Handle installs a [http.Handler] into the internal mux.
 | 
			
		||||
func (p *Proxy) Handle(pattern string, handler http.Handler) {
 | 
			
		||||
	p.mux.Handle(pattern, handler)
 | 
			
		||||
@@ -214,11 +278,23 @@ func (p *Proxy) dial(ctx context.Context, req *http.Request) (net.Conn, error) {
 | 
			
		||||
 | 
			
		||||
// Serve proxied connections on the specified listener.
 | 
			
		||||
func (p *Proxy) Serve(l net.Listener) error {
 | 
			
		||||
	p.addListener(l)
 | 
			
		||||
	defer p.removeListener(l)
 | 
			
		||||
	for {
 | 
			
		||||
		if p.isClosed() {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		c, err := l.Accept()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if p.isClosed() {
 | 
			
		||||
			_ = c.Close()
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		go p.handle(c)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -229,6 +305,7 @@ func (p *Proxy) handle(nc net.Conn) {
 | 
			
		||||
		ctx   = NewContext(nc).(*proxyContext)
 | 
			
		||||
		err   error
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if r := recover(); r != nil {
 | 
			
		||||
			if err, ok := r.(error); ok {
 | 
			
		||||
@@ -266,6 +343,8 @@ func (p *Proxy) handle(nc net.Conn) {
 | 
			
		||||
 | 
			
		||||
	// Propagate timeouts
 | 
			
		||||
	ctx.SetIdleTimeout(p.IdleTimeout)
 | 
			
		||||
	ctx.ca = p.CertificateAuthority
 | 
			
		||||
	ctx.storage = p.Storage
 | 
			
		||||
 | 
			
		||||
	for _, f := range p.OnConnect {
 | 
			
		||||
		fc, err := f.HandleConn(ctx)
 | 
			
		||||
@@ -282,6 +361,15 @@ func (p *Proxy) handle(nc net.Conn) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log := ctx.LogEntry()
 | 
			
		||||
	if p.Storage != nil {
 | 
			
		||||
		if client, err := p.Storage.ClientByIP(nc.RemoteAddr().(*net.TCPAddr).IP); err == nil {
 | 
			
		||||
			log = log.Values(logger.Values{
 | 
			
		||||
				"client_id":          client.ID,
 | 
			
		||||
				"client_network":     client.String(),
 | 
			
		||||
				"client_description": client.Description,
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	for {
 | 
			
		||||
		if ctx.transparentTLS {
 | 
			
		||||
			ctx.req = &http.Request{
 | 
			
		||||
@@ -344,7 +432,7 @@ func (p *Proxy) handle(nc net.Conn) {
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if err = p.handleRequest(ctx); err != nil {
 | 
			
		||||
			p.handleError(ctx, err, true)
 | 
			
		||||
			p.handleError(ctx, err, !netutil.IsClosing(err))
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@@ -511,7 +599,8 @@ func (p *Proxy) serveForward(ctx *proxyContext) (err error) {
 | 
			
		||||
			_ = ctx.Close()
 | 
			
		||||
			return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
	}
 | 
			
		||||
	if res != nil {
 | 
			
		||||
		ctx.res = res
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -571,28 +660,44 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
 | 
			
		||||
	return p.multiplex(ctx, srv)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Proxy) multiplex(ctx, srv Context) (err error) {
 | 
			
		||||
func (p *Proxy) multiplex(ctx, srv *proxyContext) (err error) {
 | 
			
		||||
	var (
 | 
			
		||||
		log  = ctx.LogEntry().Value("server", srv.RemoteAddr().String())
 | 
			
		||||
		errs = make(chan error, 1)
 | 
			
		||||
		done = make(chan struct{}, 1)
 | 
			
		||||
	)
 | 
			
		||||
	go func(errs chan<- error) {
 | 
			
		||||
		defer close(done)
 | 
			
		||||
		if _, err := io.Copy(srv, ctx); err != nil {
 | 
			
		||||
		if _, err := io.Copy(ctx, srv); err != nil && !netutil.IsClosing(err) {
 | 
			
		||||
			log.Err(err).Trace("Multiplexing closed in client->server")
 | 
			
		||||
			errs <- err
 | 
			
		||||
		} else {
 | 
			
		||||
			log.Trace("Multiplexing closed in client->server")
 | 
			
		||||
		}
 | 
			
		||||
	}(errs)
 | 
			
		||||
 | 
			
		||||
	go func(errs chan<- error) {
 | 
			
		||||
		if _, err := io.Copy(ctx, srv); err != nil {
 | 
			
		||||
		defer close(done)
 | 
			
		||||
		if _, err := io.Copy(srv, ctx); err != nil && !netutil.IsClosing(err) {
 | 
			
		||||
			log.Err(err).Trace("Multiplexing closed in server->client")
 | 
			
		||||
			errs <- err
 | 
			
		||||
		} else {
 | 
			
		||||
			log.Trace("Multiplexing closed in server->client")
 | 
			
		||||
		}
 | 
			
		||||
	}(errs)
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		log.Trace("Multiplexing done, force-closing client and server connections")
 | 
			
		||||
		_ = ctx.Close()
 | 
			
		||||
		_ = srv.Close()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	select {
 | 
			
		||||
	case err = <-errs:
 | 
			
		||||
		return
 | 
			
		||||
	case <-done:
 | 
			
		||||
		return
 | 
			
		||||
		return io.EOF // multiplexing never recycles connection
 | 
			
		||||
	case <-p.closed:
 | 
			
		||||
		return io.EOF // server closed
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										213
									
								
								stats/handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										213
									
								
								stats/handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,213 @@
 | 
			
		||||
package stats
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"expvar"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"html/template"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	page = template.Must(template.New("").
 | 
			
		||||
		Funcs(template.FuncMap{"path": path, "duration": duration}).
 | 
			
		||||
		Parse(`<!DOCTYPE html>
 | 
			
		||||
<html lang="us">
 | 
			
		||||
<meta charset="utf-8">
 | 
			
		||||
<title>Metrics report</title>
 | 
			
		||||
<meta name="viewport" content="width=device-width">
 | 
			
		||||
<style>
 | 
			
		||||
* { margin: 0; padding: 0; box-sizing: border-box; font-family: monospace; font-size: 12px; }
 | 
			
		||||
.container {
 | 
			
		||||
	max-width: 640px;
 | 
			
		||||
	margin: 1em auto;
 | 
			
		||||
	display: flex;
 | 
			
		||||
	flex-direction: column;
 | 
			
		||||
	padding: 0 1em;
 | 
			
		||||
}
 | 
			
		||||
h1 { text-align: center; }
 | 
			
		||||
h2 {
 | 
			
		||||
	font-weight: normal;
 | 
			
		||||
	text-overflow: ellipsis;
 | 
			
		||||
	white-space: nowrap;
 | 
			
		||||
	overflow: hidden;
 | 
			
		||||
}
 | 
			
		||||
.metric {
 | 
			
		||||
	padding: 1em 0;
 | 
			
		||||
	border-top: 1px solid rgba(0,0,0,0.33);
 | 
			
		||||
}
 | 
			
		||||
.row {
 | 
			
		||||
	display: flex;
 | 
			
		||||
	flex-direction: row;
 | 
			
		||||
	align-items: center;
 | 
			
		||||
	margin: 0.25em 0;
 | 
			
		||||
}
 | 
			
		||||
.col-1 { flex: 1; }
 | 
			
		||||
.col-2 { flex: 2.5; }
 | 
			
		||||
.table { width: 100px; border-radius: 2px; border: 1px solid rgba(0,0,0,0.33); }
 | 
			
		||||
.table td, .table th { text-align: center; }
 | 
			
		||||
.timeline { padding: 0 0.5em; }
 | 
			
		||||
path { fill: none; stroke: rgba(0,0,0,0.33); stroke-width: 1; stroke-linecap: round; stroke-linejoin: round; }
 | 
			
		||||
path:last-child { stroke: black; }
 | 
			
		||||
</style>
 | 
			
		||||
<body>
 | 
			
		||||
<div class="container">
 | 
			
		||||
<div><h1><pre>    __          __
 | 
			
		||||
.--------..-----.|  |_ .----.|__|.----..-----.
 | 
			
		||||
|        ||  -__||   _||   _||  ||  __||__ --|
 | 
			
		||||
|__|__|__||_____||____||__|  |__||____||_____|
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
</pre></h1></div>
 | 
			
		||||
{{ range . }}
 | 
			
		||||
	<div class="row metric">
 | 
			
		||||
	  <h2 class="col-1">{{ .name }}</h2>
 | 
			
		||||
		<div class="col-2">
 | 
			
		||||
		{{ if .type }}
 | 
			
		||||
			<div class="row">
 | 
			
		||||
				{{ template "table" . }}
 | 
			
		||||
				<div class="col-1"></div>
 | 
			
		||||
			</div>
 | 
			
		||||
		{{ else if .interval }}
 | 
			
		||||
			<div class="row">{{ template "timeseries" . }}</div>
 | 
			
		||||
		{{ else if .metrics}}
 | 
			
		||||
			{{ range .metrics }}
 | 
			
		||||
				<div class="row">
 | 
			
		||||
				{{ template "timeseries" . }}
 | 
			
		||||
				</div>
 | 
			
		||||
			{{ end }}
 | 
			
		||||
		{{ end }}
 | 
			
		||||
		</div>
 | 
			
		||||
  </div>
 | 
			
		||||
{{ end }}
 | 
			
		||||
</div>
 | 
			
		||||
</body>
 | 
			
		||||
</html>
 | 
			
		||||
{{ define "table" }}
 | 
			
		||||
<table class="table col-1">
 | 
			
		||||
	{{ if eq .type "c" }}
 | 
			
		||||
		<thead><tr><th>count</th></tr></thead><tbody><tr><td>{{ printf "%.2g" .count }}</td></tr></tbody>
 | 
			
		||||
	{{ else if eq .type "g" }}
 | 
			
		||||
		<thead><tr><th>mean</th><th>min</th><th>max</th></tr></thead>
 | 
			
		||||
		<tbody><tr><td>{{printf "%.2g" .mean}}</td><td>{{printf "%.2g" .min}}</td><td>{{printf "%.2g" .max}}</td></th></tbody>
 | 
			
		||||
	{{ else if eq .type "h" }}
 | 
			
		||||
		<thead><tr><th>P.50</th><th>P.90</th><th>P.99</th></tr></thead>
 | 
			
		||||
		<tbody><tr><td>{{printf "%.2g" .p50}}</td><td>{{printf "%.2g" .p90}}</td><td>{{printf "%.2g" .p99}}</td></tr></tbody>
 | 
			
		||||
	{{ end }}
 | 
			
		||||
</table>
 | 
			
		||||
{{ end }}
 | 
			
		||||
{{ define "timeseries" }}
 | 
			
		||||
  {{ template "table" .total }}
 | 
			
		||||
	<div class="col-1">
 | 
			
		||||
		<div class="row">
 | 
			
		||||
			<div class="timeline">{{ duration .samples .interval }}</div>
 | 
			
		||||
			<svg class="col-1" version="1.1" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 100 20">
 | 
			
		||||
			{{ if eq (index (index .samples 0) "type") "c" }}
 | 
			
		||||
				{{ range (path .samples "count") }}<path d={{ . }} />{{end}}
 | 
			
		||||
			{{ else if eq (index (index .samples 0) "type") "g" }}
 | 
			
		||||
				{{ range (path .samples "min" "max" "mean" ) }}<path d={{ . }} />{{end}}
 | 
			
		||||
			{{ else if eq (index (index .samples 0) "type") "h" }}
 | 
			
		||||
				{{ range (path .samples "p50" "p90" "p99") }}<path d={{ . }} />{{end}}
 | 
			
		||||
			{{ end }}
 | 
			
		||||
			</svg>
 | 
			
		||||
		</div>
 | 
			
		||||
	</div>
 | 
			
		||||
{{ end }}
 | 
			
		||||
`))
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func path(samples []any, keys ...string) []string {
 | 
			
		||||
	var min, max float64
 | 
			
		||||
	paths := make([]string, len(keys))
 | 
			
		||||
	for i := range len(samples) {
 | 
			
		||||
		s := samples[i].(map[string]any)
 | 
			
		||||
		for _, k := range keys {
 | 
			
		||||
			x := s[k].(float64)
 | 
			
		||||
			if i == 0 || x < min {
 | 
			
		||||
				min = x
 | 
			
		||||
			}
 | 
			
		||||
			if i == 0 || x > max {
 | 
			
		||||
				max = x
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	for i := range len(samples) {
 | 
			
		||||
		s := samples[i].(map[string]any)
 | 
			
		||||
		for j, k := range keys {
 | 
			
		||||
			v := s[k].(float64)
 | 
			
		||||
			x := float64(i+1) / float64(len(samples))
 | 
			
		||||
			y := (v - min) / (max - min)
 | 
			
		||||
			if max == min {
 | 
			
		||||
				y = 0
 | 
			
		||||
			}
 | 
			
		||||
			if i == 0 {
 | 
			
		||||
				paths[j] = fmt.Sprintf("M%f %f", 0.0, (1-y)*18+1)
 | 
			
		||||
			}
 | 
			
		||||
			paths[j] += fmt.Sprintf(" L%f %f", x*100, (1-y)*18+1)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return paths
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func duration(samples []any, n float64) string {
 | 
			
		||||
	n = n * float64(len(samples))
 | 
			
		||||
	if n < 60 {
 | 
			
		||||
		return fmt.Sprintf("%d sec", int(n))
 | 
			
		||||
	} else if n < 60*60 {
 | 
			
		||||
		return fmt.Sprintf("%d min", int(n/60))
 | 
			
		||||
	} else if n < 24*60*60 {
 | 
			
		||||
		return fmt.Sprintf("%d hrs", int(n/60/60))
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf("%d days", int(n/24/60/60))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Handler returns an http.Handler that renders web UI for all provided metrics.
 | 
			
		||||
func Handler(snapshot func() map[string]Metric) http.Handler {
 | 
			
		||||
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		type h map[string]any
 | 
			
		||||
		metrics := []h{}
 | 
			
		||||
		for name, metric := range snapshot() {
 | 
			
		||||
			m := h{}
 | 
			
		||||
			b, _ := json.Marshal(metric)
 | 
			
		||||
			json.Unmarshal(b, &m)
 | 
			
		||||
			m["name"] = name
 | 
			
		||||
			metrics = append(metrics, m)
 | 
			
		||||
		}
 | 
			
		||||
		sort.Slice(metrics, func(i, j int) bool {
 | 
			
		||||
			n1 := metrics[i]["name"].(string)
 | 
			
		||||
			n2 := metrics[j]["name"].(string)
 | 
			
		||||
			return strings.Compare(n1, n2) < 0
 | 
			
		||||
		})
 | 
			
		||||
		page.Execute(w, metrics)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// JSONHandler returns a [http.Handler] that renders the metrics as JSON.
 | 
			
		||||
func JSONHandler(snapshot func() map[string]Metric) http.Handler {
 | 
			
		||||
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		type h map[string]any
 | 
			
		||||
		metrics := map[string]h{}
 | 
			
		||||
		for name, metric := range snapshot() {
 | 
			
		||||
			m := h{}
 | 
			
		||||
			b, _ := json.Marshal(metric)
 | 
			
		||||
			json.Unmarshal(b, &m)
 | 
			
		||||
			metrics[name] = m
 | 
			
		||||
		}
 | 
			
		||||
		w.Header().Set("Content-Type", "application/json")
 | 
			
		||||
		json.NewEncoder(w).Encode(metrics)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Exposed returns a map of exposed metrics (see expvar package).
 | 
			
		||||
func Exposed() map[string]Metric {
 | 
			
		||||
	m := map[string]Metric{}
 | 
			
		||||
	expvar.Do(func(kv expvar.KeyValue) {
 | 
			
		||||
		if metric, ok := kv.Value.(Metric); ok {
 | 
			
		||||
			m[kv.Key] = metric
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	return m
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										104
									
								
								stats/stats.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								stats/stats.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,104 @@
 | 
			
		||||
package stats
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"math"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Metric is a single meter (counter, gauge or histogram, optionally - with history)
 | 
			
		||||
type Metric interface {
 | 
			
		||||
	Add(n float64)
 | 
			
		||||
	String() string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// metric is an extended private interface with some additional internal
 | 
			
		||||
// methods used by timeseries. Counters, gauges and histograms implement it.
 | 
			
		||||
type metric interface {
 | 
			
		||||
	Metric
 | 
			
		||||
	Reset()
 | 
			
		||||
	Aggregate(roll int, samples []metric)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type multimetric []*timeseries
 | 
			
		||||
 | 
			
		||||
func (mm multimetric) Add(n float64) {
 | 
			
		||||
	for _, m := range mm {
 | 
			
		||||
		m.Add(n)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mm multimetric) MarshalJSON() ([]byte, error) {
 | 
			
		||||
	b := []byte(`{"metrics":[`)
 | 
			
		||||
	for i, m := range mm {
 | 
			
		||||
		if i != 0 {
 | 
			
		||||
			b = append(b, ',')
 | 
			
		||||
		}
 | 
			
		||||
		x, _ := json.Marshal(m)
 | 
			
		||||
		b = append(b, x...)
 | 
			
		||||
	}
 | 
			
		||||
	b = append(b, ']', '}')
 | 
			
		||||
	return b, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mm multimetric) String() string {
 | 
			
		||||
	return mm[len(mm)-1].String()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newMetric(builder func() metric, frames ...string) Metric {
 | 
			
		||||
	if len(frames) == 0 {
 | 
			
		||||
		return builder()
 | 
			
		||||
	}
 | 
			
		||||
	if len(frames) == 1 {
 | 
			
		||||
		return newTimeseries(builder, frames[0])
 | 
			
		||||
	}
 | 
			
		||||
	mm := multimetric{}
 | 
			
		||||
	for _, frame := range frames {
 | 
			
		||||
		mm = append(mm, newTimeseries(builder, frame))
 | 
			
		||||
	}
 | 
			
		||||
	sort.Slice(mm, func(i, j int) bool {
 | 
			
		||||
		a, b := mm[i], mm[j]
 | 
			
		||||
		return a.interval.Seconds()*float64(len(a.samples)) < b.interval.Seconds()*float64(len(b.samples))
 | 
			
		||||
	})
 | 
			
		||||
	return mm
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewCounter returns a counter metric that increments the value with each
 | 
			
		||||
// incoming number.
 | 
			
		||||
func NewCounter(frames ...string) Metric {
 | 
			
		||||
	return newMetric(func() metric { return &counter{} }, frames...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type counter struct {
 | 
			
		||||
	count uint64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *counter) String() string { return strconv.FormatFloat(c.value(), 'g', -1, 64) }
 | 
			
		||||
func (c *counter) Reset()         { atomic.StoreUint64(&c.count, math.Float64bits(0)) }
 | 
			
		||||
func (c *counter) value() float64 { return math.Float64frombits(atomic.LoadUint64(&c.count)) }
 | 
			
		||||
 | 
			
		||||
func (c *counter) Add(n float64) {
 | 
			
		||||
	for {
 | 
			
		||||
		old := math.Float64frombits(atomic.LoadUint64(&c.count))
 | 
			
		||||
		new := old + n
 | 
			
		||||
		if atomic.CompareAndSwapUint64(&c.count, math.Float64bits(old), math.Float64bits(new)) {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *counter) MarshalJSON() ([]byte, error) {
 | 
			
		||||
	return json.Marshal(struct {
 | 
			
		||||
		Type  string  `json:"type"`
 | 
			
		||||
		Count float64 `json:"count"`
 | 
			
		||||
	}{"c", c.value()})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *counter) Aggregate(roll int, samples []metric) {
 | 
			
		||||
	c.Reset()
 | 
			
		||||
	for _, s := range samples {
 | 
			
		||||
		c.Add(s.(*counter).value())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										105
									
								
								stats/timeseries.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								stats/timeseries.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,105 @@
 | 
			
		||||
package stats
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type timeseries struct {
 | 
			
		||||
	sync.Mutex
 | 
			
		||||
	now      time.Time
 | 
			
		||||
	size     int
 | 
			
		||||
	interval time.Duration
 | 
			
		||||
	total    metric
 | 
			
		||||
	samples  []metric
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ts *timeseries) Reset() {
 | 
			
		||||
	ts.total.Reset()
 | 
			
		||||
	for _, s := range ts.samples {
 | 
			
		||||
		s.Reset()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ts *timeseries) roll() {
 | 
			
		||||
	t := time.Now()
 | 
			
		||||
	roll := int((t.Round(ts.interval).Sub(ts.now.Round(ts.interval))) / ts.interval)
 | 
			
		||||
	ts.now = t
 | 
			
		||||
	n := len(ts.samples)
 | 
			
		||||
	if roll <= 0 {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if roll >= len(ts.samples) {
 | 
			
		||||
		ts.Reset()
 | 
			
		||||
	} else {
 | 
			
		||||
		for i := 0; i < roll; i++ {
 | 
			
		||||
			tmp := ts.samples[n-1]
 | 
			
		||||
			for j := n - 1; j > 0; j-- {
 | 
			
		||||
				ts.samples[j] = ts.samples[j-1]
 | 
			
		||||
			}
 | 
			
		||||
			ts.samples[0] = tmp
 | 
			
		||||
			ts.samples[0].Reset()
 | 
			
		||||
		}
 | 
			
		||||
		ts.total.Aggregate(roll, ts.samples)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ts *timeseries) Add(n float64) {
 | 
			
		||||
	ts.Lock()
 | 
			
		||||
	defer ts.Unlock()
 | 
			
		||||
	ts.roll()
 | 
			
		||||
	ts.total.Add(n)
 | 
			
		||||
	ts.samples[0].Add(n)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ts *timeseries) MarshalJSON() ([]byte, error) {
 | 
			
		||||
	ts.Lock()
 | 
			
		||||
	defer ts.Unlock()
 | 
			
		||||
	ts.roll()
 | 
			
		||||
	return json.Marshal(struct {
 | 
			
		||||
		Interval float64  `json:"interval"`
 | 
			
		||||
		Total    Metric   `json:"total"`
 | 
			
		||||
		Samples  []metric `json:"samples"`
 | 
			
		||||
	}{float64(ts.interval) / float64(time.Second), ts.total, ts.samples})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ts *timeseries) String() string {
 | 
			
		||||
	ts.Lock()
 | 
			
		||||
	defer ts.Unlock()
 | 
			
		||||
	ts.roll()
 | 
			
		||||
	return ts.total.String()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newTimeseries(builder func() metric, frame string) *timeseries {
 | 
			
		||||
	var (
 | 
			
		||||
		totalNum, intervalNum   int
 | 
			
		||||
		totalUnit, intervalUnit rune
 | 
			
		||||
	)
 | 
			
		||||
	units := map[rune]time.Duration{
 | 
			
		||||
		's': time.Second,
 | 
			
		||||
		'm': time.Minute,
 | 
			
		||||
		'h': time.Hour,
 | 
			
		||||
		'd': time.Hour * 24,
 | 
			
		||||
		'w': time.Hour * 24 * 7,
 | 
			
		||||
		'M': time.Hour * 24 * 30,
 | 
			
		||||
		'y': time.Hour * 24 * 365,
 | 
			
		||||
	}
 | 
			
		||||
	fmt.Sscanf(frame, "%d%c%d%c", &totalNum, &totalUnit, &intervalNum, &intervalUnit)
 | 
			
		||||
	interval := units[intervalUnit] * time.Duration(intervalNum)
 | 
			
		||||
	if interval == 0 {
 | 
			
		||||
		interval = time.Minute
 | 
			
		||||
	}
 | 
			
		||||
	totalDuration := units[totalUnit] * time.Duration(totalNum)
 | 
			
		||||
	if totalDuration == 0 {
 | 
			
		||||
		totalDuration = interval * 15
 | 
			
		||||
	}
 | 
			
		||||
	n := int(totalDuration / interval)
 | 
			
		||||
	samples := make([]metric, n, n)
 | 
			
		||||
	for i := 0; i < n; i++ {
 | 
			
		||||
		samples[i] = builder()
 | 
			
		||||
	}
 | 
			
		||||
	totalMetric := builder()
 | 
			
		||||
	return ×eries{interval: interval, total: totalMetric, samples: samples}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										19
									
								
								styx.hcl
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								styx.hcl
									
									
									
									
									
								
							@@ -37,22 +37,35 @@ proxy {
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ca {
 | 
			
		||||
    cert = "testdata/ca.crt"
 | 
			
		||||
    key = "testdata/ca.key"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
policy "intercept" {
 | 
			
		||||
    path = "testdata/policy/intercept.rego"
 | 
			
		||||
    path = "testdata/policy/styx/intercept.rego"
 | 
			
		||||
    package = "styx.intercept"
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
policy "bogons" {
 | 
			
		||||
    path = "testdata/policy/bogons.rego"
 | 
			
		||||
    path = "testdata/policy/styx/bogons.rego"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
policy "childsafe" {
 | 
			
		||||
    path = "testdata/policy/childsafe.rego"
 | 
			
		||||
    path = "testdata/policy/custom/childsafe.rego"
 | 
			
		||||
    package = "custom"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
data {
 | 
			
		||||
    path = "testdata/match"
 | 
			
		||||
 | 
			
		||||
    storage {
 | 
			
		||||
        type = "bolt"
 | 
			
		||||
        path = "testdata/styx.bolt"
 | 
			
		||||
        #type = "sqlite"
 | 
			
		||||
        #path = "testdata/styx.db"
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    network "reserved" {
 | 
			
		||||
        type = "list"
 | 
			
		||||
        list = [
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								template/blocked-256.jpeg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								template/blocked-256.jpeg
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 25 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								template/blocked-512.jpeg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								template/blocked-512.jpeg
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 88 KiB  | 
							
								
								
									
										83
									
								
								template/blocked.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								template/blocked.html
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								template/blocked.jpeg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								template/blocked.jpeg
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 189 KiB  | 
							
								
								
									
										58
									
								
								testdata/policy/bogons.rego
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										58
									
								
								testdata/policy/bogons.rego
									
									
									
									
										vendored
									
									
								
							@@ -1,58 +0,0 @@
 | 
			
		||||
package styx
 | 
			
		||||
 | 
			
		||||
import input.request as http_request
 | 
			
		||||
 | 
			
		||||
default permit := false
 | 
			
		||||
default reject := 0
 | 
			
		||||
default template := ""
 | 
			
		||||
 | 
			
		||||
# Bogon networks
 | 
			
		||||
bogons := [
 | 
			
		||||
    "0.0.0.0/8",          # "This" network
 | 
			
		||||
    "10.0.0.0/8",         # RFC1918 Private-use networks
 | 
			
		||||
    "100.64.0.0/10",      # Carrier-grade NAT
 | 
			
		||||
    "127.0.0.0/8",        # Loopback
 | 
			
		||||
    "169.254.0.0/16",     # Link local
 | 
			
		||||
    "172.16.0.0/12",      # RFC1918 Private-use networks
 | 
			
		||||
    "192.0.0.0/24",       # IETF protocol assignments
 | 
			
		||||
    "192.0.2.0/24",       # TEST-NET-1
 | 
			
		||||
    "192.168.0.0/16",     # RFC1918 Private-use networks
 | 
			
		||||
    "198.18.0.0/15",      # Network interconnect device benchmark testing
 | 
			
		||||
    "198.51.100.0/24",    # TEST-NET-2
 | 
			
		||||
    "203.0.113.0/24",     # TEST-NET-3
 | 
			
		||||
    "224.0.0.0/4",        # Multicast
 | 
			
		||||
    "240.0.0.0/4",        # Reserved for future use
 | 
			
		||||
    "255.255.255.255/32", # Limited broadcast
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
# Resolve HTTP host to IPs
 | 
			
		||||
addrs := styx.lookup_ip_addr(http_request.host)
 | 
			
		||||
 | 
			
		||||
template := "template/blocked.html" if {
 | 
			
		||||
    some cidr in bogons
 | 
			
		||||
    net.cidr_contains(cidr, http_request.host)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template := "template/blocked.html" if {
 | 
			
		||||
    some addr in addrs
 | 
			
		||||
    some cidr in bogons
 | 
			
		||||
    net.cidr_contains(cidr, addr)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
permit if {
 | 
			
		||||
    template == ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
errors contains "Bogon destination not allowed" if {
 | 
			
		||||
    template != ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
errors contains "Could not lookup host" if {
 | 
			
		||||
    count(addrs) == 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
errors contains addr if {
 | 
			
		||||
    some addr in addrs
 | 
			
		||||
    some cidr in bogons
 | 
			
		||||
    net.cidr_contains(cidr, addr)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										56
									
								
								testdata/policy/childsafe.rego
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										56
									
								
								testdata/policy/childsafe.rego
									
									
									
									
										vendored
									
									
								
							@@ -1,56 +0,0 @@
 | 
			
		||||
package styx
 | 
			
		||||
 | 
			
		||||
import input.client as client
 | 
			
		||||
import input.request as http_request
 | 
			
		||||
 | 
			
		||||
# HTTP -> HTTPS redirects for allowed domains
 | 
			
		||||
redirect = concat("", ["https://", http_request.host, http_request.path]) if {
 | 
			
		||||
    _social
 | 
			
		||||
    http_request.scheme == "http"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
reject = 403 if {
 | 
			
		||||
    _childsafe_network
 | 
			
		||||
    _social
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
reject = 403 if {
 | 
			
		||||
    _childsafe_network
 | 
			
		||||
    _toxic
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
# Sensitive domains are always allowed
 | 
			
		||||
permit if {
 | 
			
		||||
    _sensitive
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
permit if {
 | 
			
		||||
    reject != 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
_sensitive if {
 | 
			
		||||
    styx.in_domains("sensitive", http_request.host)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
_social if {
 | 
			
		||||
    styx.in_domains("social", http_request.host)
 | 
			
		||||
    print("Domain in social", http_request.host)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
errors contains "Social networking domain not allowed" if {
 | 
			
		||||
    reject != 0
 | 
			
		||||
    _social
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
_toxic if {
 | 
			
		||||
    styx.in_domains("toxic", http_request.host)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
errors contains "Toxic domain not allowed" if {
 | 
			
		||||
    reject != 0
 | 
			
		||||
    _toxic
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
_childsafe_network if {
 | 
			
		||||
    styx.in_networks("kids", client.ip)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										102
									
								
								testdata/policy/custom/childsafe.rego
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										102
									
								
								testdata/policy/custom/childsafe.rego
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,102 @@
 | 
			
		||||
package custom
 | 
			
		||||
 | 
			
		||||
_social_domains := [
 | 
			
		||||
	"reddit.com",
 | 
			
		||||
	"roblox.com",
 | 
			
		||||
	# X
 | 
			
		||||
	"twitter.com",
 | 
			
		||||
	"x.com",
 | 
			
		||||
	# YouTube
 | 
			
		||||
	"googlevideo.com",
 | 
			
		||||
	"youtube.com",
 | 
			
		||||
	"youtu.be",
 | 
			
		||||
	"ytimg.com",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
_toxic_domains := [
 | 
			
		||||
	# Facebook
 | 
			
		||||
	"facebook.com",
 | 
			
		||||
	"facebook.net",
 | 
			
		||||
	"fbsbx.com",
 | 
			
		||||
	# Pinterest
 | 
			
		||||
	"pinterest.com",
 | 
			
		||||
	# TikTok
 | 
			
		||||
	"isnssdk.com",
 | 
			
		||||
	"musical.ly",
 | 
			
		||||
	"musically.app.link",
 | 
			
		||||
	"musically-alternate.app.link",
 | 
			
		||||
	"musemuse.cn",
 | 
			
		||||
	"sgsnssdk.com",
 | 
			
		||||
	"tiktok.com",
 | 
			
		||||
	"tiktok.org",
 | 
			
		||||
	"tiktokcdn.com",
 | 
			
		||||
	"tiktokcdn-eu.com",
 | 
			
		||||
	"tiktokv.com",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
in_domains(list, name) if {
 | 
			
		||||
	some item in list
 | 
			
		||||
	lower(name) == lower(item)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
in_domains(list, name) if {
 | 
			
		||||
	some item in list
 | 
			
		||||
	endswith(lower(name), sprintf(".%s", [lower(item)]))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
# METADATA
 | 
			
		||||
# description: Apply childssfe rules to the request, reject if it's a social
 | 
			
		||||
#              site between off-hours, reject if it's toxic.
 | 
			
		||||
# entrypoint: true
 | 
			
		||||
default redirect := ""
 | 
			
		||||
 | 
			
		||||
# HTTP -> HTTPS redirects for allowed domains
 | 
			
		||||
redirect := location if {
 | 
			
		||||
	_social
 | 
			
		||||
	input.request.scheme == "http"
 | 
			
		||||
	location := sprintf("https://%s%s", [input.request.host, input.request.path])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
default reject := 0
 | 
			
		||||
 | 
			
		||||
template := "template/blocked.html" if {
 | 
			
		||||
	_childsafe_network
 | 
			
		||||
	_social
 | 
			
		||||
	# styx.time_between("18:00", "16:00") # allowed between 16:00-18:00
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template := "template/blocked.html" if {
 | 
			
		||||
	_toxic
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
# Sensitive domains are always allowed
 | 
			
		||||
permit if {
 | 
			
		||||
	_sensitive
 | 
			
		||||
	reject != 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
_sensitive if {
 | 
			
		||||
	styx.domains_contain("sensitive", input.request.host)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
_social if {
 | 
			
		||||
	#styx.domains_contain("social", input.request.host)
 | 
			
		||||
	in_domains(_social_domains, input.request.host)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
_toxic if {
 | 
			
		||||
	in_domains(_toxic_domains, input.request.host)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
_childsafe_network if {
 | 
			
		||||
	styx.networks_contain("kids", input.client.ip)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
errors contains "Request to social networking site outside of allowed hours" if {
 | 
			
		||||
	_childsafe_network
 | 
			
		||||
	_social
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
errors contains "Request to toxic site" if {
 | 
			
		||||
	_toxic
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										21
									
								
								testdata/policy/intercept.rego
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										21
									
								
								testdata/policy/intercept.rego
									
									
									
									
										vendored
									
									
								
							@@ -1,21 +0,0 @@
 | 
			
		||||
package styx.intercept
 | 
			
		||||
 | 
			
		||||
reject := 403 if {
 | 
			
		||||
    _target_blocked
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template := "template/intercepted.html" if {
 | 
			
		||||
    _target_blocked
 | 
			
		||||
} 
 | 
			
		||||
 | 
			
		||||
errors contains "Intercepted" if {
 | 
			
		||||
    _target_blocked
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
_target_blocked if {
 | 
			
		||||
    styx.in_domains("bad", input.request.host)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
_target_blocked if {
 | 
			
		||||
    styx.in_networks("bogons", input.client.ip)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										54
									
								
								testdata/policy/styx/bogons.rego
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								testdata/policy/styx/bogons.rego
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,54 @@
 | 
			
		||||
package styx
 | 
			
		||||
 | 
			
		||||
# Bogon networks
 | 
			
		||||
_bogons := [
 | 
			
		||||
	"0.0.0.0/8", # "This" network
 | 
			
		||||
	"10.0.0.0/8", # RFC1918 Private-use networks
 | 
			
		||||
	"100.64.0.0/10", # Carrier-grade NAT
 | 
			
		||||
	"127.0.0.0/8", # Loopback
 | 
			
		||||
	"169.254.0.0/16", # Link local
 | 
			
		||||
	"172.16.0.0/12", # RFC1918 Private-use networks
 | 
			
		||||
	"192.0.0.0/24", # IETF protocol assignments
 | 
			
		||||
	"192.0.2.0/24", # TEST-NET-1
 | 
			
		||||
	"192.168.0.0/16", # RFC1918 Private-use networks
 | 
			
		||||
	"198.18.0.0/15", # Network interconnect device benchmark testing
 | 
			
		||||
	"198.51.100.0/24", # TEST-NET-2
 | 
			
		||||
	"203.0.113.0/24", # TEST-NET-3
 | 
			
		||||
	"224.0.0.0/4", # Multicast
 | 
			
		||||
	"240.0.0.0/4", # Reserved for future use
 | 
			
		||||
	"255.255.255.255/32", # Limited broadcast
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
# METADATA
 | 
			
		||||
# description: Reject requests to bogon targets.
 | 
			
		||||
# entrypoint: true
 | 
			
		||||
default permit := false
 | 
			
		||||
 | 
			
		||||
permit if {
 | 
			
		||||
	template == ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
default template := ""
 | 
			
		||||
 | 
			
		||||
template := "template/blocked.html" if {
 | 
			
		||||
	_bogon
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
errors contains "Bogon destination not allowed" if {
 | 
			
		||||
	_bogon
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
errors contains _bogon if {
 | 
			
		||||
	_bogon
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
_bogon := addr if {
 | 
			
		||||
	some addr in styx.lookup_ip_addr(input.request.host)
 | 
			
		||||
	some cidr in _bogons
 | 
			
		||||
	net.cidr_contains(cidr, addr)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
_bogon := input.request.host if {
 | 
			
		||||
	some cidr in _bogons
 | 
			
		||||
	net.cidr_contains(cidr, input.request.host)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										25
									
								
								testdata/policy/styx/intercept.rego
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								testdata/policy/styx/intercept.rego
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,25 @@
 | 
			
		||||
package styx.intercept
 | 
			
		||||
 | 
			
		||||
reject := 403 if {
 | 
			
		||||
	_bad
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template := "template/blocked.html" if {
 | 
			
		||||
	_bogon
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
errors contains "Bad domain" if {
 | 
			
		||||
	_bad
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
errors contains "Bogon target" if {
 | 
			
		||||
	_bogon
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
_bad if {
 | 
			
		||||
	styx.domains_contain("bad", input.request.host)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
_bogon if {
 | 
			
		||||
	styx.domains_contain("bogons", input.client.ip)
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user