Checkpoint

This commit is contained in:
2025-10-06 22:25:23 +02:00
parent a23259cfdc
commit a254b306f2
48 changed files with 3327 additions and 212 deletions

1
dataset/base.go Normal file
View File

@@ -0,0 +1 @@
package dataset

25
dataset/error.go Normal file
View File

@@ -0,0 +1,25 @@
package dataset
import (
"errors"
"fmt"
"os"
"github.com/mjl-/bstore"
)
type ErrNotExist struct {
Object string
ID int64
}
func (err ErrNotExist) Error() string {
return fmt.Sprintf("storage: %s not found", err.Object)
}
func IsNotExist(err error) bool {
if err == nil {
return false
}
return os.IsNotExist(err) || errors.Is(err, ErrNotExist{}) || errors.Is(err, bstore.ErrAbsent)
}

53
dataset/parser/adblock.go Normal file
View File

@@ -0,0 +1,53 @@
package parser
import (
"bufio"
"io"
"strings"
)
func init() {
RegisterDomainsParser(adblockDomainsParser{})
}
type adblockDomainsParser struct{}
func (adblockDomainsParser) CanHandle(line string) bool {
return strings.HasPrefix(strings.ToLower(line), `[adblock`) ||
strings.HasPrefix(line, "@@") || // exception
strings.HasPrefix(line, "||") || // blah
line[0] == '*'
}
func (adblockDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if isComment(line) {
continue
}
// Common AdBlock patterns:
// ||domain.com^
// |http://domain.com|
// domain.com/path
// *domain.com*
switch {
case strings.HasPrefix(line, `||`): // domain anchor
if i := strings.IndexByte(line, '^'); i != -1 {
domains = append(domains, line[2:i])
continue
}
case strings.HasPrefix(line, `|`) && strings.HasSuffix(line, `|`):
domains = append(domains, line[1:len(line)-2])
continue
case strings.HasPrefix(line, `[`):
continue
}
ignored++
}
if err = scanner.Err(); err != nil {
return
}
return unique(domains), ignored, nil
}

View File

@@ -0,0 +1,41 @@
package parser
import (
"reflect"
"sort"
"strings"
"testing"
)
func TestAdBlockParser(t *testing.T) {
test := `[Adblock Plus 2.0]
! Title: AdRules DNS List
! Homepage: https://github.com/Cats-Team/AdRules
! Powerd by Cats-Team
! Expires: 1 (update frequency)
! Description: The DNS Filters
! Total count: 145270
! Update: 2025-10-07 02:05:08(GMT+8)
/^.+stat\.kugou\.com/
/^admarvel\./
||*-ad-sign.byteimg.com^
||*-ad.a.yximgs.com^
||*-applog.fqnovel.com^
||*-datareceiver.aki-game.net^
||*.exaapi.com^`
want := []string{"*-ad-sign.byteimg.com", "*-ad.a.yximgs.com", "*-applog.fqnovel.com", "*-datareceiver.aki-game.net", "*.exaapi.com"}
parsed, ignored, err := ParseDomains(strings.NewReader(test))
if err != nil {
t.Fatal(err)
return
}
sort.Strings(parsed)
if !reflect.DeepEqual(parsed, want) {
t.Errorf("expected ParseDomains(domains) to return %v, got %v", want, parsed)
}
if ignored != 2 {
t.Errorf("expected 2 ignored, got %d", ignored)
}
}

139
dataset/parser/dns.go Normal file
View File

