Skip to content

Commit

Permalink
Reduce size of shard info and turn it into a struct (serenity-rs#1984)
Browse files Browse the repository at this point in the history
This replaces the `[u64; 2]` shard info with `ShardInfo { id: u32, total: u32 }`.

Co-authored-by: nickelc <constantin.nickel@gmail.com>
  • Loading branch information
2 people authored and arqunis committed Oct 24, 2023
1 parent 2ed4869 commit 6e5867e
Show file tree
Hide file tree
Showing 20 changed files with 177 additions and 133 deletions.
2 changes: 1 addition & 1 deletion examples/e08_shard_manager/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl EventHandler for Handler {
// Note that array index 0 is 0-indexed, while index 1 is 1-indexed.
//
// This may seem unintuitive, but it models Discord's behaviour.
println!("{} is connected on shard {}/{}!", ready.user.name, shard[0], shard[1]);
println!("{} is connected on shard {}/{}!", ready.user.name, shard.id, shard.total);
}
}
}
Expand Down
8 changes: 5 additions & 3 deletions src/cache/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use crate::model::event::{
UserUpdateEvent,
VoiceStateUpdateEvent,
};
use crate::model::gateway::ShardInfo;
use crate::model::guild::{Guild, Member, Role};
use crate::model::user::{CurrentUser, OnlineStatus};
use crate::model::voice::VoiceState;
Expand Down Expand Up @@ -588,11 +589,12 @@ impl CacheUpdate for ReadyEvent {
let mut guilds_to_remove = vec![];
let ready_guilds_hashset =
self.ready.guilds.iter().map(|status| status.id).collect::<HashSet<_>>();
let shard_data = self.ready.shard.unwrap_or([1, 1]);
let shard_data = self.ready.shard.unwrap_or_else(|| ShardInfo::new(1, 1));

for guild_entry in cache.guilds.iter() {
let guild = guild_entry.key();
// Only handle data for our shard.
if crate::utils::shard_id(guild.0, shard_data[1]) == shard_data[0]
if crate::utils::shard_id(guild.0, shard_data.total) == shard_data.id
&& !ready_guilds_hashset.contains(guild)
{
guilds_to_remove.push(*guild);
Expand All @@ -618,7 +620,7 @@ impl CacheUpdate for ReadyEvent {
cache.presences.insert(*user_id, presence.clone());
}

*cache.shard_count.write() = ready.shard.map_or(1, |s| s[1]);
*cache.shard_count.write() = ready.shard.map_or(1, |s| s.total);
*cache.user.write() = ready.user;

None
Expand Down
4 changes: 2 additions & 2 deletions src/cache/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ pub struct Cache {
/// other users.
pub(crate) private_channels: DashMap<ChannelId, PrivateChannel>,
/// The total number of shards being used by the bot.
pub(crate) shard_count: RwLock<u64>,
pub(crate) shard_count: RwLock<u32>,
/// A list of guilds which are "unavailable".
///
/// Additionally, guilds are always unavailable for bot users when a Ready
Expand Down Expand Up @@ -681,7 +681,7 @@ impl Cache {

/// Returns the number of shards.
#[inline]
pub fn shard_count(&self) -> u64 {
pub fn shard_count(&self) -> u32 {
*self.shard_count.read()
}

Expand Down
4 changes: 2 additions & 2 deletions src/client/bridge/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ pub enum ShardQueuerMessage {
ShutdownShard(ShardId, u16),
}

/// A light tuplestruct wrapper around a u64 to verify type correctness when
/// A light tuplestruct wrapper around a u32 to verify type correctness when
/// working with the IDs of shards.
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct ShardId(pub u64);
pub struct ShardId(pub u32);

impl fmt::Display for ShardId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand Down
14 changes: 7 additions & 7 deletions src/client/bridge/gateway/shard_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,11 @@ pub struct ShardManager {
/// where possible.
pub runners: Arc<Mutex<HashMap<ShardId, ShardRunnerInfo>>>,
/// The index of the first shard to initialize, 0-indexed.
shard_index: u64,
shard_index: u32,
/// The number of shards to initialize.
shard_init: u64,
shard_init: u32,
/// The total shards in use, 1-indexed.
shard_total: u64,
shard_total: u32,
shard_queuer: Sender<ShardQueuerMessage>,
shard_shutdown: Receiver<ShardId>,
}
Expand Down Expand Up @@ -196,7 +196,7 @@ impl ShardManager {
///
/// This will _not_ instantiate the new shards.
#[instrument(skip(self))]
pub async fn set_shards(&mut self, index: u64, init: u64, total: u64) {
pub async fn set_shards(&mut self, index: u32, init: u32, total: u32) {
self.shutdown_all().await;

self.shard_index = index;
Expand Down Expand Up @@ -351,9 +351,9 @@ pub struct ShardManagerOptions {
pub raw_event_handler: Option<Arc<dyn RawEventHandler>>,
#[cfg(feature = "framework")]
pub framework: Option<Arc<dyn Framework + Send + Sync>>,
pub shard_index: u64,
pub shard_init: u64,
pub shard_total: u64,
pub shard_index: u32,
pub shard_init: u32,
pub shard_total: u32,
#[cfg(feature = "voice")]
pub voice_manager: Option<Arc<dyn VoiceGatewayManager + Send + Sync + 'static>>,
pub ws_url: Arc<Mutex<String>>,
Expand Down
45 changes: 32 additions & 13 deletions src/client/bridge/gateway/shard_messenger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,19 @@ impl ShardMessenger {
///
/// ```rust,no_run
/// # use tokio::sync::Mutex;
/// # use serenity::model::gateway::GatewayIntents;
/// # use serenity::model::gateway::{GatewayIntents, ShardInfo};
/// # use serenity::client::bridge::gateway::ChunkGuildFilter;
/// # use serenity::gateway::Shard;
/// # use std::sync::Arc;
/// #
/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
/// # let mutex = Arc::new(Mutex::new("".to_string()));
/// #
/// # let mut shard = Shard::new(mutex.clone(), "", [0u64, 1u64],
/// # GatewayIntents::all()).await?;
/// # let shard_info = ShardInfo {
/// # id: 0,
/// # total: 1,
/// # };
/// # let mut shard = Shard::new(mutex.clone(), "", shard_info, GatewayIntents::all()).await?;
/// #
/// use serenity::model::id::GuildId;
///
Expand All @@ -82,16 +85,20 @@ impl ShardMessenger {
///
/// ```rust,no_run
/// # use tokio::sync::Mutex;
/// # use serenity::model::gateway::GatewayIntents;
/// # use serenity::model::gateway::{GatewayIntents, ShardInfo};
/// # use serenity::client::bridge::gateway::ChunkGuildFilter;
/// # use serenity::gateway::Shard;
/// # use std::sync::Arc;
/// #
/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
/// # let mutex = Arc::new(Mutex::new("".to_string()));
/// #
/// # let mut shard = Shard::new(mutex.clone(), "", [0u64, 1u64],
/// # GatewayIntents::all()).await?;
/// # let shard_info = ShardInfo {
/// # id: 0,
/// # total: 1,
/// # };
/// #
/// # let mut shard = Shard::new(mutex.clone(), "", shard_info, GatewayIntents::all()).await?;;
/// #
/// use serenity::model::id::GuildId;
///
Expand Down Expand Up @@ -130,14 +137,18 @@ impl ShardMessenger {
/// ```rust,no_run
/// # use tokio::sync::Mutex;
/// # use serenity::gateway::Shard;
/// # use serenity::model::gateway::GatewayIntents;
/// # use serenity::model::gateway::{GatewayIntents, ShardInfo};
/// # use std::sync::Arc;
/// #
/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
/// # let mutex = Arc::new(Mutex::new("".to_string()));
/// #
/// # let mut shard = Shard::new(mutex.clone(), "", [0u64, 1u64],
/// # GatewayIntents::all()).await?;
/// # let shard_info = ShardInfo {
/// # id: 0,
/// # total: 1,
/// # };
/// #
/// # let mut shard = Shard::new(mutex.clone(), "", shard_info, GatewayIntents::all()).await?;
/// use serenity::model::gateway::Activity;
///
/// shard.set_activity(Some(Activity::playing("Heroes of the Storm")));
Expand Down Expand Up @@ -166,7 +177,12 @@ impl ShardMessenger {
/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
/// # let mutex = Arc::new(Mutex::new("".to_string()));
/// #
/// # let mut shard = Shard::new(mutex.clone(), "", [0u64, 1u64], None).await?;
/// # let shard_info = ShardInfo {
/// # id: 0,
/// # total: 1,
/// # };
/// #
/// # let mut shard = Shard::new(mutex.clone(), "", shard_info, None).await?;
/// #
/// use serenity::model::gateway::Activity;
/// use serenity::model::user::OnlineStatus;
Expand Down Expand Up @@ -198,14 +214,17 @@ impl ShardMessenger {
/// ```rust,no_run
/// # use tokio::sync::Mutex;
/// # use serenity::gateway::Shard;
/// # use serenity::model::gateway::GatewayIntents;
/// # use serenity::model::gateway::{GatewayIntents, ShardInfo};
/// # use std::sync::Arc;
/// #
/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
/// # let mutex = Arc::new(Mutex::new("".to_string()));
/// # let shard_info = ShardInfo {
/// # id: 0,
/// # total: 1,
/// # };
/// #
/// # let mut shard = Shard::new(mutex.clone(), "", [0u64, 1u64],
/// # GatewayIntents::all()).await?;
/// # let mut shard = Shard::new(mutex.clone(), "", shard_info, GatewayIntents::all()).await?;
/// #
/// use serenity::model::user::OnlineStatus;
///
Expand Down
18 changes: 9 additions & 9 deletions src/client/bridge/gateway/shard_queuer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::framework::Framework;
use crate::gateway::{ConnectionStage, InterMessage, Shard};
use crate::internal::prelude::*;
use crate::internal::tokio::spawn_named;
use crate::model::gateway::GatewayIntents;
use crate::model::gateway::{GatewayIntents, ShardInfo};
use crate::CacheAndHttp;

const WAIT_BETWEEN_BOOTS_IN_SECONDS: u64 = 5;
Expand Down Expand Up @@ -68,7 +68,7 @@ pub struct ShardQueuer {
/// The shards that are queued for booting.
///
/// This will typically be filled with previously failed boots.
pub queue: VecDeque<(u64, u64)>,
pub queue: VecDeque<ShardInfo>,
/// A copy of the map of shard runners.
pub runners: Arc<Mutex<HashMap<ShardId, ShardRunnerInfo>>>,
/// A receiver channel for the shard queuer to be told to start shards.
Expand Down Expand Up @@ -125,8 +125,8 @@ impl ShardQueuer {
},
Ok(None) => break,
Err(_) => {
if let Some((id, total)) = self.queue.pop_front() {
self.checked_start(id, total).await;
if let Some(shard) = self.queue.pop_front() {
self.checked_start(shard.id, shard.total).await;
}
},
}
Expand Down Expand Up @@ -155,23 +155,23 @@ impl ShardQueuer {
}

#[instrument(skip(self))]
async fn checked_start(&mut self, id: u64, total: u64) {
async fn checked_start(&mut self, id: u32, total: u32) {
debug!("[Shard Queuer] Checked start for shard {} out of {}", id, total);
self.check_last_start().await;

if let Err(why) = self.start(id, total).await {
warn!("[Shard Queuer] Err starting shard {}: {:?}", id, why);
info!("[Shard Queuer] Re-queueing start of shard {}", id);

self.queue.push_back((id, total));
self.queue.push_back(ShardInfo::new(id, total));
}

self.last_start = Some(Instant::now());
}

#[instrument(skip(self))]
async fn start(&mut self, shard_id: u64, shard_total: u64) -> Result<()> {
let shard_info = [shard_id, shard_total];
async fn start(&mut self, id: u32, total: u32) -> Result<()> {
let shard_info = ShardInfo::new(id, total);

let mut shard = Shard::new(
Arc::clone(&self.ws_url),
Expand Down Expand Up @@ -207,7 +207,7 @@ impl ShardQueuer {
debug!("[ShardRunner {:?}] Stopping", runner.shard.shard_info());
});

self.runners.lock().await.insert(ShardId(shard_id), runner_info);
self.runners.lock().await.insert(ShardId(id), runner_info);

Ok(())
}
Expand Down
12 changes: 6 additions & 6 deletions src/client/bridge/gateway/shard_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ impl ShardRunner {
let e = ClientEvent::ShardStageUpdate(ShardStageUpdateEvent {
new: post,
old: pre,
shard_id: ShardId(self.shard.shard_info()[0]),
shard_id: ShardId(self.shard.shard_info().id),
});

self.dispatch(DispatchEvent::Client(e)).await;
Expand Down Expand Up @@ -275,7 +275,7 @@ impl ShardRunner {
async fn checked_shutdown(&mut self, id: ShardId, close_code: u16) -> bool {
// First verify the ID so we know for certain this runner is
// to shutdown.
if id.0 != self.shard.shard_info()[0] {
if id.0 != self.shard.shard_info().id {
// Not meant for this runner for some reason, don't
// shutdown.
return true;
Expand Down Expand Up @@ -331,7 +331,7 @@ impl ShardRunner {
&self.event_handler,
&self.raw_event_handler,
&self.runner_tx,
self.shard.shard_info()[0],
self.shard.shard_info().id,
Arc::clone(&self.cache_and_http),
)
.await;
Expand Down Expand Up @@ -469,7 +469,7 @@ impl ShardRunner {
match *event {
Event::Ready(_) => {
voice_manager
.register_shard(self.shard.shard_info()[0], self.runner_tx.clone())
.register_shard(self.shard.shard_info().id, self.runner_tx.clone())
.await;
},
Event::VoiceServerUpdate(ref event) => {
Expand Down Expand Up @@ -627,7 +627,7 @@ impl ShardRunner {
self.update_manager();

debug!("[ShardRunner {:?}] Requesting restart", self.shard.shard_info(),);
let shard_id = ShardId(self.shard.shard_info()[0]);
let shard_id = ShardId(self.shard.shard_info().id);
let msg = ShardManagerMessage::Restart(shard_id);

if let Err(error) = self.manager_tx.unbounded_send(msg) {
Expand All @@ -645,7 +645,7 @@ impl ShardRunner {
#[instrument(skip(self))]
fn update_manager(&self) {
drop(self.manager_tx.unbounded_send(ShardManagerMessage::ShardUpdate {
id: ShardId(self.shard.shard_info()[0]),
id: ShardId(self.shard.shard_info().id),
latency: self.shard.latency(),
stage: self.shard.stage(),
}));
Expand Down
6 changes: 3 additions & 3 deletions src/client/bridge/voice/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,23 @@ pub trait VoiceGatewayManager: Send + Sync {
/// Performs initial setup at the start of a connection to Discord.
///
/// This will only occur once, and provides the bot's ID and shard count.
async fn initialise(&self, shard_count: u64, user_id: UserId);
async fn initialise(&self, shard_count: u32, user_id: UserId);

/// Handler fired in response to a [`Ready`] event.
///
/// This provides the voice plugin with a channel to send gateway messages to Discord,
/// once per active shard.
///
/// [`Ready`]: crate::model::event::Event
async fn register_shard(&self, shard_id: u64, sender: Sender<InterMessage>);
async fn register_shard(&self, shard_id: u32, sender: Sender<InterMessage>);

/// Handler fired in response to a disconnect, reconnection, or rebalance.
///
/// This event invalidates the last sender associated with `shard_id`.
/// Unless the bot is fully disconnecting, this is often followed by a call
/// to [`Self::register_shard`]. Users may wish to buffer manually any gateway messages
/// sent between these calls.
async fn deregister_shard(&self, shard_id: u64);
async fn deregister_shard(&self, shard_id: u32);

/// Handler for VOICE_SERVER_UPDATE messages.
///
Expand Down
8 changes: 4 additions & 4 deletions src/client/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub struct Context {
/// The messenger to communicate with the shard runner.
pub shard: ShardMessenger,
/// The ID of the shard this context is related to.
pub shard_id: u64,
pub shard_id: u32,
pub http: Arc<Http>,
#[cfg(feature = "cache")]
pub cache: Arc<Cache>,
Expand All @@ -51,7 +51,7 @@ impl Context {
pub(crate) fn new(
data: Arc<RwLock<TypeMap>>,
runner_tx: Sender<InterMessage>,
shard_id: u64,
shard_id: u32,
http: Arc<Http>,
cache: Arc<Cache>,
) -> Context {
Expand All @@ -65,7 +65,7 @@ impl Context {
}

#[cfg(all(not(feature = "cache"), not(feature = "gateway")))]
pub fn easy(data: Arc<RwLock<TypeMap>>, shard_id: u64, http: Arc<Http>) -> Context {
pub fn easy(data: Arc<RwLock<TypeMap>>, shard_id: u32, http: Arc<Http>) -> Context {
Context {
shard_id,
data,
Expand All @@ -78,7 +78,7 @@ impl Context {
pub(crate) fn new(
data: Arc<RwLock<TypeMap>>,
runner_tx: Sender<InterMessage>,
shard_id: u64,
shard_id: u32,
http: Arc<Http>,
) -> Context {
Context {
Expand Down
Loading

0 comments on commit 6e5867e

Please sign in to comment.