diff --git a/src/cache/event.rs b/src/cache/event.rs index 9c55feda22e..0bb48f7b8c6 100644 --- a/src/cache/event.rs +++ b/src/cache/event.rs @@ -188,7 +188,7 @@ impl CacheUpdate for GuildCreateEvent { for (user_id, member) in &mut guild.members { cache.update_user_entry(&member.user); if let Some(u) = cache.user(user_id) { - member.user = u; + member.user = u.clone(); } } @@ -266,7 +266,7 @@ impl CacheUpdate for GuildMemberAddEvent { let user_id = self.member.user.id; cache.update_user_entry(&self.member.user); if let Some(u) = cache.user(user_id) { - self.member.user = u; + self.member.user = u.clone(); } if let Some(mut guild) = cache.guilds.get_mut(&self.member.guild_id) { @@ -521,7 +521,7 @@ impl CacheUpdate for PresenceUpdateEvent { } if let Some(user) = cache.user(self.presence.user.id) { - self.presence.user.update_with_user(user); + self.presence.user.update_with_user(&user); } if let Some(guild_id) = self.presence.guild_id { @@ -614,7 +614,7 @@ impl CacheUpdate for ReadyEvent { cache.update_user_entry(&user); } if let Some(user) = cache.user(user_id) { - presence.user.update_with_user(user); + presence.user.update_with_user(&user); } cache.presences.insert(*user_id, presence.clone()); diff --git a/src/cache/mod.rs b/src/cache/mod.rs index eb014751955..f96eada412a 100644 --- a/src/cache/mod.rs +++ b/src/cache/mod.rs @@ -34,6 +34,8 @@ use std::collections::{HashMap, VecDeque}; use std::hash::Hash; use std::str::FromStr; #[cfg(feature = "temp_cache")] +use std::sync::Arc; +#[cfg(feature = "temp_cache")] use std::time::Duration; use dashmap::mapref::entry::Entry; @@ -57,6 +59,8 @@ type MessageCache = DashMap>; struct NotSend; enum CacheRefInner<'a, K, V> { + #[cfg(feature = "temp_cache")] + Arc(Arc), DashRef(Ref<'a, K, V>), ReadGuard(parking_lot::RwLockReadGuard<'a, V>), } @@ -74,6 +78,11 @@ impl<'a, K, V> CacheRef<'a, K, V> { } } + #[cfg(feature = "temp_cache")] + fn from_arc(inner: Arc) -> Self { + Self::new(CacheRefInner::Arc(inner)) + } + fn from_ref(inner: Ref<'a, K, V>) -> Self { Self::new(CacheRefInner::DashRef(inner)) } @@ -88,12 +97,15 @@ impl std::ops::Deref for CacheRef<'_, K, V> { fn deref(&self) -> &Self::Target { match &self.inner { + #[cfg(feature = "temp_cache")] + CacheRefInner::Arc(inner) => &*inner, CacheRefInner::DashRef(inner) => inner.value(), CacheRefInner::ReadGuard(inner) => &*inner, } } } +pub type UserRef<'a> = CacheRef<'a, UserId, User>; pub type GuildRef<'a> = CacheRef<'a, GuildId, Guild>; pub type CurrentUserRef<'a> = CacheRef<'a, (), CurrentUser>; pub type GuildChannelRef<'a> = CacheRef<'a, ChannelId, GuildChannel>; @@ -215,7 +227,7 @@ pub struct Cache { /// /// Each value has a max TTL of 1 hour. #[cfg(feature = "temp_cache")] - pub(crate) temp_users: DashCache, + pub(crate) temp_users: DashCache>, /// The settings for the cache. settings: RwLock, } @@ -836,22 +848,22 @@ impl Cache { /// # } /// ``` #[inline] - pub fn user>(&self, user_id: U) -> Option { + pub fn user>(&self, user_id: U) -> Option> { self._user(user_id.into()) } #[cfg(feature = "temp_cache")] - fn _user(&self, user_id: UserId) -> Option { + fn _user(&self, user_id: UserId) -> Option> { if let Some(user) = self.users.get(&user_id) { - Some(user.clone()) + Some(CacheRef::from_ref(user)) } else { - self.temp_users.get(&user_id) + self.temp_users.get(&user_id).map(CacheRef::from_arc) } } #[cfg(not(feature = "temp_cache"))] - fn _user(&self, user_id: UserId) -> Option { - self.users.get(&user_id).map(|u| u.clone()) + fn _user(&self, user_id: UserId) -> Option> { + self.users.get(&user_id).map(CacheRef::from_ref) } /// Clones all users and returns them. diff --git a/src/model/gateway.rs b/src/model/gateway.rs index f1371ca34a4..c7c9bfd2105 100644 --- a/src/model/gateway.rs +++ b/src/model/gateway.rs @@ -466,14 +466,14 @@ impl PresenceUser { } #[cfg(feature = "cache")] // method is only used with the cache feature enabled - pub(crate) fn update_with_user(&mut self, user: User) { + pub(crate) fn update_with_user(&mut self, user: &User) { self.id = user.id; - if let Some(avatar) = user.avatar { - self.avatar = Some(avatar); + if let Some(avatar) = &user.avatar { + self.avatar = Some(avatar.clone()); } self.bot = Some(user.bot); self.discriminator = Some(user.discriminator); - self.name = Some(user.name); + self.name = Some(user.name.clone()); if let Some(public_flags) = user.public_flags { self.public_flags = Some(public_flags); } diff --git a/src/model/user.rs b/src/model/user.rs index 795f0449659..e3fb3052415 100644 --- a/src/model/user.rs +++ b/src/model/user.rs @@ -3,6 +3,8 @@ use std::fmt; #[cfg(feature = "model")] use std::fmt::Write; +#[cfg(feature = "temp_cache")] +use std::sync::Arc; use serde::{Deserialize, Serialize}; @@ -10,7 +12,7 @@ use super::prelude::*; #[cfg(feature = "model")] use crate::builder::{CreateBotAuthParameters, CreateMessage, EditProfile}; #[cfg(all(feature = "cache", feature = "model"))] -use crate::cache::Cache; +use crate::cache::{Cache, UserRef}; #[cfg(feature = "collector")] use crate::client::bridge::gateway::ShardMessenger; #[cfg(feature = "collector")] @@ -1106,7 +1108,7 @@ impl UserId { #[cfg(feature = "cache")] #[allow(clippy::unused_async)] #[inline] - pub async fn to_user_cached(self, cache: impl AsRef) -> Option { + pub async fn to_user_cached(self, cache: &impl AsRef) -> Option> { cache.as_ref().user(self) } @@ -1131,7 +1133,7 @@ impl UserId { { if let Some(cache) = cache_http.cache() { if let Some(user) = cache.user(self) { - return Ok(user); + return Ok(user.clone()); } } } @@ -1141,7 +1143,7 @@ impl UserId { #[cfg(all(feature = "cache", feature = "temp_cache"))] { if let Some(cache) = cache_http.cache() { - cache.temp_users.insert(user.id, user.clone()); + cache.temp_users.insert(user.id, Arc::new(user.clone())); } } diff --git a/src/utils/content_safe.rs b/src/utils/content_safe.rs index 79ddb3e9233..b1a227cc1f5 100644 --- a/src/utils/content_safe.rs +++ b/src/utils/content_safe.rs @@ -256,10 +256,10 @@ fn clean_mention( } .into() }; + cache .user(id) - .as_ref() - .map(get_username) + .map(|u| get_username(&u)) .or_else(|| users.iter().find(|u| u.id == id).map(get_username)) .unwrap_or(Cow::Borrowed("@invalid-user")) },