From 8fb604cbaa9b22bd05b58867a90e36d73cc5b9b7 Mon Sep 17 00:00:00 2001 From: mendess Date: Thu, 12 Oct 2023 14:17:18 +0100 Subject: [PATCH] Shard aggregate store in order to not reach durable object value size limits --- daphne/src/lib.rs | 6 +- daphne/src/vdaf/mod.rs | 2 +- daphne_worker/src/durable/aggregate_store.rs | 290 ++++++++++++++++++- 3 files changed, 283 insertions(+), 15 deletions(-) diff --git a/daphne/src/lib.rs b/daphne/src/lib.rs index f12d6b001..a7d81e0b4 100644 --- a/daphne/src/lib.rs +++ b/daphne/src/lib.rs @@ -591,11 +591,11 @@ pub struct DapOutputShare { pub struct DapAggregateShare { /// Number of reports in the batch. pub report_count: u64, - pub(crate) min_time: Time, - pub(crate) max_time: Time, + pub min_time: Time, + pub max_time: Time, /// Batch checkusm. pub checksum: [u8; 32], - pub(crate) data: Option, + pub data: Option, } impl DapAggregateShare { diff --git a/daphne/src/vdaf/mod.rs b/daphne/src/vdaf/mod.rs index 952d7cf81..7f45e6372 100644 --- a/daphne/src/vdaf/mod.rs +++ b/daphne/src/vdaf/mod.rs @@ -412,7 +412,7 @@ impl ParameterizedDecode for VdafPrepMessage { #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(rename_all = "snake_case")] -pub(crate) enum VdafAggregateShare { +pub enum VdafAggregateShare { Field64(prio::vdaf::AggregateShare), Field128(prio::vdaf::AggregateShare), FieldPrio2(prio::vdaf::AggregateShare), diff --git a/daphne_worker/src/durable/aggregate_store.rs b/daphne_worker/src/durable/aggregate_store.rs index f3c6467f6..886cebed1 100644 --- a/daphne_worker/src/durable/aggregate_store.rs +++ b/daphne_worker/src/durable/aggregate_store.rs @@ -1,6 +1,7 @@ // Copyright (c) 2022 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause +use core::slice; use std::ops::ControlFlow; use crate::{ @@ -8,9 +9,17 @@ use crate::{ durable::{create_span_from_request, state_get_or_default, BINDING_DAP_AGGREGATE_STORE}, initialize_tracing, int_err, }; -use daphne::DapAggregateShare; +use daphne::{messages::Time, vdaf::VdafAggregateShare, DapAggregateShare}; +use prio::{ + codec::Encode, + field::FieldElement, + vdaf::{AggregateShare, OutputShare}, +}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; use tracing::Instrument; -use worker::*; +use tracing::{debug_span, instrument}; +use worker::{wasm_bindgen::JsValue, *}; use super::{req_parse, DapDurableObject, GarbageCollectable}; @@ -47,6 +56,195 @@ pub struct AggregateStore { collected: Option, } +const MAX_CHUNK_KEY_COUNT: usize = 21; +const METADATA_KEY: &str = "meta"; + +#[derive(Debug, Serialize, Deserialize, Clone, Copy)] +#[serde(rename_all = "snake_case")] +enum VdafKind { + Field64, + Field128, + FieldPrio2, +} + +#[derive(Debug, Serialize, Deserialize)] +struct DapAggregateShareMetadata { + report_count: u64, + min_time: Time, + max_time: Time, + checksum: [u8; 32], + kind: Option, +} + +impl DapAggregateShareMetadata { + fn from_agg_share( + share: DapAggregateShare, + ) -> (Self, Option) { + let this = Self { + report_count: share.report_count, + min_time: share.min_time, + max_time: share.max_time, + checksum: share.checksum, + kind: share.data.as_ref().map(|data| match data { + daphne::vdaf::VdafAggregateShare::Field64(_) => VdafKind::Field64, + daphne::vdaf::VdafAggregateShare::Field128(_) => VdafKind::Field128, + daphne::vdaf::VdafAggregateShare::FieldPrio2(_) => VdafKind::FieldPrio2, + }), + }; + + (this, share.data) + } + + fn into_agg_share_with_data(self, data: daphne::vdaf::VdafAggregateShare) -> DapAggregateShare { + DapAggregateShare { + data: Some(data), + ..self.into_agg_share_without_data() + } + } + + fn into_agg_share_without_data(self) -> DapAggregateShare { + DapAggregateShare { + report_count: self.report_count, + min_time: self.min_time, + max_time: self.max_time, + checksum: self.checksum, + data: None, + } + } +} + +fn js_map_to_chunks(keys: &[String], map: js_sys::Map) -> Vec { + keys.iter() + .map(|k| JsValue::from_str(k)) + .filter(|k| map.has(k)) + .map(|k| map.get(&k)) + .flat_map(|js_v| { + serde_wasm_bindgen::from_value::>(js_v).expect("expect an array of bytes") + }) + .collect() +} + +impl AggregateStore { + fn agg_share_shard_keys(&self) -> Vec { + (0..MAX_CHUNK_KEY_COUNT) + .map(|n| format!("chunk_v2_{n:03}")) + .collect() + } + + // The legacy format always uses 21 chunks and uses bincode to serialize the entire structure, + // it was used during incident-4002 resolution. + // + // This method implements the old deserialization method so that old data can still be used, + // the serialization in this format is not longer done, meaning this method can eventually be + // removed. + #[instrument(skip(self, keys))] + async fn legacy_get_agg_share( + &self, + keys: &[String], + meta: DapAggregateShareMetadata, + ) -> Result { + let keys = keys + .iter() + .map(|k| k.replace("_v2", "")) + .collect::>(); + + let values = self.state.storage().get_multiple(keys.clone()).await?; + let chunks = js_map_to_chunks(&keys, values); + Ok(if chunks.is_empty() { + meta.into_agg_share_without_data() + } else { + let kind = meta.kind.expect("if there is data there should be a type"); + + fn from_slice(chunks: &[u32]) -> Result> { + let chunks = unsafe { + // SAFETY + // This conversion can be done the alignment of u8 is 1, which is means a + // reference to a u8 can never be misaligned. + // + // We also know the pointer is valid (it came from a reference) and we know + // we're not introducing mutable aliasing because we are not creating mutable + // references + slice::from_raw_parts( + chunks.as_ptr() as *const u8, + std::mem::size_of_val(chunks), + ) + }; + let share = T::byte_slice_into_vec(chunks).map_err(|e| { + worker::Error::Internal( + serde_wasm_bindgen::to_value(&e.to_string()) + .expect("string never fails to convert to JsValue"), + ) + })?; + // TODO: this an abuse of this API, this type should not be constructed this way. + Ok(AggregateShare::from(OutputShare::from(share))) + } + + let data = match kind { + VdafKind::Field64 => VdafAggregateShare::Field64(from_slice(&chunks)?), + VdafKind::Field128 => VdafAggregateShare::Field128(from_slice(&chunks)?), + VdafKind::FieldPrio2 => VdafAggregateShare::FieldPrio2(from_slice(&chunks)?), + }; + + meta.into_agg_share_with_data(data) + }) + } + + #[instrument(skip(self, keys))] + async fn get_agg_share(&self, keys: &[String]) -> Result { + let all_keys = keys + .iter() + .map(String::as_str) + .chain([METADATA_KEY]) + .collect::>(); + let values = self.state.storage().get_multiple(all_keys).await?; + tracing::debug!(len = values.size(), "FOUND VALUES"); + + if values.size() == 0 { + return Ok(DapAggregateShare::default()); + } + + let meta_key = JsValue::from_str("meta"); + let meta = + serde_wasm_bindgen::from_value::(values.get(&meta_key)) + .unwrap_or_else(|e| { + tracing::error!("failed to deser DapAggregateShareMeta: {e:?}"); + panic!("{e}") + }); + + if values.size() < 2 { + // this means there were no chunks, only a metadata key + tracing::warn!("meta key found but chunks are under legacy keys"); + return self.legacy_get_agg_share(keys, meta).await; + } + let chunks = js_map_to_chunks(keys, values); + + Ok(if chunks.is_empty() { + meta.into_agg_share_without_data() + } else { + let kind = meta.kind.expect("if there is data there should be a type"); + + fn from_slice(chunks: &[u8]) -> Result> { + let share = T::byte_slice_into_vec(chunks).map_err(|e| { + worker::Error::Internal( + serde_wasm_bindgen::to_value(&e.to_string()) + .expect("string never fails to convert to JsValue"), + ) + })?; + // TODO: this an abuse of this API, this type should not be constructed this way. + Ok(AggregateShare::from(OutputShare::from(share))) + } + + let data = match kind { + VdafKind::Field64 => VdafAggregateShare::Field64(from_slice(&chunks)?), + VdafKind::Field128 => VdafAggregateShare::Field128(from_slice(&chunks)?), + VdafKind::FieldPrio2 => VdafAggregateShare::FieldPrio2(from_slice(&chunks)?), + }; + + meta.into_agg_share_with_data(data) + }) + } +} + #[durable_object] impl DurableObject for AggregateStore { fn new(state: State, env: Env) -> Self { @@ -86,17 +284,88 @@ impl AggregateStore { // Input: `agg_share_dellta: DapAggregateShare` // Output: `()` (DURABLE_AGGREGATE_STORE_MERGE, Method::Post) => { + let span = debug_span!( + "DURABLE_AGGREGATE_STORE_MERGE", + correlation = rand::random::() + ); + let _guard = span.enter(); + tracing::debug!("STARTING MERGE"); let agg_share_delta = req_parse(&mut req).await?; - // To keep this pair of get and put operations atomic, there should be no await - // points between them. See the note below `transaction()` on - // https://developers.cloudflare.com/workers/runtime-apis/durable-objects/#transactional-storage-api. - // See issue #109. - let mut agg_share: DapAggregateShare = - state_get_or_default(&self.state, "agg_share").await?; + let keys = self.agg_share_shard_keys(); + let mut agg_share = self.get_agg_share(&keys).await?; agg_share.merge(agg_share_delta).map_err(int_err)?; - self.state.storage().put("agg_share", agg_share).await?; + let (meta, data) = DapAggregateShareMetadata::from_agg_share(agg_share); + + let (num_chunks, chunks_map) = data + .as_ref() + .map(|data| { + const CHUNK_SIZE: usize = 128_000; + // stolen from + // https://doc.rust-lang.org/std/primitive.usize.html#method.div_ceil + // because it's nightly only + fn div_ceil(lhs: usize, rhs: usize) -> usize { + let d = lhs / rhs; + let r = lhs % rhs; + if r > 0 && rhs > 0 { + d + 1 + } else { + d + } + } + + let data = data.get_encoded(); + let num_chunks = div_ceil(data.len(), CHUNK_SIZE); + assert!( + num_chunks <= keys.len(), + "too many chunks {num_chunks}. max is {}", + keys.len() + ); + + // This is effectively a map of chunk_v2_XX to a byte slice + let chunks_map = js_sys::Object::new(); + + let mut base_idx = 0; + for key in &keys[..num_chunks] { + let end = usize::min(base_idx + CHUNK_SIZE + 1, data.len()); + let chunk = &data[base_idx..end]; + + let value = js_sys::Uint8Array::new_with_length(chunk.len() as _); + value.copy_from(chunk); + + js_sys::Reflect::set( + &chunks_map, + &JsValue::from_str(key.as_str()), + &value.into(), + )?; + + base_idx = end; + } + assert_eq!( + base_idx, + data.len(), + "len: {} chunk_size: {} rem: {}", + data.len(), + CHUNK_SIZE, + data.len() % keys.len(), + ); + Result::Ok((num_chunks, chunks_map)) + }) + .transpose()? + .unwrap_or_default(); + + tracing::debug!(chunk_count = num_chunks, "PUTTING NOW"); + + js_sys::Reflect::set( + &chunks_map, + &JsValue::from_str(METADATA_KEY), + &serde_wasm_bindgen::to_value(&meta)?, + )?; + + self.state.storage().put_multiple_raw(chunks_map).await?; + + tracing::debug!("LEAVING MERGE"); Response::from_json(&()) } @@ -105,8 +374,7 @@ impl AggregateStore { // Idempotent // Output: `DapAggregateShare` (DURABLE_AGGREGATE_STORE_GET, Method::Get) => { - let agg_share: DapAggregateShare = - state_get_or_default(&self.state, "agg_share").await?; + let agg_share = self.get_agg_share(&self.agg_share_shard_keys()).await?; Response::from_json(&agg_share) }