Skip to content

Commit

Permalink
Some doc and type-checking fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
FasterSpeeding committed Jan 25, 2023
1 parent b6363ca commit 8b85188
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 16 deletions.
10 changes: 5 additions & 5 deletions tanjun/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2208,7 +2208,7 @@ async def on_error(ctx: tanjun.abc.Context, error: Exception) -> bool:
Parameters
----------
callback : ErrorHookSig
callback : tanjun.abc.ErrorHookSig
The callback to add to this hook.
This callback should take two positional arguments (of type
Expand Down Expand Up @@ -2284,7 +2284,7 @@ async def on_parser_error(ctx: tanjun.abc.Context, error: tanjun.ParserError) ->
Parameters
----------
callback : HookSig
callback : tanjun.abc.ParserHookSig
The parser error callback to add to this hook.
This callback should take two positional arguments (of type
Expand Down Expand Up @@ -2352,7 +2352,7 @@ async def post_execution(ctx: tanjun.abc.Context) -> None:
Parameters
----------
callback : HookSig
callback : tanjun.abc.HookSig
The post-execution callback to add to this hook.
This callback should take one positional argument (of type
Expand Down Expand Up @@ -2420,7 +2420,7 @@ async def pre_execution(ctx: tanjun.abc.Context) -> None:
Parameters
----------
callback : HookSig
callback : tanjun.abc.HookSig
The pre-execution callback to add to this hook.
This callback should take one positional argument (of type
Expand Down Expand Up @@ -2488,7 +2488,7 @@ async def on_success(ctx: tanjun.abc.Context) -> None:
Parameters
----------
callback : HookSig
callback : tanjun.abc.HookSig
The success callback to add to this hook.
This callback should take one positional argument (of type
Expand Down
106 changes: 104 additions & 2 deletions tanjun/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,9 +1046,51 @@ def all_checks(
return _AllChecks[_ContextT]([check, *checks])


@typing.overload
def with_all_checks(
check: tanjun.AnyCheckSig, /, *checks: tanjun.AnyCheckSig, follow_wrapped: bool = False
) -> collections.Callable[[_CommandT], _CommandT]: # TODO: specialise with overloading
) -> collections.Callable[[_CommandT], _CommandT]:
...


@typing.overload
def with_all_checks(
check: tanjun.CheckSig[tanjun.MenuContext],
/,
*checks: tanjun.CheckSig[tanjun.MenuContext],
follow_wrapped: bool = False,
) -> collections.Callable[[_MenuCommandT], _MenuCommandT]:
...


@typing.overload
def with_all_checks(
check: tanjun.CheckSig[tanjun.MessageContext],
/,
*checks: tanjun.CheckSig[tanjun.MessageContext],
follow_wrapped: bool = False,
) -> collections.Callable[[_MessageCommandT], _MessageCommandT]:
...


@typing.overload
def with_all_checks(
check: tanjun.CheckSig[tanjun.SlashContext],
/,
*checks: tanjun.CheckSig[tanjun.SlashContext],
follow_wrapped: bool = False,
) -> collections.Callable[[_SlashCommandT], _SlashCommandT]:
...


def with_all_checks(
check: tanjun.CheckSig[typing.Any], /, *checks: tanjun.CheckSig[typing.Any], follow_wrapped: bool = False
) -> (
collections.Callable[[_CommandT], _CommandT]
| collections.Callable[[_MenuCommandT], _MenuCommandT]
| collections.Callable[[_MessageCommandT], _MessageCommandT]
| collections.Callable[[_SlashCommandT], _SlashCommandT]
):
"""Add a check which will pass if all the provided checks pass through a decorator call.
This ensures that the callbacks are run in the order they were supplied in
Expand Down Expand Up @@ -1149,7 +1191,7 @@ def any_checks(
return _AnyChecks[_ContextT]([check, *checks], error, error_message, halt_execution, suppress)


# Specialising this would be too much boiler plate so for now this just isn't getting that feature.
@typing.overload
def with_any_checks(
check: tanjun.AnyCheckSig,
/,
Expand All @@ -1160,6 +1202,66 @@ def with_any_checks(
halt_execution: bool = False,
suppress: tuple[type[Exception], ...] = (errors.CommandError, errors.HaltExecution),
) -> collections.Callable[[_CommandT], _CommandT]:
...


@typing.overload
def with_any_checks(
check: tanjun.CheckSig[tanjun.MenuContext],
/,
*checks: tanjun.CheckSig[tanjun.MenuContext],
error: collections.Callable[[], Exception] | None = None,
error_message: str | collections.Mapping[str, str] | None,
follow_wrapped: bool = False,
halt_execution: bool = False,
suppress: tuple[type[Exception], ...] = (errors.CommandError, errors.HaltExecution),
) -> collections.Callable[[_MenuCommandT], _MenuCommandT]:
...


@typing.overload
def with_any_checks(
check: tanjun.CheckSig[tanjun.MessageContext],
/,
*checks: tanjun.CheckSig[tanjun.MessageContext],
error: collections.Callable[[], Exception] | None = None,
error_message: str | collections.Mapping[str, str] | None,
follow_wrapped: bool = False,
halt_execution: bool = False,
suppress: tuple[type[Exception], ...] = (errors.CommandError, errors.HaltExecution),
) -> collections.Callable[[_MessageCommandT], _MessageCommandT]:
...


@typing.overload
def with_any_checks(
check: tanjun.CheckSig[tanjun.SlashContext],
/,
*checks: tanjun.CheckSig[tanjun.SlashContext],
error: collections.Callable[[], Exception] | None = None,
error_message: str | collections.Mapping[str, str] | None,
follow_wrapped: bool = False,
halt_execution: bool = False,
suppress: tuple[type[Exception], ...] = (errors.CommandError, errors.HaltExecution),
) -> collections.Callable[[_SlashCommandT], _SlashCommandT]:
...


def with_any_checks(
check: tanjun.CheckSig[typing.Any],
/,
*checks: tanjun.CheckSig[typing.Any],
error: collections.Callable[[], Exception] | None = None,
error_message: str | collections.Mapping[str, str] | None,
follow_wrapped: bool = False,
halt_execution: bool = False,
suppress: tuple[type[Exception], ...] = (errors.CommandError, errors.HaltExecution),
) -> (
collections.Callable[[_CommandT], _CommandT]
| collections.Callable[[_MenuCommandT], _MenuCommandT]
| collections.Callable[[_MessageCommandT], _MessageCommandT]
| collections.Callable[[_SlashCommandT], _SlashCommandT]
):
"""Add a check which'll pass if any of the provided checks pass through a decorator call.
This ensures that the callbacks are run in the order they were supplied in
Expand Down
1 change: 1 addition & 0 deletions tanjun/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,7 @@ def events(self) -> hikari.api.EventManager | None:
def listeners(
self,
) -> collections.Mapping[type[hikari.Event], collections.Collection[tanjun.ListenerCallbackSig[typing.Any]]]:
# <<inherited docstring from tanjun.abc.Client>>.
return _internal.CastedView(self._listeners, lambda x: [callback.callback for callback in x.values()])

@property
Expand Down
2 changes: 1 addition & 1 deletion tanjun/commands/menu.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def __init__(
Parameters
----------
callback : collections.abc.Callable[[tanjun.abc.MenuContext, ...], collections.abc.Coroutine[Any, Any, None]]
callback : tanjun.abc.MenuCallbackSig
Callback to execute when the command is invoked.
This should be an asynchronous callback which takes one positional
Expand Down
4 changes: 2 additions & 2 deletions tanjun/commands/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def __init__(
Parameters
----------
callback : collections.abc.Callable[[tanjun.abc.MessageContext, ...], collections.abc.Coroutine[None]]
callback : tanjun.abc.MessageCallbackSig
Callback to execute when the command is invoked.
This should be an asynchronous callback which takes one positional
Expand Down Expand Up @@ -414,7 +414,7 @@ def __init__(
Parameters
----------
callback : collections.abc.Callable[[tanjun.abc.MessageContext, ...], collections.abc.Coroutine[None]]
callback : tanjun.abc.MessageCallbackSig
Callback to execute when the command is invoked.
This should be an asynchronous callback which takes one positional
Expand Down
2 changes: 1 addition & 1 deletion tanjun/commands/slash.py
Original file line number Diff line number Diff line change
Expand Up @@ -1508,7 +1508,7 @@ def __init__(
Parameters
----------
callback : collections.abc.Callable[[tanjun.abc.SlashContext, ...], collections.abc.Coroutine[Any, Any, None]]
callback : tanjun.abc.SlashCallbackSig
Callback to execute when the command is invoked.
This should be an asynchronous callback which takes one positional
Expand Down
2 changes: 1 addition & 1 deletion tests/test_annotations_future_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ async def callback(
with pytest.raises(
TypeError, match=f"Choice of type {mismatched_type.__name__} is not valid for a {type_repr.__name__} argument"
):
annotations.with_annotated_args(callback)
annotations.with_annotated_args(callback) # pyright: ignore[reportUnknownArgumentType]


def test_with_generic_float_choices():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,7 +1233,7 @@ class StubClient(tanjun.Client):
client = StubClient(mock.Mock())

with pytest.raises(ValueError, match="Missing event argument annotation"):
client.with_listener()(callback)
client.with_listener()(callback) # pyright: ignore[reportUnknownArgumentType]

add_listener_.assert_not_called()

Expand Down
2 changes: 1 addition & 1 deletion tests/test_clients_future_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class StubClient(tanjun.Client):
client = StubClient(mock.Mock())

with pytest.raises(ValueError, match="Missing event argument annotation"):
client.with_listener()(callback)
client.with_listener()(callback) # pyright: ignore[reportUnknownArgumentType]

add_listener_.assert_not_called()

Expand Down
2 changes: 1 addition & 1 deletion tests/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,7 +1200,7 @@ async def callback(foo) -> None: # type: ignore
)()

with pytest.raises(ValueError, match="Missing event argument annotation"):
component.with_listener()(callback)
component.with_listener()(callback) # pyright: ignore[reportUnknownArgumentType]

add_listener.assert_not_called()

Expand Down
2 changes: 1 addition & 1 deletion tests/test_components_future_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async def callback(foo) -> None: # type: ignore
)()

with pytest.raises(ValueError, match="Missing event argument annotation"):
component.with_listener()(callback)
component.with_listener()(callback) # pyright: ignore[reportUnknownArgumentType]

add_listener.assert_not_called()

Expand Down

0 comments on commit 8b85188

Please sign in to comment.