Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multithreaded crypto #170

Merged
merged 18 commits into from
Mar 23, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions matrix_sdk_crypto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ serde_json = "1.0.61"
zeroize = { version = "1.2.0", features = ["zeroize_derive"] }

# Misc dependencies
tokio = { version = "1.1.0", default-features = false, features = ["rt", "rt-multi-thread"] }
futures = "0.3.12"
sled = { version = "0.34.6", optional = true }
thiserror = "1.0.23"
tracing = "0.1.22"
Expand All @@ -44,14 +46,13 @@ byteorder = "1.4.2"

[dev-dependencies]
tokio = { version = "1.1.0", default-features = false, features = ["rt-multi-thread", "macros"] }
futures = "0.3.12"
proptest = "0.10.1"
serde_json = "1.0.61"
tempfile = "3.2.0"
http = "0.2.3"
matrix-sdk-test = { version = "0.2.0", path = "../matrix_sdk_test" }
indoc = "1.0.3"
criterion = { version = "0.3.4", features = ["async", "async_futures", "html_reports"] }
criterion = { version = "0.3.4", features = ["async", "async_tokio", "html_reports"] }

[target.'cfg(target_os = "linux")'.dev-dependencies]
pprof = { version = "0.4.2", features = ["flamegraph"] }
Expand Down
18 changes: 13 additions & 5 deletions matrix_sdk_crypto/benches/crypto_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod perf;

use std::convert::TryFrom;

use criterion::{async_executor::FuturesExecutor, *};
use criterion::*;

