From b47a4a179a13ca9af6d7318a46ebecba98b5651e Mon Sep 17 00:00:00 2001 From: Ivan Enderlin Date: Fri, 26 Jan 2024 17:10:01 +0100 Subject: [PATCH] !foo --- crates/matrix-sdk-crypto/Cargo.toml | 1 + .../src/store/memorystore.rs | 17 ++++ crates/matrix-sdk-crypto/src/store/mod.rs | 2 +- crates/matrix-sdk-crypto/src/store/traits.rs | 47 +++++++++- .../src/crypto_store/mod.rs | 88 +++++++++++++------ crates/matrix-sdk-indexeddb/src/stream.rs | 45 +++++++--- crates/matrix-sdk-sqlite/Cargo.toml | 1 + crates/matrix-sdk-sqlite/src/crypto_store.rs | 17 +++- 8 files changed, 174 insertions(+), 44 deletions(-) diff --git a/crates/matrix-sdk-crypto/Cargo.toml b/crates/matrix-sdk-crypto/Cargo.toml index 1257eda9278..7cbd34571d2 100644 --- a/crates/matrix-sdk-crypto/Cargo.toml +++ b/crates/matrix-sdk-crypto/Cargo.toml @@ -28,6 +28,7 @@ testing = ["dep:http"] [dependencies] aes = "0.8.1" as_variant = { workspace = true } +async-stream = { workspace = true } async-trait = { workspace = true } bs58 = { version = "0.5.0" } byteorder = { workspace = true } diff --git a/crates/matrix-sdk-crypto/src/store/memorystore.rs b/crates/matrix-sdk-crypto/src/store/memorystore.rs index f75b2155f5e..fd0c1ecc32e 100644 --- a/crates/matrix-sdk-crypto/src/store/memorystore.rs +++ b/crates/matrix-sdk-crypto/src/store/memorystore.rs @@ -19,6 +19,7 @@ use std::{ time::{Duration, Instant}, }; +use async_stream::stream; use async_trait::async_trait; use ruma::{ events::secret::request::SecretName, DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, @@ -28,6 +29,7 @@ use tokio::sync::{Mutex, RwLock}; use super::{ caches::{DeviceStore, GroupSessionStore, SessionStore}, + traits::StreamOf, Account, BackupKeys, Changes, CryptoStore, InboundGroupSession, PendingChanges, RoomKeyCounts, RoomSettings, Session, }; @@ -251,6 +253,21 @@ impl CryptoStore for MemoryStore { Ok(self.inbound_group_sessions.get_all()) } + async fn get_inbound_group_sessions_stream( + &self, + ) -> Result>> { + // There is no stream API for this `MemoryStore`. Let's simply consume the `Vec` + // from `get_inbound_group_sessions` as a stream. It's not ideal, but it's + // OK-ish for now. + let inbound_group_sessions = self.inbound_group_sessions.get_all(); + + Ok(StreamOf::new(Box::pin(stream! { + for item in inbound_group_sessions { + yield Ok(item); + } + }))) + } + async fn inbound_group_session_counts(&self) -> Result { let backed_up = self.get_inbound_group_sessions().await?.into_iter().filter(|s| s.backed_up()).count(); diff --git a/crates/matrix-sdk-crypto/src/store/mod.rs b/crates/matrix-sdk-crypto/src/store/mod.rs index 21aa405060c..cfaf05c4892 100644 --- a/crates/matrix-sdk-crypto/src/store/mod.rs +++ b/crates/matrix-sdk-crypto/src/store/mod.rs @@ -90,7 +90,7 @@ pub(crate) use crypto_store_wrapper::CryptoStoreWrapper; pub use error::{CryptoStoreError, Result}; use matrix_sdk_common::{store_locks::CrossProcessStoreLock, timeout::timeout}; pub use memorystore::MemoryStore; -pub use traits::{CryptoStore, DynCryptoStore, IntoCryptoStore}; +pub use traits::{CryptoStore, DynCryptoStore, IntoCryptoStore, StreamOf}; pub use crate::gossiping::{GossipRequest, SecretInfo}; diff --git a/crates/matrix-sdk-crypto/src/store/traits.rs b/crates/matrix-sdk-crypto/src/store/traits.rs index bebcdd03bc4..44d7f37f2e5 100644 --- a/crates/matrix-sdk-crypto/src/store/traits.rs +++ b/crates/matrix-sdk-crypto/src/store/traits.rs @@ -12,9 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::HashMap, fmt, sync::Arc}; +use std::{ + collections::HashMap, + fmt, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; use async_trait::async_trait; +use futures_core::Stream; use matrix_sdk_common::AsyncTraitDeps; use ruma::{ events::secret::request::SecretName, DeviceId, OwnedDeviceId, RoomId, TransactionId, UserId, @@ -103,6 +110,11 @@ pub trait CryptoStore: AsyncTraitDeps { /// Get all the inbound group sessions we have stored. async fn get_inbound_group_sessions(&self) -> Result, Self::Error>; + /// Get all the inbound group session we have stored, as a `Stream`. + async fn get_inbound_group_sessions_stream( + &self, + ) -> Result>, Self::Error>; + /// Get the number inbound group sessions we have and how many of them are /// backed up. async fn inbound_group_session_counts(&self) -> Result; @@ -328,6 +340,12 @@ impl CryptoStore for EraseCryptoStoreError { self.0.get_inbound_group_sessions().await.map_err(Into::into) } + async fn get_inbound_group_sessions_stream( + &self, + ) -> Result>> { + self.0.get_inbound_group_sessions_stream().await.map_err(Into::into) + } + async fn inbound_group_session_counts(&self) -> Result { self.0.inbound_group_session_counts().await.map_err(Into::into) } @@ -508,3 +526,30 @@ impl IntoCryptoStore for Arc { self } } + +/// A concrete type wrapping a `Pin>>`. +/// +/// It is used only to make the [`CryptoStore`] trait object-safe. Please don't +/// use it for other things. +pub struct StreamOf(Pin>>); + +impl StreamOf { + /// Create a new `Self`. + pub fn new(stream: Pin>>) -> Self { + Self(stream) + } +} + +impl fmt::Debug for StreamOf { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.debug_tuple("StreamOf").finish() + } +} + +impl Stream for StreamOf { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll> { + self.0.as_mut().poll_next(context) + } +} diff --git a/crates/matrix-sdk-indexeddb/src/crypto_store/mod.rs b/crates/matrix-sdk-indexeddb/src/crypto_store/mod.rs index 33b1b7ef625..34a1b47a97e 100644 --- a/crates/matrix-sdk-indexeddb/src/crypto_store/mod.rs +++ b/crates/matrix-sdk-indexeddb/src/crypto_store/mod.rs @@ -14,10 +14,12 @@ use std::{ collections::HashMap, + num::NonZeroUsize, sync::{Arc, RwLock}, }; use async_trait::async_trait; +use futures_util::StreamExt; use gloo_utils::format::JsValueSerdeExt; use indexed_db_futures::prelude::*; use matrix_sdk_crypto::{ @@ -26,8 +28,8 @@ use matrix_sdk_crypto::{ Session, StaticAccountData, }, store::{ - caches::SessionStore, BackupKeys, Changes, CryptoStore, CryptoStoreError, PendingChanges, - RoomKeyCounts, RoomSettings, + self, caches::SessionStore, BackupKeys, Changes, CryptoStore, CryptoStoreError, + PendingChanges, RoomKeyCounts, RoomSettings, StreamOf, }, types::events::room_key_withheld::RoomKeyWithheldEvent, Account, GossipRequest, GossippedSecret, ReadOnlyDevice, ReadOnlyUserIdentities, SecretInfo, @@ -43,8 +45,9 @@ use tracing::{debug, warn}; use wasm_bindgen::JsValue; use web_sys::{DomException, IdbKeyRange}; -use crate::crypto_store::{ - indexeddb_serializer::IndexeddbSerializer, migrations::open_and_upgrade_db, +use crate::{ + crypto_store::{indexeddb_serializer::IndexeddbSerializer, migrations::open_and_upgrade_db}, + stream::StreamByRenewedCursor, }; mod indexeddb_serializer; @@ -94,16 +97,16 @@ mod keys { pub const RECOVERY_KEY_V1: &str = "recovery_key_v1"; } -/// An implementation of [CryptoStore] that uses [IndexedDB] for persistent +/// An implementation of [`CryptoStore`] that uses [`IndexedDB`] for persistent /// storage. /// -/// [IndexedDB]: https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API +/// [`IndexedDB`]: https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API pub struct IndexeddbCryptoStore { static_account: RwLock>, name: String, - pub(crate) inner: IdbDatabase, + pub(crate) inner: Arc, - serializer: IndexeddbSerializer, + serializer: Arc, session_cache: SessionStore, save_changes_lock: Arc>, } @@ -175,8 +178,8 @@ impl IndexeddbCryptoStore { Ok(Self { name, session_cache, - inner: db, - serializer, + inner: Arc::new(db), + serializer: Arc::new(serializer), static_account: RwLock::new(None), save_changes_lock: Default::default(), }) @@ -280,23 +283,7 @@ impl IndexeddbCryptoStore { &self, stored_value: JsValue, ) -> Result { - let idb_object: InboundGroupSessionIndexedDbObject = - serde_wasm_bindgen::from_value(stored_value)?; - let pickled_session = - self.serializer.deserialize_value_from_bytes(&idb_object.pickled_session)?; - let session = InboundGroupSession::from_pickle(pickled_session) - .map_err(|e| IndexeddbCryptoStoreError::CryptoStoreError(e.into()))?; - - // Although a "backed up" flag is stored inside `idb_object.pickled_session`, it - // is not maintained when backups are reset. Overwrite the flag with the - // needs_backup value from the IDB object. - if idb_object.needs_backup { - session.reset_backup_state(); - } else { - session.mark_as_backed_up(); - } - - Ok(session) + deserialize_inbound_group_session(stored_value, &self.serializer) } /// Transform a [`GossipRequest`] into a `JsValue` holding a @@ -326,6 +313,28 @@ impl IndexeddbCryptoStore { } } +fn deserialize_inbound_group_session( + stored_value: JsValue, + serializer: &IndexeddbSerializer, +) -> Result { + let idb_object: InboundGroupSessionIndexedDbObject = + serde_wasm_bindgen::from_value(stored_value)?; + let pickled_session = serializer.deserialize_value_from_bytes(&idb_object.pickled_session)?; + let session = InboundGroupSession::from_pickle(pickled_session) + .map_err(|e| IndexeddbCryptoStoreError::CryptoStoreError(e.into()))?; + + // Although a "backed up" flag is stored inside `idb_object.pickled_session`, it + // is not maintained when backups are reset. Overwrite the flag with the + // needs_backup value from the IDB object. + if idb_object.needs_backup { + session.reset_backup_state(); + } else { + session.mark_as_backed_up(); + } + + Ok(session) +} + // Small hack to have the following macro invocation act as the appropriate // trait impl block on wasm, but still be compiled on non-wasm as a regular // impl block otherwise. @@ -800,6 +809,31 @@ impl_crypto_store! { ).await } + async fn get_inbound_group_sessions_stream(&self) -> Result>> { + let db = self.inner.clone(); + let serializer = self.serializer.clone(); + + let stream = StreamByRenewedCursor::new( + db, + |db| db.transaction_on_one_with_mode( + keys::INBOUND_GROUP_SESSIONS_V2, + IdbTransactionMode::Readonly, + ), + keys::INBOUND_GROUP_SESSIONS_V2.to_owned(), + // SAFETY: `unwrap` is safe because 100 isn't zero. + NonZeroUsize::new(100).unwrap(), + ) + .await? + .map(move |item: Result<(JsValue, JsValue), DomException>| -> store::Result { + let item: (JsValue, JsValue) = item.unwrap() /* FIX ME */; + let (_key, value) = item; + + Ok(deserialize_inbound_group_session(value, serializer.as_ref())?) + }); + + Ok(StreamOf::new(Box::pin(stream))) + } + async fn inbound_group_session_counts(&self) -> Result { let tx = self .inner diff --git a/crates/matrix-sdk-indexeddb/src/stream.rs b/crates/matrix-sdk-indexeddb/src/stream.rs index 9de67e65988..2b063e28fa8 100644 --- a/crates/matrix-sdk-indexeddb/src/stream.rs +++ b/crates/matrix-sdk-indexeddb/src/stream.rs @@ -7,6 +7,7 @@ use std::{ num::NonZeroUsize, pin::Pin, ptr::NonNull, + sync::Arc, task::{ready, Context, Poll}, }; @@ -224,6 +225,11 @@ pin_project! { /// } /// ``` pub struct StreamByRenewedCursor<'a, F> { + // The database that will be passed to the `transaction_builder`. + #[pin] + database: Arc, + database_ptr: NonNull>, + // The closure that is used to generate a new [`IdbTransaction`]. transaction_builder: F, @@ -255,10 +261,10 @@ pin_project! { // The `latest_transaction`. #[pin] - latest_transaction: IdbTransaction<'a>, + latest_transaction: Option>, // The self-reference to `Self::latest_transaction`. - latest_transaction_ptr: NonNull>, + latest_transaction_ptr: NonNull>>, // The `latest_object_store`. #[pin] @@ -284,28 +290,30 @@ pin_project! { impl<'a, F> StreamByRenewedCursor<'a, F> where - F: FnMut() -> Result, DomException>, + F: FnMut(&'a IdbDatabase) -> Result, DomException>, { /// Build a new `StreamByRenewdCursor`. /// /// It takes a `transaction_builder`, an `object_store_name`, and a /// `renew_every`. See the documentation of [`Self`] to learn more. pub async fn new( - mut transaction_builder: F, + database: Arc, + transaction_builder: F, object_store_name: String, renew_every: NonZeroUsize, ) -> Result>, DomException> { - let transaction = transaction_builder()?; let latest_key = JsValue::from_str(""); let after_latest_key = IdbKeyRange::lower_bound_with_open(&latest_key, true)?; let this = Self { + database, + database_ptr: NonNull::dangling(), transaction_builder, object_store_name, renew_every: renew_every.into(), renew_in: renew_every.into(), latest_key, - latest_transaction: transaction, + latest_transaction: None, latest_transaction_ptr: NonNull::dangling(), latest_object_store: None, latest_object_store_ptr: NonNull::dangling(), @@ -318,10 +326,17 @@ where unsafe { let this = Pin::get_unchecked_mut(Pin::as_mut(&mut this)); + this.database_ptr = NonNull::from(&this.database); + + let transaction = (this.transaction_builder)(this.database_ptr.as_ref())?; + this.latest_transaction = Some(transaction); + this.latest_transaction_ptr = NonNull::from(&this.latest_transaction); this.latest_object_store = Some( this.latest_transaction_ptr .as_ref() + .as_ref() + .unwrap() .object_store(this.object_store_name.as_str())?, ); this.latest_object_store_ptr = NonNull::from(&this.latest_object_store); @@ -344,7 +359,7 @@ where impl<'a, F> Stream for StreamByRenewedCursor<'a, F> where - F: FnMut() -> Result, DomException>, + F: FnMut(&'a IdbDatabase) -> Result, DomException>, { type Item = Result<(JsValue, JsValue), DomException>; @@ -357,11 +372,13 @@ where // If it's not defined, let's build one. if this.cursor_future.is_none() { // Get and save the new `transaction`. - let transaction = (this.transaction_builder)()?; - this.latest_transaction.set(transaction); + let transaction = + (this.transaction_builder)(unsafe { this.database_ptr.as_ref() })?; + this.latest_transaction.set(Some(transaction)); // Get and asve the new `object_store`. - let object_store = unsafe { this.latest_transaction_ptr.as_ref() } + let object_store = unsafe { this.latest_transaction_ptr.as_ref().as_ref() } + .unwrap() .object_store(this.object_store_name.as_str())?; this.latest_object_store.set(Some(object_store)); @@ -525,7 +542,7 @@ mod tests { Ok(()) })); - let db = db.await?; + let db = Arc::new(db.await?); { let transaction = @@ -555,7 +572,8 @@ mod tests { { let mut number_of_renews = 0; let mut stream = StreamByRenewedCursor::new( - || { + db.clone(), + |db| { number_of_renews += 1; Ok(db.transaction_on_one("baz")?) }, @@ -594,7 +612,8 @@ mod tests { { let mut number_of_renews = 0; let stream = StreamByRenewedCursor::new( - || { + db.clone(), + |db| { number_of_renews += 1; Ok(db.transaction_on_one("baz")?) }, diff --git a/crates/matrix-sdk-sqlite/Cargo.toml b/crates/matrix-sdk-sqlite/Cargo.toml index eede3850ee4..49d07e2afb5 100644 --- a/crates/matrix-sdk-sqlite/Cargo.toml +++ b/crates/matrix-sdk-sqlite/Cargo.toml @@ -16,6 +16,7 @@ crypto-store = ["dep:matrix-sdk-crypto"] state-store = ["dep:matrix-sdk-base"] [dependencies] +async-stream = { workspace = true } async-trait = { workspace = true } deadpool-sqlite = "0.7.0" itertools = { workspace = true } diff --git a/crates/matrix-sdk-sqlite/src/crypto_store.rs b/crates/matrix-sdk-sqlite/src/crypto_store.rs index 522f2acbb00..a382fcf77e1 100644 --- a/crates/matrix-sdk-sqlite/src/crypto_store.rs +++ b/crates/matrix-sdk-sqlite/src/crypto_store.rs @@ -20,6 +20,7 @@ use std::{ sync::{Arc, RwLock}, }; +use async_stream::stream; use async_trait::async_trait; use deadpool_sqlite::{Object as SqliteConn, Pool as SqlitePool, Runtime}; use matrix_sdk_crypto::{ @@ -28,8 +29,8 @@ use matrix_sdk_crypto::{ PrivateCrossSigningIdentity, Session, StaticAccountData, }, store::{ - caches::SessionStore, BackupKeys, Changes, CryptoStore, PendingChanges, RoomKeyCounts, - RoomSettings, + self, caches::SessionStore, BackupKeys, Changes, CryptoStore, PendingChanges, + RoomKeyCounts, RoomSettings, StreamOf, }, types::events::room_key_withheld::RoomKeyWithheldEvent, Account, GossipRequest, GossippedSecret, ReadOnlyDevice, ReadOnlyUserIdentities, SecretInfo, @@ -940,6 +941,18 @@ impl CryptoStore for SqliteCryptoStore { .collect() } + async fn get_inbound_group_sessions_stream( + &self, + ) -> Result>> { + let inbound_group_sessions = self.get_inbound_group_sessions().await?; + + Ok(StreamOf::new(Box::pin(stream! { + for item in inbound_group_sessions { + yield Ok(item); + } + }))) + } + async fn inbound_group_session_counts(&self) -> Result { Ok(self.acquire().await?.get_inbound_group_session_counts().await?) }