diff --git a/internal/home/clientindex.go b/internal/home/clientindex.go index c217f3efc57..87d7406a747 100644 --- a/internal/home/clientindex.go +++ b/internal/home/clientindex.go @@ -16,33 +16,32 @@ type macKey any func macToKey(mac net.HardwareAddr) (key macKey) { switch len(mac) { case 6: - arr := *(*[6]byte)(mac) - - return arr + return [6]byte(mac) case 8: - arr := *(*[8]byte)(mac) - - return arr + return [8]byte(mac) case 20: - arr := *(*[20]byte)(mac) - - return arr + return [20]byte(mac) default: - panic("invalid mac address") + panic(fmt.Errorf("invalid mac address %#v", mac)) } } // clientIndex stores all information about persistent clients. type clientIndex struct { + // clientIDToUID maps client ID to UID. clientIDToUID map[string]UID + // ipToUID maps IP address to UID. ipToUID map[netip.Addr]UID - subnetToUID aghalg.SortedMap[netip.Prefix, UID] - + // macToUID maps MAC address to UID. macToUID map[macKey]UID + // uidToClient maps UID to the persistent client. uidToClient map[UID]*persistentClient + + // subnetToUID maps subnet to UID. + subnetToUID aghalg.SortedMap[netip.Prefix, UID] } // NewClientIndex initializes the new instance of client index. @@ -56,9 +55,13 @@ func NewClientIndex() (ci *clientIndex) { } } -// add stores information about a persistent client in the index. c must -// contain UID. +// add stores information about a persistent client in the index. c must be +// non-nil and contain UID. func (ci *clientIndex) add(c *persistentClient) { + if (c.UID == UID{}) { + panic("client must contain uid") + } + for _, id := range c.ClientIDs { ci.clientIDToUID[id] = c.UID } @@ -80,7 +83,7 @@ func (ci *clientIndex) add(c *persistentClient) { } // clashes returns an error if the index contains a different persistent client -// with at least a single identifier contained by c. +// with at least a single identifier contained by c. c must be non-nil. func (ci *clientIndex) clashes(c *persistentClient) (err error) { for _, id := range c.ClientIDs { existing, ok := ci.clientIDToUID[id] @@ -109,7 +112,8 @@ func (ci *clientIndex) clashes(c *persistentClient) (err error) { return nil } -// clashesIP returns a previous client with the same IP address as c. +// clashesIP returns a previous client with the same IP address as c. c must be +// non-nil. func (ci *clientIndex) clashesIP(c *persistentClient) (p *persistentClient, ip netip.Addr) { for _, ip := range c.IPs { existing, ok := ci.ipToUID[ip] @@ -121,12 +125,13 @@ func (ci *clientIndex) clashesIP(c *persistentClient) (p *persistentClient, ip n return nil, netip.Addr{} } -// clashesSubnet returns a previous client with the same subnet as c. +// clashesSubnet returns a previous client with the same subnet as c. c must be +// non-nil. func (ci *clientIndex) clashesSubnet(c *persistentClient) (p *persistentClient, s netip.Prefix) { - var existing UID - var ok bool - for _, s = range c.Subnets { + var existing UID + var ok bool + ci.subnetToUID.Range(func(p netip.Prefix, uid UID) (cont bool) { if s == p { existing = uid @@ -146,7 +151,8 @@ func (ci *clientIndex) clashesSubnet(c *persistentClient) (p *persistentClient, return nil, netip.Prefix{} } -// clashesMAC returns a previous client with the same MAC address as c. +// clashesMAC returns a previous client with the same MAC address as c. c must +// be non-nil. func (ci *clientIndex) clashesMAC(c *persistentClient) (p *persistentClient, mac net.HardwareAddr) { for _, mac = range c.MACs { k := macToKey(mac) @@ -219,7 +225,8 @@ func (ci *clientIndex) findByMAC(mac net.HardwareAddr) (c *persistentClient, fou return nil, false } -// del removes information about persistent client from the index. +// del removes information about persistent client from the index. c must be +// non-nil. func (ci *clientIndex) del(c *persistentClient) { for _, id := range c.ClientIDs { delete(ci.clientIDToUID, id) diff --git a/internal/home/clientindex_internal_test.go b/internal/home/clientindex_internal_test.go index 868b2fc68ea..97ba2faf335 100644 --- a/internal/home/clientindex_internal_test.go +++ b/internal/home/clientindex_internal_test.go @@ -1,6 +1,7 @@ package home import ( + "net" "net/netip" "testing" @@ -114,3 +115,30 @@ func TestClientIndex(t *testing.T) { require.NoError(t, err) }) } + +func TestMACToKey(t *testing.T) { + macs := []string{ + "00:00:5e:00:53:01", + "02:00:5e:10:00:00:00:01", + "00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01", + "00-00-5e-00-53-01", + "02-00-5e-10-00-00-00-01", + "00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01", + "0000.5e00.5301", + "0200.5e10.0000.0001", + "0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001", + } + + for _, m := range macs { + mac, err := net.ParseMAC(m) + require.NoError(t, err) + + key := macToKey(mac) + assert.Len(t, key, len(mac)) + } + + assert.Panics(t, func() { + mac := net.HardwareAddr([]byte{1, 2, 3}) + _ = macToKey(mac) + }) +} diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index 5e3f6d507c6..6b384f6b792 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -113,6 +113,7 @@ func TestClients(t *testing.T) { t.Run("add_fail_name", func(t *testing.T) { ok, err := clients.add(&persistentClient{ Name: "client1", + UID: MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")}, }) require.NoError(t, err) @@ -122,6 +123,7 @@ func TestClients(t *testing.T) { t.Run("add_fail_ip", func(t *testing.T) { ok, err := clients.add(&persistentClient{ Name: "client3", + UID: MustNewUID(), }) require.Error(t, err) assert.False(t, ok) @@ -130,6 +132,7 @@ func TestClients(t *testing.T) { t.Run("update_fail_ip", func(t *testing.T) { err := clients.update(&persistentClient{Name: "client1"}, &persistentClient{ Name: "client1", + UID: MustNewUID(), }) assert.Error(t, err) }) @@ -147,6 +150,7 @@ func TestClients(t *testing.T) { err := clients.update(prev, &persistentClient{ Name: "client1", + UID: MustNewUID(), IPs: []netip.Addr{cliNewIP}, }) require.NoError(t, err) @@ -161,6 +165,7 @@ func TestClients(t *testing.T) { err = clients.update(prev, &persistentClient{ Name: "client1-renamed", + UID: MustNewUID(), IPs: []netip.Addr{cliNewIP}, UseOwnSettings: true, }) @@ -262,6 +267,7 @@ func TestClientsWHOIS(t *testing.T) { ok, err := clients.add(&persistentClient{ Name: "client1", + UID: MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")}, }) require.NoError(t, err) @@ -284,6 +290,7 @@ func TestClientsAddExisting(t *testing.T) { // Add a client. ok, err := clients.add(&persistentClient{ Name: "client1", + UID: MustNewUID(), IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")}, Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")}, MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}}, @@ -334,6 +341,7 @@ func TestClientsAddExisting(t *testing.T) { // Add a new client with the same IP as for a client with MAC. ok, err := clients.add(&persistentClient{ Name: "client2", + UID: MustNewUID(), IPs: []netip.Addr{ip}, }) require.NoError(t, err) @@ -342,6 +350,7 @@ func TestClientsAddExisting(t *testing.T) { // Add a new client with the IP from the first client's IP range. ok, err = clients.add(&persistentClient{ Name: "client3", + UID: MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, }) require.NoError(t, err) @@ -355,6 +364,7 @@ func TestClientsCustomUpstream(t *testing.T) { // Add client with upstreams. ok, err := clients.add(&persistentClient{ Name: "client1", + UID: MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")}, Upstreams: []string{ "1.1.1.1", diff --git a/internal/home/dns_internal_test.go b/internal/home/dns_internal_test.go index 592971c8a76..89b29be4384 100644 --- a/internal/home/dns_internal_test.go +++ b/internal/home/dns_internal_test.go @@ -12,7 +12,9 @@ import ( var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4}) -func idIndex(m map[string]*persistentClient) (ci *clientIndex) { +// newIDIndex is a helper function that returns a client index filled with +// persistent clients from the m. +func newIDIndex(m map[string]*persistentClient) (ci *clientIndex) { ci = NewClientIndex() for id, c := range m { @@ -35,7 +37,7 @@ func TestApplyAdditionalFiltering(t *testing.T) { }, nil) require.NoError(t, err) - Context.clients.clientIndex = idIndex(map[string]*persistentClient{ + Context.clients.clientIndex = newIDIndex(map[string]*persistentClient{ "default": { UseOwnSettings: false, safeSearchConf: filtering.SafeSearchConfig{Enabled: false}, @@ -121,7 +123,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) { }, nil) require.NoError(t, err) - Context.clients.clientIndex = idIndex(map[string]*persistentClient{ + Context.clients.clientIndex = newIDIndex(map[string]*persistentClient{ "default": { UseOwnBlockedServices: false, },