Skip to content

Commit

Permalink
all: imp code
Browse files Browse the repository at this point in the history
  • Loading branch information
schzhn committed Feb 8, 2024
1 parent b45c7d8 commit 10637fd
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 68 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package aghalg

import (
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)

Expand All @@ -18,7 +19,7 @@ type SortedMap[K comparable, V any] struct {
// TODO(s.chzhen): Use cmp.Compare in Go 1.21.
func NewSortedMap[K comparable, V any](cmp func(a, b K) (res int)) SortedMap[K, V] {
return SortedMap[K, V]{
vals: make(map[K]V),
vals: map[K]V{},
cmp: cmp,
}
}
Expand Down Expand Up @@ -69,7 +70,7 @@ func (m *SortedMap[K, V]) Clear() {

// TODO(s.chzhen): Use built-in clear in Go 1.21.
m.keys = nil
m.vals = make(map[K]V)
maps.Clear(m.vals)
}

// Range calls cb for each element of the map, sorted by m.cmp. If cb returns
Expand Down
File renamed without changes.
10 changes: 10 additions & 0 deletions internal/home/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ func NewUID() (uid UID, err error) {
return UID(uuidv7), err
}

// MustNewUID is a wrapper around [NewUID] that panics if there is an error.
func MustNewUID() (uid UID) {
uid, err := NewUID()
if err != nil {
panic(fmt.Errorf("unexpected uuidv7 error: %w", err))
}

return uid
}

// type check
var _ encoding.TextMarshaler = UID{}

Expand Down
93 changes: 66 additions & 27 deletions internal/home/clientindex.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package home

import (
"fmt"
"net"
"net/netip"

Expand All @@ -10,23 +11,24 @@ import (
// macKey contains MAC as byte array of 6, 8, or 20 bytes.
type macKey any

// macToKey converts mac into key of type macKey, which is used as the key of
// the [clientIndex.macToUID]. mac must be valid MAC address.
func macToKey(mac net.HardwareAddr) (key macKey) {
switch len(mac) {
case 6:
arr := [6]byte{}
copy(arr[:], mac[:])
arr := *(*[6]byte)(mac)

return arr
case 8:
arr := [8]byte{}
copy(arr[:], mac[:])
arr := *(*[8]byte)(mac)

return arr
default:
arr := [20]byte{}
copy(arr[:], mac[:])
case 20:
arr := *(*[20]byte)(mac)

return arr
default:
panic("invalid mac address")
}
}

Expand Down Expand Up @@ -54,7 +56,8 @@ func NewClientIndex() (ci *clientIndex) {
}
}

// add stores information about a persistent client in the index.
// add stores information about a persistent client in the index. c must
// contain UID.
func (ci *clientIndex) add(c *persistentClient) {
for _, id := range c.ClientIDs {
ci.clientIDToUID[id] = c.UID
Expand All @@ -76,26 +79,57 @@ func (ci *clientIndex) add(c *persistentClient) {
ci.uidToClient[c.UID] = c
}

// contains returns true if the index contains a persistent client with at least
// a single identifier contained by c.
func (ci *clientIndex) contains(c *persistentClient) (ok bool) {
// clashes returns an error if the index contains a different persistent client
// with at least a single identifier contained by c.
func (ci *clientIndex) clashes(c *persistentClient) (err error) {
for _, id := range c.ClientIDs {
_, ok = ci.clientIDToUID[id]
if ok {
return true
existing, ok := ci.clientIDToUID[id]
if ok && existing != c.UID {
p := ci.uidToClient[existing]

return fmt.Errorf("another client %q uses the same ID %q", p.Name, id)
}
}

p, ip := ci.clashesIP(c)
if p != nil {
return fmt.Errorf("another client %q uses the same IP %q", p.Name, ip)
}

p, s := ci.clashesSubnet(c)
if p != nil {
return fmt.Errorf("another client %q uses the same subnet %q", p.Name, s)
}

p, mac := ci.clashesMAC(c)
if p != nil {
return fmt.Errorf("another client %q uses the same MAC %q", p.Name, mac)
}

return nil
}

// clashesIP returns a previous client with the same IP address as c.
func (ci *clientIndex) clashesIP(c *persistentClient) (p *persistentClient, ip netip.Addr) {
for _, ip := range c.IPs {
_, ok = ci.ipToUID[ip]
if ok {
return true
existing, ok := ci.ipToUID[ip]
if ok && existing != c.UID {
return ci.uidToClient[existing], ip
}
}

for _, pref := range c.Subnets {
ci.subnetToUID.Range(func(p netip.Prefix, _ UID) (cont bool) {
if pref == p {
return nil, netip.Addr{}
}

// clashesSubnet returns a previous client with the same subnet as c.
func (ci *clientIndex) clashesSubnet(c *persistentClient) (p *persistentClient, s netip.Prefix) {
var existing UID
var ok bool

for _, s = range c.Subnets {
ci.subnetToUID.Range(func(p netip.Prefix, uid UID) (cont bool) {
if s == p {
existing = uid
ok = true

return false
Expand All @@ -104,20 +138,25 @@ func (ci *clientIndex) contains(c *persistentClient) (ok bool) {
return true
})

if ok {
return true
if ok && existing != c.UID {
return ci.uidToClient[existing], s
}
}

for _, mac := range c.MACs {
return nil, netip.Prefix{}
}

// clashesMAC returns a previous client with the same MAC address as c.
func (ci *clientIndex) clashesMAC(c *persistentClient) (p *persistentClient, mac net.HardwareAddr) {
for _, mac = range c.MACs {
k := macToKey(mac)
_, ok = ci.macToUID[k]
if ok {
return true
existing, ok := ci.macToUID[k]
if ok && existing != c.UID {
return ci.uidToClient[existing], mac
}
}

return false
return nil, nil
}

// find finds persistent client by string representation of the client ID, IP
Expand Down
17 changes: 13 additions & 4 deletions internal/home/clientindex_internal_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package home

import (
"net/netip"
"testing"

"github.com/AdguardTeam/AdGuardHome/internal/filtering"
Expand Down Expand Up @@ -97,11 +98,19 @@ func TestClientIndex(t *testing.T) {
})

t.Run("contains_delete", func(t *testing.T) {
ok := ci.contains(client1)
require.True(t, ok)
err := ci.clashes(client1)
require.NoError(t, err)

dup := &persistentClient{
Name: "client_with_the_same_ip_as_client1",
IPs: []netip.Addr{netip.MustParseAddr(cliIP1)},
UID: MustNewUID(),
}
err = ci.clashes(dup)
require.Error(t, err)

ci.del(client1)
ok = ci.contains(client1)
require.False(t, ok)
err = ci.clashes(dup)
require.NoError(t, err)
})
}
48 changes: 17 additions & 31 deletions internal/home/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ type DHCP interface {
type clientsContainer struct {
// TODO(a.garipov): Perhaps use a number of separate indices for different
// types (string, netip.Addr, and so on).
list map[string]*persistentClient // name -> client
idIndex map[string]*persistentClient // ID -> client
list map[string]*persistentClient // name -> client

clientIndex *clientIndex

// ipToRC maps IP addresses to runtime client information.
ipToRC map[netip.Addr]*client.Runtime
Expand Down Expand Up @@ -103,9 +104,10 @@ func (clients *clientsContainer) Init(
}

clients.list = map[string]*persistentClient{}
clients.idIndex = map[string]*persistentClient{}
clients.ipToRC = map[netip.Addr]*client.Runtime{}

clients.clientIndex = NewClientIndex()

clients.allTags = stringutil.NewSet(clientTags...)

// TODO(e.burkov): Use [dhcpsvc] implementation when it's ready.
Expand Down Expand Up @@ -518,7 +520,7 @@ func (clients *clientsContainer) UpstreamConfigByID(
// findLocked searches for a client by its ID. clients.lock is expected to be
// locked.
func (clients *clientsContainer) findLocked(id string) (c *persistentClient, ok bool) {
c, ok = clients.idIndex[id]
c, ok = clients.clientIndex.find(id)
if ok {
return c, true
}
Expand All @@ -528,14 +530,6 @@ func (clients *clientsContainer) findLocked(id string) (c *persistentClient, ok
return nil, false
}

for _, c = range clients.list {
for _, subnet := range c.Subnets {
if subnet.Contains(ip) {
return c, true
}
}
}

// TODO(e.burkov): Iterate through clients.list only once.
return clients.findDHCP(ip)
}
Expand Down Expand Up @@ -639,18 +633,15 @@ func (clients *clientsContainer) add(c *persistentClient) (ok bool, err error) {
}

// check ID index
ids := c.ids()
for _, id := range ids {
var c2 *persistentClient
c2, ok = clients.idIndex[id]
if ok {
return false, fmt.Errorf("another client uses the same ID (%q): %q", id, c2.Name)
}
err = clients.clientIndex.clashes(c)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return false, err
}

clients.addLocked(c)

log.Debug("clients: added %q: ID:%q [%d]", c.Name, ids, len(clients.list))
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.ids(), len(clients.list))

return true, nil
}
Expand All @@ -661,9 +652,7 @@ func (clients *clientsContainer) addLocked(c *persistentClient) {
clients.list[c.Name] = c

// update ID index
for _, id := range c.ids() {
clients.idIndex[id] = c
}
clients.clientIndex.add(c)
}

// remove removes a client. ok is false if there is no such client.
Expand Down Expand Up @@ -693,9 +682,7 @@ func (clients *clientsContainer) removeLocked(c *persistentClient) {
delete(clients.list, c.Name)

// Update the ID index.
for _, id := range c.ids() {
delete(clients.idIndex, id)
}
clients.clientIndex.del(c)
}

// update updates a client by its name.
Expand Down Expand Up @@ -725,11 +712,10 @@ func (clients *clientsContainer) update(prev, c *persistentClient) (err error) {
}

// Check the ID index.
for _, id := range c.ids() {
existing, ok := clients.idIndex[id]
if ok && existing != prev {
return fmt.Errorf("id %q is used by client with name %q", id, existing.Name)
}
err = clients.clientIndex.clashes(c)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}

clients.removeLocked(prev)
Expand Down
2 changes: 2 additions & 0 deletions internal/home/clients_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func TestClients(t *testing.T) {

c := &persistentClient{
Name: "client1",
UID: MustNewUID(),
IPs: []netip.Addr{cli1IP, cliIPv6},
}

Expand All @@ -78,6 +79,7 @@ func TestClients(t *testing.T) {

c = &persistentClient{
Name: "client2",
UID: MustNewUID(),
IPs: []netip.Addr{cli2IP},
}

Expand Down
Loading

0 comments on commit 10637fd

Please sign in to comment.