diff --git a/discord/bot.py b/discord/bot.py index 6e9f96f9c2..1817bda0eb 100644 --- a/discord/bot.py +++ b/discord/bot.py @@ -730,8 +730,8 @@ async def process_application_commands(self, interaction: Interaction, auto_sync if auto_sync is None: auto_sync = self.auto_sync_commands if interaction.type not in ( - InteractionType.application_command, - InteractionType.auto_complete + InteractionType.application_command, + InteractionType.auto_complete, ): return diff --git a/discord/components.py b/discord/components.py index b6719e7c1b..dfed21bad6 100644 --- a/discord/components.py +++ b/discord/components.py @@ -26,13 +26,14 @@ from __future__ import annotations from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union -from .enums import try_enum, ComponentType, ButtonStyle +from .enums import try_enum, ComponentType, ButtonStyle, InputTextStyle from .utils import get_slots, MISSING from .partial_emoji import PartialEmoji, _EmojiTag if TYPE_CHECKING: from .types.components import ( Component as ComponentPayload, + InputText as InputTextComponentPayload, ButtonComponent as ButtonComponentPayload, SelectMenu as SelectMenuPayload, SelectOption as SelectOptionPayload, @@ -128,6 +129,82 @@ def to_dict(self) -> ActionRowPayload: } # type: ignore +class InputText(Component): + """Represents an Input Text field from the Discord Bot UI Kit. + This inherits from :class:`Component`. + Attributes + ---------- + style: :class:`.InputTextStyle` + The style of the input text field. + custom_id: Optional[:class:`str`] + The ID of the input text field that gets received during an interaction. + label: Optional[:class:`str`] + The label for the input text field, if any. + placeholder: Optional[:class:`str`] + The placeholder text that is shown if nothing is selected, if any. + min_length: Optional[:class:`int`] + The minimum number of characters that must be entered + Defaults to 0 + max_length: Optional[:class:`int`] + The maximum number of characters that can be entered + required: Optional[:class:`bool`] + Whether the input text field is required or not. Defaults to `True`. + value: Optional[:class:`str`] + The value that has been entered in the input text field. + """ + + __slots__: Tuple[str, ...] = ( + "type", + "style", + "custom_id", + "label", + "placeholder", + "min_length", + "max_length", + "required", + "value", + ) + + __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ + + def __init__(self, data: InputTextComponentPayload): + self.type = ComponentType.input_text + self.style: InputTextStyle = try_enum(InputTextStyle, data["style"]) + self.custom_id = data["custom_id"] + self.label: Optional[str] = data.get("label", None) + self.placeholder: Optional[str] = data.get("placeholder", None) + self.min_length: Optional[int] = data.get("min_length", None) + self.max_length: Optional[int] = data.get("max_length", None) + self.required: bool = data.get("required", True) + self.value: Optional[str] = data.get("value", None) + + def to_dict(self) -> InputTextComponentPayload: + payload = { + "type": 4, + "style": self.style.value, + "label": self.label, + } + if self.custom_id: + payload["custom_id"] = self.custom_id + + if self.placeholder: + payload["placeholder"] = self.placeholder + + if self.min_length: + payload["min_length"] = self.min_length + + if self.max_length: + payload["max_length"] = self.max_length + + if not self.required: + payload["required"] = self.required + + if self.value: + payload["value"] = self.value + + return payload # type: ignore + + class Button(Component): """Represents a button from the Discord Bot UI Kit. diff --git a/discord/enums.py b/discord/enums.py index e14a68613f..1c90a8dcfb 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -60,6 +60,7 @@ 'ScheduledEventStatus', 'ScheduledEventPrivacyLevel', 'ScheduledEventLocationType', + 'InputTextStyle', ) @@ -550,6 +551,7 @@ class InteractionType(Enum): application_command = 2 component = 3 auto_complete = 4 + modal_submit = 5 class InteractionResponseType(Enum): @@ -561,7 +563,7 @@ class InteractionResponseType(Enum): deferred_message_update = 6 # for components message_update = 7 # for components auto_complete_result = 8 # for autocomplete interactions - + modal = 9 # for modal dialogs class VideoQualityMode(Enum): auto = 1 @@ -575,6 +577,7 @@ class ComponentType(Enum): action_row = 1 button = 2 select = 3 + input_text = 4 def __int__(self): return self.value @@ -599,6 +602,14 @@ def __int__(self): return self.value +class InputTextStyle(Enum): + short = 1 + singleline = 1 + paragraph = 2 + multiline = 2 + long = 2 + + class ApplicationType(Enum): game = 1 music = 2 diff --git a/discord/interactions.py b/discord/interactions.py index cb2273228f..706589f7bf 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -59,6 +59,7 @@ from aiohttp import ClientSession from .embeds import Embed from .ui.view import View + from .ui.modal import Modal from .channel import VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, PartialMessageable from .threads import Thread from .commands import OptionChoice @@ -775,7 +776,24 @@ async def send_autocomplete_result( ) self._responded = True - + + async def send_modal(self, modal: Modal): + if self._responded: + raise InteractionResponded(self._parent) + + payload = modal.to_dict() + adapter = async_context.get() + await adapter.create_interaction_response( + self._parent.id, + self._parent.token, + session=self._parent._session, + type=InteractionResponseType.modal.value, + data=payload, + ) + self._responded = True + self._parent._state.store_modal(modal, self._parent.user.id) + + class _InteractionMessageState: __slots__ = ('_parent', '_interaction') diff --git a/discord/state.py b/discord/state.py index 0a87033eb1..6822dbf529 100644 --- a/discord/state.py +++ b/discord/state.py @@ -47,7 +47,7 @@ from .raw_models import * from .member import Member from .role import Role -from .enums import ChannelType, try_enum, Status, ScheduledEventStatus +from .enums import ChannelType, try_enum, Status, ScheduledEventStatus, InteractionType from . import utils from .flags import ApplicationFlags, Intents, MemberCacheFlags from .object import Object @@ -55,6 +55,7 @@ from .integrations import _integration_factory from .interactions import Interaction from .ui.view import ViewStore, View +from .ui.modal import Modal, ModalStore from .stage_instance import StageInstance from .threads import Thread, ThreadMember from .sticker import GuildSticker @@ -256,7 +257,7 @@ def clear(self, *, views: bool = True) -> None: self._guilds: Dict[int, Guild] = {} if views: self._view_store: ViewStore = ViewStore(self) - + self._modal_store: ModalStore = ModalStore(self) self._voice_clients: Dict[int, VoiceProtocol] = {} # LRU of max size 128 @@ -363,6 +364,9 @@ def store_sticker(self, guild: Guild, data: GuildStickerPayload) -> GuildSticker def store_view(self, view: View, message_id: Optional[int] = None) -> None: self._view_store.add_view(view, message_id) + def store_modal(self, modal: Modal, message_id: int) -> None: + self._modal_store.add_modal(modal, message_id) + def prevent_view_updates_for(self, message_id: int) -> Optional[View]: return self._view_store.remove_message_tracking(message_id) @@ -705,6 +709,12 @@ def parse_interaction_create(self, data) -> None: custom_id = interaction.data['custom_id'] # type: ignore component_type = interaction.data['component_type'] # type: ignore self._view_store.dispatch(component_type, custom_id, interaction) + if interaction.type == InteractionType.modal_submit: + user_id, custom_id = ( + interaction.user.id, + interaction.data["custom_id"], + ) + asyncio.create_task(self._modal_store.dispatch(user_id, custom_id, interaction)) self.dispatch('interaction', interaction) diff --git a/discord/types/components.py b/discord/types/components.py index 9bf243e19a..836d2376b0 100644 --- a/discord/types/components.py +++ b/discord/types/components.py @@ -28,8 +28,9 @@ from typing import List, Literal, TypedDict, Union from .emoji import PartialEmoji -ComponentType = Literal[1, 2, 3] +ComponentType = Literal[1, 2, 3, 4] ButtonStyle = Literal[1, 2, 3, 4, 5] +InputTextStyle = Literal[1, 2] class ActionRow(TypedDict): @@ -50,6 +51,21 @@ class ButtonComponent(_ButtonComponentOptional): style: ButtonStyle +class _InputTextComponentOptional(TypedDict, total=False): + min_length: int + max_length: int + required: bool + placeholder: str + value: str + + +class InputText(_InputTextComponentOptional): + type: Literal[4] + style: InputTextStyle + custom_id: str + label: str + + class _SelectMenuOptional(TypedDict, total=False): placeholder: str min_values: int @@ -74,4 +90,4 @@ class SelectMenu(_SelectMenuOptional): options: List[SelectOption] -Component = Union[ActionRow, ButtonComponent, SelectMenu] +Component = Union[ActionRow, ButtonComponent, SelectMenu, InputText] diff --git a/discord/ui/__init__.py b/discord/ui/__init__.py index 2634fadf8e..30a2f9dec3 100644 --- a/discord/ui/__init__.py +++ b/discord/ui/__init__.py @@ -13,3 +13,5 @@ from .item import * from .button import * from .select import * +from .input_text import * +from .modal import * diff --git a/discord/ui/input_text.py b/discord/ui/input_text.py new file mode 100644 index 0000000000..ba6e801dab --- /dev/null +++ b/discord/ui/input_text.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Optional + +from ..components import InputText as InputTextComponent +from ..enums import InputTextStyle +from ..utils import MISSING +from .item import Item + +__all__ = ("InputText",) + +if TYPE_CHECKING: + from ..types.components import InputText as InputTextComponentPayload + + +class InputText(Item): + """Represents a UI text input field. + + Parameters + ---------- + style: :class:`discord.InputTextStyle` + The style of the input text field. + custom_id: Optional[:class:`str`] + The ID of the input text field that gets received during an interaction. + label: Optional[:class:`str`] + The label for the input text field, if any. + placeholder: Optional[:class:`str`] + The placeholder text that is shown if nothing is selected, if any. + min_length: Optional[:class:`int`] + The minimum number of characters that must be entered. + Defaults to 0. + max_length: Optional[:class:`int`] + The maximum number of characters that can be entered. + required: Optional[:class:`bool`] + Whether the input text field is required or not. Defaults to `True`. + value: Optional[:class:`str`] + Pre-fills the input text field with this value. + row: Optional[:class:`int`] + The relative row this button belongs to. A Discord component can only have 5 + rows. By default, items are arranged automatically into those 5 rows. If you'd + like to control the relative positioning of the row then passing an index is advised. + For example, row=1 will show up before row=2. Defaults to ``None``, which is automatic + ordering. The row number must be between 0 and 4 (i.e. zero indexed). + """ + + def __init__( + self, + style: InputTextStyle = InputTextStyle.short, + custom_id: str = MISSING, + label: Optional[str] = None, + placeholder: Optional[str] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + required: Optional[bool] = True, + value: Optional[str] = None, + row: Optional[int] = None, + ): + super().__init__() + custom_id = os.urandom(16).hex() if custom_id is MISSING else custom_id + self._underlying = InputTextComponent._raw_construct( + style=style, + custom_id=custom_id, + label=label, + placeholder=placeholder, + min_length=min_length, + max_length=max_length, + required=required, + value=value, + ) + self._input_value = None + self.row = row + + @property + def style(self) -> InputTextStyle: + """:class:`discord.InputTextStyle`: The style of the input text field.""" + return self._underlying.style + + @style.setter + def style(self, value: InputTextStyle): + if not isinstance(value, InputTextStyle): + raise TypeError( + f"style must be of type InputTextStyle not {value.__class__}" + ) + self._underlying.style = value + + @property + def custom_id(self) -> str: + """:class:`str`: The ID of the input text field that gets received during an interaction.""" + return self._underlying.custom_id + + @custom_id.setter + def custom_id(self, value: str): + if not isinstance(value, str): + raise TypeError(f"custom_id must be None or str not {value.__class__}") + self._underlying.custom_id = value + + @property + def label(self) -> str: + """:class:`str`: The label of the input text field.""" + return self._underlying.label + + @label.setter + def label(self, value: str): + if not isinstance(value, str): + raise TypeError(f"label should be None or str not {value.__class__}") + + @property + def placeholder(self) -> Optional[str]: + """Optional[:class:`str`]: The placeholder text that is shown before anything is entered, if any.""" + return self._underlying.placeholder + + @placeholder.setter + def placeholder(self, value: Optional[str]): + if value and not isinstance(value, str): + raise TypeError(f"placeholder must be None or str not {value.__class__}") # type: ignore + self._underlying.placeholder = value + + @property + def min_length(self) -> Optional[int]: + """Optional[:class:`int`]: The minimum number of characters that must be entered. Defaults to `0`.""" + return self._underlying.min_length + + @min_length.setter + def min_length(self, value: Optional[int]): + if value and not isinstance(value, int): + raise TypeError(f"min_length must be None or int not {value.__class__}") # type: ignore + self._underlying.min_length = value + + @property + def max_length(self) -> Optional[int]: + """Optional[:class:`int`]: The maximum number of characters that can be entered.""" + return self._underlying.max_length + + @max_length.setter + def max_length(self, value: Optional[int]): + if value and not isinstance(value, int): + raise TypeError(f"min_length must be None or int not {value.__class__}") # type: ignore + self._underlying.max_length = value + + @property + def required(self) -> Optional[bool]: + """Optional[:class:`bool`]: Whether the input text field is required or not. Defaults to `True`.""" + return self._underlying.required + + @required.setter + def required(self, value: Optional[bool]): + if not isinstance(value, bool): + raise TypeError(f"required must be bool not {value.__class__}") # type: ignore + + @property + def value(self) -> Optional[str]: + """Optional[:class:`str`]: The value entered in the text field.""" + return self._input_value or self._underlying.value + + @value.setter + def value(self, value: Optional[str]): + if value and not isinstance(value, str): + raise TypeError(f"value must be None or str not {value.__class__}") # type: ignore + self._underlying.value = value + + @property + def width(self) -> int: + return 5 + + def to_component_dict(self) -> InputTextComponentPayload: + return self._underlying.to_dict() + + def refresh_state(self, data) -> None: + self._input_value = data["value"] diff --git a/discord/ui/modal.py b/discord/ui/modal.py new file mode 100644 index 0000000000..f1eceb00dd --- /dev/null +++ b/discord/ui/modal.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import os +from itertools import groupby +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +from .item import Item +from .view import _ViewWeights + +__all__ = ("Modal",) + + +if TYPE_CHECKING: + from ..interactions import Interaction + from ..state import ConnectionState + + +class Modal: + """Represents a UI Modal dialog. + + This object must be inherited to create a UI within Discord. + """ + + def __init__(self, title: str, custom_id: Optional[str] = None) -> None: + self.custom_id = custom_id or os.urandom(16).hex() + self.title = title + self.children: List[Item] = [] + self.__weights = _ViewWeights(self.children) + + async def callback(self, interaction: Interaction): + """|coro| + The coroutine that is called when the modal dialog is submitted. + Should be overridden to handle the values submitted by the user. + + Parameters + ----------- + interaction: :class:`~discord.Interaction` + The interaction that submitted the modal dialog. + """ + pass + + def to_components(self) -> List[Dict[str, Any]]: + def key(item: Item) -> int: + return item._rendered_row or 0 + + children = sorted(self.children, key=key) + components: List[Dict[str, Any]] = [] + for _, group in groupby(children, key=key): + children = [item.to_component_dict() for item in group] + if not children: + continue + + components.append( + { + "type": 1, + "components": children, + } + ) + + return components + + def add_item(self, item: Item): + """Adds an item to the modal dialog. + + Parameters + ---------- + item: :class:`Item` + The item to add to the modal dialog + """ + + if len(self.children) > 5: + raise ValueError("You can only have up to 5 items in a modal dialog.") + + if not isinstance(item, Item): + raise TypeError(f"expected Item not {item.__class__!r}") + + self.__weights.add_item(item) + self.children.append(item) + + def remove_item(self, item: Item): + """Removes an item from the modal dialog. + + Parameters + ---------- + item: :class:`Item` + The item to remove from the modal dialog. + """ + try: + self.children.remove(item) + except ValueError: + pass + else: + self.__weights.remove_item(item) + + def to_dict(self): + return { + "title": self.title, + "custom_id": self.custom_id, + "components": self.to_components(), + } + + +class ModalStore: + def __init__(self, state: ConnectionState) -> None: + # (user_id, custom_id) : Modal + self._modals: Dict[Tuple[int, str], Modal] = {} + self._state: ConnectionState = state + + def add_modal(self, modal: Modal, user_id: int): + self._modals[(user_id, modal.custom_id)] = modal + + def remove_modal(self, modal: Modal, user_id): + self._modals.pop((user_id, modal.custom_id)) + + async def dispatch(self, user_id: int, custom_id: str, interaction: Interaction): + key = (user_id, custom_id) + value = self._modals.get(key) + if value is None: + return + + components = [ + component + for parent_component in interaction.data["components"] + for component in parent_component["components"] + ] + for component in components: + for child in value.children: + if child.custom_id == component["custom_id"]: # type: ignore + child.refresh_state(component) + break + await value.callback(interaction) + self.remove_modal(value, user_id) diff --git a/examples/modal_dialogs.py b/examples/modal_dialogs.py new file mode 100644 index 0000000000..3b2961db94 --- /dev/null +++ b/examples/modal_dialogs.py @@ -0,0 +1,89 @@ +import discord +from discord.ext import commands +from discord.ui import InputText, Modal + + +class Bot(commands.Bot): + def __init__(self): + super().__init__(command_prefix=">") + + +bot = Bot() + + +class MyModal(Modal): + def __init__(self) -> None: + super().__init__("Test Modal Dialog") + self.add_item(InputText(label="Short Input", placeholder="Placeholder Test")) + + self.add_item( + InputText( + label="Longer Input", + value="Longer Value\nSuper Long Value", + style=discord.InputTextStyle.long, + ) + ) + + async def callback(self, interaction: discord.Interaction): + embed = discord.Embed(title="Your Modal Results", color=discord.Color.random()) + embed.add_field(name="First Input", value=self.children[0].value, inline=False) + embed.add_field(name="Second Input", value=self.children[1].value, inline=False) + await interaction.response.send_message(embeds=[embed]) + + +@bot.slash_command(name="modaltest", guild_ids=[...]) +async def modal_slash(ctx): + """Shows an example of a modal dialog being invoked from a slash command.""" + modal = MyModal() + await ctx.interaction.response.send_modal(modal) + + +@bot.message_command(name="messagemodal", guild_ids=[...]) +async def modal_message(ctx, message): + """Shows an example of a modal dialog being invoked from a message command.""" + modal = MyModal() + modal.title = f"Modal for Message ID: {message.id}" + await ctx.interaction.response.send_modal(modal) + + +@bot.user_command(name="usermodal", guild_ids=[...]) +async def modal_user(ctx, member): + """Shows an example of a modal dialog being invoked from a user command.""" + modal = MyModal() + modal.title = f"Modal for User: {member.display_name}" + await ctx.interaction.response.send_modal(modal) + + +@bot.command() +async def modaltest(ctx): + """Shows an example of modals being invoked from an interaction component (e.g. a button or select menu)""" + + class MyView(discord.ui.View): + @discord.ui.button(label="Modal Test", style=discord.ButtonStyle.primary) + async def button_callback(self, button, interaction): + modal = MyModal() + await interaction.response.send_modal(modal) + + @discord.ui.select( + placeholder="Pick Your Modal", + min_values=1, + max_values=1, + options=[ + discord.SelectOption( + label="First Modal", description="Shows the first modal" + ), + discord.SelectOption( + label="Second Modal", description="Shows the second modal" + ), + ], + ) + async def select_callback(self, select, interaction): + modal = MyModal() + modal.title = select.values[0] + await interaction.response.send_modal(modal) + + view = MyView() + await ctx.send("Click Button, Receive Modal", view=view) + + +bot.run("your token")