Skip to content

Commit

Permalink
Factor out collector methods on models into traits (serenity-rs#3055)
Browse files Browse the repository at this point in the history
Moves the feature-gated collector methods on `UserId`, `MessageId`, etc.
into four traits:
- `CollectMessages` 
- `CollectReactions`
- `CollectModalInteractions` 
- `CollectComponentInteractions`

This also moves the quick modal machinery into the collector module and defines
a `QuickModal` trait. 

This fully removes any collector feature gates from the model types.
  • Loading branch information
mkrasnitski committed Dec 8, 2024
1 parent 6286c18 commit 2ba1184
Show file tree
Hide file tree
Showing 14 changed files with 76 additions and 230 deletions.
2 changes: 1 addition & 1 deletion examples/e05_sample_bot_structure/src/commands/modal.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use serenity::builder::*;
use serenity::collector::{CreateQuickModal, QuickModal};
use serenity::model::prelude::*;
use serenity::prelude::*;
use serenity::utils::CreateQuickModal;

pub async fn run(ctx: &Context, interaction: &CommandInteraction) -> Result<(), serenity::Error> {
let modal = CreateQuickModal::new("About you")
Expand Down
6 changes: 3 additions & 3 deletions examples/e09_collectors/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::collections::HashSet;
use std::time::Duration;

use serenity::async_trait;
use serenity::collector::MessageCollector;
use serenity::collector::{CollectMessages, CollectReactions, MessageCollector};
// Collectors are streams, that means we can use `StreamExt` and `TryStreamExt`.
use serenity::futures::stream::StreamExt;
use serenity::model::prelude::*;
Expand All @@ -27,7 +27,7 @@ impl EventHandler for Handler {
// return a builder that can be turned into a Stream, or here, where we can await a
// single reply
let collector =
msg.author.id.await_reply(ctx.shard.clone()).timeout(Duration::from_secs(10));
msg.author.id.collect_messages(ctx.shard.clone()).timeout(Duration::from_secs(10));
if let Some(answer) = collector.await {
if answer.content.to_lowercase() == "ferris" {
let _ = answer.reply(&ctx.http, "That's correct!").await;
Expand All @@ -47,7 +47,7 @@ impl EventHandler for Handler {
// The message model can also be turned into a Collector to collect reactions on it.
let collector = react_msg
.id
.await_reaction(ctx.shard.clone())
.collect_reactions(ctx.shard.clone())
.timeout(Duration::from_secs(10))
.author_id(msg.author.id);

Expand Down
5 changes: 3 additions & 2 deletions examples/e14_message_components/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use serenity::builder::{
CreateSelectMenuKind,
CreateSelectMenuOption,
};
use serenity::collector::CollectComponentInteractions;
use serenity::futures::StreamExt;
use serenity::model::prelude::*;
use serenity::prelude::*;
Expand Down Expand Up @@ -59,7 +60,7 @@ impl EventHandler for Handler {
// manually in the EventHandler.
let interaction = match m
.id
.await_component_interaction(ctx.shard.clone())
.collect_component_interactions(ctx.shard.clone())
.timeout(Duration::from_secs(60 * 3))
.await
{
Expand Down Expand Up @@ -107,7 +108,7 @@ impl EventHandler for Handler {

// Wait for multiple interactions
let mut interaction_stream =
m.id.await_component_interaction(ctx.shard.clone())
m.id.collect_component_interactions(ctx.shard.clone())
.timeout(Duration::from_secs(60 * 3))
.stream();

Expand Down
3 changes: 2 additions & 1 deletion examples/testing/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::borrow::Cow;

use serenity::builder::*;
use serenity::collector::CollectComponentInteractions;
use serenity::model::prelude::*;
use serenity::prelude::*;

Expand Down Expand Up @@ -121,7 +122,7 @@ async fn message(ctx: &Context, msg: Message) -> Result<(), serenity::Error> {
.await?;
let button_press = msg
.id
.await_component_interaction(ctx.shard.clone())
.collect_component_interactions(ctx.shard.clone())
.timeout(std::time::Duration::from_secs(10))
.await;
match button_press {
Expand Down
44 changes: 14 additions & 30 deletions src/builder/edit_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,38 +107,22 @@ impl<'a> EditMessage<'a> {
/// Suppress or unsuppress embeds in the message, this includes those generated by Discord
/// themselves.
///
/// If this is sent directly after posting the message, there is a small chance Discord hasn't
/// yet fully parsed the contained links and generated the embeds, so this embed suppression
/// request has no effect. To mitigate this, you can defer the embed suppression until the
/// embeds have loaded:
///
/// ```rust,no_run
/// # use serenity::all::*;
/// # #[cfg(feature = "collector")]
/// # async fn test(ctx: &Context, channel_id: ChannelId) -> Result<(), Error> {
/// use std::time::Duration;
///
/// use futures::StreamExt;
///
/// let mut msg = channel_id.say(&ctx.http, "<link that spawns an embed>").await?;
///
/// // When the embed appears, a MessageUpdate event is sent and we suppress the embed.
/// // No MessageUpdate event is sent if the message contains no embeddable link or if the link
/// // has been posted before and is still cached in Discord's servers (in which case the
/// // embed appears immediately), no MessageUpdate event is sent. To not wait forever in those
/// // cases, a timeout of 2000ms was added.
/// let msg_id = msg.id;
/// let mut message_updates = serenity::collector::collect(&ctx.shard, move |ev| match ev {
/// Event::MessageUpdate(x) if x.id == msg_id => Some(()),
/// _ => None,
/// });
/// let _ = tokio::time::timeout(Duration::from_millis(2000), message_updates.next()).await;
/// msg.edit(&ctx, EditMessage::new().suppress_embeds(true)).await?;
/// # Ok(()) }
/// ```
/// If this is sent directly after the message has been posted, there is a small chance Discord
/// hasn't yet fully parsed the contained links and generated the embeds, so this embed
/// suppression request has no effect. Note that this is less likely for messages you have not
/// created yourself. There are two ways to mitigate this:
/// 1. If you are editing a message you created, simply set the
/// [`MessageFlags::SUPPRESS_EMBEDS`] flag on creation using [`CreateMessage::flags`] to
/// avoid having to edit the message in the first place.
/// 2. Defer the embed suppression until the embed has loaded. When the embed appears, a
/// `MessageUpdate` event is sent over the gateway. Note that this will not occur if a link
/// has been previously posted and is still cached by Discord, in which case the embed will
/// immediately appear.
///
/// [`CreateMessage::flags`]: super::CreateMessage::flags
pub fn suppress_embeds(mut self, suppress: bool) -> Self {
// At time of writing, only `SUPPRESS_EMBEDS` can be set/unset when editing messages. See
// for details: https://discord.com/developers/docs/resources/channel#edit-message-jsonform-params
// for details: https://discord.com/developers/docs/resources/message#edit-message-jsonform-params
let flags =
suppress.then_some(MessageFlags::SUPPRESS_EMBEDS).unwrap_or_else(MessageFlags::empty);

Expand Down
24 changes: 23 additions & 1 deletion src/collector.rs → src/collector/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
mod quick_modal;

use std::sync::Arc;

use futures::future::pending;
use futures::{Stream, StreamExt as _};
pub use quick_modal::*;

use crate::gateway::{CollectorCallback, ShardMessenger};
use crate::internal::prelude::*;
Expand Down Expand Up @@ -51,6 +54,7 @@ macro_rules! make_specific_collector {
(
$( #[ $($meta:tt)* ] )*
$collector_type:ident, $item_type:ident,
$collector_trait:ident, $method_name:ident,
$extractor:pat => $extracted_item:ident,
$( $filter_name:ident: $filter_type:ty => $filter_passes:expr, )*
) => {
Expand Down Expand Up @@ -100,7 +104,7 @@ macro_rules! make_specific_collector {
let filters_pass = move |$extracted_item: &$item_type| {
// Check each of the built-in filters (author_id, channel_id, etc.)
$( if let Some($filter_name) = &self.$filter_name {
if !$filter_passes {
if !($filter_passes) {
return false;
}
} )*
Expand Down Expand Up @@ -142,12 +146,27 @@ macro_rules! make_specific_collector {
Box::pin(self.next())
}
}

pub trait $collector_trait {
fn $method_name(self, shard_messenger: ShardMessenger) -> $collector_type;
}

$(
impl $collector_trait for $filter_type {
fn $method_name(self, shard_messenger: ShardMessenger) -> $collector_type {
$collector_type::new(shard_messenger).$filter_name(self)
}
}
)*
};
}

make_specific_collector!(
// First line has name of the collector type, and the type of the collected items.
ComponentInteractionCollector, ComponentInteraction,
// Second line has name of the specific trait and method name that will be
// implemented on the filter argument types listed below.
CollectComponentInteractions, collect_component_interactions,
// This defines the extractor pattern, which extracts the data we want to collect from an Event.
Event::InteractionCreate(InteractionCreateEvent {
interaction: Interaction::Component(interaction),
Expand All @@ -165,6 +184,7 @@ make_specific_collector!(
);
make_specific_collector!(
ModalInteractionCollector, ModalInteraction,
CollectModalInteractions, collect_modal_interactions,
Event::InteractionCreate(InteractionCreateEvent {
interaction: Interaction::Modal(interaction),
}) => interaction,
Expand All @@ -176,6 +196,7 @@ make_specific_collector!(
);
make_specific_collector!(
ReactionCollector, Reaction,
CollectReactions, collect_reactions,
Event::ReactionAdd(ReactionAddEvent { reaction }) => reaction,
author_id: UserId => reaction.user_id.map_or(true, |a| a == *author_id),
channel_id: ChannelId => reaction.channel_id == *channel_id,
Expand All @@ -184,6 +205,7 @@ make_specific_collector!(
);
make_specific_collector!(
MessageCollector, Message,
CollectMessages, collect_messages,
Event::MessageCreate(MessageCreateEvent { message }) => message,
author_id: UserId => message.author.id == *author_id,
channel_id: ChannelId => message.channel_id == *channel_id,
Expand Down
34 changes: 30 additions & 4 deletions src/utils/quick_modal.rs → src/collector/quick_modal.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::borrow::Cow;
use std::future::Future;

use crate::builder::{CreateActionRow, CreateInputText, CreateInteractionResponse, CreateModal};
use crate::collector::ModalInteractionCollector;
use crate::gateway::client::Context;
use crate::internal::prelude::*;
use crate::model::prelude::*;

#[cfg(feature = "collector")]
pub struct QuickModalResponse {
pub interaction: ModalInteraction,
pub inputs: FixedArray<FixedString<u16>>,
Expand All @@ -15,7 +15,7 @@ pub struct QuickModalResponse {
/// Convenience builder to create a modal, wait for the user to submit and parse the response.
///
/// ```rust
/// # use serenity::{builder::*, model::prelude::*, prelude::*, utils::CreateQuickModal, Result};
/// # use serenity::{builder::*, model::prelude::*, prelude::*, collector::*, Result};
/// # async fn foo_(ctx: &Context, interaction: &CommandInteraction) -> Result<()> {
/// let modal = CreateQuickModal::new("About you")
/// .timeout(std::time::Duration::from_secs(600))
Expand All @@ -28,15 +28,13 @@ pub struct QuickModalResponse {
/// # Ok(())
/// # }
/// ```
#[cfg(feature = "collector")]
#[must_use]
pub struct CreateQuickModal<'a> {
title: Cow<'a, str>,
timeout: Option<std::time::Duration>,
input_texts: Vec<CreateInputText<'a>>,
}

#[cfg(feature = "collector")]
impl<'a> CreateQuickModal<'a> {
pub fn new(title: impl Into<Cow<'a, str>>) -> Self {
Self {
Expand Down Expand Up @@ -143,3 +141,31 @@ impl<'a> CreateQuickModal<'a> {
}))
}
}

pub trait QuickModal {
fn quick_modal(
&self,
ctx: &Context,
builder: CreateQuickModal<'_>,
) -> impl Future<Output = Result<Option<QuickModalResponse>>>;
}

impl QuickModal for CommandInteraction {
async fn quick_modal(
&self,
ctx: &Context,
builder: CreateQuickModal<'_>,
) -> Result<Option<QuickModalResponse>> {
builder.execute(ctx, self.id, &self.token).await
}
}

impl QuickModal for ComponentInteraction {
async fn quick_modal(
&self,
ctx: &Context,
builder: CreateQuickModal<'_>,
) -> Result<Option<QuickModalResponse>> {
builder.execute(ctx, self.id, &self.token).await
}
}
18 changes: 0 additions & 18 deletions src/model/application/command_interaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,10 @@ use crate::builder::{
CreateInteractionResponseMessage,
EditInteractionResponse,
};
#[cfg(feature = "collector")]
use crate::gateway::client::Context;
#[cfg(feature = "model")]
use crate::http::Http;
use crate::internal::utils::lending_for_each;
use crate::model::prelude::*;
#[cfg(all(feature = "collector", feature = "utils"))]
use crate::utils::{CreateQuickModal, QuickModalResponse};

/// An interaction when a user invokes a slash command.
///
Expand Down Expand Up @@ -204,20 +200,6 @@ impl CommandInteraction {
);
self.create_response(http, builder).await
}

/// See [`CreateQuickModal`].
///
/// # Errors
///
/// See [`CreateQuickModal::execute()`].
#[cfg(all(feature = "collector", feature = "utils"))]
pub async fn quick_modal(
&self,
ctx: &Context,
builder: CreateQuickModal<'_>,
) -> Result<Option<QuickModalResponse>> {
builder.execute(ctx, self.id, &self.token).await
}
}

// Manual impl needed to insert guild_id into resolved Role's
Expand Down
18 changes: 0 additions & 18 deletions src/model/application/component_interaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,9 @@ use crate::builder::{
CreateInteractionResponseMessage,
EditInteractionResponse,
};
#[cfg(feature = "collector")]
use crate::gateway::client::Context;
#[cfg(feature = "model")]
use crate::http::Http;
use crate::model::prelude::*;
#[cfg(all(feature = "collector", feature = "utils"))]
use crate::utils::{CreateQuickModal, QuickModalResponse};

/// An interaction triggered by a message component.
///
Expand Down Expand Up @@ -201,20 +197,6 @@ impl ComponentInteraction {
);
self.create_response(http, builder).await
}

/// See [`CreateQuickModal`].
///
/// # Errors
///
/// See [`CreateQuickModal::execute()`].
#[cfg(all(feature = "collector", feature = "utils"))]
pub async fn quick_modal(
&self,
ctx: &Context,
builder: CreateQuickModal<'_>,
) -> Result<Option<QuickModalResponse>> {
builder.execute(ctx, self.id, &self.token).await
}
}

// Manual impl needed to insert guild_id into model data
Expand Down
30 changes: 0 additions & 30 deletions src/model/channel/channel_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ use crate::builder::{
};
#[cfg(all(feature = "cache", feature = "model"))]
use crate::cache::Cache;
#[cfg(feature = "collector")]
use crate::collector::{MessageCollector, ReactionCollector};
#[cfg(feature = "collector")]
use crate::gateway::ShardMessenger;
#[cfg(feature = "model")]
use crate::http::{CacheHttp, Http, Typing};
use crate::model::prelude::*;
Expand Down Expand Up @@ -802,32 +798,6 @@ impl ChannelId {
builder.execute(http, self).await
}

/// Returns a builder which can be awaited to obtain a message or stream of messages in this
/// channel.
#[cfg(feature = "collector")]
pub fn await_reply(self, shard_messenger: ShardMessenger) -> MessageCollector {
MessageCollector::new(shard_messenger).channel_id(self)
}

/// Same as [`Self::await_reply`].
#[cfg(feature = "collector")]
pub fn await_replies(self, shard_messenger: ShardMessenger) -> MessageCollector {
self.await_reply(shard_messenger)
}

/// Returns a builder which can be awaited to obtain a reaction or stream of reactions sent in
/// this channel.
#[cfg(feature = "collector")]
pub fn await_reaction(self, shard_messenger: ShardMessenger) -> ReactionCollector {
ReactionCollector::new(shard_messenger).channel_id(self)
}

/// Same as [`Self::await_reaction`].
#[cfg(feature = "collector")]
pub fn await_reactions(self, shard_messenger: ShardMessenger) -> ReactionCollector {
self.await_reaction(shard_messenger)
}

/// Gets a stage instance.
///
/// # Errors
Expand Down
Loading

0 comments on commit 2ba1184

Please sign in to comment.