diff --git a/internal/client/persistent.go b/internal/client/persistent.go index 317dc72b522..1c2ce848223 100644 --- a/internal/client/persistent.go +++ b/internal/client/persistent.go @@ -12,6 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" @@ -98,6 +99,39 @@ type Persistent struct { SafeSearchConf filtering.SafeSearchConfig } +// Validate returns an error if persistent client information contains errors. +func (c *Persistent) Validate(allTags *container.MapSet[string]) (err error) { + switch { + case c.Name == "": + return errors.Error("empty name") + case c.IDsLen() == 0: + return errors.Error("id required") + case c.UID == UID{}: + return errors.Error("uid required") + } + + conf, err := proxy.ParseUpstreamsConfig(c.Upstreams, &upstream.Options{}) + if err != nil { + return fmt.Errorf("invalid upstream servers: %w", err) + } + + err = conf.Close() + if err != nil { + log.Error("client: closing upstream config: %s", err) + } + + for _, t := range c.Tags { + if !allTags.Has(t) { + return fmt.Errorf("invalid tag: %q", t) + } + } + + // TODO(s.chzhen): Move to the constructor. + slices.Sort(c.Tags) + + return nil +} + // SetTags sets the tags if they are known, otherwise logs an unknown tag. func (c *Persistent) SetTags(tags []string, known *container.MapSet[string]) { for _, t := range tags { diff --git a/internal/client/persistent_internal_test.go b/internal/client/persistent_internal_test.go index 76da1e4bbb8..75fad1277f6 100644 --- a/internal/client/persistent_internal_test.go +++ b/internal/client/persistent_internal_test.go @@ -1,13 +1,15 @@ package client import ( + "net/netip" "testing" + "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestPersistentClient_EqualIDs(t *testing.T) { +func TestPersistent_EqualIDs(t *testing.T) { const ( ip = "0.0.0.0" ip1 = "1.1.1.1" @@ -122,3 +124,50 @@ func TestPersistentClient_EqualIDs(t *testing.T) { }) } } + +func TestPersistent_Validate(t *testing.T) { + // TODO(s.chzhen): Add test cases. + testCases := []struct { + name string + cli *Persistent + wantErrMsg string + }{{ + name: "basic", + cli: &Persistent{ + Name: "basic", + IPs: []netip.Addr{ + netip.MustParseAddr("1.2.3.4"), + }, + UID: MustNewUID(), + }, + wantErrMsg: "", + }, { + name: "empty_name", + cli: &Persistent{ + Name: "", + }, + wantErrMsg: "empty name", + }, { + name: "no_id", + cli: &Persistent{ + Name: "no_id", + }, + wantErrMsg: "id required", + }, { + name: "no_uid", + cli: &Persistent{ + Name: "no_uid", + IPs: []netip.Addr{ + netip.MustParseAddr("1.2.3.4"), + }, + }, + wantErrMsg: "uid required", + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.cli.Validate(nil) + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + }) + } +} diff --git a/internal/client/storage.go b/internal/client/storage.go index a336125c039..bc317d2b162 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -1,23 +1,14 @@ package client import ( - "fmt" "net/netip" - "slices" "sync" - "github.com/AdguardTeam/dnsproxy/proxy" - "github.com/AdguardTeam/dnsproxy/upstream" - "github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" ) // Storage contains information about persistent and runtime clients. type Storage struct { - // allTags is a set of all client tags. - allTags *container.MapSet[string] - // mu protects index of persistent clients. mu *sync.Mutex @@ -29,11 +20,8 @@ type Storage struct { } // NewStorage returns initialized client storage. -func NewStorage(clientTags []string) (s *Storage) { - allTags := container.NewMapSet(clientTags...) - +func NewStorage() (s *Storage) { return &Storage{ - allTags: allTags, mu: &sync.Mutex{}, index: NewIndex(), runtimeIndex: map[netip.Addr]*Runtime{}, @@ -41,60 +29,26 @@ func NewStorage(clientTags []string) (s *Storage) { } // Add stores persistent client information or returns an error. p must be -// valid persistent client. +// valid persistent client. See [Persistent.Validate]. func (s *Storage) Add(p *Persistent) (err error) { + defer func() { err = errors.Annotate(err, "adding client: %w") }() + s.mu.Lock() defer s.mu.Unlock() - err = s.check(p) - if err != nil { - return fmt.Errorf("adding client: %w", err) - } - - s.index.Add(p) - - return nil -} - -// check returns an error if persistent client information contains errors. -// -// TODO(s.chzhen): Remove persistent client information validation. -func (s *Storage) check(p *Persistent) (err error) { - switch { - case p == nil: - return errors.Error("client is nil") - case p.Name == "": - return errors.Error("empty name") - case p.IDsLen() == 0: - return errors.Error("id required") - case p.UID == UID{}: - return errors.Error("uid required") - } - err = s.index.ClashesUID(p) if err != nil { // Don't wrap the error since there is already an annotation deferred. return err } - conf, err := proxy.ParseUpstreamsConfig(p.Upstreams, &upstream.Options{}) - if err != nil { - return fmt.Errorf("invalid upstream servers: %w", err) - } - - err = conf.Close() + err = s.index.Clashes(p) if err != nil { - log.Error("client: closing upstream config: %s", err) - } - - for _, t := range p.Tags { - if !s.allTags.Has(t) { - return fmt.Errorf("invalid tag: %q", t) - } + // Don't wrap the error since there is already an annotation deferred. + return err } - // TODO(s.chzhen): Move to the constructor. - slices.Sort(p.Tags) + s.index.Add(p) return nil } @@ -117,7 +71,6 @@ func (s *Storage) RemoveByName(name string) (ok bool) { func (s *Storage) Update(p, n *Persistent) (err error) { defer func() { err = errors.Annotate(err, "updating client: %w") }() - err = s.check(n) if err != nil { // Don't wrap the error since there is already an annotation deferred. return err diff --git a/internal/client/storage_test.go b/internal/client/storage_test.go index 66e23c4e44a..897472b96c5 100644 --- a/internal/client/storage_test.go +++ b/internal/client/storage_test.go @@ -6,10 +6,34 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/golibs/testutil" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestStorage_Add(t *testing.T) { + const ( + existingName = "existing_name" + existingClientID = "existing_client_id" + ) + + var ( + existingClientUID = client.MustNewUID() + existingIP = netip.MustParseAddr("1.2.3.4") + existingSubnet = netip.MustParsePrefix("1.2.3.0/24") + ) + + existingClient := &client.Persistent{ + Name: existingName, + IPs: []netip.Addr{existingIP}, + Subnets: []netip.Prefix{existingSubnet}, + ClientIDs: []string{existingClientID}, + UID: existingClientUID, + } + + s := client.NewStorage() + err := s.Add(existingClient) + require.NoError(t, err) + testCases := []struct { name string cli *client.Persistent @@ -18,68 +42,104 @@ func TestStorage_Add(t *testing.T) { name: "basic", cli: &client.Persistent{ Name: "basic", - IPs: []netip.Addr{ - netip.MustParseAddr("1.2.3.4"), - }, - UID: client.MustNewUID(), + IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, + UID: client.MustNewUID(), }, wantErrMsg: "", }, { - name: "nil", - cli: nil, - wantErrMsg: "adding client: client is nil", + name: "duplicate_uid", + cli: &client.Persistent{ + Name: "no_uid", + IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, + UID: existingClientUID, + }, + wantErrMsg: `adding client: another client "existing_name" uses the same uid`, + }, { + name: "duplicate_name", + cli: &client.Persistent{ + Name: existingName, + IPs: []netip.Addr{netip.MustParseAddr("3.3.3.3")}, + UID: client.MustNewUID(), + }, + wantErrMsg: `adding client: another client uses the same name "existing_name"`, }, { - name: "empty_name", + name: "duplicate_ip", cli: &client.Persistent{ - Name: "", + Name: "duplicate_ip", + IPs: []netip.Addr{existingIP}, + UID: client.MustNewUID(), }, - wantErrMsg: "adding client: empty name", + wantErrMsg: `adding client: another client "existing_name" uses the same IP "1.2.3.4"`, }, { - name: "no_id", + name: "duplicate_subnet", cli: &client.Persistent{ - Name: "no_id", + Name: "duplicate_subnet", + Subnets: []netip.Prefix{existingSubnet}, + UID: client.MustNewUID(), }, - wantErrMsg: "adding client: id required", + wantErrMsg: `adding client: another client "existing_name" ` + + `uses the same subnet "1.2.3.0/24"`, }, { - name: "no_uid", + name: "duplicate_client_id", cli: &client.Persistent{ - Name: "no_uid", - IPs: []netip.Addr{ - netip.MustParseAddr("1.2.3.4"), - }, + Name: "duplicate_client_id", + ClientIDs: []string{existingClientID}, + UID: client.MustNewUID(), }, - wantErrMsg: "adding client: uid required", + wantErrMsg: `adding client: another client "existing_name" ` + + `uses the same ClientID "existing_client_id"`, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - s := client.NewStorage(nil) - err := s.Add(tc.cli) + err = s.Add(tc.cli) testutil.AssertErrorMsg(t, tc.wantErrMsg, err) }) } +} + +func TestStorage_RemoveByName(t *testing.T) { + const ( + existingName = "existing_name" + ) + + existingClient := &client.Persistent{ + Name: existingName, + IPs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, + UID: client.MustNewUID(), + } - t.Run("duplicate_uid", func(t *testing.T) { - sameUID := client.MustNewUID() - s := client.NewStorage(nil) + s := client.NewStorage() + err := s.Add(existingClient) + require.NoError(t, err) - cli1 := &client.Persistent{ - Name: "cli1", - IPs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, - UID: sameUID, - } + testCases := []struct { + want assert.BoolAssertionFunc + name string + cliName string + }{{ + name: "existing_client", + cliName: existingName, + want: assert.True, + }, { + name: "non_existing_client", + cliName: "non_existing_client", + want: assert.False, + }} - cli2 := &client.Persistent{ - Name: "cli2", - IPs: []netip.Addr{netip.MustParseAddr("4.3.2.1")}, - UID: sameUID, - } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.want(t, s.RemoveByName(tc.cliName)) + }) + } - err := s.Add(cli1) + t.Run("duplicate_remove", func(t *testing.T) { + s = client.NewStorage() + err = s.Add(existingClient) require.NoError(t, err) - err = s.Add(cli2) - testutil.AssertErrorMsg(t, `adding client: another client "cli1" uses the same uid`, err) + assert.True(t, s.RemoveByName(existingName)) + assert.False(t, s.RemoveByName(existingName)) }) }