Skip to content

Commit

Permalink
device: remove nodes by peer in O(1) instead of O(n)
Browse files Browse the repository at this point in the history
Now that we have parent pointers hooked up, we can simply go right to
the node and remove it in place, rather than having to recursively walk
the entire trie.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
  • Loading branch information
zx2c4 committed Jun 3, 2021
1 parent b41f4cc commit c382222
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 72 deletions.
58 changes: 32 additions & 26 deletions device/allowedips.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,30 +85,6 @@ func (node *trieEntry) removeFromPeerEntries() {
}
}

func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
if node == nil {
return node
}

// walk recursively

node.child[0] = node.child[0].removeByPeer(p)
node.child[1] = node.child[1].removeByPeer(p)

if node.peer != p {
return node
}

// remove peer & merge

node.removeFromPeerEntries()
node.peer = nil
if node.child[0] == nil {
return node.child[1]
}
return node.child[0]
}

func (node *trieEntry) choose(ip net.IP) byte {
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
}
Expand Down Expand Up @@ -261,8 +237,38 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()

table.IPv4 = table.IPv4.removeByPeer(peer)
table.IPv6 = table.IPv6.removeByPeer(peer)
var next *list.Element
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
next = elem.Next()
node := elem.Value.(*trieEntry)

node.removeFromPeerEntries()
node.peer = nil
if node.child[0] != nil && node.child[1] != nil {
continue
}
bit := 0
if node.child[0] == nil {
bit = 1
}
child := node.child[bit]
if child != nil {
child.parent = node.parent
}
*node.parent.parentBit = child
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
continue
}
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
if parent.peer != nil {
continue
}
child = parent.child[node.parent.parentBitType^1]
if child != nil {
child.parent = parent.parent
}
*parent.parent.parentBit = child
}
}

func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
Expand Down
96 changes: 50 additions & 46 deletions device/allowedips_rand_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package device

import (
"math/rand"
"net"
"sort"
"testing"
)
Expand Down Expand Up @@ -64,68 +65,71 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
return nil
}

func TestTrieRandomIPv4(t *testing.T) {
var slow SlowRouter
var peers []*Peer
var allowedIPs AllowedIPs

rand.Seed(1)

const AddressLength = 4

for n := 0; n < NumberOfPeers; n++ {
peers = append(peers, &Peer{})
}

for n := 0; n < NumberOfAddresses; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers
allowedIPs.Insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index])
}

for n := 0; n < NumberOfTests; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
peer1 := slow.Lookup(addr[:])
peer2 := allowedIPs.LookupIPv4(addr[:])
if peer1 != peer2 {
t.Error("Trie did not match naive implementation, for:", addr)
func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
n := 0
for _, x := range r {
if x.peer != peer {
r[n] = x
n++
}
}
return r[:n]
}

func TestTrieRandomIPv6(t *testing.T) {
var slow SlowRouter
func TestTrieRandom(t *testing.T) {
var slow4, slow6 SlowRouter
var peers []*Peer
var allowedIPs AllowedIPs

rand.Seed(1)

const AddressLength = 16

for n := 0; n < NumberOfPeers; n++ {
peers = append(peers, &Peer{})
}

for n := 0; n < NumberOfAddresses; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers
allowedIPs.Insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index])
var addr4 [4]byte
rand.Read(addr4[:])
cidr := uint8(rand.Intn(32) + 1)
index := rand.Intn(NumberOfPeers)
allowedIPs.Insert(addr4[:], cidr, peers[index])
slow4 = slow4.Insert(addr4[:], cidr, peers[index])

var addr6 [16]byte
rand.Read(addr6[:])
cidr = uint8(rand.Intn(128) + 1)
index = rand.Intn(NumberOfPeers)
allowedIPs.Insert(addr6[:], cidr, peers[index])
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
}

for n := 0; n < NumberOfTests; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
peer1 := slow.Lookup(addr[:])
peer2 := allowedIPs.LookupIPv6(addr[:])
if peer1 != peer2 {
t.Error("Trie did not match naive implementation, for:", addr)
for p := 0; ; p++ {
for n := 0; n < NumberOfTests; n++ {
var addr4 [4]byte
rand.Read(addr4[:])
peer1 := slow4.Lookup(addr4[:])
peer2 := allowedIPs.LookupIPv4(addr4[:])
if peer1 != peer2 {
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
}

var addr6 [16]byte
rand.Read(addr6[:])
peer1 = slow6.Lookup(addr6[:])
peer2 = allowedIPs.LookupIPv6(addr6[:])
if peer1 != peer2 {
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
}
}
if p >= len(peers) {
break
}
allowedIPs.RemoveByPeer(peers[p])
slow4 = slow4.RemoveByPeer(peers[p])
slow6 = slow6.RemoveByPeer(peers[p])
}

if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
t.Error("Failed to remove all nodes from trie by peer")
}
}

0 comments on commit c382222

Please sign in to comment.