Skip to content

Commit

Permalink
Add app commands declared client callback
Browse files Browse the repository at this point in the history
  • Loading branch information
FasterSpeeding committed Aug 19, 2022
1 parent 02f23bf commit 42cdf62
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 7 deletions.
28 changes: 28 additions & 0 deletions tanjun/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3631,12 +3631,40 @@ async def open(self) -> None:
"""


class DeclaredCommands(abc.ABC):
__slots__ = ()

@property
@abc.abstractmethod
def builders(self) -> collections.Sequence[hikari.api.CommandBuilder]:
"""The declared command builders."""

@property
@abc.abstractmethod
def commands(self) -> collections.Sequence[hikari.PartialCommand]:
"""The declared command objects."""

@property
@abc.abstractmethod
def guild_id(self) -> typing.Optional[hikari.Snowflake]:
"""Id of the guild these commands were declared for.
This will be [None][] if they were declared globally.
"""


class ClientCallbackNames(str, enum.Enum):
"""Enum of the standard client callback names.
These should be dispatched by all [tanjun.abc.Client][] implementations.
"""

APP_COMMANDS_DECLARED = "app_commands_delcared"
"""Called when the application commands are declared through the client.
One positional argument of type [DeclaredCommands][].
"""

CLOSED = "closed"
"""Called when the client has finished closing.
Expand Down
46 changes: 39 additions & 7 deletions tanjun/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,33 @@ async def __call__(self) -> None:
self.client.remove_client_callback(ClientCallbackNames.STARTING, self)


class _DeclaredCommands(tanjun.DeclaredCommands):
__slots__ = ("_builders", "_commands", "_guild_id")

def __init__(
self,
builders: collections.Sequence[hikari.api.CommandBuilder],
commands: collections.Sequence[hikari.PartialCommand],
guild_id: typing.Optional[hikari.Snowflake],
/,
) -> None:
self._builders = builders
self._commands = commands
self._guild_id = guild_id

@property
def builders(self) -> collections.Sequence[hikari.api.CommandBuilder]:
return self._builders

@property
def commands(self) -> collections.Sequence[hikari.PartialCommand]:
return self._commands

@property
def guild_id(self) -> typing.Optional[hikari.Snowflake]:
return self._guild_id


class Client(tanjun.Client):
"""Tanjun's standard [tanjun.abc.Client][] implementation.
Expand Down Expand Up @@ -1312,15 +1339,15 @@ async def declare_application_commands(
user_ids = user_ids or {}
names_to_commands: dict[tuple[hikari.CommandType, str], tanjun.AppCommand[typing.Any]] = {}
conflicts: set[tuple[hikari.CommandType, str]] = set()
builders: dict[tuple[hikari.CommandType, str], hikari.api.CommandBuilder] = {}
builders_dict: dict[tuple[hikari.CommandType, str], hikari.api.CommandBuilder] = {}
message_count = 0
slash_count = 0
user_count = 0

for command in commands:
key = (command.type, command.name)
names_to_commands[key] = command
if key in builders:
if key in builders_dict:
conflicts.add(key)

builder = command.build()
Expand All @@ -1345,7 +1372,7 @@ async def declare_application_commands(
if builder.is_dm_enabled is hikari.UNDEFINED:
builder.set_is_dm_enabled(self.dms_enabled_for_app_cmds)

builders[key] = builder
builders_dict[key] = builder

if conflicts:
raise ValueError(
Expand All @@ -1367,16 +1394,17 @@ async def declare_application_commands(

if not force:
registered_commands = await self._rest.fetch_application_commands(application, guild=guild)
if len(registered_commands) == len(builders) and all(
_cmp_command(builders.get((c.type, c.name)), c) for c in registered_commands
if len(registered_commands) == len(builders_dict) and all(
_cmp_command(builders_dict.get((c.type, c.name)), c) for c in registered_commands
):
_LOGGER.info(
"Skipping bulk declare for %s application commands since they're already declared", target_type
)
return registered_commands

_LOGGER.info("Bulk declaring %s %s application commands", len(builders), target_type)
responses = await self._rest.set_application_commands(application, list(builders.values()), guild=guild)
_LOGGER.info("Bulk declaring %s %s application commands", len(builders_dict), target_type)
builders = list(builders_dict.values())
responses = await self._rest.set_application_commands(application, builders, guild=guild)

for response in responses:
if not guild:
Expand All @@ -1390,6 +1418,10 @@ async def declare_application_commands(
", ".join(f"{response.type}-{response.name}: {response.id}" for response in responses),
)

await self.dispatch_client_callback(
tanjun.ClientCallbackNames.APP_COMMANDS_DECLARED,
_DeclaredCommands(builders, responses, None if guild is hikari.UNDEFINED else hikari.Snowflake(guild)),
)
return responses

def set_auto_defer_after(self: _ClientT, time: typing.Optional[float], /) -> _ClientT:
Expand Down

0 comments on commit 42cdf62

Please sign in to comment.