diff --git a/hikari/api/event_dispatcher.py b/hikari/api/event_dispatcher.py index 45b8a65083..44c194f0a3 100644 --- a/hikari/api/event_dispatcher.py +++ b/hikari/api/event_dispatcher.py @@ -219,10 +219,10 @@ def get_listeners( The event type to look for. `T` must be a subclass of `hikari.events.base_events.Event`. polymorphic : builtins.bool - If `builtins.True`, this will return `builtins.True` if a subclass - of the given event type has a listener registered. If - `builtins.False`, then only listeners for this class specifically - are returned. The default is `builtins.True`. + If `builtins.True`, this will also return the listeners of the + subclasses of the given event type. If `builtins.False`, then + only listeners for this class specifically are returned. The + default is `builtins.True`. Returns ------- diff --git a/hikari/impl/event_manager_base.py b/hikari/impl/event_manager_base.py index 03a0207f50..0ed4d90c30 100644 --- a/hikari/impl/event_manager_base.py +++ b/hikari/impl/event_manager_base.py @@ -26,6 +26,7 @@ __all__: typing.Final[typing.List[str]] = ["EventManagerBase"] import asyncio +import inspect import logging import typing import warnings @@ -95,16 +96,16 @@ def subscribe( *, _nested: int = 0, ) -> event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co]: - if not issubclass(event_type, base_events.Event): + if not asyncio.iscoroutinefunction(callback): + raise TypeError("Event callbacks must be coroutine functions (`async def')") + + if not inspect.isclass(event_type) or not issubclass(event_type, base_events.Event): raise TypeError("Cannot subscribe to a non-Event type") # `_nested` is used to show the correct source code snippet if an intent # warning is triggered. self._check_intents(event_type, _nested) - if not asyncio.iscoroutinefunction(callback): - raise TypeError("Event callbacks must be coroutine functions (`async def')") - if event_type not in self._listeners: self._listeners[event_type] = [] @@ -144,6 +145,9 @@ def _check_intents(self, event_type: typing.Type[event_dispatcher.EventT_co], ne def get_listeners( self, event_type: typing.Type[event_dispatcher.EventT_co], *, polymorphic: bool = True, ) -> typing.Collection[event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co]]: + if not inspect.isclass(event_type) or not issubclass(event_type, base_events.Event): + raise TypeError(f"Can only get listeners for subclasses of {base_events.Event.__name__}") + if polymorphic: listeners: typing.List[event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co]] = [] for subscribed_event_type, subscribed_listeners in self._listeners.items(): @@ -250,9 +254,7 @@ async def _invoke_callback( self, callback: event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_inv], event: event_dispatcher.EventT_inv ) -> None: try: - result = callback(event) - if asyncio.iscoroutine(result): - await result + await callback(event) except Exception as ex: # Skip the first frame in logs, we don't care for it. diff --git a/tests/hikari/impl/test_event_manager_base.py b/tests/hikari/impl/test_event_manager_base.py new file mode 100644 index 0000000000..e8ed021a08 --- /dev/null +++ b/tests/hikari/impl/test_event_manager_base.py @@ -0,0 +1,263 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2020 Nekokatt +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import warnings + +import mock +import pytest + +from hikari import errors +from hikari import intents +from hikari.events import base_events +from hikari.events import member_events +from hikari.impl import event_manager_base + + +def test__default_predicate_returns_True(): + assert event_manager_base._default_predicate(None) is True + + +class TestEventManagerBase: + @pytest.fixture + def event_manager(self): + class EventManagerBaseImpl(event_manager_base.EventManagerBase): + on_existing_event = None + + return EventManagerBaseImpl(None, None) + + @pytest.mark.asyncio + async def test_consume_raw_event_when_AttributeError(self, event_manager): + with mock.patch.object(event_manager_base, "_LOGGER") as logger: + await event_manager.consume_raw_event(None, "UNEXISTING_EVENT", {}) + + logger.debug.assert_called_once_with("ignoring unknown event %s", "UNEXISTING_EVENT") + + @pytest.mark.asyncio + async def test_consume_raw_event_when_found(self, event_manager): + event_manager.on_existing_event = mock.AsyncMock() + shard = object() + + await event_manager.consume_raw_event(shard, "EXISTING_EVENT", {}) + + event_manager.on_existing_event.assert_awaited_once_with(shard, {}) + + def test_subscribe_when_callback_is_not_coroutine(self, event_manager): + def test(): + ... + + with pytest.raises(TypeError): + event_manager.subscribe(member_events.MemberCreateEvent, test) + + def test_subscribe_when_event_type_does_not_subclass_Event(self, event_manager): + async def test(): + ... + + with pytest.raises(TypeError): + event_manager.subscribe("test", test) + + def test_subscribe_when_event_type_not_in_listeners(self, event_manager): + async def test(): + ... + + with mock.patch.object(event_manager_base.EventManagerBase, "_check_intents") as check: + assert event_manager.subscribe(member_events.MemberCreateEvent, test, _nested=1) == test + + assert event_manager._listeners == {member_events.MemberCreateEvent: [test]} + check.assert_called_once_with(member_events.MemberCreateEvent, 1) + + def test_subscribe_when_event_type_in_listeners(self, event_manager): + async def test(): + ... + + async def test2(): + ... + + event_manager._listeners[member_events.MemberCreateEvent] = [test2] + + with mock.patch.object(event_manager_base.EventManagerBase, "_check_intents") as check: + assert event_manager.subscribe(member_events.MemberCreateEvent, test, _nested=2) == test + + assert event_manager._listeners == {member_events.MemberCreateEvent: [test2, test]} + check.assert_called_once_with(member_events.MemberCreateEvent, 2) + + def test__check_intents_when_intents_is_None(self, event_manager): + with mock.patch.object(base_events, "get_required_intents_for") as get_intents: + event_manager._check_intents(member_events.MemberCreateEvent, 0) + + get_intents.assert_not_called() + + def test__check_intents_when_no_intents_required(self, event_manager): + event_manager._intents = intents.Intents.ALL + + with mock.patch.object(base_events, "get_required_intents_for", return_value=None) as get_intents: + with mock.patch.object(warnings, "warn") as warn: + event_manager._check_intents(member_events.MemberCreateEvent, 0) + + get_intents.assert_called_once_with(member_events.MemberCreateEvent) + warn.assert_not_called() + + def test__check_intents_when_intents_correct(self, event_manager): + event_manager._intents = intents.Intents.GUILD_EMOJIS | intents.Intents.GUILD_MEMBERS + + with mock.patch.object( + base_events, "get_required_intents_for", return_value=intents.Intents.GUILD_MEMBERS + ) as get_intents: + with mock.patch.object(warnings, "warn") as warn: + event_manager._check_intents(member_events.MemberCreateEvent, 0) + + get_intents.assert_called_once_with(member_events.MemberCreateEvent) + warn.assert_not_called() + + def test__check_intents_when_intents_incorrect(self, event_manager): + event_manager._intents = intents.Intents.GUILD_EMOJIS + + with mock.patch.object( + base_events, "get_required_intents_for", return_value=intents.Intents.GUILD_MEMBERS + ) as get_intents: + with mock.patch.object(warnings, "warn") as warn: + event_manager._check_intents(member_events.MemberCreateEvent, 0) + + get_intents.assert_called_once_with(member_events.MemberCreateEvent) + warn.assert_called_once_with( + "You have tried to listen to MemberCreateEvent, but this will only ever be triggered if " + "you enable one of the following intents: GUILD_MEMBERS.", + category=errors.MissingIntentWarning, + stacklevel=3, + ) + + def test_get_listeners_when_not_event(self, event_manager): + with pytest.raises(TypeError): + event_manager.get_listeners("test") + + def test_get_listeners_polimorphic(self, event_manager): + event_manager._listeners = { + base_events.Event: ["this will never appear"], + member_events.MemberEvent: ["coroutine0"], + member_events.MemberCreateEvent: ["coroutine1", "coroutine2"], + member_events.MemberUpdateEvent: ["coroutine3"], + member_events.MemberDeleteEvent: ["coroutine4", "coroutine5"], + } + + assert event_manager.get_listeners(member_events.MemberEvent) == [ + "coroutine0", + "coroutine1", + "coroutine2", + "coroutine3", + "coroutine4", + "coroutine5", + ] + + def test_get_listeners_no_polimorphic_and_no_results(self, event_manager): + event_manager._listeners = { + member_events.MemberCreateEvent: ["coroutine1", "coroutine2"], + member_events.MemberUpdateEvent: ["coroutine3"], + member_events.MemberDeleteEvent: ["coroutine4", "coroutine5"], + member_events.MemberDeleteEvent: ["coroutine4", "coroutine5"], + } + + assert event_manager.get_listeners(member_events.MemberEvent, polymorphic=False) == [] + + def test_get_listeners_no_polimorphic_and_results(self, event_manager): + event_manager._listeners = { + member_events.MemberEvent: ["coroutine0"], + member_events.MemberCreateEvent: ["coroutine1", "coroutine2"], + member_events.MemberUpdateEvent: ["coroutine3"], + member_events.MemberDeleteEvent: ["coroutine4", "coroutine5"], + } + + assert event_manager.get_listeners(member_events.MemberEvent, polymorphic=False) == ["coroutine0"] + + def test_unsubscribe_when_event_type_not_in_listeners(self, event_manager): + async def test(): + ... + + event_manager._listeners = {} + + event_manager.unsubscribe(member_events.MemberCreateEvent, test) + + assert event_manager._listeners == {} + + def test_unsubscribe_when_event_type_when_list_not_empty_after_delete(self, event_manager): + async def test(): + ... + + async def test2(): + ... + + event_manager._listeners = { + member_events.MemberCreateEvent: [test, test2], + member_events.MemberDeleteEvent: [test], + } + + event_manager.unsubscribe(member_events.MemberCreateEvent, test) + + assert event_manager._listeners == { + member_events.MemberCreateEvent: [test2], + member_events.MemberDeleteEvent: [test], + } + + def test_unsubscribe_when_event_type_when_list_empty_after_delete(self, event_manager): + async def test(): + ... + + event_manager._listeners = {member_events.MemberCreateEvent: [test], member_events.MemberDeleteEvent: [test]} + + event_manager.unsubscribe(member_events.MemberCreateEvent, test) + + assert event_manager._listeners == {member_events.MemberDeleteEvent: [test]} + + def test_listen_when_no_params(self, event_manager): + with pytest.raises(TypeError): + + @event_manager.listen() + async def test(): + ... + + def test_listen_when_more_then_one_param(self, event_manager): + with pytest.raises(TypeError): + + @event_manager.listen() + async def test(a, b, c): + ... + + def test_listen_when_param_not_provided_in_decorator_nor_typehint(self, event_manager): + with pytest.raises(TypeError): + + @event_manager.listen() + async def test(event): + ... + + def test_listen_when_param_provided_in_decorator(self, event_manager): + with mock.patch.object(event_manager_base.EventManagerBase, "subscribe") as subscribe: + + @event_manager.listen(member_events.MemberCreateEvent) + async def test(event): + ... + + subscribe.assert_called_once_with(member_events.MemberCreateEvent, test, _nested=1) + + def test_listen_when_param_provided_in_typehint(self, event_manager): + with mock.patch.object(event_manager_base.EventManagerBase, "subscribe") as subscribe: + + @event_manager.listen() + async def test(event: member_events.MemberCreateEvent): + ... + + subscribe.assert_called_once_with(member_events.MemberCreateEvent, test, _nested=1)