Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up ShardManager/Queuer/Runner #2653

Merged
merged 1 commit into from
Dec 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/cache/event.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashSet;
use std::num::NonZeroU16;

use super::{Cache, CacheUpdate};
use crate::model::channel::{GuildChannel, Message};
Expand Down Expand Up @@ -452,12 +453,13 @@ impl CacheUpdate for ReadyEvent {
let mut guilds_to_remove = vec![];
let ready_guilds_hashset =
self.ready.guilds.iter().map(|status| status.id).collect::<HashSet<_>>();
let shard_data = self.ready.shard.unwrap_or_else(|| ShardInfo::new(ShardId(1), 1));
let shard_data =
self.ready.shard.unwrap_or_else(|| ShardInfo::new(ShardId(1), NonZeroU16::MIN));

for guild_entry in cache.guilds.iter() {
let guild = guild_entry.key();
// Only handle data for our shard.
if crate::utils::shard_id(*guild, shard_data.total) == shard_data.id.0
if crate::utils::shard_id(*guild, shard_data.total.get()) == shard_data.id.0
&& !ready_guilds_hashset.contains(guild)
{
guilds_to_remove.push(*guild);
Expand Down
7 changes: 4 additions & 3 deletions src/cache/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

use std::collections::{HashMap, HashSet, VecDeque};
use std::hash::Hash;
use std::num::NonZeroU16;
#[cfg(feature = "temp_cache")]
use std::sync::Arc;
#[cfg(feature = "temp_cache")]
Expand Down Expand Up @@ -125,7 +126,7 @@ pub type MessageRef<'a> = CacheRef<'a, ChannelId, Message, HashMap<MessageId, Me
#[cfg_attr(feature = "typesize", derive(typesize::derive::TypeSize))]
#[derive(Debug)]
pub(crate) struct CachedShardData {
pub total: u32,
pub total: NonZeroU16,
pub connected: HashSet<ShardId>,
pub has_sent_shards_ready: bool,
}
Expand Down Expand Up @@ -281,7 +282,7 @@ impl Cache {
message_queue: DashMap::default(),

shard_data: RwLock::new(CachedShardData {
total: 1,
total: NonZeroU16::MIN,
connected: HashSet::new(),
has_sent_shards_ready: false,
}),
Expand Down Expand Up @@ -539,7 +540,7 @@ impl Cache {

/// Returns the number of shards.
#[inline]
pub fn shard_count(&self) -> u32 {
pub fn shard_count(&self) -> NonZeroU16 {
self.shard_data.read().total
}

Expand Down
9 changes: 4 additions & 5 deletions src/client/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,13 +348,12 @@ fn update_cache_with_event(
#[cfg(feature = "cache")]
{
let mut shards = cache.shard_data.write();
if shards.connected.len() as u32 == shards.total && !shards.has_sent_shards_ready {
if shards.connected.len() == shards.total.get() as usize
&& !shards.has_sent_shards_ready
{
shards.has_sent_shards_ready = true;
let total = shards.total;
drop(shards);

extra_event = Some(FullEvent::ShardsReady {
total_shards: total,
total_shards: shards.total,
});
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/client/event_handler.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::num::NonZeroU16;

use async_trait::async_trait;

Expand Down Expand Up @@ -117,7 +118,7 @@ event_handler! {

/// Dispatched when every shard has received a Ready event
#[cfg(feature = "cache")]
ShardsReady { total_shards: u32 } => async fn shards_ready(&self, ctx: Context);
ShardsReady { total_shards: NonZeroU16 } => async fn shards_ready(&self, ctx: Context);

/// Dispatched when a channel is created.
///
Expand Down
61 changes: 22 additions & 39 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ mod error;
mod event_handler;

use std::future::IntoFuture;
use std::num::NonZeroU16;
use std::ops::Range;
use std::sync::Arc;
#[cfg(feature = "framework")]
Expand All @@ -30,7 +31,7 @@ use std::sync::OnceLock;
use futures::channel::mpsc::UnboundedReceiver as Receiver;
use futures::future::BoxFuture;
use futures::StreamExt as _;
use tokio::sync::{Mutex, RwLock};
use tokio::sync::RwLock;
use tracing::{debug, error, info, instrument};
use typemap_rev::{TypeMap, TypeMapKey};

Expand All @@ -57,6 +58,7 @@ use crate::internal::prelude::*;
use crate::model::gateway::GatewayIntents;
use crate::model::id::ApplicationId;
use crate::model::user::OnlineStatus;
use crate::utils::check_shard_total;

/// A builder implementing [`IntoFuture`] building a [`Client`] to interact with Discord.
#[cfg(feature = "gateway")]
Expand Down Expand Up @@ -333,13 +335,13 @@ impl IntoFuture for ClientBuilder {
let cache = Arc::new(Cache::new_with_settings(self.cache_settings));

Box::pin(async move {
let ws_url = Arc::new(Mutex::new(match http.get_gateway().await {
Ok(response) => response.url,
let (ws_url, shard_total) = match http.get_bot_gateway().await {
Ok(response) => (Arc::from(response.url), response.shards),
Err(err) => {
tracing::warn!("HTTP request to get gateway URL failed: {}", err);
"wss://gateway.discord.gg".to_string()
tracing::warn!("HTTP request to get gateway URL failed: {err}");
(Arc::from("wss://gateway.discord.gg"), NonZeroU16::MIN)
},
}));
};

#[cfg(feature = "framework")]
let framework_cell = Arc::new(OnceLock::new());
Expand All @@ -349,12 +351,10 @@ impl IntoFuture for ClientBuilder {
raw_event_handlers,
#[cfg(feature = "framework")]
framework: Arc::clone(&framework_cell),
shard_index: 0,
shard_init: 0,
shard_total: 0,
#[cfg(feature = "voice")]
voice_manager: voice_manager.as_ref().map(Arc::clone),
ws_url: Arc::clone(&ws_url),
shard_total,
#[cfg(feature = "cache")]
cache: Arc::clone(&cache),
http: Arc::clone(&http),
Expand Down Expand Up @@ -586,11 +586,7 @@ pub struct Client {
#[cfg(feature = "voice")]
pub voice_manager: Option<Arc<dyn VoiceGatewayManager + 'static>>,
/// URL that the client's shards will use to connect to the gateway.
///
/// This is likely not important for production usage and is, at best, used for debugging.
///
/// This is wrapped in an `Arc<Mutex<T>>` so all shards will have an updated value available.
pub ws_url: Arc<Mutex<String>>,
pub ws_url: Arc<str>,
/// The cache for the client.
#[cfg(feature = "cache")]
pub cache: Arc<Cache>,
Expand Down Expand Up @@ -638,7 +634,7 @@ impl Client {
/// [gateway docs]: crate::gateway#sharding
#[instrument(skip(self))]
pub async fn start(&mut self) -> Result<()> {
self.start_connection(0, 0, 1).await
self.start_connection(0, 0, NonZeroU16::MIN).await
}

/// Establish the connection(s) and start listening for events.
Expand Down Expand Up @@ -681,8 +677,7 @@ impl Client {
pub async fn start_autosharded(&mut self) -> Result<()> {
let (end, total) = {
let res = self.http.get_bot_gateway().await?;

(res.shards - 1, res.shards)
(res.shards.get() - 1, res.shards)
};

self.start_connection(0, end, total).await
Expand Down Expand Up @@ -743,8 +738,8 @@ impl Client {
///
/// [gateway docs]: crate::gateway#sharding
#[instrument(skip(self))]
pub async fn start_shard(&mut self, shard: u32, shards: u32) -> Result<()> {
self.start_connection(shard, shard, shards).await
pub async fn start_shard(&mut self, shard: u16, shards: u16) -> Result<()> {
self.start_connection(shard, shard, check_shard_total(shards)).await
}

/// Establish sharded connections and start listening for events.
Expand Down Expand Up @@ -784,8 +779,8 @@ impl Client {
///
/// [Gateway docs]: crate::gateway#sharding
#[instrument(skip(self))]
pub async fn start_shards(&mut self, total_shards: u32) -> Result<()> {
self.start_connection(0, total_shards - 1, total_shards).await
pub async fn start_shards(&mut self, total_shards: u16) -> Result<()> {
self.start_connection(0, total_shards - 1, check_shard_total(total_shards)).await
}

/// Establish a range of sharded connections and start listening for events.
Expand Down Expand Up @@ -825,26 +820,16 @@ impl Client {
///
/// [Gateway docs]: crate::gateway#sharding
#[instrument(skip(self))]
pub async fn start_shard_range(&mut self, range: Range<u32>, total_shards: u32) -> Result<()> {
self.start_connection(range.start, range.end, total_shards).await
pub async fn start_shard_range(&mut self, range: Range<u16>, total_shards: u16) -> Result<()> {
self.start_connection(range.start, range.end, check_shard_total(total_shards)).await
}

/// Shard data layout is:
/// 0: first shard number to initialize
/// 1: shard number to initialize up to and including
/// 2: total number of shards the bot is sharding for
///
/// Not all shards need to be initialized in this process.
///
/// # Errors
///
/// Returns a [`ClientError::Shutdown`] when all shards have shutdown due to an error.
#[instrument(skip(self))]
async fn start_connection(
&mut self,
start_shard: u32,
end_shard: u32,
total_shards: u32,
start_shard: u16,
end_shard: u16,
total_shards: NonZeroU16,
) -> Result<()> {
#[cfg(feature = "voice")]
if let Some(voice_manager) = &self.voice_manager {
Expand All @@ -855,11 +840,9 @@ impl Client {

let init = end_shard - start_shard + 1;

self.shard_manager.set_shards(start_shard, init, total_shards).await;

debug!("Initializing shard info: {} - {}/{}", start_shard, init, total_shards);

if let Err(why) = self.shard_manager.initialize() {
if let Err(why) = self.shard_manager.initialize(start_shard, init, total_shards) {
error!("Failed to boot a shard: {:?}", why);
info!("Shutting down all shards");

Expand Down
8 changes: 5 additions & 3 deletions src/gateway/bridge/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ mod shard_runner_message;
mod voice;

use std::fmt;
use std::num::NonZeroU16;
use std::time::Duration as StdDuration;

pub use self::event::ShardStageUpdateEvent;
Expand All @@ -68,9 +69,10 @@ use crate::model::id::ShardId;
/// A message to be sent to the [`ShardQueuer`].
#[derive(Clone, Debug)]
pub enum ShardQueuerMessage {
/// Message to start a shard, where the 0-index element is the ID of the Shard to start and the
/// 1-index element is the total shards in use.
Start(ShardId, ShardId),
/// Message to set the shard total.
SetShardTotal(NonZeroU16),
/// Message to start a shard.
Start(ShardId),
/// Message to shutdown the shard queuer.
Shutdown,
/// Message to dequeue/shutdown a shard.
Expand Down
Loading