Skip to content

Implemented optional duration parameter in slowmode command #3331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 99 additions & 15 deletions bot/exts/moderation/slowmode.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from datetime import UTC, datetime, timedelta
from typing import Literal

from async_rediscache import RedisCache
from dateutil.relativedelta import relativedelta
from discord import TextChannel, Thread
from discord.ext.commands import Cog, Context, group, has_any_role
from pydis_core.utils.scheduling import Scheduler

from bot.bot import Bot
from bot.constants import Channels, Emojis, MODERATION_ROLES
from bot.converters import DurationDelta
from bot.log import get_logger
from bot.utils import time
from bot.utils.time import TimestampFormats, discord_timestamp

log = get_logger(__name__)

Expand All @@ -26,8 +30,15 @@
class Slowmode(Cog):
"""Commands for getting and setting slowmode delays of text channels."""

# Stores the expiration timestamp in POSIX format for active slowmodes, keyed by channel ID.
slowmode_expiration_cache = RedisCache()

# Stores the original slowmode interval by channel ID, allowing its restoration after temporary slowmode expires.
original_slowmode_cache = RedisCache()

def __init__(self, bot: Bot) -> None:
self.bot = bot
self.scheduler = Scheduler(self.__class__.__name__)

@group(name="slowmode", aliases=["sm"], invoke_without_command=True)
async def slowmode_group(self, ctx: Context) -> None:
Expand All @@ -42,17 +53,29 @@ async def get_slowmode(self, ctx: Context, channel: MessageHolder) -> None:
channel = ctx.channel

humanized_delay = time.humanize_delta(seconds=channel.slowmode_delay)

await ctx.send(f"The slowmode delay for {channel.mention} is {humanized_delay}.")
if await self.slowmode_expiration_cache.contains(channel.id):
expiration_time = await self.slowmode_expiration_cache.get(channel.id)
expiration_timestamp = discord_timestamp(expiration_time, TimestampFormats.RELATIVE)
await ctx.send(
f"The slowmode delay for {channel.mention} is {humanized_delay} and expires in {expiration_timestamp}."
)
else:
await ctx.send(f"The slowmode delay for {channel.mention} is {humanized_delay}.")

