From 5cb81dcdeb80b7af3a8a580ca1a2b9231c85021a Mon Sep 17 00:00:00 2001 From: Faster Speeding Date: Fri, 3 Jun 2022 05:34:28 +0100 Subject: [PATCH] Initial restructuring --- pyproject.toml | 2 +- tanjun/abc.py | 62 +++++++++++++++++++++++++------------- tanjun/checks.py | 2 +- tanjun/clients.py | 10 +++--- tanjun/commands/base.py | 10 +++--- tanjun/commands/message.py | 51 ++++++++++++++++--------------- tanjun/components.py | 10 +++--- tanjun/utilities.py | 10 +++--- 8 files changed, 90 insertions(+), 67 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8d9ef40eb..731063ac5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ "Topic :: Utilities", "Typing :: Typed" ] -dependencies = ["alluka~=0.1", "hikari~=2.0.0.dev109"] +dependencies = ["alluka~=0.1", "hikari~=2.0.0.dev109", "typing-extensions>=4.2.0"] dynamic = ["description", "version"] [project.urls] diff --git a/tanjun/abc.py b/tanjun/abc.py index 1e385a2bb..ae768cb0b 100644 --- a/tanjun/abc.py +++ b/tanjun/abc.py @@ -73,6 +73,7 @@ from collections import abc as collections import hikari +import typing_extensions from alluka import abc as alluka if typing.TYPE_CHECKING: @@ -91,6 +92,7 @@ _MessageCommandT = typing.TypeVar("_MessageCommandT", bound="MessageCommand[typing.Any]") _MetaEventSigT = typing.TypeVar("_MetaEventSigT", bound="MetaEventSig") +_P = typing_extensions.ParamSpec("_P") _T = typing.TypeVar("_T") _AppCommandContextT = typing.TypeVar("_AppCommandContextT", bound="AppCommandContext") _CommandCallbackSigT = typing.TypeVar("_CommandCallbackSigT", bound="CommandCallbackSig") @@ -102,8 +104,10 @@ "_MenuTypeT", typing.Literal[hikari.CommandType.USER], typing.Literal[hikari.CommandType.MESSAGE] ) +_MaybeAwaitable = typing.Union[collections.Callable[_P, _CoroT[_T]], collections.Callable[_P, _T]] +_AutocompleteCallbackSig = collections.Callable[typing_extensions.Concatenate["AutocompleteContext", _P], _CoroT[None]] -AutocompleteCallbackSig = collections.Callable[..., _CoroT[None]] +AutocompleteCallbackSig = _AutocompleteCallbackSig[...] """Type hint of the callback an autocomplete callback should have. This will be called when handling autocomplete and should be an asynchronous @@ -112,8 +116,8 @@ autocomplete type), returns [None][] and may use dependency injection. """ - -CheckSig = typing.Union[collections.Callable[..., _CoroT[bool]], collections.Callable[..., bool]] +_CheckSig = _MaybeAwaitable[typing_extensions.Concatenate[_ContextT_contra, _P], bool] +CheckSig = _CheckSig[_ContextT_contra, ...] """Type hint of a general context check used with Tanjun [tanjun.abc.ExecutableCommand][] classes. This may be registered with a [tanjun.abc.ExecutableCommand][] to add a rule @@ -124,7 +128,29 @@ current context shouldn't lead to an execution. """ -CommandCallbackSig = collections.Callable[..., _CoroT[None]] +AnyCheckSig = _CheckSig["Context", ...] + +MessageCheckSig = _CheckSig["MessageContext", ...] + +SlashCheckSig = _CheckSig["SlashContext", ...] + + +_CommandCallbackSig = collections.Callable[typing_extensions.Concatenate[_ContextT_contra, _P], None] + +_MenuValueT = typing.TypeVar("_MenuValueT", hikari.User, hikari.InteractionMember) +_ManuCallbackSig = collections.Callable[typing_extensions.Concatenate[_ContextT_contra, _MenuValueT, _P], None] +MenuCallbackSig = _ManuCallbackSig["MenuContext", _MenuValueT, ...] +"""Type hint of a context menu command callback. + +This is guaranteed two positional; arguments of type [tanjun.abc.MenuContext][] +and either `hikari.User | hikari.InteractionMember` and/or +[hikari.messages.Message][] dependent on the type(s) of menu this is. +""" + +MessageCallbackSig = _CommandCallbackSig["MessageContext", ...] +SlashCallbackSig = _CommandCallbackSig["SlashContext", ...] + +CommandCallbackSig = _CommandCallbackSig["Context", ...] """Type hint of the callback a callable [tanjun.abc.ExecutableCommand][] instance will operate on. This will be called when executing a command and will need to take one @@ -136,10 +162,9 @@ This will have to be asynchronous. """ +_ErrorHookSig = _MaybeAwaitable[typing_extensions.Concatenate[_ContextT_contra, Exception, _P], typing.Optional[bool]] -ErrorHookSig = typing.Union[ - collections.Callable[..., typing.Optional[bool]], collections.Callable[..., _CoroT[typing.Optional[bool]]] -] +ErrorHookSig = _ErrorHookSig[_ContextT_contra, ...] """Type hint of the callback used as a unexpected command error hook. This will be called whenever an unexpected [Exception][] is raised during the @@ -153,8 +178,9 @@ [False][] is returned to indicate that the exception should be re-raised. """ +_HookSig = _MaybeAwaitable[typing_extensions.Concatenate[_ContextT_contra, _P], None] -HookSig = typing.Union[collections.Callable[..., None], collections.Callable[..., _CoroT[None]]] +HookSig = _HookSig[_ContextT_contra, ...] """Type hint of the callback used as a general command hook. !!! note @@ -163,22 +189,16 @@ are passed dependent on the type of hook this is being registered as. """ -ListenerCallbackSig = collections.Callable[..., _CoroT[None]] +_ListenerCallbackSig = collections.Callable[typing_extensions.Concatenate[Exception, _P], _CoroT[None]] + +ListenerCallbackSig = _ListenerCallbackSig[...] """Type hint of a hikari event manager callback. This is guaranteed one positional arg of type [hikari.events.base_events.Event][] regardless of implementation and must be a coruotine function which returns [None][]. """ -MenuCommandCallbackSig = collections.Callable[..., _CoroT[None]] -"""Type hint of a context menu command callback. - -This is guaranteed two positional; arguments of type [tanjun.abc.MenuContext][] -and either `hikari.User | hikari.InteractionMember` and/or -[hikari.messages.Message][] dependent on the type(s) of menu this is. -""" - -MetaEventSig = typing.Union[collections.Callable[..., _CoroT[None]], collections.Callable[..., None]] +MetaEventSig = _MaybeAwaitable[..., None] """Type hint of a client callback. The positional arguments this is guaranteed depend on the event name its being @@ -2383,7 +2403,7 @@ class ExecutableCommand(abc.ABC, typing.Generic[_ContextT_co]): @property @abc.abstractmethod - def checks(self) -> collections.Collection[CheckSig]: + def checks(self) -> collections.Collection[CheckSig[_ContextT_co]]: """Collection of checks that must be met before the command can be executed.""" @property @@ -2440,7 +2460,7 @@ def set_hooks(self: _T, hooks: typing.Optional[Hooks[_ContextT_co]], /) -> _T: """ @abc.abstractmethod - def add_check(self: _T, check: CheckSig, /) -> _T: # TODO: remove or add with_check? + def add_check(self: _T, check: CheckSig[_ContextT_co], /) -> _T: # TODO: remove or add with_check? """Add a check to the command. Parameters @@ -2455,7 +2475,7 @@ def add_check(self: _T, check: CheckSig, /) -> _T: # TODO: remove or add with_c """ @abc.abstractmethod - def remove_check(self: _T, check: CheckSig, /) -> _T: + def remove_check(self: _T, check: CheckSig[_ContextT_co], /) -> _T: """Remove a check from the command. Parameters diff --git a/tanjun/checks.py b/tanjun/checks.py index 855d916c0..245f9ef2e 100644 --- a/tanjun/checks.py +++ b/tanjun/checks.py @@ -72,7 +72,7 @@ def _optional_kwargs( - command: typing.Optional[_CommandT], check: tanjun.CheckSig, / + command: typing.Optional[_CommandT], check: tanjun.AnyCheckSig, / ) -> typing.Union[_CommandT, collections.Callable[[_CommandT], _CommandT]]: if command: return command.add_check(check) diff --git a/tanjun/clients.py b/tanjun/clients.py index b95d315f3..20356bbc3 100644 --- a/tanjun/clients.py +++ b/tanjun/clients.py @@ -69,7 +69,7 @@ if typing.TYPE_CHECKING: import types - _CheckSigT = typing.TypeVar("_CheckSigT", bound=tanjun.CheckSig) + _CheckSigT = typing.TypeVar("_CheckSigT", bound=tanjun.AnyCheckSig) _ClientT = typing.TypeVar("_ClientT", bound="Client") _ListenerCallbackSigT = typing.TypeVar("_ListenerCallbackSigT", bound=tanjun.ListenerCallbackSig) _MetaEventSigT = typing.TypeVar("_MetaEventSigT", bound=tanjun.MetaEventSig) @@ -670,7 +670,7 @@ def __init__( self._auto_defer_after: typing.Optional[float] = 2.0 self._cache = cache self._cached_application_id: typing.Optional[hikari.Snowflake] = None - self._checks: list[tanjun.CheckSig] = [] + self._checks: list[tanjun.AnyCheckSig] = [] self._client_callbacks: dict[str, list[tanjun.MetaEventSig]] = {} self._components: dict[str, tanjun.Component] = {} self._defaults_to_ephemeral: bool = False @@ -1026,7 +1026,7 @@ def cache(self) -> typing.Optional[hikari.api.Cache]: return self._cache @property - def checks(self) -> collections.Collection[tanjun.CheckSig]: + def checks(self) -> collections.Collection[tanjun.AnyCheckSig]: """Collection of the level [tanjun.abc.Context][] checks registered to this client. !!! note @@ -1642,7 +1642,7 @@ def set_human_only(self: _ClientT, value: bool = True) -> _ClientT: return self - def add_check(self: _ClientT, check: tanjun.CheckSig, /) -> _ClientT: + def add_check(self: _ClientT, check: tanjun.AnyCheckSig, /) -> _ClientT: """Add a generic check to this client. This will be applied to both message and slash command execution. @@ -1664,7 +1664,7 @@ def add_check(self: _ClientT, check: tanjun.CheckSig, /) -> _ClientT: return self - def remove_check(self: _ClientT, check: tanjun.CheckSig, /) -> _ClientT: + def remove_check(self: _ClientT, check: tanjun.AnyCheckSig, /) -> _ClientT: """Remove a check from the client. Parameters diff --git a/tanjun/commands/base.py b/tanjun/commands/base.py index cd0417b38..cd1180bde 100644 --- a/tanjun/commands/base.py +++ b/tanjun/commands/base.py @@ -42,7 +42,7 @@ from .. import components if typing.TYPE_CHECKING: - _CheckSigT = typing.TypeVar("_CheckSigT", bound=tanjun.CheckSig) + _CheckSigT = typing.TypeVar("_CheckSigT", bound=tanjun.AnyCheckSig) _PartialCommandT = typing.TypeVar("_PartialCommandT", bound="PartialCommand[typing.Any]") @@ -55,13 +55,13 @@ class PartialCommand(tanjun.ExecutableCommand[_ContextT], components.AbstractCom __slots__ = ("_checks", "_component", "_hooks", "_metadata") def __init__(self) -> None: - self._checks: list[tanjun.CheckSig] = [] + self._checks: list[tanjun.AnyCheckSig] = [] self._component: typing.Optional[tanjun.Component] = None self._hooks: typing.Optional[tanjun.Hooks[_ContextT]] = None self._metadata: dict[typing.Any, typing.Any] = {} @property - def checks(self) -> collections.Collection[tanjun.CheckSig]: + def checks(self) -> collections.Collection[tanjun.AnyCheckSig]: # <>. return self._checks.copy() @@ -98,14 +98,14 @@ def set_metadata(self: _PartialCommandT, key: typing.Any, value: typing.Any, /) self._metadata[key] = value return self - def add_check(self: _PartialCommandT, check: tanjun.CheckSig, /) -> _PartialCommandT: + def add_check(self: _PartialCommandT, check: tanjun.AnyCheckSig, /) -> _PartialCommandT: # <>. if check not in self._checks: self._checks.append(check) return self - def remove_check(self: _PartialCommandT, check: tanjun.CheckSig, /) -> _PartialCommandT: + def remove_check(self: _PartialCommandT, check: tanjun.AnyCheckSig, /) -> _PartialCommandT: # <>. self._checks.remove(check) return self diff --git a/tanjun/commands/message.py b/tanjun/commands/message.py index d3cd3d777..5a0675c4f 100644 --- a/tanjun/commands/message.py +++ b/tanjun/commands/message.py @@ -47,30 +47,31 @@ if typing.TYPE_CHECKING: _AnyMessageCommandT = typing.TypeVar("_AnyMessageCommandT", bound=tanjun.MessageCommand[typing.Any]) + _AnyCallbackSigT = typing.TypeVar("_AnyCallbackSigT", bound=tanjun.CommandCallbackSig) _CommandT = typing.Union[ - tanjun.MenuCommand["_CommandCallbackSigT", typing.Any], - tanjun.MessageCommand["_CommandCallbackSigT"], - tanjun.SlashCommand["_CommandCallbackSigT"], + tanjun.MenuCommand["_AnyCallbackSigT", typing.Any], + tanjun.MessageCommand["_AnyCallbackSigT"], + tanjun.SlashCommand["_AnyCallbackSigT"], ] - _CallbackishT = typing.Union[_CommandT["_CommandCallbackSigT"], "_CommandCallbackSigT"] + _CallbackishT = typing.Union[_CommandT["_CallbackSigT"], "_CallbackSigT"] _MessageCommandT = typing.TypeVar("_MessageCommandT", bound="MessageCommand[typing.Any]") _MessageCommandGroupT = typing.TypeVar("_MessageCommandGroupT", bound="MessageCommandGroup[typing.Any]") -_CommandCallbackSigT = typing.TypeVar("_CommandCallbackSigT", bound=tanjun.CommandCallbackSig) +_CallbackSigT = typing.TypeVar("_CallbackSigT", bound=tanjun.MessageCallbackSig) _EMPTY_DICT: typing.Final[dict[typing.Any, typing.Any]] = {} _EMPTY_HOOKS: typing.Final[hooks_.Hooks[typing.Any]] = hooks_.Hooks() class _ResultProto(typing.Protocol): @typing.overload - def __call__(self, _: _CommandT[_CommandCallbackSigT], /) -> MessageCommand[_CommandCallbackSigT]: + def __call__(self, _: _CommandT[_CallbackSigT], /) -> MessageCommand[_CallbackSigT]: ... @typing.overload - def __call__(self, _: _CommandCallbackSigT, /) -> MessageCommand[_CommandCallbackSigT]: + def __call__(self, _: _CallbackSigT, /) -> MessageCommand[_CallbackSigT]: ... - def __call__(self, _: _CallbackishT[_CommandCallbackSigT], /) -> MessageCommand[_CommandCallbackSigT]: + def __call__(self, _: _CallbackishT[_CallbackSigT], /) -> MessageCommand[_CallbackSigT]: raise NotImplementedError @@ -96,9 +97,9 @@ def as_message_command(name: str, /, *names: str) -> _ResultProto: """ def decorator( - callback: _CallbackishT[_CommandCallbackSigT], + callback: _CallbackishT[_CallbackSigT], /, - ) -> MessageCommand[_CommandCallbackSigT]: + ) -> MessageCommand[_CallbackSigT]: if isinstance(callback, (tanjun.MenuCommand, tanjun.MessageCommand, tanjun.SlashCommand)): return MessageCommand(callback.callback, name, *names, _wrapped_command=callback) @@ -109,14 +110,14 @@ def decorator( class _GroupResultProto(typing.Protocol): @typing.overload - def __call__(self, _: _CommandT[_CommandCallbackSigT], /) -> MessageCommandGroup[_CommandCallbackSigT]: + def __call__(self, _: _CommandT[_CallbackSigT], /) -> MessageCommandGroup[_CallbackSigT]: ... @typing.overload - def __call__(self, _: _CommandCallbackSigT, /) -> MessageCommandGroup[_CommandCallbackSigT]: + def __call__(self, _: _CallbackSigT, /) -> MessageCommandGroup[_CallbackSigT]: ... - def __call__(self, _: _CallbackishT[_CommandCallbackSigT], /) -> MessageCommandGroup[_CommandCallbackSigT]: + def __call__(self, _: _CallbackishT[_CallbackSigT], /) -> MessageCommandGroup[_CallbackSigT]: raise NotImplementedError @@ -146,7 +147,7 @@ def as_message_command_group(name: str, /, *names: str, strict: bool = False) -> [tanjun.Component.load_from_scope][]. """ - def decorator(callback: _CallbackishT[_CommandCallbackSigT], /) -> MessageCommandGroup[_CommandCallbackSigT]: + def decorator(callback: _CallbackishT[_CallbackSigT], /) -> MessageCommandGroup[_CallbackSigT]: if isinstance(callback, (tanjun.MenuCommand, tanjun.MessageCommand, tanjun.SlashCommand)): return MessageCommandGroup(callback.callback, name, *names, strict=strict, _wrapped_command=callback) @@ -155,7 +156,7 @@ def decorator(callback: _CallbackishT[_CommandCallbackSigT], /) -> MessageComman return decorator -class MessageCommand(base.PartialCommand[tanjun.MessageContext], tanjun.MessageCommand[_CommandCallbackSigT]): +class MessageCommand(base.PartialCommand[tanjun.MessageContext], tanjun.MessageCommand[_CallbackSigT]): """Standard implementation of a message command.""" __slots__ = ("_callback", "_names", "_parent", "_parser", "_wrapped_command") @@ -163,7 +164,7 @@ class MessageCommand(base.PartialCommand[tanjun.MessageContext], tanjun.MessageC @typing.overload def __init__( self, - callback: _CommandT[_CommandCallbackSigT], + callback: _CommandT[_CallbackSigT], name: str, /, *names: str, @@ -174,7 +175,7 @@ def __init__( @typing.overload def __init__( self, - callback: _CommandCallbackSigT, + callback: _CallbackSigT, name: str, /, *names: str, @@ -184,7 +185,7 @@ def __init__( def __init__( self, - callback: _CallbackishT[_CommandCallbackSigT], + callback: _CallbackishT[_CallbackSigT], name: str, /, *names: str, @@ -209,7 +210,7 @@ def __init__( if isinstance(callback, (tanjun.MenuCommand, tanjun.MessageCommand, tanjun.SlashCommand)): callback = callback.callback - self._callback: _CommandCallbackSigT = callback + self._callback: _CallbackSigT = callback self._names = list(dict.fromkeys((name, *names))) self._parent: typing.Optional[tanjun.MessageCommandGroup[typing.Any]] = None self._parser: typing.Optional[tanjun.MessageParser] = None @@ -219,7 +220,7 @@ def __repr__(self) -> str: return f"Command <{self._names}>" if typing.TYPE_CHECKING: - __call__: _CommandCallbackSigT + __call__: _CallbackSigT else: @@ -227,7 +228,7 @@ async def __call__(self, *args, **kwargs) -> None: await self._callback(*args, **kwargs) @property - def callback(self) -> _CommandCallbackSigT: + def callback(self) -> _CallbackSigT: # <>. return self._callback @@ -339,7 +340,7 @@ def load_into_component(self, component: tanjun.Component, /) -> None: self._wrapped_command.load_into_component(component) -class MessageCommandGroup(MessageCommand[_CommandCallbackSigT], tanjun.MessageCommandGroup[_CommandCallbackSigT]): +class MessageCommandGroup(MessageCommand[_CallbackSigT], tanjun.MessageCommandGroup[_CallbackSigT]): """Standard implementation of a message command group.""" __slots__ = ("_commands", "_is_strict", "_names_to_commands") @@ -347,7 +348,7 @@ class MessageCommandGroup(MessageCommand[_CommandCallbackSigT], tanjun.MessageCo @typing.overload def __init__( self, - callback: _CommandT[_CommandCallbackSigT], + callback: _CommandT[_CallbackSigT], name: str, /, *names: str, @@ -359,7 +360,7 @@ def __init__( @typing.overload def __init__( self, - callback: _CommandCallbackSigT, + callback: _CallbackSigT, name: str, /, *names: str, @@ -370,7 +371,7 @@ def __init__( def __init__( self, - callback: _CallbackishT[_CommandCallbackSigT], + callback: _CallbackishT[_CallbackSigT], name: str, /, *names: str, diff --git a/tanjun/components.py b/tanjun/components.py index a499aef61..ab160c0b0 100644 --- a/tanjun/components.py +++ b/tanjun/components.py @@ -55,7 +55,7 @@ _AppCommandContextT = typing.TypeVar("_AppCommandContextT", bound=tanjun.AppCommandContext) _BaseSlashCommandT = typing.TypeVar("_BaseSlashCommandT", bound=tanjun.BaseSlashCommand) - _CheckSigT = typing.TypeVar("_CheckSigT", bound=tanjun.CheckSig) + _CheckSigT = typing.TypeVar("_CheckSigT", bound=tanjun.AnyCheckSig) _ComponentT = typing.TypeVar("_ComponentT", bound="Component") _ListenerCallbackSigT = typing.TypeVar("_ListenerCallbackSigT", bound=tanjun.ListenerCallbackSig) _MenuCommandT = typing.TypeVar("_MenuCommandT", bound=tanjun.MenuCommand[typing.Any, typing.Any]) @@ -197,7 +197,7 @@ def __init__(self, *, name: typing.Optional[str] = None, strict: bool = False) - When this is [True][], message command names will not be allowed to contain spaces and will have to be unique to one command within the component. """ - self._checks: list[tanjun.CheckSig] = [] + self._checks: list[tanjun.AnyCheckSig] = [] self._client: typing.Optional[tanjun.Client] = None self._client_callbacks: dict[str, list[tanjun.MetaEventSig]] = {} self._defaults_to_ephemeral: typing.Optional[bool] = None @@ -223,7 +223,7 @@ def __repr__(self) -> str: return f"{type(self).__name__}({self.checks=}, {self.hooks=}, {self.slash_hooks=}, {self.message_hooks=})" @property - def checks(self) -> collections.Collection[tanjun.CheckSig]: + def checks(self) -> collections.Collection[tanjun.AnyCheckSig]: """Collection of the checks being run against every command execution in this component.""" return self._checks.copy() @@ -498,7 +498,7 @@ def set_slash_hooks(self: _ComponentT, hooks: typing.Optional[tanjun.SlashHooks] self._slash_hooks = hooks return self - def add_check(self: _ComponentT, check: tanjun.CheckSig, /) -> _ComponentT: + def add_check(self: _ComponentT, check: tanjun.AnyCheckSig, /) -> _ComponentT: """Add a command check to this component to be used for all its commands. Parameters @@ -516,7 +516,7 @@ def add_check(self: _ComponentT, check: tanjun.CheckSig, /) -> _ComponentT: return self - def remove_check(self: _ComponentT, check: tanjun.CheckSig, /) -> _ComponentT: + def remove_check(self: _ComponentT, check: tanjun.AnyCheckSig, /) -> _ComponentT: """Remove a command check from this component. Parameters diff --git a/tanjun/utilities.py b/tanjun/utilities.py index f1f8f0435..08b6c15fe 100644 --- a/tanjun/utilities.py +++ b/tanjun/utilities.py @@ -56,12 +56,14 @@ if typing.TYPE_CHECKING: from . import abc + _ContextT = typing.TypeVar("_ContextT", bound=abc.Context) + _KeyT = typing.TypeVar("_KeyT") _OtherValueT = typing.TypeVar("_OtherValueT") _ValueT = typing.TypeVar("_ValueT") -async def _execute_check(ctx: abc.Context, callback: abc.CheckSig, /) -> bool: +async def _execute_check(ctx: _ContextT, callback: abc.CheckSig[_ContextT], /) -> bool: foo = ctx.call_with_async_di(callback, ctx) if result := await foo: return result @@ -69,14 +71,14 @@ async def _execute_check(ctx: abc.Context, callback: abc.CheckSig, /) -> bool: raise errors.FailedCheck -async def gather_checks(ctx: abc.Context, checks: collections.Iterable[abc.CheckSig], /) -> bool: +async def gather_checks(ctx: _ContextT, checks: collections.Iterable[abc.CheckSig[_ContextT]], /) -> bool: """Gather a collection of checks. Parameters ---------- - ctx + ctx : tanjun.abc.Context The context to check. - checks + checks : tanjun.abc.CheckSig An iterable of injectable checks. Returns