Skip to content

Commit

Permalink
chore: separate connect control (AppFlowy-IO#413)
Browse files Browse the repository at this point in the history
* chore: separate connect control

* chore: add tests

* chore: add tests
  • Loading branch information
appflowy authored Mar 23, 2024
1 parent acc1341 commit 4878d51
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 64 deletions.
2 changes: 1 addition & 1 deletion libs/collab-rt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ serde_json.workspace = true
thiserror = "1.0.56"
anyhow = "1"

collab = { version = "0.1.0"}
collab = { version = "0.1.0", features = ["async-plugin"]}
collab-entity = { version = "0.1.0" }
collab-folder = { version = "0.1.0" }
collab-document = { version = "0.1.0" }
Expand Down
174 changes: 174 additions & 0 deletions libs/collab-rt/src/connect_state.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
use crate::CollabClientStream;
use collab_rt_entity::message::{RealtimeMessage, SystemMessage};
use collab_rt_entity::user::{RealtimeUser, UserDevice};
use dashmap::DashMap;

use std::sync::Arc;
use tracing::{info, trace};

#[derive(Clone, Default)]
pub struct ConnectState {
pub(crate) user_by_device: Arc<DashMap<UserDevice, RealtimeUser>>,
/// Maintains a record of all client streams. A client stream associated with a user may be terminated for the following reasons:
/// 1. User disconnection.
/// 2. Server closes the connection due to a ping/pong timeout.
pub(crate) client_stream_by_user: Arc<DashMap<RealtimeUser, CollabClientStream>>,
}

impl ConnectState {
pub fn new() -> Self {
Self::default()
}
pub fn handle_user_connect(
&self,
new_user: RealtimeUser,
client_stream: CollabClientStream,
) -> Option<RealtimeUser> {
let old_user = self
.user_by_device
.insert(UserDevice::from(&new_user), new_user.clone());

trace!(
"[realtime]: new connection => {}, removing old: {:?}",
new_user,
old_user
);

if let Some(old_user) = &old_user {
// Remove and retrieve the old client stream if it exists.
if let Some((_, client_stream)) = self.client_stream_by_user.remove(old_user) {
info!("Removing old stream for same user and device: {}", old_user);
// Notify the old stream of the duplicate connection.
client_stream
.sink
.do_send(RealtimeMessage::System(SystemMessage::DuplicateConnection));
}
// Remove the old user from all collaboration groups.
}
self.client_stream_by_user.insert(new_user, client_stream);

old_user
}

pub fn handle_user_disconnect(
&self,
disconnect_user: &RealtimeUser,
) -> Option<(UserDevice, RealtimeUser)> {
let user_device = UserDevice::from(disconnect_user);
let was_removed = self
.user_by_device
.remove_if(&user_device, |_, existing_user| {
existing_user.session_id == disconnect_user.session_id
});

if was_removed.is_some() && self.client_stream_by_user.remove(disconnect_user).is_some() {
info!("remove client stream: {}", &disconnect_user);
}

was_removed
}

#[allow(dead_code)]
fn num_connected_users(&self) -> usize {
self.user_by_device.len()
}

#[allow(dead_code)]
fn get_user_by_device(&self, user_device: &UserDevice) -> Option<RealtimeUser> {
self.user_by_device.get(user_device).map(|v| v.clone())
}
}

#[cfg(test)]
mod tests {
use crate::connect_state::ConnectState;
use crate::{CollabClientStream, RealtimeClientWebsocketSink};
use collab_rt_entity::message::RealtimeMessage;
use collab_rt_entity::user::{RealtimeUser, UserDevice};

struct MockSink;

impl RealtimeClientWebsocketSink for MockSink {
fn do_send(&self, _message: RealtimeMessage) {}
}

fn mock_user(uid: i64, device_id: &str) -> RealtimeUser {
RealtimeUser::new(
uid,
device_id.to_string(),
uuid::Uuid::new_v4().to_string(),
chrono::Utc::now().timestamp(),
)
}

fn mock_stream() -> CollabClientStream {
CollabClientStream::new(MockSink)
}

#[tokio::test]
async fn same_user_different_device_connect_test() {
let connect_state = ConnectState::new();
let user_device_a = mock_user(1, "device_a");
let user_device_b = mock_user(1, "device_b");
connect_state.handle_user_connect(user_device_a, mock_stream());
connect_state.handle_user_connect(user_device_b, mock_stream());

assert_eq!(connect_state.num_connected_users(), 2);
}

#[tokio::test]
async fn same_user_same_device_connect_test() {
let connect_state = ConnectState::new();
let user_device_a = mock_user(1, "device_a");
let user_device_b = mock_user(1, "device_a");
connect_state.handle_user_connect(user_device_a, mock_stream());
connect_state.handle_user_connect(user_device_b.clone(), mock_stream());

assert_eq!(connect_state.num_connected_users(), 1);
let user = connect_state
.get_user_by_device(&UserDevice::from(&user_device_b))
.unwrap();
assert_eq!(user, user_device_b);
}

#[tokio::test]
async fn multiple_devices_connect_test() {
let user_a = vec![
mock_user(1, "device_a"),
mock_user(1, "device_b"),
mock_user(1, "device_c"),
mock_user(1, "device_d"),
];

let user_b = vec![
mock_user(2, "device_a"),
mock_user(2, "device_b"),
mock_user(2, "device_b"),
mock_user(2, "device_a"),
];

let connect_state = ConnectState::new();

let (tx, rx_1) = tokio::sync::oneshot::channel();
let cloned_connect_state = connect_state.clone();
tokio::spawn(async move {
for user in user_a {
cloned_connect_state.handle_user_connect(user, mock_stream());
}
tx.send(()).unwrap();
});

let (tx, rx_2) = tokio::sync::oneshot::channel();
let cloned_connect_state = connect_state.clone();
tokio::spawn(async move {
for user in user_b {
cloned_connect_state.handle_user_connect(user, mock_stream());
}
tx.send(()).unwrap();
});

let _ = futures::future::join(rx_1, rx_2).await;

assert_eq!(connect_state.num_connected_users(), 6);
}
}
2 changes: 2 additions & 0 deletions libs/collab-rt/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
mod collaborate;
pub mod command;
pub mod connect_state;
pub mod error;
mod metrics;
mod permission;
mod rt_server;
mod util;

pub use metrics::*;
pub use permission::*;
pub use rt_server::*;
85 changes: 26 additions & 59 deletions libs/collab-rt/src/rt_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,32 @@ use crate::{spawn_metrics, CollabRealtimeMetrics, RealtimeAccessControl};

use anyhow::Result;
use collab_rt_entity::collab_msg::{ClientCollabMessage, CollabSinkMessage};
use collab_rt_entity::message::{MessageByObjectId, RealtimeMessage, SystemMessage};
use collab_rt_entity::message::{MessageByObjectId, RealtimeMessage};
use collab_rt_entity::user::{Editing, RealtimeUser, UserDevice};
use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
use database::collab::CollabStorage;
use std::collections::HashSet;
use std::future::Future;

use crate::connect_state::ConnectState;
use async_trait::async_trait;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
use tokio_stream::StreamExt;
use tracing::{error, event, info, trace, warn};
use tracing::{error, event, trace, warn};

#[derive(Clone)]
pub struct CollabRealtimeServer<S, AC> {
#[allow(dead_code)]
storage: Arc<S>,
/// Keep track of all collab groups
groups: Arc<AllGroup<S, AC>>,
//
pub user_by_device: Arc<DashMap<UserDevice, RealtimeUser>>,
/// Keep track of all object ids that a user is subscribed to
editing_collab_by_user: Arc<DashMap<RealtimeUser, HashSet<Editing>>>,
/// Maintains a record of all client streams. A client stream associated with a user may be terminated for the following reasons:
/// 1. User disconnection.
/// 2. Server closes the connection due to a ping/pong timeout.
client_stream_by_user: Arc<DashMap<RealtimeUser, CollabClientStream>>,
connect_state: ConnectState,
group_sender_by_object_id: Arc<DashMap<String, GroupCommandSender>>,
access_control: Arc<AC>,
/// Keep track of all object ids that a user is subscribed to
editing_collab_by_user: Arc<DashMap<RealtimeUser, HashSet<Editing>>>,
#[allow(dead_code)]
metrics: Arc<CollabRealtimeMetrics>,
}
Expand All @@ -54,10 +48,9 @@ where
metrics: Arc<CollabRealtimeMetrics>,
command_recv: RTCommandReceiver,
) -> Result<Self, RealtimeError> {
let connect_state = ConnectState::new();
let access_control = Arc::new(access_control);
let groups = Arc::new(AllGroup::new(storage.clone(), access_control.clone()));
let client_stream_by_user: Arc<DashMap<RealtimeUser, CollabClientStream>> = Default::default();
let editing_collab_by_user = Default::default();
let group_sender_by_object_id: Arc<DashMap<String, GroupCommandSender>> =
Arc::new(Default::default());

Expand All @@ -67,18 +60,16 @@ where
&group_sender_by_object_id,
Arc::downgrade(&groups),
&metrics,
&client_stream_by_user,
&connect_state.client_stream_by_user,
&storage,
);

