Skip to content

Commit

Permalink
bug fixes + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
FasterSpeeding committed Jun 30, 2021
1 parent 4d42db9 commit 868b17e
Show file tree
Hide file tree
Showing 8 changed files with 333 additions and 124 deletions.
3 changes: 3 additions & 0 deletions hikari/api/entity_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ class GatewayGuildDefinition(abc.ABC):
when the relevant resource isn't available in the inner payload.
"""

__slots__: typing.Sequence[str] = ()

@property
@abc.abstractmethod
def id(self) -> snowflakes.Snowflake:
"""ID of the guild the definition is for."""

Expand Down
4 changes: 3 additions & 1 deletion hikari/impl/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,9 @@ def __init__(
self._event_factory = event_factory_impl.EventFactoryImpl(self)

# Event handling
self._event_manager = event_manager_impl.EventManagerImpl(self._event_factory, self._intents, cache=self._cache)
self._event_manager = event_manager_impl.EventManagerImpl(
self._entity_factory, self._event_factory, self._intents, cache=self._cache
)

# Voice subsystem
self._voice = voice_impl.VoiceComponentImpl(self)
Expand Down
13 changes: 8 additions & 5 deletions hikari/impl/event_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from hikari import invites
from hikari import voices
from hikari.api import cache as cache_
from hikari.api import entity_factory as entity_factory_
from hikari.api import event_factory as event_factory_
from hikari.api import shard as gateway_shard
from hikari.internal import data_binding
Expand All @@ -74,18 +75,20 @@ def _fixed_size_nonce() -> str:
class EventManagerImpl(event_manager_base.EventManagerBase):
"""Provides event handling logic for Discord events."""

__slots__: typing.Sequence[str] = ("_cache",)
__slots__: typing.Sequence[str] = ("_cache", "_entity_factory")

def __init__(
self,
entity_factory: entity_factory_.EntityFactory,
event_factory: event_factory_.EventFactory,
intents: intents_.Intents,
/,
*,
cache: typing.Optional[cache_.MutableCache] = None,
) -> None:
self._cache = cache
super().__init__(event_factory=event_factory, intents=intents)
self._entity_factory = entity_factory
super().__init__(event_factory=event_factory, intents=intents, cache_settings=cache.settings if cache else None)

def _cache_enabled_for(self, components: config.CacheComponents, /) -> bool:
return self._cache is not None and (self._cache.settings.components & components) == components
Expand Down Expand Up @@ -156,7 +159,7 @@ async def on_guild_create(self, shard: gateway_shard.GatewayShard, payload: data
if not enabled_for_event and self._cache:
_LOGGER.log(ux.TRACE, "Skipping on_guild_create dispatch due to lack of any registered listeners")
event: typing.Optional[guild_events.GuildAvailableEvent] = None
gd = self._app.entity_factory.deserialize_gateway_guild(payload)
gd = self._entity_factory.deserialize_gateway_guild(payload)

channels = gd.channels() if self._cache_enabled_for(config.CacheComponents.GUILD_CHANNELS) else None
emojis = gd.emojis() if self._cache_enabled_for(config.CacheComponents.EMOJIS) else None
Expand Down Expand Up @@ -254,8 +257,8 @@ async def on_guild_update(self, shard: gateway_shard.GatewayShard, payload: data
if not enabled_for_event and self._cache:
_LOGGER.log(ux.TRACE, "Skipping on_guild_update raw dispatch due to lack of any registered listeners")
event: typing.Optional[guild_events.GuildUpdateEvent] = None
gd = self._app.entity_factory.deserialize_gateway_guild(payload)
emojis = gd.emojis() if self._cache_enabled_for(config.CacheComponents.GUILDS) else None
gd = self._entity_factory.deserialize_gateway_guild(payload)
emojis = gd.emojis() if self._cache_enabled_for(config.CacheComponents.EMOJIS) else None
guild = gd.guild() if self._cache_enabled_for(config.CacheComponents.GUILDS) else None
guild_id = gd.id
roles = gd.roles() if self._cache_enabled_for(config.CacheComponents.ROLES) else None
Expand Down
82 changes: 42 additions & 40 deletions hikari/impl/event_manager_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __cache_components__(self) -> config.CacheComponents:
def __event_types__(self) -> typing.Sequence[typing.Type[base_events.Event]]:
raise NotImplementedError


def _generate_weak_listener(
reference: weakref.WeakMethod,
) -> typing.Callable[[event_manager_.EventT], typing.Coroutine[typing.Any, typing.Any, None]]:
Expand Down Expand Up @@ -258,7 +259,7 @@ def _assert_is_listener(parameters: typing.Iterator[inspect.Parameter], /) -> No
if next(parameters, None) is None:
raise TypeError("Event listener must have one positional argument for the event object.")

if any(param.default is not inspect.Parameter.empty for param in parameters):
if not any(param.default is not inspect.Parameter.empty for param in parameters):
raise TypeError("Only the first argument for a listener can be required, the event argument.")


Expand Down Expand Up @@ -297,22 +298,16 @@ def decorator(method: UnboundMethodT[EventManagerBaseT], /) -> UnboundMethodT[Ev
return decorator


@attr.frozen()
@attr.define(hash=True)
class _Consumer:
callback: ConsumerT = attr.ib()
callback: ConsumerT = attr.ib(hash=True)
"""The callback function for this consumer."""

cache_components: undefined.UndefinedOr[config.CacheComponents] = attr.ib()
"""Bitfield of the cache components this consumer makes modifying calls to, if set."""

event_types: undefined.UndefinedOr[typing.Sequence[typing.Type[base_events.Event]]] = attr.ib()
event_types: undefined.UndefinedOr[typing.Sequence[typing.Type[base_events.Event]]] = attr.ib(hash=False)
"""A sequence of the types of events this consumer dispatches to, if set."""

def __attrs_post_init__(self) -> None:
# Letting only one be UNDEFINED just doesn't make sense as either being undefined leads to filtering being
# skipped all together making the other redundant.
if undefined.count(self.cache_components, self.event_types) == 1:
raise ValueError("cache_components and event_types must both either be undefined or defined")
is_caching: bool = attr.ib(hash=False)
"""Cached value of whether or not this consumer is making cache calls in the current env."""


class EventManagerBase(event_manager_.EventManager):
Expand All @@ -322,11 +317,25 @@ class EventManagerBase(event_manager_.EventManager):
is the raw event name being dispatched in lower-case.
"""

