Skip to content

Commit

Permalink
Simplify map conditional compilation (#158)
Browse files Browse the repository at this point in the history
* Simplify map conditional compilation

* Fix aws-lc-rs version

* Apply suggestions from code review

Co-authored-by: Stephane Raux <94983192+stefunctional@users.noreply.github.com>

* Run fmt

---------

Co-authored-by: Marta Mularczyk <mulmarta@amazon.com>
Co-authored-by: Stephane Raux <94983192+stefunctional@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 12, 2024
1 parent 310df89 commit f4af668
Show file tree
Hide file tree
Showing 11 changed files with 177 additions and 162 deletions.
4 changes: 2 additions & 2 deletions mls-rs-crypto-awslc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ license = "Apache-2.0 OR MIT"

[dependencies]
aws-lc-rs = "1.7.0"
aws-lc-sys = { version = "0.16.0" }
mls-rs-core = { path = "../mls-rs-core", version = "0.18.0" }
aws-lc-sys = { version = "0.17.0" }
mls-rs-core = { path = "../mls-rs-core", version = "=0.18.0" }
mls-rs-crypto-hpke = { path = "../mls-rs-crypto-hpke", version = "0.9.0" }
mls-rs-crypto-traits = { path = "../mls-rs-crypto-traits", version = "0.10.0" }
mls-rs-identity-x509 = { path = "../mls-rs-identity-x509", version = "0.11.0" }
Expand Down
10 changes: 3 additions & 7 deletions mls-rs/src/group/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ use crate::psk::{
ResumptionPSKUsage, ResumptionPsk,
};

#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
use std::collections::HashMap;

#[cfg(feature = "private_message")]
use ciphertext_processor::*;

Expand Down Expand Up @@ -265,10 +262,9 @@ where
epoch_secrets: EpochSecrets,
private_tree: TreeKemPrivate,
key_schedule: KeySchedule,
#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
pending_updates: HashMap<HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>)>, // Hash of leaf node hpke public key to secret key
#[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))]
pending_updates: Vec<(HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>))>,
#[cfg(feature = "by_ref_proposal")]
pending_updates:
crate::map::SmallMap<HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>)>, // Hash of leaf node hpke public key to secret key
pending_commit: Option<CommitGeneration>,
#[cfg(feature = "psk")]
previous_psk: Option<PskSecretInput>,
Expand Down
11 changes: 2 additions & 9 deletions mls-rs/src/group/proposal_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ use crate::group::{proposal_filter::FilterStrategy, ProposalRef, ProtocolVersion

use crate::tree_kem::leaf_node::LeafNode;

#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
use std::collections::HashMap;

#[cfg(feature = "by_ref_proposal")]
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};

