Checkpoint

This commit is contained in:
2025-10-06 22:25:23 +02:00
parent a23259cfdc
commit a254b306f2
48 changed files with 3327 additions and 212 deletions

4
.gitignore vendored
View File

@@ -1,4 +1,6 @@
# SQLite3 database file # Database file
*.bolt
*.boltdb
*.db *.db
# Log files # Log files

14
.regal.yaml Normal file
View File

@@ -0,0 +1,14 @@
rules:
idiomatic:
directory-package-mismatch:
level: ignore
style:
function-arg-return:
level: error
except-functions:
- sprintf
project:
roots:
- testdata/policy

146
admin/admin.go Normal file
View File

@@ -0,0 +1,146 @@
package admin
import (
"bytes"
"encoding/json"
"errors"
"io"
"net/http"
"os"
"strconv"
"sync"
"git.maze.io/maze/styx/dataset"
"git.maze.io/maze/styx/logger"
"git.maze.io/maze/styx/proxy"
)
type Admin struct {
Storage dataset.Storage
setupOnce sync.Once
mux *http.ServeMux
api *http.ServeMux
}
type apiError struct {
Code int
Err error
}
func (err apiError) Error() string {
return err.Err.Error()
}
func (a *Admin) setup() {
a.mux = http.NewServeMux()
a.api = http.NewServeMux()
a.api.HandleFunc("GET /groups", a.apiGroups)
a.api.HandleFunc("POST /group", a.apiGroupCreate)
a.api.HandleFunc("GET /group/{id}", a.apiGroup)
a.api.HandleFunc("PATCH /group/{id}", a.apiGroupUpdate)
a.api.HandleFunc("DELETE /group/{id}", a.apiGroupDelete)
a.api.HandleFunc("GET /clients", a.apiClients)
a.api.HandleFunc("GET /client/{id}", a.apiClient)
a.api.HandleFunc("POST /client", a.apiClientCreate)
a.api.HandleFunc("PATCH /client/{id}", a.apiClientUpdate)
a.api.HandleFunc("DELETE /client/{id}", a.apiClientDelete)
a.api.HandleFunc("GET /lists", a.apiLists)
a.api.HandleFunc("POST /list", a.apiListCreate)
a.api.HandleFunc("GET /list/{id}", a.apiList)
a.api.HandleFunc("DELETE /list/{id}", a.apiListDelete)
}
type Handler interface {
Handle(pattern string, handler http.Handler)
}
func (a *Admin) Install(handler Handler) {
a.setupOnce.Do(a.setup)
handler.Handle("/api/v1/", http.StripPrefix("/api/v1", a.api))
}
func (a *Admin) handleAPIError(w http.ResponseWriter, r *http.Request, err error) {
code := http.StatusBadRequest
switch {
case dataset.IsNotExist(err):
code = http.StatusNotFound
case os.IsPermission(err):
code = http.StatusForbidden
case errors.Is(err, apiError{}):
if c := err.(apiError).Code; c > 0 {
code = c
}
}
logger.StandardLog.Err(err).Values(logger.Values{
"code": code,
"client": r.RemoteAddr,
"method": r.Method,
"path": r.URL.Path,
}).Warn("Unexpected API error encountered")
var data []byte
if err, ok := err.(apiError); ok {
data, _ = json.Marshal(struct {
Code int `json:"code"`
Error string `json:"error"`
}{code, err.Error()})
} else {
data, _ = json.Marshal(struct {
Code int `json:"code"`
Error string `json:"error"`
}{code, http.StatusText(code)})
}
res := proxy.NewResponse(code, io.NopCloser(bytes.NewReader(data)), r)
res.Header.Set(proxy.HeaderContentType, "application/json")
for k, vv := range res.Header {
if len(vv) >= 1 {
w.Header().Set(k, vv[0])
for _, v := range vv[1:] {
w.Header().Add(k, v)
}
}
}
w.WriteHeader(code)
io.Copy(w, res.Body)
}
func (a *Admin) jsonResponse(w http.ResponseWriter, r *http.Request, value any, codes ...int) {
var (
code = http.StatusNoContent
body io.ReadCloser
size int64
)
if value != nil {
data, err := json.Marshal(value)
if err != nil {
a.handleAPIError(w, r, err)
return
}
code = http.StatusOK
body = io.NopCloser(bytes.NewReader(data))
size = int64(len(data))
}
if len(codes) > 0 {
code = codes[0]
}
res := proxy.NewResponse(code, body, r)
res.Close = true
res.Header.Set(proxy.HeaderContentLength, strconv.FormatInt(size, 10))
res.Header.Set(proxy.HeaderContentType, "application/json")
for k, vv := range res.Header {
if len(vv) >= 1 {
w.Header().Set(k, vv[0])
for _, v := range vv[1:] {
w.Header().Add(k, v)
}
}
}
w.WriteHeader(code)
io.Copy(w, res.Body)
}

183
admin/api_client.go Normal file
View File

@@ -0,0 +1,183 @@
package admin
import (
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"strconv"
"time"
"git.maze.io/maze/styx/dataset"
)
func (a *Admin) apiClients(w http.ResponseWriter, r *http.Request) {
clients, err := a.Storage.Clients()
if err != nil {
a.handleAPIError(w, r, err)
return
}
a.jsonResponse(w, r, clients)
}
func (a *Admin) apiClient(w http.ResponseWriter, r *http.Request) {
id, err := strconv.ParseInt(r.PathValue("id"), 10, 64)
if err != nil {
a.handleAPIError(w, r, err)
return
}
client, err := a.Storage.ClientByID(id)
if err != nil {
a.handleAPIError(w, r, err)
return
}
a.jsonResponse(w, r, client)
}
func (a *Admin) apiClientCreate(w http.ResponseWriter, r *http.Request) {
var request struct {
dataset.Client
Groups []int64 `json:"groups"`
ID int64 `json:"id"` // mask, not used
CreatedAt time.Time `json:"created_at"` // mask, not used
UpdatedAt time.Time `json:"updated_at"` // mask, not used
}
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
a.handleAPIError(w, r, err)
return
}
if err := a.verifyClient(&request.Client); err != nil {
a.handleAPIError(w, r, err)
return
}
var groups []dataset.Group
for _, id := range request.Groups {
group, err := a.Storage.GroupByID(id)
if err != nil {
a.handleAPIError(w, r, err)
return
}
groups = append(groups, group)
}
request.Client.Groups = groups
if err := a.Storage.SaveClient(&request.Client); err != nil {
a.handleAPIError(w, r, err)
return
}
a.jsonResponse(w, r, request.Client)
}
func (a *Admin) apiClientUpdate(w http.ResponseWriter, r *http.Request) {
id, err := strconv.ParseInt(r.PathValue("id"), 10, 64)
if err != nil {
a.handleAPIError(w, r, err)
return
}
client, err := a.Storage.ClientByID(id)
if err != nil {
a.handleAPIError(w, r, err)
return
}
log.Printf("updating: %#+v", client)
var request struct {
dataset.Client
Groups []int64 `json:"groups"`
}
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
a.handleAPIError(w, r, err)
return
}
if err := a.verifyClient(&request.Client); err != nil {
a.handleAPIError(w, r, err)
return
}
client.IP = request.Client.IP
client.Mask = request.Client.Mask
client.Description = request.Client.Description
client.Groups = client.Groups[:0]
for _, id := range request.Groups {
group, err := a.Storage.GroupByID(id)
if err != nil {
a.handleAPIError(w, r, err)
return
}
client.Groups = append(client.Groups, group)
}
if err := a.Storage.SaveClient(&client); err != nil {
a.handleAPIError(w, r, err)
return
}
a.jsonResponse(w, r, client)
}
func (a *Admin) apiClientDelete(w http.ResponseWriter, r *http.Request) {
id, err := strconv.ParseInt(r.PathValue("id"), 10, 64)
if err != nil {
a.handleAPIError(w, r, err)
return
}
client, err := a.Storage.ClientByID(id)
if err != nil {
a.handleAPIError(w, r, err)
return
}
if err = a.Storage.DeleteClient(client); err != nil {
a.handleAPIError(w, r, err)
return
}
a.jsonResponse(w, r, nil)
}
func (a *Admin) verifyClient(c *dataset.Client) (err error) {
ip := net.ParseIP(c.IP)
switch c.Network {
case "ipv4":
if ip.To4() == nil {
return apiError{Err: errors.New("invalid IPv4 address")}
}
if c.Mask == 0 {
c.Mask = 32 // one IP
}
if c.Mask <= 0 || c.Mask > 32 {
return apiError{Err: errors.New("mask can't be zero")}
}
c.IP = ip.Mask(net.CIDRMask(int(c.Mask), 32)).String()
case "ipv6":
if ip.To16() == nil {
return apiError{Err: errors.New("invalid IPv6 address")}
}
if c.Mask == 0 {
c.Mask = 128 // one IP
}
if c.Mask <= 0 || c.Mask > 128 {
return apiError{Err: errors.New("mask can't be zero")}
}
c.IP = ip.Mask(net.CIDRMask(int(c.Mask), 128)).String()
case "":
if ip.To4() != nil {
c.Network = "ipv4"
} else if ip.To16() != nil {
c.Network = "ipv6"
} else {
return apiError{Err: errors.New("invalid IP address")}
}
return a.verifyClient(c)
default:
return apiError{Err: fmt.Errorf("invalid network %q", c.Network)}
}
return
}

72
admin/api_group.go Normal file
View File

@@ -0,0 +1,72 @@
package admin
import (
"encoding/json"
"net/http"
"strconv"
"time"
"git.maze.io/maze/styx/dataset"
)
func (a *Admin) apiGroups(w http.ResponseWriter, r *http.Request) {
groups, err := a.Storage.Groups()
if err != nil {
a.handleAPIError(w, r, err)
return
}
a.jsonResponse(w, r, groups)
}
func (a *Admin) apiGroup(w http.ResponseWriter, r *http.Request) {
id, err := strconv.ParseInt(r.PathValue("id"), 10, 64)
if err != nil {
a.handleAPIError(w, r, err)
return
}
group, err := a.Storage.GroupByID(id)
if err != nil {
a.handleAPIError(w, r, err)
return
}
a.jsonResponse(w, r, group)
}
func (a *Admin) apiGroupCreate(w http.ResponseWriter, r *http.Request) {
var request struct {
dataset.Group
ID int64 `json:"id"` // mask, not used
CreatedAt time.Time `json:"created_at"` // mask, not used
UpdatedAt time.Time `json:"updated_at"` // mask, not used
}
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
a.handleAPIError(w, r, err)
return
}
if err := a.Storage.SaveGroup(&request.Group); err != nil {
a.handleAPIError(w, r, err)
return
}
a.jsonResponse(w, r, request.Group, http.StatusCreated)
}
func (a *Admin) apiGroupUpdate(w http.ResponseWriter, r *http.Request) {
}
func (a *Admin) apiGroupDelete(w http.ResponseWriter, r *http.Request) {
id, err := strconv.ParseInt(r.PathValue("id"), 10, 64)
if err != nil {
a.handleAPIError(w, r, err)
return
}
group, err := a.Storage.GroupByID(id)
if err != nil {
a.handleAPIError(w, r, err)
return
}
if err = a.Storage.DeleteGroup(group); err != nil {
a.handleAPIError(w, r, err)
return
}
a.jsonResponse(w, r, nil)
}

98
admin/api_list.go Normal file
View File

@@ -0,0 +1,98 @@
package admin
import (
"encoding/json"
"fmt"
"net/http"
"strconv"
"time"
"git.maze.io/maze/styx/dataset"
)
func (a *Admin) apiLists(w http.ResponseWriter, r *http.Request) {
lists, err := a.Storage.Lists()
if err != nil {
a.handleAPIError(w, r, err)
return
}
a.jsonResponse(w, r, lists)
}
func (a *Admin) apiList(w http.ResponseWriter, r *http.Request) {
id, err := strconv.ParseInt(r.PathValue("id"), 10, 64)
if err != nil {
a.handleAPIError(w, r, err)
return
}
list, err := a.Storage.ListByID(id)
if err != nil {
a.handleAPIError(w, r, err)
return
}
a.jsonResponse(w, r, list)
}
func (a *Admin) apiListCreate(w http.ResponseWriter, r *http.Request) {
var request struct {
dataset.List
Groups []int64 `json:"groups"`
ID int64 `json:"id"` // mask, not used
CreatedAt time.Time `json:"created_at"` // mask, not used
UpdatedAt time.Time `json:"updated_at"` // mask, not used
}
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
a.handleAPIError(w, r, err)
return
}
if err := a.verifyList(&request.List); err != nil {
a.handleAPIError(w, r, err)
return
}
request.List.Groups = request.List.Groups[:0]
for _, id := range request.Groups {
group, err := a.Storage.GroupByID(id)
if err != nil {
a.handleAPIError(w, r, err)
return
}
request.List.Groups = append(request.List.Groups, group)
}
if err := a.Storage.SaveList(&request.List); err != nil {
a.handleAPIError(w, r, err)
return
}
a.jsonResponse(w, r, request.List)
}
func (a *Admin) apiListDelete(w http.ResponseWriter, r *http.Request) {
id, err := strconv.ParseInt(r.PathValue("id"), 10, 64)
if err != nil {
a.handleAPIError(w, r, err)
return
}
list, err := a.Storage.ListByID(id)
if err != nil {
a.handleAPIError(w, r, err)
return
}
if err = a.Storage.DeleteList(list); err != nil {
a.handleAPIError(w, r, err)
return
}
a.jsonResponse(w, r, nil)
}
func (a *Admin) verifyList(list *dataset.List) error {
switch list.Type {
case dataset.ListTypeDomain, dataset.ListTypeNetwork:
default:
return apiError{Err: fmt.Errorf("unknown list type %q", list.Type)}
}
return nil
}

119
ca/authority.go Normal file
View File

