diff --git a/tuta-sdk/rust/sdk/src/entity_client.rs b/tuta-sdk/rust/sdk/src/entity_client.rs index e6a9dd09dd26..51e7969026a2 100644 --- a/tuta-sdk/rust/sdk/src/entity_client.rs +++ b/tuta-sdk/rust/sdk/src/entity_client.rs @@ -6,7 +6,7 @@ use crate::generated_id::GeneratedId; use crate::{ApiCallError, AuthHeadersProvider, IdTuple, ListLoadDirection, SdkState, TypeRef}; use crate::json_serializer::JsonSerializer; use crate::json_element::RawEntity; -use crate::metamodel::TypeModel; +use crate::metamodel::{ElementType, TypeModel}; use crate::rest_client::{HttpMethod, RestClient, RestClientOptions}; use crate::rest_error::HttpError; use crate::type_model_provider::TypeModelProvider; @@ -48,28 +48,8 @@ impl EntityClient { type_ref: &TypeRef, id: &Id, ) -> Result { - let type_model = self.get_type_model(type_ref)?; let url = format!("{}/rest/{}/{}/{}", self.base_url, type_ref.app, type_ref.type_, id); - let model_version: u32 = type_model.version.parse().map_err(|_| { - let message = format!("Tried to parse invalid model_version {}", type_model.version); - ApiCallError::InternalSdkError { error_message: message } - })?; - let options = RestClientOptions { - body: None, - headers: self.sdk_state.create_auth_headers(model_version), - }; - let response = self - .rest_client - .request_binary(url, HttpMethod::GET, options) - .await?; - let precondition = response.headers.get("precondition"); - match response.status { - 200..=299 => { - // Ok - } - _ => return Err(ApiCallError::ServerResponseError { source: HttpError::from_http_response(response.status, precondition)? }) - } - let response_bytes = response.body.expect("no body"); + let response_bytes = self.prepare_and_fire(type_ref, url).await?.expect("no body"); let response_entity = serde_json::from_slice::(response_bytes.as_slice()).unwrap(); let parsed_entity = self.json_serializer.parse(type_ref, response_entity)?; Ok(parsed_entity) @@ -100,13 +80,23 @@ impl EntityClient { #[allow(clippy::unused_async)] pub async fn load_range( &self, - _type_ref: &TypeRef, - _list_id: &GeneratedId, - _start_id: &GeneratedId, - _count: usize, - _list_load_direction: ListLoadDirection, + type_ref: &TypeRef, + list_id: &GeneratedId, + start_id: &GeneratedId, + count: usize, + direction: ListLoadDirection, ) -> Result, ApiCallError> { - todo!("entity client load_range") + let type_model = self.get_type_model(type_ref)?; + assert_eq!(type_model.element_type, ElementType::ListElement); + // FIXME: validate parameters + let reverse = direction == ListLoadDirection::DESC; + // TODO: this is not the best way to build URL, we assume that everything is URL safe + // TODO: custom ids? are they fine? + let url = format!("{}/rest/{}/{}/{}?start_id={start_id}&count={count}&reverse={reverse}", self.base_url, type_ref.app, type_ref.type_, list_id); + let response_bytes = self.prepare_and_fire(type_ref, url).await?.expect("no body"); + let response_entities = serde_json::from_slice::>(response_bytes.as_slice()).expect("invalid response"); + let parsed_entities = response_entities.into_iter().map(|e| self.json_serializer.parse(type_ref, e)).collect::, _>>()?; + Ok(parsed_entities) } /// Stores a newly created entity/instance as a single element on the backend @@ -161,6 +151,27 @@ impl EntityClient { pub async fn erase_list_element(&self, _type_ref: &TypeRef, _id: IdTuple) -> Result<(), ApiCallError> { todo!("entity client erase_list_element") } + + async fn prepare_and_fire(&self, type_ref: &TypeRef, url: String) -> Result>, ApiCallError> { + let type_model = self.get_type_model(type_ref)?; + let model_version: u32 = type_model.version.parse().expect("invalid model_version"); + let options = RestClientOptions { + body: None, + headers: self.sdk_state.create_auth_headers(model_version), + }; + let response = self + .rest_client + .request_binary(url, HttpMethod::GET, options) + .await?; + let precondition = response.headers.get("precondition"); + match response.status { + 200..=299 => { + // Ok + } + _ => return Err(ApiCallError::ServerResponseError { source: HttpError::from_http_response(response.status, precondition)? }) + } + Ok(response.body) + } } #[cfg(test)] @@ -179,13 +190,13 @@ mockall::mock! { type_ref: &TypeRef, id: &Id, ) -> Result; - async fn load_all( + pub async fn load_all( &self, type_ref: &TypeRef, list_id: &IdTuple, start: Option, ) -> Result, ApiCallError>; - async fn load_range( + pub async fn load_range( &self, type_ref: &TypeRef, list_id: &GeneratedId, @@ -193,16 +204,16 @@ mockall::mock! { count: usize, list_load_direction: ListLoadDirection, ) -> Result, ApiCallError>; - async fn setup_element(&self, type_ref: &TypeRef, entity: RawEntity) -> Vec; - async fn setup_list_element( + pub async fn setup_element(&self, type_ref: &TypeRef, entity: RawEntity) -> Vec; + pub async fn setup_list_element( &self, type_ref: &TypeRef, list_id: &IdTuple, entity: RawEntity, ) -> Vec; - async fn update(&self, type_ref: &TypeRef, entity: ParsedEntity, model_version: u32) + pub async fn update(&self, type_ref: &TypeRef, entity: ParsedEntity, model_version: u32) -> Result<(), ApiCallError>; - async fn erase_element(&self, type_ref: &TypeRef, id: &GeneratedId) -> Result<(), ApiCallError>; - async fn erase_list_element(&self, type_ref: &TypeRef, id: IdTuple) -> Result<(), ApiCallError>; + pub async fn erase_element(&self, type_ref: &TypeRef, id: &GeneratedId) -> Result<(), ApiCallError>; + pub async fn erase_list_element(&self, type_ref: &TypeRef, id: IdTuple) -> Result<(), ApiCallError>; } } diff --git a/tuta-sdk/rust/sdk/src/key_loader_facade.rs b/tuta-sdk/rust/sdk/src/key_loader_facade.rs index bae59646c89d..893edee4829d 100644 --- a/tuta-sdk/rust/sdk/src/key_loader_facade.rs +++ b/tuta-sdk/rust/sdk/src/key_loader_facade.rs @@ -1,17 +1,17 @@ -use std::cmp::Ordering; -use std::sync::Arc; -use base64::Engine; -use futures::future::BoxFuture; use crate::crypto::key::{AsymmetricKeyPair, GenericAesKey, KeyLoadError}; use crate::crypto::key_encryption::decrypt_key_pair; use crate::entities::sys::{Group, GroupKey}; use crate::generated_id::GeneratedId; -use crate::ListLoadDirection; #[mockall_double::double] use crate::typed_entity_client::TypedEntityClient; #[mockall_double::double] use crate::user_facade::UserFacade; use crate::util::Versioned; +use crate::ListLoadDirection; +use base64::Engine; +use futures::future::BoxFuture; +use std::cmp::Ordering; +use std::sync::Arc; pub struct KeyLoaderFacade { user_facade: Arc, @@ -20,48 +20,76 @@ pub struct KeyLoaderFacade { #[cfg_attr(test, mockall::automock)] impl KeyLoaderFacade { - pub fn new( - user_facade: Arc, - entity_client: Arc, - ) -> Self { + pub fn new(user_facade: Arc, entity_client: Arc) -> Self { KeyLoaderFacade { user_facade, entity_client, } } - pub async fn load_sym_group_key(&self, group_id: &GeneratedId, version: i64, current_group_key: Option) -> Result { + /// Load the symmetric group key for the groupId with the provided requestedVersion. + /// `currentGroupKey` needs to be set if the user is not a member of the group (e.g. an admin) + pub async fn load_sym_group_key( + &self, + group_id: &GeneratedId, + version: i64, + current_group_key: Option, + ) -> Result { let group_key = match current_group_key { Some(n) => { let group_key_version = n.version; if group_key_version < version { + // we might not have the membership for this group. so the caller needs to handle it by refreshing the cache return Err(KeyLoadError { reason: format!("Provided current group key is too old (${group_key_version}) to load the requested version ${version} for group ${group_id}") }); } n } - None => self.get_current_sym_group_key(group_id).await? + None => self.get_current_sym_group_key(group_id).await?, }; if group_key.version == version { Ok(group_key.object) } else { + // FIXME: refresh if group_key.version < version let group: Group = self.entity_client.load(&group_id.to_owned()).await?; - let FormerGroupKey { symmetric_group_key, .. } = self.find_former_group_key(&group, &group_key, version).await?; + let FormerGroupKey { + symmetric_group_key, + .. + } = self + .find_former_group_key(&group, &group_key, version) + .await?; Ok(symmetric_group_key) } } - async fn find_former_group_key(&self, group: &Group, current_group_key: &VersionedAesKey, target_key_version: i64) -> Result { + async fn find_former_group_key( + &self, + group: &Group, + // FIXME why do we take it by ref if we are cloning it anyway + current_group_key: &VersionedAesKey, + target_key_version: i64, + ) -> Result { let list_id = group.formerGroupKeys.clone().unwrap().list; - let start_id = GeneratedId(base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(current_group_key.version.to_string())); - let amount_of_keys_including_target = (current_group_key.version - target_key_version) as usize; - - let former_keys: Vec = self.entity_client.load_range(&list_id, &start_id, amount_of_keys_including_target, ListLoadDirection::DESC).await?; + let start_id = GeneratedId( + base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(current_group_key.version.to_string()), + ); + let amount_of_keys_including_target = + (current_group_key.version - target_key_version) as usize; + + let former_keys: Vec = self + .entity_client + .load_range( + &list_id, + &start_id, + amount_of_keys_including_target, + ListLoadDirection::DESC, + ) + .await?; let VersionedAesKey { version: mut last_version, - object: mut last_group_key + object: mut last_group_key, } = current_group_key.to_owned(); let mut last_group_key_instance: Option = None; @@ -72,13 +100,21 @@ impl KeyLoaderFacade { let next_version = version + 1; match next_version.cmp(&last_version) { - Ordering::Less => return Err(KeyLoadError { reason: format!("Unexpected group key version {version}; expected {last_version}") }), + Ordering::Less => { + return Err(KeyLoadError { + reason: format!( + "Unexpected group key version {version}; expected {last_version}" + ), + }) + } Ordering::Greater => continue, Ordering::Equal => { last_version = version; - last_group_key = last_group_key.decrypt_aes_key(&former_key.ownerEncGKey).map_err(|e| { - KeyLoadError { reason: e.to_string() } - })?; + last_group_key = last_group_key + .decrypt_aes_key(&former_key.ownerEncGKey) + .map_err(|e| KeyLoadError { + reason: e.to_string(), + })?; last_group_key_instance = Some(former_key); if last_version <= target_key_version { break; @@ -91,20 +127,28 @@ impl KeyLoaderFacade { return Err(KeyLoadError { reason: format!("Could not get last version (last version is {last_version} of {retrieved_keys_count} key(s) loaded from list {list_id}") }); } - Ok(FormerGroupKey { symmetric_group_key: last_group_key, group_key_instance: last_group_key_instance.unwrap() }) + Ok(FormerGroupKey { + symmetric_group_key: last_group_key, + group_key_instance: last_group_key_instance.unwrap(), + }) } fn decode_group_key_version(&self, element_id: &GeneratedId) -> Result { - element_id.as_str().parse().map_err(|_| - KeyLoadError { - reason: format!("Failed to decode group key version: {}", element_id) - } - ) + element_id.as_str().parse().map_err(|_| KeyLoadError { + reason: format!("Failed to decode group key version: {}", element_id), + }) } - pub async fn get_current_sym_group_key(&self, group_id: &GeneratedId) -> Result { + pub async fn get_current_sym_group_key( + &self, + group_id: &GeneratedId, + ) -> Result { if *group_id == self.user_facade.get_user_group_id() { - return self.get_current_sym_user_group_key().ok_or_else(|| KeyLoadError { reason: "no current group key".to_owned() }); + return self + .get_current_sym_user_group_key() + .ok_or_else(|| KeyLoadError { + reason: "no current group key".to_owned(), + }); } if let Some(key) = self.user_facade.key_cache().get_current_group_key(group_id) { @@ -112,7 +156,10 @@ impl KeyLoaderFacade { } // The call leads to recursive calls down the chain, so BoxFuture is used to wrap the recursive async calls - fn get_key_for_version<'a>(facade: &'a KeyLoaderFacade, group_id: &'a GeneratedId) -> BoxFuture<'a, Result> { + fn get_key_for_version<'a>( + facade: &'a KeyLoaderFacade, + group_id: &'a GeneratedId, + ) -> BoxFuture<'a, Result> { Box::pin(facade.load_and_decrypt_current_sym_group_key(group_id)) } @@ -121,46 +168,84 @@ impl KeyLoaderFacade { Ok(key) } - async fn load_and_decrypt_current_sym_group_key(&self, group_id: &GeneratedId) -> Result { + /// `group_id` MUST NOT be the user group id + async fn load_and_decrypt_current_sym_group_key( + &self, + group_id: &GeneratedId, + ) -> Result { + assert_ne!(&self.user_facade.get_user_group_id(), group_id, "Must not add the user group to the regular group key cache"); let group_membership = self.user_facade.get_membership(group_id)?; - let required_user_group_key = self.load_sym_user_group_key(group_membership.symKeyVersion).await?; + let required_user_group_key = self + .load_sym_user_group_key(group_membership.symKeyVersion) + .await?; let version = group_membership.groupKeyVersion; - let object = required_user_group_key.decrypt_aes_key(&group_membership.symEncGKey).map_err(|e| { - KeyLoadError { reason: e.to_string() } - })?; + let object = required_user_group_key + .decrypt_aes_key(&group_membership.symEncGKey) + .map_err(|e| KeyLoadError { + reason: e.to_string(), + })?; Ok(VersionedAesKey { version, object }) } - async fn load_sym_user_group_key(&self, user_group_key_version: i64) -> Result { + async fn load_sym_user_group_key( + &self, + user_group_key_version: i64, + ) -> Result { + // FIXME: check for the version and refresh cache if needed self.load_sym_group_key( &self.user_facade.get_user_group_id(), user_group_key_version, - Some(self.user_facade.get_current_user_group_key().ok_or_else(|| KeyLoadError { reason: "No use group key loaded".to_string() })?), - ).await + Some( + self.user_facade + .get_current_user_group_key() + .ok_or_else(|| KeyLoadError { + reason: "No use group key loaded".to_string(), + })?, + ), + ) + .await } fn get_current_sym_user_group_key(&self) -> Option { self.user_facade.get_current_user_group_key() } - pub async fn load_key_pair(&self, key_pair_group_id: &GeneratedId, group_key_version: i64) -> Result { + pub async fn load_key_pair( + &self, + key_pair_group_id: &GeneratedId, + group_key_version: i64, + ) -> Result { let group: Group = self.entity_client.load(key_pair_group_id).await?; let group_key = self.get_current_sym_group_key(&group._id).await?; if group_key.version == group_key_version { return self.get_and_decrypt_key_pair(&group, &group_key.object); } - let FormerGroupKey { symmetric_group_key, group_key_instance: GroupKey { keyPair: key_pair, .. }, .. } = self.find_former_group_key(&group, &group_key, group_key_version).await?; + let FormerGroupKey { + symmetric_group_key, + group_key_instance: GroupKey { + keyPair: key_pair, .. + }, + .. + } = self + .find_former_group_key(&group, &group_key, group_key_version) + .await?; if let Some(key) = key_pair { decrypt_key_pair(&symmetric_group_key, &key) } else { Err(KeyLoadError { reason: format!("key pair not found for group {key_pair_group_id} and version {group_key_version}") }) } } - fn get_and_decrypt_key_pair(&self, group: &Group, group_key: &GenericAesKey) -> Result { + fn get_and_decrypt_key_pair( + &self, + group: &Group, + group_key: &GenericAesKey, + ) -> Result { match &group.currentKeys { Some(keys) => decrypt_key_pair(group_key, keys), - None => Err(KeyLoadError { reason: format!("no key pair on group {}", group._id) }) + None => Err(KeyLoadError { + reason: format!("no key pair on group {}", group._id), + }), } } } @@ -174,69 +259,86 @@ struct FormerGroupKey { #[cfg(test)] mod tests { - use std::array::from_fn; - use crate::IdTuple; - use crate::crypto::{Aes256Key, Iv, PQKeyPairs}; + use super::*; use crate::crypto::randomizer_facade::test_util::make_thread_rng_facade; + use crate::crypto::randomizer_facade::RandomizerFacade; + use crate::crypto::{Aes256Key, Iv, PQKeyPairs}; + use crate::custom_id::CustomId; use crate::entities::sys::{GroupKeysRef, GroupMembership, KeyPair}; use crate::key_cache::MockKeyCache; use crate::typed_entity_client::MockTypedEntityClient; use crate::user_facade::MockUserFacade; - use super::*; - use crate::util::test_utils::{generate_random_group, random_aes256_key}; - use mockall::{predicate}; - use crate::crypto::randomizer_facade::RandomizerFacade; - use crate::custom_id::CustomId; use crate::util::get_vec_reversed; + use crate::util::test_utils::{generate_random_group, random_aes256_key}; + use crate::IdTuple; + use mockall::predicate; + use std::array::from_fn; fn generate_group_key(version: i64) -> VersionedAesKey { - VersionedAesKey { object: random_aes256_key().into(), version } + VersionedAesKey { + object: random_aes256_key().into(), + version, + } } fn generate_group_data() -> (Group, VersionedAesKey) { - ( - generate_random_group(None, None), - generate_group_key(1) - ) + (generate_random_group(None, None), generate_group_key(1)) } - fn generate_group_with_keys(current_key_pair: &PQKeyPairs, current_group_key: &VersionedAesKey, randomizer_facade: &RandomizerFacade) -> Group { - let PQKeyPairs { ecc_keys, kyber_keys } = current_key_pair; + fn generate_group_with_keys( + current_key_pair: &PQKeyPairs, + current_group_key: &VersionedAesKey, + randomizer_facade: &RandomizerFacade, + ) -> Group { + let PQKeyPairs { + ecc_keys, + kyber_keys, + } = current_key_pair; let group_key = ¤t_group_key.object; - let sym_enc_priv_ecc_key = group_key.encrypt_data(ecc_keys.private_key.as_bytes(), Iv::generate(randomizer_facade)).unwrap(); - let sync_enc_priv_kyber_key = group_key.encrypt_data(&kyber_keys.private_key.serialize(), Iv::generate(randomizer_facade)).unwrap(); + let sym_enc_priv_ecc_key = group_key + .encrypt_data( + ecc_keys.private_key.as_bytes(), + Iv::generate(randomizer_facade), + ) + .unwrap(); + let sync_enc_priv_kyber_key = group_key + .encrypt_data( + &kyber_keys.private_key.serialize(), + Iv::generate(randomizer_facade), + ) + .unwrap(); generate_random_group( - Some( - KeyPair { - _id: Default::default(), - pubEccKey: Some(ecc_keys.public_key.as_bytes().to_vec()), - pubKyberKey: Some(kyber_keys.public_key.serialize()), - pubRsaKey: None, - symEncPrivEccKey: Some(sym_enc_priv_ecc_key), - symEncPrivKyberKey: Some(sync_enc_priv_kyber_key), - symEncPrivRsaKey: None, - } - ), - Some( - GroupKeysRef { - _id: Default::default(), - list: GeneratedId("list".to_owned()), // Refers to `former_keys` - } - ), + Some(KeyPair { + _id: Default::default(), + pubEccKey: Some(ecc_keys.public_key.as_bytes().to_vec()), + pubKyberKey: Some(kyber_keys.public_key.serialize()), + pubRsaKey: None, + symEncPrivEccKey: Some(sym_enc_priv_ecc_key), + symEncPrivKyberKey: Some(sync_enc_priv_kyber_key), + symEncPrivRsaKey: None, + }), + Some(GroupKeysRef { + _id: Default::default(), + list: GeneratedId("list".to_owned()), // Refers to `former_keys` + }), ) } const FORMER_KEYS: usize = 2; /// Returns `(former_keys, former_key_pairs_decrypted, former_keys_decrypted)` - fn generate_former_keys(current_group_key: &VersionedAesKey, randomizer_facade: &RandomizerFacade) -> ([GroupKey; FORMER_KEYS], [PQKeyPairs; FORMER_KEYS], [Aes256Key; FORMER_KEYS]) { + fn generate_former_keys( + current_group_key: &VersionedAesKey, + randomizer_facade: &RandomizerFacade, + ) -> ( + [GroupKey; FORMER_KEYS], + [PQKeyPairs; FORMER_KEYS], + [Aes256Key; FORMER_KEYS], + ) { // Using `from_fn` has the same performance as using mutable vecs but less memory usage - let former_keys_decrypted: [Aes256Key; FORMER_KEYS] = from_fn(|_| { - random_aes256_key() - }); - let former_key_pairs_decrypted: [PQKeyPairs; FORMER_KEYS] = from_fn(|_| { - PQKeyPairs::generate(&make_thread_rng_facade()) - }); + let former_keys_decrypted: [Aes256Key; FORMER_KEYS] = from_fn(|_| random_aes256_key()); + let former_key_pairs_decrypted: [PQKeyPairs; FORMER_KEYS] = + from_fn(|_| PQKeyPairs::generate(&make_thread_rng_facade())); let mut former_keys = Vec::with_capacity(FORMER_KEYS); let mut last_key = current_group_key.object.clone(); @@ -246,57 +348,81 @@ mod tests { // Get the previous key to use as the owner key let current_key: &GenericAesKey = ¤t_key.clone().into(); - let owner_enc_g_key = last_key.encrypt_key(current_key, Iv::generate(randomizer_facade)).as_slice().to_vec(); - let sym_enc_priv_ecc_key = current_key.encrypt_data( - pq_key_pair.ecc_keys.private_key.clone().as_bytes(), - Iv::generate(randomizer_facade)).unwrap(); - - former_keys.insert(0, GroupKey { - _format: 0, - _id: IdTuple { - list_id: GeneratedId("list".to_owned()), - element_id: GeneratedId(i.to_string()), + let owner_enc_g_key = last_key + .encrypt_key(current_key, Iv::generate(randomizer_facade)) + .as_slice() + .to_vec(); + let sym_enc_priv_ecc_key = current_key + .encrypt_data( + pq_key_pair.ecc_keys.private_key.clone().as_bytes(), + Iv::generate(randomizer_facade), + ) + .unwrap(); + + former_keys.insert( + 0, + GroupKey { + _format: 0, + _id: IdTuple { + list_id: GeneratedId("list".to_owned()), + element_id: GeneratedId(i.to_string()), + }, + _ownerGroup: None, + _permissions: Default::default(), + adminGroupEncGKey: None, + adminGroupKeyVersion: None, + ownerEncGKey: owner_enc_g_key, + ownerKeyVersion: 0, + pubAdminGroupEncGKey: None, + keyPair: Some(KeyPair { + _id: Default::default(), + pubEccKey: Some(pq_key_pair.ecc_keys.public_key.as_bytes().to_vec()), + pubKyberKey: Some(pq_key_pair.kyber_keys.public_key.serialize()), + pubRsaKey: None, + symEncPrivEccKey: Some(sym_enc_priv_ecc_key), + symEncPrivKyberKey: Some( + current_key + .encrypt_data( + pq_key_pair.kyber_keys.private_key.serialize().as_slice(), + Iv::generate(randomizer_facade), + ) + .unwrap(), + ), + symEncPrivRsaKey: None, + }), }, - _ownerGroup: None, - _permissions: Default::default(), - adminGroupEncGKey: None, - adminGroupKeyVersion: None, - ownerEncGKey: owner_enc_g_key, - ownerKeyVersion: 0, - pubAdminGroupEncGKey: None, - keyPair: Some(KeyPair { - _id: Default::default(), - pubEccKey: Some(pq_key_pair.ecc_keys.public_key.as_bytes().to_vec()), - pubKyberKey: Some(pq_key_pair.kyber_keys.public_key.serialize()), - pubRsaKey: None, - symEncPrivEccKey: Some(sym_enc_priv_ecc_key - ), - symEncPrivKyberKey: Some(current_key.encrypt_data( - pq_key_pair.kyber_keys.private_key.serialize().as_slice(), - Iv::generate(randomizer_facade)).unwrap() - ), - symEncPrivRsaKey: None, - }), - }); + ); last_key = current_key.clone(); } - (former_keys.try_into().unwrap_or_else(|_| panic!()), former_key_pairs_decrypted, former_keys_decrypted) + ( + former_keys.try_into().unwrap_or_else(|_| panic!()), + former_key_pairs_decrypted, + former_keys_decrypted, + ) } - fn make_mocks_with_former_keys(group: &Group, current_group_key: &VersionedAesKey, randomizer: &RandomizerFacade, former_keys: &[GroupKey; FORMER_KEYS]) -> KeyLoaderFacade { - let (user_facade_mock, mut typed_entity_client_mock) = make_mocks(group, current_group_key, randomizer); + fn make_mocks_with_former_keys( + group: &Group, + current_group_key: &VersionedAesKey, + randomizer: &RandomizerFacade, + former_keys: &[GroupKey; FORMER_KEYS], + ) -> KeyLoaderFacade { + let (user_facade_mock, mut typed_entity_client_mock) = + make_mocks(group, current_group_key, randomizer); { for i in 0..FORMER_KEYS { let group = group.clone(); let former_keys = former_keys.clone(); let returned_keys = get_vec_reversed(former_keys[i..].to_vec()); - typed_entity_client_mock.expect_load_range::() + typed_entity_client_mock + .expect_load_range::() .with( predicate::eq(group.formerGroupKeys.unwrap().list), predicate::eq(GeneratedId( - base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(current_group_key.version.to_string()) + base64::prelude::BASE64_URL_SAFE_NO_PAD + .encode(current_group_key.version.to_string()), )), predicate::eq(FORMER_KEYS - i), predicate::eq(ListLoadDirection::DESC), @@ -311,56 +437,74 @@ mod tests { ) } - fn make_mocks(group: &Group, current_group_key: &VersionedAesKey, randomizer: &RandomizerFacade) -> (MockUserFacade, MockTypedEntityClient) { + fn make_mocks( + group: &Group, + current_group_key: &VersionedAesKey, + randomizer: &RandomizerFacade, + ) -> (MockUserFacade, MockTypedEntityClient) { let user_group_key = generate_group_key(0); let user_group = generate_random_group(None, None); let mut user_facade_mock = MockUserFacade::default(); { let user_group_key = user_group_key.clone(); - user_facade_mock.expect_get_current_user_group_key() + user_facade_mock + .expect_get_current_user_group_key() .returning(move || Some(user_group_key.clone())); } { let current_group_key = current_group_key.clone(); let mut key_cache_mock = MockKeyCache::default(); - key_cache_mock.expect_get_current_group_key() + key_cache_mock + .expect_get_current_group_key() .returning(move |_| Some(current_group_key.clone())); let key_cache = Arc::new(key_cache_mock); - user_facade_mock.expect_key_cache() + user_facade_mock + .expect_key_cache() .returning(move || key_cache.clone()); } { let user_group_id = user_group._id.clone(); - let sym_enc_g_key = user_group_key.object.encrypt_key( - ¤t_group_key.object, - Iv::generate(randomizer), - ); + let sym_enc_g_key = user_group_key + .object + .encrypt_key(¤t_group_key.object, Iv::generate(randomizer)); let current_group_key = current_group_key.clone(); - user_facade_mock.expect_get_membership() + user_facade_mock + .expect_get_membership() .with(predicate::eq(user_group_id.clone())) - .returning(move |_| Ok(GroupMembership { - _id: CustomId(user_group_id.clone().to_string()), - admin: false, - capability: None, - groupKeyVersion: current_group_key.clone().version, - groupType: None, - symEncGKey: sym_enc_g_key.clone(), - symKeyVersion: user_group_key.version, - group: user_group_id.clone(), - groupInfo: IdTuple { list_id: Default::default(), element_id: Default::default() }, - groupMember: IdTuple { list_id: Default::default(), element_id: Default::default() }, - })); + .returning(move |_| { + Ok(GroupMembership { + _id: CustomId(user_group_id.clone().to_string()), + admin: false, + capability: None, + groupKeyVersion: current_group_key.clone().version, + groupType: None, + symEncGKey: sym_enc_g_key.clone(), + symKeyVersion: user_group_key.version, + group: user_group_id.clone(), + groupInfo: IdTuple { + list_id: Default::default(), + element_id: Default::default(), + }, + groupMember: IdTuple { + list_id: Default::default(), + element_id: Default::default(), + }, + }) + }); } { let user_group_id = user_group._id.clone(); - user_facade_mock.expect_get_user_group_id().returning(move || user_group_id.clone()); + user_facade_mock + .expect_get_user_group_id() + .returning(move || user_group_id.clone()); } let mut typed_entity_client_mock = MockTypedEntityClient::default(); { let group = group.clone(); - typed_entity_client_mock.expect_load::() + typed_entity_client_mock + .expect_load::() .with(predicate::eq(group._id.clone())) .returning(move |_| Ok(group.clone())); } @@ -374,24 +518,35 @@ mod tests { let mut user_facade_mock = MockUserFacade::default(); { let user_group = user_group.clone(); - user_facade_mock.expect_get_user_group_id().returning(move || user_group._id.clone()); + user_facade_mock + .expect_get_user_group_id() + .returning(move || user_group._id.clone()); } { let user_group_key = user_group_key.clone(); - user_facade_mock.expect_get_current_user_group_key() + user_facade_mock + .expect_get_current_user_group_key() .returning(move || Some(user_group_key.clone())) .times(2); } let typed_entity_client_mock = MockTypedEntityClient::default(); - let key_loader_facade = KeyLoaderFacade::new(Arc::new(user_facade_mock), Arc::new(typed_entity_client_mock)); + let key_loader_facade = KeyLoaderFacade::new( + Arc::new(user_facade_mock), + Arc::new(typed_entity_client_mock), + ); - let current_user_group_key = key_loader_facade.get_current_sym_group_key(&user_group._id).await.unwrap(); + let current_user_group_key = key_loader_facade + .get_current_sym_group_key(&user_group._id) + .await + .unwrap(); assert_eq!(current_user_group_key.version, user_group.groupKeyVersion); assert_eq!(current_user_group_key.object, user_group_key.object); - let _ = key_loader_facade.get_current_sym_group_key(&user_group._id).await;// should not be cached + let _ = key_loader_facade + .get_current_sym_group_key(&user_group._id) + .await; // should not be cached } #[tokio::test] @@ -401,19 +556,28 @@ mod tests { let mut user_facade_mock = MockUserFacade::default(); { let user_group = group.clone(); - user_facade_mock.expect_get_user_group_id().returning(move || user_group._id.clone()); + user_facade_mock + .expect_get_user_group_id() + .returning(move || user_group._id.clone()); } { let user_group_key = current_group_key.clone(); - user_facade_mock.expect_get_current_user_group_key() + user_facade_mock + .expect_get_current_user_group_key() .returning(move || Some(user_group_key.clone())); } let typed_entity_client_mock = MockTypedEntityClient::default(); - let key_loader_facade = KeyLoaderFacade::new(Arc::new(user_facade_mock), Arc::new(typed_entity_client_mock)); + let key_loader_facade = KeyLoaderFacade::new( + Arc::new(user_facade_mock), + Arc::new(typed_entity_client_mock), + ); - let group_key = key_loader_facade.get_current_sym_group_key(&group._id).await.unwrap(); + let group_key = key_loader_facade + .get_current_sym_group_key(&group._id) + .await + .unwrap(); assert_eq!(group_key.version, group.groupKeyVersion); assert_eq!(group_key.object, current_group_key.object) } @@ -429,12 +593,17 @@ mod tests { let group = generate_group_with_keys(¤t_key_pair, ¤t_group_key, &randomizer); - let (former_keys, former_key_pairs_decrypted, _) = generate_former_keys(¤t_group_key, &randomizer); + let (former_keys, former_key_pairs_decrypted, _) = + generate_former_keys(¤t_group_key, &randomizer); - let key_loader_facade = make_mocks_with_former_keys(&group, ¤t_group_key, &randomizer, &former_keys); + let key_loader_facade = + make_mocks_with_former_keys(&group, ¤t_group_key, &randomizer, &former_keys); for i in 0..FORMER_KEYS { - let keypair = key_loader_facade.load_key_pair(&group._id, i as i64).await.unwrap(); + let keypair = key_loader_facade + .load_key_pair(&group._id, i as i64) + .await + .unwrap(); match keypair { AsymmetricKeyPair::RSAKeyPair(_) => panic!("key_loader_facade.load_key_pair() returned an RSAKeyPair! Expected PQKeyPairs."), AsymmetricKeyPair::RSAEccKeyPair(_) => panic!("key_loader_facade.load_key_pair() returned an RSAEccKeyPair! Expected PQKeyPairs."), @@ -450,20 +619,19 @@ mod tests { let user_group_key = generate_group_key(1); let randomizer = make_thread_rng_facade(); let current_key_pair = PQKeyPairs::generate(&randomizer); - let user_group = generate_group_with_keys( - ¤t_key_pair, - &user_group_key, - &randomizer, - ); + let user_group = generate_group_with_keys(¤t_key_pair, &user_group_key, &randomizer); let mut user_facade_mock = MockUserFacade::default(); { let user_group_id = user_group._id.clone(); - user_facade_mock.expect_get_user_group_id().returning(move || user_group_id.clone()); + user_facade_mock + .expect_get_user_group_id() + .returning(move || user_group_id.clone()); } { let user_group_key = user_group_key.clone(); - user_facade_mock.expect_get_current_user_group_key() + user_facade_mock + .expect_get_current_user_group_key() .returning(move || Some(user_group_key.clone())); } @@ -471,14 +639,19 @@ mod tests { { let user_group = user_group.clone(); let group_id = user_group._id.clone(); - typed_entity_client_mock.expect_load::() + typed_entity_client_mock + .expect_load::() .withf(move |id| *id == group_id.clone()) .returning(move |_| Ok(user_group.clone())); } - let key_loader_facade = KeyLoaderFacade::new(Arc::new(user_facade_mock), Arc::new(typed_entity_client_mock)); + let key_loader_facade = KeyLoaderFacade::new( + Arc::new(user_facade_mock), + Arc::new(typed_entity_client_mock), + ); - let loaded_current_key_pair = key_loader_facade.load_key_pair(&user_group._id, user_group.groupKeyVersion) + let loaded_current_key_pair = key_loader_facade + .load_key_pair(&user_group._id, user_group.groupKeyVersion) .await .unwrap(); @@ -498,14 +671,19 @@ mod tests { // Same as the length of former_keys_deprecated let current_group_key_version = FORMER_KEYS as i64; let current_group_key = generate_group_key(current_group_key_version); - let (former_keys, _, former_keys_decrypted) = generate_former_keys(¤t_group_key, &randomizer); + let (former_keys, _, former_keys_decrypted) = + generate_former_keys(¤t_group_key, &randomizer); let current_key_pair = PQKeyPairs::generate(&randomizer); let group = generate_group_with_keys(¤t_key_pair, ¤t_group_key, &randomizer); - let key_loader_facade = make_mocks_with_former_keys(&group, ¤t_group_key, &randomizer, &former_keys); + let key_loader_facade = + make_mocks_with_former_keys(&group, ¤t_group_key, &randomizer, &former_keys); for i in 0..FORMER_KEYS { - let keypair = key_loader_facade.load_sym_group_key(&group._id, i as i64, None).await.unwrap(); + let keypair = key_loader_facade + .load_sym_group_key(&group._id, i as i64, None) + .await + .unwrap(); match keypair { GenericAesKey::Aes128(_) => panic!("key_loader_facade.load_sym_group_key() returned an AES128 key! Expected an AES256 key."), GenericAesKey::Aes256(returned_group_key) => { @@ -526,14 +704,17 @@ mod tests { let current_key_pair = PQKeyPairs::generate(&randomizer); let group = generate_group_with_keys(¤t_key_pair, ¤t_group_key, &randomizer); - - let (user_facade_mock, typed_entity_client_mock) = make_mocks(&group, ¤t_group_key, &randomizer); + let (user_facade_mock, typed_entity_client_mock) = + make_mocks(&group, ¤t_group_key, &randomizer); let key_loader_facade = KeyLoaderFacade::new( Arc::new(user_facade_mock), Arc::new(typed_entity_client_mock), ); - let returned_key = key_loader_facade.load_sym_group_key(&group._id, current_group_key_version, None).await.unwrap(); + let returned_key = key_loader_facade + .load_sym_group_key(&group._id, current_group_key_version, None) + .await + .unwrap(); assert_eq!(returned_key, current_group_key.object) } @@ -553,7 +734,9 @@ mod tests { let outdated_current_group_key_version = current_group_key_version - 1; let outdated_current_group_key = VersionedAesKey { - object: former_keys_decrypted[outdated_current_group_key_version as usize].clone().into(), + object: former_keys_decrypted[outdated_current_group_key_version as usize] + .clone() + .into(), version: outdated_current_group_key_version, }; @@ -565,10 +748,13 @@ mod tests { Arc::new(typed_entity_client_mock), ); - key_loader_facade.load_sym_group_key( - &group._id, - current_group_key_version, - Some(outdated_current_group_key), - ).await.expect_err("Did not error with outdated group key"); + key_loader_facade + .load_sym_group_key( + &group._id, + current_group_key_version, + Some(outdated_current_group_key), + ) + .await + .expect_err("Did not error with outdated group key"); } } diff --git a/tuta-sdk/rust/sdk/src/metamodel.rs b/tuta-sdk/rust/sdk/src/metamodel.rs index cc78ac40b8b1..d29e0e2fb8f1 100644 --- a/tuta-sdk/rust/sdk/src/metamodel.rs +++ b/tuta-sdk/rust/sdk/src/metamodel.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use serde::Deserialize; /// A kind of element that can appear in the model -#[derive(Deserialize, PartialEq, Clone)] +#[derive(Deserialize, PartialEq, Clone, Debug)] pub enum ElementType { /// Entity referenced by a single id #[serde(rename = "ELEMENT_TYPE")] diff --git a/tuta-sdk/rust/sdk/src/typed_entity_client.rs b/tuta-sdk/rust/sdk/src/typed_entity_client.rs index f20e014a6dc9..f39b56cd2819 100644 --- a/tuta-sdk/rust/sdk/src/typed_entity_client.rs +++ b/tuta-sdk/rust/sdk/src/typed_entity_client.rs @@ -7,6 +7,7 @@ use crate::entity_client::EntityClient; use crate::entity_client::IdType; use crate::generated_id::GeneratedId; use crate::instance_mapper::InstanceMapper; +use crate::metamodel::{ElementType, TypeModel}; pub struct TypedEntityClient { entity_client: Arc, @@ -32,14 +33,10 @@ impl TypedEntityClient { id: &Id, ) -> Result { let type_model = self.entity_client.get_type_model(&T::type_ref())?; - if type_model.is_encrypted() { - return Err(ApiCallError::InternalSdkError { - error_message: format!("This client shall not handle encrypted fields! Entity: app: {}, name: {}", &T::type_ref().app, &T::type_ref().type_) - }); - } + Self::check_if_encrypted(type_model)?; let parsed_entity = self.entity_client.load::(&T::type_ref(), id).await?; let typed_entity = self.instance_mapper.parse_entity::(parsed_entity).map_err(|e| { - let message = format!("Failed to parse entity into proper types: {e}"); + let message = format!("Failed to parse entity: {e}"); ApiCallError::InternalSdkError { error_message: message } })?; Ok(typed_entity) @@ -51,10 +48,34 @@ impl TypedEntityClient { todo!("typed entity client load_all") } - #[allow(dead_code)] #[allow(clippy::unused_async)] - pub async fn load_range>(&self, _list_id: &GeneratedId, _start_id: &GeneratedId, _count: usize, _list_load_direction: ListLoadDirection) -> Result, ApiCallError> { - todo!("typed entity client load_range") + pub async fn load_range>(&self, list_id: &GeneratedId, start_id: &GeneratedId, count: usize, direction: ListLoadDirection) -> Result, ApiCallError> { + let type_model = self.entity_client.get_type_model(&T::type_ref())?; + Self::check_if_encrypted(type_model)?; + // TODO: enforce statically? + if type_model.element_type != ElementType::ListElement { + panic!("load_range for non-list type {}/{}", type_model.app, type_model.name) + } + let entities = self.entity_client.load_range(&T::type_ref(), list_id, start_id, count, direction).await?; + let typed_entities = entities + .into_iter() + .map(|e| self.instance_mapper.parse_entity(e)) + .collect::, _>>() + .map_err(|e| { + let message = format!("Failed to parse entity: {e}"); + ApiCallError::InternalSdkError { error_message: message } + })?; + Ok(typed_entities) + } + + // TODO: enforce statically? + fn check_if_encrypted(type_model: &TypeModel) -> Result<(), ApiCallError> { + if type_model.is_encrypted() { + return Err(ApiCallError::InternalSdkError { + error_message: format!("This client shall not handle encrypted fields! Entity: app: {}, name: {}", type_model.app, type_model.name) + }); + } + Ok(()) } }