diff --git a/hikari/api/cache.py b/hikari/api/cache.py index 80cc459e90..0ca7980637 100644 --- a/hikari/api/cache.py +++ b/hikari/api/cache.py @@ -32,6 +32,7 @@ if typing.TYPE_CHECKING: from hikari import channels + from hikari import config from hikari import emojis from hikari import guilds from hikari import invites @@ -88,6 +89,11 @@ class Cache(abc.ABC): __slots__: typing.Sequence[str] = () + @property + @abc.abstractmethod + def settings(self) -> config.CacheSettings: + """Get the configured settings for this cache.""" + @abc.abstractmethod def get_dm_channel_id( self, user: snowflakes.SnowflakeishOr[users.PartialUser], / diff --git a/hikari/api/entity_factory.py b/hikari/api/entity_factory.py index 8d127524d9..3f2155e18c 100644 --- a/hikari/api/entity_factory.py +++ b/hikari/api/entity_factory.py @@ -28,10 +28,7 @@ import abc import typing -import attr - from hikari import undefined -from hikari.internal import attr_extensions if typing.TYPE_CHECKING: from hikari import applications as application_models @@ -59,53 +56,56 @@ from hikari.internal import data_binding -@attr_extensions.with_copy -@attr.define(weakref_slot=False) -class GatewayGuildDefinition: - """A structure for handling entities within guild create and update events.""" - - guild: guild_models.GatewayGuild = attr.field() - """Object of the guild the definition is for.""" - - channels: typing.Optional[typing.Mapping[snowflakes.Snowflake, channel_models.GuildChannel]] = attr.field() - """Mapping of channel IDs to the channels that belong to the guild. +class GatewayGuildDefinition(abc.ABC): + """Structure for handling entities within guild create and update events. - Will be `builtins.None` when returned by guild update gateway events rather - than create. + !!! warning + The methods on this class may raise `builtins.LookupError` if called + when the relevant resource isn't available in the inner payload. """ - members: typing.Optional[typing.Mapping[snowflakes.Snowflake, guild_models.Member]] = attr.field() - """Mapping of user IDs to the members that belong to the guild. + __slots__: typing.Sequence[str] = () - Will be `builtins.None` when returned by guild update gateway events rather - than create. + @property + @abc.abstractmethod + def id(self) -> snowflakes.Snowflake: + """ID of the guild the definition is for.""" - !!! note - This may be a partial mapping of members in the guild. - """ + @abc.abstractmethod + def channels(self) -> typing.Mapping[snowflakes.Snowflake, channel_models.GuildChannel]: + """Get a mapping of channel IDs to the channels that belong to the guild.""" - presences: typing.Optional[typing.Mapping[snowflakes.Snowflake, presence_models.MemberPresence]] = attr.field() - """Mapping of user IDs to the presences that are active in the guild. + @abc.abstractmethod + def emojis(self) -> typing.Mapping[snowflakes.Snowflake, emoji_models.KnownCustomEmoji]: + """Get a mapping of emoji IDs to the emojis that belong to the guild.""" - Will be `builtins.None` when returned by guild update gateway events rather - than create. + @abc.abstractmethod + def guild(self) -> guild_models.GatewayGuild: + """Get the object of the guild this definition is for.""" - !!! note - This may be a partial mapping of presences active in the guild. - """ + @abc.abstractmethod + def members(self) -> typing.Mapping[snowflakes.Snowflake, guild_models.Member]: + """Get a mapping of user IDs to the members that belong to the guild. - roles: typing.Mapping[snowflakes.Snowflake, guild_models.Role] = attr.field() - """Mapping of role IDs to the roles that belong to the guild.""" + !!! note + This may be a partial mapping of members in the guild. + """ + + @abc.abstractmethod + def presences(self) -> typing.Mapping[snowflakes.Snowflake, presence_models.MemberPresence]: + """Get a mapping of user IDs to the presences that are active in the guild. - emojis: typing.Mapping[snowflakes.Snowflake, emoji_models.KnownCustomEmoji] = attr.field() - """Mapping of emoji IDs to the emojis that belong to the guild.""" + !!! note + This may be a partial mapping of presences active in the guild. + """ - voice_states: typing.Optional[typing.Mapping[snowflakes.Snowflake, voice_models.VoiceState]] = attr.field() - """Mapping of user IDs to the voice states that are active in the guild. + @abc.abstractmethod + def roles(self) -> typing.Mapping[snowflakes.Snowflake, guild_models.Role]: + """Get a mapping of role IDs to the roles that belong to the guild.""" - !!! note - This may be a partial mapping of voice states active in the guild. - """ + @abc.abstractmethod + def voice_states(self) -> typing.Mapping[snowflakes.Snowflake, voice_models.VoiceState]: + """Get a mapping of user IDs to the voice states that are active in the guild.""" class EntityFactory(abc.ABC): diff --git a/hikari/events/base_events.py b/hikari/events/base_events.py index 0a1f3b2af6..62aeed7af3 100644 --- a/hikari/events/base_events.py +++ b/hikari/events/base_events.py @@ -57,6 +57,23 @@ class Event(abc.ABC): __slots__: typing.Sequence[str] = () + __dispatches: typing.ClassVar[typing.Tuple[typing.Type[Event], ...]] + + def __init_subclass__(cls) -> None: + super().__init_subclass__() + # hasattr doesn't work with private variables in this case so we use a try except. + # We need to set Event's __dispatches when the first subclass is made as Event cannot + # be included in a tuple literal on itself due to not existing yet. + try: + Event.__dispatches + except AttributeError: + Event.__dispatches = (Event,) + + mro = cls.mro() + # We don't have to explicitly include Event here as issubclass(Event, Event) returns True. + # Non-event classes should be ignored. + cls.__dispatches = tuple(cls for cls in mro if issubclass(cls, Event)) + @property @abc.abstractmethod def app(self) -> traits.RESTAware: @@ -68,6 +85,11 @@ def app(self) -> traits.RESTAware: The REST-aware app trait. """ + @classmethod + def dispatches(cls) -> typing.Sequence[typing.Type[Event]]: + """Sequence of the event classes this event is dispatched as.""" + return cls.__dispatches + def get_required_intents_for(event_type: typing.Type[Event]) -> typing.Collection[intents.Intents]: """Retrieve the intents that are required to listen to an event type. diff --git a/hikari/impl/bot.py b/hikari/impl/bot.py index f707ad48d6..fd3b8a1578 100644 --- a/hikari/impl/bot.py +++ b/hikari/impl/bot.py @@ -350,7 +350,9 @@ def __init__( self._event_factory = event_factory_impl.EventFactoryImpl(self) # Event handling - self._event_manager = event_manager_impl.EventManagerImpl(self._event_factory, self._intents, cache=self._cache) + self._event_manager = event_manager_impl.EventManagerImpl( + self._entity_factory, self._event_factory, self._intents, cache=self._cache + ) # Voice subsystem self._voice = voice_impl.VoiceComponentImpl(self) diff --git a/hikari/impl/cache.py b/hikari/impl/cache.py index 3b32fa7b68..3616269b45 100644 --- a/hikari/impl/cache.py +++ b/hikari/impl/cache.py @@ -769,15 +769,19 @@ def get_me(self) -> typing.Optional[users.OwnUser]: return copy.copy(self._me) def set_me(self, user: users.OwnUser, /) -> None: - self._me = copy.copy(user) + if self._is_cache_enabled_for(config.CacheComponents.ME): + _LOGGER.debug("setting my user to %s", user) + self._me = copy.copy(user) def update_me( self, user: users.OwnUser, / ) -> typing.Tuple[typing.Optional[users.OwnUser], typing.Optional[users.OwnUser]]: - _LOGGER.debug("setting my user to %s", user) + if not self._is_cache_enabled_for(config.CacheComponents.ME): + return None, None + cached_user = self.get_me() self.set_me(user) - return cached_user, self._me + return cached_user, self.get_me() def _build_member( self, diff --git a/hikari/impl/entity_factory.py b/hikari/impl/entity_factory.py index 858ffe5f6c..f336e1af0f 100644 --- a/hikari/impl/entity_factory.py +++ b/hikari/impl/entity_factory.py @@ -182,6 +182,150 @@ class _UserFields: is_system: bool = attr.field() +@attr_extensions.with_copy +@attr.define(weakref_slot=False) +class _GatewayGuildDefinition(entity_factory.GatewayGuildDefinition): + """A structure for handling entities within guild create and update events.""" + + id: snowflakes.Snowflake = attr.field() + _payload: data_binding.JSONObject = attr.field() + _entity_factory: EntityFactoryImpl = attr.field() + _channels: typing.Optional[typing.Mapping[snowflakes.Snowflake, channel_models.GuildChannel]] = attr.field( + default=None, init=False + ) + _emojis: typing.Optional[typing.Mapping[snowflakes.Snowflake, emoji_models.KnownCustomEmoji]] = attr.field( + default=None, init=False + ) + _guild: typing.Optional[guild_models.GatewayGuild] = attr.field(default=None) + _members: typing.Optional[typing.Mapping[snowflakes.Snowflake, guild_models.Member]] = attr.field( + default=None, init=False + ) + _presences: typing.Optional[typing.Mapping[snowflakes.Snowflake, presence_models.MemberPresence]] = attr.field( + default=None, init=False + ) + _roles: typing.Optional[typing.Mapping[snowflakes.Snowflake, guild_models.Role]] = attr.field( + default=None, init=False + ) + _voice_states: typing.Optional[typing.Mapping[snowflakes.Snowflake, voice_models.VoiceState]] = attr.field( + default=None, init=False + ) + + def channels(self) -> typing.Mapping[snowflakes.Snowflake, channel_models.GuildChannel]: + if self._channels is None: + self._channels = {} + + for channel_payload in self._payload["channels"]: + try: + channel = self._entity_factory.deserialize_channel(channel_payload, guild_id=self.id) + except errors.UnrecognisedEntityError: + # Ignore the channel, this has already been logged + continue + + assert isinstance(channel, channel_models.GuildChannel) + self._channels[channel.id] = channel + + return self._channels + + def emojis(self) -> typing.Mapping[snowflakes.Snowflake, emoji_models.KnownCustomEmoji]: + if self._emojis is None: + deserialize = self._entity_factory.deserialize_known_custom_emoji + self._emojis = { + snowflakes.Snowflake(emoji["id"]): deserialize(emoji, guild_id=self.id) + for emoji in self._payload["emojis"] + } + + return self._emojis + + def guild(self) -> guild_models.GatewayGuild: + if self._guild is None: + payload = self._payload + guild_fields = self._entity_factory.set_guild_attributes(payload) + is_large = payload.get("large") + joined_at = ( + time.iso8601_datetime_string_to_datetime(payload["joined_at"]) if "joined_at" in payload else None + ) + member_count = int(payload["member_count"]) if "member_count" in payload else None + self._guild = guild_models.GatewayGuild( + app=self._entity_factory.app, + id=guild_fields.id, + name=guild_fields.name, + icon_hash=guild_fields.icon_hash, + features=guild_fields.features, + splash_hash=guild_fields.splash_hash, + discovery_splash_hash=guild_fields.discovery_splash_hash, + owner_id=guild_fields.owner_id, + afk_channel_id=guild_fields.afk_channel_id, + afk_timeout=guild_fields.afk_timeout, + verification_level=guild_fields.verification_level, + default_message_notifications=guild_fields.default_message_notifications, + explicit_content_filter=guild_fields.explicit_content_filter, + mfa_level=guild_fields.mfa_level, + application_id=guild_fields.application_id, + widget_channel_id=guild_fields.widget_channel_id, + system_channel_id=guild_fields.system_channel_id, + is_widget_enabled=guild_fields.is_widget_enabled, + system_channel_flags=guild_fields.system_channel_flags, + rules_channel_id=guild_fields.rules_channel_id, + max_video_channel_users=guild_fields.max_video_channel_users, + vanity_url_code=guild_fields.vanity_url_code, + description=guild_fields.description, + banner_hash=guild_fields.banner_hash, + premium_tier=guild_fields.premium_tier, + premium_subscription_count=guild_fields.premium_subscription_count, + preferred_locale=guild_fields.preferred_locale, + public_updates_channel_id=guild_fields.public_updates_channel_id, + nsfw_level=guild_fields.nsfw_level, + is_large=is_large, + joined_at=joined_at, + member_count=member_count, + ) + + return self._guild + + def members(self) -> typing.Mapping[snowflakes.Snowflake, guild_models.Member]: + if self._members is None: + self._members = {} + + for member_payload in self._payload["members"]: + member = self._entity_factory.deserialize_member(member_payload, guild_id=self.id) + self._members[member.user.id] = member + + return self._members + + def presences(self) -> typing.Mapping[snowflakes.Snowflake, presence_models.MemberPresence]: + if self._presences is None: + self._presences = {} + + for presence_payload in self._payload["presences"]: + presence = self._entity_factory.deserialize_member_presence(presence_payload, guild_id=self.id) + self._presences[presence.user_id] = presence + + return self._presences + + def roles(self) -> typing.Mapping[snowflakes.Snowflake, guild_models.Role]: + if self._roles is None: + self._roles = { + snowflakes.Snowflake(role["id"]): self._entity_factory.deserialize_role(role, guild_id=self.id) + for role in self._payload["roles"] + } + + return self._roles + + def voice_states(self) -> typing.Mapping[snowflakes.Snowflake, voice_models.VoiceState]: + if self._voice_states is None: + members = self.members() + self._voice_states = {} + + for voice_state_payload in self._payload["voice_states"]: + member = members[snowflakes.Snowflake(voice_state_payload["user_id"])] + voice_state = self._entity_factory.deserialize_voice_state( + voice_state_payload, guild_id=self.id, member=member + ) + self._voice_states[voice_state.user_id] = voice_state + + return self._voice_states + + class EntityFactoryImpl(entity_factory.EntityFactory): """Standard implementation for a serializer/deserializer. @@ -296,6 +440,11 @@ def __init__(self, app: traits.RESTAware) -> None: webhook_models.WebhookType.APPLICATION: self.deserialize_application_webhook, } + @property + def app(self) -> traits.RESTAware: + """Object of the application this entity factory is bound to.""" + return self._app + ###################### # APPLICATION MODELS # ###################### @@ -1357,7 +1506,7 @@ def deserialize_guild_preview(self, payload: data_binding.JSONObject) -> guild_m description=payload["description"], ) - def _set_guild_attributes(self, payload: data_binding.JSONObject) -> _GuildFields: + def set_guild_attributes(self, payload: data_binding.JSONObject) -> _GuildFields: afk_channel_id = payload["afk_channel_id"] default_message_notifications = guild_models.GuildMessageNotificationsLevel( payload["default_message_notifications"] @@ -1408,7 +1557,7 @@ def _set_guild_attributes(self, payload: data_binding.JSONObject) -> _GuildField ) def deserialize_rest_guild(self, payload: data_binding.JSONObject) -> guild_models.RESTGuild: - guild_fields = self._set_guild_attributes(payload) + guild_fields = self.set_guild_attributes(payload) approximate_member_count: typing.Optional[int] = None if "approximate_member_count" in payload: @@ -1470,96 +1619,8 @@ def deserialize_rest_guild(self, payload: data_binding.JSONObject) -> guild_mode ) def deserialize_gateway_guild(self, payload: data_binding.JSONObject) -> entity_factory.GatewayGuildDefinition: - guild_fields = self._set_guild_attributes(payload) - is_large = payload.get("large") - joined_at = time.iso8601_datetime_string_to_datetime(payload["joined_at"]) if "joined_at" in payload else None - member_count = int(payload["member_count"]) if "member_count" in payload else None - - guild = guild_models.GatewayGuild( - app=self._app, - id=guild_fields.id, - name=guild_fields.name, - icon_hash=guild_fields.icon_hash, - features=guild_fields.features, - splash_hash=guild_fields.splash_hash, - discovery_splash_hash=guild_fields.discovery_splash_hash, - owner_id=guild_fields.owner_id, - afk_channel_id=guild_fields.afk_channel_id, - afk_timeout=guild_fields.afk_timeout, - verification_level=guild_fields.verification_level, - default_message_notifications=guild_fields.default_message_notifications, - explicit_content_filter=guild_fields.explicit_content_filter, - mfa_level=guild_fields.mfa_level, - application_id=guild_fields.application_id, - widget_channel_id=guild_fields.widget_channel_id, - system_channel_id=guild_fields.system_channel_id, - is_widget_enabled=guild_fields.is_widget_enabled, - system_channel_flags=guild_fields.system_channel_flags, - rules_channel_id=guild_fields.rules_channel_id, - max_video_channel_users=guild_fields.max_video_channel_users, - vanity_url_code=guild_fields.vanity_url_code, - description=guild_fields.description, - banner_hash=guild_fields.banner_hash, - premium_tier=guild_fields.premium_tier, - premium_subscription_count=guild_fields.premium_subscription_count, - preferred_locale=guild_fields.preferred_locale, - public_updates_channel_id=guild_fields.public_updates_channel_id, - nsfw_level=guild_fields.nsfw_level, - is_large=is_large, - joined_at=joined_at, - member_count=member_count, - ) - - members: typing.Optional[typing.Dict[snowflakes.Snowflake, guild_models.Member]] = None - if "members" in payload: - members = {} - - for member_payload in payload["members"]: - member = self.deserialize_member(member_payload, guild_id=guild.id) - members[member.user.id] = member - - channels: typing.Optional[typing.Dict[snowflakes.Snowflake, channel_models.GuildChannel]] = None - if "channels" in payload: - channels = {} - - for channel_payload in payload["channels"]: - try: - channel = self.deserialize_channel(channel_payload, guild_id=guild.id) - except errors.UnrecognisedEntityError: - # Ignore the channel, this has already been logged - continue - - assert isinstance(channel, channel_models.GuildChannel) - channels[channel.id] = channel - - presences: typing.Optional[typing.Dict[snowflakes.Snowflake, presence_models.MemberPresence]] = None - if "presences" in payload: - presences = {} - - for presence_payload in payload["presences"]: - presence = self.deserialize_member_presence(presence_payload, guild_id=guild.id) - presences[presence.user_id] = presence - - voice_states: typing.Optional[typing.Dict[snowflakes.Snowflake, voice_models.VoiceState]] = None - if "voice_states" in payload: - voice_states = {} - assert members is not None - - for voice_state_payload in payload["voice_states"]: - member = members[snowflakes.Snowflake(voice_state_payload["user_id"])] - voice_state = self.deserialize_voice_state(voice_state_payload, guild_id=guild.id, member=member) - voice_states[voice_state.user_id] = voice_state - - roles = { - snowflakes.Snowflake(role["id"]): self.deserialize_role(role, guild_id=guild.id) - for role in payload["roles"] - } - emojis = { - snowflakes.Snowflake(emoji["id"]): self.deserialize_known_custom_emoji(emoji, guild_id=guild.id) - for emoji in payload["emojis"] - } - - return entity_factory.GatewayGuildDefinition(guild, channels, members, presences, roles, emojis, voice_states) + guild_id = snowflakes.Snowflake(payload["id"]) + return _GatewayGuildDefinition(id=guild_id, payload=payload, entity_factory=self) ################# # INVITE MODELS # diff --git a/hikari/impl/event_factory.py b/hikari/impl/event_factory.py index 79e071326f..e6bdea3962 100644 --- a/hikari/impl/event_factory.py +++ b/hikari/impl/event_factory.py @@ -195,19 +195,15 @@ def deserialize_guild_available_event( self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject ) -> guild_events.GuildAvailableEvent: guild_information = self._app.entity_factory.deserialize_gateway_guild(payload) - assert guild_information.channels is not None - assert guild_information.members is not None - assert guild_information.presences is not None - assert guild_information.voice_states is not None return guild_events.GuildAvailableEvent( shard=shard, - guild=guild_information.guild, - emojis=guild_information.emojis, - roles=guild_information.roles, - channels=guild_information.channels, - members=guild_information.members, - presences=guild_information.presences, - voice_states=guild_information.voice_states, + guild=guild_information.guild(), + emojis=guild_information.emojis(), + roles=guild_information.roles(), + channels=guild_information.channels(), + members=guild_information.members(), + presences=guild_information.presences(), + voice_states=guild_information.voice_states(), ) def deserialize_guild_join_event( @@ -220,13 +216,13 @@ def deserialize_guild_join_event( assert guild_information.voice_states is not None return guild_events.GuildJoinEvent( shard=shard, - guild=guild_information.guild, - emojis=guild_information.emojis, - roles=guild_information.roles, - channels=guild_information.channels, - members=guild_information.members, - presences=guild_information.presences, - voice_states=guild_information.voice_states, + guild=guild_information.guild(), + emojis=guild_information.emojis(), + roles=guild_information.roles(), + channels=guild_information.channels(), + members=guild_information.members(), + presences=guild_information.presences(), + voice_states=guild_information.voice_states(), ) def deserialize_guild_update_event( @@ -239,9 +235,9 @@ def deserialize_guild_update_event( guild_information = self._app.entity_factory.deserialize_gateway_guild(payload) return guild_events.GuildUpdateEvent( shard=shard, - guild=guild_information.guild, - emojis=guild_information.emojis, - roles=guild_information.roles, + guild=guild_information.guild(), + emojis=guild_information.emojis(), + roles=guild_information.roles(), old_guild=old_guild, ) diff --git a/hikari/impl/event_manager.py b/hikari/impl/event_manager.py index fb2161791d..023ca6ecff 100644 --- a/hikari/impl/event_manager.py +++ b/hikari/impl/event_manager.py @@ -28,27 +28,43 @@ import asyncio import base64 +import logging import random import typing +from hikari import config from hikari import errors from hikari import intents as intents_ -from hikari import presences +from hikari import presences as presences_ from hikari import snowflakes +from hikari.events import channel_events +from hikari.events import guild_events +from hikari.events import member_events +from hikari.events import message_events +from hikari.events import reaction_events +from hikari.events import role_events +from hikari.events import shard_events +from hikari.events import typing_events +from hikari.events import user_events +from hikari.events import voice_events from hikari.impl import event_manager_base from hikari.internal import time +from hikari.internal import ux if typing.TYPE_CHECKING: from hikari import guilds from hikari import invites from hikari import voices from hikari.api import cache as cache_ + from hikari.api import entity_factory as entity_factory_ from hikari.api import event_factory as event_factory_ from hikari.api import shard as gateway_shard - from hikari.events import guild_events as guild_events from hikari.internal import data_binding +_LOGGER: typing.Final[logging.Logger] = logging.getLogger("hikari.event_manager") + + def _fixed_size_nonce() -> str: # This generates nonces of length 28 for use in member chunking. head = time.monotonic_ns().to_bytes(8, "big") @@ -58,7 +74,7 @@ def _fixed_size_nonce() -> str: async def _request_guild_members( shard: gateway_shard.GatewayShard, - guild: guilds.PartialGuild, + guild: snowflakes.SnowflakeishOr[guilds.PartialGuild], *, include_presences: bool, nonce: str, @@ -74,10 +90,11 @@ async def _request_guild_members( class EventManagerImpl(event_manager_base.EventManagerBase): """Provides event handling logic for Discord events.""" - __slots__: typing.Sequence[str] = ("_cache",) + __slots__: typing.Sequence[str] = ("_cache", "_entity_factory") def __init__( self, + entity_factory: entity_factory_.EntityFactory, event_factory: event_factory_.EventFactory, intents: intents_.Intents, /, @@ -85,8 +102,14 @@ def __init__( cache: typing.Optional[cache_.MutableCache] = None, ) -> None: self._cache = cache - super().__init__(event_factory=event_factory, intents=intents) + self._entity_factory = entity_factory + components = cache.settings.components if cache else config.CacheComponents.NONE + super().__init__(event_factory=event_factory, intents=intents, cache_components=components) + def _cache_enabled_for(self, components: config.CacheComponents, /) -> bool: + return self._cache is not None and (self._cache.settings.components & components) == components + + @event_manager_base.filtered(shard_events.ShardReadyEvent, config.CacheComponents.ME) async def on_ready(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#ready for more info.""" # TODO: cache unavailable guilds on startup, I didn't bother for the time being. @@ -97,10 +120,12 @@ async def on_ready(self, shard: gateway_shard.GatewayShard, payload: data_bindin await self.dispatch(event) + @event_manager_base.filtered(shard_events.ShardResumedEvent) async def on_resumed(self, shard: gateway_shard.GatewayShard, _: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#resumed for more info.""" await self.dispatch(self._event_factory.deserialize_resumed_event(shard)) + @event_manager_base.filtered(channel_events.GuildChannelCreateEvent, config.CacheComponents.GUILD_CHANNELS) async def on_channel_create(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#channel-create for more info.""" event = self._event_factory.deserialize_guild_channel_create_event(shard, payload) @@ -110,6 +135,7 @@ async def on_channel_create(self, shard: gateway_shard.GatewayShard, payload: da await self.dispatch(event) + @event_manager_base.filtered(channel_events.GuildChannelUpdateEvent, config.CacheComponents.GUILD_CHANNELS) async def on_channel_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#channel-update for more info.""" old = self._cache.get_guild_channel(snowflakes.Snowflake(payload["id"])) if self._cache else None @@ -120,6 +146,7 @@ async def on_channel_update(self, shard: gateway_shard.GatewayShard, payload: da await self.dispatch(event) + @event_manager_base.filtered(channel_events.GuildChannelDeleteEvent, config.CacheComponents.GUILD_CHANNELS) async def on_channel_delete(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#channel-delete for more info.""" event = self._event_factory.deserialize_guild_channel_delete_event(shard, payload) @@ -129,87 +156,172 @@ async def on_channel_delete(self, shard: gateway_shard.GatewayShard, payload: da await self.dispatch(event) + @event_manager_base.filtered((channel_events.GuildPinsUpdateEvent, channel_events.DMPinsUpdateEvent)) async def on_channel_pins_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#channel-pins-update for more info.""" # TODO: we need a method for this specifically await self.dispatch(self._event_factory.deserialize_channel_pins_update_event(shard, payload)) + # Internal granularity is preferred for GUILD_CREATE over decorator based filtering due to its large cache scope. async def on_guild_create(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-create for more info.""" - event: typing.Union[guild_events.GuildAvailableEvent, guild_events.GuildJoinEvent] + enabled_for_event = self._enabled_for_event(guild_events.GuildAvailableEvent) + if not enabled_for_event and self._cache: + _LOGGER.log(ux.TRACE, "Skipping on_guild_create dispatch due to lack of any registered listeners") + event: typing.Union[guild_events.GuildAvailableEvent, guild_events.GuildJoinEvent, None] = None + gd = self._entity_factory.deserialize_gateway_guild(payload) + + channels = gd.channels() if self._cache_enabled_for(config.CacheComponents.GUILD_CHANNELS) else None + emojis = gd.emojis() if self._cache_enabled_for(config.CacheComponents.EMOJIS) else None + guild = gd.guild() if self._cache_enabled_for(config.CacheComponents.GUILDS) else None + guild_id = gd.id + members = gd.members() if self._cache_enabled_for(config.CacheComponents.MEMBERS) else None + presences = gd.presences() if self._cache_enabled_for(config.CacheComponents.PRESENCES) else None + roles = gd.roles() if self._cache_enabled_for(config.CacheComponents.ROLES) else None + voice_states = gd.voice_states() if self._cache_enabled_for(config.CacheComponents.VOICE_STATES) else None + + elif enabled_for_event: + if "unavailable" in payload: + event = self._event_factory.deserialize_guild_available_event(shard, payload) + else: + event = self._event_factory.deserialize_guild_join_event(shard, payload) + + channels = event.channels + emojis = event.emojis + guild = event.guild + guild_id = guild.id + members = event.members + presences = event.presences + roles = event.roles + voice_states = event.voice_states - if "unavailable" in payload: - event = self._event_factory.deserialize_guild_available_event(shard, payload) else: - event = self._event_factory.deserialize_guild_join_event(shard, payload) + event = None + channels = None + emojis = None + guild = None + guild_id = snowflakes.Snowflake(payload["id"]) + members = None + presences = None + roles = None + voice_states = None if self._cache: - self._cache.update_guild(event.guild) + if guild: + self._cache.update_guild(guild) - self._cache.clear_guild_channels_for_guild(event.guild.id) - for channel in event.channels.values(): - self._cache.set_guild_channel(channel) + if channels: + self._cache.clear_guild_channels_for_guild(guild_id) + for channel in channels.values(): + self._cache.set_guild_channel(channel) - self._cache.clear_emojis_for_guild(event.guild.id) - for emoji in event.emojis.values(): - self._cache.set_emoji(emoji) + if emojis: + self._cache.clear_emojis_for_guild(guild_id) + for emoji in emojis.values(): + self._cache.set_emoji(emoji) - self._cache.clear_roles_for_guild(event.guild.id) - for role in event.roles.values(): - self._cache.set_role(role) + if roles: + self._cache.clear_roles_for_guild(guild_id) + for role in roles.values(): + self._cache.set_role(role) - # TODO: do we really want to invalidate these all after an outage. - self._cache.clear_members_for_guild(event.guild.id) - for member in event.members.values(): - self._cache.set_member(member) + if members: + # TODO: do we really want to invalidate these all after an outage. + self._cache.clear_members_for_guild(guild_id) + for member in members.values(): + self._cache.set_member(member) - self._cache.clear_presences_for_guild(event.guild.id) - for presence in event.presences.values(): - self._cache.set_presence(presence) + if presences: + self._cache.clear_presences_for_guild(guild_id) + for presence in presences.values(): + self._cache.set_presence(presence) - self._cache.clear_voice_states_for_guild(event.guild.id) - for voice_state in event.voice_states.values(): - self._cache.set_voice_state(voice_state) + if voice_states: + self._cache.clear_voice_states_for_guild(guild_id) + for voice_state in voice_states.values(): + self._cache.set_voice_state(voice_state) - members_declared = self._intents & intents_.Intents.GUILD_MEMBERS - presences_declared = self._intents & intents_.Intents.GUILD_PRESENCES + recv_chunks = self._enabled_for_event(shard_events.MemberChunkEvent) or self._cache_enabled_for( + config.CacheComponents.MEMBERS + ) + members_declared = self._intents & intents_.Intents.GUILD_MEMBERS + presences_declared = self._intents & intents_.Intents.GUILD_PRESENCES + + # When intents are enabled discord will only send other member objects on the guild create + # payload if presence intents are also declared, so if this isn't the case then we also want + # to chunk small guilds. + if recv_chunks and members_declared and (payload.get("large") or not presences_declared): + # We create a task here instead of awaiting the result to avoid any rate-limits from delaying dispatch. + nonce = f"{shard.id}.{_fixed_size_nonce()}" - # When intents are enabled discord will only send other member objects on the guild create - # payload if presence intents are also declared, so if this isn't the case then we also want - # to chunk small guilds. - if members_declared and (event.guild.is_large or not presences_declared): - # We create a task here instead of awaiting the result to avoid any rate-limits from delaying dispatch. - nonce = f"{shard.id}.{_fixed_size_nonce()}" + if event: event.chunk_nonce = nonce - coroutine = _request_guild_members( - shard, event.guild, include_presences=bool(presences_declared), nonce=nonce - ) - asyncio.create_task(coroutine, name=f"{shard.id}:{event.guild.id} guild create members request") - await self.dispatch(event) + coroutine = _request_guild_members(shard, guild_id, include_presences=bool(presences_declared), nonce=nonce) + asyncio.create_task(coroutine, name=f"{shard.id}:{guild_id} guild create members request") + + if event: + await self.dispatch(event) + # Internal granularity is preferred for GUILD_UPDATE over decorator based filtering due to its large cache scope. async def on_guild_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-update for more info.""" - old = self._cache.get_guild(snowflakes.Snowflake(payload["id"])) if self._cache else None - event = self._event_factory.deserialize_guild_update_event(shard, payload, old_guild=old) + enabled_for_event = self._enabled_for_event(guild_events.GuildUpdateEvent) + + if not enabled_for_event and self._cache: + _LOGGER.log(ux.TRACE, "Skipping on_guild_update raw dispatch due to lack of any registered listeners") + event: typing.Optional[guild_events.GuildUpdateEvent] = None + gd = self._entity_factory.deserialize_gateway_guild(payload) + emojis = gd.emojis() if self._cache_enabled_for(config.CacheComponents.EMOJIS) else None + guild = gd.guild() if self._cache_enabled_for(config.CacheComponents.GUILDS) else None + guild_id = gd.id + roles = gd.roles() if self._cache_enabled_for(config.CacheComponents.ROLES) else None + + elif enabled_for_event: + guild_id = snowflakes.Snowflake(payload["id"]) + old = self._cache.get_guild(guild_id) if self._cache else None + event = self._event_factory.deserialize_guild_update_event(shard, payload, old_guild=old) + emojis = event.emojis + guild = event.guild + roles = event.roles - if self._cache: - self._cache.update_guild(event.guild) - - self._cache.clear_roles_for_guild(event.guild.id) - for role in event.roles.values(): # TODO: do we actually get this here? - self._cache.set_role(role) + else: + _LOGGER.log( + ux.TRACE, "Skipping on_guild_update raw dispatch due to lack of any registered listeners or cache need" + ) + return - self._cache.clear_emojis_for_guild(event.guild.id) # TODO: do we actually get this here? - for emoji in event.emojis.values(): - self._cache.set_emoji(emoji) + if self._cache: + if guild: + self._cache.update_guild(guild) - await self.dispatch(event) + if emojis: + self._cache.clear_emojis_for_guild(guild_id) + for emoji in emojis.values(): + self._cache.set_emoji(emoji) + if roles: + self._cache.clear_roles_for_guild(guild_id) + for role in roles.values(): + self._cache.set_role(role) + + if event: + await self.dispatch(event) + + @event_manager_base.filtered( + (guild_events.GuildLeaveEvent, guild_events.GuildUnavailableEvent), + config.CacheComponents.GUILDS + | config.CacheComponents.GUILD_CHANNELS + | config.CacheComponents.EMOJIS + | config.CacheComponents.ROLES + | config.CacheComponents.PRESENCES + | config.CacheComponents.VOICE_STATES + | config.CacheComponents.MEMBERS, + ) async def on_guild_delete(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-delete for more info.""" event: typing.Union[guild_events.GuildUnavailableEvent, guild_events.GuildLeaveEvent] - if payload.get("unavailable", False): + if payload.get("unavailable"): event = self._event_factory.deserialize_guild_unavailable_event(shard, payload) if self._cache: @@ -233,14 +345,17 @@ async def on_guild_delete(self, shard: gateway_shard.GatewayShard, payload: data await self.dispatch(event) + @event_manager_base.filtered(guild_events.BanCreateEvent) async def on_guild_ban_add(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-ban-add for more info.""" await self.dispatch(self._event_factory.deserialize_guild_ban_add_event(shard, payload)) + @event_manager_base.filtered(guild_events.BanDeleteEvent) async def on_guild_ban_remove(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-ban-remove for more info.""" await self.dispatch(self._event_factory.deserialize_guild_ban_remove_event(shard, payload)) + @event_manager_base.filtered(guild_events.EmojisUpdateEvent, config.CacheComponents.EMOJIS) async def on_guild_emojis_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-emojis-update for more info.""" guild_id = snowflakes.Snowflake(payload["guild_id"]) @@ -254,24 +369,29 @@ async def on_guild_emojis_update(self, shard: gateway_shard.GatewayShard, payloa await self.dispatch(event) + @event_manager_base.filtered(()) # An empty sequence here means that this method will always be skipped. async def on_guild_integrations_update(self, _: gateway_shard.GatewayShard, __: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-integrations-update for more info.""" # This is only here to stop this being logged or dispatched as an "unknown event". # This event is made redundant by INTEGRATION_CREATE/DELETE/UPDATE and is thus not parsed or dispatched. - return None + raise NotImplementedError + @event_manager_base.filtered(guild_events.IntegrationCreateEvent) async def on_integration_create(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: event = self._event_factory.deserialize_integration_create_event(shard, payload) await self.dispatch(event) + @event_manager_base.filtered(guild_events.IntegrationDeleteEvent) async def on_integration_delete(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: event = self._event_factory.deserialize_integration_delete_event(shard, payload) await self.dispatch(event) + @event_manager_base.filtered(guild_events.IntegrationUpdateEvent) async def on_integration_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: event = self._event_factory.deserialize_integration_update_event(shard, payload) await self.dispatch(event) + @event_manager_base.filtered(member_events.MemberCreateEvent, config.CacheComponents.MEMBERS) async def on_guild_member_add(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-member-add for more info.""" event = self._event_factory.deserialize_guild_member_add_event(shard, payload) @@ -281,6 +401,7 @@ async def on_guild_member_add(self, shard: gateway_shard.GatewayShard, payload: await self.dispatch(event) + @event_manager_base.filtered(member_events.MemberDeleteEvent, config.CacheComponents.MEMBERS) async def on_guild_member_remove(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-member-remove for more info.""" old: typing.Optional[guilds.Member] = None @@ -292,6 +413,7 @@ async def on_guild_member_remove(self, shard: gateway_shard.GatewayShard, payloa event = self._event_factory.deserialize_guild_member_remove_event(shard, payload, old_member=old) await self.dispatch(event) + @event_manager_base.filtered(member_events.MemberUpdateEvent, config.CacheComponents.MEMBERS) async def on_guild_member_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-member-update for more info.""" old: typing.Optional[guilds.Member] = None @@ -307,6 +429,7 @@ async def on_guild_member_update(self, shard: gateway_shard.GatewayShard, payloa await self.dispatch(event) + @event_manager_base.filtered(shard_events.MemberChunkEvent, config.CacheComponents.MEMBERS) async def on_guild_members_chunk(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-members-chunk for more info.""" event = self._event_factory.deserialize_guild_member_chunk_event(shard, payload) @@ -320,6 +443,7 @@ async def on_guild_members_chunk(self, shard: gateway_shard.GatewayShard, payloa await self.dispatch(event) + @event_manager_base.filtered(role_events.RoleCreateEvent, config.CacheComponents.ROLES) async def on_guild_role_create(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-role-create for more info.""" event = self._event_factory.deserialize_guild_role_create_event(shard, payload) @@ -329,6 +453,7 @@ async def on_guild_role_create(self, shard: gateway_shard.GatewayShard, payload: await self.dispatch(event) + @event_manager_base.filtered(role_events.RoleUpdateEvent, config.CacheComponents.ROLES) async def on_guild_role_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-role-update for more info.""" old = self._cache.get_role(snowflakes.Snowflake(payload["role"]["id"])) if self._cache else None @@ -339,6 +464,7 @@ async def on_guild_role_update(self, shard: gateway_shard.GatewayShard, payload: await self.dispatch(event) + @event_manager_base.filtered(role_events.RoleDeleteEvent, config.CacheComponents.ROLES) async def on_guild_role_delete(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-role-delete for more info.""" old: typing.Optional[guilds.Role] = None @@ -349,6 +475,7 @@ async def on_guild_role_delete(self, shard: gateway_shard.GatewayShard, payload: await self.dispatch(event) + @event_manager_base.filtered(channel_events.InviteCreateEvent, config.CacheComponents.INVITES) async def on_invite_create(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#invite-create for more info.""" event = self._event_factory.deserialize_invite_create_event(shard, payload) @@ -358,6 +485,7 @@ async def on_invite_create(self, shard: gateway_shard.GatewayShard, payload: dat await self.dispatch(event) + @event_manager_base.filtered(channel_events.InviteDeleteEvent, config.CacheComponents.INVITES) async def on_invite_delete(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#invite-delete for more info.""" old: typing.Optional[invites.InviteWithMetadata] = None @@ -368,6 +496,9 @@ async def on_invite_delete(self, shard: gateway_shard.GatewayShard, payload: dat await self.dispatch(event) + @event_manager_base.filtered( + (message_events.GuildMessageCreateEvent, message_events.DMMessageCreateEvent), config.CacheComponents.MESSAGES + ) async def on_message_create(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#message-create for more info.""" event = self._event_factory.deserialize_message_create_event(shard, payload) @@ -377,6 +508,9 @@ async def on_message_create(self, shard: gateway_shard.GatewayShard, payload: da await self.dispatch(event) + @event_manager_base.filtered( + (message_events.GuildMessageUpdateEvent, message_events.DMMessageUpdateEvent), config.CacheComponents.MESSAGES + ) async def on_message_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#message-update for more info.""" old = self._cache.get_message(snowflakes.Snowflake(payload["id"])) if self._cache else None @@ -387,6 +521,9 @@ async def on_message_update(self, shard: gateway_shard.GatewayShard, payload: da await self.dispatch(event) + @event_manager_base.filtered( + (message_events.GuildMessageDeleteEvent, message_events.DMMessageDeleteEvent), config.CacheComponents.MESSAGES + ) async def on_message_delete(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#message-delete for more info.""" if self._cache: @@ -399,6 +536,9 @@ async def on_message_delete(self, shard: gateway_shard.GatewayShard, payload: da await self.dispatch(event) + @event_manager_base.filtered( + (message_events.GuildMessageDeleteEvent, message_events.DMMessageDeleteEvent), config.CacheComponents.MESSAGES + ) async def on_message_delete_bulk(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#message-delete-bulk for more info.""" old_messages = {} @@ -414,6 +554,7 @@ async def on_message_delete_bulk(self, shard: gateway_shard.GatewayShard, payloa self._event_factory.deserialize_guild_message_delete_bulk_event(shard, payload, old_messages=old_messages) ) + @event_manager_base.filtered((reaction_events.GuildReactionAddEvent, reaction_events.DMReactionAddEvent)) async def on_message_reaction_add( self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject ) -> None: @@ -422,27 +563,35 @@ async def on_message_reaction_add( # TODO: this is unlikely but reaction cache? + @event_manager_base.filtered((reaction_events.GuildReactionDeleteEvent, reaction_events.DMReactionDeleteEvent)) async def on_message_reaction_remove( self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject ) -> None: """See https://discord.com/developers/docs/topics/gateway#message-reaction-remove for more info.""" await self.dispatch(self._event_factory.deserialize_message_reaction_remove_event(shard, payload)) + @event_manager_base.filtered( + (reaction_events.GuildReactionDeleteAllEvent, reaction_events.DMReactionDeleteAllEvent) + ) async def on_message_reaction_remove_all( self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject ) -> None: """See https://discord.com/developers/docs/topics/gateway#message-reaction-remove-all for more info.""" await self.dispatch(self._event_factory.deserialize_message_reaction_remove_all_event(shard, payload)) + @event_manager_base.filtered( + (reaction_events.GuildReactionDeleteEmojiEvent, reaction_events.DMReactionDeleteEmojiEvent) + ) async def on_message_reaction_remove_emoji( self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject ) -> None: """See https://discord.com/developers/docs/topics/gateway#message-reaction-remove-emoji for more info.""" await self.dispatch(self._event_factory.deserialize_message_reaction_remove_emoji_event(shard, payload)) + @event_manager_base.filtered(guild_events.PresenceUpdateEvent, config.CacheComponents.PRESENCES) async def on_presence_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#presence-update for more info.""" - old: typing.Optional[presences.MemberPresence] = None + old: typing.Optional[presences_.MemberPresence] = None if self._cache: old = self._cache.get_presence( snowflakes.Snowflake(payload["guild_id"]), snowflakes.Snowflake(payload["user"]["id"]) @@ -450,7 +599,7 @@ async def on_presence_update(self, shard: gateway_shard.GatewayShard, payload: d event = self._event_factory.deserialize_presence_update_event(shard, payload, old_presence=old) - if self._cache and event.presence.visible_status is presences.Status.OFFLINE: + if self._cache and event.presence.visible_status is presences_.Status.OFFLINE: self._cache.delete_presence(event.presence.guild_id, event.presence.user_id) elif self._cache: self._cache.update_presence(event.presence) @@ -458,10 +607,12 @@ async def on_presence_update(self, shard: gateway_shard.GatewayShard, payload: d # TODO: update user here when partial_user is set self._cache.update_user(event.partial_user) await self.dispatch(event) + @event_manager_base.filtered((typing_events.GuildTypingEvent, typing_events.DMTypingEvent)) async def on_typing_start(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#typing-start for more info.""" await self.dispatch(self._event_factory.deserialize_typing_start_event(shard, payload)) + @event_manager_base.filtered(user_events.OwnUserUpdateEvent, config.CacheComponents.ME) async def on_user_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#user-update for more info.""" old = self._cache.get_me() if self._cache else None @@ -472,6 +623,7 @@ async def on_user_update(self, shard: gateway_shard.GatewayShard, payload: data_ await self.dispatch(event) + @event_manager_base.filtered(voice_events.VoiceStateUpdateEvent, config.CacheComponents.VOICE_STATES) async def on_voice_state_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#voice-state-update for more info.""" old: typing.Optional[voices.VoiceState] = None @@ -489,10 +641,12 @@ async def on_voice_state_update(self, shard: gateway_shard.GatewayShard, payload await self.dispatch(event) + @event_manager_base.filtered(voice_events.VoiceServerUpdateEvent) async def on_voice_server_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#voice-server-update for more info.""" await self.dispatch(self._event_factory.deserialize_voice_server_update_event(shard, payload)) + @event_manager_base.filtered(channel_events.WebhookUpdateEvent) async def on_webhooks_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#webhooks-update for more info.""" await self.dispatch(self._event_factory.deserialize_webhook_update_event(shard, payload)) diff --git a/hikari/impl/event_manager_base.py b/hikari/impl/event_manager_base.py index c70d04a771..f1c275b92b 100644 --- a/hikari/impl/event_manager_base.py +++ b/hikari/impl/event_manager_base.py @@ -24,21 +24,29 @@ from __future__ import annotations -__all__: typing.List[str] = ["EventManagerBase", "EventStream"] +__all__: typing.List[str] = ["filtered", "EventManagerBase", "EventStream"] import asyncio import inspect +import itertools import logging import typing import warnings import weakref +import attr + +from hikari import config from hikari import errors from hikari import iterators +from hikari import undefined from hikari.api import event_manager as event_manager_ from hikari.events import base_events +from hikari.events import shard_events from hikari.internal import aio +from hikari.internal import fast_protocol from hikari.internal import reflect +from hikari.internal import ux if typing.TYPE_CHECKING: import types @@ -54,19 +62,37 @@ ConsumerT = typing.Callable[ [gateway_shard.GatewayShard, data_binding.JSONObject], typing.Coroutine[typing.Any, typing.Any, None] ] - ListenerMapT = typing.MutableMapping[ + ListenerMapT = typing.Dict[ typing.Type[event_manager_.EventT_co], - typing.MutableSequence[event_manager_.CallbackT[event_manager_.EventT_co]], + typing.List[event_manager_.CallbackT[event_manager_.EventT_co]], ] WaiterT = typing.Tuple[ event_manager_.PredicateT[event_manager_.EventT_co], asyncio.Future[event_manager_.EventT_co] ] - WaiterMapT = typing.MutableMapping[ - typing.Type[event_manager_.EventT_co], typing.MutableSet[WaiterT[event_manager_.EventT_co]] + WaiterMapT = typing.Dict[typing.Type[event_manager_.EventT_co], typing.Set[WaiterT[event_manager_.EventT_co]]] + + EventManagerBaseT = typing.TypeVar("EventManagerBaseT", bound="EventManagerBase") + UnboundMethodT = typing.Callable[ + [EventManagerBaseT, gateway_shard.GatewayShard, data_binding.JSONObject], + typing.Coroutine[typing.Any, typing.Any, None], ] _EventStreamT = typing.TypeVar("_EventStreamT", bound="EventStream[typing.Any]") +@typing.runtime_checkable +class _FilteredMethodT(fast_protocol.FastProtocolChecking, typing.Protocol): + async def __call__(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject, /) -> None: + raise NotImplementedError + + @property + def __cache_components__(self) -> config.CacheComponents: + raise NotImplementedError + + @property + def __event_types__(self) -> typing.Sequence[typing.Type[base_events.Event]]: + raise NotImplementedError + + def _generate_weak_listener( reference: weakref.WeakMethod[typing.Any], ) -> typing.Callable[[event_manager_.EventT], typing.Coroutine[typing.Any, typing.Any, None]]: @@ -254,6 +280,53 @@ def _assert_is_listener(parameters: typing.Iterator[inspect.Parameter], /) -> No raise TypeError("Only the first argument for a listener can be required, the event argument.") +def filtered( + event_types: typing.Union[typing.Type[base_events.Event], typing.Sequence[typing.Type[base_events.Event]]], + cache_components: config.CacheComponents = config.CacheComponents.NONE, + /, +) -> typing.Callable[[UnboundMethodT[EventManagerBaseT]], UnboundMethodT[EventManagerBaseT]]: + """Add metadata to a consumer method to indicate when it should be unmarshalled and dispatched. + + Parameters + ---------- + event_types + Types of the events this raw consumer method may dispatch. + This may either be a singular type of a sequence of types. + + Other Parameters + ---------------- + cache_components : hikari.config.CacheComponents + Bitfield of the cache components this event may make altering calls to. + This defaults to `hikari.config.CacheComponents.NONE`. + """ + if isinstance(event_types, typing.Sequence): + # dict.fromkeys is used to remove any duplicate entries here + event_types = tuple(dict.fromkeys(itertools.chain.from_iterable(e.dispatches() for e in event_types))) + + else: + event_types = event_types.dispatches() + + def decorator(method: UnboundMethodT[EventManagerBaseT], /) -> UnboundMethodT[EventManagerBaseT]: + method.__cache_components__ = cache_components # type: ignore[attr-defined] + method.__event_types__ = event_types # type: ignore[attr-defined] + assert isinstance(method, _FilteredMethodT), "Incorrect attribute(s) set for a filtered method" + return method # type: ignore[unreachable] + + return decorator + + +@attr.define(hash=True) +class _Consumer: + callback: ConsumerT = attr.ib(hash=True) + """The callback function for this consumer.""" + + event_types: undefined.UndefinedOr[typing.Sequence[typing.Type[base_events.Event]]] = attr.ib(hash=False) + """A sequence of the types of events this consumer dispatches to, if set.""" + + is_caching: bool = attr.ib(hash=False) + """Cached value of whether or not this consumer is making cache calls in the current env.""" + + class EventManagerBase(event_manager_.EventManager): """Provides functionality to consume and dispatch events. @@ -261,10 +334,24 @@ class EventManagerBase(event_manager_.EventManager): is the raw event name being dispatched in lower-case. """ - __slots__: typing.Sequence[str] = ("_event_factory", "_intents", "_listeners", "_consumers", "_waiters") + __slots__: typing.Sequence[str] = ( + "_consumers", + "_enabled_consumers_cache", + "_event_factory", + "_intents", + "_listeners", + "_waiters", + ) - def __init__(self, event_factory: event_factory_.EventFactory, intents: intents_.Intents) -> None: - self._consumers: typing.Dict[str, ConsumerT] = {} + def __init__( + self, + event_factory: event_factory_.EventFactory, + intents: intents_.Intents, + *, + cache_components: config.CacheComponents = config.CacheComponents.NONE, + ) -> None: + self._consumers: typing.Dict[str, _Consumer] = {} + self._enabled_consumers_cache: typing.Dict[_Consumer, bool] = {} self._event_factory = event_factory self._intents = intents self._listeners: ListenerMapT[base_events.Event] = {} @@ -272,15 +359,52 @@ def __init__(self, event_factory: event_factory_.EventFactory, intents: intents_ for name, member in inspect.getmembers(self): if name.startswith("on_"): - self._consumers[name[3:]] = member + event_name = name[3:] + if isinstance(member, _FilteredMethodT): + caching = (member.__cache_components__ & cache_components) != 0 + self._consumers[event_name] = _Consumer(member, member.__event_types__, caching) + + else: + self._consumers[event_name] = _Consumer( + member, undefined.UNDEFINED, cache_components != cache_components.NONE + ) + + def _clear_enabled_cache(self) -> None: + self._enabled_consumers_cache = {} + + def _enabled_for_event(self, event_type: typing.Type[base_events.Event], /) -> bool: + for cls in event_type.dispatches(): + if cls in self._listeners or cls in self._waiters: + return True + + return False + + def _enabled_for_consumer(self, consumer: _Consumer, /) -> bool: + # If undefined then we can only assume that this may link to registered listeners. + if consumer.event_types is undefined.UNDEFINED or consumer.is_caching: + return True + + if (cached_value := self._enabled_consumers_cache.get(consumer)) is not None: + return cached_value + + # The behaviour here where an empty sequence for event_types will lead to this always + # being skipped unless there's a relevant enabled cache resource is intended behaviour. + for event_type in consumer.event_types: + if event_type in self._listeners or event_type in self._waiters: + self._enabled_consumers_cache[consumer] = True + return True + + self._enabled_consumers_cache[consumer] = False + return False def consume_raw_event( self, event_name: str, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject ) -> None: - payload_event = self._event_factory.deserialize_shard_payload_event(shard, payload, name=event_name) - self.dispatch(payload_event) - callback = self._consumers[event_name.casefold()] - asyncio.create_task(self._handle_dispatch(callback, shard, payload), name=f"dispatch {event_name}") + if self._enabled_for_event(shard_events.ShardPayloadEvent): + payload_event = self._event_factory.deserialize_shard_payload_event(shard, payload, name=event_name) + self.dispatch(payload_event) + consumer = self._consumers[event_name.lower()] + asyncio.create_task(self._handle_dispatch(consumer, shard, payload), name=f"dispatch {event_name}") def subscribe( self, @@ -299,9 +423,6 @@ def subscribe( # warning is triggered. self._check_intents(event_type, _nested) - if event_type not in self._listeners: - self._listeners[event_type] = [] - _LOGGER.debug( "subscribing callback 'async def %s%s' to event-type %s.%s", getattr(callback, "__name__", ""), @@ -310,7 +431,11 @@ def subscribe( event_type.__qualname__, ) - self._listeners[event_type].append(callback) # type: ignore[arg-type] + try: + self._listeners[event_type].append(callback) # type: ignore[arg-type] + except KeyError: + self._listeners[event_type] = [callback] # type: ignore[list-item] + self._clear_enabled_cache() def _check_intents(self, event_type: typing.Type[event_manager_.EventT_co], nested: int) -> None: # Collection of combined bitfield combinations of intents that @@ -344,10 +469,10 @@ def get_listeners( if issubclass(subscribed_event_type, event_type): listeners += subscribed_listeners return listeners + else: - items = self._listeners.get(event_type) - if items is not None: - return items[:] + if items := self._listeners.get(event_type): + return items.copy() return [] @@ -356,7 +481,7 @@ def unsubscribe( event_type: typing.Type[event_manager_.EventT_co], callback: event_manager_.CallbackT[event_manager_.EventT_co], ) -> None: - if event_type in self._listeners: + if listeners := self._listeners.get(event_type): _LOGGER.debug( "unsubscribing callback %s%s from event-type %s.%s", getattr(callback, "__name__", ""), @@ -364,9 +489,10 @@ def unsubscribe( event_type.__module__, event_type.__qualname__, ) - self._listeners[event_type].remove(callback) # type: ignore[arg-type] - if not self._listeners[event_type]: + listeners.remove(callback) # type: ignore[arg-type] + if not listeners: del self._listeners[event_type] + self._clear_enabled_cache() def listen( self, @@ -404,16 +530,12 @@ def dispatch(self, event: event_manager_.EventT_inv) -> asyncio.Future[typing.An if not isinstance(event, base_events.Event): raise TypeError(f"Events must be subclasses of {base_events.Event.__name__}, not {type(event).__name__}") - # We only need to iterate through the MRO until we hit Event, as - # anything after that is random garbage we don't care about, as they do - # not describe event types. This improves efficiency as well. - mro = type(event).mro() - tasks: typing.List[typing.Coroutine[None, typing.Any, None]] = [] + clear_cache = False - for cls in mro[: mro.index(base_events.Event) + 1]: - if cls in self._listeners: - for callback in self._listeners[cls]: + for cls in event.dispatches(): + if listeners := self._listeners.get(cls): + for callback in listeners: tasks.append(self._invoke_callback(callback, event)) if cls not in self._waiters: @@ -434,6 +556,13 @@ def dispatch(self, event: event_manager_.EventT_inv) -> asyncio.Future[typing.An waiter_set.remove(waiter) + if not waiter_set: + del self._waiters[cls] + clear_cache = True + + if clear_cache: + self._clear_enabled_cache() + return asyncio.gather(*tasks) if tasks else aio.completed_future() def stream( @@ -453,7 +582,6 @@ async def wait_for( timeout: typing.Union[float, int, None], predicate: typing.Optional[event_manager_.PredicateT[event_manager_.EventT_co]] = None, ) -> event_manager_.EventT_co: - if predicate is None: predicate = _default_predicate @@ -464,6 +592,7 @@ async def wait_for( try: waiter_set = self._waiters[event_type] except KeyError: + self._clear_enabled_cache() waiter_set = set() self._waiters[event_type] = waiter_set @@ -474,16 +603,27 @@ async def wait_for( return await asyncio.wait_for(future, timeout=timeout) except asyncio.TimeoutError: waiter_set.remove(pair) # type: ignore[arg-type] + if not waiter_set: + del self._waiters[event_type] + self._clear_enabled_cache() + raise - @staticmethod async def _handle_dispatch( - callback: ConsumerT, + self, + consumer: _Consumer, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject, ) -> None: + if not self._enabled_for_consumer(consumer): + name = consumer.callback.__name__ + _LOGGER.log( + ux.TRACE, "Skipping raw dispatch for %s due to lack of any registered listeners or cache need", name + ) + return + try: - await callback(shard, payload) + await consumer.callback(shard, payload) except asyncio.CancelledError: # Skip cancelled errors, likely caused by the event loop being shut down. pass diff --git a/tests/hikari/impl/test_bot.py b/tests/hikari/impl/test_bot.py index dfacbf19a2..29cce434d2 100644 --- a/tests/hikari/impl/test_bot.py +++ b/tests/hikari/impl/test_bot.py @@ -204,7 +204,9 @@ def test_init(self): assert bot._cache is cache.return_value cache.assert_called_once_with(bot, cache_settings) assert bot._event_manager is event_manager.return_value - event_manager.assert_called_once_with(event_factory.return_value, intents, cache=cache.return_value) + event_manager.assert_called_once_with( + entity_factory.return_value, event_factory.return_value, intents, cache=cache.return_value + ) assert bot._entity_factory is entity_factory.return_value entity_factory.assert_called_once_with(bot) assert bot._event_factory is event_factory.return_value diff --git a/tests/hikari/impl/test_cache.py b/tests/hikari/impl/test_cache.py index ed6b829fe1..b9c69cea6e 100644 --- a/tests/hikari/impl/test_cache.py +++ b/tests/hikari/impl/test_cache.py @@ -1472,6 +1472,13 @@ def test_set_me(self, cache_impl): assert cache_impl._me == mock_own_user assert cache_impl._me is not mock_own_user + def test_set_me_when_not_enabled(self, cache_impl): + cache_impl._settings.components = 0 + + cache_impl.set_me(object()) + + assert cache_impl._me is None + def test_update_me_for_cached_me(self, cache_impl): mock_cached_own_user = mock.MagicMock(users.OwnUser) mock_own_user = mock.MagicMock(users.OwnUser) @@ -1490,6 +1497,17 @@ def test_update_me_for_uncached_me(self, cache_impl): assert result == (None, mock_own_user) assert cache_impl._me == mock_own_user + def test_update_me_for_when_not_enabled(self, cache_impl): + cache_impl._settings.components = 0 + cache_impl.get_me = mock.Mock() + cache_impl.set_me = mock.Mock() + + result = cache_impl.update_me(object()) + + assert result == (None, None) + cache_impl.get_me.assert_not_called() + cache_impl.set_me.assert_not_called() + def test__build_member(self, cache_impl): mock_user = mock.MagicMock(users.User) member_data = cache_utilities.MemberData( diff --git a/tests/hikari/impl/test_entity_factory.py b/tests/hikari/impl/test_entity_factory.py index 85c9a3c065..47be95836f 100644 --- a/tests/hikari/impl/test_entity_factory.py +++ b/tests/hikari/impl/test_entity_factory.py @@ -53,6 +53,198 @@ from hikari.interactions import base_interactions from hikari.interactions import command_interactions from hikari.interactions import component_interactions +from tests.hikari import hikari_test_helpers + + +@pytest.fixture() +def mock_app() -> traits.RESTAware: + return mock.MagicMock(traits.RESTAware) + + +@pytest.fixture() +def permission_overwrite_payload(): + return {"id": "4242", "type": 1, "allow": 65, "deny": 49152, "allow_new": "65", "deny_new": "49152"} + + +@pytest.fixture() +def guild_text_channel_payload(permission_overwrite_payload): + return { + "id": "123", + "guild_id": "567", + "name": "general", + "type": 0, + "position": 6, + "permission_overwrites": [permission_overwrite_payload], + "rate_limit_per_user": 2, + "nsfw": True, + "topic": "¯\\_(ツ)_/¯", + "last_message_id": "123456", + "last_pin_timestamp": "2020-05-27T15:58:51.545252+00:00", + "parent_id": "987", + } + + +@pytest.fixture() +def guild_voice_channel_payload(permission_overwrite_payload): + return { + "id": "555", + "guild_id": "789", + "name": "Secret Developer Discussions", + "type": 2, + "nsfw": True, + "position": 4, + "permission_overwrites": [permission_overwrite_payload], + "bitrate": 64000, + "user_limit": 3, + "rtc_region": "europe", + "parent_id": "456", + "video_quality_mode": 1, + } + + +@pytest.fixture() +def guild_news_channel_payload(permission_overwrite_payload): + return { + "id": "7777", + "guild_id": "123", + "name": "Important Announcements", + "type": 5, + "position": 0, + "permission_overwrites": [permission_overwrite_payload], + "nsfw": True, + "topic": "Super Important Announcements", + "last_message_id": "456", + "parent_id": "654", + "last_pin_timestamp": "2020-05-27T15:58:51.545252+00:00", + } + + +@pytest.fixture() +def user_payload(): + return { + "id": "115590097100865541", + "username": "nyaa", + "avatar": "b3b24c6d7cbcdec129d5d537067061a8", + "banner": "a_221313e1e2edsncsncsmcndsc", + "accent_color": 231321, + "discriminator": "6127", + "bot": True, + "system": True, + "public_flags": int(user_models.UserFlag.EARLY_VERIFIED_DEVELOPER), + } + + +@pytest.fixture() +def custom_emoji_payload(): + return {"id": "691225175349395456", "name": "test", "animated": True} + + +@pytest.fixture() +def known_custom_emoji_payload(user_payload): + return { + "id": "12345", + "name": "testing", + "animated": False, + "available": True, + "roles": ["123", "456"], + "user": user_payload, + "require_colons": True, + "managed": False, + } + + +@pytest.fixture() +def member_payload(user_payload): + return { + "nick": "foobarbaz", + "roles": ["11111", "22222", "33333", "44444"], + "joined_at": "2015-04-26T06:26:56.936000+00:00", + "premium_since": "2019-05-17T06:26:56.936000+00:00", + "avatar": "estrogen", + "deaf": False, + "mute": True, + "pending": False, + "user": user_payload, + "communication_disabled_until": "2021-10-18T06:26:56.936000+00:00", + } + + +@pytest.fixture() +def presence_activity_payload(custom_emoji_payload): + return { + "name": "an activity", + "type": 1, + "url": "https://69.420.owouwunyaa", + "created_at": 1584996792798, + "timestamps": {"start": 1584996792798, "end": 1999999792798}, + "application_id": "40404040404040", + "details": "They are doing stuff", + "state": "STATED", + "emoji": custom_emoji_payload, + "party": {"id": "spotify:3234234234", "size": [2, 5]}, + "assets": { + "large_image": "34234234234243", + "large_text": "LARGE TEXT", + "small_image": "3939393", + "small_text": "small text", + }, + "secrets": {"join": "who's a good secret?", "spectate": "I'm a good secret", "match": "No."}, + "instance": True, + "flags": 3, + "buttons": ["owo", "no"], + } + + +@pytest.fixture() +def member_presence_payload(user_payload, presence_activity_payload): + return { + "user": user_payload, + "activity": presence_activity_payload, + "guild_id": "44004040", + "status": "dnd", + "activities": [presence_activity_payload], + "client_status": {"desktop": "online", "mobile": "idle", "web": "dnd"}, + } + + +@pytest.fixture() +def guild_role_payload(): + return { + "id": "41771983423143936", + "name": "WE DEM BOYZZ!!!!!!", + "color": 3_447_003, + "hoist": True, + "unicode_emoji": "\N{OK HAND SIGN}", + "icon": "abc123hash", + "position": 0, + "permissions": "66321471", + "managed": False, + "mentionable": False, + "tags": { + "bot_id": "123", + "integration_id": "456", + "premium_subscriber": None, + }, + } + + +@pytest.fixture() +def voice_state_payload(member_payload): + return { + "guild_id": "929292929292992", + "channel_id": "157733188964188161", + "user_id": "115590097100865541", + "member": member_payload, + "session_id": "90326bd25d71d39b9ef95b299e3872ff", + "deaf": True, + "mute": True, + "self_deaf": False, + "self_mute": True, + "self_stream": True, + "self_video": True, + "suppress": False, + "request_to_speak_timestamp": "2021-04-17T10:11:19.970105+00:00", + } def test__with_int_cast(): @@ -89,11 +281,364 @@ def test__deserialize_max_age_returns_null(): assert entity_factory._deserialize_max_age(0) is None -class TestEntityFactoryImpl: +class TestGatewayGuildDefinition: @pytest.fixture() - def mock_app(self) -> traits.RESTAware: - return mock.MagicMock(traits.RESTAware) + def entity_factory_impl(self, mock_app) -> entity_factory.EntityFactoryImpl: + return hikari_test_helpers.mock_class_namespace(entity_factory.EntityFactoryImpl, slots_=False)(mock_app) + + def test_id_property(self, entity_factory_impl): + guild_definition = entity_factory_impl.deserialize_gateway_guild({"id": "123123451234"}) + + assert guild_definition.id == 123123451234 + + def test_channels( + self, entity_factory_impl, guild_text_channel_payload, guild_voice_channel_payload, guild_news_channel_payload + ): + guild_definition = entity_factory_impl.deserialize_gateway_guild( + { + "id": "265828729970753537", + "channels": [guild_text_channel_payload, guild_voice_channel_payload, guild_news_channel_payload], + } + ) + + assert guild_definition.channels() == { + 123: entity_factory_impl.deserialize_guild_text_channel( + guild_text_channel_payload, guild_id=snowflakes.Snowflake(265828729970753537) + ), + 555: entity_factory_impl.deserialize_guild_voice_channel( + guild_voice_channel_payload, guild_id=snowflakes.Snowflake(265828729970753537) + ), + 7777: entity_factory_impl.deserialize_guild_news_channel( + guild_news_channel_payload, guild_id=snowflakes.Snowflake(265828729970753537) + ), + } + + def test_channels_returns_cached_values(self, entity_factory_impl): + guild_definition = entity_factory_impl.deserialize_gateway_guild({"id": "265828729970753537"}) + mock_channel = object() + guild_definition._channels = {"123321": mock_channel} + entity_factory_impl.deserialize_guild_text_channel = mock.Mock() + entity_factory_impl.deserialize_guild_voice_channel = mock.Mock() + entity_factory_impl.deserialize_guild_news_channel = mock.Mock() + + assert guild_definition.channels() == {"123321": mock_channel} + + entity_factory_impl.deserialize_guild_text_channel.assert_not_called() + entity_factory_impl.deserialize_guild_voice_channel.assert_not_called() + entity_factory_impl.deserialize_guild_news_channel.assert_not_called() + + def test_channels_ignores_unrecognised_channels(self, entity_factory_impl): + guild_definition = entity_factory_impl.deserialize_gateway_guild( + {"id": "9494949", "channels": [{"id": 123, "type": 1000}]} + ) + + assert guild_definition.channels() == {} + + def test_emojis(self, entity_factory_impl, known_custom_emoji_payload): + guild_definition = entity_factory_impl.deserialize_gateway_guild( + {"id": "265828729970753537", "emojis": [known_custom_emoji_payload]}, + ) + + assert guild_definition.emojis() == { + 12345: entity_factory_impl.deserialize_known_custom_emoji( + known_custom_emoji_payload, guild_id=snowflakes.Snowflake(265828729970753537) + ) + } + + def test_emojis_returns_cached_values(self, entity_factory_impl): + mock_emoji = object() + entity_factory_impl.deserialize_known_custom_emoji = mock.Mock() + guild_definition = entity_factory_impl.deserialize_gateway_guild({"id": "265828729970753537"}) + guild_definition._emojis = {"21323232": mock_emoji} + + assert guild_definition.emojis() == {"21323232": mock_emoji} + + entity_factory_impl.deserialize_known_custom_emoji.assert_not_called() + + def test_guild(self, entity_factory_impl, mock_app): + guild_definition = entity_factory_impl.deserialize_gateway_guild( + { + "afk_channel_id": "99998888777766", + "afk_timeout": 1200, + "application_id": "39494949", + "banner": "1a2b3c", + "default_message_notifications": 1, + "description": "This is a server I guess, its a bit crap though", + "discovery_splash": "famfamFAMFAMfam", + "embed_channel_id": "9439394949", + "embed_enabled": True, + "explicit_content_filter": 2, + "features": ["ANIMATED_ICON", "MORE_EMOJI", "NEWS", "SOME_UNDOCUMENTED_FEATURE"], + "icon": "1a2b3c4d", + "id": "265828729970753537", + "joined_at": "2019-05-17T06:26:56.936000+00:00", + "large": False, + "max_members": 25000, + "max_presences": 250, + "max_video_channel_users": 25, + "member_count": 14, + "mfa_level": 1, + "name": "L33t guild", + "owner_id": "6969696", + "preferred_locale": "en-GB", + "premium_subscription_count": 1, + "premium_tier": 2, + "public_updates_channel_id": "33333333", + "rules_channel_id": "42042069", + "splash": "0ff0ff0ff", + "system_channel_flags": 3, + "system_channel_id": "19216801", + "unavailable": False, + "vanity_url_code": "loool", + "verification_level": 4, + "widget_channel_id": "9439394949", + "widget_enabled": True, + "nsfw_level": 0, + } + ) + guild = guild_definition.guild() + assert guild.app is mock_app + assert guild.id == 265828729970753537 + assert guild.name == "L33t guild" + assert guild.icon_hash == "1a2b3c4d" + assert guild.features == [ + guild_models.GuildFeature.ANIMATED_ICON, + guild_models.GuildFeature.MORE_EMOJI, + guild_models.GuildFeature.NEWS, + "SOME_UNDOCUMENTED_FEATURE", + ] + assert guild.splash_hash == "0ff0ff0ff" + assert guild.discovery_splash_hash == "famfamFAMFAMfam" + assert guild.owner_id == 6969696 + assert guild.afk_channel_id == 99998888777766 + assert guild.afk_timeout == datetime.timedelta(seconds=1200) + assert guild.verification_level == guild_models.GuildVerificationLevel.VERY_HIGH + assert guild.default_message_notifications == guild_models.GuildMessageNotificationsLevel.ONLY_MENTIONS + assert guild.explicit_content_filter == guild_models.GuildExplicitContentFilterLevel.ALL_MEMBERS + assert guild.mfa_level == guild_models.GuildMFALevel.ELEVATED + assert guild.application_id == 39494949 + assert guild.widget_channel_id == 9439394949 + assert guild.is_widget_enabled is True + assert guild.system_channel_id == 19216801 + assert guild.system_channel_flags == guild_models.GuildSystemChannelFlag(3) + assert guild.rules_channel_id == 42042069 + assert guild.joined_at == datetime.datetime(2019, 5, 17, 6, 26, 56, 936000, tzinfo=datetime.timezone.utc) + assert guild.is_large is False + assert guild.member_count == 14 + assert guild.max_video_channel_users == 25 + assert guild.vanity_url_code == "loool" + assert guild.description == "This is a server I guess, its a bit crap though" + assert guild.banner_hash == "1a2b3c" + assert guild.premium_tier == guild_models.GuildPremiumTier.TIER_2 + assert guild.premium_subscription_count == 1 + assert guild.preferred_locale == "en-GB" + assert guild.public_updates_channel_id == 33333333 + assert guild.nsfw_level == guild_models.GuildNSFWLevel.DEFAULT + + def test_guild_with_unset_fields(self, entity_factory_impl): + guild_definition = entity_factory_impl.deserialize_gateway_guild( + { + "afk_channel_id": "99998888777766", + "afk_timeout": 1200, + "application_id": "39494949", + "banner": "1a2b3c", + "default_message_notifications": 1, + "description": "This is a server I guess, its a bit crap though", + "discovery_splash": "famfamFAMFAMfam", + "emojis": [], + "explicit_content_filter": 2, + "features": ["ANIMATED_ICON", "MORE_EMOJI", "NEWS", "SOME_UNDOCUMENTED_FEATURE"], + "icon": "1a2b3c4d", + "id": "265828729970753537", + "mfa_level": 1, + "name": "L33t guild", + "owner_id": "6969696", + "preferred_locale": "en-GB", + "premium_tier": 2, + "public_updates_channel_id": "33333333", + "roles": [], + "rules_channel_id": "42042069", + "splash": "0ff0ff0ff", + "system_channel_flags": 3, + "system_channel_id": "19216801", + "vanity_url_code": "loool", + "verification_level": 4, + "nsfw_level": 0, + }, + ) + guild = guild_definition.guild() + assert guild.joined_at is None + assert guild.is_large is None + assert guild.max_video_channel_users is None + assert guild.member_count is None + assert guild.premium_subscription_count is None + assert guild.widget_channel_id is None + assert guild.is_widget_enabled is None + + def test_guild_with_null_fields(self, entity_factory_impl): + guild_definition = entity_factory_impl.deserialize_gateway_guild( + { + "afk_channel_id": None, + "afk_timeout": 1200, + "application_id": None, + "banner": None, + "channels": [], + "default_message_notifications": 1, + "description": None, + "discovery_splash": None, + "embed_channel_id": None, + "embed_enabled": True, + "emojis": [], + "explicit_content_filter": 2, + "features": ["ANIMATED_ICON", "MORE_EMOJI", "NEWS", "SOME_UNDOCUMENTED_FEATURE"], + "icon": None, + "id": "265828729970753537", + "joined_at": "2019-05-17T06:26:56.936000+00:00", + "large": False, + "max_members": 25000, + "max_presences": None, + "max_video_channel_users": 25, + "member_count": 14, + "members": [], + "mfa_level": 1, + "name": "L33t guild", + "owner_id": "6969696", + "permissions": 66_321_471, + "preferred_locale": "en-GB", + "premium_subscription_count": None, + "premium_tier": 2, + "presences": [], + "public_updates_channel_id": None, + "roles": [], + "rules_channel_id": None, + "splash": None, + "system_channel_flags": 3, + "system_channel_id": None, + "unavailable": False, + "vanity_url_code": None, + "verification_level": 4, + "voice_states": [], + "widget_channel_id": None, + "widget_enabled": True, + "nsfw_level": 0, + }, + ) + guild = guild_definition.guild() + assert guild.icon_hash is None + assert guild.splash_hash is None + assert guild.discovery_splash_hash is None + assert guild.afk_channel_id is None + assert guild.application_id is None + assert guild.widget_channel_id is None + assert guild.system_channel_id is None + assert guild.rules_channel_id is None + assert guild.vanity_url_code is None + assert guild.description is None + assert guild.banner_hash is None + assert guild.premium_subscription_count is None + assert guild.public_updates_channel_id is None + + def test_guild_returns_cached_values(self, entity_factory_impl): + mock_guild = object() + entity_factory_impl.set_guild_attributes = mock.Mock() + guild_definition = entity_factory_impl.deserialize_gateway_guild({"id": "9393939"}) + guild_definition._guild = mock_guild + + assert guild_definition.guild() is mock_guild + + entity_factory_impl.set_guild_attributes.assert_not_called() + + def test_members(self, entity_factory_impl, member_payload): + guild_definition = entity_factory_impl.deserialize_gateway_guild( + {"id": "265828729970753537", "members": [member_payload]} + ) + + assert guild_definition.members() == { + 115590097100865541: entity_factory_impl.deserialize_member( + member_payload, guild_id=snowflakes.Snowflake(265828729970753537) + ) + } + + def test_members_returns_cached_values(self, entity_factory_impl): + mock_member = object() + entity_factory_impl.deserialize_member = mock.Mock() + guild_definition = entity_factory_impl.deserialize_gateway_guild({"id": "92929292"}) + guild_definition._members = {"93939393": mock_member} + + assert guild_definition.members() == {"93939393": mock_member} + + entity_factory_impl.deserialize_member.assert_not_called() + + def test_presences(self, entity_factory_impl, member_presence_payload): + guild_definition = entity_factory_impl.deserialize_gateway_guild( + {"id": "265828729970753537", "presences": [member_presence_payload]} + ) + + assert guild_definition.presences() == { + 115590097100865541: entity_factory_impl.deserialize_member_presence( + member_presence_payload, guild_id=snowflakes.Snowflake(265828729970753537) + ) + } + + def test_presences_returns_cached_values(self, entity_factory_impl): + mock_presence = object() + entity_factory_impl.deserialize_member_presence = mock.Mock() + guild_definition = entity_factory_impl.deserialize_gateway_guild({"id": "29292992"}) + guild_definition._presences = {"3939393993": mock_presence} + + assert guild_definition.presences() == {"3939393993": mock_presence} + + entity_factory_impl.deserialize_member_presence.assert_not_called() + + def test_roles(self, entity_factory_impl, guild_role_payload): + guild_definition = entity_factory_impl.deserialize_gateway_guild( + {"id": "265828729970753537", "roles": [guild_role_payload]} + ) + + assert guild_definition.roles() == { + 41771983423143936: entity_factory_impl.deserialize_role( + guild_role_payload, guild_id=snowflakes.Snowflake(265828729970753537) + ) + } + + def test_roles_returns_cached_values(self, entity_factory_impl): + mock_role = object() + entity_factory_impl.deserialize_role = mock.Mock() + guild_definition = entity_factory_impl.deserialize_gateway_guild({"id": "9292929"}) + guild_definition._roles = {"32132123123": mock_role} + + assert guild_definition.roles() == {"32132123123": mock_role} + + entity_factory_impl.deserialize_role.assert_not_called() + + def test_voice_states(self, entity_factory_impl, member_payload, voice_state_payload): + guild_definition = entity_factory_impl.deserialize_gateway_guild( + {"id": "265828729970753537", "voice_states": [voice_state_payload], "members": [member_payload]} + ) + assert guild_definition.voice_states() == { + 115590097100865541: entity_factory_impl.deserialize_voice_state( + voice_state_payload, + guild_id=snowflakes.Snowflake(265828729970753537), + member=entity_factory_impl.deserialize_member( + member_payload, + guild_id=snowflakes.Snowflake(265828729970753537), + ), + ) + } + + def test_voice_states_returns_cached_values(self, entity_factory_impl): + mock_voice_state = object() + entity_factory_impl.deserialize_voice_state = mock.Mock() + guild_definition = entity_factory_impl.deserialize_gateway_guild({"id": "292929"}) + guild_definition._voice_states = {"9393939393": mock_voice_state} + + assert guild_definition.voice_states() == {"9393939393": mock_voice_state} + + entity_factory_impl.deserialize_voice_state.assert_not_called() + + +class TestEntityFactoryImpl: @pytest.fixture() def entity_factory_impl(self, mock_app) -> entity_factory.EntityFactoryImpl: return entity_factory.EntityFactoryImpl(app=mock_app) @@ -770,10 +1315,6 @@ def test_deserialize_channel_follow(self, entity_factory_impl, mock_app): assert follow.channel_id == 41231 assert follow.webhook_id == 939393 - @pytest.fixture() - def permission_overwrite_payload(self): - return {"id": "4242", "type": 1, "allow": 65, "deny": 49152, "allow_new": "65", "deny_new": "49152"} - @pytest.mark.parametrize("type", [0, 1]) def test_deserialize_permission_overwrite(self, entity_factory_impl, type): permission_overwrite_payload = { @@ -944,23 +1485,6 @@ def test_deserialize_guild_category_with_null_fields(self, entity_factory_impl, ) assert guild_category.parent_id is None - @pytest.fixture() - def guild_text_channel_payload(self, permission_overwrite_payload): - return { - "id": "123", - "guild_id": "567", - "name": "general", - "type": 0, - "position": 6, - "permission_overwrites": [permission_overwrite_payload], - "rate_limit_per_user": 2, - "nsfw": True, - "topic": "¯\\_(ツ)_/¯", - "last_message_id": "123456", - "last_pin_timestamp": "2020-05-27T15:58:51.545252+00:00", - "parent_id": "987", - } - def test_deserialize_guild_text_channel( self, entity_factory_impl, mock_app, guild_text_channel_payload, permission_overwrite_payload ): @@ -1024,22 +1548,6 @@ def test_deserialize_guild_text_channel_with_null_fields(self, entity_factory_im assert guild_text_channel.last_pin_timestamp is None assert guild_text_channel.parent_id is None - @pytest.fixture() - def guild_news_channel_payload(self, permission_overwrite_payload): - return { - "id": "7777", - "guild_id": "123", - "name": "Important Announcements", - "type": 5, - "position": 0, - "permission_overwrites": [permission_overwrite_payload], - "nsfw": True, - "topic": "Super Important Announcements", - "last_message_id": "456", - "parent_id": "654", - "last_pin_timestamp": "2020-05-27T15:58:51.545252+00:00", - } - def test_deserialize_guild_news_channel( self, entity_factory_impl, mock_app, guild_news_channel_payload, permission_overwrite_payload ): @@ -1158,23 +1666,6 @@ def test_deserialize_guild_store_channel_with_null_fields(self, entity_factory_i ) assert store_chanel.parent_id is None - @pytest.fixture() - def guild_voice_channel_payload(self, permission_overwrite_payload): - return { - "id": "555", - "guild_id": "789", - "name": "Secret Developer Discussions", - "type": 2, - "nsfw": True, - "position": 4, - "permission_overwrites": [permission_overwrite_payload], - "bitrate": 64000, - "user_limit": 3, - "rtc_region": "europe", - "parent_id": "456", - "video_quality_mode": 1, - } - def test_deserialize_guild_voice_channel( self, entity_factory_impl, mock_app, guild_voice_channel_payload, permission_overwrite_payload ): @@ -1685,10 +2176,6 @@ def test_deserialize_unicode_emoji(self, entity_factory_impl): assert emoji.name == "🤷" assert isinstance(emoji, emoji_models.UnicodeEmoji) - @pytest.fixture() - def custom_emoji_payload(self): - return {"id": "691225175349395456", "name": "test", "animated": True} - def test_deserialize_custom_emoji(self, entity_factory_impl, mock_app, custom_emoji_payload): emoji = entity_factory_impl.deserialize_custom_emoji(custom_emoji_payload) assert emoji.id == snowflakes.Snowflake(691225175349395456) @@ -1703,19 +2190,6 @@ def test_deserialize_custom_emoji_with_unset_and_null_fields( assert emoji.is_animated is False assert emoji.name is None - @pytest.fixture() - def known_custom_emoji_payload(self, user_payload): - return { - "id": "12345", - "name": "testing", - "animated": False, - "available": True, - "roles": ["123", "456"], - "user": user_payload, - "require_colons": True, - "managed": False, - } - def test_deserialize_known_custom_emoji( self, entity_factory_impl, mock_app, user_payload, known_custom_emoji_payload ): @@ -1877,21 +2351,6 @@ def test_serialize_welcome_channel_with_no_emoji(self, entity_factory_impl, mock assert result == {"channel_id": "4312312", "description": "meow2"} - @pytest.fixture() - def member_payload(self, user_payload): - return { - "nick": "foobarbaz", - "roles": ["11111", "22222", "33333", "44444"], - "joined_at": "2015-04-26T06:26:56.936000+00:00", - "premium_since": "2019-05-17T06:26:56.936000+00:00", - "avatar": "estrogen", - "deaf": False, - "mute": True, - "pending": False, - "user": user_payload, - "communication_disabled_until": "2021-10-18T06:26:56.936000+00:00", - } - def test_deserialize_member(self, entity_factory_impl, mock_app, member_payload, user_payload): member_payload = {**member_payload, "guild_id": "76543325"} member = entity_factory_impl.deserialize_member(member_payload) @@ -1984,26 +2443,6 @@ def test_deserialize_member_with_passed_through_user_object_and_guild_id(self, e assert member.user is mock_user assert member.guild_id == 64234 - @pytest.fixture() - def guild_role_payload(self): - return { - "id": "41771983423143936", - "name": "WE DEM BOYZZ!!!!!!", - "color": 3_447_003, - "hoist": True, - "unicode_emoji": "\N{OK HAND SIGN}", - "icon": "abc123hash", - "position": 0, - "permissions": "66321471", - "managed": False, - "mentionable": False, - "tags": { - "bot_id": "123", - "integration_id": "456", - "premium_subscriber": None, - }, - } - def test_deserialize_role(self, entity_factory_impl, mock_app, guild_role_payload): guild_role = entity_factory_impl.deserialize_role(guild_role_payload, guild_id=snowflakes.Snowflake(76534453)) assert guild_role.app is mock_app @@ -4606,42 +5045,6 @@ def test_deserialize_message_deserializes_old_stickers_field(self, entity_factor # PRESENCE MODELS # ################### - @pytest.fixture() - def presence_activity_payload(self, custom_emoji_payload): - return { - "name": "an activity", - "type": 1, - "url": "https://69.420.owouwunyaa", - "created_at": 1584996792798, - "timestamps": {"start": 1584996792798, "end": 1999999792798}, - "application_id": "40404040404040", - "details": "They are doing stuff", - "state": "STATED", - "emoji": custom_emoji_payload, - "party": {"id": "spotify:3234234234", "size": [2, 5]}, - "assets": { - "large_image": "34234234234243", - "large_text": "LARGE TEXT", - "small_image": "3939393", - "small_text": "small text", - }, - "secrets": {"join": "who's a good secret?", "spectate": "I'm a good secret", "match": "No."}, - "instance": True, - "flags": 3, - "buttons": ["owo", "no"], - } - - @pytest.fixture() - def member_presence_payload(self, user_payload, presence_activity_payload): - return { - "user": user_payload, - "activity": presence_activity_payload, - "guild_id": "44004040", - "status": "dnd", - "activities": [presence_activity_payload], - "client_status": {"desktop": "online", "mobile": "idle", "web": "dnd"}, - } - def test_deserialize_member_presence( self, entity_factory_impl, mock_app, member_presence_payload, custom_emoji_payload, user_payload ): @@ -5275,20 +5678,6 @@ def test_deserialize_template_with_null_fields(self, entity_factory_impl, templa # USER MODELS # ############### - @pytest.fixture() - def user_payload(self): - return { - "id": "115590097100865541", - "username": "nyaa", - "avatar": "b3b24c6d7cbcdec129d5d537067061a8", - "banner": "a_221313e1e2edsncsncsmcndsc", - "accent_color": 231321, - "discriminator": "6127", - "bot": True, - "system": True, - "public_flags": int(user_models.UserFlag.EARLY_VERIFIED_DEVELOPER), - } - def test_deserialize_user(self, entity_factory_impl, mock_app, user_payload): user = entity_factory_impl.deserialize_user(user_payload) assert user.app is mock_app @@ -5385,24 +5774,6 @@ def test_deserialize_my_user_with_unset_fields(self, entity_factory_impl, mock_a # VOICE MODELS # ################ - @pytest.fixture() - def voice_state_payload(self, member_payload): - return { - "guild_id": "929292929292992", - "channel_id": "157733188964188161", - "user_id": "115590097100865541", - "member": member_payload, - "session_id": "90326bd25d71d39b9ef95b299e3872ff", - "deaf": True, - "mute": True, - "self_deaf": False, - "self_mute": True, - "self_stream": True, - "self_video": True, - "suppress": False, - "request_to_speak_timestamp": "2021-04-17T10:11:19.970105+00:00", - } - def test_deserialize_voice_state_with_guild_id_in_payload( self, entity_factory_impl, mock_app, voice_state_payload, member_payload ): diff --git a/tests/hikari/impl/test_event_factory.py b/tests/hikari/impl/test_event_factory.py index 0fdf09a7c9..10d7e4a1ef 100644 --- a/tests/hikari/impl/test_event_factory.py +++ b/tests/hikari/impl/test_event_factory.py @@ -210,17 +210,24 @@ def test_deserialize_guild_available_event(self, event_factory, mock_app, mock_s mock_payload = mock.Mock(app=mock_app) event = event_factory.deserialize_guild_available_event(mock_shard, mock_payload) - mock_app.entity_factory.deserialize_gateway_guild.assert_called_once_with(mock_payload) assert isinstance(event, guild_events.GuildAvailableEvent) assert event.shard is mock_shard - assert event.guild is mock_app.entity_factory.deserialize_gateway_guild.return_value.guild - assert event.emojis is mock_app.entity_factory.deserialize_gateway_guild.return_value.emojis - assert event.roles is mock_app.entity_factory.deserialize_gateway_guild.return_value.roles - assert event.channels is mock_app.entity_factory.deserialize_gateway_guild.return_value.channels - assert event.members is mock_app.entity_factory.deserialize_gateway_guild.return_value.members - assert event.presences is mock_app.entity_factory.deserialize_gateway_guild.return_value.presences - assert event.voice_states is mock_app.entity_factory.deserialize_gateway_guild.return_value.voice_states + guild_definition = mock_app.entity_factory.deserialize_gateway_guild.return_value + assert event.guild is guild_definition.guild.return_value + assert event.emojis is guild_definition.emojis.return_value + assert event.roles is guild_definition.roles.return_value + assert event.channels is guild_definition.channels.return_value + assert event.members is guild_definition.members.return_value + assert event.presences is guild_definition.presences.return_value + assert event.voice_states is guild_definition.voice_states.return_value + guild_definition.guild.assert_called_once_with() + guild_definition.emojis.assert_called_once_with() + guild_definition.roles.assert_called_once_with() + guild_definition.channels.assert_called_once_with() + guild_definition.members.assert_called_once_with() + guild_definition.presences.assert_called_once_with() + guild_definition.voice_states.assert_called_once_with() def test_deserialize_guild_join_event(self, event_factory, mock_app, mock_shard): mock_payload = mock.Mock(app=mock_app) @@ -247,10 +254,14 @@ def test_deserialize_guild_update_event(self, event_factory, mock_app, mock_shar mock_app.entity_factory.deserialize_gateway_guild.assert_called_once_with(mock_payload) assert isinstance(event, guild_events.GuildUpdateEvent) assert event.shard is mock_shard - assert event.guild is mock_app.entity_factory.deserialize_gateway_guild.return_value.guild - assert event.emojis is mock_app.entity_factory.deserialize_gateway_guild.return_value.emojis - assert event.roles is mock_app.entity_factory.deserialize_gateway_guild.return_value.roles + guild_definition = mock_app.entity_factory.deserialize_gateway_guild.return_value + assert event.guild is guild_definition.guild.return_value + assert event.emojis is guild_definition.emojis.return_value + assert event.roles is guild_definition.roles.return_value assert event.old_guild is mock_old_guild + guild_definition.guild.assert_called_once_with() + guild_definition.emojis.assert_called_once_with() + guild_definition.roles.assert_called_once_with() def test_deserialize_guild_leave_event(self, event_factory, mock_app, mock_shard): mock_payload = {"id": "43123123"} diff --git a/tests/hikari/impl/test_event_manager.py b/tests/hikari/impl/test_event_manager.py index 46ca5a5b49..db06762050 100644 --- a/tests/hikari/impl/test_event_manager.py +++ b/tests/hikari/impl/test_event_manager.py @@ -29,10 +29,13 @@ import pytest from hikari import channels +from hikari import config from hikari import errors from hikari import intents from hikari import presences from hikari.api import event_factory as event_factory_ +from hikari.events import guild_events +from hikari.events import shard_events from hikari.impl import event_manager as event_manager_ from hikari.internal import time from tests.hikari import hikari_test_helpers @@ -86,23 +89,27 @@ async def test__request_guild_members_handles_state_conflict_error(shard): class TestEventManagerImpl: + @pytest.fixture() + def entity_factory(self): + return mock.Mock() + @pytest.fixture() def event_factory(self): return mock.Mock() @pytest.fixture() - def event_manager(self, event_factory): + def event_manager(self, entity_factory, event_factory): obj = hikari_test_helpers.mock_class_namespace(event_manager_.EventManagerImpl, slots_=False)( - event_factory, intents.Intents.ALL, cache=mock.Mock() + entity_factory, event_factory, intents.Intents.ALL, cache=mock.Mock(settings=config.CacheSettings()) ) obj.dispatch = mock.AsyncMock() return obj @pytest.fixture() - def stateless_event_manager(self, event_factory): + def stateless_event_manager(self, event_factory, entity_factory): obj = hikari_test_helpers.mock_class_namespace(event_manager_.EventManagerImpl, slots_=False)( - event_factory, intents.Intents.ALL, cache=None + entity_factory, event_factory, intents.Intents.ALL, cache=None ) obj.dispatch = mock.AsyncMock() @@ -273,8 +280,7 @@ async def test_on_guild_create_stateful_with_unavailable_field(self, event_manag event_factory.deserialize_guild_available_event.assert_called_once_with(shard, payload) event_manager.dispatch.assert_awaited_once_with(event) - @pytest.mark.asyncio() - async def test_on_guild_create_stateful_without_unavailable_field(self, event_manager, shard, event_factory): + async def test_on_guild_create_stateful_and_dispatching(self, event_manager, shard, event_factory): payload = {} event = mock.Mock( guild=mock.Mock(id=123, is_large=False), @@ -287,6 +293,8 @@ async def test_on_guild_create_stateful_without_unavailable_field(self, event_ma chunk_nonce=None, ) + event_factory.deserialize_guild_join_event.return_value = event + event_manager._enabled_for_event = mock.Mock(return_value=True) event_factory.deserialize_guild_join_event.return_value = event shard.request_guild_members = mock.AsyncMock() @@ -294,6 +302,9 @@ async def test_on_guild_create_stateful_without_unavailable_field(self, event_ma assert event.chunk_nonce is None shard.request_guild_members.assert_not_called() + event_manager._enabled_for_event.assert_has_calls( + [mock.call(guild_events.GuildAvailableEvent), mock.call(shard_events.MemberChunkEvent)] + ) event_manager._cache.update_guild.assert_called_once_with(event.guild) @@ -315,6 +326,8 @@ async def test_on_guild_create_stateful_without_unavailable_field(self, event_ma event_manager._cache.clear_voice_states_for_guild.assert_called_once_with(123) event_manager._cache.set_voice_state.assert_called_once_with(345) + event_factory.deserialize_guild_join_event.assert_called_once_with(shard, payload) + event_factory.deserialize_gateway_guild.assert_not_called() event_factory.deserialize_guild_join_event.assert_called_once_with(shard, payload) event_manager.dispatch.assert_awaited_once_with(event) @@ -379,10 +392,105 @@ async def test_on_guild_create_when_request_chunks_with_unavailable_field( event_manager.dispatch.assert_awaited_once_with(event) @pytest.mark.asyncio() - async def test_on_guild_create_when_request_chunks_without_unavailable_field( + async def test_on_guild_create_stateful_and_not_dispatching_with_all_cache_components( + self, event_manager, shard, entity_factory, event_factory + ): + payload = {"id": "123"} + mock_channel = object() + mock_emoji = object() + mock_role = object() + mock_member = object() + mock_presence = object() + mock_voice_state = object() + guild_definition = entity_factory.deserialize_gateway_guild.return_value + guild_definition.id = 123 + guild_definition.guild.return_value = mock.Mock(id=123, is_large=False) + guild_definition.channels.return_value = {456: mock_channel} + guild_definition.emojis.return_value = {789: mock_emoji} + guild_definition.roles.return_value = {1234: mock_role} + guild_definition.members.return_value = {5678: mock_member} + guild_definition.presences.return_value = {9012: mock_presence} + guild_definition.voice_states.return_value = {345: mock_voice_state} + + event_manager._enabled_for_event = mock.Mock(return_value=False) + shard.request_guild_members = mock.AsyncMock() + + await event_manager.on_guild_create(shard, payload) + + shard.request_guild_members.assert_not_called() + event_manager._enabled_for_event.assert_has_calls( + [mock.call(guild_events.GuildAvailableEvent), mock.call(shard_events.MemberChunkEvent)] + ) + + event_manager._cache.update_guild.assert_called_once_with(guild_definition.guild.return_value) + + event_manager._cache.clear_guild_channels_for_guild.assert_called_once_with(123) + event_manager._cache.set_guild_channel.assert_called_once_with(mock_channel) + + event_manager._cache.clear_emojis_for_guild.assert_called_once_with(123) + event_manager._cache.set_emoji.assert_called_once_with(mock_emoji) + + event_manager._cache.clear_roles_for_guild.assert_called_once_with(123) + event_manager._cache.set_role.assert_called_once_with(mock_role) + + event_manager._cache.clear_members_for_guild.assert_called_once_with(123) + event_manager._cache.set_member.assert_called_once_with(mock_member) + + event_manager._cache.clear_presences_for_guild.assert_called_once_with(123) + event_manager._cache.set_presence.assert_called_once_with(mock_presence) + + event_manager._cache.clear_voice_states_for_guild.assert_called_once_with(123) + event_manager._cache.set_voice_state.assert_called_once_with(mock_voice_state) + + entity_factory.deserialize_gateway_guild.assert_called_once_with(payload) + event_factory.deserialize_guild_create_event.assert_not_called() + event_manager.dispatch.assert_not_called() + + @pytest.mark.asyncio() + async def test_on_guild_create_stateful_and_not_dispatching_with_no_cache_components( + self, event_manager, shard, event_factory, entity_factory + ): + payload = {"id": "123"} + event_manager._cache_enabled_for = mock.Mock(return_value=False) + event_manager._enabled_for_event = mock.Mock(return_value=False) + shard.request_guild_members = mock.AsyncMock() + + await event_manager.on_guild_create(shard, payload) + + shard.request_guild_members.assert_not_called() + event_manager._enabled_for_event.assert_has_calls( + [mock.call(guild_events.GuildAvailableEvent), mock.call(shard_events.MemberChunkEvent)] + ) + + event_manager._cache.update_guild.assert_not_called() + + event_manager._cache.clear_guild_channels_for_guild.assert_not_called() + event_manager._cache.set_guild_channel.assert_not_called() + + event_manager._cache.clear_emojis_for_guild.assert_not_called() + event_manager._cache.set_emoji.assert_not_called() + + event_manager._cache.clear_roles_for_guild.assert_not_called() + event_manager._cache.set_role.assert_not_called() + + event_manager._cache.clear_members_for_guild.assert_not_called() + event_manager._cache.set_member.assert_not_called() + + event_manager._cache.clear_presences_for_guild.assert_not_called() + event_manager._cache.set_presence.assert_not_called() + + event_manager._cache.clear_voice_states_for_guild.assert_not_called() + event_manager._cache.set_voice_state.assert_not_called() + + entity_factory.deserialize_gateway_guild.assert_called_once_with(payload) + event_factory.deserialize_guild_create_event.assert_not_called() + event_manager.dispatch.assert_not_called() + + @pytest.mark.asyncio() + async def test_on_guild_create_when_request_chunks_when_dispatching_available_event( self, event_manager, shard, event_factory ): - payload = {} + payload = {"large": True, "id": 123} event = mock.Mock( guild=mock.Mock(id=123, is_large=True), channels={"TestChannel": 456}, @@ -394,7 +502,9 @@ async def test_on_guild_create_when_request_chunks_without_unavailable_field( chunk_nonce=None, ) + event_manager._enabled_for_event = mock.Mock(return_value=True) event_factory.deserialize_guild_join_event.return_value = event + event_manager._cache.settings.components = config.CacheComponents.MEMBERS shard.request_guild_members = mock.Mock() stack = contextlib.ExitStack() @@ -407,36 +517,54 @@ async def test_on_guild_create_when_request_chunks_without_unavailable_field( with stack: await event_manager.on_guild_create(shard, payload) + event_manager._enabled_for_event.assert_has_calls( + [mock.call(guild_events.GuildAvailableEvent), mock.call(shard_events.MemberChunkEvent)] + ) uuid.assert_called_once_with() nonce = "987.uuid" assert event.chunk_nonce == nonce - _request_guild_members.assert_called_once_with(shard, event.guild, include_presences=True, nonce=nonce) + _request_guild_members.assert_called_once_with(shard, 123, include_presences=True, nonce=nonce) create_task.assert_called_once_with( _request_guild_members.return_value, name="987:123 guild create members request" ) - event_manager._cache.update_guild.assert_called_once_with(event.guild) + event_factory.deserialize_guild_create_event.assert_called_once_with(shard, payload) - event_manager._cache.clear_guild_channels_for_guild.assert_called_once_with(123) - event_manager._cache.set_guild_channel.assert_called_once_with(456) + @pytest.mark.asyncio() + async def test_on_guild_create_when_request_chunks_when_not_dispatching_available_event( + self, stateless_event_manager, shard, entity_factory, event_factory, event_manager + ): + payload = {"large": True, "id": 123} - event_manager._cache.clear_emojis_for_guild.assert_called_once_with(123) - event_manager._cache.set_emoji.assert_called_once_with(789) + stateless_event_manager._enabled_for_event = mock.Mock(side_effect=[False, True]) + entity_factory.deserialize_gateway_guild.return_value.id = 123 - event_manager._cache.clear_roles_for_guild.assert_called_once_with(123) - event_manager._cache.set_role.assert_called_once_with(1234) + stack = contextlib.ExitStack() + _request_guild_members = stack.enter_context( + mock.patch("hikari.impl.event_manager._request_guild_members", new_callable=mock.Mock) + ) + create_task = stack.enter_context(mock.patch.object(asyncio, "create_task")) + uuid = stack.enter_context(mock.patch("hikari.impl.event_manager._fixed_size_nonce", return_value="uuid")) - event_manager._cache.clear_members_for_guild.assert_called_once_with(123) - event_manager._cache.set_member.assert_called_once_with(5678) + with stack: + await stateless_event_manager.on_guild_create(shard, payload) - event_manager._cache.clear_presences_for_guild.assert_called_once_with(123) - event_manager._cache.set_presence.assert_called_once_with(9012) + stateless_event_manager._enabled_for_event.assert_has_calls( + [mock.call(guild_events.GuildAvailableEvent), mock.call(shard_events.MemberChunkEvent)] + ) + uuid.assert_called_once_with() + _request_guild_members.assert_called_once_with(shard, 123, include_presences=True, nonce="987.uuid") + create_task.assert_called_once_with( + _request_guild_members.return_value, name="987:123 guild create members request" + ) + + event_factory.deserialize_guild_join_event.assert_not_called() + entity_factory.deserialize_gateway_guild.assert_not_called() event_manager._cache.clear_voice_states_for_guild.assert_called_once_with(123) event_manager._cache.set_voice_state.assert_called_once_with(345) event_factory.deserialize_guild_join_event.assert_called_once_with(shard, payload) - event_manager.dispatch.assert_awaited_once_with(event) @pytest.mark.asyncio() async def test_on_guild_create_stateless_with_unavailable_field( @@ -453,27 +581,52 @@ async def test_on_guild_create_stateless_with_unavailable_field( event_factory.deserialize_guild_available_event.return_value ) - @pytest.mark.asyncio() - async def test_on_guild_create_stateless_without_unavailable_field( - self, stateless_event_manager, shard, event_factory + async def test_on_guild_create_stateless_and_dispatching( + self, stateless_event_manager, shard, event_factory, entity_factory ): - payload = {} + payload = {"id": "123123"} + stateless_event_manager._enabled_for_event = mock.Mock(return_value=True) shard.request_guild_members = mock.AsyncMock() await stateless_event_manager.on_guild_create(shard, payload) + shard.request_guild_members.assert_not_called() + stateless_event_manager._enabled_for_event.assert_has_calls( + [mock.call(guild_events.GuildAvailableEvent), mock.call(shard_events.MemberChunkEvent)] + ) + entity_factory.deserialize_gateway_guild.assert_not_called() event_factory.deserialize_guild_join_event.assert_called_once_with(shard, payload) stateless_event_manager.dispatch.assert_awaited_once_with( event_factory.deserialize_guild_join_event.return_value ) @pytest.mark.asyncio() - async def test_on_guild_update_stateful(self, event_manager, shard, event_factory): + async def test_on_guild_create_stateless_and_not_dispatching( + self, stateless_event_manager, shard, event_factory, entity_factory + ): + payload = {"id": "123123"} + stateless_event_manager._enabled_for_event = mock.Mock(return_value=False) + + shard.request_guild_members = mock.AsyncMock() + + await stateless_event_manager.on_guild_create(shard, payload) + + shard.request_guild_members.assert_not_called() + stateless_event_manager._enabled_for_event.assert_has_calls( + [mock.call(guild_events.GuildAvailableEvent), mock.call(shard_events.MemberChunkEvent)] + ) + entity_factory.deserialize_gateway_guild.assert_not_called() + event_factory.deserialize_guild_create_event.assert_not_called() + stateless_event_manager.dispatch.assert_not_called() + + @pytest.mark.asyncio() + async def test_on_guild_update_stateful_and_dispatching(self, event_manager, shard, event_factory, entity_factory): payload = {"id": 123} old_guild = object() mock_role = object() mock_emoji = object() + event_manager._enabled_for_event = mock.Mock(return_value=True) event = mock.Mock(roles={555: mock_role}, emojis={333: mock_emoji}, guild=mock.Mock(id=123)) event_factory.deserialize_guild_update_event.return_value = event @@ -481,26 +634,98 @@ async def test_on_guild_update_stateful(self, event_manager, shard, event_factor await event_manager.on_guild_update(shard, payload) + event_manager._enabled_for_event.assert_called_once_with(guild_events.GuildUpdateEvent) event_manager._cache.get_guild.assert_called_once_with(123) event_manager._cache.update_guild.assert_called_once_with(event.guild) event_manager._cache.clear_roles_for_guild.assert_called_once_with(123) event_manager._cache.set_role.assert_called_once_with(mock_role) event_manager._cache.clear_emojis_for_guild.assert_called_once_with(123) event_manager._cache.set_emoji.assert_called_once_with(mock_emoji) + entity_factory.deserialize_gateway_guild.assert_not_called() event_factory.deserialize_guild_update_event.assert_called_once_with(shard, payload, old_guild=old_guild) event_manager.dispatch.assert_awaited_once_with(event) @pytest.mark.asyncio() - async def test_on_guild_update_stateless(self, stateless_event_manager, shard, event_factory): + async def test_on_guild_update_all_cache_components_and_not_dispatching( + self, event_manager, shard, event_factory, entity_factory + ): + payload = {"id": 123} + mock_role = object() + mock_emoji = object() + event_manager._enabled_for_event = mock.Mock(return_value=False) + guild_definition = entity_factory.deserialize_gateway_guild.return_value + guild_definition.id = 123 + guild_definition.emojis.return_value = {0: mock_emoji} + guild_definition.roles.return_value = {1: mock_role} + + await event_manager.on_guild_update(shard, payload) + + entity_factory.deserialize_gateway_guild.assert_called_once_with({"id": 123}) + event_manager._enabled_for_event.assert_called_once_with(guild_events.GuildUpdateEvent) + event_manager._cache.update_guild.assert_called_once_with(guild_definition.guild.return_value) + event_manager._cache.clear_emojis_for_guild.assert_called_once_with(123) + event_manager._cache.set_emoji.assert_called_once_with(mock_emoji) + event_manager._cache.clear_roles_for_guild.assert_called_once_with(123) + event_manager._cache.set_role.assert_called_once_with(mock_role) + event_factory.deserialize_guild_update_event.assert_not_called() + event_manager.dispatch.assert_not_called() + guild_definition.emojis.assert_called_once_with() + guild_definition.roles.assert_called_once_with() + guild_definition.guild.assert_called_once_with() + + @pytest.mark.asyncio() + async def test_on_guild_update_no_cache_components_and_not_dispatching( + self, event_manager, shard, event_factory, entity_factory + ): + payload = {"id": 123} + event_manager._cache_enabled_for = mock.Mock(return_value=False) + event_manager._enabled_for_event = mock.Mock(return_value=False) + guild_definition = entity_factory.deserialize_gateway_guild.return_value + + await event_manager.on_guild_update(shard, payload) + + entity_factory.deserialize_gateway_guild.assert_called_once_with({"id": 123}) + event_manager._enabled_for_event.assert_called_once_with(guild_events.GuildUpdateEvent) + event_manager._cache.update_guild.assert_not_called() + event_manager._cache.clear_emojis_for_guild.assert_not_called() + event_manager._cache.set_emoji.assert_not_called() + event_manager._cache.clear_roles_for_guild.assert_not_called() + event_manager._cache.set_role.assert_not_called() + event_factory.deserialize_guild_update_event.assert_not_called() + event_manager.dispatch.assert_not_called() + guild_definition.emojis.assert_not_called() + guild_definition.roles.assert_not_called() + guild_definition.guild.assert_not_called() + + @pytest.mark.asyncio() + async def test_on_guild_update_stateless_and_dispatching( + self, stateless_event_manager, shard, event_factory, entity_factory + ): payload = {"id": 123} + stateless_event_manager._enabled_for_event = mock.Mock(return_value=True) await stateless_event_manager.on_guild_update(shard, payload) + stateless_event_manager._enabled_for_event.assert_called_once_with(guild_events.GuildUpdateEvent) + entity_factory.deserialize_gateway_guild.assert_not_called() event_factory.deserialize_guild_update_event.assert_called_once_with(shard, payload, old_guild=None) stateless_event_manager.dispatch.assert_awaited_once_with( event_factory.deserialize_guild_update_event.return_value ) + @pytest.mark.asyncio() + async def test_on_guild_update_stateless_and_not_dispatching( + self, stateless_event_manager, shard, entity_factory, event_factory + ): + stateless_event_manager._enabled_for_event = mock.Mock(return_value=False) + + await stateless_event_manager.on_guild_update(shard, {"id": 123}) + + stateless_event_manager._enabled_for_event.assert_called_once_with(guild_events.GuildUpdateEvent) + entity_factory.deserialize_gateway_guild.assert_not_called() + event_factory.deserialize_guild_update_event.assert_not_called() + stateless_event_manager.dispatch.assert_not_called() + @pytest.mark.asyncio() async def test_on_guild_delete_stateful_when_available(self, event_manager, shard, event_factory): payload = {"unavailable": False, "id": "123"} @@ -612,7 +837,8 @@ async def test_on_guild_emojis_update_stateless(self, stateless_event_manager, s @pytest.mark.asyncio() async def test_on_guild_integrations_update(self, event_manager, shard): - assert await event_manager.on_guild_integrations_update(shard, {}) is None + with pytest.raises(NotImplementedError): + await event_manager.on_guild_integrations_update(shard, {}) event_manager.dispatch.assert_not_called() diff --git a/tests/hikari/impl/test_event_manager_base.py b/tests/hikari/impl/test_event_manager_base.py index f44d19d94a..361b8aef99 100644 --- a/tests/hikari/impl/test_event_manager_base.py +++ b/tests/hikari/impl/test_event_manager_base.py @@ -29,11 +29,14 @@ import mock import pytest +from hikari import config from hikari import errors from hikari import intents from hikari import iterators +from hikari import undefined from hikari.events import base_events from hikari.events import member_events +from hikari.events import shard_events from hikari.impl import event_manager_base from hikari.internal import reflect from tests.hikari import hikari_test_helpers @@ -398,20 +401,160 @@ class EventManagerBaseImpl(event_manager_base.EventManagerBase): def test___init___loads_consumers(self): class StubManager(event_manager_base.EventManagerBase): + @event_manager_base.filtered(shard_events.ShardEvent, config.CacheComponents.MEMBERS) async def on_foo(self, event): raise NotImplementedError + @event_manager_base.filtered((shard_events.ShardStateEvent, shard_events.ShardPayloadEvent)) async def on_bar(self, event): raise NotImplementedError + @event_manager_base.filtered(shard_events.MemberChunkEvent, config.CacheComponents.MESSAGES) + async def on_bat(self, event): + raise NotImplementedError + + async def on_not_decorated(self, event): + raise NotImplementedError + async def not_a_listener(self): raise NotImplementedError - manager = StubManager(mock.Mock(), mock.Mock(intents=42)) - assert manager._consumers == {"foo": manager.on_foo, "bar": manager.on_bar} + expected_bar_events = ( + shard_events.ShardStateEvent, + shard_events.ShardEvent, + base_events.Event, + shard_events.ShardPayloadEvent, + ) + expected_bat_events = (shard_events.MemberChunkEvent, shard_events.ShardEvent, base_events.Event) + manager = StubManager( + mock.Mock(), + 0, + cache_components=config.CacheComponents.MEMBERS | config.CacheComponents.GUILD_CHANNELS, + ) + assert manager._consumers == { + "foo": event_manager_base._Consumer(manager.on_foo, (shard_events.ShardEvent, base_events.Event), True), + "bar": event_manager_base._Consumer(manager.on_bar, expected_bar_events, False), + "bat": event_manager_base._Consumer(manager.on_bat, expected_bat_events, False), + "not_decorated": event_manager_base._Consumer(manager.on_not_decorated, undefined.UNDEFINED, True), + } + + def test___init___loads_consumers_when_cacheless(self): + class StubManager(event_manager_base.EventManagerBase): + @event_manager_base.filtered(shard_events.ShardEvent, config.CacheComponents.MEMBERS) + async def on_foo(self, event): + raise NotImplementedError + + @event_manager_base.filtered((shard_events.ShardStateEvent, shard_events.ShardPayloadEvent)) + async def on_bar(self, event): + raise NotImplementedError + + @event_manager_base.filtered(shard_events.MemberChunkEvent, config.CacheComponents.MESSAGES) + async def on_bat(self, event): + raise NotImplementedError + + async def on_not_decorated(self, event): + raise NotImplementedError + + async def not_a_listener(self): + raise NotImplementedError + + expected_bar_events = ( + shard_events.ShardStateEvent, + shard_events.ShardEvent, + base_events.Event, + shard_events.ShardPayloadEvent, + ) + expected_bat_events = (shard_events.MemberChunkEvent, shard_events.ShardEvent, base_events.Event) + manager = StubManager(mock.Mock(), 0, cache_components=config.CacheComponents.NONE) + assert manager._consumers == { + "foo": event_manager_base._Consumer(manager.on_foo, (shard_events.ShardEvent, base_events.Event), False), + "bar": event_manager_base._Consumer(manager.on_bar, expected_bar_events, False), + "bat": event_manager_base._Consumer(manager.on_bat, expected_bat_events, False), + "not_decorated": event_manager_base._Consumer(manager.on_not_decorated, undefined.UNDEFINED, False), + } + + def test__clear_enabled_cache(self): + event_manager = hikari_test_helpers.mock_class_namespace(event_manager_base.EventManagerBase, init_=False)() + event_manager._enabled_consumers_cache = {object: object(), "ok": object()} + + event_manager._clear_enabled_cache() + + assert event_manager._enabled_consumers_cache == {} + + def test__enabled_for_event_when_listener_registered(self, event_manager): + event_manager._listeners = {} + + def test__enabled_for_event_when_waiter_registered(self, event_manager): + event_manager._listeners = {} + + def test__enabled_for_event_when_not_registered(self, event_manager): + event_manager._listeners = {shard_events.ShardPayloadEvent: [], shard_events.MemberChunkEvent: []} + + assert event_manager._enabled_for_event(shard_events.ShardStateEvent) is False + + def test__enabled_for_consumer_when_event_types_is_undefined(self, event_manager): + consumer = mock.Mock(event_types=undefined.UNDEFINED) + + assert event_manager._enabled_for_consumer(consumer) is True + + def test__enabled_for_consumer_when_caching(self, event_manager): + consumer = mock.Mock(event_types=(), is_caching=True) + + assert event_manager._enabled_for_consumer(consumer) is True + + @pytest.mark.parametrize("cached_state", [False, True]) + def test__enabled_for_consumer_when_consumer_state_cached(self, event_manager, cached_state): + consumer = mock.Mock(event_types=(), is_caching=False) + event_manager._enabled_consumers_cache[consumer] = cached_state + + assert event_manager._enabled_for_consumer(consumer) is cached_state + + def test__enabled_for_consumer_when_consumer_state_not_cached_and_listeners_present(self, event_manager): + event_manager._listeners[shard_events.MemberChunkEvent] = [] + consumer = mock.Mock( + event_types=(shard_events.ShardEvent, shard_events.ShardPayloadEvent, shard_events.MemberChunkEvent), + is_caching=False, + ) + + result = event_manager._enabled_for_consumer(consumer) + + assert result is True + assert event_manager._enabled_consumers_cache[consumer] is True + + def test__enabled_for_consumer_when_consumer_state_not_cached_and_waiters_present(self, event_manager): + event_manager._waiters[shard_events.MemberChunkEvent] = [] + consumer = mock.Mock( + event_types=(shard_events.ShardEvent, shard_events.ShardPayloadEvent, shard_events.MemberChunkEvent), + is_caching=False, + ) + + result = event_manager._enabled_for_consumer(consumer) + + assert result is True + assert event_manager._enabled_consumers_cache[consumer] is True + + def test__enabled_for_consumer_when_consumer_state_not_cached_and_not_enabled(self, event_manager): + consumer = mock.Mock( + event_types=(shard_events.ShardEvent, shard_events.ShardPayloadEvent, shard_events.MemberChunkEvent), + is_caching=False, + ) + + result = event_manager._enabled_for_consumer(consumer) + + assert result is False + assert event_manager._enabled_consumers_cache[consumer] is False + + def test__enabled_for_consumer_when_consumer_state_not_cached_and_no_event_types(self, event_manager): + consumer = mock.Mock(event_types=(), is_caching=False) + + result = event_manager._enabled_for_consumer(consumer) + + assert result is False + assert event_manager._enabled_consumers_cache[consumer] is False @pytest.mark.asyncio() async def test_consume_raw_event_when_KeyError(self, event_manager): + event_manager._enabled_for_event = mock.Mock(return_value=True) mock_payload = {"id": "3123123123"} mock_shard = mock.Mock(id=123) event_manager._handle_dispatch = mock.Mock() @@ -427,9 +570,11 @@ async def test_consume_raw_event_when_KeyError(self, event_manager): event_manager._event_factory.deserialize_shard_payload_event.assert_called_once_with( mock_shard, mock_payload, name="UNEXISTING_EVENT" ) + event_manager._enabled_for_event.assert_called_once_with(shard_events.ShardPayloadEvent) @pytest.mark.asyncio() async def test_consume_raw_event_when_found(self, event_manager): + event_manager._enabled_for_event = mock.Mock(return_value=True) event_manager._handle_dispatch = mock.Mock() event_manager.dispatch = mock.Mock() on_existing_event = object() @@ -451,43 +596,69 @@ async def test_consume_raw_event_when_found(self, event_manager): event_manager._event_factory.deserialize_shard_payload_event.assert_called_once_with( shard, payload, name="EXISTING_EVENT" ) + event_manager._enabled_for_event.assert_called_once_with(shard_events.ShardPayloadEvent) + + @pytest.mark.asyncio() + async def test_consume_raw_event_skips_raw_dispatch_when_not_enabled(self, event_manager): + event_manager._enabled_for_event = mock.Mock(return_value=False) + event_manager._handle_dispatch = mock.Mock() + event_manager.dispatch = mock.Mock() + on_existing_event = object() + event_manager._consumers = {"existing_event": on_existing_event} + shard = object() + payload = {"berp": "baz"} + + with mock.patch("asyncio.create_task") as create_task: + event_manager.consume_raw_event("EXISTING_EVENT", shard, payload) + + event_manager._handle_dispatch.assert_called_once_with(on_existing_event, shard, {"berp": "baz"}) + create_task.assert_called_once_with( + event_manager._handle_dispatch(on_existing_event, shard, {"berp": "baz"}), + name="dispatch EXISTING_EVENT", + ) + event_manager.dispatch.assert_not_called() + event_manager._event_factory.deserialize_shard_payload_event.vassert_not_called() + event_manager._enabled_for_event.assert_called_once_with(shard_events.ShardPayloadEvent) @pytest.mark.asyncio() async def test_handle_dispatch_invokes_callback(self, event_manager, event_loop): - callback = mock.AsyncMock() + event_manager._enabled_for_consumer = mock.Mock(return_value=True) + consumer = mock.AsyncMock() error_handler = mock.MagicMock() event_loop.set_exception_handler(error_handler) shard = object() pl = {"foo": "bar"} - await event_manager._handle_dispatch(callback, shard, pl) + await event_manager._handle_dispatch(consumer, shard, pl) - callback.assert_awaited_once_with(shard, pl) + consumer.callback.assert_awaited_once_with(shard, pl) error_handler.assert_not_called() @pytest.mark.asyncio() async def test_handle_dispatch_ignores_cancelled_errors(self, event_manager, event_loop): - callback = mock.AsyncMock(side_effect=asyncio.CancelledError) + event_manager._enabled_for_consumer = mock.Mock(return_value=True) + consumer = mock.AsyncMock(side_effect=asyncio.CancelledError) error_handler = mock.MagicMock() event_loop.set_exception_handler(error_handler) shard = object() pl = {"lorem": "ipsum"} - await event_manager._handle_dispatch(callback, shard, pl) + await event_manager._handle_dispatch(consumer, shard, pl) error_handler.assert_not_called() @pytest.mark.asyncio() async def test_handle_dispatch_handles_exceptions(self, event_manager, event_loop): + event_manager._enabled_for_consumer = mock.Mock(return_value=True) exc = Exception("aaaa!") - callback = mock.AsyncMock(side_effect=exc) + consumer = mock.Mock(callback=mock.AsyncMock(side_effect=exc)) error_handler = mock.MagicMock() event_loop.set_exception_handler(error_handler) shard = object() pl = {"i like": "cats"} with mock.patch.object(asyncio, "current_task") as current_task: - await event_manager._handle_dispatch(callback, shard, pl) + await event_manager._handle_dispatch(consumer, shard, pl) error_handler.assert_called_once_with( event_loop, @@ -498,6 +669,20 @@ async def test_handle_dispatch_handles_exceptions(self, event_manager, event_loo }, ) + @pytest.mark.asyncio() + async def test_handle_dispatch_invokes_when_consumer_not_enabled(self, event_manager, event_loop): + event_manager._enabled_for_consumer = mock.Mock(return_value=False) + consumer = mock.Mock(callback=mock.AsyncMock(__name__="ok")) + error_handler = mock.MagicMock() + event_loop.set_exception_handler(error_handler) + shard = object() + pl = {"foo": "bar"} + + await event_manager._handle_dispatch(consumer, shard, pl) + + consumer.callback.assert_not_called() + error_handler.assert_not_called() + def test_subscribe_when_callback_is_not_coroutine(self, event_manager): def test(): ... @@ -513,6 +698,8 @@ async def test(): event_manager.subscribe("test", test) def test_subscribe_when_event_type_not_in_listeners(self, event_manager): + event_manager._clear_enabled_cache = mock.Mock() + async def test(): ... @@ -521,6 +708,7 @@ async def test(): assert event_manager._listeners == {member_events.MemberCreateEvent: [test]} check.assert_called_once_with(member_events.MemberCreateEvent, 1) + event_manager._clear_enabled_cache.assert_called_once_with() def test_subscribe_when_event_type_in_listeners(self, event_manager): async def test(): @@ -529,6 +717,7 @@ async def test(): async def test2(): ... + event_manager._clear_enabled_cache = mock.Mock() event_manager._listeners[member_events.MemberCreateEvent] = [test2] with mock.patch.object(event_manager_base.EventManagerBase, "_check_intents") as check: @@ -536,6 +725,7 @@ async def test2(): assert event_manager._listeners == {member_events.MemberCreateEvent: [test2, test]} check.assert_called_once_with(member_events.MemberCreateEvent, 2) + event_manager._clear_enabled_cache.assert_not_called() def test__check_intents_when_no_intents_required(self, event_manager): event_manager._intents = intents.Intents.ALL @@ -579,7 +769,7 @@ def test__check_intents_when_intents_incorrect(self, event_manager): def test_get_listeners_when_not_event(self, event_manager): assert len(event_manager.get_listeners("test")) == 0 - def test_get_listeners_polimorphic(self, event_manager): + def test_get_listeners_polymorphic(self, event_manager): event_manager._listeners = { base_events.Event: ["this will never appear"], member_events.MemberEvent: ["coroutine0"], @@ -597,17 +787,16 @@ def test_get_listeners_polimorphic(self, event_manager): "coroutine5", ] - def test_get_listeners_no_polimorphic_and_no_results(self, event_manager): + def test_get_listeners_monomorphic_and_no_results(self, event_manager): event_manager._listeners = { member_events.MemberCreateEvent: ["coroutine1", "coroutine2"], member_events.MemberUpdateEvent: ["coroutine3"], member_events.MemberDeleteEvent: ["coroutine4", "coroutine5"], - member_events.MemberDeleteEvent: ["coroutine4", "coroutine5"], } assert event_manager.get_listeners(member_events.MemberEvent, polymorphic=False) == [] - def test_get_listeners_no_polimorphic_and_results(self, event_manager): + def test_get_listeners_monomorphic_and_results(self, event_manager): event_manager._listeners = { member_events.MemberEvent: ["coroutine0"], member_events.MemberCreateEvent: ["coroutine1", "coroutine2"], @@ -634,6 +823,7 @@ async def test(): async def test2(): ... + event_manager._clear_enabled_cache = mock.Mock() event_manager._listeners = { member_events.MemberCreateEvent: [test, test2], member_events.MemberDeleteEvent: [test], @@ -645,16 +835,19 @@ async def test2(): member_events.MemberCreateEvent: [test2], member_events.MemberDeleteEvent: [test], } + event_manager._clear_enabled_cache.assert_not_called() def test_unsubscribe_when_event_type_when_list_empty_after_delete(self, event_manager): async def test(): ... + event_manager._clear_enabled_cache = mock.Mock() event_manager._listeners = {member_events.MemberCreateEvent: [test], member_events.MemberDeleteEvent: [test]} event_manager.unsubscribe(member_events.MemberCreateEvent, test) assert event_manager._listeners == {member_events.MemberDeleteEvent: [test]} + event_manager._clear_enabled_cache.assert_called_once_with() def test_listen_when_no_params(self, event_manager): with pytest.raises(TypeError):