diff --git a/changes/1383.feature.md b/changes/1383.feature.md new file mode 100644 index 0000000000..a62fa1fb6d --- /dev/null +++ b/changes/1383.feature.md @@ -0,0 +1 @@ +Support yielding in interaction listeners. diff --git a/hikari/api/interaction_server.py b/hikari/api/interaction_server.py index 4bf328e420..88c97cfa2d 100644 --- a/hikari/api/interaction_server.py +++ b/hikari/api/interaction_server.py @@ -48,7 +48,10 @@ ] -ListenerT = typing.Callable[["_InteractionT_co"], typing.Awaitable["_ResponseT_co"]] +ListenerT = typing.Union[ + typing.Callable[["_InteractionT_co"], typing.Awaitable["_ResponseT_co"]], + typing.Callable[["_InteractionT_co"], typing.AsyncGenerator["_ResponseT_co", None]], +] """Type hint of a Interaction server's listener callback. This should be an async callback which takes in one positional argument which @@ -255,8 +258,12 @@ def set_listener( interaction_type : typing.Type[hikari.interactions.base_interactions.PartialInteraction] The type of interaction this listener should be registered for. listener : typing.Optional[ListenerT[hikari.interactions.base_interactions.PartialInteraction, hikari.api.special_endpoints.InteractionResponseBuilder]] - The asynchronous listener callback to set or `None` to - unset the previous listener. + The asynchronous listener callback to set or `None` to unset the previous listener. + + An asynchronous listener can be either a normal coroutine or an + async generator which should yield exactly once. This allows + sending an initial response to the request, while still + later executing further logic. Other Parameters ---------------- diff --git a/hikari/impl/interaction_server.py b/hikari/impl/interaction_server.py index 439a876294..364aab04b2 100644 --- a/hikari/impl/interaction_server.py +++ b/hikari/impl/interaction_server.py @@ -26,6 +26,7 @@ __all__: typing.Sequence[str] = ("InteractionServer",) import asyncio +import inspect import logging import typing @@ -166,6 +167,22 @@ async def write(self, writer: aiohttp.abc.AbstractStreamWriter) -> None: await writer.write(chunk) +async def _consume_generator_listener(generator: typing.AsyncGenerator[typing.Any, None]) -> None: + try: + await generator.__anext__() + + # We expect only one yield! + await generator.athrow(RuntimeError("Generator listener yielded more than once, expected only one yield")) + + except StopAsyncIteration: + pass + + except Exception as exc: + asyncio.get_running_loop().call_exception_handler( + {"message": "Exception occurred during interaction post dispatch", "exception": exc} + ) + + class InteractionServer(interaction_server.InteractionServer): """Standard implementation of `hikari.api.interaction_server.InteractionServer`. @@ -201,6 +218,7 @@ class InteractionServer(interaction_server.InteractionServer): "_public_key", "_rest_client", "_server", + "_running_generator_listeners", ) def __init__( @@ -237,6 +255,7 @@ def __init__( self._rest_client = rest_client self._server: typing.Optional[aiohttp.web_runner.AppRunner] = None self._public_key = nacl.signing.VerifyKey(public_key) if public_key is not None else None + self._running_generator_listeners: typing.List[asyncio.Task[None]] = [] @property def is_alive(self) -> bool: @@ -365,6 +384,11 @@ async def close(self) -> None: await self._server.cleanup() self._server = None self._application_fetch_lock = None + + # Wait for handlers to complete + await asyncio.gather(*self._running_generator_listeners) + self._running_generator_listeners = [] + self._close_event.set() self._close_event = None self._is_closing = False @@ -440,7 +464,17 @@ async def on_interaction(self, body: bytes, signature: bytes, timestamp: bytes) if listener := self._listeners.get(type(interaction)): _LOGGER.debug("Dispatching interaction %s", interaction.id) try: - result = await listener(interaction) + call = listener(interaction) + + if inspect.isasyncgen(call): + result = await call.__anext__() + task = asyncio.create_task(_consume_generator_listener(call)) + task.add_done_callback(self._running_generator_listeners.remove) + self._running_generator_listeners.append(task) + + else: + result = await call + raw_payload, files = result.build(self._entity_factory) payload = self._dumps(raw_payload) diff --git a/hikari/interactions/command_interactions.py b/hikari/interactions/command_interactions.py index a75d3613b5..398a678eb5 100644 --- a/hikari/interactions/command_interactions.py +++ b/hikari/interactions/command_interactions.py @@ -422,6 +422,15 @@ def build_deferred_response(self) -> special_endpoints.InteractionDeferredBuilde the result of this call can be returned as is without any modifications being made to it. + Examples + -------- + .. code-block:: python + + async def handle_command_interaction(interaction: CommandInteraction) -> InteractionMessageBuilder: + yield interaction.build_deferred_response() + + await interaction.edit_initial_response("Pong!") + Returns ------- hikari.api.special_endpoints.InteractionMessageBuilder diff --git a/tests/hikari/impl/test_interaction_server.py b/tests/hikari/impl/test_interaction_server.py index 8f995fbf91..febf3fae5b 100644 --- a/tests/hikari/impl/test_interaction_server.py +++ b/tests/hikari/impl/test_interaction_server.py @@ -74,6 +74,75 @@ async def write_headers(self, status_line: str, headers: "multidict.CIMultiDict[ pass +@pytest.mark.asyncio() +class TestConsumeGeneratorListener: + async def test_normal_behaviour(self): + async def mock_generator_listener(): + nonlocal g_continued + + yield + + g_continued = True + + g_continued = False + generator = mock_generator_listener() + # The function expects the generator to have already yielded once + await generator.__anext__() + + await interaction_server_impl._consume_generator_listener(generator) + + assert g_continued is True + + async def test_when_more_than_one_yield(self): + async def mock_generator_listener(): + nonlocal g_continued + + yield + + g_continued = True + + yield + + g_continued = False + generator = mock_generator_listener() + # The function expects the generator to have already yielded once + await generator.__anext__() + + loop = mock.Mock() + with mock.patch.object(asyncio, "get_running_loop", return_value=loop): + await interaction_server_impl._consume_generator_listener(generator) + + assert g_continued is True + args, _ = loop.call_exception_handler.call_args_list[0] + exception = args[0]["exception"] + assert isinstance(exception, RuntimeError) + assert exception.args == ("Generator listener yielded more than once, expected only one yield",) + + async def test_when_exception(self): + async def mock_generator_listener(): + nonlocal g_continued, exception + + yield + + g_continued = True + + raise exception + + g_continued = False + exception = ValueError("Some random exception") + generator = mock_generator_listener() + # The function expects the generator to have already yielded once + await generator.__anext__() + + loop = mock.Mock() + with mock.patch.object(asyncio, "get_running_loop", return_value=loop): + await interaction_server_impl._consume_generator_listener(generator) + + assert g_continued is True + args, _ = loop.call_exception_handler.call_args_list[0] + assert args[0]["exception"] is exception + + @pytest.fixture() def valid_edd25519(): body = ( @@ -521,13 +590,31 @@ async def test_close(self, mock_interaction_server: interaction_server_impl.Inte mock_interaction_server._is_closing = False mock_interaction_server._server = mock_runner mock_interaction_server._close_event = mock_event - - await mock_interaction_server.close() + generator_listener_1 = mock.Mock() + generator_listener_2 = mock.Mock() + generator_listener_3 = mock.Mock() + generator_listener_4 = mock.Mock() + mock_interaction_server._running_generator_listeners = [ + generator_listener_1, + generator_listener_2, + generator_listener_3, + generator_listener_4, + ] + + with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()) as gather: + await mock_interaction_server.close() mock_runner.shutdown.assert_awaited_once() mock_runner.cleanup.assert_awaited_once() mock_event.set.assert_called_once() assert mock_interaction_server._is_closing is False + assert mock_interaction_server._running_generator_listeners == [] + gather.assert_awaited_once_with( + generator_listener_1, + generator_listener_2, + generator_listener_3, + generator_listener_4, + ) @pytest.mark.asyncio() async def test_close_when_closing(self, mock_interaction_server: interaction_server_impl.InteractionServer): @@ -537,6 +624,8 @@ async def test_close_when_closing(self, mock_interaction_server: interaction_ser mock_interaction_server._close_event = mock_event mock_interaction_server._is_closing = True mock_interaction_server.join = mock.AsyncMock() + mock_listener = object() + mock_interaction_server._running_generator_listeners = [mock_listener] await mock_interaction_server.close() @@ -544,6 +633,7 @@ async def test_close_when_closing(self, mock_interaction_server: interaction_ser mock_runner.cleanup.assert_not_called() mock_event.set.assert_not_called() mock_interaction_server.join.assert_awaited_once() + assert mock_interaction_server._running_generator_listeners == [mock_listener] @pytest.mark.asyncio() async def test_close_when_not_running(self, mock_interaction_server: interaction_server_impl.InteractionServer): @@ -596,6 +686,56 @@ async def test_on_interaction( assert result.payload == b'{"ok": "No boomer"}' assert result.status_code == 200 + @pytest.mark.asyncio() + async def test_on_interaction_with_generator_listener( + self, + mock_interaction_server: interaction_server_impl.InteractionServer, + mock_entity_factory: entity_factory_impl.EntityFactoryImpl, + public_key: bytes, + valid_edd25519: bytes, + valid_payload: bytes, + ): + async def mock_generator_listener(event): + nonlocal g_called, g_complete + + g_called = True + assert event is mock_entity_factory.deserialize_interaction.return_value + + yield mock_builder + + g_complete = True + + mock_interaction_server._public_key = nacl.signing.VerifyKey(public_key) + mock_file_1 = mock.Mock() + mock_file_2 = mock.Mock() + mock_entity_factory.deserialize_interaction.return_value = base_interactions.PartialInteraction( + app=None, id=123, application_id=541324, type=2, token="ok", version=1 + ) + mock_builder = mock.Mock(build=mock.Mock(return_value=({"ok": "No boomer"}, [mock_file_1, mock_file_2]))) + g_called = False + g_complete = False + mock_interaction_server.set_listener(base_interactions.PartialInteraction, mock_generator_listener) + + result = await mock_interaction_server.on_interaction(*valid_edd25519) + + mock_builder.build.assert_called_once_with(mock_entity_factory) + mock_entity_factory.deserialize_interaction.assert_called_once_with(valid_payload) + assert result.content_type == "application/json" + assert result.charset == "UTF-8" + assert result.files == [mock_file_1, mock_file_2] + assert result.headers is None + assert result.payload == b'{"ok": "No boomer"}' + assert result.status_code == 200 + + assert g_called is True + assert g_complete is False + assert len(mock_interaction_server._running_generator_listeners) != 0 + # Give some time for the task to complete + await asyncio.sleep(hikari_test_helpers.REASONABLE_QUICK_RESPONSE_TIME) + + assert g_complete is True + assert len(mock_interaction_server._running_generator_listeners) == 0 + @pytest.mark.asyncio() async def test_on_interaction_calls__fetch_public_key( self, mock_interaction_server: interaction_server_impl.InteractionServer