diff --git a/discord/bot.py b/discord/bot.py index a7f7bf564d..37ff681b38 100644 --- a/discord/bot.py +++ b/discord/bot.py @@ -28,50 +28,42 @@ import asyncio import collections import inspect +import sys import traceback -from .commands.errors import CheckFailure - -from typing import ( - Any, - Callable, - Coroutine, - List, - Optional, - Type, - TypeVar, - Union, -) +from typing import Any, Callable, Coroutine, List, Optional, Type, TypeVar, Union -import sys +import discord from .client import Client -from .shard import AutoShardedClient -from .utils import MISSING, get, find, async_all +from .cog import CogMixin from .commands import ( - SlashCommand, - SlashCommandGroup, - MessageCommand, - UserCommand, ApplicationCommand, ApplicationContext, AutocompleteContext, + MessageCommand, + SlashCommand, + SlashCommandGroup, + UserCommand, command, ) -from .cog import CogMixin - -from .errors import Forbidden, DiscordException -from .interactions import Interaction +from .commands.errors import CheckFailure from .enums import InteractionType +from .errors import DiscordException, Forbidden +from .interactions import Interaction +from .shard import AutoShardedClient +from .state import ConnectionState +from .utils import MISSING, async_all, find, get CoroFunc = Callable[..., Coroutine[Any, Any, Any]] -CFT = TypeVar('CFT', bound=CoroFunc) +CFT = TypeVar("CFT", bound=CoroFunc) __all__ = ( - 'ApplicationCommandMixin', - 'Bot', - 'AutoShardedBot', + "ApplicationCommandMixin", + "Bot", + "AutoShardedBot", ) + class ApplicationCommandMixin: """A mixin that implements common functionality for classes that need application command compatibility. @@ -149,10 +141,10 @@ def remove_application_command( @property def get_command(self): """Shortcut for :meth:`.get_application_command`. - + .. note:: Overridden in :class:`ext.commands.Bot`. - + .. versionadded:: 2.0 """ # TODO: Do something like we did in self.commands for this @@ -185,10 +177,7 @@ def get_application_command( """ for command in self._application_commands.values(): - if ( - command.name == name - and isinstance(command, type) - ): + if command.name == name and isinstance(command, type): if guild_ids is not None and command.guild_ids != guild_ids: return return command @@ -287,7 +276,12 @@ async def register_commands(self) -> None: raise else: for i in cmds: - cmd = find(lambda cmd: cmd.name == i["name"] and cmd.type == i["type"] and int(i["guild_id"]) in cmd.guild_ids, self.pending_application_commands) + cmd = find( + lambda cmd: cmd.name == i["name"] + and cmd.type == i["type"] + and int(i["guild_id"]) in cmd.guild_ids, + self.pending_application_commands, + ) cmd.id = i["id"] self._application_commands[cmd.id] = cmd @@ -380,7 +374,9 @@ async def register_commands(self) -> None: if len(new_cmd_perm["permissions"]) > 10: print( "Command '{name}' has more than 10 permission overrides in guild ({guild_id}).\nwill only use the first 10 permission overrides.".format( - name=self._application_commands[new_cmd_perm["id"]].name, + name=self._application_commands[ + new_cmd_perm["id"] + ].name, guild_id=guild_id, ) ) @@ -424,8 +420,8 @@ async def process_application_commands(self, interaction: Interaction) -> None: The interaction to process """ if interaction.type not in ( - InteractionType.application_command, - InteractionType.auto_complete + InteractionType.application_command, + InteractionType.auto_complete, ): return @@ -438,7 +434,7 @@ async def process_application_commands(self, interaction: Interaction) -> None: ctx = await self.get_autocomplete_context(interaction) ctx.command = command return await command.invoke_autocomplete_callback(ctx) - + ctx = await self.get_application_context(interaction) ctx.command = command self.dispatch("application_command", ctx) @@ -591,17 +587,20 @@ def group( Callable[[Type[SlashCommandGroup]], SlashCommandGroup] The slash command group that was created. """ + def inner(cls: Type[SlashCommandGroup]) -> SlashCommandGroup: group = cls( name, ( description or inspect.cleandoc(cls.__doc__).splitlines()[0] - if cls.__doc__ is not None else "No description provided" + if cls.__doc__ is not None + else "No description provided" ), - guild_ids=guild_ids + guild_ids=guild_ids, ) self.add_application_command(group) return group + return inner slash_group = group @@ -667,7 +666,6 @@ class be provided, it must be similar enough to return cls(self, interaction) - class BotBase(ApplicationCommandMixin, CogMixin): _supports_prefixed_commands = False # TODO I think @@ -717,6 +715,13 @@ async def on_connect(self): async def on_interaction(self, interaction): await self.process_application_commands(interaction) + if interaction.type == discord.InteractionType.modal_submit: + state: ConnectionState = self._connection # type: ignore + user_id, custom_id = ( + interaction.user.id, + interaction.data["custom_id"], + ) + await state._modal_store.dispatch(user_id, custom_id, interaction) async def on_application_command_error( self, context: ApplicationContext, exception: DiscordException @@ -730,7 +735,7 @@ async def on_application_command_error( This only fires if you do not specify any listeners for command error. """ - if self.extra_events.get('on_application_command_error', None): + if self.extra_events.get("on_application_command_error", None): return command = context.command @@ -887,7 +892,7 @@ async def my_message(message): pass name = func.__name__ if name is MISSING else name if not asyncio.iscoroutinefunction(func): - raise TypeError('Listeners must be coroutines') + raise TypeError("Listeners must be coroutines") if name in self.extra_events: self.extra_events[name].append(func) @@ -953,7 +958,7 @@ def decorator(func: CFT) -> CFT: def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None: # super() will resolve to Client super().dispatch(event_name, *args, **kwargs) # type: ignore - ev = 'on_' + event_name + ev = f"on_{event_name}" for event in self.extra_events.get(ev, []): self._schedule_event(event, ev, *args, **kwargs) # type: ignore diff --git a/discord/components.py b/discord/components.py index b6719e7c1b..bda34abb2c 100644 --- a/discord/components.py +++ b/discord/components.py @@ -25,31 +25,42 @@ from __future__ import annotations -from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union -from .enums import try_enum, ComponentType, ButtonStyle -from .utils import get_slots, MISSING +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) + +from .enums import ButtonStyle, ComponentType, InputTextStyle, try_enum from .partial_emoji import PartialEmoji, _EmojiTag +from .utils import MISSING, get_slots if TYPE_CHECKING: - from .types.components import ( - Component as ComponentPayload, - ButtonComponent as ButtonComponentPayload, - SelectMenu as SelectMenuPayload, - SelectOption as SelectOptionPayload, - ActionRow as ActionRowPayload, - ) from .emoji import Emoji + from .types.components import ActionRow as ActionRowPayload + from .types.components import ButtonComponent as ButtonComponentPayload + from .types.components import Component as ComponentPayload + from .types.components import InputText as InputTextComponentPayload + from .types.components import SelectMenu as SelectMenuPayload + from .types.components import SelectOption as SelectOptionPayload __all__ = ( - 'Component', - 'ActionRow', - 'Button', - 'SelectMenu', - 'SelectOption', + "Component", + "ActionRow", + "Button", + "SelectMenu", + "SelectOption", ) -C = TypeVar('C', bound='Component') +C = TypeVar("C", bound="Component") class Component: @@ -71,14 +82,14 @@ class Component: The type of component. """ - __slots__: Tuple[str, ...] = ('type',) + __slots__: Tuple[str, ...] = ("type",) __repr_info__: ClassVar[Tuple[str, ...]] type: ComponentType def __repr__(self) -> str: - attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__repr_info__) - return f'<{self.__class__.__name__} {attrs}>' + attrs = " ".join(f"{key}={getattr(self, key)!r}" for key in self.__repr_info__) + return f"<{self.__class__.__name__} {attrs}>" @classmethod def _raw_construct(cls: Type[C], **kwargs) -> C: @@ -113,21 +124,101 @@ class ActionRow(Component): The children components that this holds, if any. """ - __slots__: Tuple[str, ...] = ('children',) + __slots__: Tuple[str, ...] = ("children",) __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ def __init__(self, data: ComponentPayload): - self.type: ComponentType = try_enum(ComponentType, data['type']) - self.children: List[Component] = [_component_factory(d) for d in data.get('components', [])] + self.type: ComponentType = try_enum(ComponentType, data["type"]) + self.children: List[Component] = [ + _component_factory(d) for d in data.get("components", []) + ] def to_dict(self) -> ActionRowPayload: return { - 'type': int(self.type), - 'components': [child.to_dict() for child in self.children], + "type": int(self.type), + "components": [child.to_dict() for child in self.children], } # type: ignore +class InputText(Component): + """Represents an Input Text field from the Discord Bot UI Kit. + + This inherits from :class:`Component`. + + Attributes + ---------- + style: :class:`.InputTextStyle` + The style of the input text field. + custom_id: Optional[:class:`str`] + The ID of the input text field that gets received during an interaction. + label: Optional[:class:`str`] + The label for the input text field, if any. + placeholder: Optional[:class:`str`] + The placeholder text that is shown if nothing is selected, if any. + min_length: Optional[:class:`int`] + The minimum number of characters that must be entered + Defaults to 0 + max_length: Optional[:class:`int`] + The maximum number of characters that can be entered + required: Optional[:class:`bool`] + Whether the input text field is required or not. Defaults to `True`. + value: Optional[:class:`str`] + The value that has been entered in the input text field. + """ + + __slots__: Tuple[str, ...] = ( + "type", + "style", + "custom_id", + "label", + "placeholder", + "min_length", + "max_length", + "required", + "value", + ) + + __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ + + def __init__(self, data: InputTextComponentPayload): + self.type = ComponentType.input_text + self.style: InputTextStyle = try_enum(InputTextStyle, data["style"]) + self.custom_id = data["custom_id"] + self.label: Optional[str] = data.get("label", None) + self.placeholder: Optional[str] = data.get("placeholder", None) + self.min_length: Optional[int] = data.get("min_length", None) + self.max_length: Optional[int] = data.get("max_length", None) + self.required: bool = data.get("required", True) + self.value: Optional[str] = data.get("value", None) + + def to_dict(self) -> InputTextComponentPayload: + payload = { + "type": 4, + "style": self.style.value, + "label": self.label, + } + if self.custom_id: + payload["custom_id"] = self.custom_id + + if self.placeholder: + payload["placeholder"] = self.placeholder + + if self.min_length: + payload["min_length"] = self.min_length + + if self.max_length: + payload["max_length"] = self.max_length + + if not self.required: + payload["required"] = self.required + + if self.value: + payload["value"] = self.value + + return payload # type: ignore + + class Button(Component): """Represents a button from the Discord Bot UI Kit. @@ -158,44 +249,44 @@ class Button(Component): """ __slots__: Tuple[str, ...] = ( - 'style', - 'custom_id', - 'url', - 'disabled', - 'label', - 'emoji', + "style", + "custom_id", + "url", + "disabled", + "label", + "emoji", ) __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ def __init__(self, data: ButtonComponentPayload): - self.type: ComponentType = try_enum(ComponentType, data['type']) - self.style: ButtonStyle = try_enum(ButtonStyle, data['style']) - self.custom_id: Optional[str] = data.get('custom_id') - self.url: Optional[str] = data.get('url') - self.disabled: bool = data.get('disabled', False) - self.label: Optional[str] = data.get('label') + self.type: ComponentType = try_enum(ComponentType, data["type"]) + self.style: ButtonStyle = try_enum(ButtonStyle, data["style"]) + self.custom_id: Optional[str] = data.get("custom_id") + self.url: Optional[str] = data.get("url") + self.disabled: bool = data.get("disabled", False) + self.label: Optional[str] = data.get("label") self.emoji: Optional[PartialEmoji] try: - self.emoji = PartialEmoji.from_dict(data['emoji']) + self.emoji = PartialEmoji.from_dict(data["emoji"]) except KeyError: self.emoji = None def to_dict(self) -> ButtonComponentPayload: payload = { - 'type': 2, - 'style': int(self.style), - 'label': self.label, - 'disabled': self.disabled, + "type": 2, + "style": int(self.style), + "label": self.label, + "disabled": self.disabled, } if self.custom_id: - payload['custom_id'] = self.custom_id + payload["custom_id"] = self.custom_id if self.url: - payload['url'] = self.url + payload["url"] = self.url if self.emoji: - payload['emoji'] = self.emoji.to_dict() + payload["emoji"] = self.emoji.to_dict() return payload # type: ignore @@ -232,37 +323,39 @@ class SelectMenu(Component): """ __slots__: Tuple[str, ...] = ( - 'custom_id', - 'placeholder', - 'min_values', - 'max_values', - 'options', - 'disabled', + "custom_id", + "placeholder", + "min_values", + "max_values", + "options", + "disabled", ) __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ def __init__(self, data: SelectMenuPayload): self.type = ComponentType.select - self.custom_id: str = data['custom_id'] - self.placeholder: Optional[str] = data.get('placeholder') - self.min_values: int = data.get('min_values', 1) - self.max_values: int = data.get('max_values', 1) - self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get('options', [])] - self.disabled: bool = data.get('disabled', False) + self.custom_id: str = data["custom_id"] + self.placeholder: Optional[str] = data.get("placeholder") + self.min_values: int = data.get("min_values", 1) + self.max_values: int = data.get("max_values", 1) + self.options: List[SelectOption] = [ + SelectOption.from_dict(option) for option in data.get("options", []) + ] + self.disabled: bool = data.get("disabled", False) def to_dict(self) -> SelectMenuPayload: payload: SelectMenuPayload = { - 'type': self.type.value, - 'custom_id': self.custom_id, - 'min_values': self.min_values, - 'max_values': self.max_values, - 'options': [op.to_dict() for op in self.options], - 'disabled': self.disabled, + "type": self.type.value, + "custom_id": self.custom_id, + "min_values": self.min_values, + "max_values": self.max_values, + "options": [op.to_dict() for op in self.options], + "disabled": self.disabled, } if self.placeholder: - payload['placeholder'] = self.placeholder + payload["placeholder"] = self.placeholder return payload @@ -293,11 +386,11 @@ class SelectOption: """ __slots__: Tuple[str, ...] = ( - 'label', - 'value', - 'description', - 'emoji', - 'default', + "label", + "value", + "description", + "emoji", + "default", ) def __init__( @@ -319,60 +412,62 @@ def __init__( elif isinstance(emoji, _EmojiTag): emoji = emoji._to_partial() else: - raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}') + raise TypeError( + f"expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}" + ) self.emoji = emoji self.default = default def __repr__(self) -> str: return ( - f'' + f"" ) def __str__(self) -> str: if self.emoji: - base = f'{self.emoji} {self.label}' + base = f"{self.emoji} {self.label}" else: base = self.label if self.description: - return f'{base}\n{self.description}' + return f"{base}\n{self.description}" return base @classmethod def from_dict(cls, data: SelectOptionPayload) -> SelectOption: try: - emoji = PartialEmoji.from_dict(data['emoji']) + emoji = PartialEmoji.from_dict(data["emoji"]) except KeyError: emoji = None return cls( - label=data['label'], - value=data['value'], - description=data.get('description'), + label=data["label"], + value=data["value"], + description=data.get("description"), emoji=emoji, - default=data.get('default', False), + default=data.get("default", False), ) def to_dict(self) -> SelectOptionPayload: payload: SelectOptionPayload = { - 'label': self.label, - 'value': self.value, - 'default': self.default, + "label": self.label, + "value": self.value, + "default": self.default, } if self.emoji: - payload['emoji'] = self.emoji.to_dict() # type: ignore + payload["emoji"] = self.emoji.to_dict() # type: ignore if self.description: - payload['description'] = self.description + payload["description"] = self.description return payload def _component_factory(data: ComponentPayload) -> Component: - component_type = data['type'] + component_type = data["type"] if component_type == 1: return ActionRow(data) elif component_type == 2: diff --git a/discord/enums.py b/discord/enums.py index 71414d81a3..f44c4455c6 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -25,45 +25,46 @@ import types from collections import namedtuple -from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Type, TypeVar __all__ = ( - 'Enum', - 'ChannelType', - 'MessageType', - 'VoiceRegion', - 'SpeakingState', - 'VerificationLevel', - 'ContentFilter', - 'Status', - 'DefaultAvatar', - 'AuditLogAction', - 'AuditLogActionCategory', - 'UserFlags', - 'ActivityType', - 'NotificationLevel', - 'TeamMembershipState', - 'WebhookType', - 'ExpireBehaviour', - 'ExpireBehavior', - 'StickerType', - 'StickerFormatType', - 'InviteTarget', - 'VideoQualityMode', - 'ComponentType', - 'ButtonStyle', - 'StagePrivacyLevel', - 'InteractionType', - 'InteractionResponseType', - 'NSFWLevel', - 'EmbeddedActivity', + "Enum", + "ChannelType", + "MessageType", + "VoiceRegion", + "SpeakingState", + "VerificationLevel", + "ContentFilter", + "Status", + "DefaultAvatar", + "AuditLogAction", + "AuditLogActionCategory", + "UserFlags", + "ActivityType", + "NotificationLevel", + "TeamMembershipState", + "WebhookType", + "ExpireBehaviour", + "ExpireBehavior", + "StickerType", + "StickerFormatType", + "InviteTarget", + "VideoQualityMode", + "ComponentType", + "ButtonStyle", + "StagePrivacyLevel", + "InteractionType", + "InteractionResponseType", + "NSFWLevel", + "EmbeddedActivity", + "InputTextStyle", ) def _create_value_cls(name, comparable): - cls = namedtuple('_EnumValue_' + name, 'name value') - cls.__repr__ = lambda self: f'<{name}.{self.name}: {self.value!r}>' - cls.__str__ = lambda self: f'{name}.{self.name}' + cls = namedtuple(f"_EnumValue_{name}", "name value") + cls.__repr__ = lambda self: f"<{name}.{self.name}: {self.value!r}>" + cls.__str__ = lambda self: f"{name}.{self.name}" if comparable: cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value @@ -71,8 +72,9 @@ def _create_value_cls(name, comparable): cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value return cls + def _is_descriptor(obj): - return hasattr(obj, '__get__') or hasattr(obj, '__set__') or hasattr(obj, '__delete__') + return hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") class EnumMeta(type): @@ -90,7 +92,7 @@ def __new__(cls, name, bases, attrs, *, comparable: bool = False): value_cls = _create_value_cls(name, comparable) for key, value in list(attrs.items()): is_descriptor = _is_descriptor(value) - if key[0] == '_' and not is_descriptor: + if key[0] == "_" and not is_descriptor: continue # Special case classmethod to just pass through @@ -112,10 +114,10 @@ def __new__(cls, name, bases, attrs, *, comparable: bool = False): member_mapping[key] = new_value attrs[key] = new_value - attrs['_enum_value_map_'] = value_mapping - attrs['_enum_member_map_'] = member_mapping - attrs['_enum_member_names_'] = member_names - attrs['_enum_value_cls_'] = value_cls + attrs["_enum_value_map_"] = value_mapping + attrs["_enum_member_map_"] = member_mapping + attrs["_enum_member_names_"] = member_names + attrs["_enum_value_cls_"] = value_cls actual_cls = super().__new__(cls, name, bases, attrs) value_cls._actual_enum_cls_ = actual_cls # type: ignore return actual_cls @@ -130,7 +132,7 @@ def __len__(cls): return len(cls._enum_member_names_) def __repr__(cls): - return f'' + return f"" @property def __members__(cls): @@ -146,10 +148,10 @@ def __getitem__(cls, key): return cls._enum_member_map_[key] def __setattr__(cls, name, value): - raise TypeError('Enums are immutable.') + raise TypeError("Enums are immutable.") def __delattr__(cls, attr): - raise TypeError('Enums are immutable') + raise TypeError("Enums are immutable") def __instancecheck__(self, instance): # isinstance(x, Y) @@ -220,29 +222,29 @@ class MessageType(Enum): class VoiceRegion(Enum): - us_west = 'us-west' - us_east = 'us-east' - us_south = 'us-south' - us_central = 'us-central' - eu_west = 'eu-west' - eu_central = 'eu-central' - singapore = 'singapore' - london = 'london' - sydney = 'sydney' - amsterdam = 'amsterdam' - frankfurt = 'frankfurt' - brazil = 'brazil' - hongkong = 'hongkong' - russia = 'russia' - japan = 'japan' - southafrica = 'southafrica' - south_korea = 'south-korea' - india = 'india' - europe = 'europe' - dubai = 'dubai' - vip_us_east = 'vip-us-east' - vip_us_west = 'vip-us-west' - vip_amsterdam = 'vip-amsterdam' + us_west = "us-west" + us_east = "us-east" + us_south = "us-south" + us_central = "us-central" + eu_west = "eu-west" + eu_central = "eu-central" + singapore = "singapore" + london = "london" + sydney = "sydney" + amsterdam = "amsterdam" + frankfurt = "frankfurt" + brazil = "brazil" + hongkong = "hongkong" + russia = "russia" + japan = "japan" + southafrica = "southafrica" + south_korea = "south-korea" + india = "india" + europe = "europe" + dubai = "dubai" + vip_us_east = "vip-us-east" + vip_us_west = "vip-us-west" + vip_amsterdam = "vip-amsterdam" def __str__(self): return self.value @@ -282,13 +284,13 @@ def __str__(self): class Status(Enum): - online = 'online' - offline = 'offline' - idle = 'idle' - dnd = 'dnd' - do_not_disturb = 'dnd' - invisible = 'invisible' - streaming = 'streaming' + online = "online" + offline = "offline" + idle = "idle" + dnd = "dnd" + do_not_disturb = "dnd" + invisible = "invisible" + streaming = "streaming" def __str__(self): return self.value @@ -427,35 +429,35 @@ def category(self) -> Optional[AuditLogActionCategory]: def target_type(self) -> Optional[str]: v = self.value if v == -1: - return 'all' + return "all" elif v < 10: - return 'guild' + return "guild" elif v < 20: - return 'channel' + return "channel" elif v < 30: - return 'user' + return "user" elif v < 40: - return 'role' + return "role" elif v < 50: - return 'invite' + return "invite" elif v < 60: - return 'webhook' + return "webhook" elif v < 70: - return 'emoji' + return "emoji" elif v == 73: - return 'channel' + return "channel" elif v < 80: - return 'message' + return "message" elif v < 83: - return 'integration' + return "integration" elif v < 90: - return 'stage_instance' + return "stage_instance" elif v < 93: - return 'sticker' + return "sticker" elif v < 103: - return 'scheduled_event' + return "scheduled_event" elif v < 113: - return 'thread' + return "thread" class UserFlags(Enum): @@ -547,6 +549,7 @@ class InteractionType(Enum): application_command = 2 component = 3 auto_complete = 4 + modal_submit = 5 class InteractionResponseType(Enum): @@ -557,8 +560,8 @@ class InteractionResponseType(Enum): deferred_channel_message = 5 # (with source) deferred_message_update = 6 # for components message_update = 7 # for components - auto_complete_result = 8 # for autocomplete interactions - + auto_complete_result = 8 # for autocomplete interactions + modal = 9 # for modal dialogs class VideoQualityMode(Enum): auto = 1 @@ -572,6 +575,7 @@ class ComponentType(Enum): action_row = 1 button = 2 select = 3 + input_text = 4 def __int__(self): return self.value @@ -596,6 +600,14 @@ def __int__(self): return self.value +class InputTextStyle(Enum): + short = 1 + singleline = 1 + paragraph = 2 + multiline = 2 + long = 2 + + class ApplicationType(Enum): game = 1 music = 2 @@ -631,14 +643,14 @@ class SlashCommandOptionType(Enum): @classmethod def from_datatype(cls, datatype): - if isinstance(datatype, tuple): # typing.Union has been used + if isinstance(datatype, tuple): # typing.Union has been used datatypes = [cls.from_datatype(op) for op in datatype] if all([x == cls.channel for x in datatypes]): return cls.channel elif set(datatypes) <= {cls.role, cls.user}: return cls.mentionable else: - raise TypeError('Invalid usage of typing.Union') + raise TypeError("Invalid usage of typing.Union") if issubclass(datatype, str): return cls.string @@ -651,11 +663,7 @@ def from_datatype(cls, datatype): if datatype.__name__ in ["Member", "User"]: return cls.user - if datatype.__name__ in [ - "GuildChannel", "TextChannel", - "VoiceChannel", "StageChannel", - "CategoryChannel" - ]: + if datatype.__name__ in ["GuildChannel", "TextChannel", "VoiceChannel", "StageChannel", "CategoryChannel"]: return cls.channel if datatype.__name__ == "Role": return cls.role @@ -663,7 +671,7 @@ def from_datatype(cls, datatype): return cls.mentionable # TODO: Improve the error message - raise TypeError(f'Invalid class {datatype} used as an input type for an Option') + raise TypeError(f"Invalid class {datatype} used as an input type for an Option") class EmbeddedActivity(Enum): @@ -697,13 +705,15 @@ class EmbeddedActivity(Enum): watch_together_dev = 880218832743055411 word_snacks = 879863976006127627 word_snacks_dev = 879864010126786570 - youtube_together = 755600276941176913 + youtube_together = 755600276941176913 + + +T = TypeVar("T") -T = TypeVar('T') def create_unknown_value(cls: Type[T], val: Any) -> T: value_cls = cls._enum_value_cls_ # type: ignore - name = f'unknown_{val}' + name = f"unknown_{val}" return value_cls(name=name, value=val) diff --git a/discord/interactions.py b/discord/interactions.py index fd1c7cbdee..b97287a74d 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -26,45 +26,64 @@ """ from __future__ import annotations -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union + import asyncio +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from . import utils -from .enums import try_enum, InteractionType, InteractionResponseType -from .errors import InteractionResponded, HTTPException, ClientException, InvalidArgument -from .channel import PartialMessageable, ChannelType +from .channel import ChannelType, PartialMessageable +from .enums import InteractionResponseType, InteractionType, try_enum +from .errors import ( + ClientException, + HTTPException, + InteractionResponded, + InvalidArgument, +) from .file import File -from .user import User from .member import Member -from .message import Message, Attachment from .mentions import AllowedMentions +from .message import Attachment, Message from .object import Object from .permissions import Permissions -from .webhook.async_ import async_context, Webhook, handle_message_parameters +from .user import User +from .webhook.async_ import Webhook, async_context, handle_message_parameters __all__ = ( - 'Interaction', - 'InteractionMessage', - 'InteractionResponse', + "Interaction", + "InteractionMessage", + "InteractionResponse", ) if TYPE_CHECKING: - from .types.interactions import ( - Interaction as InteractionPayload, - InteractionData, + from aiohttp import ClientSession + + from .channel import ( + CategoryChannel, + PartialMessageable, + StageChannel, + StoreChannel, + TextChannel, + VoiceChannel, ) + from .commands import OptionChoice + from .embeds import Embed from .guild import Guild - from .state import ConnectionState from .mentions import AllowedMentions - from aiohttp import ClientSession - from .embeds import Embed - from .ui.view import View - from .channel import VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, PartialMessageable + from .state import ConnectionState from .threads import Thread - from .commands import OptionChoice + from .types.interactions import Interaction as InteractionPayload + from .types.interactions import InteractionData + from .ui.view import View + from .ui.modal import Modal InteractionChannel = Union[ - VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, Thread, PartialMessageable + VoiceChannel, + StageChannel, + TextChannel, + CategoryChannel, + StoreChannel, + Thread, + PartialMessageable, ] @@ -103,23 +122,23 @@ class Interaction: """ __slots__: Tuple[str, ...] = ( - 'id', - 'type', - 'guild_id', - 'channel_id', - 'data', - 'application_id', - 'message', - 'user', - 'token', - 'version', - '_permissions', - '_state', - '_session', - '_original_message', - '_cs_response', - '_cs_followup', - '_cs_channel', + "id", + "type", + "guild_id", + "channel_id", + "data", + "application_id", + "message", + "user", + "token", + "version", + "_permissions", + "_state", + "_session", + "_original_message", + "_cs_response", + "_cs_followup", + "_cs_channel", ) def __init__(self, *, data: InteractionPayload, state: ConnectionState): @@ -129,18 +148,18 @@ def __init__(self, *, data: InteractionPayload, state: ConnectionState): self._from_data(data) def _from_data(self, data: InteractionPayload): - self.id: int = int(data['id']) - self.type: InteractionType = try_enum(InteractionType, data['type']) - self.data: Optional[InteractionData] = data.get('data') - self.token: str = data['token'] - self.version: int = data['version'] - self.channel_id: Optional[int] = utils._get_as_snowflake(data, 'channel_id') - self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id') - self.application_id: int = int(data['application_id']) + self.id: int = int(data["id"]) + self.type: InteractionType = try_enum(InteractionType, data["type"]) + self.data: Optional[InteractionData] = data.get("data") + self.token: str = data["token"] + self.version: int = data["version"] + self.channel_id: Optional[int] = utils._get_as_snowflake(data, "channel_id") + self.guild_id: Optional[int] = utils._get_as_snowflake(data, "guild_id") + self.application_id: int = int(data["application_id"]) self.message: Optional[Message] try: - self.message = Message(state=self._state, channel=self.channel, data=data['message']) # type: ignore + self.message = Message(state=self._state, channel=self.channel, data=data["message"]) # type: ignore except KeyError: self.message = None @@ -151,15 +170,15 @@ def _from_data(self, data: InteractionPayload): if self.guild_id: guild = self.guild or Object(id=self.guild_id) try: - member = data['member'] # type: ignore + member = data["member"] # type: ignore except KeyError: pass else: self.user = Member(state=self._state, guild=guild, data=member) # type: ignore - self._permissions = int(member.get('permissions', 0)) + self._permissions = int(member.get("permissions", 0)) else: try: - self.user = User(state=self._state, data=data['user']) + self.user = User(state=self._state, data=data["user"]) except KeyError: pass @@ -176,7 +195,7 @@ def is_component(self) -> bool: """:class:`bool`: Indicates whether the interaction is a message component.""" return self.type == InteractionType.component - @utils.cached_slot_property('_cs_channel') + @utils.cached_slot_property("_cs_channel") def channel(self) -> Optional[InteractionChannel]: """Optional[Union[:class:`abc.GuildChannel`, :class:`PartialMessageable`, :class:`Thread`]]: The channel the interaction was sent from. @@ -187,8 +206,14 @@ def channel(self) -> Optional[InteractionChannel]: channel = guild and guild._resolve_channel(self.channel_id) if channel is None: if self.channel_id is not None: - type = ChannelType.text if self.guild_id is not None else ChannelType.private - return PartialMessageable(state=self._state, id=self.channel_id, type=type) + type = ( + ChannelType.text + if self.guild_id is not None + else ChannelType.private + ) + return PartialMessageable( + state=self._state, id=self.channel_id, type=type + ) return None return channel @@ -200,7 +225,7 @@ def permissions(self) -> Permissions: """ return Permissions(self._permissions) - @utils.cached_slot_property('_cs_response') + @utils.cached_slot_property("_cs_response") def response(self) -> InteractionResponse: """:class:`InteractionResponse`: Returns an object responsible for handling responding to the interaction. @@ -209,13 +234,13 @@ def response(self) -> InteractionResponse: """ return InteractionResponse(self) - @utils.cached_slot_property('_cs_followup') + @utils.cached_slot_property("_cs_followup") def followup(self) -> Webhook: """:class:`Webhook`: Returns the follow up webhook for follow up interactions.""" payload = { - 'id': self.application_id, - 'type': 3, - 'token': self.token, + "id": self.application_id, + "type": 3, + "token": self.token, } return Webhook.from_state(data=payload, state=self._state) @@ -249,7 +274,7 @@ async def original_message(self) -> InteractionMessage: # TODO: fix later to not raise? channel = self.channel if channel is None: - raise ClientException('Channel for message could not be resolved') + raise ClientException("Channel for message could not be resolved") adapter = async_context.get() data = await adapter.get_original_interaction_response( @@ -380,8 +405,8 @@ class InteractionResponse: """ __slots__: Tuple[str, ...] = ( - '_responded', - '_parent', + "_responded", + "_parent", ) def __init__(self, parent: Interaction): @@ -394,7 +419,6 @@ def is_done(self) -> bool: An interaction can only be responded to once. """ return self._responded - async def defer(self, *, ephemeral: bool = False) -> None: """|coro| @@ -425,19 +449,23 @@ async def defer(self, *, ephemeral: bool = False) -> None: parent = self._parent if parent.type is InteractionType.component: if ephemeral: - data = {'flags': 64} + data = {"flags": 64} defer_type = InteractionResponseType.deferred_channel_message.value else: defer_type = InteractionResponseType.deferred_message_update.value elif parent.type is InteractionType.application_command: defer_type = InteractionResponseType.deferred_channel_message.value if ephemeral: - data = {'flags': 64} + data = {"flags": 64} if defer_type: adapter = async_context.get() await adapter.create_interaction_response( - parent.id, parent.token, session=parent._session, type=defer_type, data=data + parent.id, + parent.token, + session=parent._session, + type=defer_type, + data=data, ) self._responded = True @@ -462,7 +490,10 @@ async def pong(self) -> None: if parent.type is InteractionType.ping: adapter = async_context.get() await adapter.create_interaction_response( - parent.id, parent.token, session=parent._session, type=InteractionResponseType.pong.value + parent.id, + parent.token, + session=parent._session, + type=InteractionResponseType.pong.value, ) self._responded = True @@ -478,7 +509,7 @@ async def send_message( allowed_mentions: AllowedMentions = None, file: File = None, files: List[File] = None, - delete_after: float = None + delete_after: float = None, ) -> Interaction: """|coro| @@ -512,7 +543,7 @@ async def send_message( The file to upload. files: :class:`List[File]` A list of files to upload. Must be a maximum of 10. - + Raises ------- HTTPException @@ -528,53 +559,59 @@ async def send_message( raise InteractionResponded(self._parent) payload: Dict[str, Any] = { - 'tts': tts, + "tts": tts, } if embed is not MISSING and embeds is not MISSING: - raise TypeError('cannot mix embed and embeds keyword arguments') + raise TypeError("cannot mix embed and embeds keyword arguments") if embed is not MISSING: embeds = [embed] if embeds: if len(embeds) > 10: - raise ValueError('embeds cannot exceed maximum of 10 elements') - payload['embeds'] = [e.to_dict() for e in embeds] + raise ValueError("embeds cannot exceed maximum of 10 elements") + payload["embeds"] = [e.to_dict() for e in embeds] if content is not None: - payload['content'] = str(content) + payload["content"] = str(content) if ephemeral: - payload['flags'] = 64 + payload["flags"] = 64 if view is not MISSING: - payload['components'] = view.to_components() + payload["components"] = view.to_components() state = self._parent._state if allowed_mentions is not None: if state.allowed_mentions is not None: - payload['allowed_mentions'] = state.allowed_mentions.merge(allowed_mentions).to_dict() + payload["allowed_mentions"] = state.allowed_mentions.merge( + allowed_mentions + ).to_dict() else: - payload['allowed_mentions'] = allowed_mentions.to_dict() + payload["allowed_mentions"] = allowed_mentions.to_dict() else: - payload['allowed_mentions'] = state.allowed_mentions and state.allowed_mentions.to_dict() + payload["allowed_mentions"] = ( + state.allowed_mentions and state.allowed_mentions.to_dict() + ) if file is not None and files is not None: - raise InvalidArgument('cannot pass both file and files parameter to send()') - + raise InvalidArgument("cannot pass both file and files parameter to send()") + if file is not None: if not isinstance(file, File): - raise InvalidArgument('file parameter must be File') + raise InvalidArgument("file parameter must be File") else: files = [file] if files is not None: if len(files) > 10: - raise InvalidArgument('files parameter must be a list of up to 10 elements') + raise InvalidArgument( + "files parameter must be a list of up to 10 elements" + ) elif not all(isinstance(file, File) for file in files): - raise InvalidArgument('files parameter must be a list of File') + raise InvalidArgument("files parameter must be a list of File") parent = self._parent adapter = async_context.get() @@ -585,7 +622,7 @@ async def send_message( session=parent._session, type=InteractionResponseType.channel_message.value, data=payload, - files=files + files=files, ) finally: if files: @@ -600,9 +637,11 @@ async def send_message( self._responded = True if delete_after is not None: + async def delete(): await asyncio.sleep(delete_after) await self._parent.delete_original_message() + asyncio.ensure_future(delete(), loop=self._parent._state.loop) return self._parent @@ -658,12 +697,12 @@ async def edit_message( payload = {} if content is not MISSING: if content is None: - payload['content'] = None + payload["content"] = None else: - payload['content'] = str(content) + payload["content"] = str(content) if embed is not MISSING and embeds is not MISSING: - raise TypeError('cannot mix both embed and embeds keyword arguments') + raise TypeError("cannot mix both embed and embeds keyword arguments") if embed is not MISSING: if embed is None: @@ -672,17 +711,17 @@ async def edit_message( embeds = [embed] if embeds is not MISSING: - payload['embeds'] = [e.to_dict() for e in embeds] + payload["embeds"] = [e.to_dict() for e in embeds] if attachments is not MISSING: - payload['attachments'] = [a.to_dict() for a in attachments] + payload["attachments"] = [a.to_dict() for a in attachments] if view is not MISSING: state.prevent_view_updates_for(message_id) if view is None: - payload['components'] = [] + payload["components"] = [] else: - payload['components'] = view.to_components() + payload["components"] = view.to_components() adapter = async_context.get() await adapter.create_interaction_response( @@ -709,7 +748,7 @@ async def send_autocomplete_result( Parameters ----------- choices: List[:class:`OptionChoice`] - A list of choices. + A list of choices. Raises ------- @@ -726,9 +765,7 @@ async def send_autocomplete_result( if parent.type is not InteractionType.auto_complete: return - payload = { - "choices": [c.to_dict() for c in choices] - } + payload = {"choices": [c.to_dict() for c in choices]} adapter = async_context.get() await adapter.create_interaction_response( @@ -740,9 +777,26 @@ async def send_autocomplete_result( ) self._responded = True - + + async def send_modal(self, modal: Modal): + if self._responded: + raise InteractionResponded(self._parent) + + payload = modal.to_dict() + adapter = async_context.get() + await adapter.create_interaction_response( + self._parent.id, + self._parent.token, + session=self._parent._session, + type=InteractionResponseType.modal.value, + data=payload, + ) + self._responded = True + self._parent._state.store_modal(modal, self._parent.user.id) + + class _InteractionMessageState: - __slots__ = ('_parent', '_interaction') + __slots__ = ("_parent", "_interaction") def __init__(self, interaction: Interaction, parent: ConnectionState): self._interaction: Interaction = interaction diff --git a/discord/state.py b/discord/state.py index b2f25f302d..8e5e8bdcec 100644 --- a/discord/state.py +++ b/discord/state.py @@ -26,59 +26,71 @@ from __future__ import annotations import asyncio -from collections import deque, OrderedDict import copy import datetime +import inspect import itertools import logging -from typing import Dict, Optional, TYPE_CHECKING, Union, Callable, Any, List, TypeVar, Coroutine, Sequence, Tuple, Deque -import inspect - import os +from collections import OrderedDict, deque +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Deque, + Dict, + List, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) -from .guild import Guild +from . import utils from .activity import BaseActivity -from .user import User, ClientUser -from .emoji import Emoji -from .mentions import AllowedMentions -from .partial_emoji import PartialEmoji -from .message import Message from .channel import * from .channel import _channel_factory -from .raw_models import * -from .member import Member -from .role import Role -from .enums import ChannelType, try_enum, Status -from . import utils +from .emoji import Emoji +from .enums import ChannelType, Status, try_enum from .flags import ApplicationFlags, Intents, MemberCacheFlags -from .object import Object -from .invite import Invite +from .guild import Guild from .integrations import _integration_factory from .interactions import Interaction -from .ui.view import ViewStore, View +from .invite import Invite +from .member import Member +from .mentions import AllowedMentions +from .message import Message +from .object import Object +from .partial_emoji import PartialEmoji +from .raw_models import * +from .role import Role from .stage_instance import StageInstance -from .threads import Thread, ThreadMember from .sticker import GuildSticker +from .threads import Thread, ThreadMember +from .ui.modal import Modal, ModalStore +from .ui.view import View, ViewStore +from .user import ClientUser, User if TYPE_CHECKING: from .abc import PrivateChannel - from .message import MessageableChannel - from .guild import GuildChannel, VocalGuildChannel - from .http import HTTPClient - from .voice_client import VoiceProtocol from .client import Client from .gateway import DiscordWebSocket - + from .guild import GuildChannel, VocalGuildChannel + from .http import HTTPClient + from .message import MessageableChannel from .types.activity import Activity as ActivityPayload from .types.channel import DMChannel as DMChannelPayload - from .types.user import User as UserPayload from .types.emoji import Emoji as EmojiPayload - from .types.sticker import GuildSticker as GuildStickerPayload from .types.guild import Guild as GuildPayload from .types.message import Message as MessagePayload + from .types.sticker import GuildSticker as GuildStickerPayload + from .types.user import User as UserPayload + from .voice_client import VoiceProtocol - T = TypeVar('T') - CS = TypeVar('CS', bound='ConnectionState') + T = TypeVar("T") + CS = TypeVar("CS", bound="ConnectionState") Channel = Union[GuildChannel, VocalGuildChannel, PrivateChannel, PartialMessageable] @@ -133,11 +145,13 @@ def done(self) -> None: _log = logging.getLogger(__name__) -async def logging_coroutine(coroutine: Coroutine[Any, Any, T], *, info: str) -> Optional[T]: +async def logging_coroutine( + coroutine: Coroutine[Any, Any, T], *, info: str +) -> Optional[T]: try: await coroutine except Exception: - _log.exception('Exception occurred during %s', info) + _log.exception("Exception occurred during %s", info) class ConnectionState: @@ -158,7 +172,7 @@ def __init__( ) -> None: self.loop: asyncio.AbstractEventLoop = loop self.http: HTTPClient = http - self.max_messages: Optional[int] = options.get('max_messages', 1000) + self.max_messages: Optional[int] = options.get("max_messages", 1000) if self.max_messages is not None and self.max_messages <= 0: self.max_messages = 1000 @@ -167,56 +181,70 @@ def __init__( self.hooks: Dict[str, Callable] = hooks self.shard_count: Optional[int] = None self._ready_task: Optional[asyncio.Task] = None - self.application_id: Optional[int] = utils._get_as_snowflake(options, 'application_id') - self.heartbeat_timeout: float = options.get('heartbeat_timeout', 60.0) - self.guild_ready_timeout: float = options.get('guild_ready_timeout', 2.0) + self.application_id: Optional[int] = utils._get_as_snowflake( + options, "application_id" + ) + self.heartbeat_timeout: float = options.get("heartbeat_timeout", 60.0) + self.guild_ready_timeout: float = options.get("guild_ready_timeout", 2.0) if self.guild_ready_timeout < 0: - raise ValueError('guild_ready_timeout cannot be negative') + raise ValueError("guild_ready_timeout cannot be negative") - allowed_mentions = options.get('allowed_mentions') + allowed_mentions = options.get("allowed_mentions") - if allowed_mentions is not None and not isinstance(allowed_mentions, AllowedMentions): - raise TypeError('allowed_mentions parameter must be AllowedMentions') + if allowed_mentions is not None and not isinstance( + allowed_mentions, AllowedMentions + ): + raise TypeError("allowed_mentions parameter must be AllowedMentions") self.allowed_mentions: Optional[AllowedMentions] = allowed_mentions self._chunk_requests: Dict[Union[int, str], ChunkRequest] = {} - activity = options.get('activity', None) + activity = options.get("activity", None) if activity: if not isinstance(activity, BaseActivity): - raise TypeError('activity parameter must derive from BaseActivity.') + raise TypeError("activity parameter must derive from BaseActivity.") activity = activity.to_dict() - status = options.get('status', None) + status = options.get("status", None) if status: if status is Status.offline: - status = 'invisible' + status = "invisible" else: status = str(status) - intents = options.get('intents', None) + intents = options.get("intents", None) if intents is not None: if not isinstance(intents, Intents): - raise TypeError(f'intents parameter must be Intent not {type(intents)!r}') + raise TypeError( + f"intents parameter must be Intent not {type(intents)!r}" + ) else: intents = Intents.default() if not intents.guilds: - _log.warning('Guilds intent seems to be disabled. This may cause state related issues.') + _log.warning( + "Guilds intent seems to be disabled. This may cause state related issues." + ) - self._chunk_guilds: bool = options.get('chunk_guilds_at_startup', intents.members) + self._chunk_guilds: bool = options.get( + "chunk_guilds_at_startup", intents.members + ) # Ensure these two are set properly if not intents.members and self._chunk_guilds: - raise ValueError('Intents.members must be enabled to chunk guilds at startup.') + raise ValueError( + "Intents.members must be enabled to chunk guilds at startup." + ) - cache_flags = options.get('member_cache_flags', None) + cache_flags = options.get("member_cache_flags", None) if cache_flags is None: cache_flags = MemberCacheFlags.from_intents(intents) else: if not isinstance(cache_flags, MemberCacheFlags): - raise TypeError(f'member_cache_flags parameter must be MemberCacheFlags not {type(cache_flags)!r}') + raise TypeError( + f"member_cache_flags parameter must be MemberCacheFlags not {type(cache_flags)!r}" + ) cache_flags._verify_intents(intents) @@ -231,7 +259,7 @@ def __init__( self.parsers = parsers = {} for attr, func in inspect.getmembers(self): - if attr.startswith('parse_'): + if attr.startswith("parse_"): parsers[attr[6:].upper()] = func self.clear() @@ -256,6 +284,7 @@ def clear(self, *, views: bool = True) -> None: self._guilds: Dict[int, Guild] = {} if views: self._view_store: ViewStore = ViewStore(self) + self._modal_store: ModalStore = ModalStore(self) self._voice_clients: Dict[int, VoiceProtocol] = {} @@ -268,7 +297,9 @@ def clear(self, *, views: bool = True) -> None: else: self._messages: Optional[Deque[Message]] = None - def process_chunk_requests(self, guild_id: int, nonce: Optional[str], members: List[Member], complete: bool) -> None: + def process_chunk_requests( + self, guild_id: int, nonce: Optional[str], members: List[Member], complete: bool + ) -> None: removed = [] for key, request in self._chunk_requests.items(): if request.guild_id == guild_id and request.nonce == nonce: @@ -326,12 +357,12 @@ def _update_references(self, ws: DiscordWebSocket) -> None: vc.main_ws = ws # type: ignore def store_user(self, data: UserPayload) -> User: - user_id = int(data['id']) + user_id = int(data["id"]) try: return self._users[user_id] except KeyError: user = User(state=self, data=data) - if user.discriminator != '0000': + if user.discriminator != "0000": self._users[user_id] = user user._stored = True return user @@ -351,18 +382,21 @@ def get_user(self, id: Optional[int]) -> Optional[User]: def store_emoji(self, guild: Guild, data: EmojiPayload) -> Emoji: # the id will be present here - emoji_id = int(data['id']) # type: ignore + emoji_id = int(data["id"]) # type: ignore self._emojis[emoji_id] = emoji = Emoji(guild=guild, state=self, data=data) return emoji def store_sticker(self, guild: Guild, data: GuildStickerPayload) -> GuildSticker: - sticker_id = int(data['id']) + sticker_id = int(data["id"]) self._stickers[sticker_id] = sticker = GuildSticker(state=self, data=data) return sticker def store_view(self, view: View, message_id: Optional[int] = None) -> None: self._view_store.add_view(view, message_id) + def store_modal(self, modal: Modal, message_id: int) -> None: + self._modal_store.add_modal(modal, message_id) + def prevent_view_updates_for(self, message_id: int) -> Optional[View]: return self._view_store.remove_message_tracking(message_id) @@ -412,7 +446,9 @@ def get_sticker(self, sticker_id: Optional[int]) -> Optional[GuildSticker]: def private_channels(self) -> List[PrivateChannel]: return list(self._private_channels.values()) - def _get_private_channel(self, channel_id: Optional[int]) -> Optional[PrivateChannel]: + def _get_private_channel( + self, channel_id: Optional[int] + ) -> Optional[PrivateChannel]: try: # the keys of self._private_channels are ints value = self._private_channels[channel_id] # type: ignore @@ -422,7 +458,9 @@ def _get_private_channel(self, channel_id: Optional[int]) -> Optional[PrivateCha self._private_channels.move_to_end(channel_id) # type: ignore return value - def _get_private_channel_by_user(self, user_id: Optional[int]) -> Optional[DMChannel]: + def _get_private_channel_by_user( + self, user_id: Optional[int] + ) -> Optional[DMChannel]: # the keys of self._private_channels are ints return self._private_channels_by_user.get(user_id) # type: ignore @@ -452,7 +490,11 @@ def _remove_private_channel(self, channel: PrivateChannel) -> None: self._private_channels_by_user.pop(recipient.id, None) def _get_message(self, msg_id: Optional[int]) -> Optional[Message]: - return utils.find(lambda m: m.id == msg_id, reversed(self._messages)) if self._messages else None + return ( + utils.find(lambda m: m.id == msg_id, reversed(self._messages)) + if self._messages + else None + ) def _add_guild_from_data(self, data: GuildPayload) -> Guild: guild = Guild(data=data, state=self) @@ -461,12 +503,18 @@ def _add_guild_from_data(self, data: GuildPayload) -> Guild: def _guild_needs_chunking(self, guild: Guild) -> bool: # If presences are enabled then we get back the old guild.large behaviour - return self._chunk_guilds and not guild.chunked and not (self._intents.presences and not guild.large) - - def _get_guild_channel(self, data: MessagePayload) -> Tuple[Union[Channel, Thread], Optional[Guild]]: - channel_id = int(data['channel_id']) + return ( + self._chunk_guilds + and not guild.chunked + and not (self._intents.presences and not guild.large) + ) + + def _get_guild_channel( + self, data: MessagePayload + ) -> Tuple[Union[Channel, Thread], Optional[Guild]]: + channel_id = int(data["channel_id"]) try: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) except KeyError: channel = DMChannel._from_message(self, channel_id) guild = None @@ -476,16 +524,32 @@ def _get_guild_channel(self, data: MessagePayload) -> Tuple[Union[Channel, Threa return channel or PartialMessageable(state=self, id=channel_id), guild async def chunker( - self, guild_id: int, query: str = '', limit: int = 0, presences: bool = False, *, nonce: Optional[str] = None + self, + guild_id: int, + query: str = "", + limit: int = 0, + presences: bool = False, + *, + nonce: Optional[str] = None, ) -> None: ws = self._get_websocket(guild_id) # This is ignored upstream - await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) + await ws.request_chunks( + guild_id, query=query, limit=limit, presences=presences, nonce=nonce + ) - async def query_members(self, guild: Guild, query: str, limit: int, user_ids: List[int], cache: bool, presences: bool): + async def query_members( + self, + guild: Guild, + query: str, + limit: int, + user_ids: List[int], + cache: bool, + presences: bool, + ): guild_id = guild.id ws = self._get_websocket(guild_id) if ws is None: - raise RuntimeError('Somehow do not have a websocket for this guild_id') + raise RuntimeError("Somehow do not have a websocket for this guild_id") request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache) self._chunk_requests[request.nonce] = request @@ -493,11 +557,21 @@ async def query_members(self, guild: Guild, query: str, limit: int, user_ids: Li try: # start the query operation await ws.request_chunks( - guild_id, query=query, limit=limit, user_ids=user_ids, presences=presences, nonce=request.nonce + guild_id, + query=query, + limit=limit, + user_ids=user_ids, + presences=presences, + nonce=request.nonce, ) return await asyncio.wait_for(request.wait(), timeout=30.0) except asyncio.TimeoutError: - _log.warning('Timed out waiting for chunks with query %r and limit %d for guild_id %d', query, limit, guild_id) + _log.warning( + "Timed out waiting for chunks with query %r and limit %d for guild_id %d", + query, + limit, + guild_id, + ) raise async def _delay_ready(self) -> None: @@ -507,7 +581,9 @@ async def _delay_ready(self) -> None: # this snippet of code is basically waiting N seconds # until the last GUILD_CREATE was sent try: - guild = await asyncio.wait_for(self._ready_state.get(), timeout=self.guild_ready_timeout) + guild = await asyncio.wait_for( + self._ready_state.get(), timeout=self.guild_ready_timeout + ) except asyncio.TimeoutError: break else: @@ -516,20 +592,24 @@ async def _delay_ready(self) -> None: states.append((guild, future)) else: if guild.unavailable is False: - self.dispatch('guild_available', guild) + self.dispatch("guild_available", guild) else: - self.dispatch('guild_join', guild) + self.dispatch("guild_join", guild) for guild, future in states: try: await asyncio.wait_for(future, timeout=5.0) except asyncio.TimeoutError: - _log.warning('Shard ID %s timed out waiting for chunks for guild_id %s.', guild.shard_id, guild.id) + _log.warning( + "Shard ID %s timed out waiting for chunks for guild_id %s.", + guild.shard_id, + guild.id, + ) if guild.unavailable is False: - self.dispatch('guild_available', guild) + self.dispatch("guild_available", guild) else: - self.dispatch('guild_join', guild) + self.dispatch("guild_join", guild) # remove the state try: @@ -541,8 +621,8 @@ async def _delay_ready(self) -> None: pass else: # dispatch the event - self.call_handlers('ready') - self.dispatch('ready') + self.call_handlers("ready") + self.dispatch("ready") finally: self._ready_task = None @@ -552,33 +632,33 @@ def parse_ready(self, data) -> None: self._ready_state = asyncio.Queue() self.clear(views=False) - self.user = ClientUser(state=self, data=data['user']) - self.store_user(data['user']) + self.user = ClientUser(state=self, data=data["user"]) + self.store_user(data["user"]) if self.application_id is None: try: - application = data['application'] + application = data["application"] except KeyError: pass else: - self.application_id = utils._get_as_snowflake(application, 'id') + self.application_id = utils._get_as_snowflake(application, "id") # flags will always be present here - self.application_flags = ApplicationFlags._from_value(application['flags']) # type: ignore + self.application_flags = ApplicationFlags._from_value(application["flags"]) # type: ignore - for guild_data in data['guilds']: + for guild_data in data["guilds"]: self._add_guild_from_data(guild_data) - self.dispatch('connect') + self.dispatch("connect") self._ready_task = asyncio.create_task(self._delay_ready()) def parse_resumed(self, data) -> None: - self.dispatch('resumed') + self.dispatch("resumed") def parse_message_create(self, data) -> None: channel, _ = self._get_guild_channel(data) # channel would be the correct type here message = Message(channel=channel, data=data, state=self) # type: ignore - self.dispatch('message', message) + self.dispatch("message", message) if self._messages is not None: self._messages.append(message) # we ensure that the channel is either a TextChannel or Thread @@ -589,21 +669,23 @@ def parse_message_delete(self, data) -> None: raw = RawMessageDeleteEvent(data) found = self._get_message(raw.message_id) raw.cached_message = found - self.dispatch('raw_message_delete', raw) + self.dispatch("raw_message_delete", raw) if self._messages is not None and found is not None: - self.dispatch('message_delete', found) + self.dispatch("message_delete", found) self._messages.remove(found) def parse_message_delete_bulk(self, data) -> None: raw = RawBulkMessageDeleteEvent(data) if self._messages: - found_messages = [message for message in self._messages if message.id in raw.message_ids] + found_messages = [ + message for message in self._messages if message.id in raw.message_ids + ] else: found_messages = [] raw.cached_messages = found_messages - self.dispatch('raw_bulk_message_delete', raw) + self.dispatch("raw_bulk_message_delete", raw) if found_messages: - self.dispatch('bulk_message_delete', found_messages) + self.dispatch("bulk_message_delete", found_messages) for msg in found_messages: # self._messages won't be None here self._messages.remove(msg) # type: ignore @@ -614,25 +696,27 @@ def parse_message_update(self, data) -> None: if message is not None: older_message = copy.copy(message) raw.cached_message = older_message - self.dispatch('raw_message_edit', raw) + self.dispatch("raw_message_edit", raw) message._update(data) # Coerce the `after` parameter to take the new updated Member # ref: #5999 older_message.author = message.author - self.dispatch('message_edit', older_message, message) + self.dispatch("message_edit", older_message, message) else: - self.dispatch('raw_message_edit', raw) + self.dispatch("raw_message_edit", raw) - if 'components' in data and self._view_store.is_message_tracked(raw.message_id): - self._view_store.update_from_message(raw.message_id, data['components']) + if "components" in data and self._view_store.is_message_tracked(raw.message_id): + self._view_store.update_from_message(raw.message_id, data["components"]) def parse_message_reaction_add(self, data) -> None: - emoji = data['emoji'] - emoji_id = utils._get_as_snowflake(emoji, 'id') - emoji = PartialEmoji.with_state(self, id=emoji_id, animated=emoji.get('animated', False), name=emoji['name']) - raw = RawReactionActionEvent(data, emoji, 'REACTION_ADD') - - member_data = data.get('member') + emoji = data["emoji"] + emoji_id = utils._get_as_snowflake(emoji, "id") + emoji = PartialEmoji.with_state( + self, id=emoji_id, animated=emoji.get("animated", False), name=emoji["name"] + ) + raw = RawReactionActionEvent(data, emoji, "REACTION_ADD") + + member_data = data.get("member") if member_data: guild = self._get_guild(raw.guild_id) if guild is not None: @@ -641,7 +725,7 @@ def parse_message_reaction_add(self, data) -> None: raw.member = None else: raw.member = None - self.dispatch('raw_reaction_add', raw) + self.dispatch("raw_reaction_add", raw) # rich interface here message = self._get_message(raw.message_id) @@ -651,24 +735,24 @@ def parse_message_reaction_add(self, data) -> None: user = raw.member or self._get_reaction_user(message.channel, raw.user_id) if user: - self.dispatch('reaction_add', reaction, user) + self.dispatch("reaction_add", reaction, user) def parse_message_reaction_remove_all(self, data) -> None: raw = RawReactionClearEvent(data) - self.dispatch('raw_reaction_clear', raw) + self.dispatch("raw_reaction_clear", raw) message = self._get_message(raw.message_id) if message is not None: old_reactions = message.reactions.copy() message.reactions.clear() - self.dispatch('reaction_clear', message, old_reactions) + self.dispatch("reaction_clear", message, old_reactions) def parse_message_reaction_remove(self, data) -> None: - emoji = data['emoji'] - emoji_id = utils._get_as_snowflake(emoji, 'id') - emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji['name']) - raw = RawReactionActionEvent(data, emoji, 'REACTION_REMOVE') - self.dispatch('raw_reaction_remove', raw) + emoji = data["emoji"] + emoji_id = utils._get_as_snowflake(emoji, "id") + emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji["name"]) + raw = RawReactionActionEvent(data, emoji, "REACTION_REMOVE") + self.dispatch("raw_reaction_remove", raw) message = self._get_message(raw.message_id) if message is not None: @@ -680,14 +764,14 @@ def parse_message_reaction_remove(self, data) -> None: else: user = self._get_reaction_user(message.channel, raw.user_id) if user: - self.dispatch('reaction_remove', reaction, user) + self.dispatch("reaction_remove", reaction, user) def parse_message_reaction_remove_emoji(self, data) -> None: - emoji = data['emoji'] - emoji_id = utils._get_as_snowflake(emoji, 'id') - emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji['name']) + emoji = data["emoji"] + emoji_id = utils._get_as_snowflake(emoji, "id") + emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji["name"]) raw = RawReactionClearEmojiEvent(data, emoji) - self.dispatch('raw_reaction_clear_emoji', raw) + self.dispatch("raw_reaction_clear_emoji", raw) message = self._get_message(raw.message_id) if message is not None: @@ -697,38 +781,44 @@ def parse_message_reaction_remove_emoji(self, data) -> None: pass else: if reaction: - self.dispatch('reaction_clear_emoji', reaction) + self.dispatch("reaction_clear_emoji", reaction) def parse_interaction_create(self, data) -> None: interaction = Interaction(data=data, state=self) - if data['type'] == 3: # interaction component - custom_id = interaction.data['custom_id'] # type: ignore - component_type = interaction.data['component_type'] # type: ignore + if data["type"] == 3: # interaction component + custom_id = interaction.data["custom_id"] # type: ignore + component_type = interaction.data["component_type"] # type: ignore self._view_store.dispatch(component_type, custom_id, interaction) - self.dispatch('interaction', interaction) + self.dispatch("interaction", interaction) def parse_presence_update(self, data) -> None: - guild_id = utils._get_as_snowflake(data, 'guild_id') + guild_id = utils._get_as_snowflake(data, "guild_id") # guild_id won't be None here guild = self._get_guild(guild_id) if guild is None: - _log.debug('PRESENCE_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) + _log.debug( + "PRESENCE_UPDATE referencing an unknown guild ID: %s. Discarding.", + guild_id, + ) return - user = data['user'] - member_id = int(user['id']) + user = data["user"] + member_id = int(user["id"]) member = guild.get_member(member_id) if member is None: - _log.debug('PRESENCE_UPDATE referencing an unknown member ID: %s. Discarding', member_id) + _log.debug( + "PRESENCE_UPDATE referencing an unknown member ID: %s. Discarding", + member_id, + ) return old_member = Member._copy(member) user_update = member._presence_update(data=data, user=user) if user_update: - self.dispatch('user_update', user_update[0], user_update[1]) + self.dispatch("user_update", user_update[0], user_update[1]) - self.dispatch('presence_update', old_member, member) + self.dispatch("presence_update", old_member, member) def parse_user_update(self, data) -> None: # self.user is *always* cached when this is called @@ -740,66 +830,78 @@ def parse_user_update(self, data) -> None: def parse_invite_create(self, data) -> None: invite = Invite.from_gateway(state=self, data=data) - self.dispatch('invite_create', invite) + self.dispatch("invite_create", invite) def parse_invite_delete(self, data) -> None: invite = Invite.from_gateway(state=self, data=data) - self.dispatch('invite_delete', invite) + self.dispatch("invite_delete", invite) def parse_channel_delete(self, data) -> None: - guild = self._get_guild(utils._get_as_snowflake(data, 'guild_id')) - channel_id = int(data['id']) + guild = self._get_guild(utils._get_as_snowflake(data, "guild_id")) + channel_id = int(data["id"]) if guild is not None: channel = guild.get_channel(channel_id) if channel is not None: guild._remove_channel(channel) - self.dispatch('guild_channel_delete', channel) + self.dispatch("guild_channel_delete", channel) def parse_channel_update(self, data) -> None: - channel_type = try_enum(ChannelType, data.get('type')) - channel_id = int(data['id']) + channel_type = try_enum(ChannelType, data.get("type")) + channel_id = int(data["id"]) if channel_type is ChannelType.group: channel = self._get_private_channel(channel_id) old_channel = copy.copy(channel) # the channel is a GroupChannel channel._update_group(data) # type: ignore - self.dispatch('private_channel_update', old_channel, channel) + self.dispatch("private_channel_update", old_channel, channel) return - guild_id = utils._get_as_snowflake(data, 'guild_id') + guild_id = utils._get_as_snowflake(data, "guild_id") guild = self._get_guild(guild_id) if guild is not None: channel = guild.get_channel(channel_id) if channel is not None: old_channel = copy.copy(channel) channel._update(guild, data) - self.dispatch('guild_channel_update', old_channel, channel) + self.dispatch("guild_channel_update", old_channel, channel) else: - _log.debug('CHANNEL_UPDATE referencing an unknown channel ID: %s. Discarding.', channel_id) + _log.debug( + "CHANNEL_UPDATE referencing an unknown channel ID: %s. Discarding.", + channel_id, + ) else: - _log.debug('CHANNEL_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) + _log.debug( + "CHANNEL_UPDATE referencing an unknown guild ID: %s. Discarding.", + guild_id, + ) def parse_channel_create(self, data) -> None: - factory, ch_type = _channel_factory(data['type']) + factory, ch_type = _channel_factory(data["type"]) if factory is None: - _log.debug('CHANNEL_CREATE referencing an unknown channel type %s. Discarding.', data['type']) + _log.debug( + "CHANNEL_CREATE referencing an unknown channel type %s. Discarding.", + data["type"], + ) return - guild_id = utils._get_as_snowflake(data, 'guild_id') + guild_id = utils._get_as_snowflake(data, "guild_id") guild = self._get_guild(guild_id) if guild is not None: # the factory can't be a DMChannel or GroupChannel here channel = factory(guild=guild, state=self, data=data) # type: ignore guild._add_channel(channel) # type: ignore - self.dispatch('guild_channel_create', channel) + self.dispatch("guild_channel_create", channel) else: - _log.debug('CHANNEL_CREATE referencing an unknown guild ID: %s. Discarding.', guild_id) + _log.debug( + "CHANNEL_CREATE referencing an unknown guild ID: %s. Discarding.", + guild_id, + ) return def parse_channel_pins_update(self, data) -> None: - channel_id = int(data['channel_id']) + channel_id = int(data["channel_id"]) try: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) except KeyError: guild = None channel = self._get_private_channel(channel_id) @@ -807,75 +909,93 @@ def parse_channel_pins_update(self, data) -> None: channel = guild and guild._resolve_channel(channel_id) if channel is None: - _log.debug('CHANNEL_PINS_UPDATE referencing an unknown channel ID: %s. Discarding.', channel_id) + _log.debug( + "CHANNEL_PINS_UPDATE referencing an unknown channel ID: %s. Discarding.", + channel_id, + ) return - last_pin = utils.parse_time(data['last_pin_timestamp']) if data['last_pin_timestamp'] else None + last_pin = ( + utils.parse_time(data["last_pin_timestamp"]) + if data["last_pin_timestamp"] + else None + ) if guild is None: - self.dispatch('private_channel_pins_update', channel, last_pin) + self.dispatch("private_channel_pins_update", channel, last_pin) else: - self.dispatch('guild_channel_pins_update', channel, last_pin) + self.dispatch("guild_channel_pins_update", channel, last_pin) def parse_thread_create(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_CREATE referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug( + "THREAD_CREATE referencing an unknown guild ID: %s. Discarding", + guild_id, + ) return thread = Thread(guild=guild, state=guild._state, data=data) has_thread = guild.get_thread(thread.id) guild._add_thread(thread) if not has_thread: - self.dispatch('thread_join', thread) + self.dispatch("thread_join", thread) def parse_thread_update(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_UPDATE referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug( + "THREAD_UPDATE referencing an unknown guild ID: %s. Discarding", + guild_id, + ) return - thread_id = int(data['id']) + thread_id = int(data["id"]) thread = guild.get_thread(thread_id) if thread is not None: old = copy.copy(thread) thread._update(data) - self.dispatch('thread_update', old, thread) + self.dispatch("thread_update", old, thread) else: thread = Thread(guild=guild, state=guild._state, data=data) guild._add_thread(thread) - self.dispatch('thread_join', thread) + self.dispatch("thread_join", thread) def parse_thread_delete(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_DELETE referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug( + "THREAD_DELETE referencing an unknown guild ID: %s. Discarding", + guild_id, + ) return raw = RawThreadDeleteEvent(data) thread = guild.get_thread(raw.thread_id) raw.thread = thread - self.dispatch('raw_thread_delete', raw) - + self.dispatch("raw_thread_delete", raw) if thread is not None: guild._remove_thread(thread) # type: ignore - self.dispatch('thread_delete', thread) + self.dispatch("thread_delete", thread) def parse_thread_list_sync(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_LIST_SYNC referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug( + "THREAD_LIST_SYNC referencing an unknown guild ID: %s. Discarding", + guild_id, + ) return try: - channel_ids = set(data['channel_ids']) + channel_ids = set(data["channel_ids"]) except KeyError: # If not provided, then the entire guild is being synced # So all previous thread data should be overwritten @@ -884,12 +1004,12 @@ def parse_thread_list_sync(self, data) -> None: else: previous_threads = guild._filter_threads(channel_ids) - threads = {d['id']: guild._store_thread(d) for d in data.get('threads', [])} + threads = {d["id"]: guild._store_thread(d) for d in data.get("threads", [])} - for member in data.get('members', []): + for member in data.get("members", []): try: # note: member['id'] is the thread_id - thread = threads[member['id']] + thread = threads[member["id"]] except KeyError: continue else: @@ -898,63 +1018,78 @@ def parse_thread_list_sync(self, data) -> None: for thread in threads.values(): old = previous_threads.pop(thread.id, None) if old is None: - self.dispatch('thread_join', thread) + self.dispatch("thread_join", thread) for thread in previous_threads.values(): - self.dispatch('thread_remove', thread) + self.dispatch("thread_remove", thread) def parse_thread_member_update(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_MEMBER_UPDATE referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug( + "THREAD_MEMBER_UPDATE referencing an unknown guild ID: %s. Discarding", + guild_id, + ) return - thread_id = int(data['id']) + thread_id = int(data["id"]) thread: Optional[Thread] = guild.get_thread(thread_id) if thread is None: - _log.debug('THREAD_MEMBER_UPDATE referencing an unknown thread ID: %s. Discarding', thread_id) + _log.debug( + "THREAD_MEMBER_UPDATE referencing an unknown thread ID: %s. Discarding", + thread_id, + ) return member = ThreadMember(thread, data) thread.me = member def parse_thread_members_update(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_MEMBERS_UPDATE referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug( + "THREAD_MEMBERS_UPDATE referencing an unknown guild ID: %s. Discarding", + guild_id, + ) return - thread_id = int(data['id']) + thread_id = int(data["id"]) thread: Optional[Thread] = guild.get_thread(thread_id) if thread is None: - _log.debug('THREAD_MEMBERS_UPDATE referencing an unknown thread ID: %s. Discarding', thread_id) + _log.debug( + "THREAD_MEMBERS_UPDATE referencing an unknown thread ID: %s. Discarding", + thread_id, + ) return - added_members = [ThreadMember(thread, d) for d in data.get('added_members', [])] - removed_member_ids = [int(x) for x in data.get('removed_member_ids', [])] + added_members = [ThreadMember(thread, d) for d in data.get("added_members", [])] + removed_member_ids = [int(x) for x in data.get("removed_member_ids", [])] self_id = self.self_id for member in added_members: if member.id != self_id: thread._add_member(member) - self.dispatch('thread_member_join', member) + self.dispatch("thread_member_join", member) else: thread.me = member - self.dispatch('thread_join', thread) + self.dispatch("thread_join", thread) for member_id in removed_member_ids: if member_id != self_id: member = thread._pop_member(member_id) if member is not None: - self.dispatch('thread_member_remove', member) + self.dispatch("thread_member_remove", member) else: - self.dispatch('thread_remove', thread) + self.dispatch("thread_remove", thread) def parse_guild_member_add(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is None: - _log.debug('GUILD_MEMBER_ADD referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "GUILD_MEMBER_ADD referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) return member = Member(guild=guild, data=data, state=self) @@ -966,30 +1101,36 @@ def parse_guild_member_add(self, data) -> None: except AttributeError: pass - self.dispatch('member_join', member) + self.dispatch("member_join", member) def parse_guild_member_remove(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: try: guild._member_count -= 1 except AttributeError: pass - user_id = int(data['user']['id']) + user_id = int(data["user"]["id"]) member = guild.get_member(user_id) if member is not None: guild._remove_member(member) # type: ignore - self.dispatch('member_remove', member) + self.dispatch("member_remove", member) else: - _log.debug('GUILD_MEMBER_REMOVE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "GUILD_MEMBER_REMOVE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) def parse_guild_member_update(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) - user = data['user'] - user_id = int(user['id']) + guild = self._get_guild(int(data["guild_id"])) + user = data["user"] + user_id = int(user["id"]) if guild is None: - _log.debug('GUILD_MEMBER_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "GUILD_MEMBER_UPDATE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) return member = guild.get_member(user_id) @@ -998,9 +1139,9 @@ def parse_guild_member_update(self, data) -> None: member._update(data) user_update = member._update_inner_user(user) if user_update: - self.dispatch('user_update', user_update[0], user_update[1]) + self.dispatch("user_update", user_update[0], user_update[1]) - self.dispatch('member_update', old_member, member) + self.dispatch("member_update", old_member, member) else: if self.member_cache_flags.joined: member = Member(data=data, guild=guild, state=self) @@ -1008,43 +1149,52 @@ def parse_guild_member_update(self, data) -> None: # Force an update on the inner user if necessary user_update = member._update_inner_user(user) if user_update: - self.dispatch('user_update', user_update[0], user_update[1]) + self.dispatch("user_update", user_update[0], user_update[1]) guild._add_member(member) - _log.debug('GUILD_MEMBER_UPDATE referencing an unknown member ID: %s. Discarding.', user_id) + _log.debug( + "GUILD_MEMBER_UPDATE referencing an unknown member ID: %s. Discarding.", + user_id, + ) def parse_guild_emojis_update(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is None: - _log.debug('GUILD_EMOJIS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "GUILD_EMOJIS_UPDATE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) return before_emojis = guild.emojis for emoji in before_emojis: self._emojis.pop(emoji.id, None) # guild won't be None here - guild.emojis = tuple(map(lambda d: self.store_emoji(guild, d), data['emojis'])) # type: ignore - self.dispatch('guild_emojis_update', guild, before_emojis, guild.emojis) + guild.emojis = tuple(map(lambda d: self.store_emoji(guild, d), data["emojis"])) # type: ignore + self.dispatch("guild_emojis_update", guild, before_emojis, guild.emojis) def parse_guild_stickers_update(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is None: - _log.debug('GUILD_STICKERS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "GUILD_STICKERS_UPDATE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) return before_stickers = guild.stickers for emoji in before_stickers: self._stickers.pop(emoji.id, None) # guild won't be None here - guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data['stickers'])) # type: ignore - self.dispatch('guild_stickers_update', guild, before_stickers, guild.stickers) + guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data["stickers"])) # type: ignore + self.dispatch("guild_stickers_update", guild, before_stickers, guild.stickers) def _get_create_guild(self, data): - if data.get('unavailable') is False: + if data.get("unavailable") is False: # GUILD_CREATE with unavailable in the response # usually means that the guild has become available # and is therefore in the cache - guild = self._get_guild(int(data['id'])) + guild = self._get_guild(int(data["id"])) if guild is not None: guild.unavailable = False guild._from_data(data) @@ -1059,7 +1209,9 @@ async def chunk_guild(self, guild, *, wait=True, cache=None): cache = cache or self.member_cache_flags.joined request = self._chunk_requests.get(guild.id) if request is None: - self._chunk_requests[guild.id] = request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache) + self._chunk_requests[guild.id] = request = ChunkRequest( + guild.id, self.loop, self._get_guild, cache=cache + ) await self.chunker(guild.id, nonce=request.nonce) if wait: @@ -1070,15 +1222,15 @@ async def _chunk_and_dispatch(self, guild, unavailable): try: await asyncio.wait_for(self.chunk_guild(guild), timeout=60.0) except asyncio.TimeoutError: - _log.info('Somehow timed out waiting for chunks.') + _log.info("Somehow timed out waiting for chunks.") if unavailable is False: - self.dispatch('guild_available', guild) + self.dispatch("guild_available", guild) else: - self.dispatch('guild_join', guild) + self.dispatch("guild_join", guild) def parse_guild_create(self, data) -> None: - unavailable = data.get('unavailable') + unavailable = data.get("unavailable") if unavailable is True: # joined a guild with unavailable == True so.. return @@ -1101,40 +1253,47 @@ def parse_guild_create(self, data) -> None: # Dispatch available if newly available if unavailable is False: - self.dispatch('guild_available', guild) + self.dispatch("guild_available", guild) else: - self.dispatch('guild_join', guild) + self.dispatch("guild_join", guild) def parse_guild_update(self, data) -> None: - guild = self._get_guild(int(data['id'])) + guild = self._get_guild(int(data["id"])) if guild is not None: old_guild = copy.copy(guild) guild._from_data(data) - self.dispatch('guild_update', old_guild, guild) + self.dispatch("guild_update", old_guild, guild) else: - _log.debug('GUILD_UPDATE referencing an unknown guild ID: %s. Discarding.', data['id']) + _log.debug( + "GUILD_UPDATE referencing an unknown guild ID: %s. Discarding.", + data["id"], + ) def parse_guild_delete(self, data) -> None: - guild = self._get_guild(int(data['id'])) + guild = self._get_guild(int(data["id"])) if guild is None: - _log.debug('GUILD_DELETE referencing an unknown guild ID: %s. Discarding.', data['id']) + _log.debug( + "GUILD_DELETE referencing an unknown guild ID: %s. Discarding.", + data["id"], + ) return - if data.get('unavailable', False): + if data.get("unavailable", False): # GUILD_DELETE with unavailable being True means that the # guild that was available is now currently unavailable guild.unavailable = True - self.dispatch('guild_unavailable', guild) + self.dispatch("guild_unavailable", guild) return # do a cleanup of the messages cache if self._messages is not None: self._messages: Optional[Deque[Message]] = deque( - (msg for msg in self._messages if msg.guild != guild), maxlen=self.max_messages + (msg for msg in self._messages if msg.guild != guild), + maxlen=self.max_messages, ) self._remove_guild(guild) - self.dispatch('guild_remove', guild) + self.dispatch("guild_remove", guild) def parse_guild_ban_add(self, data) -> None: # we make the assumption that GUILD_BAN_ADD is done @@ -1142,204 +1301,262 @@ def parse_guild_ban_add(self, data) -> None: # hence we don't remove it from cache or do anything # strange with it, the main purpose of this event # is mainly to dispatch to another event worth listening to for logging - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: try: - user = User(data=data['user'], state=self) + user = User(data=data["user"], state=self) except KeyError: pass else: member = guild.get_member(user.id) or user - self.dispatch('member_ban', guild, member) + self.dispatch("member_ban", guild, member) def parse_guild_ban_remove(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) - if guild is not None and 'user' in data: - user = self.store_user(data['user']) - self.dispatch('member_unban', guild, user) + guild = self._get_guild(int(data["guild_id"])) + if guild is not None and "user" in data: + user = self.store_user(data["user"]) + self.dispatch("member_unban", guild, user) def parse_guild_role_create(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is None: - _log.debug('GUILD_ROLE_CREATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "GUILD_ROLE_CREATE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) return - role_data = data['role'] + role_data = data["role"] role = Role(guild=guild, data=role_data, state=self) guild._add_role(role) - self.dispatch('guild_role_create', role) + self.dispatch("guild_role_create", role) def parse_guild_role_delete(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: - role_id = int(data['role_id']) + role_id = int(data["role_id"]) try: role = guild._remove_role(role_id) except KeyError: return else: - self.dispatch('guild_role_delete', role) + self.dispatch("guild_role_delete", role) else: - _log.debug('GUILD_ROLE_DELETE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "GUILD_ROLE_DELETE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) def parse_guild_role_update(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: - role_data = data['role'] - role_id = int(role_data['id']) + role_data = data["role"] + role_id = int(role_data["id"]) role = guild.get_role(role_id) if role is not None: old_role = copy.copy(role) role._update(role_data) - self.dispatch('guild_role_update', old_role, role) + self.dispatch("guild_role_update", old_role, role) else: - _log.debug('GUILD_ROLE_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "GUILD_ROLE_UPDATE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) def parse_guild_members_chunk(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild = self._get_guild(guild_id) - presences = data.get('presences', []) + presences = data.get("presences", []) # the guild won't be None here - members = [Member(guild=guild, data=member, state=self) for member in data.get('members', [])] # type: ignore - _log.debug('Processed a chunk for %s members in guild ID %s.', len(members), guild_id) + members = [Member(guild=guild, data=member, state=self) for member in data.get("members", [])] # type: ignore + _log.debug( + "Processed a chunk for %s members in guild ID %s.", len(members), guild_id + ) if presences: member_dict = {str(member.id): member for member in members} for presence in presences: - user = presence['user'] - member_id = user['id'] + user = presence["user"] + member_id = user["id"] member = member_dict.get(member_id) if member is not None: member._presence_update(presence, user) - complete = data.get('chunk_index', 0) + 1 == data.get('chunk_count') - self.process_chunk_requests(guild_id, data.get('nonce'), members, complete) + complete = data.get("chunk_index", 0) + 1 == data.get("chunk_count") + self.process_chunk_requests(guild_id, data.get("nonce"), members, complete) def parse_guild_integrations_update(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: - self.dispatch('guild_integrations_update', guild) + self.dispatch("guild_integrations_update", guild) else: - _log.debug('GUILD_INTEGRATIONS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "GUILD_INTEGRATIONS_UPDATE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) def parse_integration_create(self, data) -> None: - guild_id = int(data.pop('guild_id')) + guild_id = int(data.pop("guild_id")) guild = self._get_guild(guild_id) if guild is not None: - cls, _ = _integration_factory(data['type']) + cls, _ = _integration_factory(data["type"]) integration = cls(data=data, guild=guild) - self.dispatch('integration_create', integration) + self.dispatch("integration_create", integration) else: - _log.debug('INTEGRATION_CREATE referencing an unknown guild ID: %s. Discarding.', guild_id) + _log.debug( + "INTEGRATION_CREATE referencing an unknown guild ID: %s. Discarding.", + guild_id, + ) def parse_integration_update(self, data) -> None: - guild_id = int(data.pop('guild_id')) + guild_id = int(data.pop("guild_id")) guild = self._get_guild(guild_id) if guild is not None: - cls, _ = _integration_factory(data['type']) + cls, _ = _integration_factory(data["type"]) integration = cls(data=data, guild=guild) - self.dispatch('integration_update', integration) + self.dispatch("integration_update", integration) else: - _log.debug('INTEGRATION_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) + _log.debug( + "INTEGRATION_UPDATE referencing an unknown guild ID: %s. Discarding.", + guild_id, + ) def parse_integration_delete(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild = self._get_guild(guild_id) if guild is not None: raw = RawIntegrationDeleteEvent(data) - self.dispatch('raw_integration_delete', raw) + self.dispatch("raw_integration_delete", raw) else: - _log.debug('INTEGRATION_DELETE referencing an unknown guild ID: %s. Discarding.', guild_id) + _log.debug( + "INTEGRATION_DELETE referencing an unknown guild ID: %s. Discarding.", + guild_id, + ) def parse_webhooks_update(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is None: - _log.debug('WEBHOOKS_UPDATE referencing an unknown guild ID: %s. Discarding', data['guild_id']) + _log.debug( + "WEBHOOKS_UPDATE referencing an unknown guild ID: %s. Discarding", + data["guild_id"], + ) return - channel = guild.get_channel(int(data['channel_id'])) + channel = guild.get_channel(int(data["channel_id"])) if channel is not None: - self.dispatch('webhooks_update', channel) + self.dispatch("webhooks_update", channel) else: - _log.debug('WEBHOOKS_UPDATE referencing an unknown channel ID: %s. Discarding.', data['channel_id']) + _log.debug( + "WEBHOOKS_UPDATE referencing an unknown channel ID: %s. Discarding.", + data["channel_id"], + ) def parse_stage_instance_create(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: stage_instance = StageInstance(guild=guild, state=self, data=data) guild._stage_instances[stage_instance.id] = stage_instance - self.dispatch('stage_instance_create', stage_instance) + self.dispatch("stage_instance_create", stage_instance) else: - _log.debug('STAGE_INSTANCE_CREATE referencing unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "STAGE_INSTANCE_CREATE referencing unknown guild ID: %s. Discarding.", + data["guild_id"], + ) def parse_stage_instance_update(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: - stage_instance = guild._stage_instances.get(int(data['id'])) + stage_instance = guild._stage_instances.get(int(data["id"])) if stage_instance is not None: old_stage_instance = copy.copy(stage_instance) stage_instance._update(data) - self.dispatch('stage_instance_update', old_stage_instance, stage_instance) + self.dispatch( + "stage_instance_update", old_stage_instance, stage_instance + ) else: - _log.debug('STAGE_INSTANCE_UPDATE referencing unknown stage instance ID: %s. Discarding.', data['id']) + _log.debug( + "STAGE_INSTANCE_UPDATE referencing unknown stage instance ID: %s. Discarding.", + data["id"], + ) else: - _log.debug('STAGE_INSTANCE_UPDATE referencing unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "STAGE_INSTANCE_UPDATE referencing unknown guild ID: %s. Discarding.", + data["guild_id"], + ) def parse_stage_instance_delete(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: try: - stage_instance = guild._stage_instances.pop(int(data['id'])) + stage_instance = guild._stage_instances.pop(int(data["id"])) except KeyError: pass else: - self.dispatch('stage_instance_delete', stage_instance) + self.dispatch("stage_instance_delete", stage_instance) else: - _log.debug('STAGE_INSTANCE_DELETE referencing unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "STAGE_INSTANCE_DELETE referencing unknown guild ID: %s. Discarding.", + data["guild_id"], + ) def parse_voice_state_update(self, data) -> None: - guild = self._get_guild(utils._get_as_snowflake(data, 'guild_id')) - channel_id = utils._get_as_snowflake(data, 'channel_id') + guild = self._get_guild(utils._get_as_snowflake(data, "guild_id")) + channel_id = utils._get_as_snowflake(data, "channel_id") flags = self.member_cache_flags # self.user is *always* cached when this is called self_id = self.user.id # type: ignore if guild is not None: - if int(data['user_id']) == self_id: + if int(data["user_id"]) == self_id: voice = self._get_voice_client(guild.id) if voice is not None: coro = voice.on_voice_state_update(data) - asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice state update handler')) + asyncio.create_task( + logging_coroutine( + coro, info="Voice Protocol voice state update handler" + ) + ) member, before, after = guild._update_voice_state(data, channel_id) # type: ignore if member is not None: if flags.voice: - if channel_id is None and flags._voice_only and member.id != self_id: + if ( + channel_id is None + and flags._voice_only + and member.id != self_id + ): # Only remove from cache if we only have the voice flag enabled # Member doesn't meet the Snowflake protocol currently guild._remove_member(member) # type: ignore elif channel_id is not None: guild._add_member(member) - self.dispatch('voice_state_update', member, before, after) + self.dispatch("voice_state_update", member, before, after) else: - _log.debug('VOICE_STATE_UPDATE referencing an unknown member ID: %s. Discarding.', data['user_id']) + _log.debug( + "VOICE_STATE_UPDATE referencing an unknown member ID: %s. Discarding.", + data["user_id"], + ) def parse_voice_server_update(self, data) -> None: try: - key_id = int(data['guild_id']) + key_id = int(data["guild_id"]) except KeyError: - key_id = int(data['channel_id']) + key_id = int(data["channel_id"]) vc = self._get_voice_client(key_id) if vc is not None: coro = vc.on_voice_server_update(data) - asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice server update handler')) + asyncio.create_task( + logging_coroutine( + coro, info="Voice Protocol voice server update handler" + ) + ) def parse_typing_start(self, data) -> None: raw = RawTypingEvent(data) - member_data = data.get('member') + member_data = data.get("member") if member_data: guild = self._get_guild(raw.guild_id) if guild is not None: @@ -1348,16 +1565,18 @@ def parse_typing_start(self, data) -> None: raw.member = None else: raw.member = None - self.dispatch('raw_typing', raw) + self.dispatch("raw_typing", raw) channel, guild = self._get_guild_channel(data) if channel is not None: user = raw.member or self._get_typing_user(channel, raw.user_id) if user is not None: - self.dispatch('typing', channel, user, raw.when) + self.dispatch("typing", channel, user, raw.when) - def _get_typing_user(self, channel: Optional[MessageableChannel], user_id: int) -> Optional[Union[User, Member]]: + def _get_typing_user( + self, channel: Optional[MessageableChannel], user_id: int + ) -> Optional[Union[User, Member]]: if isinstance(channel, DMChannel): return channel.recipient or self.get_user(user_id) @@ -1369,23 +1588,32 @@ def _get_typing_user(self, channel: Optional[MessageableChannel], user_id: int) return self.get_user(user_id) - def _get_reaction_user(self, channel: MessageableChannel, user_id: int) -> Optional[Union[User, Member]]: + def _get_reaction_user( + self, channel: MessageableChannel, user_id: int + ) -> Optional[Union[User, Member]]: if isinstance(channel, TextChannel): return channel.guild.get_member(user_id) return self.get_user(user_id) def get_reaction_emoji(self, data) -> Union[Emoji, PartialEmoji]: - emoji_id = utils._get_as_snowflake(data, 'id') + emoji_id = utils._get_as_snowflake(data, "id") if not emoji_id: - return data['name'] + return data["name"] try: return self._emojis[emoji_id] except KeyError: - return PartialEmoji.with_state(self, animated=data.get('animated', False), id=emoji_id, name=data['name']) + return PartialEmoji.with_state( + self, + animated=data.get("animated", False), + id=emoji_id, + name=data["name"], + ) - def _upgrade_partial_emoji(self, emoji: PartialEmoji) -> Union[Emoji, PartialEmoji, str]: + def _upgrade_partial_emoji( + self, emoji: PartialEmoji + ) -> Union[Emoji, PartialEmoji, str]: emoji_id = emoji.id if not emoji_id: return emoji.name @@ -1408,7 +1636,12 @@ def get_channel(self, id: Optional[int]) -> Optional[Union[Channel, Thread]]: return channel def create_message( - self, *, channel: Union[TextChannel, Thread, DMChannel, GroupChannel, PartialMessageable], data: MessagePayload + self, + *, + channel: Union[ + TextChannel, Thread, DMChannel, GroupChannel, PartialMessageable + ], + data: MessagePayload, ) -> Message: return Message(state=self, channel=channel, data=data) @@ -1428,14 +1661,16 @@ def _update_message_references(self) -> None: new_guild = self._get_guild(msg.guild.id) if new_guild is not None and new_guild is not msg.guild: channel_id = msg.channel.id - channel = new_guild._resolve_channel(channel_id) or Object(id=channel_id) + channel = new_guild._resolve_channel(channel_id) or Object( + id=channel_id + ) # channel will either be a TextChannel, Thread or Object msg._rebind_cached_references(new_guild, channel) # type: ignore async def chunker( self, guild_id: int, - query: str = '', + query: str = "", limit: int = 0, presences: bool = False, *, @@ -1443,7 +1678,9 @@ async def chunker( nonce: Optional[str] = None, ) -> None: ws = self._get_websocket(guild_id, shard_id=shard_id) - await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) + await ws.request_chunks( + guild_id, query=query, limit=limit, presences=presences, nonce=nonce + ) async def _delay_ready(self) -> None: await self.shards_launched.wait() @@ -1454,17 +1691,24 @@ async def _delay_ready(self) -> None: # this snippet of code is basically waiting N seconds # until the last GUILD_CREATE was sent try: - guild = await asyncio.wait_for(self._ready_state.get(), timeout=self.guild_ready_timeout) + guild = await asyncio.wait_for( + self._ready_state.get(), timeout=self.guild_ready_timeout + ) except asyncio.TimeoutError: break else: if self._guild_needs_chunking(guild): - _log.debug('Guild ID %d requires chunking, will be done in the background.', guild.id) + _log.debug( + "Guild ID %d requires chunking, will be done in the background.", + guild.id, + ) if len(current_bucket) >= max_concurrency: try: - await utils.sane_wait_for(current_bucket, timeout=max_concurrency * 70.0) + await utils.sane_wait_for( + current_bucket, timeout=max_concurrency * 70.0 + ) except asyncio.TimeoutError: - fmt = 'Shard ID %s failed to wait for chunks from a sub-bucket with length %d' + fmt = "Shard ID %s failed to wait for chunks from a sub-bucket with length %d" _log.warning(fmt, guild.shard_id, len(current_bucket)) finally: current_bucket = [] @@ -1487,15 +1731,18 @@ async def _delay_ready(self) -> None: await utils.sane_wait_for(futures, timeout=timeout) except asyncio.TimeoutError: _log.warning( - 'Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds', shard_id, timeout, len(guilds) + "Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds", + shard_id, + timeout, + len(guilds), ) for guild in children: if guild.unavailable is False: - self.dispatch('guild_available', guild) + self.dispatch("guild_available", guild) else: - self.dispatch('guild_join', guild) + self.dispatch("guild_join", guild) - self.dispatch('shard_ready', shard_id) + self.dispatch("shard_ready", shard_id) # remove the state try: @@ -1509,38 +1756,40 @@ async def _delay_ready(self) -> None: self._ready_task = None # dispatch the event - self.call_handlers('ready') - self.dispatch('ready') + self.call_handlers("ready") + self.dispatch("ready") def parse_ready(self, data) -> None: - if not hasattr(self, '_ready_state'): + if not hasattr(self, "_ready_state"): self._ready_state = asyncio.Queue() - self.user = user = ClientUser(state=self, data=data['user']) + self.user = user = ClientUser(state=self, data=data["user"]) # self._users is a list of Users, we're setting a ClientUser self._users[user.id] = user # type: ignore if self.application_id is None: try: - application = data['application'] + application = data["application"] except KeyError: pass else: - self.application_id = utils._get_as_snowflake(application, 'id') - self.application_flags = ApplicationFlags._from_value(application['flags']) + self.application_id = utils._get_as_snowflake(application, "id") + self.application_flags = ApplicationFlags._from_value( + application["flags"] + ) - for guild_data in data['guilds']: + for guild_data in data["guilds"]: self._add_guild_from_data(guild_data) if self._messages: self._update_message_references() - self.dispatch('connect') - self.dispatch('shard_connect', data['__shard_id__']) + self.dispatch("connect") + self.dispatch("shard_connect", data["__shard_id__"]) if self._ready_task is None: self._ready_task = asyncio.create_task(self._delay_ready()) def parse_resumed(self, data) -> None: - self.dispatch('resumed') - self.dispatch('shard_resumed', data['__shard_id__']) \ No newline at end of file + self.dispatch("resumed") + self.dispatch("shard_resumed", data["__shard_id__"]) diff --git a/discord/types/components.py b/discord/types/components.py index 9bf243e19a..d12e96aa02 100644 --- a/discord/types/components.py +++ b/discord/types/components.py @@ -26,10 +26,12 @@ from __future__ import annotations from typing import List, Literal, TypedDict, Union + from .emoji import PartialEmoji -ComponentType = Literal[1, 2, 3] +ComponentType = Literal[1, 2, 3, 4] ButtonStyle = Literal[1, 2, 3, 4, 5] +InputTextStyle = Literal[1, 2] class ActionRow(TypedDict): @@ -50,6 +52,21 @@ class ButtonComponent(_ButtonComponentOptional): style: ButtonStyle +class _InputTextComponentOptional(TypedDict, total=False): + min_length: int + max_length: int + required: bool + placeholder: str + value: str + + +class InputText(_InputTextComponentOptional): + type: Literal[4] + style: InputTextStyle + custom_id: str + label: str + + class _SelectMenuOptional(TypedDict, total=False): placeholder: str min_values: int @@ -74,4 +91,4 @@ class SelectMenu(_SelectMenuOptional): options: List[SelectOption] -Component = Union[ActionRow, ButtonComponent, SelectMenu] +Component = Union[ActionRow, ButtonComponent, SelectMenu, InputText] diff --git a/discord/ui/__init__.py b/discord/ui/__init__.py index 91e7ecf739..74a8ac5b95 100644 --- a/discord/ui/__init__.py +++ b/discord/ui/__init__.py @@ -9,7 +9,9 @@ """ -from .view import * -from .item import * from .button import * +from .input_text import * +from .item import * +from .modal import * from .select import * +from .view import * diff --git a/discord/ui/input_text.py b/discord/ui/input_text.py new file mode 100644 index 0000000000..df85bfcb3c --- /dev/null +++ b/discord/ui/input_text.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Optional + +from ..components import InputText as InputTextComponent +from ..enums import InputTextStyle +from ..utils import MISSING +from .item import Item + +__all__ = ("InputText",) + +if TYPE_CHECKING: + from ..types.components import InputText as InputTextComponentPayload + + +class InputText(Item): + """Represents a UI text input field. + + Parameters + ---------- + style: :class:`discord.InputTextStyle` + The style of the input text field. + custom_id: Optional[:class:`str`] + The ID of the input text field that gets received during an interaction. + label: Optional[:class:`str`] + The label for the input text field, if any. + placeholder: Optional[:class:`str`] + The placeholder text that is shown if nothing is selected, if any. + min_length: Optional[:class:`int`] + The minimum number of characters that must be entered + Defaults to 0 + max_length: Optional[:class:`int`] + The maximum number of characters that can be entered + required: Optional[:class:`bool`] + Whether the input text field is required or not. Defaults to `True`. + value: Optional[:class:`str`] + Pre-fills the input text field with this value + row: Optional[:class:`int`] + The relative row this button belongs to. A Discord component can only have 5 + rows. By default, items are arranged automatically into those 5 rows. If you'd + like to control the relative positioning of the row then passing an index is advised. + For example, row=1 will show up before row=2. Defaults to ``None``, which is automatic + ordering. The row number must be between 0 and 4 (i.e. zero indexed). + """ + + def __init__( + self, + style: InputTextStyle = InputTextStyle.short, + custom_id: str = MISSING, + label: Optional[str] = None, + placeholder: Optional[str] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + required: Optional[bool] = True, + value: Optional[str] = None, + row: Optional[int] = None, + ): + super().__init__() + custom_id = os.urandom(16).hex() if custom_id is MISSING else custom_id + self._underlying = InputTextComponent._raw_construct( + style=style, + custom_id=custom_id, + label=label, + placeholder=placeholder, + min_length=min_length, + max_length=max_length, + required=required, + value=value, + ) + self._input_value = None + self.row = row + + @property + def style(self) -> InputTextStyle: + """:class:`discord.InputTextStyle`: The style of the input text field.""" + return self._underlying.style + + @style.setter + def style(self, value: InputTextStyle): + if not isinstance(value, InputTextStyle): + raise TypeError( + f"style must be of type InputTextStyle not {value.__class__}" + ) + self._underlying.style = value + + @property + def custom_id(self) -> str: + """:class:`str`: The ID of the input text field that gets received during an interaction.""" + return self._underlying.custom_id + + @custom_id.setter + def custom_id(self, value: str): + if not isinstance(value, str): + raise TypeError(f"custom_id must be None or str not {value.__class__}") + self._underlying.custom_id = value + + @property + def label(self) -> str: + """:class:`str`: The label of the input text field.""" + return self._underlying.label + + @label.setter + def label(self, value: str): + if not isinstance(value, str): + raise TypeError(f"label should be None or str not {value.__class__}") + + @property + def placeholder(self) -> Optional[str]: + """Optional[:class:`str`]: The placeholder text that is shown before anything is entered, if any.""" + return self._underlying.placeholder + + @placeholder.setter + def placeholder(self, value: Optional[str]): + if value and not isinstance(value, str): + raise TypeError(f"placeholder must be None or str not {value.__class__}") # type: ignore + self._underlying.placeholder = value + + @property + def min_length(self) -> Optional[int]: + """Optional[:class:`int`]: The minimum number of characters that must be entered. Defaults to `0`.""" + return self._underlying.min_length + + @min_length.setter + def min_length(self, value: Optional[int]): + if value and not isinstance(value, int): + raise TypeError(f"min_length must be None or int not {value.__class__}") # type: ignore + self._underlying.min_length = value + + @property + def max_length(self) -> Optional[int]: + """Optional[:class:`int`]: The maximum number of characters that can be entered.""" + return self._underlying.max_length + + @max_length.setter + def max_length(self, value: Optional[int]): + if value and not isinstance(value, int): + raise TypeError(f"min_length must be None or int not {value.__class__}") # type: ignore + self._underlying.max_length = value + + @property + def required(self) -> Optional[bool]: + """Optional[:class:`bool`]: Whether the input text field is required or not. Defaults to `True`.""" + return self._underlying.required + + @required.setter + def required(self, value: Optional[bool]): + if not isinstance(value, bool): + raise TypeError(f"required must be bool not {value.__class__}") # type: ignore + + @property + def value(self) -> Optional[str]: + """Optional[:class:`str`]: The value entered in the text field.""" + return self._input_value or self._underlying.value + + @value.setter + def value(self, value: Optional[str]): + if value and not isinstance(value, str): + raise TypeError(f"value must be None or str not {value.__class__}") # type: ignore + self._underlying.value = value + + @property + def width(self) -> int: + return 5 + + def to_component_dict(self) -> InputTextComponentPayload: + return self._underlying.to_dict() + + def refresh_state(self, data) -> None: + self._input_value = data["value"] diff --git a/discord/ui/modal.py b/discord/ui/modal.py new file mode 100644 index 0000000000..6f2c93e4ed --- /dev/null +++ b/discord/ui/modal.py @@ -0,0 +1,125 @@ +from __future__ import annotations +import os +from itertools import groupby +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +from .item import Item +from .view import _ViewWeights + +__all__ = ("Modal",) + + +if TYPE_CHECKING: + from ..interactions import Interaction + from ..state import ConnectionState + + +class Modal: + """Represents a UI Modal dialog. + + This object must be inherited to create a UI within Discord. + """ + + def __init__(self, title: str, custom_id: Optional[str] = None) -> None: + self.custom_id = custom_id or os.urandom(16).hex() + self.title = title + self.children: List[Item] = [] + self.__weights = _ViewWeights(self.children) + + async def callback(self, interaction: Interaction): + """|coro| + + The coroutine that is called when the modal dialog is submitted. + Should be overridden to handle the values submitted by the user. + + Parameters + ----------- + interaction: :class:`~discord.Interaction` + The interaction that submitted the modal dialog. + """ + pass + + def to_components(self) -> List[Dict[str, Any]]: + def key(item: Item) -> int: + return item._rendered_row or 0 + + children = sorted(self.children, key=key) + components: List[Dict[str, Any]] = [] + for _, group in groupby(children, key=key): + children = [item.to_component_dict() for item in group] + if not children: + continue + + components.append( + { + "type": 1, + "components": children, + } + ) + + return components + + def add_item(self, item: Item): + """Adds an item to the modal dialog. + + Parameters + ---------- + + item: :class:`Item` + The item to add to the modal dialog + """ + + if len(self.children) > 5: + raise ValueError("You can only have up to 5 items in a modal dialog.") + + if not isinstance(item, Item): + raise TypeError(f"expected Item not {item.__class__!r}") + + self.__weights.add_item(item) + self.children.append(item) + + def remove_item(self, item: Item): + """Removes an item from the modal dialog. + + Parameters + ---------- + item: :class:`Item` + The item to remove from the modal dialog. + """ + try: + self.children.remove(item) + except ValueError: + pass + else: + self.__weights.remove_item(item) + + def to_dict(self): + return {"title": self.title, "custom_id": self.custom_id, "components": self.to_components()} + + +class ModalStore: + def __init__(self, state: ConnectionState) -> None: + # (user_id, custom_id) : Modal + self._modals: Dict[Tuple[int, str], Modal] = {} + self._state: ConnectionState = state + + def add_modal(self, modal: Modal, user_id: int): + self._modals[(user_id, modal.custom_id)] = modal + + def remove_modal(self, modal: Modal, user_id): + self._modals.pop((user_id, modal.custom_id)) + + async def dispatch(self, user_id: int, custom_id: str, interaction: Interaction): + key = (user_id, custom_id) + value = self._modals.get(key) + if value is None: + return + + components = [component for parent_component in interaction.data["components"] for component in parent_component["components"]] + for component in components: + for child in value.children: + if child.custom_id == component["custom_id"]: # type: ignore + child.refresh_state(component) + break + await value.callback(interaction) + self.remove_modal(value, user_id) diff --git a/examples/modal_dialogs.py b/examples/modal_dialogs.py new file mode 100644 index 0000000000..82a95fced6 --- /dev/null +++ b/examples/modal_dialogs.py @@ -0,0 +1,78 @@ +import discord +from discord.ext import commands +from discord.ui import InputText, Modal + + +class Bot(commands.Bot): + def __init__(self): + super().__init__(command_prefix=">") + + +bot = Bot() + + +class MyModal(Modal): + def __init__(self) -> None: + super().__init__("Test Modal Dialog") + self.add_item(InputText(label="Short Input", placeholder="Placeholder Test")) + + self.add_item( + InputText( + label="Longer Input", + value="Longer Value\nSuper Long Value", + style=discord.InputTextStyle.long, + ) + ) + + async def callback(self, interaction: discord.Interaction): + embed = discord.Embed(title="Your Modal Results", color=discord.Color.random()) + embed.add_field(name="First Input", value=self.children[0].value, inline=False) + embed.add_field(name="Second Input", value=self.children[1].value, inline=False) + await interaction.response.send_message(embeds=[embed]) + + +@bot.slash_command(name="modaltest", guild_ids=[907533384081879040]) +async def modal_slash(ctx): + """Shows an example of a modal dialog being invoked from a slash command.""" + modal = MyModal() + await ctx.interaction.response.send_modal(modal) + + +@bot.message_command(name="messagemodal", guild_ids=[907533384081879040]) +async def modal_message(ctx, message): + """Shows an example of a modal dialog being invoked from a message command.""" + modal = MyModal() + modal.title = f"Modal for Message ID: {message.id}" + await ctx.interaction.response.send_modal(modal) + + +@bot.user_command(name="usermodal", guild_ids=[907533384081879040]) +async def modal_user(ctx, member): + """Shows an example of a modal dialog being invoked from a user command.""" + modal = MyModal() + modal.title = f"Modal for User: {member.display_name}" + await ctx.interaction.response.send_modal(modal) + + +@bot.command() +async def modaltest(ctx): + """Shows an example of modals being invoked from an interaction component (e.g. a button or select menu)""" + class MyView(discord.ui.View): + @discord.ui.button(label="Modal Test", style=discord.ButtonStyle.primary) + async def button_callback(self, button, interaction): + modal = MyModal() + await interaction.response.send_modal(modal) + + @discord.ui.select(placeholder='Pick Your Modal', min_values=1, max_values=1, options=[ + discord.SelectOption(label='First Modal', description='Shows the first modal'), + discord.SelectOption(label='Second Modal', description='Shows the second modal'), + ]) + async def select_callback(self, select, interaction): + modal = MyModal() + modal.title = select.values[0] + await interaction.response.send_modal(modal) + + view = MyView() + await ctx.send("Click Button, Receive Modal", view=view) + +bot.run("your token")