@slowmode_group.command(name="set", aliases=["s"])
async def set_slowmode(
self,
ctx: Context,
channel: MessageHolder,
delay: DurationDelta | Literal["0s", "0seconds"],
duration: DurationDelta | None = None
) -> None:
"""Set the slowmode delay for a text channel."""
"""
Set the slowmode delay for a text channel.

Supports temporary slowmodes with the `duration` argument that automatically
revert to the original delay after expiration.
"""
# Use the channel this command was invoked in if one was not given
if channel is None:
channel = ctx.channel
Expand All @@ -66,37 +89,98 @@ async def set_slowmode(
humanized_delay = time.humanize_delta(delay)

# Ensure the delay is within discord's limits
if slowmode_delay <= SLOWMODE_MAX_DELAY:
log.info(f"{ctx.author} set the slowmode delay for #{channel} to {humanized_delay}.")

await channel.edit(slowmode_delay=slowmode_delay)
if channel.id in COMMONLY_SLOWMODED_CHANNELS:
log.info(f"Recording slowmode change in stats for {channel.name}.")
self.bot.stats.gauge(f"slowmode.{COMMONLY_SLOWMODED_CHANNELS[channel.id]}", slowmode_delay)
if not slowmode_delay <= SLOWMODE_MAX_DELAY:
log.info(
f"{ctx.author} tried to set the slowmode delay of #{channel} to {humanized_delay}, "
"which is not between 0 and 6 hours."
)

await ctx.send(
f"{Emojis.check_mark} The slowmode delay for {channel.mention} is now {humanized_delay}."
f"{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours."
)
return

else:
if duration is not None:
slowmode_duration = time.relativedelta_to_timedelta(duration).total_seconds()
humanized_duration = time.humanize_delta(duration)

expiration_time = datetime.now(tz=UTC) + timedelta(seconds=slowmode_duration)
expiration_timestamp = discord_timestamp(expiration_time, TimestampFormats.RELATIVE)

# Only update original_slowmode_cache if the last slowmode was not temporary.
if not await self.slowmode_expiration_cache.contains(channel.id):
await self.original_slowmode_cache.set(channel.id, channel.slowmode_delay)
await self.slowmode_expiration_cache.set(channel.id, expiration_time.timestamp())

self.scheduler.schedule_at(expiration_time, channel.id, self._revert_slowmode(channel.id))
log.info(
f"{ctx.author} tried to set the slowmode delay of #{channel} to {humanized_delay}, "
"which is not between 0 and 6 hours."
f"{ctx.author} set the slowmode delay for #{channel} to"
f"{humanized_delay} which expires in {humanized_duration}."
)
await channel.edit(slowmode_delay=slowmode_delay)
await ctx.send(
f"{Emojis.check_mark} The slowmode delay for {channel.mention}"
f" is now {humanized_delay} and expires in {expiration_timestamp}."
)
else:
if await self.slowmode_expiration_cache.contains(channel.id):
await self.slowmode_expiration_cache.delete(channel.id)
await self.original_slowmode_cache.delete(channel.id)
self.scheduler.cancel(channel.id)

log.info(f"{ctx.author} set the slowmode delay for #{channel} to {humanized_delay}.")
await channel.edit(slowmode_delay=slowmode_delay)
await ctx.send(
f"{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours."
f"{Emojis.check_mark} The slowmode delay for {channel.mention} is now {humanized_delay}."
)
if channel.id in COMMONLY_SLOWMODED_CHANNELS:
log.info(f"Recording slowmode change in stats for {channel.name}.")
self.bot.stats.gauge(f"slowmode.{COMMONLY_SLOWMODED_CHANNELS[channel.id]}", slowmode_delay)

async def _reschedule(self) -> None:
log.trace("Rescheduling the expiration of temporary slowmodes from cache.")
for channel_id, expiration in await self.slowmode_expiration_cache.items():
expiration_datetime = datetime.fromtimestamp(expiration, tz=UTC)
channel = self.bot.get_channel(channel_id)
log.info(f"Rescheduling slowmode expiration for #{channel} ({channel_id}).")
self.scheduler.schedule_at(expiration_datetime, channel_id, self._revert_slowmode(channel_id))

async def _revert_slowmode(self, channel_id: int) -> None:
original_slowmode = await self.original_slowmode_cache.get(channel_id)
slowmode_delay = time.humanize_delta(seconds=original_slowmode)
channel = self.bot.get_channel(channel_id)
log.info(f"Slowmode in #{channel} ({channel.id}) has expired and has reverted to {slowmode_delay}.")
await channel.edit(slowmode_delay=original_slowmode)
await channel.send(
f"{Emojis.check_mark} A previously applied slowmode has expired and has been reverted to {slowmode_delay}."
)
await self.slowmode_expiration_cache.delete(channel.id)
await self.original_slowmode_cache.delete(channel.id)

@slowmode_group.command(name="reset", aliases=["r"])
async def reset_slowmode(self, ctx: Context, channel: MessageHolder) -> None:
"""Reset the slowmode delay for a text channel to 0 seconds."""
await self.set_slowmode(ctx, channel, relativedelta(seconds=0))
if channel is None:
channel = ctx.channel
if await self.slowmode_expiration_cache.contains(channel.id):
await self.slowmode_expiration_cache.delete(channel.id)
await self.original_slowmode_cache.delete(channel.id)
self.scheduler.cancel(channel.id)

async def cog_check(self, ctx: Context) -> bool:
"""Only allow moderators to invoke the commands in this cog."""
return await has_any_role(*MODERATION_ROLES).predicate(ctx)

async def cog_load(self) -> None:
"""Wait for guild to become available and reschedule slowmodes which should expire."""
await self.bot.wait_until_guild_available()
await self._reschedule()

async def cog_unload(self) -> None:
"""Cancel all scheduled tasks."""
self.scheduler.cancel_all()


async def setup(bot: Bot) -> None:
"""Load the Slowmode cog."""
Expand Down
114 changes: 112 additions & 2 deletions tests/bot/exts/moderation/test_slowmode.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import unittest
import asyncio
import datetime
from unittest import mock

from dateutil.relativedelta import relativedelta

from bot.constants import Emojis
from bot.exts.moderation.slowmode import Slowmode
from tests.base import RedisTestCase
from tests.helpers import MockBot, MockContext, MockTextChannel


class SlowmodeTests(unittest.IsolatedAsyncioTestCase):
class SlowmodeTests(RedisTestCase):

def setUp(self) -> None:
self.bot = MockBot()
Expand Down Expand Up @@ -95,6 +97,114 @@ async def test_reset_slowmode_sets_delay_to_zero(self) -> None:
self.ctx, text_channel, relativedelta(seconds=0)
)

