From 5fde76bb20f818f052fe89dc90c2b3ea790da4d2 Mon Sep 17 00:00:00 2001 From: Stanislav Chzhen Date: Fri, 21 Jun 2024 16:58:17 +0300 Subject: [PATCH] all: imp code --- internal/client/persistent.go | 5 +- internal/client/persistent_internal_test.go | 2 +- internal/client/storage.go | 96 +++++++++++-- internal/client/storage_test.go | 13 +- internal/home/clients.go | 147 +++++++++++++++++--- internal/home/clients_internal_test.go | 55 ++++++-- internal/home/clientshttp.go | 32 +++-- internal/home/clientshttp_internal_test.go | 10 +- internal/home/dns_internal_test.go | 23 +-- 9 files changed, 296 insertions(+), 87 deletions(-) diff --git a/internal/client/persistent.go b/internal/client/persistent.go index 1c2ce848223..229efe61038 100644 --- a/internal/client/persistent.go +++ b/internal/client/persistent.go @@ -70,6 +70,7 @@ type Persistent struct { // BlockedServices is the configuration of blocked services of a client. BlockedServices *filtering.BlockedServices + // Name of the persistent client. Must not be empty. Name string Tags []string @@ -99,8 +100,8 @@ type Persistent struct { SafeSearchConf filtering.SafeSearchConfig } -// Validate returns an error if persistent client information contains errors. -func (c *Persistent) Validate(allTags *container.MapSet[string]) (err error) { +// validate returns an error if persistent client information contains errors. +func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) { switch { case c.Name == "": return errors.Error("empty name") diff --git a/internal/client/persistent_internal_test.go b/internal/client/persistent_internal_test.go index 75fad1277f6..89190285184 100644 --- a/internal/client/persistent_internal_test.go +++ b/internal/client/persistent_internal_test.go @@ -166,7 +166,7 @@ func TestPersistent_Validate(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - err := tc.cli.Validate(nil) + err := tc.cli.validate(nil) testutil.AssertErrorMsg(t, tc.wantErrMsg, err) }) } diff --git a/internal/client/storage.go b/internal/client/storage.go index 59be5c3661b..d9abc529596 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -6,32 +6,46 @@ import ( "net/netip" "sync" + "github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" ) // Storage contains information about persistent and runtime clients. type Storage struct { - // mu protects index of persistent clients. + // allowedTags is a set of all allowed tags. + allowedTags *container.MapSet[string] + + // mu protects indexes of persistent and runtime clients. mu *sync.Mutex // index contains information about persistent clients. index *Index + + // runtimeIndex contains information about runtime clients. + runtimeIndex *RuntimeIndex } // NewStorage returns initialized client storage. -func NewStorage() (s *Storage) { +func NewStorage(allowedTags *container.MapSet[string]) (s *Storage) { return &Storage{ - mu: &sync.Mutex{}, - index: NewIndex(), + allowedTags: allowedTags, + mu: &sync.Mutex{}, + index: NewIndex(), + runtimeIndex: NewRuntimeIndex(), } } -// Add stores persistent client information or returns an error. p must be -// valid persistent client. See [Persistent.Validate]. +// Add stores persistent client information or returns an error. func (s *Storage) Add(p *Persistent) (err error) { defer func() { err = errors.Annotate(err, "adding client: %w") }() + err = p.validate(s.allowedTags) + if err != nil { + // Don't wrap the error since there is already an annotation deferred. + return err + } + s.mu.Lock() defer s.mu.Unlock() @@ -129,11 +143,16 @@ func (s *Storage) RemoveByName(name string) (ok bool) { } // Update finds the stored persistent client by its name and updates its -// information from n. n must be valid persistent client. See -// [Persistent.Validate]. -func (s *Storage) Update(name string, n *Persistent) (err error) { +// information from p. +func (s *Storage) Update(name string, p *Persistent) (err error) { defer func() { err = errors.Annotate(err, "updating client: %w") }() + err = p.validate(s.allowedTags) + if err != nil { + // Don't wrap the error since there is already an annotation deferred. + return err + } + s.mu.Lock() defer s.mu.Unlock() @@ -142,19 +161,19 @@ func (s *Storage) Update(name string, n *Persistent) (err error) { return fmt.Errorf("client %q is not found", name) } - // Client n has a newly generated UID, so replace it with the stored one. + // Client p has a newly generated UID, so replace it with the stored one. // // TODO(s.chzhen): Remove when frontend starts handling UIDs. - n.UID = stored.UID + p.UID = stored.UID - err = s.index.Clashes(n) + err = s.index.Clashes(p) if err != nil { // Don't wrap the error since there is already an annotation deferred. return err } s.index.Delete(stored) - s.index.Add(n) + s.index.Add(p) return nil } @@ -183,3 +202,54 @@ func (s *Storage) CloseUpstreams() (err error) { return s.index.CloseUpstreams() } + +// ClientRuntime returns a copy of the saved runtime client by ip. If no such +// client exists, returns nil. +func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) { + s.mu.Lock() + defer s.mu.Unlock() + + return s.runtimeIndex.Client(ip) +} + +// AddRuntime saves the runtime client information in the storage. IP address +// of a client must be unique. rc must not be nil. +func (s *Storage) AddRuntime(rc *Runtime) { + s.mu.Lock() + defer s.mu.Unlock() + + s.runtimeIndex.Add(rc) +} + +// SizeRuntime returns the number of the runtime clients. +func (s *Storage) SizeRuntime() (n int) { + s.mu.Lock() + defer s.mu.Unlock() + + return s.runtimeIndex.Size() +} + +// RangeRuntime calls f for each runtime client in an undefined order. +func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) { + s.mu.Lock() + defer s.mu.Unlock() + + s.runtimeIndex.Range(f) +} + +// DeleteRuntime removes the runtime client by ip. +func (s *Storage) DeleteRuntime(ip netip.Addr) { + s.mu.Lock() + defer s.mu.Unlock() + + s.runtimeIndex.Delete(ip) +} + +// DeleteBySource removes all runtime clients that have information only from +// the specified source and returns the number of removed clients. +func (s *Storage) DeleteBySource(src Source) (n int) { + s.mu.Lock() + defer s.mu.Unlock() + + return s.runtimeIndex.DeleteBySource(src) +} diff --git a/internal/client/storage_test.go b/internal/client/storage_test.go index b7b03bd8741..fef021085a8 100644 --- a/internal/client/storage_test.go +++ b/internal/client/storage_test.go @@ -16,7 +16,7 @@ import ( func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) { tb.Helper() - s = client.NewStorage() + s = client.NewStorage(nil) for _, c := range m { c.UID = client.MustNewUID() @@ -57,7 +57,7 @@ func TestStorage_Add(t *testing.T) { UID: existingClientUID, } - s := client.NewStorage() + s := client.NewStorage(nil) err := s.Add(existingClient) require.NoError(t, err) @@ -137,7 +137,7 @@ func TestStorage_RemoveByName(t *testing.T) { UID: client.MustNewUID(), } - s := client.NewStorage() + s := client.NewStorage(nil) err := s.Add(existingClient) require.NoError(t, err) @@ -162,7 +162,7 @@ func TestStorage_RemoveByName(t *testing.T) { } t.Run("duplicate_remove", func(t *testing.T) { - s = client.NewStorage() + s = client.NewStorage(nil) err = s.Add(existingClient) require.NoError(t, err) @@ -366,6 +366,7 @@ func TestStorage_Update(t *testing.T) { cli: &client.Persistent{ Name: "basic", IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, + UID: client.MustNewUID(), }, wantErrMsg: "", }, { @@ -373,6 +374,7 @@ func TestStorage_Update(t *testing.T) { cli: &client.Persistent{ Name: obstructingName, IPs: []netip.Addr{netip.MustParseAddr("3.3.3.3")}, + UID: client.MustNewUID(), }, wantErrMsg: `updating client: another client uses the same name "obstructing_name"`, }, { @@ -380,6 +382,7 @@ func TestStorage_Update(t *testing.T) { cli: &client.Persistent{ Name: "duplicate_ip", IPs: []netip.Addr{obstructingIP}, + UID: client.MustNewUID(), }, wantErrMsg: `updating client: another client "obstructing_name" uses the same IP "1.2.3.4"`, }, { @@ -387,6 +390,7 @@ func TestStorage_Update(t *testing.T) { cli: &client.Persistent{ Name: "duplicate_subnet", Subnets: []netip.Prefix{obstructingSubnet}, + UID: client.MustNewUID(), }, wantErrMsg: `updating client: another client "obstructing_name" ` + `uses the same subnet "1.2.3.0/24"`, @@ -395,6 +399,7 @@ func TestStorage_Update(t *testing.T) { cli: &client.Persistent{ Name: "duplicate_client_id", ClientIDs: []string{obstructingClientID}, + UID: client.MustNewUID(), }, wantErrMsg: `updating client: another client "obstructing_name" ` + `uses the same ClientID "obstructing_client_id"`, diff --git a/internal/home/clients.go b/internal/home/clients.go index e924ce43a0f..72c14178390 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -44,8 +44,8 @@ type DHCP interface { // clientsContainer is the storage of all runtime and persistent clients. type clientsContainer struct { - // storage stores information about persistent clients. - storage *client.Storage + // clientIndex stores information about persistent clients. + clientIndex *client.Index // runtimeIndex stores information about runtime clients. runtimeIndex *client.RuntimeIndex @@ -103,13 +103,13 @@ func (clients *clientsContainer) Init( filteringConf *filtering.Config, ) (err error) { // TODO(s.chzhen): Refactor it. - if clients.storage != nil { + if clients.clientIndex != nil { return errors.Error("clients container already initialized") } clients.runtimeIndex = client.NewRuntimeIndex() - clients.storage = client.NewStorage() + clients.clientIndex = client.NewIndex() clients.allTags = container.NewMapSet(clientTags...) @@ -285,14 +285,17 @@ func (clients *clientsContainer) addFromConfig( return fmt.Errorf("clients: init persistent client at index %d: %w", i, err) } - err = cli.Validate(clients.allTags) + // TODO(s.chzhen): Consider moving to the client index constructor. + err = clients.clientIndex.ClashesUID(cli) if err != nil { - return fmt.Errorf("validating client %s at index %d: %w", cli.Name, i, err) + return fmt.Errorf("adding client %s at index %d: %w", cli.Name, i, err) } - err = clients.storage.Add(cli) + err = clients.add(cli) if err != nil { - return fmt.Errorf("adding client %s at index %d: %w", cli.Name, i, err) + // TODO(s.chzhen): Return an error instead of logging if more + // stringent requirements are implemented. + log.Error("clients: adding client %s at index %d: %s", cli.Name, i, err) } } @@ -305,8 +308,8 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) { clients.lock.Lock() defer clients.lock.Unlock() - objs = []*clientObject{} - clients.storage.RangeByName(func(cli *client.Persistent) (cont bool) { + objs = make([]*clientObject, 0, clients.clientIndex.Size()) + clients.clientIndex.RangeByName(func(cli *client.Persistent) (cont bool) { objs = append(objs, &clientObject{ Name: cli.Name, @@ -333,7 +336,7 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) { return true }) - return slices.Clip(objs) + return objs } // arpClientsUpdatePeriod defines how often ARP clients are updated. @@ -409,8 +412,12 @@ func (clients *clientsContainer) clientOrArtificial( } }() - cli, ok := clients.storage.FindLoose(ip, id) - if ok { + cli, ok := clients.find(id) + if !ok { + cli = clients.clientIndex.FindByIPWithoutZone(ip) + } + + if cli != nil { return &querylog.Client{ Name: cli.Name, IgnoreQueryLog: cli.IgnoreQueryLog, @@ -516,7 +523,7 @@ func (clients *clientsContainer) UpstreamConfigByID( // findLocked searches for a client by its ID. clients.lock is expected to be // locked. func (clients *clientsContainer) findLocked(id string) (c *client.Persistent, ok bool) { - c, ok = clients.storage.Find(id) + c, ok = clients.clientIndex.Find(id) if ok { return c, true } @@ -538,7 +545,7 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent, return nil, false } - return clients.storage.FindByMAC(foundMAC) + return clients.clientIndex.FindByMAC(foundMAC) } // runtimeClient returns a runtime client from internal index. Note that it @@ -572,6 +579,114 @@ func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Ru return rc } +// check validates the client. It also sorts the client tags. +func (clients *clientsContainer) check(c *client.Persistent) (err error) { + switch { + case c == nil: + return errors.Error("client is nil") + case c.Name == "": + return errors.Error("invalid name") + case c.IDsLen() == 0: + return errors.Error("id required") + default: + // Go on. + } + + for _, t := range c.Tags { + if !clients.allTags.Has(t) { + return fmt.Errorf("invalid tag: %q", t) + } + } + + // TODO(s.chzhen): Move to the constructor. + slices.Sort(c.Tags) + + _, err = proxy.ParseUpstreamsConfig(c.Upstreams, &upstream.Options{}) + if err != nil { + return fmt.Errorf("invalid upstream servers: %w", err) + } + + return nil +} + +// add adds a persistent client or returns an error. +func (clients *clientsContainer) add(c *client.Persistent) (err error) { + err = clients.check(c) + if err != nil { + // Don't wrap the error since it's informative enough as is. + return err + } + + clients.lock.Lock() + defer clients.lock.Unlock() + + err = clients.clientIndex.Clashes(c) + if err != nil { + // Don't wrap the error since it's informative enough as is. + return err + } + + clients.addLocked(c) + + log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs(), clients.clientIndex.Size()) + + return nil +} + +// addLocked c to the indexes. clients.lock is expected to be locked. +func (clients *clientsContainer) addLocked(c *client.Persistent) { + clients.clientIndex.Add(c) +} + +// remove removes a client. ok is false if there is no such client. +func (clients *clientsContainer) remove(name string) (ok bool) { + clients.lock.Lock() + defer clients.lock.Unlock() + + c, ok := clients.clientIndex.FindByName(name) + if !ok { + return false + } + + clients.removeLocked(c) + + return true +} + +// removeLocked removes c from the indexes. clients.lock is expected to be +// locked. +func (clients *clientsContainer) removeLocked(c *client.Persistent) { + if err := c.CloseUpstreams(); err != nil { + log.Error("client container: removing client %s: %s", c.Name, err) + } + + // Update the ID index. + clients.clientIndex.Delete(c) +} + +// update updates a client by its name. +func (clients *clientsContainer) update(prev, c *client.Persistent) (err error) { + err = clients.check(c) + if err != nil { + // Don't wrap the error since it's informative enough as is. + return err + } + + clients.lock.Lock() + defer clients.lock.Unlock() + + err = clients.clientIndex.Clashes(c) + if err != nil { + // Don't wrap the error since it's informative enough as is. + return err + } + + clients.removeLocked(prev) + clients.addLocked(c) + + return nil +} + // setWHOISInfo sets the WHOIS information for a client. clients.lock is // expected to be locked. func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) { @@ -733,5 +848,5 @@ func (clients *clientsContainer) addFromSystemARP() { // close gracefully closes all the client-specific upstream configurations of // the persistent clients. func (clients *clientsContainer) close() (err error) { - return clients.storage.CloseUpstreams() + return clients.clientIndex.CloseUpstreams() } diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index 5a0ba47041c..d371df7b6ff 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -72,7 +72,7 @@ func TestClients(t *testing.T) { IPs: []netip.Addr{cli1IP, cliIPv6}, } - err := clients.storage.Add(c) + err := clients.add(c) require.NoError(t, err) c = &client.Persistent{ @@ -81,7 +81,7 @@ func TestClients(t *testing.T) { IPs: []netip.Addr{cli2IP}, } - err = clients.storage.Add(c) + err = clients.add(c) require.NoError(t, err) c, ok := clients.find(cli1) @@ -106,6 +106,31 @@ func TestClients(t *testing.T) { assert.Equal(t, clients.clientSource(cli2IP), client.SourcePersistent) }) + t.Run("add_fail_name", func(t *testing.T) { + err := clients.add(&client.Persistent{ + Name: "client1", + UID: client.MustNewUID(), + IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")}, + }) + require.Error(t, err) + }) + + t.Run("add_fail_ip", func(t *testing.T) { + err := clients.add(&client.Persistent{ + Name: "client3", + UID: client.MustNewUID(), + }) + require.Error(t, err) + }) + + t.Run("update_fail_ip", func(t *testing.T) { + err := clients.update(&client.Persistent{Name: "client1"}, &client.Persistent{ + Name: "client1", + UID: client.MustNewUID(), + }) + assert.Error(t, err) + }) + t.Run("update_success", func(t *testing.T) { var ( cliOld = "1.1.1.1" @@ -114,11 +139,11 @@ func TestClients(t *testing.T) { cliNewIP = netip.MustParseAddr(cliNew) ) - prev, ok := clients.storage.FindByName("client1") + prev, ok := clients.clientIndex.FindByName("client1") require.True(t, ok) require.NotNil(t, prev) - err := clients.storage.Update("client1", &client.Persistent{ + err := clients.update(prev, &client.Persistent{ Name: "client1", UID: prev.UID, IPs: []netip.Addr{cliNewIP}, @@ -130,11 +155,11 @@ func TestClients(t *testing.T) { assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent) - prev, ok = clients.storage.FindByName("client1") + prev, ok = clients.clientIndex.FindByName("client1") require.True(t, ok) require.NotNil(t, prev) - err = clients.storage.Update("client1", &client.Persistent{ + err = clients.update(prev, &client.Persistent{ Name: "client1-renamed", UID: prev.UID, IPs: []netip.Addr{cliNewIP}, @@ -148,7 +173,7 @@ func TestClients(t *testing.T) { assert.Equal(t, "client1-renamed", c.Name) assert.True(t, c.UseOwnSettings) - nilCli, ok := clients.storage.FindByName("client1") + nilCli, ok := clients.clientIndex.FindByName("client1") require.False(t, ok) assert.Nil(t, nilCli) @@ -159,7 +184,7 @@ func TestClients(t *testing.T) { }) t.Run("del_success", func(t *testing.T) { - ok := clients.storage.RemoveByName("client1-renamed") + ok := clients.remove("client1-renamed") require.True(t, ok) _, ok = clients.find("1.1.1.2") @@ -167,7 +192,7 @@ func TestClients(t *testing.T) { }) t.Run("del_fail", func(t *testing.T) { - ok := clients.storage.RemoveByName("client3") + ok := clients.remove("client3") assert.False(t, ok) }) @@ -236,7 +261,7 @@ func TestClientsWHOIS(t *testing.T) { t.Run("can't_set_manually-added", func(t *testing.T) { ip := netip.MustParseAddr("1.1.1.2") - err := clients.storage.Add(&client.Persistent{ + err := clients.add(&client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")}, @@ -247,7 +272,7 @@ func TestClientsWHOIS(t *testing.T) { rc := clients.runtimeIndex.Client(ip) require.Nil(t, rc) - assert.True(t, clients.storage.RemoveByName("client1")) + assert.True(t, clients.remove("client1")) }) } @@ -258,7 +283,7 @@ func TestClientsAddExisting(t *testing.T) { ip := netip.MustParseAddr("1.1.1.1") // Add a client. - err := clients.storage.Add(&client.Persistent{ + err := clients.add(&client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")}, @@ -308,7 +333,7 @@ func TestClientsAddExisting(t *testing.T) { require.NoError(t, err) // Add a new client with the same IP as for a client with MAC. - err = clients.storage.Add(&client.Persistent{ + err = clients.add(&client.Persistent{ Name: "client2", UID: client.MustNewUID(), IPs: []netip.Addr{ip}, @@ -316,7 +341,7 @@ func TestClientsAddExisting(t *testing.T) { require.NoError(t, err) // Add a new client with the IP from the first client's IP range. - err = clients.storage.Add(&client.Persistent{ + err = clients.add(&client.Persistent{ Name: "client3", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, @@ -329,7 +354,7 @@ func TestClientsCustomUpstream(t *testing.T) { clients := newClientsContainer(t) // Add client with upstreams. - err := clients.storage.Add(&client.Persistent{ + err := clients.add(&client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")}, diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index aea5fceaf28..40a91f862ab 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -96,7 +96,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http clients.lock.Lock() defer clients.lock.Unlock() - clients.storage.RangeByName(func(c *client.Persistent) (cont bool) { + clients.clientIndex.Range(func(c *client.Persistent) (cont bool) { cj := clientToJSON(c) data.Clients = append(data.Clients, cj) @@ -336,14 +336,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http. return } - err = c.Validate(clients.allTags) - if err != nil { - aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) - - return - } - - err = clients.storage.Add(c) + err = clients.add(c) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) @@ -371,7 +364,7 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http. return } - if !clients.storage.RemoveByName(cj.Name) { + if !clients.remove(cj.Name) { aghhttp.Error(r, w, http.StatusBadRequest, "Client not found") return @@ -406,21 +399,30 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht return } - c, err := clients.jsonToClient(dj.Data, nil) - if err != nil { - aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) + var prev *client.Persistent + var ok bool + + func() { + clients.lock.Lock() + defer clients.lock.Unlock() + + prev, ok = clients.clientIndex.FindByName(dj.Name) + }() + + if !ok { + aghhttp.Error(r, w, http.StatusBadRequest, "client not found") return } - err = c.Validate(clients.allTags) + c, err := clients.jsonToClient(dj.Data, prev) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) return } - err = clients.storage.Update(dj.Name, c) + err = clients.update(prev, c) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) diff --git a/internal/home/clientshttp_internal_test.go b/internal/home/clientshttp_internal_test.go index f72a346fc13..dc1aa87d073 100644 --- a/internal/home/clientshttp_internal_test.go +++ b/internal/home/clientshttp_internal_test.go @@ -198,11 +198,11 @@ func TestClientsContainer_HandleDelClient(t *testing.T) { clients := newClientsContainer(t) clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) - err := clients.storage.Add(clientOne) + err := clients.add(clientOne) require.NoError(t, err) clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) - err = clients.storage.Add(clientTwo) + err = clients.add(clientTwo) require.NoError(t, err) assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo}) @@ -260,7 +260,7 @@ func TestClientsContainer_HandleUpdateClient(t *testing.T) { clients := newClientsContainer(t) clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) - err := clients.storage.Add(clientOne) + err := clients.add(clientOne) require.NoError(t, err) assertPersistentClients(t, clients, []*client.Persistent{clientOne}) @@ -342,11 +342,11 @@ func TestClientsContainer_HandleFindClient(t *testing.T) { } clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) - err := clients.storage.Add(clientOne) + err := clients.add(clientOne) require.NoError(t, err) clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) - err = clients.storage.Add(clientTwo) + err = clients.add(clientTwo) require.NoError(t, err) assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo}) diff --git a/internal/home/dns_internal_test.go b/internal/home/dns_internal_test.go index b50146724c0..8413e2a33fa 100644 --- a/internal/home/dns_internal_test.go +++ b/internal/home/dns_internal_test.go @@ -13,18 +13,17 @@ import ( var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4}) -// newStorage is a helper function that returns a client index filled with +// newIDIndex is a helper function that returns a client index filled with // persistent clients from the m. It also generates a UID for each client. -func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) { - tb.Helper() +func newIDIndex(m []*client.Persistent) (ci *client.Index) { + ci = client.NewIndex() - s = client.NewStorage() for _, c := range m { c.UID = client.MustNewUID() - require.NoError(tb, s.Add(c)) + ci.Add(c) } - return s + return ci } func TestApplyAdditionalFiltering(t *testing.T) { @@ -37,8 +36,7 @@ func TestApplyAdditionalFiltering(t *testing.T) { }, nil) require.NoError(t, err) - Context.clients.storage = newStorage(t, []*client.Persistent{{ - Name: "default", + Context.clients.clientIndex = newIDIndex([]*client.Persistent{{ ClientIDs: []string{"default"}, UseOwnSettings: false, SafeSearchConf: filtering.SafeSearchConfig{Enabled: false}, @@ -46,7 +44,6 @@ func TestApplyAdditionalFiltering(t *testing.T) { SafeBrowsingEnabled: false, ParentalEnabled: false, }, { - Name: "custom_filtering", ClientIDs: []string{"custom_filtering"}, UseOwnSettings: true, SafeSearchConf: filtering.SafeSearchConfig{Enabled: true}, @@ -54,7 +51,6 @@ func TestApplyAdditionalFiltering(t *testing.T) { SafeBrowsingEnabled: true, ParentalEnabled: true, }, { - Name: "partial_custom_filtering", ClientIDs: []string{"partial_custom_filtering"}, UseOwnSettings: true, SafeSearchConf: filtering.SafeSearchConfig{Enabled: true}, @@ -125,19 +121,16 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) { }, nil) require.NoError(t, err) - Context.clients.storage = newStorage(t, []*client.Persistent{{ - Name: "default", + Context.clients.clientIndex = newIDIndex([]*client.Persistent{{ ClientIDs: []string{"default"}, UseOwnBlockedServices: false, }, { - Name: "no_services", ClientIDs: []string{"no_services"}, BlockedServices: &filtering.BlockedServices{ Schedule: schedule.EmptyWeekly(), }, UseOwnBlockedServices: true, }, { - Name: "services", ClientIDs: []string{"services"}, BlockedServices: &filtering.BlockedServices{ Schedule: schedule.EmptyWeekly(), @@ -145,7 +138,6 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) { }, UseOwnBlockedServices: true, }, { - Name: "invalid_services", ClientIDs: []string{"invalid_services"}, BlockedServices: &filtering.BlockedServices{ Schedule: schedule.EmptyWeekly(), @@ -153,7 +145,6 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) { }, UseOwnBlockedServices: true, }, { - Name: "allow_all", ClientIDs: []string{"allow_all"}, BlockedServices: &filtering.BlockedServices{ Schedule: schedule.FullWeekly(),