@@ -0,0 +1,139 @@
package parser
import (
"bufio"
"io"
"strings"
"github.com/miekg/dns"
)
func init() {
RegisterDomainsParser(dnsmasqDomainsParser{})
RegisterDomainsParser(mosDNSDomainsParser{})
RegisterDomainsParser(smartDNSDomainsParser{})
RegisterDomainsParser(unboundDomainsParser{})
}
type dnsmasqDomainsParser struct{}
func (dnsmasqDomainsParser) CanHandle(line string) bool {
return strings.HasPrefix(line, "address=/")
}
func (dnsmasqDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if isComment(line) {
continue
}
switch {
case strings.HasPrefix(line, "address=/"):
part := strings.FieldsFunc(line, func(r rune) bool { return r == '/' })
if len(part) >= 3 && isDomainName(part[1]) {
domains = append(domains, part[1])
continue
}
}
ignored++
}
if err = scanner.Err(); err != nil {
return
}
return unique(domains), ignored, nil
}
type mosDNSDomainsParser struct{}
func (mosDNSDomainsParser) CanHandle(line string) bool {
if strings.HasPrefix(line, "domain:") {
return isDomainName(line[7:])
}
return false
}
func (mosDNSDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if isComment(line) {
continue
}
if strings.HasPrefix(line, "domain:") {
domains = append(domains, line[7:])
continue
}
ignored++
}
if err = scanner.Err(); err != nil {
return
}
return unique(domains), ignored, nil
}
type smartDNSDomainsParser struct{}
func (smartDNSDomainsParser) CanHandle(line string) bool {
return strings.HasPrefix(line, "address /")
}
func (smartDNSDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if isComment(line) {
continue
}
if strings.HasPrefix(line, "address /") {
if i := strings.IndexByte(line[9:], '/'); i > -1 {
domains = append(domains, line[9:i+9])
continue
}
}
ignored++
}
if err = scanner.Err(); err != nil {
return
}
return unique(domains), ignored, nil
}
type unboundDomainsParser struct{}
func (unboundDomainsParser) CanHandle(line string) bool {
return strings.HasPrefix(line, "local-data:") ||
strings.HasPrefix(line, "local-zone:")
}
func (unboundDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if isComment(line) {
continue
}
switch {
case strings.HasPrefix(line, "local-data:"):
record := strings.Trim(strings.TrimSpace(line[11:]), `"`)
if rr, err := dns.NewRR(record); err == nil {
switch rr.Header().Rrtype {
case dns.TypeA, dns.TypeAAAA, dns.TypeCNAME:
domains = append(domains, strings.Trim(rr.Header().Name, `.`))
continue
}
}
case strings.HasPrefix(line, "local-zone:") && strings.HasSuffix(line, " reject"):
line = strings.Trim(strings.TrimSpace(line[11:]), `"`)
if i := strings.IndexByte(line, '"'); i > -1 {
domains = append(domains, line[:i])
continue
}
}
ignored++
}
if err = scanner.Err(); err != nil {
return
}
return unique(domains), ignored, nil
}

106
dataset/parser/dns_test.go Normal file
View File

@@ -0,0 +1,106 @@
package parser
import (
"reflect"
"sort"
"strings"
"testing"
)
func TestDNSMasqParser(t *testing.T) {
tests := []struct {
Name string
Test string
Want []string
WantIgnored int
}{
{
"data",
`
local-data: "junk1.doubleclick.net A 127.0.0.1"
local-data: "junk2.doubleclick.net A 127.0.0.1"
local-data: "junk2.doubleclick.net CNAME doubleclick.net."
local-data: "junk6.doubleclick.net AAAA ::1"
local-data: "doubleclick.net A 127.0.0.1"
local-data: "ad.junk1.doubleclick.net A 127.0.0.1"
local-data: "adjunk.google.com A 127.0.0.1"`,
[]string{"ad.junk1.doubleclick.net", "adjunk.google.com", "doubleclick.net", "junk1.doubleclick.net", "junk2.doubleclick.net", "junk6.doubleclick.net"},
0,
},
{
"zone",
`
local-zone: "doubleclick.net" reject
local-zone: "adjunk.google.com" reject`,
[]string{"adjunk.google.com", "doubleclick.net"},
0,
},
{
"address",
`
address=/ziyu.net/0.0.0.0
address=/zlp6s.pw/0.0.0.0
address=/zm232.com/0.0.0.0
`,
[]string{"ziyu.net", "zlp6s.pw", "zm232.com"},
0,
},
}
for _, test := range tests {
t.Run(test.Name, func(it *testing.T) {
parsed, ignored, err := ParseDomains(strings.NewReader(test.Test))
if err != nil {
t.Fatal(err)
return
}
sort.Strings(parsed)
if !reflect.DeepEqual(parsed, test.Want) {
t.Errorf("expected ParseDomains(dnsmasq) to return\n\t%v, got\n\t%v", test.Want, parsed)
}
if ignored != test.WantIgnored {
t.Errorf("expected %d ignored, got %d", test.WantIgnored, ignored)
}
})
}
}
func TestMOSDNSParser(t *testing.T) {
test := `domain:0019x.com
domain:002777.xyz
domain:003store.com
domain:00404850.xyz`
want := []string{"0019x.com", "002777.xyz", "003store.com", "00404850.xyz"}
parsed, _, err := ParseDomains(strings.NewReader(test))
if err != nil {
t.Fatal(err)
return
}
sort.Strings(parsed)
if !reflect.DeepEqual(parsed, want) {
t.Errorf("expected ParseDomains(domains) to return %v, got %v", want, parsed)
}
}
func TestSmartDNSParser(t *testing.T) {
test := `# Title:AdRules SmartDNS List
# Update: 2025-10-07 02:05:08(GMT+8)
address /0.myikas.com/#
address /0.net.easyjet.com/#
address /0.nextyourcontent.com/#
address /0019x.com/#`
want := []string{"0.myikas.com", "0.net.easyjet.com", "0.nextyourcontent.com", "0019x.com"}
parsed, _, err := ParseDomains(strings.NewReader(test))
if err != nil {
t.Fatal(err)
return
}
sort.Strings(parsed)
if !reflect.DeepEqual(parsed, want) {
t.Errorf("expected ParseDomains(domains) to return %v, got %v", want, parsed)
}
}

