From 42cdf62e36e4827e1c3e9f0d297cfe64828bb0d8 Mon Sep 17 00:00:00 2001 From: Faster Speeding Date: Fri, 19 Aug 2022 19:36:34 +0100 Subject: [PATCH] Add app commands declared client callback --- tanjun/abc.py | 28 ++++++++++++++++++++++++++++ tanjun/clients.py | 46 +++++++++++++++++++++++++++++++++++++++------- 2 files changed, 67 insertions(+), 7 deletions(-) diff --git a/tanjun/abc.py b/tanjun/abc.py index a72fb49a7..680251aaa 100644 --- a/tanjun/abc.py +++ b/tanjun/abc.py @@ -3631,12 +3631,40 @@ async def open(self) -> None: """ +class DeclaredCommands(abc.ABC): + __slots__ = () + + @property + @abc.abstractmethod + def builders(self) -> collections.Sequence[hikari.api.CommandBuilder]: + """The declared command builders.""" + + @property + @abc.abstractmethod + def commands(self) -> collections.Sequence[hikari.PartialCommand]: + """The declared command objects.""" + + @property + @abc.abstractmethod + def guild_id(self) -> typing.Optional[hikari.Snowflake]: + """Id of the guild these commands were declared for. + + This will be [None][] if they were declared globally. + """ + + class ClientCallbackNames(str, enum.Enum): """Enum of the standard client callback names. These should be dispatched by all [tanjun.abc.Client][] implementations. """ + APP_COMMANDS_DECLARED = "app_commands_delcared" + """Called when the application commands are declared through the client. + + One positional argument of type [DeclaredCommands][]. + """ + CLOSED = "closed" """Called when the client has finished closing. diff --git a/tanjun/clients.py b/tanjun/clients.py index 792fd9783..32664ac92 100644 --- a/tanjun/clients.py +++ b/tanjun/clients.py @@ -491,6 +491,33 @@ async def __call__(self) -> None: self.client.remove_client_callback(ClientCallbackNames.STARTING, self) +class _DeclaredCommands(tanjun.DeclaredCommands): + __slots__ = ("_builders", "_commands", "_guild_id") + + def __init__( + self, + builders: collections.Sequence[hikari.api.CommandBuilder], + commands: collections.Sequence[hikari.PartialCommand], + guild_id: typing.Optional[hikari.Snowflake], + /, + ) -> None: + self._builders = builders + self._commands = commands + self._guild_id = guild_id + + @property + def builders(self) -> collections.Sequence[hikari.api.CommandBuilder]: + return self._builders + + @property + def commands(self) -> collections.Sequence[hikari.PartialCommand]: + return self._commands + + @property + def guild_id(self) -> typing.Optional[hikari.Snowflake]: + return self._guild_id + + class Client(tanjun.Client): """Tanjun's standard [tanjun.abc.Client][] implementation. @@ -1312,7 +1339,7 @@ async def declare_application_commands( user_ids = user_ids or {} names_to_commands: dict[tuple[hikari.CommandType, str], tanjun.AppCommand[typing.Any]] = {} conflicts: set[tuple[hikari.CommandType, str]] = set() - builders: dict[tuple[hikari.CommandType, str], hikari.api.CommandBuilder] = {} + builders_dict: dict[tuple[hikari.CommandType, str], hikari.api.CommandBuilder] = {} message_count = 0 slash_count = 0 user_count = 0 @@ -1320,7 +1347,7 @@ async def declare_application_commands( for command in commands: key = (command.type, command.name) names_to_commands[key] = command - if key in builders: + if key in builders_dict: conflicts.add(key) builder = command.build() @@ -1345,7 +1372,7 @@ async def declare_application_commands( if builder.is_dm_enabled is hikari.UNDEFINED: builder.set_is_dm_enabled(self.dms_enabled_for_app_cmds) - builders[key] = builder + builders_dict[key] = builder if conflicts: raise ValueError( @@ -1367,16 +1394,17 @@ async def declare_application_commands( if not force: registered_commands = await self._rest.fetch_application_commands(application, guild=guild) - if len(registered_commands) == len(builders) and all( - _cmp_command(builders.get((c.type, c.name)), c) for c in registered_commands + if len(registered_commands) == len(builders_dict) and all( + _cmp_command(builders_dict.get((c.type, c.name)), c) for c in registered_commands ): _LOGGER.info( "Skipping bulk declare for %s application commands since they're already declared", target_type ) return registered_commands - _LOGGER.info("Bulk declaring %s %s application commands", len(builders), target_type) - responses = await self._rest.set_application_commands(application, list(builders.values()), guild=guild) + _LOGGER.info("Bulk declaring %s %s application commands", len(builders_dict), target_type) + builders = list(builders_dict.values()) + responses = await self._rest.set_application_commands(application, builders, guild=guild) for response in responses: if not guild: @@ -1390,6 +1418,10 @@ async def declare_application_commands( ", ".join(f"{response.type}-{response.name}: {response.id}" for response in responses), ) + await self.dispatch_client_callback( + tanjun.ClientCallbackNames.APP_COMMANDS_DECLARED, + _DeclaredCommands(builders, responses, None if guild is hikari.UNDEFINED else hikari.Snowflake(guild)), + ) return responses def set_auto_defer_after(self: _ClientT, time: typing.Optional[float], /) -> _ClientT: