diff --git a/Cargo.toml b/Cargo.toml index 82a27d5c97f..027c8503c48 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,7 +49,7 @@ percent-encoding = { version = "2.3.0", optional = true } mini-moka = { version = "0.10.2", optional = true } mime_guess = { version = "2.0.4", optional = true } dashmap = { version = "5.5.3", features = ["serde"], optional = true } -parking_lot = { version = "0.12.1", optional = true } +parking_lot = { version = "0.12.1"} ed25519-dalek = { version = "2.0.0", optional = true } typesize = { version = "0.1.6", optional = true, features = ["url", "time", "serde_json", "secrecy", "dashmap", "parking_lot", "nonmax", "extract_map_01", "details"] } # serde feature only allows for serialisation, @@ -83,7 +83,7 @@ default_no_backend = [ builder = ["tokio/fs"] # Enables the cache, which stores the data received from Discord gateway to provide access to # complete guild data, channels, users and more without needing HTTP requests. -cache = ["fxhash", "dashmap", "parking_lot"] +cache = ["fxhash", "dashmap"] # Enables collectors, a utility feature that lets you await interaction events in code with # zero setup, without needing to setup an InteractionCreate event listener. collector = ["gateway", "model"] @@ -95,7 +95,7 @@ framework = ["client", "model", "utils"] # Enables gateway support, which allows bots to listen for Discord events. gateway = ["flate2"] # Enables HTTP, which enables bots to execute actions on Discord. -http = ["dashmap", "parking_lot", "mime_guess", "percent-encoding"] +http = ["dashmap", "mime_guess", "percent-encoding"] # Enables wrapper methods around HTTP requests on model types. # Requires "builder" to configure the requests and "http" to execute them. # Note: the model type definitions themselves are always active, regardless of this feature. @@ -116,7 +116,7 @@ chrono = ["dep:chrono", "typesize?/chrono"] # This enables all parts of the serenity codebase # (Note: all feature-gated APIs to be documented should have their features listed here!) -# +# # Unstable functionality should be gated under the `unstable` feature. full = ["default", "collector", "voice", "voice_model", "interactions_endpoint"] diff --git a/src/collector.rs b/src/collector.rs index 83b632d2b1c..459dcc496e1 100644 --- a/src/collector.rs +++ b/src/collector.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use futures::future::pending; use futures::{Stream, StreamExt as _}; @@ -35,7 +37,7 @@ pub fn collect( let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel(); // Register an event callback in the shard. It's kept alive as long as we return `true` - shard.add_collector(CollectorCallback(Box::new(move |event| match extractor(event) { + shard.add_collector(CollectorCallback(Arc::new(move |event| match extractor(event) { // If this event matches, we send it to the receiver stream Some(item) => sender.send(item).is_ok(), None => !sender.is_closed(), diff --git a/src/gateway/bridge/mod.rs b/src/gateway/bridge/mod.rs index 0741ef7f1ee..f5b2bbe845c 100644 --- a/src/gateway/bridge/mod.rs +++ b/src/gateway/bridge/mod.rs @@ -51,6 +51,7 @@ mod voice; use std::fmt; use std::num::NonZeroU16; +use std::sync::Arc; use std::time::Duration as StdDuration; pub use self::event::ShardStageUpdateEvent; @@ -97,9 +98,17 @@ pub struct ShardRunnerInfo { /// Newtype around a callback that will be called on every incoming request. As long as this /// collector should still receive events, it should return `true`. Once it returns `false`, it is /// removed. -pub struct CollectorCallback(pub Box bool + Send + Sync>); +#[derive(Clone)] +pub struct CollectorCallback(pub Arc bool + Send + Sync>); + impl std::fmt::Debug for CollectorCallback { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_tuple("CollectorCallback").finish() } } + +impl PartialEq for CollectorCallback { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } +} diff --git a/src/gateway/bridge/shard_messenger.rs b/src/gateway/bridge/shard_messenger.rs index c9371fd7424..39bb10fb82f 100644 --- a/src/gateway/bridge/shard_messenger.rs +++ b/src/gateway/bridge/shard_messenger.rs @@ -21,7 +21,7 @@ use crate::model::prelude::*; pub struct ShardMessenger { pub(crate) tx: Sender, #[cfg(feature = "collector")] - pub(crate) collectors: Arc>>, + pub(crate) collectors: Arc>>, } impl ShardMessenger { @@ -211,6 +211,6 @@ impl ShardMessenger { #[cfg(feature = "collector")] pub fn add_collector(&self, collector: CollectorCallback) { - self.collectors.lock().expect("poison").push(collector); + self.collectors.write().push(collector); } } diff --git a/src/gateway/bridge/shard_runner.rs b/src/gateway/bridge/shard_runner.rs index 2e64075eff4..4eb2db2bc8d 100644 --- a/src/gateway/bridge/shard_runner.rs +++ b/src/gateway/bridge/shard_runner.rs @@ -43,7 +43,7 @@ pub struct ShardRunner { pub cache: Arc, pub http: Arc, #[cfg(feature = "collector")] - pub(crate) collectors: Arc>>, + pub(crate) collectors: Arc>>, } impl ShardRunner { @@ -66,7 +66,7 @@ impl ShardRunner { cache: opt.cache, http: opt.http, #[cfg(feature = "collector")] - collectors: Arc::new(std::sync::Mutex::new(vec![])), + collectors: Arc::new(parking_lot::RwLock::new(vec![])), } } @@ -171,7 +171,18 @@ impl ShardRunner { if let Some(event) = event { #[cfg(feature = "collector")] - self.collectors.lock().expect("poison").retain_mut(|callback| (callback.0)(&event)); + { + let read_lock = self.collectors.read(); + // search all collectors to be removed and clone the Arcs + let to_remove: Vec<_> = + read_lock.iter().filter(|callback| !callback.0(&event)).cloned().collect(); + drop(read_lock); + // remove all found arcs from the collection + // this compares the inner pointer of the Arc + if !to_remove.is_empty() { + self.collectors.write().retain(|f| !to_remove.contains(f)); + } + } spawn_named( "shard_runner::dispatch", dispatch_model(