@@ -0,0 +1,119 @@
package ca
import (
"crypto"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"math/big"
"net"
"strings"
"sync"
"time"
"git.maze.io/maze/styx/internal/cryptutil"
"git.maze.io/maze/styx/logger"
"github.com/miekg/dns"
)
type CertificateAuthority interface {
GetCertificate(commonName string, dnsNames []string, ips []net.IP) (*tls.Certificate, error)
}
type ca struct {
cert *x509.Certificate
key crypto.PrivateKey
cache sync.Map
}
func Open(certData, keyData string) (CertificateAuthority, error) {
cert, key, err := cryptutil.LoadKeyPair(certData, keyData)
if err != nil {
return nil, err
} else if !cert.IsCA {
return nil, fmt.Errorf("ca: certificate for %s is not a certificate authority", cert.Subject.String())
}
return &ca{
cert: cert,
key: key,
}, nil
}
func (ca *ca) GetCertificate(cn string, names []string, ips []net.IP) (*tls.Certificate, error) {
var (
log = logger.StandardLog.Values(logger.Values{
"cn": cn,
"names": names,
"ips": ips,
})
now = time.Now().UTC()
parent = parentDomain(cn)
)
if cn == parent {
names = append(names, "*."+cn)
} else {
names = append(names, "*."+parent, cn)
cn = parent
log = log.Value("cn", cn)
}
if v, ok := ca.cache.Load(parent); ok {
if cert, ok := v.(*tls.Certificate); ok && now.After(cert.Leaf.NotBefore) && now.Before(cert.Leaf.NotAfter.Add(-time.Hour)) {
log.Value("valid", cert.Leaf.NotAfter.Sub(now)).Debug("Using cached certificate")
return cert, nil
}
log.Debug("Cached certificate invalid")
ca.cache.Delete(parent)
}
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, fmt.Errorf("ca: failed to generate serial number: %w", err)
}
notBefore := now.Round(24 * time.Hour)
notAfter := notBefore.Add(48 * time.Hour)
log.Values(logger.Values{
"serial": serialNumber.String(),
"subject": pkix.Name{CommonName: cn}.String(),
}).Debug("Generating certificate")
template := &x509.Certificate{
SerialNumber: serialNumber,
KeyUsage: x509.KeyUsageDataEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
Subject: pkix.Name{CommonName: cn},
DNSNames: names,
IPAddresses: ips,
PublicKey: cryptutil.PublicKey(ca.key),
NotBefore: notBefore,
NotAfter: notAfter,
}
der, err := x509.CreateCertificate(rand.Reader, template, ca.cert, template.PublicKey, ca.key)
if err != nil {
return nil, err
}
cert, err := x509.ParseCertificate(der)
if err != nil {
return nil, err
}
output := &tls.Certificate{
Certificate: [][]byte{der},
Leaf: cert,
PrivateKey: ca.key,
}
ca.cache.Store(parent, output)
return output, nil
}
func parentDomain(name string) string {
part := dns.SplitDomainName(name)
if len(part) <= 2 {
return name
}
return strings.Join(part[1:], ".")
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/hashicorp/hcl/v2/gohcl" "github.com/hashicorp/hcl/v2/gohcl"
"github.com/hashicorp/hcl/v2/hclsimple" "github.com/hashicorp/hcl/v2/hclsimple"
"git.maze.io/maze/styx/ca"
"git.maze.io/maze/styx/dataset" "git.maze.io/maze/styx/dataset"
"git.maze.io/maze/styx/internal/cryptutil" "git.maze.io/maze/styx/internal/cryptutil"
"git.maze.io/maze/styx/logger" "git.maze.io/maze/styx/logger"
@@ -18,6 +19,7 @@ import (
type Config struct { type Config struct {
Proxy ProxyConfig `hcl:"proxy,block"` Proxy ProxyConfig `hcl:"proxy,block"`
Policy []PolicyConfig `hcl:"policy,block"` Policy []PolicyConfig `hcl:"policy,block"`
CA *CAConfig `hcl:"ca,block"`
Data DataConfig `hcl:"data,block"` Data DataConfig `hcl:"data,block"`
} }
@@ -145,8 +147,18 @@ type PolicyConfig struct {
Package string `hcl:"package,optional"` Package string `hcl:"package,optional"`
} }
type CAConfig struct {
Cert string `hcl:"cert"`
Key string `hcl:"key,optional"`
}
func (c CAConfig) CertificateAuthority() (ca.CertificateAuthority, error) {
return ca.Open(c.Cert, c.Key)
}
type DataConfig struct { type DataConfig struct {
Path string `hcl:"path,optional"` Path string `hcl:"path,optional"`
Storage DataStorageConfig `hcl:"storage,block"`
Domains []DomainDataConfig `hcl:"domain,block"` Domains []DomainDataConfig `hcl:"domain,block"`
Networks []NetworkDataConfig `hcl:"network,block"` Networks []NetworkDataConfig `hcl:"network,block"`
} }
@@ -165,6 +177,39 @@ func (c DataConfig) Configure() error {
return nil return nil
} }
func (c DataConfig) OpenStorage() (dataset.Storage, error) {
switch c.Storage.Type {
case "", "bolt", "boltdb":
var config struct {
Path string `hcl:"path"`
}
if diag := gohcl.DecodeBody(c.Storage.Body, nil, &config); diag.HasErrors() {
return nil, diag
}
//return dataset.OpenBolt(config.Path)
return dataset.OpenBStore(config.Path)
/*
case "sqlite", "sqlite3":
var config struct {
Path string `hcl:"path"`
}
if diag := gohcl.DecodeBody(c.Storage.Body, nil, &config); diag.HasErrors() {
return nil, diag
}
return dataset.OpenSQLite(config.Path)
*/
default:
return nil, fmt.Errorf("storage: no %q driver", c.Storage.Type)
}
}
type DataStorageConfig struct {
Type string `hcl:"type"`
Body hcl.Body `hcl:",remain"`
}
type DomainDataConfig struct { type DomainDataConfig struct {
Name string `hcl:"name,label"` Name string `hcl:"name,label"`
Type string `hcl:"type"` Type string `hcl:"type"`

View File

@@ -7,6 +7,9 @@ import (
"os/signal" "os/signal"
"syscall" "syscall"
"git.maze.io/maze/styx/admin"
"git.maze.io/maze/styx/ca"
"git.maze.io/maze/styx/dataset"
"git.maze.io/maze/styx/logger" "git.maze.io/maze/styx/logger"
"git.maze.io/maze/styx/proxy" "git.maze.io/maze/styx/proxy"
) )
@@ -40,6 +43,22 @@ func main() {
log.Err(err).Fatal("Invalid data configuration") log.Err(err).Fatal("Invalid data configuration")
} }
var ca ca.CertificateAuthority
if config.CA != nil {
if ca, err = config.CA.CertificateAuthority(); err != nil {
log.Err(err).Fatal("Invalid ca configuration")
}
}
var storage dataset.Storage
if storage, err = config.Data.OpenStorage(); err != nil {
log.Err(err).Fatal("Invalid data.storage configuration")
}
admin := &admin.Admin{
Storage: storage,
}
proxies, err := config.Proxies(log) proxies, err := config.Proxies(log)
if err != nil { if err != nil {
log.Err(err).Fatal("Error configuring proxy ports") log.Err(err).Fatal("Error configuring proxy ports")
@@ -52,6 +71,9 @@ func main() {
) )
for i, p := range proxies { for i, p := range proxies {
p.CertificateAuthority = ca
p.Storage = storage
admin.Install(p)
go run(config.Proxy.Port[i].Listen, p, errs) go run(config.Proxy.Port[i].Listen, p, errs)
} }
@@ -64,12 +86,18 @@ func main() {
case syscall.SIGHUP: case syscall.SIGHUP:
log.Value("signal", sig.String()).Warn("Ignored reload signal ¯\\_(ツ)_/¯") log.Value("signal", sig.String()).Warn("Ignored reload signal ¯\\_(ツ)_/¯")
default: default:
log.Value("signal", sig.String()).Info("Shutting down on signal") log.Value("signal", sig.String()).Warn("Shutting down on signal")
return close(done)
} }
case <-done: case <-done:
log.Info("Shutting down gracefully") log.Warn("Shutting down gracefully")
for i, p := range proxies {
log.Value("port", config.Proxy.Port[i].Listen).Info("Proxy port closing")
if err := p.Close(); err != nil {
log.Err(err).Error("Error closing proxy")
}
}
return return
case err = <-errs: case err = <-errs:

1
dataset/base.go Normal file
View File

@@ -0,0 +1 @@
package dataset

25
dataset/error.go Normal file
View File

@@ -0,0 +1,25 @@
package dataset
import (
"errors"
"fmt"
"os"
"github.com/mjl-/bstore"
)
type ErrNotExist struct {
Object string
ID int64
}
func (err ErrNotExist) Error() string {
return fmt.Sprintf("storage: %s not found", err.Object)
}
func IsNotExist(err error) bool {
if err == nil {
return false
}
return os.IsNotExist(err) || errors.Is(err, ErrNotExist{}) || errors.Is(err, bstore.ErrAbsent)
}

53
dataset/parser/adblock.go Normal file
View File

@@ -0,0 +1,53 @@
package parser
import (
"bufio"
"io"
"strings"
)
func init() {
RegisterDomainsParser(adblockDomainsParser{})
}
type adblockDomainsParser struct{}
func (adblockDomainsParser) CanHandle(line string) bool {
return strings.HasPrefix(strings.ToLower(line), `[adblock`) ||
strings.HasPrefix(line, "@@") || // exception
strings.HasPrefix(line, "||") || // blah
line[0] == '*'
}
func (adblockDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if isComment(line) {
continue
}
// Common AdBlock patterns:
// ||domain.com^
// |http://domain.com|
// domain.com/path
// *domain.com*
switch {
case strings.HasPrefix(line, `||`): // domain anchor
if i := strings.IndexByte(line, '^'); i != -1 {
domains = append(domains, line[2:i])
continue
}
case strings.HasPrefix(line, `|`) && strings.HasSuffix(line, `|`):
domains = append(domains, line[1:len(line)-2])
continue
case strings.HasPrefix(line, `[`):
continue
}
ignored++
}
if err = scanner.Err(); err != nil {
return
}
return unique(domains), ignored, nil
}

View File

@@ -0,0 +1,41 @@
package parser
import (
"reflect"
"sort"
"strings"
"testing"
)
func TestAdBlockParser(t *testing.T) {
test := `[Adblock Plus 2.0]
! Title: AdRules DNS List
! Homepage: https://github.com/Cats-Team/AdRules
! Powerd by Cats-Team
! Expires: 1 (update frequency)
! Description: The DNS Filters
! Total count: 145270
! Update: 2025-10-07 02:05:08(GMT+8)
/^.+stat\.kugou\.com/
/^admarvel\./
||*-ad-sign.byteimg.com^
||*-ad.a.yximgs.com^
||*-applog.fqnovel.com^
||*-datareceiver.aki-game.net^
||*.exaapi.com^`
want := []string{"*-ad-sign.byteimg.com", "*-ad.a.yximgs.com", "*-applog.fqnovel.com", "*-datareceiver.aki-game.net", "*.exaapi.com"}
parsed, ignored, err := ParseDomains(strings.NewReader(test))
if err != nil {
t.Fatal(err)
return
}
sort.Strings(parsed)
if !reflect.DeepEqual(parsed, want) {
t.Errorf("expected ParseDomains(domains) to return %v, got %v", want, parsed)
}
if ignored != 2 {
t.Errorf("expected 2 ignored, got %d", ignored)
}
}

139
dataset/parser/dns.go Normal file
View File

@@ -0,0 +1,139 @@
package parser
import (
"bufio"
"io"
"strings"
"github.com/miekg/dns"
)
func init() {
RegisterDomainsParser(dnsmasqDomainsParser{})
RegisterDomainsParser(mosDNSDomainsParser{})
RegisterDomainsParser(smartDNSDomainsParser{})
RegisterDomainsParser(unboundDomainsParser{})
}
type dnsmasqDomainsParser struct{}
func (dnsmasqDomainsParser) CanHandle(line string) bool {
return strings.HasPrefix(line, "address=/")
}
func (dnsmasqDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if isComment(line) {
continue
}
switch {
case strings.HasPrefix(line, "address=/"):
part := strings.FieldsFunc(line, func(r rune) bool { return r == '/' })
if len(part) >= 3 && isDomainName(part[1]) {
domains = append(domains, part[1])
continue
}
}
ignored++
}
if err = scanner.Err(); err != nil {
return
}
return unique(domains), ignored, nil
}
type mosDNSDomainsParser struct{}
func (mosDNSDomainsParser) CanHandle(line string) bool {
if strings.HasPrefix(line, "domain:") {
return isDomainName(line[7:])
}
return false
}
func (mosDNSDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if isComment(line) {
continue
}
if strings.HasPrefix(line, "domain:") {
domains = append(domains, line[7:])
continue
}
ignored++
}
if err = scanner.Err(); err != nil {
return
}
return unique(domains), ignored, nil
}
type smartDNSDomainsParser struct{}
func (smartDNSDomainsParser) CanHandle(line string) bool {
return strings.HasPrefix(line, "address /")
}
func (smartDNSDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if isComment(line) {
continue
}
if strings.HasPrefix(line, "address /") {
if i := strings.IndexByte(line[9:], '/'); i > -1 {
domains = append(domains, line[9:i+9])
continue
}
}
ignored++
}
if err = scanner.Err(); err != nil {
return
}
return unique(domains), ignored, nil
}
type unboundDomainsParser struct{}
func (unboundDomainsParser) CanHandle(line string) bool {
return strings.HasPrefix(line, "local-data:") ||
strings.HasPrefix(line, "local-zone:")
}
func (unboundDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if isComment(line) {
continue
}
switch {
case strings.HasPrefix(line, "local-data:"):
record := strings.Trim(strings.TrimSpace(line[11:]), `"`)
if rr, err := dns.NewRR(record); err == nil {
switch rr.Header().Rrtype {
case dns.TypeA, dns.TypeAAAA, dns.TypeCNAME:
domains = append(domains, strings.Trim(rr.Header().Name, `.`))
continue
}
}
case strings.HasPrefix(line, "local-zone:") && strings.HasSuffix(line, " reject"):
line = strings.Trim(strings.TrimSpace(line[11:]), `"`)
if i := strings.IndexByte(line, '"'); i > -1 {
domains = append(domains, line[:i])
continue
}
}
ignored++
}
if err = scanner.Err(); err != nil {
return
}
return unique(domains), ignored, nil
}