__slots__: typing.Sequence[str] = ("_event_factory", "_intents", "_consumers", "_enabled_consumers_cache", "_listeners", "_waiters")
__slots__: typing.Sequence[str] = (
"_consumers",
"_dispatches_for_cache",
"_enabled_consumers_cache",
"_event_factory",
"_intents",
"_listeners",
"_waiters",
)

def __init__(self, event_factory: event_factory_.EventFactory, intents: intents_.Intents) -> None:
self._dispatches_for_cache: typing.Dict[_Consumer, bool] = {}
def __init__(
self,
event_factory: event_factory_.EventFactory,
intents: intents_.Intents,
*,
cache_settings: typing.Optional[config.CacheSettings] = None,
) -> None:
self._consumers: typing.Dict[str, _Consumer] = {}
self._dispatches_for_cache: typing.Dict[_Consumer, bool] = {}
self._enabled_consumers_cache: typing.Dict[_Consumer, bool] = {}
self._event_factory = event_factory
self._intents = intents
Expand All @@ -337,10 +346,11 @@ def __init__(self, event_factory: event_factory_.EventFactory, intents: intents_
if name.startswith("on_"):
event_name = name[3:]
if isinstance(member, _FilteredMethodT):
self._consumers[event_name] = _Consumer(member, member.__cache_components__, member.__event_types__)
caching = bool(cache_settings and (member.__cache_components__ & cache_settings.components))
self._consumers[event_name] = _Consumer(member, member.__event_types__, caching)

else:
self._consumers[event_name] = _Consumer(member, undefined.UNDEFINED, undefined.UNDEFINED)
self._consumers[event_name] = _Consumer(member, undefined.UNDEFINED, bool(cache_settings))

def _clear_enabled_cache(self) -> None:
self._enabled_consumers_cache = {}
Expand All @@ -354,29 +364,21 @@ def _enabled_for_event(self, event_type: typing.Type[base_events.Event], /) -> b

def _enabled_for_consumer(self, consumer: _Consumer, /) -> bool:
# If undefined then we can only assume that this may link to registered listeners.
if consumer.event_types is undefined.UNDEFINED:
return True

# If event_types is not UNDEFINED then cache_components shouldn't ever be undefined.
assert consumer.cache_components is not undefined.UNDEFINED
if (cached_value := self._enabled_consumers_cache.get(consumer)) is True:
if consumer.event_types is undefined.UNDEFINED or consumer.is_caching:
return True

if cached_value is None:
# The behaviour here where an empty sequence for event_types will lead to this always
# being skipped unless there's a relevant enabled cache resource is intended behaviour.
for event_type in consumer.event_types:
if event_type in self._listeners or event_type in self._waiters:
self._enabled_consumers_cache[consumer] = True
return True
if (cached_value := self._enabled_consumers_cache.get(consumer)) is not None:
return cached_value

self._enabled_consumers_cache[consumer] = False
# The behaviour here where an empty sequence for event_types will lead to this always
# being skipped unless there's a relevant enabled cache resource is intended behaviour.
for event_type in consumer.event_types:
if event_type in self._listeners or event_type in self._waiters:
self._enabled_consumers_cache[consumer] = True
return True

# If cache_components is NONE then it doesn't make any altering state calls.
return (
consumer.cache_components != config.CacheComponents.NONE
and (consumer.cache_components & self._app.cache.settings.components) != 0
)
self._enabled_consumers_cache[consumer] = False
return False

def consume_raw_event(
self, event_name: str, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject
Expand Down Expand Up @@ -445,10 +447,10 @@ def get_listeners(
polymorphic: bool = True,
) -> typing.Collection[event_manager_.CallbackT[event_manager_.EventT_co]]:
if polymorphic:
listeners: typing.List[event_manager.CallbackT[event_manager.EventT_co]] = []
for cls in event_type.dispatches():
listeners.extend(self._listeners[cls])

listeners: typing.List[event_manager_.CallbackT[event_manager_.EventT_co]] = []
for subscribed_event_type, subscribed_listeners in self._listeners.items():
if issubclass(subscribed_event_type, event_type):
listeners += subscribed_listeners
return listeners

else:
Expand Down
4 changes: 3 additions & 1 deletion tests/hikari/impl/test_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def test_init(self):
assert bot._cache is cache.return_value
cache.assert_called_once_with(bot, cache_settings)
assert bot._event_manager is event_manager.return_value
event_manager.assert_called_once_with(event_factory.return_value, intents, cache=cache.return_value)
event_manager.assert_called_once_with(
entity_factory.return_value, event_factory.return_value, intents, cache=cache.return_value
)
assert bot._entity_factory is entity_factory.return_value
entity_factory.assert_called_once_with(bot)
assert bot._event_factory is event_factory.return_value
Expand Down
18 changes: 18 additions & 0 deletions tests/hikari/impl/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,13 @@ def test_set_me(self, cache_impl):
assert cache_impl._me == mock_own_user
assert cache_impl._me is not mock_own_user

def test_set_me_when_not_enabled(self, cache_impl):
cache_impl._settings.components = 0

cache_impl.set_me(object())

assert cache_impl._me is None

def test_update_me_for_cached_me(self, cache_impl):
mock_cached_own_user = mock.MagicMock(users.OwnUser)
mock_own_user = mock.MagicMock(users.OwnUser)
Expand All @@ -1496,6 +1503,17 @@ def test_update_me_for_uncached_me(self, cache_impl):
assert result == (None, mock_own_user)
assert cache_impl._me == mock_own_user

def test_update_me_for_when_not_enabled(self, cache_impl):
cache_impl._settings.components = 0
cache_impl.get_me = mock.Mock()
cache_impl.set_me = mock.Mock()

result = cache_impl.update_me(object())

assert result == (None, None)
cache_impl.get_me.assert_not_called()
cache_impl.set_me.assert_not_called()

def test__build_member(self, cache_impl):
mock_user = mock.MagicMock(users.User)
member_data = cache_utilities.MemberData(
Expand Down
Loading

0 comments on commit 868b17e

Please sign in to comment.