Files
styx/dataset/storage.go
2025-10-08 20:57:13 +02:00

237 lines
5.7 KiB
Go

package dataset
import (
"bytes"
"fmt"
"io"
"io/fs"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"slices"
"time"
"git.maze.io/maze/styx/dataset/parser"
_ "github.com/mattn/go-sqlite3" // SQLite3 driver
)
type Storage interface {
Groups() ([]Group, error)
GroupByID(int64) (Group, error)
GroupByName(name string) (Group, error)
SaveGroup(*Group) error
DeleteGroup(Group) error
Clients() (Clients, error)
ClientByID(int64) (Client, error)
ClientByAddr(netip.Addr) (Client, error)
// ClientByIP(net.IP) (Client, error)
SaveClient(*Client) error
DeleteClient(Client) error
Lists() ([]List, error)
ListsByGroup(Group) ([]List, error)
ListByID(int64) (List, error)
SaveList(*List) error
DeleteList(List) error
}
type Group struct {
ID int64 `json:"id"`
Name string `json:"name" bstore:"nonzero,unique"`
IsEnabled bool `json:"is_enabled" bstore:"nonzero"`
Description string `json:"description"`
CreatedAt time.Time `json:"created_at" bstore:"nonzero"`
UpdatedAt time.Time `json:"updated_at" bstore:"nonzero"`
}
type Client struct {
ID int64 `json:"id"`
Network string `json:"network" bstore:"nonzero,index"`
IP string `json:"ip" bstore:"nonzero,unique IP+Mask"`
Mask int `json:"mask"`
Description string `json:"description"`
Groups []Group `json:"groups,omitempty" bstore:"-"`
CreatedAt time.Time `json:"created_at" bstore:"nonzero"`
UpdatedAt time.Time `json:"updated_at" bstore:"nonzero"`
}
type ClientGroup struct {
ID int64 `json:"id"`
ClientID int64 `json:"client_id" bstore:"ref Client,index"`
GroupID int64 `json:"group_id" bstore:"ref Group,index"`
}
func (c *Client) ContainsIP(ip net.IP) bool {
ipnet := &net.IPNet{
IP: net.ParseIP(c.IP),
Mask: net.CIDRMask(int(c.Mask), 32),
}
if ipnet.IP == nil {
return false
}
return ipnet.Contains(ip)
}
func (c *Client) ContainsAddr(ip netip.Addr) bool {
return c.Prefix().Contains(ip)
}
func (c Client) Prefix() netip.Prefix {
ip, _ := netip.ParseAddr(c.IP)
return netip.PrefixFrom(ip, c.Mask)
}
func (c *Client) String() string {
ipnet := &net.IPNet{
IP: net.ParseIP(c.IP),
Mask: net.CIDRMask(int(c.Mask), 32),
}
return ipnet.String()
}
type Clients []Client
func (cs Clients) ByIP(ip net.IP) *Client {
var candidates []*Client
for _, c := range cs {
if c.ContainsIP(ip) {
candidates = append(candidates, &c)
}
}
switch len(candidates) {
case 0:
return nil
case 1:
return candidates[0]
default:
slices.SortStableFunc(candidates, func(a, b *Client) int {
return int(b.Mask) - int(a.Mask)
})
return candidates[0]
}
}
const (
ListTypeDomain = "domain"
ListTypeNetwork = "network"
)
const (
MinListRefresh = 1 * time.Minute
DefaultListRefresh = 30 * time.Minute
)
type List struct {
ID int64 `json:"id"`
Type string `json:"type"`
Source string `json:"source"`
IsEnabled bool `json:"is_enabled"`
Permit bool `json:"permit"`
Groups []Group `json:"groups,omitempty" bstore:"-"`
Status int `json:"status"`
Comment string `json:"comment"`
Cache []byte `json:"cache"`
Refresh time.Duration `json:"refresh"`
LastModified time.Time `json:"last_modified"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (list *List) Networks() (*NetworkTrie, error) {
if list.Type != ListTypeNetwork {
return nil, nil
}
prefixes, _, err := parser.ParseNetworks(bytes.NewReader(list.Cache))
if err != nil {
return nil, err
}
return NewNetworkTrie(prefixes...), nil
}
func (list *List) Domains() (*DomainTrie, error) {
if list.Type != ListTypeDomain {
return nil, nil
}
domains, _, err := parser.ParseDomains(bytes.NewReader(list.Cache))
if err != nil {
return nil, err
}
return NewDomainTrie(list.Permit, domains...)
}
func (list *List) Update() (updated bool, err error) {
u, err := url.Parse(list.Source)
if err != nil {
return false, err
}
switch u.Scheme {
case "", "file":
return list.updateFile(u.Path)
case "http", "https":
return list.updateHTTP(u.String())
default:
return false, fmt.Errorf("dataset: don't know how to update %s sources", u.Scheme)
}
}
func (list *List) updateFile(name string) (updated bool, err error) {
var info fs.FileInfo
if info, err = os.Stat(name); err != nil {
return
} else if info.IsDir() {
return false, fmt.Errorf("dataset: list %d: %q is a directory", list.ID, name)
}
if updated = info.ModTime().After(list.UpdatedAt); !updated {
return
}
list.Cache, _ = os.ReadFile(name)
return
}
func (list *List) updateHTTP(url string) (updated bool, err error) {
if updated, err = list.shouldUpdateHTTP(url); err != nil || !updated {
return
}
var response *http.Response
if response, err = http.DefaultClient.Get(url); err != nil {
return
}
defer response.Body.Close()
if list.Cache, err = io.ReadAll(response.Body); err != nil {
return
}
return true, nil
}
func (list *List) shouldUpdateHTTP(url string) (updated bool, err error) {
var response *http.Response
if response, err = http.DefaultClient.Head(url); err != nil {
return
}
defer response.Body.Close()
if value := response.Header.Get("Last-Modified"); value != "" {
var lastModified time.Time
if lastModified, err = time.Parse(http.TimeFormat, value); err == nil {
return lastModified.After(list.LastModified), nil
}
}
// There are no headers that would indicate last-modified time, so assume we have to update:
return true, nil
}
type ListGroup struct {
ID int64 `json:"id"`
ListID int64 `json:"list_id" bstore:"ref List,index"`
GroupID int64 `json:"group_id" bstore:"ref Group,index"`
}