package dataset import ( "context" "errors" "fmt" "net" "net/netip" "path/filepath" "slices" "strings" "time" "git.maze.io/maze/styx/logger" "github.com/mjl-/bstore" ) type bstoreStorage struct { db *bstore.DB path string } func OpenBStore(name string) (Storage, error) { log := logger.StandardLog.Value("database", name) if !filepath.IsAbs(name) { var err error if name, err = filepath.Abs(name); err != nil { log.Err(err).Error("Opening BoltDB storage failed; invalid path") return nil, err } log = log.Value("database", name) } log.Debug("Opening BoltDB storage") ctx := context.Background() db, err := bstore.Open(ctx, name, nil, Group{}, Client{}, ClientGroup{}, List{}, ListGroup{}, ) if err != nil { log.Err(err).Error("Opening BoltDB storage failed") return nil, err } var ( s = &bstoreStorage{db: db, path: name} defaultGroup Group defaultClient4 Client defaultClient6 Client ) if defaultGroup, err = s.GroupByName("Default"); errors.Is(err, bstore.ErrAbsent) { log.Debug("Creating default group") defaultGroup = Group{ Name: "Default", IsEnabled: true, Description: "Default group", } if err = s.SaveGroup(&defaultGroup); err != nil { return nil, err } } else if err != nil { return nil, err } if defaultClient4, err = bstore.QueryDB[Client](ctx, db). FilterEqual("Network", "ipv4"). FilterFn(func(client Client) bool { return net.ParseIP(client.IP).Equal(net.ParseIP("0.0.0.0")) && client.Mask == 0 }).Get(); errors.Is(err, bstore.ErrAbsent) { log.Debug("Creating default IPv4 clients") defaultClient4 = Client{ Network: "ipv4", IP: "0.0.0.0", Mask: 0, Description: "All IPv4 clients", } if err = s.SaveClient(&defaultClient4); err != nil { return nil, err } if err = s.db.Insert(ctx, &ClientGroup{ClientID: defaultClient4.ID, GroupID: defaultGroup.ID}); err != nil { return nil, err } } else if err != nil { return nil, err } if defaultClient6, err = bstore.QueryDB[Client](ctx, db). FilterEqual("Network", "ipv6"). FilterFn(func(client Client) bool { return net.ParseIP(client.IP).Equal(net.ParseIP("::")) && client.Mask == 0 }).Get(); errors.Is(err, bstore.ErrAbsent) { log.Debug("Creating default IPv6 clients") defaultClient6 = Client{ Network: "ipv6", IP: "::", Mask: 0, Description: "All IPv6 clients", } if err = s.SaveClient(&defaultClient6); err != nil { return nil, err } if err = s.db.Insert(ctx, &ClientGroup{ClientID: defaultClient6.ID, GroupID: defaultGroup.ID}); err != nil { return nil, err } } else if err != nil { return nil, err } // Start updater log.Trace("Starting list updater") NewUpdater(s) return s, nil } func (s *bstoreStorage) log() logger.Structured { return logger.StandardLog.Values(logger.Values{ "storage": "bstore", "storage_path": s.path, }) } func (s *bstoreStorage) Groups() ([]Group, error) { var ( ctx = context.Background() query = bstore.QueryDB[Group](ctx, s.db) groups = make([]Group, 0) ) for group := range query.All() { groups = append(groups, group) } if err := query.Err(); err != nil && !errors.Is(err, bstore.ErrFinished) { return nil, err } return groups, nil } func (s *bstoreStorage) GroupByID(id int64) (Group, error) { ctx := context.Background() return bstore.QueryDB[Group](ctx, s.db).FilterID(id).Get() } func (s *bstoreStorage) GroupByName(name string) (Group, error) { ctx := context.Background() return bstore.QueryDB[Group](ctx, s.db).FilterFn(func(group Group) bool { return strings.EqualFold(group.Name, name) }).Get() } func (s *bstoreStorage) SaveGroup(group *Group) (err error) { ctx := context.Background() group.UpdatedAt = time.Now().UTC() if group.CreatedAt.Equal(time.Time{}) { group.CreatedAt = group.UpdatedAt err = s.db.Insert(ctx, group) } else { err = s.db.Update(ctx, group) } if err != nil { return fmt.Errorf("dataset: save group %s failed: %w", group.Name, err) } return nil } func (s *bstoreStorage) DeleteGroup(group Group) (err error) { ctx := context.Background() tx, err := s.db.Begin(ctx, true) if err != nil { return err } if _, err = bstore.QueryTx[ClientGroup](tx).FilterEqual("GroupID", group.ID).Delete(); err != nil { return } if _, err = bstore.QueryTx[ListGroup](tx).FilterEqual("GroupID", group.ID).Delete(); err != nil { return } if err = tx.Delete(group); err != nil { return } return tx.Commit() } func (s *bstoreStorage) Clients() (Clients, error) { var ( ctx = context.Background() query = bstore.QueryDB[Client](ctx, s.db) clients = make(Clients, 0) ) for client := range query.All() { clients = append(clients, client) } if err := query.Err(); err != nil && !errors.Is(err, bstore.ErrFinished) { return nil, err } return clients, nil } func (s *bstoreStorage) ClientByID(id int64) (Client, error) { ctx := context.Background() client, err := bstore.QueryDB[Client](ctx, s.db).FilterID(id).Get() if err != nil { return client, err } return s.clientResolveGroups(ctx, client) } func (s *bstoreStorage) ClientByIP(ip net.IP) (Client, error) { addr, _ := netip.AddrFromSlice(ip) return s.ClientByAddr(addr) } func (s *bstoreStorage) ClientByAddr(addr netip.Addr) (Client, error) { if !addr.IsValid() { return Client{}, ErrNotExist{Object: "client"} } var ( ctx = context.Background() clients Clients network string ) if addr.Is4() { network = "ipv4" } else { network = "ipv6" } if network == "" { return Client{}, ErrNotExist{Object: "client"} } for client, err := range bstore.QueryDB[Client](ctx, s.db). FilterEqual("Network", network). FilterFn(func(client Client) bool { return client.ContainsAddr(addr) }).All() { if err != nil { return Client{}, err } clients = append(clients, client) } var client Client switch len(clients) { case 0: return Client{}, ErrNotExist{Object: "client"} case 1: client = clients[0] default: slices.SortStableFunc(clients, func(a, b Client) int { return int(b.Mask) - int(a.Mask) }) client = clients[0] } return s.clientResolveGroups(ctx, client) } func (s *bstoreStorage) clientResolveGroups(ctx context.Context, client Client) (Client, error) { for clientGroup, err := range bstore.QueryDB[ClientGroup](ctx, s.db).FilterEqual("ClientID", client.ID).All() { if err != nil { return Client{}, err } if group, err := s.GroupByID(clientGroup.GroupID); err == nil { client.Groups = append(client.Groups, group) } } return client, nil } func (s *bstoreStorage) SaveClient(client *Client) (err error) { log := s.log() ctx := context.Background() client.UpdatedAt = time.Now().UTC() tx, err := s.db.Begin(ctx, true) if err != nil { return err } log = log.Values(logger.Values{"ip": client.IP, "mask": client.Mask, "description": client.Description}) if client.CreatedAt.Equal(time.Time{}) { log.Debug("Create client") client.CreatedAt = client.UpdatedAt if err = tx.Insert(client); err != nil { return fmt.Errorf("dataset: client insert failed: %w", err) } } else { log.Debug("Update client") if err = tx.Update(client); err != nil { return fmt.Errorf("dataset: client update failed: %w", err) } } var deleted int if deleted, err = bstore.QueryTx[ClientGroup](tx).FilterEqual("ClientID", client.ID).Delete(); err != nil { return fmt.Errorf("dataset: client groups delete failed: %w", err) } log.Debugf("Deleted %d groups", deleted) log.Debugf("Linking %d groups", len(client.Groups)) for _, group := range client.Groups { if err = tx.Insert(&ClientGroup{ClientID: client.ID, GroupID: group.ID}); err != nil { return fmt.Errorf("dataset: client groups insert failed: %w", err) } } return tx.Commit() } func (s *bstoreStorage) DeleteClient(client Client) (err error) { ctx := context.Background() tx, err := s.db.Begin(ctx, true) if err != nil { return err } if _, err = bstore.QueryTx[ClientGroup](tx).FilterEqual("ClientID", client.ID).Delete(); err != nil { return } if err = tx.Delete(client); err != nil { return } return tx.Commit() } func (s *bstoreStorage) Lists() ([]List, error) { var ( ctx = context.Background() query = bstore.QueryDB[List](ctx, s.db) lists = make([]List, 0) ) for list := range query.All() { lists = append(lists, list) } if err := query.Err(); err != nil && !errors.Is(err, bstore.ErrFinished) { return nil, err } return lists, nil } func (s *bstoreStorage) ListsByGroup(group Group) ([]List, error) { ctx := context.Background() ids := make([]int64, 0) for item, err := range bstore.QueryDB[ListGroup](ctx, s.db).FilterEqual("GroupID", group.ID).All() { if err != nil { return nil, err } ids = append(ids, item.ListID) } var lists []List for list, err := range bstore.QueryDB[List](ctx, s.db).FilterIDs(ids).All() { if err != nil { return nil, err } lists = append(lists, list) } return lists, nil } func (s *bstoreStorage) ListByID(id int64) (List, error) { ctx := context.Background() list, err := bstore.QueryDB[List](ctx, s.db).FilterID(id).Get() if err != nil { return list, err } return s.listResolveGroups(ctx, list) } func (s *bstoreStorage) listResolveGroups(ctx context.Context, list List) (List, error) { for listGroup, err := range bstore.QueryDB[ListGroup](ctx, s.db).FilterEqual("ListID", list.ID).All() { if err != nil { return List{}, err } if group, err := s.GroupByID(listGroup.GroupID); err == nil { list.Groups = append(list.Groups, group) } } return list, nil } func (s *bstoreStorage) SaveList(list *List) (err error) { if list.Type != ListTypeDomain && list.Type != ListTypeNetwork { return fmt.Errorf("storage: unknown list type %q", list.Type) } if list.Refresh == 0 { list.Refresh = DefaultListRefresh } else if list.Refresh < MinListRefresh { list.Refresh = MinListRefresh } list.UpdatedAt = time.Now().UTC() ctx := context.Background() tx, err := s.db.Begin(ctx, true) if err != nil { return err } log := s.log() log = log.Values(logger.Values{ "type": list.Type, "source": list.Source, "is_enabled": list.IsEnabled, "status": list.Status, "cache": len(list.Cache), "refresh": list.Refresh, }) if list.CreatedAt.Equal(time.Time{}) { log.Debug("Creating list") list.CreatedAt = list.UpdatedAt if err = tx.Insert(list); err != nil { return fmt.Errorf("dataset: list insert failed: %w", err) } } else { log.Debug("Updating list") if err = tx.Update(list); err != nil { return fmt.Errorf("dataset: list update failed: %w", err) } } var deleted int if deleted, err = bstore.QueryTx[ListGroup](tx).FilterEqual("ListID", list.ID).Delete(); err != nil { return fmt.Errorf("dataset: list groups delete failed: %w", err) } log.Debugf("Deleted %d groups", deleted) log.Debugf("Linking %d groups", len(list.Groups)) for _, group := range list.Groups { if err = tx.Insert(&ListGroup{ListID: list.ID, GroupID: group.ID}); err != nil { return fmt.Errorf("dataset: list groups insert failed: %w", err) } } return tx.Commit() } func (s *bstoreStorage) DeleteList(list List) (err error) { ctx := context.Background() tx, err := s.db.Begin(ctx, true) if err != nil { return err } if _, err = bstore.QueryTx[ListGroup](tx).FilterEqual("ListID", list.ID).Delete(); err != nil { return } if err = tx.Delete(list); err != nil { return } return tx.Commit() }