diff --git a/tanjun/abc.py b/tanjun/abc.py index 76432fded..1176b9ac0 100644 --- a/tanjun/abc.py +++ b/tanjun/abc.py @@ -2208,7 +2208,7 @@ async def on_error(ctx: tanjun.abc.Context, error: Exception) -> bool: Parameters ---------- - callback : ErrorHookSig + callback : tanjun.abc.ErrorHookSig The callback to add to this hook. This callback should take two positional arguments (of type @@ -2284,7 +2284,7 @@ async def on_parser_error(ctx: tanjun.abc.Context, error: tanjun.ParserError) -> Parameters ---------- - callback : HookSig + callback : tanjun.abc.ParserHookSig The parser error callback to add to this hook. This callback should take two positional arguments (of type @@ -2352,7 +2352,7 @@ async def post_execution(ctx: tanjun.abc.Context) -> None: Parameters ---------- - callback : HookSig + callback : tanjun.abc.HookSig The post-execution callback to add to this hook. This callback should take one positional argument (of type @@ -2420,7 +2420,7 @@ async def pre_execution(ctx: tanjun.abc.Context) -> None: Parameters ---------- - callback : HookSig + callback : tanjun.abc.HookSig The pre-execution callback to add to this hook. This callback should take one positional argument (of type @@ -2488,7 +2488,7 @@ async def on_success(ctx: tanjun.abc.Context) -> None: Parameters ---------- - callback : HookSig + callback : tanjun.abc.HookSig The success callback to add to this hook. This callback should take one positional argument (of type diff --git a/tanjun/checks.py b/tanjun/checks.py index 4c67684ff..b47d2afa1 100644 --- a/tanjun/checks.py +++ b/tanjun/checks.py @@ -1046,9 +1046,51 @@ def all_checks( return _AllChecks[_ContextT]([check, *checks]) +@typing.overload def with_all_checks( check: tanjun.AnyCheckSig, /, *checks: tanjun.AnyCheckSig, follow_wrapped: bool = False -) -> collections.Callable[[_CommandT], _CommandT]: # TODO: specialise with overloading +) -> collections.Callable[[_CommandT], _CommandT]: + ... + + +@typing.overload +def with_all_checks( + check: tanjun.CheckSig[tanjun.MenuContext], + /, + *checks: tanjun.CheckSig[tanjun.MenuContext], + follow_wrapped: bool = False, +) -> collections.Callable[[_MenuCommandT], _MenuCommandT]: + ... + + +@typing.overload +def with_all_checks( + check: tanjun.CheckSig[tanjun.MessageContext], + /, + *checks: tanjun.CheckSig[tanjun.MessageContext], + follow_wrapped: bool = False, +) -> collections.Callable[[_MessageCommandT], _MessageCommandT]: + ... + + +@typing.overload +def with_all_checks( + check: tanjun.CheckSig[tanjun.SlashContext], + /, + *checks: tanjun.CheckSig[tanjun.SlashContext], + follow_wrapped: bool = False, +) -> collections.Callable[[_SlashCommandT], _SlashCommandT]: + ... + + +def with_all_checks( + check: tanjun.CheckSig[typing.Any], /, *checks: tanjun.CheckSig[typing.Any], follow_wrapped: bool = False +) -> ( + collections.Callable[[_CommandT], _CommandT] + | collections.Callable[[_MenuCommandT], _MenuCommandT] + | collections.Callable[[_MessageCommandT], _MessageCommandT] + | collections.Callable[[_SlashCommandT], _SlashCommandT] +): """Add a check which will pass if all the provided checks pass through a decorator call. This ensures that the callbacks are run in the order they were supplied in @@ -1149,7 +1191,7 @@ def any_checks( return _AnyChecks[_ContextT]([check, *checks], error, error_message, halt_execution, suppress) -# Specialising this would be too much boiler plate so for now this just isn't getting that feature. +@typing.overload def with_any_checks( check: tanjun.AnyCheckSig, /, @@ -1160,6 +1202,66 @@ def with_any_checks( halt_execution: bool = False, suppress: tuple[type[Exception], ...] = (errors.CommandError, errors.HaltExecution), ) -> collections.Callable[[_CommandT], _CommandT]: + ... + + +@typing.overload +def with_any_checks( + check: tanjun.CheckSig[tanjun.MenuContext], + /, + *checks: tanjun.CheckSig[tanjun.MenuContext], + error: collections.Callable[[], Exception] | None = None, + error_message: str | collections.Mapping[str, str] | None, + follow_wrapped: bool = False, + halt_execution: bool = False, + suppress: tuple[type[Exception], ...] = (errors.CommandError, errors.HaltExecution), +) -> collections.Callable[[_MenuCommandT], _MenuCommandT]: + ... + + +@typing.overload +def with_any_checks( + check: tanjun.CheckSig[tanjun.MessageContext], + /, + *checks: tanjun.CheckSig[tanjun.MessageContext], + error: collections.Callable[[], Exception] | None = None, + error_message: str | collections.Mapping[str, str] | None, + follow_wrapped: bool = False, + halt_execution: bool = False, + suppress: tuple[type[Exception], ...] = (errors.CommandError, errors.HaltExecution), +) -> collections.Callable[[_MessageCommandT], _MessageCommandT]: + ... + + +@typing.overload +def with_any_checks( + check: tanjun.CheckSig[tanjun.SlashContext], + /, + *checks: tanjun.CheckSig[tanjun.SlashContext], + error: collections.Callable[[], Exception] | None = None, + error_message: str | collections.Mapping[str, str] | None, + follow_wrapped: bool = False, + halt_execution: bool = False, + suppress: tuple[type[Exception], ...] = (errors.CommandError, errors.HaltExecution), +) -> collections.Callable[[_SlashCommandT], _SlashCommandT]: + ... + + +def with_any_checks( + check: tanjun.CheckSig[typing.Any], + /, + *checks: tanjun.CheckSig[typing.Any], + error: collections.Callable[[], Exception] | None = None, + error_message: str | collections.Mapping[str, str] | None, + follow_wrapped: bool = False, + halt_execution: bool = False, + suppress: tuple[type[Exception], ...] = (errors.CommandError, errors.HaltExecution), +) -> ( + collections.Callable[[_CommandT], _CommandT] + | collections.Callable[[_MenuCommandT], _MenuCommandT] + | collections.Callable[[_MessageCommandT], _MessageCommandT] + | collections.Callable[[_SlashCommandT], _SlashCommandT] +): """Add a check which'll pass if any of the provided checks pass through a decorator call. This ensures that the callbacks are run in the order they were supplied in diff --git a/tanjun/clients.py b/tanjun/clients.py index a2f71a10c..96c8efc02 100644 --- a/tanjun/clients.py +++ b/tanjun/clients.py @@ -1097,6 +1097,7 @@ def events(self) -> hikari.api.EventManager | None: def listeners( self, ) -> collections.Mapping[type[hikari.Event], collections.Collection[tanjun.ListenerCallbackSig[typing.Any]]]: + # <>. return _internal.CastedView(self._listeners, lambda x: [callback.callback for callback in x.values()]) @property diff --git a/tanjun/commands/menu.py b/tanjun/commands/menu.py index bc32e023a..14b50238b 100644 --- a/tanjun/commands/menu.py +++ b/tanjun/commands/menu.py @@ -431,7 +431,7 @@ def __init__( Parameters ---------- - callback : collections.abc.Callable[[tanjun.abc.MenuContext, ...], collections.abc.Coroutine[Any, Any, None]] + callback : tanjun.abc.MenuCallbackSig Callback to execute when the command is invoked. This should be an asynchronous callback which takes one positional diff --git a/tanjun/commands/message.py b/tanjun/commands/message.py index 371c3d1af..26d1bf9f6 100644 --- a/tanjun/commands/message.py +++ b/tanjun/commands/message.py @@ -214,7 +214,7 @@ def __init__( Parameters ---------- - callback : collections.abc.Callable[[tanjun.abc.MessageContext, ...], collections.abc.Coroutine[None]] + callback : tanjun.abc.MessageCallbackSig Callback to execute when the command is invoked. This should be an asynchronous callback which takes one positional @@ -414,7 +414,7 @@ def __init__( Parameters ---------- - callback : collections.abc.Callable[[tanjun.abc.MessageContext, ...], collections.abc.Coroutine[None]] + callback : tanjun.abc.MessageCallbackSig Callback to execute when the command is invoked. This should be an asynchronous callback which takes one positional diff --git a/tanjun/commands/slash.py b/tanjun/commands/slash.py index f6c65186c..827c4d8f0 100644 --- a/tanjun/commands/slash.py +++ b/tanjun/commands/slash.py @@ -1508,7 +1508,7 @@ def __init__( Parameters ---------- - callback : collections.abc.Callable[[tanjun.abc.SlashContext, ...], collections.abc.Coroutine[Any, Any, None]] + callback : tanjun.abc.SlashCallbackSig Callback to execute when the command is invoked. This should be an asynchronous callback which takes one positional diff --git a/tests/test_annotations_future_annotations.py b/tests/test_annotations_future_annotations.py index c2eff4df9..f4fa87d7d 100644 --- a/tests/test_annotations_future_annotations.py +++ b/tests/test_annotations_future_annotations.py @@ -662,7 +662,7 @@ async def callback( with pytest.raises( TypeError, match=f"Choice of type {mismatched_type.__name__} is not valid for a {type_repr.__name__} argument" ): - annotations.with_annotated_args(callback) + annotations.with_annotated_args(callback) # pyright: ignore[reportUnknownArgumentType] def test_with_generic_float_choices(): diff --git a/tests/test_clients.py b/tests/test_clients.py index c448c8946..0b725722a 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -1233,7 +1233,7 @@ class StubClient(tanjun.Client): client = StubClient(mock.Mock()) with pytest.raises(ValueError, match="Missing event argument annotation"): - client.with_listener()(callback) + client.with_listener()(callback) # pyright: ignore[reportUnknownArgumentType] add_listener_.assert_not_called() diff --git a/tests/test_clients_future_annotations.py b/tests/test_clients_future_annotations.py index 52d12ac5e..407dae84d 100644 --- a/tests/test_clients_future_annotations.py +++ b/tests/test_clients_future_annotations.py @@ -53,7 +53,7 @@ class StubClient(tanjun.Client): client = StubClient(mock.Mock()) with pytest.raises(ValueError, match="Missing event argument annotation"): - client.with_listener()(callback) + client.with_listener()(callback) # pyright: ignore[reportUnknownArgumentType] add_listener_.assert_not_called() diff --git a/tests/test_components.py b/tests/test_components.py index 1c7baf1bd..060b6778d 100644 --- a/tests/test_components.py +++ b/tests/test_components.py @@ -1200,7 +1200,7 @@ async def callback(foo) -> None: # type: ignore )() with pytest.raises(ValueError, match="Missing event argument annotation"): - component.with_listener()(callback) + component.with_listener()(callback) # pyright: ignore[reportUnknownArgumentType] add_listener.assert_not_called() diff --git a/tests/test_components_future_annotations.py b/tests/test_components_future_annotations.py index f10f6ad0f..e6e9562b9 100644 --- a/tests/test_components_future_annotations.py +++ b/tests/test_components_future_annotations.py @@ -52,7 +52,7 @@ async def callback(foo) -> None: # type: ignore )() with pytest.raises(ValueError, match="Missing event argument annotation"): - component.with_listener()(callback) + component.with_listener()(callback) # pyright: ignore[reportUnknownArgumentType] add_listener.assert_not_called()