106
dataset/parser/dns_test.go Normal file
View File

@@ -0,0 +1,106 @@
package parser
import (
"reflect"
"sort"
"strings"
"testing"
)
func TestDNSMasqParser(t *testing.T) {
tests := []struct {
Name string
Test string
Want []string
WantIgnored int
}{
{
"data",
`
local-data: "junk1.doubleclick.net A 127.0.0.1"
local-data: "junk2.doubleclick.net A 127.0.0.1"
local-data: "junk2.doubleclick.net CNAME doubleclick.net."
local-data: "junk6.doubleclick.net AAAA ::1"
local-data: "doubleclick.net A 127.0.0.1"
local-data: "ad.junk1.doubleclick.net A 127.0.0.1"
local-data: "adjunk.google.com A 127.0.0.1"`,
[]string{"ad.junk1.doubleclick.net", "adjunk.google.com", "doubleclick.net", "junk1.doubleclick.net", "junk2.doubleclick.net", "junk6.doubleclick.net"},
0,
},
{
"zone",
`
local-zone: "doubleclick.net" reject
local-zone: "adjunk.google.com" reject`,
[]string{"adjunk.google.com", "doubleclick.net"},
0,
},
{
"address",
`
address=/ziyu.net/0.0.0.0
address=/zlp6s.pw/0.0.0.0
address=/zm232.com/0.0.0.0
`,
[]string{"ziyu.net", "zlp6s.pw", "zm232.com"},
0,
},
}
for _, test := range tests {
t.Run(test.Name, func(it *testing.T) {
parsed, ignored, err := ParseDomains(strings.NewReader(test.Test))
if err != nil {
t.Fatal(err)
return
}
sort.Strings(parsed)
if !reflect.DeepEqual(parsed, test.Want) {
t.Errorf("expected ParseDomains(dnsmasq) to return\n\t%v, got\n\t%v", test.Want, parsed)
}
if ignored != test.WantIgnored {
t.Errorf("expected %d ignored, got %d", test.WantIgnored, ignored)
}
})
}
}
func TestMOSDNSParser(t *testing.T) {
test := `domain:0019x.com
domain:002777.xyz
domain:003store.com
domain:00404850.xyz`
want := []string{"0019x.com", "002777.xyz", "003store.com", "00404850.xyz"}
parsed, _, err := ParseDomains(strings.NewReader(test))
if err != nil {
t.Fatal(err)
return
}
sort.Strings(parsed)
if !reflect.DeepEqual(parsed, want) {
t.Errorf("expected ParseDomains(domains) to return %v, got %v", want, parsed)
}
}
func TestSmartDNSParser(t *testing.T) {
test := `# Title:AdRules SmartDNS List
# Update: 2025-10-07 02:05:08(GMT+8)
address /0.myikas.com/#
address /0.net.easyjet.com/#
address /0.nextyourcontent.com/#
address /0019x.com/#`
want := []string{"0.myikas.com", "0.net.easyjet.com", "0.nextyourcontent.com", "0019x.com"}
parsed, _, err := ParseDomains(strings.NewReader(test))
if err != nil {
t.Fatal(err)
return
}
sort.Strings(parsed)
if !reflect.DeepEqual(parsed, want) {
t.Errorf("expected ParseDomains(domains) to return %v, got %v", want, parsed)
}
}

40
dataset/parser/domains.go Normal file
View File

@@ -0,0 +1,40 @@
package parser
import (
"bufio"
"io"
"net"
"strings"
)
func init() {
domainsParsers = append(domainsParsers, domainsParser{})
}
type domainsParser struct{}
func (domainsParser) CanHandle(line string) bool {
return isDomainName(line) &&
!strings.ContainsRune(line, ' ') &&
!strings.ContainsRune(line, ':') &&
net.ParseIP(line) == nil
}
func (domainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if isComment(line) {
continue
}
if isDomainName(line) {
domains = append(domains, line)
continue
}
ignored++
}
if err = scanner.Err(); err != nil {
return
}
return unique(domains), ignored, nil
}

View File

@@ -0,0 +1,31 @@
package parser
import (
"reflect"
"sort"
"strings"
"testing"
)
func TestParseDomains(t *testing.T) {
test := `# This is a comment
facebook.com
tiktok.com
bogus ignored
youtube.com`
want := []string{"facebook.com", "tiktok.com", "youtube.com"}
parsed, ignored, err := ParseDomains(strings.NewReader(test))
if err != nil {
t.Fatal(err)
return
}
sort.Strings(parsed)
if !reflect.DeepEqual(parsed, want) {
t.Errorf("expected ParseDomains(domains) to return %v, got %v", want, parsed)
}
if ignored != 1 {
t.Errorf("expected 1 ignored, got %d", ignored)
}
}

41
dataset/parser/hosts.go Normal file
View File

@@ -0,0 +1,41 @@
package parser
import (
"bufio"
"io"
"net"
"strings"
)
func init() {
RegisterDomainsParser(hostsParser{})
}
type hostsParser struct{}
func (hostsParser) CanHandle(line string) bool {
part := strings.Fields(line)
return len(part) >= 2 && net.ParseIP(part[0]) != nil
}
func (hostsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if isComment(line) {
continue
}
part := strings.Fields(line)
if len(part) >= 2 && net.ParseIP(part[0]) != nil {
domains = append(domains, part[1:]...)
continue
}
ignored++
}
if err = scanner.Err(); err != nil {
return
}
return unique(domains), ignored, nil
}

View File

@@ -0,0 +1,38 @@
package parser
import (
"reflect"
"sort"
"strings"
"testing"
)
func TestParseHosts(t *testing.T) {
test := `##
# Host Database
#
# localhost is used to configure the loopback interface
# when the system is booting. Do not change this entry.
##
127.0.0.1 localhost dragon dragon.local dragon.maze.network
255.255.255.255 broadcasthost
::1 localhost
ff00::1 multicast
1.2.3.4
`
want := []string{"broadcasthost", "dragon", "dragon.local", "dragon.maze.network", "localhost", "multicast"}
parsed, ignored, err := ParseDomains(strings.NewReader(test))
if err != nil {
t.Fatal(err)
return
}
sort.Strings(parsed)
if !reflect.DeepEqual(parsed, want) {
t.Errorf("expected ParseDomains(hosts) to return %v, got %v", want, parsed)
}
if ignored != 1 {
t.Errorf("expected 1 ignored, got %d", ignored)
}
}

76
dataset/parser/parser.go Normal file
View File

@@ -0,0 +1,76 @@
package parser
import (
"bufio"
"bytes"
"errors"
"io"
"log"
"strings"
"github.com/miekg/dns"
)
var ErrNoParser = errors.New("no suitable parser could be found")
type Parser interface {
CanHandle(line string) bool
}
type DomainsParser interface {
Parser
ParseDomains(io.Reader) (domains []string, ignored int, err error)
}
var domainsParsers []DomainsParser
func RegisterDomainsParser(parser DomainsParser) {
domainsParsers = append(domainsParsers, parser)
}
func ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
var (
buffer = new(bytes.Buffer)
scanner = bufio.NewScanner(io.TeeReader(r, buffer))
line string
parser DomainsParser
)
for scanner.Scan() {
line = strings.TrimSpace(scanner.Text())
if isComment(line) {
continue
}
for _, parser = range domainsParsers {
if parser.CanHandle(line) {
log.Printf("using parser %T", parser)
return parser.ParseDomains(io.MultiReader(buffer, r))
}
}
break
}
return nil, 0, ErrNoParser
}
func isComment(line string) bool {
return line == "" || line[0] == '#' || line[0] == '!'
}
func isDomainName(name string) bool {
n, ok := dns.IsDomainName(name)
return n >= 2 && ok
}
func unique(strings []string) []string {
if strings == nil {
return nil
}
v := make(map[string]struct{})
for _, s := range strings {
v[s] = struct{}{}
}
o := make([]string, 0, len(v))
for k := range v {
o = append(o, k)
}
return o
}

View File

@@ -0,0 +1,31 @@
package parser
import (
"reflect"
"sort"
"testing"
)
func TestUnique(t *testing.T) {
tests := []struct {
Name string
Test []string
Want []string
}{
{"nil", nil, nil},
{"single", []string{"test"}, []string{"test"}},
{"duplicate", []string{"test", "test"}, []string{"test"}},
{"multiple", []string{"a", "a", "b", "b", "b", "c"}, []string{"a", "b", "c"}},
}
for _, test := range tests {
t.Run(test.Name, func(it *testing.T) {
v := unique(test.Test)
if v != nil {
sort.Strings(v)
}
if !reflect.DeepEqual(v, test.Want) {
it.Errorf("expected unique(%v) to return %v, got %v", test.Test, test.Want, v)
}
})
}
}

231
dataset/storage.go Normal file
View File

