Skip to content

Commit

Permalink
Simplify the message cache (#2757)
Browse files Browse the repository at this point in the history
  • Loading branch information
GnomedDev authored Feb 9, 2024
1 parent f874ac8 commit 1775e6f
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 38 deletions.
33 changes: 16 additions & 17 deletions src/cache/event.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashSet;
use std::collections::{HashSet, VecDeque};
use std::num::NonZeroU16;

use super::{Cache, CacheUpdate};
Expand Down Expand Up @@ -54,16 +54,16 @@ impl CacheUpdate for ChannelCreateEvent {
}

impl CacheUpdate for ChannelDeleteEvent {
type Output = Vec<Message>;
type Output = VecDeque<Message>;

fn update(&mut self, cache: &Cache) -> Option<Vec<Message>> {
fn update(&mut self, cache: &Cache) -> Option<VecDeque<Message>> {
let (channel_id, guild_id) = (self.channel.id, self.channel.guild_id);

cache.channels.remove(&channel_id);
cache.guilds.get_mut(&guild_id).map(|mut g| g.channels.remove(&channel_id));

// Remove the cached messages for the channel.
cache.messages.remove(&channel_id).map(|(_, messages)| messages.into_values().collect())
cache.messages.remove(&channel_id).map(|(_, messages)| messages)
}
}

Expand Down Expand Up @@ -337,18 +337,15 @@ impl CacheUpdate for MessageCreateEvent {
}

let mut messages = cache.messages.entry(self.message.channel_id).or_default();
let mut queue = cache.message_queue.entry(self.message.channel_id).or_default();

let mut removed_msg = None;

if messages.len() == max {
if let Some(id) = queue.pop_front() {
removed_msg = messages.remove(&id);
}
removed_msg = messages.pop_front();
}

queue.push_back(self.message.id);
messages.insert(self.message.id, self.message.clone());
if !messages.iter().any(|m| m.id == self.message.id) {
messages.push_back(self.message.clone());
}

removed_msg
}
Expand All @@ -358,13 +355,15 @@ impl CacheUpdate for MessageUpdateEvent {
type Output = Message;

fn update(&mut self, cache: &Cache) -> Option<Self::Output> {
let mut messages = cache.messages.get_mut(&self.channel_id)?;
let message = messages.get_mut(&self.id)?;
let old_message = message.clone();

self.apply_to_message(message);
for message in cache.messages.get_mut(&self.channel_id)?.iter_mut() {
if message.id == self.id {
let old_message = message.clone();
self.apply_to_message(message);
return Some(old_message);
}
}

Some(old_message)
None
}
}

Expand Down
30 changes: 11 additions & 19 deletions src/cache/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ mod wrappers;
pub(crate) use wrappers::MaybeOwnedArc;
use wrappers::{BuildHasher, MaybeMap, ReadOnlyMapRef};

type MessageCache = DashMap<ChannelId, HashMap<MessageId, Message>, BuildHasher>;

struct NotSend;

enum CacheRefInner<'a, K, V, T> {
Expand Down Expand Up @@ -117,9 +115,9 @@ pub type SettingsRef<'a> = CacheRef<'a, (), Settings>;
pub type CurrentUserRef<'a> = CacheRef<'a, (), CurrentUser>;
pub type GuildChannelRef<'a> = MappedGuildRef<'a, GuildChannel>;
pub type GuildRolesRef<'a> = MappedGuildRef<'a, HashMap<RoleId, Role>>;
pub type ChannelMessagesRef<'a> = CacheRef<'a, ChannelId, VecDeque<Message>>;
pub type MessageRef<'a> = CacheRef<'a, ChannelId, Message, VecDeque<Message>>;
pub type GuildChannelsRef<'a> = MappedGuildRef<'a, HashMap<ChannelId, GuildChannel>>;
pub type ChannelMessagesRef<'a> = CacheRef<'a, ChannelId, HashMap<MessageId, Message>>;
pub type MessageRef<'a> = CacheRef<'a, ChannelId, Message, HashMap<MessageId, Message>>;

#[cfg_attr(feature = "typesize", derive(typesize::derive::TypeSize))]
#[derive(Debug)]
Expand Down Expand Up @@ -190,13 +188,7 @@ pub struct Cache {

// Messages cache:
// ---
pub(crate) messages: MessageCache,
/// Queue of message IDs for each channel.
///
/// This is simply a vecdeque so we can keep track of the order of messages inserted into the
/// cache. When a maximum number of messages are in a channel's cache, we can pop the front and
/// remove that ID from the cache.
pub(crate) message_queue: DashMap<ChannelId, VecDeque<MessageId>, BuildHasher>,
pub(crate) messages: DashMap<ChannelId, VecDeque<Message>, BuildHasher>,

// Miscellanous fixed-size data
// ---
Expand Down Expand Up @@ -259,7 +251,6 @@ impl Cache {
presences: MaybeMap(settings.cache_users.then(DashMap::default)),

messages: DashMap::default(),
message_queue: DashMap::default(),

shard_data: RwLock::new(CachedShardData {
total: NonZeroU16::MIN,
Expand Down Expand Up @@ -385,10 +376,10 @@ impl Cache {
/// # use serenity::model::id::ChannelId;
/// #
/// # let cache: serenity::cache::Cache = todo!();
/// let messages_in_channel = cache.channel_messages(ChannelId::new(7));
/// let messages_by_user = messages_in_channel
/// .as_ref()
/// .map(|msgs| msgs.values().filter(|m| m.author.id == 8).collect::<Vec<_>>());
/// if let Some(messages_in_channel) = cache.channel_messages(ChannelId::new(7)) {
/// let messages_by_user: Vec<_> =
/// messages_in_channel.iter().filter(|m| m.author.id == 8).collect();
/// }
/// ```
pub fn channel_messages(&self, channel_id: ChannelId) -> Option<ChannelMessagesRef<'_>> {
self.messages.get(&channel_id).map(CacheRef::from_ref)
Expand Down Expand Up @@ -534,8 +525,9 @@ impl Cache {
return Some(CacheRef::from_arc(message));
}

let channel_messages = self.messages.get(&channel_id)?;
let message = channel_messages.try_map(|messages| messages.get(&message_id)).ok()?;
let messages = self.messages.get(&channel_id)?;
let message =
messages.try_map(|messages| messages.iter().find(|m| m.id == message_id)).ok()?;
Some(CacheRef::from_mapped_ref(message))
}

Expand Down Expand Up @@ -693,7 +685,7 @@ mod test {

assert_eq!(channel.len(), 2);
// Check that the first message is now removed.
assert!(!channel.contains_key(&MessageId::new(3)));
assert!(!channel.iter().any(|m| m.id == MessageId::new(3)));
}

let channel = GuildChannel {
Expand Down
4 changes: 2 additions & 2 deletions src/client/event_handler.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::collections::{HashMap, VecDeque};
#[cfg(feature = "cache")]
use std::num::NonZeroU16;

Expand Down Expand Up @@ -139,7 +139,7 @@ event_handler! {
/// Dispatched when a channel is deleted.
///
/// Provides said channel's data.
ChannelDelete { channel: GuildChannel, messages: Option<Vec<Message>> } => async fn channel_delete(&self, ctx: Context);
ChannelDelete { channel: GuildChannel, messages: Option<VecDeque<Message>> } => async fn channel_delete(&self, ctx: Context);

/// Dispatched when a pin is added, deleted.
///
Expand Down

0 comments on commit 1775e6f

Please sign in to comment.