Skip to content

Commit

Permalink
Rework leaf node extensions to work via parameters rather than as a c… (
Browse files Browse the repository at this point in the history
#196)

* Rework leaf node extensions to work via parameters rather than as a client configuration

* Rework key package extensions to work via parameters rather than as a client configuration

* Address clippy issues

* Apply formatting changes

* Remove TODO on Renit Client as WONT DO is the conclusion

* Fix unit tests breaking due to grease
  • Loading branch information
CaioSym authored Oct 17, 2024
1 parent 1a1fa84 commit c710afa
Show file tree
Hide file tree
Showing 30 changed files with 311 additions and 196 deletions.
13 changes: 10 additions & 3 deletions mls-rs-uniffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,10 @@ impl Client {
/// See [`mls_rs::Client::generate_key_package_message`] for
/// details.
pub async fn generate_key_package_message(&self) -> Result<Message, Error> {
let message = self.inner.generate_key_package_message().await?;
let message = self
.inner
.generate_key_package_message(Default::default(), Default::default())
.await?;
Ok(message.into())
}

Expand All @@ -403,10 +406,14 @@ impl Client {
let inner = match group_id {
Some(group_id) => {
self.inner
.create_group_with_id(group_id, extensions)
.create_group_with_id(group_id, extensions, Default::default())
.await?
}
None => {
self.inner
.create_group(extensions, Default::default())
.await?
}
None => self.inner.create_group(extensions).await?,
};
Ok(Group {
inner: Arc::new(Mutex::new(inner)),
Expand Down
4 changes: 2 additions & 2 deletions mls-rs/benches/group_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ use mls_rs_crypto_openssl::OpensslCryptoProvider;

fn bench(c: &mut Criterion) {
let alice = make_client("alice")
.create_group(Default::default())
.create_group(Default::default(), Default::default())
.unwrap();

const MAX_ADD_COUNT: usize = 1000;

let key_packages = (0..MAX_ADD_COUNT)
.map(|i| {
make_client(&format!("bob-{i}"))
.generate_key_package_message()
.generate_key_package_message(Default::default(), Default::default())
.unwrap()
})
.collect::<Vec<_>>();
Expand Down
5 changes: 3 additions & 2 deletions mls-rs/examples/basic_server_usage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,9 @@ fn main() -> Result<(), MlsError> {
let bob = make_client("bob")?;

// Alice creates a group with bob
let mut alice_group = alice.create_group(ExtensionList::default())?;
let bob_key_package = bob.generate_key_package_message()?;
let mut alice_group = alice.create_group(ExtensionList::default(), Default::default())?;
let bob_key_package =
bob.generate_key_package_message(Default::default(), Default::default())?;

let welcome = &alice_group
.commit_builder()
Expand Down
5 changes: 3 additions & 2 deletions mls-rs/examples/basic_usage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ fn main() -> Result<(), MlsError> {
let bob = make_client(crypto_provider.clone(), "bob")?;

// Alice creates a new group.
let mut alice_group = alice.create_group(ExtensionList::default())?;
let mut alice_group = alice.create_group(ExtensionList::default(), Default::default())?;

// Bob generates a key package that Alice needs to add Bob to the group.
let bob_key_package = bob.generate_key_package_message()?;
let bob_key_package =
bob.generate_key_package_message(Default::default(), Default::default())?;

// Alice issues a commit that adds Bob to the group.
let alice_commit = alice_group
Expand Down
9 changes: 6 additions & 3 deletions mls-rs/examples/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,11 +369,13 @@ fn main() -> Result<(), CustomError> {
let roster = vec![alice.credential];
context_extensions.set_from(RosterExtension { roster })?;

let mut alice_tablet_group = make_client(alice_tablet)?.create_group(context_extensions)?;
let mut alice_tablet_group =
make_client(alice_tablet)?.create_group(context_extensions, Default::default())?;

// Alice can add her other device
let alice_pc_client = make_client(alice_pc)?;
let key_package = alice_pc_client.generate_key_package_message()?;
let key_package =
alice_pc_client.generate_key_package_message(Default::default(), Default::default())?;

let welcome = alice_tablet_group
.commit_builder()
Expand All @@ -387,7 +389,8 @@ fn main() -> Result<(), CustomError> {

// Alice cannot add bob's devices yet
let bob_tablet_client = make_client(bob_tablet)?;
let key_package = bob_tablet_client.generate_key_package_message()?;
let key_package =
bob_tablet_client.generate_key_package_message(Default::default(), Default::default())?;

let res = alice_tablet_group
.commit_builder()
Expand Down
10 changes: 6 additions & 4 deletions mls-rs/examples/large_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,16 @@ fn make_groups_best_case<P: CryptoProvider + Clone>(
) -> Result<Vec<Group<impl MlsConfig>>, MlsError> {
let bob_client = make_client(crypto_provider.clone(), &make_name(0))?;

let bob_group = bob_client.create_group(Default::default())?;
let bob_group = bob_client.create_group(Default::default(), Default::default())?;

let mut groups = vec![bob_group];

for i in 0..(num_groups - 1) {
let bob_client = make_client(crypto_provider.clone(), &make_name(i + 1))?;

// The new client generates a key package.
let bob_kpkg = bob_client.generate_key_package_message()?;
let bob_kpkg =
bob_client.generate_key_package_message(Default::default(), Default::default())?;

// Last group sends a commit adding the new client to the group.
let commit = groups
Expand Down Expand Up @@ -100,7 +101,7 @@ fn make_groups_worst_case<P: CryptoProvider + Clone>(
) -> Result<Vec<Group<impl MlsConfig>>, MlsError> {
let alice_client = make_client(crypto_provider.clone(), &make_name(0))?;

let mut alice_group = alice_client.create_group(Default::default())?;
let mut alice_group = alice_client.create_group(Default::default(), Default::default())?;

let bob_clients = (0..(num_groups - 1))
.map(|i| make_client(crypto_provider.clone(), &make_name(i + 1)))
Expand All @@ -110,7 +111,8 @@ fn make_groups_worst_case<P: CryptoProvider + Clone>(
let mut commit_builder = alice_group.commit_builder();

for bob_client in &bob_clients {
let bob_kpkg = bob_client.generate_key_package_message()?;
let bob_kpkg =
bob_client.generate_key_package_message(Default::default(), Default::default())?;
commit_builder = commit_builder.add_member(bob_kpkg)?;
}

Expand Down
4 changes: 3 additions & 1 deletion mls-rs/examples/x509.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ fn main() {
.signing_identity(signing_identity, secret_key, CIPHERSUITE)
.build();

let mut alice_group = alice_client.create_group(Default::default()).unwrap();
let mut alice_group = alice_client
.create_group(Default::default(), Default::default())
.unwrap();

alice_group.commit(Vec::new()).unwrap();
alice_group.apply_pending_commit().unwrap();
Expand Down
66 changes: 51 additions & 15 deletions mls-rs/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,12 +429,23 @@ where
///
/// A key package message may only be used once.
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn generate_key_package_message(&self) -> Result<MlsMessage, MlsError> {
Ok(self.generate_key_package().await?.key_package_message())
pub async fn generate_key_package_message(
&self,
key_package_extensions: ExtensionList,
leaf_node_extensions: ExtensionList,
) -> Result<MlsMessage, MlsError> {
Ok(self
.generate_key_package(key_package_extensions, leaf_node_extensions)
.await?
.key_package_message())
}

#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn generate_key_package(&self) -> Result<KeyPackageGeneration, MlsError> {
async fn generate_key_package(
&self,
key_package_extensions: ExtensionList,
leaf_node_extensions: ExtensionList,
) -> Result<KeyPackageGeneration, MlsError> {
let (signing_identity, cipher_suite) = self.signing_identity()?;

let cipher_suite_provider = self
Expand All @@ -454,8 +465,8 @@ where
.generate(
self.config.lifetime(),
self.config.capabilities(),
self.config.key_package_extensions(),
self.config.leaf_node_extensions(),
key_package_extensions,
leaf_node_extensions,
)
.await?;

Expand Down Expand Up @@ -486,6 +497,7 @@ where
&self,
group_id: Vec<u8>,
group_context_extensions: ExtensionList,
leaf_node_extensions: ExtensionList,
) -> Result<Group<C>, MlsError> {
let (signing_identity, cipher_suite) = self.signing_identity()?;

Expand All @@ -496,6 +508,7 @@ where
self.version,
signing_identity.clone(),
group_context_extensions,
leaf_node_extensions,
self.signer()?.clone(),
)
.await
Expand All @@ -510,6 +523,7 @@ where
pub async fn create_group(
&self,
group_context_extensions: ExtensionList,
leaf_node_extensions: ExtensionList,
) -> Result<Group<C>, MlsError> {
let (signing_identity, cipher_suite) = self.signing_identity()?;

Expand All @@ -520,6 +534,7 @@ where
self.version,
signing_identity.clone(),
group_context_extensions,
leaf_node_extensions,
self.signer()?.clone(),
)
.await
Expand Down Expand Up @@ -674,6 +689,8 @@ where
group_info: &MlsMessage,
tree_data: Option<crate::group::ExportedTree<'_>>,
authenticated_data: Vec<u8>,
key_package_extensions: ExtensionList,
leaf_node_extensions: ExtensionList,
) -> Result<MlsMessage, MlsError> {
let protocol_version = group_info.version;

Expand Down Expand Up @@ -702,7 +719,10 @@ where
)
.await?;

let key_package = self.generate_key_package().await?.key_package;
let key_package = self
.generate_key_package(key_package_extensions, leaf_node_extensions)
.await?
.key_package;

(key_package.cipher_suite == cipher_suite)
.then_some(())
Expand Down Expand Up @@ -745,11 +765,6 @@ where
.ok_or(MlsError::SignerNotFound)
}

/// Returns key package extensions used by this client
pub fn key_package_extensions(&self) -> ExtensionList {
self.config.key_package_extensions()
}

/// The [KeyPackageStorage] that this client was configured to use.
#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
pub fn key_package_store(&self) -> <C as ClientConfig>::KeyPackageRepository {
Expand Down Expand Up @@ -793,14 +808,24 @@ pub(crate) mod test_utils {
cipher_suite: CipherSuite,
identity: &str,
) -> (Client<TestClientConfig>, MlsMessage) {
test_client_with_key_pkg_custom(protocol_version, cipher_suite, identity, |_| {}).await
test_client_with_key_pkg_custom(
protocol_version,
cipher_suite,
identity,
Default::default(),
Default::default(),
|_| {},
)
.await
}

#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn test_client_with_key_pkg_custom<F>(
protocol_version: ProtocolVersion,
cipher_suite: CipherSuite,
identity: &str,
key_package_extensions: ExtensionList,
leaf_node_extensions: ExtensionList,
mut config: F,
) -> (Client<TestClientConfig>, MlsMessage)
where
Expand All @@ -816,7 +841,10 @@ pub(crate) mod test_utils {

config(&mut client.config);

let key_package = client.generate_key_package_message().await.unwrap();
let key_package = client
.generate_key_package_message(key_package_extensions, leaf_node_extensions)
.await
.unwrap();

(client, key_package)
}
Expand Down Expand Up @@ -863,7 +891,10 @@ mod tests {
.build();

// TODO: Tests around extensions
let key_package = client.generate_key_package_message().await.unwrap();
let key_package = client
.generate_key_package_message(Default::default(), Default::default())
.await
.unwrap();

assert_eq!(key_package.version, protocol_version);

Expand Down Expand Up @@ -902,6 +933,8 @@ mod tests {
&alice_group.group_info_message(true).await.unwrap(),
None,
vec![],
Default::default(),
Default::default(),
)
.await
.unwrap();
Expand Down Expand Up @@ -1047,7 +1080,10 @@ mod tests {
.signing_identity(alice_identity.clone(), secret_key, TEST_CIPHER_SUITE)
.build();

let msg = alice.generate_key_package_message().await.unwrap();
let msg = alice
.generate_key_package_message(Default::default(), Default::default())
.await
.unwrap();
let res = alice.commit_external(msg).await.map(|_| ());

assert_matches!(res, Err(MlsError::UnexpectedMessageType));
Expand Down
Loading

0 comments on commit c710afa

Please sign in to comment.