diff --git a/p2p/kademlia.py b/p2p/kademlia.py index dcf6fa5f13..c445bec2a9 100644 --- a/p2p/kademlia.py +++ b/p2p/kademlia.py @@ -1,19 +1,21 @@ +from functools import total_ordering +from urllib import parse as urlparse import asyncio +import bisect +import contextlib import ipaddress import logging -import bisect import operator import random import struct import time -from urllib import parse as urlparse -from functools import total_ordering from typing import ( # noqa: F401 Any, Callable, cast, Dict, Generator, + Iterator, List, Set, Sized, @@ -401,6 +403,15 @@ def update_routing_table(self, node: Node) -> None: # replacement cache. asyncio.ensure_future(self.bond(eviction_candidate, self.wire.cancel_token)) + @contextlib.contextmanager + def _ping_callback_event(self, remote: Node) -> Iterator[asyncio.Event]: + event = asyncio.Event() + self.ping_callbacks[remote] = event.set + try: + yield event + finally: + del self.ping_callbacks[remote] + async def wait_ping(self, remote: Node, cancel_token: CancelToken) -> bool: """Wait for a ping from the given remote. @@ -412,17 +423,15 @@ async def wait_ping(self, remote: Node, cancel_token: CancelToken) -> bool: raise AlreadyWaiting( "There's another coroutine waiting for a ping packet from {}".format(remote)) - event = asyncio.Event() - self.ping_callbacks[remote] = event.set - got_ping = False - try: - got_ping = await wait_with_token( - event.wait(), token=cancel_token, timeout=k_request_timeout) - self.logger.debug('got expected ping from %s', remote) - except TimeoutError: - self.logger.debug('timed out waiting for ping from %s', remote) - # TODO: Use a contextmanager to ensure we always delete the callback from the list. - del self.ping_callbacks[remote] + with self._ping_callback_event(remote) as event: + try: + got_ping = await wait_with_token( + event.wait(), token=cancel_token, timeout=k_request_timeout) + self.logger.debug('got expected ping from %s', remote) + except TimeoutError: + got_ping = False + self.logger.debug('timed out waiting for ping from %s', remote) + return got_ping async def wait_pong(self, remote: Node, token: bytes, cancel_token: CancelToken) -> bool: @@ -486,6 +495,7 @@ def process(response): def ping(self, node: Node) -> bytes: if node == self.this_node: raise ValueError("Cannot ping self") + return self.wire.send_ping(node) async def bond(self, node: Node, cancel_token: CancelToken) -> bool: @@ -496,6 +506,9 @@ async def bond(self, node: Node, cancel_token: CancelToken) -> bool: """ if node in self.routing: return True + elif node in self.ping_callbacks: + self.logger.debug("bonding failed, already waiting for pong from: %s", node) + return False token = self.ping(node) @@ -511,7 +524,11 @@ async def bond(self, node: Node, cancel_token: CancelToken) -> bool: # Give the remote node a chance to ping us before we move on and start sending find_node # requests. It is ok for wait_ping() to timeout and return false here as that just means # the remote remembers us. - await self.wait_ping(node, cancel_token) + try: + await self.wait_ping(node, cancel_token) + except AlreadyWaiting: + self.logger.debug("bonding failed, already waiting for pong from %s", node) + return False self.logger.debug("bonding completed successfully with %s", node) self.update_routing_table(node)