diff --git a/internal/home/clientindex_internal_test.go b/internal/home/clientindex_internal_test.go index 97ba2faf335..b89703db0bb 100644 --- a/internal/home/clientindex_internal_test.go +++ b/internal/home/clientindex_internal_test.go @@ -5,7 +5,6 @@ import ( "net/netip" "testing" - "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -25,116 +24,183 @@ func TestClientIndex(t *testing.T) { cliMAC = "11:11:11:11:11:11" ) - objs := []clientObject{{ - Name: "client1", - IDs: []string{cliIP1, cliIPv6}, - BlockedServices: &filtering.BlockedServices{}, + clients := []*persistentClient{{ + Name: "client1", + IPs: []netip.Addr{ + netip.MustParseAddr(cliIP1), + netip.MustParseAddr(cliIPv6), + }, }, { - Name: "client2", - IDs: []string{cliIP2, cliSubnet}, - BlockedServices: &filtering.BlockedServices{}, + Name: "client2", + IPs: []netip.Addr{netip.MustParseAddr(cliIP2)}, + Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)}, }, { - Name: "client_with_mac", - IDs: []string{cliMAC}, - BlockedServices: &filtering.BlockedServices{}, + Name: "client_with_mac", + MACs: []net.HardwareAddr{mustParseMAC(cliMAC)}, }, { - Name: "client_with_id", - IDs: []string{cliID}, - BlockedServices: &filtering.BlockedServices{}, + Name: "client_with_id", + ClientIDs: []string{cliID}, }} - clients := []*persistentClient{} - for _, o := range objs { - cli, err := o.toPersistent(&filtering.Config{}, nil) - require.NoError(t, err) + ci := newIDIndex(clients) - clients = append(clients, cli) - } - - client1 := clients[0] - client2 := clients[1] - clientWithMAC := clients[2] - clientWithID := clients[3] - - ci := NewClientIndex() - - t.Run("add_find", func(t *testing.T) { - ci.add(client1) - ci.add(client2) - ci.add(clientWithMAC) - ci.add(clientWithID) - - c, ok := ci.find(cliIP1) - require.True(t, ok) - - assert.Equal(t, client1.Name, c.Name) + testCases := []struct { + name string + ids []string + want *persistentClient + }{{ + name: "ipv4_ipv6", + ids: []string{cliIP1, cliIPv6}, + want: clients[0], + }, { + name: "ipv4_subnet", + ids: []string{cliIP2, cliSubnetIP}, + want: clients[1], + }, { + name: "mac", + ids: []string{cliMAC}, + want: clients[2], + }, { + name: "client_id", + ids: []string{cliID}, + want: clients[3], + }} - c, ok = ci.find(cliIPv6) - require.True(t, ok) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for _, id := range tc.ids { + c, ok := ci.find(id) + require.True(t, ok) - assert.Equal(t, client1.Name, c.Name) + assert.Equal(t, tc.want, c) + } + }) + } - c, ok = ci.find(cliIP2) - require.True(t, ok) + t.Run("not_found", func(t *testing.T) { + _, ok := ci.find(cliIPNone) + assert.False(t, ok) + }) +} - assert.Equal(t, client2.Name, c.Name) +func TestClientIndex_Clashes(t *testing.T) { + const ( + cliIP1 = "1.1.1.1" + cliSubnet = "2.2.2.0/24" + cliSubnetIP = "2.2.2.222" + cliID = "client-id" + cliMAC = "11:11:11:11:11:11" + ) - c, ok = ci.find(cliSubnetIP) - require.True(t, ok) + clients := []*persistentClient{{ + Name: "client_with_ip", + IPs: []netip.Addr{netip.MustParseAddr(cliIP1)}, + }, { + Name: "client_with_subnet", + Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)}, + }, { + Name: "client_with_mac", + MACs: []net.HardwareAddr{mustParseMAC(cliMAC)}, + }, { + Name: "client_with_id", + ClientIDs: []string{cliID}, + }} - assert.Equal(t, client2.Name, c.Name) + ci := newIDIndex(clients) - c, ok = ci.find(cliMAC) - require.True(t, ok) + testCases := []struct { + name string + client *persistentClient + }{{ + name: "ipv4", + client: clients[0], + }, { + name: "subnet", + client: clients[1], + }, { + name: "mac", + client: clients[2], + }, { + name: "client_id", + client: clients[3], + }} - assert.Equal(t, clientWithMAC.Name, c.Name) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + clone := tc.client.shallowClone() + clone.UID = MustNewUID() - c, ok = ci.find(cliID) - require.True(t, ok) + err := ci.clashes(clone) + require.Error(t, err) - assert.Equal(t, clientWithID.Name, c.Name) + ci.del(tc.client) + err = ci.clashes(clone) + require.NoError(t, err) + }) + } +} - _, ok = ci.find(cliIPNone) - assert.False(t, ok) - }) +// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an +// error. +func mustParseMAC(s string) (mac net.HardwareAddr) { + mac, err := net.ParseMAC(s) + if err != nil { + panic(err) + } - t.Run("contains_delete", func(t *testing.T) { - err := ci.clashes(client1) - require.NoError(t, err) - - dup := &persistentClient{ - Name: "client_with_the_same_ip_as_client1", - IPs: []netip.Addr{netip.MustParseAddr(cliIP1)}, - UID: MustNewUID(), - } - err = ci.clashes(dup) - require.Error(t, err) - - ci.del(client1) - err = ci.clashes(dup) - require.NoError(t, err) - }) + return mac } func TestMACToKey(t *testing.T) { - macs := []string{ - "00:00:5e:00:53:01", - "02:00:5e:10:00:00:00:01", - "00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01", - "00-00-5e-00-53-01", - "02-00-5e-10-00-00-00-01", - "00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01", - "0000.5e00.5301", - "0200.5e10.0000.0001", - "0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001", - } + testCases := []struct { + name string + in string + want any + }{{ + name: "column6", + in: "00:00:5e:00:53:01", + want: [6]byte(mustParseMAC("00:00:5e:00:53:01")), + }, { + name: "column8", + in: "02:00:5e:10:00:00:00:01", + want: [8]byte(mustParseMAC("02:00:5e:10:00:00:00:01")), + }, { + name: "column20", + in: "00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01", + want: [20]byte(mustParseMAC("00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01")), + }, { + name: "hyphen6", + in: "00-00-5e-00-53-01", + want: [6]byte(mustParseMAC("00-00-5e-00-53-01")), + }, { + name: "hyphen8", + in: "02-00-5e-10-00-00-00-01", + want: [8]byte(mustParseMAC("02-00-5e-10-00-00-00-01")), + }, { + name: "hyphen20", + in: "00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01", + want: [20]byte(mustParseMAC("00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01")), + }, { + name: "dot6", + in: "0000.5e00.5301", + want: [6]byte(mustParseMAC("0000.5e00.5301")), + }, { + name: "dot8", + in: "0200.5e10.0000.0001", + want: [8]byte(mustParseMAC("0200.5e10.0000.0001")), + }, { + name: "dot20", + in: "0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001", + want: [20]byte(mustParseMAC("0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001")), + }} - for _, m := range macs { - mac, err := net.ParseMAC(m) - require.NoError(t, err) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mac := mustParseMAC(tc.in) - key := macToKey(mac) - assert.Len(t, key, len(mac)) + key := macToKey(mac) + assert.Equal(t, tc.want, key) + }) } assert.Panics(t, func() { diff --git a/internal/home/dns_internal_test.go b/internal/home/dns_internal_test.go index 89b29be4384..ff279752ee4 100644 --- a/internal/home/dns_internal_test.go +++ b/internal/home/dns_internal_test.go @@ -13,14 +13,12 @@ import ( var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4}) // newIDIndex is a helper function that returns a client index filled with -// persistent clients from the m. -func newIDIndex(m map[string]*persistentClient) (ci *clientIndex) { +// persistent clients from the m. It also generates a UID for each client. +func newIDIndex(m []*persistentClient) (ci *clientIndex) { ci = NewClientIndex() - for id, c := range m { - c.ClientIDs = []string{id} + for _, c := range m { c.UID = MustNewUID() - ci.add(c) } @@ -37,29 +35,28 @@ func TestApplyAdditionalFiltering(t *testing.T) { }, nil) require.NoError(t, err) - Context.clients.clientIndex = newIDIndex(map[string]*persistentClient{ - "default": { - UseOwnSettings: false, - safeSearchConf: filtering.SafeSearchConfig{Enabled: false}, - FilteringEnabled: false, - SafeBrowsingEnabled: false, - ParentalEnabled: false, - }, - "custom_filtering": { - UseOwnSettings: true, - safeSearchConf: filtering.SafeSearchConfig{Enabled: true}, - FilteringEnabled: true, - SafeBrowsingEnabled: true, - ParentalEnabled: true, - }, - "partial_custom_filtering": { - UseOwnSettings: true, - safeSearchConf: filtering.SafeSearchConfig{Enabled: true}, - FilteringEnabled: true, - SafeBrowsingEnabled: false, - ParentalEnabled: false, - }, - }) + Context.clients.clientIndex = newIDIndex([]*persistentClient{{ + ClientIDs: []string{"default"}, + UseOwnSettings: false, + safeSearchConf: filtering.SafeSearchConfig{Enabled: false}, + FilteringEnabled: false, + SafeBrowsingEnabled: false, + ParentalEnabled: false, + }, { + ClientIDs: []string{"custom_filtering"}, + UseOwnSettings: true, + safeSearchConf: filtering.SafeSearchConfig{Enabled: true}, + FilteringEnabled: true, + SafeBrowsingEnabled: true, + ParentalEnabled: true, + }, { + ClientIDs: []string{"partial_custom_filtering"}, + UseOwnSettings: true, + safeSearchConf: filtering.SafeSearchConfig{Enabled: true}, + FilteringEnabled: true, + SafeBrowsingEnabled: false, + ParentalEnabled: false, + }}) testCases := []struct { name string @@ -123,38 +120,37 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) { }, nil) require.NoError(t, err) - Context.clients.clientIndex = newIDIndex(map[string]*persistentClient{ - "default": { - UseOwnBlockedServices: false, - }, - "no_services": { - BlockedServices: &filtering.BlockedServices{ - Schedule: schedule.EmptyWeekly(), - }, - UseOwnBlockedServices: true, + Context.clients.clientIndex = newIDIndex([]*persistentClient{{ + ClientIDs: []string{"default"}, + UseOwnBlockedServices: false, + }, { + ClientIDs: []string{"no_services"}, + BlockedServices: &filtering.BlockedServices{ + Schedule: schedule.EmptyWeekly(), }, - "services": { - BlockedServices: &filtering.BlockedServices{ - Schedule: schedule.EmptyWeekly(), - IDs: clientBlockedServices, - }, - UseOwnBlockedServices: true, + UseOwnBlockedServices: true, + }, { + ClientIDs: []string{"services"}, + BlockedServices: &filtering.BlockedServices{ + Schedule: schedule.EmptyWeekly(), + IDs: clientBlockedServices, }, - "invalid_services": { - BlockedServices: &filtering.BlockedServices{ - Schedule: schedule.EmptyWeekly(), - IDs: invalidBlockedServices, - }, - UseOwnBlockedServices: true, + UseOwnBlockedServices: true, + }, { + ClientIDs: []string{"invalid_services"}, + BlockedServices: &filtering.BlockedServices{ + Schedule: schedule.EmptyWeekly(), + IDs: invalidBlockedServices, }, - "allow_all": { - BlockedServices: &filtering.BlockedServices{ - Schedule: schedule.FullWeekly(), - IDs: clientBlockedServices, - }, - UseOwnBlockedServices: true, + UseOwnBlockedServices: true, + }, { + ClientIDs: []string{"allow_all"}, + BlockedServices: &filtering.BlockedServices{ + Schedule: schedule.FullWeekly(), + IDs: clientBlockedServices, }, - }) + UseOwnBlockedServices: true, + }}) testCases := []struct { name string