diff --git a/ddht/v5_1/abc.py b/ddht/v5_1/abc.py index d7ffbf3c..9d113ba6 100644 --- a/ddht/v5_1/abc.py +++ b/ddht/v5_1/abc.py @@ -220,6 +220,7 @@ def subscribe( ) -> AsyncContextManager[trio.abc.ReceiveChannel[InboundMessage[TMessage]]]: ... + @abstractmethod def subscribe_request( self, request: AnyOutboundMessage, response_payload_type: Type[TMessage], ) -> AsyncContextManager[ @@ -253,11 +254,13 @@ async def send_ping( # # Message Sending API # + @abstractmethod async def send_pong( self, endpoint: Endpoint, node_id: NodeID, *, request_id: int, ) -> None: ... + @abstractmethod async def send_find_nodes( self, endpoint: Endpoint, @@ -268,6 +271,7 @@ async def send_find_nodes( ) -> int: ... + @abstractmethod async def send_found_nodes( self, endpoint: Endpoint, @@ -278,22 +282,25 @@ async def send_found_nodes( ) -> int: ... + @abstractmethod async def send_talk_request( self, endpoint: Endpoint, node_id: NodeID, *, protocol: bytes, - request: bytes, + payload: bytes, request_id: Optional[int] = None, ) -> int: ... + @abstractmethod async def send_talk_response( - self, endpoint: Endpoint, node_id: NodeID, *, response: bytes, request_id: int, + self, endpoint: Endpoint, node_id: NodeID, *, payload: bytes, request_id: int, ) -> None: ... + @abstractmethod async def send_register_topic( self, endpoint: Endpoint, @@ -306,6 +313,7 @@ async def send_register_topic( ) -> int: ... + @abstractmethod async def send_ticket( self, endpoint: Endpoint, @@ -317,11 +325,13 @@ async def send_ticket( ) -> None: ... + @abstractmethod async def send_registration_confirmation( self, endpoint: Endpoint, node_id: NodeID, *, topic: bytes, request_id: int, ) -> None: ... + @abstractmethod async def send_topic_query( self, endpoint: Endpoint, @@ -335,21 +345,25 @@ async def send_topic_query( # # Request/Response API # + @abstractmethod async def ping( self, endpoint: Endpoint, node_id: NodeID ) -> InboundMessage[PongMessage]: ... + @abstractmethod async def find_nodes( self, endpoint: Endpoint, node_id: NodeID, distances: Collection[int] ) -> Tuple[InboundMessage[FoundNodesMessage], ...]: ... - async def talk_request( - self, endpoint: Endpoint, node_id: NodeID, protocol: bytes, request: bytes + @abstractmethod + async def talk( + self, endpoint: Endpoint, node_id: NodeID, protocol: bytes, payload: bytes ) -> InboundMessage[TalkResponseMessage]: ... + @abstractmethod async def register_topic( self, endpoint: Endpoint, @@ -362,6 +376,7 @@ async def register_topic( ]: ... + @abstractmethod async def topic_query( self, endpoint: Endpoint, node_id: NodeID, topic: bytes ) -> InboundMessage[FoundNodesMessage]: @@ -408,25 +423,41 @@ def enr_db(self) -> ENRDatabaseAPI: # # High Level API # + @abstractmethod async def bond( self, node_id: NodeID, *, endpoint: Optional[Endpoint] = None ) -> bool: ... + @abstractmethod async def ping( self, node_id: NodeID, *, endpoint: Optional[Endpoint] = None ) -> PongMessage: ... + @abstractmethod async def find_nodes( self, node_id: NodeID, *distances: int, endpoint: Optional[Endpoint] = None, ) -> Tuple[ENRAPI, ...]: ... + @abstractmethod + async def talk( + self, + node_id: NodeID, + *, + protocol: bytes, + payload: bytes, + endpoint: Optional[Endpoint] = None, + ) -> bytes: + ... + + @abstractmethod async def get_enr( self, node_id: NodeID, *, endpoint: Optional[Endpoint] = None ) -> ENRAPI: ... + @abstractmethod async def recursive_find_nodes(self, target: NodeID) -> Tuple[ENRAPI, ...]: ... diff --git a/ddht/v5_1/client.py b/ddht/v5_1/client.py index 694fdd81..16f59b45 100644 --- a/ddht/v5_1/client.py +++ b/ddht/v5_1/client.py @@ -22,7 +22,7 @@ from ddht.kademlia import compute_log_distance from ddht.message_registry import MessageTypeRegistry from ddht.v5_1.abc import ClientAPI, EventsAPI -from ddht.v5_1.constants import FOUND_NODES_MAX_PAYLOAD_SIZE, REQUEST_RESPONSE_TIMEOUT +from ddht.v5_1.constants import FOUND_NODES_MAX_PAYLOAD_SIZE from ddht.v5_1.dispatcher import Dispatcher from ddht.v5_1.envelope import ( EnvelopeDecoder, @@ -258,12 +258,12 @@ async def send_talk_request( node_id: NodeID, *, protocol: bytes, - request: bytes, + payload: bytes, request_id: Optional[int] = None, ) -> int: with self._get_request_id(node_id, request_id) as message_request_id: message = AnyOutboundMessage( - TalkRequestMessage(message_request_id, protocol, request), + TalkRequestMessage(message_request_id, protocol, payload), endpoint, node_id, ) @@ -272,10 +272,10 @@ async def send_talk_request( return message_request_id async def send_talk_response( - self, endpoint: Endpoint, node_id: NodeID, *, response: bytes, request_id: int, + self, endpoint: Endpoint, node_id: NodeID, *, payload: bytes, request_id: int, ) -> None: message = AnyOutboundMessage( - TalkResponseMessage(request_id, response,), endpoint, node_id, + TalkResponseMessage(request_id, payload), endpoint, node_id, ) await self.dispatcher.send_message(message) @@ -351,8 +351,7 @@ async def ping( async with self.dispatcher.subscribe_request( request, PongMessage ) as subscription: - with trio.fail_after(REQUEST_RESPONSE_TIMEOUT): - return await subscription.receive() + return await subscription.receive() async def find_nodes( self, endpoint: Endpoint, node_id: NodeID, distances: Collection[int], @@ -364,22 +363,21 @@ async def find_nodes( async with self.dispatcher.subscribe_request( request, FoundNodesMessage ) as subscription: - with trio.fail_after(REQUEST_RESPONSE_TIMEOUT): - head_response = await subscription.receive() - total = head_response.message.total - responses: Tuple[InboundMessage[FoundNodesMessage], ...] - if total == 1: - responses = (head_response,) - elif total > 1: - tail_responses: List[InboundMessage[FoundNodesMessage]] = [] - for _ in range(total - 1): - tail_responses.append(await subscription.receive()) - responses = (head_response,) + tuple(tail_responses) - else: - # TODO: this code path needs to be excercised and - # probably replaced with some sort of - # `SessionTerminated` exception. - raise Exception("Invalid `total` counter in response") + head_response = await subscription.receive() + total = head_response.message.total + responses: Tuple[InboundMessage[FoundNodesMessage], ...] + if total == 1: + responses = (head_response,) + elif total > 1: + tail_responses: List[InboundMessage[FoundNodesMessage]] = [] + for _ in range(total - 1): + tail_responses.append(await subscription.receive()) + responses = (head_response,) + tuple(tail_responses) + else: + # TODO: this code path needs to be excercised and + # probably replaced with some sort of + # `SessionTerminated` exception. + raise Exception("Invalid `total` counter in response") # Validate that all responses are indeed at one of the # specified distances. @@ -400,9 +398,16 @@ async def find_nodes( return responses async def talk( - self, endpoint: Endpoint, node_id: NodeID, protocol: bytes, request: bytes + self, endpoint: Endpoint, node_id: NodeID, protocol: bytes, payload: bytes ) -> InboundMessage[TalkResponseMessage]: - raise NotImplementedError + with self._get_request_id(node_id) as request_id: + request = AnyOutboundMessage( + TalkRequestMessage(request_id, protocol, payload), endpoint, node_id, + ) + async with self.dispatcher.subscribe_request( + request, TalkResponseMessage + ) as subscription: + return await subscription.receive() async def register_topic( self, diff --git a/ddht/v5_1/messages.py b/ddht/v5_1/messages.py index 2c111190..b3f76d74 100644 --- a/ddht/v5_1/messages.py +++ b/ddht/v5_1/messages.py @@ -59,14 +59,16 @@ class FoundNodesMessage(BaseMessage): class TalkRequestMessage(BaseMessage): message_type = 5 - fields = (("request_id", big_endian_int), ("protocol", binary), ("request", binary)) + fields = (("request_id", big_endian_int), ("protocol", binary), ("payload", binary)) @v51_registry.register class TalkResponseMessage(BaseMessage): message_type = 6 - fields = (("request_id", big_endian_int), ("response", binary)) + payload: bytes + + fields = (("request_id", big_endian_int), ("payload", binary)) @v51_registry.register diff --git a/ddht/v5_1/network.py b/ddht/v5_1/network.py index b5046ef3..d2130863 100644 --- a/ddht/v5_1/network.py +++ b/ddht/v5_1/network.py @@ -77,7 +77,7 @@ async def bond( ) -> bool: try: pong = await self.ping(node_id, endpoint=endpoint) - except trio.TooSlowError: + except trio.EndOfChannel: self.logger.debug( "Bonding with %s timed out during ping", humanize_node_id(node_id) ) @@ -88,7 +88,7 @@ async def bond( except KeyError: try: enr = await self.get_enr(node_id, endpoint=endpoint) - except trio.TooSlowError: + except trio.EndOfChannel: self.logger.debug( "Bonding with %s timed out during ENR retrieval", humanize_node_id(node_id), @@ -98,7 +98,7 @@ async def bond( if pong.enr_seq > enr.sequence_number: try: enr = await self.get_enr(node_id, endpoint=endpoint) - except trio.TooSlowError: + except trio.EndOfChannel: self.logger.debug( "Bonding with %s timed out during ENR retrieval", humanize_node_id(node_id), @@ -133,6 +133,19 @@ async def find_nodes( responses = await self.client.find_nodes(endpoint, node_id, distances=distances) return tuple(enr for response in responses for enr in response.message.enrs) + async def talk( + self, + node_id: NodeID, + *, + protocol: bytes, + payload: bytes, + endpoint: Optional[Endpoint] = None, + ) -> bytes: + if endpoint is None: + endpoint = self._endpoint_for_node_id(node_id) + response = await self.client.talk(endpoint, node_id, protocol, payload) + return response.message.payload + async def get_enr( self, node_id: NodeID, *, endpoint: Optional[Endpoint] = None ) -> ENRAPI: @@ -156,7 +169,7 @@ async def do_lookup(node_id: NodeID) -> None: distance = compute_log_distance(node_id, target) try: enrs = await self.find_nodes(node_id, distance) - except trio.TooSlowError: + except trio.EndOfChannel: unresponsive_node_ids.add(node_id) return diff --git a/tests/core/v5_1/test_client.py b/tests/core/v5_1/test_client.py index 8415f8ad..56e7c39a 100644 --- a/tests/core/v5_1/test_client.py +++ b/tests/core/v5_1/test_client.py @@ -1,3 +1,4 @@ +from contextlib import AsyncExitStack import itertools from eth_enr.tools.factories import ENRFactory @@ -8,6 +9,7 @@ from ddht.datagram import OutboundDatagram from ddht.kademlia import KademliaRoutingTable +from ddht.v5_1.messages import TalkRequestMessage @pytest.fixture @@ -92,7 +94,7 @@ async def test_client_send_talk_request(alice, bob, alice_client, bob_client): bob.endpoint, bob.node_id, protocol=b"test", - request=b"test-request", + payload=b"test-request", ) @@ -104,7 +106,7 @@ async def test_client_send_talk_response(alice, bob, alice_client, bob_client): await alice_client.send_talk_response( bob.endpoint, bob.node_id, - response=b"test-response", + payload=b"test-response", request_id=1234, ) @@ -189,6 +191,13 @@ async def _send_response(): assert pong.message.packet_port == alice.endpoint.port +@pytest.mark.trio +async def test_client_ping_timeout(alice, bob_client, autojump_clock): + with trio.fail_after(60): + with pytest.raises(trio.EndOfChannel): + await bob_client.ping(alice.endpoint, alice.node_id) + + @pytest.mark.trio async def test_client_request_response_find_nodes_found_nodes( alice, bob, alice_client, bob_client @@ -237,6 +246,55 @@ async def _send_response(): assert len(checked_bucket_indexes) > 4 +@pytest.mark.trio +async def test_client_talk_request_response(alice, bob, alice_client, bob_client): + async def _do_talk_response(client): + async with client.dispatcher.subscribe(TalkRequestMessage) as subscription: + request = await subscription.receive() + await client.send_talk_response( + request.sender_endpoint, + request.sender_node_id, + payload=b"talk-response", + request_id=request.message.request_id, + ) + + with trio.fail_after(2): + async with AsyncExitStack() as stack: + await stack.enter_async_context( + alice.events.talk_request_sent.subscribe_and_wait() + ) + await stack.enter_async_context( + alice.events.talk_response_received.subscribe_and_wait() + ) + await stack.enter_async_context( + bob.events.talk_request_received.subscribe_and_wait() + ) + await stack.enter_async_context( + bob.events.talk_response_sent.subscribe_and_wait() + ) + async with trio.open_nursery() as nursery: + nursery.start_soon(_do_talk_response, bob_client) + response = await alice_client.talk( + bob.endpoint, + bob.node_id, + protocol=b"test-talk-proto", + payload=b"test-request", + ) + assert response.message.payload == b"talk-response" + + +@pytest.mark.trio +async def test_client_talk_request_response_timeout(alice, bob_client, autojump_clock): + with trio.fail_after(60): + with pytest.raises(trio.EndOfChannel): + await bob_client.talk( + alice.endpoint, + alice.node_id, + protocol=b"test", + payload=b"test-request", + ) + + @given(datagram_bytes=st.binary(max_size=1024)) @pytest.mark.trio async def test_client_handles_malformed_datagrams(tester, datagram_bytes): diff --git a/tests/core/v5_1/test_network.py b/tests/core/v5_1/test_network.py index 39370ace..75ad9b72 100644 --- a/tests/core/v5_1/test_network.py +++ b/tests/core/v5_1/test_network.py @@ -8,7 +8,7 @@ from ddht.kademlia import compute_log_distance from ddht.v5_1.constants import ROUTING_TABLE_KEEP_ALIVE -from ddht.v5_1.messages import FoundNodesMessage +from ddht.v5_1.messages import FoundNodesMessage, TalkRequestMessage @pytest.mark.trio @@ -207,3 +207,28 @@ async def test_network_pings_oldest_routing_table(tester, alice, bob, autojump_c assert not alice_network.routing_table._contains(bob.node_id, False) assert not alice_network.routing_table._contains(carol.node_id, False) assert not alice_network.routing_table._contains(dylan.node_id, False) + + +@pytest.mark.trio +async def test_network_talk_api(alice, bob): + async def _do_talk_response(network): + async with network.dispatcher.subscribe(TalkRequestMessage) as subscription: + request = await subscription.receive() + await network.client.send_talk_response( + request.sender_endpoint, + request.sender_node_id, + payload=b"test-response-payload", + request_id=request.message.request_id, + ) + + async with alice.network() as alice_network: + async with bob.network() as bob_network: + async with trio.open_nursery() as nursery: + nursery.start_soon(_do_talk_response, bob_network) + + with trio.fail_after(2): + response = await alice_network.talk( + bob.node_id, protocol=b"test", payload=b"test-payload", + ) + + assert response == b"test-response-payload"