Skip to content

Commit

Permalink
Merge pull request #96 from pipermerriam/piper/fill-out-talk-apis
Browse files Browse the repository at this point in the history
Fully implement TALK apis
  • Loading branch information
pipermerriam committed Sep 30, 2020
2 parents 64485c3 + f25c829 commit b181822
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 38 deletions.
39 changes: 35 additions & 4 deletions ddht/v5_1/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def subscribe(
) -> AsyncContextManager[trio.abc.ReceiveChannel[InboundMessage[TMessage]]]:
...

@abstractmethod
def subscribe_request(
self, request: AnyOutboundMessage, response_payload_type: Type[TMessage],
) -> AsyncContextManager[
Expand Down Expand Up @@ -253,11 +254,13 @@ async def send_ping(
#
# Message Sending API
#
@abstractmethod
async def send_pong(
self, endpoint: Endpoint, node_id: NodeID, *, request_id: int,
) -> None:
...

@abstractmethod
async def send_find_nodes(
self,
endpoint: Endpoint,
Expand All @@ -268,6 +271,7 @@ async def send_find_nodes(
) -> int:
...

@abstractmethod
async def send_found_nodes(
self,
endpoint: Endpoint,
Expand All @@ -278,22 +282,25 @@ async def send_found_nodes(
) -> int:
...

@abstractmethod
async def send_talk_request(
self,
endpoint: Endpoint,
node_id: NodeID,
*,
protocol: bytes,
request: bytes,
payload: bytes,
request_id: Optional[int] = None,
) -> int:
...

@abstractmethod
async def send_talk_response(
self, endpoint: Endpoint, node_id: NodeID, *, response: bytes, request_id: int,
self, endpoint: Endpoint, node_id: NodeID, *, payload: bytes, request_id: int,
) -> None:
...

@abstractmethod
async def send_register_topic(
self,
endpoint: Endpoint,
Expand All @@ -306,6 +313,7 @@ async def send_register_topic(
) -> int:
...

@abstractmethod
async def send_ticket(
self,
endpoint: Endpoint,
Expand All @@ -317,11 +325,13 @@ async def send_ticket(
) -> None:
...

@abstractmethod
async def send_registration_confirmation(
self, endpoint: Endpoint, node_id: NodeID, *, topic: bytes, request_id: int,
) -> None:
...

@abstractmethod
async def send_topic_query(
self,
endpoint: Endpoint,
Expand All @@ -335,21 +345,25 @@ async def send_topic_query(
#
# Request/Response API
#
@abstractmethod
async def ping(
self, endpoint: Endpoint, node_id: NodeID
) -> InboundMessage[PongMessage]:
...

@abstractmethod
async def find_nodes(
self, endpoint: Endpoint, node_id: NodeID, distances: Collection[int]
) -> Tuple[InboundMessage[FoundNodesMessage], ...]:
...

async def talk_request(
self, endpoint: Endpoint, node_id: NodeID, protocol: bytes, request: bytes
@abstractmethod
async def talk(
self, endpoint: Endpoint, node_id: NodeID, protocol: bytes, payload: bytes
) -> InboundMessage[TalkResponseMessage]:
...

@abstractmethod
async def register_topic(
self,
endpoint: Endpoint,
Expand All @@ -362,6 +376,7 @@ async def register_topic(
]:
...

@abstractmethod
async def topic_query(
self, endpoint: Endpoint, node_id: NodeID, topic: bytes
) -> InboundMessage[FoundNodesMessage]:
Expand Down Expand Up @@ -408,25 +423,41 @@ def enr_db(self) -> ENRDatabaseAPI:
#
# High Level API
#
@abstractmethod
async def bond(
self, node_id: NodeID, *, endpoint: Optional[Endpoint] = None
) -> bool:
...

@abstractmethod
async def ping(
self, node_id: NodeID, *, endpoint: Optional[Endpoint] = None
) -> PongMessage:
...

@abstractmethod
async def find_nodes(
self, node_id: NodeID, *distances: int, endpoint: Optional[Endpoint] = None,
) -> Tuple[ENRAPI, ...]:
...

@abstractmethod
async def talk(
self,
node_id: NodeID,
*,
protocol: bytes,
payload: bytes,
endpoint: Optional[Endpoint] = None,
) -> bytes:
...

@abstractmethod
async def get_enr(
self, node_id: NodeID, *, endpoint: Optional[Endpoint] = None
) -> ENRAPI:
...

@abstractmethod
async def recursive_find_nodes(self, target: NodeID) -> Tuple[ENRAPI, ...]:
...
55 changes: 30 additions & 25 deletions ddht/v5_1/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ddht.kademlia import compute_log_distance
from ddht.message_registry import MessageTypeRegistry
from ddht.v5_1.abc import ClientAPI, EventsAPI
from ddht.v5_1.constants import FOUND_NODES_MAX_PAYLOAD_SIZE, REQUEST_RESPONSE_TIMEOUT
from ddht.v5_1.constants import FOUND_NODES_MAX_PAYLOAD_SIZE
from ddht.v5_1.dispatcher import Dispatcher
from ddht.v5_1.envelope import (
EnvelopeDecoder,
Expand Down Expand Up @@ -258,12 +258,12 @@ async def send_talk_request(
node_id: NodeID,
*,
protocol: bytes,
request: bytes,
payload: bytes,
request_id: Optional[int] = None,
) -> int:
with self._get_request_id(node_id, request_id) as message_request_id:
message = AnyOutboundMessage(
TalkRequestMessage(message_request_id, protocol, request),
TalkRequestMessage(message_request_id, protocol, payload),
endpoint,
node_id,
)
Expand All @@ -272,10 +272,10 @@ async def send_talk_request(
return message_request_id

async def send_talk_response(
self, endpoint: Endpoint, node_id: NodeID, *, response: bytes, request_id: int,
self, endpoint: Endpoint, node_id: NodeID, *, payload: bytes, request_id: int,
) -> None:
message = AnyOutboundMessage(
TalkResponseMessage(request_id, response,), endpoint, node_id,
TalkResponseMessage(request_id, payload), endpoint, node_id,
)
await self.dispatcher.send_message(message)

Expand Down Expand Up @@ -351,8 +351,7 @@ async def ping(
async with self.dispatcher.subscribe_request(
request, PongMessage
) as subscription:
with trio.fail_after(REQUEST_RESPONSE_TIMEOUT):
return await subscription.receive()
return await subscription.receive()

async def find_nodes(
self, endpoint: Endpoint, node_id: NodeID, distances: Collection[int],
Expand All @@ -364,22 +363,21 @@ async def find_nodes(
async with self.dispatcher.subscribe_request(
request, FoundNodesMessage
) as subscription:
with trio.fail_after(REQUEST_RESPONSE_TIMEOUT):
head_response = await subscription.receive()
total = head_response.message.total
responses: Tuple[InboundMessage[FoundNodesMessage], ...]
if total == 1:
responses = (head_response,)
elif total > 1:
tail_responses: List[InboundMessage[FoundNodesMessage]] = []
for _ in range(total - 1):
tail_responses.append(await subscription.receive())
responses = (head_response,) + tuple(tail_responses)
else:
# TODO: this code path needs to be excercised and
# probably replaced with some sort of
# `SessionTerminated` exception.
raise Exception("Invalid `total` counter in response")
head_response = await subscription.receive()
total = head_response.message.total
responses: Tuple[InboundMessage[FoundNodesMessage], ...]
if total == 1:
responses = (head_response,)
elif total > 1:
tail_responses: List[InboundMessage[FoundNodesMessage]] = []
for _ in range(total - 1):
tail_responses.append(await subscription.receive())
responses = (head_response,) + tuple(tail_responses)
else:
# TODO: this code path needs to be excercised and
# probably replaced with some sort of
# `SessionTerminated` exception.
raise Exception("Invalid `total` counter in response")

# Validate that all responses are indeed at one of the
# specified distances.
Expand All @@ -400,9 +398,16 @@ async def find_nodes(
return responses

async def talk(
self, endpoint: Endpoint, node_id: NodeID, protocol: bytes, request: bytes
self, endpoint: Endpoint, node_id: NodeID, protocol: bytes, payload: bytes
) -> InboundMessage[TalkResponseMessage]:
raise NotImplementedError
with self._get_request_id(node_id) as request_id:
request = AnyOutboundMessage(
TalkRequestMessage(request_id, protocol, payload), endpoint, node_id,
)
async with self.dispatcher.subscribe_request(
request, TalkResponseMessage
) as subscription:
return await subscription.receive()

async def register_topic(
self,
Expand Down
6 changes: 4 additions & 2 deletions ddht/v5_1/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,16 @@ class FoundNodesMessage(BaseMessage):
class TalkRequestMessage(BaseMessage):
message_type = 5

fields = (("request_id", big_endian_int), ("protocol", binary), ("request", binary))
fields = (("request_id", big_endian_int), ("protocol", binary), ("payload", binary))


@v51_registry.register
class TalkResponseMessage(BaseMessage):
message_type = 6

fields = (("request_id", big_endian_int), ("response", binary))
payload: bytes

fields = (("request_id", big_endian_int), ("payload", binary))


@v51_registry.register
Expand Down
21 changes: 17 additions & 4 deletions ddht/v5_1/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async def bond(
) -> bool:
try:
pong = await self.ping(node_id, endpoint=endpoint)
except trio.TooSlowError:
except trio.EndOfChannel:
self.logger.debug(
"Bonding with %s timed out during ping", humanize_node_id(node_id)
)
Expand All @@ -88,7 +88,7 @@ async def bond(
except KeyError:
try:
enr = await self.get_enr(node_id, endpoint=endpoint)
except trio.TooSlowError:
except trio.EndOfChannel:
self.logger.debug(
"Bonding with %s timed out during ENR retrieval",
humanize_node_id(node_id),
Expand All @@ -98,7 +98,7 @@ async def bond(
if pong.enr_seq > enr.sequence_number:
try:
enr = await self.get_enr(node_id, endpoint=endpoint)
except trio.TooSlowError:
except trio.EndOfChannel:
self.logger.debug(
"Bonding with %s timed out during ENR retrieval",
humanize_node_id(node_id),
Expand Down Expand Up @@ -133,6 +133,19 @@ async def find_nodes(
responses = await self.client.find_nodes(endpoint, node_id, distances=distances)
return tuple(enr for response in responses for enr in response.message.enrs)

async def talk(
self,
node_id: NodeID,
*,
protocol: bytes,
payload: bytes,
endpoint: Optional[Endpoint] = None,
) -> bytes:
if endpoint is None:
endpoint = self._endpoint_for_node_id(node_id)
response = await self.client.talk(endpoint, node_id, protocol, payload)
return response.message.payload

async def get_enr(
self, node_id: NodeID, *, endpoint: Optional[Endpoint] = None
) -> ENRAPI:
Expand All @@ -156,7 +169,7 @@ async def do_lookup(node_id: NodeID) -> None:
distance = compute_log_distance(node_id, target)
try:
enrs = await self.find_nodes(node_id, distance)
except trio.TooSlowError:
except trio.EndOfChannel:
unresponsive_node_ids.add(node_id)
return

Expand Down
Loading

0 comments on commit b181822

Please sign in to comment.