Skip to content

Commit

Permalink
Coalesce report initialization operations into a trait
Browse files Browse the repository at this point in the history
Define a trait `DapReportInitializer` that takes a sequence of reports
in the "consumed" state and transitions each report to the "initialized"
state. Implementers are expected to perform the early validation steps
(i.e., the protocol logic previously implemented by
`DapAggregator::check_early_reject()`) and initialize VDAF preparation.

The primary motivation for this change is to improve the performance of
`DaphneWorker`. VDAF preparation initialization is expensive, so we'd
like to parallelize the computation. Because the Workers runtime is
single-threaded, the best option is to offload VDAF preparation to
durable objects. In particular, `DaphneWorker`'s implementation of
`DapReportInitializer` piggy-backs this computation on the requests to
`ReportsProcessed`.

Co-authored-by: mendess <pmendes@cloudflare.com>
  • Loading branch information
cjpatton and mendess committed Jul 10, 2023
1 parent 787eda4 commit e4cf5ed
Show file tree
Hide file tree
Showing 8 changed files with 709 additions and 369 deletions.
6 changes: 2 additions & 4 deletions daphne/benches/aggregation.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) 2023 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

use criterion::{criterion_group, criterion_main, Criterion};
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use daphne::{
hpke::HpkeKemId, testing::AggregationJobTest, DapLeaderTransition, DapMeasurement, DapVersion,
Prio3Config, VdafConfig,
Expand Down Expand Up @@ -46,9 +46,7 @@ fn handle_agg_job_init_req(c: &mut Criterion) {

c.bench_function(&format!("handle_agg_job_init_req {vdaf:?}"), |b| {
b.to_async(&rt).iter(|| async {
agg_job_test
.handle_agg_job_init_req(&agg_job_init_req)
.await
black_box(agg_job_test.handle_agg_job_init_req(&agg_job_init_req)).await
})
});
}
Expand Down
22 changes: 12 additions & 10 deletions daphne/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ use crate::{
hpke::HpkeReceiverConfig,
messages::{
AggregationJobId, BatchId, BatchSelector, Collection, CollectionJobId,
Draft02AggregationJobId, Duration, Interval, PartialBatchSelector, ReportId,
ReportMetadata, TaskId, Time,
Draft02AggregationJobId, Duration, Interval, PartialBatchSelector, ReportId, TaskId, Time,
},
taskprov::TaskprovVersion,
vdaf::{VdafAggregateShare, VdafPrepMessage, VdafPrepState, VdafVerifyKey},
Expand All @@ -66,6 +65,7 @@ use std::{
fmt::{Debug, Display},
};
use url::Url;
use vdaf::{EarlyReportState, EarlyReportStateConsumed};

/// DAP version used for a task.
#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, Hash, PartialEq, Serialize)]
Expand Down Expand Up @@ -364,31 +364,33 @@ impl DapTaskConfig {
}
}

/// Return the batch span of a set of reports with the given metadata.
/// Return the batch span of a set of reports.
pub fn batch_span_for_meta<'sel, 'rep>(
&self,
part_batch_sel: &'sel PartialBatchSelector,
report_meta: impl Iterator<Item = &'rep ReportMetadata>,
) -> Result<HashMap<DapBatchBucket<'sel>, Vec<&'rep ReportMetadata>>, DapError> {
consumed_reports: impl Iterator<Item = &'rep EarlyReportStateConsumed<'rep>>,
) -> Result<HashMap<DapBatchBucket<'sel>, Vec<&'rep EarlyReportStateConsumed<'rep>>>, DapError>
{
if !self.query.is_valid_part_batch_sel(part_batch_sel) {
return Err(fatal_error!(
err = "partial batch selector not compatible with task",
));
}

let mut span: HashMap<_, Vec<_>> = HashMap::new();
for metadata in report_meta {
for consumed_report in consumed_reports.filter(|consumed_report| consumed_report.is_ready())
{
let bucket = match part_batch_sel {
PartialBatchSelector::TimeInterval => DapBatchBucket::TimeInterval {
batch_window: self.quantized_time_lower_bound(metadata.time),
batch_window: self.quantized_time_lower_bound(consumed_report.metadata().time),
},
PartialBatchSelector::FixedSizeByBatchId { batch_id } => {
DapBatchBucket::FixedSize { batch_id }
}
};

let report_ids = span.entry(bucket).or_default();
report_ids.push(metadata);
let consumed_reports_per_bucket = span.entry(bucket).or_default();
consumed_reports_per_bucket.push(consumed_report);
}

Ok(span)
Expand Down Expand Up @@ -481,7 +483,7 @@ impl DapHelperState {
.map_err(|e| DapAbort::from_codec_error(e, None))?;
let mut seq = vec![];
while (r.position() as usize) < data.len() {
let state = VdafPrepState::decode_with_param(vdaf_config, &mut r)
let state = VdafPrepState::decode_with_param(&(vdaf_config, false), &mut r)
.map_err(|e| DapAbort::from_codec_error(e, None))?;
let time = Time::decode(&mut r).map_err(|e| DapAbort::from_codec_error(e, None))?;
let report_id =
Expand Down
123 changes: 31 additions & 92 deletions daphne/src/roles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ use crate::{
constant_time_eq, decode_base64url, AggregateShare, AggregateShareReq,
AggregationJobContinueReq, AggregationJobInitReq, AggregationJobResp, BatchId,
BatchSelector, Collection, CollectionJobId, CollectionReq, HpkeConfigList, Interval,
PartialBatchSelector, Query, Report, ReportId, ReportMetadata, TaskId, Time,
TransitionFailure, TransitionVar,
PartialBatchSelector, Query, Report, ReportMetadata, TaskId, Time, TransitionFailure,
},
metrics::{DaphneMetrics, DaphneRequestType},
taskprov, DapAbort, DapAggregateShare, DapCollectJob, DapError, DapGlobalConfig,
DapHelperState, DapHelperTransition, DapLeaderProcessTelemetry, DapLeaderTransition,
DapOutputShare, DapQueryConfig, DapRequest, DapResource, DapResponse, DapTaskConfig,
DapVersion, MetaAggregationJobId,
taskprov,
vdaf::{EarlyReportStateConsumed, EarlyReportStateInitialized},
DapAbort, DapAggregateShare, DapCollectJob, DapError, DapGlobalConfig, DapHelperState,
DapHelperTransition, DapLeaderProcessTelemetry, DapLeaderTransition, DapOutputShare,
DapQueryConfig, DapRequest, DapResource, DapResponse, DapTaskConfig, DapVersion,
MetaAggregationJobId,
};
use async_trait::async_trait;
use futures::TryFutureExt;
Expand All @@ -41,9 +42,26 @@ pub trait DapAuthorizedSender<S> {
) -> Result<S, DapError>;
}

/// Report initializer. Used by a DAP Aggregator [`DapAggregator`] when initializing an aggregation
/// job.
#[async_trait(?Send)]
pub trait DapReportInitializer {
/// Initialize a sequence of reports that are in the "consumed" state by performing the early
/// validation steps (check if the report was replayed, belongs to a batch that has been
/// collected) and initializing VDAF preparation.
async fn initialize_reports<'req>(
&self,
is_leader: bool,
task_id: &TaskId,
task_config: &DapTaskConfig,
part_batch_sel: &PartialBatchSelector,
consumed_reports: Vec<EarlyReportStateConsumed<'req>>,
) -> Result<Vec<EarlyReportStateInitialized<'req>>, DapError>;
}

/// DAP Aggregator functionality.
#[async_trait(?Send)]
pub trait DapAggregator<S>: HpkeDecrypter + Sized {
pub trait DapAggregator<S>: HpkeDecrypter + DapReportInitializer + Sized {
/// A refernce to a task configuration stored by the Aggregator.
type WrappedDapTaskConfig<'a>: AsRef<DapTaskConfig>;

Expand Down Expand Up @@ -120,16 +138,6 @@ pub trait DapAggregator<S>: HpkeDecrypter + Sized {
batch_sel: &BatchSelector,
) -> Result<DapAggregateShare, DapError>;

/// Ensure a set of reorts can be aggregated. Return a transition failure for each report
/// that must be rejected early, due to the repot being replayed, the bucket that contains the
/// report being collected, etc.
async fn check_early_reject(
&self,
task_id: &TaskId,
part_batch_sel: &PartialBatchSelector,
report_meta: impl Iterator<Item = &ReportMetadata>,
) -> Result<HashMap<ReportId, TransitionFailure>, DapError>;

/// Mark a batch as collected.
async fn mark_collected(
&self,
Expand Down Expand Up @@ -477,34 +485,12 @@ pub trait DapLeader<S>: DapAuthorizedSender<S> + DapAggregator<S> {
) -> Result<u64, DapAbort> {
let metrics = self.metrics().with_host(host);

// Filter out early rejected reports.
//
// TODO Add a test similar to http_post_aggregate_init_expired_task() in roles_test.rs that
// verifies that the Leader properly checks for expiration. This will require extending the
// test framework to run run_agg_job() directly.
let early_rejects = self
.check_early_reject(
task_id,
part_batch_sel,
reports.iter().map(|report| &report.report_metadata),
)
.await?;
let reports = reports
.into_iter()
.filter(|report| {
if let Some(failure) = early_rejects.get(&report.report_metadata.id) {
metrics.report_inc_by(&format!("rejected_{failure}"), 1);
return false;
}
true
})
.collect();

// Prepare AggregationJobInitReq.
let agg_job_id = MetaAggregationJobId::gen_for_version(&task_config.version);
let transition = task_config
.vdaf
.produce_agg_job_init_req(
self,
self,
task_id,
task_config,
Expand Down Expand Up @@ -883,68 +869,21 @@ pub trait DapHelper<S>: DapAggregator<S> {
&agg_job_init_req.agg_param,
)?;

let early_rejects_future = self.check_early_reject(
task_id,
&agg_job_init_req.part_batch_sel,
agg_job_init_req
.report_shares
.iter()
.map(|report_share| &report_share.report_metadata),
);

let transition_future = task_config
let transition = task_config
.vdaf
.handle_agg_job_init_req(
self,
self,
task_id,
task_config,
&agg_job_init_req,
&metrics,
)
.map_err(DapError::Abort);

let (early_rejects, transition) =
futures::try_join!(early_rejects_future, transition_future)?;
.map_err(DapError::Abort)
.await?;

let agg_job_resp = match transition {
DapHelperTransition::Continue(mut state, mut agg_job_resp) => {
// Filter out early rejected reports.
let mut state_index = 0;
for transition in agg_job_resp.transitions.iter_mut() {
let early_failure = early_rejects.get(&transition.report_id);
if !matches!(transition.var, TransitionVar::Failed(..))
&& early_failure.is_some()
{
// NOTE(cjpatton) Clippy wants us to use and `if let` statement to
// unwrap `early_failure`. I don't think this works becauase we
// only want to enter this loop if `early_failure.is_some()` and
// the current `transition` is not a failure. As far as I know, `if
// let` statements can't yet be combined with other conditions.
#[allow(clippy::unnecessary_unwrap)]
let failure = early_failure.unwrap();
transition.var = TransitionVar::Failed(*failure);

// Remove VDAF preparation state of reports that were rejected early.
if transition.report_id == state.seq[state_index].2 {
let _val = state.seq.remove(state_index);
} else {
// The report ID in the Helper state and Aggregate response
// must be aligned. If not, handle as an internal error.
return Err(fatal_error!(err = "report IDs not aligned").into());
}

// NOTE(cjpatton) Unlike the Leader, the Helper filters out early
// rejects after processing all of the reports. (This is an
// optimization intended to reduce latency.) To avoid overcounting
// rejection metrics, the latter rejections take precedence. The
// Leader has the opposite behavior: Early rejections are resolved
// first, so take precedence.
metrics.report_inc_by(&format!("rejected_{failure}"), 1);
} else {
state_index += 1;
}
}

DapHelperTransition::Continue(state, agg_job_resp) => {
if !self
.put_helper_state_if_not_exists(task_id, &agg_job_id, &state)
.await?
Expand Down
Loading

0 comments on commit e4cf5ed

Please sign in to comment.