Skip to content

Commit

Permalink
Decode request reply (#161)
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarriba authored Sep 26, 2023
1 parent 67a73dd commit 819c51f
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 5 deletions.
15 changes: 12 additions & 3 deletions py/farm_ng/core/event_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,16 +274,19 @@ async def request_reply(
path: str,
message: Message,
timestamps: list[Timestamp] | None = None,
) -> RequestReplyReply:
decode: bool = False,
) -> RequestReplyReply | Message:
"""Sends a request and waits for a reply.
Args:
path (str): the path of the request.
message (Message): the message to send.
timestamps (list[Timestamp], optional): the timestamps to add to the event.
Defaults to None.
decode (bool, optional): if True, the payload will be decoded. Defaults to False.
Returns:
ReqRepReply: the reply.
ReqRepReply: the request reply with the event and the payload or the decoded message.
"""
# try to connect to the server, if it fails return an emmpty response
if not await self._try_connect():
Expand Down Expand Up @@ -322,7 +325,13 @@ async def request_reply(
reply.event.timestamps.append(
get_monotonic_now(semantics=StampSemantics.CLIENT_RECEIVE),
)
return reply

# decode the payload if requested
reply_or_message: RequestReplyReply | Message = reply
if decode:
reply_or_message = payload_to_protobuf(reply.event, reply.payload)

return reply_or_message


async def test_subscribe(client: EventClient, uri: Uri):
Expand Down
34 changes: 33 additions & 1 deletion py/farm_ng/core/event_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import argparse
import asyncio
import inspect
import logging
import time
from dataclasses import dataclass
Expand Down Expand Up @@ -112,6 +113,9 @@ def __init__(
# the request/reply handler
self._request_reply_handler: Callable | None = None

# NOTE: experimental
self._decode_request_reply_handler_message: bool = False

# add the service to the asyncio server
event_service_pb2_grpc.add_EventServiceServicer_to_server(self, server)

Expand Down Expand Up @@ -145,6 +149,23 @@ def request_reply_handler(self, handler: Callable) -> None:
"""Sets the request/reply handler."""
self._request_reply_handler = handler

def add_request_reply_handler(self, handler: Callable) -> None:
"""Sets the request/reply handler."""

params = inspect.signature(handler).parameters

if len(params) not in (1, 2):
msg = "Request/reply handler must have one or two parameters"
raise ValueError(
msg,
)

if len(params) == 2: # noqa: PLR2004
self._decode_request_reply_handler_message = True

# is safe to set the handler
self._request_reply_handler = handler

@property
def uris(self) -> dict[str, Uri]:
"""Returns the URIs of the service."""
Expand Down Expand Up @@ -283,13 +304,24 @@ async def requestReply(
# adds the timestamps to the event as it passes through the service
recv_stamp = get_monotonic_now(StampSemantics.SERVICE_RECEIVE)
request.event.timestamps.append(recv_stamp)

# metadata to return with the reply
event = Event()
event.CopyFrom(request.event)
event.uri.path = "/request" + event.uri.path

reply_message: Message
if self._request_reply_handler is not None:
reply_message = await self._request_reply_handler(request)
# decode the requested message to satisfy the handler signature
if self._decode_request_reply_handler_message:
message = payload_to_protobuf(request.event, request.payload)
reply_message = await self._request_reply_handler(
request.event,
message,
)
else:
reply_message = await self._request_reply_handler(request)

if reply_message is None:
self.logger.error(
"Request invalid, please check your request channel and packet %s",
Expand Down
81 changes: 80 additions & 1 deletion py/tests/_asyncio/test_event_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@

import pytest
from farm_ng.core.event_client import EventClient
from farm_ng.core.event_pb2 import Event
from farm_ng.core.event_service import EventServiceGrpc
from google.protobuf.wrappers_pb2 import Int32Value
from farm_ng.core.event_service_pb2 import (
RequestReplyRequest,
)
from google.protobuf.empty_pb2 import Empty
from google.protobuf.message import Message
from google.protobuf.wrappers_pb2 import Int32Value, StringValue


class TestEventClient:
Expand Down Expand Up @@ -59,3 +65,76 @@ async def subscribe_callback(client: EventClient, queue: asyncio.Queue):
except asyncio.CancelledError:
pass
assert task.done()

@pytest.mark.anyio()
async def test_request_reply(
self,
event_service: EventServiceGrpc,
event_client: EventClient,
) -> None:
async def request_reply_handler(
request: RequestReplyRequest,
) -> Message:
if request.event.uri.path == "/get_foo":
return StringValue(value="foo")
if request.event.uri.path == "/get_bar":
return StringValue(value="bar")
if request.event.uri.path == "/await":
await asyncio.sleep(0.1)
return Empty()
return None

# reset the counts
event_service.reset()
event_service.request_reply_handler = request_reply_handler

# get decoded response
res = await event_client.request_reply("/get_foo", Empty(), decode=True)
assert res.value == "foo"

res = await event_client.request_reply("/get_bar", Empty(), decode=True)
assert res.value == "bar"

res = await event_client.request_reply("/await", Empty(), decode=True)
assert isinstance(res, Empty)

# get raw response
res = await event_client.request_reply("/get_foo", Empty(), decode=False)
assert res.event.uri.path == "/reply/request/get_foo"
assert "StringValue" in res.event.uri.query
assert res.payload == b"\n\x03foo"

@pytest.mark.anyio()
async def test_request_reply_callback(
self,
event_service: EventServiceGrpc,
event_client: EventClient,
) -> None:
async def request_reply_handler(
event: Event,
message: StringValue,
) -> StringValue:
if event.uri.path == "/foo":
return StringValue(value=f"{message.value} world !")
return Empty()

# reset the counts
event_service.reset()
event_service.add_request_reply_handler(request_reply_handler)

# get decoded response
res = await event_client.request_reply(
"/foo",
StringValue(value="hello"),
decode=True,
)
assert isinstance(res, StringValue)
assert res.value == "hello world !"

# empty response
res = await event_client.request_reply(
"/bar",
StringValue(value="hello"),
decode=True,
)
assert isinstance(res, Empty)

0 comments on commit 819c51f

Please sign in to comment.