diff --git a/.gitignore b/.gitignore index 0c3b16c..73ef4f4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ -# SQLite3 database file +# Database file +*.bolt +*.boltdb *.db # Log files diff --git a/.regal.yaml b/.regal.yaml new file mode 100644 index 0000000..3fef402 --- /dev/null +++ b/.regal.yaml @@ -0,0 +1,14 @@ +rules: + idiomatic: + directory-package-mismatch: + level: ignore + + style: + function-arg-return: + level: error + except-functions: + - sprintf + +project: + roots: + - testdata/policy \ No newline at end of file diff --git a/admin/admin.go b/admin/admin.go new file mode 100644 index 0000000..8dfab7b --- /dev/null +++ b/admin/admin.go @@ -0,0 +1,146 @@ +package admin + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "net/http" + "os" + "strconv" + "sync" + + "git.maze.io/maze/styx/dataset" + "git.maze.io/maze/styx/logger" + "git.maze.io/maze/styx/proxy" +) + +type Admin struct { + Storage dataset.Storage + setupOnce sync.Once + mux *http.ServeMux + api *http.ServeMux +} + +type apiError struct { + Code int + Err error +} + +func (err apiError) Error() string { + return err.Err.Error() +} + +func (a *Admin) setup() { + a.mux = http.NewServeMux() + + a.api = http.NewServeMux() + a.api.HandleFunc("GET /groups", a.apiGroups) + a.api.HandleFunc("POST /group", a.apiGroupCreate) + a.api.HandleFunc("GET /group/{id}", a.apiGroup) + a.api.HandleFunc("PATCH /group/{id}", a.apiGroupUpdate) + a.api.HandleFunc("DELETE /group/{id}", a.apiGroupDelete) + a.api.HandleFunc("GET /clients", a.apiClients) + a.api.HandleFunc("GET /client/{id}", a.apiClient) + a.api.HandleFunc("POST /client", a.apiClientCreate) + a.api.HandleFunc("PATCH /client/{id}", a.apiClientUpdate) + a.api.HandleFunc("DELETE /client/{id}", a.apiClientDelete) + a.api.HandleFunc("GET /lists", a.apiLists) + a.api.HandleFunc("POST /list", a.apiListCreate) + a.api.HandleFunc("GET /list/{id}", a.apiList) + a.api.HandleFunc("DELETE /list/{id}", a.apiListDelete) +} + +type Handler interface { + Handle(pattern string, handler http.Handler) +} + +func (a *Admin) Install(handler Handler) { + a.setupOnce.Do(a.setup) + handler.Handle("/api/v1/", http.StripPrefix("/api/v1", a.api)) +} + +func (a *Admin) handleAPIError(w http.ResponseWriter, r *http.Request, err error) { + code := http.StatusBadRequest + switch { + case dataset.IsNotExist(err): + code = http.StatusNotFound + case os.IsPermission(err): + code = http.StatusForbidden + case errors.Is(err, apiError{}): + if c := err.(apiError).Code; c > 0 { + code = c + } + } + + logger.StandardLog.Err(err).Values(logger.Values{ + "code": code, + "client": r.RemoteAddr, + "method": r.Method, + "path": r.URL.Path, + }).Warn("Unexpected API error encountered") + + var data []byte + if err, ok := err.(apiError); ok { + data, _ = json.Marshal(struct { + Code int `json:"code"` + Error string `json:"error"` + }{code, err.Error()}) + } else { + data, _ = json.Marshal(struct { + Code int `json:"code"` + Error string `json:"error"` + }{code, http.StatusText(code)}) + } + + res := proxy.NewResponse(code, io.NopCloser(bytes.NewReader(data)), r) + res.Header.Set(proxy.HeaderContentType, "application/json") + + for k, vv := range res.Header { + if len(vv) >= 1 { + w.Header().Set(k, vv[0]) + for _, v := range vv[1:] { + w.Header().Add(k, v) + } + } + } + w.WriteHeader(code) + io.Copy(w, res.Body) +} + +func (a *Admin) jsonResponse(w http.ResponseWriter, r *http.Request, value any, codes ...int) { + var ( + code = http.StatusNoContent + body io.ReadCloser + size int64 + ) + if value != nil { + data, err := json.Marshal(value) + if err != nil { + a.handleAPIError(w, r, err) + return + } + code = http.StatusOK + body = io.NopCloser(bytes.NewReader(data)) + size = int64(len(data)) + } + if len(codes) > 0 { + code = codes[0] + } + + res := proxy.NewResponse(code, body, r) + res.Close = true + res.Header.Set(proxy.HeaderContentLength, strconv.FormatInt(size, 10)) + res.Header.Set(proxy.HeaderContentType, "application/json") + + for k, vv := range res.Header { + if len(vv) >= 1 { + w.Header().Set(k, vv[0]) + for _, v := range vv[1:] { + w.Header().Add(k, v) + } + } + } + w.WriteHeader(code) + io.Copy(w, res.Body) +} diff --git a/admin/api_client.go b/admin/api_client.go new file mode 100644 index 0000000..8605199 --- /dev/null +++ b/admin/api_client.go @@ -0,0 +1,183 @@ +package admin + +import ( + "encoding/json" + "errors" + "fmt" + "log" + "net" + "net/http" + "strconv" + "time" + + "git.maze.io/maze/styx/dataset" +) + +func (a *Admin) apiClients(w http.ResponseWriter, r *http.Request) { + clients, err := a.Storage.Clients() + if err != nil { + a.handleAPIError(w, r, err) + return + } + a.jsonResponse(w, r, clients) +} + +func (a *Admin) apiClient(w http.ResponseWriter, r *http.Request) { + id, err := strconv.ParseInt(r.PathValue("id"), 10, 64) + if err != nil { + a.handleAPIError(w, r, err) + return + } + client, err := a.Storage.ClientByID(id) + if err != nil { + a.handleAPIError(w, r, err) + return + } + a.jsonResponse(w, r, client) +} + +func (a *Admin) apiClientCreate(w http.ResponseWriter, r *http.Request) { + var request struct { + dataset.Client + Groups []int64 `json:"groups"` + ID int64 `json:"id"` // mask, not used + CreatedAt time.Time `json:"created_at"` // mask, not used + UpdatedAt time.Time `json:"updated_at"` // mask, not used + } + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + a.handleAPIError(w, r, err) + return + } + + if err := a.verifyClient(&request.Client); err != nil { + a.handleAPIError(w, r, err) + return + } + + var groups []dataset.Group + for _, id := range request.Groups { + group, err := a.Storage.GroupByID(id) + if err != nil { + a.handleAPIError(w, r, err) + return + } + groups = append(groups, group) + } + + request.Client.Groups = groups + if err := a.Storage.SaveClient(&request.Client); err != nil { + a.handleAPIError(w, r, err) + return + } + + a.jsonResponse(w, r, request.Client) +} + +func (a *Admin) apiClientUpdate(w http.ResponseWriter, r *http.Request) { + id, err := strconv.ParseInt(r.PathValue("id"), 10, 64) + if err != nil { + a.handleAPIError(w, r, err) + return + } + + client, err := a.Storage.ClientByID(id) + if err != nil { + a.handleAPIError(w, r, err) + return + } + log.Printf("updating: %#+v", client) + + var request struct { + dataset.Client + Groups []int64 `json:"groups"` + } + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + a.handleAPIError(w, r, err) + return + } + + if err := a.verifyClient(&request.Client); err != nil { + a.handleAPIError(w, r, err) + return + } + + client.IP = request.Client.IP + client.Mask = request.Client.Mask + client.Description = request.Client.Description + client.Groups = client.Groups[:0] + for _, id := range request.Groups { + group, err := a.Storage.GroupByID(id) + if err != nil { + a.handleAPIError(w, r, err) + return + } + client.Groups = append(client.Groups, group) + } + if err := a.Storage.SaveClient(&client); err != nil { + a.handleAPIError(w, r, err) + return + } + + a.jsonResponse(w, r, client) +} + +func (a *Admin) apiClientDelete(w http.ResponseWriter, r *http.Request) { + id, err := strconv.ParseInt(r.PathValue("id"), 10, 64) + if err != nil { + a.handleAPIError(w, r, err) + return + } + client, err := a.Storage.ClientByID(id) + if err != nil { + a.handleAPIError(w, r, err) + return + } + if err = a.Storage.DeleteClient(client); err != nil { + a.handleAPIError(w, r, err) + return + } + a.jsonResponse(w, r, nil) +} + +func (a *Admin) verifyClient(c *dataset.Client) (err error) { + ip := net.ParseIP(c.IP) + switch c.Network { + case "ipv4": + if ip.To4() == nil { + return apiError{Err: errors.New("invalid IPv4 address")} + } + if c.Mask == 0 { + c.Mask = 32 // one IP + } + if c.Mask <= 0 || c.Mask > 32 { + return apiError{Err: errors.New("mask can't be zero")} + } + c.IP = ip.Mask(net.CIDRMask(int(c.Mask), 32)).String() + + case "ipv6": + if ip.To16() == nil { + return apiError{Err: errors.New("invalid IPv6 address")} + } + if c.Mask == 0 { + c.Mask = 128 // one IP + } + if c.Mask <= 0 || c.Mask > 128 { + return apiError{Err: errors.New("mask can't be zero")} + } + c.IP = ip.Mask(net.CIDRMask(int(c.Mask), 128)).String() + + case "": + if ip.To4() != nil { + c.Network = "ipv4" + } else if ip.To16() != nil { + c.Network = "ipv6" + } else { + return apiError{Err: errors.New("invalid IP address")} + } + return a.verifyClient(c) + + default: + return apiError{Err: fmt.Errorf("invalid network %q", c.Network)} + } + return +} diff --git a/admin/api_group.go b/admin/api_group.go new file mode 100644 index 0000000..e981e00 --- /dev/null +++ b/admin/api_group.go @@ -0,0 +1,72 @@ +package admin + +import ( + "encoding/json" + "net/http" + "strconv" + "time" + + "git.maze.io/maze/styx/dataset" +) + +func (a *Admin) apiGroups(w http.ResponseWriter, r *http.Request) { + groups, err := a.Storage.Groups() + if err != nil { + a.handleAPIError(w, r, err) + return + } + a.jsonResponse(w, r, groups) +} + +func (a *Admin) apiGroup(w http.ResponseWriter, r *http.Request) { + id, err := strconv.ParseInt(r.PathValue("id"), 10, 64) + if err != nil { + a.handleAPIError(w, r, err) + return + } + group, err := a.Storage.GroupByID(id) + if err != nil { + a.handleAPIError(w, r, err) + return + } + a.jsonResponse(w, r, group) +} + +func (a *Admin) apiGroupCreate(w http.ResponseWriter, r *http.Request) { + var request struct { + dataset.Group + ID int64 `json:"id"` // mask, not used + CreatedAt time.Time `json:"created_at"` // mask, not used + UpdatedAt time.Time `json:"updated_at"` // mask, not used + } + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + a.handleAPIError(w, r, err) + return + } + if err := a.Storage.SaveGroup(&request.Group); err != nil { + a.handleAPIError(w, r, err) + return + } + a.jsonResponse(w, r, request.Group, http.StatusCreated) +} + +func (a *Admin) apiGroupUpdate(w http.ResponseWriter, r *http.Request) { +} + +func (a *Admin) apiGroupDelete(w http.ResponseWriter, r *http.Request) { + id, err := strconv.ParseInt(r.PathValue("id"), 10, 64) + if err != nil { + a.handleAPIError(w, r, err) + return + } + group, err := a.Storage.GroupByID(id) + if err != nil { + a.handleAPIError(w, r, err) + return + } + if err = a.Storage.DeleteGroup(group); err != nil { + a.handleAPIError(w, r, err) + return + } + a.jsonResponse(w, r, nil) +} diff --git a/admin/api_list.go b/admin/api_list.go new file mode 100644 index 0000000..b31e9e1 --- /dev/null +++ b/admin/api_list.go @@ -0,0 +1,98 @@ +package admin + +import ( + "encoding/json" + "fmt" + "net/http" + "strconv" + "time" + + "git.maze.io/maze/styx/dataset" +) + +func (a *Admin) apiLists(w http.ResponseWriter, r *http.Request) { + lists, err := a.Storage.Lists() + if err != nil { + a.handleAPIError(w, r, err) + return + } + a.jsonResponse(w, r, lists) +} + +func (a *Admin) apiList(w http.ResponseWriter, r *http.Request) { + id, err := strconv.ParseInt(r.PathValue("id"), 10, 64) + if err != nil { + a.handleAPIError(w, r, err) + return + } + list, err := a.Storage.ListByID(id) + if err != nil { + a.handleAPIError(w, r, err) + return + } + a.jsonResponse(w, r, list) +} + +func (a *Admin) apiListCreate(w http.ResponseWriter, r *http.Request) { + var request struct { + dataset.List + Groups []int64 `json:"groups"` + ID int64 `json:"id"` // mask, not used + CreatedAt time.Time `json:"created_at"` // mask, not used + UpdatedAt time.Time `json:"updated_at"` // mask, not used + } + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + a.handleAPIError(w, r, err) + return + } + + if err := a.verifyList(&request.List); err != nil { + a.handleAPIError(w, r, err) + return + } + + request.List.Groups = request.List.Groups[:0] + for _, id := range request.Groups { + group, err := a.Storage.GroupByID(id) + if err != nil { + a.handleAPIError(w, r, err) + return + } + request.List.Groups = append(request.List.Groups, group) + } + + if err := a.Storage.SaveList(&request.List); err != nil { + a.handleAPIError(w, r, err) + return + } + + a.jsonResponse(w, r, request.List) +} + +func (a *Admin) apiListDelete(w http.ResponseWriter, r *http.Request) { + id, err := strconv.ParseInt(r.PathValue("id"), 10, 64) + if err != nil { + a.handleAPIError(w, r, err) + return + } + list, err := a.Storage.ListByID(id) + if err != nil { + a.handleAPIError(w, r, err) + return + } + if err = a.Storage.DeleteList(list); err != nil { + a.handleAPIError(w, r, err) + return + } + a.jsonResponse(w, r, nil) +} + +func (a *Admin) verifyList(list *dataset.List) error { + switch list.Type { + case dataset.ListTypeDomain, dataset.ListTypeNetwork: + default: + return apiError{Err: fmt.Errorf("unknown list type %q", list.Type)} + } + + return nil +} diff --git a/ca/authority.go b/ca/authority.go new file mode 100644 index 0000000..3763bcd --- /dev/null +++ b/ca/authority.go @@ -0,0 +1,119 @@ +package ca + +import ( + "crypto" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "fmt" + "math/big" + "net" + "strings" + "sync" + "time" + + "git.maze.io/maze/styx/internal/cryptutil" + "git.maze.io/maze/styx/logger" + "github.com/miekg/dns" +) + +type CertificateAuthority interface { + GetCertificate(commonName string, dnsNames []string, ips []net.IP) (*tls.Certificate, error) +} + +type ca struct { + cert *x509.Certificate + key crypto.PrivateKey + cache sync.Map +} + +func Open(certData, keyData string) (CertificateAuthority, error) { + cert, key, err := cryptutil.LoadKeyPair(certData, keyData) + if err != nil { + return nil, err + } else if !cert.IsCA { + return nil, fmt.Errorf("ca: certificate for %s is not a certificate authority", cert.Subject.String()) + } + + return &ca{ + cert: cert, + key: key, + }, nil +} + +func (ca *ca) GetCertificate(cn string, names []string, ips []net.IP) (*tls.Certificate, error) { + var ( + log = logger.StandardLog.Values(logger.Values{ + "cn": cn, + "names": names, + "ips": ips, + }) + now = time.Now().UTC() + parent = parentDomain(cn) + ) + if cn == parent { + names = append(names, "*."+cn) + } else { + names = append(names, "*."+parent, cn) + cn = parent + log = log.Value("cn", cn) + } + if v, ok := ca.cache.Load(parent); ok { + if cert, ok := v.(*tls.Certificate); ok && now.After(cert.Leaf.NotBefore) && now.Before(cert.Leaf.NotAfter.Add(-time.Hour)) { + log.Value("valid", cert.Leaf.NotAfter.Sub(now)).Debug("Using cached certificate") + return cert, nil + } + log.Debug("Cached certificate invalid") + ca.cache.Delete(parent) + } + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, fmt.Errorf("ca: failed to generate serial number: %w", err) + } + + notBefore := now.Round(24 * time.Hour) + notAfter := notBefore.Add(48 * time.Hour) + + log.Values(logger.Values{ + "serial": serialNumber.String(), + "subject": pkix.Name{CommonName: cn}.String(), + }).Debug("Generating certificate") + template := &x509.Certificate{ + SerialNumber: serialNumber, + KeyUsage: x509.KeyUsageDataEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + Subject: pkix.Name{CommonName: cn}, + DNSNames: names, + IPAddresses: ips, + PublicKey: cryptutil.PublicKey(ca.key), + NotBefore: notBefore, + NotAfter: notAfter, + } + der, err := x509.CreateCertificate(rand.Reader, template, ca.cert, template.PublicKey, ca.key) + if err != nil { + return nil, err + } + cert, err := x509.ParseCertificate(der) + if err != nil { + return nil, err + } + + output := &tls.Certificate{ + Certificate: [][]byte{der}, + Leaf: cert, + PrivateKey: ca.key, + } + ca.cache.Store(parent, output) + return output, nil +} + +func parentDomain(name string) string { + part := dns.SplitDomainName(name) + if len(part) <= 2 { + return name + } + return strings.Join(part[1:], ".") +} diff --git a/cmd/styx/config.go b/cmd/styx/config.go index 25cd622..3bc970d 100644 --- a/cmd/styx/config.go +++ b/cmd/styx/config.go @@ -8,6 +8,7 @@ import ( "github.com/hashicorp/hcl/v2/gohcl" "github.com/hashicorp/hcl/v2/hclsimple" + "git.maze.io/maze/styx/ca" "git.maze.io/maze/styx/dataset" "git.maze.io/maze/styx/internal/cryptutil" "git.maze.io/maze/styx/logger" @@ -18,6 +19,7 @@ import ( type Config struct { Proxy ProxyConfig `hcl:"proxy,block"` Policy []PolicyConfig `hcl:"policy,block"` + CA *CAConfig `hcl:"ca,block"` Data DataConfig `hcl:"data,block"` } @@ -145,8 +147,18 @@ type PolicyConfig struct { Package string `hcl:"package,optional"` } +type CAConfig struct { + Cert string `hcl:"cert"` + Key string `hcl:"key,optional"` +} + +func (c CAConfig) CertificateAuthority() (ca.CertificateAuthority, error) { + return ca.Open(c.Cert, c.Key) +} + type DataConfig struct { Path string `hcl:"path,optional"` + Storage DataStorageConfig `hcl:"storage,block"` Domains []DomainDataConfig `hcl:"domain,block"` Networks []NetworkDataConfig `hcl:"network,block"` } @@ -165,6 +177,39 @@ func (c DataConfig) Configure() error { return nil } +func (c DataConfig) OpenStorage() (dataset.Storage, error) { + switch c.Storage.Type { + case "", "bolt", "boltdb": + var config struct { + Path string `hcl:"path"` + } + if diag := gohcl.DecodeBody(c.Storage.Body, nil, &config); diag.HasErrors() { + return nil, diag + } + //return dataset.OpenBolt(config.Path) + return dataset.OpenBStore(config.Path) + + /* + case "sqlite", "sqlite3": + var config struct { + Path string `hcl:"path"` + } + if diag := gohcl.DecodeBody(c.Storage.Body, nil, &config); diag.HasErrors() { + return nil, diag + } + return dataset.OpenSQLite(config.Path) + */ + + default: + return nil, fmt.Errorf("storage: no %q driver", c.Storage.Type) + } +} + +type DataStorageConfig struct { + Type string `hcl:"type"` + Body hcl.Body `hcl:",remain"` +} + type DomainDataConfig struct { Name string `hcl:"name,label"` Type string `hcl:"type"` diff --git a/cmd/styx/main.go b/cmd/styx/main.go index 18adf49..03d1d27 100644 --- a/cmd/styx/main.go +++ b/cmd/styx/main.go @@ -7,6 +7,9 @@ import ( "os/signal" "syscall" + "git.maze.io/maze/styx/admin" + "git.maze.io/maze/styx/ca" + "git.maze.io/maze/styx/dataset" "git.maze.io/maze/styx/logger" "git.maze.io/maze/styx/proxy" ) @@ -40,6 +43,22 @@ func main() { log.Err(err).Fatal("Invalid data configuration") } + var ca ca.CertificateAuthority + if config.CA != nil { + if ca, err = config.CA.CertificateAuthority(); err != nil { + log.Err(err).Fatal("Invalid ca configuration") + } + } + + var storage dataset.Storage + if storage, err = config.Data.OpenStorage(); err != nil { + log.Err(err).Fatal("Invalid data.storage configuration") + } + + admin := &admin.Admin{ + Storage: storage, + } + proxies, err := config.Proxies(log) if err != nil { log.Err(err).Fatal("Error configuring proxy ports") @@ -52,6 +71,9 @@ func main() { ) for i, p := range proxies { + p.CertificateAuthority = ca + p.Storage = storage + admin.Install(p) go run(config.Proxy.Port[i].Listen, p, errs) } @@ -64,12 +86,18 @@ func main() { case syscall.SIGHUP: log.Value("signal", sig.String()).Warn("Ignored reload signal ¯\\_(ツ)_/¯") default: - log.Value("signal", sig.String()).Info("Shutting down on signal") - return + log.Value("signal", sig.String()).Warn("Shutting down on signal") + close(done) } case <-done: - log.Info("Shutting down gracefully") + log.Warn("Shutting down gracefully") + for i, p := range proxies { + log.Value("port", config.Proxy.Port[i].Listen).Info("Proxy port closing") + if err := p.Close(); err != nil { + log.Err(err).Error("Error closing proxy") + } + } return case err = <-errs: diff --git a/dataset/base.go b/dataset/base.go new file mode 100644 index 0000000..8ff540f --- /dev/null +++ b/dataset/base.go @@ -0,0 +1 @@ +package dataset diff --git a/dataset/error.go b/dataset/error.go new file mode 100644 index 0000000..609ab80 --- /dev/null +++ b/dataset/error.go @@ -0,0 +1,25 @@ +package dataset + +import ( + "errors" + "fmt" + "os" + + "github.com/mjl-/bstore" +) + +type ErrNotExist struct { + Object string + ID int64 +} + +func (err ErrNotExist) Error() string { + return fmt.Sprintf("storage: %s not found", err.Object) +} + +func IsNotExist(err error) bool { + if err == nil { + return false + } + return os.IsNotExist(err) || errors.Is(err, ErrNotExist{}) || errors.Is(err, bstore.ErrAbsent) +} diff --git a/dataset/parser/adblock.go b/dataset/parser/adblock.go new file mode 100644 index 0000000..1b35df6 --- /dev/null +++ b/dataset/parser/adblock.go @@ -0,0 +1,53 @@ +package parser + +import ( + "bufio" + "io" + "strings" +) + +func init() { + RegisterDomainsParser(adblockDomainsParser{}) +} + +type adblockDomainsParser struct{} + +func (adblockDomainsParser) CanHandle(line string) bool { + return strings.HasPrefix(strings.ToLower(line), `[adblock`) || + strings.HasPrefix(line, "@@") || // exception + strings.HasPrefix(line, "||") || // blah + line[0] == '*' +} + +func (adblockDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if isComment(line) { + continue + } + + // Common AdBlock patterns: + // ||domain.com^ + // |http://domain.com| + // domain.com/path + // *domain.com* + switch { + case strings.HasPrefix(line, `||`): // domain anchor + if i := strings.IndexByte(line, '^'); i != -1 { + domains = append(domains, line[2:i]) + continue + } + case strings.HasPrefix(line, `|`) && strings.HasSuffix(line, `|`): + domains = append(domains, line[1:len(line)-2]) + continue + case strings.HasPrefix(line, `[`): + continue + } + ignored++ + } + if err = scanner.Err(); err != nil { + return + } + return unique(domains), ignored, nil +} diff --git a/dataset/parser/adblock_test.go b/dataset/parser/adblock_test.go new file mode 100644 index 0000000..9874bb5 --- /dev/null +++ b/dataset/parser/adblock_test.go @@ -0,0 +1,41 @@ +package parser + +import ( + "reflect" + "sort" + "strings" + "testing" +) + +func TestAdBlockParser(t *testing.T) { + test := `[Adblock Plus 2.0] +! Title: AdRules DNS List +! Homepage: https://github.com/Cats-Team/AdRules +! Powerd by Cats-Team +! Expires: 1 (update frequency) +! Description: The DNS Filters +! Total count: 145270 +! Update: 2025-10-07 02:05:08(GMT+8) +/^.+stat\.kugou\.com/ +/^admarvel\./ +||*-ad-sign.byteimg.com^ +||*-ad.a.yximgs.com^ +||*-applog.fqnovel.com^ +||*-datareceiver.aki-game.net^ +||*.exaapi.com^` + want := []string{"*-ad-sign.byteimg.com", "*-ad.a.yximgs.com", "*-applog.fqnovel.com", "*-datareceiver.aki-game.net", "*.exaapi.com"} + + parsed, ignored, err := ParseDomains(strings.NewReader(test)) + if err != nil { + t.Fatal(err) + return + } + + sort.Strings(parsed) + if !reflect.DeepEqual(parsed, want) { + t.Errorf("expected ParseDomains(domains) to return %v, got %v", want, parsed) + } + if ignored != 2 { + t.Errorf("expected 2 ignored, got %d", ignored) + } +} diff --git a/dataset/parser/dns.go b/dataset/parser/dns.go new file mode 100644 index 0000000..8b331f7 --- /dev/null +++ b/dataset/parser/dns.go @@ -0,0 +1,139 @@ +package parser + +import ( + "bufio" + "io" + "strings" + + "github.com/miekg/dns" +) + +func init() { + RegisterDomainsParser(dnsmasqDomainsParser{}) + RegisterDomainsParser(mosDNSDomainsParser{}) + RegisterDomainsParser(smartDNSDomainsParser{}) + RegisterDomainsParser(unboundDomainsParser{}) +} + +type dnsmasqDomainsParser struct{} + +func (dnsmasqDomainsParser) CanHandle(line string) bool { + return strings.HasPrefix(line, "address=/") +} + +func (dnsmasqDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if isComment(line) { + continue + } + switch { + case strings.HasPrefix(line, "address=/"): + part := strings.FieldsFunc(line, func(r rune) bool { return r == '/' }) + if len(part) >= 3 && isDomainName(part[1]) { + domains = append(domains, part[1]) + continue + } + } + ignored++ + } + if err = scanner.Err(); err != nil { + return + } + return unique(domains), ignored, nil +} + +type mosDNSDomainsParser struct{} + +func (mosDNSDomainsParser) CanHandle(line string) bool { + if strings.HasPrefix(line, "domain:") { + return isDomainName(line[7:]) + } + return false +} + +func (mosDNSDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if isComment(line) { + continue + } + if strings.HasPrefix(line, "domain:") { + domains = append(domains, line[7:]) + continue + } + ignored++ + } + if err = scanner.Err(); err != nil { + return + } + return unique(domains), ignored, nil +} + +type smartDNSDomainsParser struct{} + +func (smartDNSDomainsParser) CanHandle(line string) bool { + return strings.HasPrefix(line, "address /") +} + +func (smartDNSDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if isComment(line) { + continue + } + if strings.HasPrefix(line, "address /") { + if i := strings.IndexByte(line[9:], '/'); i > -1 { + domains = append(domains, line[9:i+9]) + continue + } + } + ignored++ + } + if err = scanner.Err(); err != nil { + return + } + return unique(domains), ignored, nil +} + +type unboundDomainsParser struct{} + +func (unboundDomainsParser) CanHandle(line string) bool { + return strings.HasPrefix(line, "local-data:") || + strings.HasPrefix(line, "local-zone:") +} + +func (unboundDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if isComment(line) { + continue + } + switch { + case strings.HasPrefix(line, "local-data:"): + record := strings.Trim(strings.TrimSpace(line[11:]), `"`) + if rr, err := dns.NewRR(record); err == nil { + switch rr.Header().Rrtype { + case dns.TypeA, dns.TypeAAAA, dns.TypeCNAME: + domains = append(domains, strings.Trim(rr.Header().Name, `.`)) + continue + } + } + case strings.HasPrefix(line, "local-zone:") && strings.HasSuffix(line, " reject"): + line = strings.Trim(strings.TrimSpace(line[11:]), `"`) + if i := strings.IndexByte(line, '"'); i > -1 { + domains = append(domains, line[:i]) + continue + } + } + ignored++ + } + if err = scanner.Err(); err != nil { + return + } + return unique(domains), ignored, nil +} diff --git a/dataset/parser/dns_test.go b/dataset/parser/dns_test.go new file mode 100644 index 0000000..7b06833 --- /dev/null +++ b/dataset/parser/dns_test.go @@ -0,0 +1,106 @@ +package parser + +import ( + "reflect" + "sort" + "strings" + "testing" +) + +func TestDNSMasqParser(t *testing.T) { + tests := []struct { + Name string + Test string + Want []string + WantIgnored int + }{ + { + "data", + ` +local-data: "junk1.doubleclick.net A 127.0.0.1" +local-data: "junk2.doubleclick.net A 127.0.0.1" +local-data: "junk2.doubleclick.net CNAME doubleclick.net." +local-data: "junk6.doubleclick.net AAAA ::1" +local-data: "doubleclick.net A 127.0.0.1" +local-data: "ad.junk1.doubleclick.net A 127.0.0.1" +local-data: "adjunk.google.com A 127.0.0.1"`, + []string{"ad.junk1.doubleclick.net", "adjunk.google.com", "doubleclick.net", "junk1.doubleclick.net", "junk2.doubleclick.net", "junk6.doubleclick.net"}, + 0, + }, + { + "zone", + ` +local-zone: "doubleclick.net" reject +local-zone: "adjunk.google.com" reject`, + []string{"adjunk.google.com", "doubleclick.net"}, + 0, + }, + { + "address", + ` +address=/ziyu.net/0.0.0.0 +address=/zlp6s.pw/0.0.0.0 +address=/zm232.com/0.0.0.0 + `, + []string{"ziyu.net", "zlp6s.pw", "zm232.com"}, + 0, + }, + } + for _, test := range tests { + t.Run(test.Name, func(it *testing.T) { + parsed, ignored, err := ParseDomains(strings.NewReader(test.Test)) + if err != nil { + t.Fatal(err) + return + } + + sort.Strings(parsed) + if !reflect.DeepEqual(parsed, test.Want) { + t.Errorf("expected ParseDomains(dnsmasq) to return\n\t%v, got\n\t%v", test.Want, parsed) + } + if ignored != test.WantIgnored { + t.Errorf("expected %d ignored, got %d", test.WantIgnored, ignored) + } + }) + } +} + +func TestMOSDNSParser(t *testing.T) { + test := `domain:0019x.com +domain:002777.xyz +domain:003store.com +domain:00404850.xyz` + want := []string{"0019x.com", "002777.xyz", "003store.com", "00404850.xyz"} + + parsed, _, err := ParseDomains(strings.NewReader(test)) + if err != nil { + t.Fatal(err) + return + } + + sort.Strings(parsed) + if !reflect.DeepEqual(parsed, want) { + t.Errorf("expected ParseDomains(domains) to return %v, got %v", want, parsed) + } +} + +func TestSmartDNSParser(t *testing.T) { + test := `# Title:AdRules SmartDNS List +# Update: 2025-10-07 02:05:08(GMT+8) +address /0.myikas.com/# +address /0.net.easyjet.com/# +address /0.nextyourcontent.com/# +address /0019x.com/#` + want := []string{"0.myikas.com", "0.net.easyjet.com", "0.nextyourcontent.com", "0019x.com"} + + parsed, _, err := ParseDomains(strings.NewReader(test)) + if err != nil { + t.Fatal(err) + return + } + + sort.Strings(parsed) + if !reflect.DeepEqual(parsed, want) { + t.Errorf("expected ParseDomains(domains) to return %v, got %v", want, parsed) + } +} diff --git a/dataset/parser/domains.go b/dataset/parser/domains.go new file mode 100644 index 0000000..31b38f9 --- /dev/null +++ b/dataset/parser/domains.go @@ -0,0 +1,40 @@ +package parser + +import ( + "bufio" + "io" + "net" + "strings" +) + +func init() { + domainsParsers = append(domainsParsers, domainsParser{}) +} + +type domainsParser struct{} + +func (domainsParser) CanHandle(line string) bool { + return isDomainName(line) && + !strings.ContainsRune(line, ' ') && + !strings.ContainsRune(line, ':') && + net.ParseIP(line) == nil +} + +func (domainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if isComment(line) { + continue + } + if isDomainName(line) { + domains = append(domains, line) + continue + } + ignored++ + } + if err = scanner.Err(); err != nil { + return + } + return unique(domains), ignored, nil +} diff --git a/dataset/parser/domains_test.go b/dataset/parser/domains_test.go new file mode 100644 index 0000000..38c6bd7 --- /dev/null +++ b/dataset/parser/domains_test.go @@ -0,0 +1,31 @@ +package parser + +import ( + "reflect" + "sort" + "strings" + "testing" +) + +func TestParseDomains(t *testing.T) { + test := `# This is a comment +facebook.com +tiktok.com +bogus ignored +youtube.com` + want := []string{"facebook.com", "tiktok.com", "youtube.com"} + + parsed, ignored, err := ParseDomains(strings.NewReader(test)) + if err != nil { + t.Fatal(err) + return + } + + sort.Strings(parsed) + if !reflect.DeepEqual(parsed, want) { + t.Errorf("expected ParseDomains(domains) to return %v, got %v", want, parsed) + } + if ignored != 1 { + t.Errorf("expected 1 ignored, got %d", ignored) + } +} diff --git a/dataset/parser/hosts.go b/dataset/parser/hosts.go new file mode 100644 index 0000000..2512424 --- /dev/null +++ b/dataset/parser/hosts.go @@ -0,0 +1,41 @@ +package parser + +import ( + "bufio" + "io" + "net" + "strings" +) + +func init() { + RegisterDomainsParser(hostsParser{}) +} + +type hostsParser struct{} + +func (hostsParser) CanHandle(line string) bool { + part := strings.Fields(line) + return len(part) >= 2 && net.ParseIP(part[0]) != nil +} + +func (hostsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if isComment(line) { + continue + } + + part := strings.Fields(line) + if len(part) >= 2 && net.ParseIP(part[0]) != nil { + domains = append(domains, part[1:]...) + continue + } + + ignored++ + } + if err = scanner.Err(); err != nil { + return + } + return unique(domains), ignored, nil +} diff --git a/dataset/parser/hosts_test.go b/dataset/parser/hosts_test.go new file mode 100644 index 0000000..38a434d --- /dev/null +++ b/dataset/parser/hosts_test.go @@ -0,0 +1,38 @@ +package parser + +import ( + "reflect" + "sort" + "strings" + "testing" +) + +func TestParseHosts(t *testing.T) { + test := `## +# Host Database +# +# localhost is used to configure the loopback interface +# when the system is booting. Do not change this entry. +## +127.0.0.1 localhost dragon dragon.local dragon.maze.network +255.255.255.255 broadcasthost +::1 localhost +ff00::1 multicast +1.2.3.4 +` + want := []string{"broadcasthost", "dragon", "dragon.local", "dragon.maze.network", "localhost", "multicast"} + + parsed, ignored, err := ParseDomains(strings.NewReader(test)) + if err != nil { + t.Fatal(err) + return + } + + sort.Strings(parsed) + if !reflect.DeepEqual(parsed, want) { + t.Errorf("expected ParseDomains(hosts) to return %v, got %v", want, parsed) + } + if ignored != 1 { + t.Errorf("expected 1 ignored, got %d", ignored) + } +} diff --git a/dataset/parser/parser.go b/dataset/parser/parser.go new file mode 100644 index 0000000..e398f27 --- /dev/null +++ b/dataset/parser/parser.go @@ -0,0 +1,76 @@ +package parser + +import ( + "bufio" + "bytes" + "errors" + "io" + "log" + "strings" + + "github.com/miekg/dns" +) + +var ErrNoParser = errors.New("no suitable parser could be found") + +type Parser interface { + CanHandle(line string) bool +} + +type DomainsParser interface { + Parser + ParseDomains(io.Reader) (domains []string, ignored int, err error) +} + +var domainsParsers []DomainsParser + +func RegisterDomainsParser(parser DomainsParser) { + domainsParsers = append(domainsParsers, parser) +} + +func ParseDomains(r io.Reader) (domains []string, ignored int, err error) { + var ( + buffer = new(bytes.Buffer) + scanner = bufio.NewScanner(io.TeeReader(r, buffer)) + line string + parser DomainsParser + ) + for scanner.Scan() { + line = strings.TrimSpace(scanner.Text()) + if isComment(line) { + continue + } + for _, parser = range domainsParsers { + if parser.CanHandle(line) { + log.Printf("using parser %T", parser) + return parser.ParseDomains(io.MultiReader(buffer, r)) + } + } + break + } + return nil, 0, ErrNoParser +} + +func isComment(line string) bool { + return line == "" || line[0] == '#' || line[0] == '!' +} + +func isDomainName(name string) bool { + n, ok := dns.IsDomainName(name) + return n >= 2 && ok +} + +func unique(strings []string) []string { + if strings == nil { + return nil + } + v := make(map[string]struct{}) + for _, s := range strings { + v[s] = struct{}{} + } + o := make([]string, 0, len(v)) + for k := range v { + o = append(o, k) + } + return o +} diff --git a/dataset/parser/parser_test.go b/dataset/parser/parser_test.go new file mode 100644 index 0000000..fd469e2 --- /dev/null +++ b/dataset/parser/parser_test.go @@ -0,0 +1,31 @@ +package parser + +import ( + "reflect" + "sort" + "testing" +) + +func TestUnique(t *testing.T) { + tests := []struct { + Name string + Test []string + Want []string + }{ + {"nil", nil, nil}, + {"single", []string{"test"}, []string{"test"}}, + {"duplicate", []string{"test", "test"}, []string{"test"}}, + {"multiple", []string{"a", "a", "b", "b", "b", "c"}, []string{"a", "b", "c"}}, + } + for _, test := range tests { + t.Run(test.Name, func(it *testing.T) { + v := unique(test.Test) + if v != nil { + sort.Strings(v) + } + if !reflect.DeepEqual(v, test.Want) { + it.Errorf("expected unique(%v) to return %v, got %v", test.Test, test.Want, v) + } + }) + } +} diff --git a/dataset/storage.go b/dataset/storage.go new file mode 100644 index 0000000..a5366a8 --- /dev/null +++ b/dataset/storage.go @@ -0,0 +1,231 @@ +package dataset + +import ( + "bufio" + "bytes" + "fmt" + "io" + "io/fs" + "net" + "net/http" + "net/url" + "os" + "slices" + "strings" + "time" + + _ "github.com/mattn/go-sqlite3" // SQLite3 driver + "github.com/miekg/dns" +) + +type Storage interface { + Groups() ([]Group, error) + GroupByID(int64) (Group, error) + GroupByName(name string) (Group, error) + SaveGroup(*Group) error + DeleteGroup(Group) error + + Clients() (Clients, error) + ClientByID(int64) (Client, error) + ClientByIP(net.IP) (Client, error) + SaveClient(*Client) error + DeleteClient(Client) error + + Lists() ([]List, error) + ListByID(int64) (List, error) + SaveList(*List) error + DeleteList(List) error +} + +type Group struct { + ID int64 `json:"id"` + Name string `json:"name" bstore:"nonzero,unique"` + IsEnabled bool `json:"is_enabled" bstore:"nonzero"` + Description string `json:"description"` + CreatedAt time.Time `json:"created_at" bstore:"nonzero"` + UpdatedAt time.Time `json:"updated_at" bstore:"nonzero"` + Storage Storage `json:"-" bstore:"-"` +} + +type Client struct { + ID int64 `json:"id"` + Network string `json:"network" bstore:"nonzero,index"` + IP string `json:"ip" bstore:"nonzero,unique IP+Mask"` + Mask int `json:"mask"` + Description string `json:"description"` + Groups []Group `json:"groups,omitempty" bstore:"-"` + CreatedAt time.Time `json:"created_at" bstore:"nonzero"` + UpdatedAt time.Time `json:"updated_at" bstore:"nonzero"` + Storage Storage `json:"-" bstore:"-"` +} + +type WithClient interface { + Client() (Client, error) +} + +type ClientGroup struct { + ID int64 `json:"id"` + ClientID int64 `json:"client_id" bstore:"ref Client,index"` + GroupID int64 `json:"group_id" bstore:"ref Group,index"` +} + +func (c *Client) ContainsIP(ip net.IP) bool { + ipnet := &net.IPNet{ + IP: net.ParseIP(c.IP), + Mask: net.CIDRMask(int(c.Mask), 32), + } + if ipnet.IP == nil { + return false + } + return ipnet.Contains(ip) +} + +func (c *Client) String() string { + ipnet := &net.IPNet{ + IP: net.ParseIP(c.IP), + Mask: net.CIDRMask(int(c.Mask), 32), + } + return ipnet.String() +} + +type Clients []Client + +func (cs Clients) ByIP(ip net.IP) *Client { + var candidates []*Client + for _, c := range cs { + if c.ContainsIP(ip) { + candidates = append(candidates, &c) + } + } + switch len(candidates) { + case 0: + return nil + case 1: + return candidates[0] + default: + slices.SortStableFunc(candidates, func(a, b *Client) int { + return int(b.Mask) - int(a.Mask) + }) + return candidates[0] + } +} + +const ( + ListTypeDomain = "domain" + ListTypeNetwork = "network" +) + +const ( + MinListRefresh = 1 * time.Minute + DefaultListRefresh = 30 * time.Minute +) + +type List struct { + ID int64 `json:"id"` + Type string `json:"type"` + Source string `json:"source"` + IsEnabled bool `json:"is_enabled"` + Permit bool `json:"permit"` + Groups []Group `json:"groups,omitempty" bstore:"-"` + Status int `json:"status"` + Comment string `json:"comment"` + Cache []byte `json:"cache"` + Refresh time.Duration `json:"refresh"` + LastModified time.Time `json:"last_modified"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (list *List) Domains() (*DomainTree, error) { + if list.Type != ListTypeDomain { + return nil, nil + } + + var ( + tree = NewDomainList() + scan = bufio.NewScanner(bytes.NewReader(list.Cache)) + ) + for scan.Scan() { + line := strings.TrimSpace(scan.Text()) + if line == "" || line[0] == '#' { + continue + } + if labels, ok := dns.IsDomainName(line); ok && labels >= 2 { + tree.Add(line) + } + } + if err := scan.Err(); err != nil { + return nil, err + } + return tree, nil +} + +func (list *List) Update() (updated bool, err error) { + u, err := url.Parse(list.Source) + if err != nil { + return false, err + } + + switch u.Scheme { + case "", "file": + return list.updateFile(u.Path) + case "http", "https": + return list.updateHTTP(u.String()) + default: + return false, fmt.Errorf("dataset: don't know how to update %s sources", u.Scheme) + } +} + +func (list *List) updateFile(name string) (updated bool, err error) { + var info fs.FileInfo + if info, err = os.Stat(name); err != nil { + return + } else if info.IsDir() { + return false, fmt.Errorf("dataset: list %d: %q is a directory", list.ID, name) + } + if updated = info.ModTime().After(list.UpdatedAt); !updated { + return + } + list.Cache, _ = os.ReadFile(name) + return +} + +func (list *List) updateHTTP(url string) (updated bool, err error) { + if updated, err = list.shouldUpdateHTTP(url); err != nil || !updated { + return + } + + var response *http.Response + if response, err = http.DefaultClient.Get(url); err != nil { + return + } + defer response.Body.Close() + if list.Cache, err = io.ReadAll(response.Body); err != nil { + return + } + return true, nil +} + +func (list *List) shouldUpdateHTTP(url string) (updated bool, err error) { + var response *http.Response + if response, err = http.DefaultClient.Head(url); err != nil { + return + } + defer response.Body.Close() + + if value := response.Header.Get("Last-Modified"); value != "" { + var lastModified time.Time + if lastModified, err = time.Parse(http.TimeFormat, value); err == nil { + return lastModified.After(list.LastModified), nil + } + } + + // There are no headers that would indicate last-modified time, so assume we have to update: + return true, nil +} + +type ListGroup struct { + ID int64 `json:"id"` + ListID int64 `json:"list_id" bstore:"ref List,index"` + GroupID int64 `json:"group_id" bstore:"ref Group,index"` +} diff --git a/dataset/storage_bstore.go b/dataset/storage_bstore.go new file mode 100644 index 0000000..89168a2 --- /dev/null +++ b/dataset/storage_bstore.go @@ -0,0 +1,412 @@ +package dataset + +import ( + "context" + "errors" + "fmt" + "net" + "path/filepath" + "slices" + "strings" + "time" + + "git.maze.io/maze/styx/logger" + "github.com/mjl-/bstore" +) + +type bstoreStorage struct { + db *bstore.DB + path string +} + +func OpenBStore(name string) (Storage, error) { + if !filepath.IsAbs(name) { + var err error + if name, err = filepath.Abs(name); err != nil { + return nil, err + } + } + + ctx := context.Background() + db, err := bstore.Open(ctx, name, nil, + Group{}, + Client{}, + ClientGroup{}, + List{}, + ListGroup{}, + ) + if err != nil { + return nil, err + } + + var ( + s = &bstoreStorage{db: db, path: name} + defaultGroup Group + defaultClient4 Client + defaultClient6 Client + ) + + if defaultGroup, err = s.GroupByName("Default"); errors.Is(err, bstore.ErrAbsent) { + defaultGroup = Group{ + Name: "Default", + IsEnabled: true, + Description: "Default group", + } + if err = s.SaveGroup(&defaultGroup); err != nil { + return nil, err + } + } else if err != nil { + return nil, err + } + if defaultClient4, err = bstore.QueryDB[Client](ctx, db). + FilterEqual("Network", "ipv4"). + FilterFn(func(client Client) bool { + return net.ParseIP(client.IP).Equal(net.ParseIP("0.0.0.0")) && client.Mask == 0 + }).Get(); errors.Is(err, bstore.ErrAbsent) { + defaultClient4 = Client{ + Network: "ipv4", + IP: "0.0.0.0", + Mask: 0, + Description: "All IPv4 clients", + } + if err = s.SaveClient(&defaultClient4); err != nil { + return nil, err + } + if err = s.db.Insert(ctx, &ClientGroup{ClientID: defaultClient4.ID, GroupID: defaultGroup.ID}); err != nil { + return nil, err + } + } else if err != nil { + return nil, err + } + if defaultClient6, err = bstore.QueryDB[Client](ctx, db). + FilterEqual("Network", "ipv6"). + FilterFn(func(client Client) bool { + return net.ParseIP(client.IP).Equal(net.ParseIP("::")) && client.Mask == 0 + }).Get(); errors.Is(err, bstore.ErrAbsent) { + defaultClient6 = Client{ + Network: "ipv6", + IP: "::", + Mask: 0, + Description: "All IPv6 clients", + } + if err = s.SaveClient(&defaultClient6); err != nil { + return nil, err + } + if err = s.db.Insert(ctx, &ClientGroup{ClientID: defaultClient6.ID, GroupID: defaultGroup.ID}); err != nil { + return nil, err + } + } else if err != nil { + return nil, err + } + + // Start updater + NewUpdater(s) + + return s, nil +} + +func (s *bstoreStorage) log() logger.Structured { + return logger.StandardLog.Values(logger.Values{ + "storage": "bstore", + "storage_path": s.path, + }) +} + +func (s *bstoreStorage) Groups() ([]Group, error) { + var ( + ctx = context.Background() + query = bstore.QueryDB[Group](ctx, s.db) + groups = make([]Group, 0) + ) + for group := range query.All() { + groups = append(groups, group) + } + if err := query.Err(); err != nil && !errors.Is(err, bstore.ErrFinished) { + return nil, err + } + return groups, nil +} + +func (s *bstoreStorage) GroupByID(id int64) (Group, error) { + ctx := context.Background() + return bstore.QueryDB[Group](ctx, s.db).FilterID(id).Get() +} + +func (s *bstoreStorage) GroupByName(name string) (Group, error) { + ctx := context.Background() + return bstore.QueryDB[Group](ctx, s.db).FilterFn(func(group Group) bool { + return strings.EqualFold(group.Name, name) + }).Get() +} + +func (s *bstoreStorage) SaveGroup(group *Group) (err error) { + ctx := context.Background() + group.UpdatedAt = time.Now().UTC() + if group.CreatedAt.Equal(time.Time{}) { + group.CreatedAt = group.UpdatedAt + err = s.db.Insert(ctx, group) + } else { + err = s.db.Update(ctx, group) + } + if err != nil { + return fmt.Errorf("dataset: save group %s failed: %w", group.Name, err) + } + return nil +} + +func (s *bstoreStorage) DeleteGroup(group Group) (err error) { + ctx := context.Background() + tx, err := s.db.Begin(ctx, true) + if err != nil { + return err + } + if _, err = bstore.QueryTx[ClientGroup](tx).FilterEqual("GroupID", group.ID).Delete(); err != nil { + return + } + if _, err = bstore.QueryTx[ListGroup](tx).FilterEqual("GroupID", group.ID).Delete(); err != nil { + return + } + if err = tx.Delete(group); err != nil { + return + } + return tx.Commit() +} + +func (s *bstoreStorage) Clients() (Clients, error) { + var ( + ctx = context.Background() + query = bstore.QueryDB[Client](ctx, s.db) + clients = make(Clients, 0) + ) + for client := range query.All() { + clients = append(clients, client) + } + if err := query.Err(); err != nil && !errors.Is(err, bstore.ErrFinished) { + return nil, err + } + return clients, nil +} + +func (s *bstoreStorage) ClientByID(id int64) (Client, error) { + ctx := context.Background() + client, err := bstore.QueryDB[Client](ctx, s.db).FilterID(id).Get() + if err != nil { + return client, err + } + return s.clientResolveGroups(ctx, client) +} + +func (s *bstoreStorage) ClientByIP(ip net.IP) (Client, error) { + if ip == nil { + return Client{}, ErrNotExist{Object: "client"} + } + var ( + ctx = context.Background() + clients Clients + network string + ) + if ip4 := ip.To4(); ip4 != nil { + network = "ipv4" + } else if ip6 := ip.To16(); ip6 != nil { + network = "ipv6" + } + if network == "" { + return Client{}, ErrNotExist{Object: "client"} + } + for client, err := range bstore.QueryDB[Client](ctx, s.db). + FilterEqual("Network", network). + FilterFn(func(client Client) bool { + return client.ContainsIP(ip) + }).All() { + if err != nil { + return Client{}, err + } + clients = append(clients, client) + } + + var client Client + switch len(clients) { + case 0: + return Client{}, ErrNotExist{Object: "client"} + case 1: + client = clients[0] + default: + slices.SortStableFunc(clients, func(a, b Client) int { + return int(b.Mask) - int(a.Mask) + }) + client = clients[0] + } + return s.clientResolveGroups(ctx, client) +} + +func (s *bstoreStorage) clientResolveGroups(ctx context.Context, client Client) (Client, error) { + for clientGroup, err := range bstore.QueryDB[ClientGroup](ctx, s.db).FilterEqual("ClientID", client.ID).All() { + if err != nil { + return Client{}, err + } + if group, err := s.GroupByID(clientGroup.GroupID); err == nil { + client.Groups = append(client.Groups, group) + } + } + return client, nil +} + +func (s *bstoreStorage) SaveClient(client *Client) (err error) { + log := s.log() + ctx := context.Background() + client.UpdatedAt = time.Now().UTC() + + tx, err := s.db.Begin(ctx, true) + if err != nil { + return err + } + + log = log.Values(logger.Values{"ip": client.IP, "mask": client.Mask, "description": client.Description}) + if client.CreatedAt.Equal(time.Time{}) { + log.Debug("Create client") + client.CreatedAt = client.UpdatedAt + if err = tx.Insert(client); err != nil { + return fmt.Errorf("dataset: client insert failed: %w", err) + } + } else { + log.Debug("Update client") + if err = tx.Update(client); err != nil { + return fmt.Errorf("dataset: client update failed: %w", err) + } + } + + var deleted int + if deleted, err = bstore.QueryTx[ClientGroup](tx).FilterEqual("ClientID", client.ID).Delete(); err != nil { + return fmt.Errorf("dataset: client groups delete failed: %w", err) + } + log.Debugf("Deleted %d groups", deleted) + log.Debugf("Linking %d groups", len(client.Groups)) + for _, group := range client.Groups { + if err = tx.Insert(&ClientGroup{ClientID: client.ID, GroupID: group.ID}); err != nil { + return fmt.Errorf("dataset: client groups insert failed: %w", err) + } + } + + return tx.Commit() +} + +func (s *bstoreStorage) DeleteClient(client Client) (err error) { + ctx := context.Background() + tx, err := s.db.Begin(ctx, true) + if err != nil { + return err + } + if _, err = bstore.QueryTx[ClientGroup](tx).FilterEqual("ClientID", client.ID).Delete(); err != nil { + return + } + if err = tx.Delete(client); err != nil { + return + } + return tx.Commit() +} + +func (s *bstoreStorage) Lists() ([]List, error) { + var ( + ctx = context.Background() + query = bstore.QueryDB[List](ctx, s.db) + lists = make([]List, 0) + ) + for list := range query.All() { + lists = append(lists, list) + } + if err := query.Err(); err != nil && !errors.Is(err, bstore.ErrFinished) { + return nil, err + } + return lists, nil +} + +func (s *bstoreStorage) ListByID(id int64) (List, error) { + ctx := context.Background() + list, err := bstore.QueryDB[List](ctx, s.db).FilterID(id).Get() + if err != nil { + return list, err + } + return s.listResolveGroups(ctx, list) +} + +func (s *bstoreStorage) listResolveGroups(ctx context.Context, list List) (List, error) { + for listGroup, err := range bstore.QueryDB[ListGroup](ctx, s.db).FilterEqual("ListID", list.ID).All() { + if err != nil { + return List{}, err + } + if group, err := s.GroupByID(listGroup.GroupID); err == nil { + list.Groups = append(list.Groups, group) + } + } + return list, nil +} + +func (s *bstoreStorage) SaveList(list *List) (err error) { + if list.Type != ListTypeDomain && list.Type != ListTypeNetwork { + return fmt.Errorf("storage: unknown list type %q", list.Type) + } + if list.Refresh == 0 { + list.Refresh = DefaultListRefresh + } else if list.Refresh < MinListRefresh { + list.Refresh = MinListRefresh + } + list.UpdatedAt = time.Now().UTC() + + ctx := context.Background() + tx, err := s.db.Begin(ctx, true) + if err != nil { + return err + } + + log := s.log() + log = log.Values(logger.Values{ + "type": list.Type, + "source": list.Source, + "is_enabled": list.IsEnabled, + "status": list.Status, + "cache": len(list.Cache), + "refresh": list.Refresh, + }) + + if list.CreatedAt.Equal(time.Time{}) { + log.Debug("Creating list") + list.CreatedAt = list.UpdatedAt + if err = tx.Insert(list); err != nil { + return fmt.Errorf("dataset: list insert failed: %w", err) + } + } else { + log.Debug("Updating list") + if err = tx.Update(list); err != nil { + return fmt.Errorf("dataset: list update failed: %w", err) + } + } + + var deleted int + if deleted, err = bstore.QueryTx[ListGroup](tx).FilterEqual("ListID", list.ID).Delete(); err != nil { + return fmt.Errorf("dataset: list groups delete failed: %w", err) + } + log.Debugf("Deleted %d groups", deleted) + log.Debugf("Linking %d groups", len(list.Groups)) + for _, group := range list.Groups { + if err = tx.Insert(&ListGroup{ListID: list.ID, GroupID: group.ID}); err != nil { + return fmt.Errorf("dataset: list groups insert failed: %w", err) + } + } + + return tx.Commit() +} + +func (s *bstoreStorage) DeleteList(list List) (err error) { + ctx := context.Background() + tx, err := s.db.Begin(ctx, true) + if err != nil { + return err + } + if _, err = bstore.QueryTx[ListGroup](tx).FilterEqual("ListID", list.ID).Delete(); err != nil { + return + } + if err = tx.Delete(list); err != nil { + return + } + return tx.Commit() +} diff --git a/dataset/updater.go b/dataset/updater.go new file mode 100644 index 0000000..ee6667c --- /dev/null +++ b/dataset/updater.go @@ -0,0 +1,226 @@ +package dataset + +import ( + "bytes" + "io" + "net/http" + "net/url" + "os" + "sync" + "time" + + "git.maze.io/maze/styx/logger" +) + +type Updater struct { + storage Storage + lists sync.Map // map[int64]List + updaters sync.Map // map[int64]*updaterJob + done chan struct{} +} + +func NewUpdater(storage Storage) *Updater { + u := &Updater{ + storage: storage, + done: make(chan struct{}, 1), + } + go u.refresh() + return u +} + +func (u *Updater) Close() error { + select { + case <-u.done: + return nil + default: + close(u.done) + return nil + } +} + +func (u *Updater) refresh() { + check := time.NewTicker(time.Second) + defer check.Stop() + + var ( + log = logger.StandardLog + ) + for { + select { + case <-u.done: + log.Debug("Updater closing, stopping updaters...") + u.updaters.Range(func(key, value any) bool { + if value != nil { + close(value.(*updaterJob).done) + } + return true + }) + return + + case now := <-check.C: + u.check(now, log) + } + } +} + +func (u *Updater) check(now time.Time, log logger.Structured) (wait time.Duration) { + log.Trace("Checking lists") + lists, err := u.storage.Lists() + if err != nil { + log.Err(err).Error("Updater can't retrieve lists") + return -1 + } + + var missing = make(map[int64]bool) + u.lists.Range(func(key, _ any) bool { + log.Tracef("List %d has updater running", key) + missing[key.(int64)] = true + return true + }) + for _, list := range lists { + log.Tracef("List %d is active: %t", list.ID, list.IsEnabled) + if !list.IsEnabled { + continue + } + delete(missing, list.ID) + if _, exists := u.lists.Load(list.ID); !exists { + u.lists.Store(list.ID, list) + updater := newUpdaterJob(u.storage, &list) + u.updaters.Store(list.ID, updater) + } + } + + for id := range missing { + log.Tracef("List %d has updater running, but is no longer active, reaping...", id) + if updater, ok := u.updaters.Load(id); ok { + close(updater.(*updaterJob).done) + u.updaters.Delete(id) + } + } + + return +} + +type updaterJob struct { + storage Storage + list *List + done chan struct{} +} + +func newUpdaterJob(storage Storage, list *List) *updaterJob { + job := &updaterJob{ + storage: storage, + list: list, + done: make(chan struct{}, 1), + } + go job.loop() + return job +} + +func (job *updaterJob) loop() { + var ( + ticker = time.NewTicker(job.list.Refresh) + first = time.After(0) + now time.Time + log = logger.StandardLog.Values(logger.Values{ + "list": job.list.ID, + "type": job.list.Type, + }) + ) + defer ticker.Stop() + for { + select { + case <-job.done: + log.Debug("List updater stopping") + return + + case now = <-ticker.C: + case now = <-first: + } + + log.Debug("List updater running") + if update, err := job.run(now); err != nil { + log.Err(err).Error("List updater failed") + } else if update { + if err = job.storage.SaveList(job.list); err != nil { + log.Err(err).Error("List updater save failed") + } + } + } +} + +// run this updater +func (job *updaterJob) run(now time.Time) (update bool, err error) { + u, err := url.Parse(job.list.Source) + if err != nil { + return false, err + } + + log := logger.StandardLog.Values(logger.Values{ + "list": job.list.ID, + "source": job.list.Source, + }) + if u.Scheme == "" || u.Scheme == "file" { + log.Debug("Updating list from file") + return job.updateFile(u.Path) + } + log.Debug("Updating list from URL") + return job.updateHTTP(u) +} + +func (job *updaterJob) updateFile(name string) (update bool, err error) { + var b []byte + if b, err = os.ReadFile(name); err != nil { + return + } + if update = !bytes.Equal(b, job.list.Cache); update { + job.list.Cache = b + } + return +} + +func (job *updaterJob) updateHTTP(location *url.URL) (update bool, err error) { + if update, err = job.shouldUpdateHTTP(location); err != nil || !update { + return + } + var ( + req *http.Request + res *http.Response + ) + if req, err = http.NewRequest(http.MethodGet, location.String(), nil); err != nil { + return + } + if res, err = http.DefaultClient.Do(req); err != nil { + return + } + defer res.Body.Close() + + if job.list.Cache, err = io.ReadAll(res.Body); err != nil { + return + } + return true, nil +} + +func (job *updaterJob) shouldUpdateHTTP(location *url.URL) (update bool, err error) { + if len(job.list.Cache) == 0 { + // Nothing cached, please update. + return true, nil + } + + var ( + req *http.Request + res *http.Response + ) + if req, err = http.NewRequest(http.MethodHead, location.String(), nil); err != nil { + return + } + if res, err = http.DefaultClient.Do(req); err != nil { + return + } + defer res.Body.Close() + + if lastModified, err := time.Parse(http.TimeFormat, res.Header.Get("Last-Modified")); err == nil { + return lastModified.After(job.list.UpdatedAt), nil + } + return true, nil // not sure, no Last-Modified, so let's update? +} diff --git a/go.mod b/go.mod index 9fbaccd..c941b30 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/hashicorp/hcl/v2 v2.24.0 github.com/mattn/go-sqlite3 v1.14.32 github.com/miekg/dns v1.1.68 + github.com/mjl-/bstore v0.0.10 github.com/open-policy-agent/opa v1.9.0 github.com/rs/zerolog v1.34.0 github.com/sirupsen/logrus v1.9.4-0.20230606125235-dd1b4c2e81af @@ -54,6 +55,7 @@ require ( github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/yashtewari/glob-intersection v0.2.0 // indirect github.com/zclconf/go-cty v1.16.3 // indirect + go.etcd.io/bbolt v1.4.3 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/otel v1.38.0 // indirect go.opentelemetry.io/otel/metric v1.38.0 // indirect diff --git a/go.sum b/go.sum index 558c1d0..a43c774 100644 --- a/go.sum +++ b/go.sum @@ -98,6 +98,8 @@ github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA= github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps= github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0= github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0= +github.com/mjl-/bstore v0.0.10 h1:fYLQy3EdgXvRHoa8Q3sXMAjZf+uQLRbsh9rYjGep/t4= +github.com/mjl-/bstore v0.0.10/go.mod h1:QzqlAZAVRKwyojCRd9v25viFsMxK5UmIbdxgEyHdK6c= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/open-policy-agent/opa v1.9.0 h1:QWFNwbcc29IRy0xwD3hRrMc/RtSersLY1Z6TaID3vgI= @@ -152,6 +154,8 @@ github.com/zclconf/go-cty v1.16.3 h1:osr++gw2T61A8KVYHoQiFbFd1Lh3JOCXc/jFLJXKTxk github.com/zclconf/go-cty v1.16.3/go.mod h1:VvMs5i0vgZdhYawQNq5kePSpLAoz8u1xvZgrPIxfnZE= github.com/zclconf/go-cty-debug v0.0.0-20240509010212-0d6042c53940 h1:4r45xpDWB6ZMSMNJFMOjqrGHynW3DIBuR2H9j0ug+Mo= github.com/zclconf/go-cty-debug v0.0.0-20240509010212-0d6042c53940/go.mod h1:CmBdvvj3nqzfzJ6nTCIwDTPZ56aVGvDrmztiO5g3qrM= +go.etcd.io/bbolt v1.4.3 h1:dEadXpI6G79deX5prL3QRNP6JB8UxVkqo4UPnHaNXJo= +go.etcd.io/bbolt v1.4.3/go.mod h1:tKQlpPaYCVFctUIgFKFnAlvbmB3tpy1vkTnDWohtc0E= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18= diff --git a/internal/netutil/conn.go b/internal/netutil/conn.go index 7f1379e..c9a40c9 100644 --- a/internal/netutil/conn.go +++ b/internal/netutil/conn.go @@ -8,6 +8,8 @@ import ( "sync/atomic" "syscall" "time" + + "git.maze.io/maze/styx/logger" ) // BufferedConn uses byte buffers for Read and Write operations on a [net.Conn]. @@ -123,10 +125,13 @@ type AcceptOnce struct { } func (listener *AcceptOnce) Accept() (net.Conn, error) { + log := logger.StandardLog.Value("client", listener.Conn.RemoteAddr().String()) if listener.once.Load() { + log.Trace("Accept already happened, responding EOF") return nil, io.EOF } listener.once.Store(true) + log.Trace("Accept client") return listener.Conn, nil } diff --git a/internal/timeutil/time.go b/internal/timeutil/time.go new file mode 100644 index 0000000..28c686b --- /dev/null +++ b/internal/timeutil/time.go @@ -0,0 +1,74 @@ +package timeutil + +import "time" + +var ( + validTimeLayouts = []string{ + "15:04:05.999999999", + "15:04:05", + "15:04", + "3:04:05PM", + "3:04PM", + "3PM", + } +) + +type Time struct { + Hour int + Minute int + Second int + Nanosecond int +} + +func ParseTime(value string) (Time, error) { + var t time.Time + for _, layout := range validTimeLayouts { + var err error + if t, err = time.Parse(layout, value); err == nil { + return Time{ + Hour: t.Hour(), + Minute: t.Minute(), + Second: t.Second(), + Nanosecond: t.Nanosecond(), + }, nil + } + } + return Time{}, &time.ParseError{ + Value: value, + Message: "invalid time", + } +} + +func Now() Time { + t := time.Now() + return Time{ + Hour: t.Hour(), + Minute: t.Minute(), + Second: t.Second(), + Nanosecond: t.Nanosecond(), + } +} + +func (t Time) After(other Time) bool { + return other.Before(t) +} + +func (t Time) Before(other Time) bool { + if t.Hour == other.Hour { + if t.Minute == other.Minute { + if t.Second == other.Second { + return t.Nanosecond < other.Nanosecond + } + return t.Second < other.Second + } + return t.Minute < other.Minute + } + return t.Hour < other.Hour +} + +func (t Time) Eq(other Time) bool { + return t.Hour == other.Hour && + t.Minute == other.Minute && + t.Second == other.Second && + t.Nanosecond == other.Nanosecond +} diff --git a/policy/func.go b/policy/func.go index c039a66..85024ac 100644 --- a/policy/func.go +++ b/policy/func.go @@ -15,17 +15,25 @@ import ( "github.com/open-policy-agent/opa/v1/types" "git.maze.io/maze/styx/dataset" + "git.maze.io/maze/styx/internal/timeutil" "git.maze.io/maze/styx/logger" ) -var netLookupIPAddrDecl = types.NewFunction( +var lookupIPAddrFunc = ®o.Function{ + Name: "styx.lookup_ip_addr", + Decl: lookupIPAddrDecl, + Memoize: true, + Nondeterministic: true, +} + +var lookupIPAddrDecl = types.NewFunction( types.Args( types.Named("name", types.S).Description("Host name to lookup"), ), types.Named("result", types.SetOfStr).Description("set(string) of IP address"), ) -func netLookupIPAddrImpl(bc rego.BuiltinContext, nameTerm *ast.Term) (*ast.Term, error) { +func lookupIPAddr(bc rego.BuiltinContext, nameTerm *ast.Term) (*ast.Term, error) { log := logger.StandardLog.Value("func", "styx.lookup_ip_addr") log.Trace("Call function") @@ -61,6 +69,57 @@ func netLookupIPAddrImpl(bc rego.BuiltinContext, nameTerm *ast.Term) (*ast.Term, return ast.SetTerm(terms...), nil } +var timebetweenFunc = ®o.Function{ + Name: "styx.time_between", + Decl: timeBetweenDecl, + Nondeterministic: false, +} + +var timeBetweenDecl = types.NewFunction( + types.Args( + types.Named("start", types.S).Description("Start time"), + types.Named("end", types.S).Description("End time"), + ), + types.Named("result", types.B).Description("`true` if the current local time is between `start` and `end`"), +) + +func timeBetween(bc rego.BuiltinContext, startTerm, endTerm *ast.Term) (*ast.Term, error) { + log := logger.StandardLog.Value("func", "styx.time_between") + log.Trace("Call function") + + start, err := parseTimeTerm(startTerm) + if err != nil { + log.Err(err).Debug("Invalid start time") + return nil, err + } + end, err := parseTimeTerm(endTerm) + if err != nil { + log.Err(err).Debug("Invalid end time") + return nil, err + } + + now := timeutil.Now() + if start.Before(end) { + return ast.BooleanTerm((now.Eq(start) || now.After(start)) && now.Before(end)), nil + } + return ast.BooleanTerm(now.Eq(end) || now.After(end) || now.Before(start)), nil +} + +func parseTimeTerm(term *ast.Term) (timeutil.Time, error) { + timeArg, ok := term.Value.(ast.String) + if !ok { + return timeutil.Time{}, errors.New("expected string argument") + } + return timeutil.ParseTime(strings.Trim(timeArg.String(), `"`)) +} + +var domainContainsFunc = ®o.Function{ + Name: "styx.domains_contain", + Decl: domainContainsDecl, + Memoize: true, + Nondeterministic: true, +} + var domainContainsDecl = types.NewFunction( types.Args( types.Named("list", types.S).Description("Domain list to check against"), @@ -69,8 +128,8 @@ var domainContainsDecl = types.NewFunction( types.Named("result", types.B).Description("`true` if `name` is contained within `list`"), ) -func domainContainsImpl(bc rego.BuiltinContext, listTerm, nameTerm *ast.Term) (*ast.Term, error) { - log := logger.StandardLog.Value("func", "styx.in_domains") +func domainContains(bc rego.BuiltinContext, listTerm, nameTerm *ast.Term) (*ast.Term, error) { + log := logger.StandardLog.Value("func", "styx.domains_contain") log.Trace("Call function") list, err := parseDomainListTerm(listTerm) @@ -91,6 +150,13 @@ func domainContainsImpl(bc rego.BuiltinContext, listTerm, nameTerm *ast.Term) (* return ast.BooleanTerm(list.Contains(name)), nil } +var networkContainsFunc = ®o.Function{ + Name: "styx.networks_contain", + Decl: networkContainsDecl, + Memoize: true, + Nondeterministic: true, +} + var networkContainsDecl = types.NewFunction( types.Args( types.Named("list", types.S).Description("Network list to check against"), @@ -99,8 +165,8 @@ var networkContainsDecl = types.NewFunction( types.Named("result", types.B).Description("`true` if `ip` is contained within `list`"), ) -func networkContainsImpl(bc rego.BuiltinContext, listTerm, ipTerm *ast.Term) (*ast.Term, error) { - log := logger.StandardLog.Value("func", "styx.in_networks") +func networkContains(bc rego.BuiltinContext, listTerm, ipTerm *ast.Term) (*ast.Term, error) { + log := logger.StandardLog.Value("func", "styx.networks_contain") list, err := parseNetworkListTerm(listTerm) if err != nil { diff --git a/policy/handler.go b/policy/handler.go index aaa3761..f32c729 100644 --- a/policy/handler.go +++ b/policy/handler.go @@ -1,9 +1,12 @@ package policy import ( + "bufio" + "crypto/tls" "net" "net/http" + "git.maze.io/maze/styx/ca" "git.maze.io/maze/styx/internal/netutil" "git.maze.io/maze/styx/logger" proxy "git.maze.io/maze/styx/proxy" @@ -24,6 +27,7 @@ func NewRequestHandler(p *Policy) proxy.RequestHandler { log.Err(err).Error("Error generating response") return nil, nil } + log.Debug("Replacing HTTP response from policy") return nil, r }) } @@ -47,21 +51,52 @@ func NewDialHandler(p *Policy) proxy.DialHandler { return nil, nil } - c := netutil.NewLoopback() + // Create a fake loopback connection + pipe := netutil.NewLoopback() go func(c net.Conn) { - s := &http.Server{ - Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - r.Write(w) - }), + defer func() { _ = c.Close() }() + if req.URL.Scheme == "https" || req.URL.Scheme == "wss" || netutil.Port(req.URL.Host) == 443 { + c = maybeUpgradeToTLS(c, ctx, req, log) } - _ = s.Serve(&netutil.AcceptOnce{Conn: c}) - }(c.Server) - return c.Client, nil + br := bufio.NewReader(c) + if _, err := http.ReadRequest(br); err != nil { + log.Err(err).Warn("Malformed HTTP request in MITM connection") + } + _ = r.Write(c) + }(pipe.Server) + + return pipe.Client, nil }) } +func maybeUpgradeToTLS(c net.Conn, ctx proxy.Context, req *http.Request, log logger.Structured) net.Conn { + var ca ca.CertificateAuthority + if caCtx, ok := ctx.(proxy.WithCertificateAuthority); ok { + ca = caCtx.CertificateAuthority() + } + if ca == nil { + return c + } + + secure := tls.Server(c, &tls.Config{ + GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + log.Values(logger.Values{ + "cn": req.URL.Host, + "names": hello.ServerName, + }).Debug("Requesting certificate from CA") + return ca.GetCertificate(netutil.Host(req.URL.Host), []string{hello.ServerName}, nil) + }, + NextProtos: []string{"http/1.1"}, + }) + if err := secure.Handshake(); err != nil { + log.Err(err).Warn("Failed to pretend secure HTTP") + return c + } + return secure +} + func NewForwardHandler(p *Policy) proxy.ForwardHandler { log := logger.StandardLog.Value("policy", p.name) return proxy.ForwardHandlerFunc(func(ctx proxy.Context, req *http.Request) (*http.Response, error) { @@ -72,7 +107,15 @@ func NewForwardHandler(p *Policy) proxy.ForwardHandler { log.Err(err).Error("Error evaulating policy") return nil, nil } - return result.Response(ctx) + r, err := result.Response(ctx) + if err != nil { + log.Err(err).Error("Error generating response") + return nil, err + } + if r != nil { + log.Debug("Replacing HTTP response from policy") + } + return r, nil }) } @@ -80,6 +123,7 @@ func NewResponseHandler(p *Policy) proxy.ResponseHandler { log := logger.StandardLog.Value("policy", p.name) return proxy.ResponseHandlerFunc(func(ctx proxy.Context) *http.Response { input := NewInputFromResponse(ctx, ctx.Response()) + input.logValues(log).Trace("Running response handler") result, err := p.Query(input) if err != nil { log.Err(err).Error("Error evaulating policy") @@ -90,6 +134,9 @@ func NewResponseHandler(p *Policy) proxy.ResponseHandler { log.Err(err).Error("Error generating response") return nil } + if r != nil { + log.Debug("Replacing HTTP response from policy") + } return r }) } diff --git a/policy/input.go b/policy/input.go index 9cf04f6..84327c6 100644 --- a/policy/input.go +++ b/policy/input.go @@ -10,19 +10,26 @@ import ( "net/url" "strconv" + "git.maze.io/maze/styx/dataset" "git.maze.io/maze/styx/internal/netutil" "git.maze.io/maze/styx/logger" + proxy "git.maze.io/maze/styx/proxy" ) // Input represents the input to the policy query. type Input struct { - Client *Client `json:"client"` - TLS *TLS `json:"tls"` - Request *Request `json:"request"` - Response *Response `json:"response"` + Context map[string]any `json:"context"` + Client *Client `json:"client"` + Groups []*Group `json:"groups"` + TLS *TLS `json:"tls"` + Request *Request `json:"request"` + Response *Response `json:"response"` } func (i *Input) logValues(log logger.Structured) logger.Structured { + if i.Context != nil { + log = log.Values(i.Context) + } log = i.Client.logValues(log) log = i.TLS.logValues(log) log = i.Request.logValues(log) @@ -34,10 +41,29 @@ func NewInputFromConn(c net.Conn) *Input { if c == nil { return new(Input) } - return &Input{ - Client: NewClientFromConn(c), - TLS: NewTLSFromConn(c), + + input := &Input{ + Context: make(map[string]any), + Client: NewClientFromConn(c), + TLS: NewTLSFromConn(c), } + + if wcl, ok := c.(dataset.WithClient); ok { + client, err := wcl.Client() + if err == nil { + input.Context["client_id"] = client.ID + input.Context["client_description"] = client.Description + input.Context["groups"] = client.Groups + } + } + + if ctx, ok := c.(proxy.Context); ok { + input.Context["local"] = NewClientFromAddr(ctx.LocalAddr()) + input.Context["bytes_rx"] = ctx.BytesRead() + input.Context["bytes_tx"] = ctx.BytesSent() + } + + return input } func NewInputFromRequest(c net.Conn, r *http.Request) *Input { @@ -131,6 +157,10 @@ func NewClientFromAddr(addr net.Addr) *Client { } } +type Group struct { + Name string `json:"name"` +} + type TLS struct { Version string `json:"version"` CipherSuite string `json:"cipher_suite"` diff --git a/policy/policy.go b/policy/policy.go index 3015e36..6db76c5 100644 --- a/policy/policy.go +++ b/policy/policy.go @@ -67,24 +67,10 @@ func newRego(option func(*rego.Rego), pkg string) []func(*rego.Rego) { rego.Query("data." + pkg), rego.Strict(true), rego.Capabilities(capabilities), - rego.Function2(®o.Function{ - Name: "styx.in_domains", - Decl: domainContainsDecl, - Memoize: true, - Nondeterministic: true, - }, domainContainsImpl), - rego.Function2(®o.Function{ - Name: "styx.in_networks", - Decl: networkContainsDecl, - Memoize: true, - Nondeterministic: true, - }, networkContainsImpl), - rego.Function1(®o.Function{ - Name: "styx.lookup_ip_addr", // override builtin - Decl: netLookupIPAddrDecl, - Memoize: true, - Nondeterministic: true, - }, netLookupIPAddrImpl), + rego.Function2(domainContainsFunc, domainContains), + rego.Function2(networkContainsFunc, networkContains), + rego.Function1(lookupIPAddrFunc, lookupIPAddr), + rego.Function2(timebetweenFunc, timeBetween), rego.PrintHook(printHook{}), option, } @@ -128,16 +114,20 @@ func (r *Result) Response(ctx proxy.Context) (*http.Response, error) { switch { case r.Redirect != "": + log.Value("location", r.Redirect).Trace("Creating a HTTP redirect response") response := proxy.NewResponse(http.StatusFound, nil, ctx.Request()) response.Header.Set("Server", "styx") response.Header.Set(proxy.HeaderLocation, r.Redirect) return response, nil case r.Template != "": + log = log.Value("template", r.Template) + log.Trace("Creating a HTTP template response") + b := new(bytes.Buffer) t, err := template.New(filepath.Base(r.Template)).ParseFiles(r.Template) if err != nil { - log.Value("template", r.Template).Err(err).Warn("Error loading template in response") + log.Err(err).Warn("Error loading template in response") return nil, err } t = t.Funcs(template.FuncMap{ @@ -149,7 +139,7 @@ func (r *Result) Response(ctx proxy.Context) (*http.Response, error) { "Response": ctx.Response(), "Errors": r.Errors, }); err != nil { - log.Value("template", r.Template).Err(err).Warn("Error rendering template response") + log.Err(err).Warn("Error rendering template response") return nil, err } @@ -159,46 +149,34 @@ func (r *Result) Response(ctx proxy.Context) (*http.Response, error) { return response, nil case r.Reject > 0: + log.Value("code", r.Reject).Trace("Creating a HTTP reject response") body := io.NopCloser(bytes.NewBufferString(http.StatusText(r.Reject))) response := proxy.NewResponse(r.Reject, body, ctx.Request()) response.Header.Set(proxy.HeaderContentType, "text/plain") return response, nil case r.Permit != nil && !*r.Permit: + log.Trace("Creating a HTTP reject response due to explicit not permit") body := io.NopCloser(bytes.NewBufferString(http.StatusText(http.StatusForbidden))) response := proxy.NewResponse(http.StatusForbidden, body, ctx.Request()) response.Header.Set(proxy.HeaderContentType, "text/plain") return response, nil default: + log.Trace("Not creating a HTTP response") return nil, nil } } func (p *Policy) Query(input *Input) (*Result, error) { - /* - e := json.NewEncoder(os.Stdout) - e.SetIndent("", " ") - e.Encode(doc) - */ - log := logger.StandardLog.Value("policy", p.name) log.Trace("Evaluating policy") - r := rego.New(append(p.options, rego.Input(input))...) - - ctx := context.Background() - /* - query, err := p.rego.PrepareForEval(ctx) - if err != nil { - return nil, err - } - rs, err := query.Eval(ctx, rego.EvalInput(input)) - if err != nil { - return nil, err - } - */ - rs, err := r.Eval(ctx) + var ( + rego = rego.New(append(p.options, rego.Input(input))...) + ctx = context.Background() + rs, err = rego.Eval(ctx) + ) if err != nil { return nil, err } @@ -208,6 +186,12 @@ func (p *Policy) Query(input *Input) (*Result, error) { result := &Result{} for _, expr := range rs[0].Expressions { if m, ok := expr.Value.(map[string]any); ok { + // Remove private variables. + for k := range m { + if len(k) > 0 && k[0] == '_' { + delete(m, k) + } + } log.Values(m).Trace("Policy result expression") if err = mapstructure.Decode(m, result); err != nil { return nil, err diff --git a/proxy/context.go b/proxy/context.go index b743acf..a984be8 100644 --- a/proxy/context.go +++ b/proxy/context.go @@ -14,6 +14,8 @@ import ( "sync/atomic" "time" + "git.maze.io/maze/styx/ca" + "git.maze.io/maze/styx/dataset" "git.maze.io/maze/styx/logger" ) @@ -42,6 +44,13 @@ type Context interface { // Response is the response that will be sent back to the client. Response() *http.Response + + // Client group. + Client() (dataset.Client, error) +} + +type WithCertificateAuthority interface { + CertificateAuthority() ca.CertificateAuthority } type countingReader struct { @@ -80,6 +89,9 @@ type proxyContext struct { req *http.Request res *http.Response idleTimeout time.Duration + ca ca.CertificateAuthority + storage dataset.Storage + client dataset.Client } // NewContext returns an initialized context for the provided [net.Conn]. @@ -218,4 +230,28 @@ func (c *proxyContext) WriteHeader(code int) { //return c.res.Header.Write(c) } +func (c *proxyContext) CertificateAuthority() ca.CertificateAuthority { + return c.ca +} + +func (c *proxyContext) Client() (dataset.Client, error) { + if c.storage == nil { + return dataset.Client{}, dataset.ErrNotExist{Object: "client"} + } + if !c.client.CreatedAt.Equal(time.Time{}) { + return c.client, nil + } + + var err error + switch addr := c.Conn.RemoteAddr().(type) { + case *net.TCPAddr: + c.client, err = c.storage.ClientByIP(addr.IP) + case *net.UDPAddr: + c.client, err = c.storage.ClientByIP(addr.IP) + default: + err = dataset.ErrNotExist{Object: "client"} + } + return c.client, err +} + var _ Context = (*proxyContext)(nil) diff --git a/proxy/proxy.go b/proxy/proxy.go index 66e64c1..7c8e365 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -15,9 +15,12 @@ import ( "slices" "strconv" "strings" + "sync" "syscall" "time" + "git.maze.io/maze/styx/ca" + "git.maze.io/maze/styx/dataset" "git.maze.io/maze/styx/internal/netutil" "git.maze.io/maze/styx/logger" "git.maze.io/maze/styx/stats" @@ -26,6 +29,7 @@ import ( // Common HTTP headers. const ( HeaderConnection = "Connection" + HeaderContentLength = "Content-Length" HeaderContentType = "Content-Type" HeaderDate = "Date" HeaderForwarded = "Forwarded" @@ -146,7 +150,17 @@ type Proxy struct { // WebSocketIdleTimeout is the timeout for idle WebSocket connections. WebSocketIdleTimeout time.Duration - mux *http.ServeMux + // CertificateAuthority can issue certificates for man-in-the-middle connections. + CertificateAuthority ca.CertificateAuthority + + // Storage for resolving clients/groups + Storage dataset.Storage + + mux *http.ServeMux + closed chan struct{} + closeOnce sync.Once + mu sync.RWMutex + listeners []net.Listener } // New [Proxy] with somewhat sane defaults. @@ -157,6 +171,7 @@ func New() *Proxy { IdleTimeout: DefaultIdleTimeout, WebSocketIdleTimeout: DefaultWebSocketIdleTimeout, mux: http.NewServeMux(), + closed: make(chan struct{}, 1), } // Make sure the roundtripper uses our dialers. @@ -181,6 +196,55 @@ func New() *Proxy { return p } +func (p *Proxy) Close() error { + var closeListeners bool + p.closeOnce.Do(func() { + close(p.closed) + closeListeners = true + }) + if closeListeners { + p.mu.RLock() + for _, l := range p.listeners { + _ = l.Close() + } + p.mu.RUnlock() + } + return nil +} + +func (p *Proxy) isClosed() bool { + select { + case <-p.closed: + return true + default: + return false + } +} + +func (p *Proxy) addListener(l net.Listener) { + if l == nil { + return + } + p.mu.Lock() + p.listeners = append(p.listeners, l) + p.mu.Unlock() +} + +func (p *Proxy) removeListener(l net.Listener) { + if l == nil { + return + } + p.mu.Lock() + listeners := make([]net.Listener, 0, len(p.listeners)-1) + for _, o := range p.listeners { + if o != l { + listeners = append(listeners, o) + } + } + p.listeners = listeners + p.mu.Unlock() +} + // Handle installs a [http.Handler] into the internal mux. func (p *Proxy) Handle(pattern string, handler http.Handler) { p.mux.Handle(pattern, handler) @@ -214,11 +278,23 @@ func (p *Proxy) dial(ctx context.Context, req *http.Request) (net.Conn, error) { // Serve proxied connections on the specified listener. func (p *Proxy) Serve(l net.Listener) error { + p.addListener(l) + defer p.removeListener(l) for { + if p.isClosed() { + return nil + } + c, err := l.Accept() if err != nil { return err } + + if p.isClosed() { + _ = c.Close() + return nil + } + go p.handle(c) } } @@ -229,6 +305,7 @@ func (p *Proxy) handle(nc net.Conn) { ctx = NewContext(nc).(*proxyContext) err error ) + defer func() { if r := recover(); r != nil { if err, ok := r.(error); ok { @@ -266,6 +343,8 @@ func (p *Proxy) handle(nc net.Conn) { // Propagate timeouts ctx.SetIdleTimeout(p.IdleTimeout) + ctx.ca = p.CertificateAuthority + ctx.storage = p.Storage for _, f := range p.OnConnect { fc, err := f.HandleConn(ctx) @@ -282,6 +361,15 @@ func (p *Proxy) handle(nc net.Conn) { } log := ctx.LogEntry() + if p.Storage != nil { + if client, err := p.Storage.ClientByIP(nc.RemoteAddr().(*net.TCPAddr).IP); err == nil { + log = log.Values(logger.Values{ + "client_id": client.ID, + "client_network": client.String(), + "client_description": client.Description, + }) + } + } for { if ctx.transparentTLS { ctx.req = &http.Request{ @@ -344,7 +432,7 @@ func (p *Proxy) handle(nc net.Conn) { } if err = p.handleRequest(ctx); err != nil { - p.handleError(ctx, err, true) + p.handleError(ctx, err, !netutil.IsClosing(err)) return } @@ -511,7 +599,8 @@ func (p *Proxy) serveForward(ctx *proxyContext) (err error) { _ = ctx.Close() return fmt.Errorf("proxy: forward %s error: %w", ctx.req.URL, err) } - } else { + } + if res != nil { ctx.res = res } @@ -571,28 +660,44 @@ func (p *Proxy) serveWebSocket(ctx *proxyContext) (err error) { return p.multiplex(ctx, srv) } -func (p *Proxy) multiplex(ctx, srv Context) (err error) { +func (p *Proxy) multiplex(ctx, srv *proxyContext) (err error) { var ( + log = ctx.LogEntry().Value("server", srv.RemoteAddr().String()) errs = make(chan error, 1) done = make(chan struct{}, 1) ) go func(errs chan<- error) { - defer close(done) - if _, err := io.Copy(srv, ctx); err != nil { + if _, err := io.Copy(ctx, srv); err != nil && !netutil.IsClosing(err) { + log.Err(err).Trace("Multiplexing closed in client->server") errs <- err + } else { + log.Trace("Multiplexing closed in client->server") } }(errs) + go func(errs chan<- error) { - if _, err := io.Copy(ctx, srv); err != nil { + defer close(done) + if _, err := io.Copy(srv, ctx); err != nil && !netutil.IsClosing(err) { + log.Err(err).Trace("Multiplexing closed in server->client") errs <- err + } else { + log.Trace("Multiplexing closed in server->client") } }(errs) + defer func() { + log.Trace("Multiplexing done, force-closing client and server connections") + _ = ctx.Close() + _ = srv.Close() + }() + select { case err = <-errs: return case <-done: - return + return io.EOF // multiplexing never recycles connection + case <-p.closed: + return io.EOF // server closed } } diff --git a/stats/handler.go b/stats/handler.go new file mode 100644 index 0000000..fce13eb --- /dev/null +++ b/stats/handler.go @@ -0,0 +1,213 @@ +package stats + +import ( + "encoding/json" + "expvar" + "fmt" + "net/http" + "sort" + "strings" + + "html/template" +) + +var ( + page = template.Must(template.New(""). + Funcs(template.FuncMap{"path": path, "duration": duration}). + Parse(` + + +
__ __ +.--------..-----.| |_ .----.|__|.----..-----. +| || -__|| _|| _|| || __||__ --| +|__|__|__||_____||____||__| |__||____||_____| + + +
count | ||
---|---|---|
{{ printf "%.2g" .count }} | ||
mean | min | max |
{{printf "%.2g" .mean}} | {{printf "%.2g" .min}} | {{printf "%.2g" .max}} |
P.50 | P.90 | P.99 |
{{printf "%.2g" .p50}} | {{printf "%.2g" .p90}} | {{printf "%.2g" .p99}} |