Skip to content

Commit

Permalink
Support yielding in interaction handler (#1383)
Browse files Browse the repository at this point in the history
  • Loading branch information
davfsa authored Jan 1, 2023
1 parent fe83246 commit 0c49566
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 6 deletions.
1 change: 1 addition & 0 deletions changes/1383.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support yielding in interaction listeners.
13 changes: 10 additions & 3 deletions hikari/api/interaction_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@
]


ListenerT = typing.Callable[["_InteractionT_co"], typing.Awaitable["_ResponseT_co"]]
ListenerT = typing.Union[
typing.Callable[["_InteractionT_co"], typing.Awaitable["_ResponseT_co"]],
typing.Callable[["_InteractionT_co"], typing.AsyncGenerator["_ResponseT_co", None]],
]
"""Type hint of a Interaction server's listener callback.
This should be an async callback which takes in one positional argument which
Expand Down Expand Up @@ -255,8 +258,12 @@ def set_listener(
interaction_type : typing.Type[hikari.interactions.base_interactions.PartialInteraction]
The type of interaction this listener should be registered for.
listener : typing.Optional[ListenerT[hikari.interactions.base_interactions.PartialInteraction, hikari.api.special_endpoints.InteractionResponseBuilder]]
The asynchronous listener callback to set or `None` to
unset the previous listener.
The asynchronous listener callback to set or `None` to unset the previous listener.
An asynchronous listener can be either a normal coroutine or an
async generator which should yield exactly once. This allows
sending an initial response to the request, while still
later executing further logic.
Other Parameters
----------------
Expand Down
36 changes: 35 additions & 1 deletion hikari/impl/interaction_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
__all__: typing.Sequence[str] = ("InteractionServer",)

import asyncio
import inspect
import logging
import typing

Expand Down Expand Up @@ -166,6 +167,22 @@ async def write(self, writer: aiohttp.abc.AbstractStreamWriter) -> None:
await writer.write(chunk)


async def _consume_generator_listener(generator: typing.AsyncGenerator[typing.Any, None]) -> None:
try:
await generator.__anext__()

# We expect only one yield!
await generator.athrow(RuntimeError("Generator listener yielded more than once, expected only one yield"))

except StopAsyncIteration:
pass

except Exception as exc:
asyncio.get_running_loop().call_exception_handler(
{"message": "Exception occurred during interaction post dispatch", "exception": exc}
)


class InteractionServer(interaction_server.InteractionServer):
"""Standard implementation of `hikari.api.interaction_server.InteractionServer`.
Expand Down Expand Up @@ -201,6 +218,7 @@ class InteractionServer(interaction_server.InteractionServer):
"_public_key",
"_rest_client",
"_server",
"_running_generator_listeners",
)