Ok(Self {
storage,
groups,
user_by_device: Default::default(),
editing_collab_by_user,
client_stream_by_user,
connect_state,
group_sender_by_object_id,
access_control,
editing_collab_by_user: Arc::new(Default::default()),
metrics,
})
}
Expand All @@ -97,38 +88,15 @@ where
) -> Pin<Box<dyn Future<Output = Result<(), RealtimeError>>>> {
let new_client_stream = CollabClientStream::new(conn_sink);
let groups = self.groups.clone();
let device_by_user = self.user_by_device.clone();
let client_stream_by_user = self.client_stream_by_user.clone();
let connect_control = self.connect_state.clone();
let editing_collab_by_user = self.editing_collab_by_user.clone();

Box::pin(async move {
let old_user =
device_by_user.insert(UserDevice::from(&connected_user), connected_user.clone());

trace!(
"[realtime]: new connection => {}, removing old: {:?}",
connected_user,
old_user
);

// If there was a previous connection for the same user, handle cleanup.
if let Some(old_user) = old_user {
// Remove and retrieve the old client stream if it exists.
if let Some((_, client_stream)) = client_stream_by_user.remove(&old_user) {
info!(
"Removing old stream for same user and device: {}",
&old_user
);
// Notify the old stream of the duplicate connection.
client_stream
.sink
.do_send(RealtimeMessage::System(SystemMessage::DuplicateConnection));
}
if let Some(old_user) = connect_control.handle_user_connect(connected_user, new_client_stream)
{
// Remove the old user from all collaboration groups.
remove_user_in_groups(&groups, &editing_collab_by_user, &old_user).await;
}

client_stream_by_user.insert(connected_user, new_client_stream);
Ok(())
})
}
Expand All @@ -145,23 +113,14 @@ where
disconnect_user: RealtimeUser,
) -> Pin<Box<dyn Future<Output = Result<(), RealtimeError>>>> {
let groups = self.groups.clone();
let client_stream_by_user = self.client_stream_by_user.clone();
let connect_control = self.connect_state.clone();
let editing_collab_by_user = self.editing_collab_by_user.clone();
let device_by_user = self.user_by_device.clone();

Box::pin(async move {
let user_device = UserDevice::from(&disconnect_user);
let was_removed = device_by_user.remove_if(&user_device, |_, existing_user| {
existing_user.session_id == disconnect_user.session_id
});

trace!("[realtime]: disconnect => {}", disconnect_user);
let was_removed = connect_control.handle_user_disconnect(&disconnect_user);
if was_removed.is_some() {
trace!("[realtime]: disconnect => {}", disconnect_user);

remove_user_in_groups(&groups, &editing_collab_by_user, &disconnect_user).await;
if client_stream_by_user.remove(&disconnect_user).is_some() {
info!("remove client stream: {}", &disconnect_user);
}
}

Ok(())
Expand All @@ -175,7 +134,7 @@ where
message_by_oid: MessageByObjectId,
) -> Pin<Box<dyn Future<Output = Result<(), RealtimeError>>>> {
let group_sender_by_object_id = self.group_sender_by_object_id.clone();
let client_stream_by_user = self.client_stream_by_user.clone();
let client_stream_by_user = self.connect_state.client_stream_by_user.clone();
let groups = self.groups.clone();
let edit_collab_by_user = self.editing_collab_by_user.clone();
let access_control = self.access_control.clone();
Expand Down Expand Up @@ -223,6 +182,14 @@ where
Ok(())
})
}

pub fn get_user_by_device(&self, user_device: &UserDevice) -> Option<RealtimeUser> {
self
.connect_state
.user_by_device
.get(user_device)
.map(|entry| entry.value().clone())
}
}

async fn remove_user_in_groups<S, AC>(
Expand Down Expand Up @@ -257,7 +224,7 @@ pub trait RealtimeClientWebsocketSink: Send + Sync + 'static {
}

pub struct CollabClientStream {
sink: Arc<dyn RealtimeClientWebsocketSink>,
pub(crate) sink: Arc<dyn RealtimeClientWebsocketSink>,
/// Used to receive messages from the collab server. The message will forward to the [CollabBroadcast] which
/// will broadcast the message to all connected clients.
///
Expand Down
5 changes: 1 addition & 4 deletions src/biz/actix_ws/server/rt_actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,7 @@ where
message,
} = client_msg;

let user = self
.user_by_device
.get(&UserDevice::new(&device_id, uid))
.map(|entry| entry.value().clone());
let user = self.get_user_by_device(&UserDevice::new(&device_id, uid));

match (user, message.transform()) {
(Some(user), Ok(messages)) => self.handle_client_message(user, messages),
Expand Down

0 comments on commit 4878d51

Please sign in to comment.