diff --git a/examples/e08_shard_manager/src/main.rs b/examples/e08_shard_manager/src/main.rs index 259c956ecc2..980a9af791f 100644 --- a/examples/e08_shard_manager/src/main.rs +++ b/examples/e08_shard_manager/src/main.rs @@ -39,7 +39,7 @@ impl EventHandler for Handler { // Note that array index 0 is 0-indexed, while index 1 is 1-indexed. // // This may seem unintuitive, but it models Discord's behaviour. - println!("{} is connected on shard {}/{}!", ready.user.name, shard[0], shard[1]); + println!("{} is connected on shard {}/{}!", ready.user.name, shard.id, shard.total); } } } diff --git a/src/cache/event.rs b/src/cache/event.rs index c8a71963e6b..9c55feda22e 100644 --- a/src/cache/event.rs +++ b/src/cache/event.rs @@ -30,6 +30,7 @@ use crate::model::event::{ UserUpdateEvent, VoiceStateUpdateEvent, }; +use crate::model::gateway::ShardInfo; use crate::model::guild::{Guild, Member, Role}; use crate::model::user::{CurrentUser, OnlineStatus}; use crate::model::voice::VoiceState; @@ -588,11 +589,12 @@ impl CacheUpdate for ReadyEvent { let mut guilds_to_remove = vec![]; let ready_guilds_hashset = self.ready.guilds.iter().map(|status| status.id).collect::>(); - let shard_data = self.ready.shard.unwrap_or([1, 1]); + let shard_data = self.ready.shard.unwrap_or_else(|| ShardInfo::new(1, 1)); + for guild_entry in cache.guilds.iter() { let guild = guild_entry.key(); // Only handle data for our shard. - if crate::utils::shard_id(guild.0, shard_data[1]) == shard_data[0] + if crate::utils::shard_id(guild.0, shard_data.total) == shard_data.id && !ready_guilds_hashset.contains(guild) { guilds_to_remove.push(*guild); @@ -618,7 +620,7 @@ impl CacheUpdate for ReadyEvent { cache.presences.insert(*user_id, presence.clone()); } - *cache.shard_count.write() = ready.shard.map_or(1, |s| s[1]); + *cache.shard_count.write() = ready.shard.map_or(1, |s| s.total); *cache.user.write() = ready.user; None diff --git a/src/cache/mod.rs b/src/cache/mod.rs index 078777f123f..eb014751955 100644 --- a/src/cache/mod.rs +++ b/src/cache/mod.rs @@ -175,7 +175,7 @@ pub struct Cache { /// other users. pub(crate) private_channels: DashMap, /// The total number of shards being used by the bot. - pub(crate) shard_count: RwLock, + pub(crate) shard_count: RwLock, /// A list of guilds which are "unavailable". /// /// Additionally, guilds are always unavailable for bot users when a Ready @@ -681,7 +681,7 @@ impl Cache { /// Returns the number of shards. #[inline] - pub fn shard_count(&self) -> u64 { + pub fn shard_count(&self) -> u32 { *self.shard_count.read() } diff --git a/src/client/bridge/gateway/mod.rs b/src/client/bridge/gateway/mod.rs index 59cb6409a2c..71055cf5307 100644 --- a/src/client/bridge/gateway/mod.rs +++ b/src/client/bridge/gateway/mod.rs @@ -123,10 +123,10 @@ pub enum ShardQueuerMessage { ShutdownShard(ShardId, u16), } -/// A light tuplestruct wrapper around a u64 to verify type correctness when +/// A light tuplestruct wrapper around a u32 to verify type correctness when /// working with the IDs of shards. #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] -pub struct ShardId(pub u64); +pub struct ShardId(pub u32); impl fmt::Display for ShardId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { diff --git a/src/client/bridge/gateway/shard_manager.rs b/src/client/bridge/gateway/shard_manager.rs index 0fced9310ff..2caf73d1e87 100644 --- a/src/client/bridge/gateway/shard_manager.rs +++ b/src/client/bridge/gateway/shard_manager.rs @@ -105,11 +105,11 @@ pub struct ShardManager { /// where possible. pub runners: Arc>>, /// The index of the first shard to initialize, 0-indexed. - shard_index: u64, + shard_index: u32, /// The number of shards to initialize. - shard_init: u64, + shard_init: u32, /// The total shards in use, 1-indexed. - shard_total: u64, + shard_total: u32, shard_queuer: Sender, shard_shutdown: Receiver, } @@ -196,7 +196,7 @@ impl ShardManager { /// /// This will _not_ instantiate the new shards. #[instrument(skip(self))] - pub async fn set_shards(&mut self, index: u64, init: u64, total: u64) { + pub async fn set_shards(&mut self, index: u32, init: u32, total: u32) { self.shutdown_all().await; self.shard_index = index; @@ -351,9 +351,9 @@ pub struct ShardManagerOptions { pub raw_event_handler: Option>, #[cfg(feature = "framework")] pub framework: Option>, - pub shard_index: u64, - pub shard_init: u64, - pub shard_total: u64, + pub shard_index: u32, + pub shard_init: u32, + pub shard_total: u32, #[cfg(feature = "voice")] pub voice_manager: Option>, pub ws_url: Arc>, diff --git a/src/client/bridge/gateway/shard_messenger.rs b/src/client/bridge/gateway/shard_messenger.rs index 2288ab7ea6a..2343f701c2d 100644 --- a/src/client/bridge/gateway/shard_messenger.rs +++ b/src/client/bridge/gateway/shard_messenger.rs @@ -59,7 +59,7 @@ impl ShardMessenger { /// /// ```rust,no_run /// # use tokio::sync::Mutex; - /// # use serenity::model::gateway::GatewayIntents; + /// # use serenity::model::gateway::{GatewayIntents, ShardInfo}; /// # use serenity::client::bridge::gateway::ChunkGuildFilter; /// # use serenity::gateway::Shard; /// # use std::sync::Arc; @@ -67,8 +67,11 @@ impl ShardMessenger { /// # async fn run() -> Result<(), Box> { /// # let mutex = Arc::new(Mutex::new("".to_string())); /// # - /// # let mut shard = Shard::new(mutex.clone(), "", [0u64, 1u64], - /// # GatewayIntents::all()).await?; + /// # let shard_info = ShardInfo { + /// # id: 0, + /// # total: 1, + /// # }; + /// # let mut shard = Shard::new(mutex.clone(), "", shard_info, GatewayIntents::all()).await?; /// # /// use serenity::model::id::GuildId; /// @@ -82,7 +85,7 @@ impl ShardMessenger { /// /// ```rust,no_run /// # use tokio::sync::Mutex; - /// # use serenity::model::gateway::GatewayIntents; + /// # use serenity::model::gateway::{GatewayIntents, ShardInfo}; /// # use serenity::client::bridge::gateway::ChunkGuildFilter; /// # use serenity::gateway::Shard; /// # use std::sync::Arc; @@ -90,8 +93,12 @@ impl ShardMessenger { /// # async fn run() -> Result<(), Box> { /// # let mutex = Arc::new(Mutex::new("".to_string())); /// # - /// # let mut shard = Shard::new(mutex.clone(), "", [0u64, 1u64], - /// # GatewayIntents::all()).await?; + /// # let shard_info = ShardInfo { + /// # id: 0, + /// # total: 1, + /// # }; + /// # + /// # let mut shard = Shard::new(mutex.clone(), "", shard_info, GatewayIntents::all()).await?;; /// # /// use serenity::model::id::GuildId; /// @@ -130,14 +137,18 @@ impl ShardMessenger { /// ```rust,no_run /// # use tokio::sync::Mutex; /// # use serenity::gateway::Shard; - /// # use serenity::model::gateway::GatewayIntents; + /// # use serenity::model::gateway::{GatewayIntents, ShardInfo}; /// # use std::sync::Arc; /// # /// # async fn run() -> Result<(), Box> { /// # let mutex = Arc::new(Mutex::new("".to_string())); /// # - /// # let mut shard = Shard::new(mutex.clone(), "", [0u64, 1u64], - /// # GatewayIntents::all()).await?; + /// # let shard_info = ShardInfo { + /// # id: 0, + /// # total: 1, + /// # }; + /// # + /// # let mut shard = Shard::new(mutex.clone(), "", shard_info, GatewayIntents::all()).await?; /// use serenity::model::gateway::Activity; /// /// shard.set_activity(Some(Activity::playing("Heroes of the Storm"))); @@ -166,7 +177,12 @@ impl ShardMessenger { /// # async fn run() -> Result<(), Box> { /// # let mutex = Arc::new(Mutex::new("".to_string())); /// # - /// # let mut shard = Shard::new(mutex.clone(), "", [0u64, 1u64], None).await?; + /// # let shard_info = ShardInfo { + /// # id: 0, + /// # total: 1, + /// # }; + /// # + /// # let mut shard = Shard::new(mutex.clone(), "", shard_info, None).await?; /// # /// use serenity::model::gateway::Activity; /// use serenity::model::user::OnlineStatus; @@ -198,14 +214,17 @@ impl ShardMessenger { /// ```rust,no_run /// # use tokio::sync::Mutex; /// # use serenity::gateway::Shard; - /// # use serenity::model::gateway::GatewayIntents; + /// # use serenity::model::gateway::{GatewayIntents, ShardInfo}; /// # use std::sync::Arc; /// # /// # async fn run() -> Result<(), Box> { /// # let mutex = Arc::new(Mutex::new("".to_string())); + /// # let shard_info = ShardInfo { + /// # id: 0, + /// # total: 1, + /// # }; /// # - /// # let mut shard = Shard::new(mutex.clone(), "", [0u64, 1u64], - /// # GatewayIntents::all()).await?; + /// # let mut shard = Shard::new(mutex.clone(), "", shard_info, GatewayIntents::all()).await?; /// # /// use serenity::model::user::OnlineStatus; /// diff --git a/src/client/bridge/gateway/shard_queuer.rs b/src/client/bridge/gateway/shard_queuer.rs index e6633eb01fa..73f37254d12 100644 --- a/src/client/bridge/gateway/shard_queuer.rs +++ b/src/client/bridge/gateway/shard_queuer.rs @@ -26,7 +26,7 @@ use crate::framework::Framework; use crate::gateway::{ConnectionStage, InterMessage, Shard}; use crate::internal::prelude::*; use crate::internal::tokio::spawn_named; -use crate::model::gateway::GatewayIntents; +use crate::model::gateway::{GatewayIntents, ShardInfo}; use crate::CacheAndHttp; const WAIT_BETWEEN_BOOTS_IN_SECONDS: u64 = 5; @@ -68,7 +68,7 @@ pub struct ShardQueuer { /// The shards that are queued for booting. /// /// This will typically be filled with previously failed boots. - pub queue: VecDeque<(u64, u64)>, + pub queue: VecDeque, /// A copy of the map of shard runners. pub runners: Arc>>, /// A receiver channel for the shard queuer to be told to start shards. @@ -125,8 +125,8 @@ impl ShardQueuer { }, Ok(None) => break, Err(_) => { - if let Some((id, total)) = self.queue.pop_front() { - self.checked_start(id, total).await; + if let Some(shard) = self.queue.pop_front() { + self.checked_start(shard.id, shard.total).await; } }, } @@ -155,7 +155,7 @@ impl ShardQueuer { } #[instrument(skip(self))] - async fn checked_start(&mut self, id: u64, total: u64) { + async fn checked_start(&mut self, id: u32, total: u32) { debug!("[Shard Queuer] Checked start for shard {} out of {}", id, total); self.check_last_start().await; @@ -163,15 +163,15 @@ impl ShardQueuer { warn!("[Shard Queuer] Err starting shard {}: {:?}", id, why); info!("[Shard Queuer] Re-queueing start of shard {}", id); - self.queue.push_back((id, total)); + self.queue.push_back(ShardInfo::new(id, total)); } self.last_start = Some(Instant::now()); } #[instrument(skip(self))] - async fn start(&mut self, shard_id: u64, shard_total: u64) -> Result<()> { - let shard_info = [shard_id, shard_total]; + async fn start(&mut self, id: u32, total: u32) -> Result<()> { + let shard_info = ShardInfo::new(id, total); let mut shard = Shard::new( Arc::clone(&self.ws_url), @@ -207,7 +207,7 @@ impl ShardQueuer { debug!("[ShardRunner {:?}] Stopping", runner.shard.shard_info()); }); - self.runners.lock().await.insert(ShardId(shard_id), runner_info); + self.runners.lock().await.insert(ShardId(id), runner_info); Ok(()) } diff --git a/src/client/bridge/gateway/shard_runner.rs b/src/client/bridge/gateway/shard_runner.rs index 32a2ac8de07..c7a2e6ef9a6 100644 --- a/src/client/bridge/gateway/shard_runner.rs +++ b/src/client/bridge/gateway/shard_runner.rs @@ -143,7 +143,7 @@ impl ShardRunner { let e = ClientEvent::ShardStageUpdate(ShardStageUpdateEvent { new: post, old: pre, - shard_id: ShardId(self.shard.shard_info()[0]), + shard_id: ShardId(self.shard.shard_info().id), }); self.dispatch(DispatchEvent::Client(e)).await; @@ -275,7 +275,7 @@ impl ShardRunner { async fn checked_shutdown(&mut self, id: ShardId, close_code: u16) -> bool { // First verify the ID so we know for certain this runner is // to shutdown. - if id.0 != self.shard.shard_info()[0] { + if id.0 != self.shard.shard_info().id { // Not meant for this runner for some reason, don't // shutdown. return true; @@ -331,7 +331,7 @@ impl ShardRunner { &self.event_handler, &self.raw_event_handler, &self.runner_tx, - self.shard.shard_info()[0], + self.shard.shard_info().id, Arc::clone(&self.cache_and_http), ) .await; @@ -469,7 +469,7 @@ impl ShardRunner { match *event { Event::Ready(_) => { voice_manager - .register_shard(self.shard.shard_info()[0], self.runner_tx.clone()) + .register_shard(self.shard.shard_info().id, self.runner_tx.clone()) .await; }, Event::VoiceServerUpdate(ref event) => { @@ -627,7 +627,7 @@ impl ShardRunner { self.update_manager(); debug!("[ShardRunner {:?}] Requesting restart", self.shard.shard_info(),); - let shard_id = ShardId(self.shard.shard_info()[0]); + let shard_id = ShardId(self.shard.shard_info().id); let msg = ShardManagerMessage::Restart(shard_id); if let Err(error) = self.manager_tx.unbounded_send(msg) { @@ -645,7 +645,7 @@ impl ShardRunner { #[instrument(skip(self))] fn update_manager(&self) { drop(self.manager_tx.unbounded_send(ShardManagerMessage::ShardUpdate { - id: ShardId(self.shard.shard_info()[0]), + id: ShardId(self.shard.shard_info().id), latency: self.shard.latency(), stage: self.shard.stage(), })); diff --git a/src/client/bridge/voice/mod.rs b/src/client/bridge/voice/mod.rs index dc82a0ba190..e81a58b7e2a 100644 --- a/src/client/bridge/voice/mod.rs +++ b/src/client/bridge/voice/mod.rs @@ -14,7 +14,7 @@ pub trait VoiceGatewayManager: Send + Sync { /// Performs initial setup at the start of a connection to Discord. /// /// This will only occur once, and provides the bot's ID and shard count. - async fn initialise(&self, shard_count: u64, user_id: UserId); + async fn initialise(&self, shard_count: u32, user_id: UserId); /// Handler fired in response to a [`Ready`] event. /// @@ -22,7 +22,7 @@ pub trait VoiceGatewayManager: Send + Sync { /// once per active shard. /// /// [`Ready`]: crate::model::event::Event - async fn register_shard(&self, shard_id: u64, sender: Sender); + async fn register_shard(&self, shard_id: u32, sender: Sender); /// Handler fired in response to a disconnect, reconnection, or rebalance. /// @@ -30,7 +30,7 @@ pub trait VoiceGatewayManager: Send + Sync { /// Unless the bot is fully disconnecting, this is often followed by a call /// to [`Self::register_shard`]. Users may wish to buffer manually any gateway messages /// sent between these calls. - async fn deregister_shard(&self, shard_id: u64); + async fn deregister_shard(&self, shard_id: u32); /// Handler for VOICE_SERVER_UPDATE messages. /// diff --git a/src/client/context.rs b/src/client/context.rs index 3fb550e0ddd..e2bce3186cc 100644 --- a/src/client/context.rs +++ b/src/client/context.rs @@ -39,7 +39,7 @@ pub struct Context { /// The messenger to communicate with the shard runner. pub shard: ShardMessenger, /// The ID of the shard this context is related to. - pub shard_id: u64, + pub shard_id: u32, pub http: Arc, #[cfg(feature = "cache")] pub cache: Arc, @@ -51,7 +51,7 @@ impl Context { pub(crate) fn new( data: Arc>, runner_tx: Sender, - shard_id: u64, + shard_id: u32, http: Arc, cache: Arc, ) -> Context { @@ -65,7 +65,7 @@ impl Context { } #[cfg(all(not(feature = "cache"), not(feature = "gateway")))] - pub fn easy(data: Arc>, shard_id: u64, http: Arc) -> Context { + pub fn easy(data: Arc>, shard_id: u32, http: Arc) -> Context { Context { shard_id, data, @@ -78,7 +78,7 @@ impl Context { pub(crate) fn new( data: Arc>, runner_tx: Sender, - shard_id: u64, + shard_id: u32, http: Arc, ) -> Context { Context { diff --git a/src/client/dispatch.rs b/src/client/dispatch.rs index 25b95c1b66e..0aca33f10b8 100644 --- a/src/client/dispatch.rs +++ b/src/client/dispatch.rs @@ -46,7 +46,7 @@ fn update(_cache_and_http: &Arc, _event: &mut E) -> Option<()> fn context( data: &Arc>, runner_tx: &Sender, - shard_id: u64, + shard_id: u32, http: &Arc, cache: &Arc, ) -> Context { @@ -57,7 +57,7 @@ fn context( fn context( data: &Arc>, runner_tx: &Sender, - shard_id: u64, + shard_id: u32, http: &Arc, ) -> Context { Context::new(Arc::clone(data), runner_tx.clone(), shard_id, Arc::clone(http)) @@ -163,7 +163,7 @@ pub(crate) fn dispatch<'rec>( event_handler: &'rec Option>, raw_event_handler: &'rec Option>, runner_tx: &'rec Sender, - shard_id: u64, + shard_id: u32, cache_and_http: Arc, ) -> BoxFuture<'rec, ()> { async move { @@ -340,7 +340,7 @@ async fn handle_event( data: &Arc>, event_handler: &Arc, runner_tx: &Sender, - shard_id: u64, + shard_id: u32, cache_and_http: Arc, ) { #[cfg(not(feature = "cache"))] diff --git a/src/client/mod.rs b/src/client/mod.rs index b41ba1833ea..940baf04e49 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -26,6 +26,7 @@ mod error; mod event_handler; use std::future::Future; +use std::ops::Range; use std::pin::Pin; use std::sync::Arc; use std::task::{Context as FutContext, Poll}; @@ -707,7 +708,7 @@ impl Client { /// [gateway docs]: crate::gateway#sharding #[instrument(skip(self))] pub async fn start(&mut self) -> Result<()> { - self.start_connection([0, 0, 1]).await + self.start_connection(0, 0, 1).await } /// Establish the connection(s) and start listening for events. @@ -755,13 +756,13 @@ impl Client { /// [gateway docs]: crate::gateway#sharding #[instrument(skip(self))] pub async fn start_autosharded(&mut self) -> Result<()> { - let (x, y) = { + let (end, total) = { let res = self.cache_and_http.http.get_bot_gateway().await?; (res.shards - 1, res.shards) }; - self.start_connection([0, x, y]).await + self.start_connection(0, end, total).await } /// Establish a sharded connection and start listening for events. @@ -832,8 +833,8 @@ impl Client { /// /// [gateway docs]: crate::gateway#sharding #[instrument(skip(self))] - pub async fn start_shard(&mut self, shard: u64, shards: u64) -> Result<()> { - self.start_connection([shard, shard, shards]).await + pub async fn start_shard(&mut self, shard: u32, shards: u32) -> Result<()> { + self.start_connection(shard, shard, shards).await } /// Establish sharded connections and start listening for events. @@ -880,8 +881,8 @@ impl Client { /// /// [Gateway docs]: crate::gateway#sharding #[instrument(skip(self))] - pub async fn start_shards(&mut self, total_shards: u64) -> Result<()> { - self.start_connection([0, total_shards - 1, total_shards]).await + pub async fn start_shards(&mut self, total_shards: u32) -> Result<()> { + self.start_connection(0, total_shards - 1, total_shards).await } /// Establish a range of sharded connections and start listening for events. @@ -915,7 +916,7 @@ impl Client { /// let mut client = /// Client::builder(&token, GatewayIntents::default()).event_handler(Handler).await?; /// - /// if let Err(why) = client.start_shard_range([4, 7], 10).await { + /// if let Err(why) = client.start_shard_range(4..7, 10).await { /// println!("Err with client: {:?}", why); /// } /// # Ok(()) @@ -929,8 +930,8 @@ impl Client { /// /// [Gateway docs]: crate::gateway#sharding #[instrument(skip(self))] - pub async fn start_shard_range(&mut self, range: [u64; 2], total_shards: u64) -> Result<()> { - self.start_connection([range[0], range[1], total_shards]).await + pub async fn start_shard_range(&mut self, range: Range, total_shards: u32) -> Result<()> { + self.start_connection(range.start, range.end, total_shards).await } /// Shard data layout is: @@ -945,22 +946,27 @@ impl Client { /// Returns a [`ClientError::Shutdown`] when all shards have shutdown due to /// an error. #[instrument(skip(self))] - async fn start_connection(&mut self, shard_data: [u64; 3]) -> Result<()> { + async fn start_connection( + &mut self, + start_shard: u32, + end_shard: u32, + total_shards: u32, + ) -> Result<()> { #[cfg(feature = "voice")] if let Some(voice_manager) = &self.voice_manager { let user = self.cache_and_http.http.get_current_user().await?; - voice_manager.initialise(shard_data[2], user.id).await; + voice_manager.initialise(total_shards, user.id).await; } { let mut manager = self.shard_manager.lock().await; - let init = shard_data[1] - shard_data[0] + 1; + let init = end_shard - start_shard + 1; - manager.set_shards(shard_data[0], init, shard_data[2]).await; + manager.set_shards(start_shard, init, total_shards).await; - debug!("Initializing shard info: {} - {}/{}", shard_data[0], init, shard_data[2],); + debug!("Initializing shard info: {} - {}/{}", start_shard, init, total_shards); if let Err(why) = manager.initialize() { error!("Failed to boot a shard: {:?}", why); diff --git a/src/gateway/shard.rs b/src/gateway/shard.rs index 3b5093f1117..a580e2021dc 100644 --- a/src/gateway/shard.rs +++ b/src/gateway/shard.rs @@ -13,7 +13,7 @@ use crate::constants::{self, close_codes}; use crate::http::Http; use crate::internal::prelude::*; use crate::model::event::{Event, GatewayEvent}; -use crate::model::gateway::{Activity, GatewayIntents}; +use crate::model::gateway::{Activity, GatewayIntents, ShardInfo}; use crate::model::id::GuildId; use crate::model::user::OnlineStatus; @@ -70,7 +70,7 @@ pub struct Shard { last_heartbeat_acknowledged: bool, seq: u64, session_id: Option, - shard_info: [u64; 2], + shard_info: ShardInfo, stage: ConnectionStage, /// Instant of when the shard was started. // This acts as a timeout to determine if the shard has - for some reason - @@ -98,14 +98,19 @@ impl Shard { /// use tokio::sync::Mutex; /// # /// # use serenity::http::Http; - /// # use serenity::model::gateway::GatewayIntents; + /// # use serenity::model::gateway::{GatewayIntents, ShardInfo}; /// # /// # async fn run() -> Result<(), Box> { /// # let http = Arc::new(Http::new("token")); /// let token = std::env::var("DISCORD_BOT_TOKEN")?; + /// let shard_info = ShardInfo { + /// id: 0, + /// total: 1, + /// }; + /// /// // retrieve the gateway response, which contains the URL to connect to /// let gateway = Arc::new(Mutex::new(http.get_gateway().await?.url)); - /// let shard = Shard::new(gateway, &token, [0u64, 1u64], GatewayIntents::all()).await?; + /// let shard = Shard::new(gateway, &token, shard_info, GatewayIntents::all()).await?; /// /// // at this point, you can create a `loop`, and receive events and match /// // their variants @@ -120,7 +125,7 @@ impl Shard { pub async fn new( ws_url: Arc>, token: &str, - shard_info: [u64; 2], + shard_info: ShardInfo, intents: GatewayIntents, ) -> Result { let url = ws_url.lock().await.clone(); @@ -266,41 +271,9 @@ impl Shard { /// Retrieves a copy of the current shard information. /// - /// The first element is the _current_ shard - 0-indexed - while the second - /// element is the _total number_ of shards -- 1-indexed. - /// /// For example, if using 3 shards in total, and if this is shard 1, then it /// can be read as "the second of three shards". - /// - /// # Examples - /// - /// Retrieving the shard info for the second shard, out of two shards total: - /// - /// For example, if using 3 shards in total, and if this is shard 1, then it - /// can be read as "the second of three shards". - /// - /// # Examples - /// - /// Retrieving the shard info for the second shard, out of two shards total: - /// - /// ```rust,no_run - /// # use serenity::gateway::Shard; - /// # use serenity::prelude::Mutex; - /// # use serenity::model::gateway::GatewayIntents; - /// # use std::sync::Arc; - /// # - /// # #[cfg(feature = "model")] - /// # async fn run() { - /// # - /// # let mutex = Arc::new(Mutex::new("".to_string())); - /// # - /// # let mut shard = Shard::new(mutex.clone(), "", [0u64, 1u64], - /// # GatewayIntents::all()).await.unwrap(); - /// # - /// assert_eq!(shard.shard_info(), [1, 2]); - /// # } - /// ``` - pub fn shard_info(&self) -> [u64; 2] { + pub fn shard_info(&self) -> ShardInfo { self.shard_info } @@ -652,15 +625,19 @@ impl Shard { /// /// ```rust,no_run /// # use tokio::sync::Mutex; - /// # use serenity::model::gateway::GatewayIntents; + /// # use serenity::model::gateway::{GatewayIntents, ShardInfo}; /// # use serenity::client::bridge::gateway::ChunkGuildFilter; /// # use serenity::gateway::Shard; /// # use std::sync::Arc; /// # /// # async fn run() -> Result<(), Box> { /// # let mutex = Arc::new(Mutex::new("".to_string())); + /// # let shard_info = ShardInfo { + /// # id: 0, + /// # total: 1, + /// # }; /// # - /// # let mut shard = Shard::new(mutex.clone(), "", [0u64, 1u64], GatewayIntents::all()).await?; + /// # let mut shard = Shard::new(mutex.clone(), "", shard_info, GatewayIntents::all()).await?; /// # /// use serenity::model::id::GuildId; /// @@ -675,7 +652,7 @@ impl Shard { /// ```rust,no_run /// # use tokio::sync::Mutex; /// # use serenity::gateway::Shard; - /// # use serenity::model::gateway::GatewayIntents; + /// # use serenity::model::gateway::{GatewayIntents, ShardInfo}; /// # use serenity::client::bridge::gateway::ChunkGuildFilter; /// # use std::error::Error; /// # use std::sync::Arc; @@ -683,8 +660,11 @@ impl Shard { /// # async fn run() -> Result<(), Box> { /// # let mutex = Arc::new(Mutex::new("".to_string())); /// # - /// # let mut shard = Shard::new(mutex.clone(), "", [0u64, 1u64], - /// # GatewayIntents::all()).await?; + /// # let shard_info = ShardInfo { + /// # id: 0, + /// # total: 1, + /// # }; + /// # let mut shard = Shard::new(mutex.clone(), "", shard_info, GatewayIntents::all()).await?; /// # /// use serenity::model::id::GuildId; /// diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index dbec352d3a6..0384e89fd29 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -17,7 +17,7 @@ use crate::constants::{self, Opcode}; use crate::gateway::{CurrentPresence, GatewayError}; use crate::json::{from_str, to_string}; use crate::model::event::GatewayEvent; -use crate::model::gateway::{ActivityType, GatewayIntents}; +use crate::model::gateway::{ActivityType, GatewayIntents, ShardInfo}; use crate::model::id::{GuildId, UserId}; use crate::{Error, Result}; @@ -54,9 +54,9 @@ enum WebSocketMessageData<'a> { ChunkGuild(ChunkGuildMessage<'a>), Identify { compress: bool, - shard: &'a [u64; 2], token: &'a str, large_threshold: u8, + shard: &'a ShardInfo, intents: GatewayIntents, properties: IdentifyProperties, }, @@ -163,7 +163,7 @@ impl WsClient { pub async fn send_chunk_guild( &mut self, guild_id: GuildId, - shard_info: &[u64; 2], + shard_info: &ShardInfo, limit: Option, filter: ChunkGuildFilter, nonce: Option<&str>, @@ -190,7 +190,7 @@ impl WsClient { } #[instrument(skip(self))] - pub async fn send_heartbeat(&mut self, shard_info: &[u64; 2], seq: Option) -> Result<()> { + pub async fn send_heartbeat(&mut self, shard_info: &ShardInfo, seq: Option) -> Result<()> { trace!("[Shard {:?}] Sending heartbeat d: {:?}", shard_info, seq); self.send_json(&WebSocketMessage { @@ -203,7 +203,7 @@ impl WsClient { #[instrument(skip(self, token))] pub async fn send_identify( &mut self, - shard: &[u64; 2], + shard: &ShardInfo, token: &str, intents: GatewayIntents, ) -> Result<()> { @@ -231,7 +231,7 @@ impl WsClient { #[instrument(skip(self))] pub async fn send_presence_update( &mut self, - shard_info: &[u64; 2], + shard_info: &ShardInfo, current_presence: &CurrentPresence, ) -> Result<()> { let (activity, status) = current_presence; @@ -258,7 +258,7 @@ impl WsClient { #[instrument(skip(self, token))] pub async fn send_resume( &mut self, - shard_info: &[u64; 2], + shard_info: &ShardInfo, session_id: &str, seq: u64, token: &str, diff --git a/src/model/gateway.rs b/src/model/gateway.rs index 7df67cb225e..f1371ca34a4 100644 --- a/src/model/gateway.rs +++ b/src/model/gateway.rs @@ -1,5 +1,6 @@ //! Models pertaining to the gateway. +use serde::ser::SerializeSeq; use url::Url; use super::prelude::*; @@ -21,7 +22,7 @@ pub struct BotGateway { pub session_start_limit: SessionStartLimit, /// The number of shards that is recommended to be used by the current bot /// user. - pub shards: u64, + pub shards: u32, /// The gateway to connect to. pub url: String, } @@ -512,7 +513,7 @@ pub struct Ready { #[serde(default, with = "private_channels")] pub private_channels: HashMap, pub session_id: String, - pub shard: Option<[u64; 2]>, + pub shard: Option, #[serde(default, rename = "_trace")] pub trace: Vec, pub user: CurrentUser, @@ -537,6 +538,42 @@ pub struct SessionStartLimit { /// The number of identify requests allowed per 5 seconds. pub max_concurrency: u64, } + +#[derive(Clone, Copy, Debug)] +pub struct ShardInfo { + pub id: u32, + pub total: u32, +} + +impl ShardInfo { + #[cfg(feature = "client")] + #[must_use] + pub(crate) fn new(id: u32, total: u32) -> Self { + Self { + id, + total, + } + } +} + +impl<'de> serde::Deserialize<'de> for ShardInfo { + fn deserialize>(deserializer: D) -> StdResult { + <(u32, u32)>::deserialize(deserializer).map(|(id, total)| ShardInfo { + id, + total, + }) + } +} + +impl serde::Serialize for ShardInfo { + fn serialize(&self, serializer: S) -> StdResult { + let mut seq = serializer.serialize_seq(Some(2))?; + seq.serialize_element(&self.id)?; + seq.serialize_element(&self.total)?; + seq.end() + } +} + /// Timestamps of when a user started and/or is ending their activity. /// /// [Discord docs](https://discord.com/developers/docs/game-sdk/activities#data-models-activitytimestamps-struct). diff --git a/src/model/guild/guild_id.rs b/src/model/guild/guild_id.rs index 364edf5443c..1f2e1ea9b65 100644 --- a/src/model/guild/guild_id.rs +++ b/src/model/guild/guild_id.rs @@ -1331,8 +1331,8 @@ impl GuildId { #[cfg(all(feature = "cache", feature = "utils"))] #[inline] #[must_use] - pub fn shard_id(self, cache: impl AsRef) -> u64 { - crate::utils::shard_id(self.get(), cache.as_ref().shard_count()) + pub fn shard_id(self, cache: impl AsRef) -> u32 { + crate::utils::shard_id(self, cache.as_ref().shard_count()) } /// Returns the Id of the shard associated with the guild. @@ -1359,8 +1359,8 @@ impl GuildId { #[cfg(all(feature = "utils", not(feature = "cache")))] #[inline] #[must_use] - pub fn shard_id(self, shard_count: u64) -> u64 { - crate::utils::shard_id(self.get(), shard_count) + pub fn shard_id(self, shard_count: u32) -> u32 { + crate::utils::shard_id(self, shard_count) } /// Starts an integration sync for the given integration Id. diff --git a/src/model/guild/mod.rs b/src/model/guild/mod.rs index d13b0e392ab..ac72d128448 100644 --- a/src/model/guild/mod.rs +++ b/src/model/guild/mod.rs @@ -2464,7 +2464,7 @@ impl Guild { /// [`utils::shard_id`]: crate::utils::shard_id #[cfg(all(feature = "cache", feature = "utils"))] #[inline] - pub fn shard_id(&self, cache: impl AsRef) -> u64 { + pub fn shard_id(&self, cache: impl AsRef) -> u32 { self.id.shard_id(&cache) } @@ -2491,7 +2491,7 @@ impl Guild { #[cfg(all(feature = "utils", not(feature = "cache")))] #[inline] #[must_use] - pub fn shard_id(&self, shard_count: u64) -> u64 { + pub fn shard_id(&self, shard_count: u32) -> u32 { self.id.shard_id(shard_count) } diff --git a/src/model/guild/partial_guild.rs b/src/model/guild/partial_guild.rs index e1bd3df1624..e498d24a79a 100644 --- a/src/model/guild/partial_guild.rs +++ b/src/model/guild/partial_guild.rs @@ -1468,7 +1468,7 @@ impl PartialGuild { #[cfg(all(feature = "cache", feature = "utils"))] #[inline] #[must_use] - pub fn shard_id(&self, cache: impl AsRef) -> u64 { + pub fn shard_id(&self, cache: impl AsRef) -> u32 { self.id.shard_id(cache) } @@ -1495,7 +1495,7 @@ impl PartialGuild { #[cfg(all(feature = "utils", not(feature = "cache")))] #[inline] #[must_use] - pub fn shard_id(&self, shard_count: u64) -> u64 { + pub fn shard_id(&self, shard_count: u32) -> u32 { self.id.shard_id(shard_count) } diff --git a/src/model/invite.rs b/src/model/invite.rs index 0d5a1b212cd..00f10baa7a4 100644 --- a/src/model/invite.rs +++ b/src/model/invite.rs @@ -259,7 +259,7 @@ impl InviteGuild { #[cfg(all(feature = "cache", feature = "utils"))] #[inline] #[must_use] - pub fn shard_id(&self, cache: impl AsRef) -> u64 { + pub fn shard_id(&self, cache: impl AsRef) -> u32 { self.id.shard_id(&cache) } @@ -286,7 +286,7 @@ impl InviteGuild { #[cfg(all(feature = "utils", not(feature = "cache")))] #[inline] #[must_use] - pub fn shard_id(&self, shard_count: u64) -> u64 { + pub fn shard_id(&self, shard_count: u32) -> u32 { self.id.shard_id(shard_count) } } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 7d38a195081..ac8f6e33ed8 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -471,8 +471,8 @@ pub fn parse_webhook(url: &Url) -> Option<(u64, &str)> { /// assert_eq!(utils::shard_id(81384788765712384 as u64, 17), 7); /// ``` #[inline] -pub fn shard_id(guild_id: impl Into, shard_count: u64) -> u64 { - (guild_id.into() >> 22) % shard_count +pub fn shard_id(guild_id: impl Into, shard_count: u32) -> u32 { + ((guild_id.into() >> 22) % (shard_count as u64)) as u32 } #[cfg(test)]