449 lines
11 KiB
Go
449 lines
11 KiB
Go
package dataset
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"path/filepath"
|
|
"slices"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.maze.io/maze/styx/logger"
|
|
"github.com/mjl-/bstore"
|
|
)
|
|
|
|
type bstoreStorage struct {
|
|
db *bstore.DB
|
|
path string
|
|
}
|
|
|
|
func OpenBStore(name string) (Storage, error) {
|
|
log := logger.StandardLog.Value("database", name)
|
|
|
|
if !filepath.IsAbs(name) {
|
|
var err error
|
|
if name, err = filepath.Abs(name); err != nil {
|
|
log.Err(err).Error("Opening BoltDB storage failed; invalid path")
|
|
return nil, err
|
|
}
|
|
log = log.Value("database", name)
|
|
}
|
|
|
|
log.Debug("Opening BoltDB storage")
|
|
ctx := context.Background()
|
|
db, err := bstore.Open(ctx, name, nil,
|
|
Group{},
|
|
Client{},
|
|
ClientGroup{},
|
|
List{},
|
|
ListGroup{},
|
|
)
|
|
if err != nil {
|
|
log.Err(err).Error("Opening BoltDB storage failed")
|
|
return nil, err
|
|
}
|
|
|
|
var (
|
|
s = &bstoreStorage{db: db, path: name}
|
|
defaultGroup Group
|
|
defaultClient4 Client
|
|
defaultClient6 Client
|
|
)
|
|
|
|
if defaultGroup, err = s.GroupByName("Default"); errors.Is(err, bstore.ErrAbsent) {
|
|
log.Debug("Creating default group")
|
|
defaultGroup = Group{
|
|
Name: "Default",
|
|
IsEnabled: true,
|
|
Description: "Default group",
|
|
}
|
|
if err = s.SaveGroup(&defaultGroup); err != nil {
|
|
return nil, err
|
|
}
|
|
} else if err != nil {
|
|
return nil, err
|
|
}
|
|
if defaultClient4, err = bstore.QueryDB[Client](ctx, db).
|
|
FilterEqual("Network", "ipv4").
|
|
FilterFn(func(client Client) bool {
|
|
return net.ParseIP(client.IP).Equal(net.ParseIP("0.0.0.0")) && client.Mask == 0
|
|
}).Get(); errors.Is(err, bstore.ErrAbsent) {
|
|
log.Debug("Creating default IPv4 clients")
|
|
defaultClient4 = Client{
|
|
Network: "ipv4",
|
|
IP: "0.0.0.0",
|
|
Mask: 0,
|
|
Description: "All IPv4 clients",
|
|
}
|
|
if err = s.SaveClient(&defaultClient4); err != nil {
|
|
return nil, err
|
|
}
|
|
if err = s.db.Insert(ctx, &ClientGroup{ClientID: defaultClient4.ID, GroupID: defaultGroup.ID}); err != nil {
|
|
return nil, err
|
|
}
|
|
} else if err != nil {
|
|
return nil, err
|
|
}
|
|
if defaultClient6, err = bstore.QueryDB[Client](ctx, db).
|
|
FilterEqual("Network", "ipv6").
|
|
FilterFn(func(client Client) bool {
|
|
return net.ParseIP(client.IP).Equal(net.ParseIP("::")) && client.Mask == 0
|
|
}).Get(); errors.Is(err, bstore.ErrAbsent) {
|
|
log.Debug("Creating default IPv6 clients")
|
|
defaultClient6 = Client{
|
|
Network: "ipv6",
|
|
IP: "::",
|
|
Mask: 0,
|
|
Description: "All IPv6 clients",
|
|
}
|
|
if err = s.SaveClient(&defaultClient6); err != nil {
|
|
return nil, err
|
|
}
|
|
if err = s.db.Insert(ctx, &ClientGroup{ClientID: defaultClient6.ID, GroupID: defaultGroup.ID}); err != nil {
|
|
return nil, err
|
|
}
|
|
} else if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Start updater
|
|
log.Trace("Starting list updater")
|
|
NewUpdater(s)
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (s *bstoreStorage) log() logger.Structured {
|
|
return logger.StandardLog.Values(logger.Values{
|
|
"storage": "bstore",
|
|
"storage_path": s.path,
|
|
})
|
|
}
|
|
|
|
func (s *bstoreStorage) Groups() ([]Group, error) {
|
|
var (
|
|
ctx = context.Background()
|
|
query = bstore.QueryDB[Group](ctx, s.db)
|
|
groups = make([]Group, 0)
|
|
)
|
|
for group := range query.All() {
|
|
groups = append(groups, group)
|
|
}
|
|
if err := query.Err(); err != nil && !errors.Is(err, bstore.ErrFinished) {
|
|
return nil, err
|
|
}
|
|
return groups, nil
|
|
}
|
|
|
|
func (s *bstoreStorage) GroupByID(id int64) (Group, error) {
|
|
ctx := context.Background()
|
|
return bstore.QueryDB[Group](ctx, s.db).FilterID(id).Get()
|
|
}
|
|
|
|
func (s *bstoreStorage) GroupByName(name string) (Group, error) {
|
|
ctx := context.Background()
|
|
return bstore.QueryDB[Group](ctx, s.db).FilterFn(func(group Group) bool {
|
|
return strings.EqualFold(group.Name, name)
|
|
}).Get()
|
|
}
|
|
|
|
func (s *bstoreStorage) SaveGroup(group *Group) (err error) {
|
|
ctx := context.Background()
|
|
group.UpdatedAt = time.Now().UTC()
|
|
if group.CreatedAt.Equal(time.Time{}) {
|
|
group.CreatedAt = group.UpdatedAt
|
|
err = s.db.Insert(ctx, group)
|
|
} else {
|
|
err = s.db.Update(ctx, group)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("dataset: save group %s failed: %w", group.Name, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *bstoreStorage) DeleteGroup(group Group) (err error) {
|
|
ctx := context.Background()
|
|
tx, err := s.db.Begin(ctx, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, err = bstore.QueryTx[ClientGroup](tx).FilterEqual("GroupID", group.ID).Delete(); err != nil {
|
|
return
|
|
}
|
|
if _, err = bstore.QueryTx[ListGroup](tx).FilterEqual("GroupID", group.ID).Delete(); err != nil {
|
|
return
|
|
}
|
|
if err = tx.Delete(group); err != nil {
|
|
return
|
|
}
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (s *bstoreStorage) Clients() (Clients, error) {
|
|
var (
|
|
ctx = context.Background()
|
|
query = bstore.QueryDB[Client](ctx, s.db)
|
|
clients = make(Clients, 0)
|
|
)
|
|
for client := range query.All() {
|
|
clients = append(clients, client)
|
|
}
|
|
if err := query.Err(); err != nil && !errors.Is(err, bstore.ErrFinished) {
|
|
return nil, err
|
|
}
|
|
return clients, nil
|
|
}
|
|
|
|
func (s *bstoreStorage) ClientByID(id int64) (Client, error) {
|
|
ctx := context.Background()
|
|
client, err := bstore.QueryDB[Client](ctx, s.db).FilterID(id).Get()
|
|
if err != nil {
|
|
return client, err
|
|
}
|
|
return s.clientResolveGroups(ctx, client)
|
|
}
|
|
|
|
func (s *bstoreStorage) ClientByIP(ip net.IP) (Client, error) {
|
|
addr, _ := netip.AddrFromSlice(ip)
|
|
return s.ClientByAddr(addr)
|
|
}
|
|
|
|
func (s *bstoreStorage) ClientByAddr(addr netip.Addr) (Client, error) {
|
|
if !addr.IsValid() {
|
|
return Client{}, ErrNotExist{Object: "client"}
|
|
}
|
|
var (
|
|
ctx = context.Background()
|
|
clients Clients
|
|
network string
|
|
)
|
|
if addr.Is4() {
|
|
network = "ipv4"
|
|
} else {
|
|
network = "ipv6"
|
|
}
|
|
if network == "" {
|
|
return Client{}, ErrNotExist{Object: "client"}
|
|
}
|
|
for client, err := range bstore.QueryDB[Client](ctx, s.db).
|
|
FilterEqual("Network", network).
|
|
FilterFn(func(client Client) bool {
|
|
return client.ContainsAddr(addr)
|
|
}).All() {
|
|
if err != nil {
|
|
return Client{}, err
|
|
}
|
|
clients = append(clients, client)
|
|
}
|
|
|
|
var client Client
|
|
switch len(clients) {
|
|
case 0:
|
|
return Client{}, ErrNotExist{Object: "client"}
|
|
case 1:
|
|
client = clients[0]
|
|
default:
|
|
slices.SortStableFunc(clients, func(a, b Client) int {
|
|
return int(b.Mask) - int(a.Mask)
|
|
})
|
|
client = clients[0]
|
|
}
|
|
return s.clientResolveGroups(ctx, client)
|
|
}
|
|
|
|
func (s *bstoreStorage) clientResolveGroups(ctx context.Context, client Client) (Client, error) {
|
|
for clientGroup, err := range bstore.QueryDB[ClientGroup](ctx, s.db).FilterEqual("ClientID", client.ID).All() {
|
|
if err != nil {
|
|
return Client{}, err
|
|
}
|
|
if group, err := s.GroupByID(clientGroup.GroupID); err == nil {
|
|
client.Groups = append(client.Groups, group)
|
|
}
|
|
}
|
|
return client, nil
|
|
}
|
|
|
|
func (s *bstoreStorage) SaveClient(client *Client) (err error) {
|
|
log := s.log()
|
|
ctx := context.Background()
|
|
client.UpdatedAt = time.Now().UTC()
|
|
|
|
tx, err := s.db.Begin(ctx, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
log = log.Values(logger.Values{"ip": client.IP, "mask": client.Mask, "description": client.Description})
|
|
if client.CreatedAt.Equal(time.Time{}) {
|
|
log.Debug("Create client")
|
|
client.CreatedAt = client.UpdatedAt
|
|
if err = tx.Insert(client); err != nil {
|
|
return fmt.Errorf("dataset: client insert failed: %w", err)
|
|
}
|
|
} else {
|
|
log.Debug("Update client")
|
|
if err = tx.Update(client); err != nil {
|
|
return fmt.Errorf("dataset: client update failed: %w", err)
|
|
}
|
|
}
|
|
|
|
var deleted int
|
|
if deleted, err = bstore.QueryTx[ClientGroup](tx).FilterEqual("ClientID", client.ID).Delete(); err != nil {
|
|
return fmt.Errorf("dataset: client groups delete failed: %w", err)
|
|
}
|
|
log.Debugf("Deleted %d groups", deleted)
|
|
log.Debugf("Linking %d groups", len(client.Groups))
|
|
for _, group := range client.Groups {
|
|
if err = tx.Insert(&ClientGroup{ClientID: client.ID, GroupID: group.ID}); err != nil {
|
|
return fmt.Errorf("dataset: client groups insert failed: %w", err)
|
|
}
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (s *bstoreStorage) DeleteClient(client Client) (err error) {
|
|
ctx := context.Background()
|
|
tx, err := s.db.Begin(ctx, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, err = bstore.QueryTx[ClientGroup](tx).FilterEqual("ClientID", client.ID).Delete(); err != nil {
|
|
return
|
|
}
|
|
if err = tx.Delete(client); err != nil {
|
|
return
|
|
}
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (s *bstoreStorage) Lists() ([]List, error) {
|
|
var (
|
|
ctx = context.Background()
|
|
query = bstore.QueryDB[List](ctx, s.db)
|
|
lists = make([]List, 0)
|
|
)
|
|
for list := range query.All() {
|
|
lists = append(lists, list)
|
|
}
|
|
if err := query.Err(); err != nil && !errors.Is(err, bstore.ErrFinished) {
|
|
return nil, err
|
|
}
|
|
return lists, nil
|
|
}
|
|
|
|
func (s *bstoreStorage) ListsByGroup(group Group) ([]List, error) {
|
|
ctx := context.Background()
|
|
ids := make([]int64, 0)
|
|
for item, err := range bstore.QueryDB[ListGroup](ctx, s.db).FilterEqual("GroupID", group.ID).All() {
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ids = append(ids, item.ListID)
|
|
}
|
|
|
|
var lists []List
|
|
for list, err := range bstore.QueryDB[List](ctx, s.db).FilterIDs(ids).All() {
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
lists = append(lists, list)
|
|
}
|
|
return lists, nil
|
|
}
|
|
|
|
func (s *bstoreStorage) ListByID(id int64) (List, error) {
|
|
ctx := context.Background()
|
|
list, err := bstore.QueryDB[List](ctx, s.db).FilterID(id).Get()
|
|
if err != nil {
|
|
return list, err
|
|
}
|
|
return s.listResolveGroups(ctx, list)
|
|
}
|
|
|
|
func (s *bstoreStorage) listResolveGroups(ctx context.Context, list List) (List, error) {
|
|
for listGroup, err := range bstore.QueryDB[ListGroup](ctx, s.db).FilterEqual("ListID", list.ID).All() {
|
|
if err != nil {
|
|
return List{}, err
|
|
}
|
|
if group, err := s.GroupByID(listGroup.GroupID); err == nil {
|
|
list.Groups = append(list.Groups, group)
|
|
}
|
|
}
|
|
return list, nil
|
|
}
|
|
|
|
func (s *bstoreStorage) SaveList(list *List) (err error) {
|
|
if list.Type != ListTypeDomain && list.Type != ListTypeNetwork {
|
|
return fmt.Errorf("storage: unknown list type %q", list.Type)
|
|
}
|
|
if list.Refresh == 0 {
|
|
list.Refresh = DefaultListRefresh
|
|
} else if list.Refresh < MinListRefresh {
|
|
list.Refresh = MinListRefresh
|
|
}
|
|
list.UpdatedAt = time.Now().UTC()
|
|
|
|
ctx := context.Background()
|
|
tx, err := s.db.Begin(ctx, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
log := s.log()
|
|
log = log.Values(logger.Values{
|
|
"type": list.Type,
|
|
"source": list.Source,
|
|
"is_enabled": list.IsEnabled,
|
|
"status": list.Status,
|
|
"cache": len(list.Cache),
|
|
"refresh": list.Refresh,
|
|
})
|
|
|
|
if list.CreatedAt.Equal(time.Time{}) {
|
|
log.Debug("Creating list")
|
|
list.CreatedAt = list.UpdatedAt
|
|
if err = tx.Insert(list); err != nil {
|
|
return fmt.Errorf("dataset: list insert failed: %w", err)
|
|
}
|
|
} else {
|
|
log.Debug("Updating list")
|
|
if err = tx.Update(list); err != nil {
|
|
return fmt.Errorf("dataset: list update failed: %w", err)
|
|
}
|
|
}
|
|
|
|
var deleted int
|
|
if deleted, err = bstore.QueryTx[ListGroup](tx).FilterEqual("ListID", list.ID).Delete(); err != nil {
|
|
return fmt.Errorf("dataset: list groups delete failed: %w", err)
|
|
}
|
|
log.Debugf("Deleted %d groups", deleted)
|
|
log.Debugf("Linking %d groups", len(list.Groups))
|
|
for _, group := range list.Groups {
|
|
if err = tx.Insert(&ListGroup{ListID: list.ID, GroupID: group.ID}); err != nil {
|
|
return fmt.Errorf("dataset: list groups insert failed: %w", err)
|
|
}
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (s *bstoreStorage) DeleteList(list List) (err error) {
|
|
ctx := context.Background()
|
|
tx, err := s.db.Begin(ctx, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, err = bstore.QueryTx[ListGroup](tx).FilterEqual("ListID", list.ID).Delete(); err != nil {
|
|
return
|
|
}
|
|
if err = tx.Delete(list); err != nil {
|
|
return
|
|
}
|
|
return tx.Commit()
|
|
}
|