Checkpoint
This commit is contained in:
231
dataset/storage.go
Normal file
231
dataset/storage.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package dataset
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite3 driver
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type Storage interface {
|
||||
Groups() ([]Group, error)
|
||||
GroupByID(int64) (Group, error)
|
||||
GroupByName(name string) (Group, error)
|
||||
SaveGroup(*Group) error
|
||||
DeleteGroup(Group) error
|
||||
|
||||
Clients() (Clients, error)
|
||||
ClientByID(int64) (Client, error)
|
||||
ClientByIP(net.IP) (Client, error)
|
||||
SaveClient(*Client) error
|
||||
DeleteClient(Client) error
|
||||
|
||||
Lists() ([]List, error)
|
||||
ListByID(int64) (List, error)
|
||||
SaveList(*List) error
|
||||
DeleteList(List) error
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name" bstore:"nonzero,unique"`
|
||||
IsEnabled bool `json:"is_enabled" bstore:"nonzero"`
|
||||
Description string `json:"description"`
|
||||
CreatedAt time.Time `json:"created_at" bstore:"nonzero"`
|
||||
UpdatedAt time.Time `json:"updated_at" bstore:"nonzero"`
|
||||
Storage Storage `json:"-" bstore:"-"`
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
ID int64 `json:"id"`
|
||||
Network string `json:"network" bstore:"nonzero,index"`
|
||||
IP string `json:"ip" bstore:"nonzero,unique IP+Mask"`
|
||||
Mask int `json:"mask"`
|
||||
Description string `json:"description"`
|
||||
Groups []Group `json:"groups,omitempty" bstore:"-"`
|
||||
CreatedAt time.Time `json:"created_at" bstore:"nonzero"`
|
||||
UpdatedAt time.Time `json:"updated_at" bstore:"nonzero"`
|
||||
Storage Storage `json:"-" bstore:"-"`
|
||||
}
|
||||
|
||||
type WithClient interface {
|
||||
Client() (Client, error)
|
||||
}
|
||||
|
||||
type ClientGroup struct {
|
||||
ID int64 `json:"id"`
|
||||
ClientID int64 `json:"client_id" bstore:"ref Client,index"`
|
||||
GroupID int64 `json:"group_id" bstore:"ref Group,index"`
|
||||
}
|
||||
|
||||
func (c *Client) ContainsIP(ip net.IP) bool {
|
||||
ipnet := &net.IPNet{
|
||||
IP: net.ParseIP(c.IP),
|
||||
Mask: net.CIDRMask(int(c.Mask), 32),
|
||||
}
|
||||
if ipnet.IP == nil {
|
||||
return false
|
||||
}
|
||||
return ipnet.Contains(ip)
|
||||
}
|
||||
|
||||
func (c *Client) String() string {
|
||||
ipnet := &net.IPNet{
|
||||
IP: net.ParseIP(c.IP),
|
||||
Mask: net.CIDRMask(int(c.Mask), 32),
|
||||
}
|
||||
return ipnet.String()
|
||||
}
|
||||
|
||||
type Clients []Client
|
||||
|
||||
func (cs Clients) ByIP(ip net.IP) *Client {
|
||||
var candidates []*Client
|
||||
for _, c := range cs {
|
||||
if c.ContainsIP(ip) {
|
||||
candidates = append(candidates, &c)
|
||||
}
|
||||
}
|
||||
switch len(candidates) {
|
||||
case 0:
|
||||
return nil
|
||||
case 1:
|
||||
return candidates[0]
|
||||
default:
|
||||
slices.SortStableFunc(candidates, func(a, b *Client) int {
|
||||
return int(b.Mask) - int(a.Mask)
|
||||
})
|
||||
return candidates[0]
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
ListTypeDomain = "domain"
|
||||
ListTypeNetwork = "network"
|
||||
)
|
||||
|
||||
const (
|
||||
MinListRefresh = 1 * time.Minute
|
||||
DefaultListRefresh = 30 * time.Minute
|
||||
)
|
||||
|
||||
type List struct {
|
||||
ID int64 `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Source string `json:"source"`
|
||||
IsEnabled bool `json:"is_enabled"`
|
||||
Permit bool `json:"permit"`
|
||||
Groups []Group `json:"groups,omitempty" bstore:"-"`
|
||||
Status int `json:"status"`
|
||||
Comment string `json:"comment"`
|
||||
Cache []byte `json:"cache"`
|
||||
Refresh time.Duration `json:"refresh"`
|
||||
LastModified time.Time `json:"last_modified"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
func (list *List) Domains() (*DomainTree, error) {
|
||||
if list.Type != ListTypeDomain {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var (
|
||||
tree = NewDomainList()
|
||||
scan = bufio.NewScanner(bytes.NewReader(list.Cache))
|
||||
)
|
||||
for scan.Scan() {
|
||||
line := strings.TrimSpace(scan.Text())
|
||||
if line == "" || line[0] == '#' {
|
||||
continue
|
||||
}
|
||||
if labels, ok := dns.IsDomainName(line); ok && labels >= 2 {
|
||||
tree.Add(line)
|
||||
}
|
||||
}
|
||||
if err := scan.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tree, nil
|
||||
}
|
||||
|
||||
func (list *List) Update() (updated bool, err error) {
|
||||
u, err := url.Parse(list.Source)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
switch u.Scheme {
|
||||
case "", "file":
|
||||
return list.updateFile(u.Path)
|
||||
case "http", "https":
|
||||
return list.updateHTTP(u.String())
|
||||
default:
|
||||
return false, fmt.Errorf("dataset: don't know how to update %s sources", u.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func (list *List) updateFile(name string) (updated bool, err error) {
|
||||
var info fs.FileInfo
|
||||
if info, err = os.Stat(name); err != nil {
|
||||
return
|
||||
} else if info.IsDir() {
|
||||
return false, fmt.Errorf("dataset: list %d: %q is a directory", list.ID, name)
|
||||
}
|
||||
if updated = info.ModTime().After(list.UpdatedAt); !updated {
|
||||
return
|
||||
}
|
||||
list.Cache, _ = os.ReadFile(name)
|
||||
return
|
||||
}
|
||||
|
||||
func (list *List) updateHTTP(url string) (updated bool, err error) {
|
||||
if updated, err = list.shouldUpdateHTTP(url); err != nil || !updated {
|
||||
return
|
||||
}
|
||||
|
||||
var response *http.Response
|
||||
if response, err = http.DefaultClient.Get(url); err != nil {
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if list.Cache, err = io.ReadAll(response.Body); err != nil {
|
||||
return
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (list *List) shouldUpdateHTTP(url string) (updated bool, err error) {
|
||||
var response *http.Response
|
||||
if response, err = http.DefaultClient.Head(url); err != nil {
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if value := response.Header.Get("Last-Modified"); value != "" {
|
||||
var lastModified time.Time
|
||||
if lastModified, err = time.Parse(http.TimeFormat, value); err == nil {
|
||||
return lastModified.After(list.LastModified), nil
|
||||
}
|
||||
}
|
||||
|
||||
// There are no headers that would indicate last-modified time, so assume we have to update:
|
||||
return true, nil
|
||||
}
|
||||
|
||||
type ListGroup struct {
|
||||
ID int64 `json:"id"`
|
||||
ListID int64 `json:"list_id" bstore:"ref List,index"`
|
||||
GroupID int64 `json:"group_id" bstore:"ref Group,index"`
|
||||
}
|
Reference in New Issue
Block a user