diff --git a/libs/collab-rt/Cargo.toml b/libs/collab-rt/Cargo.toml index 0f7a69876b558..8e4d93dc0128d 100644 --- a/libs/collab-rt/Cargo.toml +++ b/libs/collab-rt/Cargo.toml @@ -21,7 +21,7 @@ serde_json.workspace = true thiserror = "1.0.56" anyhow = "1" -collab = { version = "0.1.0"} +collab = { version = "0.1.0", features = ["async-plugin"]} collab-entity = { version = "0.1.0" } collab-folder = { version = "0.1.0" } collab-document = { version = "0.1.0" } diff --git a/libs/collab-rt/src/connect_state.rs b/libs/collab-rt/src/connect_state.rs new file mode 100644 index 0000000000000..eef1fcf7e4409 --- /dev/null +++ b/libs/collab-rt/src/connect_state.rs @@ -0,0 +1,174 @@ +use crate::CollabClientStream; +use collab_rt_entity::message::{RealtimeMessage, SystemMessage}; +use collab_rt_entity::user::{RealtimeUser, UserDevice}; +use dashmap::DashMap; + +use std::sync::Arc; +use tracing::{info, trace}; + +#[derive(Clone, Default)] +pub struct ConnectState { + pub(crate) user_by_device: Arc>, + /// Maintains a record of all client streams. A client stream associated with a user may be terminated for the following reasons: + /// 1. User disconnection. + /// 2. Server closes the connection due to a ping/pong timeout. + pub(crate) client_stream_by_user: Arc>, +} + +impl ConnectState { + pub fn new() -> Self { + Self::default() + } + pub fn handle_user_connect( + &self, + new_user: RealtimeUser, + client_stream: CollabClientStream, + ) -> Option { + let old_user = self + .user_by_device + .insert(UserDevice::from(&new_user), new_user.clone()); + + trace!( + "[realtime]: new connection => {}, removing old: {:?}", + new_user, + old_user + ); + + if let Some(old_user) = &old_user { + // Remove and retrieve the old client stream if it exists. + if let Some((_, client_stream)) = self.client_stream_by_user.remove(old_user) { + info!("Removing old stream for same user and device: {}", old_user); + // Notify the old stream of the duplicate connection. + client_stream + .sink + .do_send(RealtimeMessage::System(SystemMessage::DuplicateConnection)); + } + // Remove the old user from all collaboration groups. + } + self.client_stream_by_user.insert(new_user, client_stream); + + old_user + } + + pub fn handle_user_disconnect( + &self, + disconnect_user: &RealtimeUser, + ) -> Option<(UserDevice, RealtimeUser)> { + let user_device = UserDevice::from(disconnect_user); + let was_removed = self + .user_by_device + .remove_if(&user_device, |_, existing_user| { + existing_user.session_id == disconnect_user.session_id + }); + + if was_removed.is_some() && self.client_stream_by_user.remove(disconnect_user).is_some() { + info!("remove client stream: {}", &disconnect_user); + } + + was_removed + } + + #[allow(dead_code)] + fn num_connected_users(&self) -> usize { + self.user_by_device.len() + } + + #[allow(dead_code)] + fn get_user_by_device(&self, user_device: &UserDevice) -> Option { + self.user_by_device.get(user_device).map(|v| v.clone()) + } +} + +#[cfg(test)] +mod tests { + use crate::connect_state::ConnectState; + use crate::{CollabClientStream, RealtimeClientWebsocketSink}; + use collab_rt_entity::message::RealtimeMessage; + use collab_rt_entity::user::{RealtimeUser, UserDevice}; + + struct MockSink; + + impl RealtimeClientWebsocketSink for MockSink { + fn do_send(&self, _message: RealtimeMessage) {} + } + + fn mock_user(uid: i64, device_id: &str) -> RealtimeUser { + RealtimeUser::new( + uid, + device_id.to_string(), + uuid::Uuid::new_v4().to_string(), + chrono::Utc::now().timestamp(), + ) + } + + fn mock_stream() -> CollabClientStream { + CollabClientStream::new(MockSink) + } + + #[tokio::test] + async fn same_user_different_device_connect_test() { + let connect_state = ConnectState::new(); + let user_device_a = mock_user(1, "device_a"); + let user_device_b = mock_user(1, "device_b"); + connect_state.handle_user_connect(user_device_a, mock_stream()); + connect_state.handle_user_connect(user_device_b, mock_stream()); + + assert_eq!(connect_state.num_connected_users(), 2); + } + + #[tokio::test] + async fn same_user_same_device_connect_test() { + let connect_state = ConnectState::new(); + let user_device_a = mock_user(1, "device_a"); + let user_device_b = mock_user(1, "device_a"); + connect_state.handle_user_connect(user_device_a, mock_stream()); + connect_state.handle_user_connect(user_device_b.clone(), mock_stream()); + + assert_eq!(connect_state.num_connected_users(), 1); + let user = connect_state + .get_user_by_device(&UserDevice::from(&user_device_b)) + .unwrap(); + assert_eq!(user, user_device_b); + } + + #[tokio::test] + async fn multiple_devices_connect_test() { + let user_a = vec![ + mock_user(1, "device_a"), + mock_user(1, "device_b"), + mock_user(1, "device_c"), + mock_user(1, "device_d"), + ]; + + let user_b = vec![ + mock_user(2, "device_a"), + mock_user(2, "device_b"), + mock_user(2, "device_b"), + mock_user(2, "device_a"), + ]; + + let connect_state = ConnectState::new(); + + let (tx, rx_1) = tokio::sync::oneshot::channel(); + let cloned_connect_state = connect_state.clone(); + tokio::spawn(async move { + for user in user_a { + cloned_connect_state.handle_user_connect(user, mock_stream()); + } + tx.send(()).unwrap(); + }); + + let (tx, rx_2) = tokio::sync::oneshot::channel(); + let cloned_connect_state = connect_state.clone(); + tokio::spawn(async move { + for user in user_b { + cloned_connect_state.handle_user_connect(user, mock_stream()); + } + tx.send(()).unwrap(); + }); + + let _ = futures::future::join(rx_1, rx_2).await; + + assert_eq!(connect_state.num_connected_users(), 6); + } +} diff --git a/libs/collab-rt/src/lib.rs b/libs/collab-rt/src/lib.rs index 7c962272a6474..605e85542d1b4 100644 --- a/libs/collab-rt/src/lib.rs +++ b/libs/collab-rt/src/lib.rs @@ -1,10 +1,12 @@ mod collaborate; pub mod command; +pub mod connect_state; pub mod error; mod metrics; mod permission; mod rt_server; mod util; + pub use metrics::*; pub use permission::*; pub use rt_server::*; diff --git a/libs/collab-rt/src/rt_server.rs b/libs/collab-rt/src/rt_server.rs index 3b3459d437ce8..2c8b3f1fbe904 100644 --- a/libs/collab-rt/src/rt_server.rs +++ b/libs/collab-rt/src/rt_server.rs @@ -7,7 +7,7 @@ use crate::{spawn_metrics, CollabRealtimeMetrics, RealtimeAccessControl}; use anyhow::Result; use collab_rt_entity::collab_msg::{ClientCollabMessage, CollabSinkMessage}; -use collab_rt_entity::message::{MessageByObjectId, RealtimeMessage, SystemMessage}; +use collab_rt_entity::message::{MessageByObjectId, RealtimeMessage}; use collab_rt_entity::user::{Editing, RealtimeUser, UserDevice}; use dashmap::mapref::entry::Entry; use dashmap::DashMap; @@ -15,30 +15,24 @@ use database::collab::CollabStorage; use std::collections::HashSet; use std::future::Future; +use crate::connect_state::ConnectState; use async_trait::async_trait; use std::pin::Pin; use std::sync::Arc; use std::time::Duration; use tokio_stream::wrappers::{BroadcastStream, ReceiverStream}; use tokio_stream::StreamExt; -use tracing::{error, event, info, trace, warn}; +use tracing::{error, event, trace, warn}; #[derive(Clone)] pub struct CollabRealtimeServer { - #[allow(dead_code)] - storage: Arc, /// Keep track of all collab groups groups: Arc>, - // - pub user_by_device: Arc>, - /// Keep track of all object ids that a user is subscribed to - editing_collab_by_user: Arc>>, - /// Maintains a record of all client streams. A client stream associated with a user may be terminated for the following reasons: - /// 1. User disconnection. - /// 2. Server closes the connection due to a ping/pong timeout. - client_stream_by_user: Arc>, + connect_state: ConnectState, group_sender_by_object_id: Arc>, access_control: Arc, + /// Keep track of all object ids that a user is subscribed to + editing_collab_by_user: Arc>>, #[allow(dead_code)] metrics: Arc, } @@ -54,10 +48,9 @@ where metrics: Arc, command_recv: RTCommandReceiver, ) -> Result { + let connect_state = ConnectState::new(); let access_control = Arc::new(access_control); let groups = Arc::new(AllGroup::new(storage.clone(), access_control.clone())); - let client_stream_by_user: Arc> = Default::default(); - let editing_collab_by_user = Default::default(); let group_sender_by_object_id: Arc> = Arc::new(Default::default()); @@ -67,18 +60,16 @@ where &group_sender_by_object_id, Arc::downgrade(&groups), &metrics, - &client_stream_by_user, + &connect_state.client_stream_by_user, &storage, ); Ok(Self { - storage, groups, - user_by_device: Default::default(), - editing_collab_by_user, - client_stream_by_user, + connect_state, group_sender_by_object_id, access_control, + editing_collab_by_user: Arc::new(Default::default()), metrics, }) } @@ -97,38 +88,15 @@ where ) -> Pin>>> { let new_client_stream = CollabClientStream::new(conn_sink); let groups = self.groups.clone(); - let device_by_user = self.user_by_device.clone(); - let client_stream_by_user = self.client_stream_by_user.clone(); + let connect_control = self.connect_state.clone(); let editing_collab_by_user = self.editing_collab_by_user.clone(); Box::pin(async move { - let old_user = - device_by_user.insert(UserDevice::from(&connected_user), connected_user.clone()); - - trace!( - "[realtime]: new connection => {}, removing old: {:?}", - connected_user, - old_user - ); - - // If there was a previous connection for the same user, handle cleanup. - if let Some(old_user) = old_user { - // Remove and retrieve the old client stream if it exists. - if let Some((_, client_stream)) = client_stream_by_user.remove(&old_user) { - info!( - "Removing old stream for same user and device: {}", - &old_user - ); - // Notify the old stream of the duplicate connection. - client_stream - .sink - .do_send(RealtimeMessage::System(SystemMessage::DuplicateConnection)); - } + if let Some(old_user) = connect_control.handle_user_connect(connected_user, new_client_stream) + { // Remove the old user from all collaboration groups. remove_user_in_groups(&groups, &editing_collab_by_user, &old_user).await; } - - client_stream_by_user.insert(connected_user, new_client_stream); Ok(()) }) } @@ -145,23 +113,14 @@ where disconnect_user: RealtimeUser, ) -> Pin>>> { let groups = self.groups.clone(); - let client_stream_by_user = self.client_stream_by_user.clone(); + let connect_control = self.connect_state.clone(); let editing_collab_by_user = self.editing_collab_by_user.clone(); - let device_by_user = self.user_by_device.clone(); Box::pin(async move { - let user_device = UserDevice::from(&disconnect_user); - let was_removed = device_by_user.remove_if(&user_device, |_, existing_user| { - existing_user.session_id == disconnect_user.session_id - }); - + trace!("[realtime]: disconnect => {}", disconnect_user); + let was_removed = connect_control.handle_user_disconnect(&disconnect_user); if was_removed.is_some() { - trace!("[realtime]: disconnect => {}", disconnect_user); - remove_user_in_groups(&groups, &editing_collab_by_user, &disconnect_user).await; - if client_stream_by_user.remove(&disconnect_user).is_some() { - info!("remove client stream: {}", &disconnect_user); - } } Ok(()) @@ -175,7 +134,7 @@ where message_by_oid: MessageByObjectId, ) -> Pin>>> { let group_sender_by_object_id = self.group_sender_by_object_id.clone(); - let client_stream_by_user = self.client_stream_by_user.clone(); + let client_stream_by_user = self.connect_state.client_stream_by_user.clone(); let groups = self.groups.clone(); let edit_collab_by_user = self.editing_collab_by_user.clone(); let access_control = self.access_control.clone(); @@ -223,6 +182,14 @@ where Ok(()) }) } + + pub fn get_user_by_device(&self, user_device: &UserDevice) -> Option { + self + .connect_state + .user_by_device + .get(user_device) + .map(|entry| entry.value().clone()) + } } async fn remove_user_in_groups( @@ -257,7 +224,7 @@ pub trait RealtimeClientWebsocketSink: Send + Sync + 'static { } pub struct CollabClientStream { - sink: Arc, + pub(crate) sink: Arc, /// Used to receive messages from the collab server. The message will forward to the [CollabBroadcast] which /// will broadcast the message to all connected clients. /// diff --git a/src/biz/actix_ws/server/rt_actor.rs b/src/biz/actix_ws/server/rt_actor.rs index a8673cf8a78ac..25ff6a05d5394 100644 --- a/src/biz/actix_ws/server/rt_actor.rs +++ b/src/biz/actix_ws/server/rt_actor.rs @@ -99,10 +99,7 @@ where message, } = client_msg; - let user = self - .user_by_device - .get(&UserDevice::new(&device_id, uid)) - .map(|entry| entry.value().clone()); + let user = self.get_user_by_device(&UserDevice::new(&device_id, uid)); match (user, message.transform()) { (Some(user), Ok(messages)) => self.handle_client_message(user, messages),