Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ExtractMap instead of HashMap when possible #2797

Merged
merged 1 commit into from
Mar 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ bool_to_bitflags = { version = "0.1.0" }
nonmax = { version = "0.5.5", features = ["serde"] }
strum = { version = "0.26", features = ["derive"] }
to-arraystring = "0.1.0"
extract_map = { version = "0.1.0", features = ["serde", "iter_mut"] }
# Optional dependencies
fxhash = { version = "0.2.1", optional = true }
chrono = { version = "0.4.31", default-features = false, features = ["clock", "serde"], optional = true }
Expand All @@ -49,7 +50,7 @@ mime_guess = { version = "2.0.4", optional = true }
dashmap = { version = "5.5.3", features = ["serde"], optional = true }
parking_lot = { version = "0.12.1", optional = true }
ed25519-dalek = { version = "2.0.0", optional = true }
typesize = { version = "0.1.5", optional = true, features = ["url", "time", "serde_json", "secrecy", "dashmap", "parking_lot", "nonmax", "details"] }
typesize = { version = "0.1.6", optional = true, features = ["url", "time", "serde_json", "secrecy", "dashmap", "parking_lot", "nonmax", "extract_map_01", "details"] }
# serde feature only allows for serialisation,
# Serenity workspace crates
serenity-voice-model = { version = "0.2.0", path = "./voice-model", optional = true }
Expand Down Expand Up @@ -145,3 +146,4 @@ native_tls_backend = [
[package.metadata.docs.rs]
features = ["full"]
rustdoc-args = ["--cfg", "docsrs"]

77 changes: 41 additions & 36 deletions src/cache/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl CacheUpdate for ChannelCreateEvent {
let old_channel = cache
.guilds
.get_mut(&self.channel.guild_id)
.and_then(|mut g| g.channels.insert(self.channel.id, self.channel.clone()));
.and_then(|mut g| g.channels.insert(self.channel.clone()));

old_channel
}
Expand All @@ -72,7 +72,7 @@ impl CacheUpdate for ChannelUpdateEvent {
cache
.guilds
.get_mut(&self.channel.guild_id)
.and_then(|mut g| g.channels.insert(self.channel.id, self.channel.clone()))
.and_then(|mut g| g.channels.insert(self.channel.clone()))
}
}