40
dataset/parser/domains.go Normal file
View File

@@ -0,0 +1,40 @@
package parser
import (
"bufio"
"io"
"net"
"strings"
)
func init() {
domainsParsers = append(domainsParsers, domainsParser{})
}
type domainsParser struct{}
func (domainsParser) CanHandle(line string) bool {
return isDomainName(line) &&
!strings.ContainsRune(line, ' ') &&
!strings.ContainsRune(line, ':') &&
net.ParseIP(line) == nil
}
func (domainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if isComment(line) {
continue
}
if isDomainName(line) {
domains = append(domains, line)
continue
}
ignored++
}
if err = scanner.Err(); err != nil {
return
}
return unique(domains), ignored, nil
}

View File

@@ -0,0 +1,31 @@
package parser
import (
"reflect"
"sort"
"strings"
"testing"
)
func TestParseDomains(t *testing.T) {
test := `# This is a comment
facebook.com
tiktok.com
bogus ignored
youtube.com`
want := []string{"facebook.com", "tiktok.com", "youtube.com"}
parsed, ignored, err := ParseDomains(strings.NewReader(test))
if err != nil {
t.Fatal(err)
return
}
sort.Strings(parsed)
if !reflect.DeepEqual(parsed, want) {
t.Errorf("expected ParseDomains(domains) to return %v, got %v", want, parsed)
}
if ignored != 1 {
t.Errorf("expected 1 ignored, got %d", ignored)
}
}

41
dataset/parser/hosts.go Normal file
View File

@@ -0,0 +1,41 @@
package parser
import (
"bufio"
"io"
"net"
"strings"
)
func init() {
RegisterDomainsParser(hostsParser{})
}
type hostsParser struct{}
func (hostsParser) CanHandle(line string) bool {
part := strings.Fields(line)
return len(part) >= 2 && net.ParseIP(part[0]) != nil
}
func (hostsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if isComment(line) {
continue
}
part := strings.Fields(line)
if len(part) >= 2 && net.ParseIP(part[0]) != nil {
domains = append(domains, part[1:]...)
continue
}
ignored++
}
if err = scanner.Err(); err != nil {
return
}
return unique(domains), ignored, nil
}

View File

@@ -0,0 +1,38 @@
package parser
import (
"reflect"
"sort"
"strings"
"testing"
)
func TestParseHosts(t *testing.T) {
test := `##
# Host Database
#
# localhost is used to configure the loopback interface
# when the system is booting. Do not change this entry.
##
127.0.0.1 localhost dragon dragon.local dragon.maze.network
255.255.255.255 broadcasthost
::1 localhost
ff00::1 multicast
1.2.3.4
`
want := []string{"broadcasthost", "dragon", "dragon.local", "dragon.maze.network", "localhost", "multicast"}
parsed, ignored, err := ParseDomains(strings.NewReader(test))
if err != nil {
t.Fatal(err)
return
}
sort.Strings(parsed)
if !reflect.DeepEqual(parsed, want) {
t.Errorf("expected ParseDomains(hosts) to return %v, got %v", want, parsed)
}
if ignored != 1 {
t.Errorf("expected 1 ignored, got %d", ignored)
}
}

76
dataset/parser/parser.go Normal file
View File

@@ -0,0 +1,76 @@
package parser
import (
"bufio"
"bytes"
"errors"
"io"
"log"
"strings"
"github.com/miekg/dns"
)
var ErrNoParser = errors.New("no suitable parser could be found")
type Parser interface {
CanHandle(line string) bool
}
type DomainsParser interface {
Parser
ParseDomains(io.Reader) (domains []string, ignored int, err error)
}
var domainsParsers []DomainsParser
func RegisterDomainsParser(parser DomainsParser) {
domainsParsers = append(domainsParsers, parser)
}
func ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
var (
buffer = new(bytes.Buffer)
scanner = bufio.NewScanner(io.TeeReader(r, buffer))
line string
parser DomainsParser
)
for scanner.Scan() {
line = strings.TrimSpace(scanner.Text())
if isComment(line) {
continue
}
for _, parser = range domainsParsers {
if parser.CanHandle(line) {
log.Printf("using parser %T", parser)
return parser.ParseDomains(io.MultiReader(buffer, r))
}
}
break
}
return nil, 0, ErrNoParser
}
func isComment(line string) bool {
return line == "" || line[0] == '#' || line[0] == '!'
}
func isDomainName(name string) bool {
n, ok := dns.IsDomainName(name)
return n >= 2 && ok
}
func unique(strings []string) []string {
if strings == nil {
return nil
}
v := make(map[string]struct{})
for _, s := range strings {
v[s] = struct{}{}
}
o := make([]string, 0, len(v))
for k := range v {
o = append(o, k)
}
return o
}

