diff --git a/examples/confirm.py b/examples/confirm.py index e4718a7..117f0c1 100644 --- a/examples/confirm.py +++ b/examples/confirm.py @@ -1,8 +1,6 @@ # This example requires the 'message_content' privileged intent to function. -import asyncio - import discord from discord.ext import commands from ductile import State, View, ViewObject @@ -23,6 +21,7 @@ class ConfirmView(View): def __init__(self) -> None: super().__init__() self.approved = State[bool | None](None, self) + self.description = State("Are you sure?", self) def render(self) -> ViewObject: async def handle_approve(interaction: discord.Interaction) -> None: @@ -36,7 +35,7 @@ async def handle_deny(interaction: discord.Interaction) -> None: self.stop() return ViewObject( - embeds=[discord.Embed(title="Confirm", description="Are you sure?")], + embeds=[discord.Embed(title="Confirm", description=self.description())], components=[ Button("yes", style={"color": "green", "disabled": self.approved() is not None}, on_click=handle_approve), Button("no", style={"color": "red", "disabled": self.approved() is not None}, on_click=handle_deny), @@ -44,7 +43,7 @@ async def handle_deny(interaction: discord.Interaction) -> None: ) async def on_timeout(self) -> None: - print("Timed out") # noqa: T201 + self.description.set_state("Timed out.") self.approved.set_state(False) # noqa: FBT003 self.stop() @@ -74,8 +73,6 @@ async def wait_for_confirm(interaction: discord.Interaction) -> None: controller = InteractionController(view, interaction=interaction, timeout=60, ephemeral=True) await controller.send() await controller.wait() - # This sleep is workaround. see details at https://github.com/sushi-chaaaan/ductile-ui/issues/23 - await asyncio.sleep(3.0) self.approved.set_state(view.approved()) return ViewObject( @@ -105,4 +102,4 @@ async def send_counter(ctx: commands.Context) -> None: await controller.wait() -bot.run("MY_COOL_BOT_TOKEN") +bot.run("MTE2MTUzOTQyMzg5NzM4NzAyOQ.GoAzw6.Okx6ZDcHIpPZhExWb54aiHZ-BYHZtJMf0hqSBY") diff --git a/src/ductile/controller/controller.py b/src/ductile/controller/controller.py index ed6578b..434b81c 100644 --- a/src/ductile/controller/controller.py +++ b/src/ductile/controller/controller.py @@ -2,6 +2,7 @@ from ..internal import _InternalView # noqa: TID252 from ..state import State # noqa: TID252 +from ..utils import wait_tasks_by_name # noqa: TID252 if TYPE_CHECKING: from collections.abc import Generator @@ -89,12 +90,16 @@ async def wait(self) -> ViewResult: `states` is a dictionary of all states in the view. """ - timed_out = await self.__raw_view.wait() + is_timed_out = await self.__raw_view.wait() + + # this is got from discord.ui.View._dispatch_timeout() + timeout_task_name = f"discord-ui-view-timeout-{self.__raw_view.id}" + await wait_tasks_by_name([timeout_task_name]) d = {} for key, state in self._get_all_state_in_view(): d[key] = state.get_state() - return ViewResult(timed_out, d) + return ViewResult(is_timed_out, d) def _get_all_state_in_view(self) -> "Generator[tuple[str, State[Any]], None, None]": for k, v in self.__view.__dict__.items(): diff --git a/src/ductile/utils/__init__.py b/src/ductile/utils/__init__.py index 3b1ec31..55de695 100644 --- a/src/ductile/utils/__init__.py +++ b/src/ductile/utils/__init__.py @@ -1,7 +1,10 @@ +from .async_helper import get_all_tasks, wait_tasks_by_name from .call import call_any_function from .logger import get_logger __all__ = [ + "get_all_tasks", + "wait_tasks_by_name", "call_any_function", "get_logger", ] diff --git a/src/ductile/utils/async_helper.py b/src/ductile/utils/async_helper.py new file mode 100644 index 0000000..ba6a993 --- /dev/null +++ b/src/ductile/utils/async_helper.py @@ -0,0 +1,58 @@ +import asyncio +from typing import Any + + +def get_all_tasks(*, loop: asyncio.AbstractEventLoop | None = None, only_undone: bool = False) -> set[asyncio.Task]: + """ + Return a set of all tasks running on the given event loop. + + Args: + loop `(asyncio.AbstractEventLoop | None)`: The event loop to get tasks from. + If None, the current running loop is used. + only_undone `(bool)`: If True, only tasks that are not done are returned. + + Returns + ------- + `set[asyncio.Task]`: A set of all tasks running on the given event loop. + """ + _loop: asyncio.AbstractEventLoop | None = loop or safe_get_running_loop() + + if _loop is None: + return set() + + tasks = asyncio.all_tasks(loop=_loop) + + if only_undone: + return {t for t in tasks if not t.done()} + + return tasks + + +async def wait_tasks_by_name(names: list[str]) -> list[asyncio.Future[Any]]: + """ + Wait for all tasks with the given names to complete. + + Args: + names `(list[str])`: A list of task names to wait for. + + Returns + ------- + `list[asyncio.Future[Any]]`: A list of futures representing the results of the completed tasks. + """ + tasks = [t for t in get_all_tasks() if t.get_name() in names] + return await asyncio.gather(*tasks, return_exceptions=True) + + +def safe_get_running_loop() -> asyncio.AbstractEventLoop | None: + """ + Get the running loop safely. + + Returns + ------- + `asyncio.AbstractEventLoop | None` + The running loop. None if faced RuntimeError. + """ + try: + return asyncio.get_running_loop() + except RuntimeError: + return None