diff --git a/evm/p2p/kademlia.py b/evm/p2p/kademlia.py index aff132499b..2a9cf9e137 100644 --- a/evm/p2p/kademlia.py +++ b/evm/p2p/kademlia.py @@ -1,6 +1,7 @@ import asyncio import ipaddress import logging +import bisect import operator import random import struct @@ -106,6 +107,7 @@ def __hash__(self): return hash(self.pubkey) +@total_ordering class KBucket: """A bucket of nodes whose IDs fall between the bucket's start and end. @@ -189,6 +191,11 @@ def __contains__(self, node): def __len__(self): return len(self.nodes) + def __lt__(self, other): + if not isinstance(other, self.__class__): + raise TypeError("Cannot compare KBucket with type {}.".format(other.__class__)) + return self.end < other.start + class RoutingTable: @@ -212,12 +219,12 @@ def not_full_buckets(self): return [b for b in self.buckets if not b.is_full] def remove_node(self, node): - self.get_bucket_for_node(node).remove_node(node) + binary_get_bucket_for_node(self.buckets, node) def add_node(self, node): if node == self.this_node: raise ValueError("Cannot add this_node to routing table") - bucket = self.get_bucket_for_node(node) + bucket = binary_get_bucket_for_node(self.buckets, node) eviction_candidate = bucket.add(node) if eviction_candidate is not None: # bucket is full # Split if the bucket has the local node in its range or if the depth is not congruent @@ -231,10 +238,7 @@ def add_node(self, node): return None # successfully added to not full bucket def get_bucket_for_node(self, node): - for bucket in self.buckets: - if node.id < bucket.end: - return bucket - raise ValueError("No bucket found for node with id {}".format(node.id)) + return binary_get_bucket_for_node(self.buckets, node) def buckets_by_distance_to(self, id): return sorted(self.buckets, key=operator.methodcaller('distance_to', id)) @@ -264,6 +268,19 @@ def neighbours(self, node_id, k=k_bucket_size): return sort_by_distance(nodes, node_id)[:k] +def binary_get_bucket_for_node(buckets, node): + """Given a list of ordered buckets, returns the bucket for a given node.""" + bucket_ends = [bucket.end for bucket in buckets] + bucket_position = bisect.bisect_left(bucket_ends, node.id) + # Prevents edge cases where bisect_left returns an out of range index + try: + bucket = buckets[bucket_position] + assert bucket.start <= node.id <= bucket.end + return bucket + except (IndexError, AssertionError): + raise ValueError("No bucket found for node with id {}".format(node.id)) + + class KademliaProtocol: logger = logging.getLogger("evm.p2p.discovery.KademliaProtocol") diff --git a/evm/p2p/test_kademlia.py b/evm/p2p/test_kademlia.py index 370de5a7d4..b622c0ad1b 100644 --- a/evm/p2p/test_kademlia.py +++ b/evm/p2p/test_kademlia.py @@ -263,6 +263,59 @@ def test_kbucket_split(): assert len(bucket2) == bucket.k / 2 +def test_bucket_ordering(): + first = kademlia.KBucket(0, 50) + second = kademlia.KBucket(51, 100) + third = random_node() + assert first < second + with pytest.raises(TypeError): + first > third + + +@pytest.mark.parametrize( + "bucket_list, node_id", + ( + (list([]), 5), + # test for node.id < bucket.end + (list([kademlia.KBucket(0, 4)]), 5), + # test for node.id > bucket.start + (list([kademlia.KBucket(6, 10)]), 5), + # test multiple buckets that don't contain node.id + (list( + [ + kademlia.KBucket(1, 5), + kademlia.KBucket(6, 49), + kademlia.KBucket(50, 100), + ] + ), 0), + ) +) +def test_binary_get_bucket_for_node_error(bucket_list, node_id): + node = random_node(nodeid=node_id) + with pytest.raises(ValueError): + kademlia.binary_get_bucket_for_node(bucket_list, node) + + +@pytest.mark.parametrize( + "bucket_list, node_id, correct_position", + ( + (list([kademlia.KBucket(0, 100)]), 5, 0), + (list([kademlia.KBucket(0, 49), kademlia.KBucket(50, 100)]), 5, 0), + (list( + [ + kademlia.KBucket(0, 1), + kademlia.KBucket(2, 5), + kademlia.KBucket(6, 49), + kademlia.KBucket(50, 100) + ] + ), 5, 1), + ) +) +def test_binary_get_bucket_for_node(bucket_list, node_id, correct_position): + node = random_node(nodeid=node_id) + assert kademlia.binary_get_bucket_for_node(bucket_list, node) == bucket_list[correct_position] + + def test_compute_shared_prefix_bits(): # When we have less than 2 nodes, the depth is k_id_size. nodes = [random_node()]