From 0746543f8dd880703dbd9ff1d087bfe2ad32bf45 Mon Sep 17 00:00:00 2001 From: Michael Krasnitski <42564254+mkrasnitski@users.noreply.github.com> Date: Tue, 16 Jan 2024 10:32:33 -0500 Subject: [PATCH] Implement `max_concurrency` support when starting shards (#2661) When a session is first kicked off and we need to start many shards, we can use the value of `max_concurrency` to kick off multiple shards in parallel. --- src/client/mod.rs | 11 +- src/gateway/bridge/mod.rs | 6 +- src/gateway/bridge/shard_manager.rs | 28 +++-- src/gateway/bridge/shard_queuer.rs | 156 ++++++++++++++++++++++------ src/model/gateway.rs | 5 +- 5 files changed, 156 insertions(+), 50 deletions(-) diff --git a/src/client/mod.rs b/src/client/mod.rs index 2e9a5f263a2..ed3e7b94c34 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -333,11 +333,15 @@ impl IntoFuture for ClientBuilder { let cache = Arc::new(Cache::new_with_settings(self.cache_settings)); Box::pin(async move { - let (ws_url, shard_total) = match http.get_bot_gateway().await { - Ok(response) => (Arc::from(response.url), response.shards), + let (ws_url, shard_total, max_concurrency) = match http.get_bot_gateway().await { + Ok(response) => ( + Arc::from(response.url), + response.shards, + response.session_start_limit.max_concurrency, + ), Err(err) => { tracing::warn!("HTTP request to get gateway URL failed: {err}"); - (Arc::from("wss://gateway.discord.gg"), NonZeroU16::MIN) + (Arc::from("wss://gateway.discord.gg"), NonZeroU16::MIN, NonZeroU16::MIN) }, }; @@ -358,6 +362,7 @@ impl IntoFuture for ClientBuilder { http: Arc::clone(&http), intents, presence: Some(presence), + max_concurrency, }); let client = Client { diff --git a/src/gateway/bridge/mod.rs b/src/gateway/bridge/mod.rs index 4fc51c1cd87..ed0621933ce 100644 --- a/src/gateway/bridge/mod.rs +++ b/src/gateway/bridge/mod.rs @@ -56,7 +56,7 @@ use std::time::Duration as StdDuration; pub use self::event::ShardStageUpdateEvent; pub use self::shard_manager::{ShardManager, ShardManagerOptions}; pub use self::shard_messenger::ShardMessenger; -pub use self::shard_queuer::ShardQueuer; +pub use self::shard_queuer::{ShardQueue, ShardQueuer}; pub use self::shard_runner::{ShardRunner, ShardRunnerOptions}; pub use self::shard_runner_message::ShardRunnerMessage; #[cfg(feature = "voice")] @@ -72,11 +72,11 @@ pub enum ShardQueuerMessage { /// Message to set the shard total. SetShardTotal(NonZeroU16), /// Message to start a shard. - Start(ShardId), + Start { shard_id: ShardId, concurrent: bool }, /// Message to shutdown the shard queuer. Shutdown, /// Message to dequeue/shutdown a shard. - ShutdownShard(ShardId, u16), + ShutdownShard { shard_id: ShardId, code: u16 }, } /// Information about a [`ShardRunner`]. diff --git a/src/gateway/bridge/shard_manager.rs b/src/gateway/bridge/shard_manager.rs index 0ba18141452..310c7d0c20a 100644 --- a/src/gateway/bridge/shard_manager.rs +++ b/src/gateway/bridge/shard_manager.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, VecDeque}; +use std::collections::HashMap; use std::num::NonZeroU16; use std::sync::Arc; #[cfg(feature = "framework")] @@ -14,7 +14,7 @@ use typemap_rev::TypeMap; #[cfg(feature = "voice")] use super::VoiceGatewayManager; -use super::{ShardId, ShardQueuer, ShardQueuerMessage, ShardRunnerInfo}; +use super::{ShardId, ShardQueue, ShardQueuer, ShardQueuerMessage, ShardRunnerInfo}; #[cfg(feature = "cache")] use crate::cache::Cache; use crate::client::{EventHandler, RawEventHandler}; @@ -71,6 +71,7 @@ use crate::model::gateway::GatewayIntents; /// let ws_url = Arc::from(gateway_info.url); /// let data = Arc::new(RwLock::new(TypeMap::new())); /// let event_handler = Arc::new(Handler) as Arc; +/// let max_concurrency = std::num::NonZeroU16::MIN; /// let framework = Arc::new(StandardFramework::new()) as Arc; /// /// ShardManager::new(ShardManagerOptions { @@ -87,6 +88,7 @@ use crate::model::gateway::GatewayIntents; /// # http, /// intents: GatewayIntents::non_privileged(), /// presence: None, +/// max_concurrency, /// }); /// # Ok(()) /// # } @@ -137,7 +139,7 @@ impl ShardManager { framework: opt.framework, last_start: None, manager: Arc::clone(&manager), - queue: VecDeque::new(), + queue: ShardQueue::new(opt.max_concurrency), runners, rx: shard_queue_rx, #[cfg(feature = "voice")] @@ -176,7 +178,7 @@ impl ShardManager { self.set_shard_total(shard_total); for shard_id in shard_index..shard_to { - self.boot(ShardId(shard_id)); + self.boot(ShardId(shard_id), true); } } @@ -204,7 +206,7 @@ impl ShardManager { pub async fn restart(&self, shard_id: ShardId) { info!("Restarting shard {shard_id}"); self.shutdown(shard_id, 4000).await; - self.boot(shard_id); + self.boot(shard_id, false); } /// Returns the [`ShardId`]s of the shards that have been instantiated and currently have a @@ -233,9 +235,10 @@ impl ShardManager { { let mut shard_shutdown = self.shard_shutdown.lock().await; - drop( - self.shard_queuer.unbounded_send(ShardQueuerMessage::ShutdownShard(shard_id, code)), - ); + drop(self.shard_queuer.unbounded_send(ShardQueuerMessage::ShutdownShard { + shard_id, + code, + })); match timeout(TIMEOUT, shard_shutdown.next()).await { Ok(Some(shutdown_shard_id)) => { if shutdown_shard_id != shard_id { @@ -299,11 +302,13 @@ impl ShardManager { } #[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))] - fn boot(&self, shard_id: ShardId) { + fn boot(&self, shard_id: ShardId, concurrent: bool) { info!("Telling shard queuer to start shard {shard_id}"); - let msg = ShardQueuerMessage::Start(shard_id); - drop(self.shard_queuer.unbounded_send(msg)); + drop(self.shard_queuer.unbounded_send(ShardQueuerMessage::Start { + shard_id, + concurrent, + })); } /// Returns the gateway intents used for this gateway connection. @@ -371,4 +376,5 @@ pub struct ShardManagerOptions { pub http: Arc, pub intents: GatewayIntents, pub presence: Option, + pub max_concurrency: NonZeroU16, } diff --git a/src/gateway/bridge/shard_queuer.rs b/src/gateway/bridge/shard_queuer.rs index 7b60cae7580..7b64689e9a3 100644 --- a/src/gateway/bridge/shard_queuer.rs +++ b/src/gateway/bridge/shard_queuer.rs @@ -62,9 +62,7 @@ pub struct ShardQueuer { /// A copy of the [`ShardManager`] to communicate with it. pub manager: Arc, /// The shards that are queued for booting. - /// - /// This will typically be filled with previously failed boots. - pub queue: VecDeque, + pub queue: ShardQueue, /// A copy of the map of shard runners. pub runners: Arc>>, /// A receiver channel for the shard queuer to be told to start shards. @@ -102,35 +100,60 @@ impl ShardQueuer { /// **Note**: This should be run in its own thread due to the blocking nature of the loop. #[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))] pub async fn run(&mut self) { - // The duration to timeout from reads over the Rx channel. This can be done in a loop, and - // if the read times out then a shard can be started if one is presently waiting in the - // queue. + // We read from the Rx channel in a loop, and use a timeout of 5 seconds so that we don't + // hang forever. When we receive a command to start a shard, we append it to our queue. The + // queue is popped in batches of shards, which are started in parallel. A batch is fired + // every 5 seconds at minimum in order to avoid being ratelimited. const TIMEOUT: Duration = Duration::from_secs(WAIT_BETWEEN_BOOTS_IN_SECONDS); loop { - match timeout(TIMEOUT, self.rx.next()).await { - Ok(Some(ShardQueuerMessage::Shutdown)) => { - debug!("[Shard Queuer] Received to shutdown."); - self.shutdown_runners().await; - - break; - }, - Ok(Some(ShardQueuerMessage::ShutdownShard(shard, code))) => { - debug!("[Shard Queuer] Received to shutdown shard {} with {}.", shard.0, code); - self.shutdown(shard, code).await; - }, - Ok(Some(ShardQueuerMessage::Start(shard_id))) => { - self.checked_start(shard_id).await; - }, - Ok(Some(ShardQueuerMessage::SetShardTotal(shard_total))) => { - self.shard_total = shard_total; - }, - Ok(None) => break, - Err(_) => { - if let Some(shard) = self.queue.pop_front() { - self.checked_start(shard).await; - } - }, + if let Ok(msg) = timeout(TIMEOUT, self.rx.next()).await { + match msg { + Some(ShardQueuerMessage::SetShardTotal(shard_total)) => { + self.shard_total = shard_total; + }, + Some(ShardQueuerMessage::Start { + shard_id, + concurrent, + }) => { + if concurrent { + // If we're starting multiple shards, we can start them concurrently + // according to `max_concurrency`, and want our batches to be of + // maximal size. + self.queue.push_back(shard_id); + if self.queue.buckets_filled() { + let batch = self.queue.pop_batch(); + self.checked_start_batch(batch).await; + } + } else { + // In cases where we're only starting a single shard (e.g. if we're + // restarting a shard), we assume the queue will never fill up and skip + // using it so that we don't incur a 5 second timeout. + self.checked_start(shard_id).await; + } + }, + Some(ShardQueuerMessage::ShutdownShard { + shard_id, + code, + }) => { + debug!( + "[Shard Queuer] Received to shutdown shard {} with code {}", + shard_id.0, code + ); + self.shutdown(shard_id, code).await; + }, + Some(ShardQueuerMessage::Shutdown) => { + debug!("[Shard Queuer] Received to shutdown all shards"); + self.shutdown_runners().await; + break; + }, + None => break, + } + } else { + // Once we've stopped receiving `Start` commands, we no longer care about the size + // of our batches being maximal. + let batch = self.queue.pop_batch(); + self.checked_start_batch(batch).await; } } } @@ -157,14 +180,35 @@ impl ShardQueuer { debug!("[Shard Queuer] Checked start for shard {shard_id}"); self.check_last_start().await; + self.try_start(shard_id).await; + + self.last_start = Some(Instant::now()); + } + + #[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))] + async fn checked_start_batch(&mut self, shard_ids: Vec) { + if shard_ids.is_empty() { + return; + } + + debug!("[Shard Queuer] Starting batch of {} shards", shard_ids.len()); + self.check_last_start().await; + for shard_id in shard_ids { + debug!("[Shard Queuer] Starting shard {shard_id}"); + self.try_start(shard_id).await; + } + self.last_start = Some(Instant::now()); + } + + #[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))] + async fn try_start(&mut self, shard_id: ShardId) { if let Err(why) = self.start(shard_id).await { warn!("[Shard Queuer] Err starting shard {shard_id}: {why:?}"); info!("[Shard Queuer] Re-queueing start of shard {shard_id}"); - self.queue.push_back(shard_id); + // Try again in the next batch. + self.queue.push_front(shard_id); } - - self.last_start = Some(Instant::now()); } #[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))] @@ -253,3 +297,51 @@ impl ShardQueuer { } } } + +/// A queue of [`ShardId`]s that is split up into multiple buckets according to the value of +/// [`max_concurrency`](crate::model::gateway::SessionStartLimit::max_concurrency). +#[must_use] +pub struct ShardQueue { + buckets: HashMap>, + max_concurrency: NonZeroU16, +} + +impl ShardQueue { + pub fn new(max_concurrency: NonZeroU16) -> Self { + Self { + buckets: HashMap::with_capacity(max_concurrency.get() as usize), + max_concurrency, + } + } + + /// Calculates the corresponding bucket for the given `ShardId` and **appends** to it. + pub fn push_back(&mut self, shard_id: ShardId) { + let bucket = shard_id.0 % self.max_concurrency.get(); + self.buckets.entry(bucket).or_default().push_back(shard_id); + } + + /// Calculates the corresponding bucket for the given `ShardId` and **prepends** to it. + pub fn push_front(&mut self, shard_id: ShardId) { + let bucket = shard_id.0 % self.max_concurrency.get(); + self.buckets.entry(bucket).or_default().push_front(shard_id); + } + + /// Pops a `ShardId` from every bucket containing at least one and returns them all as a `Vec`. + pub fn pop_batch(&mut self) -> Vec { + (0..self.max_concurrency.get()) + .filter_map(|i| self.buckets.get_mut(&i).and_then(|bucket| bucket.pop_front())) + .collect() + } + + /// Returns `true` if every bucket contains at least one `ShardId`. + #[must_use] + pub fn buckets_filled(&self) -> bool { + for i in 0..self.max_concurrency.get() { + let Some(bucket) = self.buckets.get(&i) else { return false }; + if bucket.is_empty() { + return false; + } + } + true + } +} diff --git a/src/model/gateway.rs b/src/model/gateway.rs index 69d9ea662db..e974aceb532 100644 --- a/src/model/gateway.rs +++ b/src/model/gateway.rs @@ -357,7 +357,10 @@ pub struct SessionStartLimit { /// The total number of session starts within the ratelimit period allowed. pub total: u64, /// The number of identify requests allowed per 5 seconds. - pub max_concurrency: u64, + /// + /// This is almost always 1, but for large bots (in more than 150,000 servers) it can be + /// larger. + pub max_concurrency: NonZeroU16, } #[cfg_attr(feature = "typesize", derive(typesize::derive::TypeSize))]