Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Message and slash checks for components #154

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions tanjun/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ class Component(abc.Component):
----------
checks : typing.Optional[collections.abc.Iterable[abc.CheckSig]]
Iterable of check callbacks to set for this component, if provided.
slash_checks : typing.Optional[collections.abc.Iterable[tanjun.abc.CheckSig]]
The slash check callbacks to set for the component, if provided.
message_checks : typing.Optional[collections.abc.Iterable[tanjun.abc.CheckSig]]
The message check callbacks to set for the component, if provided.
hooks : typing.Optional[tanjun.abc.AnyHooks]
The hooks this component should add to the execution of all its
commands (message and slash).
Expand Down Expand Up @@ -162,18 +166,22 @@ class Component(abc.Component):
"_is_strict",
"_listeners",
"_message_commands",
"_message_checks",
"_message_hooks",
"_metadata",
"_name",
"_names_to_commands",
"_slash_commands",
"_slash_checks",
"_slash_hooks",
)

def __init__(
self,
*,
checks: typing.Optional[collections.Iterable[abc.CheckSig]] = None,
slash_checks: typing.Optional[collections.Iterable[abc.CheckSig]] = None,
message_checks: typing.Optional[collections.Iterable[abc.CheckSig]] = None,
hooks: typing.Optional[abc.AnyHooks] = None,
slash_hooks: typing.Optional[abc.SlashHooks] = None,
message_hooks: typing.Optional[abc.MessageHooks] = None,
Expand All @@ -191,11 +199,17 @@ def __init__(
self._is_strict = strict
self._listeners: dict[type[base_events.Event], list[abc.ListenerCallbackSig]] = {}
self._message_commands: list[abc.MessageCommand] = []
self._message_checks: list[checks_.InjectableCheck] = (
[checks_.InjectableCheck(check) for check in dict.fromkeys(slash_checks)] if slash_checks else []
)
self._message_hooks = message_hooks
self._metadata: dict[typing.Any, typing.Any] = {}
self._name = name or base64.b64encode(random.randbytes(32)).decode()
self._names_to_commands: dict[str, abc.MessageCommand] = {}
self._slash_commands: dict[str, abc.BaseSlashCommand] = {}
self._slash_checks: list[checks_.InjectableCheck] = (
[checks_.InjectableCheck(check) for check in dict.fromkeys(message_checks)] if message_checks else []
)
self._slash_hooks = slash_hooks

if load_from_attributes and type(self) is not Component: # No need to run this on the base class.
Expand Down Expand Up @@ -228,6 +242,10 @@ def name(self) -> str:
def slash_commands(self) -> collections.Collection[abc.BaseSlashCommand]:
return self._slash_commands.copy().values()

@property
def slash_checks(self) -> collections.Collection[abc.CheckSig]:
return tuple(check.callback for check in self._slash_checks)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we get a quick 1 line docstring here explaining it


@property
def slash_hooks(self) -> typing.Optional[abc.SlashHooks]:
return self._slash_hooks
Expand All @@ -236,13 +254,17 @@ def slash_hooks(self) -> typing.Optional[abc.SlashHooks]:
def message_commands(self) -> collections.Collection[abc.MessageCommand]:
return self._message_commands.copy()

@property
def message_checks(self) -> collections.Collection[abc.CheckSig]:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we get a quick 1 line docstring here explaining it

return tuple(check.callback for check in self._message_checks)

@property
def message_hooks(self) -> typing.Optional[abc.MessageHooks]:
return self._message_hooks

@property
def needs_injector(self) -> bool:
return any(check.needs_injector for check in self._checks)
return any(check.needs_injector for check in self._checks + self._message_checks + self._slash_checks)

@property
def listeners(
Expand All @@ -258,12 +280,14 @@ def copy(self: _ComponentT, *, _new: bool = True) -> _ComponentT:
if not _new:
self._checks = [check.copy() for check in self._checks]
self._slash_commands = {name: command.copy() for name, command in self._slash_commands.items()}
self._slash_checks = [check.copy() for check in self._slash_checks]
self._hooks = self._hooks.copy() if self._hooks else None
self._listeners = {
event: [copy.copy(listener) for listener in listeners] for event, listeners in self._listeners.items()
}
commands = {command: command.copy() for command in self._message_commands}
self._message_commands = list(commands.values())
self._message_checks = [check.copy() for check in self._message_checks]
self._metadata = self._metadata.copy()
self._names_to_commands = {name: commands[command] for name, command in self._names_to_commands.items()}
return self
Expand Down Expand Up @@ -317,6 +341,34 @@ def with_check(self, check: abc.CheckSigT, /) -> abc.CheckSigT:
self.add_check(check)
return check

def add_slash_check(self: _ComponentT, check: abc.CheckSig, /) -> _ComponentT:
if check not in self._slash_checks:
self._slash_checks.append(checks_.InjectableCheck(check))

return self

def remove_slash_check(self: _ComponentT, check: abc.CheckSig, /) -> _ComponentT:
self._slash_checks.remove(typing.cast("checks_.InjectableCheck", check))
return self

def with_slash_check(self, check: abc.CheckSigT, /) -> abc.CheckSigT:
self.add_slash_check(check)
return check

def add_message_check(self: _ComponentT, check: abc.CheckSig, /) -> _ComponentT:
if check not in self._message_checks:
self._message_checks.append(checks_.InjectableCheck(check))

return self

def remove_message_check(self: _ComponentT, check: abc.CheckSig, /) -> _ComponentT:
self._message_checks.remove(typing.cast("checks_.InjectableCheck", check))
return self

def with_message_check(self, check: abc.CheckSigT, /) -> abc.CheckSigT:
self.add_message_check(check)
return check

def add_client_callback(self: _ComponentT, event_name: str, callback: abc.MetaEventSig, /) -> _ComponentT:
event_name = event_name.lower()
try:
Expand Down Expand Up @@ -576,7 +628,11 @@ def unbind_client(self, client: abc.Client, /) -> None:
self._client = None

async def _check_context(self, ctx: abc.Context, /) -> bool:
return await utilities.gather_checks(ctx, self._checks)
if abc.SlashContext in type(ctx).mro():
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than instance check here, we could just specifically call utilitiys.gather_checks(ctx, self._slash_checks/self._message_checks) around line 645 and 656

additional_checks = self._slash_checks
else:
additional_checks = self._message_checks
return await utilities.gather_checks(ctx, self._checks + additional_checks)

async def _check_message_context(
self, ctx: abc.MessageContext, /
Expand Down
80 changes: 80 additions & 0 deletions tests/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,86 @@ def test_with_check(self):

assert result is mock_check

def test_add_slash_check(self):
mock_check = mock.Mock()
component = tanjun.Component()

result = component.add_slash_check(mock_check)

assert result is component

def test_add_slash_check_when_already_present(self):
mock_check = mock.Mock()
component = tanjun.Component().add_slash_check(mock_check)

with mock.patch.object(tanjun.checks, "InjectableCheck") as InjectableCheck:
result = component.add_slash_check(mock_check)

InjectableCheck.assert_not_called()

assert list(component.slash_checks).count(mock_check) == 1
assert result is component

def test_remove_slash_check(self):
component = tanjun.Component().add_slash_check(mock.Mock())

result = component.remove_slash_check(next(iter(component.slash_checks)))

assert result is component
assert not component.slash_checks

def test_remove_slash_check_when_not_present(self):
with pytest.raises(ValueError, match=".+"):
tanjun.Component().remove_slash_check(mock.Mock())

def test_with_slash_check(self):
mock_check = mock.Mock()
component = tanjun.Component()

result = component.with_slash_check(mock_check)

assert result is mock_check

def test_add_message_check(self):
mock_check = mock.Mock()
component = tanjun.Component()

result = component.add_message_check(mock_check)

assert result is component

def test_add_message_check_when_already_present(self):
mock_check = mock.Mock()
component = tanjun.Component().add_message_check(mock_check)

with mock.patch.object(tanjun.checks, "InjectableCheck") as InjectableCheck:
result = component.add_message_check(mock_check)

InjectableCheck.assert_not_called()

assert list(component.message_checks).count(mock_check) == 1
assert result is component

def test_remove_message_check(self):
component = tanjun.Component().add_message_check(mock.Mock())

result = component.remove_message_check(next(iter(component.message_checks)))

assert result is component
assert not component.message_checks

def test_remove_message_check_when_not_present(self):
with pytest.raises(ValueError, match=".+"):
tanjun.Component().remove_message_check(mock.Mock())

def test_with_message_check(self):
mock_check = mock.Mock()
component = tanjun.Component()

result = component.with_message_check(mock_check)

assert result is mock_check

def test_add_client_callback(self):
mock_callback = mock.Mock()
mock_other_callback = mock.Mock()
Expand Down