Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support synchronous function in callback #30

Merged
merged 2 commits into from
Oct 14, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: support synchronous function in callback
  • Loading branch information
sushichan044 committed Oct 14, 2023

Unverified

The signing certificate or its chain could not be verified.
commit 7f0faa1b3b58d3a7e73a86fc6dc33930d942bc7c
20 changes: 20 additions & 0 deletions src/ductile/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,42 @@
from .callback import (
ChannelSelectCallback,
ChannelSelectSyncCallback,
#
InteractionCallback,
InteractionSyncCallback,
#
MentionableSelectCallback,
MentionableSelectSyncCallback,
#
ModalCallback,
ModalSyncCallback,
#
RoleSelectCallback,
RoleSelectSyncCallback,
#
SelectCallback,
SelectSyncCallback,
#
UserSelectCallback,
UserSelectSyncCallback,
)
from .view import ViewErrorHandler, ViewTimeoutHandler

__all__ = [
"InteractionCallback",
"InteractionSyncCallback",
"SelectCallback",
"SelectSyncCallback",
"ChannelSelectCallback",
"ChannelSelectSyncCallback",
"RoleSelectCallback",
"RoleSelectSyncCallback",
"MentionableSelectCallback",
"MentionableSelectSyncCallback",
"UserSelectCallback",
"UserSelectSyncCallback",
"ModalCallback",
"ModalSyncCallback",
"ViewErrorHandler",
"ViewTimeoutHandler",
]
23 changes: 22 additions & 1 deletion src/ductile/types/callback.py
Original file line number Diff line number Diff line change
@@ -14,24 +14,45 @@
"ModalCallback",
]


# InteractionCallback
InteractionCallback: TypeAlias = Callable[[discord.Interaction], Awaitable[None]]
InteractionSyncCallback: TypeAlias = Callable[[discord.Interaction], None]

# SelectCallback
SelectCallback: TypeAlias = Callable[[discord.Interaction, list[str]], Awaitable[None]]
SelectSyncCallback: TypeAlias = Callable[[discord.Interaction, list[str]], None]

ChannelSelectCallback: TypeAlias = Callable[
[discord.Interaction, list[AppCommandChannel | AppCommandThread]],
Awaitable[None],
]
ChannelSelectSyncCallback: TypeAlias = Callable[
[discord.Interaction, list[AppCommandChannel | AppCommandThread]],
None,
]

RoleSelectCallback: TypeAlias = Callable[
[discord.Interaction, list[discord.Role]],
Awaitable[None],
]
RoleSelectSyncCallback: TypeAlias = Callable[
[discord.Interaction, list[discord.Role]],
None,
]

MentionableSelectCallback: TypeAlias = Callable[
[discord.Interaction, list[discord.Role | discord.Member | discord.User]],
Awaitable[None],
]
MentionableSelectSyncCallback: TypeAlias = Callable[
[discord.Interaction, list[discord.Role | discord.Member | discord.User]],
None,
]

UserSelectCallback: TypeAlias = Callable[[discord.Interaction, list[discord.User | discord.Member]], Awaitable[None]]
UserSelectSyncCallback: TypeAlias = Callable[[discord.Interaction, list[discord.User | discord.Member]], None]


# ModalCallback
ModalCallback: TypeAlias = Callable[[discord.Interaction, dict[str, str]], Awaitable[None]]
ModalSyncCallback: TypeAlias = Callable[[discord.Interaction, dict[str, str]], None]
16 changes: 11 additions & 5 deletions src/ductile/ui/button.py
Original file line number Diff line number Diff line change
@@ -3,12 +3,12 @@
from discord import ButtonStyle as _ButtonStyle
from discord import ui

from ..utils import call_any_function # noqa: TID252
from ..utils import call_any_function, is_sync_func # noqa: TID252

if TYPE_CHECKING:
from discord import Emoji, Interaction, PartialEmoji

from ..types import InteractionCallback # noqa: TID252
from ..types import InteractionCallback, InteractionSyncCallback # noqa: TID252


class _ButtonStyleRequired(TypedDict):
@@ -37,7 +37,7 @@ def __init__(
*,
style: ButtonStyle,
custom_id: str | None = None,
on_click: "InteractionCallback | None" = None,
on_click: "InteractionCallback | InteractionSyncCallback | None" = None,
) -> None:
__style = _ButtonStyle[style.get("color", "grey")]
__disabled = style.get("disabled", False)
@@ -54,8 +54,14 @@ def __init__(
)

