-
-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Message and slash checks for components #154
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -132,6 +132,10 @@ class Component(abc.Component): | |
---------- | ||
checks : typing.Optional[collections.abc.Iterable[abc.CheckSig]] | ||
Iterable of check callbacks to set for this component, if provided. | ||
slash_checks : typing.Optional[collections.abc.Iterable[tanjun.abc.CheckSig]] | ||
The slash check callbacks to set for the component, if provided. | ||
message_checks : typing.Optional[collections.abc.Iterable[tanjun.abc.CheckSig]] | ||
The message check callbacks to set for the component, if provided. | ||
hooks : typing.Optional[tanjun.abc.AnyHooks] | ||
The hooks this component should add to the execution of all its | ||
commands (message and slash). | ||
|
@@ -162,18 +166,22 @@ class Component(abc.Component): | |
"_is_strict", | ||
"_listeners", | ||
"_message_commands", | ||
"_message_checks", | ||
"_message_hooks", | ||
"_metadata", | ||
"_name", | ||
"_names_to_commands", | ||
"_slash_commands", | ||
"_slash_checks", | ||
"_slash_hooks", | ||
) | ||
|
||
def __init__( | ||
self, | ||
*, | ||
checks: typing.Optional[collections.Iterable[abc.CheckSig]] = None, | ||
slash_checks: typing.Optional[collections.Iterable[abc.CheckSig]] = None, | ||
message_checks: typing.Optional[collections.Iterable[abc.CheckSig]] = None, | ||
hooks: typing.Optional[abc.AnyHooks] = None, | ||
slash_hooks: typing.Optional[abc.SlashHooks] = None, | ||
message_hooks: typing.Optional[abc.MessageHooks] = None, | ||
|
@@ -191,11 +199,17 @@ def __init__( | |
self._is_strict = strict | ||
self._listeners: dict[type[base_events.Event], list[abc.ListenerCallbackSig]] = {} | ||
self._message_commands: list[abc.MessageCommand] = [] | ||
self._message_checks: list[checks_.InjectableCheck] = ( | ||
[checks_.InjectableCheck(check) for check in dict.fromkeys(slash_checks)] if slash_checks else [] | ||
) | ||
self._message_hooks = message_hooks | ||
self._metadata: dict[typing.Any, typing.Any] = {} | ||
self._name = name or base64.b64encode(random.randbytes(32)).decode() | ||
self._names_to_commands: dict[str, abc.MessageCommand] = {} | ||
self._slash_commands: dict[str, abc.BaseSlashCommand] = {} | ||
self._slash_checks: list[checks_.InjectableCheck] = ( | ||
[checks_.InjectableCheck(check) for check in dict.fromkeys(message_checks)] if message_checks else [] | ||
) | ||
self._slash_hooks = slash_hooks | ||
|
||
if load_from_attributes and type(self) is not Component: # No need to run this on the base class. | ||
|
@@ -228,6 +242,10 @@ def name(self) -> str: | |
def slash_commands(self) -> collections.Collection[abc.BaseSlashCommand]: | ||
return self._slash_commands.copy().values() | ||
|
||
@property | ||
def slash_checks(self) -> collections.Collection[abc.CheckSig]: | ||
return tuple(check.callback for check in self._slash_checks) | ||
|
||
@property | ||
def slash_hooks(self) -> typing.Optional[abc.SlashHooks]: | ||
return self._slash_hooks | ||
|
@@ -236,13 +254,17 @@ def slash_hooks(self) -> typing.Optional[abc.SlashHooks]: | |
def message_commands(self) -> collections.Collection[abc.MessageCommand]: | ||
return self._message_commands.copy() | ||
|
||
@property | ||
def message_checks(self) -> collections.Collection[abc.CheckSig]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we get a quick 1 line docstring here explaining it |
||
return tuple(check.callback for check in self._message_checks) | ||
|
||
@property | ||
def message_hooks(self) -> typing.Optional[abc.MessageHooks]: | ||
return self._message_hooks | ||
|
||
@property | ||
def needs_injector(self) -> bool: | ||
return any(check.needs_injector for check in self._checks) | ||
return any(check.needs_injector for check in self._checks + self._message_checks + self._slash_checks) | ||
|
||
@property | ||
def listeners( | ||
|
@@ -258,12 +280,14 @@ def copy(self: _ComponentT, *, _new: bool = True) -> _ComponentT: | |
if not _new: | ||
self._checks = [check.copy() for check in self._checks] | ||
self._slash_commands = {name: command.copy() for name, command in self._slash_commands.items()} | ||
self._slash_checks = [check.copy() for check in self._slash_checks] | ||
self._hooks = self._hooks.copy() if self._hooks else None | ||
self._listeners = { | ||
event: [copy.copy(listener) for listener in listeners] for event, listeners in self._listeners.items() | ||
} | ||
commands = {command: command.copy() for command in self._message_commands} | ||
self._message_commands = list(commands.values()) | ||
self._message_checks = [check.copy() for check in self._message_checks] | ||
self._metadata = self._metadata.copy() | ||
self._names_to_commands = {name: commands[command] for name, command in self._names_to_commands.items()} | ||
return self | ||
|
@@ -317,6 +341,34 @@ def with_check(self, check: abc.CheckSigT, /) -> abc.CheckSigT: | |
self.add_check(check) | ||
return check | ||
|
||
def add_slash_check(self: _ComponentT, check: abc.CheckSig, /) -> _ComponentT: | ||
if check not in self._slash_checks: | ||
self._slash_checks.append(checks_.InjectableCheck(check)) | ||
|
||
return self | ||
|
||
def remove_slash_check(self: _ComponentT, check: abc.CheckSig, /) -> _ComponentT: | ||
self._slash_checks.remove(typing.cast("checks_.InjectableCheck", check)) | ||
return self | ||
|
||
def with_slash_check(self, check: abc.CheckSigT, /) -> abc.CheckSigT: | ||
self.add_slash_check(check) | ||
return check | ||
|
||
def add_message_check(self: _ComponentT, check: abc.CheckSig, /) -> _ComponentT: | ||
if check not in self._message_checks: | ||
self._message_checks.append(checks_.InjectableCheck(check)) | ||
|
||
return self | ||
|
||
def remove_message_check(self: _ComponentT, check: abc.CheckSig, /) -> _ComponentT: | ||
self._message_checks.remove(typing.cast("checks_.InjectableCheck", check)) | ||
return self | ||
|
||
def with_message_check(self, check: abc.CheckSigT, /) -> abc.CheckSigT: | ||
self.add_message_check(check) | ||
return check | ||
|
||
def add_client_callback(self: _ComponentT, event_name: str, callback: abc.MetaEventSig, /) -> _ComponentT: | ||
event_name = event_name.lower() | ||
try: | ||
|
@@ -576,7 +628,11 @@ def unbind_client(self, client: abc.Client, /) -> None: | |
self._client = None | ||
|
||
async def _check_context(self, ctx: abc.Context, /) -> bool: | ||
return await utilities.gather_checks(ctx, self._checks) | ||
if abc.SlashContext in type(ctx).mro(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than instance check here, we could just specifically call |
||
additional_checks = self._slash_checks | ||
else: | ||
additional_checks = self._message_checks | ||
return await utilities.gather_checks(ctx, self._checks + additional_checks) | ||
|
||
async def _check_message_context( | ||
self, ctx: abc.MessageContext, / | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we get a quick 1 line docstring here explaining it