diff --git a/mls-rs-crypto-awslc/Cargo.toml b/mls-rs-crypto-awslc/Cargo.toml index 92ac1b4c..669b4b43 100644 --- a/mls-rs-crypto-awslc/Cargo.toml +++ b/mls-rs-crypto-awslc/Cargo.toml @@ -10,8 +10,8 @@ license = "Apache-2.0 OR MIT" [dependencies] aws-lc-rs = "1.7.0" -aws-lc-sys = { version = "0.16.0" } -mls-rs-core = { path = "../mls-rs-core", version = "0.18.0" } +aws-lc-sys = { version = "0.17.0" } +mls-rs-core = { path = "../mls-rs-core", version = "=0.18.0" } mls-rs-crypto-hpke = { path = "../mls-rs-crypto-hpke", version = "0.9.0" } mls-rs-crypto-traits = { path = "../mls-rs-crypto-traits", version = "0.10.0" } mls-rs-identity-x509 = { path = "../mls-rs-identity-x509", version = "0.11.0" } diff --git a/mls-rs/src/group/mod.rs b/mls-rs/src/group/mod.rs index 08490338..59cfc17b 100644 --- a/mls-rs/src/group/mod.rs +++ b/mls-rs/src/group/mod.rs @@ -51,9 +51,6 @@ use crate::psk::{ ResumptionPSKUsage, ResumptionPsk, }; -#[cfg(all(feature = "std", feature = "by_ref_proposal"))] -use std::collections::HashMap; - #[cfg(feature = "private_message")] use ciphertext_processor::*; @@ -265,10 +262,9 @@ where epoch_secrets: EpochSecrets, private_tree: TreeKemPrivate, key_schedule: KeySchedule, - #[cfg(all(feature = "std", feature = "by_ref_proposal"))] - pending_updates: HashMap)>, // Hash of leaf node hpke public key to secret key - #[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))] - pending_updates: Vec<(HpkePublicKey, (HpkeSecretKey, Option))>, + #[cfg(feature = "by_ref_proposal")] + pending_updates: + crate::map::SmallMap)>, // Hash of leaf node hpke public key to secret key pending_commit: Option, #[cfg(feature = "psk")] previous_psk: Option, diff --git a/mls-rs/src/group/proposal_cache.rs b/mls-rs/src/group/proposal_cache.rs index 5a4fa83f..c0512047 100644 --- a/mls-rs/src/group/proposal_cache.rs +++ b/mls-rs/src/group/proposal_cache.rs @@ -23,9 +23,6 @@ use crate::group::{proposal_filter::FilterStrategy, ProposalRef, ProtocolVersion use crate::tree_kem::leaf_node::LeafNode; -#[cfg(all(feature = "std", feature = "by_ref_proposal"))] -use std::collections::HashMap; - #[cfg(feature = "by_ref_proposal")] use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; @@ -50,10 +47,7 @@ pub struct CachedProposal { pub(crate) struct ProposalCache { protocol_version: ProtocolVersion, group_id: Vec, - #[cfg(feature = "std")] - pub(crate) proposals: HashMap, - #[cfg(not(feature = "std"))] - pub(crate) proposals: Vec<(ProposalRef, CachedProposal)>, + pub(crate) proposals: crate::map::SmallMap, } #[cfg(feature = "by_ref_proposal")] @@ -83,8 +77,7 @@ impl ProposalCache { pub fn import( protocol_version: ProtocolVersion, group_id: Vec, - #[cfg(feature = "std")] proposals: HashMap, - #[cfg(not(feature = "std"))] proposals: Vec<(ProposalRef, CachedProposal)>, + proposals: crate::map::SmallMap, ) -> Self { Self { protocol_version, diff --git a/mls-rs/src/group/secret_tree.rs b/mls-rs/src/group/secret_tree.rs index df0c30fa..1422f4cb 100644 --- a/mls-rs/src/group/secret_tree.rs +++ b/mls-rs/src/group/secret_tree.rs @@ -10,17 +10,11 @@ use core::{ use zeroize::Zeroizing; -use crate::{client::MlsError, tree_kem::math::TreeIndex, CipherSuiteProvider}; +use crate::{client::MlsError, map::LargeMap, tree_kem::math::TreeIndex, CipherSuiteProvider}; use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; use mls_rs_core::error::IntoAnyError; -#[cfg(feature = "std")] -use std::collections::HashMap; - -#[cfg(not(feature = "std"))] -use alloc::collections::BTreeMap; - use super::key_schedule::kdf_expand_with_label; pub(crate) const MAX_RATCHET_BACK_HISTORY: u32 = 1024; @@ -94,13 +88,9 @@ impl From>> for TreeSecret { #[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize, Default)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] struct TreeSecretsVec { - #[cfg(feature = "std")] - inner: HashMap, - #[cfg(not(feature = "std"))] - inner: Vec<(T, SecretTreeNode)>, + inner: LargeMap, } -#[cfg(feature = "std")] impl TreeSecretsVec { fn set_node(&mut self, index: T, value: SecretTreeNode) { self.inner.insert(index, value); @@ -111,30 +101,6 @@ impl TreeSecretsVec { } } -#[cfg(not(feature = "std"))] -impl TreeSecretsVec { - fn set_node(&mut self, index: T, value: SecretTreeNode) { - if let Some(i) = self.find_node(&index) { - self.inner[i] = (index, value) - } else { - self.inner.push((index, value)) - } - } - - fn take_node(&mut self, index: &T) -> Option { - self.find_node(index).map(|i| self.inner.remove(i).1) - } - - fn find_node(&self, index: &T) -> Option { - use itertools::Itertools; - - self.inner - .iter() - .find_position(|(i, _)| i == index) - .map(|(i, _)| i) - } -} - #[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct SecretTree { @@ -364,10 +330,8 @@ impl MessageKeyData { pub struct SecretKeyRatchet { secret: TreeSecret, generation: u32, - #[cfg(all(feature = "out_of_order", feature = "std"))] - history: HashMap, - #[cfg(all(feature = "out_of_order", not(feature = "std")))] - history: BTreeMap, + #[cfg(feature = "out_of_order")] + history: LargeMap, } impl MlsSize for SecretKeyRatchet { @@ -404,20 +368,9 @@ impl MlsDecode for SecretKeyRatchet { Ok(Self { secret: mls_rs_codec::byte_vec::mls_decode(reader)?, generation: u32::mls_decode(reader)?, - #[cfg(all(feature = "std", feature = "out_of_order"))] - history: mls_rs_codec::iter::mls_decode_collection(reader, |data| { - let mut items = HashMap::default(); - - while !data.is_empty() { - let item = MessageKeyData::mls_decode(data)?; - items.insert(item.generation, item); - } - - Ok(items) - })?, - #[cfg(all(not(feature = "std"), feature = "out_of_order"))] + #[cfg(feature = "out_of_order")] history: mls_rs_codec::iter::mls_decode_collection(reader, |data| { - let mut items = alloc::collections::BTreeMap::default(); + let mut items = LargeMap::default(); while !data.is_empty() { let item = MessageKeyData::mls_decode(data)?; diff --git a/mls-rs/src/group/snapshot.rs b/mls-rs/src/group/snapshot.rs index dca64f8f..b070d71b 100644 --- a/mls-rs/src/group/snapshot.rs +++ b/mls-rs/src/group/snapshot.rs @@ -16,6 +16,7 @@ use crate::{ use crate::{ crypto::{HpkePublicKey, HpkeSecretKey}, group::ProposalRef, + map::SmallMap, }; #[cfg(feature = "by_ref_proposal")] @@ -27,12 +28,6 @@ use mls_rs_core::crypto::SignatureSecretKey; #[cfg(feature = "tree_index")] use mls_rs_core::identity::IdentityProvider; -#[cfg(all(feature = "std", feature = "by_ref_proposal"))] -use std::collections::HashMap; - -#[cfg(all(feature = "by_ref_proposal", not(feature = "std")))] -use alloc::vec::Vec; - use super::{cipher_suite_provider, epoch::EpochSecrets, state_repo::GroupStateRepository}; #[derive(Debug, PartialEq, Clone, MlsEncode, MlsDecode, MlsSize)] @@ -43,10 +38,8 @@ pub(crate) struct Snapshot { private_tree: TreeKemPrivate, epoch_secrets: EpochSecrets, key_schedule: KeySchedule, - #[cfg(all(feature = "std", feature = "by_ref_proposal"))] - pending_updates: HashMap)>, - #[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))] - pending_updates: Vec<(HpkePublicKey, (HpkeSecretKey, Option))>, + #[cfg(feature = "by_ref_proposal")] + pending_updates: SmallMap)>, pending_commit: Option, signer: SignatureSecretKey, } @@ -55,10 +48,8 @@ pub(crate) struct Snapshot { #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub(crate) struct RawGroupState { pub(crate) context: GroupContext, - #[cfg(all(feature = "std", feature = "by_ref_proposal"))] - pub(crate) proposals: HashMap, - #[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))] - pub(crate) proposals: Vec<(ProposalRef, CachedProposal)>, + #[cfg(feature = "by_ref_proposal")] + pub(crate) proposals: SmallMap, pub(crate) public_tree: TreeKemPublic, pub(crate) interim_transcript_hash: InterimTranscriptHash, pub(crate) pending_reinit: Option, diff --git a/mls-rs/src/lib.rs b/mls-rs/src/lib.rs index 115b3f87..080d4c51 100644 --- a/mls-rs/src/lib.rs +++ b/mls-rs/src/lib.rs @@ -148,6 +148,7 @@ mod hash_reference; pub mod identity; mod iter; mod key_package; +pub(crate) mod map; /// Pre-shared key support. pub mod psk; mod signer; diff --git a/mls-rs/src/map.rs b/mls-rs/src/map.rs new file mode 100644 index 00000000..067072a5 --- /dev/null +++ b/mls-rs/src/map.rs @@ -0,0 +1,118 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// Copyright by contributors to this project. +// SPDX-License-Identifier: (Apache-2.0 OR MIT) + +use alloc::vec::Vec; +use core::{ + hash::Hash, + ops::{Deref, DerefMut}, +}; + +use map_impl::SmallMapInner; +pub use map_impl::{LargeMap, LargeMapEntry, SmallMap}; +use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; + +#[cfg(feature = "std")] +mod map_impl { + use core::hash::Hash; + use std::collections::{hash_map::Entry, HashMap}; + + #[derive(Clone, Debug, PartialEq, Eq)] + #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] + pub struct SmallMap(pub(super) HashMap); + + pub type LargeMap = SmallMap; + pub(super) type SmallMapInner = HashMap; + pub type LargeMapEntry<'a, K, V> = Entry<'a, K, V>; +} + +#[cfg(not(feature = "std"))] +mod map_impl { + use core::hash::Hash; + + use alloc::{ + collections::{btree_map::Entry, BTreeMap}, + vec::Vec, + }; + #[cfg(feature = "by_ref_proposal")] + use itertools::Itertools; + + #[derive(Clone, Debug, PartialEq, Eq)] + pub struct SmallMap(pub(super) Vec<(K, V)>); + + pub type LargeMap = BTreeMap; + pub(super) type SmallMapInner = Vec<(K, V)>; + pub type LargeMapEntry<'a, K, V> = Entry<'a, K, V>; + + #[cfg(feature = "by_ref_proposal")] + impl SmallMap { + pub fn get(&self, key: &K) -> Option<&V> { + self.find(key).map(|i| &self.0[i].1) + } + + pub fn insert(&mut self, key: K, value: V) { + match self.0.iter_mut().find(|(k, _)| (k == &key)) { + Some((_, v)) => *v = value, + None => self.0.push((key, value)), + } + } + + pub fn remove(&mut self, key: &K) -> Option { + self.find(key).map(|i| self.0.remove(i).1) + } + + fn find(&self, key: &K) -> Option { + self.0.iter().position(|(k, _)| k == key) + } + } +} + +impl Default for SmallMap { + fn default() -> Self { + Self(SmallMapInner::new()) + } +} + +impl Deref for SmallMap { + type Target = SmallMapInner; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for SmallMap { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl MlsDecode for SmallMap +where + K: Hash + Eq + MlsEncode + MlsDecode + MlsSize, + V: MlsEncode + MlsDecode + MlsSize, +{ + fn mls_decode(reader: &mut &[u8]) -> Result { + SmallMapInner::mls_decode(reader).map(Self) + } +} + +impl MlsSize for SmallMap +where + K: Hash + Eq + MlsEncode + MlsDecode + MlsSize, + V: MlsEncode + MlsDecode + MlsSize, +{ + fn mls_encoded_len(&self) -> usize { + self.0.mls_encoded_len() + } +} + +impl MlsEncode for SmallMap +where + K: Hash + Eq + MlsEncode + MlsDecode + MlsSize, + V: MlsEncode + MlsDecode + MlsSize, +{ + fn mls_encode(&self, writer: &mut Vec) -> Result<(), mls_rs_codec::Error> { + self.0.mls_encode(writer) + } +} diff --git a/mls-rs/src/storage_provider/in_memory/group_state_storage.rs b/mls-rs/src/storage_provider/in_memory/group_state_storage.rs index 5999ed03..3b3e3e36 100644 --- a/mls-rs/src/storage_provider/in_memory/group_state_storage.rs +++ b/mls-rs/src/storage_provider/in_memory/group_state_storage.rs @@ -18,19 +18,16 @@ use mls_rs_core::group::{EpochRecord, GroupState, GroupStateStorage}; #[cfg(not(target_has_atomic = "ptr"))] use portable_atomic_util::Arc; -use crate::client::MlsError; - -#[cfg(feature = "std")] -use std::collections::{hash_map::Entry, HashMap}; - -#[cfg(not(feature = "std"))] -use alloc::collections::{btree_map::Entry, BTreeMap}; +use crate::{ + client::MlsError, + map::{LargeMap, LargeMapEntry}, +}; #[cfg(feature = "std")] -use std::sync::Mutex; +use std::sync::{Mutex, MutexGuard}; #[cfg(not(feature = "std"))] -use spin::Mutex; +use spin::{Mutex, MutexGuard}; pub(crate) const DEFAULT_EPOCH_RETENTION_LIMIT: usize = 3; @@ -101,10 +98,7 @@ impl InMemoryGroupData { /// /// All clones of an instance of this type share the same underlying HashMap. pub struct InMemoryGroupStateStorage { - #[cfg(feature = "std")] - pub(crate) inner: Arc, InMemoryGroupData>>>, - #[cfg(not(feature = "std"))] - pub(crate) inner: Arc, InMemoryGroupData>>>, + pub(crate) inner: Arc, InMemoryGroupData>>>, pub(crate) max_epoch_retention: usize, } @@ -158,14 +152,12 @@ impl InMemoryGroupStateStorage { self.lock().remove(group_id); } - #[cfg(feature = "std")] - fn lock(&self) -> std::sync::MutexGuard<'_, HashMap, InMemoryGroupData>> { - self.inner.lock().unwrap() - } + fn lock(&self) -> MutexGuard<'_, LargeMap, InMemoryGroupData>> { + #[cfg(feature = "std")] + return self.inner.lock().unwrap(); - #[cfg(not(feature = "std"))] - fn lock(&self) -> spin::mutex::MutexGuard<'_, BTreeMap, InMemoryGroupData>> { - self.inner.lock() + #[cfg(not(feature = "std"))] + return self.inner.lock(); } } @@ -210,12 +202,12 @@ impl GroupStateStorage for InMemoryGroupStateStorage { let mut group_map = self.lock(); let group_data = match group_map.entry(state.id) { - Entry::Occupied(entry) => { + LargeMapEntry::Occupied(entry) => { let data = entry.into_mut(); data.state_data = state.data; data } - Entry::Vacant(entry) => entry.insert(InMemoryGroupData::new(state.data)), + LargeMapEntry::Vacant(entry) => entry.insert(InMemoryGroupData::new(state.data)), }; epoch_inserts diff --git a/mls-rs/src/storage_provider/in_memory/key_package_storage.rs b/mls-rs/src/storage_provider/in_memory/key_package_storage.rs index 427a8a49..12ea8a89 100644 --- a/mls-rs/src/storage_provider/in_memory/key_package_storage.rs +++ b/mls-rs/src/storage_provider/in_memory/key_package_storage.rs @@ -13,31 +13,25 @@ use core::{ fmt::{self, Debug}, }; -#[cfg(feature = "std")] -use std::collections::HashMap; - -#[cfg(not(feature = "std"))] -use alloc::collections::BTreeMap; use alloc::vec::Vec; use mls_rs_core::key_package::{KeyPackageData, KeyPackageStorage}; #[cfg(feature = "std")] -use std::sync::Mutex; +use std::sync::{Mutex, MutexGuard}; #[cfg(mls_build_async)] use alloc::boxed::Box; #[cfg(not(feature = "std"))] -use spin::Mutex; +use spin::{Mutex, MutexGuard}; + +use crate::map::LargeMap; #[derive(Clone, Default)] /// In memory key package storage backed by a HashMap. /// /// All clones of an instance of this type share the same underlying HashMap. pub struct InMemoryKeyPackageStorage { - #[cfg(feature = "std")] - inner: Arc, KeyPackageData>>>, - #[cfg(not(feature = "std"))] - inner: Arc, KeyPackageData>>>, + inner: Arc, KeyPackageData>>>, } impl Debug for InMemoryKeyPackageStorage { @@ -88,14 +82,12 @@ impl InMemoryKeyPackageStorage { .collect() } - #[cfg(feature = "std")] - fn lock(&self) -> std::sync::MutexGuard<'_, HashMap, KeyPackageData>> { - self.inner.lock().unwrap() - } + fn lock(&self) -> MutexGuard<'_, LargeMap, KeyPackageData>> { + #[cfg(feature = "std")] + return self.inner.lock().unwrap(); - #[cfg(not(feature = "std"))] - fn lock(&self) -> spin::mutex::MutexGuard<'_, BTreeMap, KeyPackageData>> { - self.inner.lock() + #[cfg(not(feature = "std"))] + return self.inner.lock(); } } diff --git a/mls-rs/src/storage_provider/in_memory/psk_storage.rs b/mls-rs/src/storage_provider/in_memory/psk_storage.rs index e1b0b757..4ebad299 100644 --- a/mls-rs/src/storage_provider/in_memory/psk_storage.rs +++ b/mls-rs/src/storage_provider/in_memory/psk_storage.rs @@ -10,12 +10,6 @@ use portable_atomic_util::Arc; use core::convert::Infallible; -#[cfg(feature = "std")] -use std::collections::HashMap; - -#[cfg(not(feature = "std"))] -use alloc::collections::BTreeMap; - use mls_rs_core::psk::{ExternalPskId, PreSharedKey, PreSharedKeyStorage}; #[cfg(mls_build_async)] @@ -26,15 +20,14 @@ use std::sync::Mutex; #[cfg(not(feature = "std"))] use spin::Mutex; +use crate::map::LargeMap; + #[derive(Clone, Debug, Default)] /// In memory pre-shared key storage backed by a HashMap. /// /// All clones of an instance of this type share the same underlying HashMap. pub struct InMemoryPreSharedKeyStorage { - #[cfg(feature = "std")] - inner: Arc>>, - #[cfg(not(feature = "std"))] - inner: Arc>>, + inner: Arc>>, } impl InMemoryPreSharedKeyStorage { diff --git a/mls-rs/src/tree_kem/tree_index.rs b/mls-rs/src/tree_kem/tree_index.rs index 4e6731ad..013c941e 100644 --- a/mls-rs/src/tree_kem/tree_index.rs +++ b/mls-rs/src/tree_kem/tree_index.rs @@ -10,7 +10,10 @@ use core::fmt::{self, Debug}; use crate::group::proposal::ProposalType; #[cfg(feature = "tree_index")] -use crate::identity::CredentialType; +use crate::{ + identity::CredentialType, + map::{LargeMap, LargeMapEntry}, +}; #[cfg(feature = "tree_index")] use mls_rs_core::crypto::SignaturePublicKey; @@ -18,12 +21,6 @@ use mls_rs_core::crypto::SignaturePublicKey; #[cfg(all(feature = "tree_index", feature = "std"))] use itertools::Itertools; -#[cfg(all(feature = "tree_index", not(feature = "std")))] -use alloc::collections::{btree_map::Entry, BTreeMap}; - -#[cfg(all(feature = "tree_index", feature = "std"))] -use std::collections::{hash_map::Entry, HashMap}; - #[cfg(all(feature = "tree_index", not(feature = "std")))] use alloc::collections::BTreeSet; @@ -43,26 +40,15 @@ impl Debug for Identifier { } } -#[cfg(all(feature = "tree_index", feature = "std"))] -#[derive(Clone, Debug, Default, PartialEq, MlsSize, MlsEncode, MlsDecode)] -pub struct TreeIndex { - credential_signature_key: HashMap, - hpke_key: HashMap, - identities: HashMap, - credential_type_counters: HashMap, - #[cfg(feature = "custom_proposal")] - proposal_type_counter: HashMap, -} - -#[cfg(all(feature = "tree_index", not(feature = "std")))] +#[cfg(feature = "tree_index")] #[derive(Clone, Debug, Default, PartialEq, MlsSize, MlsEncode, MlsDecode)] pub struct TreeIndex { - credential_signature_key: BTreeMap, - hpke_key: BTreeMap, - identities: BTreeMap, - credential_type_counters: BTreeMap, + credential_signature_key: LargeMap, + hpke_key: LargeMap, + identities: LargeMap, + credential_type_counters: LargeMap, #[cfg(feature = "custom_proposal")] - proposal_type_counter: BTreeMap, + proposal_type_counter: LargeMap, } #[cfg(feature = "tree_index")] @@ -156,18 +142,18 @@ impl TreeIndex { let pub_key = leaf_node.signing_identity.signature_key.clone(); let credential_entry = self.credential_signature_key.entry(pub_key); - if let Entry::Occupied(entry) = credential_entry { + if let LargeMapEntry::Occupied(entry) = credential_entry { return Err(MlsError::DuplicateLeafData(**entry.get())); } let hpke_entry = self.hpke_key.entry(leaf_node.public_key.clone()); - if let Entry::Occupied(entry) = hpke_entry { + if let LargeMapEntry::Occupied(entry) = hpke_entry { return Err(MlsError::DuplicateLeafData(**entry.get())); } let identity_entry = self.identities.entry(Identifier(identity)); - if let Entry::Occupied(entry) = identity_entry { + if let LargeMapEntry::Occupied(entry) = identity_entry { return Err(MlsError::DuplicateLeafData(**entry.get())); }