From 2a5a7adc9a882893b96f25fdde9d5c2ef1a93161 Mon Sep 17 00:00:00 2001 From: Tom Leavy Date: Sun, 25 Feb 2024 00:44:38 -0500 Subject: [PATCH 1/5] [uniffi] Add support for custom GroupStateStorage interface --- mls-rs-uniffi/src/config.rs | 43 +++++++++ mls-rs-uniffi/src/config/group_state.rs | 123 ++++++++++++++++++++++++ mls-rs-uniffi/src/lib.rs | 40 ++++---- 3 files changed, 187 insertions(+), 19 deletions(-) create mode 100644 mls-rs-uniffi/src/config.rs create mode 100644 mls-rs-uniffi/src/config/group_state.rs diff --git a/mls-rs-uniffi/src/config.rs b/mls-rs-uniffi/src/config.rs new file mode 100644 index 00000000..953317a3 --- /dev/null +++ b/mls-rs-uniffi/src/config.rs @@ -0,0 +1,43 @@ +use std::{fmt::Debug, sync::Arc}; + +use mls_rs::{ + client_builder::{self, WithGroupStateStorage}, + identity::basic, +}; +use mls_rs_core::error::IntoAnyError; +use mls_rs_crypto_openssl::OpensslCryptoProvider; + +use self::group_state::GroupStateStorageWrapper; + +mod group_state; + +#[derive(Debug, thiserror::Error, uniffi::Error)] +#[uniffi(flat_error)] +#[non_exhaustive] +pub enum FFICallbackError { + #[error("data preparation error")] + DataPreparationError { + #[from] + inner: mls_rs_core::mls_rs_codec::Error, + }, + #[error("unexpected callback error")] + UnexpectedCallbackError { + #[from] + inner: uniffi::UnexpectedUniFFICallbackError, + }, +} + +impl IntoAnyError for FFICallbackError {} + +pub type UniFFIConfig = client_builder::WithIdentityProvider< + basic::BasicIdentityProvider, + client_builder::WithCryptoProvider< + OpensslCryptoProvider, + WithGroupStateStorage, + >, +>; + +#[derive(Debug, Clone, uniffi::Record)] +pub struct ClientConfig { + pub group_state_storage: Arc, +} diff --git a/mls-rs-uniffi/src/config/group_state.rs b/mls-rs-uniffi/src/config/group_state.rs new file mode 100644 index 00000000..322610a8 --- /dev/null +++ b/mls-rs-uniffi/src/config/group_state.rs @@ -0,0 +1,123 @@ +use std::{fmt::Debug, sync::Arc}; + +use mls_rs_core::mls_rs_codec::{MlsDecode, MlsEncode}; + +use super::FFICallbackError; + +#[derive(Clone, Debug, uniffi::Object)] +pub struct GroupState { + pub id: Vec, + pub data: Vec, +} + +impl mls_rs_core::group::GroupState for GroupState { + fn id(&self) -> Vec { + self.id.clone() + } +} + +#[derive(Clone, Debug, uniffi::Object)] +pub struct EpochRecord { + pub id: u64, + pub data: Vec, +} + +impl mls_rs_core::group::EpochRecord for EpochRecord { + fn id(&self) -> u64 { + self.id + } +} + +#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] +#[cfg_attr(mls_build_async, maybe_async::must_be_async)] +#[uniffi::export] +pub trait GroupStateStorage: Send + Sync + Debug { + async fn state(&self, group_id: &[u8]) -> Result>, FFICallbackError>; + async fn epoch( + &self, + group_id: &[u8], + epoch_id: u64, + ) -> Result>, FFICallbackError>; + + async fn write( + &self, + state: Arc, + epoch_inserts: Vec>, + epoch_updates: Vec>, + ) -> Result<(), FFICallbackError>; + + async fn max_epoch_id(&self, group_id: &[u8]) -> Result, FFICallbackError>; +} + +#[derive(Debug, Clone)] +pub(crate) struct GroupStateStorageWrapper(Arc); + +impl From> for GroupStateStorageWrapper { + fn from(value: Arc) -> Self { + Self(value) + } +} + +#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] +#[cfg_attr(mls_build_async, maybe_async::must_be_async)] +impl mls_rs_core::group::GroupStateStorage for GroupStateStorageWrapper { + type Error = FFICallbackError; + + async fn state(&self, group_id: &[u8]) -> Result, Self::Error> + where + T: mls_rs_core::group::GroupState + MlsEncode + MlsDecode, + { + let state_data = self.0.state(group_id)?; + + state_data + .as_deref() + .map(|v| T::mls_decode(&mut &*v)) + .transpose() + .map_err(Into::into) + } + + async fn epoch(&self, group_id: &[u8], epoch_id: u64) -> Result, Self::Error> + where + T: mls_rs_core::group::EpochRecord + MlsEncode + MlsDecode, + { + let epoch_data = self.0.epoch(group_id, epoch_id)?; + + epoch_data + .as_deref() + .map(|v| T::mls_decode(&mut &*v)) + .transpose() + .map_err(Into::into) + } + + async fn write( + &mut self, + state: ST, + epoch_inserts: Vec, + epoch_updates: Vec, + ) -> Result<(), Self::Error> + where + ST: mls_rs_core::group::GroupState + MlsEncode + MlsDecode + Send + Sync, + ET: mls_rs_core::group::EpochRecord + MlsEncode + MlsDecode + Send + Sync, + { + let state = Arc::new(GroupState { + id: state.id(), + data: state.mls_encode_to_vec().unwrap(), + }); + + let epoch_to_record = |v: ET| { + Arc::new(EpochRecord { + id: v.id(), + data: v.mls_encode_to_vec().unwrap(), + }) + }; + + let inserts = epoch_inserts.into_iter().map(epoch_to_record).collect(); + let updates = epoch_updates.into_iter().map(epoch_to_record).collect(); + + self.0.write(state, inserts, updates) + } + + async fn max_epoch_id(&self, group_id: &[u8]) -> Result, Self::Error> { + self.0.max_epoch_id(group_id) + } +} diff --git a/mls-rs-uniffi/src/lib.rs b/mls-rs-uniffi/src/lib.rs index 086eb77b..1d364127 100644 --- a/mls-rs-uniffi/src/lib.rs +++ b/mls-rs-uniffi/src/lib.rs @@ -19,15 +19,16 @@ #[cfg(test)] pub mod test_utils; +mod config; use std::sync::Arc; +use config::{ClientConfig, UniFFIConfig}; #[cfg(not(mls_build_async))] use std::sync::Mutex; #[cfg(mls_build_async)] use tokio::sync::Mutex; -use mls_rs::client_builder; use mls_rs::error::{IntoAnyError, MlsError}; use mls_rs::group; use mls_rs::identity::basic; @@ -96,11 +97,6 @@ pub struct SignatureKeypair { secret_key: Arc, } -pub type Config = client_builder::WithIdentityProvider< - basic::BasicIdentityProvider, - client_builder::WithCryptoProvider, ->; - /// Light-weight wrapper around a [`mls_rs::ExtensionList`]. #[derive(uniffi::Object, Debug, Clone)] pub struct ExtensionList { @@ -247,7 +243,7 @@ pub async fn generate_signature_keypair( /// See [`mls_rs::Client`] for details. #[derive(Clone, Debug, uniffi::Object)] pub struct Client { - inner: mls_rs::client::Client, + inner: mls_rs::client::Client, } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] @@ -260,7 +256,11 @@ impl Client { /// /// See [`mls_rs::Client::builder`] for details. #[uniffi::constructor] - pub fn new(id: Vec, signature_keypair: SignatureKeypair) -> Self { + pub fn new( + id: Vec, + signature_keypair: SignatureKeypair, + client_config: ClientConfig, + ) -> Self { let cipher_suite = signature_keypair.cipher_suite; let public_key = arc_unwrap_or_clone(signature_keypair.public_key); let secret_key = arc_unwrap_or_clone(signature_keypair.secret_key); @@ -268,13 +268,15 @@ impl Client { let basic_credential = BasicCredential::new(id); let signing_identity = identity::SigningIdentity::new(basic_credential.into_credential(), public_key.inner); - Client { - inner: mls_rs::Client::builder() - .crypto_provider(crypto_provider) - .identity_provider(basic::BasicIdentityProvider::new()) - .signing_identity(signing_identity, secret_key.inner, cipher_suite.into()) - .build(), - } + + let client = mls_rs::Client::builder() + .crypto_provider(crypto_provider) + .identity_provider(basic::BasicIdentityProvider::new()) + .signing_identity(signing_identity, secret_key.inner, cipher_suite.into()) + .group_state_storage(client_config.group_state_storage.into()) + .build(); + + Client { inner: client } } /// Generate a new key package for this client. @@ -379,25 +381,25 @@ impl From for SigningIdentity { /// See [`mls_rs::Group`] for details. #[derive(Clone, uniffi::Object)] pub struct Group { - inner: Arc>>, + inner: Arc>>, } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] impl Group { #[cfg(not(mls_build_async))] - fn inner(&self) -> std::sync::MutexGuard<'_, mls_rs::Group> { + fn inner(&self) -> std::sync::MutexGuard<'_, mls_rs::Group> { self.inner.lock().unwrap() } #[cfg(mls_build_async)] - async fn inner(&self) -> tokio::sync::MutexGuard<'_, mls_rs::Group> { + async fn inner(&self) -> tokio::sync::MutexGuard<'_, mls_rs::Group> { self.inner.lock().await } } /// Find the identity for the member with a given index. fn index_to_identity( - group: &mls_rs::Group, + group: &mls_rs::Group, index: u32, ) -> Result { let member = group From b143d5a669494a400ced6c94208b274f5b34ad4e Mon Sep 17 00:00:00 2001 From: Tom Leavy Date: Mon, 26 Feb 2024 12:24:16 -0500 Subject: [PATCH 2/5] Allow external languages to implement callbacks --- mls-rs-uniffi/Cargo.toml | 1 + mls-rs-uniffi/src/config/group_state.rs | 14 +++++++------- mls-rs-uniffi/src/lib.rs | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/mls-rs-uniffi/Cargo.toml b/mls-rs-uniffi/Cargo.toml index 45005ad8..f1320637 100644 --- a/mls-rs-uniffi/Cargo.toml +++ b/mls-rs-uniffi/Cargo.toml @@ -15,6 +15,7 @@ crate-type = ["lib", "cdylib"] name = "mls_rs_uniffi" [dependencies] +async-trait = "0.1.77" maybe-async = "0.2.10" mls-rs = { path = "../mls-rs" } mls-rs-core = { path = "../mls-rs-core" } diff --git a/mls-rs-uniffi/src/config/group_state.rs b/mls-rs-uniffi/src/config/group_state.rs index 322610a8..8b69ecf2 100644 --- a/mls-rs-uniffi/src/config/group_state.rs +++ b/mls-rs-uniffi/src/config/group_state.rs @@ -30,12 +30,12 @@ impl mls_rs_core::group::EpochRecord for EpochRecord { #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(mls_build_async, maybe_async::must_be_async)] -#[uniffi::export] +#[uniffi::export(with_foreign)] pub trait GroupStateStorage: Send + Sync + Debug { - async fn state(&self, group_id: &[u8]) -> Result>, FFICallbackError>; + async fn state(&self, group_id: Vec) -> Result>, FFICallbackError>; async fn epoch( &self, - group_id: &[u8], + group_id: Vec, epoch_id: u64, ) -> Result>, FFICallbackError>; @@ -46,7 +46,7 @@ pub trait GroupStateStorage: Send + Sync + Debug { epoch_updates: Vec>, ) -> Result<(), FFICallbackError>; - async fn max_epoch_id(&self, group_id: &[u8]) -> Result, FFICallbackError>; + async fn max_epoch_id(&self, group_id: Vec) -> Result, FFICallbackError>; } #[derive(Debug, Clone)] @@ -67,7 +67,7 @@ impl mls_rs_core::group::GroupStateStorage for GroupStateStorageWrapper { where T: mls_rs_core::group::GroupState + MlsEncode + MlsDecode, { - let state_data = self.0.state(group_id)?; + let state_data = self.0.state(group_id.to_vec())?; state_data .as_deref() @@ -80,7 +80,7 @@ impl mls_rs_core::group::GroupStateStorage for GroupStateStorageWrapper { where T: mls_rs_core::group::EpochRecord + MlsEncode + MlsDecode, { - let epoch_data = self.0.epoch(group_id, epoch_id)?; + let epoch_data = self.0.epoch(group_id.to_vec(), epoch_id)?; epoch_data .as_deref() @@ -118,6 +118,6 @@ impl mls_rs_core::group::GroupStateStorage for GroupStateStorageWrapper { } async fn max_epoch_id(&self, group_id: &[u8]) -> Result, Self::Error> { - self.0.max_epoch_id(group_id) + self.0.max_epoch_id(group_id.to_vec()) } } diff --git a/mls-rs-uniffi/src/lib.rs b/mls-rs-uniffi/src/lib.rs index 1d364127..ffeef99c 100644 --- a/mls-rs-uniffi/src/lib.rs +++ b/mls-rs-uniffi/src/lib.rs @@ -17,9 +17,9 @@ //! //! [UniFFI]: https://mozilla.github.io/uniffi-rs/ +mod config; #[cfg(test)] pub mod test_utils; -mod config; use std::sync::Arc; From cf32ac9992de2fee112a2580c732ef8a575a0fe5 Mon Sep 17 00:00:00 2001 From: Tom Leavy Date: Mon, 26 Feb 2024 14:22:00 -0500 Subject: [PATCH 3/5] Remove unwrap from group_state.rs --- mls-rs-uniffi/src/config/group_state.rs | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/mls-rs-uniffi/src/config/group_state.rs b/mls-rs-uniffi/src/config/group_state.rs index 8b69ecf2..44746a4c 100644 --- a/mls-rs-uniffi/src/config/group_state.rs +++ b/mls-rs-uniffi/src/config/group_state.rs @@ -101,18 +101,25 @@ impl mls_rs_core::group::GroupStateStorage for GroupStateStorageWrapper { { let state = Arc::new(GroupState { id: state.id(), - data: state.mls_encode_to_vec().unwrap(), + data: state.mls_encode_to_vec()?, }); - let epoch_to_record = |v: ET| { - Arc::new(EpochRecord { + let epoch_to_record = |v: ET| -> Result<_, Self::Error> { + Ok(Arc::new(EpochRecord { id: v.id(), - data: v.mls_encode_to_vec().unwrap(), - }) + data: v.mls_encode_to_vec()?, + })) }; - let inserts = epoch_inserts.into_iter().map(epoch_to_record).collect(); - let updates = epoch_updates.into_iter().map(epoch_to_record).collect(); + let inserts = epoch_inserts + .into_iter() + .map(epoch_to_record) + .collect::, _>>()?; + + let updates = epoch_updates + .into_iter() + .map(epoch_to_record) + .collect::, _>>()?; self.0.write(state, inserts, updates) } From 65a090fdfab9da5a0246a310edd0fcd06c7c3779 Mon Sep 17 00:00:00 2001 From: Tom Leavy Date: Mon, 26 Feb 2024 16:54:57 -0500 Subject: [PATCH 4/5] Fix simple_scenario_sync and remove async for now --- mls-rs-uniffi/src/config/group_state.rs | 18 ++--- mls-rs-uniffi/src/lib.rs | 24 +++++- .../test_bindings/simple_scenario_async.py | 19 ----- .../test_bindings/simple_scenario_sync.py | 75 ++++++++++++++++++- 4 files changed, 102 insertions(+), 34 deletions(-) delete mode 100644 mls-rs-uniffi/test_bindings/simple_scenario_async.py diff --git a/mls-rs-uniffi/src/config/group_state.rs b/mls-rs-uniffi/src/config/group_state.rs index 44746a4c..aa381725 100644 --- a/mls-rs-uniffi/src/config/group_state.rs +++ b/mls-rs-uniffi/src/config/group_state.rs @@ -4,7 +4,7 @@ use mls_rs_core::mls_rs_codec::{MlsDecode, MlsEncode}; use super::FFICallbackError; -#[derive(Clone, Debug, uniffi::Object)] +#[derive(Clone, Debug, uniffi::Record)] pub struct GroupState { pub id: Vec, pub data: Vec, @@ -16,7 +16,7 @@ impl mls_rs_core::group::GroupState for GroupState { } } -#[derive(Clone, Debug, uniffi::Object)] +#[derive(Clone, Debug, uniffi::Record)] pub struct EpochRecord { pub id: u64, pub data: Vec, @@ -41,9 +41,9 @@ pub trait GroupStateStorage: Send + Sync + Debug { async fn write( &self, - state: Arc, - epoch_inserts: Vec>, - epoch_updates: Vec>, + state: GroupState, + epoch_inserts: Vec, + epoch_updates: Vec, ) -> Result<(), FFICallbackError>; async fn max_epoch_id(&self, group_id: Vec) -> Result, FFICallbackError>; @@ -99,16 +99,16 @@ impl mls_rs_core::group::GroupStateStorage for GroupStateStorageWrapper { ST: mls_rs_core::group::GroupState + MlsEncode + MlsDecode + Send + Sync, ET: mls_rs_core::group::EpochRecord + MlsEncode + MlsDecode + Send + Sync, { - let state = Arc::new(GroupState { + let state = GroupState { id: state.id(), data: state.mls_encode_to_vec()?, - }); + }; let epoch_to_record = |v: ET| -> Result<_, Self::Error> { - Ok(Arc::new(EpochRecord { + Ok(EpochRecord { id: v.id(), data: v.mls_encode_to_vec()?, - })) + }) }; let inserts = epoch_inserts diff --git a/mls-rs-uniffi/src/lib.rs b/mls-rs-uniffi/src/lib.rs index ffeef99c..23ae2498 100644 --- a/mls-rs-uniffi/src/lib.rs +++ b/mls-rs-uniffi/src/lib.rs @@ -53,12 +53,12 @@ fn arc_unwrap_or_clone(arc: Arc) -> T { #[uniffi(flat_error)] #[non_exhaustive] pub enum Error { - #[error("A mls-rs error occurred")] + #[error("A mls-rs error occurred: {inner}")] MlsError { #[from] inner: mls_rs::error::MlsError, }, - #[error("An unknown error occurred")] + #[error("An unknown error occurred: {inner}")] AnyError { #[from] inner: mls_rs::error::AnyError, @@ -329,6 +329,19 @@ impl Client { group_info_extensions, }) } + + /// Load an existing group. + /// + /// See [`mls_rs::Client::load_group`] for details. + pub async fn load_group(&self, group_id: Vec) -> Result { + self.inner + .load_group(&group_id) + .await + .map(|g| Group { + inner: Arc::new(Mutex::new(g)), + }) + .map_err(Into::into) + } } #[derive(Clone, Debug, uniffi::Object)] @@ -423,6 +436,13 @@ async fn signing_identity_to_identifier( #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[uniffi::export] impl Group { + /// Write the current state of the group to storage defined by + /// [`ClientConfig::group_state_storage`] + pub async fn write_to_storage(&self) -> Result<(), Error> { + let mut group = self.inner().await; + group.write_to_storage().await.map_err(Into::into) + } + /// Perform a commit of received proposals (or an empty commit). /// /// TODO: ensure `path_required` is always set in diff --git a/mls-rs-uniffi/test_bindings/simple_scenario_async.py b/mls-rs-uniffi/test_bindings/simple_scenario_async.py deleted file mode 100644 index 2e7a0c16..00000000 --- a/mls-rs-uniffi/test_bindings/simple_scenario_async.py +++ /dev/null @@ -1,19 +0,0 @@ -from mls_rs_uniffi import CipherSuite, generate_signature_keypair, Client -from asyncio import run - -key = run(generate_signature_keypair(CipherSuite.CURVE25519_AES128)) -alice = Client(b'alice', key) - -key = run(generate_signature_keypair(CipherSuite.CURVE25519_AES128)) -bob = Client(b'bob', key) - -alice = run(alice.create_group(None)) -kp = run(bob.generate_key_package_message()) -commit = run(alice.add_members([kp])) -run(alice.process_incoming_message(commit.commit_message())) -bob = run(bob.join_group(commit.welcome_messages()[0])).group - -msg = run(alice.encrypt_application_message(b'hello, bob')) -output = run(bob.process_incoming_message(msg)) - -assert output.data == b'hello, bob' \ No newline at end of file diff --git a/mls-rs-uniffi/test_bindings/simple_scenario_sync.py b/mls-rs-uniffi/test_bindings/simple_scenario_sync.py index 5aa3d411..39c76041 100644 --- a/mls-rs-uniffi/test_bindings/simple_scenario_sync.py +++ b/mls-rs-uniffi/test_bindings/simple_scenario_sync.py @@ -1,10 +1,74 @@ -from mls_rs_uniffi import CipherSuite, generate_signature_keypair, Client +from mls_rs_uniffi import CipherSuite, generate_signature_keypair, Client, GroupStateStorage, ClientConfig + +class EpochData: + def __init__(self, id: "int", data: "bytes"): + self.id = id + self.data = data + +class GroupStateData: + def __init__(self, state: "bytes"): + self.state = state + self.epoch_data = [] + +class PythonGroupStateStorage(GroupStateStorage): + def __init__(self): + self.groups = {} + + def state(self, group_id: "bytes"): + group = self.groups.get(group_id.hex()) + + if group == None: + return None + + group.state + + def epoch(self, group_id: "bytes",epoch_id: "int"): + group = self.groups[group_id.hex()] + + if group == None: + return None + + for epoch in group.epoch_data: + if epoch.id == epoch_id: + return epoch + + return None + + def write(self, state: "GroupState",epoch_inserts: "typing.List[EpochRecord]",epoch_updates: "typing.List[EpochRecord]"): + if self.groups.get(state.id.hex()) == None: + self.groups[state.id.hex()] = GroupStateData(state.data) + + group = self.groups[state.id.hex()] + + for insert in epoch_inserts: + group.epoch_data.append(insert) + + for update in epoch_updates: + for i in range(len(group.epoch_data)): + if group.epoch_data[i].id == update.id: + group.epoch_data[i] = update + + def max_epoch_id(self, group_id: "bytes"): + group = self.groups.get(group_id.hex()) + + if group == None: + return None + + last = group.epoch_data.last() + + if last == None: + return None + + return last.id + +group_state_storage = PythonGroupStateStorage() +client_config = ClientConfig(group_state_storage) key = generate_signature_keypair(CipherSuite.CURVE25519_AES128) -alice = Client(b'alice', key) +alice = Client(b'alice', key, client_config) key = generate_signature_keypair(CipherSuite.CURVE25519_AES128) -bob = Client(b'bob', key) +bob = Client(b'bob', key, client_config) alice = alice.create_group(None) kp = bob.generate_key_package_message() @@ -15,4 +79,7 @@ msg = alice.encrypt_application_message(b'hello, bob') output = bob.process_incoming_message(msg) -assert output.data == b'hello, bob' \ No newline at end of file +alice.write_to_storage() + +assert output.data == b'hello, bob' +assert len(group_state_storage.groups) == 1 From 3f9bc18c41051e39e45be0f223d5ed7b948f06c8 Mon Sep 17 00:00:00 2001 From: Tom Leavy Date: Mon, 26 Feb 2024 17:00:02 -0500 Subject: [PATCH 5/5] Ignore async tests --- mls-rs-uniffi/src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/mls-rs-uniffi/src/lib.rs b/mls-rs-uniffi/src/lib.rs index 23ae2498..814f85fa 100644 --- a/mls-rs-uniffi/src/lib.rs +++ b/mls-rs-uniffi/src/lib.rs @@ -619,6 +619,7 @@ mod sync_tests { mod async_tests { use crate::test_utils::run_python; + #[ignore] #[test] fn test_simple_scenario() -> Result<(), Box> { run_python(include_str!("../test_bindings/simple_scenario_async.py"))