Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[uniffi] Add support for custom GroupStateStorage interface #86

Merged
merged 5 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mls-rs-uniffi/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ crate-type = ["lib", "cdylib"]
name = "mls_rs_uniffi"

[dependencies]
async-trait = "0.1.77"
maybe-async = "0.2.10"
mls-rs = { path = "../mls-rs" }
mls-rs-core = { path = "../mls-rs-core" }
Expand Down
43 changes: 43 additions & 0 deletions mls-rs-uniffi/src/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use std::{fmt::Debug, sync::Arc};

use mls_rs::{
client_builder::{self, WithGroupStateStorage},
identity::basic,
};
use mls_rs_core::error::IntoAnyError;
use mls_rs_crypto_openssl::OpensslCryptoProvider;

use self::group_state::GroupStateStorageWrapper;

mod group_state;

#[derive(Debug, thiserror::Error, uniffi::Error)]
#[uniffi(flat_error)]
#[non_exhaustive]
pub enum FFICallbackError {
tomleavy marked this conversation as resolved.
Show resolved Hide resolved
#[error("data preparation error")]
DataPreparationError {
#[from]
inner: mls_rs_core::mls_rs_codec::Error,
},
#[error("unexpected callback error")]
UnexpectedCallbackError {
#[from]
inner: uniffi::UnexpectedUniFFICallbackError,
},
}

impl IntoAnyError for FFICallbackError {}

pub type UniFFIConfig = client_builder::WithIdentityProvider<
basic::BasicIdentityProvider,
client_builder::WithCryptoProvider<
OpensslCryptoProvider,
WithGroupStateStorage<GroupStateStorageWrapper, client_builder::BaseConfig>,
>,
>;

#[derive(Debug, Clone, uniffi::Record)]
pub struct ClientConfig {
pub group_state_storage: Arc<dyn group_state::GroupStateStorage>,
}
130 changes: 130 additions & 0 deletions mls-rs-uniffi/src/config/group_state.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
use std::{fmt::Debug, sync::Arc};

use mls_rs_core::mls_rs_codec::{MlsDecode, MlsEncode};

use super::FFICallbackError;

#[derive(Clone, Debug, uniffi::Record)]
pub struct GroupState {
pub id: Vec<u8>,
pub data: Vec<u8>,
}

impl mls_rs_core::group::GroupState for GroupState {
fn id(&self) -> Vec<u8> {
self.id.clone()
}
}

#[derive(Clone, Debug, uniffi::Record)]
pub struct EpochRecord {
pub id: u64,
pub data: Vec<u8>,
}

impl mls_rs_core::group::EpochRecord for EpochRecord {
fn id(&self) -> u64 {
self.id
}
}

#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
#[uniffi::export(with_foreign)]
pub trait GroupStateStorage: Send + Sync + Debug {
async fn state(&self, group_id: Vec<u8>) -> Result<Option<Vec<u8>>, FFICallbackError>;
async fn epoch(
&self,
group_id: Vec<u8>,
epoch_id: u64,
) -> Result<Option<Vec<u8>>, FFICallbackError>;

async fn write(
&self,
state: GroupState,
epoch_inserts: Vec<EpochRecord>,
epoch_updates: Vec<EpochRecord>,
) -> Result<(), FFICallbackError>;

async fn max_epoch_id(&self, group_id: Vec<u8>) -> Result<Option<u64>, FFICallbackError>;
}

#[derive(Debug, Clone)]
pub(crate) struct GroupStateStorageWrapper(Arc<dyn GroupStateStorage>);

impl From<Arc<dyn GroupStateStorage>> for GroupStateStorageWrapper {
fn from(value: Arc<dyn GroupStateStorage>) -> Self {
Self(value)
}
}