@@ -0,0 +1,231 @@
package dataset
import (
"bufio"
"bytes"
"fmt"
"io"
"io/fs"
"net"
"net/http"
"net/url"
"os"
"slices"
"strings"
"time"
_ "github.com/mattn/go-sqlite3" // SQLite3 driver
"github.com/miekg/dns"
)
type Storage interface {
Groups() ([]Group, error)
GroupByID(int64) (Group, error)
GroupByName(name string) (Group, error)
SaveGroup(*Group) error
DeleteGroup(Group) error
Clients() (Clients, error)
ClientByID(int64) (Client, error)
ClientByIP(net.IP) (Client, error)
SaveClient(*Client) error
DeleteClient(Client) error
Lists() ([]List, error)
ListByID(int64) (List, error)
SaveList(*List) error
DeleteList(List) error
}
type Group struct {
ID int64 `json:"id"`
Name string `json:"name" bstore:"nonzero,unique"`
IsEnabled bool `json:"is_enabled" bstore:"nonzero"`
Description string `json:"description"`
CreatedAt time.Time `json:"created_at" bstore:"nonzero"`
UpdatedAt time.Time `json:"updated_at" bstore:"nonzero"`
Storage Storage `json:"-" bstore:"-"`
}
type Client struct {
ID int64 `json:"id"`
Network string `json:"network" bstore:"nonzero,index"`
IP string `json:"ip" bstore:"nonzero,unique IP+Mask"`
Mask int `json:"mask"`
Description string `json:"description"`
Groups []Group `json:"groups,omitempty" bstore:"-"`
CreatedAt time.Time `json:"created_at" bstore:"nonzero"`
UpdatedAt time.Time `json:"updated_at" bstore:"nonzero"`
Storage Storage `json:"-" bstore:"-"`
}
type WithClient interface {
Client() (Client, error)
}
type ClientGroup struct {
ID int64 `json:"id"`
ClientID int64 `json:"client_id" bstore:"ref Client,index"`
GroupID int64 `json:"group_id" bstore:"ref Group,index"`
}
func (c *Client) ContainsIP(ip net.IP) bool {
ipnet := &net.IPNet{
IP: net.ParseIP(c.IP),
Mask: net.CIDRMask(int(c.Mask), 32),
}
if ipnet.IP == nil {
return false
}
return ipnet.Contains(ip)
}
func (c *Client) String() string {
ipnet := &net.IPNet{
IP: net.ParseIP(c.IP),
Mask: net.CIDRMask(int(c.Mask), 32),
}
return ipnet.String()
}
type Clients []Client
func (cs Clients) ByIP(ip net.IP) *Client {
var candidates []*Client
for _, c := range cs {
if c.ContainsIP(ip) {
candidates = append(candidates, &c)
}
}
switch len(candidates) {
case 0:
return nil
case 1:
return candidates[0]
default:
slices.SortStableFunc(candidates, func(a, b *Client) int {
return int(b.Mask) - int(a.Mask)
})
return candidates[0]
}
}
const (
ListTypeDomain = "domain"
ListTypeNetwork = "network"
)
const (
MinListRefresh = 1 * time.Minute
DefaultListRefresh = 30 * time.Minute
)
type List struct {
ID int64 `json:"id"`
Type string `json:"type"`
Source string `json:"source"`
IsEnabled bool `json:"is_enabled"`
Permit bool `json:"permit"`
Groups []Group `json:"groups,omitempty" bstore:"-"`
Status int `json:"status"`
Comment string `json:"comment"`
Cache []byte `json:"cache"`
Refresh time.Duration `json:"refresh"`
LastModified time.Time `json:"last_modified"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (list *List) Domains() (*DomainTree, error) {
if list.Type != ListTypeDomain {
return nil, nil
}
var (
tree = NewDomainList()
scan = bufio.NewScanner(bytes.NewReader(list.Cache))
)
for scan.Scan() {
line := strings.TrimSpace(scan.Text())
if line == "" || line[0] == '#' {
continue
}
if labels, ok := dns.IsDomainName(line); ok && labels >= 2 {
tree.Add(line)
}
}
if err := scan.Err(); err != nil {
return nil, err
}
return tree, nil
}
func (list *List) Update() (updated bool, err error) {
u, err := url.Parse(list.Source)
if err != nil {
return false, err
}
switch u.Scheme {
case "", "file":
return list.updateFile(u.Path)
case "http", "https":
return list.updateHTTP(u.String())
default:
return false, fmt.Errorf("dataset: don't know how to update %s sources", u.Scheme)
}
}
func (list *List) updateFile(name string) (updated bool, err error) {
var info fs.FileInfo
if info, err = os.Stat(name); err != nil {
return
} else if info.IsDir() {
return false, fmt.Errorf("dataset: list %d: %q is a directory", list.ID, name)
}
if updated = info.ModTime().After(list.UpdatedAt); !updated {
return
}
list.Cache, _ = os.ReadFile(name)
return
}
func (list *List) updateHTTP(url string) (updated bool, err error) {
if updated, err = list.shouldUpdateHTTP(url); err != nil || !updated {
return
}
var response *http.Response
if response, err = http.DefaultClient.Get(url); err != nil {
return
}
defer response.Body.Close()
if list.Cache, err = io.ReadAll(response.Body); err != nil {
return
}
return true, nil
}
func (list *List) shouldUpdateHTTP(url string) (updated bool, err error) {
var response *http.Response
if response, err = http.DefaultClient.Head(url); err != nil {
return
}
defer response.Body.Close()
if value := response.Header.Get("Last-Modified"); value != "" {
var lastModified time.Time
if lastModified, err = time.Parse(http.TimeFormat, value); err == nil {
return lastModified.After(list.LastModified), nil
}
}
// There are no headers that would indicate last-modified time, so assume we have to update:
return true, nil
}
type ListGroup struct {
ID int64 `json:"id"`
ListID int64 `json:"list_id" bstore:"ref List,index"`
GroupID int64 `json:"group_id" bstore:"ref Group,index"`
}

412
dataset/storage_bstore.go Normal file
View File

@@ -0,0 +1,412 @@
package dataset
import (
"context"
"errors"
"fmt"
"net"
"path/filepath"
"slices"
"strings"
"time"
"git.maze.io/maze/styx/logger"
"github.com/mjl-/bstore"
)
type bstoreStorage struct {
db *bstore.DB
path string
}
func OpenBStore(name string) (Storage, error) {
if !filepath.IsAbs(name) {
var err error
if name, err = filepath.Abs(name); err != nil {
return nil, err
}
}
ctx := context.Background()
db, err := bstore.Open(ctx, name, nil,
Group{},
Client{},
ClientGroup{},
List{},
ListGroup{},
)
if err != nil {
return nil, err
}
var (
s = &bstoreStorage{db: db, path: name}
defaultGroup Group
defaultClient4 Client
defaultClient6 Client
)
if defaultGroup, err = s.GroupByName("Default"); errors.Is(err, bstore.ErrAbsent) {
defaultGroup = Group{
Name: "Default",
IsEnabled: true,
Description: "Default group",
}
if err = s.SaveGroup(&defaultGroup); err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
if defaultClient4, err = bstore.QueryDB[Client](ctx, db).
FilterEqual("Network", "ipv4").
FilterFn(func(client Client) bool {
return net.ParseIP(client.IP).Equal(net.ParseIP("0.0.0.0")) && client.Mask == 0
}).Get(); errors.Is(err, bstore.ErrAbsent) {
defaultClient4 = Client{
Network: "ipv4",
IP: "0.0.0.0",
Mask: 0,
Description: "All IPv4 clients",
}
if err = s.SaveClient(&defaultClient4); err != nil {
return nil, err
}
if err = s.db.Insert(ctx, &ClientGroup{ClientID: defaultClient4.ID, GroupID: defaultGroup.ID}); err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
if defaultClient6, err = bstore.QueryDB[Client](ctx, db).
FilterEqual("Network", "ipv6").
FilterFn(func(client Client) bool {
return net.ParseIP(client.IP).Equal(net.ParseIP("::")) && client.Mask == 0
}).Get(); errors.Is(err, bstore.ErrAbsent) {
defaultClient6 = Client{
Network: "ipv6",
IP: "::",
Mask: 0,
Description: "All IPv6 clients",
}
if err = s.SaveClient(&defaultClient6); err != nil {
return nil, err
}
if err = s.db.Insert(ctx, &ClientGroup{ClientID: defaultClient6.ID, GroupID: defaultGroup.ID}); err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
// Start updater
NewUpdater(s)
return s, nil
}
func (s *bstoreStorage) log() logger.Structured {
return logger.StandardLog.Values(logger.Values{
"storage": "bstore",
"storage_path": s.path,
})
}
func (s *bstoreStorage) Groups() ([]Group, error) {
var (
ctx = context.Background()
query = bstore.QueryDB[Group](ctx, s.db)
groups = make([]Group, 0)
)
for group := range query.All() {
groups = append(groups, group)
}
if err := query.Err(); err != nil && !errors.Is(err, bstore.ErrFinished) {
return nil, err
}
return groups, nil
}
func (s *bstoreStorage) GroupByID(id int64) (Group, error) {
ctx := context.Background()
return bstore.QueryDB[Group](ctx, s.db).FilterID(id).Get()
}
func (s *bstoreStorage) GroupByName(name string) (Group, error) {
ctx := context.Background()
return bstore.QueryDB[Group](ctx, s.db).FilterFn(func(group Group) bool {
return strings.EqualFold(group.Name, name)
}).Get()
}
func (s *bstoreStorage) SaveGroup(group *Group) (err error) {
ctx := context.Background()
group.UpdatedAt = time.Now().UTC()
if group.CreatedAt.Equal(time.Time{}) {
group.CreatedAt = group.UpdatedAt
err = s.db.Insert(ctx, group)
} else {
err = s.db.Update(ctx, group)
}
if err != nil {
return fmt.Errorf("dataset: save group %s failed: %w", group.Name, err)
}
return nil
}
func (s *bstoreStorage) DeleteGroup(group Group) (err error) {
ctx := context.Background()
tx, err := s.db.Begin(ctx, true)
if err != nil {
return err
}
if _, err = bstore.QueryTx[ClientGroup](tx).FilterEqual("GroupID", group.ID).Delete(); err != nil {
return
}
if _, err = bstore.QueryTx[ListGroup](tx).FilterEqual("GroupID", group.ID).Delete(); err != nil {
return
}
if err = tx.Delete(group); err != nil {
return
}
return tx.Commit()
}
func (s *bstoreStorage) Clients() (Clients, error) {
var (
ctx = context.Background()
query = bstore.QueryDB[Client](ctx, s.db)
clients = make(Clients, 0)
)
for client := range query.All() {
clients = append(clients, client)
}
if err := query.Err(); err != nil && !errors.Is(err, bstore.ErrFinished) {
return nil, err
}
return clients, nil
}
func (s *bstoreStorage) ClientByID(id int64) (Client, error) {
ctx := context.Background()
client, err := bstore.QueryDB[Client](ctx, s.db).FilterID(id).Get()
if err != nil {
return client, err
}
return s.clientResolveGroups(ctx, client)
}
func (s *bstoreStorage) ClientByIP(ip net.IP) (Client, error) {
if ip == nil {
return Client{}, ErrNotExist{Object: "client"}
}
var (
ctx = context.Background()
clients Clients
network string
)
if ip4 := ip.To4(); ip4 != nil {
network = "ipv4"
} else if ip6 := ip.To16(); ip6 != nil {
network = "ipv6"
}
if network == "" {
return Client{}, ErrNotExist{Object: "client"}
}
for client, err := range bstore.QueryDB[Client](ctx, s.db).
FilterEqual("Network", network).
FilterFn(func(client Client) bool {
return client.ContainsIP(ip)
}).All() {
if err != nil {
return Client{}, err
}
clients = append(clients, client)
}
var client Client
switch len(clients) {
case 0:
return Client{}, ErrNotExist{Object: "client"}
case 1:
client = clients[0]
default:
slices.SortStableFunc(clients, func(a, b Client) int {
return int(b.Mask) - int(a.Mask)
})
client = clients[0]
}
return s.clientResolveGroups(ctx, client)
}
func (s *bstoreStorage) clientResolveGroups(ctx context.Context, client Client) (Client, error) {
for clientGroup, err := range bstore.QueryDB[ClientGroup](ctx, s.db).FilterEqual("ClientID", client.ID).All() {
if err != nil {
return Client{}, err
}
if group, err := s.GroupByID(clientGroup.GroupID); err == nil {
client.Groups = append(client.Groups, group)
}
}
return client, nil
}
func (s *bstoreStorage) SaveClient(client *Client) (err error) {
log := s.log()
ctx := context.Background()
client.UpdatedAt = time.Now().UTC()
tx, err := s.db.Begin(ctx, true)
if err != nil {
return err
}
log = log.Values(logger.Values{"ip": client.IP, "mask": client.Mask, "description": client.Description})
if client.CreatedAt.Equal(time.Time{}) {
log.Debug("Create client")
client.CreatedAt = client.UpdatedAt
if err = tx.Insert(client); err != nil {
return fmt.Errorf("dataset: client insert failed: %w", err)
}
} else {
log.Debug("Update client")
if err = tx.Update(client); err != nil {
return fmt.Errorf("dataset: client update failed: %w", err)
}
}
var deleted int
if deleted, err = bstore.QueryTx[ClientGroup](tx).FilterEqual("ClientID", client.ID).Delete(); err != nil {
return fmt.Errorf("dataset: client groups delete failed: %w", err)
}
log.Debugf("Deleted %d groups", deleted)
log.Debugf("Linking %d groups", len(client.Groups))
for _, group := range client.Groups {
if err = tx.Insert(&ClientGroup{ClientID: client.ID, GroupID: group.ID}); err != nil {
return fmt.Errorf("dataset: client groups insert failed: %w", err)
}
}
return tx.Commit()
}
func (s *bstoreStorage) DeleteClient(client Client) (err error) {
ctx := context.Background()
tx, err := s.db.Begin(ctx, true)
if err != nil {
return err
}
if _, err = bstore.QueryTx[ClientGroup](tx).FilterEqual("ClientID", client.ID).Delete(); err != nil {
return
}
if err = tx.Delete(client); err != nil {
return
}
return tx.Commit()
}
func (s *bstoreStorage) Lists() ([]List, error) {
var (
ctx = context.Background()
query = bstore.QueryDB[List](ctx, s.db)
lists = make([]List, 0)
)
for list := range query.All() {
lists = append(lists, list)
}
if err := query.Err(); err != nil && !errors.Is(err, bstore.ErrFinished) {
return nil, err
}
return lists, nil
}
func (s *bstoreStorage) ListByID(id int64) (List, error) {
ctx := context.Background()
list, err := bstore.QueryDB[List](ctx, s.db).FilterID(id).Get()
if err != nil {
return list, err
}
return s.listResolveGroups(ctx, list)
}
func (s *bstoreStorage) listResolveGroups(ctx context.Context, list List) (List, error) {
for listGroup, err := range bstore.QueryDB[ListGroup](ctx, s.db).FilterEqual("ListID", list.ID).All() {
if err != nil {
return List{}, err
}
if group, err := s.GroupByID(listGroup.GroupID); err == nil {
list.Groups = append(list.Groups, group)
}
}
return list, nil
}
func (s *bstoreStorage) SaveList(list *List) (err error) {
if list.Type != ListTypeDomain && list.Type != ListTypeNetwork {
return fmt.Errorf("storage: unknown list type %q", list.Type)
}
if list.Refresh == 0 {
list.Refresh = DefaultListRefresh
} else if list.Refresh < MinListRefresh {
list.Refresh = MinListRefresh
}
list.UpdatedAt = time.Now().UTC()
ctx := context.Background()
tx, err := s.db.Begin(ctx, true)
if err != nil {
return err
}
log := s.log()
log = log.Values(logger.Values{
"type": list.Type,
"source": list.Source,
"is_enabled": list.IsEnabled,
"status": list.Status,
"cache": len(list.Cache),
"refresh": list.Refresh,
})
if list.CreatedAt.Equal(time.Time{}) {
log.Debug("Creating list")
list.CreatedAt = list.UpdatedAt
if err = tx.Insert(list); err != nil {
return fmt.Errorf("dataset: list insert failed: %w", err)
}
} else {
log.Debug("Updating list")
if err = tx.Update(list); err != nil {
return fmt.Errorf("dataset: list update failed: %w", err)
}
}
var deleted int
if deleted, err = bstore.QueryTx[ListGroup](tx).FilterEqual("ListID", list.ID).Delete(); err != nil {
return fmt.Errorf("dataset: list groups delete failed: %w", err)
}
log.Debugf("Deleted %d groups", deleted)
log.Debugf("Linking %d groups", len(list.Groups))
for _, group := range list.Groups {
if err = tx.Insert(&ListGroup{ListID: list.ID, GroupID: group.ID}); err != nil {
return fmt.Errorf("dataset: list groups insert failed: %w", err)
}
}
return tx.Commit()
}
func (s *bstoreStorage) DeleteList(list List) (err error) {
ctx := context.Background()
tx, err := s.db.Begin(ctx, true)
if err != nil {
return err
}
if _, err = bstore.QueryTx[ListGroup](tx).FilterEqual("ListID", list.ID).Delete(); err != nil {
return
}
if err = tx.Delete(list); err != nil {
return
}
return tx.Commit()
}

226
dataset/updater.go Normal file
View File

@@ -0,0 +1,226 @@
package dataset
import (
"bytes"
"io"
"net/http"
"net/url"
"os"
"sync"
"time"
"git.maze.io/maze/styx/logger"
)
type Updater struct {
storage Storage
lists sync.Map // map[int64]List
updaters sync.Map // map[int64]*updaterJob
done chan struct{}
}
func NewUpdater(storage Storage) *Updater {
u := &Updater{
storage: storage,
done: make(chan struct{}, 1),
}
go u.refresh()
return u
}
func (u *Updater) Close() error {
select {
case <-u.done:
return nil
default:
close(u.done)
return nil
}
}
func (u *Updater) refresh() {
check := time.NewTicker(time.Second)
defer check.Stop()
var (
log = logger.StandardLog
)
for {
select {
case <-u.done:
log.Debug("Updater closing, stopping updaters...")
u.updaters.Range(func(key, value any) bool {
if value != nil {
close(value.(*updaterJob).done)
}
return true
})
return
case now := <-check.C:
u.check(now, log)
}
}
}
func (u *Updater) check(now time.Time, log logger.Structured) (wait time.Duration) {
log.Trace("Checking lists")
lists, err := u.storage.Lists()
if err != nil {
log.Err(err).Error("Updater can't retrieve lists")
return -1
}
var missing = make(map[int64]bool)
u.lists.Range(func(key, _ any) bool {
log.Tracef("List %d has updater running", key)
missing[key.(int64)] = true
return true
})
for _, list := range lists {
log.Tracef("List %d is active: %t", list.ID, list.IsEnabled)
if !list.IsEnabled {
continue
}
delete(missing, list.ID)
if _, exists := u.lists.Load(list.ID); !exists {
u.lists.Store(list.ID, list)
updater := newUpdaterJob(u.storage, &list)
u.updaters.Store(list.ID, updater)
}
}
for id := range missing {
log.Tracef("List %d has updater running, but is no longer active, reaping...", id)
if updater, ok := u.updaters.Load(id); ok {
close(updater.(*updaterJob).done)
u.updaters.Delete(id)
}
}
return
}
type updaterJob struct {
storage Storage
list *List
done chan struct{}
}
func newUpdaterJob(storage Storage, list *List) *updaterJob {
job := &updaterJob{
storage: storage,
list: list,
done: make(chan struct{}, 1),
}
go job.loop()
return job
}
func (job *updaterJob) loop() {
var (
ticker = time.NewTicker(job.list.Refresh)
first = time.After(0)
now time.Time
log = logger.StandardLog.Values(logger.Values{
"list": job.list.ID,
"type": job.list.Type,
})
)
defer ticker.Stop()
for {
select {
case <-job.done:
log.Debug("List updater stopping")
return
case now = <-ticker.C:
case now = <-first:
}
log.Debug("List updater running")
if update, err := job.run(now); err != nil {
log.Err(err).Error("List updater failed")
} else if update {
if err = job.storage.SaveList(job.list); err != nil {
log.Err(err).Error("List updater save failed")
}
}
}
}
// run this updater
func (job *updaterJob) run(now time.Time) (update bool, err error) {
u, err := url.Parse(job.list.Source)
if err != nil {
return false, err
}
log := logger.StandardLog.Values(logger.Values{
"list": job.list.ID,
"source": job.list.Source,
})
if u.Scheme == "" || u.Scheme == "file" {
log.Debug("Updating list from file")
return job.updateFile(u.Path)
}
log.Debug("Updating list from URL")
return job.updateHTTP(u)
}
func (job *updaterJob) updateFile(name string) (update bool, err error) {
var b []byte
if b, err = os.ReadFile(name); err != nil {
return
}
if update = !bytes.Equal(b, job.list.Cache); update {
job.list.Cache = b
}
return
}
func (job *updaterJob) updateHTTP(location *url.URL) (update bool, err error) {
if update, err = job.shouldUpdateHTTP(location); err != nil || !update {
return
}
var (
req *http.Request
res *http.Response
)
if req, err = http.NewRequest(http.MethodGet, location.String(), nil); err != nil {
return
}
if res, err = http.DefaultClient.Do(req); err != nil {
return
}
defer res.Body.Close()
if job.list.Cache, err = io.ReadAll(res.Body); err != nil {
return
}
return true, nil
}
func (job *updaterJob) shouldUpdateHTTP(location *url.URL) (update bool, err error) {
if len(job.list.Cache) == 0 {
// Nothing cached, please update.
return true, nil
}
var (
req *http.Request
res *http.Response
)
if req, err = http.NewRequest(http.MethodHead, location.String(), nil); err != nil {
return
}
if res, err = http.DefaultClient.Do(req); err != nil {
return
}
defer res.Body.Close()
if lastModified, err := time.Parse(http.TimeFormat, res.Header.Get("Last-Modified")); err == nil {
return lastModified.After(job.list.UpdatedAt), nil
}
return true, nil // not sure, no Last-Modified, so let's update?
}

2
go.mod
View File

@@ -7,6 +7,7 @@ require (
github.com/hashicorp/hcl/v2 v2.24.0 github.com/hashicorp/hcl/v2 v2.24.0
github.com/mattn/go-sqlite3 v1.14.32 github.com/mattn/go-sqlite3 v1.14.32
github.com/miekg/dns v1.1.68 github.com/miekg/dns v1.1.68
github.com/mjl-/bstore v0.0.10
github.com/open-policy-agent/opa v1.9.0 github.com/open-policy-agent/opa v1.9.0
github.com/rs/zerolog v1.34.0 github.com/rs/zerolog v1.34.0
github.com/sirupsen/logrus v1.9.4-0.20230606125235-dd1b4c2e81af github.com/sirupsen/logrus v1.9.4-0.20230606125235-dd1b4c2e81af
@@ -54,6 +55,7 @@ require (
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
github.com/yashtewari/glob-intersection v0.2.0 // indirect github.com/yashtewari/glob-intersection v0.2.0 // indirect
github.com/zclconf/go-cty v1.16.3 // indirect github.com/zclconf/go-cty v1.16.3 // indirect
go.etcd.io/bbolt v1.4.3 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/otel v1.38.0 // indirect go.opentelemetry.io/otel v1.38.0 // indirect
go.opentelemetry.io/otel/metric v1.38.0 // indirect go.opentelemetry.io/otel/metric v1.38.0 // indirect

4
go.sum
View File

@@ -98,6 +98,8 @@ github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps= github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0= github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0=
github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0= github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0=
github.com/mjl-/bstore v0.0.10 h1:fYLQy3EdgXvRHoa8Q3sXMAjZf+uQLRbsh9rYjGep/t4=
github.com/mjl-/bstore v0.0.10/go.mod h1:QzqlAZAVRKwyojCRd9v25viFsMxK5UmIbdxgEyHdK6c=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/open-policy-agent/opa v1.9.0 h1:QWFNwbcc29IRy0xwD3hRrMc/RtSersLY1Z6TaID3vgI= github.com/open-policy-agent/opa v1.9.0 h1:QWFNwbcc29IRy0xwD3hRrMc/RtSersLY1Z6TaID3vgI=
@@ -152,6 +154,8 @@ github.com/zclconf/go-cty v1.16.3 h1:osr++gw2T61A8KVYHoQiFbFd1Lh3JOCXc/jFLJXKTxk
github.com/zclconf/go-cty v1.16.3/go.mod h1:VvMs5i0vgZdhYawQNq5kePSpLAoz8u1xvZgrPIxfnZE= github.com/zclconf/go-cty v1.16.3/go.mod h1:VvMs5i0vgZdhYawQNq5kePSpLAoz8u1xvZgrPIxfnZE=
github.com/zclconf/go-cty-debug v0.0.0-20240509010212-0d6042c53940 h1:4r45xpDWB6ZMSMNJFMOjqrGHynW3DIBuR2H9j0ug+Mo= github.com/zclconf/go-cty-debug v0.0.0-20240509010212-0d6042c53940 h1:4r45xpDWB6ZMSMNJFMOjqrGHynW3DIBuR2H9j0ug+Mo=
github.com/zclconf/go-cty-debug v0.0.0-20240509010212-0d6042c53940/go.mod h1:CmBdvvj3nqzfzJ6nTCIwDTPZ56aVGvDrmztiO5g3qrM= github.com/zclconf/go-cty-debug v0.0.0-20240509010212-0d6042c53940/go.mod h1:CmBdvvj3nqzfzJ6nTCIwDTPZ56aVGvDrmztiO5g3qrM=
go.etcd.io/bbolt v1.4.3 h1:dEadXpI6G79deX5prL3QRNP6JB8UxVkqo4UPnHaNXJo=
go.etcd.io/bbolt v1.4.3/go.mod h1:tKQlpPaYCVFctUIgFKFnAlvbmB3tpy1vkTnDWohtc0E=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18=

View File

@@ -8,6 +8,8 @@ import (
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"time" "time"
"git.maze.io/maze/styx/logger"
) )
// BufferedConn uses byte buffers for Read and Write operations on a [net.Conn]. // BufferedConn uses byte buffers for Read and Write operations on a [net.Conn].
@@ -123,10 +125,13 @@ type AcceptOnce struct {
} }
func (listener *AcceptOnce) Accept() (net.Conn, error) { func (listener *AcceptOnce) Accept() (net.Conn, error) {
log := logger.StandardLog.Value("client", listener.Conn.RemoteAddr().String())
if listener.once.Load() { if listener.once.Load() {
log.Trace("Accept already happened, responding EOF")
return nil, io.EOF return nil, io.EOF
} }
listener.once.Store(true) listener.once.Store(true)
log.Trace("Accept client")
return listener.Conn, nil return listener.Conn, nil
} }

74
internal/timeutil/time.go Normal file
View File

@@ -0,0 +1,74 @@
package timeutil
import "time"
var (
validTimeLayouts = []string{
"15:04:05.999999999",
"15:04:05",
"15:04",
"3:04:05PM",
"3:04PM",
"3PM",
}
)
type Time struct {
Hour int
Minute int
Second int
Nanosecond int
}
func ParseTime(value string) (Time, error) {
var t time.Time
for _, layout := range validTimeLayouts {
var err error
if t, err = time.Parse(layout, value); err == nil {
return Time{
Hour: t.Hour(),
Minute: t.Minute(),
Second: t.Second(),
Nanosecond: t.Nanosecond(),
}, nil
}
}
return Time{}, &time.ParseError{
Value: value,
Message: "invalid time",
}
}
func Now() Time {
t := time.Now()
return Time{
Hour: t.Hour(),
Minute: t.Minute(),
Second: t.Second(),
Nanosecond: t.Nanosecond(),
}
}
func (t Time) After(other Time) bool {
return other.Before(t)
}
func (t Time) Before(other Time) bool {
if t.Hour == other.Hour {
if t.Minute == other.Minute {
if t.Second == other.Second {
return t.Nanosecond < other.Nanosecond
}
return t.Second < other.Second
}
return t.Minute < other.Minute
}
return t.Hour < other.Hour
}
func (t Time) Eq(other Time) bool {
return t.Hour == other.Hour &&
t.Minute == other.Minute &&
t.Second == other.Second &&
t.Nanosecond == other.Nanosecond
}

View File

@@ -15,17 +15,25 @@ import (
"github.com/open-policy-agent/opa/v1/types" "github.com/open-policy-agent/opa/v1/types"
"git.maze.io/maze/styx/dataset" "git.maze.io/maze/styx/dataset"
"git.maze.io/maze/styx/internal/timeutil"
"git.maze.io/maze/styx/logger" "git.maze.io/maze/styx/logger"
) )
var netLookupIPAddrDecl = types.NewFunction( var lookupIPAddrFunc = &rego.Function{
Name: "styx.lookup_ip_addr",
Decl: lookupIPAddrDecl,
Memoize: true,
Nondeterministic: true,
}
var lookupIPAddrDecl = types.NewFunction(
types.Args( types.Args(
types.Named("name", types.S).Description("Host name to lookup"), types.Named("name", types.S).Description("Host name to lookup"),
), ),
types.Named("result", types.SetOfStr).Description("set(string) of IP address"), types.Named("result", types.SetOfStr).Description("set(string) of IP address"),
) )
func netLookupIPAddrImpl(bc rego.BuiltinContext, nameTerm *ast.Term) (*ast.Term, error) { func lookupIPAddr(bc rego.BuiltinContext, nameTerm *ast.Term) (*ast.Term, error) {
log := logger.StandardLog.Value("func", "styx.lookup_ip_addr") log := logger.StandardLog.Value("func", "styx.lookup_ip_addr")
log.Trace("Call function") log.Trace("Call function")
@@ -61,6 +69,57 @@ func netLookupIPAddrImpl(bc rego.BuiltinContext, nameTerm *ast.Term) (*ast.Term,
return ast.SetTerm(terms...), nil return ast.SetTerm(terms...), nil
} }
var timebetweenFunc = &rego.Function{
Name: "styx.time_between",
Decl: timeBetweenDecl,
Nondeterministic: false,
}
var timeBetweenDecl = types.NewFunction(
types.Args(
types.Named("start", types.S).Description("Start time"),
types.Named("end", types.S).Description("End time"),
),
types.Named("result", types.B).Description("`true` if the current local time is between `start` and `end`"),
)
func timeBetween(bc rego.BuiltinContext, startTerm, endTerm *ast.Term) (*ast.Term, error) {
log := logger.StandardLog.Value("func", "styx.time_between")
log.Trace("Call function")
start, err := parseTimeTerm(startTerm)
if err != nil {
log.Err(err).Debug("Invalid start time")
return nil, err
}
end, err := parseTimeTerm(endTerm)
if err != nil {
log.Err(err).Debug("Invalid end time")
return nil, err
}
now := timeutil.Now()
if start.Before(end) {
return ast.BooleanTerm((now.Eq(start) || now.After(start)) && now.Before(end)), nil
}
return ast.BooleanTerm(now.Eq(end) || now.After(end) || now.Before(start)), nil
}
func parseTimeTerm(term *ast.Term) (timeutil.Time, error) {
timeArg, ok := term.Value.(ast.String)
if !ok {
return timeutil.Time{}, errors.New("expected string argument")
}
return timeutil.ParseTime(strings.Trim(timeArg.String(), `"`))
}
var domainContainsFunc = &rego.Function{
Name: "styx.domains_contain",
Decl: domainContainsDecl,
Memoize: true,
Nondeterministic: true,
}
var domainContainsDecl = types.NewFunction( var domainContainsDecl = types.NewFunction(
types.Args( types.Args(
types.Named("list", types.S).Description("Domain list to check against"), types.Named("list", types.S).Description("Domain list to check against"),
@@ -69,8 +128,8 @@ var domainContainsDecl = types.NewFunction(
types.Named("result", types.B).Description("`true` if `name` is contained within `list`"), types.Named("result", types.B).Description("`true` if `name` is contained within `list`"),
) )
func domainContainsImpl(bc rego.BuiltinContext, listTerm, nameTerm *ast.Term) (*ast.Term, error) { func domainContains(bc rego.BuiltinContext, listTerm, nameTerm *ast.Term) (*ast.Term, error) {
log := logger.StandardLog.Value("func", "styx.in_domains") log := logger.StandardLog.Value("func", "styx.domains_contain")
log.Trace("Call function") log.Trace("Call function")
list, err := parseDomainListTerm(listTerm) list, err := parseDomainListTerm(listTerm)
@@ -91,6 +150,13 @@ func domainContainsImpl(bc rego.BuiltinContext, listTerm, nameTerm *ast.Term) (*
return ast.BooleanTerm(list.Contains(name)), nil return ast.BooleanTerm(list.Contains(name)), nil
} }
var networkContainsFunc = &rego.Function{
Name: "styx.networks_contain",
Decl: networkContainsDecl,
Memoize: true,
Nondeterministic: true,
}
var networkContainsDecl = types.NewFunction( var networkContainsDecl = types.NewFunction(
types.Args( types.Args(
types.Named("list", types.S).Description("Network list to check against"), types.Named("list", types.S).Description("Network list to check against"),
@@ -99,8 +165,8 @@ var networkContainsDecl = types.NewFunction(
types.Named("result", types.B).Description("`true` if `ip` is contained within `list`"), types.Named("result", types.B).Description("`true` if `ip` is contained within `list`"),
) )
func networkContainsImpl(bc rego.BuiltinContext, listTerm, ipTerm *ast.Term) (*ast.Term, error) { func networkContains(bc rego.BuiltinContext, listTerm, ipTerm *ast.Term) (*ast.Term, error) {
log := logger.StandardLog.Value("func", "styx.in_networks") log := logger.StandardLog.Value("func", "styx.networks_contain")
list, err := parseNetworkListTerm(listTerm) list, err := parseNetworkListTerm(listTerm)
if err != nil { if err != nil {

View File

@@ -1,9 +1,12 @@
package policy package policy
import ( import (
"bufio"
"crypto/tls"
"net" "net"
"net/http" "net/http"
"git.maze.io/maze/styx/ca"
"git.maze.io/maze/styx/internal/netutil" "git.maze.io/maze/styx/internal/netutil"
"git.maze.io/maze/styx/logger" "git.maze.io/maze/styx/logger"
proxy "git.maze.io/maze/styx/proxy" proxy "git.maze.io/maze/styx/proxy"
@@ -24,6 +27,7 @@ func NewRequestHandler(p *Policy) proxy.RequestHandler {
log.Err(err).Error("Error generating response") log.Err(err).Error("Error generating response")
return nil, nil return nil, nil
} }
log.Debug("Replacing HTTP response from policy")
return nil, r return nil, r
}) })
} }
@@ -47,21 +51,52 @@ func NewDialHandler(p *Policy) proxy.DialHandler {
return nil, nil return nil, nil
} }
c := netutil.NewLoopback() // Create a fake loopback connection
pipe := netutil.NewLoopback()
go func(c net.Conn) { go func(c net.Conn) {
s := &http.Server{ defer func() { _ = c.Close() }()
Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { if req.URL.Scheme == "https" || req.URL.Scheme == "wss" || netutil.Port(req.URL.Host) == 443 {
r.Write(w) c = maybeUpgradeToTLS(c, ctx, req, log)
}),
} }
_ = s.Serve(&netutil.AcceptOnce{Conn: c})
}(c.Server)
return c.Client, nil br := bufio.NewReader(c)
if _, err := http.ReadRequest(br); err != nil {
log.Err(err).Warn("Malformed HTTP request in MITM connection")
}
_ = r.Write(c)
}(pipe.Server)
return pipe.Client, nil
}) })
} }
func maybeUpgradeToTLS(c net.Conn, ctx proxy.Context, req *http.Request, log logger.Structured) net.Conn {
var ca ca.CertificateAuthority
if caCtx, ok := ctx.(proxy.WithCertificateAuthority); ok {
ca = caCtx.CertificateAuthority()
}
if ca == nil {
return c
}
secure := tls.Server(c, &tls.Config{
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
log.Values(logger.Values{
"cn": req.URL.Host,
"names": hello.ServerName,
}).Debug("Requesting certificate from CA")
return ca.GetCertificate(netutil.Host(req.URL.Host), []string{hello.ServerName}, nil)
},
NextProtos: []string{"http/1.1"},
})
if err := secure.Handshake(); err != nil {
log.Err(err).Warn("Failed to pretend secure HTTP")
return c
}
return secure
}
func NewForwardHandler(p *Policy) proxy.ForwardHandler { func NewForwardHandler(p *Policy) proxy.ForwardHandler {
log := logger.StandardLog.Value("policy", p.name) log := logger.StandardLog.Value("policy", p.name)
return proxy.ForwardHandlerFunc(func(ctx proxy.Context, req *http.Request) (*http.Response, error) { return proxy.ForwardHandlerFunc(func(ctx proxy.Context, req *http.Request) (*http.Response, error) {
@@ -72,7 +107,15 @@ func NewForwardHandler(p *Policy) proxy.ForwardHandler {
log.Err(err).Error("Error evaulating policy") log.Err(err).Error("Error evaulating policy")
return nil, nil return nil, nil
} }
return result.Response(ctx) r, err := result.Response(ctx)
if err != nil {
log.Err(err).Error("Error generating response")
return nil, err
}
if r != nil {
log.Debug("Replacing HTTP response from policy")
}
return r, nil
}) })
} }
@@ -80,6 +123,7 @@ func NewResponseHandler(p *Policy) proxy.ResponseHandler {
log := logger.StandardLog.Value("policy", p.name) log := logger.StandardLog.Value("policy", p.name)
return proxy.ResponseHandlerFunc(func(ctx proxy.Context) *http.Response { return proxy.ResponseHandlerFunc(func(ctx proxy.Context) *http.Response {
input := NewInputFromResponse(ctx, ctx.Response()) input := NewInputFromResponse(ctx, ctx.Response())
input.logValues(log).Trace("Running response handler")
result, err := p.Query(input) result, err := p.Query(input)
if err != nil { if err != nil {
log.Err(err).Error("Error evaulating policy") log.Err(err).Error("Error evaulating policy")
@@ -90,6 +134,9 @@ func NewResponseHandler(p *Policy) proxy.ResponseHandler {
log.Err(err).Error("Error generating response") log.Err(err).Error("Error generating response")
return nil return nil
} }
if r != nil {
log.Debug("Replacing HTTP response from policy")
}
return r return r
}) })
} }

