Skip to content

Commit

Permalink
Merge pull request #26 from sushi-chaaaan/fix/#23-timeout_handling
Browse files Browse the repository at this point in the history
Fix/#23 timeout handling
  • Loading branch information
sushichan044 authored Oct 13, 2023
2 parents bad46ef + 8c8b524 commit e5928c1
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 9 deletions.
11 changes: 4 additions & 7 deletions examples/confirm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# This example requires the 'message_content' privileged intent to function.


import asyncio

import discord
from discord.ext import commands
from ductile import State, View, ViewObject
Expand All @@ -23,6 +21,7 @@ class ConfirmView(View):
def __init__(self) -> None:
super().__init__()
self.approved = State[bool | None](None, self)
self.description = State("Are you sure?", self)

def render(self) -> ViewObject:
async def handle_approve(interaction: discord.Interaction) -> None:
Expand All @@ -36,15 +35,15 @@ async def handle_deny(interaction: discord.Interaction) -> None:
self.stop()

return ViewObject(
embeds=[discord.Embed(title="Confirm", description="Are you sure?")],
embeds=[discord.Embed(title="Confirm", description=self.description())],
components=[
Button("yes", style={"color": "green", "disabled": self.approved() is not None}, on_click=handle_approve),
Button("no", style={"color": "red", "disabled": self.approved() is not None}, on_click=handle_deny),
],
)

async def on_timeout(self) -> None:
print("Timed out") # noqa: T201
self.description.set_state("Timed out.")
self.approved.set_state(False) # noqa: FBT003
self.stop()

Expand Down Expand Up @@ -74,8 +73,6 @@ async def wait_for_confirm(interaction: discord.Interaction) -> None:
controller = InteractionController(view, interaction=interaction, timeout=60, ephemeral=True)
await controller.send()
await controller.wait()
# This sleep is workaround. see details at https://github.com/sushi-chaaaan/ductile-ui/issues/23
await asyncio.sleep(3.0)
self.approved.set_state(view.approved())

return ViewObject(
Expand Down Expand Up @@ -105,4 +102,4 @@ async def send_counter(ctx: commands.Context) -> None:
await controller.wait()


bot.run("MY_COOL_BOT_TOKEN")
bot.run("MTE2MTUzOTQyMzg5NzM4NzAyOQ.GoAzw6.Okx6ZDcHIpPZhExWb54aiHZ-BYHZtJMf0hqSBY")
9 changes: 7 additions & 2 deletions src/ductile/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from ..internal import _InternalView # noqa: TID252
from ..state import State # noqa: TID252
from ..utils import wait_tasks_by_name # noqa: TID252

if TYPE_CHECKING:
from collections.abc import Generator
Expand Down Expand Up @@ -89,12 +90,16 @@ async def wait(self) -> ViewResult:
`states` is a dictionary of all states in the view.
"""
timed_out = await self.__raw_view.wait()
is_timed_out = await self.__raw_view.wait()

# this is got from discord.ui.View._dispatch_timeout()
timeout_task_name = f"discord-ui-view-timeout-{self.__raw_view.id}"
await wait_tasks_by_name([timeout_task_name])

d = {}
for key, state in self._get_all_state_in_view():
d[key] = state.get_state()
return ViewResult(timed_out, d)
return ViewResult(is_timed_out, d)

def _get_all_state_in_view(self) -> "Generator[tuple[str, State[Any]], None, None]":
for k, v in self.__view.__dict__.items():
Expand Down
3 changes: 3 additions & 0 deletions src/ductile/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from .async_helper import get_all_tasks, wait_tasks_by_name
from .call import call_any_function
from .logger import get_logger

__all__ = [
"get_all_tasks",
"wait_tasks_by_name",
"call_any_function",
"get_logger",
]
58 changes: 58 additions & 0 deletions src/ductile/utils/async_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import asyncio
from typing import Any


def get_all_tasks(*, loop: asyncio.AbstractEventLoop | None = None, only_undone: bool = False) -> set[asyncio.Task]:
"""
Return a set of all tasks running on the given event loop.
Args:
loop `(asyncio.AbstractEventLoop | None)`: The event loop to get tasks from.
If None, the current running loop is used.
only_undone `(bool)`: If True, only tasks that are not done are returned.
Returns
-------
`set[asyncio.Task]`: A set of all tasks running on the given event loop.
"""
_loop: asyncio.AbstractEventLoop | None = loop or safe_get_running_loop()

if _loop is None:
return set()

tasks = asyncio.all_tasks(loop=_loop)

if only_undone:
return {t for t in tasks if not t.done()}

return tasks


async def wait_tasks_by_name(names: list[str]) -> list[asyncio.Future[Any]]:
"""
Wait for all tasks with the given names to complete.
Args:
names `(list[str])`: A list of task names to wait for.
Returns
-------
`list[asyncio.Future[Any]]`: A list of futures representing the results of the completed tasks.
"""
tasks = [t for t in get_all_tasks() if t.get_name() in names]
return await asyncio.gather(*tasks, return_exceptions=True)


def safe_get_running_loop() -> asyncio.AbstractEventLoop | None:
"""
Get the running loop safely.
Returns
-------
`asyncio.AbstractEventLoop | None`
The running loop. None if faced RuntimeError.
"""
try:
return asyncio.get_running_loop()
except RuntimeError:
return None

0 comments on commit e5928c1

Please sign in to comment.