#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
impl mls_rs_core::group::GroupStateStorage for GroupStateStorageWrapper {
type Error = FFICallbackError;

async fn state<T>(&self, group_id: &[u8]) -> Result<Option<T>, Self::Error>
where
T: mls_rs_core::group::GroupState + MlsEncode + MlsDecode,
{
let state_data = self.0.state(group_id.to_vec())?;

state_data
.as_deref()
.map(|v| T::mls_decode(&mut &*v))
.transpose()
.map_err(Into::into)
}

async fn epoch<T>(&self, group_id: &[u8], epoch_id: u64) -> Result<Option<T>, Self::Error>
where
T: mls_rs_core::group::EpochRecord + MlsEncode + MlsDecode,
{
let epoch_data = self.0.epoch(group_id.to_vec(), epoch_id)?;

epoch_data
.as_deref()
.map(|v| T::mls_decode(&mut &*v))
.transpose()
.map_err(Into::into)
}

async fn write<ST, ET>(
&mut self,
state: ST,
epoch_inserts: Vec<ET>,
epoch_updates: Vec<ET>,
) -> Result<(), Self::Error>
where
ST: mls_rs_core::group::GroupState + MlsEncode + MlsDecode + Send + Sync,
ET: mls_rs_core::group::EpochRecord + MlsEncode + MlsDecode + Send + Sync,
{
let state = GroupState {
id: state.id(),
data: state.mls_encode_to_vec()?,
};

let epoch_to_record = |v: ET| -> Result<_, Self::Error> {
Ok(EpochRecord {
id: v.id(),
data: v.mls_encode_to_vec()?,
})
};

let inserts = epoch_inserts
.into_iter()
.map(epoch_to_record)
.collect::<Result<Vec<_>, _>>()?;

let updates = epoch_updates
.into_iter()
.map(epoch_to_record)
.collect::<Result<Vec<_>, _>>()?;

self.0.write(state, inserts, updates)
}

async fn max_epoch_id(&self, group_id: &[u8]) -> Result<Option<u64>, Self::Error> {
self.0.max_epoch_id(group_id.to_vec())
}
}
65 changes: 44 additions & 21 deletions mls-rs-uniffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@
//!
//! [UniFFI]: https://mozilla.github.io/uniffi-rs/

mod config;
#[cfg(test)]
pub mod test_utils;

use std::sync::Arc;

use config::{ClientConfig, UniFFIConfig};
#[cfg(not(mls_build_async))]
use std::sync::Mutex;
#[cfg(mls_build_async)]
use tokio::sync::Mutex;

