Checkpoint
This commit is contained in:
1
dataset/base.go
Normal file
1
dataset/base.go
Normal file
@@ -0,0 +1 @@
|
||||
package dataset
|
25
dataset/error.go
Normal file
25
dataset/error.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package dataset
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/mjl-/bstore"
|
||||
)
|
||||
|
||||
type ErrNotExist struct {
|
||||
Object string
|
||||
ID int64
|
||||
}
|
||||
|
||||
func (err ErrNotExist) Error() string {
|
||||
return fmt.Sprintf("storage: %s not found", err.Object)
|
||||
}
|
||||
|
||||
func IsNotExist(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return os.IsNotExist(err) || errors.Is(err, ErrNotExist{}) || errors.Is(err, bstore.ErrAbsent)
|
||||
}
|
53
dataset/parser/adblock.go
Normal file
53
dataset/parser/adblock.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterDomainsParser(adblockDomainsParser{})
|
||||
}
|
||||
|
||||
type adblockDomainsParser struct{}
|
||||
|
||||
func (adblockDomainsParser) CanHandle(line string) bool {
|
||||
return strings.HasPrefix(strings.ToLower(line), `[adblock`) ||
|
||||
strings.HasPrefix(line, "@@") || // exception
|
||||
strings.HasPrefix(line, "||") || // blah
|
||||
line[0] == '*'
|
||||
}
|
||||
|
||||
func (adblockDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if isComment(line) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Common AdBlock patterns:
|
||||
// ||domain.com^
|
||||
// |http://domain.com|
|
||||
// domain.com/path
|
||||
// *domain.com*
|
||||
switch {
|
||||
case strings.HasPrefix(line, `||`): // domain anchor
|
||||
if i := strings.IndexByte(line, '^'); i != -1 {
|
||||
domains = append(domains, line[2:i])
|
||||
continue
|
||||
}
|
||||
case strings.HasPrefix(line, `|`) && strings.HasSuffix(line, `|`):
|
||||
domains = append(domains, line[1:len(line)-2])
|
||||
continue
|
||||
case strings.HasPrefix(line, `[`):
|
||||
continue
|
||||
}
|
||||
ignored++
|
||||
}
|
||||
if err = scanner.Err(); err != nil {
|
||||
return
|
||||
}
|
||||
return unique(domains), ignored, nil
|
||||
}
|
41
dataset/parser/adblock_test.go
Normal file
41
dataset/parser/adblock_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAdBlockParser(t *testing.T) {
|
||||
test := `[Adblock Plus 2.0]
|
||||
! Title: AdRules DNS List
|
||||
! Homepage: https://github.com/Cats-Team/AdRules
|
||||
! Powerd by Cats-Team
|
||||
! Expires: 1 (update frequency)
|
||||
! Description: The DNS Filters
|
||||
! Total count: 145270
|
||||
! Update: 2025-10-07 02:05:08(GMT+8)
|
||||
/^.+stat\.kugou\.com/
|
||||
/^admarvel\./
|
||||
||*-ad-sign.byteimg.com^
|
||||
||*-ad.a.yximgs.com^
|
||||
||*-applog.fqnovel.com^
|
||||
||*-datareceiver.aki-game.net^
|
||||
||*.exaapi.com^`
|
||||
want := []string{"*-ad-sign.byteimg.com", "*-ad.a.yximgs.com", "*-applog.fqnovel.com", "*-datareceiver.aki-game.net", "*.exaapi.com"}
|
||||
|
||||
parsed, ignored, err := ParseDomains(strings.NewReader(test))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
sort.Strings(parsed)
|
||||
if !reflect.DeepEqual(parsed, want) {
|
||||
t.Errorf("expected ParseDomains(domains) to return %v, got %v", want, parsed)
|
||||
}
|
||||
if ignored != 2 {
|
||||
t.Errorf("expected 2 ignored, got %d", ignored)
|
||||
}
|
||||
}
|
139
dataset/parser/dns.go
Normal file
139
dataset/parser/dns.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterDomainsParser(dnsmasqDomainsParser{})
|
||||
RegisterDomainsParser(mosDNSDomainsParser{})
|
||||
RegisterDomainsParser(smartDNSDomainsParser{})
|
||||
RegisterDomainsParser(unboundDomainsParser{})
|
||||
}
|
||||
|
||||
type dnsmasqDomainsParser struct{}
|
||||
|
||||
func (dnsmasqDomainsParser) CanHandle(line string) bool {
|
||||
return strings.HasPrefix(line, "address=/")
|
||||
}
|
||||
|
||||
func (dnsmasqDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if isComment(line) {
|
||||
continue
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(line, "address=/"):
|
||||
part := strings.FieldsFunc(line, func(r rune) bool { return r == '/' })
|
||||
if len(part) >= 3 && isDomainName(part[1]) {
|
||||
domains = append(domains, part[1])
|
||||
continue
|
||||
}
|
||||
}
|
||||
ignored++
|
||||
}
|
||||
if err = scanner.Err(); err != nil {
|
||||
return
|
||||
}
|
||||
return unique(domains), ignored, nil
|
||||
}
|
||||
|
||||
type mosDNSDomainsParser struct{}
|
||||
|
||||
func (mosDNSDomainsParser) CanHandle(line string) bool {
|
||||
if strings.HasPrefix(line, "domain:") {
|
||||
return isDomainName(line[7:])
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (mosDNSDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if isComment(line) {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "domain:") {
|
||||
domains = append(domains, line[7:])
|
||||
continue
|
||||
}
|
||||
ignored++
|
||||
}
|
||||
if err = scanner.Err(); err != nil {
|
||||
return
|
||||
}
|
||||
return unique(domains), ignored, nil
|
||||
}
|
||||
|
||||
type smartDNSDomainsParser struct{}
|
||||
|
||||
func (smartDNSDomainsParser) CanHandle(line string) bool {
|
||||
return strings.HasPrefix(line, "address /")
|
||||
}
|
||||
|
||||
func (smartDNSDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if isComment(line) {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "address /") {
|
||||
if i := strings.IndexByte(line[9:], '/'); i > -1 {
|
||||
domains = append(domains, line[9:i+9])
|
||||
continue
|
||||
}
|
||||
}
|
||||
ignored++
|
||||
}
|
||||
if err = scanner.Err(); err != nil {
|
||||
return
|
||||
}
|
||||
return unique(domains), ignored, nil
|
||||
}
|
||||
|
||||
type unboundDomainsParser struct{}
|
||||
|
||||
func (unboundDomainsParser) CanHandle(line string) bool {
|
||||
return strings.HasPrefix(line, "local-data:") ||
|
||||
strings.HasPrefix(line, "local-zone:")
|
||||
}
|
||||
|
||||
func (unboundDomainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if isComment(line) {
|
||||
continue
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(line, "local-data:"):
|
||||
record := strings.Trim(strings.TrimSpace(line[11:]), `"`)
|
||||
if rr, err := dns.NewRR(record); err == nil {
|
||||
switch rr.Header().Rrtype {
|
||||
case dns.TypeA, dns.TypeAAAA, dns.TypeCNAME:
|
||||
domains = append(domains, strings.Trim(rr.Header().Name, `.`))
|
||||
continue
|
||||
}
|
||||
}
|
||||
case strings.HasPrefix(line, "local-zone:") && strings.HasSuffix(line, " reject"):
|
||||
line = strings.Trim(strings.TrimSpace(line[11:]), `"`)
|
||||
if i := strings.IndexByte(line, '"'); i > -1 {
|
||||
domains = append(domains, line[:i])
|
||||
continue
|
||||
}
|
||||
}
|
||||
ignored++
|
||||
}
|
||||
if err = scanner.Err(); err != nil {
|
||||
return
|
||||
}
|
||||
return unique(domains), ignored, nil
|
||||
}
|
106
dataset/parser/dns_test.go
Normal file
106
dataset/parser/dns_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDNSMasqParser(t *testing.T) {
|
||||
tests := []struct {
|
||||
Name string
|
||||
Test string
|
||||
Want []string
|
||||
WantIgnored int
|
||||
}{
|
||||
{
|
||||
"data",
|
||||
`
|
||||
local-data: "junk1.doubleclick.net A 127.0.0.1"
|
||||
local-data: "junk2.doubleclick.net A 127.0.0.1"
|
||||
local-data: "junk2.doubleclick.net CNAME doubleclick.net."
|
||||
local-data: "junk6.doubleclick.net AAAA ::1"
|
||||
local-data: "doubleclick.net A 127.0.0.1"
|
||||
local-data: "ad.junk1.doubleclick.net A 127.0.0.1"
|
||||
local-data: "adjunk.google.com A 127.0.0.1"`,
|
||||
[]string{"ad.junk1.doubleclick.net", "adjunk.google.com", "doubleclick.net", "junk1.doubleclick.net", "junk2.doubleclick.net", "junk6.doubleclick.net"},
|
||||
0,
|
||||
},
|
||||
{
|
||||
"zone",
|
||||
`
|
||||
local-zone: "doubleclick.net" reject
|
||||
local-zone: "adjunk.google.com" reject`,
|
||||
[]string{"adjunk.google.com", "doubleclick.net"},
|
||||
0,
|
||||
},
|
||||
{
|
||||
"address",
|
||||
`
|
||||
address=/ziyu.net/0.0.0.0
|
||||
address=/zlp6s.pw/0.0.0.0
|
||||
address=/zm232.com/0.0.0.0
|
||||
`,
|
||||
[]string{"ziyu.net", "zlp6s.pw", "zm232.com"},
|
||||
0,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.Name, func(it *testing.T) {
|
||||
parsed, ignored, err := ParseDomains(strings.NewReader(test.Test))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
sort.Strings(parsed)
|
||||
if !reflect.DeepEqual(parsed, test.Want) {
|
||||
t.Errorf("expected ParseDomains(dnsmasq) to return\n\t%v, got\n\t%v", test.Want, parsed)
|
||||
}
|
||||
if ignored != test.WantIgnored {
|
||||
t.Errorf("expected %d ignored, got %d", test.WantIgnored, ignored)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMOSDNSParser(t *testing.T) {
|
||||
test := `domain:0019x.com
|
||||
domain:002777.xyz
|
||||
domain:003store.com
|
||||
domain:00404850.xyz`
|
||||
want := []string{"0019x.com", "002777.xyz", "003store.com", "00404850.xyz"}
|
||||
|
||||
parsed, _, err := ParseDomains(strings.NewReader(test))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
sort.Strings(parsed)
|
||||
if !reflect.DeepEqual(parsed, want) {
|
||||
t.Errorf("expected ParseDomains(domains) to return %v, got %v", want, parsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSmartDNSParser(t *testing.T) {
|
||||
test := `# Title:AdRules SmartDNS List
|
||||
# Update: 2025-10-07 02:05:08(GMT+8)
|
||||
address /0.myikas.com/#
|
||||
address /0.net.easyjet.com/#
|
||||
address /0.nextyourcontent.com/#
|
||||
address /0019x.com/#`
|
||||
want := []string{"0.myikas.com", "0.net.easyjet.com", "0.nextyourcontent.com", "0019x.com"}
|
||||
|
||||
parsed, _, err := ParseDomains(strings.NewReader(test))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
sort.Strings(parsed)
|
||||
if !reflect.DeepEqual(parsed, want) {
|
||||
t.Errorf("expected ParseDomains(domains) to return %v, got %v", want, parsed)
|
||||
}
|
||||
}
|
40
dataset/parser/domains.go
Normal file
40
dataset/parser/domains.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func init() {
|
||||
domainsParsers = append(domainsParsers, domainsParser{})
|
||||
}
|
||||
|
||||
type domainsParser struct{}
|
||||
|
||||
func (domainsParser) CanHandle(line string) bool {
|
||||
return isDomainName(line) &&
|
||||
!strings.ContainsRune(line, ' ') &&
|
||||
!strings.ContainsRune(line, ':') &&
|
||||
net.ParseIP(line) == nil
|
||||
}
|
||||
|
||||
func (domainsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if isComment(line) {
|
||||
continue
|
||||
}
|
||||
if isDomainName(line) {
|
||||
domains = append(domains, line)
|
||||
continue
|
||||
}
|
||||
ignored++
|
||||
}
|
||||
if err = scanner.Err(); err != nil {
|
||||
return
|
||||
}
|
||||
return unique(domains), ignored, nil
|
||||
}
|
31
dataset/parser/domains_test.go
Normal file
31
dataset/parser/domains_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseDomains(t *testing.T) {
|
||||
test := `# This is a comment
|
||||
facebook.com
|
||||
tiktok.com
|
||||
bogus ignored
|
||||
youtube.com`
|
||||
want := []string{"facebook.com", "tiktok.com", "youtube.com"}
|
||||
|
||||
parsed, ignored, err := ParseDomains(strings.NewReader(test))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
sort.Strings(parsed)
|
||||
if !reflect.DeepEqual(parsed, want) {
|
||||
t.Errorf("expected ParseDomains(domains) to return %v, got %v", want, parsed)
|
||||
}
|
||||
if ignored != 1 {
|
||||
t.Errorf("expected 1 ignored, got %d", ignored)
|
||||
}
|
||||
}
|
41
dataset/parser/hosts.go
Normal file
41
dataset/parser/hosts.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterDomainsParser(hostsParser{})
|
||||
}
|
||||
|
||||
type hostsParser struct{}
|
||||
|
||||
func (hostsParser) CanHandle(line string) bool {
|
||||
part := strings.Fields(line)
|
||||
return len(part) >= 2 && net.ParseIP(part[0]) != nil
|
||||
}
|
||||
|
||||
func (hostsParser) ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if isComment(line) {
|
||||
continue
|
||||
}
|
||||
|
||||
part := strings.Fields(line)
|
||||
if len(part) >= 2 && net.ParseIP(part[0]) != nil {
|
||||
domains = append(domains, part[1:]...)
|
||||
continue
|
||||
}
|
||||
|
||||
ignored++
|
||||
}
|
||||
if err = scanner.Err(); err != nil {
|
||||
return
|
||||
}
|
||||
return unique(domains), ignored, nil
|
||||
}
|
38
dataset/parser/hosts_test.go
Normal file
38
dataset/parser/hosts_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseHosts(t *testing.T) {
|
||||
test := `##
|
||||
# Host Database
|
||||
#
|
||||
# localhost is used to configure the loopback interface
|
||||
# when the system is booting. Do not change this entry.
|
||||
##
|
||||
127.0.0.1 localhost dragon dragon.local dragon.maze.network
|
||||
255.255.255.255 broadcasthost
|
||||
::1 localhost
|
||||
ff00::1 multicast
|
||||
1.2.3.4
|
||||
`
|
||||
want := []string{"broadcasthost", "dragon", "dragon.local", "dragon.maze.network", "localhost", "multicast"}
|
||||
|
||||
parsed, ignored, err := ParseDomains(strings.NewReader(test))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
sort.Strings(parsed)
|
||||
if !reflect.DeepEqual(parsed, want) {
|
||||
t.Errorf("expected ParseDomains(hosts) to return %v, got %v", want, parsed)
|
||||
}
|
||||
if ignored != 1 {
|
||||
t.Errorf("expected 1 ignored, got %d", ignored)
|
||||
}
|
||||
}
|
76
dataset/parser/parser.go
Normal file
76
dataset/parser/parser.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
var ErrNoParser = errors.New("no suitable parser could be found")
|
||||
|
||||
type Parser interface {
|
||||
CanHandle(line string) bool
|
||||
}
|
||||
|
||||
type DomainsParser interface {
|
||||
Parser
|
||||
ParseDomains(io.Reader) (domains []string, ignored int, err error)
|
||||
}
|
||||
|
||||
var domainsParsers []DomainsParser
|
||||
|
||||
func RegisterDomainsParser(parser DomainsParser) {
|
||||
domainsParsers = append(domainsParsers, parser)
|
||||
}
|
||||
|
||||
func ParseDomains(r io.Reader) (domains []string, ignored int, err error) {
|
||||
var (
|
||||
buffer = new(bytes.Buffer)
|
||||
scanner = bufio.NewScanner(io.TeeReader(r, buffer))
|
||||
line string
|
||||
parser DomainsParser
|
||||
)
|
||||
for scanner.Scan() {
|
||||
line = strings.TrimSpace(scanner.Text())
|
||||
if isComment(line) {
|
||||
continue
|
||||
}
|
||||
for _, parser = range domainsParsers {
|
||||
if parser.CanHandle(line) {
|
||||
log.Printf("using parser %T", parser)
|
||||
return parser.ParseDomains(io.MultiReader(buffer, r))
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
return nil, 0, ErrNoParser
|
||||
}
|
||||
|
||||
func isComment(line string) bool {
|
||||
return line == "" || line[0] == '#' || line[0] == '!'
|
||||
}
|
||||
|
||||
func isDomainName(name string) bool {
|
||||
n, ok := dns.IsDomainName(name)
|
||||
return n >= 2 && ok
|
||||
}
|
||||
|
||||
func unique(strings []string) []string {
|
||||
if strings == nil {
|
||||
return nil
|
||||
}
|
||||
v := make(map[string]struct{})
|
||||
for _, s := range strings {
|
||||
v[s] = struct{}{}
|
||||
}
|
||||
o := make([]string, 0, len(v))
|
||||
for k := range v {
|
||||
o = append(o, k)
|
||||
}
|
||||
return o
|
||||
}
|
31
dataset/parser/parser_test.go
Normal file
31
dataset/parser/parser_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUnique(t *testing.T) {
|
||||
tests := []struct {
|
||||
Name string
|
||||
Test []string
|
||||
Want []string
|
||||
}{
|
||||
{"nil", nil, nil},
|
||||
{"single", []string{"test"}, []string{"test"}},
|
||||
{"duplicate", []string{"test", "test"}, []string{"test"}},
|
||||
{"multiple", []string{"a", "a", "b", "b", "b", "c"}, []string{"a", "b", "c"}},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.Name, func(it *testing.T) {
|
||||
v := unique(test.Test)
|
||||
if v != nil {
|
||||
sort.Strings(v)
|
||||
}
|
||||
if !reflect.DeepEqual(v, test.Want) {
|
||||
it.Errorf("expected unique(%v) to return %v, got %v", test.Test, test.Want, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
231
dataset/storage.go
Normal file
231
dataset/storage.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package dataset
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite3 driver
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type Storage interface {
|
||||
Groups() ([]Group, error)
|
||||
GroupByID(int64) (Group, error)
|
||||
GroupByName(name string) (Group, error)
|
||||
SaveGroup(*Group) error
|
||||
DeleteGroup(Group) error
|
||||
|
||||
Clients() (Clients, error)
|
||||
ClientByID(int64) (Client, error)
|
||||
ClientByIP(net.IP) (Client, error)
|
||||
SaveClient(*Client) error
|
||||
DeleteClient(Client) error
|
||||
|
||||
Lists() ([]List, error)
|
||||
ListByID(int64) (List, error)
|
||||
SaveList(*List) error
|
||||
DeleteList(List) error
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name" bstore:"nonzero,unique"`
|
||||
IsEnabled bool `json:"is_enabled" bstore:"nonzero"`
|
||||
Description string `json:"description"`
|
||||
CreatedAt time.Time `json:"created_at" bstore:"nonzero"`
|
||||
UpdatedAt time.Time `json:"updated_at" bstore:"nonzero"`
|
||||
Storage Storage `json:"-" bstore:"-"`
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
ID int64 `json:"id"`
|
||||
Network string `json:"network" bstore:"nonzero,index"`
|
||||
IP string `json:"ip" bstore:"nonzero,unique IP+Mask"`
|
||||
Mask int `json:"mask"`
|
||||
Description string `json:"description"`
|
||||
Groups []Group `json:"groups,omitempty" bstore:"-"`
|
||||
CreatedAt time.Time `json:"created_at" bstore:"nonzero"`
|
||||
UpdatedAt time.Time `json:"updated_at" bstore:"nonzero"`
|
||||
Storage Storage `json:"-" bstore:"-"`
|
||||
}
|
||||
|
||||
type WithClient interface {
|
||||
Client() (Client, error)
|
||||
}
|
||||
|
||||
type ClientGroup struct {
|
||||
ID int64 `json:"id"`
|
||||
ClientID int64 `json:"client_id" bstore:"ref Client,index"`
|
||||
GroupID int64 `json:"group_id" bstore:"ref Group,index"`
|
||||
}
|
||||
|
||||
func (c *Client) ContainsIP(ip net.IP) bool {
|
||||
ipnet := &net.IPNet{
|
||||
IP: net.ParseIP(c.IP),
|
||||
Mask: net.CIDRMask(int(c.Mask), 32),
|
||||
}
|
||||
if ipnet.IP == nil {
|
||||
return false
|
||||
}
|
||||
return ipnet.Contains(ip)
|
||||
}
|
||||
|
||||
func (c *Client) String() string {
|
||||
ipnet := &net.IPNet{
|
||||
IP: net.ParseIP(c.IP),
|
||||
Mask: net.CIDRMask(int(c.Mask), 32),
|
||||
}
|
||||
return ipnet.String()
|
||||
}
|
||||
|
||||
type Clients []Client
|
||||
|
||||
func (cs Clients) ByIP(ip net.IP) *Client {
|
||||
var candidates []*Client
|
||||
for _, c := range cs {
|
||||
if c.ContainsIP(ip) {
|
||||
candidates = append(candidates, &c)
|
||||
}
|
||||
}
|
||||
switch len(candidates) {
|
||||
case 0:
|
||||
return nil
|
||||
case 1:
|
||||
return candidates[0]
|
||||
default:
|
||||
slices.SortStableFunc(candidates, func(a, b *Client) int {
|
||||
return int(b.Mask) - int(a.Mask)
|
||||
})
|
||||
return candidates[0]
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
ListTypeDomain = "domain"
|
||||
ListTypeNetwork = "network"
|
||||
)
|
||||
|
||||
const (
|
||||
MinListRefresh = 1 * time.Minute
|
||||
DefaultListRefresh = 30 * time.Minute
|
||||
)
|
||||
|
||||
type List struct {
|
||||
ID int64 `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Source string `json:"source"`
|
||||
IsEnabled bool `json:"is_enabled"`
|
||||
Permit bool `json:"permit"`
|
||||
Groups []Group `json:"groups,omitempty" bstore:"-"`
|
||||
Status int `json:"status"`
|
||||
Comment string `json:"comment"`
|
||||
Cache []byte `json:"cache"`
|
||||
Refresh time.Duration `json:"refresh"`
|
||||
LastModified time.Time `json:"last_modified"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
func (list *List) Domains() (*DomainTree, error) {
|
||||
if list.Type != ListTypeDomain {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var (
|
||||
tree = NewDomainList()
|
||||
scan = bufio.NewScanner(bytes.NewReader(list.Cache))
|
||||
)
|
||||
for scan.Scan() {
|
||||
line := strings.TrimSpace(scan.Text())
|
||||
if line == "" || line[0] == '#' {
|
||||
continue
|
||||
}
|
||||
if labels, ok := dns.IsDomainName(line); ok && labels >= 2 {
|
||||
tree.Add(line)
|
||||
}
|
||||
}
|
||||
if err := scan.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tree, nil
|
||||
}
|
||||
|
||||
func (list *List) Update() (updated bool, err error) {
|
||||
u, err := url.Parse(list.Source)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
switch u.Scheme {
|
||||
case "", "file":
|
||||
return list.updateFile(u.Path)
|
||||
case "http", "https":
|
||||
return list.updateHTTP(u.String())
|
||||
default:
|
||||
return false, fmt.Errorf("dataset: don't know how to update %s sources", u.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func (list *List) updateFile(name string) (updated bool, err error) {
|
||||
var info fs.FileInfo
|
||||
if info, err = os.Stat(name); err != nil {
|
||||
return
|
||||
} else if info.IsDir() {
|
||||
return false, fmt.Errorf("dataset: list %d: %q is a directory", list.ID, name)
|
||||
}
|
||||
if updated = info.ModTime().After(list.UpdatedAt); !updated {
|
||||
return
|
||||
}
|
||||
list.Cache, _ = os.ReadFile(name)
|
||||
return
|
||||
}
|
||||
|
||||
func (list *List) updateHTTP(url string) (updated bool, err error) {
|
||||
if updated, err = list.shouldUpdateHTTP(url); err != nil || !updated {
|
||||
return
|
||||
}
|
||||
|
||||
var response *http.Response
|
||||
if response, err = http.DefaultClient.Get(url); err != nil {
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if list.Cache, err = io.ReadAll(response.Body); err != nil {
|
||||
return
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (list *List) shouldUpdateHTTP(url string) (updated bool, err error) {
|
||||
var response *http.Response
|
||||
if response, err = http.DefaultClient.Head(url); err != nil {
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if value := response.Header.Get("Last-Modified"); value != "" {
|
||||
var lastModified time.Time
|
||||
if lastModified, err = time.Parse(http.TimeFormat, value); err == nil {
|
||||
return lastModified.After(list.LastModified), nil
|
||||
}
|
||||
}
|
||||
|
||||
// There are no headers that would indicate last-modified time, so assume we have to update:
|
||||
return true, nil
|
||||
}
|
||||
|
||||
type ListGroup struct {
|
||||
ID int64 `json:"id"`
|
||||
ListID int64 `json:"list_id" bstore:"ref List,index"`
|
||||
GroupID int64 `json:"group_id" bstore:"ref Group,index"`
|
||||
}
|
412
dataset/storage_bstore.go
Normal file
412
dataset/storage_bstore.go
Normal file
@@ -0,0 +1,412 @@
|
||||
package dataset
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/logger"
|
||||
"github.com/mjl-/bstore"
|
||||
)
|
||||
|
||||
type bstoreStorage struct {
|
||||
db *bstore.DB
|
||||
path string
|
||||
}
|
||||
|
||||
func OpenBStore(name string) (Storage, error) {
|
||||
if !filepath.IsAbs(name) {
|
||||
var err error
|
||||
if name, err = filepath.Abs(name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
db, err := bstore.Open(ctx, name, nil,
|
||||
Group{},
|
||||
Client{},
|
||||
ClientGroup{},
|
||||
List{},
|
||||
ListGroup{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
s = &bstoreStorage{db: db, path: name}
|
||||
defaultGroup Group
|
||||
defaultClient4 Client
|
||||
defaultClient6 Client
|
||||
)
|
||||
|
||||
if defaultGroup, err = s.GroupByName("Default"); errors.Is(err, bstore.ErrAbsent) {
|
||||
defaultGroup = Group{
|
||||
Name: "Default",
|
||||
IsEnabled: true,
|
||||
Description: "Default group",
|
||||
}
|
||||
if err = s.SaveGroup(&defaultGroup); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if defaultClient4, err = bstore.QueryDB[Client](ctx, db).
|
||||
FilterEqual("Network", "ipv4").
|
||||
FilterFn(func(client Client) bool {
|
||||
return net.ParseIP(client.IP).Equal(net.ParseIP("0.0.0.0")) && client.Mask == 0
|
||||
}).Get(); errors.Is(err, bstore.ErrAbsent) {
|
||||
defaultClient4 = Client{
|
||||
Network: "ipv4",
|
||||
IP: "0.0.0.0",
|
||||
Mask: 0,
|
||||
Description: "All IPv4 clients",
|
||||
}
|
||||
if err = s.SaveClient(&defaultClient4); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = s.db.Insert(ctx, &ClientGroup{ClientID: defaultClient4.ID, GroupID: defaultGroup.ID}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if defaultClient6, err = bstore.QueryDB[Client](ctx, db).
|
||||
FilterEqual("Network", "ipv6").
|
||||
FilterFn(func(client Client) bool {
|
||||
return net.ParseIP(client.IP).Equal(net.ParseIP("::")) && client.Mask == 0
|
||||
}).Get(); errors.Is(err, bstore.ErrAbsent) {
|
||||
defaultClient6 = Client{
|
||||
Network: "ipv6",
|
||||
IP: "::",
|
||||
Mask: 0,
|
||||
Description: "All IPv6 clients",
|
||||
}
|
||||
if err = s.SaveClient(&defaultClient6); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = s.db.Insert(ctx, &ClientGroup{ClientID: defaultClient6.ID, GroupID: defaultGroup.ID}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Start updater
|
||||
NewUpdater(s)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *bstoreStorage) log() logger.Structured {
|
||||
return logger.StandardLog.Values(logger.Values{
|
||||
"storage": "bstore",
|
||||
"storage_path": s.path,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *bstoreStorage) Groups() ([]Group, error) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
query = bstore.QueryDB[Group](ctx, s.db)
|
||||
groups = make([]Group, 0)
|
||||
)
|
||||
for group := range query.All() {
|
||||
groups = append(groups, group)
|
||||
}
|
||||
if err := query.Err(); err != nil && !errors.Is(err, bstore.ErrFinished) {
|
||||
return nil, err
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (s *bstoreStorage) GroupByID(id int64) (Group, error) {
|
||||
ctx := context.Background()
|
||||
return bstore.QueryDB[Group](ctx, s.db).FilterID(id).Get()
|
||||
}
|
||||
|
||||
func (s *bstoreStorage) GroupByName(name string) (Group, error) {
|
||||
ctx := context.Background()
|
||||
return bstore.QueryDB[Group](ctx, s.db).FilterFn(func(group Group) bool {
|
||||
return strings.EqualFold(group.Name, name)
|
||||
}).Get()
|
||||
}
|
||||
|
||||
func (s *bstoreStorage) SaveGroup(group *Group) (err error) {
|
||||
ctx := context.Background()
|
||||
group.UpdatedAt = time.Now().UTC()
|
||||
if group.CreatedAt.Equal(time.Time{}) {
|
||||
group.CreatedAt = group.UpdatedAt
|
||||
err = s.db.Insert(ctx, group)
|
||||
} else {
|
||||
err = s.db.Update(ctx, group)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("dataset: save group %s failed: %w", group.Name, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *bstoreStorage) DeleteGroup(group Group) (err error) {
|
||||
ctx := context.Background()
|
||||
tx, err := s.db.Begin(ctx, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = bstore.QueryTx[ClientGroup](tx).FilterEqual("GroupID", group.ID).Delete(); err != nil {
|
||||
return
|
||||
}
|
||||
if _, err = bstore.QueryTx[ListGroup](tx).FilterEqual("GroupID", group.ID).Delete(); err != nil {
|
||||
return
|
||||
}
|
||||
if err = tx.Delete(group); err != nil {
|
||||
return
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *bstoreStorage) Clients() (Clients, error) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
query = bstore.QueryDB[Client](ctx, s.db)
|
||||
clients = make(Clients, 0)
|
||||
)
|
||||
for client := range query.All() {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
if err := query.Err(); err != nil && !errors.Is(err, bstore.ErrFinished) {
|
||||
return nil, err
|
||||
}
|
||||
return clients, nil
|
||||
}
|
||||
|
||||
func (s *bstoreStorage) ClientByID(id int64) (Client, error) {
|
||||
ctx := context.Background()
|
||||
client, err := bstore.QueryDB[Client](ctx, s.db).FilterID(id).Get()
|
||||
if err != nil {
|
||||
return client, err
|
||||
}
|
||||
return s.clientResolveGroups(ctx, client)
|
||||
}
|
||||
|
||||
func (s *bstoreStorage) ClientByIP(ip net.IP) (Client, error) {
|
||||
if ip == nil {
|
||||
return Client{}, ErrNotExist{Object: "client"}
|
||||
}
|
||||
var (
|
||||
ctx = context.Background()
|
||||
clients Clients
|
||||
network string
|
||||
)
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
network = "ipv4"
|
||||
} else if ip6 := ip.To16(); ip6 != nil {
|
||||
network = "ipv6"
|
||||
}
|
||||
if network == "" {
|
||||
return Client{}, ErrNotExist{Object: "client"}
|
||||
}
|
||||
for client, err := range bstore.QueryDB[Client](ctx, s.db).
|
||||
FilterEqual("Network", network).
|
||||
FilterFn(func(client Client) bool {
|
||||
return client.ContainsIP(ip)
|
||||
}).All() {
|
||||
if err != nil {
|
||||
return Client{}, err
|
||||
}
|
||||
clients = append(clients, client)
|
||||
}
|
||||
|
||||
var client Client
|
||||
switch len(clients) {
|
||||
case 0:
|
||||
return Client{}, ErrNotExist{Object: "client"}
|
||||
case 1:
|
||||
client = clients[0]
|
||||
default:
|
||||
slices.SortStableFunc(clients, func(a, b Client) int {
|
||||
return int(b.Mask) - int(a.Mask)
|
||||
})
|
||||
client = clients[0]
|
||||
}
|
||||
return s.clientResolveGroups(ctx, client)
|
||||
}
|
||||
|
||||
func (s *bstoreStorage) clientResolveGroups(ctx context.Context, client Client) (Client, error) {
|
||||
for clientGroup, err := range bstore.QueryDB[ClientGroup](ctx, s.db).FilterEqual("ClientID", client.ID).All() {
|
||||
if err != nil {
|
||||
return Client{}, err
|
||||
}
|
||||
if group, err := s.GroupByID(clientGroup.GroupID); err == nil {
|
||||
client.Groups = append(client.Groups, group)
|
||||
}
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *bstoreStorage) SaveClient(client *Client) (err error) {
|
||||
log := s.log()
|
||||
ctx := context.Background()
|
||||
client.UpdatedAt = time.Now().UTC()
|
||||
|
||||
tx, err := s.db.Begin(ctx, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log = log.Values(logger.Values{"ip": client.IP, "mask": client.Mask, "description": client.Description})
|
||||
if client.CreatedAt.Equal(time.Time{}) {
|
||||
log.Debug("Create client")
|
||||
client.CreatedAt = client.UpdatedAt
|
||||
if err = tx.Insert(client); err != nil {
|
||||
return fmt.Errorf("dataset: client insert failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
log.Debug("Update client")
|
||||
if err = tx.Update(client); err != nil {
|
||||
return fmt.Errorf("dataset: client update failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var deleted int
|
||||
if deleted, err = bstore.QueryTx[ClientGroup](tx).FilterEqual("ClientID", client.ID).Delete(); err != nil {
|
||||
return fmt.Errorf("dataset: client groups delete failed: %w", err)
|
||||
}
|
||||
log.Debugf("Deleted %d groups", deleted)
|
||||
log.Debugf("Linking %d groups", len(client.Groups))
|
||||
for _, group := range client.Groups {
|
||||
if err = tx.Insert(&ClientGroup{ClientID: client.ID, GroupID: group.ID}); err != nil {
|
||||
return fmt.Errorf("dataset: client groups insert failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *bstoreStorage) DeleteClient(client Client) (err error) {
|
||||
ctx := context.Background()
|
||||
tx, err := s.db.Begin(ctx, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = bstore.QueryTx[ClientGroup](tx).FilterEqual("ClientID", client.ID).Delete(); err != nil {
|
||||
return
|
||||
}
|
||||
if err = tx.Delete(client); err != nil {
|
||||
return
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *bstoreStorage) Lists() ([]List, error) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
query = bstore.QueryDB[List](ctx, s.db)
|
||||
lists = make([]List, 0)
|
||||
)
|
||||
for list := range query.All() {
|
||||
lists = append(lists, list)
|
||||
}
|
||||
if err := query.Err(); err != nil && !errors.Is(err, bstore.ErrFinished) {
|
||||
return nil, err
|
||||
}
|
||||
return lists, nil
|
||||
}
|
||||
|
||||
func (s *bstoreStorage) ListByID(id int64) (List, error) {
|
||||
ctx := context.Background()
|
||||
list, err := bstore.QueryDB[List](ctx, s.db).FilterID(id).Get()
|
||||
if err != nil {
|
||||
return list, err
|
||||
}
|
||||
return s.listResolveGroups(ctx, list)
|
||||
}
|
||||
|
||||
func (s *bstoreStorage) listResolveGroups(ctx context.Context, list List) (List, error) {
|
||||
for listGroup, err := range bstore.QueryDB[ListGroup](ctx, s.db).FilterEqual("ListID", list.ID).All() {
|
||||
if err != nil {
|
||||
return List{}, err
|
||||
}
|
||||
if group, err := s.GroupByID(listGroup.GroupID); err == nil {
|
||||
list.Groups = append(list.Groups, group)
|
||||
}
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (s *bstoreStorage) SaveList(list *List) (err error) {
|
||||
if list.Type != ListTypeDomain && list.Type != ListTypeNetwork {
|
||||
return fmt.Errorf("storage: unknown list type %q", list.Type)
|
||||
}
|
||||
if list.Refresh == 0 {
|
||||
list.Refresh = DefaultListRefresh
|
||||
} else if list.Refresh < MinListRefresh {
|
||||
list.Refresh = MinListRefresh
|
||||
}
|
||||
list.UpdatedAt = time.Now().UTC()
|
||||
|
||||
ctx := context.Background()
|
||||
tx, err := s.db.Begin(ctx, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log := s.log()
|
||||
log = log.Values(logger.Values{
|
||||
"type": list.Type,
|
||||
"source": list.Source,
|
||||
"is_enabled": list.IsEnabled,
|
||||
"status": list.Status,
|
||||
"cache": len(list.Cache),
|
||||
"refresh": list.Refresh,
|
||||
})
|
||||
|
||||
if list.CreatedAt.Equal(time.Time{}) {
|
||||
log.Debug("Creating list")
|
||||
list.CreatedAt = list.UpdatedAt
|
||||
if err = tx.Insert(list); err != nil {
|
||||
return fmt.Errorf("dataset: list insert failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
log.Debug("Updating list")
|
||||
if err = tx.Update(list); err != nil {
|
||||
return fmt.Errorf("dataset: list update failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var deleted int
|
||||
if deleted, err = bstore.QueryTx[ListGroup](tx).FilterEqual("ListID", list.ID).Delete(); err != nil {
|
||||
return fmt.Errorf("dataset: list groups delete failed: %w", err)
|
||||
}
|
||||
log.Debugf("Deleted %d groups", deleted)
|
||||
log.Debugf("Linking %d groups", len(list.Groups))
|
||||
for _, group := range list.Groups {
|
||||
if err = tx.Insert(&ListGroup{ListID: list.ID, GroupID: group.ID}); err != nil {
|
||||
return fmt.Errorf("dataset: list groups insert failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *bstoreStorage) DeleteList(list List) (err error) {
|
||||
ctx := context.Background()
|
||||
tx, err := s.db.Begin(ctx, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = bstore.QueryTx[ListGroup](tx).FilterEqual("ListID", list.ID).Delete(); err != nil {
|
||||
return
|
||||
}
|
||||
if err = tx.Delete(list); err != nil {
|
||||
return
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
226
dataset/updater.go
Normal file
226
dataset/updater.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package dataset
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/logger"
|
||||
)
|
||||
|
||||
type Updater struct {
|
||||
storage Storage
|
||||
lists sync.Map // map[int64]List
|
||||
updaters sync.Map // map[int64]*updaterJob
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func NewUpdater(storage Storage) *Updater {
|
||||
u := &Updater{
|
||||
storage: storage,
|
||||
done: make(chan struct{}, 1),
|
||||
}
|
||||
go u.refresh()
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *Updater) Close() error {
|
||||
select {
|
||||
case <-u.done:
|
||||
return nil
|
||||
default:
|
||||
close(u.done)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (u *Updater) refresh() {
|
||||
check := time.NewTicker(time.Second)
|
||||
defer check.Stop()
|
||||
|
||||
var (
|
||||
log = logger.StandardLog
|
||||
)
|
||||
for {
|
||||
select {
|
||||
case <-u.done:
|
||||
log.Debug("Updater closing, stopping updaters...")
|
||||
u.updaters.Range(func(key, value any) bool {
|
||||
if value != nil {
|
||||
close(value.(*updaterJob).done)
|
||||
}
|
||||
return true
|
||||
})
|
||||
return
|
||||
|
||||
case now := <-check.C:
|
||||
u.check(now, log)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *Updater) check(now time.Time, log logger.Structured) (wait time.Duration) {
|
||||
log.Trace("Checking lists")
|
||||
lists, err := u.storage.Lists()
|
||||
if err != nil {
|
||||
log.Err(err).Error("Updater can't retrieve lists")
|
||||
return -1
|
||||
}
|
||||
|
||||
var missing = make(map[int64]bool)
|
||||
u.lists.Range(func(key, _ any) bool {
|
||||
log.Tracef("List %d has updater running", key)
|
||||
missing[key.(int64)] = true
|
||||
return true
|
||||
})
|
||||
for _, list := range lists {
|
||||
log.Tracef("List %d is active: %t", list.ID, list.IsEnabled)
|
||||
if !list.IsEnabled {
|
||||
continue
|
||||
}
|
||||
delete(missing, list.ID)
|
||||
if _, exists := u.lists.Load(list.ID); !exists {
|
||||
u.lists.Store(list.ID, list)
|
||||
updater := newUpdaterJob(u.storage, &list)
|
||||
u.updaters.Store(list.ID, updater)
|
||||
}
|
||||
}
|
||||
|
||||
for id := range missing {
|
||||
log.Tracef("List %d has updater running, but is no longer active, reaping...", id)
|
||||
if updater, ok := u.updaters.Load(id); ok {
|
||||
close(updater.(*updaterJob).done)
|
||||
u.updaters.Delete(id)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type updaterJob struct {
|
||||
storage Storage
|
||||
list *List
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newUpdaterJob(storage Storage, list *List) *updaterJob {
|
||||
job := &updaterJob{
|
||||
storage: storage,
|
||||
list: list,
|
||||
done: make(chan struct{}, 1),
|
||||
}
|
||||
go job.loop()
|
||||
return job
|
||||
}
|
||||
|
||||
func (job *updaterJob) loop() {
|
||||
var (
|
||||
ticker = time.NewTicker(job.list.Refresh)
|
||||
first = time.After(0)
|
||||
now time.Time
|
||||
log = logger.StandardLog.Values(logger.Values{
|
||||
"list": job.list.ID,
|
||||
"type": job.list.Type,
|
||||
})
|
||||
)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-job.done:
|
||||
log.Debug("List updater stopping")
|
||||
return
|
||||
|
||||
case now = <-ticker.C:
|
||||
case now = <-first:
|
||||
}
|
||||
|
||||
log.Debug("List updater running")
|
||||
if update, err := job.run(now); err != nil {
|
||||
log.Err(err).Error("List updater failed")
|
||||
} else if update {
|
||||
if err = job.storage.SaveList(job.list); err != nil {
|
||||
log.Err(err).Error("List updater save failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// run this updater
|
||||
func (job *updaterJob) run(now time.Time) (update bool, err error) {
|
||||
u, err := url.Parse(job.list.Source)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
log := logger.StandardLog.Values(logger.Values{
|
||||
"list": job.list.ID,
|
||||
"source": job.list.Source,
|
||||
})
|
||||
if u.Scheme == "" || u.Scheme == "file" {
|
||||
log.Debug("Updating list from file")
|
||||
return job.updateFile(u.Path)
|
||||
}
|
||||
log.Debug("Updating list from URL")
|
||||
return job.updateHTTP(u)
|
||||
}
|
||||
|
||||
func (job *updaterJob) updateFile(name string) (update bool, err error) {
|
||||
var b []byte
|
||||
if b, err = os.ReadFile(name); err != nil {
|
||||
return
|
||||
}
|
||||
if update = !bytes.Equal(b, job.list.Cache); update {
|
||||
job.list.Cache = b
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (job *updaterJob) updateHTTP(location *url.URL) (update bool, err error) {
|
||||
if update, err = job.shouldUpdateHTTP(location); err != nil || !update {
|
||||
return
|
||||
}
|
||||
var (
|
||||
req *http.Request
|
||||
res *http.Response
|
||||
)
|
||||
if req, err = http.NewRequest(http.MethodGet, location.String(), nil); err != nil {
|
||||
return
|
||||
}
|
||||
if res, err = http.DefaultClient.Do(req); err != nil {
|
||||
return
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if job.list.Cache, err = io.ReadAll(res.Body); err != nil {
|
||||
return
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (job *updaterJob) shouldUpdateHTTP(location *url.URL) (update bool, err error) {
|
||||
if len(job.list.Cache) == 0 {
|
||||
// Nothing cached, please update.
|
||||
return true, nil
|
||||
}
|
||||
|
||||
var (
|
||||
req *http.Request
|
||||
res *http.Response
|
||||
)
|
||||
if req, err = http.NewRequest(http.MethodHead, location.String(), nil); err != nil {
|
||||
return
|
||||
}
|
||||
if res, err = http.DefaultClient.Do(req); err != nil {
|
||||
return
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if lastModified, err := time.Parse(http.TimeFormat, res.Header.Get("Last-Modified")); err == nil {
|
||||
return lastModified.After(job.list.UpdatedAt), nil
|
||||
}
|
||||
return true, nil // not sure, no Last-Modified, so let's update?
|
||||
}
|
Reference in New Issue
Block a user