Expand All @@ -82,7 +82,7 @@ impl CacheUpdate for ChannelPinsUpdateEvent {
fn update(&mut self, cache: &Cache) -> Option<()> {
if let Some(guild_id) = self.guild_id {
if let Some(mut guild) = cache.guilds.get_mut(&guild_id) {
if let Some(channel) = guild.channels.get_mut(&self.channel_id) {
if let Some(mut channel) = guild.channels.get_mut(&self.channel_id) {
channel.last_pin_timestamp = self.last_pin_timestamp;
}
}
Expand Down Expand Up @@ -118,9 +118,9 @@ impl CacheUpdate for GuildDeleteEvent {

match cache.guilds.remove(&self.guild.id) {
Some(guild) => {
for channel_id in guild.1.channels.keys() {
for channel in &guild.1.channels {
// Remove the channel's cached messages.
cache.messages.remove(channel_id);
cache.messages.remove(&channel.id);
}

Some(guild.1)
Expand Down Expand Up @@ -148,7 +148,7 @@ impl CacheUpdate for GuildMemberAddEvent {
fn update(&mut self, cache: &Cache) -> Option<()> {
if let Some(mut guild) = cache.guilds.get_mut(&self.member.guild_id) {
guild.member_count += 1;
guild.members.insert(self.member.user.id, self.member.clone());
guild.members.insert(self.member.clone());
}

None
Expand All @@ -173,7 +173,7 @@ impl CacheUpdate for GuildMemberUpdateEvent {

fn update(&mut self, cache: &Cache) -> Option<Self::Output> {
if let Some(mut guild) = cache.guilds.get_mut(&self.guild_id) {
let item = if let Some(member) = guild.members.get_mut(&self.user.id) {
let item = if let Some(mut member) = guild.members.get_mut(&self.user.id) {
let item = Some(member.clone());

member.joined_at.clone_from(&Some(self.joined_at));
Expand Down Expand Up @@ -213,7 +213,7 @@ impl CacheUpdate for GuildMemberUpdateEvent {
new_member.set_deaf(self.deaf());
new_member.set_mute(self.mute());

guild.members.insert(self.user.id, new_member);
guild.members.insert(new_member);
}

item
Expand All @@ -239,11 +239,7 @@ impl CacheUpdate for GuildRoleCreateEvent {
type Output = ();

fn update(&mut self, cache: &Cache) -> Option<()> {
cache
.guilds
.get_mut(&self.role.guild_id)
.map(|mut g| g.roles.insert(self.role.id, self.role.clone()));

cache.guilds.get_mut(&self.role.guild_id).map(|mut g| g.roles.insert(self.role.clone()));
None
}
}
Expand All @@ -261,8 +257,8 @@ impl CacheUpdate for GuildRoleUpdateEvent {

fn update(&mut self, cache: &Cache) -> Option<Self::Output> {
if let Some(mut guild) = cache.guilds.get_mut(&self.role.guild_id) {
if let Some(role) = guild.roles.get_mut(&self.role.id) {
return Some(std::mem::replace(role, self.role.clone()));
if let Some(mut role) = guild.roles.get_mut(&self.role.id) {
return Some(std::mem::replace(&mut *role, self.role.clone()));
}
}

Expand Down Expand Up @@ -328,9 +324,14 @@ impl CacheUpdate for MessageCreateEvent {
let guild = self.message.guild_id.and_then(|g_id| cache.guilds.get_mut(&g_id));

if let Some(mut guild) = guild {
if let Some(channel) = guild.channels.get_mut(&self.message.channel_id) {
update_channel_last_message_id(&self.message, channel, cache);
} else {
let mut found_channel = false;
if let Some(mut channel) = guild.channels.get_mut(&self.message.channel_id) {
update_channel_last_message_id(&self.message, &mut channel, cache);
found_channel = true;
}

// found_channel is to avoid limitations of the NLL borrow checker.
if !found_channel {
// This may be a thread.
let thread =
guild.threads.iter_mut().find(|thread| thread.id == self.message.channel_id);
Expand Down Expand Up @@ -403,25 +404,27 @@ impl CacheUpdate for PresenceUpdateEvent {
if self.presence.status == OnlineStatus::Offline {
guild.presences.remove(&self.presence.user.id);
} else {
guild.presences.insert(self.presence.user.id, self.presence.clone());
guild.presences.insert(self.presence.clone());
}

// Create a partial member instance out of the presence update data.
if let Some(user) = self.presence.user.to_user() {
guild.members.entry(self.presence.user.id).or_insert_with(|| Member {
guild_id,
joined_at: None,
nick: None,
user,
roles: FixedArray::default(),
premium_since: None,
permissions: None,
avatar: None,
communication_disabled_until: None,
flags: GuildMemberFlags::default(),
unusual_dm_activity_until: None,
__generated_flags: MemberGeneratedFlags::empty(),
});
if !guild.members.contains_key(&self.presence.user.id) {
guild.members.insert(Member {
guild_id,
joined_at: None,
nick: None,
user,
roles: FixedArray::default(),
premium_since: None,
permissions: None,
avatar: None,
communication_disabled_until: None,
flags: GuildMemberFlags::default(),
unusual_dm_activity_until: None,
__generated_flags: MemberGeneratedFlags::empty(),
});
}
}
}
}
Expand Down Expand Up @@ -566,12 +569,14 @@ impl CacheUpdate for VoiceStateUpdateEvent {
if let Some(guild_id) = self.voice_state.guild_id {
if let Some(mut guild) = cache.guilds.get_mut(&guild_id) {
if let Some(member) = &self.voice_state.member {
guild.members.insert(member.user.id, member.clone());
guild.members.insert(member.clone());
}

if self.voice_state.channel_id.is_some() {
// Update or add to the voice state list
guild.voice_states.insert(self.voice_state.user_id, self.voice_state.clone())
let old_state = guild.voice_states.remove(&self.voice_state.user_id);
guild.voice_states.insert(self.voice_state.clone());
old_state
} else {
// Remove the user from the voice state list
guild.voice_states.remove(&self.voice_state.user_id)
Expand All @@ -590,7 +595,7 @@ impl CacheUpdate for VoiceChannelStatusUpdateEvent {

fn update(&mut self, cache: &Cache) -> Option<Self::Output> {
let mut guild = cache.guilds.get_mut(&self.guild_id)?;
let channel = guild.channels.get_mut(&self.id)?;
let mut channel = guild.channels.get_mut(&self.id)?;

let old = channel.status.clone();
channel.status.clone_from(&self.status);
Expand Down
18 changes: 8 additions & 10 deletions src/cache/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,16 +445,14 @@ impl Cache {
}

/// Clones all channel categories in the given guild and returns them.
pub fn guild_categories(&self, guild_id: GuildId) -> Option<HashMap<ChannelId, GuildChannel>> {
pub fn guild_categories(
&self,
guild_id: GuildId,
) -> Option<ExtractMap<ChannelId, GuildChannel>> {
let guild = self.guilds.get(&guild_id)?;
Some(
guild
.channels
.iter()
.filter(|(_id, channel)| channel.kind == ChannelType::Category)
.map(|(id, channel)| (*id, channel.clone()))
.collect(),
)

let filter = |channel: &&GuildChannel| channel.kind == ChannelType::Category;
Some(guild.channels.iter().filter(filter).cloned().collect())
}

/// Inserts new messages into the message cache for a channel manually.
Expand Down Expand Up @@ -571,7 +569,7 @@ mod test {
let mut guild_create = GuildCreateEvent {
guild: Guild {
id: GuildId::new(1),
channels: HashMap::from([(ChannelId::new(2), channel)]),
channels: ExtractMap::from_iter([channel]),
..Default::default()
},
};
Expand Down
4 changes: 2 additions & 2 deletions src/client/event_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ event_handler! {
/// Dispatched when the emojis are updated.
///
/// Provides the guild's id and the new state of the emojis in the guild.
GuildEmojisUpdate { guild_id: GuildId, current_state: HashMap<EmojiId, Emoji> } => async fn guild_emojis_update(&self, ctx: Context);
GuildEmojisUpdate { guild_id: GuildId, current_state: ExtractMap<EmojiId, Emoji> } => async fn guild_emojis_update(&self, ctx: Context);

/// Dispatched when a guild's integration is added, updated or removed.
///
Expand Down Expand Up @@ -245,7 +245,7 @@ event_handler! {
/// Dispatched when the stickers are updated.
///
/// Provides the guild's id and the new state of the stickers in the guild.
GuildStickersUpdate { guild_id: GuildId, current_state: HashMap<StickerId, Sticker> } => async fn guild_stickers_update(&self, ctx: Context);
GuildStickersUpdate { guild_id: GuildId, current_state: ExtractMap<StickerId, Sticker> } => async fn guild_stickers_update(&self, ctx: Context);

/// Dispatched when the guild is updated.
///
Expand Down
7 changes: 5 additions & 2 deletions src/http/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3134,7 +3134,10 @@ impl Http {
}

/// Gets all channels in a guild.
pub async fn get_channels(&self, guild_id: GuildId) -> Result<Vec<GuildChannel>> {
pub async fn get_channels(
&self,
guild_id: GuildId,
) -> Result<ExtractMap<ChannelId, GuildChannel>> {
self.fire(Request {
body: None,
multipart: None,
Expand Down Expand Up @@ -3639,7 +3642,7 @@ impl Http {
}

/// Retrieves a list of roles in a [`Guild`].
pub async fn get_guild_roles(&self, guild_id: GuildId) -> Result<Vec<Role>> {
pub async fn get_guild_roles(&self, guild_id: GuildId) -> Result<ExtractMap<RoleId, Role>> {
let mut value: Value = self
.fire(Request {
body: None,
Expand Down
1 change: 1 addition & 0 deletions src/internal/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
pub use std::result::Result as StdResult;

pub use extract_map::{ExtractKey, ExtractMap, LendingIterator};
pub use serde_json::Value;
pub use small_fixed_array::{FixedArray, FixedString, TruncatingInto};

Expand Down
11 changes: 11 additions & 0 deletions src/internal/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,14 @@ pub(crate) fn join_to_string(
buf.truncate(buf.len() - 1);
buf
}

// Required because of https://github.com/Crazytieguy/gat-lending-iterator/issues/31
macro_rules! lending_for_each {
($iter:expr, |$item:ident| $body:expr ) => {
while let Some(mut $item) = $iter.next() {
$body
}
};
}

pub(crate) use lending_for_each;
46 changes: 35 additions & 11 deletions src/model/application/command_interaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::client::Context;
#[cfg(feature = "model")]
use crate::http::Http;
use crate::internal::prelude::*;
use crate::internal::utils::lending_for_each;
use crate::model::application::{CommandOptionType, CommandType};
use crate::model::channel::{Attachment, Message, PartialChannel};
use crate::model::guild::{Member, PartialMember, Role};
Expand Down Expand Up @@ -243,7 +244,9 @@ impl<'de> Deserialize<'de> for CommandInteraction {
// If `member` is present, `user` wasn't sent and is still filled with default data
interaction.user = member.user.clone();
}
interaction.data.resolved.roles.values_mut().for_each(|r| r.guild_id = guild_id);

let mut role_iter = interaction.data.resolved.roles.iter_mut();
lending_for_each!(role_iter, |r| r.guild_id = guild_id);
}
Ok(interaction)
}
Expand Down Expand Up @@ -476,23 +479,44 @@ pub enum ResolvedTarget<'a> {
#[non_exhaustive]
pub struct CommandDataResolved {
/// The resolved users.
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub users: HashMap<UserId, User>,
#[serde(
default,
skip_serializing_if = "ExtractMap::is_empty",
serialize_with = "extract_map::serialize_as_map"
)]
pub users: ExtractMap<UserId, User>,
/// The resolved partial members.
// Cannot use ExtractMap, as PartialMember does not always store an ID.
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub members: HashMap<UserId, PartialMember>,
/// The resolved roles.
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub roles: HashMap<RoleId, Role>,
#[serde(
default,
skip_serializing_if = "ExtractMap::is_empty",
serialize_with = "extract_map::serialize_as_map"
)]
pub roles: ExtractMap<RoleId, Role>,
/// The resolved partial channels.
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub channels: HashMap<ChannelId, PartialChannel>,
#[serde(
default,
skip_serializing_if = "ExtractMap::is_empty",
serialize_with = "extract_map::serialize_as_map"
)]
pub channels: ExtractMap<ChannelId, PartialChannel>,
/// The resolved messages.
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub messages: HashMap<MessageId, Message>,
#[serde(
default,
skip_serializing_if = "ExtractMap::is_empty",
serialize_with = "extract_map::serialize_as_map"
)]
pub messages: ExtractMap<MessageId, Message>,
/// The resolved attachments.
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub attachments: HashMap<AttachmentId, Attachment>,
#[serde(
default,
skip_serializing_if = "ExtractMap::is_empty",
serialize_with = "extract_map::serialize_as_map"
)]
pub attachments: ExtractMap<AttachmentId, Attachment>,
}

/// A set of a parameter and a value from the user.
Expand Down
6 changes: 6 additions & 0 deletions src/model/channel/attachment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,9 @@ impl Attachment {
Ok(bytes.to_vec())
}
}

impl ExtractKey<AttachmentId> for Attachment {
fn extract_key(&self) -> &AttachmentId {
&self.id
}
}
Loading
Loading