use mls_rs::client_builder;
use mls_rs::error::{IntoAnyError, MlsError};
use mls_rs::group;
use mls_rs::identity::basic;
Expand All @@ -52,12 +53,12 @@ fn arc_unwrap_or_clone<T: Clone>(arc: Arc<T>) -> T {
#[uniffi(flat_error)]
#[non_exhaustive]
pub enum Error {
#[error("A mls-rs error occurred")]
#[error("A mls-rs error occurred: {inner}")]
MlsError {
#[from]
inner: mls_rs::error::MlsError,
},
#[error("An unknown error occurred")]
#[error("An unknown error occurred: {inner}")]
AnyError {
#[from]
inner: mls_rs::error::AnyError,
Expand Down Expand Up @@ -96,11 +97,6 @@ pub struct SignatureKeypair {
secret_key: Arc<SignatureSecretKey>,
}

pub type Config = client_builder::WithIdentityProvider<
basic::BasicIdentityProvider,
client_builder::WithCryptoProvider<OpensslCryptoProvider, client_builder::BaseConfig>,
>;

/// Light-weight wrapper around a [`mls_rs::ExtensionList`].
#[derive(uniffi::Object, Debug, Clone)]
pub struct ExtensionList {
Expand Down Expand Up @@ -247,7 +243,7 @@ pub async fn generate_signature_keypair(
/// See [`mls_rs::Client`] for details.
#[derive(Clone, Debug, uniffi::Object)]
pub struct Client {
inner: mls_rs::client::Client<Config>,
inner: mls_rs::client::Client<UniFFIConfig>,
}

#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
Expand All @@ -260,21 +256,27 @@ impl Client {
///
/// See [`mls_rs::Client::builder`] for details.
#[uniffi::constructor]
pub fn new(id: Vec<u8>, signature_keypair: SignatureKeypair) -> Self {
pub fn new(
id: Vec<u8>,
signature_keypair: SignatureKeypair,
client_config: ClientConfig,
) -> Self {
let cipher_suite = signature_keypair.cipher_suite;
let public_key = arc_unwrap_or_clone(signature_keypair.public_key);
let secret_key = arc_unwrap_or_clone(signature_keypair.secret_key);
let crypto_provider = OpensslCryptoProvider::new();
let basic_credential = BasicCredential::new(id);
let signing_identity =
identity::SigningIdentity::new(basic_credential.into_credential(), public_key.inner);
Client {
inner: mls_rs::Client::builder()
.crypto_provider(crypto_provider)
.identity_provider(basic::BasicIdentityProvider::new())
.signing_identity(signing_identity, secret_key.inner, cipher_suite.into())
.build(),
}

let client = mls_rs::Client::builder()
.crypto_provider(crypto_provider)
.identity_provider(basic::BasicIdentityProvider::new())
.signing_identity(signing_identity, secret_key.inner, cipher_suite.into())
.group_state_storage(client_config.group_state_storage.into())
.build();

Client { inner: client }
}

/// Generate a new key package for this client.
Expand Down Expand Up @@ -327,6 +329,19 @@ impl Client {
group_info_extensions,
})
}

/// Load an existing group.
///
/// See [`mls_rs::Client::load_group`] for details.
pub async fn load_group(&self, group_id: Vec<u8>) -> Result<Group, Error> {
self.inner
.load_group(&group_id)
.await
.map(|g| Group {
inner: Arc::new(Mutex::new(g)),
})
.map_err(Into::into)
}
}

#[derive(Clone, Debug, uniffi::Object)]
Expand Down Expand Up @@ -379,25 +394,25 @@ impl From<identity::SigningIdentity> for SigningIdentity {
/// See [`mls_rs::Group`] for details.
#[derive(Clone, uniffi::Object)]
pub struct Group {
inner: Arc<Mutex<mls_rs::Group<Config>>>,
inner: Arc<Mutex<mls_rs::Group<UniFFIConfig>>>,
}

#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
impl Group {
#[cfg(not(mls_build_async))]
fn inner(&self) -> std::sync::MutexGuard<'_, mls_rs::Group<Config>> {
fn inner(&self) -> std::sync::MutexGuard<'_, mls_rs::Group<UniFFIConfig>> {
self.inner.lock().unwrap()
}

#[cfg(mls_build_async)]
async fn inner(&self) -> tokio::sync::MutexGuard<'_, mls_rs::Group<Config>> {
async fn inner(&self) -> tokio::sync::MutexGuard<'_, mls_rs::Group<UniFFIConfig>> {
self.inner.lock().await
}
}

/// Find the identity for the member with a given index.
fn index_to_identity(
group: &mls_rs::Group<Config>,
group: &mls_rs::Group<UniFFIConfig>,
index: u32,
) -> Result<identity::SigningIdentity, Error> {
let member = group
Expand All @@ -421,6 +436,13 @@ async fn signing_identity_to_identifier(
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
#[uniffi::export]
impl Group {
/// Write the current state of the group to storage defined by
/// [`ClientConfig::group_state_storage`]
pub async fn write_to_storage(&self) -> Result<(), Error> {
let mut group = self.inner().await;
group.write_to_storage().await.map_err(Into::into)
}

/// Perform a commit of received proposals (or an empty commit).
///
/// TODO: ensure `path_required` is always set in
Expand Down Expand Up @@ -597,6 +619,7 @@ mod sync_tests {
mod async_tests {
use crate::test_utils::run_python;

#[ignore]
#[test]
fn test_simple_scenario() -> Result<(), Box<dyn std::error::Error>> {
run_python(include_str!("../test_bindings/simple_scenario_async.py"))
Expand Down
19 changes: 0 additions & 19 deletions mls-rs-uniffi/test_bindings/simple_scenario_async.py

This file was deleted.

Loading
Loading