use futures::executor::block_on;
use matrix_sdk_common::{
Expand All @@ -17,6 +17,7 @@ use matrix_sdk_common::{
use matrix_sdk_crypto::{EncryptionSettings, OlmMachine};
use matrix_sdk_test::response_from_file;
use serde_json::Value;
use tokio::runtime::Builder;

fn alice_id() -> UserId {
user_id!("@alice:example.org")
Expand All @@ -41,6 +42,9 @@ fn keys_claim_response() -> claim_keys::Response {
}

pub fn keys_query(c: &mut Criterion) {
let runtime = Builder::new_multi_thread()
.build()
.expect("Can't create runtime");
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
let response = keys_query_response();
let uuid = Uuid::new_v4();
Expand All @@ -62,7 +66,7 @@ pub fn keys_query(c: &mut Criterion) {
BenchmarkId::new("memory store", &name),
&response,
|b, response| {
b.to_async(FuturesExecutor)
b.to_async(&runtime)
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
},
);
Expand All @@ -80,7 +84,7 @@ pub fn keys_query(c: &mut Criterion) {
BenchmarkId::new("sled store", &name),
&response,
|b, response| {
b.to_async(FuturesExecutor)
b.to_async(&runtime)
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
},
);
Expand Down Expand Up @@ -147,6 +151,10 @@ pub fn keys_claiming(c: &mut Criterion) {
}

pub fn room_key_sharing(c: &mut Criterion) {
let runtime = Builder::new_multi_thread()
.build()
.expect("Can't create runtime");

let keys_query_response = keys_query_response();
let uuid = Uuid::new_v4();
let response = keys_claim_response();
Expand All @@ -169,7 +177,7 @@ pub fn room_key_sharing(c: &mut Criterion) {
let name = format!("{} devices", count);

group.bench_function(BenchmarkId::new("memory store", &name), |b| {
b.to_async(FuturesExecutor).iter(|| async {
b.to_async(&runtime).iter(|| async {
let requests = machine
.share_group_session(&room_id, users.iter(), EncryptionSettings::default())
.await
Expand Down Expand Up @@ -200,7 +208,7 @@ pub fn room_key_sharing(c: &mut Criterion) {
block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();

group.bench_function(BenchmarkId::new("sled store", &name), |b| {
b.to_async(FuturesExecutor).iter(|| async {
b.to_async(&runtime).iter(|| async {
let requests = machine
.share_group_session(&room_id, users.iter(), EncryptionSettings::default())
.await
Expand Down
97 changes: 70 additions & 27 deletions matrix_sdk_crypto/src/session_manager/group_sessions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ use std::{
sync::Arc,
};

use futures::future::join_all;

use tokio;

use dashmap::DashMap;
use matrix_sdk_common::{
api::r0::to_device::DeviceIdOrAllDevices,
Expand Down Expand Up @@ -188,35 +192,52 @@ impl GroupSessionManager {
/// Encrypt the given content for the given devices and create a to-device
/// requests that sends the encrypted content to them.
async fn encrypt_session_for(
&self,
content: Value,
devices: &[Device],
devices: Vec<Device>,
) -> OlmResult<(Uuid, ToDeviceRequest, Vec<Session>)> {
let mut messages = BTreeMap::new();
let mut changed_sessions = Vec::new();

for device in devices {
let encrypt = |device: Device, content: Value| async move {
let mut message = BTreeMap::new();

let encrypted = device.encrypt(EventType::RoomKey, content.clone()).await;

let (used_session, encrypted) = match encrypted {
Ok(c) => c,
let used_session = match encrypted {
Ok((session, encrypted)) => {
message
.entry(device.user_id().clone())
.or_insert_with(BTreeMap::new)
.insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
serde_json::value::to_raw_value(&encrypted)?,
);
Some(session)
}
// TODO we'll want to create m.room_key.withheld here.
Err(OlmError::MissingSession)
| Err(OlmError::EventError(EventError::MissingSenderKey)) => {
continue;
}
| Err(OlmError::EventError(EventError::MissingSenderKey)) => None,
Err(e) => return Err(e),
};

changed_sessions.push(used_session);
Ok((used_session, message))
};

messages
.entry(device.user_id().clone())
.or_insert_with(BTreeMap::new)
.insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
serde_json::value::to_raw_value(&encrypted)?,
);
let tasks: Vec<_> = devices
.iter()
.map(|d| tokio::spawn(encrypt(d.clone(), content.clone())))
.collect();

let results = join_all(tasks).await;

for result in results {
let (used_session, message) = result.expect("Encryption task paniced")?;
poljar marked this conversation as resolved.
Show resolved Hide resolved

if let Some(session) = used_session {
changed_sessions.push(session);
}

messages.extend(message);
}

let id = Uuid::new_v4();
Expand Down Expand Up @@ -334,6 +355,24 @@ impl GroupSessionManager {
Ok((should_rotate, devices))
}

pub async fn encrypt_request(
chunk: Vec<Device>,
content: Value,
outbound: OutboundGroupSession,
message_index: u32,
being_shared: Arc<DashMap<Uuid, OutboundGroupSession>>,
) -> OlmResult<Vec<Session>> {
let (id, request, used_sessions) =
Self::encrypt_session_for(content.clone(), chunk).await?;

if !request.messages.is_empty() {
outbound.add_request(id, request.into(), message_index);
being_shared.insert(id, outbound.clone());
}

Ok(used_sessions)
}

/// Get to-device requests to share a group session with users in a room.
///
/// # Arguments
Expand Down Expand Up @@ -427,18 +466,22 @@ impl GroupSessionManager {
);
}

for device_map_chunk in devices.chunks(Self::MAX_TO_DEVICE_MESSAGES) {
let (id, request, used_sessions) = self
.encrypt_session_for(key_content.clone(), device_map_chunk)
.await?;

if !request.messages.is_empty() {
outbound.add_request(id, request.into(), message_index);
self.outbound_sessions_being_shared
.insert(id, outbound.clone());
}
let tasks: Vec<_> = devices
.chunks(Self::MAX_TO_DEVICE_MESSAGES)
.map(|chunk| {
tokio::spawn(Self::encrypt_request(
chunk.to_vec(),
key_content.clone(),
outbound.clone(),
message_index,
self.outbound_sessions_being_shared.clone(),
))
})
.collect();

changes.sessions.extend(used_sessions);
for result in join_all(tasks).await {
let used_sessions: OlmResult<Vec<Session>> = result.expect("Encryption task paniced");
changes.sessions.extend(used_sessions?);
}

let requests = outbound.pending_requests();
Expand Down