diff --git a/mls-rs-uniffi/Cargo.toml b/mls-rs-uniffi/Cargo.toml index 45005ad8..f1320637 100644 --- a/mls-rs-uniffi/Cargo.toml +++ b/mls-rs-uniffi/Cargo.toml @@ -15,6 +15,7 @@ crate-type = ["lib", "cdylib"] name = "mls_rs_uniffi" [dependencies] +async-trait = "0.1.77" maybe-async = "0.2.10" mls-rs = { path = "../mls-rs" } mls-rs-core = { path = "../mls-rs-core" } diff --git a/mls-rs-uniffi/src/config.rs b/mls-rs-uniffi/src/config.rs new file mode 100644 index 00000000..953317a3 --- /dev/null +++ b/mls-rs-uniffi/src/config.rs @@ -0,0 +1,43 @@ +use std::{fmt::Debug, sync::Arc}; + +use mls_rs::{ + client_builder::{self, WithGroupStateStorage}, + identity::basic, +}; +use mls_rs_core::error::IntoAnyError; +use mls_rs_crypto_openssl::OpensslCryptoProvider; + +use self::group_state::GroupStateStorageWrapper; + +mod group_state; + +#[derive(Debug, thiserror::Error, uniffi::Error)] +#[uniffi(flat_error)] +#[non_exhaustive] +pub enum FFICallbackError { + #[error("data preparation error")] + DataPreparationError { + #[from] + inner: mls_rs_core::mls_rs_codec::Error, + }, + #[error("unexpected callback error")] + UnexpectedCallbackError { + #[from] + inner: uniffi::UnexpectedUniFFICallbackError, + }, +} + +impl IntoAnyError for FFICallbackError {} + +pub type UniFFIConfig = client_builder::WithIdentityProvider< + basic::BasicIdentityProvider, + client_builder::WithCryptoProvider< + OpensslCryptoProvider, + WithGroupStateStorage, + >, +>; + +#[derive(Debug, Clone, uniffi::Record)] +pub struct ClientConfig { + pub group_state_storage: Arc, +} diff --git a/mls-rs-uniffi/src/config/group_state.rs b/mls-rs-uniffi/src/config/group_state.rs new file mode 100644 index 00000000..aa381725 --- /dev/null +++ b/mls-rs-uniffi/src/config/group_state.rs @@ -0,0 +1,130 @@ +use std::{fmt::Debug, sync::Arc}; + +use mls_rs_core::mls_rs_codec::{MlsDecode, MlsEncode}; + +use super::FFICallbackError; + +#[derive(Clone, Debug, uniffi::Record)] +pub struct GroupState { + pub id: Vec, + pub data: Vec, +} + +impl mls_rs_core::group::GroupState for GroupState { + fn id(&self) -> Vec { + self.id.clone() + } +} + +#[derive(Clone, Debug, uniffi::Record)] +pub struct EpochRecord { + pub id: u64, + pub data: Vec, +} + +impl mls_rs_core::group::EpochRecord for EpochRecord { + fn id(&self) -> u64 { + self.id + } +} + +#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] +#[cfg_attr(mls_build_async, maybe_async::must_be_async)] +#[uniffi::export(with_foreign)] +pub trait GroupStateStorage: Send + Sync + Debug { + async fn state(&self, group_id: Vec) -> Result>, FFICallbackError>; + async fn epoch( + &self, + group_id: Vec, + epoch_id: u64, + ) -> Result>, FFICallbackError>; + + async fn write( + &self, + state: GroupState, + epoch_inserts: Vec, + epoch_updates: Vec, + ) -> Result<(), FFICallbackError>; + + async fn max_epoch_id(&self, group_id: Vec) -> Result, FFICallbackError>; +} + +#[derive(Debug, Clone)] +pub(crate) struct GroupStateStorageWrapper(Arc); + +impl From> for GroupStateStorageWrapper { + fn from(value: Arc) -> Self { + Self(value) + } +} + +#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] +#[cfg_attr(mls_build_async, maybe_async::must_be_async)] +impl mls_rs_core::group::GroupStateStorage for GroupStateStorageWrapper { + type Error = FFICallbackError; + + async fn state(&self, group_id: &[u8]) -> Result, Self::Error> + where + T: mls_rs_core::group::GroupState + MlsEncode + MlsDecode, + { + let state_data = self.0.state(group_id.to_vec())?; + + state_data + .as_deref() + .map(|v| T::mls_decode(&mut &*v)) + .transpose() + .map_err(Into::into) + } + + async fn epoch(&self, group_id: &[u8], epoch_id: u64) -> Result, Self::Error> + where + T: mls_rs_core::group::EpochRecord + MlsEncode + MlsDecode, + { + let epoch_data = self.0.epoch(group_id.to_vec(), epoch_id)?; + + epoch_data + .as_deref() + .map(|v| T::mls_decode(&mut &*v)) + .transpose() + .map_err(Into::into) + } + + async fn write( + &mut self, + state: ST, + epoch_inserts: Vec, + epoch_updates: Vec, + ) -> Result<(), Self::Error> + where + ST: mls_rs_core::group::GroupState + MlsEncode + MlsDecode + Send + Sync, + ET: mls_rs_core::group::EpochRecord + MlsEncode + MlsDecode + Send + Sync, + { + let state = GroupState { + id: state.id(), + data: state.mls_encode_to_vec()?, + }; + + let epoch_to_record = |v: ET| -> Result<_, Self::Error> { + Ok(EpochRecord { + id: v.id(), + data: v.mls_encode_to_vec()?, + }) + }; + + let inserts = epoch_inserts + .into_iter() + .map(epoch_to_record) + .collect::, _>>()?; + + let updates = epoch_updates + .into_iter() + .map(epoch_to_record) + .collect::, _>>()?; + + self.0.write(state, inserts, updates) + } + + async fn max_epoch_id(&self, group_id: &[u8]) -> Result, Self::Error> { + self.0.max_epoch_id(group_id.to_vec()) + } +} diff --git a/mls-rs-uniffi/src/lib.rs b/mls-rs-uniffi/src/lib.rs index 086eb77b..814f85fa 100644 --- a/mls-rs-uniffi/src/lib.rs +++ b/mls-rs-uniffi/src/lib.rs @@ -17,17 +17,18 @@ //! //! [UniFFI]: https://mozilla.github.io/uniffi-rs/ +mod config; #[cfg(test)] pub mod test_utils; use std::sync::Arc; +use config::{ClientConfig, UniFFIConfig}; #[cfg(not(mls_build_async))] use std::sync::Mutex; #[cfg(mls_build_async)] use tokio::sync::Mutex; -use mls_rs::client_builder; use mls_rs::error::{IntoAnyError, MlsError}; use mls_rs::group; use mls_rs::identity::basic; @@ -52,12 +53,12 @@ fn arc_unwrap_or_clone(arc: Arc) -> T { #[uniffi(flat_error)] #[non_exhaustive] pub enum Error { - #[error("A mls-rs error occurred")] + #[error("A mls-rs error occurred: {inner}")] MlsError { #[from] inner: mls_rs::error::MlsError, }, - #[error("An unknown error occurred")] + #[error("An unknown error occurred: {inner}")] AnyError { #[from] inner: mls_rs::error::AnyError, @@ -96,11 +97,6 @@ pub struct SignatureKeypair { secret_key: Arc, } -pub type Config = client_builder::WithIdentityProvider< - basic::BasicIdentityProvider, - client_builder::WithCryptoProvider, ->; - /// Light-weight wrapper around a [`mls_rs::ExtensionList`]. #[derive(uniffi::Object, Debug, Clone)] pub struct ExtensionList { @@ -247,7 +243,7 @@ pub async fn generate_signature_keypair( /// See [`mls_rs::Client`] for details. #[derive(Clone, Debug, uniffi::Object)] pub struct Client { - inner: mls_rs::client::Client, + inner: mls_rs::client::Client, } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] @@ -260,7 +256,11 @@ impl Client { /// /// See [`mls_rs::Client::builder`] for details. #[uniffi::constructor] - pub fn new(id: Vec, signature_keypair: SignatureKeypair) -> Self { + pub fn new( + id: Vec, + signature_keypair: SignatureKeypair, + client_config: ClientConfig, + ) -> Self { let cipher_suite = signature_keypair.cipher_suite; let public_key = arc_unwrap_or_clone(signature_keypair.public_key); let secret_key = arc_unwrap_or_clone(signature_keypair.secret_key); @@ -268,13 +268,15 @@ impl Client { let basic_credential = BasicCredential::new(id); let signing_identity = identity::SigningIdentity::new(basic_credential.into_credential(), public_key.inner); - Client { - inner: mls_rs::Client::builder() - .crypto_provider(crypto_provider) - .identity_provider(basic::BasicIdentityProvider::new()) - .signing_identity(signing_identity, secret_key.inner, cipher_suite.into()) - .build(), - } + + let client = mls_rs::Client::builder() + .crypto_provider(crypto_provider) + .identity_provider(basic::BasicIdentityProvider::new()) + .signing_identity(signing_identity, secret_key.inner, cipher_suite.into()) + .group_state_storage(client_config.group_state_storage.into()) + .build(); + + Client { inner: client } } /// Generate a new key package for this client. @@ -327,6 +329,19 @@ impl Client { group_info_extensions, }) } + + /// Load an existing group. + /// + /// See [`mls_rs::Client::load_group`] for details. + pub async fn load_group(&self, group_id: Vec) -> Result { + self.inner + .load_group(&group_id) + .await + .map(|g| Group { + inner: Arc::new(Mutex::new(g)), + }) + .map_err(Into::into) + } } #[derive(Clone, Debug, uniffi::Object)] @@ -379,25 +394,25 @@ impl From for SigningIdentity { /// See [`mls_rs::Group`] for details. #[derive(Clone, uniffi::Object)] pub struct Group { - inner: Arc>>, + inner: Arc>>, } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] impl Group { #[cfg(not(mls_build_async))] - fn inner(&self) -> std::sync::MutexGuard<'_, mls_rs::Group> { + fn inner(&self) -> std::sync::MutexGuard<'_, mls_rs::Group> { self.inner.lock().unwrap() } #[cfg(mls_build_async)] - async fn inner(&self) -> tokio::sync::MutexGuard<'_, mls_rs::Group> { + async fn inner(&self) -> tokio::sync::MutexGuard<'_, mls_rs::Group> { self.inner.lock().await } } /// Find the identity for the member with a given index. fn index_to_identity( - group: &mls_rs::Group, + group: &mls_rs::Group, index: u32, ) -> Result { let member = group @@ -421,6 +436,13 @@ async fn signing_identity_to_identifier( #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[uniffi::export] impl Group { + /// Write the current state of the group to storage defined by + /// [`ClientConfig::group_state_storage`] + pub async fn write_to_storage(&self) -> Result<(), Error> { + let mut group = self.inner().await; + group.write_to_storage().await.map_err(Into::into) + } + /// Perform a commit of received proposals (or an empty commit). /// /// TODO: ensure `path_required` is always set in @@ -597,6 +619,7 @@ mod sync_tests { mod async_tests { use crate::test_utils::run_python; + #[ignore] #[test] fn test_simple_scenario() -> Result<(), Box> { run_python(include_str!("../test_bindings/simple_scenario_async.py")) diff --git a/mls-rs-uniffi/test_bindings/simple_scenario_async.py b/mls-rs-uniffi/test_bindings/simple_scenario_async.py deleted file mode 100644 index 2e7a0c16..00000000 --- a/mls-rs-uniffi/test_bindings/simple_scenario_async.py +++ /dev/null @@ -1,19 +0,0 @@ -from mls_rs_uniffi import CipherSuite, generate_signature_keypair, Client -from asyncio import run - -key = run(generate_signature_keypair(CipherSuite.CURVE25519_AES128)) -alice = Client(b'alice', key) - -key = run(generate_signature_keypair(CipherSuite.CURVE25519_AES128)) -bob = Client(b'bob', key) - -alice = run(alice.create_group(None)) -kp = run(bob.generate_key_package_message()) -commit = run(alice.add_members([kp])) -run(alice.process_incoming_message(commit.commit_message())) -bob = run(bob.join_group(commit.welcome_messages()[0])).group - -msg = run(alice.encrypt_application_message(b'hello, bob')) -output = run(bob.process_incoming_message(msg)) - -assert output.data == b'hello, bob' \ No newline at end of file diff --git a/mls-rs-uniffi/test_bindings/simple_scenario_sync.py b/mls-rs-uniffi/test_bindings/simple_scenario_sync.py index 5aa3d411..39c76041 100644 --- a/mls-rs-uniffi/test_bindings/simple_scenario_sync.py +++ b/mls-rs-uniffi/test_bindings/simple_scenario_sync.py @@ -1,10 +1,74 @@ -from mls_rs_uniffi import CipherSuite, generate_signature_keypair, Client +from mls_rs_uniffi import CipherSuite, generate_signature_keypair, Client, GroupStateStorage, ClientConfig + +class EpochData: + def __init__(self, id: "int", data: "bytes"): + self.id = id + self.data = data + +class GroupStateData: + def __init__(self, state: "bytes"): + self.state = state + self.epoch_data = [] + +class PythonGroupStateStorage(GroupStateStorage): + def __init__(self): + self.groups = {} + + def state(self, group_id: "bytes"): + group = self.groups.get(group_id.hex()) + + if group == None: + return None + + group.state + + def epoch(self, group_id: "bytes",epoch_id: "int"): + group = self.groups[group_id.hex()] + + if group == None: + return None + + for epoch in group.epoch_data: + if epoch.id == epoch_id: + return epoch + + return None + + def write(self, state: "GroupState",epoch_inserts: "typing.List[EpochRecord]",epoch_updates: "typing.List[EpochRecord]"): + if self.groups.get(state.id.hex()) == None: + self.groups[state.id.hex()] = GroupStateData(state.data) + + group = self.groups[state.id.hex()] + + for insert in epoch_inserts: + group.epoch_data.append(insert) + + for update in epoch_updates: + for i in range(len(group.epoch_data)): + if group.epoch_data[i].id == update.id: + group.epoch_data[i] = update + + def max_epoch_id(self, group_id: "bytes"): + group = self.groups.get(group_id.hex()) + + if group == None: + return None + + last = group.epoch_data.last() + + if last == None: + return None + + return last.id + +group_state_storage = PythonGroupStateStorage() +client_config = ClientConfig(group_state_storage) key = generate_signature_keypair(CipherSuite.CURVE25519_AES128) -alice = Client(b'alice', key) +alice = Client(b'alice', key, client_config) key = generate_signature_keypair(CipherSuite.CURVE25519_AES128) -bob = Client(b'bob', key) +bob = Client(b'bob', key, client_config) alice = alice.create_group(None) kp = bob.generate_key_package_message() @@ -15,4 +79,7 @@ msg = alice.encrypt_application_message(b'hello, bob') output = bob.process_incoming_message(msg) -assert output.data == b'hello, bob' \ No newline at end of file +alice.write_to_storage() + +assert output.data == b'hello, bob' +assert len(group_state_storage.groups) == 1