Skip to content

Commit

Permalink
Implement max_concurrency
Browse files Browse the repository at this point in the history
  • Loading branch information
GnomedDev committed Dec 5, 2023
1 parent c0a9436 commit f031098
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 43 deletions.
10 changes: 7 additions & 3 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ mod error;
mod event_handler;

use std::future::IntoFuture;
use std::num::NonZeroU16;
use std::ops::Range;
use std::sync::Arc;
#[cfg(feature = "framework")]
Expand Down Expand Up @@ -359,11 +360,13 @@ impl IntoFuture for ClientBuilder {
let cache = Arc::new(Cache::new_with_settings(self.cache_settings));

Box::pin(async move {
let ws_url = match http.get_gateway().await {
Ok(response) => Arc::from(response.url),
let (ws_url, max_concurrency) = match http.get_bot_gateway().await {
Ok(response) => {
(Arc::from(response.url), response.session_start_limit.max_concurrency)
},
Err(err) => {
tracing::warn!("HTTP request to get gateway URL failed: {err}");
Arc::from("wss://gateway.discord.gg")
(Arc::from("wss://gateway.discord.gg"), NonZeroU16::new(1).expect("1 != 0"))
},
};

Expand All @@ -383,6 +386,7 @@ impl IntoFuture for ClientBuilder {
http: Arc::clone(&http),
intents,
presence: Some(presence),
max_concurrency,
});

let client = Client {
Expand Down
3 changes: 3 additions & 0 deletions src/gateway/bridge/shard_manager.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::{HashMap, VecDeque};
use std::num::NonZeroU16;
use std::sync::Arc;
#[cfg(feature = "framework")]
use std::sync::OnceLock;
Expand Down Expand Up @@ -144,6 +145,7 @@ impl ShardManager {
http: opt.http,
intents: opt.intents,
presence: opt.presence,
max_concurrency: opt.max_concurrency,
};

spawn_named("shard_queuer::run", async move {
Expand Down Expand Up @@ -364,4 +366,5 @@ pub struct ShardManagerOptions {
pub http: Arc<Http>,
pub intents: GatewayIntents,
pub presence: Option<PresenceData>,
pub max_concurrency: NonZeroU16,
}
60 changes: 21 additions & 39 deletions src/gateway/bridge/shard_queuer.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::collections::{HashMap, VecDeque};
use std::num::NonZeroU16;
use std::sync::Arc;
#[cfg(feature = "framework")]
use std::sync::OnceLock;

use futures::channel::mpsc::UnboundedReceiver as Receiver;
use futures::StreamExt;
use tokio::sync::{Mutex, RwLock};
use tokio::time::{sleep, timeout, Duration, Instant};
use tokio::time::{timeout, Duration, Instant};
use tracing::{debug, info, instrument, warn};
use typemap_rev::TypeMap;

Expand Down Expand Up @@ -78,6 +79,10 @@ pub struct ShardQueuer {
pub http: Arc<Http>,
pub intents: GatewayIntents,
pub presence: Option<PresenceData>,
/// The maximum amount of shards that can be started at once.
///
/// This is almost always 1, but for bots in more than 150,000 servers this can be more.
pub max_concurrency: NonZeroU16,
}

impl ShardQueuer {
Expand Down Expand Up @@ -116,62 +121,39 @@ impl ShardQueuer {
debug!("[Shard Queuer] Received to shutdown shard {} with {}.", shard.0, code);
self.shutdown(shard, code).await;
},
Ok(Some(ShardQueuerMessage::Start(shard_info))) => {
debug!(
"[Shard Queuer] Received to start shard {} of {}.",
shard_info.id, shard_info.total
);

self.check_last_start().await;
self.checked_start(shard_info).await;
Ok(Some(ShardQueuerMessage::Start(info))) => {
self.queue.push_back(info);
},
Ok(None) => break,
Err(_) => {
if let Some(shard) = self.queue.pop_front() {
self.checked_start(shard).await;
}
},
Err(_) => self.start_batch().await,
}
}
}

#[instrument(skip(self))]
async fn check_last_start(&mut self) {
let Some(instant) = self.last_start else { return };

// We must wait 5 seconds between IDENTIFYs to avoid session invalidations.
let duration = Duration::from_secs(WAIT_BETWEEN_BOOTS_IN_SECONDS);
let elapsed = instant.elapsed();

if elapsed >= duration {
async fn start_batch(&mut self) {
if self.queue.is_empty() {
return;
}

let to_sleep = duration - elapsed;
let batch_size = (self.max_concurrency.get() as usize).min(self.queue.len());
debug!("[Shard Queuer] Starting batch of {batch_size} shards.");

sleep(to_sleep).await;
}

#[instrument(skip(self))]
async fn checked_start(&mut self, shard_info: ShardInfo) {
debug!(
"[Shard Queuer] Checked start for shard {} out of {}",
shard_info.id, shard_info.total
);
for shard_info in self.queue.drain(..batch_size).collect::<Vec<_>>() {
if let Err(why) = self.start(shard_info).await {
warn!("[Shard Queuer] Err starting shard {}: {:?}", shard_info.id, why);
info!("[Shard Queuer] Re-queueing start of shard {}", shard_info.id);

self.check_last_start().await;
if let Err(why) = self.start(shard_info).await {
warn!("[Shard Queuer] Err starting shard {}: {:?}", shard_info.id, why);
info!("[Shard Queuer] Re-queueing start of shard {}", shard_info.id);

self.queue.push_back(shard_info);
// Try again, must be in same batch so in push_front.
self.queue.push_front(shard_info);
}
}

self.last_start = Some(Instant::now());
}

#[instrument(skip(self))]
async fn start(&mut self, shard_info: ShardInfo) -> Result<()> {
async fn start(&self, shard_info: ShardInfo) -> Result<()> {
let mut shard = Shard::new(
Arc::clone(&self.ws_url),
self.http.token(),
Expand Down
2 changes: 1 addition & 1 deletion src/model/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ pub struct SessionStartLimit {
/// The total number of session starts within the ratelimit period allowed.
pub total: u64,
/// The number of identify requests allowed per 5 seconds.
pub max_concurrency: u64,
pub max_concurrency: NonZeroU16,
}

#[cfg_attr(feature = "typesize", derive(typesize::derive::TypeSize))]
Expand Down

0 comments on commit f031098

Please sign in to comment.