View File

@@ -0,0 +1,31 @@
package parser
import (
"reflect"
"sort"
"testing"
)
func TestUnique(t *testing.T) {
tests := []struct {
Name string
Test []string
Want []string
}{
{"nil", nil, nil},
{"single", []string{"test"}, []string{"test"}},
{"duplicate", []string{"test", "test"}, []string{"test"}},
{"multiple", []string{"a", "a", "b", "b", "b", "c"}, []string{"a", "b", "c"}},
}
for _, test := range tests {
t.Run(test.Name, func(it *testing.T) {
v := unique(test.Test)
if v != nil {
sort.Strings(v)
}
if !reflect.DeepEqual(v, test.Want) {
it.Errorf("expected unique(%v) to return %v, got %v", test.Test, test.Want, v)
}
})
}
}

231
dataset/storage.go Normal file
View File

@@ -0,0 +1,231 @@
package dataset
import (
"bufio"
"bytes"
"fmt"
"io"
"io/fs"
"net"
"net/http"
"net/url"
"os"
"slices"
"strings"
"time"
_ "github.com/mattn/go-sqlite3" // SQLite3 driver
"github.com/miekg/dns"
)
type Storage interface {
Groups() ([]Group, error)
GroupByID(int64) (Group, error)
GroupByName(name string) (Group, error)
SaveGroup(*Group) error
DeleteGroup(Group) error
Clients() (Clients, error)
ClientByID(int64) (Client, error)
ClientByIP(net.IP) (Client, error)
SaveClient(*Client) error
DeleteClient(Client) error
Lists() ([]List, error)
ListByID(int64) (List, error)
SaveList(*List) error
DeleteList(List) error
}
type Group struct {
ID int64 `json:"id"`
Name string `json:"name" bstore:"nonzero,unique"`
IsEnabled bool `json:"is_enabled" bstore:"nonzero"`
Description string `json:"description"`
CreatedAt time.Time `json:"created_at" bstore:"nonzero"`
UpdatedAt time.Time `json:"updated_at" bstore:"nonzero"`
Storage Storage `json:"-" bstore:"-"`
}
type Client struct {
ID int64 `json:"id"`
Network string `json:"network" bstore:"nonzero,index"`
IP string `json:"ip" bstore:"nonzero,unique IP+Mask"`
Mask int `json:"mask"`
Description string `json:"description"`
Groups []Group `json:"groups,omitempty" bstore:"-"`
CreatedAt time.Time `json:"created_at" bstore:"nonzero"`
UpdatedAt time.Time `json:"updated_at" bstore:"nonzero"`
Storage Storage `json:"-" bstore:"-"`
}
type WithClient interface {
Client() (Client, error)
}
type ClientGroup struct {
ID int64 `json:"id"`
ClientID int64 `json:"client_id" bstore:"ref Client,index"`
GroupID int64 `json:"group_id" bstore:"ref Group,index"`
}
func (c *Client) ContainsIP(ip net.IP) bool {
ipnet := &net.IPNet{
IP: net.ParseIP(c.IP),
Mask: net.CIDRMask(int(c.Mask), 32),
}
if ipnet.IP == nil {
return false
}
return ipnet.Contains(ip)
}
func (c *Client) String() string {
ipnet := &net.IPNet{
IP: net.ParseIP(c.IP),
Mask: net.CIDRMask(int(c.Mask), 32),
}
return ipnet.String()
}
type Clients []Client
func (cs Clients) ByIP(ip net.IP) *Client {
var candidates []*Client
for _, c := range cs {
if c.ContainsIP(ip) {
candidates = append(candidates, &c)
}
}
switch len(candidates) {
case 0:
return nil
case 1:
return candidates[0]
default:
slices.SortStableFunc(candidates, func(a, b *Client) int {
return int(b.Mask) - int(a.Mask)
})
return candidates[0]
}
}
const (
ListTypeDomain = "domain"
ListTypeNetwork = "network"
)
const (
MinListRefresh = 1 * time.Minute
DefaultListRefresh = 30 * time.Minute
)
type List struct {
ID int64 `json:"id"`
Type string `json:"type"`
Source string `json:"source"`
IsEnabled bool `json:"is_enabled"`
Permit bool `json:"permit"`
Groups []Group `json:"groups,omitempty" bstore:"-"`
Status int `json:"status"`
Comment string `json:"comment"`
Cache []byte `json:"cache"`
Refresh time.Duration `json:"refresh"`
LastModified time.Time `json:"last_modified"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (list *List) Domains() (*DomainTree, error) {
if list.Type != ListTypeDomain {
return nil, nil
}
var (
tree = NewDomainList()
scan = bufio.NewScanner(bytes.NewReader(list.Cache))
)
for scan.Scan() {
line := strings.TrimSpace(scan.Text())
if line == "" || line[0] == '#' {
continue
}
if labels, ok := dns.IsDomainName(line); ok && labels >= 2 {
tree.Add(line)
}
}
if err := scan.Err(); err != nil {
return nil, err
}
return tree, nil
}
func (list *List) Update() (updated bool, err error) {
u, err := url.Parse(list.Source)
if err != nil {
return false, err
}
switch u.Scheme {
case "", "file":
return list.updateFile(u.Path)
case "http", "https":
return list.updateHTTP(u.String())
default:
return false, fmt.Errorf("dataset: don't know how to update %s sources", u.Scheme)
}
}
func (list *List) updateFile(name string) (updated bool, err error) {
var info fs.FileInfo
if info, err = os.Stat(name); err != nil {
return
} else if info.IsDir() {
return false, fmt.Errorf("dataset: list %d: %q is a directory", list.ID, name)
}
if updated = info.ModTime().After(list.UpdatedAt); !updated {
return
}
list.Cache, _ = os.ReadFile(name)
return
}
func (list *List) updateHTTP(url string) (updated bool, err error) {
if updated, err = list.shouldUpdateHTTP(url); err != nil || !updated {
return
}
var response *http.Response
if response, err = http.DefaultClient.Get(url); err != nil {
return
}
defer response.Body.Close()
if list.Cache, err = io.ReadAll(response.Body); err != nil {
return
}
return true, nil
}
func (list *List) shouldUpdateHTTP(url string) (updated bool, err error) {
var response *http.Response
if response, err = http.DefaultClient.Head(url); err != nil {
return
}
defer response.Body.Close()
if value := response.Header.Get("Last-Modified"); value != "" {
var lastModified time.Time
if lastModified, err = time.Parse(http.TimeFormat, value); err == nil {
return lastModified.After(list.LastModified), nil
}
}
// There are no headers that would indicate last-modified time, so assume we have to update:
return true, nil
}
type ListGroup struct {
ID int64 `json:"id"`
ListID int64 `json:"list_id" bstore:"ref List,index"`
GroupID int64 `json:"group_id" bstore:"ref Group,index"`
}

412
dataset/storage_bstore.go Normal file
View File

@@ -0,0 +1,412 @@
package dataset
import (
"context"
"errors"
"fmt"
"net"
"path/filepath"
"slices"
"strings"
"time"
"git.maze.io/maze/styx/logger"
"github.com/mjl-/bstore"
)
type bstoreStorage struct {
db *bstore.DB
path string
}
func OpenBStore(name string) (Storage, error) {
if !filepath.IsAbs(name) {
var err error
if name, err = filepath.Abs(name); err != nil {
return nil, err
}
}
ctx := context.Background()
db, err := bstore.Open(ctx, name, nil,
Group{},
Client{},
ClientGroup{},
List{},
ListGroup{},
)
if err != nil {
return nil, err
}
var (
s = &bstoreStorage{db: db, path: name}
defaultGroup Group
defaultClient4 Client
defaultClient6 Client
)
if defaultGroup, err = s.GroupByName("Default"); errors.Is(err, bstore.ErrAbsent) {
defaultGroup = Group{
Name: "Default",
IsEnabled: true,
Description: "Default group",
}
if err = s.SaveGroup(&defaultGroup); err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
if defaultClient4, err = bstore.QueryDB[Client](ctx, db).
FilterEqual("Network", "ipv4").
FilterFn(func(client Client) bool {
return net.ParseIP(client.IP).Equal(net.ParseIP("0.0.0.0")) && client.Mask == 0
}).Get(); errors.Is(err, bstore.ErrAbsent) {
defaultClient4 = Client{
Network: "ipv4",
IP: "0.0.0.0",
Mask: 0,
Description: "All IPv4 clients",
}
if err = s.SaveClient(&defaultClient4); err != nil {
return nil, err
}
if err = s.db.Insert(ctx, &ClientGroup{ClientID: defaultClient4.ID, GroupID: defaultGroup.ID}); err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
if defaultClient6, err = bstore.QueryDB[Client](ctx, db).
FilterEqual("Network", "ipv6").
FilterFn(func(client Client) bool {
return net.ParseIP(client.IP).Equal(net.ParseIP("::")) && client.Mask == 0
}).Get(); errors.Is(err, bstore.ErrAbsent) {
defaultClient6 = Client{
Network: "ipv6",
IP: "::",
Mask: 0,
Description: "All IPv6 clients",
}
if err = s.SaveClient(&defaultClient6); err != nil {
return nil, err
}
if err = s.db.Insert(ctx, &ClientGroup{ClientID: defaultClient6.ID, GroupID: defaultGroup.ID}); err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
// Start updater
NewUpdater(s)
return s, nil
}
func (s *bstoreStorage) log() logger.Structured {
return logger.StandardLog.Values(logger.Values{
"storage": "bstore",
"storage_path": s.path,
})
}
func (s *bstoreStorage) Groups() ([]Group, error) {
var (
ctx = context.Background()
query = bstore.QueryDB[Group](ctx, s.db)
groups = make([]Group, 0)
)
for group := range query.All() {
groups = append(groups, group)
}
if err := query.Err(); err != nil && !errors.Is(err, bstore.ErrFinished) {
return nil, err
}
return groups, nil
}
func (s *bstoreStorage) GroupByID(id int64) (Group, error) {
ctx := context.Background()
return bstore.QueryDB[Group](ctx, s.db).FilterID(id).Get()
}
func (s *bstoreStorage) GroupByName(name string) (Group, error) {
ctx := context.Background()
return bstore.QueryDB[Group](ctx, s.db).FilterFn(func(group Group) bool {
return strings.EqualFold(group.Name, name)
}).Get()
}
func (s *bstoreStorage) SaveGroup(group *Group) (err error) {
ctx := context.Background()
group.UpdatedAt = time.Now().UTC()
if group.CreatedAt.Equal(time.Time{}) {
group.CreatedAt = group.UpdatedAt
err = s.db.Insert(ctx, group)
} else {
err = s.db.Update(ctx, group)
}
if err != nil {
return fmt.Errorf("dataset: save group %s failed: %w", group.Name, err)
}
return nil
}
func (s *bstoreStorage) DeleteGroup(group Group) (err error) {
ctx := context.Background()
tx, err := s.db.Begin(ctx, true)
if err != nil {
return err
}
if _, err = bstore.QueryTx[ClientGroup](tx).FilterEqual("GroupID", group.ID).Delete(); err != nil {
return
}
if _, err = bstore.QueryTx[ListGroup](tx).FilterEqual("GroupID", group.ID).Delete(); err != nil {
return
}
if err = tx.Delete(group); err != nil {
return
}
return tx.Commit()
}
func (s *bstoreStorage) Clients() (Clients, error) {
var (
ctx = context.Background()
query = bstore.QueryDB[Client](ctx, s.db)
clients = make(Clients, 0)
)
for client := range query.All() {
clients = append(clients, client)
}
if err := query.Err(); err != nil && !errors.Is(err, bstore.ErrFinished) {
return nil, err
}
return clients, nil
}
func (s *bstoreStorage) ClientByID(id int64) (Client, error) {
ctx := context.Background()
client, err := bstore.QueryDB[Client](ctx, s.db).FilterID(id).Get()
if err != nil {
return client, err
}
return s.clientResolveGroups(ctx, client)
}
func (s *bstoreStorage) ClientByIP(ip net.IP) (Client, error) {
if ip == nil {
return Client{}, ErrNotExist{Object: "client"}
}
var (
ctx = context.Background()
clients Clients
network string
)
if ip4 := ip.To4(); ip4 != nil {
network = "ipv4"
} else if ip6 := ip.To16(); ip6 != nil {
network = "ipv6"
}
if network == "" {
return Client{}, ErrNotExist{Object: "client"}
}
for client, err := range bstore.QueryDB[Client](ctx, s.db).
FilterEqual("Network", network).
FilterFn(func(client Client) bool {
return client.ContainsIP(ip)
}).All() {
if err != nil {
return Client{}, err
}
clients = append(clients, client)
}
var client Client
switch len(clients) {
case 0:
return Client{}, ErrNotExist{Object: "client"}
case 1:
client = clients[0]
default:
slices.SortStableFunc(clients, func(a, b Client) int {
return int(b.Mask) - int(a.Mask)
})
client = clients[0]
}
return s.clientResolveGroups(ctx, client)
}
func (s *bstoreStorage) clientResolveGroups(ctx context.Context, client Client) (Client, error) {
for clientGroup, err := range bstore.QueryDB[ClientGroup](ctx, s.db).FilterEqual("ClientID", client.ID).All() {
if err != nil {
return Client{}, err
}
if group, err := s.GroupByID(clientGroup.GroupID); err == nil {
client.Groups = append(client.Groups, group)
}
}
return client, nil
}
func (s *bstoreStorage) SaveClient(client *Client) (err error) {
log := s.log()
ctx := context.Background()
client.UpdatedAt = time.Now().UTC()
tx, err := s.db.Begin(ctx, true)
if err != nil {
return err
}
log = log.Values(logger.Values{"ip": client.IP, "mask": client.Mask, "description": client.Description})
if client.CreatedAt.Equal(time.Time{}) {
log.Debug("Create client")
client.CreatedAt = client.UpdatedAt
if err = tx.Insert(client); err != nil {
return fmt.Errorf("dataset: client insert failed: %w", err)
}
} else {
log.Debug("Update client")
if err = tx.Update(client); err != nil {
return fmt.Errorf("dataset: client update failed: %w", err)
}
}
var deleted int
if deleted, err = bstore.QueryTx[ClientGroup](tx).FilterEqual("ClientID", client.ID).Delete(); err != nil {
return fmt.Errorf("dataset: client groups delete failed: %w", err)
}
log.Debugf("Deleted %d groups", deleted)
log.Debugf("Linking %d groups", len(client.Groups))
for _, group := range client.Groups {
if err = tx.Insert(&ClientGroup{ClientID: client.ID, GroupID: group.ID}); err != nil {
return fmt.Errorf("dataset: client groups insert failed: %w", err)
}
}
return tx.Commit()
}
func (s *bstoreStorage) DeleteClient(client Client) (err error) {
ctx := context.Background()
tx, err := s.db.Begin(ctx, true)
if err != nil {
return err
}
if _, err = bstore.QueryTx[ClientGroup](tx).FilterEqual("ClientID", client.ID).Delete(); err != nil {
return
}
if err = tx.Delete(client); err != nil {
return
}
return tx.Commit()
}
func (s *bstoreStorage) Lists() ([]List, error) {
var (
ctx = context.Background()
query = bstore.QueryDB[List](ctx, s.db)
lists = make([]List, 0)
)
for list := range query.All() {
lists = append(lists, list)
}
if err := query.Err(); err != nil && !errors.Is(err, bstore.ErrFinished) {
return nil, err
}
return lists, nil
}
func (s *bstoreStorage) ListByID(id int64) (List, error) {
ctx := context.Background()
list, err := bstore.QueryDB[List](ctx, s.db).FilterID(id).Get()
if err != nil {
return list, err
}
return s.listResolveGroups(ctx, list)
}
func (s *bstoreStorage) listResolveGroups(ctx context.Context, list List) (List, error) {
for listGroup, err := range bstore.QueryDB[ListGroup](ctx, s.db).FilterEqual("ListID", list.ID).All() {
if err != nil {
return List{}, err
}
if group, err := s.GroupByID(listGroup.GroupID); err == nil {
list.Groups = append(list.Groups, group)
}
}
return list, nil
}
func (s *bstoreStorage) SaveList(list *List) (err error) {
if list.Type != ListTypeDomain && list.Type != ListTypeNetwork {
return fmt.Errorf("storage: unknown list type %q", list.Type)
}
if list.Refresh == 0 {
list.Refresh = DefaultListRefresh
} else if list.Refresh < MinListRefresh {
list.Refresh = MinListRefresh
}
list.UpdatedAt = time.Now().UTC()
ctx := context.Background()
tx, err := s.db.Begin(ctx, true)
if err != nil {
return err
}
log := s.log()
log = log.Values(logger.Values{
"type": list.Type,
"source": list.Source,
"is_enabled": list.IsEnabled,
"status": list.Status,
"cache": len(list.Cache),
"refresh": list.Refresh,
})
if list.CreatedAt.Equal(time.Time{}) {
log.Debug("Creating list")
list.CreatedAt = list.UpdatedAt
if err = tx.Insert(list); err != nil {
return fmt.Errorf("dataset: list insert failed: %w", err)
}
} else {
log.Debug("Updating list")
if err = tx.Update(list); err != nil {
return fmt.Errorf("dataset: list update failed: %w", err)
}
}
var deleted int
if deleted, err = bstore.QueryTx[ListGroup](tx).FilterEqual("ListID", list.ID).Delete(); err != nil {
return fmt.Errorf("dataset: list groups delete failed: %w", err)
}
log.Debugf("Deleted %d groups", deleted)
log.Debugf("Linking %d groups", len(list.Groups))
for _, group := range list.Groups {
if err = tx.Insert(&ListGroup{ListID: list.ID, GroupID: group.ID}); err != nil {
return fmt.Errorf("dataset: list groups insert failed: %w", err)
}
}
return tx.Commit()
}
func (s *bstoreStorage) DeleteList(list List) (err error) {
ctx := context.Background()
tx, err := s.db.Begin(ctx, true)
if err != nil {
return err
}
if _, err = bstore.QueryTx[ListGroup](tx).FilterEqual("ListID", list.ID).Delete(); err != nil {
return
}
if err = tx.Delete(list); err != nil {
return
}
return tx.Commit()
}

226
dataset/updater.go Normal file
View File

@@ -0,0 +1,226 @@
package dataset
import (
"bytes"
"io"
"net/http"
"net/url"
"os"
"sync"
"time"
"git.maze.io/maze/styx/logger"
)
type Updater struct {
storage Storage
lists sync.Map // map[int64]List
updaters sync.Map // map[int64]*updaterJob
done chan struct{}
}
func NewUpdater(storage Storage) *Updater {
u := &Updater{
storage: storage,
done: make(chan struct{}, 1),
}
go u.refresh()
return u
}
func (u *Updater) Close() error {
select {
case <-u.done:
return nil
default:
close(u.done)
return nil
}
}
func (u *Updater) refresh() {
check := time.NewTicker(time.Second)
defer check.Stop()
var (
log = logger.StandardLog
)
for {
select {
case <-u.done:
log.Debug("Updater closing, stopping updaters...")
u.updaters.Range(func(key, value any) bool {
if value != nil {
close(value.(*updaterJob).done)
}
return true
})
return
case now := <-check.C:
u.check(now, log)
}
}
}
func (u *Updater) check(now time.Time, log logger.Structured) (wait time.Duration) {
log.Trace("Checking lists")
lists, err := u.storage.Lists()
if err != nil {
log.Err(err).Error("Updater can't retrieve lists")
return -1
}
var missing = make(map[int64]bool)
u.lists.Range(func(key, _ any) bool {
log.Tracef("List %d has updater running", key)
missing[key.(int64)] = true
return true
})
for _, list := range lists {
log.Tracef("List %d is active: %t", list.ID, list.IsEnabled)
if !list.IsEnabled {
continue
}
delete(missing, list.ID)
if _, exists := u.lists.Load(list.ID); !exists {
u.lists.Store(list.ID, list)
updater := newUpdaterJob(u.storage, &list)
u.updaters.Store(list.ID, updater)
}
}
for id := range missing {
log.Tracef("List %d has updater running, but is no longer active, reaping...", id)
if updater, ok := u.updaters.Load(id); ok {
close(updater.(*updaterJob).done)
u.updaters.Delete(id)
}
}
return
}
type updaterJob struct {
storage Storage
list *List
done chan struct{}
}
func newUpdaterJob(storage Storage, list *List) *updaterJob {
job := &updaterJob{
storage: storage,
list: list,
done: make(chan struct{}, 1),
}
go job.loop()
return job
}
func (job *updaterJob) loop() {
var (
ticker = time.NewTicker(job.list.Refresh)
first = time.After(0)
now time.Time
log = logger.StandardLog.Values(logger.Values{
"list": job.list.ID,
"type": job.list.Type,
})
)
defer ticker.Stop()
for {
select {
case <-job.done:
log.Debug("List updater stopping")
return
case now = <-ticker.C:
case now = <-first:
}
log.Debug("List updater running")
if update, err := job.run(now); err != nil {
log.Err(err).Error("List updater failed")
} else if update {
if err = job.storage.SaveList(job.list); err != nil {
log.Err(err).Error("List updater save failed")
}
}
}
}
// run this updater
func (job *updaterJob) run(now time.Time) (update bool, err error) {
u, err := url.Parse(job.list.Source)
if err != nil {
return false, err
}
log := logger.StandardLog.Values(logger.Values{
"list": job.list.ID,
"source": job.list.Source,
})
if u.Scheme == "" || u.Scheme == "file" {
log.Debug("Updating list from file")
return job.updateFile(u.Path)
}
log.Debug("Updating list from URL")
return job.updateHTTP(u)
}
func (job *updaterJob) updateFile(name string) (update bool, err error) {
var b []byte
if b, err = os.ReadFile(name); err != nil {
return
}
if update = !bytes.Equal(b, job.list.Cache); update {
job.list.Cache = b
}
return
}
func (job *updaterJob) updateHTTP(location *url.URL) (update bool, err error) {
if update, err = job.shouldUpdateHTTP(location); err != nil || !update {
return
}
var (
req *http.Request
res *http.Response
)
if req, err = http.NewRequest(http.MethodGet, location.String(), nil); err != nil {
return
}
if res, err = http.DefaultClient.Do(req); err != nil {
return
}
defer res.Body.Close()
if job.list.Cache, err = io.ReadAll(res.Body); err != nil {
return
}
return true, nil
}
func (job *updaterJob) shouldUpdateHTTP(location *url.URL) (update bool, err error) {
if len(job.list.Cache) == 0 {
// Nothing cached, please update.
return true, nil
}
var (
req *http.Request
res *http.Response
)
if req, err = http.NewRequest(http.MethodHead, location.String(), nil); err != nil {
return
}
if res, err = http.DefaultClient.Do(req); err != nil {
return
}
defer res.Body.Close()
if lastModified, err := time.Parse(http.TimeFormat, res.Header.Get("Last-Modified")); err == nil {
return lastModified.After(job.list.UpdatedAt), nil
}
return true, nil // not sure, no Last-Modified, so let's update?
}