Expand All @@ -50,10 +47,7 @@ pub struct CachedProposal {
pub(crate) struct ProposalCache {
protocol_version: ProtocolVersion,
group_id: Vec<u8>,
#[cfg(feature = "std")]
pub(crate) proposals: HashMap<ProposalRef, CachedProposal>,
#[cfg(not(feature = "std"))]
pub(crate) proposals: Vec<(ProposalRef, CachedProposal)>,
pub(crate) proposals: crate::map::SmallMap<ProposalRef, CachedProposal>,
}

#[cfg(feature = "by_ref_proposal")]
Expand Down Expand Up @@ -83,8 +77,7 @@ impl ProposalCache {
pub fn import(
protocol_version: ProtocolVersion,
group_id: Vec<u8>,
#[cfg(feature = "std")] proposals: HashMap<ProposalRef, CachedProposal>,
#[cfg(not(feature = "std"))] proposals: Vec<(ProposalRef, CachedProposal)>,
proposals: crate::map::SmallMap<ProposalRef, CachedProposal>,
) -> Self {
Self {
protocol_version,
Expand Down
59 changes: 6 additions & 53 deletions mls-rs/src/group/secret_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,11 @@ use core::{

use zeroize::Zeroizing;

use crate::{client::MlsError, tree_kem::math::TreeIndex, CipherSuiteProvider};
use crate::{client::MlsError, map::LargeMap, tree_kem::math::TreeIndex, CipherSuiteProvider};

use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
use mls_rs_core::error::IntoAnyError;

#[cfg(feature = "std")]
use std::collections::HashMap;

#[cfg(not(feature = "std"))]
use alloc::collections::BTreeMap;

use super::key_schedule::kdf_expand_with_label;

pub(crate) const MAX_RATCHET_BACK_HISTORY: u32 = 1024;
Expand Down Expand Up @@ -94,13 +88,9 @@ impl From<Zeroizing<Vec<u8>>> for TreeSecret {
#[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
struct TreeSecretsVec<T: TreeIndex> {
#[cfg(feature = "std")]
inner: HashMap<T, SecretTreeNode>,
#[cfg(not(feature = "std"))]
inner: Vec<(T, SecretTreeNode)>,
inner: LargeMap<T, SecretTreeNode>,
}

#[cfg(feature = "std")]
impl<T: TreeIndex> TreeSecretsVec<T> {
fn set_node(&mut self, index: T, value: SecretTreeNode) {
self.inner.insert(index, value);
Expand All @@ -111,30 +101,6 @@ impl<T: TreeIndex> TreeSecretsVec<T> {
}
}

#[cfg(not(feature = "std"))]
impl<T: TreeIndex> TreeSecretsVec<T> {
fn set_node(&mut self, index: T, value: SecretTreeNode) {
if let Some(i) = self.find_node(&index) {
self.inner[i] = (index, value)
} else {
self.inner.push((index, value))
}
}

fn take_node(&mut self, index: &T) -> Option<SecretTreeNode> {
self.find_node(index).map(|i| self.inner.remove(i).1)
}

fn find_node(&self, index: &T) -> Option<usize> {
use itertools::Itertools;

self.inner
.iter()
.find_position(|(i, _)| i == index)
.map(|(i, _)| i)
}
}

#[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SecretTree<T: TreeIndex> {
Expand Down Expand Up @@ -364,10 +330,8 @@ impl MessageKeyData {
pub struct SecretKeyRatchet {
secret: TreeSecret,
generation: u32,
#[cfg(all(feature = "out_of_order", feature = "std"))]
history: HashMap<u32, MessageKeyData>,
#[cfg(all(feature = "out_of_order", not(feature = "std")))]
history: BTreeMap<u32, MessageKeyData>,
#[cfg(feature = "out_of_order")]
history: LargeMap<u32, MessageKeyData>,
}

impl MlsSize for SecretKeyRatchet {
Expand Down Expand Up @@ -404,20 +368,9 @@ impl MlsDecode for SecretKeyRatchet {
Ok(Self {
secret: mls_rs_codec::byte_vec::mls_decode(reader)?,
generation: u32::mls_decode(reader)?,
#[cfg(all(feature = "std", feature = "out_of_order"))]
history: mls_rs_codec::iter::mls_decode_collection(reader, |data| {
let mut items = HashMap::default();

while !data.is_empty() {
let item = MessageKeyData::mls_decode(data)?;
items.insert(item.generation, item);
}

Ok(items)
})?,
#[cfg(all(not(feature = "std"), feature = "out_of_order"))]
#[cfg(feature = "out_of_order")]
history: mls_rs_codec::iter::mls_decode_collection(reader, |data| {
let mut items = alloc::collections::BTreeMap::default();
let mut items = LargeMap::default();

while !data.is_empty() {
let item = MessageKeyData::mls_decode(data)?;
Expand Down
19 changes: 5 additions & 14 deletions mls-rs/src/group/snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::{
use crate::{
crypto::{HpkePublicKey, HpkeSecretKey},
group::ProposalRef,
map::SmallMap,
};

#[cfg(feature = "by_ref_proposal")]
Expand All @@ -27,12 +28,6 @@ use mls_rs_core::crypto::SignatureSecretKey;
#[cfg(feature = "tree_index")]
use mls_rs_core::identity::IdentityProvider;

#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
use std::collections::HashMap;

#[cfg(all(feature = "by_ref_proposal", not(feature = "std")))]
use alloc::vec::Vec;

use super::{cipher_suite_provider, epoch::EpochSecrets, state_repo::GroupStateRepository};

#[derive(Debug, PartialEq, Clone, MlsEncode, MlsDecode, MlsSize)]
Expand All @@ -43,10 +38,8 @@ pub(crate) struct Snapshot {
private_tree: TreeKemPrivate,
epoch_secrets: EpochSecrets,
key_schedule: KeySchedule,
#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
pending_updates: HashMap<HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>)>,
#[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))]
pending_updates: Vec<(HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>))>,
#[cfg(feature = "by_ref_proposal")]
pending_updates: SmallMap<HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>)>,
pending_commit: Option<CommitGeneration>,
signer: SignatureSecretKey,
}
Expand All @@ -55,10 +48,8 @@ pub(crate) struct Snapshot {
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub(crate) struct RawGroupState {
pub(crate) context: GroupContext,
#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
pub(crate) proposals: HashMap<ProposalRef, CachedProposal>,
#[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))]
pub(crate) proposals: Vec<(ProposalRef, CachedProposal)>,
#[cfg(feature = "by_ref_proposal")]
pub(crate) proposals: SmallMap<ProposalRef, CachedProposal>,
pub(crate) public_tree: TreeKemPublic,
pub(crate) interim_transcript_hash: InterimTranscriptHash,
pub(crate) pending_reinit: Option<ReInitProposal>,
Expand Down
1 change: 1 addition & 0 deletions mls-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ mod hash_reference;
pub mod identity;
mod iter;
mod key_package;
pub(crate) mod map;
/// Pre-shared key support.
pub mod psk;
mod signer;
Expand Down
118 changes: 118 additions & 0 deletions mls-rs/src/map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use alloc::vec::Vec;
use core::{
hash::Hash,
ops::{Deref, DerefMut},
};

use map_impl::SmallMapInner;
pub use map_impl::{LargeMap, LargeMapEntry, SmallMap};
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};

#[cfg(feature = "std")]
mod map_impl {
use core::hash::Hash;
use std::collections::{hash_map::Entry, HashMap};

#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SmallMap<K: Hash + Eq, V>(pub(super) HashMap<K, V>);

pub type LargeMap<K, V> = SmallMap<K, V>;
pub(super) type SmallMapInner<K, V> = HashMap<K, V>;
pub type LargeMapEntry<'a, K, V> = Entry<'a, K, V>;
}

#[cfg(not(feature = "std"))]
mod map_impl {
use core::hash::Hash;

use alloc::{
collections::{btree_map::Entry, BTreeMap},
vec::Vec,
};
#[cfg(feature = "by_ref_proposal")]
use itertools::Itertools;

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct SmallMap<K: Hash + Eq, V>(pub(super) Vec<(K, V)>);

pub type LargeMap<K, V> = BTreeMap<K, V>;
pub(super) type SmallMapInner<K, V> = Vec<(K, V)>;
pub type LargeMapEntry<'a, K, V> = Entry<'a, K, V>;

#[cfg(feature = "by_ref_proposal")]
impl<K: Hash + Eq, V> SmallMap<K, V> {
pub fn get(&self, key: &K) -> Option<&V> {
self.find(key).map(|i| &self.0[i].1)
}

pub fn insert(&mut self, key: K, value: V) {
match self.0.iter_mut().find(|(k, _)| (k == &key)) {
Some((_, v)) => *v = value,
None => self.0.push((key, value)),
}
}

pub fn remove(&mut self, key: &K) -> Option<V> {
self.find(key).map(|i| self.0.remove(i).1)
}

fn find(&self, key: &K) -> Option<usize> {
self.0.iter().position(|(k, _)| k == key)
}
}
}

impl<K: Hash + Eq, V> Default for SmallMap<K, V> {
fn default() -> Self {
Self(SmallMapInner::new())
}
}

impl<K: Hash + Eq, V> Deref for SmallMap<K, V> {
type Target = SmallMapInner<K, V>;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<K: Hash + Eq, V> DerefMut for SmallMap<K, V> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

impl<K, V> MlsDecode for SmallMap<K, V>
where
K: Hash + Eq + MlsEncode + MlsDecode + MlsSize,
V: MlsEncode + MlsDecode + MlsSize,
{
fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
SmallMapInner::mls_decode(reader).map(Self)
}
}

impl<K, V> MlsSize for SmallMap<K, V>
where
K: Hash + Eq + MlsEncode + MlsDecode + MlsSize,
V: MlsEncode + MlsDecode + MlsSize,
{
fn mls_encoded_len(&self) -> usize {
self.0.mls_encoded_len()
}
}

impl<K, V> MlsEncode for SmallMap<K, V>
where
K: Hash + Eq + MlsEncode + MlsDecode + MlsSize,
V: MlsEncode + MlsDecode + MlsSize,
{
fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
self.0.mls_encode(writer)
}
}
Loading

0 comments on commit f4af668

Please sign in to comment.