From 4eaba591af5fa9ff7a55543a9696ebe17fec5c81 Mon Sep 17 00:00:00 2001 From: Lucina Date: Fri, 12 Nov 2021 16:35:19 +0000 Subject: [PATCH] Start filtering events before unmarshalling (#636) * Switch to iterating over mro during listener/waiter registration (rather than during event dispatch) * Add in a system to avoid unmarshalling data for events which aren't being used * Add settings property to cache interface to allow for introspection * Also add "ME" resource to cache config * Specialise guild create and update handlers to avoid unmarshalling data which isn't being cached when no listeners are registered for the event * For this to work gateway guild definition handling had to be refactored to switch to explicitly specifying which mappings it should include when calling it * Logic fixes around event checks * Register listeners by subclasses not parents (mro) (For this the subclasses need to be cached on the Event classes) * Add voodoo on new event cls callback to Event class * This is meant to be a mock way to handle the edge case of new subclassing Event types being added after the event manage has been initialised which might be unorthodox but probably has some wack use case * Switch over to mro based approach * Switch over to mro based approach * Cache whether a consumer can be dispatched or not * Slight logic cleanup * Prefer internal granularity on guild create and update methods * rename event_manager_base.as_listener to "filtered" and remove from on_guild_create and update * Also clear the dispatches for cache when waiters are depleted * Only deserialize guild object on guild create and update if necessary * Add check to shard payload dispatch and refactor consumer check logic * Internal refactors and naming scheme changes * Plus fix CacheImpl.update_me not copying the stored member entry before returning it * Add internal _FilteredMethod proto to event manager base * Move filtering to _handle_dispatch * Add internal _FilteredMethod proto to event manager base * Move filtering to _handle_dispatch * Add trace logging calls to on_guild_create and on_guild_update * Small logic fix + add code/logic comments and docs * As an artifact of this addition, on_guild_integrations_update acn raise NotImplementedError now since it should always be skipped * Some test fixes * cache_components shouldn't ever be undefined if event_types isn't * Try the builder pattern for GatewayGuildDefinition * Switch GatewayGuildDefinition to using getter style methods for delaying deserialization * test fixes and additions * bug fixes + tests * Post-rebase fixes * Have EventManagerBase take components rather than the cache settings * remove _dispatches_for_cache + add in missing filtered decorator calls * Post-rebase fix * post-rebase fixes * Change i forgot to commit * formatting fixes * Mypy and flake8 fixes --- hikari/api/cache.py | 6 + hikari/api/entity_factory.py | 76 +- hikari/config.py | 8 +- hikari/events/base_events.py | 22 + hikari/impl/bot.py | 4 +- hikari/impl/cache.py | 10 +- hikari/impl/entity_factory.py | 245 +++-- hikari/impl/event_factory.py | 38 +- hikari/impl/event_manager.py | 274 +++-- hikari/impl/event_manager_base.py | 212 +++- tests/hikari/impl/test_bot.py | 4 +- tests/hikari/impl/test_cache.py | 18 + tests/hikari/impl/test_entity_factory.py | 993 ++++++++++--------- tests/hikari/impl/test_event_factory.py | 33 +- tests/hikari/impl/test_event_manager.py | 282 +++++- tests/hikari/impl/test_event_manager_base.py | 219 +++- 16 files changed, 1696 insertions(+), 748 deletions(-) diff --git a/hikari/api/cache.py b/hikari/api/cache.py index 19e6c3337c..d200c8c683 100644 --- a/hikari/api/cache.py +++ b/hikari/api/cache.py @@ -32,6 +32,7 @@ if typing.TYPE_CHECKING: from hikari import channels + from hikari import config from hikari import emojis from hikari import guilds from hikari import invites @@ -88,6 +89,11 @@ class Cache(abc.ABC): __slots__: typing.Sequence[str] = () + @property + @abc.abstractmethod + def settings(self) -> config.CacheSettings: + """Get the configured settings for this cache.""" + @abc.abstractmethod def get_dm_channel_id( self, user: snowflakes.SnowflakeishOr[users.PartialUser], / diff --git a/hikari/api/entity_factory.py b/hikari/api/entity_factory.py index 91020f6373..02af7b508c 100644 --- a/hikari/api/entity_factory.py +++ b/hikari/api/entity_factory.py @@ -28,10 +28,7 @@ import abc import typing -import attr - from hikari import undefined -from hikari.internal import attr_extensions if typing.TYPE_CHECKING: from hikari import applications as application_models @@ -58,53 +55,56 @@ from hikari.internal import data_binding -@attr_extensions.with_copy -@attr.define(weakref_slot=False) -class GatewayGuildDefinition: - """A structure for handling entities within guild create and update events.""" - - guild: guild_models.GatewayGuild = attr.field() - """Object of the guild the definition is for.""" - - channels: typing.Optional[typing.Mapping[snowflakes.Snowflake, channel_models.GuildChannel]] = attr.field() - """Mapping of channel IDs to the channels that belong to the guild. +class GatewayGuildDefinition(abc.ABC): + """Structure for handling entities within guild create and update events. - Will be `builtins.None` when returned by guild update gateway events rather - than create. + !!! warning + The methods on this class may raise `builtins.LookupError` if called + when the relevant resource isn't available in the inner payload. """ - members: typing.Optional[typing.Mapping[snowflakes.Snowflake, guild_models.Member]] = attr.field() - """Mapping of user IDs to the members that belong to the guild. + __slots__: typing.Sequence[str] = () - Will be `builtins.None` when returned by guild update gateway events rather - than create. + @property + @abc.abstractmethod + def id(self) -> snowflakes.Snowflake: + """ID of the guild the definition is for.""" - !!! note - This may be a partial mapping of members in the guild. - """ + @abc.abstractmethod + def channels(self) -> typing.Mapping[snowflakes.Snowflake, channel_models.GuildChannel]: + """Get a mapping of channel IDs to the channels that belong to the guild.""" - presences: typing.Optional[typing.Mapping[snowflakes.Snowflake, presence_models.MemberPresence]] = attr.field() - """Mapping of user IDs to the presences that are active in the guild. + @abc.abstractmethod + def emojis(self) -> typing.Mapping[snowflakes.Snowflake, emoji_models.KnownCustomEmoji]: + """Get a mapping of emoji IDs to the emojis that belong to the guild.""" - Will be `builtins.None` when returned by guild update gateway events rather - than create. + @abc.abstractmethod + def guild(self) -> guild_models.GatewayGuild: + """Get the object of the guild this definition is for.""" - !!! note - This may be a partial mapping of presences active in the guild. - """ + @abc.abstractmethod + def members(self) -> typing.Mapping[snowflakes.Snowflake, guild_models.Member]: + """Get a mapping of user IDs to the members that belong to the guild. - roles: typing.Mapping[snowflakes.Snowflake, guild_models.Role] = attr.field() - """Mapping of role IDs to the roles that belong to the guild.""" + !!! note + This may be a partial mapping of members in the guild. + """ + + @abc.abstractmethod + def presences(self) -> typing.Mapping[snowflakes.Snowflake, presence_models.MemberPresence]: + """Get a mapping of user IDs to the presences that are active in the guild. - emojis: typing.Mapping[snowflakes.Snowflake, emoji_models.KnownCustomEmoji] = attr.field() - """Mapping of emoji IDs to the emojis that belong to the guild.""" + !!! note + This may be a partial mapping of presences active in the guild. + """ - voice_states: typing.Optional[typing.Mapping[snowflakes.Snowflake, voice_models.VoiceState]] = attr.field() - """Mapping of user IDs to the voice states that are active in the guild. + @abc.abstractmethod + def roles(self) -> typing.Mapping[snowflakes.Snowflake, guild_models.Role]: + """Get a mapping of role IDs to the roles that belong to the guild.""" - !!! note - This may be a partial mapping of voice states active in the guild. - """ + @abc.abstractmethod + def voice_states(self) -> typing.Mapping[snowflakes.Snowflake, voice_models.VoiceState]: + """Get a mapping of user IDs to the voice states that are active in the guild.""" class EntityFactory(abc.ABC): diff --git a/hikari/config.py b/hikari/config.py index f2440debf5..2cc2f58901 100644 --- a/hikari/config.py +++ b/hikari/config.py @@ -420,6 +420,9 @@ class CacheComponents(enums.Flag): MESSAGES = 1 << 8 """Enables the messages cache.""" + ME = 1 << 9 + """Enables the me cache.""" + DM_CHANNEL_IDS = 1 << 10 """Enables the DM channel IDs cache.""" @@ -433,6 +436,7 @@ class CacheComponents(enums.Flag): | PRESENCES | VOICE_STATES | MESSAGES + | ME | DM_CHANNEL_IDS ) """Fully enables the cache.""" @@ -443,13 +447,13 @@ class CacheComponents(enums.Flag): class CacheSettings: """Settings to control the cache.""" - components: CacheComponents = attr.field(default=CacheComponents.ALL) + components: CacheComponents = attr.field(converter=CacheComponents, default=CacheComponents.ALL) """The cache components to use. Defaults to `CacheComponents.ALL`. """ - max_messages: int = attr.field(default=300) + max_messages: int = attr.field(converter=int, default=300) """The maximum number of messages to store in the cache at once. This will have no effect if the messages cache is not enabled. diff --git a/hikari/events/base_events.py b/hikari/events/base_events.py index 1b3f7a3c67..a0fb9cb3ff 100644 --- a/hikari/events/base_events.py +++ b/hikari/events/base_events.py @@ -57,6 +57,23 @@ class Event(abc.ABC): __slots__: typing.Sequence[str] = () + __dispatches: typing.ClassVar[typing.Tuple[typing.Type[Event], ...]] + + def __init_subclass__(cls) -> None: + super().__init_subclass__() + # hasattr doesn't work with private variables in this case so we use a try except. + # We need to set Event's __dispatches when the first subclass is made as Event cannot + # be included in a tuple literal on itself due to not existing yet. + try: + Event.__dispatches + except AttributeError: + Event.__dispatches = (Event,) + + mro = cls.mro() + # We don't have to explicitly include Event here as issubclass(Event, Event) returns True. + # Non-event classes should be ignored. + cls.__dispatches = tuple(cls for cls in mro if issubclass(cls, Event)) + @property @abc.abstractmethod def app(self) -> traits.RESTAware: @@ -68,6 +85,11 @@ def app(self) -> traits.RESTAware: The REST-aware app trait. """ + @classmethod + def dispatches(cls) -> typing.Sequence[typing.Type[Event]]: + """Sequence of the event classes this event is dispatched as.""" + return cls.__dispatches + def get_required_intents_for(event_type: typing.Type[Event]) -> typing.Collection[intents.Intents]: """Retrieve the intents that are required to listen to an event type. diff --git a/hikari/impl/bot.py b/hikari/impl/bot.py index 9dc9541c7f..bdfa485b1d 100644 --- a/hikari/impl/bot.py +++ b/hikari/impl/bot.py @@ -274,7 +274,9 @@ def __init__( self._event_factory = event_factory_impl.EventFactoryImpl(self) # Event handling - self._event_manager = event_manager_impl.EventManagerImpl(self._event_factory, self._intents, cache=self._cache) + self._event_manager = event_manager_impl.EventManagerImpl( + self._entity_factory, self._event_factory, self._intents, cache=self._cache + ) # Voice subsystem self._voice = voice_impl.VoiceComponentImpl(self) diff --git a/hikari/impl/cache.py b/hikari/impl/cache.py index ae54b45e99..2ff7a94b67 100644 --- a/hikari/impl/cache.py +++ b/hikari/impl/cache.py @@ -770,15 +770,19 @@ def get_me(self) -> typing.Optional[users.OwnUser]: return copy.copy(self._me) def set_me(self, user: users.OwnUser, /) -> None: - self._me = copy.copy(user) + if self._is_cache_enabled_for(config.CacheComponents.ME): + _LOGGER.debug("setting my user to %s", user) + self._me = copy.copy(user) def update_me( self, user: users.OwnUser, / ) -> typing.Tuple[typing.Optional[users.OwnUser], typing.Optional[users.OwnUser]]: - _LOGGER.debug("setting my user to %s", user) + if not self._is_cache_enabled_for(config.CacheComponents.ME): + return None, None + cached_user = self.get_me() self.set_me(user) - return cached_user, self._me + return cached_user, self.get_me() def _build_member( self, diff --git a/hikari/impl/entity_factory.py b/hikari/impl/entity_factory.py index fce6589ab7..92fd60bfc7 100644 --- a/hikari/impl/entity_factory.py +++ b/hikari/impl/entity_factory.py @@ -179,6 +179,150 @@ class _UserFields: is_system: bool = attr.field() +@attr_extensions.with_copy +@attr.define(weakref_slot=False) +class _GatewayGuildDefinition(entity_factory.GatewayGuildDefinition): + """A structure for handling entities within guild create and update events.""" + + id: snowflakes.Snowflake = attr.field() + _payload: data_binding.JSONObject = attr.field() + _entity_factory: EntityFactoryImpl = attr.field() + _channels: typing.Optional[typing.Mapping[snowflakes.Snowflake, channel_models.GuildChannel]] = attr.field( + default=None, init=False + ) + _emojis: typing.Optional[typing.Mapping[snowflakes.Snowflake, emoji_models.KnownCustomEmoji]] = attr.field( + default=None, init=False + ) + _guild: typing.Optional[guild_models.GatewayGuild] = attr.field(default=None) + _members: typing.Optional[typing.Mapping[snowflakes.Snowflake, guild_models.Member]] = attr.field( + default=None, init=False + ) + _presences: typing.Optional[typing.Mapping[snowflakes.Snowflake, presence_models.MemberPresence]] = attr.field( + default=None, init=False + ) + _roles: typing.Optional[typing.Mapping[snowflakes.Snowflake, guild_models.Role]] = attr.field( + default=None, init=False + ) + _voice_states: typing.Optional[typing.Mapping[snowflakes.Snowflake, voice_models.VoiceState]] = attr.field( + default=None, init=False + ) + + def channels(self) -> typing.Mapping[snowflakes.Snowflake, channel_models.GuildChannel]: + if self._channels is None: + self._channels = {} + + for channel_payload in self._payload["channels"]: + try: + channel = self._entity_factory.deserialize_channel(channel_payload, guild_id=self.id) + except errors.UnrecognisedEntityError: + # Ignore the channel, this has already been logged + continue + + assert isinstance(channel, channel_models.GuildChannel) + self._channels[channel.id] = channel + + return self._channels + + def emojis(self) -> typing.Mapping[snowflakes.Snowflake, emoji_models.KnownCustomEmoji]: + if self._emojis is None: + deserialize = self._entity_factory.deserialize_known_custom_emoji + self._emojis = { + snowflakes.Snowflake(emoji["id"]): deserialize(emoji, guild_id=self.id) + for emoji in self._payload["emojis"] + } + + return self._emojis + + def guild(self) -> guild_models.GatewayGuild: + if self._guild is None: + payload = self._payload + guild_fields = self._entity_factory.set_guild_attributes(payload) + is_large = payload.get("large") + joined_at = ( + time.iso8601_datetime_string_to_datetime(payload["joined_at"]) if "joined_at" in payload else None + ) + member_count = int(payload["member_count"]) if "member_count" in payload else None + self._guild = guild_models.GatewayGuild( + app=self._entity_factory.app, + id=guild_fields.id, + name=guild_fields.name, + icon_hash=guild_fields.icon_hash, + features=guild_fields.features, + splash_hash=guild_fields.splash_hash, + discovery_splash_hash=guild_fields.discovery_splash_hash, + owner_id=guild_fields.owner_id, + afk_channel_id=guild_fields.afk_channel_id, + afk_timeout=guild_fields.afk_timeout, + verification_level=guild_fields.verification_level, + default_message_notifications=guild_fields.default_message_notifications, + explicit_content_filter=guild_fields.explicit_content_filter, + mfa_level=guild_fields.mfa_level, + application_id=guild_fields.application_id, + widget_channel_id=guild_fields.widget_channel_id, + system_channel_id=guild_fields.system_channel_id, + is_widget_enabled=guild_fields.is_widget_enabled, + system_channel_flags=guild_fields.system_channel_flags, + rules_channel_id=guild_fields.rules_channel_id, + max_video_channel_users=guild_fields.max_video_channel_users, + vanity_url_code=guild_fields.vanity_url_code, + description=guild_fields.description, + banner_hash=guild_fields.banner_hash, + premium_tier=guild_fields.premium_tier, + premium_subscription_count=guild_fields.premium_subscription_count, + preferred_locale=guild_fields.preferred_locale, + public_updates_channel_id=guild_fields.public_updates_channel_id, + nsfw_level=guild_fields.nsfw_level, + is_large=is_large, + joined_at=joined_at, + member_count=member_count, + ) + + return self._guild + + def members(self) -> typing.Mapping[snowflakes.Snowflake, guild_models.Member]: + if self._members is None: + self._members = {} + + for member_payload in self._payload["members"]: + member = self._entity_factory.deserialize_member(member_payload, guild_id=self.id) + self._members[member.user.id] = member + + return self._members + + def presences(self) -> typing.Mapping[snowflakes.Snowflake, presence_models.MemberPresence]: + if self._presences is None: + self._presences = {} + + for presence_payload in self._payload["presences"]: + presence = self._entity_factory.deserialize_member_presence(presence_payload, guild_id=self.id) + self._presences[presence.user_id] = presence + + return self._presences + + def roles(self) -> typing.Mapping[snowflakes.Snowflake, guild_models.Role]: + if self._roles is None: + self._roles = { + snowflakes.Snowflake(role["id"]): self._entity_factory.deserialize_role(role, guild_id=self.id) + for role in self._payload["roles"] + } + + return self._roles + + def voice_states(self) -> typing.Mapping[snowflakes.Snowflake, voice_models.VoiceState]: + if self._voice_states is None: + members = self.members() + self._voice_states = {} + + for voice_state_payload in self._payload["voice_states"]: + member = members[snowflakes.Snowflake(voice_state_payload["user_id"])] + voice_state = self._entity_factory.deserialize_voice_state( + voice_state_payload, guild_id=self.id, member=member + ) + self._voice_states[voice_state.user_id] = voice_state + + return self._voice_states + + class EntityFactoryImpl(entity_factory.EntityFactory): """Standard implementation for a serializer/deserializer. @@ -280,6 +424,11 @@ def __init__(self, app: traits.RESTAware) -> None: webhook_models.WebhookType.APPLICATION: self.deserialize_application_webhook, } + @property + def app(self) -> traits.RESTAware: + """Object of the application this entity factory is bound to.""" + return self._app + ###################### # APPLICATION MODELS # ###################### @@ -1334,7 +1483,7 @@ def deserialize_guild_preview(self, payload: data_binding.JSONObject) -> guild_m description=payload["description"], ) - def _set_guild_attributes(self, payload: data_binding.JSONObject) -> _GuildFields: + def set_guild_attributes(self, payload: data_binding.JSONObject) -> _GuildFields: afk_channel_id = payload["afk_channel_id"] default_message_notifications = guild_models.GuildMessageNotificationsLevel( payload["default_message_notifications"] @@ -1385,7 +1534,7 @@ def _set_guild_attributes(self, payload: data_binding.JSONObject) -> _GuildField ) def deserialize_rest_guild(self, payload: data_binding.JSONObject) -> guild_models.RESTGuild: - guild_fields = self._set_guild_attributes(payload) + guild_fields = self.set_guild_attributes(payload) approximate_member_count: typing.Optional[int] = None if "approximate_member_count" in payload: @@ -1447,96 +1596,8 @@ def deserialize_rest_guild(self, payload: data_binding.JSONObject) -> guild_mode ) def deserialize_gateway_guild(self, payload: data_binding.JSONObject) -> entity_factory.GatewayGuildDefinition: - guild_fields = self._set_guild_attributes(payload) - is_large = payload.get("large") - joined_at = time.iso8601_datetime_string_to_datetime(payload["joined_at"]) if "joined_at" in payload else None - member_count = int(payload["member_count"]) if "member_count" in payload else None - - guild = guild_models.GatewayGuild( - app=self._app, - id=guild_fields.id, - name=guild_fields.name, - icon_hash=guild_fields.icon_hash, - features=guild_fields.features, - splash_hash=guild_fields.splash_hash, - discovery_splash_hash=guild_fields.discovery_splash_hash, - owner_id=guild_fields.owner_id, - afk_channel_id=guild_fields.afk_channel_id, - afk_timeout=guild_fields.afk_timeout, - verification_level=guild_fields.verification_level, - default_message_notifications=guild_fields.default_message_notifications, - explicit_content_filter=guild_fields.explicit_content_filter, - mfa_level=guild_fields.mfa_level, - application_id=guild_fields.application_id, - widget_channel_id=guild_fields.widget_channel_id, - system_channel_id=guild_fields.system_channel_id, - is_widget_enabled=guild_fields.is_widget_enabled, - system_channel_flags=guild_fields.system_channel_flags, - rules_channel_id=guild_fields.rules_channel_id, - max_video_channel_users=guild_fields.max_video_channel_users, - vanity_url_code=guild_fields.vanity_url_code, - description=guild_fields.description, - banner_hash=guild_fields.banner_hash, - premium_tier=guild_fields.premium_tier, - premium_subscription_count=guild_fields.premium_subscription_count, - preferred_locale=guild_fields.preferred_locale, - public_updates_channel_id=guild_fields.public_updates_channel_id, - nsfw_level=guild_fields.nsfw_level, - is_large=is_large, - joined_at=joined_at, - member_count=member_count, - ) - - members: typing.Optional[typing.Dict[snowflakes.Snowflake, guild_models.Member]] = None - if "members" in payload: - members = {} - - for member_payload in payload["members"]: - member = self.deserialize_member(member_payload, guild_id=guild.id) - members[member.user.id] = member - - channels: typing.Optional[typing.Dict[snowflakes.Snowflake, channel_models.GuildChannel]] = None - if "channels" in payload: - channels = {} - - for channel_payload in payload["channels"]: - try: - channel = self.deserialize_channel(channel_payload, guild_id=guild.id) - except errors.UnrecognisedEntityError: - # Ignore the channel, this has already been logged - continue - - assert isinstance(channel, channel_models.GuildChannel) - channels[channel.id] = channel - - presences: typing.Optional[typing.Dict[snowflakes.Snowflake, presence_models.MemberPresence]] = None - if "presences" in payload: - presences = {} - - for presence_payload in payload["presences"]: - presence = self.deserialize_member_presence(presence_payload, guild_id=guild.id) - presences[presence.user_id] = presence - - voice_states: typing.Optional[typing.Dict[snowflakes.Snowflake, voice_models.VoiceState]] = None - if "voice_states" in payload: - voice_states = {} - assert members is not None - - for voice_state_payload in payload["voice_states"]: - member = members[snowflakes.Snowflake(voice_state_payload["user_id"])] - voice_state = self.deserialize_voice_state(voice_state_payload, guild_id=guild.id, member=member) - voice_states[voice_state.user_id] = voice_state - - roles = { - snowflakes.Snowflake(role["id"]): self.deserialize_role(role, guild_id=guild.id) - for role in payload["roles"] - } - emojis = { - snowflakes.Snowflake(emoji["id"]): self.deserialize_known_custom_emoji(emoji, guild_id=guild.id) - for emoji in payload["emojis"] - } - - return entity_factory.GatewayGuildDefinition(guild, channels, members, presences, roles, emojis, voice_states) + guild_id = snowflakes.Snowflake(payload["id"]) + return _GatewayGuildDefinition(id=guild_id, payload=payload, entity_factory=self) ################# # INVITE MODELS # diff --git a/hikari/impl/event_factory.py b/hikari/impl/event_factory.py index 92b6917c4a..80933e4b7b 100644 --- a/hikari/impl/event_factory.py +++ b/hikari/impl/event_factory.py @@ -194,19 +194,15 @@ def deserialize_guild_available_event( self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject ) -> guild_events.GuildAvailableEvent: guild_information = self._app.entity_factory.deserialize_gateway_guild(payload) - assert guild_information.channels is not None - assert guild_information.members is not None - assert guild_information.presences is not None - assert guild_information.voice_states is not None return guild_events.GuildAvailableEvent( shard=shard, - guild=guild_information.guild, - emojis=guild_information.emojis, - roles=guild_information.roles, - channels=guild_information.channels, - members=guild_information.members, - presences=guild_information.presences, - voice_states=guild_information.voice_states, + guild=guild_information.guild(), + emojis=guild_information.emojis(), + roles=guild_information.roles(), + channels=guild_information.channels(), + members=guild_information.members(), + presences=guild_information.presences(), + voice_states=guild_information.voice_states(), ) def deserialize_guild_join_event( @@ -219,13 +215,13 @@ def deserialize_guild_join_event( assert guild_information.voice_states is not None return guild_events.GuildJoinEvent( shard=shard, - guild=guild_information.guild, - emojis=guild_information.emojis, - roles=guild_information.roles, - channels=guild_information.channels, - members=guild_information.members, - presences=guild_information.presences, - voice_states=guild_information.voice_states, + guild=guild_information.guild(), + emojis=guild_information.emojis(), + roles=guild_information.roles(), + channels=guild_information.channels(), + members=guild_information.members(), + presences=guild_information.presences(), + voice_states=guild_information.voice_states(), ) def deserialize_guild_update_event( @@ -238,9 +234,9 @@ def deserialize_guild_update_event( guild_information = self._app.entity_factory.deserialize_gateway_guild(payload) return guild_events.GuildUpdateEvent( shard=shard, - guild=guild_information.guild, - emojis=guild_information.emojis, - roles=guild_information.roles, + guild=guild_information.guild(), + emojis=guild_information.emojis(), + roles=guild_information.roles(), old_guild=old_guild, ) diff --git a/hikari/impl/event_manager.py b/hikari/impl/event_manager.py index bcd72d9eb4..e16a536935 100644 --- a/hikari/impl/event_manager.py +++ b/hikari/impl/event_manager.py @@ -28,27 +28,43 @@ import asyncio import base64 +import logging import random import typing +from hikari import config from hikari import errors from hikari import intents as intents_ -from hikari import presences +from hikari import presences as presences_ from hikari import snowflakes +from hikari.events import channel_events +from hikari.events import guild_events +from hikari.events import member_events +from hikari.events import message_events +from hikari.events import reaction_events +from hikari.events import role_events +from hikari.events import shard_events +from hikari.events import typing_events +from hikari.events import user_events +from hikari.events import voice_events from hikari.impl import event_manager_base from hikari.internal import time +from hikari.internal import ux if typing.TYPE_CHECKING: from hikari import guilds from hikari import invites from hikari import voices from hikari.api import cache as cache_ + from hikari.api import entity_factory as entity_factory_ from hikari.api import event_factory as event_factory_ from hikari.api import shard as gateway_shard - from hikari.events import guild_events as guild_events from hikari.internal import data_binding +_LOGGER: typing.Final[logging.Logger] = logging.getLogger("hikari.event_manager") + + def _fixed_size_nonce() -> str: # This generates nonces of length 28 for use in member chunking. head = time.monotonic_ns().to_bytes(8, "big") @@ -58,7 +74,7 @@ def _fixed_size_nonce() -> str: async def _request_guild_members( shard: gateway_shard.GatewayShard, - guild: guilds.PartialGuild, + guild: snowflakes.SnowflakeishOr[guilds.PartialGuild], *, include_presences: bool, nonce: str, @@ -74,10 +90,11 @@ async def _request_guild_members( class EventManagerImpl(event_manager_base.EventManagerBase): """Provides event handling logic for Discord events.""" - __slots__: typing.Sequence[str] = ("_cache",) + __slots__: typing.Sequence[str] = ("_cache", "_entity_factory") def __init__( self, + entity_factory: entity_factory_.EntityFactory, event_factory: event_factory_.EventFactory, intents: intents_.Intents, /, @@ -85,8 +102,14 @@ def __init__( cache: typing.Optional[cache_.MutableCache] = None, ) -> None: self._cache = cache - super().__init__(event_factory=event_factory, intents=intents) + self._entity_factory = entity_factory + components = cache.settings.components if cache else config.CacheComponents.NONE + super().__init__(event_factory=event_factory, intents=intents, cache_components=components) + def _cache_enabled_for(self, components: config.CacheComponents, /) -> bool: + return self._cache is not None and (self._cache.settings.components & components) == components + + @event_manager_base.filtered(shard_events.ShardReadyEvent, config.CacheComponents.ME) async def on_ready(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#ready for more info.""" # TODO: cache unavailable guilds on startup, I didn't bother for the time being. @@ -97,10 +120,12 @@ async def on_ready(self, shard: gateway_shard.GatewayShard, payload: data_bindin await self.dispatch(event) + @event_manager_base.filtered(shard_events.ShardResumedEvent) async def on_resumed(self, shard: gateway_shard.GatewayShard, _: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#resumed for more info.""" await self.dispatch(self._event_factory.deserialize_resumed_event(shard)) + @event_manager_base.filtered(channel_events.GuildChannelCreateEvent, config.CacheComponents.GUILD_CHANNELS) async def on_channel_create(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#channel-create for more info.""" event = self._event_factory.deserialize_guild_channel_create_event(shard, payload) @@ -110,6 +135,7 @@ async def on_channel_create(self, shard: gateway_shard.GatewayShard, payload: da await self.dispatch(event) + @event_manager_base.filtered(channel_events.GuildChannelUpdateEvent, config.CacheComponents.GUILD_CHANNELS) async def on_channel_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#channel-update for more info.""" old = self._cache.get_guild_channel(snowflakes.Snowflake(payload["id"])) if self._cache else None @@ -120,6 +146,7 @@ async def on_channel_update(self, shard: gateway_shard.GatewayShard, payload: da await self.dispatch(event) + @event_manager_base.filtered(channel_events.GuildChannelDeleteEvent, config.CacheComponents.GUILD_CHANNELS) async def on_channel_delete(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#channel-delete for more info.""" event = self._event_factory.deserialize_guild_channel_delete_event(shard, payload) @@ -129,87 +156,172 @@ async def on_channel_delete(self, shard: gateway_shard.GatewayShard, payload: da await self.dispatch(event) + @event_manager_base.filtered((channel_events.GuildPinsUpdateEvent, channel_events.DMPinsUpdateEvent)) async def on_channel_pins_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#channel-pins-update for more info.""" # TODO: we need a method for this specifically await self.dispatch(self._event_factory.deserialize_channel_pins_update_event(shard, payload)) + # Internal granularity is preferred for GUILD_CREATE over decorator based filtering due to its large cache scope. async def on_guild_create(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-create for more info.""" - event: typing.Union[guild_events.GuildAvailableEvent, guild_events.GuildJoinEvent] + enabled_for_event = self._enabled_for_event(guild_events.GuildAvailableEvent) + if not enabled_for_event and self._cache: + _LOGGER.log(ux.TRACE, "Skipping on_guild_create dispatch due to lack of any registered listeners") + event: typing.Union[guild_events.GuildAvailableEvent, guild_events.GuildJoinEvent, None] = None + gd = self._entity_factory.deserialize_gateway_guild(payload) + + channels = gd.channels() if self._cache_enabled_for(config.CacheComponents.GUILD_CHANNELS) else None + emojis = gd.emojis() if self._cache_enabled_for(config.CacheComponents.EMOJIS) else None + guild = gd.guild() if self._cache_enabled_for(config.CacheComponents.GUILDS) else None + guild_id = gd.id + members = gd.members() if self._cache_enabled_for(config.CacheComponents.MEMBERS) else None + presences = gd.presences() if self._cache_enabled_for(config.CacheComponents.PRESENCES) else None + roles = gd.roles() if self._cache_enabled_for(config.CacheComponents.ROLES) else None + voice_states = gd.voice_states() if self._cache_enabled_for(config.CacheComponents.VOICE_STATES) else None + + elif enabled_for_event: + if "unavailable" in payload: + event = self._event_factory.deserialize_guild_available_event(shard, payload) + else: + event = self._event_factory.deserialize_guild_join_event(shard, payload) + + channels = event.channels + emojis = event.emojis + guild = event.guild + guild_id = guild.id + members = event.members + presences = event.presences + roles = event.roles + voice_states = event.voice_states - if "unavailable" in payload: - event = self._event_factory.deserialize_guild_available_event(shard, payload) else: - event = self._event_factory.deserialize_guild_join_event(shard, payload) + event = None + channels = None + emojis = None + guild = None + guild_id = snowflakes.Snowflake(payload["id"]) + members = None + presences = None + roles = None + voice_states = None if self._cache: - self._cache.update_guild(event.guild) - - self._cache.clear_guild_channels_for_guild(event.guild.id) - for channel in event.channels.values(): - self._cache.set_guild_channel(channel) + if guild: + self._cache.update_guild(guild) - self._cache.clear_emojis_for_guild(event.guild.id) - for emoji in event.emojis.values(): - self._cache.set_emoji(emoji) - - self._cache.clear_roles_for_guild(event.guild.id) - for role in event.roles.values(): - self._cache.set_role(role) + if channels: + self._cache.clear_guild_channels_for_guild(guild_id) + for channel in channels.values(): + self._cache.set_guild_channel(channel) - # TODO: do we really want to invalidate these all after an outage. - self._cache.clear_members_for_guild(event.guild.id) - for member in event.members.values(): - self._cache.set_member(member) + if emojis: + self._cache.clear_emojis_for_guild(guild_id) + for emoji in emojis.values(): + self._cache.set_emoji(emoji) - self._cache.clear_presences_for_guild(event.guild.id) - for presence in event.presences.values(): - self._cache.set_presence(presence) + if roles: + self._cache.clear_roles_for_guild(guild_id) + for role in roles.values(): + self._cache.set_role(role) - self._cache.clear_voice_states_for_guild(event.guild.id) - for voice_state in event.voice_states.values(): - self._cache.set_voice_state(voice_state) + if members: + # TODO: do we really want to invalidate these all after an outage. + self._cache.clear_members_for_guild(guild_id) + for member in members.values(): + self._cache.set_member(member) - members_declared = self._intents & intents_.Intents.GUILD_MEMBERS - presences_declared = self._intents & intents_.Intents.GUILD_PRESENCES + if presences: + self._cache.clear_presences_for_guild(guild_id) + for presence in presences.values(): + self._cache.set_presence(presence) - # When intents are enabled discord will only send other member objects on the guild create - # payload if presence intents are also declared, so if this isn't the case then we also want - # to chunk small guilds. - if members_declared and (event.guild.is_large or not presences_declared): - # We create a task here instead of awaiting the result to avoid any rate-limits from delaying dispatch. - nonce = f"{shard.id}.{_fixed_size_nonce()}" + if voice_states: + self._cache.clear_voice_states_for_guild(guild_id) + for voice_state in voice_states.values(): + self._cache.set_voice_state(voice_state) + + recv_chunks = self._enabled_for_event(shard_events.MemberChunkEvent) or self._cache_enabled_for( + config.CacheComponents.MEMBERS + ) + members_declared = self._intents & intents_.Intents.GUILD_MEMBERS + presences_declared = self._intents & intents_.Intents.GUILD_PRESENCES + + # When intents are enabled discord will only send other member objects on the guild create + # payload if presence intents are also declared, so if this isn't the case then we also want + # to chunk small guilds. + if recv_chunks and members_declared and (payload.get("large") or not presences_declared): + # We create a task here instead of awaiting the result to avoid any rate-limits from delaying dispatch. + nonce = f"{shard.id}.{_fixed_size_nonce()}" + + if event: event.chunk_nonce = nonce - coroutine = _request_guild_members( - shard, event.guild, include_presences=bool(presences_declared), nonce=nonce - ) - asyncio.create_task(coroutine, name=f"{shard.id}:{event.guild.id} guild create members request") - await self.dispatch(event) + coroutine = _request_guild_members(shard, guild_id, include_presences=bool(presences_declared), nonce=nonce) + asyncio.create_task(coroutine, name=f"{shard.id}:{guild_id} guild create members request") + + if event: + await self.dispatch(event) + # Internal granularity is preferred for GUILD_UPDATE over decorator based filtering due to its large cache scope. async def on_guild_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-update for more info.""" - old = self._cache.get_guild(snowflakes.Snowflake(payload["id"])) if self._cache else None - event = self._event_factory.deserialize_guild_update_event(shard, payload, old_guild=old) + enabled_for_event = self._enabled_for_event(guild_events.GuildUpdateEvent) + + if not enabled_for_event and self._cache: + _LOGGER.log(ux.TRACE, "Skipping on_guild_update raw dispatch due to lack of any registered listeners") + event: typing.Optional[guild_events.GuildUpdateEvent] = None + gd = self._entity_factory.deserialize_gateway_guild(payload) + emojis = gd.emojis() if self._cache_enabled_for(config.CacheComponents.EMOJIS) else None + guild = gd.guild() if self._cache_enabled_for(config.CacheComponents.GUILDS) else None + guild_id = gd.id + roles = gd.roles() if self._cache_enabled_for(config.CacheComponents.ROLES) else None + + elif enabled_for_event: + guild_id = snowflakes.Snowflake(payload["id"]) + old = self._cache.get_guild(guild_id) if self._cache else None + event = self._event_factory.deserialize_guild_update_event(shard, payload, old_guild=old) + emojis = event.emojis + guild = event.guild + roles = event.roles - if self._cache: - self._cache.update_guild(event.guild) - - self._cache.clear_roles_for_guild(event.guild.id) - for role in event.roles.values(): # TODO: do we actually get this here? - self._cache.set_role(role) + else: + _LOGGER.log( + ux.TRACE, "Skipping on_guild_update raw dispatch due to lack of any registered listeners or cache need" + ) + return - self._cache.clear_emojis_for_guild(event.guild.id) # TODO: do we actually get this here? - for emoji in event.emojis.values(): - self._cache.set_emoji(emoji) + if self._cache: + if guild: + self._cache.update_guild(guild) - await self.dispatch(event) + if emojis: + self._cache.clear_emojis_for_guild(guild_id) + for emoji in emojis.values(): + self._cache.set_emoji(emoji) + if roles: + self._cache.clear_roles_for_guild(guild_id) + for role in roles.values(): + self._cache.set_role(role) + + if event: + await self.dispatch(event) + + @event_manager_base.filtered( + (guild_events.GuildLeaveEvent, guild_events.GuildUnavailableEvent), + config.CacheComponents.GUILDS + | config.CacheComponents.GUILD_CHANNELS + | config.CacheComponents.EMOJIS + | config.CacheComponents.ROLES + | config.CacheComponents.PRESENCES + | config.CacheComponents.VOICE_STATES + | config.CacheComponents.MEMBERS, + ) async def on_guild_delete(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-delete for more info.""" event: typing.Union[guild_events.GuildUnavailableEvent, guild_events.GuildLeaveEvent] - if payload.get("unavailable", False): + if payload.get("unavailable"): event = self._event_factory.deserialize_guild_unavailable_event(shard, payload) if self._cache: @@ -233,14 +345,17 @@ async def on_guild_delete(self, shard: gateway_shard.GatewayShard, payload: data await self.dispatch(event) + @event_manager_base.filtered(guild_events.BanCreateEvent) async def on_guild_ban_add(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-ban-add for more info.""" await self.dispatch(self._event_factory.deserialize_guild_ban_add_event(shard, payload)) + @event_manager_base.filtered(guild_events.BanDeleteEvent) async def on_guild_ban_remove(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-ban-remove for more info.""" await self.dispatch(self._event_factory.deserialize_guild_ban_remove_event(shard, payload)) + @event_manager_base.filtered(guild_events.EmojisUpdateEvent, config.CacheComponents.EMOJIS) async def on_guild_emojis_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-emojis-update for more info.""" guild_id = snowflakes.Snowflake(payload["guild_id"]) @@ -254,24 +369,29 @@ async def on_guild_emojis_update(self, shard: gateway_shard.GatewayShard, payloa await self.dispatch(event) + @event_manager_base.filtered(()) # An empty sequence here means that this method will always be skipped. async def on_guild_integrations_update(self, _: gateway_shard.GatewayShard, __: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-integrations-update for more info.""" # This is only here to stop this being logged or dispatched as an "unknown event". # This event is made redundant by INTEGRATION_CREATE/DELETE/UPDATE and is thus not parsed or dispatched. - return None + raise NotImplementedError + @event_manager_base.filtered(guild_events.IntegrationCreateEvent) async def on_integration_create(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: event = self._event_factory.deserialize_integration_create_event(shard, payload) await self.dispatch(event) + @event_manager_base.filtered(guild_events.IntegrationDeleteEvent) async def on_integration_delete(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: event = self._event_factory.deserialize_integration_delete_event(shard, payload) await self.dispatch(event) + @event_manager_base.filtered(guild_events.IntegrationUpdateEvent) async def on_integration_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: event = self._event_factory.deserialize_integration_update_event(shard, payload) await self.dispatch(event) + @event_manager_base.filtered(member_events.MemberCreateEvent, config.CacheComponents.MEMBERS) async def on_guild_member_add(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-member-add for more info.""" event = self._event_factory.deserialize_guild_member_add_event(shard, payload) @@ -281,6 +401,7 @@ async def on_guild_member_add(self, shard: gateway_shard.GatewayShard, payload: await self.dispatch(event) + @event_manager_base.filtered(member_events.MemberDeleteEvent, config.CacheComponents.MEMBERS) async def on_guild_member_remove(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-member-remove for more info.""" old: typing.Optional[guilds.Member] = None @@ -292,6 +413,7 @@ async def on_guild_member_remove(self, shard: gateway_shard.GatewayShard, payloa event = self._event_factory.deserialize_guild_member_remove_event(shard, payload, old_member=old) await self.dispatch(event) + @event_manager_base.filtered(member_events.MemberUpdateEvent, config.CacheComponents.MEMBERS) async def on_guild_member_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-member-update for more info.""" old: typing.Optional[guilds.Member] = None @@ -307,6 +429,7 @@ async def on_guild_member_update(self, shard: gateway_shard.GatewayShard, payloa await self.dispatch(event) + @event_manager_base.filtered(shard_events.MemberChunkEvent, config.CacheComponents.MEMBERS) async def on_guild_members_chunk(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-members-chunk for more info.""" event = self._event_factory.deserialize_guild_member_chunk_event(shard, payload) @@ -320,6 +443,7 @@ async def on_guild_members_chunk(self, shard: gateway_shard.GatewayShard, payloa await self.dispatch(event) + @event_manager_base.filtered(role_events.RoleCreateEvent, config.CacheComponents.ROLES) async def on_guild_role_create(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-role-create for more info.""" event = self._event_factory.deserialize_guild_role_create_event(shard, payload) @@ -329,6 +453,7 @@ async def on_guild_role_create(self, shard: gateway_shard.GatewayShard, payload: await self.dispatch(event) + @event_manager_base.filtered(role_events.RoleUpdateEvent, config.CacheComponents.ROLES) async def on_guild_role_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-role-update for more info.""" old = self._cache.get_role(snowflakes.Snowflake(payload["role"]["id"])) if self._cache else None @@ -339,6 +464,7 @@ async def on_guild_role_update(self, shard: gateway_shard.GatewayShard, payload: await self.dispatch(event) + @event_manager_base.filtered(role_events.RoleDeleteEvent, config.CacheComponents.ROLES) async def on_guild_role_delete(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-role-delete for more info.""" old: typing.Optional[guilds.Role] = None @@ -349,6 +475,7 @@ async def on_guild_role_delete(self, shard: gateway_shard.GatewayShard, payload: await self.dispatch(event) + @event_manager_base.filtered(channel_events.InviteCreateEvent, config.CacheComponents.INVITES) async def on_invite_create(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#invite-create for more info.""" event = self._event_factory.deserialize_invite_create_event(shard, payload) @@ -358,6 +485,7 @@ async def on_invite_create(self, shard: gateway_shard.GatewayShard, payload: dat await self.dispatch(event) + @event_manager_base.filtered(channel_events.InviteDeleteEvent, config.CacheComponents.INVITES) async def on_invite_delete(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#invite-delete for more info.""" old: typing.Optional[invites.InviteWithMetadata] = None @@ -368,6 +496,9 @@ async def on_invite_delete(self, shard: gateway_shard.GatewayShard, payload: dat await self.dispatch(event) + @event_manager_base.filtered( + (message_events.GuildMessageCreateEvent, message_events.DMMessageCreateEvent), config.CacheComponents.MESSAGES + ) async def on_message_create(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#message-create for more info.""" event = self._event_factory.deserialize_message_create_event(shard, payload) @@ -377,6 +508,9 @@ async def on_message_create(self, shard: gateway_shard.GatewayShard, payload: da await self.dispatch(event) + @event_manager_base.filtered( + (message_events.GuildMessageUpdateEvent, message_events.DMMessageUpdateEvent), config.CacheComponents.MESSAGES + ) async def on_message_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#message-update for more info.""" old = self._cache.get_message(snowflakes.Snowflake(payload["id"])) if self._cache else None @@ -387,6 +521,9 @@ async def on_message_update(self, shard: gateway_shard.GatewayShard, payload: da await self.dispatch(event) + @event_manager_base.filtered( + (message_events.GuildMessageDeleteEvent, message_events.DMMessageDeleteEvent), config.CacheComponents.MESSAGES + ) async def on_message_delete(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#message-delete for more info.""" event = self._event_factory.deserialize_message_delete_event(shard, payload) @@ -396,6 +533,9 @@ async def on_message_delete(self, shard: gateway_shard.GatewayShard, payload: da await self.dispatch(event) + @event_manager_base.filtered( + (message_events.GuildMessageDeleteEvent, message_events.DMMessageDeleteEvent), config.CacheComponents.MESSAGES + ) async def on_message_delete_bulk(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#message-delete-bulk for more info.""" event = self._event_factory.deserialize_message_delete_bulk_event(shard, payload) @@ -406,6 +546,7 @@ async def on_message_delete_bulk(self, shard: gateway_shard.GatewayShard, payloa await self.dispatch(event) + @event_manager_base.filtered((reaction_events.GuildReactionAddEvent, reaction_events.DMReactionAddEvent)) async def on_message_reaction_add( self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject ) -> None: @@ -414,27 +555,35 @@ async def on_message_reaction_add( # TODO: this is unlikely but reaction cache? + @event_manager_base.filtered((reaction_events.GuildReactionDeleteEvent, reaction_events.DMReactionDeleteEvent)) async def on_message_reaction_remove( self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject ) -> None: """See https://discord.com/developers/docs/topics/gateway#message-reaction-remove for more info.""" await self.dispatch(self._event_factory.deserialize_message_reaction_remove_event(shard, payload)) + @event_manager_base.filtered( + (reaction_events.GuildReactionDeleteAllEvent, reaction_events.DMReactionDeleteAllEvent) + ) async def on_message_reaction_remove_all( self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject ) -> None: """See https://discord.com/developers/docs/topics/gateway#message-reaction-remove-all for more info.""" await self.dispatch(self._event_factory.deserialize_message_reaction_remove_all_event(shard, payload)) + @event_manager_base.filtered( + (reaction_events.GuildReactionDeleteEmojiEvent, reaction_events.DMReactionDeleteEmojiEvent) + ) async def on_message_reaction_remove_emoji( self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject ) -> None: """See https://discord.com/developers/docs/topics/gateway#message-reaction-remove-emoji for more info.""" await self.dispatch(self._event_factory.deserialize_message_reaction_remove_emoji_event(shard, payload)) + @event_manager_base.filtered(guild_events.PresenceUpdateEvent, config.CacheComponents.PRESENCES) async def on_presence_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#presence-update for more info.""" - old: typing.Optional[presences.MemberPresence] = None + old: typing.Optional[presences_.MemberPresence] = None if self._cache: old = self._cache.get_presence( snowflakes.Snowflake(payload["guild_id"]), snowflakes.Snowflake(payload["user"]["id"]) @@ -442,7 +591,7 @@ async def on_presence_update(self, shard: gateway_shard.GatewayShard, payload: d event = self._event_factory.deserialize_presence_update_event(shard, payload, old_presence=old) - if self._cache and event.presence.visible_status is presences.Status.OFFLINE: + if self._cache and event.presence.visible_status is presences_.Status.OFFLINE: self._cache.delete_presence(event.presence.guild_id, event.presence.user_id) elif self._cache: self._cache.update_presence(event.presence) @@ -450,10 +599,12 @@ async def on_presence_update(self, shard: gateway_shard.GatewayShard, payload: d # TODO: update user here when partial_user is set self._cache.update_user(event.partial_user) await self.dispatch(event) + @event_manager_base.filtered((typing_events.GuildTypingEvent, typing_events.DMTypingEvent)) async def on_typing_start(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#typing-start for more info.""" await self.dispatch(self._event_factory.deserialize_typing_start_event(shard, payload)) + @event_manager_base.filtered(user_events.OwnUserUpdateEvent, config.CacheComponents.ME) async def on_user_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#user-update for more info.""" old = self._cache.get_me() if self._cache else None @@ -464,6 +615,7 @@ async def on_user_update(self, shard: gateway_shard.GatewayShard, payload: data_ await self.dispatch(event) + @event_manager_base.filtered(voice_events.VoiceStateUpdateEvent, config.CacheComponents.VOICE_STATES) async def on_voice_state_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#voice-state-update for more info.""" old: typing.Optional[voices.VoiceState] = None @@ -481,10 +633,12 @@ async def on_voice_state_update(self, shard: gateway_shard.GatewayShard, payload await self.dispatch(event) + @event_manager_base.filtered(voice_events.VoiceServerUpdateEvent) async def on_voice_server_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#voice-server-update for more info.""" await self.dispatch(self._event_factory.deserialize_voice_server_update_event(shard, payload)) + @event_manager_base.filtered(channel_events.WebhookUpdateEvent) async def on_webhooks_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#webhooks-update for more info.""" await self.dispatch(self._event_factory.deserialize_webhook_update_event(shard, payload)) diff --git a/hikari/impl/event_manager_base.py b/hikari/impl/event_manager_base.py index 0c66ed6439..e26efd926c 100644 --- a/hikari/impl/event_manager_base.py +++ b/hikari/impl/event_manager_base.py @@ -24,21 +24,29 @@ from __future__ import annotations -__all__: typing.List[str] = ["EventManagerBase", "EventStream"] +__all__: typing.List[str] = ["filtered", "EventManagerBase", "EventStream"] import asyncio import inspect +import itertools import logging import typing import warnings import weakref +import attr + +from hikari import config from hikari import errors from hikari import iterators +from hikari import undefined from hikari.api import event_manager as event_manager_ from hikari.events import base_events +from hikari.events import shard_events from hikari.internal import aio +from hikari.internal import fast_protocol from hikari.internal import reflect +from hikari.internal import ux if typing.TYPE_CHECKING: import types @@ -54,18 +62,36 @@ ConsumerT = typing.Callable[ [gateway_shard.GatewayShard, data_binding.JSONObject], typing.Coroutine[typing.Any, typing.Any, None] ] - ListenerMapT = typing.MutableMapping[ + ListenerMapT = typing.Dict[ typing.Type[event_manager_.EventT_co], - typing.MutableSequence[event_manager_.CallbackT[event_manager_.EventT_co]], + typing.List[event_manager_.CallbackT[event_manager_.EventT_co]], ] WaiterT = typing.Tuple[ event_manager_.PredicateT[event_manager_.EventT_co], asyncio.Future[event_manager_.EventT_co] ] - WaiterMapT = typing.MutableMapping[ - typing.Type[event_manager_.EventT_co], typing.MutableSet[WaiterT[event_manager_.EventT_co]] + WaiterMapT = typing.Dict[typing.Type[event_manager_.EventT_co], typing.Set[WaiterT[event_manager_.EventT_co]]] + + EventManagerBaseT = typing.TypeVar("EventManagerBaseT", bound="EventManagerBase") + UnboundMethodT = typing.Callable[ + [EventManagerBaseT, gateway_shard.GatewayShard, data_binding.JSONObject], + typing.Coroutine[typing.Any, typing.Any, None], ] +@typing.runtime_checkable +class _FilteredMethodT(fast_protocol.FastProtocolChecking, typing.Protocol): + async def __call__(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject, /) -> None: + raise NotImplementedError + + @property + def __cache_components__(self) -> config.CacheComponents: + raise NotImplementedError + + @property + def __event_types__(self) -> typing.Sequence[typing.Type[base_events.Event]]: + raise NotImplementedError + + def _generate_weak_listener( reference: weakref.WeakMethod, ) -> typing.Callable[[event_manager_.EventT], typing.Coroutine[typing.Any, typing.Any, None]]: @@ -233,10 +259,57 @@ def _assert_is_listener(parameters: typing.Iterator[inspect.Parameter], /) -> No if next(parameters, None) is None: raise TypeError("Event listener must have one positional argument for the event object.") - if any(param.default is not inspect.Parameter.empty for param in parameters): + if any(param.default is inspect.Parameter.empty for param in parameters): raise TypeError("Only the first argument for a listener can be required, the event argument.") +def filtered( + event_types: typing.Union[typing.Type[base_events.Event], typing.Sequence[typing.Type[base_events.Event]]], + cache_components: config.CacheComponents = config.CacheComponents.NONE, + /, +) -> typing.Callable[[UnboundMethodT[EventManagerBaseT]], UnboundMethodT[EventManagerBaseT]]: + """Add metadata to a consumer method to indicate when it should be unmarshalled and dispatched. + + Parameters + ---------- + event_types + Types of the events this raw consumer method may dispatch. + This may either be a singular type of a sequence of types. + + Other Parameters + ---------------- + cache_components : hikari.config.CacheComponents + Bitfield of the cache components this event may make altering calls to. + This defaults to `hikari.config.CacheComponents.NONE`. + """ + if isinstance(event_types, typing.Sequence): + # dict.fromkeys is used to remove any duplicate entries here + event_types = tuple(dict.fromkeys(itertools.chain.from_iterable(e.dispatches() for e in event_types))) + + else: + event_types = event_types.dispatches() + + def decorator(method: UnboundMethodT[EventManagerBaseT], /) -> UnboundMethodT[EventManagerBaseT]: + method.__cache_components__ = cache_components # type: ignore[attr-defined] + method.__event_types__ = event_types # type: ignore[attr-defined] + assert isinstance(method, _FilteredMethodT), "Incorrect attribute(s) set for a filtered method" + return method # type: ignore[unreachable] + + return decorator + + +@attr.define(hash=True) +class _Consumer: + callback: ConsumerT = attr.ib(hash=True) + """The callback function for this consumer.""" + + event_types: undefined.UndefinedOr[typing.Sequence[typing.Type[base_events.Event]]] = attr.ib(hash=False) + """A sequence of the types of events this consumer dispatches to, if set.""" + + is_caching: bool = attr.ib(hash=False) + """Cached value of whether or not this consumer is making cache calls in the current env.""" + + class EventManagerBase(event_manager_.EventManager): """Provides functionality to consume and dispatch events. @@ -244,10 +317,24 @@ class EventManagerBase(event_manager_.EventManager): is the raw event name being dispatched in lower-case. """ - __slots__: typing.Sequence[str] = ("_event_factory", "_intents", "_listeners", "_consumers", "_waiters") + __slots__: typing.Sequence[str] = ( + "_consumers", + "_enabled_consumers_cache", + "_event_factory", + "_intents", + "_listeners", + "_waiters", + ) - def __init__(self, event_factory: event_factory_.EventFactory, intents: intents_.Intents) -> None: - self._consumers: typing.Dict[str, ConsumerT] = {} + def __init__( + self, + event_factory: event_factory_.EventFactory, + intents: intents_.Intents, + *, + cache_components: config.CacheComponents = config.CacheComponents.NONE, + ) -> None: + self._consumers: typing.Dict[str, _Consumer] = {} + self._enabled_consumers_cache: typing.Dict[_Consumer, bool] = {} self._event_factory = event_factory self._intents = intents self._listeners: ListenerMapT[base_events.Event] = {} @@ -255,15 +342,52 @@ def __init__(self, event_factory: event_factory_.EventFactory, intents: intents_ for name, member in inspect.getmembers(self): if name.startswith("on_"): - self._consumers[name[3:]] = member + event_name = name[3:] + if isinstance(member, _FilteredMethodT): + caching = (member.__cache_components__ & cache_components) != 0 + self._consumers[event_name] = _Consumer(member, member.__event_types__, caching) + + else: + self._consumers[event_name] = _Consumer( + member, undefined.UNDEFINED, cache_components != cache_components.NONE + ) + + def _clear_enabled_cache(self) -> None: + self._enabled_consumers_cache = {} + + def _enabled_for_event(self, event_type: typing.Type[base_events.Event], /) -> bool: + for cls in event_type.dispatches(): + if cls in self._listeners or cls in self._waiters: + return True + + return False + + def _enabled_for_consumer(self, consumer: _Consumer, /) -> bool: + # If undefined then we can only assume that this may link to registered listeners. + if consumer.event_types is undefined.UNDEFINED or consumer.is_caching: + return True + + if (cached_value := self._enabled_consumers_cache.get(consumer)) is not None: + return cached_value + + # The behaviour here where an empty sequence for event_types will lead to this always + # being skipped unless there's a relevant enabled cache resource is intended behaviour. + for event_type in consumer.event_types: + if event_type in self._listeners or event_type in self._waiters: + self._enabled_consumers_cache[consumer] = True + return True + + self._enabled_consumers_cache[consumer] = False + return False def consume_raw_event( self, event_name: str, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject ) -> None: - payload_event = self._event_factory.deserialize_shard_payload_event(shard, payload, name=event_name) - self.dispatch(payload_event) - callback = self._consumers[event_name.casefold()] - asyncio.create_task(self._handle_dispatch(callback, shard, payload), name=f"dispatch {event_name}") + if self._enabled_for_event(shard_events.ShardPayloadEvent): + payload_event = self._event_factory.deserialize_shard_payload_event(shard, payload, name=event_name) + self.dispatch(payload_event) + consumer = self._consumers[event_name.lower()] + asyncio.create_task(self._handle_dispatch(consumer, shard, payload), name=f"dispatch {event_name}") def subscribe( self, @@ -282,9 +406,6 @@ def subscribe( # warning is triggered. self._check_intents(event_type, _nested) - if event_type not in self._listeners: - self._listeners[event_type] = [] - _LOGGER.debug( "subscribing callback 'async def %s%s' to event-type %s.%s", getattr(callback, "__name__", ""), @@ -293,7 +414,11 @@ def subscribe( event_type.__qualname__, ) - self._listeners[event_type].append(callback) # type: ignore[arg-type] + try: + self._listeners[event_type].append(callback) # type: ignore[arg-type] + except KeyError: + self._listeners[event_type] = [callback] # type: ignore[list-item] + self._clear_enabled_cache() def _check_intents(self, event_type: typing.Type[event_manager_.EventT_co], nested: int) -> None: # Collection of combined bitfield combinations of intents that @@ -327,10 +452,10 @@ def get_listeners( if issubclass(subscribed_event_type, event_type): listeners += subscribed_listeners return listeners + else: - items = self._listeners.get(event_type) - if items is not None: - return items[:] + if items := self._listeners.get(event_type): + return items.copy() return [] @@ -339,7 +464,7 @@ def unsubscribe( event_type: typing.Type[event_manager_.EventT_co], callback: event_manager_.CallbackT[event_manager_.EventT_co], ) -> None: - if event_type in self._listeners: + if listeners := self._listeners.get(event_type): _LOGGER.debug( "unsubscribing callback %s%s from event-type %s.%s", getattr(callback, "__name__", ""), @@ -347,9 +472,10 @@ def unsubscribe( event_type.__module__, event_type.__qualname__, ) - self._listeners[event_type].remove(callback) # type: ignore[arg-type] - if not self._listeners[event_type]: + listeners.remove(callback) # type: ignore[arg-type] + if not listeners: del self._listeners[event_type] + self._clear_enabled_cache() def listen( self, @@ -387,16 +513,12 @@ def dispatch(self, event: event_manager_.EventT_inv) -> asyncio.Future[typing.An if not isinstance(event, base_events.Event): raise TypeError(f"Events must be subclasses of {base_events.Event.__name__}, not {type(event).__name__}") - # We only need to iterate through the MRO until we hit Event, as - # anything after that is random garbage we don't care about, as they do - # not describe event types. This improves efficiency as well. - mro = type(event).mro() - tasks: typing.List[typing.Coroutine[None, typing.Any, None]] = [] + clear_cache = False - for cls in mro[: mro.index(base_events.Event) + 1]: - if cls in self._listeners: - for callback in self._listeners[cls]: + for cls in event.dispatches(): + if listeners := self._listeners.get(cls): + for callback in listeners: tasks.append(self._invoke_callback(callback, event)) if cls not in self._waiters: @@ -417,6 +539,13 @@ def dispatch(self, event: event_manager_.EventT_inv) -> asyncio.Future[typing.An waiter_set.remove(waiter) + if not waiter_set: + del self._waiters[cls] + clear_cache = True + + if clear_cache: + self._clear_enabled_cache() + return asyncio.gather(*tasks) if tasks else aio.completed_future() def stream( @@ -436,7 +565,6 @@ async def wait_for( timeout: typing.Union[float, int, None], predicate: typing.Optional[event_manager_.PredicateT[event_manager_.EventT_co]] = None, ) -> event_manager_.EventT_co: - if predicate is None: predicate = _default_predicate @@ -447,6 +575,7 @@ async def wait_for( try: waiter_set = self._waiters[event_type] except KeyError: + self._clear_enabled_cache() waiter_set = set() self._waiters[event_type] = waiter_set @@ -457,16 +586,27 @@ async def wait_for( return await asyncio.wait_for(future, timeout=timeout) except asyncio.TimeoutError: waiter_set.remove(pair) # type: ignore[arg-type] + if not waiter_set: + del self._waiters[event_type] + self._clear_enabled_cache() + raise - @staticmethod async def _handle_dispatch( - callback: ConsumerT, + self, + consumer: _Consumer, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject, ) -> None: + if not self._enabled_for_consumer(consumer): + name = consumer.callback.__name__ + _LOGGER.log( + ux.TRACE, "Skipping raw dispatch for %s due to lack of any registered listeners or cache need", name + ) + return + try: - await callback(shard, payload) + await consumer.callback(shard, payload) except asyncio.CancelledError: # Skip cancelled errors, likely caused by the event loop being shut down. pass diff --git a/tests/hikari/impl/test_bot.py b/tests/hikari/impl/test_bot.py index f9b37dc0a1..a778de8f02 100644 --- a/tests/hikari/impl/test_bot.py +++ b/tests/hikari/impl/test_bot.py @@ -159,7 +159,9 @@ def test_init(self): assert bot._cache is cache.return_value cache.assert_called_once_with(bot, cache_settings) assert bot._event_manager is event_manager.return_value - event_manager.assert_called_once_with(event_factory.return_value, intents, cache=cache.return_value) + event_manager.assert_called_once_with( + entity_factory.return_value, event_factory.return_value, intents, cache=cache.return_value + ) assert bot._entity_factory is entity_factory.return_value entity_factory.assert_called_once_with(bot) assert bot._event_factory is event_factory.return_value diff --git a/tests/hikari/impl/test_cache.py b/tests/hikari/impl/test_cache.py index 715802ed20..4dc83155e3 100644 --- a/tests/hikari/impl/test_cache.py +++ b/tests/hikari/impl/test_cache.py @@ -1471,6 +1471,13 @@ def test_set_me(self, cache_impl): assert cache_impl._me == mock_own_user assert cache_impl._me is not mock_own_user + def test_set_me_when_not_enabled(self, cache_impl): + cache_impl._settings.components = 0 + + cache_impl.set_me(object()) + + assert cache_impl._me is None + def test_update_me_for_cached_me(self, cache_impl): mock_cached_own_user = mock.MagicMock(users.OwnUser) mock_own_user = mock.MagicMock(users.OwnUser) @@ -1489,6 +1496,17 @@ def test_update_me_for_uncached_me(self, cache_impl): assert result == (None, mock_own_user) assert cache_impl._me == mock_own_user + def test_update_me_for_when_not_enabled(self, cache_impl): + cache_impl._settings.components = 0 + cache_impl.get_me = mock.Mock() + cache_impl.set_me = mock.Mock() + + result = cache_impl.update_me(object()) + + assert result == (None, None) + cache_impl.get_me.assert_not_called() + cache_impl.set_me.assert_not_called() + def test__build_member(self, cache_impl): mock_user = mock.MagicMock(users.User) member_data = cache_utilities.MemberData( diff --git a/tests/hikari/impl/test_entity_factory.py b/tests/hikari/impl/test_entity_factory.py index 92ce8008e0..890b5892b7 100644 --- a/tests/hikari/impl/test_entity_factory.py +++ b/tests/hikari/impl/test_entity_factory.py @@ -50,6 +50,197 @@ from hikari.interactions import base_interactions from hikari.interactions import command_interactions from hikari.interactions import component_interactions +from tests.hikari import hikari_test_helpers + + +@pytest.fixture() +def mock_app() -> traits.RESTAware: + return mock.MagicMock(traits.RESTAware) + + +@pytest.fixture() +def permission_overwrite_payload(): + return {"id": "4242", "type": 1, "allow": 65, "deny": 49152, "allow_new": "65", "deny_new": "49152"} + + +@pytest.fixture() +def guild_text_channel_payload(permission_overwrite_payload): + return { + "id": "123", + "guild_id": "567", + "name": "general", + "type": 0, + "position": 6, + "permission_overwrites": [permission_overwrite_payload], + "rate_limit_per_user": 2, + "nsfw": True, + "topic": "¯\\_(ツ)_/¯", + "last_message_id": "123456", + "last_pin_timestamp": "2020-05-27T15:58:51.545252+00:00", + "parent_id": "987", + } + + +@pytest.fixture() +def guild_voice_channel_payload(permission_overwrite_payload): + return { + "id": "555", + "guild_id": "789", + "name": "Secret Developer Discussions", + "type": 2, + "nsfw": True, + "position": 4, + "permission_overwrites": [permission_overwrite_payload], + "bitrate": 64000, + "user_limit": 3, + "rtc_region": "europe", + "parent_id": "456", + "video_quality_mode": 1, + } + + +@pytest.fixture() +def guild_news_channel_payload(permission_overwrite_payload): + return { + "id": "7777", + "guild_id": "123", + "name": "Important Announcements", + "type": 5, + "position": 0, + "permission_overwrites": [permission_overwrite_payload], + "nsfw": True, + "topic": "Super Important Announcements", + "last_message_id": "456", + "parent_id": "654", + "last_pin_timestamp": "2020-05-27T15:58:51.545252+00:00", + } + + +@pytest.fixture() +def user_payload(): + return { + "id": "115590097100865541", + "username": "nyaa", + "avatar": "b3b24c6d7cbcdec129d5d537067061a8", + "banner": "a_221313e1e2edsncsncsmcndsc", + "accent_color": 231321, + "discriminator": "6127", + "bot": True, + "system": True, + "public_flags": int(user_models.UserFlag.EARLY_VERIFIED_DEVELOPER), + } + + +@pytest.fixture() +def custom_emoji_payload(): + return {"id": "691225175349395456", "name": "test", "animated": True} + + +@pytest.fixture() +def known_custom_emoji_payload(user_payload): + return { + "id": "12345", + "name": "testing", + "animated": False, + "available": True, + "roles": ["123", "456"], + "user": user_payload, + "require_colons": True, + "managed": False, + } + + +@pytest.fixture() +def member_payload(user_payload): + return { + "nick": "foobarbaz", + "roles": ["11111", "22222", "33333", "44444"], + "joined_at": "2015-04-26T06:26:56.936000+00:00", + "premium_since": "2019-05-17T06:26:56.936000+00:00", + "avatar": "estrogen", + "deaf": False, + "mute": True, + "pending": False, + "user": user_payload, + } + + +@pytest.fixture() +def presence_activity_payload(custom_emoji_payload): + return { + "name": "an activity", + "type": 1, + "url": "https://69.420.owouwunyaa", + "created_at": 1584996792798, + "timestamps": {"start": 1584996792798, "end": 1999999792798}, + "application_id": "40404040404040", + "details": "They are doing stuff", + "state": "STATED", + "emoji": custom_emoji_payload, + "party": {"id": "spotify:3234234234", "size": [2, 5]}, + "assets": { + "large_image": "34234234234243", + "large_text": "LARGE TEXT", + "small_image": "3939393", + "small_text": "small text", + }, + "secrets": {"join": "who's a good secret?", "spectate": "I'm a good secret", "match": "No."}, + "instance": True, + "flags": 3, + "buttons": ["owo", "no"], + } + + +@pytest.fixture() +def member_presence_payload(user_payload, presence_activity_payload): + return { + "user": user_payload, + "activity": presence_activity_payload, + "guild_id": "44004040", + "status": "dnd", + "activities": [presence_activity_payload], + "client_status": {"desktop": "online", "mobile": "idle", "web": "dnd"}, + } + + +@pytest.fixture() +def guild_role_payload(): + return { + "id": "41771983423143936", + "name": "WE DEM BOYZZ!!!!!!", + "color": 3_447_003, + "hoist": True, + "unicode_emoji": "\N{OK HAND SIGN}", + "icon": "abc123hash", + "position": 0, + "permissions": "66321471", + "managed": False, + "mentionable": False, + "tags": { + "bot_id": "123", + "integration_id": "456", + "premium_subscriber": None, + }, + } + + +@pytest.fixture() +def voice_state_payload(member_payload): + return { + "guild_id": "929292929292992", + "channel_id": "157733188964188161", + "user_id": "115590097100865541", + "member": member_payload, + "session_id": "90326bd25d71d39b9ef95b299e3872ff", + "deaf": True, + "mute": True, + "self_deaf": False, + "self_mute": True, + "self_stream": True, + "self_video": True, + "suppress": False, + "request_to_speak_timestamp": "2021-04-17T10:11:19.970105+00:00", + } def test__with_int_cast(): @@ -86,11 +277,364 @@ def test__deserialize_max_age_returns_null(): assert entity_factory._deserialize_max_age(0) is None -class TestEntityFactoryImpl: +class TestGatewayGuildDefinition: @pytest.fixture() - def mock_app(self) -> traits.RESTAware: - return mock.MagicMock(traits.RESTAware) + def entity_factory_impl(self, mock_app) -> entity_factory.EntityFactoryImpl: + return hikari_test_helpers.mock_class_namespace(entity_factory.EntityFactoryImpl, slots_=False)(mock_app) + + def test_id_property(self, entity_factory_impl): + guild_definition = entity_factory_impl.deserialize_gateway_guild({"id": "123123451234"}) + + assert guild_definition.id == 123123451234 + + def test_channels( + self, entity_factory_impl, guild_text_channel_payload, guild_voice_channel_payload, guild_news_channel_payload + ): + guild_definition = entity_factory_impl.deserialize_gateway_guild( + { + "id": "265828729970753537", + "channels": [guild_text_channel_payload, guild_voice_channel_payload, guild_news_channel_payload], + } + ) + + assert guild_definition.channels() == { + 123: entity_factory_impl.deserialize_guild_text_channel( + guild_text_channel_payload, guild_id=snowflakes.Snowflake(265828729970753537) + ), + 555: entity_factory_impl.deserialize_guild_voice_channel( + guild_voice_channel_payload, guild_id=snowflakes.Snowflake(265828729970753537) + ), + 7777: entity_factory_impl.deserialize_guild_news_channel( + guild_news_channel_payload, guild_id=snowflakes.Snowflake(265828729970753537) + ), + } + + def test_channels_returns_cached_values(self, entity_factory_impl): + guild_definition = entity_factory_impl.deserialize_gateway_guild({"id": "265828729970753537"}) + mock_channel = object() + guild_definition._channels = {"123321": mock_channel} + entity_factory_impl.deserialize_guild_text_channel = mock.Mock() + entity_factory_impl.deserialize_guild_voice_channel = mock.Mock() + entity_factory_impl.deserialize_guild_news_channel = mock.Mock() + + assert guild_definition.channels() == {"123321": mock_channel} + + entity_factory_impl.deserialize_guild_text_channel.assert_not_called() + entity_factory_impl.deserialize_guild_voice_channel.assert_not_called() + entity_factory_impl.deserialize_guild_news_channel.assert_not_called() + + def test_channels_ignores_unrecognised_channels(self, entity_factory_impl): + guild_definition = entity_factory_impl.deserialize_gateway_guild( + {"id": "9494949", "channels": [{"id": 123, "type": 1000}]} + ) + + assert guild_definition.channels() == {} + + def test_emojis(self, entity_factory_impl, known_custom_emoji_payload): + guild_definition = entity_factory_impl.deserialize_gateway_guild( + {"id": "265828729970753537", "emojis": [known_custom_emoji_payload]}, + ) + + assert guild_definition.emojis() == { + 12345: entity_factory_impl.deserialize_known_custom_emoji( + known_custom_emoji_payload, guild_id=snowflakes.Snowflake(265828729970753537) + ) + } + + def test_emojis_returns_cached_values(self, entity_factory_impl): + mock_emoji = object() + entity_factory_impl.deserialize_known_custom_emoji = mock.Mock() + guild_definition = entity_factory_impl.deserialize_gateway_guild({"id": "265828729970753537"}) + guild_definition._emojis = {"21323232": mock_emoji} + + assert guild_definition.emojis() == {"21323232": mock_emoji} + + entity_factory_impl.deserialize_known_custom_emoji.assert_not_called() + + def test_guild(self, entity_factory_impl, mock_app): + guild_definition = entity_factory_impl.deserialize_gateway_guild( + { + "afk_channel_id": "99998888777766", + "afk_timeout": 1200, + "application_id": "39494949", + "banner": "1a2b3c", + "default_message_notifications": 1, + "description": "This is a server I guess, its a bit crap though", + "discovery_splash": "famfamFAMFAMfam", + "embed_channel_id": "9439394949", + "embed_enabled": True, + "explicit_content_filter": 2, + "features": ["ANIMATED_ICON", "MORE_EMOJI", "NEWS", "SOME_UNDOCUMENTED_FEATURE"], + "icon": "1a2b3c4d", + "id": "265828729970753537", + "joined_at": "2019-05-17T06:26:56.936000+00:00", + "large": False, + "max_members": 25000, + "max_presences": 250, + "max_video_channel_users": 25, + "member_count": 14, + "mfa_level": 1, + "name": "L33t guild", + "owner_id": "6969696", + "preferred_locale": "en-GB", + "premium_subscription_count": 1, + "premium_tier": 2, + "public_updates_channel_id": "33333333", + "rules_channel_id": "42042069", + "splash": "0ff0ff0ff", + "system_channel_flags": 3, + "system_channel_id": "19216801", + "unavailable": False, + "vanity_url_code": "loool", + "verification_level": 4, + "widget_channel_id": "9439394949", + "widget_enabled": True, + "nsfw_level": 0, + } + ) + + guild = guild_definition.guild() + assert guild.app is mock_app + assert guild.id == 265828729970753537 + assert guild.name == "L33t guild" + assert guild.icon_hash == "1a2b3c4d" + assert guild.features == [ + guild_models.GuildFeature.ANIMATED_ICON, + guild_models.GuildFeature.MORE_EMOJI, + guild_models.GuildFeature.NEWS, + "SOME_UNDOCUMENTED_FEATURE", + ] + assert guild.splash_hash == "0ff0ff0ff" + assert guild.discovery_splash_hash == "famfamFAMFAMfam" + assert guild.owner_id == 6969696 + assert guild.afk_channel_id == 99998888777766 + assert guild.afk_timeout == datetime.timedelta(seconds=1200) + assert guild.verification_level == guild_models.GuildVerificationLevel.VERY_HIGH + assert guild.default_message_notifications == guild_models.GuildMessageNotificationsLevel.ONLY_MENTIONS + assert guild.explicit_content_filter == guild_models.GuildExplicitContentFilterLevel.ALL_MEMBERS + assert guild.mfa_level == guild_models.GuildMFALevel.ELEVATED + assert guild.application_id == 39494949 + assert guild.widget_channel_id == 9439394949 + assert guild.is_widget_enabled is True + assert guild.system_channel_id == 19216801 + assert guild.system_channel_flags == guild_models.GuildSystemChannelFlag(3) + assert guild.rules_channel_id == 42042069 + assert guild.joined_at == datetime.datetime(2019, 5, 17, 6, 26, 56, 936000, tzinfo=datetime.timezone.utc) + assert guild.is_large is False + assert guild.member_count == 14 + assert guild.max_video_channel_users == 25 + assert guild.vanity_url_code == "loool" + assert guild.description == "This is a server I guess, its a bit crap though" + assert guild.banner_hash == "1a2b3c" + assert guild.premium_tier == guild_models.GuildPremiumTier.TIER_2 + assert guild.premium_subscription_count == 1 + assert guild.preferred_locale == "en-GB" + assert guild.public_updates_channel_id == 33333333 + assert guild.nsfw_level == guild_models.GuildNSFWLevel.DEFAULT + def test_guild_with_unset_fields(self, entity_factory_impl): + guild_definition = entity_factory_impl.deserialize_gateway_guild( + { + "afk_channel_id": "99998888777766", + "afk_timeout": 1200, + "application_id": "39494949", + "banner": "1a2b3c", + "default_message_notifications": 1, + "description": "This is a server I guess, its a bit crap though", + "discovery_splash": "famfamFAMFAMfam", + "emojis": [], + "explicit_content_filter": 2, + "features": ["ANIMATED_ICON", "MORE_EMOJI", "NEWS", "SOME_UNDOCUMENTED_FEATURE"], + "icon": "1a2b3c4d", + "id": "265828729970753537", + "mfa_level": 1, + "name": "L33t guild", + "owner_id": "6969696", + "preferred_locale": "en-GB", + "premium_tier": 2, + "public_updates_channel_id": "33333333", + "roles": [], + "rules_channel_id": "42042069", + "splash": "0ff0ff0ff", + "system_channel_flags": 3, + "system_channel_id": "19216801", + "vanity_url_code": "loool", + "verification_level": 4, + "nsfw_level": 0, + }, + ) + guild = guild_definition.guild() + assert guild.joined_at is None + assert guild.is_large is None + assert guild.max_video_channel_users is None + assert guild.member_count is None + assert guild.premium_subscription_count is None + assert guild.widget_channel_id is None + assert guild.is_widget_enabled is None + + def test_guild_with_null_fields(self, entity_factory_impl): + guild_definition = entity_factory_impl.deserialize_gateway_guild( + { + "afk_channel_id": None, + "afk_timeout": 1200, + "application_id": None, + "banner": None, + "channels": [], + "default_message_notifications": 1, + "description": None, + "discovery_splash": None, + "embed_channel_id": None, + "embed_enabled": True, + "emojis": [], + "explicit_content_filter": 2, + "features": ["ANIMATED_ICON", "MORE_EMOJI", "NEWS", "SOME_UNDOCUMENTED_FEATURE"], + "icon": None, + "id": "265828729970753537", + "joined_at": "2019-05-17T06:26:56.936000+00:00", + "large": False, + "max_members": 25000, + "max_presences": None, + "max_video_channel_users": 25, + "member_count": 14, + "members": [], + "mfa_level": 1, + "name": "L33t guild", + "owner_id": "6969696", + "permissions": 66_321_471, + "preferred_locale": "en-GB", + "premium_subscription_count": None, + "premium_tier": 2, + "presences": [], + "public_updates_channel_id": None, + "roles": [], + "rules_channel_id": None, + "splash": None, + "system_channel_flags": 3, + "system_channel_id": None, + "unavailable": False, + "vanity_url_code": None, + "verification_level": 4, + "voice_states": [], + "widget_channel_id": None, + "widget_enabled": True, + "nsfw_level": 0, + }, + ) + guild = guild_definition.guild() + assert guild.icon_hash is None + assert guild.splash_hash is None + assert guild.discovery_splash_hash is None + assert guild.afk_channel_id is None + assert guild.application_id is None + assert guild.widget_channel_id is None + assert guild.system_channel_id is None + assert guild.rules_channel_id is None + assert guild.vanity_url_code is None + assert guild.description is None + assert guild.banner_hash is None + assert guild.premium_subscription_count is None + assert guild.public_updates_channel_id is None + + def test_guild_returns_cached_values(self, entity_factory_impl): + mock_guild = object() + entity_factory_impl.set_guild_attributes = mock.Mock() + guild_definition = entity_factory_impl.deserialize_gateway_guild({"id": "9393939"}) + guild_definition._guild = mock_guild + + assert guild_definition.guild() is mock_guild + + entity_factory_impl.set_guild_attributes.assert_not_called() + + def test_members(self, entity_factory_impl, member_payload): + guild_definition = entity_factory_impl.deserialize_gateway_guild( + {"id": "265828729970753537", "members": [member_payload]} + ) + + assert guild_definition.members() == { + 115590097100865541: entity_factory_impl.deserialize_member( + member_payload, guild_id=snowflakes.Snowflake(265828729970753537) + ) + } + + def test_members_returns_cached_values(self, entity_factory_impl): + mock_member = object() + entity_factory_impl.deserialize_member = mock.Mock() + guild_definition = entity_factory_impl.deserialize_gateway_guild({"id": "92929292"}) + guild_definition._members = {"93939393": mock_member} + + assert guild_definition.members() == {"93939393": mock_member} + + entity_factory_impl.deserialize_member.assert_not_called() + + def test_presences(self, entity_factory_impl, member_presence_payload): + guild_definition = entity_factory_impl.deserialize_gateway_guild( + {"id": "265828729970753537", "presences": [member_presence_payload]} + ) + + assert guild_definition.presences() == { + 115590097100865541: entity_factory_impl.deserialize_member_presence( + member_presence_payload, guild_id=snowflakes.Snowflake(265828729970753537) + ) + } + + def test_presences_returns_cached_values(self, entity_factory_impl): + mock_presence = object() + entity_factory_impl.deserialize_member_presence = mock.Mock() + guild_definition = entity_factory_impl.deserialize_gateway_guild({"id": "29292992"}) + guild_definition._presences = {"3939393993": mock_presence} + + assert guild_definition.presences() == {"3939393993": mock_presence} + + entity_factory_impl.deserialize_member_presence.assert_not_called() + + def test_roles(self, entity_factory_impl, guild_role_payload): + guild_definition = entity_factory_impl.deserialize_gateway_guild( + {"id": "265828729970753537", "roles": [guild_role_payload]} + ) + + assert guild_definition.roles() == { + 41771983423143936: entity_factory_impl.deserialize_role( + guild_role_payload, guild_id=snowflakes.Snowflake(265828729970753537) + ) + } + + def test_roles_returns_cached_values(self, entity_factory_impl): + mock_role = object() + entity_factory_impl.deserialize_role = mock.Mock() + guild_definition = entity_factory_impl.deserialize_gateway_guild({"id": "9292929"}) + guild_definition._roles = {"32132123123": mock_role} + + assert guild_definition.roles() == {"32132123123": mock_role} + + entity_factory_impl.deserialize_role.assert_not_called() + + def test_voice_states(self, entity_factory_impl, member_payload, voice_state_payload): + guild_definition = entity_factory_impl.deserialize_gateway_guild( + {"id": "265828729970753537", "voice_states": [voice_state_payload], "members": [member_payload]} + ) + assert guild_definition.voice_states() == { + 115590097100865541: entity_factory_impl.deserialize_voice_state( + voice_state_payload, + guild_id=snowflakes.Snowflake(265828729970753537), + member=entity_factory_impl.deserialize_member( + member_payload, + guild_id=snowflakes.Snowflake(265828729970753537), + ), + ) + } + + def test_voice_states_returns_cached_values(self, entity_factory_impl): + mock_voice_state = object() + entity_factory_impl.deserialize_voice_state = mock.Mock() + guild_definition = entity_factory_impl.deserialize_gateway_guild({"id": "292929"}) + guild_definition._voice_states = {"9393939393": mock_voice_state} + + assert guild_definition.voice_states() == {"9393939393": mock_voice_state} + + entity_factory_impl.deserialize_voice_state.assert_not_called() + + +class TestEntityFactoryImpl: @pytest.fixture() def entity_factory_impl(self, mock_app) -> entity_factory.EntityFactoryImpl: return entity_factory.EntityFactoryImpl(app=mock_app) @@ -763,10 +1307,6 @@ def test_deserialize_channel_follow(self, entity_factory_impl, mock_app): assert follow.channel_id == 41231 assert follow.webhook_id == 939393 - @pytest.fixture() - def permission_overwrite_payload(self): - return {"id": "4242", "type": 1, "allow": 65, "deny": 49152, "allow_new": "65", "deny_new": "49152"} - @pytest.mark.parametrize("type", [0, 1]) def test_deserialize_permission_overwrite(self, entity_factory_impl, type): permission_overwrite_payload = { @@ -937,23 +1477,6 @@ def test_deserialize_guild_category_with_null_fields(self, entity_factory_impl, ) assert guild_category.parent_id is None - @pytest.fixture() - def guild_text_channel_payload(self, permission_overwrite_payload): - return { - "id": "123", - "guild_id": "567", - "name": "general", - "type": 0, - "position": 6, - "permission_overwrites": [permission_overwrite_payload], - "rate_limit_per_user": 2, - "nsfw": True, - "topic": "¯\\_(ツ)_/¯", - "last_message_id": "123456", - "last_pin_timestamp": "2020-05-27T15:58:51.545252+00:00", - "parent_id": "987", - } - def test_deserialize_guild_text_channel( self, entity_factory_impl, mock_app, guild_text_channel_payload, permission_overwrite_payload ): @@ -1017,22 +1540,6 @@ def test_deserialize_guild_text_channel_with_null_fields(self, entity_factory_im assert guild_text_channel.last_pin_timestamp is None assert guild_text_channel.parent_id is None - @pytest.fixture() - def guild_news_channel_payload(self, permission_overwrite_payload): - return { - "id": "7777", - "guild_id": "123", - "name": "Important Announcements", - "type": 5, - "position": 0, - "permission_overwrites": [permission_overwrite_payload], - "nsfw": True, - "topic": "Super Important Announcements", - "last_message_id": "456", - "parent_id": "654", - "last_pin_timestamp": "2020-05-27T15:58:51.545252+00:00", - } - def test_deserialize_guild_news_channel( self, entity_factory_impl, mock_app, guild_news_channel_payload, permission_overwrite_payload ): @@ -1151,23 +1658,6 @@ def test_deserialize_guild_store_channel_with_null_fields(self, entity_factory_i ) assert store_chanel.parent_id is None - @pytest.fixture() - def guild_voice_channel_payload(self, permission_overwrite_payload): - return { - "id": "555", - "guild_id": "789", - "name": "Secret Developer Discussions", - "type": 2, - "nsfw": True, - "position": 4, - "permission_overwrites": [permission_overwrite_payload], - "bitrate": 64000, - "user_limit": 3, - "rtc_region": "europe", - "parent_id": "456", - "video_quality_mode": 1, - } - def test_deserialize_guild_voice_channel( self, entity_factory_impl, mock_app, guild_voice_channel_payload, permission_overwrite_payload ): @@ -1678,10 +2168,6 @@ def test_deserialize_unicode_emoji(self, entity_factory_impl): assert emoji.name == "🤷" assert isinstance(emoji, emoji_models.UnicodeEmoji) - @pytest.fixture() - def custom_emoji_payload(self): - return {"id": "691225175349395456", "name": "test", "animated": True} - def test_deserialize_custom_emoji(self, entity_factory_impl, mock_app, custom_emoji_payload): emoji = entity_factory_impl.deserialize_custom_emoji(custom_emoji_payload) assert emoji.id == snowflakes.Snowflake(691225175349395456) @@ -1696,19 +2182,6 @@ def test_deserialize_custom_emoji_with_unset_and_null_fields( assert emoji.is_animated is False assert emoji.name is None - @pytest.fixture() - def known_custom_emoji_payload(self, user_payload): - return { - "id": "12345", - "name": "testing", - "animated": False, - "available": True, - "roles": ["123", "456"], - "user": user_payload, - "require_colons": True, - "managed": False, - } - def test_deserialize_known_custom_emoji( self, entity_factory_impl, mock_app, user_payload, known_custom_emoji_payload ): @@ -1870,20 +2343,6 @@ def test_serialize_welcome_channel_with_no_emoji(self, entity_factory_impl, mock assert result == {"channel_id": "4312312", "description": "meow2"} - @pytest.fixture() - def member_payload(self, user_payload): - return { - "nick": "foobarbaz", - "roles": ["11111", "22222", "33333", "44444"], - "joined_at": "2015-04-26T06:26:56.936000+00:00", - "premium_since": "2019-05-17T06:26:56.936000+00:00", - "avatar": "estrogen", - "deaf": False, - "mute": True, - "pending": False, - "user": user_payload, - } - def test_deserialize_member(self, entity_factory_impl, mock_app, member_payload, user_payload): member_payload = {**member_payload, "guild_id": "76543325"} member = entity_factory_impl.deserialize_member(member_payload) @@ -1972,26 +2431,6 @@ def test_deserialize_member_with_passed_through_user_object_and_guild_id(self, e assert member.user is mock_user assert member.guild_id == 64234 - @pytest.fixture() - def guild_role_payload(self): - return { - "id": "41771983423143936", - "name": "WE DEM BOYZZ!!!!!!", - "color": 3_447_003, - "hoist": True, - "unicode_emoji": "\N{OK HAND SIGN}", - "icon": "abc123hash", - "position": 0, - "permissions": "66321471", - "managed": False, - "mentionable": False, - "tags": { - "bot_id": "123", - "integration_id": "456", - "premium_subscriber": None, - }, - } - def test_deserialize_role(self, entity_factory_impl, mock_app, guild_role_payload): guild_role = entity_factory_impl.deserialize_role(guild_role_payload, guild_id=snowflakes.Snowflake(76534453)) assert guild_role.app is mock_app @@ -2431,273 +2870,11 @@ def test_deserialize_rest_guild_with_null_fields(self, entity_factory_impl): assert guild.premium_subscription_count is None assert guild.public_updates_channel_id is None - @pytest.fixture() - def deserialize_gateway_guild_payload( - self, - guild_text_channel_payload, - guild_voice_channel_payload, - guild_news_channel_payload, - known_custom_emoji_payload, - member_payload, - member_presence_payload, - guild_role_payload, - voice_state_payload, - ): - return { - "afk_channel_id": "99998888777766", - "afk_timeout": 1200, - "application_id": "39494949", - "banner": "1a2b3c", - "channels": [guild_text_channel_payload, guild_voice_channel_payload, guild_news_channel_payload], - "default_message_notifications": 1, - "description": "This is a server I guess, its a bit crap though", - "discovery_splash": "famfamFAMFAMfam", - "embed_channel_id": "9439394949", - "embed_enabled": True, - "emojis": [known_custom_emoji_payload], - "explicit_content_filter": 2, - "features": ["ANIMATED_ICON", "MORE_EMOJI", "NEWS", "SOME_UNDOCUMENTED_FEATURE"], - "icon": "1a2b3c4d", - "id": "265828729970753537", - "joined_at": "2019-05-17T06:26:56.936000+00:00", - "large": False, - "max_members": 25000, - "max_presences": 250, - "max_video_channel_users": 25, - "member_count": 14, - "members": [member_payload], - "mfa_level": 1, - "name": "L33t guild", - "owner_id": "6969696", - "preferred_locale": "en-GB", - "premium_subscription_count": 1, - "premium_tier": 2, - "presences": [member_presence_payload], - "public_updates_channel_id": "33333333", - "roles": [guild_role_payload], - "rules_channel_id": "42042069", - "splash": "0ff0ff0ff", - "system_channel_flags": 3, - "system_channel_id": "19216801", - "unavailable": False, - "vanity_url_code": "loool", - "verification_level": 4, - "voice_states": [voice_state_payload], - "widget_channel_id": "9439394949", - "widget_enabled": True, - "nsfw_level": 0, - } - - def test_deserialize_gateway_guild( - self, - entity_factory_impl, - mock_app, - deserialize_gateway_guild_payload, - guild_text_channel_payload, - guild_voice_channel_payload, - guild_news_channel_payload, - known_custom_emoji_payload, - member_payload, - member_presence_payload, - guild_role_payload, - voice_state_payload, - ): - guild_definition = entity_factory_impl.deserialize_gateway_guild(deserialize_gateway_guild_payload) - guild = guild_definition.guild - assert guild.app is mock_app - assert guild.id == 265828729970753537 - assert guild.name == "L33t guild" - assert guild.icon_hash == "1a2b3c4d" - assert guild.features == [ - guild_models.GuildFeature.ANIMATED_ICON, - guild_models.GuildFeature.MORE_EMOJI, - guild_models.GuildFeature.NEWS, - "SOME_UNDOCUMENTED_FEATURE", - ] - assert guild.splash_hash == "0ff0ff0ff" - assert guild.discovery_splash_hash == "famfamFAMFAMfam" - assert guild.owner_id == 6969696 - assert guild.afk_channel_id == 99998888777766 - assert guild.afk_timeout == datetime.timedelta(seconds=1200) - assert guild.verification_level == guild_models.GuildVerificationLevel.VERY_HIGH - assert guild.default_message_notifications == guild_models.GuildMessageNotificationsLevel.ONLY_MENTIONS - assert guild.explicit_content_filter == guild_models.GuildExplicitContentFilterLevel.ALL_MEMBERS - assert guild.mfa_level == guild_models.GuildMFALevel.ELEVATED - assert guild.application_id == 39494949 - assert guild.widget_channel_id == 9439394949 - assert guild.is_widget_enabled is True - assert guild.system_channel_id == 19216801 - assert guild.system_channel_flags == guild_models.GuildSystemChannelFlag(3) - assert guild.rules_channel_id == 42042069 - assert guild.joined_at == datetime.datetime(2019, 5, 17, 6, 26, 56, 936000, tzinfo=datetime.timezone.utc) - assert guild.is_large is False - assert guild.member_count == 14 - assert guild.max_video_channel_users == 25 - assert guild.vanity_url_code == "loool" - assert guild.description == "This is a server I guess, its a bit crap though" - assert guild.banner_hash == "1a2b3c" - assert guild.premium_tier == guild_models.GuildPremiumTier.TIER_2 - assert guild.premium_subscription_count == 1 - assert guild.preferred_locale == "en-GB" - assert guild.public_updates_channel_id == 33333333 - assert guild.nsfw_level == guild_models.GuildNSFWLevel.DEFAULT - - assert guild_definition.roles == { - 41771983423143936: entity_factory_impl.deserialize_role( - guild_role_payload, guild_id=snowflakes.Snowflake(265828729970753537) - ) - } - assert guild_definition.emojis == { - 12345: entity_factory_impl.deserialize_known_custom_emoji( - known_custom_emoji_payload, guild_id=snowflakes.Snowflake(265828729970753537) - ) - } - assert guild_definition.members == { - 115590097100865541: entity_factory_impl.deserialize_member( - member_payload, guild_id=snowflakes.Snowflake(265828729970753537) - ) - } - assert guild_definition.channels == { - 123: entity_factory_impl.deserialize_guild_text_channel( - guild_text_channel_payload, guild_id=snowflakes.Snowflake(265828729970753537) - ), - 555: entity_factory_impl.deserialize_guild_voice_channel( - guild_voice_channel_payload, guild_id=snowflakes.Snowflake(265828729970753537) - ), - 7777: entity_factory_impl.deserialize_guild_news_channel( - guild_news_channel_payload, guild_id=snowflakes.Snowflake(265828729970753537) - ), - } - assert guild_definition.presences == { - 115590097100865541: entity_factory_impl.deserialize_member_presence( - member_presence_payload, guild_id=snowflakes.Snowflake(265828729970753537) - ) - } - assert guild_definition.voice_states == { - 115590097100865541: entity_factory_impl.deserialize_voice_state( - voice_state_payload, - guild_id=snowflakes.Snowflake(265828729970753537), - member=entity_factory_impl.deserialize_member( - member_payload, - guild_id=snowflakes.Snowflake(265828729970753537), - ), - ) - } - - def test_deserialize_gateway_guild_with_unset_fields(self, entity_factory_impl): - guild_definition = entity_factory_impl.deserialize_gateway_guild( - { - "afk_channel_id": "99998888777766", - "afk_timeout": 1200, - "application_id": "39494949", - "banner": "1a2b3c", - "default_message_notifications": 1, - "description": "This is a server I guess, its a bit crap though", - "discovery_splash": "famfamFAMFAMfam", - "emojis": [], - "explicit_content_filter": 2, - "features": ["ANIMATED_ICON", "MORE_EMOJI", "NEWS", "SOME_UNDOCUMENTED_FEATURE"], - "icon": "1a2b3c4d", - "id": "265828729970753537", - "mfa_level": 1, - "name": "L33t guild", - "owner_id": "6969696", - "preferred_locale": "en-GB", - "premium_tier": 2, - "public_updates_channel_id": "33333333", - "roles": [], - "rules_channel_id": "42042069", - "splash": "0ff0ff0ff", - "system_channel_flags": 3, - "system_channel_id": "19216801", - "vanity_url_code": "loool", - "verification_level": 4, - "nsfw_level": 0, - } - ) - guild = guild_definition.guild - assert guild.joined_at is None - assert guild.is_large is None - assert guild.max_video_channel_users is None - assert guild.member_count is None - assert guild.premium_subscription_count is None - assert guild.widget_channel_id is None - assert guild.is_widget_enabled is None - assert guild_definition.channels is None - assert guild_definition.members is None - assert guild_definition.presences is None - assert guild_definition.voice_states is None - - def test_deserialize_gateway_guild_with_null_fields(self, entity_factory_impl): - guild_definition = entity_factory_impl.deserialize_gateway_guild( - { - "afk_channel_id": None, - "afk_timeout": 1200, - "application_id": None, - "banner": None, - "channels": [], - "default_message_notifications": 1, - "description": None, - "discovery_splash": None, - "embed_channel_id": None, - "embed_enabled": True, - "emojis": [], - "explicit_content_filter": 2, - "features": ["ANIMATED_ICON", "MORE_EMOJI", "NEWS", "SOME_UNDOCUMENTED_FEATURE"], - "icon": None, - "id": "265828729970753537", - "joined_at": "2019-05-17T06:26:56.936000+00:00", - "large": False, - "max_members": 25000, - "max_presences": None, - "max_video_channel_users": 25, - "member_count": 14, - "members": [], - "mfa_level": 1, - "name": "L33t guild", - "owner_id": "6969696", - "permissions": 66_321_471, - "preferred_locale": "en-GB", - "premium_subscription_count": None, - "premium_tier": 2, - "presences": [], - "public_updates_channel_id": None, - "roles": [], - "rules_channel_id": None, - "splash": None, - "system_channel_flags": 3, - "system_channel_id": None, - "unavailable": False, - "vanity_url_code": None, - "verification_level": 4, - "voice_states": [], - "widget_channel_id": None, - "widget_enabled": True, - "nsfw_level": 0, - } - ) - guild = guild_definition.guild - assert guild.icon_hash is None - assert guild.splash_hash is None - assert guild.discovery_splash_hash is None - assert guild.afk_channel_id is None - assert guild.application_id is None - assert guild.widget_channel_id is None - assert guild.system_channel_id is None - assert guild.rules_channel_id is None - assert guild.vanity_url_code is None - assert guild.description is None - assert guild.banner_hash is None - assert guild.premium_subscription_count is None - assert guild.public_updates_channel_id is None + def test_deserialize_gateway_guild(self, entity_factory_impl): + guild_definition = entity_factory_impl.deserialize_gateway_guild({"id": "123123"}) - def test_deserialize_gateway_guild_ignores_unrecognised_channels( - self, entity_factory_impl, deserialize_gateway_guild_payload - ): - deserialize_gateway_guild_payload["channels"] = [{"id": 123, "type": 1000}] - guild_definition = entity_factory_impl.deserialize_gateway_guild(deserialize_gateway_guild_payload) - - assert guild_definition.channels == {} + assert guild_definition.id == 123123 + assert guild_definition._payload == {"id": "123123"} ###################### # INTERACTION MODELS # @@ -4308,42 +4485,6 @@ def test_deserialize_message_deserializes_old_stickers_field(self, entity_factor # PRESENCE MODELS # ################### - @pytest.fixture() - def presence_activity_payload(self, custom_emoji_payload): - return { - "name": "an activity", - "type": 1, - "url": "https://69.420.owouwunyaa", - "created_at": 1584996792798, - "timestamps": {"start": 1584996792798, "end": 1999999792798}, - "application_id": "40404040404040", - "details": "They are doing stuff", - "state": "STATED", - "emoji": custom_emoji_payload, - "party": {"id": "spotify:3234234234", "size": [2, 5]}, - "assets": { - "large_image": "34234234234243", - "large_text": "LARGE TEXT", - "small_image": "3939393", - "small_text": "small text", - }, - "secrets": {"join": "who's a good secret?", "spectate": "I'm a good secret", "match": "No."}, - "instance": True, - "flags": 3, - "buttons": ["owo", "no"], - } - - @pytest.fixture() - def member_presence_payload(self, user_payload, presence_activity_payload): - return { - "user": user_payload, - "activity": presence_activity_payload, - "guild_id": "44004040", - "status": "dnd", - "activities": [presence_activity_payload], - "client_status": {"desktop": "online", "mobile": "idle", "web": "dnd"}, - } - def test_deserialize_member_presence( self, entity_factory_impl, mock_app, member_presence_payload, custom_emoji_payload, user_payload ): @@ -4685,20 +4826,6 @@ def test_deserialize_template_with_null_fields(self, entity_factory_impl, templa # USER MODELS # ############### - @pytest.fixture() - def user_payload(self): - return { - "id": "115590097100865541", - "username": "nyaa", - "avatar": "b3b24c6d7cbcdec129d5d537067061a8", - "banner": "a_221313e1e2edsncsncsmcndsc", - "accent_color": 231321, - "discriminator": "6127", - "bot": True, - "system": True, - "public_flags": int(user_models.UserFlag.EARLY_VERIFIED_DEVELOPER), - } - def test_deserialize_user(self, entity_factory_impl, mock_app, user_payload): user = entity_factory_impl.deserialize_user(user_payload) assert user.app is mock_app @@ -4794,24 +4921,6 @@ def test_deserialize_my_user_with_unset_fields(self, entity_factory_impl, mock_a # VOICE MODELS # ################ - @pytest.fixture() - def voice_state_payload(self, member_payload): - return { - "guild_id": "929292929292992", - "channel_id": "157733188964188161", - "user_id": "115590097100865541", - "member": member_payload, - "session_id": "90326bd25d71d39b9ef95b299e3872ff", - "deaf": True, - "mute": True, - "self_deaf": False, - "self_mute": True, - "self_stream": True, - "self_video": True, - "suppress": False, - "request_to_speak_timestamp": "2021-04-17T10:11:19.970105+00:00", - } - def test_deserialize_voice_state_with_guild_id_in_payload( self, entity_factory_impl, mock_app, voice_state_payload, member_payload ): diff --git a/tests/hikari/impl/test_event_factory.py b/tests/hikari/impl/test_event_factory.py index a601fa0d01..49360f3c5a 100644 --- a/tests/hikari/impl/test_event_factory.py +++ b/tests/hikari/impl/test_event_factory.py @@ -209,17 +209,24 @@ def test_deserialize_guild_available_event(self, event_factory, mock_app, mock_s mock_payload = mock.Mock(app=mock_app) event = event_factory.deserialize_guild_available_event(mock_shard, mock_payload) - mock_app.entity_factory.deserialize_gateway_guild.assert_called_once_with(mock_payload) assert isinstance(event, guild_events.GuildAvailableEvent) assert event.shard is mock_shard - assert event.guild is mock_app.entity_factory.deserialize_gateway_guild.return_value.guild - assert event.emojis is mock_app.entity_factory.deserialize_gateway_guild.return_value.emojis - assert event.roles is mock_app.entity_factory.deserialize_gateway_guild.return_value.roles - assert event.channels is mock_app.entity_factory.deserialize_gateway_guild.return_value.channels - assert event.members is mock_app.entity_factory.deserialize_gateway_guild.return_value.members - assert event.presences is mock_app.entity_factory.deserialize_gateway_guild.return_value.presences - assert event.voice_states is mock_app.entity_factory.deserialize_gateway_guild.return_value.voice_states + guild_definition = mock_app.entity_factory.deserialize_gateway_guild.return_value + assert event.guild is guild_definition.guild.return_value + assert event.emojis is guild_definition.emojis.return_value + assert event.roles is guild_definition.roles.return_value + assert event.channels is guild_definition.channels.return_value + assert event.members is guild_definition.members.return_value + assert event.presences is guild_definition.presences.return_value + assert event.voice_states is guild_definition.voice_states.return_value + guild_definition.guild.assert_called_once_with() + guild_definition.emojis.assert_called_once_with() + guild_definition.roles.assert_called_once_with() + guild_definition.channels.assert_called_once_with() + guild_definition.members.assert_called_once_with() + guild_definition.presences.assert_called_once_with() + guild_definition.voice_states.assert_called_once_with() def test_deserialize_guild_join_event(self, event_factory, mock_app, mock_shard): mock_payload = mock.Mock(app=mock_app) @@ -246,10 +253,14 @@ def test_deserialize_guild_update_event(self, event_factory, mock_app, mock_shar mock_app.entity_factory.deserialize_gateway_guild.assert_called_once_with(mock_payload) assert isinstance(event, guild_events.GuildUpdateEvent) assert event.shard is mock_shard - assert event.guild is mock_app.entity_factory.deserialize_gateway_guild.return_value.guild - assert event.emojis is mock_app.entity_factory.deserialize_gateway_guild.return_value.emojis - assert event.roles is mock_app.entity_factory.deserialize_gateway_guild.return_value.roles + guild_definition = mock_app.entity_factory.deserialize_gateway_guild.return_value + assert event.guild is guild_definition.guild.return_value + assert event.emojis is guild_definition.emojis.return_value + assert event.roles is guild_definition.roles.return_value assert event.old_guild is mock_old_guild + guild_definition.guild.assert_called_once_with() + guild_definition.emojis.assert_called_once_with() + guild_definition.roles.assert_called_once_with() def test_deserialize_guild_leave_event(self, event_factory, mock_app, mock_shard): mock_payload = {"id": "43123123"} diff --git a/tests/hikari/impl/test_event_manager.py b/tests/hikari/impl/test_event_manager.py index 9840fc76dd..9d985de274 100644 --- a/tests/hikari/impl/test_event_manager.py +++ b/tests/hikari/impl/test_event_manager.py @@ -29,9 +29,12 @@ import pytest from hikari import channels +from hikari import config from hikari import errors from hikari import intents from hikari import presences +from hikari.events import guild_events +from hikari.events import shard_events from hikari.impl import event_manager from hikari.internal import time from tests.hikari import hikari_test_helpers @@ -85,23 +88,27 @@ async def test__request_guild_members_handles_state_conflict_error(shard): class TestEventManagerImpl: + @pytest.fixture() + def entity_factory(self): + return mock.Mock() + @pytest.fixture() def event_factory(self): return mock.Mock() @pytest.fixture() - def event_manager(self, event_factory): + def event_manager(self, entity_factory, event_factory): obj = hikari_test_helpers.mock_class_namespace(event_manager.EventManagerImpl, slots_=False)( - event_factory, intents.Intents.ALL, cache=mock.Mock() + entity_factory, event_factory, intents.Intents.ALL, cache=mock.Mock(settings=config.CacheSettings()) ) obj.dispatch = mock.AsyncMock() return obj @pytest.fixture() - def stateless_event_manager(self, event_factory): + def stateless_event_manager(self, event_factory, entity_factory): obj = hikari_test_helpers.mock_class_namespace(event_manager.EventManagerImpl, slots_=False)( - event_factory, intents.Intents.ALL, cache=None + entity_factory, event_factory, intents.Intents.ALL, cache=None ) obj.dispatch = mock.AsyncMock() @@ -272,8 +279,7 @@ async def test_on_guild_create_stateful_with_unavailable_field(self, event_manag event_factory.deserialize_guild_available_event.assert_called_once_with(shard, payload) event_manager.dispatch.assert_awaited_once_with(event) - @pytest.mark.asyncio() - async def test_on_guild_create_stateful_without_unavailable_field(self, event_manager, shard, event_factory): + async def test_on_guild_create_stateful_and_dispatching(self, event_manager, shard, event_factory): payload = {} event = mock.Mock( guild=mock.Mock(id=123, is_large=False), @@ -286,6 +292,8 @@ async def test_on_guild_create_stateful_without_unavailable_field(self, event_ma chunk_nonce=None, ) + event_factory.deserialize_guild_join_event.return_value = event + event_manager._enabled_for_event = mock.Mock(return_value=True) event_factory.deserialize_guild_join_event.return_value = event shard.request_guild_members = mock.AsyncMock() @@ -293,6 +301,9 @@ async def test_on_guild_create_stateful_without_unavailable_field(self, event_ma assert event.chunk_nonce is None shard.request_guild_members.assert_not_called() + event_manager._enabled_for_event.assert_has_calls( + [mock.call(guild_events.GuildAvailableEvent), mock.call(shard_events.MemberChunkEvent)] + ) event_manager._cache.update_guild.assert_called_once_with(event.guild) @@ -314,6 +325,8 @@ async def test_on_guild_create_stateful_without_unavailable_field(self, event_ma event_manager._cache.clear_voice_states_for_guild.assert_called_once_with(123) event_manager._cache.set_voice_state.assert_called_once_with(345) + event_factory.deserialize_guild_join_event.assert_called_once_with(shard, payload) + event_factory.deserialize_gateway_guild.assert_not_called() event_factory.deserialize_guild_join_event.assert_called_once_with(shard, payload) event_manager.dispatch.assert_awaited_once_with(event) @@ -378,10 +391,105 @@ async def test_on_guild_create_when_request_chunks_with_unavailable_field( event_manager.dispatch.assert_awaited_once_with(event) @pytest.mark.asyncio() - async def test_on_guild_create_when_request_chunks_without_unavailable_field( + async def test_on_guild_create_stateful_and_not_dispatching_with_all_cache_components( + self, event_manager, shard, entity_factory, event_factory + ): + payload = {"id": "123"} + mock_channel = object() + mock_emoji = object() + mock_role = object() + mock_member = object() + mock_presence = object() + mock_voice_state = object() + guild_definition = entity_factory.deserialize_gateway_guild.return_value + guild_definition.id = 123 + guild_definition.guild.return_value = mock.Mock(id=123, is_large=False) + guild_definition.channels.return_value = {456: mock_channel} + guild_definition.emojis.return_value = {789: mock_emoji} + guild_definition.roles.return_value = {1234: mock_role} + guild_definition.members.return_value = {5678: mock_member} + guild_definition.presences.return_value = {9012: mock_presence} + guild_definition.voice_states.return_value = {345: mock_voice_state} + + event_manager._enabled_for_event = mock.Mock(return_value=False) + shard.request_guild_members = mock.AsyncMock() + + await event_manager.on_guild_create(shard, payload) + + shard.request_guild_members.assert_not_called() + event_manager._enabled_for_event.assert_has_calls( + [mock.call(guild_events.GuildAvailableEvent), mock.call(shard_events.MemberChunkEvent)] + ) + + event_manager._cache.update_guild.assert_called_once_with(guild_definition.guild.return_value) + + event_manager._cache.clear_guild_channels_for_guild.assert_called_once_with(123) + event_manager._cache.set_guild_channel.assert_called_once_with(mock_channel) + + event_manager._cache.clear_emojis_for_guild.assert_called_once_with(123) + event_manager._cache.set_emoji.assert_called_once_with(mock_emoji) + + event_manager._cache.clear_roles_for_guild.assert_called_once_with(123) + event_manager._cache.set_role.assert_called_once_with(mock_role) + + event_manager._cache.clear_members_for_guild.assert_called_once_with(123) + event_manager._cache.set_member.assert_called_once_with(mock_member) + + event_manager._cache.clear_presences_for_guild.assert_called_once_with(123) + event_manager._cache.set_presence.assert_called_once_with(mock_presence) + + event_manager._cache.clear_voice_states_for_guild.assert_called_once_with(123) + event_manager._cache.set_voice_state.assert_called_once_with(mock_voice_state) + + entity_factory.deserialize_gateway_guild.assert_called_once_with(payload) + event_factory.deserialize_guild_create_event.assert_not_called() + event_manager.dispatch.assert_not_called() + + @pytest.mark.asyncio() + async def test_on_guild_create_stateful_and_not_dispatching_with_no_cache_components( + self, event_manager, shard, event_factory, entity_factory + ): + payload = {"id": "123"} + event_manager._cache_enabled_for = mock.Mock(return_value=False) + event_manager._enabled_for_event = mock.Mock(return_value=False) + shard.request_guild_members = mock.AsyncMock() + + await event_manager.on_guild_create(shard, payload) + + shard.request_guild_members.assert_not_called() + event_manager._enabled_for_event.assert_has_calls( + [mock.call(guild_events.GuildAvailableEvent), mock.call(shard_events.MemberChunkEvent)] + ) + + event_manager._cache.update_guild.assert_not_called() + + event_manager._cache.clear_guild_channels_for_guild.assert_not_called() + event_manager._cache.set_guild_channel.assert_not_called() + + event_manager._cache.clear_emojis_for_guild.assert_not_called() + event_manager._cache.set_emoji.assert_not_called() + + event_manager._cache.clear_roles_for_guild.assert_not_called() + event_manager._cache.set_role.assert_not_called() + + event_manager._cache.clear_members_for_guild.assert_not_called() + event_manager._cache.set_member.assert_not_called() + + event_manager._cache.clear_presences_for_guild.assert_not_called() + event_manager._cache.set_presence.assert_not_called() + + event_manager._cache.clear_voice_states_for_guild.assert_not_called() + event_manager._cache.set_voice_state.assert_not_called() + + entity_factory.deserialize_gateway_guild.assert_called_once_with(payload) + event_factory.deserialize_guild_create_event.assert_not_called() + event_manager.dispatch.assert_not_called() + + @pytest.mark.asyncio() + async def test_on_guild_create_when_request_chunks_when_dispatching_available_event( self, event_manager, shard, event_factory ): - payload = {} + payload = {"large": True, "id": 123} event = mock.Mock( guild=mock.Mock(id=123, is_large=True), channels={"TestChannel": 456}, @@ -393,7 +501,9 @@ async def test_on_guild_create_when_request_chunks_without_unavailable_field( chunk_nonce=None, ) + event_manager._enabled_for_event = mock.Mock(return_value=True) event_factory.deserialize_guild_join_event.return_value = event + event_manager._cache.settings.components = config.CacheComponents.MEMBERS shard.request_guild_members = mock.Mock() stack = contextlib.ExitStack() @@ -406,36 +516,54 @@ async def test_on_guild_create_when_request_chunks_without_unavailable_field( with stack: await event_manager.on_guild_create(shard, payload) + event_manager._enabled_for_event.assert_has_calls( + [mock.call(guild_events.GuildAvailableEvent), mock.call(shard_events.MemberChunkEvent)] + ) uuid.assert_called_once_with() nonce = "987.uuid" assert event.chunk_nonce == nonce - _request_guild_members.assert_called_once_with(shard, event.guild, include_presences=True, nonce=nonce) + _request_guild_members.assert_called_once_with(shard, 123, include_presences=True, nonce=nonce) create_task.assert_called_once_with( _request_guild_members.return_value, name="987:123 guild create members request" ) - event_manager._cache.update_guild.assert_called_once_with(event.guild) + event_factory.deserialize_guild_create_event.assert_called_once_with(shard, payload) - event_manager._cache.clear_guild_channels_for_guild.assert_called_once_with(123) - event_manager._cache.set_guild_channel.assert_called_once_with(456) + @pytest.mark.asyncio() + async def test_on_guild_create_when_request_chunks_when_not_dispatching_available_event( + self, stateless_event_manager, shard, entity_factory, event_factory, event_manager + ): + payload = {"large": True, "id": 123} - event_manager._cache.clear_emojis_for_guild.assert_called_once_with(123) - event_manager._cache.set_emoji.assert_called_once_with(789) + stateless_event_manager._enabled_for_event = mock.Mock(side_effect=[False, True]) + entity_factory.deserialize_gateway_guild.return_value.id = 123 - event_manager._cache.clear_roles_for_guild.assert_called_once_with(123) - event_manager._cache.set_role.assert_called_once_with(1234) + stack = contextlib.ExitStack() + _request_guild_members = stack.enter_context( + mock.patch("hikari.impl.event_manager._request_guild_members", new_callable=mock.Mock) + ) + create_task = stack.enter_context(mock.patch.object(asyncio, "create_task")) + uuid = stack.enter_context(mock.patch("hikari.impl.event_manager._fixed_size_nonce", return_value="uuid")) - event_manager._cache.clear_members_for_guild.assert_called_once_with(123) - event_manager._cache.set_member.assert_called_once_with(5678) + with stack: + await stateless_event_manager.on_guild_create(shard, payload) - event_manager._cache.clear_presences_for_guild.assert_called_once_with(123) - event_manager._cache.set_presence.assert_called_once_with(9012) + stateless_event_manager._enabled_for_event.assert_has_calls( + [mock.call(guild_events.GuildAvailableEvent), mock.call(shard_events.MemberChunkEvent)] + ) + uuid.assert_called_once_with() + _request_guild_members.assert_called_once_with(shard, 123, include_presences=True, nonce="987.uuid") + create_task.assert_called_once_with( + _request_guild_members.return_value, name="987:123 guild create members request" + ) + + event_factory.deserialize_guild_join_event.assert_not_called() + entity_factory.deserialize_gateway_guild.assert_not_called() event_manager._cache.clear_voice_states_for_guild.assert_called_once_with(123) event_manager._cache.set_voice_state.assert_called_once_with(345) event_factory.deserialize_guild_join_event.assert_called_once_with(shard, payload) - event_manager.dispatch.assert_awaited_once_with(event) @pytest.mark.asyncio() async def test_on_guild_create_stateless_with_unavailable_field( @@ -452,27 +580,52 @@ async def test_on_guild_create_stateless_with_unavailable_field( event_factory.deserialize_guild_available_event.return_value ) - @pytest.mark.asyncio() - async def test_on_guild_create_stateless_without_unavailable_field( - self, stateless_event_manager, shard, event_factory + async def test_on_guild_create_stateless_and_dispatching( + self, stateless_event_manager, shard, event_factory, entity_factory ): - payload = {} + payload = {"id": "123123"} + stateless_event_manager._enabled_for_event = mock.Mock(return_value=True) shard.request_guild_members = mock.AsyncMock() await stateless_event_manager.on_guild_create(shard, payload) + shard.request_guild_members.assert_not_called() + stateless_event_manager._enabled_for_event.assert_has_calls( + [mock.call(guild_events.GuildAvailableEvent), mock.call(shard_events.MemberChunkEvent)] + ) + entity_factory.deserialize_gateway_guild.assert_not_called() event_factory.deserialize_guild_join_event.assert_called_once_with(shard, payload) stateless_event_manager.dispatch.assert_awaited_once_with( event_factory.deserialize_guild_join_event.return_value ) @pytest.mark.asyncio() - async def test_on_guild_update_stateful(self, event_manager, shard, event_factory): + async def test_on_guild_create_stateless_and_not_dispatching( + self, stateless_event_manager, shard, event_factory, entity_factory + ): + payload = {"id": "123123"} + stateless_event_manager._enabled_for_event = mock.Mock(return_value=False) + + shard.request_guild_members = mock.AsyncMock() + + await stateless_event_manager.on_guild_create(shard, payload) + + shard.request_guild_members.assert_not_called() + stateless_event_manager._enabled_for_event.assert_has_calls( + [mock.call(guild_events.GuildAvailableEvent), mock.call(shard_events.MemberChunkEvent)] + ) + entity_factory.deserialize_gateway_guild.assert_not_called() + event_factory.deserialize_guild_create_event.assert_not_called() + stateless_event_manager.dispatch.assert_not_called() + + @pytest.mark.asyncio() + async def test_on_guild_update_stateful_and_dispatching(self, event_manager, shard, event_factory, entity_factory): payload = {"id": 123} old_guild = object() mock_role = object() mock_emoji = object() + event_manager._enabled_for_event = mock.Mock(return_value=True) event = mock.Mock(roles={555: mock_role}, emojis={333: mock_emoji}, guild=mock.Mock(id=123)) event_factory.deserialize_guild_update_event.return_value = event @@ -480,26 +633,98 @@ async def test_on_guild_update_stateful(self, event_manager, shard, event_factor await event_manager.on_guild_update(shard, payload) + event_manager._enabled_for_event.assert_called_once_with(guild_events.GuildUpdateEvent) event_manager._cache.get_guild.assert_called_once_with(123) event_manager._cache.update_guild.assert_called_once_with(event.guild) event_manager._cache.clear_roles_for_guild.assert_called_once_with(123) event_manager._cache.set_role.assert_called_once_with(mock_role) event_manager._cache.clear_emojis_for_guild.assert_called_once_with(123) event_manager._cache.set_emoji.assert_called_once_with(mock_emoji) + entity_factory.deserialize_gateway_guild.assert_not_called() event_factory.deserialize_guild_update_event.assert_called_once_with(shard, payload, old_guild=old_guild) event_manager.dispatch.assert_awaited_once_with(event) @pytest.mark.asyncio() - async def test_on_guild_update_stateless(self, stateless_event_manager, shard, event_factory): + async def test_on_guild_update_all_cache_components_and_not_dispatching( + self, event_manager, shard, event_factory, entity_factory + ): + payload = {"id": 123} + mock_role = object() + mock_emoji = object() + event_manager._enabled_for_event = mock.Mock(return_value=False) + guild_definition = entity_factory.deserialize_gateway_guild.return_value + guild_definition.id = 123 + guild_definition.emojis.return_value = {0: mock_emoji} + guild_definition.roles.return_value = {1: mock_role} + + await event_manager.on_guild_update(shard, payload) + + entity_factory.deserialize_gateway_guild.assert_called_once_with({"id": 123}) + event_manager._enabled_for_event.assert_called_once_with(guild_events.GuildUpdateEvent) + event_manager._cache.update_guild.assert_called_once_with(guild_definition.guild.return_value) + event_manager._cache.clear_emojis_for_guild.assert_called_once_with(123) + event_manager._cache.set_emoji.assert_called_once_with(mock_emoji) + event_manager._cache.clear_roles_for_guild.assert_called_once_with(123) + event_manager._cache.set_role.assert_called_once_with(mock_role) + event_factory.deserialize_guild_update_event.assert_not_called() + event_manager.dispatch.assert_not_called() + guild_definition.emojis.assert_called_once_with() + guild_definition.roles.assert_called_once_with() + guild_definition.guild.assert_called_once_with() + + @pytest.mark.asyncio() + async def test_on_guild_update_no_cache_components_and_not_dispatching( + self, event_manager, shard, event_factory, entity_factory + ): + payload = {"id": 123} + event_manager._cache_enabled_for = mock.Mock(return_value=False) + event_manager._enabled_for_event = mock.Mock(return_value=False) + guild_definition = entity_factory.deserialize_gateway_guild.return_value + + await event_manager.on_guild_update(shard, payload) + + entity_factory.deserialize_gateway_guild.assert_called_once_with({"id": 123}) + event_manager._enabled_for_event.assert_called_once_with(guild_events.GuildUpdateEvent) + event_manager._cache.update_guild.assert_not_called() + event_manager._cache.clear_emojis_for_guild.assert_not_called() + event_manager._cache.set_emoji.assert_not_called() + event_manager._cache.clear_roles_for_guild.assert_not_called() + event_manager._cache.set_role.assert_not_called() + event_factory.deserialize_guild_update_event.assert_not_called() + event_manager.dispatch.assert_not_called() + guild_definition.emojis.assert_not_called() + guild_definition.roles.assert_not_called() + guild_definition.guild.assert_not_called() + + @pytest.mark.asyncio() + async def test_on_guild_update_stateless_and_dispatching( + self, stateless_event_manager, shard, event_factory, entity_factory + ): payload = {"id": 123} + stateless_event_manager._enabled_for_event = mock.Mock(return_value=True) await stateless_event_manager.on_guild_update(shard, payload) + stateless_event_manager._enabled_for_event.assert_called_once_with(guild_events.GuildUpdateEvent) + entity_factory.deserialize_gateway_guild.assert_not_called() event_factory.deserialize_guild_update_event.assert_called_once_with(shard, payload, old_guild=None) stateless_event_manager.dispatch.assert_awaited_once_with( event_factory.deserialize_guild_update_event.return_value ) + @pytest.mark.asyncio() + async def test_on_guild_update_stateless_and_not_dispatching( + self, stateless_event_manager, shard, entity_factory, event_factory + ): + stateless_event_manager._enabled_for_event = mock.Mock(return_value=False) + + await stateless_event_manager.on_guild_update(shard, {"id": 123}) + + stateless_event_manager._enabled_for_event.assert_called_once_with(guild_events.GuildUpdateEvent) + entity_factory.deserialize_gateway_guild.assert_not_called() + event_factory.deserialize_guild_update_event.assert_not_called() + stateless_event_manager.dispatch.assert_not_called() + @pytest.mark.asyncio() async def test_on_guild_delete_stateful_when_available(self, event_manager, shard, event_factory): payload = {"unavailable": False, "id": "123"} @@ -611,7 +836,8 @@ async def test_on_guild_emojis_update_stateless(self, stateless_event_manager, s @pytest.mark.asyncio() async def test_on_guild_integrations_update(self, event_manager, shard): - assert await event_manager.on_guild_integrations_update(shard, {}) is None + with pytest.raises(NotImplementedError): + await event_manager.on_guild_integrations_update(shard, {}) event_manager.dispatch.assert_not_called() diff --git a/tests/hikari/impl/test_event_manager_base.py b/tests/hikari/impl/test_event_manager_base.py index 8ff38bebc6..aebbfb281b 100644 --- a/tests/hikari/impl/test_event_manager_base.py +++ b/tests/hikari/impl/test_event_manager_base.py @@ -29,11 +29,14 @@ import mock import pytest +from hikari import config from hikari import errors from hikari import intents from hikari import iterators +from hikari import undefined from hikari.events import base_events from hikari.events import member_events +from hikari.events import shard_events from hikari.impl import event_manager_base from hikari.internal import reflect from tests.hikari import hikari_test_helpers @@ -397,20 +400,160 @@ class EventManagerBaseImpl(event_manager_base.EventManagerBase): def test___init___loads_consumers(self): class StubManager(event_manager_base.EventManagerBase): + @event_manager_base.filtered(shard_events.ShardEvent, config.CacheComponents.MEMBERS) async def on_foo(self, event): raise NotImplementedError + @event_manager_base.filtered((shard_events.ShardStateEvent, shard_events.ShardPayloadEvent)) async def on_bar(self, event): raise NotImplementedError + @event_manager_base.filtered(shard_events.MemberChunkEvent, config.CacheComponents.MESSAGES) + async def on_bat(self, event): + raise NotImplementedError + + async def on_not_decorated(self, event): + raise NotImplementedError + async def not_a_listener(self): raise NotImplementedError - manager = StubManager(mock.Mock(), mock.Mock(intents=42)) - assert manager._consumers == {"foo": manager.on_foo, "bar": manager.on_bar} + expected_bar_events = ( + shard_events.ShardStateEvent, + shard_events.ShardEvent, + base_events.Event, + shard_events.ShardPayloadEvent, + ) + expected_bat_events = (shard_events.MemberChunkEvent, shard_events.ShardEvent, base_events.Event) + manager = StubManager( + mock.Mock(), + 0, + cache_components=config.CacheComponents.MEMBERS | config.CacheComponents.GUILD_CHANNELS, + ) + assert manager._consumers == { + "foo": event_manager_base._Consumer(manager.on_foo, (shard_events.ShardEvent, base_events.Event), True), + "bar": event_manager_base._Consumer(manager.on_bar, expected_bar_events, False), + "bat": event_manager_base._Consumer(manager.on_bat, expected_bat_events, False), + "not_decorated": event_manager_base._Consumer(manager.on_not_decorated, undefined.UNDEFINED, True), + } + + def test___init___loads_consumers_when_cacheless(self): + class StubManager(event_manager_base.EventManagerBase): + @event_manager_base.filtered(shard_events.ShardEvent, config.CacheComponents.MEMBERS) + async def on_foo(self, event): + raise NotImplementedError + + @event_manager_base.filtered((shard_events.ShardStateEvent, shard_events.ShardPayloadEvent)) + async def on_bar(self, event): + raise NotImplementedError + + @event_manager_base.filtered(shard_events.MemberChunkEvent, config.CacheComponents.MESSAGES) + async def on_bat(self, event): + raise NotImplementedError + + async def on_not_decorated(self, event): + raise NotImplementedError + + async def not_a_listener(self): + raise NotImplementedError + + expected_bar_events = ( + shard_events.ShardStateEvent, + shard_events.ShardEvent, + base_events.Event, + shard_events.ShardPayloadEvent, + ) + expected_bat_events = (shard_events.MemberChunkEvent, shard_events.ShardEvent, base_events.Event) + manager = StubManager(mock.Mock(), 0, cache_components=config.CacheComponents.NONE) + assert manager._consumers == { + "foo": event_manager_base._Consumer(manager.on_foo, (shard_events.ShardEvent, base_events.Event), False), + "bar": event_manager_base._Consumer(manager.on_bar, expected_bar_events, False), + "bat": event_manager_base._Consumer(manager.on_bat, expected_bat_events, False), + "not_decorated": event_manager_base._Consumer(manager.on_not_decorated, undefined.UNDEFINED, False), + } + + def test__clear_enabled_cache(self): + event_manager = hikari_test_helpers.mock_class_namespace(event_manager_base.EventManagerBase, init_=False)() + event_manager._enabled_consumers_cache = {object: object(), "ok": object()} + + event_manager._clear_enabled_cache() + + assert event_manager._enabled_consumers_cache == {} + + def test__enabled_for_event_when_listener_registered(self, event_manager): + event_manager._listeners = {} + + def test__enabled_for_event_when_waiter_registered(self, event_manager): + event_manager._listeners = {} + + def test__enabled_for_event_when_not_registered(self, event_manager): + event_manager._listeners = {shard_events.ShardPayloadEvent: [], shard_events.MemberChunkEvent: []} + + assert event_manager._enabled_for_event(shard_events.ShardStateEvent) is False + + def test__enabled_for_consumer_when_event_types_is_undefined(self, event_manager): + consumer = mock.Mock(event_types=undefined.UNDEFINED) + + assert event_manager._enabled_for_consumer(consumer) is True + + def test__enabled_for_consumer_when_caching(self, event_manager): + consumer = mock.Mock(event_types=(), is_caching=True) + + assert event_manager._enabled_for_consumer(consumer) is True + + @pytest.mark.parametrize("cached_state", [False, True]) + def test__enabled_for_consumer_when_consumer_state_cached(self, event_manager, cached_state): + consumer = mock.Mock(event_types=(), is_caching=False) + event_manager._enabled_consumers_cache[consumer] = cached_state + + assert event_manager._enabled_for_consumer(consumer) is cached_state + + def test__enabled_for_consumer_when_consumer_state_not_cached_and_listeners_present(self, event_manager): + event_manager._listeners[shard_events.MemberChunkEvent] = [] + consumer = mock.Mock( + event_types=(shard_events.ShardEvent, shard_events.ShardPayloadEvent, shard_events.MemberChunkEvent), + is_caching=False, + ) + + result = event_manager._enabled_for_consumer(consumer) + + assert result is True + assert event_manager._enabled_consumers_cache[consumer] is True + + def test__enabled_for_consumer_when_consumer_state_not_cached_and_waiters_present(self, event_manager): + event_manager._waiters[shard_events.MemberChunkEvent] = [] + consumer = mock.Mock( + event_types=(shard_events.ShardEvent, shard_events.ShardPayloadEvent, shard_events.MemberChunkEvent), + is_caching=False, + ) + + result = event_manager._enabled_for_consumer(consumer) + + assert result is True + assert event_manager._enabled_consumers_cache[consumer] is True + + def test__enabled_for_consumer_when_consumer_state_not_cached_and_not_enabled(self, event_manager): + consumer = mock.Mock( + event_types=(shard_events.ShardEvent, shard_events.ShardPayloadEvent, shard_events.MemberChunkEvent), + is_caching=False, + ) + + result = event_manager._enabled_for_consumer(consumer) + + assert result is False + assert event_manager._enabled_consumers_cache[consumer] is False + + def test__enabled_for_consumer_when_consumer_state_not_cached_and_no_event_types(self, event_manager): + consumer = mock.Mock(event_types=(), is_caching=False) + + result = event_manager._enabled_for_consumer(consumer) + + assert result is False + assert event_manager._enabled_consumers_cache[consumer] is False @pytest.mark.asyncio() async def test_consume_raw_event_when_KeyError(self, event_manager): + event_manager._enabled_for_event = mock.Mock(return_value=True) mock_payload = {"id": "3123123123"} mock_shard = mock.Mock(id=123) event_manager._handle_dispatch = mock.Mock() @@ -426,9 +569,11 @@ async def test_consume_raw_event_when_KeyError(self, event_manager): event_manager._event_factory.deserialize_shard_payload_event.assert_called_once_with( mock_shard, mock_payload, name="UNEXISTING_EVENT" ) + event_manager._enabled_for_event.assert_called_once_with(shard_events.ShardPayloadEvent) @pytest.mark.asyncio() async def test_consume_raw_event_when_found(self, event_manager): + event_manager._enabled_for_event = mock.Mock(return_value=True) event_manager._handle_dispatch = mock.Mock() event_manager.dispatch = mock.Mock() on_existing_event = object() @@ -450,43 +595,69 @@ async def test_consume_raw_event_when_found(self, event_manager): event_manager._event_factory.deserialize_shard_payload_event.assert_called_once_with( shard, payload, name="EXISTING_EVENT" ) + event_manager._enabled_for_event.assert_called_once_with(shard_events.ShardPayloadEvent) + + @pytest.mark.asyncio() + async def test_consume_raw_event_skips_raw_dispatch_when_not_enabled(self, event_manager): + event_manager._enabled_for_event = mock.Mock(return_value=False) + event_manager._handle_dispatch = mock.Mock() + event_manager.dispatch = mock.Mock() + on_existing_event = object() + event_manager._consumers = {"existing_event": on_existing_event} + shard = object() + payload = {"berp": "baz"} + + with mock.patch("asyncio.create_task") as create_task: + event_manager.consume_raw_event("EXISTING_EVENT", shard, payload) + + event_manager._handle_dispatch.assert_called_once_with(on_existing_event, shard, {"berp": "baz"}) + create_task.assert_called_once_with( + event_manager._handle_dispatch(on_existing_event, shard, {"berp": "baz"}), + name="dispatch EXISTING_EVENT", + ) + event_manager.dispatch.assert_not_called() + event_manager._event_factory.deserialize_shard_payload_event.vassert_not_called() + event_manager._enabled_for_event.assert_called_once_with(shard_events.ShardPayloadEvent) @pytest.mark.asyncio() async def test_handle_dispatch_invokes_callback(self, event_manager, event_loop): - callback = mock.AsyncMock() + event_manager._enabled_for_consumer = mock.Mock(return_value=True) + consumer = mock.AsyncMock() error_handler = mock.MagicMock() event_loop.set_exception_handler(error_handler) shard = object() pl = {"foo": "bar"} - await event_manager._handle_dispatch(callback, shard, pl) + await event_manager._handle_dispatch(consumer, shard, pl) - callback.assert_awaited_once_with(shard, pl) + consumer.callback.assert_awaited_once_with(shard, pl) error_handler.assert_not_called() @pytest.mark.asyncio() async def test_handle_dispatch_ignores_cancelled_errors(self, event_manager, event_loop): - callback = mock.AsyncMock(side_effect=asyncio.CancelledError) + event_manager._enabled_for_consumer = mock.Mock(return_value=True) + consumer = mock.AsyncMock(side_effect=asyncio.CancelledError) error_handler = mock.MagicMock() event_loop.set_exception_handler(error_handler) shard = object() pl = {"lorem": "ipsum"} - await event_manager._handle_dispatch(callback, shard, pl) + await event_manager._handle_dispatch(consumer, shard, pl) error_handler.assert_not_called() @pytest.mark.asyncio() async def test_handle_dispatch_handles_exceptions(self, event_manager, event_loop): + event_manager._enabled_for_consumer = mock.Mock(return_value=True) exc = Exception("aaaa!") - callback = mock.AsyncMock(side_effect=exc) + consumer = mock.Mock(callback=mock.AsyncMock(side_effect=exc)) error_handler = mock.MagicMock() event_loop.set_exception_handler(error_handler) shard = object() pl = {"i like": "cats"} with mock.patch.object(asyncio, "current_task") as current_task: - await event_manager._handle_dispatch(callback, shard, pl) + await event_manager._handle_dispatch(consumer, shard, pl) error_handler.assert_called_once_with( event_loop, @@ -497,6 +668,20 @@ async def test_handle_dispatch_handles_exceptions(self, event_manager, event_loo }, ) + @pytest.mark.asyncio() + async def test_handle_dispatch_invokes_when_consumer_not_enabled(self, event_manager, event_loop): + event_manager._enabled_for_consumer = mock.Mock(return_value=False) + consumer = mock.Mock(callback=mock.AsyncMock(__name__="ok")) + error_handler = mock.MagicMock() + event_loop.set_exception_handler(error_handler) + shard = object() + pl = {"foo": "bar"} + + await event_manager._handle_dispatch(consumer, shard, pl) + + consumer.callback.assert_not_called() + error_handler.assert_not_called() + def test_subscribe_when_callback_is_not_coroutine(self, event_manager): def test(): ... @@ -512,6 +697,8 @@ async def test(): event_manager.subscribe("test", test) def test_subscribe_when_event_type_not_in_listeners(self, event_manager): + event_manager._clear_enabled_cache = mock.Mock() + async def test(): ... @@ -520,6 +707,7 @@ async def test(): assert event_manager._listeners == {member_events.MemberCreateEvent: [test]} check.assert_called_once_with(member_events.MemberCreateEvent, 1) + event_manager._clear_enabled_cache.assert_called_once_with() def test_subscribe_when_event_type_in_listeners(self, event_manager): async def test(): @@ -528,6 +716,7 @@ async def test(): async def test2(): ... + event_manager._clear_enabled_cache = mock.Mock() event_manager._listeners[member_events.MemberCreateEvent] = [test2] with mock.patch.object(event_manager_base.EventManagerBase, "_check_intents") as check: @@ -535,6 +724,7 @@ async def test2(): assert event_manager._listeners == {member_events.MemberCreateEvent: [test2, test]} check.assert_called_once_with(member_events.MemberCreateEvent, 2) + event_manager._clear_enabled_cache.assert_not_called() def test__check_intents_when_no_intents_required(self, event_manager): event_manager._intents = intents.Intents.ALL @@ -578,7 +768,7 @@ def test__check_intents_when_intents_incorrect(self, event_manager): def test_get_listeners_when_not_event(self, event_manager): assert len(event_manager.get_listeners("test")) == 0 - def test_get_listeners_polimorphic(self, event_manager): + def test_get_listeners_polymorphic(self, event_manager): event_manager._listeners = { base_events.Event: ["this will never appear"], member_events.MemberEvent: ["coroutine0"], @@ -596,17 +786,16 @@ def test_get_listeners_polimorphic(self, event_manager): "coroutine5", ] - def test_get_listeners_no_polimorphic_and_no_results(self, event_manager): + def test_get_listeners_monomorphic_and_no_results(self, event_manager): event_manager._listeners = { member_events.MemberCreateEvent: ["coroutine1", "coroutine2"], member_events.MemberUpdateEvent: ["coroutine3"], member_events.MemberDeleteEvent: ["coroutine4", "coroutine5"], - member_events.MemberDeleteEvent: ["coroutine4", "coroutine5"], } assert event_manager.get_listeners(member_events.MemberEvent, polymorphic=False) == [] - def test_get_listeners_no_polimorphic_and_results(self, event_manager): + def test_get_listeners_monomorphic_and_results(self, event_manager): event_manager._listeners = { member_events.MemberEvent: ["coroutine0"], member_events.MemberCreateEvent: ["coroutine1", "coroutine2"], @@ -633,6 +822,7 @@ async def test(): async def test2(): ... + event_manager._clear_enabled_cache = mock.Mock() event_manager._listeners = { member_events.MemberCreateEvent: [test, test2], member_events.MemberDeleteEvent: [test], @@ -644,16 +834,19 @@ async def test2(): member_events.MemberCreateEvent: [test2], member_events.MemberDeleteEvent: [test], } + event_manager._clear_enabled_cache.assert_not_called() def test_unsubscribe_when_event_type_when_list_empty_after_delete(self, event_manager): async def test(): ... + event_manager._clear_enabled_cache = mock.Mock() event_manager._listeners = {member_events.MemberCreateEvent: [test], member_events.MemberDeleteEvent: [test]} event_manager.unsubscribe(member_events.MemberCreateEvent, test) assert event_manager._listeners == {member_events.MemberDeleteEvent: [test]} + event_manager._clear_enabled_cache.assert_called_once_with() def test_listen_when_no_params(self, event_manager): with pytest.raises(TypeError):