Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions evm/p2p/kademlia.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import ipaddress
import logging
import bisect
import operator
import random
import struct
Expand Down Expand Up @@ -106,6 +107,7 @@ def __hash__(self):
return hash(self.pubkey)


@total_ordering
class KBucket:
"""A bucket of nodes whose IDs fall between the bucket's start and end.

Expand Down Expand Up @@ -189,6 +191,11 @@ def __contains__(self, node):
def __len__(self):
return len(self.nodes)

def __lt__(self, other):
if not isinstance(other, self.__class__):
raise TypeError("Cannot compare KBucket with type {}.".format(other.__class__))
return self.end < other.start


class RoutingTable:

Expand All @@ -212,12 +219,12 @@ def not_full_buckets(self):
return [b for b in self.buckets if not b.is_full]

def remove_node(self, node):
self.get_bucket_for_node(node).remove_node(node)
binary_get_bucket_for_node(self.buckets, node)

def add_node(self, node):
if node == self.this_node:
raise ValueError("Cannot add this_node to routing table")
bucket = self.get_bucket_for_node(node)
bucket = binary_get_bucket_for_node(self.buckets, node)
eviction_candidate = bucket.add(node)
if eviction_candidate is not None: # bucket is full
# Split if the bucket has the local node in its range or if the depth is not congruent
Expand All @@ -231,10 +238,7 @@ def add_node(self, node):
return None # successfully added to not full bucket

def get_bucket_for_node(self, node):
for bucket in self.buckets:
if node.id < bucket.end:
return bucket
raise ValueError("No bucket found for node with id {}".format(node.id))
return binary_get_bucket_for_node(self.buckets, node)

def buckets_by_distance_to(self, id):
return sorted(self.buckets, key=operator.methodcaller('distance_to', id))
Expand Down Expand Up @@ -264,6 +268,19 @@ def neighbours(self, node_id, k=k_bucket_size):
return sort_by_distance(nodes, node_id)[:k]


def binary_get_bucket_for_node(buckets, node):
"""Given a list of ordered buckets, returns the bucket for a given node."""
bucket_ends = [bucket.end for bucket in buckets]
bucket_position = bisect.bisect_left(bucket_ends, node.id)
# Prevents edge cases where bisect_left returns an out of range index
try:
bucket = buckets[bucket_position]
assert bucket.start <= node.id <= bucket.end
return bucket
except (IndexError, AssertionError):
raise ValueError("No bucket found for node with id {}".format(node.id))


class KademliaProtocol:
logger = logging.getLogger("evm.p2p.discovery.KademliaProtocol")

Expand Down
53 changes: 53 additions & 0 deletions evm/p2p/test_kademlia.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,59 @@ def test_kbucket_split():
assert len(bucket2) == bucket.k / 2


def test_bucket_ordering():
first = kademlia.KBucket(0, 50)
second = kademlia.KBucket(51, 100)
third = random_node()
assert first < second
with pytest.raises(TypeError):
first > third


@pytest.mark.parametrize(
"bucket_list, node_id",
(
(list([]), 5),
# test for node.id < bucket.end
(list([kademlia.KBucket(0, 4)]), 5),
# test for node.id > bucket.start
(list([kademlia.KBucket(6, 10)]), 5),
# test multiple buckets that don't contain node.id
(list(
[
kademlia.KBucket(1, 5),
kademlia.KBucket(6, 49),
kademlia.KBucket(50, 100),
]
), 0),
)
)
def test_binary_get_bucket_for_node_error(bucket_list, node_id):
node = random_node(nodeid=node_id)
with pytest.raises(ValueError):
kademlia.binary_get_bucket_for_node(bucket_list, node)


@pytest.mark.parametrize(
"bucket_list, node_id, correct_position",
(
(list([kademlia.KBucket(0, 100)]), 5, 0),
(list([kademlia.KBucket(0, 49), kademlia.KBucket(50, 100)]), 5, 0),
(list(
[
kademlia.KBucket(0, 1),
kademlia.KBucket(2, 5),
kademlia.KBucket(6, 49),
kademlia.KBucket(50, 100)
]
), 5, 1),
)
)
def test_binary_get_bucket_for_node(bucket_list, node_id, correct_position):
node = random_node(nodeid=node_id)
assert kademlia.binary_get_bucket_for_node(bucket_list, node) == bucket_list[correct_position]


def test_compute_shared_prefix_bits():
# When we have less than 2 nodes, the depth is k_id_size.
nodes = [random_node()]
Expand Down