Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 49 additions & 44 deletions discord/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,50 +28,42 @@
import asyncio
import collections
import inspect
import sys
import traceback
from .commands.errors import CheckFailure

from typing import (
Any,
Callable,
Coroutine,
List,
Optional,
Type,
TypeVar,
Union,
)
from typing import Any, Callable, Coroutine, List, Optional, Type, TypeVar, Union

import sys
import discord

from .client import Client
from .shard import AutoShardedClient
from .utils import MISSING, get, find, async_all
from .cog import CogMixin
from .commands import (
SlashCommand,
SlashCommandGroup,
MessageCommand,
UserCommand,
ApplicationCommand,
ApplicationContext,
AutocompleteContext,
MessageCommand,
SlashCommand,
SlashCommandGroup,
UserCommand,
command,
)
from .cog import CogMixin

from .errors import Forbidden, DiscordException
from .interactions import Interaction
from .commands.errors import CheckFailure
from .enums import InteractionType
from .errors import DiscordException, Forbidden
from .interactions import Interaction
from .shard import AutoShardedClient
from .state import ConnectionState
from .utils import MISSING, async_all, find, get

CoroFunc = Callable[..., Coroutine[Any, Any, Any]]
CFT = TypeVar('CFT', bound=CoroFunc)
CFT = TypeVar("CFT", bound=CoroFunc)

__all__ = (
'ApplicationCommandMixin',
'Bot',
'AutoShardedBot',
"ApplicationCommandMixin",
"Bot",
"AutoShardedBot",
)


class ApplicationCommandMixin:
"""A mixin that implements common functionality for classes that need
application command compatibility.
Expand Down Expand Up @@ -149,10 +141,10 @@ def remove_application_command(
@property
def get_command(self):
"""Shortcut for :meth:`.get_application_command`.

.. note::
Overridden in :class:`ext.commands.Bot`.

.. versionadded:: 2.0
"""
# TODO: Do something like we did in self.commands for this
Expand Down Expand Up @@ -185,10 +177,7 @@ def get_application_command(
"""

for command in self._application_commands.values():
if (
command.name == name
and isinstance(command, type)
):
if command.name == name and isinstance(command, type):
if guild_ids is not None and command.guild_ids != guild_ids:
return
return command
Expand Down Expand Up @@ -287,7 +276,12 @@ async def register_commands(self) -> None:
raise
else:
for i in cmds:
cmd = find(lambda cmd: cmd.name == i["name"] and cmd.type == i["type"] and int(i["guild_id"]) in cmd.guild_ids, self.pending_application_commands)
cmd = find(
lambda cmd: cmd.name == i["name"]
and cmd.type == i["type"]
and int(i["guild_id"]) in cmd.guild_ids,
self.pending_application_commands,
)
cmd.id = i["id"]
self._application_commands[cmd.id] = cmd

Expand Down Expand Up @@ -380,7 +374,9 @@ async def register_commands(self) -> None:
if len(new_cmd_perm["permissions"]) > 10:
print(
"Command '{name}' has more than 10 permission overrides in guild ({guild_id}).\nwill only use the first 10 permission overrides.".format(
name=self._application_commands[new_cmd_perm["id"]].name,
name=self._application_commands[
new_cmd_perm["id"]
].name,
guild_id=guild_id,
)
)
Expand Down Expand Up @@ -424,8 +420,8 @@ async def process_application_commands(self, interaction: Interaction) -> None:
The interaction to process
"""
if interaction.type not in (
InteractionType.application_command,
InteractionType.auto_complete
InteractionType.application_command,
InteractionType.auto_complete,
):
return

Expand All @@ -438,7 +434,7 @@ async def process_application_commands(self, interaction: Interaction) -> None:
ctx = await self.get_autocomplete_context(interaction)
ctx.command = command
return await command.invoke_autocomplete_callback(ctx)

ctx = await self.get_application_context(interaction)
ctx.command = command
self.dispatch("application_command", ctx)
Expand Down Expand Up @@ -591,17 +587,20 @@ def group(
Callable[[Type[SlashCommandGroup]], SlashCommandGroup]
The slash command group that was created.
"""

def inner(cls: Type[SlashCommandGroup]) -> SlashCommandGroup:
group = cls(
name,
(
description or inspect.cleandoc(cls.__doc__).splitlines()[0]
if cls.__doc__ is not None else "No description provided"
if cls.__doc__ is not None
else "No description provided"
),
guild_ids=guild_ids
guild_ids=guild_ids,
)
self.add_application_command(group)
return group

return inner

slash_group = group
Expand Down Expand Up @@ -667,7 +666,6 @@ class be provided, it must be similar enough to
return cls(self, interaction)



class BotBase(ApplicationCommandMixin, CogMixin):
_supports_prefixed_commands = False
# TODO I think
Expand Down Expand Up @@ -717,6 +715,13 @@ async def on_connect(self):

async def on_interaction(self, interaction):
await self.process_application_commands(interaction)
if interaction.type == discord.InteractionType.modal_submit:
state: ConnectionState = self._connection # type: ignore
user_id, custom_id = (
interaction.user.id,
interaction.data["custom_id"],
)
await state._modal_store.dispatch(user_id, custom_id, interaction)

async def on_application_command_error(
self, context: ApplicationContext, exception: DiscordException
Expand All @@ -730,7 +735,7 @@ async def on_application_command_error(

This only fires if you do not specify any listeners for command error.
"""
if self.extra_events.get('on_application_command_error', None):
if self.extra_events.get("on_application_command_error", None):
return

command = context.command
Expand Down Expand Up @@ -887,7 +892,7 @@ async def my_message(message): pass
name = func.__name__ if name is MISSING else name

if not asyncio.iscoroutinefunction(func):
raise TypeError('Listeners must be coroutines')
raise TypeError("Listeners must be coroutines")

if name in self.extra_events:
self.extra_events[name].append(func)
Expand Down Expand Up @@ -953,7 +958,7 @@ def decorator(func: CFT) -> CFT:
def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None:
# super() will resolve to Client
super().dispatch(event_name, *args, **kwargs) # type: ignore
ev = 'on_' + event_name
ev = f"on_{event_name}"
for event in self.extra_events.get(ev, []):
self._schedule_event(event, ev, *args, **kwargs) # type: ignore

Expand Down
Loading