Skip to content

Commit

Permalink
Implement max_concurrency support when starting shards (serenity-rs…
Browse files Browse the repository at this point in the history
…#2661)

When a session is first kicked off and we need to start many shards, we
can use the value of `max_concurrency` to kick off multiple shards in
parallel.
  • Loading branch information
mkrasnitski authored and arqunis committed Mar 1, 2024
1 parent 614e16d commit 3b740d7
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 50 deletions.
11 changes: 8 additions & 3 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,11 +333,15 @@ impl IntoFuture for ClientBuilder {
let cache = Arc::new(Cache::new_with_settings(self.cache_settings));

Box::pin(async move {
let (ws_url, shard_total) = match http.get_bot_gateway().await {
Ok(response) => (Arc::from(response.url), response.shards),
let (ws_url, shard_total, max_concurrency) = match http.get_bot_gateway().await {
Ok(response) => (
Arc::from(response.url),
response.shards,
response.session_start_limit.max_concurrency,
),
Err(err) => {
tracing::warn!("HTTP request to get gateway URL failed: {err}");
(Arc::from("wss://gateway.discord.gg"), NonZeroU16::MIN)
(Arc::from("wss://gateway.discord.gg"), NonZeroU16::MIN, NonZeroU16::MIN)
},
};

Expand All @@ -358,6 +362,7 @@ impl IntoFuture for ClientBuilder {
http: Arc::clone(&http),
intents,
presence: Some(presence),
max_concurrency,
});

let client = Client {
Expand Down
6 changes: 3 additions & 3 deletions src/gateway/bridge/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ use std::time::Duration as StdDuration;
pub use self::event::ShardStageUpdateEvent;
pub use self::shard_manager::{ShardManager, ShardManagerOptions};
pub use self::shard_messenger::ShardMessenger;
pub use self::shard_queuer::ShardQueuer;
pub use self::shard_queuer::{ShardQueue, ShardQueuer};
pub use self::shard_runner::{ShardRunner, ShardRunnerOptions};
pub use self::shard_runner_message::ShardRunnerMessage;
#[cfg(feature = "voice")]
Expand All @@ -72,11 +72,11 @@ pub enum ShardQueuerMessage {
/// Message to set the shard total.
SetShardTotal(NonZeroU16),
/// Message to start a shard.
Start(ShardId),
Start { shard_id: ShardId, concurrent: bool },
/// Message to shutdown the shard queuer.
Shutdown,
/// Message to dequeue/shutdown a shard.
ShutdownShard(ShardId, u16),
ShutdownShard { shard_id: ShardId, code: u16 },
}

/// Information about a [`ShardRunner`].
Expand Down
28 changes: 17 additions & 11 deletions src/gateway/bridge/shard_manager.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::{HashMap, VecDeque};
use std::collections::HashMap;
use std::num::NonZeroU16;
use std::sync::Arc;
#[cfg(feature = "framework")]
Expand All @@ -14,7 +14,7 @@ use typemap_rev::TypeMap;

#[cfg(feature = "voice")]
use super::VoiceGatewayManager;
use super::{ShardId, ShardQueuer, ShardQueuerMessage, ShardRunnerInfo};
use super::{ShardId, ShardQueue, ShardQueuer, ShardQueuerMessage, ShardRunnerInfo};
#[cfg(feature = "cache")]
use crate::cache::Cache;
use crate::client::{EventHandler, RawEventHandler};
Expand Down Expand Up @@ -71,6 +71,7 @@ use crate::model::gateway::GatewayIntents;
/// let ws_url = Arc::from(gateway_info.url);
/// let data = Arc::new(RwLock::new(TypeMap::new()));
/// let event_handler = Arc::new(Handler) as Arc<dyn EventHandler>;
/// let max_concurrency = std::num::NonZeroU16::MIN;
/// let framework = Arc::new(StandardFramework::new()) as Arc<dyn Framework + 'static>;
///
/// ShardManager::new(ShardManagerOptions {
Expand All @@ -87,6 +88,7 @@ use crate::model::gateway::GatewayIntents;
/// # http,
/// intents: GatewayIntents::non_privileged(),
/// presence: None,
/// max_concurrency,
/// });
/// # Ok(())
/// # }
Expand Down Expand Up @@ -137,7 +139,7 @@ impl ShardManager {
framework: opt.framework,
last_start: None,
manager: Arc::clone(&manager),
queue: VecDeque::new(),
queue: ShardQueue::new(opt.max_concurrency),
runners,
rx: shard_queue_rx,
#[cfg(feature = "voice")]
Expand Down Expand Up @@ -176,7 +178,7 @@ impl ShardManager {

self.set_shard_total(shard_total);
for shard_id in shard_index..shard_to {
self.boot(ShardId(shard_id));
self.boot(ShardId(shard_id), true);
}
}

Expand Down Expand Up @@ -204,7 +206,7 @@ impl ShardManager {
pub async fn restart(&self, shard_id: ShardId) {
info!("Restarting shard {shard_id}");
self.shutdown(shard_id, 4000).await;
self.boot(shard_id);
self.boot(shard_id, false);
}

/// Returns the [`ShardId`]s of the shards that have been instantiated and currently have a
Expand Down Expand Up @@ -233,9 +235,10 @@ impl ShardManager {
{
let mut shard_shutdown = self.shard_shutdown.lock().await;

drop(
self.shard_queuer.unbounded_send(ShardQueuerMessage::ShutdownShard(shard_id, code)),
);
drop(self.shard_queuer.unbounded_send(ShardQueuerMessage::ShutdownShard {
shard_id,
code,
}));
match timeout(TIMEOUT, shard_shutdown.next()).await {
Ok(Some(shutdown_shard_id)) => {
if shutdown_shard_id != shard_id {
Expand Down Expand Up @@ -299,11 +302,13 @@ impl ShardManager {
}

#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
fn boot(&self, shard_id: ShardId) {
fn boot(&self, shard_id: ShardId, concurrent: bool) {
info!("Telling shard queuer to start shard {shard_id}");

let msg = ShardQueuerMessage::Start(shard_id);
drop(self.shard_queuer.unbounded_send(msg));
drop(self.shard_queuer.unbounded_send(ShardQueuerMessage::Start {
shard_id,
concurrent,
}));
}

/// Returns the gateway intents used for this gateway connection.
Expand Down Expand Up @@ -371,4 +376,5 @@ pub struct ShardManagerOptions {
pub http: Arc<Http>,
pub intents: GatewayIntents,
pub presence: Option<PresenceData>,
pub max_concurrency: NonZeroU16,
}
156 changes: 124 additions & 32 deletions src/gateway/bridge/shard_queuer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ pub struct ShardQueuer {
/// A copy of the [`ShardManager`] to communicate with it.
pub manager: Arc<ShardManager>,
/// The shards that are queued for booting.
///
/// This will typically be filled with previously failed boots.
pub queue: VecDeque<ShardId>,
pub queue: ShardQueue,
/// A copy of the map of shard runners.
pub runners: Arc<Mutex<HashMap<ShardId, ShardRunnerInfo>>>,
/// A receiver channel for the shard queuer to be told to start shards.
Expand Down Expand Up @@ -102,35 +100,60 @@ impl ShardQueuer {
/// **Note**: This should be run in its own thread due to the blocking nature of the loop.
#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
pub async fn run(&mut self) {
// The duration to timeout from reads over the Rx channel. This can be done in a loop, and
// if the read times out then a shard can be started if one is presently waiting in the
// queue.
// We read from the Rx channel in a loop, and use a timeout of 5 seconds so that we don't
// hang forever. When we receive a command to start a shard, we append it to our queue. The
// queue is popped in batches of shards, which are started in parallel. A batch is fired
// every 5 seconds at minimum in order to avoid being ratelimited.
const TIMEOUT: Duration = Duration::from_secs(WAIT_BETWEEN_BOOTS_IN_SECONDS);

loop {
match timeout(TIMEOUT, self.rx.next()).await {
Ok(Some(ShardQueuerMessage::Shutdown)) => {
debug!("[Shard Queuer] Received to shutdown.");
self.shutdown_runners().await;

break;
},
Ok(Some(ShardQueuerMessage::ShutdownShard(shard, code))) => {
debug!("[Shard Queuer] Received to shutdown shard {} with {}.", shard.0, code);
self.shutdown(shard, code).await;
},
Ok(Some(ShardQueuerMessage::Start(shard_id))) => {
self.checked_start(shard_id).await;
},
Ok(Some(ShardQueuerMessage::SetShardTotal(shard_total))) => {
self.shard_total = shard_total;
},
Ok(None) => break,
Err(_) => {
if let Some(shard) = self.queue.pop_front() {
self.checked_start(shard).await;
}
},
if let Ok(msg) = timeout(TIMEOUT, self.rx.next()).await {
match msg {
Some(ShardQueuerMessage::SetShardTotal(shard_total)) => {
self.shard_total = shard_total;
},
Some(ShardQueuerMessage::Start {
shard_id,
concurrent,
}) => {
if concurrent {
// If we're starting multiple shards, we can start them concurrently
// according to `max_concurrency`, and want our batches to be of
// maximal size.
self.queue.push_back(shard_id);
if self.queue.buckets_filled() {
let batch = self.queue.pop_batch();
self.checked_start_batch(batch).await;
}
} else {
// In cases where we're only starting a single shard (e.g. if we're
// restarting a shard), we assume the queue will never fill up and skip
// using it so that we don't incur a 5 second timeout.
self.checked_start(shard_id).await;
}
},
Some(ShardQueuerMessage::ShutdownShard {
shard_id,
code,
}) => {
debug!(
"[Shard Queuer] Received to shutdown shard {} with code {}",
shard_id.0, code
);
self.shutdown(shard_id, code).await;
},
Some(ShardQueuerMessage::Shutdown) => {
debug!("[Shard Queuer] Received to shutdown all shards");
self.shutdown_runners().await;
break;
},
None => break,
}
} else {
// Once we've stopped receiving `Start` commands, we no longer care about the size
// of our batches being maximal.
let batch = self.queue.pop_batch();
self.checked_start_batch(batch).await;
}
}
}
Expand All @@ -157,14 +180,35 @@ impl ShardQueuer {
debug!("[Shard Queuer] Checked start for shard {shard_id}");

self.check_last_start().await;
self.try_start(shard_id).await;

self.last_start = Some(Instant::now());
}

#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
async fn checked_start_batch(&mut self, shard_ids: Vec<ShardId>) {
if shard_ids.is_empty() {
return;
}

debug!("[Shard Queuer] Starting batch of {} shards", shard_ids.len());
self.check_last_start().await;
for shard_id in shard_ids {
debug!("[Shard Queuer] Starting shard {shard_id}");
self.try_start(shard_id).await;
}
self.last_start = Some(Instant::now());
}

#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
async fn try_start(&mut self, shard_id: ShardId) {
if let Err(why) = self.start(shard_id).await {
warn!("[Shard Queuer] Err starting shard {shard_id}: {why:?}");
info!("[Shard Queuer] Re-queueing start of shard {shard_id}");

self.queue.push_back(shard_id);
// Try again in the next batch.
self.queue.push_front(shard_id);
}

self.last_start = Some(Instant::now());
}

#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
Expand Down Expand Up @@ -253,3 +297,51 @@ impl ShardQueuer {
}
}
}

/// A queue of [`ShardId`]s that is split up into multiple buckets according to the value of
/// [`max_concurrency`](crate::model::gateway::SessionStartLimit::max_concurrency).
#[must_use]
pub struct ShardQueue {
buckets: HashMap<u16, VecDeque<ShardId>>,
max_concurrency: NonZeroU16,
}

impl ShardQueue {
pub fn new(max_concurrency: NonZeroU16) -> Self {
Self {
buckets: HashMap::with_capacity(max_concurrency.get() as usize),
max_concurrency,
}
}

/// Calculates the corresponding bucket for the given `ShardId` and **appends** to it.
pub fn push_back(&mut self, shard_id: ShardId) {
let bucket = shard_id.0 % self.max_concurrency.get();
self.buckets.entry(bucket).or_default().push_back(shard_id);
}

/// Calculates the corresponding bucket for the given `ShardId` and **prepends** to it.
pub fn push_front(&mut self, shard_id: ShardId) {
let bucket = shard_id.0 % self.max_concurrency.get();
self.buckets.entry(bucket).or_default().push_front(shard_id);
}

/// Pops a `ShardId` from every bucket containing at least one and returns them all as a `Vec`.
pub fn pop_batch(&mut self) -> Vec<ShardId> {
(0..self.max_concurrency.get())
.filter_map(|i| self.buckets.get_mut(&i).and_then(|bucket| bucket.pop_front()))
.collect()
}

/// Returns `true` if every bucket contains at least one `ShardId`.
#[must_use]
pub fn buckets_filled(&self) -> bool {
for i in 0..self.max_concurrency.get() {
let Some(bucket) = self.buckets.get(&i) else { return false };
if bucket.is_empty() {
return false;
}
}
true
}
}
5 changes: 4 additions & 1 deletion src/model/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,10 @@ pub struct SessionStartLimit {
/// The total number of session starts within the ratelimit period allowed.
pub total: u64,
/// The number of identify requests allowed per 5 seconds.
pub max_concurrency: u64,
///
/// This is almost always 1, but for large bots (in more than 150,000 servers) it can be
/// larger.
pub max_concurrency: NonZeroU16,
}

#[cfg_attr(feature = "typesize", derive(typesize::derive::TypeSize))]
Expand Down

0 comments on commit 3b740d7

Please sign in to comment.