Skip to content

Commit

Permalink
feat: Add get_inbound_group_sessions_stream to the CryptoStore tr…
Browse files Browse the repository at this point in the history
…ait.

This patch implements the `get_inbound_group_sessions_stream` method onto
the `CryptoStore` trait.

In order for the trait to continue being object-safe, this patch
implements a `StreamOf` struct that simply wraps any `T` where `T:
Stream`. `StreamOf` also implements `Stream` and forwards everything to
its inner stream.

This patch finally updates the test suite of `matrix-sdk-crypto` to test
this new `get_inbound_group_sessions_stream` method.
  • Loading branch information
Hywan committed Jan 29, 2024
1 parent 790d405 commit 641afac
Show file tree
Hide file tree
Showing 10 changed files with 274 additions and 67 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/matrix-sdk-crypto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ testing = ["dep:http"]
[dependencies]
aes = "0.8.1"
as_variant = { workspace = true }
async-stream = { workspace = true }
async-trait = { workspace = true }
bs58 = { version = "0.5.0" }
byteorder = { workspace = true }
Expand Down
72 changes: 69 additions & 3 deletions crates/matrix-sdk-crypto/src/store/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
macro_rules! cryptostore_integration_tests {
() => {
mod cryptostore_integration_tests {
use std::time::Duration;
use std::collections::{BTreeMap, HashMap};
use std::{
collections::{BTreeMap, HashMap},
future::ready,
time::Duration,
};

use assert_matches::assert_matches;
use futures_util::{StreamExt, TryStreamExt};
use matrix_sdk_test::async_test;
use ruma::{
device_id,
Expand Down Expand Up @@ -266,6 +270,65 @@ macro_rules! cryptostore_integration_tests {
store.save_changes(changes).await.expect("Can't save group session");
}

#[async_test]
async fn save_many_inbound_group_sessions() {
let (account, store) = get_loaded_store("save_many_inbound_group_sessions").await;

const NUMBER_OF_SESSIONS_FOR_ROOM1: usize = 400;
const NUMBER_OF_SESSIONS_FOR_ROOM2: usize = 600;
const NUMBER_OF_SESSIONS: usize = NUMBER_OF_SESSIONS_FOR_ROOM1 + NUMBER_OF_SESSIONS_FOR_ROOM2;

let room_id1 = room_id!("!a:localhost");
let room_id2 = room_id!("!b:localhost");
let mut sessions = Vec::with_capacity(NUMBER_OF_SESSIONS);


for i in 0..(NUMBER_OF_SESSIONS) {
let (_, session) = account.create_group_session_pair_with_defaults(
if i < NUMBER_OF_SESSIONS_FOR_ROOM1 {
&room_id1
} else {
&room_id2
}
).await;
sessions.push(session);
}

let changes =
Changes { inbound_group_sessions: sessions, ..Default::default() };

store.save_changes(changes).await.expect("Can't save group session");

assert_eq!(store.get_inbound_group_sessions().await.unwrap().len(), NUMBER_OF_SESSIONS);
assert_eq!(store.get_inbound_group_sessions_stream().await.unwrap().collect::<Vec<_>>().await.len(), NUMBER_OF_SESSIONS);

// Inefficient because all sessions are collected once in a `Vec`, then filtered, then collected again in another `Vec`.
assert_eq!(
store
.get_inbound_group_sessions()
.await
.unwrap()
.into_iter()
.filter(|session: &InboundGroupSession| session.room_id() == room_id2)
.collect::<Vec<_>>()
.len(),
NUMBER_OF_SESSIONS_FOR_ROOM2
);

// Efficient because sessions are filtered, then collected only once.
assert_eq!(
store
.get_inbound_group_sessions_stream()
.await
.unwrap()
.try_filter(|session: &InboundGroupSession| ready(session.room_id() == room_id2))
.collect::<Vec<_>>()
.await
.len(),
NUMBER_OF_SESSIONS_FOR_ROOM2
);
}

#[async_test]
async fn save_inbound_group_session_for_backup() {
let (account, store) =
Expand All @@ -286,6 +349,7 @@ macro_rules! cryptostore_integration_tests {
.unwrap();
assert_eq!(session, loaded_session);
assert_eq!(store.get_inbound_group_sessions().await.unwrap().len(), 1);
assert_eq!(store.get_inbound_group_sessions_stream().await.unwrap().collect::<Vec<_>>().await.len(), 1);
assert_eq!(store.inbound_group_session_counts().await.unwrap().total, 1);
assert_eq!(store.inbound_group_session_counts().await.unwrap().backed_up, 0);

Expand Down Expand Up @@ -373,7 +437,8 @@ macro_rules! cryptostore_integration_tests {
async fn load_inbound_group_session() {
let dir = "load_inbound_group_session";
let (account, store) = get_loaded_store(dir).await;
assert_eq!(store.get_inbound_group_sessions().await.unwrap().len(), 0);
assert!(store.get_inbound_group_sessions().await.unwrap().is_empty());
assert!(store.get_inbound_group_sessions_stream().await.unwrap().collect::<Vec<_>>().await.is_empty());

let room_id = &room_id!("!test:localhost");
let (_, session) = account.create_group_session_pair_with_defaults(room_id).await;
Expand Down Expand Up @@ -402,6 +467,7 @@ macro_rules! cryptostore_integration_tests {
let export = loaded_session.export().await;

assert_eq!(store.get_inbound_group_sessions().await.unwrap().len(), 1);
assert_eq!(store.get_inbound_group_sessions_stream().await.unwrap().collect::<Vec<_>>().await.len(), 1);
assert_eq!(store.inbound_group_session_counts().await.unwrap().total, 1);
}

Expand Down
17 changes: 17 additions & 0 deletions crates/matrix-sdk-crypto/src/store/memorystore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use std::{
time::{Duration, Instant},
};

use async_stream::stream;
use async_trait::async_trait;
use ruma::{
events::secret::request::SecretName, DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId,
Expand All @@ -28,6 +29,7 @@ use tokio::sync::{Mutex, RwLock};

use super::{
caches::{DeviceStore, GroupSessionStore, SessionStore},
traits::StreamOf,
Account, BackupKeys, Changes, CryptoStore, InboundGroupSession, PendingChanges, RoomKeyCounts,
RoomSettings, Session,
};
Expand Down Expand Up @@ -251,6 +253,21 @@ impl CryptoStore for MemoryStore {
Ok(self.inbound_group_sessions.get_all())
}

async fn get_inbound_group_sessions_stream(
&self,
) -> Result<StreamOf<super::error::Result<InboundGroupSession>>> {
// There is no stream API for this `MemoryStore`. Let's simply consume the `Vec`
// from `get_inbound_group_sessions` as a stream. It's not ideal, but it's
// OK-ish for now.
let inbound_group_sessions = self.inbound_group_sessions.get_all();

Ok(StreamOf::new(Box::pin(stream! {
for item in inbound_group_sessions {
yield Ok(item);
}
})))
}

async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts> {
let backed_up =
self.get_inbound_group_sessions().await?.into_iter().filter(|s| s.backed_up()).count();
Expand Down
2 changes: 1 addition & 1 deletion crates/matrix-sdk-crypto/src/store/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ pub(crate) use crypto_store_wrapper::CryptoStoreWrapper;
pub use error::{CryptoStoreError, Result};
use matrix_sdk_common::{store_locks::CrossProcessStoreLock, timeout::timeout};
pub use memorystore::MemoryStore;
pub use traits::{CryptoStore, DynCryptoStore, IntoCryptoStore};
pub use traits::{CryptoStore, DynCryptoStore, IntoCryptoStore, StreamOf};

pub use crate::gossiping::{GossipRequest, SecretInfo};

Expand Down
47 changes: 46 additions & 1 deletion crates/matrix-sdk-crypto/src/store/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::{collections::HashMap, fmt, sync::Arc};
use std::{
collections::HashMap,
fmt,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};

use async_trait::async_trait;
use futures_core::Stream;
use matrix_sdk_common::AsyncTraitDeps;
use ruma::{
events::secret::request::SecretName, DeviceId, OwnedDeviceId, RoomId, TransactionId, UserId,
Expand Down Expand Up @@ -103,6 +110,11 @@ pub trait CryptoStore: AsyncTraitDeps {
/// Get all the inbound group sessions we have stored.
async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>, Self::Error>;

/// Get all the inbound group session we have stored, as a `Stream`.
async fn get_inbound_group_sessions_stream(
&self,
) -> Result<StreamOf<Result<InboundGroupSession>>, Self::Error>;

/// Get the number inbound group sessions we have and how many of them are
/// backed up.
async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts, Self::Error>;
Expand Down Expand Up @@ -328,6 +340,12 @@ impl<T: CryptoStore> CryptoStore for EraseCryptoStoreError<T> {
self.0.get_inbound_group_sessions().await.map_err(Into::into)
}

async fn get_inbound_group_sessions_stream(
&self,
) -> Result<StreamOf<Result<InboundGroupSession>>> {
self.0.get_inbound_group_sessions_stream().await.map_err(Into::into)
}

async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts> {
self.0.inbound_group_session_counts().await.map_err(Into::into)
}
Expand Down Expand Up @@ -508,3 +526,30 @@ impl IntoCryptoStore for Arc<DynCryptoStore> {
self
}
}

/// A concrete type wrapping a `Pin<Box<dyn Stream<Item = T>>>`.
///
/// It is used only to make the [`CryptoStore`] trait object-safe. Please don't
/// use it for other things.
pub struct StreamOf<T>(Pin<Box<dyn Stream<Item = T>>>);

impl<T> StreamOf<T> {
/// Create a new `Self`.
pub fn new(stream: Pin<Box<dyn Stream<Item = T>>>) -> Self {
Self(stream)
}
}

impl<T> fmt::Debug for StreamOf<T> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.debug_tuple("StreamOf").finish()
}
}

impl<T> Stream for StreamOf<T> {
type Item = T;

fn poll_next(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.0.as_mut().poll_next(context)
}
}
Loading

0 comments on commit 641afac

Please sign in to comment.