def __init__(
Expand Down Expand Up @@ -237,6 +255,7 @@ def __init__(
self._rest_client = rest_client
self._server: typing.Optional[aiohttp.web_runner.AppRunner] = None
self._public_key = nacl.signing.VerifyKey(public_key) if public_key is not None else None
self._running_generator_listeners: typing.List[asyncio.Task[None]] = []

@property
def is_alive(self) -> bool:
Expand Down Expand Up @@ -365,6 +384,11 @@ async def close(self) -> None:
await self._server.cleanup()
self._server = None
self._application_fetch_lock = None

# Wait for handlers to complete
await asyncio.gather(*self._running_generator_listeners)
self._running_generator_listeners = []

self._close_event.set()
self._close_event = None
self._is_closing = False
Expand Down Expand Up @@ -440,7 +464,17 @@ async def on_interaction(self, body: bytes, signature: bytes, timestamp: bytes)
if listener := self._listeners.get(type(interaction)):
_LOGGER.debug("Dispatching interaction %s", interaction.id)
try:
result = await listener(interaction)
call = listener(interaction)

if inspect.isasyncgen(call):
result = await call.__anext__()
task = asyncio.create_task(_consume_generator_listener(call))
task.add_done_callback(self._running_generator_listeners.remove)
self._running_generator_listeners.append(task)

else:
result = await call

raw_payload, files = result.build(self._entity_factory)
payload = self._dumps(raw_payload)

Expand Down
9 changes: 9 additions & 0 deletions hikari/interactions/command_interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,15 @@ def build_deferred_response(self) -> special_endpoints.InteractionDeferredBuilde
the result of this call can be returned as is without any modifications
being made to it.
Examples
--------
.. code-block:: python
async def handle_command_interaction(interaction: CommandInteraction) -> InteractionMessageBuilder:
yield interaction.build_deferred_response()
await interaction.edit_initial_response("Pong!")
Returns
-------
hikari.api.special_endpoints.InteractionMessageBuilder
Expand Down
144 changes: 142 additions & 2 deletions tests/hikari/impl/test_interaction_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,75 @@ async def write_headers(self, status_line: str, headers: "multidict.CIMultiDict[
pass


@pytest.mark.asyncio()
class TestConsumeGeneratorListener:
async def test_normal_behaviour(self):
async def mock_generator_listener():
nonlocal g_continued

yield

g_continued = True

g_continued = False
generator = mock_generator_listener()
# The function expects the generator to have already yielded once
await generator.__anext__()

await interaction_server_impl._consume_generator_listener(generator)

assert g_continued is True

async def test_when_more_than_one_yield(self):
async def mock_generator_listener():
nonlocal g_continued

yield

g_continued = True

yield

g_continued = False
generator = mock_generator_listener()
# The function expects the generator to have already yielded once
await generator.__anext__()

loop = mock.Mock()
with mock.patch.object(asyncio, "get_running_loop", return_value=loop):
await interaction_server_impl._consume_generator_listener(generator)

assert g_continued is True
args, _ = loop.call_exception_handler.call_args_list[0]
exception = args[0]["exception"]
assert isinstance(exception, RuntimeError)
assert exception.args == ("Generator listener yielded more than once, expected only one yield",)

async def test_when_exception(self):
async def mock_generator_listener():
nonlocal g_continued, exception

yield

g_continued = True

raise exception

g_continued = False
exception = ValueError("Some random exception")
generator = mock_generator_listener()
# The function expects the generator to have already yielded once
await generator.__anext__()

loop = mock.Mock()
with mock.patch.object(asyncio, "get_running_loop", return_value=loop):
await interaction_server_impl._consume_generator_listener(generator)

assert g_continued is True
args, _ = loop.call_exception_handler.call_args_list[0]
assert args[0]["exception"] is exception


@pytest.fixture()
def valid_edd25519():
body = (
Expand Down Expand Up @@ -521,13 +590,31 @@ async def test_close(self, mock_interaction_server: interaction_server_impl.Inte
mock_interaction_server._is_closing = False
mock_interaction_server._server = mock_runner
mock_interaction_server._close_event = mock_event

await mock_interaction_server.close()
generator_listener_1 = mock.Mock()
generator_listener_2 = mock.Mock()
generator_listener_3 = mock.Mock()
generator_listener_4 = mock.Mock()
mock_interaction_server._running_generator_listeners = [
generator_listener_1,
generator_listener_2,
generator_listener_3,
generator_listener_4,
]

with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()) as gather:
await mock_interaction_server.close()

mock_runner.shutdown.assert_awaited_once()
mock_runner.cleanup.assert_awaited_once()
mock_event.set.assert_called_once()
assert mock_interaction_server._is_closing is False
assert mock_interaction_server._running_generator_listeners == []
gather.assert_awaited_once_with(
generator_listener_1,
generator_listener_2,
generator_listener_3,
generator_listener_4,
)

@pytest.mark.asyncio()
async def test_close_when_closing(self, mock_interaction_server: interaction_server_impl.InteractionServer):
Expand All @@ -537,13 +624,16 @@ async def test_close_when_closing(self, mock_interaction_server: interaction_ser
mock_interaction_server._close_event = mock_event
mock_interaction_server._is_closing = True
mock_interaction_server.join = mock.AsyncMock()
mock_listener = object()
mock_interaction_server._running_generator_listeners = [mock_listener]

await mock_interaction_server.close()

mock_runner.shutdown.assert_not_called()
mock_runner.cleanup.assert_not_called()
mock_event.set.assert_not_called()
mock_interaction_server.join.assert_awaited_once()
assert mock_interaction_server._running_generator_listeners == [mock_listener]

@pytest.mark.asyncio()
async def test_close_when_not_running(self, mock_interaction_server: interaction_server_impl.InteractionServer):
Expand Down Expand Up @@ -596,6 +686,56 @@ async def test_on_interaction(
assert result.payload == b'{"ok": "No boomer"}'
assert result.status_code == 200

@pytest.mark.asyncio()
async def test_on_interaction_with_generator_listener(
self,
mock_interaction_server: interaction_server_impl.InteractionServer,
mock_entity_factory: entity_factory_impl.EntityFactoryImpl,
public_key: bytes,
valid_edd25519: bytes,
valid_payload: bytes,
):
async def mock_generator_listener(event):
nonlocal g_called, g_complete

g_called = True
assert event is mock_entity_factory.deserialize_interaction.return_value

yield mock_builder

g_complete = True

mock_interaction_server._public_key = nacl.signing.VerifyKey(public_key)
mock_file_1 = mock.Mock()
mock_file_2 = mock.Mock()
mock_entity_factory.deserialize_interaction.return_value = base_interactions.PartialInteraction(
app=None, id=123, application_id=541324, type=2, token="ok", version=1
)
mock_builder = mock.Mock(build=mock.Mock(return_value=({"ok": "No boomer"}, [mock_file_1, mock_file_2])))
g_called = False
g_complete = False
mock_interaction_server.set_listener(base_interactions.PartialInteraction, mock_generator_listener)

result = await mock_interaction_server.on_interaction(*valid_edd25519)

mock_builder.build.assert_called_once_with(mock_entity_factory)
mock_entity_factory.deserialize_interaction.assert_called_once_with(valid_payload)
assert result.content_type == "application/json"
assert result.charset == "UTF-8"
assert result.files == [mock_file_1, mock_file_2]
assert result.headers is None
assert result.payload == b'{"ok": "No boomer"}'
assert result.status_code == 200

assert g_called is True
assert g_complete is False
assert len(mock_interaction_server._running_generator_listeners) != 0
# Give some time for the task to complete
await asyncio.sleep(hikari_test_helpers.REASONABLE_QUICK_RESPONSE_TIME)

assert g_complete is True
assert len(mock_interaction_server._running_generator_listeners) == 0

@pytest.mark.asyncio()
async def test_on_interaction_calls__fetch_public_key(
self, mock_interaction_server: interaction_server_impl.InteractionServer
Expand Down

0 comments on commit 0c49566

Please sign in to comment.