@mock.patch("bot.exts.moderation.slowmode.datetime")
async def test_set_slowmode_with_duration(self, mock_datetime) -> None:
"""Set slowmode with a duration"""
mock_datetime.now.return_value = datetime.datetime(2025, 6, 2, 12, 0, 0, tzinfo=datetime.UTC)
test_cases = (
("python-general", 6, 6000, f"{Emojis.check_mark} The slowmode delay for #python-general is now 6 seconds"
" and expires in <t:1748871600:R>."),
("mod-spam", 5, 600, f"{Emojis.check_mark} The slowmode delay for #mod-spam is now 5 seconds and expires"
" in <t:1748866200:R>."),
("changelog", 12, 7200, f"{Emojis.check_mark} The slowmode delay for #changelog is now 12 seconds and"
" expires in <t:1748872800:R>.")
)
for channel_name, seconds, duration, result_msg in test_cases:
with self.subTest(
channel_mention=channel_name,
seconds=seconds,
duration=duration,
result_msg=result_msg
):
text_channel = MockTextChannel(name=channel_name, slowmode_delay=0)
await self.cog.set_slowmode(
self.cog,
self.ctx,
text_channel,
relativedelta(seconds=seconds),
duration=relativedelta(seconds=duration)
)
text_channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds))
self.ctx.send.assert_called_once_with(result_msg)
self.ctx.reset_mock()

@mock.patch("bot.exts.moderation.slowmode.datetime", wraps=datetime.datetime)
async def test_callback_scheduled(self, mock_datetime, ):
"""Schedule slowmode to be reverted"""
mock_now = datetime.datetime(2025, 6, 2, 12, 0, 0, tzinfo=datetime.UTC)
mock_datetime.now.return_value = mock_now
self.cog.scheduler=mock.MagicMock(wraps=self.cog.scheduler)

text_channel = MockTextChannel(name="python-general", slowmode_delay=2, id=123)
await self.cog.set_slowmode(
self.cog,
self.ctx,
text_channel,
relativedelta(seconds=4),
relativedelta(seconds=10))

args = (mock_now+relativedelta(seconds=10), text_channel.id, mock.ANY)
self.cog.scheduler.schedule_at.assert_called_once_with(*args)

async def test_revert_slowmode_callback(self) -> None:
"""Check that the slowmode is reverted"""
text_channel = MockTextChannel(name="python-general", slowmode_delay=2, id=123)
self.bot.get_channel = mock.MagicMock(return_value=text_channel)
await self.cog.set_slowmode(
self.cog, self.ctx, text_channel, relativedelta(seconds=4), relativedelta(seconds=10)
)
await self.cog._revert_slowmode(text_channel.id)
text_channel.edit.assert_awaited_with(slowmode_delay=2)
text_channel.send.assert_called_once_with(
f"{Emojis.check_mark} A previously applied slowmode has expired and has been reverted to 2 seconds."
)

async def test_reschedule_slowmodes(self) -> None:
"""Does not reschedule if cache is empty"""
self.cog.scheduler.schedule_at = mock.MagicMock()
self.cog._reschedule = mock.AsyncMock()
await self.cog.cog_unload()
await self.cog.cog_load()

self.cog._reschedule.assert_called()
self.cog.scheduler.schedule_at.assert_not_called()

async def test_reschedule_upon_reload(self) -> None:
""" Check that method `_reschedule` is called upon cog reload"""
self.cog._reschedule = mock.AsyncMock(wraps=self.cog._reschedule)
await self.cog.cog_unload()
await self.cog.cog_load()

self.cog._reschedule.assert_called()

@mock.patch("bot.exts.moderation.slowmode.datetime", wraps=datetime.datetime)
async def test_reschedules_slowmodes(self, mock_datetime) -> None:
"""Slowmodes are loaded from cache at cog reload and scheduled to be reverted."""
mock_datetime.now.return_value = datetime.datetime(2025, 6, 2, 12, 0, 0, tzinfo=datetime.UTC)
mock_now = datetime.datetime(2025, 6, 2, 12, 0, 0, tzinfo=datetime.UTC)

channels = {}
slowmodes = (
(123, (mock_now - datetime.timedelta(10)).timestamp(), 2), # expiration in the past
(456, (mock_now + datetime.timedelta(20)).timestamp(), 4), # expiration in the future
)

for channel_id, expiration_datetime, delay in slowmodes:
channel = MockTextChannel(slowmode_delay=delay, id=channel_id)
channels[channel_id] = channel
await self.cog.slowmode_expiration_cache.set(channel_id, expiration_datetime)
await self.cog.original_slowmode_cache.set(channel_id, delay)

self.bot.get_channel = mock.MagicMock(side_effect=lambda channel_id: channels.get(channel_id))
await self.cog.cog_unload()
await self.cog.cog_load()
for channel_id in channels:
self.assertIn(channel_id, self.cog.scheduler)

await asyncio.sleep(1) # give scheduled task time to execute
channels[123].edit.assert_awaited_once_with(slowmode_delay=channels[123].slowmode_delay)
channels[456].edit.assert_not_called()

@mock.patch("bot.exts.moderation.slowmode.has_any_role")
@mock.patch("bot.exts.moderation.slowmode.MODERATION_ROLES", new=(1, 2, 3))
async def test_cog_check(self, role_check):
Expand Down