View File

@@ -10,19 +10,26 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"git.maze.io/maze/styx/dataset"
"git.maze.io/maze/styx/internal/netutil" "git.maze.io/maze/styx/internal/netutil"
"git.maze.io/maze/styx/logger" "git.maze.io/maze/styx/logger"
proxy "git.maze.io/maze/styx/proxy"
) )
// Input represents the input to the policy query. // Input represents the input to the policy query.
type Input struct { type Input struct {
Context map[string]any `json:"context"`
Client *Client `json:"client"` Client *Client `json:"client"`
Groups []*Group `json:"groups"`
TLS *TLS `json:"tls"` TLS *TLS `json:"tls"`
Request *Request `json:"request"` Request *Request `json:"request"`
Response *Response `json:"response"` Response *Response `json:"response"`
} }
func (i *Input) logValues(log logger.Structured) logger.Structured { func (i *Input) logValues(log logger.Structured) logger.Structured {
if i.Context != nil {
log = log.Values(i.Context)
}
log = i.Client.logValues(log) log = i.Client.logValues(log)
log = i.TLS.logValues(log) log = i.TLS.logValues(log)
log = i.Request.logValues(log) log = i.Request.logValues(log)
@@ -34,10 +41,29 @@ func NewInputFromConn(c net.Conn) *Input {
if c == nil { if c == nil {
return new(Input) return new(Input)
} }
return &Input{
input := &Input{
Context: make(map[string]any),
Client: NewClientFromConn(c), Client: NewClientFromConn(c),
TLS: NewTLSFromConn(c), TLS: NewTLSFromConn(c),
} }
if wcl, ok := c.(dataset.WithClient); ok {
client, err := wcl.Client()
if err == nil {
input.Context["client_id"] = client.ID
input.Context["client_description"] = client.Description
input.Context["groups"] = client.Groups
}
}
if ctx, ok := c.(proxy.Context); ok {
input.Context["local"] = NewClientFromAddr(ctx.LocalAddr())
input.Context["bytes_rx"] = ctx.BytesRead()
input.Context["bytes_tx"] = ctx.BytesSent()
}
return input
} }
func NewInputFromRequest(c net.Conn, r *http.Request) *Input { func NewInputFromRequest(c net.Conn, r *http.Request) *Input {
@@ -131,6 +157,10 @@ func NewClientFromAddr(addr net.Addr) *Client {
} }
} }
type Group struct {
Name string `json:"name"`
}
type TLS struct { type TLS struct {
Version string `json:"version"` Version string `json:"version"`
CipherSuite string `json:"cipher_suite"` CipherSuite string `json:"cipher_suite"`

View File

@@ -67,24 +67,10 @@ func newRego(option func(*rego.Rego), pkg string) []func(*rego.Rego) {
rego.Query("data." + pkg), rego.Query("data." + pkg),
rego.Strict(true), rego.Strict(true),
rego.Capabilities(capabilities), rego.Capabilities(capabilities),
rego.Function2(&rego.Function{ rego.Function2(domainContainsFunc, domainContains),
Name: "styx.in_domains", rego.Function2(networkContainsFunc, networkContains),
Decl: domainContainsDecl, rego.Function1(lookupIPAddrFunc, lookupIPAddr),
Memoize: true, rego.Function2(timebetweenFunc, timeBetween),
Nondeterministic: true,
}, domainContainsImpl),
rego.Function2(&rego.Function{
Name: "styx.in_networks",
Decl: networkContainsDecl,
Memoize: true,
Nondeterministic: true,
}, networkContainsImpl),
rego.Function1(&rego.Function{
Name: "styx.lookup_ip_addr", // override builtin
Decl: netLookupIPAddrDecl,
Memoize: true,
Nondeterministic: true,
}, netLookupIPAddrImpl),
rego.PrintHook(printHook{}), rego.PrintHook(printHook{}),
option, option,
} }
@@ -128,16 +114,20 @@ func (r *Result) Response(ctx proxy.Context) (*http.Response, error) {
switch { switch {
case r.Redirect != "": case r.Redirect != "":
log.Value("location", r.Redirect).Trace("Creating a HTTP redirect response")
response := proxy.NewResponse(http.StatusFound, nil, ctx.Request()) response := proxy.NewResponse(http.StatusFound, nil, ctx.Request())
response.Header.Set("Server", "styx") response.Header.Set("Server", "styx")
response.Header.Set(proxy.HeaderLocation, r.Redirect) response.Header.Set(proxy.HeaderLocation, r.Redirect)
return response, nil return response, nil
case r.Template != "": case r.Template != "":
log = log.Value("template", r.Template)
log.Trace("Creating a HTTP template response")
b := new(bytes.Buffer) b := new(bytes.Buffer)
t, err := template.New(filepath.Base(r.Template)).ParseFiles(r.Template) t, err := template.New(filepath.Base(r.Template)).ParseFiles(r.Template)
if err != nil { if err != nil {
log.Value("template", r.Template).Err(err).Warn("Error loading template in response") log.Err(err).Warn("Error loading template in response")
return nil, err return nil, err
} }
t = t.Funcs(template.FuncMap{ t = t.Funcs(template.FuncMap{
@@ -149,7 +139,7 @@ func (r *Result) Response(ctx proxy.Context) (*http.Response, error) {
"Response": ctx.Response(), "Response": ctx.Response(),
"Errors": r.Errors, "Errors": r.Errors,
}); err != nil { }); err != nil {
log.Value("template", r.Template).Err(err).Warn("Error rendering template response") log.Err(err).Warn("Error rendering template response")
return nil, err return nil, err
} }
@@ -159,46 +149,34 @@ func (r *Result) Response(ctx proxy.Context) (*http.Response, error) {
return response, nil return response, nil
case r.Reject > 0: case r.Reject > 0:
log.Value("code", r.Reject).Trace("Creating a HTTP reject response")
body := io.NopCloser(bytes.NewBufferString(http.StatusText(r.Reject))) body := io.NopCloser(bytes.NewBufferString(http.StatusText(r.Reject)))
response := proxy.NewResponse(r.Reject, body, ctx.Request()) response := proxy.NewResponse(r.Reject, body, ctx.Request())
response.Header.Set(proxy.HeaderContentType, "text/plain") response.Header.Set(proxy.HeaderContentType, "text/plain")
return response, nil return response, nil
case r.Permit != nil && !*r.Permit: case r.Permit != nil && !*r.Permit:
log.Trace("Creating a HTTP reject response due to explicit not permit")
body := io.NopCloser(bytes.NewBufferString(http.StatusText(http.StatusForbidden))) body := io.NopCloser(bytes.NewBufferString(http.StatusText(http.StatusForbidden)))
response := proxy.NewResponse(http.StatusForbidden, body, ctx.Request()) response := proxy.NewResponse(http.StatusForbidden, body, ctx.Request())
response.Header.Set(proxy.HeaderContentType, "text/plain") response.Header.Set(proxy.HeaderContentType, "text/plain")
return response, nil return response, nil
default: default:
log.Trace("Not creating a HTTP response")
return nil, nil return nil, nil
} }
} }
func (p *Policy) Query(input *Input) (*Result, error) { func (p *Policy) Query(input *Input) (*Result, error) {
/*
e := json.NewEncoder(os.Stdout)
e.SetIndent("", " ")
e.Encode(doc)
*/
log := logger.StandardLog.Value("policy", p.name) log := logger.StandardLog.Value("policy", p.name)
log.Trace("Evaluating policy") log.Trace("Evaluating policy")
r := rego.New(append(p.options, rego.Input(input))...) var (
rego = rego.New(append(p.options, rego.Input(input))...)
ctx := context.Background() ctx = context.Background()
/* rs, err = rego.Eval(ctx)
query, err := p.rego.PrepareForEval(ctx) )
if err != nil {
return nil, err
}
rs, err := query.Eval(ctx, rego.EvalInput(input))
if err != nil {
return nil, err
}
*/
rs, err := r.Eval(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -208,6 +186,12 @@ func (p *Policy) Query(input *Input) (*Result, error) {
result := &Result{} result := &Result{}
for _, expr := range rs[0].Expressions { for _, expr := range rs[0].Expressions {
if m, ok := expr.Value.(map[string]any); ok { if m, ok := expr.Value.(map[string]any); ok {
// Remove private variables.
for k := range m {
if len(k) > 0 && k[0] == '_' {
delete(m, k)
}
}
log.Values(m).Trace("Policy result expression") log.Values(m).Trace("Policy result expression")
if err = mapstructure.Decode(m, result); err != nil { if err = mapstructure.Decode(m, result); err != nil {
return nil, err return nil, err

View File

@@ -14,6 +14,8 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"git.maze.io/maze/styx/ca"
"git.maze.io/maze/styx/dataset"
"git.maze.io/maze/styx/logger" "git.maze.io/maze/styx/logger"
) )
@@ -42,6 +44,13 @@ type Context interface {
// Response is the response that will be sent back to the client. // Response is the response that will be sent back to the client.
Response() *http.Response Response() *http.Response
// Client group.
Client() (dataset.Client, error)
}
type WithCertificateAuthority interface {
CertificateAuthority() ca.CertificateAuthority
} }
type countingReader struct { type countingReader struct {
@@ -80,6 +89,9 @@ type proxyContext struct {
req *http.Request req *http.Request
res *http.Response res *http.Response
idleTimeout time.Duration idleTimeout time.Duration
ca ca.CertificateAuthority
storage dataset.Storage
client dataset.Client
} }
// NewContext returns an initialized context for the provided [net.Conn]. // NewContext returns an initialized context for the provided [net.Conn].
@@ -218,4 +230,28 @@ func (c *proxyContext) WriteHeader(code int) {
//return c.res.Header.Write(c) //return c.res.Header.Write(c)
} }
func (c *proxyContext) CertificateAuthority() ca.CertificateAuthority {
return c.ca
}
func (c *proxyContext) Client() (dataset.Client, error) {
if c.storage == nil {
return dataset.Client{}, dataset.ErrNotExist{Object: "client"}
}
if !c.client.CreatedAt.Equal(time.Time{}) {
return c.client, nil
}
var err error
switch addr := c.Conn.RemoteAddr().(type) {
case *net.TCPAddr:
c.client, err = c.storage.ClientByIP(addr.IP)
case *net.UDPAddr:
c.client, err = c.storage.ClientByIP(addr.IP)
default:
err = dataset.ErrNotExist{Object: "client"}
}
return c.client, err
}
var _ Context = (*proxyContext)(nil) var _ Context = (*proxyContext)(nil)

View File

@@ -15,9 +15,12 @@ import (
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"sync"
"syscall" "syscall"
"time" "time"
"git.maze.io/maze/styx/ca"
"git.maze.io/maze/styx/dataset"
"git.maze.io/maze/styx/internal/netutil" "git.maze.io/maze/styx/internal/netutil"
"git.maze.io/maze/styx/logger" "git.maze.io/maze/styx/logger"
"git.maze.io/maze/styx/stats" "git.maze.io/maze/styx/stats"
@@ -26,6 +29,7 @@ import (
// Common HTTP headers. // Common HTTP headers.
const ( const (
HeaderConnection = "Connection" HeaderConnection = "Connection"
HeaderContentLength = "Content-Length"
HeaderContentType = "Content-Type" HeaderContentType = "Content-Type"
HeaderDate = "Date" HeaderDate = "Date"
HeaderForwarded = "Forwarded" HeaderForwarded = "Forwarded"
@@ -146,7 +150,17 @@ type Proxy struct {
// WebSocketIdleTimeout is the timeout for idle WebSocket connections. // WebSocketIdleTimeout is the timeout for idle WebSocket connections.
WebSocketIdleTimeout time.Duration WebSocketIdleTimeout time.Duration
// CertificateAuthority can issue certificates for man-in-the-middle connections.
CertificateAuthority ca.CertificateAuthority
// Storage for resolving clients/groups
Storage dataset.Storage
mux *http.ServeMux mux *http.ServeMux
closed chan struct{}
closeOnce sync.Once
mu sync.RWMutex
listeners []net.Listener
} }
// New [Proxy] with somewhat sane defaults. // New [Proxy] with somewhat sane defaults.
@@ -157,6 +171,7 @@ func New() *Proxy {
IdleTimeout: DefaultIdleTimeout, IdleTimeout: DefaultIdleTimeout,
WebSocketIdleTimeout: DefaultWebSocketIdleTimeout, WebSocketIdleTimeout: DefaultWebSocketIdleTimeout,
mux: http.NewServeMux(), mux: http.NewServeMux(),
closed: make(chan struct{}, 1),
} }
// Make sure the roundtripper uses our dialers. // Make sure the roundtripper uses our dialers.
@@ -181,6 +196,55 @@ func New() *Proxy {
return p return p
} }
func (p *Proxy) Close() error {
var closeListeners bool
p.closeOnce.Do(func() {
close(p.closed)
closeListeners = true
})
if closeListeners {
p.mu.RLock()
for _, l := range p.listeners {
_ = l.Close()
}
p.mu.RUnlock()
}
return nil
}
func (p *Proxy) isClosed() bool {
select {
case <-p.closed:
return true
default:
return false
}
}
func (p *Proxy) addListener(l net.Listener) {
if l == nil {
return
}
p.mu.Lock()
p.listeners = append(p.listeners, l)
p.mu.Unlock()
}
func (p *Proxy) removeListener(l net.Listener) {
if l == nil {
return
}
p.mu.Lock()
listeners := make([]net.Listener, 0, len(p.listeners)-1)
for _, o := range p.listeners {
if o != l {
listeners = append(listeners, o)
}
}
p.listeners = listeners
p.mu.Unlock()
}
// Handle installs a [http.Handler] into the internal mux. // Handle installs a [http.Handler] into the internal mux.
func (p *Proxy) Handle(pattern string, handler http.Handler) { func (p *Proxy) Handle(pattern string, handler http.Handler) {
p.mux.Handle(pattern, handler) p.mux.Handle(pattern, handler)
@@ -214,11 +278,23 @@ func (p *Proxy) dial(ctx context.Context, req *http.Request) (net.Conn, error) {
// Serve proxied connections on the specified listener. // Serve proxied connections on the specified listener.
func (p *Proxy) Serve(l net.Listener) error { func (p *Proxy) Serve(l net.Listener) error {
p.addListener(l)
defer p.removeListener(l)
for { for {
if p.isClosed() {
return nil
}
c, err := l.Accept() c, err := l.Accept()
if err != nil { if err != nil {
return err return err
} }
if p.isClosed() {
_ = c.Close()
return nil
}
go p.handle(c) go p.handle(c)
} }
} }
@@ -229,6 +305,7 @@ func (p *Proxy) handle(nc net.Conn) {
ctx = NewContext(nc).(*proxyContext) ctx = NewContext(nc).(*proxyContext)
err error err error
) )
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
if err, ok := r.(error); ok { if err, ok := r.(error); ok {
@@ -266,6 +343,8 @@ func (p *Proxy) handle(nc net.Conn) {
// Propagate timeouts // Propagate timeouts
ctx.SetIdleTimeout(p.IdleTimeout) ctx.SetIdleTimeout(p.IdleTimeout)
ctx.ca = p.CertificateAuthority
ctx.storage = p.Storage
for _, f := range p.OnConnect { for _, f := range p.OnConnect {
fc, err := f.HandleConn(ctx) fc, err := f.HandleConn(ctx)
@@ -282,6 +361,15 @@ func (p *Proxy) handle(nc net.Conn) {
} }
log := ctx.LogEntry() log := ctx.LogEntry()
if p.Storage != nil {
if client, err := p.Storage.ClientByIP(nc.RemoteAddr().(*net.TCPAddr).IP); err == nil {
log = log.Values(logger.Values{
"client_id": client.ID,
"client_network": client.String(),
"client_description": client.Description,
})
}
}
for { for {
if ctx.transparentTLS { if ctx.transparentTLS {
ctx.req = &http.Request{ ctx.req = &http.Request{
@@ -344,7 +432,7 @@ func (p *Proxy) handle(nc net.Conn) {
} }
if err = p.handleRequest(ctx); err != nil { if err = p.handleRequest(ctx); err != nil {
p.handleError(ctx, err, true) p.handleError(ctx, err, !netutil.IsClosing(err))
return return
} }
@@ -511,7 +599,8 @@ func (p *Proxy) serveForward(ctx *proxyContext) (err error) {
_ = ctx.Close() _ = ctx.Close()
return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err) return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err)
} }
} else { }
if res != nil {
ctx.res = res ctx.res = res
} }
@@ -571,28 +660,44 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) {
return p.multiplex(ctx, srv) return p.multiplex(ctx, srv)
} }
func (p *Proxy) multiplex(ctx, srv Context) (err error) { func (p *Proxy) multiplex(ctx, srv *proxyContext) (err error) {
var ( var (
log = ctx.LogEntry().Value("server", srv.RemoteAddr().String())
errs = make(chan error, 1) errs = make(chan error, 1)
done = make(chan struct{}, 1) done = make(chan struct{}, 1)
) )
go func(errs chan<- error) { go func(errs chan<- error) {
defer close(done) if _, err := io.Copy(ctx, srv); err != nil && !netutil.IsClosing(err) {
if _, err := io.Copy(srv, ctx); err != nil { log.Err(err).Trace("Multiplexing closed in client->server")
errs <- err errs <- err
} else {
log.Trace("Multiplexing closed in client->server")
} }
}(errs) }(errs)
go func(errs chan<- error) { go func(errs chan<- error) {
if _, err := io.Copy(ctx, srv); err != nil { defer close(done)
if _, err := io.Copy(srv, ctx); err != nil && !netutil.IsClosing(err) {
log.Err(err).Trace("Multiplexing closed in server->client")
errs <- err errs <- err
} else {
log.Trace("Multiplexing closed in server->client")
} }
}(errs) }(errs)
defer func() {
log.Trace("Multiplexing done, force-closing client and server connections")
_ = ctx.Close()
_ = srv.Close()
}()
select { select {
case err = <-errs: case err = <-errs:
return return
case <-done: case <-done:
return return io.EOF // multiplexing never recycles connection
case <-p.closed:
return io.EOF // server closed
} }
} }

213
stats/handler.go Normal file
View File

@@ -0,0 +1,213 @@
package stats
import (
"encoding/json"
"expvar"
"fmt"
"net/http"
"sort"
"strings"
"html/template"
)
var (
page = template.Must(template.New("").
Funcs(template.FuncMap{"path": path, "duration": duration}).
Parse(`<!DOCTYPE html>
<html lang="us">
<meta charset="utf-8">
<title>Metrics report</title>
<meta name="viewport" content="width=device-width">
<style>
* { margin: 0; padding: 0; box-sizing: border-box; font-family: monospace; font-size: 12px; }
.container {
max-width: 640px;
margin: 1em auto;
display: flex;
flex-direction: column;
padding: 0 1em;
}
h1 { text-align: center; }
h2 {
font-weight: normal;
text-overflow: ellipsis;
white-space: nowrap;
overflow: hidden;
}
.metric {
padding: 1em 0;
border-top: 1px solid rgba(0,0,0,0.33);
}
.row {
display: flex;
flex-direction: row;
align-items: center;
margin: 0.25em 0;
}
.col-1 { flex: 1; }
.col-2 { flex: 2.5; }
.table { width: 100px; border-radius: 2px; border: 1px solid rgba(0,0,0,0.33); }
.table td, .table th { text-align: center; }
.timeline { padding: 0 0.5em; }
path { fill: none; stroke: rgba(0,0,0,0.33); stroke-width: 1; stroke-linecap: round; stroke-linejoin: round; }
path:last-child { stroke: black; }
</style>
<body>
<div class="container">
<div><h1><pre> __ __
.--------..-----.| |_ .----.|__|.----..-----.
| || -__|| _|| _|| || __||__ --|
|__|__|__||_____||____||__| |__||____||_____|
</pre></h1></div>
{{ range . }}
<div class="row metric">
<h2 class="col-1">{{ .name }}</h2>
<div class="col-2">
{{ if .type }}
<div class="row">
{{ template "table" . }}
<div class="col-1"></div>
</div>
{{ else if .interval }}
<div class="row">{{ template "timeseries" . }}</div>
{{ else if .metrics}}
{{ range .metrics }}
<div class="row">
{{ template "timeseries" . }}
</div>
{{ end }}
{{ end }}
</div>
</div>
{{ end }}
</div>
</body>
</html>
{{ define "table" }}
<table class="table col-1">
{{ if eq .type "c" }}
<thead><tr><th>count</th></tr></thead><tbody><tr><td>{{ printf "%.2g" .count }}</td></tr></tbody>
{{ else if eq .type "g" }}
<thead><tr><th>mean</th><th>min</th><th>max</th></tr></thead>
<tbody><tr><td>{{printf "%.2g" .mean}}</td><td>{{printf "%.2g" .min}}</td><td>{{printf "%.2g" .max}}</td></th></tbody>
{{ else if eq .type "h" }}
<thead><tr><th>P.50</th><th>P.90</th><th>P.99</th></tr></thead>
<tbody><tr><td>{{printf "%.2g" .p50}}</td><td>{{printf "%.2g" .p90}}</td><td>{{printf "%.2g" .p99}}</td></tr></tbody>
{{ end }}
</table>
{{ end }}
{{ define "timeseries" }}
{{ template "table" .total }}
<div class="col-1">
<div class="row">
<div class="timeline">{{ duration .samples .interval }}</div>
<svg class="col-1" version="1.1" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 100 20">
{{ if eq (index (index .samples 0) "type") "c" }}
{{ range (path .samples "count") }}<path d={{ . }} />{{end}}
{{ else if eq (index (index .samples 0) "type") "g" }}
{{ range (path .samples "min" "max" "mean" ) }}<path d={{ . }} />{{end}}
{{ else if eq (index (index .samples 0) "type") "h" }}
{{ range (path .samples "p50" "p90" "p99") }}<path d={{ . }} />{{end}}
{{ end }}
</svg>
</div>
</div>
{{ end }}
`))
)
func path(samples []any, keys ...string) []string {
var min, max float64
paths := make([]string, len(keys))
for i := range len(samples) {
s := samples[i].(map[string]any)
for _, k := range keys {
x := s[k].(float64)
if i == 0 || x < min {
min = x
}
if i == 0 || x > max {
max = x
}
}
}
for i := range len(samples) {
s := samples[i].(map[string]any)
for j, k := range keys {
v := s[k].(float64)
x := float64(i+1) / float64(len(samples))
y := (v - min) / (max - min)
if max == min {
y = 0
}
if i == 0 {
paths[j] = fmt.Sprintf("M%f %f", 0.0, (1-y)*18+1)
}
paths[j] += fmt.Sprintf(" L%f %f", x*100, (1-y)*18+1)
}
}
return paths
}
func duration(samples []any, n float64) string {
n = n * float64(len(samples))
if n < 60 {
return fmt.Sprintf("%d sec", int(n))
} else if n < 60*60 {
return fmt.Sprintf("%d min", int(n/60))
} else if n < 24*60*60 {
return fmt.Sprintf("%d hrs", int(n/60/60))
}
return fmt.Sprintf("%d days", int(n/24/60/60))
}
// Handler returns an http.Handler that renders web UI for all provided metrics.
func Handler(snapshot func() map[string]Metric) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
type h map[string]any
metrics := []h{}
for name, metric := range snapshot() {
m := h{}
b, _ := json.Marshal(metric)
json.Unmarshal(b, &m)
m["name"] = name
metrics = append(metrics, m)
}
sort.Slice(metrics, func(i, j int) bool {
n1 := metrics[i]["name"].(string)
n2 := metrics[j]["name"].(string)
return strings.Compare(n1, n2) < 0
})
page.Execute(w, metrics)
})
}
// JSONHandler returns a [http.Handler] that renders the metrics as JSON.
func JSONHandler(snapshot func() map[string]Metric) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
type h map[string]any
metrics := map[string]h{}
for name, metric := range snapshot() {
m := h{}
b, _ := json.Marshal(metric)
json.Unmarshal(b, &m)
metrics[name] = m
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(metrics)
})
}
// Exposed returns a map of exposed metrics (see expvar package).
func Exposed() map[string]Metric {
m := map[string]Metric{}
expvar.Do(func(kv expvar.KeyValue) {
if metric, ok := kv.Value.(Metric); ok {
m[kv.Key] = metric
}
})
return m
}

104
stats/stats.go Normal file
View File

@@ -0,0 +1,104 @@
package stats
import (
"encoding/json"
"math"
"sort"
"strconv"
"sync/atomic"
)
// Metric is a single meter (counter, gauge or histogram, optionally - with history)
type Metric interface {
Add(n float64)
String() string
}
// metric is an extended private interface with some additional internal
// methods used by timeseries. Counters, gauges and histograms implement it.
type metric interface {
Metric
Reset()
Aggregate(roll int, samples []metric)
}
type multimetric []*timeseries
func (mm multimetric) Add(n float64) {
for _, m := range mm {
m.Add(n)
}
}
func (mm multimetric) MarshalJSON() ([]byte, error) {
b := []byte(`{"metrics":[`)
for i, m := range mm {
if i != 0 {
b = append(b, ',')
}
x, _ := json.Marshal(m)
b = append(b, x...)
}
b = append(b, ']', '}')
return b, nil
}
func (mm multimetric) String() string {
return mm[len(mm)-1].String()
}
func newMetric(builder func() metric, frames ...string) Metric {
if len(frames) == 0 {
return builder()
}
if len(frames) == 1 {
return newTimeseries(builder, frames[0])
}
mm := multimetric{}
for _, frame := range frames {
mm = append(mm, newTimeseries(builder, frame))
}
sort.Slice(mm, func(i, j int) bool {
a, b := mm[i], mm[j]
return a.interval.Seconds()*float64(len(a.samples)) < b.interval.Seconds()*float64(len(b.samples))
})
return mm
}
// NewCounter returns a counter metric that increments the value with each
// incoming number.
func NewCounter(frames ...string) Metric {
return newMetric(func() metric { return &counter{} }, frames...)
}
type counter struct {
count uint64
}
func (c *counter) String() string { return strconv.FormatFloat(c.value(), 'g', -1, 64) }
func (c *counter) Reset() { atomic.StoreUint64(&c.count, math.Float64bits(0)) }
func (c *counter) value() float64 { return math.Float64frombits(atomic.LoadUint64(&c.count)) }
func (c *counter) Add(n float64) {
for {
old := math.Float64frombits(atomic.LoadUint64(&c.count))
new := old + n
if atomic.CompareAndSwapUint64(&c.count, math.Float64bits(old), math.Float64bits(new)) {
return
}
}
}
func (c *counter) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string `json:"type"`
Count float64 `json:"count"`
}{"c", c.value()})
}
func (c *counter) Aggregate(roll int, samples []metric) {
c.Reset()
for _, s := range samples {
c.Add(s.(*counter).value())
}
}

105
stats/timeseries.go Normal file
View File

@@ -0,0 +1,105 @@
package stats
import (
"encoding/json"
"fmt"
"sync"
"time"
)
type timeseries struct {
sync.Mutex
now time.Time
size int
interval time.Duration
total metric
samples []metric
}
func (ts *timeseries) Reset() {
ts.total.Reset()
for _, s := range ts.samples {
s.Reset()
}
}
func (ts *timeseries) roll() {
t := time.Now()
roll := int((t.Round(ts.interval).Sub(ts.now.Round(ts.interval))) / ts.interval)
ts.now = t
n := len(ts.samples)
if roll <= 0 {
return
}
if roll >= len(ts.samples) {
ts.Reset()
} else {
for i := 0; i < roll; i++ {
tmp := ts.samples[n-1]
for j := n - 1; j > 0; j-- {
ts.samples[j] = ts.samples[j-1]
}
ts.samples[0] = tmp
ts.samples[0].Reset()
}
ts.total.Aggregate(roll, ts.samples)
}
}
func (ts *timeseries) Add(n float64) {
ts.Lock()
defer ts.Unlock()
ts.roll()
ts.total.Add(n)
ts.samples[0].Add(n)
}
func (ts *timeseries) MarshalJSON() ([]byte, error) {
ts.Lock()
defer ts.Unlock()
ts.roll()
return json.Marshal(struct {
Interval float64 `json:"interval"`
Total Metric `json:"total"`
Samples []metric `json:"samples"`
}{float64(ts.interval) / float64(time.Second), ts.total, ts.samples})
}
func (ts *timeseries) String() string {
ts.Lock()
defer ts.Unlock()
ts.roll()
return ts.total.String()
}
func newTimeseries(builder func() metric, frame string) *timeseries {
var (
totalNum, intervalNum int
totalUnit, intervalUnit rune
)
units := map[rune]time.Duration{
's': time.Second,
'm': time.Minute,
'h': time.Hour,
'd': time.Hour * 24,
'w': time.Hour * 24 * 7,
'M': time.Hour * 24 * 30,
'y': time.Hour * 24 * 365,
}
fmt.Sscanf(frame, "%d%c%d%c", &totalNum, &totalUnit, &intervalNum, &intervalUnit)
interval := units[intervalUnit] * time.Duration(intervalNum)
if interval == 0 {
interval = time.Minute
}
totalDuration := units[totalUnit] * time.Duration(totalNum)
if totalDuration == 0 {
totalDuration = interval * 15
}
n := int(totalDuration / interval)
samples := make([]metric, n, n)
for i := 0; i < n; i++ {
samples[i] = builder()
}
totalMetric := builder()
return &timeseries{interval: interval, total: totalMetric, samples: samples}
}

View File

@@ -37,22 +37,35 @@ proxy {
} }
} }
ca {
cert = "testdata/ca.crt"
key = "testdata/ca.key"
}
policy "intercept" { policy "intercept" {
path = "testdata/policy/intercept.rego" path = "testdata/policy/styx/intercept.rego"
package = "styx.intercept" package = "styx.intercept"
} }
policy "bogons" { policy "bogons" {
path = "testdata/policy/bogons.rego" path = "testdata/policy/styx/bogons.rego"
} }
policy "childsafe" { policy "childsafe" {
path = "testdata/policy/childsafe.rego" path = "testdata/policy/custom/childsafe.rego"
package = "custom"
} }
data { data {
path = "testdata/match" path = "testdata/match"
storage {
type = "bolt"
path = "testdata/styx.bolt"
#type = "sqlite"
#path = "testdata/styx.db"
}
network "reserved" { network "reserved" {
type = "list" type = "list"
list = [ list = [

BIN
template/blocked-256.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

BIN
template/blocked-512.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 88 KiB

83
template/blocked.html Normal file

File diff suppressed because one or more lines are too long

BIN
template/blocked.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 189 KiB

View File

@@ -1,58 +0,0 @@
package styx
import input.request as http_request
default permit := false
default reject := 0
default template := ""
# Bogon networks
bogons := [
"0.0.0.0/8", # "This" network
"10.0.0.0/8", # RFC1918 Private-use networks
"100.64.0.0/10", # Carrier-grade NAT
"127.0.0.0/8", # Loopback
"169.254.0.0/16", # Link local
"172.16.0.0/12", # RFC1918 Private-use networks
"192.0.0.0/24", # IETF protocol assignments
"192.0.2.0/24", # TEST-NET-1
"192.168.0.0/16", # RFC1918 Private-use networks
"198.18.0.0/15", # Network interconnect device benchmark testing
"198.51.100.0/24", # TEST-NET-2
"203.0.113.0/24", # TEST-NET-3
"224.0.0.0/4", # Multicast
"240.0.0.0/4", # Reserved for future use
"255.255.255.255/32", # Limited broadcast
]
# Resolve HTTP host to IPs
addrs := styx.lookup_ip_addr(http_request.host)
template := "template/blocked.html" if {
some cidr in bogons
net.cidr_contains(cidr, http_request.host)
}
template := "template/blocked.html" if {
some addr in addrs
some cidr in bogons
net.cidr_contains(cidr, addr)
}
permit if {
template == ""
}
errors contains "Bogon destination not allowed" if {
template != ""
}
errors contains "Could not lookup host" if {
count(addrs) == 0
}
errors contains addr if {
some addr in addrs
some cidr in bogons
net.cidr_contains(cidr, addr)
}

View File

@@ -1,56 +0,0 @@
package styx
import input.client as client
import input.request as http_request
# HTTP -> HTTPS redirects for allowed domains
redirect = concat("", ["https://", http_request.host, http_request.path]) if {
_social
http_request.scheme == "http"
}
reject = 403 if {
_childsafe_network
_social
}
reject = 403 if {
_childsafe_network
_toxic
}
# Sensitive domains are always allowed
permit if {
_sensitive
}
permit if {
reject != 0
}
_sensitive if {
styx.in_domains("sensitive", http_request.host)
}
_social if {
styx.in_domains("social", http_request.host)
print("Domain in social", http_request.host)
}
errors contains "Social networking domain not allowed" if {
reject != 0
_social
}
_toxic if {
styx.in_domains("toxic", http_request.host)
}
errors contains "Toxic domain not allowed" if {
reject != 0
_toxic
}
_childsafe_network if {
styx.in_networks("kids", client.ip)
}

102
testdata/policy/custom/childsafe.rego vendored Normal file
View File

@@ -0,0 +1,102 @@
package custom
_social_domains := [
"reddit.com",
"roblox.com",
# X
"twitter.com",
"x.com",
# YouTube
"googlevideo.com",
"youtube.com",
"youtu.be",
"ytimg.com",
]
_toxic_domains := [
# Facebook
"facebook.com",
"facebook.net",
"fbsbx.com",
# Pinterest
"pinterest.com",
# TikTok
"isnssdk.com",
"musical.ly",
"musically.app.link",
"musically-alternate.app.link",
"musemuse.cn",
"sgsnssdk.com",
"tiktok.com",
"tiktok.org",
"tiktokcdn.com",
"tiktokcdn-eu.com",
"tiktokv.com",
]
in_domains(list, name) if {
some item in list
lower(name) == lower(item)
}
in_domains(list, name) if {
some item in list
endswith(lower(name), sprintf(".%s", [lower(item)]))
}
# METADATA
# description: Apply childssfe rules to the request, reject if it's a social
# site between off-hours, reject if it's toxic.
# entrypoint: true
default redirect := ""
# HTTP -> HTTPS redirects for allowed domains
redirect := location if {
_social
input.request.scheme == "http"
location := sprintf("https://%s%s", [input.request.host, input.request.path])
}
default reject := 0
template := "template/blocked.html" if {
_childsafe_network
_social
# styx.time_between("18:00", "16:00") # allowed between 16:00-18:00
}
template := "template/blocked.html" if {
_toxic
}
# Sensitive domains are always allowed
permit if {
_sensitive
reject != 0
}
_sensitive if {
styx.domains_contain("sensitive", input.request.host)
}
_social if {
#styx.domains_contain("social", input.request.host)
in_domains(_social_domains, input.request.host)
}
_toxic if {
in_domains(_toxic_domains, input.request.host)
}
_childsafe_network if {
styx.networks_contain("kids", input.client.ip)
}
errors contains "Request to social networking site outside of allowed hours" if {
_childsafe_network
_social
}
errors contains "Request to toxic site" if {
_toxic
}

View File

@@ -1,21 +0,0 @@
package styx.intercept
reject := 403 if {
_target_blocked
}
template := "template/intercepted.html" if {
_target_blocked
}
errors contains "Intercepted" if {
_target_blocked
}
_target_blocked if {
styx.in_domains("bad", input.request.host)
}
_target_blocked if {
styx.in_networks("bogons", input.client.ip)
}

54
testdata/policy/styx/bogons.rego vendored Normal file
View File

@@ -0,0 +1,54 @@
package styx
# Bogon networks
_bogons := [
"0.0.0.0/8", # "This" network
"10.0.0.0/8", # RFC1918 Private-use networks
"100.64.0.0/10", # Carrier-grade NAT
"127.0.0.0/8", # Loopback
"169.254.0.0/16", # Link local
"172.16.0.0/12", # RFC1918 Private-use networks
"192.0.0.0/24", # IETF protocol assignments
"192.0.2.0/24", # TEST-NET-1
"192.168.0.0/16", # RFC1918 Private-use networks
"198.18.0.0/15", # Network interconnect device benchmark testing
"198.51.100.0/24", # TEST-NET-2
"203.0.113.0/24", # TEST-NET-3
"224.0.0.0/4", # Multicast
"240.0.0.0/4", # Reserved for future use
"255.255.255.255/32", # Limited broadcast
]
# METADATA
# description: Reject requests to bogon targets.
# entrypoint: true
default permit := false
permit if {
template == ""
}
default template := ""
template := "template/blocked.html" if {
_bogon
}
errors contains "Bogon destination not allowed" if {
_bogon
}
errors contains _bogon if {
_bogon
}
_bogon := addr if {
some addr in styx.lookup_ip_addr(input.request.host)
some cidr in _bogons
net.cidr_contains(cidr, addr)
}
_bogon := input.request.host if {
some cidr in _bogons
net.cidr_contains(cidr, input.request.host)
}

25
testdata/policy/styx/intercept.rego vendored Normal file
View File

@@ -0,0 +1,25 @@
package styx.intercept
reject := 403 if {
_bad
}
template := "template/blocked.html" if {
_bogon
}
errors contains "Bad domain" if {
_bad
}
errors contains "Bogon target" if {
_bogon
}
_bad if {
styx.domains_contain("bad", input.request.host)
}
_bogon if {
styx.domains_contain("bogons", input.client.ip)
}