async def callback(self, interaction: "Interaction") -> None: # noqa: D102
if self.__callback_fn:
await call_any_function(self.__callback_fn, interaction)
if self.__callback_fn is None:
await interaction.response.defer()
return

if is_sync_func(self.__callback_fn):
await interaction.response.defer()

await call_any_function(self.__callback_fn, interaction)


class LinkButton(ui.Button):
16 changes: 12 additions & 4 deletions src/ductile/ui/modal.py
Original file line number Diff line number Diff line change
@@ -2,10 +2,12 @@

from discord import TextStyle, ui

from ..utils import call_any_function, is_sync_func # noqa: TID252

if TYPE_CHECKING:
from discord import Interaction

from ..types import ModalCallback # noqa: TID252
from ..types import ModalCallback, ModalSyncCallback # noqa: TID252


class TextInputStyle(TypedDict, total=False):
@@ -71,7 +73,7 @@ def __init__( # noqa: PLR0913
inputs: list[TextInput],
timeout: float | None = None,
custom_id: str | None = None,
on_submit: "ModalCallback | None" = None,
on_submit: "ModalCallback | ModalSyncCallback | None" = None,
) -> None:
__d = {
"title": title,
@@ -86,5 +88,11 @@ def __init__( # noqa: PLR0913
self.add_item(_in)

async def on_submit(self, interaction: "Interaction") -> None: # noqa: D102
if self.__callback_fn:
await self.__callback_fn(interaction, {i.label: i.value for i in self.__inputs})
if self.__callback_fn is None:
await interaction.response.defer()
return

if is_sync_func(self.__callback_fn):
await interaction.response.defer()

await call_any_function(self.__callback_fn, interaction, {i.label: i.value for i in self.__inputs})
67 changes: 51 additions & 16 deletions src/ductile/ui/select.py
Original file line number Diff line number Diff line change
@@ -4,17 +4,22 @@
from discord import SelectOption as _SelectOption
from pydantic import BaseModel, Field

from ..utils import call_any_function # noqa: TID252
from ..utils import call_any_function, is_sync_func # noqa: TID252

if TYPE_CHECKING:
from discord import ChannelType, Interaction

from ..types import ( # noqa: TID252
ChannelSelectCallback,
ChannelSelectSyncCallback,
MentionableSelectCallback,
MentionableSelectSyncCallback,
RoleSelectCallback,
RoleSelectSyncCallback,
SelectCallback,
SelectSyncCallback,
UserSelectCallback,
UserSelectSyncCallback,
)


@@ -85,7 +90,7 @@ def __init__( # noqa: PLR0913
style: SelectStyle,
options: list[SelectOption],
custom_id: str | None = None,
on_select: "SelectCallback | None" = None,
on_select: "SelectCallback | SelectSyncCallback | None" = None,
) -> None:
__disabled = style.get("disabled", False)
__placeholder = style.get("placeholder", None)
@@ -113,8 +118,14 @@ def __init__( # noqa: PLR0913
super().__init__(**__d)

async def callback(self, interaction: "Interaction") -> None: # noqa: D102
if self.__callback_fn:
await call_any_function(self.__callback_fn, interaction, self.values)
if self.__callback_fn is None:
await interaction.response.defer()
return

if is_sync_func(self.__callback_fn):
await interaction.response.defer()

await call_any_function(self.__callback_fn, interaction, self.values)


class ChannelSelect(ui.ChannelSelect):
@@ -130,7 +141,7 @@ def __init__(
config: ChannelSelectConfig,
style: SelectStyle,
custom_id: str | None = None,
on_select: "ChannelSelectCallback | None" = None,
on_select: "ChannelSelectCallback| ChannelSelectSyncCallback | None" = None,
) -> None:
__disabled = style.get("disabled", False)
__placeholder = style.get("placeholder", None)
@@ -149,8 +160,14 @@ def __init__(
super().__init__(**__d)

async def callback(self, interaction: "Interaction") -> None: # noqa: D102
if self.__callback_fn:
await call_any_function(self.__callback_fn, interaction, self.values)
if self.__callback_fn is None:
await interaction.response.defer()
return

if is_sync_func(self.__callback_fn):
await interaction.response.defer()

await call_any_function(self.__callback_fn, interaction, self.values)


class RoleSelect(ui.RoleSelect):
@@ -166,7 +183,7 @@ def __init__(
config: RoleSelectConfig,
style: SelectStyle,
custom_id: str | None = None,
on_select: "RoleSelectCallback | None" = None,
on_select: "RoleSelectCallback | RoleSelectSyncCallback | None" = None,
) -> None:
__disabled = style.get("disabled", False)
__placeholder = style.get("placeholder", None)
@@ -184,8 +201,14 @@ def __init__(
super().__init__(**__d)

async def callback(self, interaction: "Interaction") -> None: # noqa: D102
if self.__callback_fn:
await call_any_function(self.__callback_fn, interaction, self.values)
if self.__callback_fn is None:
await interaction.response.defer()
return

if is_sync_func(self.__callback_fn):
await interaction.response.defer()

await call_any_function(self.__callback_fn, interaction, self.values)


class MentionableSelect(ui.MentionableSelect):
@@ -202,7 +225,7 @@ def __init__(
config: MentionableSelectConfig,
style: SelectStyle,
custom_id: str | None = None,
on_select: "MentionableSelectCallback | None" = None,
on_select: "MentionableSelectCallback | MentionableSelectSyncCallback | None" = None,
) -> None:
__disabled = style.get("disabled", False)
__placeholder = style.get("placeholder", None)
@@ -220,8 +243,14 @@ def __init__(
super().__init__(**__d)

async def callback(self, interaction: "Interaction") -> None: # noqa: D102
if self.__callback_fn:
await call_any_function(self.__callback_fn, interaction, self.values)
if self.__callback_fn is None:
await interaction.response.defer()
return

if is_sync_func(self.__callback_fn):
await interaction.response.defer()

await call_any_function(self.__callback_fn, interaction, self.values)


class UserSelect(ui.UserSelect):
@@ -237,7 +266,7 @@ def __init__(
config: UserSelectConfig,
style: SelectStyle,
custom_id: str | None = None,
on_select: "UserSelectCallback | None" = None,
on_select: "UserSelectCallback | UserSelectSyncCallback | None" = None,
) -> None:
__disabled = style.get("disabled", False)
__placeholder = style.get("placeholder", None)
@@ -255,5 +284,11 @@ def __init__(
super().__init__(**__d)

async def callback(self, interaction: "Interaction") -> None: # noqa: D102
if self.__callback_fn:
await call_any_function(self.__callback_fn, interaction, self.values)
if self.__callback_fn is None:
await interaction.response.defer()
return

if is_sync_func(self.__callback_fn):
await interaction.response.defer()

await call_any_function(self.__callback_fn, interaction, self.values)
3 changes: 3 additions & 0 deletions src/ductile/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from .async_helper import get_all_tasks, wait_tasks_by_name
from .call import call_any_function
from .logger import get_logger
from .type_helper import is_async_func, is_sync_func

__all__ = [
"get_all_tasks",
"wait_tasks_by_name",
"call_any_function",
"get_logger",
"is_async_func",
"is_sync_func",
]
36 changes: 36 additions & 0 deletions src/ductile/utils/type_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from collections.abc import Awaitable, Callable
from inspect import iscoroutinefunction
from typing import ParamSpec, TypeGuard, TypeVar

_P = ParamSpec("_P")
_R = TypeVar("_R")


def is_async_func(func: Callable[_P, _R | Awaitable[_R]]) -> TypeGuard[Callable[_P, Awaitable[_R]]]:
"""
Check if a function is an asynchronous function.

Args
----
func (`Callable`): The function to check.

Returns
-------
`True` if the function is an asynchronous, `False` otherwise.
"""
return callable(func) and iscoroutinefunction(func)


def is_sync_func(func: Callable[_P, _R | Awaitable[_R]]) -> TypeGuard[Callable[_P, _R]]:
"""
Check if a function is synchronous.

Args
----
func (`Callable`): The function to check.

Returns
-------
`True` if the function is synchronous, `False` otherwise.
"""
return callable(func) and not iscoroutinefunction(func)