package dataset import ( "bytes" "fmt" "io" "io/fs" "net" "net/http" "net/netip" "net/url" "os" "slices" "time" "git.maze.io/maze/styx/dataset/parser" _ "github.com/mattn/go-sqlite3" // SQLite3 driver ) type Storage interface { Groups() ([]Group, error) GroupByID(int64) (Group, error) GroupByName(name string) (Group, error) SaveGroup(*Group) error DeleteGroup(Group) error Clients() (Clients, error) ClientByID(int64) (Client, error) ClientByAddr(netip.Addr) (Client, error) // ClientByIP(net.IP) (Client, error) SaveClient(*Client) error DeleteClient(Client) error Lists() ([]List, error) ListsByGroup(Group) ([]List, error) ListByID(int64) (List, error) SaveList(*List) error DeleteList(List) error } type Group struct { ID int64 `json:"id"` Name string `json:"name" bstore:"nonzero,unique"` IsEnabled bool `json:"is_enabled" bstore:"nonzero"` Description string `json:"description"` CreatedAt time.Time `json:"created_at" bstore:"nonzero"` UpdatedAt time.Time `json:"updated_at" bstore:"nonzero"` } type Client struct { ID int64 `json:"id"` Network string `json:"network" bstore:"nonzero,index"` IP string `json:"ip" bstore:"nonzero,unique IP+Mask"` Mask int `json:"mask"` Description string `json:"description"` Groups []Group `json:"groups,omitempty" bstore:"-"` CreatedAt time.Time `json:"created_at" bstore:"nonzero"` UpdatedAt time.Time `json:"updated_at" bstore:"nonzero"` } type ClientGroup struct { ID int64 `json:"id"` ClientID int64 `json:"client_id" bstore:"ref Client,index"` GroupID int64 `json:"group_id" bstore:"ref Group,index"` } func (c *Client) ContainsIP(ip net.IP) bool { ipnet := &net.IPNet{ IP: net.ParseIP(c.IP), Mask: net.CIDRMask(int(c.Mask), 32), } if ipnet.IP == nil { return false } return ipnet.Contains(ip) } func (c *Client) ContainsAddr(ip netip.Addr) bool { return c.Prefix().Contains(ip) } func (c Client) Prefix() netip.Prefix { ip, _ := netip.ParseAddr(c.IP) return netip.PrefixFrom(ip, c.Mask) } func (c *Client) String() string { ipnet := &net.IPNet{ IP: net.ParseIP(c.IP), Mask: net.CIDRMask(int(c.Mask), 32), } return ipnet.String() } type Clients []Client func (cs Clients) ByIP(ip net.IP) *Client { var candidates []*Client for _, c := range cs { if c.ContainsIP(ip) { candidates = append(candidates, &c) } } switch len(candidates) { case 0: return nil case 1: return candidates[0] default: slices.SortStableFunc(candidates, func(a, b *Client) int { return int(b.Mask) - int(a.Mask) }) return candidates[0] } } const ( ListTypeDomain = "domain" ListTypeNetwork = "network" ) const ( MinListRefresh = 1 * time.Minute DefaultListRefresh = 30 * time.Minute ) type List struct { ID int64 `json:"id"` Type string `json:"type"` Source string `json:"source"` IsEnabled bool `json:"is_enabled"` Permit bool `json:"permit"` Groups []Group `json:"groups,omitempty" bstore:"-"` Status int `json:"status"` Comment string `json:"comment"` Cache []byte `json:"cache"` Refresh time.Duration `json:"refresh"` LastModified time.Time `json:"last_modified"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } func (list *List) Networks() (*NetworkTrie, error) { if list.Type != ListTypeNetwork { return nil, nil } prefixes, _, err := parser.ParseNetworks(bytes.NewReader(list.Cache)) if err != nil { return nil, err } return NewNetworkTrie(prefixes...), nil } func (list *List) Domains() (*DomainTrie, error) { if list.Type != ListTypeDomain { return nil, nil } domains, _, err := parser.ParseDomains(bytes.NewReader(list.Cache)) if err != nil { return nil, err } return NewDomainTrie(list.Permit, domains...) } func (list *List) Update() (updated bool, err error) { u, err := url.Parse(list.Source) if err != nil { return false, err } switch u.Scheme { case "", "file": return list.updateFile(u.Path) case "http", "https": return list.updateHTTP(u.String()) default: return false, fmt.Errorf("dataset: don't know how to update %s sources", u.Scheme) } } func (list *List) updateFile(name string) (updated bool, err error) { var info fs.FileInfo if info, err = os.Stat(name); err != nil { return } else if info.IsDir() { return false, fmt.Errorf("dataset: list %d: %q is a directory", list.ID, name) } if updated = info.ModTime().After(list.UpdatedAt); !updated { return } list.Cache, _ = os.ReadFile(name) return } func (list *List) updateHTTP(url string) (updated bool, err error) { if updated, err = list.shouldUpdateHTTP(url); err != nil || !updated { return } var response *http.Response if response, err = http.DefaultClient.Get(url); err != nil { return } defer response.Body.Close() if list.Cache, err = io.ReadAll(response.Body); err != nil { return } return true, nil } func (list *List) shouldUpdateHTTP(url string) (updated bool, err error) { var response *http.Response if response, err = http.DefaultClient.Head(url); err != nil { return } defer response.Body.Close() if value := response.Header.Get("Last-Modified"); value != "" { var lastModified time.Time if lastModified, err = time.Parse(http.TimeFormat, value); err == nil { return lastModified.After(list.LastModified), nil } } // There are no headers that would indicate last-modified time, so assume we have to update: return true, nil } type ListGroup struct { ID int64 `json:"id"` ListID int64 `json:"list_id" bstore:"ref List,index"` GroupID int64 `json:"group_id" bstore:"ref Group,index"` }