From 641afacd45e7e2f836e7edbe6dce033dcca9f5bc Mon Sep 17 00:00:00 2001 From: Ivan Enderlin Date: Fri, 26 Jan 2024 17:10:01 +0100 Subject: [PATCH] feat: Add `get_inbound_group_sessions_stream` to the `CryptoStore` trait. This patch implements the `get_inbound_group_sessions_stream` method onto the `CryptoStore` trait. In order for the trait to continue being object-safe, this patch implements a `StreamOf` struct that simply wraps any `T` where `T: Stream`. `StreamOf` also implements `Stream` and forwards everything to its inner stream. This patch finally updates the test suite of `matrix-sdk-crypto` to test this new `get_inbound_group_sessions_stream` method. --- Cargo.lock | 3 + crates/matrix-sdk-crypto/Cargo.toml | 1 + .../src/store/integration_tests.rs | 72 ++++++++++++++- .../src/store/memorystore.rs | 17 ++++ crates/matrix-sdk-crypto/src/store/mod.rs | 2 +- crates/matrix-sdk-crypto/src/store/traits.rs | 47 +++++++++- .../src/crypto_store/mod.rs | 89 +++++++++++++------ crates/matrix-sdk-indexeddb/src/stream.rs | 88 +++++++++++------- crates/matrix-sdk-sqlite/Cargo.toml | 2 + crates/matrix-sdk-sqlite/src/crypto_store.rs | 20 ++++- 10 files changed, 274 insertions(+), 67 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ccf0223ff9e..67477e14b2b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3098,6 +3098,7 @@ dependencies = [ "as_variant", "assert_matches", "assert_matches2", + "async-stream", "async-trait", "bs58", "byteorder", @@ -3279,8 +3280,10 @@ name = "matrix-sdk-sqlite" version = "0.7.0" dependencies = [ "assert_matches", + "async-stream", "async-trait", "deadpool-sqlite", + "futures-util", "glob", "itertools 0.12.0", "matrix-sdk-base", diff --git a/crates/matrix-sdk-crypto/Cargo.toml b/crates/matrix-sdk-crypto/Cargo.toml index 1257eda9278..7cbd34571d2 100644 --- a/crates/matrix-sdk-crypto/Cargo.toml +++ b/crates/matrix-sdk-crypto/Cargo.toml @@ -28,6 +28,7 @@ testing = ["dep:http"] [dependencies] aes = "0.8.1" as_variant = { workspace = true } +async-stream = { workspace = true } async-trait = { workspace = true } bs58 = { version = "0.5.0" } byteorder = { workspace = true } diff --git a/crates/matrix-sdk-crypto/src/store/integration_tests.rs b/crates/matrix-sdk-crypto/src/store/integration_tests.rs index 9373300efea..de44d0d5f70 100644 --- a/crates/matrix-sdk-crypto/src/store/integration_tests.rs +++ b/crates/matrix-sdk-crypto/src/store/integration_tests.rs @@ -3,10 +3,14 @@ macro_rules! cryptostore_integration_tests { () => { mod cryptostore_integration_tests { - use std::time::Duration; - use std::collections::{BTreeMap, HashMap}; + use std::{ + collections::{BTreeMap, HashMap}, + future::ready, + time::Duration, + }; use assert_matches::assert_matches; + use futures_util::{StreamExt, TryStreamExt}; use matrix_sdk_test::async_test; use ruma::{ device_id, @@ -266,6 +270,65 @@ macro_rules! cryptostore_integration_tests { store.save_changes(changes).await.expect("Can't save group session"); } + #[async_test] + async fn save_many_inbound_group_sessions() { + let (account, store) = get_loaded_store("save_many_inbound_group_sessions").await; + + const NUMBER_OF_SESSIONS_FOR_ROOM1: usize = 400; + const NUMBER_OF_SESSIONS_FOR_ROOM2: usize = 600; + const NUMBER_OF_SESSIONS: usize = NUMBER_OF_SESSIONS_FOR_ROOM1 + NUMBER_OF_SESSIONS_FOR_ROOM2; + + let room_id1 = room_id!("!a:localhost"); + let room_id2 = room_id!("!b:localhost"); + let mut sessions = Vec::with_capacity(NUMBER_OF_SESSIONS); + + + for i in 0..(NUMBER_OF_SESSIONS) { + let (_, session) = account.create_group_session_pair_with_defaults( + if i < NUMBER_OF_SESSIONS_FOR_ROOM1 { + &room_id1 + } else { + &room_id2 + } + ).await; + sessions.push(session); + } + + let changes = + Changes { inbound_group_sessions: sessions, ..Default::default() }; + + store.save_changes(changes).await.expect("Can't save group session"); + + assert_eq!(store.get_inbound_group_sessions().await.unwrap().len(), NUMBER_OF_SESSIONS); + assert_eq!(store.get_inbound_group_sessions_stream().await.unwrap().collect::>().await.len(), NUMBER_OF_SESSIONS); + + // Inefficient because all sessions are collected once in a `Vec`, then filtered, then collected again in another `Vec`. + assert_eq!( + store + .get_inbound_group_sessions() + .await + .unwrap() + .into_iter() + .filter(|session: &InboundGroupSession| session.room_id() == room_id2) + .collect::>() + .len(), + NUMBER_OF_SESSIONS_FOR_ROOM2 + ); + + // Efficient because sessions are filtered, then collected only once. + assert_eq!( + store + .get_inbound_group_sessions_stream() + .await + .unwrap() + .try_filter(|session: &InboundGroupSession| ready(session.room_id() == room_id2)) + .collect::>() + .await + .len(), + NUMBER_OF_SESSIONS_FOR_ROOM2 + ); + } + #[async_test] async fn save_inbound_group_session_for_backup() { let (account, store) = @@ -286,6 +349,7 @@ macro_rules! cryptostore_integration_tests { .unwrap(); assert_eq!(session, loaded_session); assert_eq!(store.get_inbound_group_sessions().await.unwrap().len(), 1); + assert_eq!(store.get_inbound_group_sessions_stream().await.unwrap().collect::>().await.len(), 1); assert_eq!(store.inbound_group_session_counts().await.unwrap().total, 1); assert_eq!(store.inbound_group_session_counts().await.unwrap().backed_up, 0); @@ -373,7 +437,8 @@ macro_rules! cryptostore_integration_tests { async fn load_inbound_group_session() { let dir = "load_inbound_group_session"; let (account, store) = get_loaded_store(dir).await; - assert_eq!(store.get_inbound_group_sessions().await.unwrap().len(), 0); + assert!(store.get_inbound_group_sessions().await.unwrap().is_empty()); + assert!(store.get_inbound_group_sessions_stream().await.unwrap().collect::>().await.is_empty()); let room_id = &room_id!("!test:localhost"); let (_, session) = account.create_group_session_pair_with_defaults(room_id).await; @@ -402,6 +467,7 @@ macro_rules! cryptostore_integration_tests { let export = loaded_session.export().await; assert_eq!(store.get_inbound_group_sessions().await.unwrap().len(), 1); + assert_eq!(store.get_inbound_group_sessions_stream().await.unwrap().collect::>().await.len(), 1); assert_eq!(store.inbound_group_session_counts().await.unwrap().total, 1); } diff --git a/crates/matrix-sdk-crypto/src/store/memorystore.rs b/crates/matrix-sdk-crypto/src/store/memorystore.rs index f75b2155f5e..fd0c1ecc32e 100644 --- a/crates/matrix-sdk-crypto/src/store/memorystore.rs +++ b/crates/matrix-sdk-crypto/src/store/memorystore.rs @@ -19,6 +19,7 @@ use std::{ time::{Duration, Instant}, }; +use async_stream::stream; use async_trait::async_trait; use ruma::{ events::secret::request::SecretName, DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, @@ -28,6 +29,7 @@ use tokio::sync::{Mutex, RwLock}; use super::{ caches::{DeviceStore, GroupSessionStore, SessionStore}, + traits::StreamOf, Account, BackupKeys, Changes, CryptoStore, InboundGroupSession, PendingChanges, RoomKeyCounts, RoomSettings, Session, }; @@ -251,6 +253,21 @@ impl CryptoStore for MemoryStore { Ok(self.inbound_group_sessions.get_all()) } + async fn get_inbound_group_sessions_stream( + &self, + ) -> Result>> { + // There is no stream API for this `MemoryStore`. Let's simply consume the `Vec` + // from `get_inbound_group_sessions` as a stream. It's not ideal, but it's + // OK-ish for now. + let inbound_group_sessions = self.inbound_group_sessions.get_all(); + + Ok(StreamOf::new(Box::pin(stream! { + for item in inbound_group_sessions { + yield Ok(item); + } + }))) + } + async fn inbound_group_session_counts(&self) -> Result { let backed_up = self.get_inbound_group_sessions().await?.into_iter().filter(|s| s.backed_up()).count(); diff --git a/crates/matrix-sdk-crypto/src/store/mod.rs b/crates/matrix-sdk-crypto/src/store/mod.rs index 21aa405060c..cfaf05c4892 100644 --- a/crates/matrix-sdk-crypto/src/store/mod.rs +++ b/crates/matrix-sdk-crypto/src/store/mod.rs @@ -90,7 +90,7 @@ pub(crate) use crypto_store_wrapper::CryptoStoreWrapper; pub use error::{CryptoStoreError, Result}; use matrix_sdk_common::{store_locks::CrossProcessStoreLock, timeout::timeout}; pub use memorystore::MemoryStore; -pub use traits::{CryptoStore, DynCryptoStore, IntoCryptoStore}; +pub use traits::{CryptoStore, DynCryptoStore, IntoCryptoStore, StreamOf}; pub use crate::gossiping::{GossipRequest, SecretInfo}; diff --git a/crates/matrix-sdk-crypto/src/store/traits.rs b/crates/matrix-sdk-crypto/src/store/traits.rs index bebcdd03bc4..44d7f37f2e5 100644 --- a/crates/matrix-sdk-crypto/src/store/traits.rs +++ b/crates/matrix-sdk-crypto/src/store/traits.rs @@ -12,9 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::HashMap, fmt, sync::Arc}; +use std::{ + collections::HashMap, + fmt, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; use async_trait::async_trait; +use futures_core::Stream; use matrix_sdk_common::AsyncTraitDeps; use ruma::{ events::secret::request::SecretName, DeviceId, OwnedDeviceId, RoomId, TransactionId, UserId, @@ -103,6 +110,11 @@ pub trait CryptoStore: AsyncTraitDeps { /// Get all the inbound group sessions we have stored. async fn get_inbound_group_sessions(&self) -> Result, Self::Error>; + /// Get all the inbound group session we have stored, as a `Stream`. + async fn get_inbound_group_sessions_stream( + &self, + ) -> Result>, Self::Error>; + /// Get the number inbound group sessions we have and how many of them are /// backed up. async fn inbound_group_session_counts(&self) -> Result; @@ -328,6 +340,12 @@ impl CryptoStore for EraseCryptoStoreError { self.0.get_inbound_group_sessions().await.map_err(Into::into) } + async fn get_inbound_group_sessions_stream( + &self, + ) -> Result>> { + self.0.get_inbound_group_sessions_stream().await.map_err(Into::into) + } + async fn inbound_group_session_counts(&self) -> Result { self.0.inbound_group_session_counts().await.map_err(Into::into) } @@ -508,3 +526,30 @@ impl IntoCryptoStore for Arc { self } } + +/// A concrete type wrapping a `Pin>>`. +/// +/// It is used only to make the [`CryptoStore`] trait object-safe. Please don't +/// use it for other things. +pub struct StreamOf(Pin>>); + +impl StreamOf { + /// Create a new `Self`. + pub fn new(stream: Pin>>) -> Self { + Self(stream) + } +} + +impl fmt::Debug for StreamOf { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.debug_tuple("StreamOf").finish() + } +} + +impl Stream for StreamOf { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll> { + self.0.as_mut().poll_next(context) + } +} diff --git a/crates/matrix-sdk-indexeddb/src/crypto_store/mod.rs b/crates/matrix-sdk-indexeddb/src/crypto_store/mod.rs index 33b1b7ef625..d27635c316e 100644 --- a/crates/matrix-sdk-indexeddb/src/crypto_store/mod.rs +++ b/crates/matrix-sdk-indexeddb/src/crypto_store/mod.rs @@ -14,10 +14,12 @@ use std::{ collections::HashMap, + num::NonZeroUsize, sync::{Arc, RwLock}, }; use async_trait::async_trait; +use futures_util::StreamExt; use gloo_utils::format::JsValueSerdeExt; use indexed_db_futures::prelude::*; use matrix_sdk_crypto::{ @@ -26,8 +28,8 @@ use matrix_sdk_crypto::{ Session, StaticAccountData, }, store::{ - caches::SessionStore, BackupKeys, Changes, CryptoStore, CryptoStoreError, PendingChanges, - RoomKeyCounts, RoomSettings, + self, caches::SessionStore, BackupKeys, Changes, CryptoStore, CryptoStoreError, + PendingChanges, RoomKeyCounts, RoomSettings, StreamOf, }, types::events::room_key_withheld::RoomKeyWithheldEvent, Account, GossipRequest, GossippedSecret, ReadOnlyDevice, ReadOnlyUserIdentities, SecretInfo, @@ -43,8 +45,10 @@ use tracing::{debug, warn}; use wasm_bindgen::JsValue; use web_sys::{DomException, IdbKeyRange}; -use crate::crypto_store::{ - indexeddb_serializer::IndexeddbSerializer, migrations::open_and_upgrade_db, +use crate::{ + crypto_store::{indexeddb_serializer::IndexeddbSerializer, migrations::open_and_upgrade_db}, + stream::StreamByRenewedCursor, + IndexeddbStateStoreError, }; mod indexeddb_serializer; @@ -94,16 +98,16 @@ mod keys { pub const RECOVERY_KEY_V1: &str = "recovery_key_v1"; } -/// An implementation of [CryptoStore] that uses [IndexedDB] for persistent +/// An implementation of [`CryptoStore`] that uses [`IndexedDB`] for persistent /// storage. /// -/// [IndexedDB]: https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API +/// [`IndexedDB`]: https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API pub struct IndexeddbCryptoStore { static_account: RwLock>, name: String, - pub(crate) inner: IdbDatabase, + pub(crate) inner: Arc, - serializer: IndexeddbSerializer, + serializer: Arc, session_cache: SessionStore, save_changes_lock: Arc>, } @@ -175,8 +179,8 @@ impl IndexeddbCryptoStore { Ok(Self { name, session_cache, - inner: db, - serializer, + inner: Arc::new(db), + serializer: Arc::new(serializer), static_account: RwLock::new(None), save_changes_lock: Default::default(), }) @@ -280,23 +284,7 @@ impl IndexeddbCryptoStore { &self, stored_value: JsValue, ) -> Result { - let idb_object: InboundGroupSessionIndexedDbObject = - serde_wasm_bindgen::from_value(stored_value)?; - let pickled_session = - self.serializer.deserialize_value_from_bytes(&idb_object.pickled_session)?; - let session = InboundGroupSession::from_pickle(pickled_session) - .map_err(|e| IndexeddbCryptoStoreError::CryptoStoreError(e.into()))?; - - // Although a "backed up" flag is stored inside `idb_object.pickled_session`, it - // is not maintained when backups are reset. Overwrite the flag with the - // needs_backup value from the IDB object. - if idb_object.needs_backup { - session.reset_backup_state(); - } else { - session.mark_as_backed_up(); - } - - Ok(session) + deserialize_inbound_group_session(stored_value, &self.serializer) } /// Transform a [`GossipRequest`] into a `JsValue` holding a @@ -326,6 +314,28 @@ impl IndexeddbCryptoStore { } } +fn deserialize_inbound_group_session( + stored_value: JsValue, + serializer: &IndexeddbSerializer, +) -> Result { + let idb_object: InboundGroupSessionIndexedDbObject = + serde_wasm_bindgen::from_value(stored_value)?; + let pickled_session = serializer.deserialize_value_from_bytes(&idb_object.pickled_session)?; + let session = InboundGroupSession::from_pickle(pickled_session) + .map_err(|e| IndexeddbCryptoStoreError::CryptoStoreError(e.into()))?; + + // Although a "backed up" flag is stored inside `idb_object.pickled_session`, it + // is not maintained when backups are reset. Overwrite the flag with the + // needs_backup value from the IDB object. + if idb_object.needs_backup { + session.reset_backup_state(); + } else { + session.mark_as_backed_up(); + } + + Ok(session) +} + // Small hack to have the following macro invocation act as the appropriate // trait impl block on wasm, but still be compiled on non-wasm as a regular // impl block otherwise. @@ -800,6 +810,31 @@ impl_crypto_store! { ).await } + async fn get_inbound_group_sessions_stream(&self) -> Result>> { + let db = self.inner.clone(); + let serializer = self.serializer.clone(); + + let stream = StreamByRenewedCursor::new( + db, + |db| db.transaction_on_one_with_mode( + keys::INBOUND_GROUP_SESSIONS_V2, + IdbTransactionMode::Readonly, + ), + keys::INBOUND_GROUP_SESSIONS_V2.to_owned(), + // SAFETY: `unwrap` is safe because 100 isn't zero. + NonZeroUsize::new(100).unwrap(), + ) + .await? + .map(move |item: Result<(JsValue, JsValue), DomException>| -> store::Result { + let item: (JsValue, JsValue) = item.map_err(IndexeddbStateStoreError::from).map_err(store::CryptoStoreError::backend)?; + let (_key, value) = item; + + Ok(deserialize_inbound_group_session(value, &serializer)?) + }); + + Ok(StreamOf::new(Box::pin(stream))) + } + async fn inbound_group_session_counts(&self) -> Result { let tx = self .inner diff --git a/crates/matrix-sdk-indexeddb/src/stream.rs b/crates/matrix-sdk-indexeddb/src/stream.rs index 9de67e65988..b954fe14aa5 100644 --- a/crates/matrix-sdk-indexeddb/src/stream.rs +++ b/crates/matrix-sdk-indexeddb/src/stream.rs @@ -7,6 +7,7 @@ use std::{ num::NonZeroUsize, pin::Pin, ptr::NonNull, + sync::Arc, task::{ready, Context, Poll}, }; @@ -169,8 +170,8 @@ pin_project! { /// /// Then. What is renewed? The [`IdbTransaction`]. How? With a _transaction /// builder_: a closure than generates an [`IdbTransaction`]. Why a range? - /// Because when a transaction is renewed, the cursor must be re-positioned to the - /// previous position. + /// Because when a transaction is renewed, the cursor must be re-positioned to + /// the previous position. /// /// Such async iterator is helpful when reading a lot of data from an IndexedDB /// object store. With [`StreamExt`], one can easily map the results from the @@ -243,22 +244,39 @@ pin_project! { // when an [`IdbTransaction`] is renwed. latest_key: JsValue, - // Explanations regarding the next 4 fields + // The inner stream, aka [`StreamByCursor`]. + #[pin] + inner_stream: Option>, + + // When asking for a new cursor, the [`IdbObjectStore::open_cursor_with_range`] + // method has to be called. It's an async method. Thus, in the `Stream` + // implementation of `Self`, this future must be stored to be polled manually. + cursor_future: Option>>>>, + + // Explanations regarding the next 6 fields // // The types of [`indexedb_db_futures`] are difficult to store because there is a global - // lifetime across all the types. To get an `IdbObjectStore`, we need an `IdbTransaction`. If - // we store `IdbObjectStore` alone, the `IdbTransaction` will be dropped, thus not having a - // long enough lifetime. On the opposite, if we store `IdbTransaction`, the reference is moved - // inside the `IdbObjectStore`, which forbids to move the owned `IdbTransaction`. To solve - // that, we use self-referential fields, one for the owned value, one for the reference for the - // owned value. It's not ideal, but at least it works! + // lifetime across all the types to the database. To get an `IdbObjectStore`, we need an + // `IdbTransaction`. If we store `IdbObjectStore` alone, the `IdbTransaction` will be + // dropped, thus not having a long enough lifetime. On the opposite, if we store + // `IdbTransaction`, the reference is moved inside the `IdbObjectStore`, which forbids to + // move the owned `IdbTransaction`. To solve that, we use self-referential fields, one for + // the owned value, one for the reference for the owned value. It's not ideal, but at least + // it works! + + // The database that will be passed to the `transaction_builder`. + #[pin] + database: Arc, + + // The self-reference to `Self::database`. + database_ptr: NonNull>, // The `latest_transaction`. #[pin] - latest_transaction: IdbTransaction<'a>, + latest_transaction: Option>, // The self-reference to `Self::latest_transaction`. - latest_transaction_ptr: NonNull>, + latest_transaction_ptr: NonNull>>, // The `latest_object_store`. #[pin] @@ -267,15 +285,6 @@ pin_project! { // The self-reference to `Self::latest_object_store`. latest_object_store_ptr: NonNull>>, - // The inner stream, aka [`StreamByCursor`]. - #[pin] - inner_stream: Option>, - - // When asking for a new cursor, the [`IdbObjectStore::open_cursor_with_range`] - // method has to be called. It's an async method. Thus, in the `Stream` - // implementation of `Self`, this future must be stored to be polled manually. - cursor_future: Option>>>>, - // The entire struct must be unmovable. Let's use a `PhantomPinned` so that it cannot implement // `Unpin`. _pin: PhantomPinned, @@ -284,18 +293,18 @@ pin_project! { impl<'a, F> StreamByRenewedCursor<'a, F> where - F: FnMut() -> Result, DomException>, + F: FnMut(&'a IdbDatabase) -> Result, DomException>, { /// Build a new `StreamByRenewdCursor`. /// /// It takes a `transaction_builder`, an `object_store_name`, and a /// `renew_every`. See the documentation of [`Self`] to learn more. pub async fn new( - mut transaction_builder: F, + database: Arc, + transaction_builder: F, object_store_name: String, renew_every: NonZeroUsize, ) -> Result>, DomException> { - let transaction = transaction_builder()?; let latest_key = JsValue::from_str(""); let after_latest_key = IdbKeyRange::lower_bound_with_open(&latest_key, true)?; @@ -305,12 +314,14 @@ where renew_every: renew_every.into(), renew_in: renew_every.into(), latest_key, - latest_transaction: transaction, + inner_stream: None, + cursor_future: None, + database, + database_ptr: NonNull::dangling(), + latest_transaction: None, latest_transaction_ptr: NonNull::dangling(), latest_object_store: None, latest_object_store_ptr: NonNull::dangling(), - inner_stream: None, - cursor_future: None, _pin: PhantomPinned, }; let mut this = Box::pin(this); @@ -318,10 +329,17 @@ where unsafe { let this = Pin::get_unchecked_mut(Pin::as_mut(&mut this)); + this.database_ptr = NonNull::from(&this.database); + + let transaction = (this.transaction_builder)(this.database_ptr.as_ref())?; + this.latest_transaction = Some(transaction); + this.latest_transaction_ptr = NonNull::from(&this.latest_transaction); this.latest_object_store = Some( this.latest_transaction_ptr .as_ref() + .as_ref() + .unwrap() .object_store(this.object_store_name.as_str())?, ); this.latest_object_store_ptr = NonNull::from(&this.latest_object_store); @@ -344,7 +362,7 @@ where impl<'a, F> Stream for StreamByRenewedCursor<'a, F> where - F: FnMut() -> Result, DomException>, + F: FnMut(&'a IdbDatabase) -> Result, DomException>, { type Item = Result<(JsValue, JsValue), DomException>; @@ -357,11 +375,13 @@ where // If it's not defined, let's build one. if this.cursor_future.is_none() { // Get and save the new `transaction`. - let transaction = (this.transaction_builder)()?; - this.latest_transaction.set(transaction); + let transaction = + (this.transaction_builder)(unsafe { this.database_ptr.as_ref() })?; + this.latest_transaction.set(Some(transaction)); // Get and asve the new `object_store`. - let object_store = unsafe { this.latest_transaction_ptr.as_ref() } + let object_store = unsafe { this.latest_transaction_ptr.as_ref().as_ref() } + .unwrap() .object_store(this.object_store_name.as_str())?; this.latest_object_store.set(Some(object_store)); @@ -525,7 +545,7 @@ mod tests { Ok(()) })); - let db = db.await?; + let db = Arc::new(db.await?); { let transaction = @@ -555,7 +575,8 @@ mod tests { { let mut number_of_renews = 0; let mut stream = StreamByRenewedCursor::new( - || { + db.clone(), + |db| { number_of_renews += 1; Ok(db.transaction_on_one("baz")?) }, @@ -594,7 +615,8 @@ mod tests { { let mut number_of_renews = 0; let stream = StreamByRenewedCursor::new( - || { + db.clone(), + |db| { number_of_renews += 1; Ok(db.transaction_on_one("baz")?) }, diff --git a/crates/matrix-sdk-sqlite/Cargo.toml b/crates/matrix-sdk-sqlite/Cargo.toml index eede3850ee4..6c8eb1d7ca7 100644 --- a/crates/matrix-sdk-sqlite/Cargo.toml +++ b/crates/matrix-sdk-sqlite/Cargo.toml @@ -16,6 +16,7 @@ crypto-store = ["dep:matrix-sdk-crypto"] state-store = ["dep:matrix-sdk-base"] [dependencies] +async-stream = { workspace = true } async-trait = { workspace = true } deadpool-sqlite = "0.7.0" itertools = { workspace = true } @@ -34,6 +35,7 @@ vodozemac = { workspace = true } [dev-dependencies] assert_matches = { workspace = true } +futures-util = { workspace = true } glob = "0.3.0" matrix-sdk-base = { workspace = true, features = ["testing"] } matrix-sdk-crypto = { workspace = true, features = ["testing"] } diff --git a/crates/matrix-sdk-sqlite/src/crypto_store.rs b/crates/matrix-sdk-sqlite/src/crypto_store.rs index 522f2acbb00..2a91c41269d 100644 --- a/crates/matrix-sdk-sqlite/src/crypto_store.rs +++ b/crates/matrix-sdk-sqlite/src/crypto_store.rs @@ -20,6 +20,7 @@ use std::{ sync::{Arc, RwLock}, }; +use async_stream::stream; use async_trait::async_trait; use deadpool_sqlite::{Object as SqliteConn, Pool as SqlitePool, Runtime}; use matrix_sdk_crypto::{ @@ -28,8 +29,8 @@ use matrix_sdk_crypto::{ PrivateCrossSigningIdentity, Session, StaticAccountData, }, store::{ - caches::SessionStore, BackupKeys, Changes, CryptoStore, PendingChanges, RoomKeyCounts, - RoomSettings, + self, caches::SessionStore, BackupKeys, Changes, CryptoStore, PendingChanges, + RoomKeyCounts, RoomSettings, StreamOf, }, types::events::room_key_withheld::RoomKeyWithheldEvent, Account, GossipRequest, GossippedSecret, ReadOnlyDevice, ReadOnlyUserIdentities, SecretInfo, @@ -940,6 +941,21 @@ impl CryptoStore for SqliteCryptoStore { .collect() } + // For now, this method transforms `Self::get_inbound_group_sessions`'s result + // into a `Stream`. That's really not useful. Ideally, we want to iterate the + // SQLite table progressively, with a cursor or something like this. + async fn get_inbound_group_sessions_stream( + &self, + ) -> Result>> { + let inbound_group_sessions = self.get_inbound_group_sessions().await?; + + Ok(StreamOf::new(Box::pin(stream! { + for item in inbound_group_sessions { + yield Ok(item); + } + }))) + } + async fn inbound_group_session_counts(&self) -> Result { Ok(self.acquire().await?.get_inbound_group_session_counts().await?) }