diff --git a/Cargo.lock b/Cargo.lock index c4eb4c43f..491b7bf5c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1546,6 +1546,7 @@ dependencies = [ "janus_test_util", "lazy_static", "prio", + "rand", "ring", "testcontainers", "tokio", @@ -1936,9 +1937,9 @@ checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" [[package]] name = "prio" -version = "0.7.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34dec6b842c17bb4ebff74907295c102f8790d47b8a914eade914efc6e2b1aa8" +checksum = "6e58420a9eed730ff64632a63d8f01becd86367999562978c2ac891582ebb720" dependencies = [ "aes 0.8.1", "aes-gcm", diff --git a/db/schema.sql b/db/schema.sql index ab3fb9c26..87a11f156 100644 --- a/db/schema.sql +++ b/db/schema.sql @@ -43,14 +43,14 @@ CREATE TABLE task_hpke_keys( CONSTRAINT fk_task_id FOREIGN KEY(task_id) REFERENCES tasks(id) ); --- The VDAF verification parameters used by a given task. +-- The VDAF verification keys used by a given task. -- TODO(#229): support multiple verification parameters per task -CREATE TABLE task_vdaf_verify_params( +CREATE TABLE task_vdaf_verify_keys( id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, -- artificial ID, internal-only - task_id BIGINT NOT NULL, -- task ID the verification parameter is associated with - vdaf_verify_param BYTEA NOT NULL, -- the VDAF verification parameter (opaque VDAF message, encrypted) + task_id BIGINT NOT NULL, -- task ID the verification key is associated with + vdaf_verify_key BYTEA NOT NULL, -- the VDAF verification key (encrypted) - CONSTRAINT vdaf_verify_param_unique_task_id UNIQUE(task_id), + CONSTRAINT vdaf_verify_key_unique_task_id UNIQUE(task_id), CONSTRAINT fk_task_id FOREIGN KEY(task_id) REFERENCES tasks(id) ); diff --git a/janus/Cargo.toml b/janus/Cargo.toml index bdca24307..30d231e94 100644 --- a/janus/Cargo.toml +++ b/janus/Cargo.toml @@ -12,7 +12,7 @@ chrono = "0.4" hex = "0.4.3" hpke-dispatch = "0.3.0" num_enum = "0.5.6" -prio = "0.7.1" +prio = "0.8.0" rand = "0.8" ring = "0.16.20" serde = { version = "1.0.137", features = ["derive"] } diff --git a/janus_server/Cargo.toml b/janus_server/Cargo.toml index 36b4207d8..f16e95de0 100644 --- a/janus_server/Cargo.toml +++ b/janus_server/Cargo.toml @@ -34,7 +34,7 @@ opentelemetry-otlp = { version = "0.10.0", optional = true, features = ["metrics opentelemetry-prometheus = { version = "0.10.0", optional = true } opentelemetry-semantic-conventions = { version = "0.9.0", optional = true } postgres-types = { version = "0.2.3", features = ["derive", "array-impls"] } -prio = "0.7.1" +prio = "0.8.0" prometheus = { version = "0.13.1", optional = true } rand = "0.8" reqwest = { version = "0.11.4", default-features = false, features = ["rustls-tls", "json"] } diff --git a/janus_server/src/aggregator.rs b/janus_server/src/aggregator.rs index b939cc9e5..7bed0977d 100644 --- a/janus_server/src/aggregator.rs +++ b/janus_server/src/aggregator.rs @@ -21,7 +21,7 @@ use crate::{ AggregateInitializeResp, AggregateShareReq, AggregateShareResp, AggregationJobId, CollectReq, CollectResp, PrepareStep, PrepareStepResult, ReportShare, ReportShareError, }, - task::{Task, VdafInstance, DAP_AUTH_HEADER}, + task::{self, Task, VdafInstance, DAP_AUTH_HEADER, PRIO3_AES128_VERIFY_KEY_LENGTH}, }; use bytes::Bytes; use futures::try_join; @@ -42,13 +42,14 @@ use prio::{ codec::{Decode, Encode, ParameterizedDecode}, vdaf::{ self, - prio3::{Prio3Aes128Count, Prio3Aes128Histogram, Prio3Aes128Sum}, - PrepareTransition, Vdaf, + prio3::{Prio3, Prio3Aes128Count, Prio3Aes128Histogram, Prio3Aes128Sum}, + PrepareTransition, }, }; use std::{ collections::{HashMap, HashSet}, convert::Infallible, + fmt, future::Future, io::Cursor, net::SocketAddr, @@ -115,17 +116,17 @@ pub enum Error { #[error("VDAF error: {0}")] Vdaf(#[from] vdaf::VdafError), /// A collect or aggregate share request was rejected because the interval is valid, per §4.6 - #[error("Invalid batch interval: {0} {1:?}")] + #[error("invalid batch interval: {0} {1:?}")] InvalidBatchInterval(Interval, TaskId), /// There are not enough reports in the batch interval to meet the task's minimum batch size. - #[error("Insufficient number of reports ({0}) for task {1:?}")] + #[error("insufficient number of reports ({0}) for task {1:?}")] InsufficientBatchSize(u64, TaskId), #[error("URL parse error: {0}")] Url(#[from] url::ParseError), /// The checksum or report count in one aggregator's aggregate share does not match the other /// aggregator's aggregate share, suggesting different sets of reports were aggregated. #[error( - "Batch misalignment: own checksum: {own_checksum:?} own report count: {own_report_count} \ + "batch misalignment: own checksum: {own_checksum:?} own report count: {own_report_count} \ peer checksum: {peer_checksum:?} peer report count: {peer_report_count}" )] BatchMisalignment { @@ -136,13 +137,13 @@ peer checksum: {peer_checksum:?} peer report count: {peer_report_count}" peer_report_count: u64, }, /// Too many queries against a single batch. - #[error("Maxiumum batch lifetime for task {0:?} exceeded")] + #[error("maxiumum batch lifetime for task {0:?} exceeded")] BatchLifetimeExceeded(TaskId), /// HPKE failure. #[error("HPKE error: {0}")] Hpke(#[from] janus::hpke::Error), /// Error handling task parameters - #[error("Invalid task parameters: {0}")] + #[error("invalid task parameters: {0}")] TaskParameters(#[from] crate::task::Error), /// Error making an HTTP request #[error("HTTP client error: {0}")] @@ -394,34 +395,33 @@ impl TaskAggregator { /// Create a new aggregator. `report_recipient` is used to decrypt reports received by this /// aggregator. fn new(task: Task) -> Result { - let current_vdaf_verify_parameter = task.vdaf_verify_parameters.last().unwrap(); + let current_vdaf_verify_key = task.vdaf_verify_keys.last().unwrap(); let vdaf_ops = match &task.vdaf { VdafInstance::Real(janus::task::VdafInstance::Prio3Aes128Count) => { - let vdaf = Prio3Aes128Count::new(2)?; - let verify_param = ::VerifyParam::get_decoded_with_param( - &vdaf, - current_vdaf_verify_parameter, - )?; - VdafOps::Prio3Aes128Count(vdaf, verify_param) + let vdaf = Prio3::new_aes128_count(2)?; + let verify_key = current_vdaf_verify_key + .clone() + .try_into() + .map_err(|_| Error::TaskParameters(task::Error::AggregatorAuthKeySize))?; + VdafOps::Prio3Aes128Count(vdaf, verify_key) } VdafInstance::Real(janus::task::VdafInstance::Prio3Aes128Sum { bits }) => { - let vdaf = Prio3Aes128Sum::new(2, *bits)?; - let verify_param = ::VerifyParam::get_decoded_with_param( - &vdaf, - current_vdaf_verify_parameter, - )?; - VdafOps::Prio3Aes128Sum(vdaf, verify_param) + let vdaf = Prio3::new_aes128_sum(2, *bits)?; + let verify_key = current_vdaf_verify_key + .clone() + .try_into() + .map_err(|_| Error::TaskParameters(task::Error::AggregatorAuthKeySize))?; + VdafOps::Prio3Aes128Sum(vdaf, verify_key) } VdafInstance::Real(janus::task::VdafInstance::Prio3Aes128Histogram { buckets }) => { - let vdaf = Prio3Aes128Histogram::new(2, &*buckets)?; - let verify_param = - ::VerifyParam::get_decoded_with_param( - &vdaf, - current_vdaf_verify_parameter, - )?; - VdafOps::Prio3Aes128Histogram(vdaf, verify_param) + let vdaf = Prio3::new_aes128_histogram(2, &*buckets)?; + let verify_key = current_vdaf_verify_key + .clone() + .try_into() + .map_err(|_| Error::TaskParameters(task::Error::AggregatorAuthKeySize))?; + VdafOps::Prio3Aes128Histogram(vdaf, verify_key) } #[cfg(test)] @@ -438,9 +438,10 @@ impl TaskAggregator { #[cfg(test)] VdafInstance::FakeFailsPrepStep => { + const VERIFY_KEY_LENGTH: usize = dummy_vdaf::Vdaf::VERIFY_KEY_LENGTH; VdafOps::Fake(dummy_vdaf::Vdaf::new().with_prep_step_fn( - || -> PrepareTransition<(), (), dummy_vdaf::OutputShare> { - PrepareTransition::Fail(VdafError::Uncategorized( + || -> Result, VdafError> { + Err(VdafError::Uncategorized( "FakeFailsPrepStep failed at prep_step".to_string(), )) }, @@ -621,12 +622,10 @@ impl TaskAggregator { /// VdafOps stores VDAF-specific operations for a TaskAggregator in a non-generic way. #[allow(clippy::enum_variant_names)] enum VdafOps { - Prio3Aes128Count(Prio3Aes128Count, ::VerifyParam), - Prio3Aes128Sum(Prio3Aes128Sum, ::VerifyParam), - Prio3Aes128Histogram( - Prio3Aes128Histogram, - ::VerifyParam, - ), + // For the Prio3 VdafOps, the second parameter is the verify_key. + Prio3Aes128Count(Prio3Aes128Count, [u8; PRIO3_AES128_VERIFY_KEY_LENGTH]), + Prio3Aes128Sum(Prio3Aes128Sum, [u8; PRIO3_AES128_VERIFY_KEY_LENGTH]), + Prio3Aes128Histogram(Prio3Aes128Histogram, [u8; PRIO3_AES128_VERIFY_KEY_LENGTH]), #[cfg(test)] Fake(dummy_vdaf::Vdaf), @@ -642,44 +641,39 @@ impl VdafOps { req: AggregateInitializeReq, ) -> Result { match self { - VdafOps::Prio3Aes128Count(vdaf, verify_param) => { - Self::handle_aggregate_init_generic::( - datastore, - vdaf, - task, - verify_param, - req, - ) + VdafOps::Prio3Aes128Count(vdaf, verify_key) => { + Self::handle_aggregate_init_generic::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + _, + >(datastore, vdaf, task, verify_key, req) .await } - VdafOps::Prio3Aes128Sum(vdaf, verify_param) => { - Self::handle_aggregate_init_generic::( - datastore, - vdaf, - task, - verify_param, - req, - ) + VdafOps::Prio3Aes128Sum(vdaf, verify_key) => { + Self::handle_aggregate_init_generic::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Sum, + _, + >(datastore, vdaf, task, verify_key, req) .await } - VdafOps::Prio3Aes128Histogram(vdaf, verify_param) => { - Self::handle_aggregate_init_generic::( - datastore, - vdaf, - task, - verify_param, - req, - ) + VdafOps::Prio3Aes128Histogram(vdaf, verify_key) => { + Self::handle_aggregate_init_generic::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Histogram, + _, + >(datastore, vdaf, task, verify_key, req) .await } #[cfg(test)] VdafOps::Fake(vdaf) => { - Self::handle_aggregate_init_generic::( + const VERIFY_KEY_LENGTH: usize = dummy_vdaf::Vdaf::VERIFY_KEY_LENGTH; + Self::handle_aggregate_init_generic::( datastore, vdaf, task, - &(), + &[], req, ) .await @@ -694,45 +688,36 @@ impl VdafOps { req: AggregateContinueReq, ) -> Result { match self { - VdafOps::Prio3Aes128Count(vdaf, verify_param) => { - Self::handle_aggregate_continue_generic::( - datastore, - vdaf, - task, - verify_param, - req, - ) + VdafOps::Prio3Aes128Count(vdaf, _) => { + Self::handle_aggregate_continue_generic::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + _, + >(datastore, vdaf, task, req) .await } - VdafOps::Prio3Aes128Sum(vdaf, verify_param) => { - Self::handle_aggregate_continue_generic::( - datastore, - vdaf, - task, - verify_param, - req, - ) + VdafOps::Prio3Aes128Sum(vdaf, _) => { + Self::handle_aggregate_continue_generic::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Sum, + _, + >(datastore, vdaf, task, req) .await } - VdafOps::Prio3Aes128Histogram(vdaf, verify_param) => { - Self::handle_aggregate_continue_generic::( - datastore, - vdaf, - task, - verify_param, - req, - ) + VdafOps::Prio3Aes128Histogram(vdaf, _) => { + Self::handle_aggregate_continue_generic::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Histogram, + _, + >(datastore, vdaf, task, req) .await } #[cfg(test)] VdafOps::Fake(vdaf) => { - Self::handle_aggregate_continue_generic::( - datastore, - vdaf, - task, - &(), - req, + const VERIFY_KEY_LENGTH: usize = dummy_vdaf::Vdaf::VERIFY_KEY_LENGTH; + Self::handle_aggregate_continue_generic::( + datastore, vdaf, task, req, ) .await } @@ -741,21 +726,21 @@ impl VdafOps { /// Implements the aggregate initialization request portion of the `/aggregate` endpoint for the /// helper, described in §4.4.4.1 of draft-gpew-priv-ppm. - async fn handle_aggregate_init_generic( + async fn handle_aggregate_init_generic, C: Clock>( datastore: &Datastore, vdaf: &A, task: &Task, - verify_param: &A::VerifyParam, + verify_key: &[u8; L], req: AggregateInitializeReq, ) -> Result where A: 'static + Send + Sync, A::AggregationParam: Send + Sync, A::AggregateShare: Send + Sync, - for<'a> >::Error: std::fmt::Display, + for<'a> >::Error: fmt::Display, for<'a> &'a A::AggregateShare: Into>, A::PrepareMessage: Send + Sync, - A::PrepareStep: Send + Sync + Encode, + A::PrepareState: Send + Sync + Encode, A::OutputShare: Send + Sync, for<'a> &'a A::OutputShare: Into>, { @@ -777,16 +762,15 @@ impl VdafOps { // Decrypt shares & prepare initialization states. (§4.4.4.1) // TODO(#221): reject reports that are "too old" with `report-dropped`. // TODO(#221): reject reports in batches that have completed an aggregate-share request with `batch-collected`. - struct ReportShareData + struct ReportShareData> where for<'a> &'a A::AggregateShare: Into>, { report_share: ReportShare, prep_result: PrepareStepResult, - agg_state: ReportAggregationState, + agg_state: ReportAggregationState, } let mut saw_continue = false; - let mut saw_finish = false; let mut report_share_data = Vec::new(); let agg_param = A::AggregationParam::get_decoded(&req.agg_param)?; for report_share in req.report_shares { @@ -827,7 +811,7 @@ impl VdafOps { // TODO(https://github.com/ietf-wg-ppm/draft-ietf-ppm-dap/issues/255): agree on/standardize // an error code for "client report data can't be decoded" & use it here. let input_share = plaintext.and_then(|plaintext| { - A::InputShare::get_decoded_with_param(verify_param, &plaintext) + A::InputShare::get_decoded_with_param(&(vdaf, Role::Helper.index().unwrap()), &plaintext) .map_err(|err| { warn!(?task_id, nonce = %report_share.nonce, %err, "Couldn't decode input share from report share"); ReportShareError::VdafPrepError @@ -837,10 +821,11 @@ impl VdafOps { // Next, the aggregator runs the preparation-state initialization algorithm for the VDAF // associated with the task and computes the first state transition. [...] If either // step fails, then the aggregator MUST fail with error `vdaf-prep-error`. (§4.4.2.2) - let step = input_share.and_then(|input_share| { + let init_rslt = input_share.and_then(|input_share| { vdaf .prepare_init( - verify_param, + verify_key, + Role::Helper.index().unwrap(), &agg_param, &report_share.nonce.get_encoded(), &input_share, @@ -850,56 +835,32 @@ impl VdafOps { ReportShareError::VdafPrepError }) }); - let prep_trans = step.map(|step| vdaf.prepare_step(step, None)); - report_share_data.push(match prep_trans { - Ok(PrepareTransition::Continue(prep_step, prep_msg)) => { + report_share_data.push(match init_rslt { + Ok((prep_state, prep_share)) => { saw_continue = true; ReportShareData { report_share, - prep_result: PrepareStepResult::Continued(prep_msg.get_encoded()), - agg_state: ReportAggregationState::::Waiting(prep_step, None), + prep_result: PrepareStepResult::Continued(prep_share.get_encoded()), + agg_state: ReportAggregationState::::Waiting(prep_state, None), } } - Ok(PrepareTransition::Finish(output_share)) => { - saw_finish = true; - ReportShareData { - report_share, - prep_result: PrepareStepResult::Finished, - agg_state: ReportAggregationState::::Finished(output_share), - } - } - - Ok(PrepareTransition::Fail(err)) => { - warn!(?task_id, nonce = %report_share.nonce, %err, "Couldn't prepare_step report share"); - ReportShareData { - report_share, - prep_result: PrepareStepResult::Failed(ReportShareError::VdafPrepError), - agg_state: ReportAggregationState::::Failed(ReportShareError::VdafPrepError), - } - }, - Err(err) => ReportShareData { report_share, prep_result: PrepareStepResult::Failed(err), - agg_state: ReportAggregationState::::Failed(err), + agg_state: ReportAggregationState::::Failed(err), }, }); } // Store data to datastore. - let aggregation_job_state = match (saw_continue, saw_finish) { - (false, false) => AggregationJobState::Finished, // everything failed, or there were no reports - (true, false) => AggregationJobState::InProgress, - (false, true) => AggregationJobState::Finished, - (true, true) => { - return Err(Error::Internal( - "VDAF took an inconsistent number of rounds to reach Finish state".to_string(), - )) - } + let aggregation_job_state = if saw_continue { + AggregationJobState::InProgress + } else { + AggregationJobState::Finished }; - let aggregation_job = Arc::new(AggregationJob:: { + let aggregation_job = Arc::new(AggregationJob:: { aggregation_job_id: req.job_id, task_id, aggregation_param: agg_param, @@ -915,7 +876,7 @@ impl VdafOps { // Write aggregation job. tx.put_aggregation_job(&aggregation_job).await?; - let mut accumulator = Accumulator::::new( + let mut accumulator = Accumulator::::new( task_id, min_batch_duration, &aggregation_job.aggregation_param, @@ -925,7 +886,7 @@ impl VdafOps { // Write client report & report aggregation. tx.put_report_share(task_id, &share_data.report_share) .await?; - tx.put_report_aggregation(&ReportAggregation:: { + tx.put_report_aggregation(&ReportAggregation:: { aggregation_job_id: req.job_id, task_id, nonce: share_data.report_share.nonce, @@ -934,7 +895,7 @@ impl VdafOps { }) .await?; - if let ReportAggregationState::::Finished(ref output_share) = + if let ReportAggregationState::::Finished(ref output_share) = share_data.agg_state { accumulator.update(output_share, share_data.report_share.nonce)?; @@ -962,29 +923,27 @@ impl VdafOps { }) } - async fn handle_aggregate_continue_generic( + async fn handle_aggregate_continue_generic, C: Clock>( datastore: &Datastore, vdaf: &A, task: &Task, - verify_param: &A::VerifyParam, req: AggregateContinueReq, ) -> Result where A: 'static + Send + Sync, A::AggregationParam: Send + Sync, A::AggregateShare: Send + Sync, - for<'a> >::Error: std::fmt::Display, + for<'a> >::Error: fmt::Display, for<'a> &'a A::AggregateShare: Into>, - A::PrepareStep: Send + Sync + Encode + ParameterizedDecode, + for<'a> A::PrepareState: Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)>, + A::PrepareShare: Send + Sync, A::PrepareMessage: Send + Sync, A::OutputShare: Send + Sync + for<'a> TryFrom<&'a [u8]>, for<'a> &'a A::OutputShare: Into>, - A::VerifyParam: Send + Sync, { let task_id = task.id; let min_batch_duration = task.min_batch_duration; let vdaf = Arc::new(vdaf.clone()); - let verify_param = Arc::new(verify_param.clone()); let prep_steps = Arc::new(req.prepare_steps); // TODO(#224): don't hold DB transaction open while computing VDAF updates? @@ -995,15 +954,15 @@ impl VdafOps { Ok(datastore .run_tx(|tx| { let vdaf = Arc::clone(&vdaf); - let verify_param = Arc::clone(&verify_param); let prep_steps = Arc::clone(&prep_steps); Box::pin(async move { // Read existing state. let (aggregation_job, report_aggregations) = try_join!( - tx.get_aggregation_job::(task_id, req.job_id), - tx.get_report_aggregations_for_aggregation_job::( - &verify_param, + tx.get_aggregation_job::(task_id, req.job_id), + tx.get_report_aggregations_for_aggregation_job( + vdaf.as_ref(), + Role::Helper, task_id, req.job_id, ), @@ -1014,7 +973,7 @@ impl VdafOps { let mut report_aggregations = report_aggregations.into_iter(); let (mut saw_continue, mut saw_finish) = (false, false); let mut response_prep_steps = Vec::new(); - let mut accumulator = Accumulator::::new(task_id, min_batch_duration, &aggregation_job.aggregation_param); + let mut accumulator = Accumulator::::new(task_id, min_batch_duration, &aggregation_job.aggregation_param); for prep_step in prep_steps.iter() { // Match preparation step received from leader to stored report aggregation, @@ -1073,19 +1032,18 @@ impl VdafOps { }; // Compute the next transition, prepare to respond & update DB. - let prep_trans = vdaf.prepare_step(prep_state, Some(prep_msg)); - match prep_trans { - PrepareTransition::Continue(prep_state, prep_msg) => { + match vdaf.prepare_step(prep_state, prep_msg) { + Ok(PrepareTransition::Continue(prep_state, prep_share))=> { saw_continue = true; report_aggregation.state = ReportAggregationState::Waiting(prep_state, None); response_prep_steps.push(PrepareStep { nonce: prep_step.nonce, - result: PrepareStepResult::Continued(prep_msg.get_encoded()), + result: PrepareStepResult::Continued(prep_share.get_encoded()), }) } - PrepareTransition::Finish(output_share) => { + Ok(PrepareTransition::Finish(output_share)) => { saw_finish = true; accumulator.update(&output_share, prep_step.nonce)?; report_aggregation.state = @@ -1096,7 +1054,7 @@ impl VdafOps { }); } - PrepareTransition::Fail(err) => { + Err(err) => { warn!(?task_id, job_id = ?req.job_id, nonce = %prep_step.nonce, %err, "Prepare step failed"); report_aggregation.state = ReportAggregationState::Failed(ReportShareError::VdafPrepError); @@ -1153,43 +1111,54 @@ impl VdafOps { ) -> Result { match self { VdafOps::Prio3Aes128Count(_, _) => { - Self::handle_collect_generic::(datastore, task, collect_req) - .await + Self::handle_collect_generic::( + datastore, + task, + collect_req, + ) + .await } VdafOps::Prio3Aes128Sum(_, _) => { - Self::handle_collect_generic::(datastore, task, collect_req) - .await - } - VdafOps::Prio3Aes128Histogram(_, _) => { - Self::handle_collect_generic::( + Self::handle_collect_generic::( datastore, task, collect_req, ) .await } + VdafOps::Prio3Aes128Histogram(_, _) => { + Self::handle_collect_generic::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Histogram, + _, + >(datastore, task, collect_req) + .await + } #[cfg(test)] VdafOps::Fake(_) => { - Self::handle_collect_generic::(datastore, task, collect_req) - .await + const VERIFY_KEY_LENGTH: usize = dummy_vdaf::Vdaf::VERIFY_KEY_LENGTH; + Self::handle_collect_generic::( + datastore, + task, + collect_req, + ) + .await } } } #[tracing::instrument(skip(datastore), err)] - async fn handle_collect_generic( + async fn handle_collect_generic, C: Clock>( datastore: &Datastore, task: &Task, req: &CollectReq, ) -> Result where - A: vdaf::Aggregator, A::AggregationParam: Send + Sync, A::AggregateShare: Send + Sync, Vec: for<'a> From<&'a A::AggregateShare>, for<'a> >::Error: std::fmt::Display, - C: Clock, { // §4.5: check that the batch interval meets the requirements from §4.6 if !task.validate_batch_interval(req.batch_interval) { @@ -1212,7 +1181,7 @@ impl VdafOps { debug!(collect_request = ?req, "Cache miss, creating new collect job UUID"); let aggregation_param = A::AggregationParam::get_decoded(&req.agg_param)?; let batch_unit_aggregations = tx - .get_batch_unit_aggregations_for_task_in_interval::( + .get_batch_unit_aggregations_for_task_in_interval::( task.id, req.batch_interval, &aggregation_param, @@ -1241,33 +1210,32 @@ impl VdafOps { ) -> Result, Error> { match self { VdafOps::Prio3Aes128Count(_, _) => { - Self::handle_collect_job_generic::( - datastore, - task, - collect_job_id, - ) - .await - } - VdafOps::Prio3Aes128Sum(_, _) => { - Self::handle_collect_job_generic::( - datastore, - task, - collect_job_id, - ) + Self::handle_collect_job_generic::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + _, + >(datastore, task, collect_job_id) .await } + VdafOps::Prio3Aes128Sum(_, _) => Self::handle_collect_job_generic::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Sum, + _, + >(datastore, task, collect_job_id) + .await, VdafOps::Prio3Aes128Histogram(_, _) => { - Self::handle_collect_job_generic::( - datastore, - task, - collect_job_id, - ) + Self::handle_collect_job_generic::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Histogram, + _, + >(datastore, task, collect_job_id) .await } #[cfg(test)] VdafOps::Fake(_) => { - Self::handle_collect_job_generic::( + const VERIFY_KEY_LENGTH: usize = dummy_vdaf::Vdaf::VERIFY_KEY_LENGTH; + Self::handle_collect_job_generic::( datastore, task, collect_job_id, @@ -1277,24 +1245,22 @@ impl VdafOps { } } - async fn handle_collect_job_generic( + async fn handle_collect_job_generic, C: Clock>( datastore: &Datastore, task: &Task, collect_job_id: Uuid, ) -> Result, Error> where - A: vdaf::Aggregator, A::AggregationParam: Send + Sync, A::AggregateShare: Send + Sync, Vec: for<'a> From<&'a A::AggregateShare>, for<'a> >::Error: std::fmt::Display, - C: Clock, { let task_id = task.id; let collect_job = datastore .run_tx(move |tx| { Box::pin(async move { - tx.get_collect_job::(collect_job_id) + tx.get_collect_job::(collect_job_id) .await? .ok_or_else(|| { datastore::Error::User( @@ -1370,33 +1336,34 @@ impl VdafOps { ) -> Result { match self { VdafOps::Prio3Aes128Count(_, _) => { - Self::handle_aggregate_share_generic::( - datastore, - task, - aggregate_share_req, - ) + Self::handle_aggregate_share_generic::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + _, + >(datastore, task, aggregate_share_req) .await } VdafOps::Prio3Aes128Sum(_, _) => { - Self::handle_aggregate_share_generic::( - datastore, - task, - aggregate_share_req, - ) + Self::handle_aggregate_share_generic::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Sum, + _, + >(datastore, task, aggregate_share_req) .await } VdafOps::Prio3Aes128Histogram(_, _) => { - Self::handle_aggregate_share_generic::( - datastore, - task, - aggregate_share_req, - ) + Self::handle_aggregate_share_generic::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Histogram, + _, + >(datastore, task, aggregate_share_req) .await } #[cfg(test)] VdafOps::Fake(_) => { - Self::handle_aggregate_share_generic::( + const VERIFY_KEY_LENGTH: usize = dummy_vdaf::Vdaf::VERIFY_KEY_LENGTH; + Self::handle_aggregate_share_generic::( datastore, task, aggregate_share_req, @@ -1406,18 +1373,16 @@ impl VdafOps { } } - async fn handle_aggregate_share_generic( + async fn handle_aggregate_share_generic, C: Clock>( datastore: &Datastore, task: &Task, aggregate_share_req: &AggregateShareReq, ) -> Result where - A: vdaf::Aggregator, A::AggregationParam: Send + Sync, A::AggregateShare: Send + Sync, Vec: for<'a> From<&'a A::AggregateShare>, for<'a> >::Error: std::fmt::Display, - C: Clock, { let aggregate_share_job = datastore .run_tx(move |tx| { @@ -1446,7 +1411,7 @@ impl VdafOps { &aggregate_share_req.aggregation_param, )?; let batch_unit_aggregations = tx - .get_batch_unit_aggregations_for_task_in_interval::( + .get_batch_unit_aggregations_for_task_in_interval::( task.id, aggregate_share_req.batch_interval, &aggregation_param, @@ -1461,13 +1426,13 @@ impl VdafOps { .await?; let (helper_aggregate_share, report_count, checksum) = - compute_aggregate_share::(&task, &batch_unit_aggregations) + compute_aggregate_share::(&task, &batch_unit_aggregations) .await .map_err(|e| datastore::Error::User(e.into()))?; // Now that we are satisfied that the request is serviceable, we consume batch lifetime by // recording the aggregate share request parameters and the result. - let aggregate_share_job = AggregateShareJob:: { + let aggregate_share_job = AggregateShareJob:: { task_id: task.id, batch_interval: aggregate_share_req.batch_interval, aggregation_param, @@ -2030,7 +1995,7 @@ mod tests { }, trace::test_util::install_test_trace_subscriber, }; - use ::janus_test_util::{dummy_vdaf, run_vdaf, MockClock, PrepareTransition}; + use ::janus_test_util::{dummy_vdaf, run_vdaf, MockClock}; use assert_matches::assert_matches; use http::Method; use janus::{ @@ -2853,7 +2818,7 @@ mod tests { install_test_trace_subscriber(); let task_id = TaskId::random(); - let mut task = new_dummy_task( + let task = new_dummy_task( task_id, janus::task::VdafInstance::Prio3Aes128Count.into(), Role::Helper, @@ -2861,9 +2826,14 @@ mod tests { let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; - let vdaf = Prio3Aes128Count::new(2).unwrap(); - let (public_param, verify_params) = vdaf.setup().unwrap(); - task.vdaf_verify_parameters = vec![verify_params.iter().last().unwrap().get_encoded()]; + let vdaf = Prio3::new_aes128_count(2).unwrap(); + let verify_key = task + .vdaf_verify_keys + .get(0) + .unwrap() + .clone() + .try_into() + .unwrap(); let hpke_key = current_hpke_key(&task.hpke_keys); datastore @@ -2876,7 +2846,7 @@ mod tests { // report_share_0 is a "happy path" report. let nonce_0 = Nonce::generate(&clock); - let input_share = run_vdaf(&vdaf, &public_param, &verify_params, &(), nonce_0, &0) + let input_share = run_vdaf(&vdaf, &verify_key, &(), nonce_0, &0) .input_shares .remove(1); let report_share_0 = generate_helper_report_share::( @@ -3205,7 +3175,7 @@ mod tests { let task_id = TaskId::random(); let aggregation_job_id = AggregationJobId::random(); - let mut task = new_dummy_task( + let task = new_dummy_task( task_id, janus::task::VdafInstance::Prio3Aes128Count.into(), Role::Helper, @@ -3214,17 +3184,22 @@ mod tests { let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; let datastore = Arc::new(datastore); - let vdaf = Prio3Aes128Count::new(2).unwrap(); - let (public_param, verify_params) = vdaf.setup().unwrap(); - task.vdaf_verify_parameters = vec![verify_params.iter().last().unwrap().get_encoded()]; + let vdaf = Arc::new(Prio3::new_aes128_count(2).unwrap()); + let verify_key = task + .vdaf_verify_keys + .get(0) + .unwrap() + .clone() + .try_into() + .unwrap(); let hpke_key = current_hpke_key(&task.hpke_keys); // report_share_0 is a "happy path" report. let nonce_0 = Nonce::generate(&clock); - let transcript_0 = run_vdaf(&vdaf, &public_param, &verify_params, &(), nonce_0, &0); - let prep_step_0 = assert_matches!(&transcript_0.transitions[1][0], PrepareTransition::::Continue(prep_step, _) => prep_step.clone()); - let out_share_0 = assert_matches!(&transcript_0.transitions[1][1], PrepareTransition::::Finish(out_share) => out_share.clone()); - let prep_msg_0 = transcript_0.combined_messages[0].clone(); + let transcript_0 = run_vdaf(vdaf.as_ref(), &verify_key, &(), nonce_0, &0); + let prep_state_0 = assert_matches!(&transcript_0.prepare_transitions[1][0], PrepareTransition::::Continue(prep_state, _) => prep_state.clone()); + let out_share_0 = assert_matches!(&transcript_0.prepare_transitions[1][1], PrepareTransition::::Finish(out_share) => out_share.clone()); + let prep_msg_0 = transcript_0.prepare_messages[0].clone(); let report_share_0 = generate_helper_report_share::( task_id, nonce_0, @@ -3234,8 +3209,8 @@ mod tests { // report_share_1 is omitted by the leader's request. let nonce_1 = Nonce::generate(&clock); - let transcript_1 = run_vdaf(&vdaf, &(), &verify_params, &(), nonce_1, &0); - let prep_step_1 = assert_matches!(&transcript_1.transitions[1][0], PrepareTransition::::Continue(prep_step, _) => prep_step.clone()); + let transcript_1 = run_vdaf(vdaf.as_ref(), &verify_key, &(), nonce_1, &0); + let prep_state_1 = assert_matches!(&transcript_1.prepare_transitions[1][0], PrepareTransition::::Continue(prep_state, _) => prep_state.clone()); let report_share_1 = generate_helper_report_share::( task_id, nonce_1, @@ -3248,7 +3223,7 @@ mod tests { let task = task.clone(); let (report_share_0, report_share_1) = (report_share_0.clone(), report_share_1.clone()); - let (prep_step_0, prep_step_1) = (prep_step_0.clone(), prep_step_1.clone()); + let (prep_state_0, prep_state_1) = (prep_state_0.clone(), prep_state_1.clone()); Box::pin(async move { tx.put_task(&task).await?; @@ -3256,7 +3231,10 @@ mod tests { tx.put_report_share(task_id, &report_share_0).await?; tx.put_report_share(task_id, &report_share_1).await?; - tx.put_aggregation_job(&AggregationJob:: { + tx.put_aggregation_job(&AggregationJob::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id, task_id, aggregation_param: (), @@ -3264,20 +3242,26 @@ mod tests { }) .await?; - tx.put_report_aggregation(&ReportAggregation:: { + tx.put_report_aggregation(&ReportAggregation::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id, task_id, nonce: nonce_0, ord: 0, - state: ReportAggregationState::Waiting(prep_step_0, None), + state: ReportAggregationState::Waiting(prep_state_0, None), }) .await?; - tx.put_report_aggregation(&ReportAggregation:: { + tx.put_report_aggregation(&ReportAggregation::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id, task_id, nonce: nonce_1, ord: 1, - state: ReportAggregationState::Waiting(prep_step_1, None), + state: ReportAggregationState::Waiting(prep_state_1, None), }) .await?; @@ -3335,20 +3319,23 @@ mod tests { // Validate datastore. let (aggregation_job, report_aggregations) = datastore .run_tx(|tx| { - let verify_params = verify_params.clone(); + let vdaf = Arc::clone(&vdaf); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::(task_id, aggregation_job_id) + .get_aggregation_job::( + task_id, + aggregation_job_id, + ) .await?; let report_aggregations = tx - .get_report_aggregations_for_aggregation_job::( - &verify_params[1].clone(), + .get_report_aggregations_for_aggregation_job( + vdaf.as_ref(), + Role::Helper, task_id, aggregation_job_id, ) .await?; - Ok((aggregation_job, report_aggregations)) }) }) @@ -3392,7 +3379,7 @@ mod tests { let task_id = TaskId::random(); let aggregation_job_id_0 = AggregationJobId::random(); let aggregation_job_id_1 = AggregationJobId::random(); - let mut task = new_dummy_task( + let task = new_dummy_task( task_id, janus::task::VdafInstance::Prio3Aes128Count.into(), Role::Helper, @@ -3407,17 +3394,22 @@ mod tests { .unwrap(), ); - let vdaf = Prio3Aes128Count::new(2).unwrap(); - let (_, verify_params) = vdaf.setup().unwrap(); - task.vdaf_verify_parameters = vec![verify_params.iter().last().unwrap().get_encoded()]; + let vdaf = Prio3::new_aes128_count(2).unwrap(); + let verify_key = task + .vdaf_verify_keys + .get(0) + .unwrap() + .clone() + .try_into() + .unwrap(); let hpke_key = current_hpke_key(&task.hpke_keys); // report_share_0 is a "happy path" report. let nonce_0 = Nonce::generate(&first_batch_unit_interval_clock); - let transcript_0 = run_vdaf(&vdaf, &(), &verify_params, &(), nonce_0, &0); - let prep_step_0 = assert_matches!(&transcript_0.transitions[1][0], PrepareTransition::::Continue(prep_step, _) => prep_step.clone()); - let out_share_0 = assert_matches!(&transcript_0.transitions[1][1], PrepareTransition::::Finish(out_share) => out_share.clone()); - let prep_msg_0 = transcript_0.combined_messages[0].clone(); + let transcript_0 = run_vdaf(&vdaf, &verify_key, &(), nonce_0, &0); + let prep_state_0 = assert_matches!(&transcript_0.prepare_transitions[1][0], PrepareTransition::::Continue(prep_state, _) => prep_state.clone()); + let out_share_0 = assert_matches!(&transcript_0.prepare_transitions[1][1], PrepareTransition::::Finish(out_share) => out_share.clone()); + let prep_msg_0 = transcript_0.prepare_messages[0].clone(); let report_share_0 = generate_helper_report_share::( task_id, nonce_0, @@ -3428,10 +3420,10 @@ mod tests { // report_share_1 is another "happy path" report to exercise in-memory accumulation of // output shares let nonce_1 = Nonce::generate(&first_batch_unit_interval_clock); - let transcript_1 = run_vdaf(&vdaf, &(), &verify_params, &(), nonce_1, &0); - let prep_step_1 = assert_matches!(&transcript_1.transitions[1][0], PrepareTransition::::Continue(prep_step, _) => prep_step.clone()); - let out_share_1 = assert_matches!(&transcript_1.transitions[1][1], PrepareTransition::::Finish(out_share) => out_share.clone()); - let prep_msg_1 = transcript_1.combined_messages[0].clone(); + let transcript_1 = run_vdaf(&vdaf, &verify_key, &(), nonce_1, &0); + let prep_state_1 = assert_matches!(&transcript_1.prepare_transitions[1][0], PrepareTransition::::Continue(prep_state, _) => prep_state.clone()); + let out_share_1 = assert_matches!(&transcript_1.prepare_transitions[1][1], PrepareTransition::::Finish(out_share) => out_share.clone()); + let prep_msg_1 = transcript_1.prepare_messages[0].clone(); let report_share_1 = generate_helper_report_share::( task_id, nonce_1, @@ -3441,10 +3433,10 @@ mod tests { // report share 2 aggregates successfully, but into a distinct batch unit aggregation. let nonce_2 = Nonce::generate(&second_batch_unit_interval_clock); - let transcript_2 = run_vdaf(&vdaf, &(), &verify_params, &(), nonce_2, &0); - let prep_step_2 = assert_matches!(&transcript_2.transitions[1][0], PrepareTransition::::Continue(prep_step, _) => prep_step.clone()); - let out_share_2 = assert_matches!(&transcript_2.transitions[1][1], PrepareTransition::::Finish(out_share) => out_share.clone()); - let prep_msg_2 = transcript_2.combined_messages[0].clone(); + let transcript_2 = run_vdaf(&vdaf, &verify_key, &(), nonce_2, &0); + let prep_state_2 = assert_matches!(&transcript_2.prepare_transitions[1][0], PrepareTransition::::Continue(prep_state, _) => prep_state.clone()); + let out_share_2 = assert_matches!(&transcript_2.prepare_transitions[1][1], PrepareTransition::::Finish(out_share) => out_share.clone()); + let prep_msg_2 = transcript_2.prepare_messages[0].clone(); let report_share_2 = generate_helper_report_share::( task_id, nonce_2, @@ -3460,10 +3452,10 @@ mod tests { report_share_1.clone(), report_share_2.clone(), ); - let (prep_step_0, prep_step_1, prep_step_2) = ( - prep_step_0.clone(), - prep_step_1.clone(), - prep_step_2.clone(), + let (prep_state_0, prep_state_1, prep_state_2) = ( + prep_state_0.clone(), + prep_state_1.clone(), + prep_state_2.clone(), ); Box::pin(async move { @@ -3473,7 +3465,10 @@ mod tests { tx.put_report_share(task_id, &report_share_1).await?; tx.put_report_share(task_id, &report_share_2).await?; - tx.put_aggregation_job(&AggregationJob:: { + tx.put_aggregation_job(&AggregationJob::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id: aggregation_job_id_0, task_id, aggregation_param: (), @@ -3481,28 +3476,37 @@ mod tests { }) .await?; - tx.put_report_aggregation(&ReportAggregation:: { + tx.put_report_aggregation(&ReportAggregation::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id: aggregation_job_id_0, task_id, nonce: nonce_0, ord: 0, - state: ReportAggregationState::Waiting(prep_step_0, None), + state: ReportAggregationState::Waiting(prep_state_0, None), }) .await?; - tx.put_report_aggregation(&ReportAggregation:: { + tx.put_report_aggregation(&ReportAggregation::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id: aggregation_job_id_0, task_id, nonce: nonce_1, ord: 1, - state: ReportAggregationState::Waiting(prep_step_1, None), + state: ReportAggregationState::Waiting(prep_state_1, None), }) .await?; - tx.put_report_aggregation(&ReportAggregation:: { + tx.put_report_aggregation(&ReportAggregation::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id: aggregation_job_id_0, task_id, nonce: nonce_2, ord: 2, - state: ReportAggregationState::Waiting(prep_step_2, None), + state: ReportAggregationState::Waiting(prep_state_2, None), }) .await?; @@ -3557,7 +3561,7 @@ mod tests { let batch_unit_aggregations = datastore .run_tx(|tx| { Box::pin(async move { - tx.get_batch_unit_aggregations_for_task_in_interval::( + tx.get_batch_unit_aggregations_for_task_in_interval::( task_id, Interval::new( nonce_0 @@ -3585,7 +3589,7 @@ mod tests { assert_eq!( batch_unit_aggregations, vec![ - BatchUnitAggregation:: { + BatchUnitAggregation:: { task_id, unit_interval_start: nonce_0 .time() @@ -3596,7 +3600,7 @@ mod tests { report_count: 2, checksum, }, - BatchUnitAggregation:: { + BatchUnitAggregation:: { task_id, unit_interval_start: nonce_2 .time() @@ -3614,10 +3618,10 @@ mod tests { // batch_unit_aggregations rows created earlier. // report_share_3 gets aggreated into the first batch unit interval. let nonce_3 = Nonce::generate(&first_batch_unit_interval_clock); - let transcript_3 = run_vdaf(&vdaf, &(), &verify_params, &(), nonce_3, &0); - let prep_step_3 = assert_matches!(&transcript_3.transitions[1][0], PrepareTransition::::Continue(prep_step, _) => prep_step.clone()); - let out_share_3 = assert_matches!(&transcript_3.transitions[1][1], PrepareTransition::::Finish(out_share) => out_share.clone()); - let prep_msg_3 = transcript_3.combined_messages[0].clone(); + let transcript_3 = run_vdaf(&vdaf, &verify_key, &(), nonce_3, &0); + let prep_state_3 = assert_matches!(&transcript_3.prepare_transitions[1][0], PrepareTransition::::Continue(prep_state, _) => prep_state.clone()); + let out_share_3 = assert_matches!(&transcript_3.prepare_transitions[1][1], PrepareTransition::::Finish(out_share) => out_share.clone()); + let prep_msg_3 = transcript_3.prepare_messages[0].clone(); let report_share_3 = generate_helper_report_share::( task_id, nonce_3, @@ -3627,10 +3631,10 @@ mod tests { // report_share_4 gets aggregated into the second batch unit interval let nonce_4 = Nonce::generate(&second_batch_unit_interval_clock); - let transcript_4 = run_vdaf(&vdaf, &(), &verify_params, &(), nonce_4, &0); - let prep_step_4 = assert_matches!(&transcript_4.transitions[1][0], PrepareTransition::::Continue(prep_step, _) => prep_step.clone()); - let out_share_4 = assert_matches!(&transcript_4.transitions[1][1], PrepareTransition::::Finish(out_share) => out_share.clone()); - let prep_msg_4 = transcript_4.combined_messages[0].clone(); + let transcript_4 = run_vdaf(&vdaf, &verify_key, &(), nonce_4, &0); + let prep_state_4 = assert_matches!(&transcript_4.prepare_transitions[1][0], PrepareTransition::::Continue(prep_state, _) => prep_state.clone()); + let out_share_4 = assert_matches!(&transcript_4.prepare_transitions[1][1], PrepareTransition::::Finish(out_share) => out_share.clone()); + let prep_msg_4 = transcript_4.prepare_messages[0].clone(); let report_share_4 = generate_helper_report_share::( task_id, nonce_4, @@ -3640,10 +3644,10 @@ mod tests { // report share 5 also gets aggregated into the second batch unit interval let nonce_5 = Nonce::generate(&second_batch_unit_interval_clock); - let transcript_5 = run_vdaf(&vdaf, &(), &verify_params, &(), nonce_5, &0); - let prep_step_5 = assert_matches!(&transcript_5.transitions[1][0], PrepareTransition::::Continue(prep_step, _) => prep_step.clone()); - let out_share_5 = assert_matches!(&transcript_5.transitions[1][1], PrepareTransition::::Finish(out_share) => out_share.clone()); - let prep_msg_5 = transcript_5.combined_messages[0].clone(); + let transcript_5 = run_vdaf(&vdaf, &verify_key, &(), nonce_5, &0); + let prep_state_5 = assert_matches!(&transcript_5.prepare_transitions[1][0], PrepareTransition::::Continue(prep_state, _) => prep_state.clone()); + let out_share_5 = assert_matches!(&transcript_5.prepare_transitions[1][1], PrepareTransition::::Finish(out_share) => out_share.clone()); + let prep_msg_5 = transcript_5.prepare_messages[0].clone(); let report_share_5 = generate_helper_report_share::( task_id, nonce_5, @@ -3658,10 +3662,10 @@ mod tests { report_share_4.clone(), report_share_5.clone(), ); - let (prep_step_3, prep_step_4, prep_step_5) = ( - prep_step_3.clone(), - prep_step_4.clone(), - prep_step_5.clone(), + let (prep_state_3, prep_state_4, prep_state_5) = ( + prep_state_3.clone(), + prep_state_4.clone(), + prep_state_5.clone(), ); Box::pin(async move { @@ -3669,7 +3673,10 @@ mod tests { tx.put_report_share(task_id, &report_share_4).await?; tx.put_report_share(task_id, &report_share_5).await?; - tx.put_aggregation_job(&AggregationJob:: { + tx.put_aggregation_job(&AggregationJob::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id: aggregation_job_id_1, task_id, aggregation_param: (), @@ -3677,28 +3684,37 @@ mod tests { }) .await?; - tx.put_report_aggregation(&ReportAggregation:: { + tx.put_report_aggregation(&ReportAggregation::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id: aggregation_job_id_1, task_id, nonce: nonce_3, ord: 3, - state: ReportAggregationState::Waiting(prep_step_3, None), + state: ReportAggregationState::Waiting(prep_state_3, None), }) .await?; - tx.put_report_aggregation(&ReportAggregation:: { + tx.put_report_aggregation(&ReportAggregation::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id: aggregation_job_id_1, task_id, nonce: nonce_4, ord: 4, - state: ReportAggregationState::Waiting(prep_step_4, None), + state: ReportAggregationState::Waiting(prep_state_4, None), }) .await?; - tx.put_report_aggregation(&ReportAggregation:: { + tx.put_report_aggregation(&ReportAggregation::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id: aggregation_job_id_1, task_id, nonce: nonce_5, ord: 5, - state: ReportAggregationState::Waiting(prep_step_5, None), + state: ReportAggregationState::Waiting(prep_state_5, None), }) .await?; @@ -3752,7 +3768,7 @@ mod tests { let batch_unit_aggregations = datastore .run_tx(|tx| { Box::pin(async move { - tx.get_batch_unit_aggregations_for_task_in_interval::( + tx.get_batch_unit_aggregations_for_task_in_interval::( task_id, Interval::new( nonce_0 @@ -3788,7 +3804,7 @@ mod tests { assert_eq!( batch_unit_aggregations, vec![ - BatchUnitAggregation:: { + BatchUnitAggregation:: { task_id, unit_interval_start: nonce_0 .time() @@ -3799,7 +3815,7 @@ mod tests { report_count: 3, checksum: first_checksum, }, - BatchUnitAggregation:: { + BatchUnitAggregation:: { task_id, unit_interval_start: nonce_2 .time() @@ -3831,6 +3847,8 @@ mod tests { let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; let datastore = Arc::new(datastore); + const VERIFY_KEY_LENGTH: usize = dummy_vdaf::Vdaf::VERIFY_KEY_LENGTH; + // Setup datastore. datastore .run_tx(|tx| { @@ -3851,14 +3869,19 @@ mod tests { }, ) .await?; - tx.put_aggregation_job(&AggregationJob:: { - aggregation_job_id, - task_id, - aggregation_param: (), - state: AggregationJobState::InProgress, - }) + tx.put_aggregation_job( + &AggregationJob:: { + aggregation_job_id, + task_id, + aggregation_param: (), + state: AggregationJobState::InProgress, + }, + ) .await?; - tx.put_report_aggregation(&ReportAggregation:: { + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + dummy_vdaf::Vdaf, + > { aggregation_job_id, task_id, nonce, @@ -3932,6 +3955,8 @@ mod tests { let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; let datastore = Arc::new(datastore); + const VERIFY_KEY_LENGTH: usize = dummy_vdaf::Vdaf::VERIFY_KEY_LENGTH; + // Setup datastore. datastore .run_tx(|tx| { @@ -3952,14 +3977,19 @@ mod tests { }, ) .await?; - tx.put_aggregation_job(&AggregationJob:: { - aggregation_job_id, - task_id, - aggregation_param: (), - state: AggregationJobState::InProgress, - }) + tx.put_aggregation_job( + &AggregationJob:: { + aggregation_job_id, + task_id, + aggregation_param: (), + state: AggregationJobState::InProgress, + }, + ) .await?; - tx.put_report_aggregation(&ReportAggregation:: { + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + dummy_vdaf::Vdaf, + > { aggregation_job_id, task_id, nonce, @@ -4023,11 +4053,15 @@ mod tests { .run_tx(|tx| { Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::(task_id, aggregation_job_id) + .get_aggregation_job::( + task_id, + aggregation_job_id, + ) .await?; let report_aggregation = tx - .get_report_aggregation::( - &(), + .get_report_aggregation( + &dummy_vdaf::Vdaf::default(), + Role::Helper, task_id, aggregation_job_id, nonce, @@ -4076,6 +4110,8 @@ mod tests { let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; + const VERIFY_KEY_LENGTH: usize = dummy_vdaf::Vdaf::VERIFY_KEY_LENGTH; + // Setup datastore. datastore .run_tx(|tx| { @@ -4096,14 +4132,19 @@ mod tests { }, ) .await?; - tx.put_aggregation_job(&AggregationJob:: { - aggregation_job_id, - task_id, - aggregation_param: (), - state: AggregationJobState::InProgress, - }) + tx.put_aggregation_job( + &AggregationJob:: { + aggregation_job_id, + task_id, + aggregation_param: (), + state: AggregationJobState::InProgress, + }, + ) .await?; - tx.put_report_aggregation(&ReportAggregation:: { + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + dummy_vdaf::Vdaf, + > { aggregation_job_id, task_id, nonce, @@ -4184,6 +4225,8 @@ mod tests { let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; + const VERIFY_KEY_LENGTH: usize = dummy_vdaf::Vdaf::VERIFY_KEY_LENGTH; + // Setup datastore. datastore .run_tx(|tx| { @@ -4219,15 +4262,20 @@ mod tests { ) .await?; - tx.put_aggregation_job(&AggregationJob:: { - aggregation_job_id, - task_id, - aggregation_param: (), - state: AggregationJobState::InProgress, - }) + tx.put_aggregation_job( + &AggregationJob:: { + aggregation_job_id, + task_id, + aggregation_param: (), + state: AggregationJobState::InProgress, + }, + ) .await?; - tx.put_report_aggregation(&ReportAggregation:: { + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + dummy_vdaf::Vdaf, + > { aggregation_job_id, task_id, nonce: nonce_0, @@ -4235,7 +4283,10 @@ mod tests { state: ReportAggregationState::Waiting((), None), }) .await?; - tx.put_report_aggregation(&ReportAggregation:: { + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + dummy_vdaf::Vdaf, + > { aggregation_job_id, task_id, nonce: nonce_1, @@ -4316,6 +4367,8 @@ mod tests { let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; + const VERIFY_KEY_LENGTH: usize = dummy_vdaf::Vdaf::VERIFY_KEY_LENGTH; + // Setup datastore. datastore .run_tx(|tx| { @@ -4336,14 +4389,19 @@ mod tests { }, ) .await?; - tx.put_aggregation_job(&AggregationJob:: { - aggregation_job_id, - task_id, - aggregation_param: (), - state: AggregationJobState::InProgress, - }) + tx.put_aggregation_job( + &AggregationJob:: { + aggregation_job_id, + task_id, + aggregation_param: (), + state: AggregationJobState::InProgress, + }, + ) .await?; - tx.put_report_aggregation(&ReportAggregation:: { + tx.put_report_aggregation(&ReportAggregation::< + VERIFY_KEY_LENGTH, + dummy_vdaf::Vdaf, + > { aggregation_job_id, task_id, nonce, @@ -4616,7 +4674,7 @@ mod tests { ) .unwrap(); - tx.update_collect_job::( + tx.update_collect_job::( collect_job_id, &leader_aggregate_share, &encrypted_helper_aggregate_share, @@ -4703,6 +4761,8 @@ mod tests { let (datastore, _db_handle) = ephemeral_datastore(MockClock::default()).await; + const VERIFY_KEY_LENGTH: usize = dummy_vdaf::Vdaf::VERIFY_KEY_LENGTH; + datastore .run_tx(|tx| { let task = task.clone(); @@ -4710,7 +4770,10 @@ mod tests { Box::pin(async move { tx.put_task(&task).await?; - tx.put_batch_unit_aggregation(&BatchUnitAggregation:: { + tx.put_batch_unit_aggregation(&BatchUnitAggregation::< + VERIFY_KEY_LENGTH, + dummy_vdaf::Vdaf, + > { task_id: task.id, unit_interval_start: Time::from_seconds_since_epoch(0), aggregation_param: (), @@ -4995,7 +5058,10 @@ mod tests { datastore .run_tx(|tx| { Box::pin(async move { - tx.put_batch_unit_aggregation(&BatchUnitAggregation:: { + tx.put_batch_unit_aggregation(&BatchUnitAggregation::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { task_id, unit_interval_start: Time::from_seconds_since_epoch(500), aggregation_param, @@ -5005,7 +5071,10 @@ mod tests { }) .await?; - tx.put_batch_unit_aggregation(&BatchUnitAggregation:: { + tx.put_batch_unit_aggregation(&BatchUnitAggregation::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { task_id, unit_interval_start: Time::from_seconds_since_epoch(1500), aggregation_param, @@ -5015,7 +5084,10 @@ mod tests { }) .await?; - tx.put_batch_unit_aggregation(&BatchUnitAggregation:: { + tx.put_batch_unit_aggregation(&BatchUnitAggregation::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { task_id, unit_interval_start: Time::from_seconds_since_epoch(2000), aggregation_param, @@ -5025,7 +5097,10 @@ mod tests { }) .await?; - tx.put_batch_unit_aggregation(&BatchUnitAggregation:: { + tx.put_batch_unit_aggregation(&BatchUnitAggregation::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { task_id, unit_interval_start: Time::from_seconds_since_epoch(2500), aggregation_param, diff --git a/janus_server/src/aggregator/accumulator.rs b/janus_server/src/aggregator/accumulator.rs index 73639a454..2c10cf495 100644 --- a/janus_server/src/aggregator/accumulator.rs +++ b/janus_server/src/aggregator/accumulator.rs @@ -12,7 +12,7 @@ use std::collections::HashMap; use tracing::debug; #[derive(Debug)] -struct Accumulation +struct Accumulation> where for<'a> &'a A::AggregateShare: Into>, { @@ -21,7 +21,7 @@ where checksum: NonceChecksum, } -impl Accumulation +impl> Accumulation where for<'a> &'a A::AggregateShare: Into>, { @@ -40,7 +40,7 @@ where /// batch unit interval begins to the accumulated aggregate share, report count and checksum. #[derive(Derivative)] #[derivative(Debug)] -pub(super) struct Accumulator +pub(super) struct Accumulator> where for<'a> &'a A::AggregateShare: Into>, { @@ -48,10 +48,10 @@ where min_batch_duration: Duration, #[derivative(Debug = "ignore")] aggregation_param: A::AggregationParam, - accumulations: HashMap>, + accumulations: HashMap>, } -impl Accumulator +impl> Accumulator where for<'a> &'a A::AggregateShare: Into>, for<'a> >::Error: std::fmt::Display, @@ -110,7 +110,7 @@ where let unit_interval = Interval::new(unit_interval_start, self.min_batch_duration)?; let mut batch_unit_aggregations = tx - .get_batch_unit_aggregations_for_task_in_interval::( + .get_batch_unit_aggregations_for_task_in_interval::( self.task_id, unit_interval, &self.aggregation_param, @@ -145,7 +145,7 @@ where unit_interval_start = ?unit_interval.start(), "inserting new batch_unit_aggregation row", ); - tx.put_batch_unit_aggregation::(&BatchUnitAggregation { + tx.put_batch_unit_aggregation::(&BatchUnitAggregation { task_id: self.task_id, unit_interval_start: unit_interval.start(), aggregation_param: self.aggregation_param.clone(), diff --git a/janus_server/src/aggregator/aggregate_share.rs b/janus_server/src/aggregator/aggregate_share.rs index f3f2c1c4e..dbb268b56 100644 --- a/janus_server/src/aggregator/aggregate_share.rs +++ b/janus_server/src/aggregator/aggregate_share.rs @@ -9,7 +9,7 @@ use crate::{ Datastore, Transaction, }, message::{AggregateShareReq, AggregateShareResp}, - task::Task, + task::{Task, PRIO3_AES128_VERIFY_KEY_LENGTH}, task::{VdafInstance, DAP_AUTH_HEADER}, }; use futures::try_join; @@ -60,7 +60,7 @@ impl CollectJobDriver { ) -> Result<(), Error> { match lease.leased().vdaf { VdafInstance::Real(janus::task::VdafInstance::Prio3Aes128Count) => { - self.step_collect_job_generic::( + self.step_collect_job_generic::( datastore, lease ) @@ -68,7 +68,7 @@ impl CollectJobDriver { } VdafInstance::Real(janus::task::VdafInstance::Prio3Aes128Sum { .. }) => { - self.step_collect_job_generic::( + self.step_collect_job_generic::( datastore, lease ) @@ -76,7 +76,7 @@ impl CollectJobDriver { } VdafInstance::Real(janus::task::VdafInstance::Prio3Aes128Histogram { .. }) => { - self.step_collect_job_generic::( + self.step_collect_job_generic::( datastore, lease, ) @@ -85,7 +85,9 @@ impl CollectJobDriver { #[cfg(test)] VdafInstance::Fake => { - self.step_collect_job_generic::>( + type FakeVdaf = janus_test_util::dummy_vdaf::VdafWithAggregationParameter; + const VERIFY_KEY_LENGTH: usize = FakeVdaf::VERIFY_KEY_LENGTH; + self.step_collect_job_generic::( datastore, lease, ) @@ -97,14 +99,13 @@ impl CollectJobDriver { } #[tracing::instrument(skip(self, datastore), err)] - async fn step_collect_job_generic( + async fn step_collect_job_generic>( &self, datastore: Arc>, lease: Lease, ) -> Result<(), Error> where - C: Clock, - A: vdaf::Aggregator + 'static, + A: 'static, A::AggregationParam: Send + Sync, A::AggregateShare: 'static + Send + Sync, for<'a> &'a A::AggregateShare: Into>, @@ -124,17 +125,17 @@ impl CollectJobDriver { datastore::Error::User(Error::UnrecognizedTask(task_id).into()) })?; - let collect_job = - tx.get_collect_job::(collect_job_id) - .await? - .ok_or_else(|| { - datastore::Error::User( - Error::UnrecognizedCollectJob(collect_job_id).into(), - ) - })?; + let collect_job = tx + .get_collect_job::(collect_job_id) + .await? + .ok_or_else(|| { + datastore::Error::User( + Error::UnrecognizedCollectJob(collect_job_id).into(), + ) + })?; let batch_unit_aggregations = tx - .get_batch_unit_aggregations_for_task_in_interval::( + .get_batch_unit_aggregations_for_task_in_interval::( task.id, collect_job.batch_interval, &collect_job.aggregation_param, @@ -152,7 +153,7 @@ impl CollectJobDriver { } let (leader_aggregate_share, report_count, checksum) = - compute_aggregate_share::(&task, &batch_unit_aggregations) + compute_aggregate_share::(&task, &batch_unit_aggregations) .await .map_err(|e| datastore::Error::User(e.into()))?; @@ -194,7 +195,7 @@ impl CollectJobDriver { Arc::clone(&encrypted_helper_aggregate_share); let lease = Arc::clone(&lease); Box::pin(async move { - tx.update_collect_job::( + tx.update_collect_job::( collect_job_id, &leader_aggregate_share, &encrypted_helper_aggregate_share, @@ -237,12 +238,11 @@ impl CollectJobDriver { /// been driven to completion, and that the batch lifetime requirements have been validated for the /// included batch units. #[tracing::instrument(err)] -pub(crate) async fn compute_aggregate_share( +pub(crate) async fn compute_aggregate_share>( task: &Task, - batch_unit_aggregations: &[BatchUnitAggregation], + batch_unit_aggregations: &[BatchUnitAggregation], ) -> Result<(A::AggregateShare, u64, NonceChecksum), Error> where - A: vdaf::Aggregator, Vec: for<'a> From<&'a A::AggregateShare>, for<'a> >::Error: std::fmt::Display, { @@ -301,16 +301,18 @@ where /// Check whether any member of `batch_unit_aggregations` has been included in enough collect /// jobs (for `task.role` == [`Role::Leader`]) or aggregate share jobs (for `task.role` == /// [`Role::Helper`]) to violate the task's maximum batch lifetime. -pub(crate) async fn validate_batch_lifetime_for_unit_aggregations( +pub(crate) async fn validate_batch_lifetime_for_unit_aggregations< + const L: usize, + C: Clock, + A: vdaf::Aggregator, +>( tx: &Transaction<'_, C>, task: &Task, - batch_unit_aggregations: &[BatchUnitAggregation], + batch_unit_aggregations: &[BatchUnitAggregation], ) -> Result<(), datastore::Error> where - A: vdaf::Aggregator, Vec: for<'a> From<&'a A::AggregateShare>, for<'a> >::Error: std::fmt::Display, - C: Clock, { // Check how many rows in the relevant table have a batch interval that includes each batch // unit. Each such row consumes one unit of batch lifetime (§4.6). @@ -402,6 +404,9 @@ mod tests { let (ds, _db_handle) = ephemeral_datastore(clock.clone()).await; let ds = Arc::new(ds); + type FakeVdaf = VdafWithAggregationParameter; + const VERIFY_KEY_LENGTH: usize = FakeVdaf::VERIFY_KEY_LENGTH; + let task_id = TaskId::random(); let mut task = new_dummy_task(task_id, VdafInstance::Fake, Role::Leader); task.aggregator_endpoints = vec![ @@ -426,7 +431,7 @@ mod tests { .await?; let aggregation_job_id = AggregationJobId::random(); - tx.put_aggregation_job(&AggregationJob::> { + tx.put_aggregation_job(&AggregationJob:: { aggregation_job_id, task_id, aggregation_param, @@ -438,9 +443,7 @@ mod tests { tx.put_client_report(&Report::new(task_id, nonce, Vec::new(), Vec::new())) .await?; - tx.put_report_aggregation(&ReportAggregation::< - VdafWithAggregationParameter, - > { + tx.put_report_aggregation(&ReportAggregation:: { aggregation_job_id, task_id, nonce, @@ -478,24 +481,28 @@ mod tests { ds.run_tx(|tx| { let clock = clock.clone(); Box::pin(async move { - tx.put_batch_unit_aggregation(&BatchUnitAggregation:: { - task_id, - unit_interval_start: clock.now(), - aggregation_param, - aggregate_share: AggregateShare(), - report_count: 5, - checksum: NonceChecksum::get_decoded(&[3; 32]).unwrap(), - }) + tx.put_batch_unit_aggregation( + &BatchUnitAggregation:: { + task_id, + unit_interval_start: clock.now(), + aggregation_param, + aggregate_share: AggregateShare(), + report_count: 5, + checksum: NonceChecksum::get_decoded(&[3; 32]).unwrap(), + }, + ) .await?; - tx.put_batch_unit_aggregation(&BatchUnitAggregation:: { - task_id, - unit_interval_start: clock.now().add(Duration::from_seconds(1000)).unwrap(), - aggregation_param, - aggregate_share: AggregateShare(), - report_count: 5, - checksum: NonceChecksum::get_decoded(&[2; 32]).unwrap(), - }) + tx.put_batch_unit_aggregation( + &BatchUnitAggregation:: { + task_id, + unit_interval_start: clock.now().add(Duration::from_seconds(1000)).unwrap(), + aggregation_param, + aggregate_share: AggregateShare(), + report_count: 5, + checksum: NonceChecksum::get_decoded(&[2; 32]).unwrap(), + }, + ) .await?; Ok(()) @@ -534,7 +541,7 @@ mod tests { ds.run_tx(|tx| { Box::pin(async move { let collect_job = tx - .get_collect_job::(collect_job_id) + .get_collect_job::(collect_job_id) .await .unwrap() .unwrap(); @@ -574,7 +581,7 @@ mod tests { let helper_aggregate_share = helper_response.encrypted_aggregate_share.clone(); Box::pin(async move { let collect_job = tx - .get_collect_job::(collect_job_id) + .get_collect_job::(collect_job_id) .await .unwrap() .unwrap(); @@ -610,6 +617,7 @@ mod tests { task.min_batch_size = 10; let batch_interval = Interval::new(clock.now(), Duration::from_seconds(2000)).unwrap(); let aggregation_param = 0u8; + const VERIFY_KEY_LENGTH: usize = FakeVdaf::VERIFY_KEY_LENGTH; let (collect_job_id, lease) = ds .run_tx(|tx| { @@ -623,7 +631,7 @@ mod tests { .await?; let aggregation_job_id = AggregationJobId::random(); - tx.put_aggregation_job(&AggregationJob:: { + tx.put_aggregation_job(&AggregationJob:: { aggregation_job_id, task_id, aggregation_param, @@ -635,7 +643,7 @@ mod tests { tx.put_client_report(&Report::new(task_id, nonce, Vec::new(), Vec::new())) .await?; - tx.put_report_aggregation(&ReportAggregation:: { + tx.put_report_aggregation(&ReportAggregation:: { aggregation_job_id, task_id, nonce, @@ -644,7 +652,10 @@ mod tests { }) .await?; - tx.put_batch_unit_aggregation(&BatchUnitAggregation:: { + tx.put_batch_unit_aggregation(&BatchUnitAggregation::< + VERIFY_KEY_LENGTH, + FakeVdaf, + > { task_id, unit_interval_start: clock.now(), aggregation_param, @@ -653,7 +664,10 @@ mod tests { checksum: NonceChecksum::get_decoded(&[3; 32]).unwrap(), }) .await?; - tx.put_batch_unit_aggregation(&BatchUnitAggregation:: { + tx.put_batch_unit_aggregation(&BatchUnitAggregation::< + VERIFY_KEY_LENGTH, + FakeVdaf, + > { task_id, unit_interval_start: clock.now().add(Duration::from_seconds(1000)).unwrap(), aggregation_param, @@ -690,7 +704,7 @@ mod tests { .run_tx(|tx| { Box::pin(async move { let collect_job = tx - .get_collect_job::(collect_job_id) + .get_collect_job::(collect_job_id) .await? .unwrap(); diff --git a/janus_server/src/bin/aggregation_job_creator.rs b/janus_server/src/bin/aggregation_job_creator.rs index 9767d4626..356e920eb 100644 --- a/janus_server/src/bin/aggregation_job_creator.rs +++ b/janus_server/src/bin/aggregation_job_creator.rs @@ -11,7 +11,7 @@ use janus_server::datastore::models::{ }; use janus_server::datastore::{self, Datastore}; use janus_server::message::AggregationJobId; -use janus_server::task::Task; +use janus_server::task::{Task, PRIO3_AES128_VERIFY_KEY_LENGTH}; use prio::codec::Encode; use prio::vdaf; use prio::vdaf::prio3::{Prio3Aes128Count, Prio3Aes128Histogram, Prio3Aes128Sum}; @@ -183,21 +183,17 @@ impl AggregationJobCreator { async fn create_aggregation_jobs_for_task(&self, task: &Task) -> anyhow::Result<()> { match task.vdaf { janus_server::task::VdafInstance::Real(janus::task::VdafInstance::Prio3Aes128Count) => { - self.create_aggregation_jobs_for_task_no_param::(task) + self.create_aggregation_jobs_for_task_no_param::(task) .await } - janus_server::task::VdafInstance::Real(janus::task::VdafInstance::Prio3Aes128Sum { - .. - }) => { - self.create_aggregation_jobs_for_task_no_param::(task) + janus_server::task::VdafInstance::Real(janus::task::VdafInstance::Prio3Aes128Sum { .. }) => { + self.create_aggregation_jobs_for_task_no_param::(task) .await } - janus_server::task::VdafInstance::Real( - janus::task::VdafInstance::Prio3Aes128Histogram { .. }, - ) => { - self.create_aggregation_jobs_for_task_no_param::(task) + janus_server::task::VdafInstance::Real(janus::task::VdafInstance::Prio3Aes128Histogram { .. }) => { + self.create_aggregation_jobs_for_task_no_param::(task) .await } @@ -209,14 +205,17 @@ impl AggregationJobCreator { } #[tracing::instrument(skip(self), err)] - async fn create_aggregation_jobs_for_task_no_param>( + async fn create_aggregation_jobs_for_task_no_param< + const L: usize, + A: vdaf::Aggregator, + >( &self, task: &Task, ) -> anyhow::Result<()> where for<'a> &'a A::AggregateShare: Into>, A::PrepareMessage: Send + Sync, - A::PrepareStep: Send + Sync + Encode, + A::PrepareState: Send + Sync + Encode, A::OutputShare: Send + Sync, for<'a> &'a A::OutputShare: Into>, { @@ -268,7 +267,7 @@ impl AggregationJobCreator { report_count = agg_job_nonces.len(), "Creating aggregation job" ); - agg_jobs.push(AggregationJob:: { + agg_jobs.push(AggregationJob:: { aggregation_job_id, task_id, aggregation_param: (), @@ -276,7 +275,7 @@ impl AggregationJobCreator { }); for (ord, nonce) in agg_job_nonces.iter().enumerate() { - report_aggs.push(ReportAggregation:: { + report_aggs.push(ReportAggregation:: { aggregation_job_id, task_id, nonce: *nonce, @@ -320,11 +319,11 @@ mod tests { use janus_server::{ datastore::{Crypter, Datastore, Transaction}, message::{test_util::new_dummy_report, AggregationJobId}, - task::test_util::new_dummy_task, + task::{test_util::new_dummy_task, PRIO3_AES128_VERIFY_KEY_LENGTH}, trace::test_util::install_test_trace_subscriber, }; use janus_test_util::MockClock; - use prio::vdaf::{prio3::Prio3Aes128Count, Vdaf as _}; + use prio::vdaf::prio3::{Prio3, Prio3Aes128Count}; use std::{ collections::{HashMap, HashSet}, iter, @@ -645,23 +644,17 @@ mod tests { tx: &Transaction<'_, C>, task_id: TaskId, ) -> HashMap { - // For this test, all of the report aggregations will be in the Start state, so the verify - // parameter effectively does not matter. - let verify_param = Prio3Aes128Count::new(2) - .unwrap() - .setup() - .unwrap() - .1 - .remove(0); + let vdaf = Prio3::new_aes128_count(2).unwrap(); try_join_all( - tx.get_aggregation_jobs_for_task_id::(task_id) + tx.get_aggregation_jobs_for_task_id::(task_id) .await .unwrap() .into_iter() .map(|agg_job| { - tx.get_report_aggregations_for_aggregation_job::( - &verify_param, + tx.get_report_aggregations_for_aggregation_job( + &vdaf, + Role::Leader, task_id, agg_job.aggregation_job_id, ) diff --git a/janus_server/src/bin/aggregation_job_driver.rs b/janus_server/src/bin/aggregation_job_driver.rs index 03aa06919..77214527a 100644 --- a/janus_server/src/bin/aggregation_job_driver.rs +++ b/janus_server/src/bin/aggregation_job_driver.rs @@ -24,13 +24,13 @@ use janus_server::{ AggregateContinueReq, AggregateContinueResp, AggregateInitializeReq, AggregateInitializeResp, PrepareStep, PrepareStepResult, ReportShare, ReportShareError, }, - task::{Task, VdafInstance, DAP_AUTH_HEADER}, + task::{Task, VdafInstance, DAP_AUTH_HEADER, PRIO3_AES128_VERIFY_KEY_LENGTH}, }; use prio::{ codec::{Decode, Encode, ParameterizedDecode}, vdaf::{ self, - prio3::{Prio3Aes128Count, Prio3Aes128Histogram, Prio3Aes128Sum}, + prio3::{Prio3, Prio3Aes128Count, Prio3Aes128Histogram, Prio3Aes128Sum}, PrepareTransition, }, }; @@ -157,17 +157,17 @@ impl AggregationJobDriver { ) -> Result<()> { match lease.leased().vdaf { VdafInstance::Real(janus::task::VdafInstance::Prio3Aes128Count) => { - let vdaf = Prio3Aes128Count::new(2)?; + let vdaf = Prio3::new_aes128_count(2)?; self.step_aggregation_job_generic(datastore, vdaf, lease) .await } VdafInstance::Real(janus::task::VdafInstance::Prio3Aes128Sum { bits }) => { - let vdaf = Prio3Aes128Sum::new(2, bits)?; + let vdaf = Prio3::new_aes128_sum(2, bits)?; self.step_aggregation_job_generic(datastore, vdaf, lease) .await } VdafInstance::Real(janus::task::VdafInstance::Prio3Aes128Histogram { ref buckets }) => { - let vdaf = Prio3Aes128Histogram::new(2, buckets)?; + let vdaf = Prio3::new_aes128_histogram(2, buckets)?; self.step_aggregation_job_generic(datastore, vdaf, lease) .await } @@ -176,44 +176,51 @@ impl AggregationJobDriver { } } - async fn step_aggregation_job_generic( + async fn step_aggregation_job_generic>( &self, datastore: Arc>, vdaf: A, lease: Lease, ) -> Result<()> where - C: Clock, - A: vdaf::Aggregator + 'static + Send + Sync, + A: 'static + Send + Sync, A::AggregationParam: Send + Sync, for<'a> &'a A::AggregateShare: Into>, A::OutputShare: PartialEq + Eq + Send + Sync + for<'a> TryFrom<&'a [u8]>, for<'a> &'a A::OutputShare: Into>, - A::PrepareStep: PartialEq + Eq + Send + Sync + Encode + ParameterizedDecode, + for<'a> A::PrepareState: + PartialEq + Eq + Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)>, A::PrepareMessage: PartialEq + Eq + Send + Sync, - A::VerifyParam: Send + Sync + ParameterizedDecode, { // Read all information about the aggregation job. let vdaf = Arc::new(vdaf); let task_id = lease.leased().task_id; let aggregation_job_id = lease.leased().aggregation_job_id; - let (task, aggregation_job, report_aggregations, client_reports, verify_param) = datastore + let (task, aggregation_job, report_aggregations, client_reports, verify_key) = datastore .run_tx(|tx| { let vdaf = Arc::clone(&vdaf); Box::pin(async move { let task = tx.get_task(task_id).await?.ok_or_else(|| { datastore::Error::User(anyhow!("couldn't find task {}", task_id).into()) })?; - let verify_param = A::VerifyParam::get_decoded_with_param( - &vdaf, - task.vdaf_verify_parameters.get(0).unwrap(), - )?; + let verify_key = task + .vdaf_verify_keys + .get(0) + .unwrap() + .clone() + .try_into() + .map_err(|_| { + datastore::Error::User( + anyhow!("VDAF verification key has wrong length").into(), + ) + })?; let aggregation_job_future = - tx.get_aggregation_job::(task_id, aggregation_job_id); + tx.get_aggregation_job::(task_id, aggregation_job_id); let report_aggregations_future = tx - .get_report_aggregations_for_aggregation_job::( - &verify_param, + .get_report_aggregations_for_aggregation_job( + vdaf.as_ref(), + Role::Leader, task_id, aggregation_job_id, ); @@ -268,7 +275,7 @@ impl AggregationJobDriver { aggregation_job, report_aggregations, client_reports, - verify_param, + verify_key, )) }) }) @@ -287,7 +294,7 @@ impl AggregationJobDriver { match (saw_start, saw_waiting, saw_finished) { // Only saw report aggregations in state "start" (or failed or invalid). (true, false, false) => self.step_aggregation_job_aggregate_init( - &datastore, vdaf.as_ref(), lease, task, aggregation_job, report_aggregations, client_reports, verify_param).await, + &datastore, vdaf.as_ref(), lease, task, aggregation_job, report_aggregations, client_reports, verify_key).await, // Only saw report aggregations in state "waiting" (or failed or invalid). (false, true, false) => self.step_aggregation_job_aggregate_continue( @@ -298,25 +305,24 @@ impl AggregationJobDriver { } #[allow(clippy::too_many_arguments)] - async fn step_aggregation_job_aggregate_init( + async fn step_aggregation_job_aggregate_init>( &self, datastore: &Datastore, vdaf: &A, lease: Lease, task: Task, - aggregation_job: AggregationJob, - report_aggregations: Vec>, + aggregation_job: AggregationJob, + report_aggregations: Vec>, client_reports: Vec, - verify_param: A::VerifyParam, + verify_key: [u8; L], ) -> Result<()> where - C: Clock, - A: vdaf::Aggregator + 'static, + A: 'static, A::AggregationParam: Send + Sync, for<'a> &'a A::AggregateShare: Into>, A::OutputShare: PartialEq + Eq + Send + Sync, for<'a> &'a A::OutputShare: Into>, - A::PrepareStep: PartialEq + Eq + Send + Sync + Encode, + A::PrepareState: PartialEq + Eq + Send + Sync + Encode, A::PrepareMessage: PartialEq + Eq + Send + Sync, { // Zip the report aggregations at start with the client reports, verifying that their nonces @@ -404,7 +410,7 @@ impl AggregationJobDriver { } }; let leader_input_share = match A::InputShare::get_decoded_with_param( - &verify_param, + &(vdaf, Role::Leader.index().unwrap()), &leader_input_share_bytes, ) { Ok(leader_input_share) => leader_input_share, @@ -418,13 +424,14 @@ impl AggregationJobDriver { }; // Initialize the leader's preparation state from the input share. - let prep_state = match vdaf.prepare_init( - &verify_param, + let (prep_state, prep_share) = match vdaf.prepare_init( + &verify_key, + Role::Leader.index().unwrap(), &aggregation_job.aggregation_param, &report.nonce().get_encoded(), &leader_input_share, ) { - Ok(prep_state) => prep_state, + Ok(prep_state_and_share) => prep_state_and_share, Err(err) => { error!(report_nonce = %report_aggregation.nonce, ?err, "Couldn't initialize leader's preparation state"); report_aggregation.state = @@ -433,14 +440,6 @@ impl AggregationJobDriver { continue; } }; - let leader_transition = vdaf.prepare_step(prep_state, None); - if let PrepareTransition::Fail(err) = leader_transition { - error!(report_nonce = %report_aggregation.nonce, ?err, "Couldn't step leader's initial preparation state"); - report_aggregation.state = - ReportAggregationState::Failed(ReportShareError::VdafPrepError); - report_aggregations_to_write.push(report_aggregation); - continue; - }; report_shares.push(ReportShare { nonce: report.nonce(), @@ -449,7 +448,7 @@ impl AggregationJobDriver { }); stepped_aggregations.push(SteppedAggregation { report_aggregation, - leader_transition, + leader_transition: PrepareTransition::Continue(prep_state, prep_share), }); } @@ -488,23 +487,26 @@ impl AggregationJobDriver { .await } - async fn step_aggregation_job_aggregate_continue( + async fn step_aggregation_job_aggregate_continue< + const L: usize, + C: Clock, + A: vdaf::Aggregator, + >( &self, datastore: &Datastore, vdaf: &A, lease: Lease, task: Task, - aggregation_job: AggregationJob, - report_aggregations: Vec>, + aggregation_job: AggregationJob, + report_aggregations: Vec>, ) -> Result<()> where - C: Clock, - A: vdaf::Aggregator + 'static, + A: 'static, A::AggregationParam: Send + Sync, for<'a> &'a A::AggregateShare: Into>, A::OutputShare: Send + Sync, for<'a> &'a A::OutputShare: Into>, - A::PrepareStep: Send + Sync + Encode, + A::PrepareState: Send + Sync + Encode, A::PrepareMessage: Send + Sync, { // Visit the report aggregations, ignoring any that have already failed; compute our own @@ -520,15 +522,18 @@ impl AggregationJobDriver { .ok_or_else(|| anyhow!("report aggregation missing prepare message"))?; // Step our own state. - let leader_transition = - vdaf.prepare_step(prep_state.clone(), Some(prep_msg.clone())); - if let PrepareTransition::Fail(err) = leader_transition { - error!(report_nonce = %report_aggregation.nonce, ?err, "Couldn't step report aggregation"); - report_aggregation.state = - ReportAggregationState::Failed(ReportShareError::VdafPrepError); - report_aggregations_to_write.push(report_aggregation); - continue; - } + let leader_transition = match vdaf + .prepare_step(prep_state.clone(), prep_msg.clone()) + { + Ok(leader_transition) => leader_transition, + Err(err) => { + error!(report_nonce = %report_aggregation.nonce, ?err, "Couldn't step report aggregation"); + report_aggregation.state = + ReportAggregationState::Failed(ReportShareError::VdafPrepError); + report_aggregations_to_write.push(report_aggregation); + continue; + } + }; prepare_steps.push(PrepareStep { nonce: report_aggregation.nonce, @@ -576,25 +581,24 @@ impl AggregationJobDriver { } #[allow(clippy::too_many_arguments)] - async fn process_response_from_helper( + async fn process_response_from_helper>( &self, datastore: &Datastore, vdaf: &A, lease: Lease, - aggregation_job: AggregationJob, - stepped_aggregations: Vec>, - mut report_aggregations_to_write: Vec>, + aggregation_job: AggregationJob, + stepped_aggregations: Vec>, + mut report_aggregations_to_write: Vec>, prep_steps: Vec, ) -> Result<()> where - C: Clock, - A: vdaf::Aggregator + 'static, + A: 'static, A::AggregationParam: Send + Sync, for<'a> &'a A::AggregateShare: Into>, A::OutputShare: Send + Sync, for<'a> &'a A::OutputShare: Into>, A::PrepareMessage: Send + Sync, - A::PrepareStep: Send + Sync + Encode, + A::PrepareState: Send + Sync + Encode, { // Handle response, computing the new report aggregations to be stored. if stepped_aggregations.len() != prep_steps.len() { @@ -616,26 +620,26 @@ impl AggregationJobDriver { } match helper_prep_step.result { PrepareStepResult::Continued(payload) => { - // If the leader continued too, combine the leader's message with the helper's - // and prepare to store the leader's new state & the combined message for the - // next round. If the leader didn't continue, transition to INVALID. - if let PrepareTransition::Continue(leader_prep_state, leader_prep_msg) = + // If the leader continued too, combine the leader's prepare share with the + // helper's to compute next round's prepare message. Prepare to store the + // leader's new state & the prepare message. If the leader didn't continue, + // transition to INVALID. + if let PrepareTransition::Continue(leader_prep_state, leader_prep_share) = leader_transition { - let helper_prep_msg = - A::PrepareMessage::get_decoded_with_param(&leader_prep_state, &payload) + let helper_prep_share = + A::PrepareShare::get_decoded_with_param(&leader_prep_state, &payload) .context("couldn't decode helper's prepare message"); - let combined_prep_msg = helper_prep_msg.and_then(|helper_prep_msg| { - vdaf.prepare_preprocess([leader_prep_msg, helper_prep_msg]) - .context("couldn't combine leader & helper prepare messages") + let prep_msg = helper_prep_share.and_then(|helper_prep_share| { + vdaf.prepare_preprocess([leader_prep_share, helper_prep_share]) + .context("couldn't preprocess leader & helper prepare shares into prepare message") }); - report_aggregation.state = match combined_prep_msg { - Ok(combined_prep_msg) => ReportAggregationState::Waiting( - leader_prep_state, - Some(combined_prep_msg), - ), + report_aggregation.state = match prep_msg { + Ok(prep_msg) => { + ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)) + } Err(err) => { - error!(report_nonce = %report_aggregation.nonce, ?err, "Couldn't compute combined prepare message"); + error!(report_nonce = %report_aggregation.nonce, ?err, "Couldn't compute prepare message"); ReportAggregationState::Failed(ReportShareError::VdafPrepError) } } @@ -719,15 +723,15 @@ impl AggregationJobDriver { ) -> Result<()> { match &lease.leased().vdaf { VdafInstance::Real(janus::task::VdafInstance::Prio3Aes128Count) => { - self.cancel_aggregation_job_generic::(datastore, lease) + self.cancel_aggregation_job_generic::(datastore, lease) .await } VdafInstance::Real(janus::task::VdafInstance::Prio3Aes128Sum { .. }) => { - self.cancel_aggregation_job_generic::(datastore, lease) + self.cancel_aggregation_job_generic::(datastore, lease) .await } VdafInstance::Real(janus::task::VdafInstance::Prio3Aes128Histogram { .. }) => { - self.cancel_aggregation_job_generic::(datastore, lease) + self.cancel_aggregation_job_generic::(datastore, lease) .await } @@ -735,16 +739,15 @@ impl AggregationJobDriver { } } - async fn cancel_aggregation_job_generic( + async fn cancel_aggregation_job_generic>( &self, datastore: Arc>, lease: Lease, ) -> Result<()> where - A: vdaf::Aggregator + Send + Sync + 'static, + A: Send + Sync + 'static, A::AggregationParam: Send + Sync, for<'a> &'a A::AggregateShare: Into>, - C: Clock, { let lease = Arc::new(lease); let (task_id, aggregation_job_id) = @@ -754,7 +757,7 @@ impl AggregationJobDriver { let lease = Arc::clone(&lease); Box::pin(async move { let mut aggregation_job = tx - .get_aggregation_job::(task_id, aggregation_job_id) + .get_aggregation_job::(task_id, aggregation_job_id) .await? .ok_or_else(|| { datastore::Error::User( @@ -784,12 +787,12 @@ impl AggregationJobDriver { /// SteppedAggregation represents a report aggregation along with the associated preparation-state /// transition representing the next step for the leader. -struct SteppedAggregation +struct SteppedAggregation> where for<'a> &'a A::AggregateShare: Into>, { - report_aggregation: ReportAggregation, - leader_transition: PrepareTransition, + report_aggregation: ReportAggregation, + leader_transition: PrepareTransition, } #[cfg(test)] @@ -818,14 +821,17 @@ mod tests { AggregateContinueReq, AggregateContinueResp, AggregateInitializeReq, AggregateInitializeResp, AggregationJobId, PrepareStep, PrepareStepResult, ReportShare, }, - task::test_util::new_dummy_task, + task::{test_util::new_dummy_task, PRIO3_AES128_VERIFY_KEY_LENGTH}, trace::test_util::install_test_trace_subscriber, }; use janus_test_util::{run_vdaf, MockClock}; use mockito::mock; use prio::{ codec::Encode, - vdaf::{prio3::Prio3Aes128Count, PrepareTransition, Vdaf}, + vdaf::{ + prio3::{Prio3, Prio3Aes128Count}, + PrepareTransition, + }, }; use reqwest::Url; use std::{str, sync::Arc}; @@ -846,11 +852,8 @@ mod tests { let clock = MockClock::default(); let (ds, _db_handle) = ephemeral_datastore(clock.clone()).await; let ds = Arc::new(ds); - let vdaf = Prio3Aes128Count::new(2).unwrap(); - let (public_param, verify_params) = vdaf.setup().unwrap(); - let leader_verify_param = verify_params.get(Role::Leader.index().unwrap()).unwrap(); + let vdaf = Arc::new(Prio3::new_aes128_count(2).unwrap()); let nonce = Nonce::generate(&clock); - let transcript = run_vdaf(&vdaf, &public_param, &verify_params, &(), nonce, &0); let task_id = TaskId::random(); let mut task = new_dummy_task(task_id, VdafInstance::Prio3Aes128Count.into(), Role::Leader); @@ -858,10 +861,15 @@ mod tests { Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter Url::parse(&mockito::server_url()).unwrap(), ]; - task.vdaf_verify_parameters = vec![verify_params - .get(Role::Leader.index().unwrap()) + let verify_key = task + .vdaf_verify_keys + .get(0) .unwrap() - .get_encoded()]; + .clone() + .try_into() + .unwrap(); + + let transcript = run_vdaf(vdaf.as_ref(), &verify_key, &(), nonce, &0); let agg_auth_token = task.primary_aggregator_auth_token().clone(); let (leader_hpke_config, _) = task.hpke_keys.iter().next().unwrap().1; @@ -880,14 +888,20 @@ mod tests { tx.put_task(&task).await?; tx.put_client_report(&report).await?; - tx.put_aggregation_job(&AggregationJob:: { + tx.put_aggregation_job(&AggregationJob::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id, task_id, aggregation_param: (), state: AggregationJobState::InProgress, }) .await?; - tx.put_report_aggregation(&ReportAggregation:: { + tx.put_report_aggregation(&ReportAggregation::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id, task_id, nonce: report.nonce(), @@ -901,7 +915,7 @@ mod tests { .unwrap(); // Setup: prepare mocked HTTP responses. - let helper_vdaf_msg = assert_matches!(&transcript.transitions[Role::Helper.index().unwrap()][0], PrepareTransition::Continue(_, prep_msg) => prep_msg); + let helper_vdaf_msg = assert_matches!(&transcript.prepare_transitions[Role::Helper.index().unwrap()][0], PrepareTransition::Continue(_, prep_share) => prep_share); let helper_responses = vec![ ( AggregateInitializeReq::MEDIA_TYPE, @@ -949,6 +963,7 @@ mod tests { .build() .unwrap(), }); + // Run. Give the aggregation job driver enough time to step aggregation jobs, then kill it. let aggregation_job_driver = Arc::new(JobDriver::new( clock, @@ -1007,32 +1022,38 @@ mod tests { mocked_aggregate.assert(); } - let want_aggregation_job = AggregationJob:: { - aggregation_job_id, - task_id, - aggregation_param: (), - state: AggregationJobState::Finished, - }; - let leader_output_share = assert_matches!(&transcript.transitions[Role::Leader.index().unwrap()][1], PrepareTransition::Finish(leader_output_share) => leader_output_share.clone()); - let want_report_aggregation = ReportAggregation:: { - aggregation_job_id, - task_id, - nonce, - ord: 0, - state: ReportAggregationState::Finished(leader_output_share), - }; + let want_aggregation_job = + AggregationJob:: { + aggregation_job_id, + task_id, + aggregation_param: (), + state: AggregationJobState::Finished, + }; + let leader_output_share = assert_matches!(&transcript.prepare_transitions[Role::Leader.index().unwrap()][1], PrepareTransition::Finish(leader_output_share) => leader_output_share.clone()); + let want_report_aggregation = + ReportAggregation:: { + aggregation_job_id, + task_id, + nonce, + ord: 0, + state: ReportAggregationState::Finished(leader_output_share), + }; let (got_aggregation_job, got_report_aggregation) = ds .run_tx(|tx| { - let leader_verify_param = leader_verify_param.clone(); + let vdaf = Arc::clone(&vdaf); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::(task_id, aggregation_job_id) + .get_aggregation_job::( + task_id, + aggregation_job_id, + ) .await? .unwrap(); let report_aggregation = tx - .get_report_aggregation::( - &leader_verify_param, + .get_report_aggregation( + vdaf.as_ref(), + Role::Leader, task_id, aggregation_job_id, nonce, @@ -1056,11 +1077,8 @@ mod tests { let clock = MockClock::default(); let (ds, _db_handle) = ephemeral_datastore(clock.clone()).await; let ds = Arc::new(ds); - let vdaf = Prio3Aes128Count::new(2).unwrap(); - let (public_param, verify_params) = vdaf.setup().unwrap(); - let leader_verify_param = verify_params.get(Role::Leader.index().unwrap()).unwrap(); + let vdaf = Arc::new(Prio3::new_aes128_count(2).unwrap()); let nonce = Nonce::generate(&clock); - let transcript = run_vdaf(&vdaf, &public_param, &verify_params, &(), nonce, &0); let task_id = TaskId::random(); let mut task = new_dummy_task(task_id, VdafInstance::Prio3Aes128Count.into(), Role::Leader); @@ -1068,7 +1086,15 @@ mod tests { Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter Url::parse(&mockito::server_url()).unwrap(), ]; - task.vdaf_verify_parameters = vec![leader_verify_param.get_encoded()]; + let verify_key = task + .vdaf_verify_keys + .get(0) + .unwrap() + .clone() + .try_into() + .unwrap(); + + let transcript = run_vdaf(vdaf.as_ref(), &verify_key, &(), nonce, &0); let agg_auth_token = task.primary_aggregator_auth_token(); let (leader_hpke_config, _) = task.hpke_keys.iter().next().unwrap().1; @@ -1088,14 +1114,20 @@ mod tests { tx.put_task(&task).await?; tx.put_client_report(&report).await?; - tx.put_aggregation_job(&AggregationJob:: { + tx.put_aggregation_job(&AggregationJob::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id, task_id, aggregation_param: (), state: AggregationJobState::InProgress, }) .await?; - tx.put_report_aggregation(&ReportAggregation:: { + tx.put_report_aggregation(&ReportAggregation::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id, task_id, nonce: report.nonce(), @@ -1133,7 +1165,7 @@ mod tests { .clone(), }], }; - let helper_vdaf_msg = assert_matches!(&transcript.transitions[Role::Helper.index().unwrap()][0], PrepareTransition::Continue(_, prep_msg) => prep_msg); + let helper_vdaf_msg = assert_matches!(&transcript.prepare_transitions[Role::Helper.index().unwrap()][0], PrepareTransition::Continue(_, prep_share) => prep_share); let helper_response = AggregateInitializeResp { job_id: aggregation_job_id, prepare_steps: vec![PrepareStep { @@ -1165,33 +1197,39 @@ mod tests { // Verify. mocked_aggregate.assert(); - let want_aggregation_job = AggregationJob:: { - aggregation_job_id, - task_id, - aggregation_param: (), - state: AggregationJobState::InProgress, - }; - let leader_prep_state = assert_matches!(&transcript.transitions[Role::Leader.index().unwrap()][0], PrepareTransition::Continue(prep_state, _) => prep_state.clone()); - let combined_prep_msg = transcript.combined_messages[0].clone(); - let want_report_aggregation = ReportAggregation:: { - aggregation_job_id, - task_id, - nonce, - ord: 0, - state: ReportAggregationState::Waiting(leader_prep_state, Some(combined_prep_msg)), - }; + let want_aggregation_job = + AggregationJob:: { + aggregation_job_id, + task_id, + aggregation_param: (), + state: AggregationJobState::InProgress, + }; + let leader_prep_state = assert_matches!(&transcript.prepare_transitions[Role::Leader.index().unwrap()][0], PrepareTransition::Continue(prep_state, _) => prep_state.clone()); + let prep_msg = transcript.prepare_messages[0].clone(); + let want_report_aggregation = + ReportAggregation:: { + aggregation_job_id, + task_id, + nonce, + ord: 0, + state: ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)), + }; let (got_aggregation_job, got_report_aggregation) = ds .run_tx(|tx| { - let leader_verify_param = leader_verify_param.clone(); + let vdaf = Arc::clone(&vdaf); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::(task_id, aggregation_job_id) + .get_aggregation_job::( + task_id, + aggregation_job_id, + ) .await? .unwrap(); let report_aggregation = tx - .get_report_aggregation::( - &leader_verify_param, + .get_report_aggregation( + vdaf.as_ref(), + Role::Leader, task_id, aggregation_job_id, nonce, @@ -1216,11 +1254,8 @@ mod tests { let clock = MockClock::default(); let (ds, _db_handle) = ephemeral_datastore(clock.clone()).await; let ds = Arc::new(ds); - let vdaf = Prio3Aes128Count::new(2).unwrap(); - let (public_param, verify_params) = vdaf.setup().unwrap(); - let leader_verify_param = verify_params.get(Role::Leader.index().unwrap()).unwrap(); + let vdaf = Arc::new(Prio3::new_aes128_count(2).unwrap()); let nonce = Nonce::generate(&clock); - let transcript = run_vdaf(&vdaf, &public_param, &verify_params, &(), nonce, &0); let task_id = TaskId::random(); let mut task = new_dummy_task(task_id, VdafInstance::Prio3Aes128Count.into(), Role::Leader); @@ -1228,7 +1263,15 @@ mod tests { Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter Url::parse(&mockito::server_url()).unwrap(), ]; - task.vdaf_verify_parameters = vec![leader_verify_param.get_encoded()]; + let verify_key = task + .vdaf_verify_keys + .get(0) + .unwrap() + .clone() + .try_into() + .unwrap(); + + let transcript = run_vdaf(vdaf.as_ref(), &verify_key, &(), nonce, &0); let agg_auth_token = task.primary_aggregator_auth_token(); let (leader_hpke_config, _) = task.hpke_keys.iter().next().unwrap().1; @@ -1241,37 +1284,40 @@ mod tests { ); let aggregation_job_id = AggregationJobId::random(); - let leader_prep_state = assert_matches!(&transcript.transitions[Role::Leader.index().unwrap()][0], PrepareTransition::Continue(prep_state, _) => prep_state); - let combined_msg = &transcript.combined_messages[0]; + let leader_prep_state = assert_matches!(&transcript.prepare_transitions[Role::Leader.index().unwrap()][0], PrepareTransition::Continue(prep_state, _) => prep_state); + let prep_msg = &transcript.prepare_messages[0]; let lease = ds .run_tx(|tx| { - let (task, report, leader_prep_state, combined_msg) = ( + let (task, report, leader_prep_state, prep_msg) = ( task.clone(), report.clone(), leader_prep_state.clone(), - combined_msg.clone(), + prep_msg.clone(), ); Box::pin(async move { tx.put_task(&task).await?; tx.put_client_report(&report).await?; - tx.put_aggregation_job(&AggregationJob:: { + tx.put_aggregation_job(&AggregationJob::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id, task_id, aggregation_param: (), state: AggregationJobState::InProgress, }) .await?; - tx.put_report_aggregation(&ReportAggregation:: { + tx.put_report_aggregation(&ReportAggregation::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id, task_id, nonce: report.nonce(), ord: 0, - state: ReportAggregationState::Waiting( - leader_prep_state, - Some(combined_msg), - ), + state: ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)), }) .await?; @@ -1295,7 +1341,7 @@ mod tests { job_id: aggregation_job_id, prepare_steps: vec![PrepareStep { nonce, - result: PrepareStepResult::Continued(combined_msg.get_encoded()), + result: PrepareStepResult::Continued(prep_msg.get_encoded()), }], }; let helper_response = AggregateContinueResp { @@ -1329,32 +1375,38 @@ mod tests { // Verify. mocked_aggregate.assert(); - let want_aggregation_job = AggregationJob:: { - aggregation_job_id, - task_id, - aggregation_param: (), - state: AggregationJobState::Finished, - }; - let leader_output_share = assert_matches!(&transcript.transitions[Role::Leader.index().unwrap()][1], PrepareTransition::Finish(leader_output_share) => leader_output_share.clone()); - let want_report_aggregation = ReportAggregation:: { - aggregation_job_id, - task_id, - nonce, - ord: 0, - state: ReportAggregationState::Finished(leader_output_share), - }; + let want_aggregation_job = + AggregationJob:: { + aggregation_job_id, + task_id, + aggregation_param: (), + state: AggregationJobState::Finished, + }; + let leader_output_share = assert_matches!(&transcript.prepare_transitions[Role::Leader.index().unwrap()][1], PrepareTransition::Finish(leader_output_share) => leader_output_share.clone()); + let want_report_aggregation = + ReportAggregation:: { + aggregation_job_id, + task_id, + nonce, + ord: 0, + state: ReportAggregationState::Finished(leader_output_share), + }; let (got_aggregation_job, got_report_aggregation) = ds .run_tx(|tx| { - let leader_verify_param = leader_verify_param.clone(); + let vdaf = Arc::clone(&vdaf); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::(task_id, aggregation_job_id) + .get_aggregation_job::( + task_id, + aggregation_job_id, + ) .await? .unwrap(); let report_aggregation = tx - .get_report_aggregation::( - &leader_verify_param, + .get_report_aggregation( + vdaf.as_ref(), + Role::Leader, task_id, aggregation_job_id, nonce, @@ -1378,12 +1430,8 @@ mod tests { let clock = MockClock::default(); let (ds, _db_handle) = ephemeral_datastore(clock.clone()).await; let ds = Arc::new(ds); - let vdaf = Prio3Aes128Count::new(2).unwrap(); - let (public_param, verify_params) = vdaf.setup().unwrap(); - let leader_verify_param = verify_params.get(Role::Leader.index().unwrap()).unwrap(); + let vdaf = Arc::new(Prio3::new_aes128_count(2).unwrap()); let nonce = Nonce::generate(&clock); - let input_shares = - run_vdaf(&vdaf, &public_param, &verify_params, &(), nonce, &0).input_shares; let task_id = TaskId::random(); let mut task = new_dummy_task(task_id, VdafInstance::Prio3Aes128Count.into(), Role::Leader); @@ -1391,7 +1439,15 @@ mod tests { Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter Url::parse(&mockito::server_url()).unwrap(), ]; - task.vdaf_verify_parameters = vec![leader_verify_param.get_encoded()]; + let verify_key = task + .vdaf_verify_keys + .get(0) + .unwrap() + .clone() + .try_into() + .unwrap(); + + let input_shares = run_vdaf(vdaf.as_ref(), &verify_key, &(), nonce, &0).input_shares; let (leader_hpke_config, _) = task.hpke_keys.iter().next().unwrap().1; let (helper_hpke_config, _) = generate_hpke_config_and_private_key(); @@ -1403,19 +1459,20 @@ mod tests { ); let aggregation_job_id = AggregationJobId::random(); - let aggregation_job = AggregationJob:: { + let aggregation_job = AggregationJob:: { aggregation_job_id, task_id, aggregation_param: (), state: AggregationJobState::InProgress, }; - let report_aggregation = ReportAggregation:: { - aggregation_job_id, - task_id, - nonce, - ord: 0, - state: ReportAggregationState::Start, - }; + let report_aggregation = + ReportAggregation:: { + aggregation_job_id, + task_id, + nonce, + ord: 0, + state: ReportAggregationState::Start, + }; let lease = ds .run_tx(|tx| { @@ -1462,15 +1519,19 @@ mod tests { let (got_aggregation_job, got_report_aggregation, got_leases) = ds .run_tx(|tx| { - let leader_verify_param = leader_verify_param.clone(); + let vdaf = Arc::clone(&vdaf); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::(task_id, aggregation_job_id) + .get_aggregation_job::( + task_id, + aggregation_job_id, + ) .await? .unwrap(); let report_aggregation = tx - .get_report_aggregation::( - &leader_verify_param, + .get_report_aggregation( + vdaf.as_ref(), + Role::Leader, task_id, aggregation_job_id, nonce, diff --git a/janus_server/src/client.rs b/janus_server/src/client.rs index 631bbca9c..4543e413e 100644 --- a/janus_server/src/client.rs +++ b/janus_server/src/client.rs @@ -16,11 +16,11 @@ use url::Url; #[derive(Debug, thiserror::Error)] pub enum Error { - #[error("Invalid parameter {0}")] + #[error("invalid parameter {0}")] InvalidParameter(&'static str), #[error("HTTP client error: {0}")] HttpClient(#[from] reqwest::Error), - #[error("Codec error: {0}")] + #[error("codec error: {0}")] Codec(#[from] prio::codec::CodecError), #[error("HTTP response status {0}")] Http(StatusCode), @@ -118,7 +118,6 @@ where { parameters: ClientParameters, vdaf_client: V, - vdaf_public_parameter: V::PublicParam, clock: C, http_client: reqwest::Client, leader_hpke_config: HpkeConfig, @@ -132,7 +131,6 @@ where pub fn new( parameters: ClientParameters, vdaf_client: V, - vdaf_public_parameter: V::PublicParam, clock: C, http_client: &reqwest::Client, leader_hpke_config: HpkeConfig, @@ -144,7 +142,6 @@ where Self { parameters, vdaf_client, - vdaf_public_parameter, clock, http_client: http_client.clone(), leader_hpke_config, @@ -157,9 +154,7 @@ where /// share plus one proof share for each aggregator and then uploaded to the /// leader. pub async fn upload(&self, measurement: &V::Measurement) -> Result<(), Error> { - let input_shares = self - .vdaf_client - .shard(&self.vdaf_public_parameter, measurement)?; + let input_shares = self.vdaf_client.shard(measurement)?; assert_eq!(input_shares.len(), 2); // PPM only supports VDAFs using two aggregators. let nonce = Nonce::generate(&self.clock); @@ -215,13 +210,10 @@ mod tests { use janus::{hpke::test_util::generate_hpke_config_and_private_key, message::TaskId}; use janus_test_util::MockClock; use mockito::mock; - use prio::vdaf::prio3::{Prio3Aes128Count, Prio3Aes128Sum}; + use prio::vdaf::prio3::Prio3; use url::Url; - fn setup_client( - vdaf_client: V, - public_parameter: V::PublicParam, - ) -> Client + fn setup_client(vdaf_client: V) -> Client where for<'a> &'a V::AggregateShare: Into>, { @@ -241,7 +233,6 @@ mod tests { Client::new( client_parameters, vdaf_client, - public_parameter, clock, &default_http_client().unwrap(), leader_hpke_config, @@ -258,8 +249,7 @@ mod tests { .expect(1) .create(); - let client = setup_client(Prio3Aes128Count::new(2).unwrap(), ()); - + let client = setup_client(Prio3::new_aes128_count(2).unwrap()); client.upload(&1).await.unwrap(); mocked_upload.assert(); @@ -268,9 +258,9 @@ mod tests { #[tokio::test] async fn upload_prio3_invalid_measurement() { install_test_trace_subscriber(); - let vdaf = Prio3Aes128Sum::new(2, 16).unwrap(); + let vdaf = Prio3::new_aes128_sum(2, 16).unwrap(); - let client = setup_client(vdaf, ()); + let client = setup_client(vdaf); // 65536 is too big for a 16 bit sum and will be rejected by the VDAF. // Make sure we get the right error variant but otherwise we aren't // picky about its contents. @@ -287,7 +277,7 @@ mod tests { .expect(1) .create(); - let client = setup_client(Prio3Aes128Count::new(2).unwrap(), ()); + let client = setup_client(Prio3::new_aes128_count(2).unwrap()); assert_matches!( client.upload(&1).await, Err(Error::Http(StatusCode::NOT_IMPLEMENTED)) diff --git a/janus_server/src/datastore.rs b/janus_server/src/datastore.rs index cebda5bd8..06e45cfa6 100644 --- a/janus_server/src/datastore.rs +++ b/janus_server/src/datastore.rs @@ -9,6 +9,7 @@ use crate::{ message::{AggregateShareReq, AggregationJobId, ReportShare}, task::{self, AggregatorAuthenticationToken, Task, VdafInstance}, }; +use anyhow::anyhow; use chrono::NaiveDateTime; use futures::try_join; use janus::{ @@ -248,34 +249,34 @@ impl Transaction<'_, C> { ]; let hpke_configs_future = self.tx.execute(&stmt, hpke_configs_params); - // VDAF verification parameters. - let mut vdaf_verify_params: Vec> = Vec::new(); - for vdaf_verify_param in task.vdaf_verify_parameters.iter() { - let encrypted_vdaf_verify_param = self.crypter.encrypt( - "task_vdaf_verify_params", + // VDAF verification keys. + let mut vdaf_verify_keys: Vec> = Vec::new(); + for vdaf_verify_key in task.vdaf_verify_keys.iter() { + let encrypted_vdaf_verify_key = self.crypter.encrypt( + "task_vdaf_verify_keys", task.id.as_bytes(), - "vdaf_verify_param", - vdaf_verify_param.as_ref(), + "vdaf_verify_key", + vdaf_verify_key.as_ref(), )?; - vdaf_verify_params.push(encrypted_vdaf_verify_param); + vdaf_verify_keys.push(encrypted_vdaf_verify_key); } let stmt = self .tx .prepare_cached( - "INSERT INTO task_vdaf_verify_params (task_id, vdaf_verify_param) + "INSERT INTO task_vdaf_verify_keys (task_id, vdaf_verify_key) SELECT (SELECT id FROM tasks WHERE task_id = $1), * FROM UNNEST($2::BYTEA[])", ) .await?; - let vdaf_verify_params_params: &[&(dyn ToSql + Sync)] = &[ + let vdaf_verify_keys_params: &[&(dyn ToSql + Sync)] = &[ /* task_id */ &task.id.as_bytes(), - /* vdaf_verify_params */ &vdaf_verify_params, + /* vdaf_verify_keys */ &vdaf_verify_keys, ]; - let vdaf_verify_params_future = self.tx.execute(&stmt, vdaf_verify_params_params); + let vdaf_verify_keys_future = self.tx.execute(&stmt, vdaf_verify_keys_params); try_join!( auth_tokens_future, hpke_configs_future, - vdaf_verify_params_future + vdaf_verify_keys_future )?; Ok(()) @@ -316,17 +317,17 @@ impl Transaction<'_, C> { let stmt = self .tx .prepare_cached( - "SELECT vdaf_verify_param FROM task_vdaf_verify_params + "SELECT vdaf_verify_key FROM task_vdaf_verify_keys WHERE task_id = (SELECT id FROM tasks WHERE task_id = $1)", ) .await?; - let vdaf_verify_param_rows = self.tx.query(&stmt, params); + let vdaf_verify_key_rows = self.tx.query(&stmt, params); - let (task_row, agg_auth_token_rows, hpke_key_rows, vdaf_verify_param_rows) = try_join!( + let (task_row, agg_auth_token_rows, hpke_key_rows, vdaf_verify_key_rows) = try_join!( task_row, agg_auth_token_rows, hpke_key_rows, - vdaf_verify_param_rows, + vdaf_verify_key_rows, )?; task_row .map(|task_row| { @@ -335,7 +336,7 @@ impl Transaction<'_, C> { task_row, agg_auth_token_rows, hpke_key_rows, - vdaf_verify_param_rows, + vdaf_verify_key_rows, ) }) .transpose() @@ -376,16 +377,16 @@ impl Transaction<'_, C> { let hpke_config_rows = self.tx.query(&stmt, &[]); let stmt = self.tx.prepare_cached( - "SELECT (SELECT tasks.task_id FROM tasks WHERE tasks.id = task_vdaf_verify_params.task_id), - vdaf_verify_param FROM task_vdaf_verify_params" + "SELECT (SELECT tasks.task_id FROM tasks WHERE tasks.id = task_vdaf_verify_keys.task_id), + vdaf_verify_key FROM task_vdaf_verify_keys" ).await?; - let vdaf_verify_param_rows = self.tx.query(&stmt, &[]); + let vdaf_verify_key_rows = self.tx.query(&stmt, &[]); - let (task_rows, agg_auth_token_rows, hpke_config_rows, vdaf_verify_param_rows) = try_join!( + let (task_rows, agg_auth_token_rows, hpke_config_rows, vdaf_verify_key_rows) = try_join!( task_rows, agg_auth_token_rows, hpke_config_rows, - vdaf_verify_param_rows + vdaf_verify_key_rows )?; let mut task_row_by_id = Vec::new(); @@ -412,10 +413,10 @@ impl Transaction<'_, C> { .push(row); } - let mut vdaf_verify_param_rows_by_task_id: HashMap> = HashMap::new(); - for row in vdaf_verify_param_rows { + let mut vdaf_verify_key_rows_by_task_id: HashMap> = HashMap::new(); + for row in vdaf_verify_key_rows { let task_id = TaskId::get_decoded(row.get("task_id"))?; - vdaf_verify_param_rows_by_task_id + vdaf_verify_key_rows_by_task_id .entry(task_id) .or_default() .push(row); @@ -433,7 +434,7 @@ impl Transaction<'_, C> { hpke_config_rows_by_task_id .remove(&task_id) .unwrap_or_default(), - vdaf_verify_param_rows_by_task_id + vdaf_verify_key_rows_by_task_id .remove(&task_id) .unwrap_or_default(), ) @@ -451,7 +452,7 @@ impl Transaction<'_, C> { row: Row, agg_auth_token_rows: Vec, hpke_key_rows: Vec, - vdaf_verify_param_rows: Vec, + vdaf_verify_key_rows: Vec, ) -> Result { // Scalar task parameters. let aggregator_role: AggregatorRole = row.get("aggregator_role"); @@ -508,15 +509,14 @@ impl Transaction<'_, C> { hpke_configs.push((config, private_key)); } - let mut vdaf_verify_params = Vec::new(); - for row in vdaf_verify_param_rows { - let encrypted_vdaf_verify_param: Vec = row.get("vdaf_verify_param"); - - vdaf_verify_params.push(self.crypter.decrypt( - "task_vdaf_verify_params", + let mut vdaf_verify_keys = Vec::new(); + for row in vdaf_verify_key_rows { + let encrypted_vdaf_verify_key: Vec = row.get("vdaf_verify_key"); + vdaf_verify_keys.push(self.crypter.decrypt( + "task_vdaf_verify_keys", task_id.as_bytes(), - "vdaf_verify_param", - &encrypted_vdaf_verify_param, + "vdaf_verify_key", + &encrypted_vdaf_verify_key, )?); } @@ -525,7 +525,7 @@ impl Transaction<'_, C> { endpoints, vdaf, aggregator_role.as_role(), - vdaf_verify_params, + vdaf_verify_keys, max_batch_lifetime, min_batch_size, min_batch_duration, @@ -688,11 +688,11 @@ impl Transaction<'_, C> { /// get_aggregation_job retrieves an aggregation job by ID. #[tracing::instrument(skip(self), err)] - pub async fn get_aggregation_job( + pub async fn get_aggregation_job>( &self, task_id: TaskId, aggregation_job_id: AggregationJobId, - ) -> Result>, Error> + ) -> Result>, Error> where for<'a> &'a A::AggregateShare: Into>, { @@ -721,10 +721,10 @@ impl Transaction<'_, C> { /// intended for use in tests. #[tracing::instrument(skip(self), err)] #[doc(hidden)] - pub async fn get_aggregation_jobs_for_task_id( + pub async fn get_aggregation_jobs_for_task_id>( &self, task_id: TaskId, - ) -> Result>, Error> + ) -> Result>, Error> where for<'a> &'a A::AggregateShare: Into>, { @@ -748,11 +748,11 @@ impl Transaction<'_, C> { .collect() } - fn aggregation_job_from_row( + fn aggregation_job_from_row>( task_id: TaskId, aggregation_job_id: AggregationJobId, row: Row, - ) -> Result, Error> + ) -> Result, Error> where for<'a> &'a A::AggregateShare: Into>, { @@ -873,9 +873,9 @@ impl Transaction<'_, C> { /// put_aggregation_job stores an aggregation job. #[tracing::instrument(skip(self), err)] - pub async fn put_aggregation_job( + pub async fn put_aggregation_job>( &self, - aggregation_job: &AggregationJob, + aggregation_job: &AggregationJob, ) -> Result<(), Error> where for<'a> &'a A::AggregateShare: Into>, @@ -900,9 +900,9 @@ impl Transaction<'_, C> { /// update_aggregation_job updates a stored aggregation job. #[tracing::instrument(skip(self), err)] - pub async fn update_aggregation_job( + pub async fn update_aggregation_job>( &self, - aggregation_job: &AggregationJob, + aggregation_job: &AggregationJob, ) -> Result<(), Error> where for<'a> &'a A::AggregateShare: Into>, @@ -933,15 +933,16 @@ impl Transaction<'_, C> { /// get_report_aggregation gets a report aggregation by ID. #[tracing::instrument(skip(self), err)] - pub async fn get_report_aggregation( + pub async fn get_report_aggregation>( &self, - verify_param: &A::VerifyParam, + vdaf: &A, + role: Role, task_id: TaskId, aggregation_job_id: AggregationJobId, nonce: Nonce, - ) -> Result>, Error> + ) -> Result>, Error> where - A::PrepareStep: ParameterizedDecode, + for<'a> A::PrepareState: ParameterizedDecode<(&'a A, usize)>, A::OutputShare: for<'a> TryFrom<&'a [u8]>, for<'a> &'a A::AggregateShare: Into>, { @@ -973,21 +974,25 @@ impl Transaction<'_, C> { ], ) .await? - .map(|row| report_aggregation_from_row(verify_param, task_id, aggregation_job_id, row)) + .map(|row| report_aggregation_from_row(vdaf, role, task_id, aggregation_job_id, row)) .transpose() } /// get_report_aggregations_for_aggregation_job retrieves all report aggregations associated /// with a given aggregation job, ordered by their natural ordering. #[tracing::instrument(skip(self), err)] - pub async fn get_report_aggregations_for_aggregation_job( + pub async fn get_report_aggregations_for_aggregation_job< + const L: usize, + A: vdaf::Aggregator, + >( &self, - verify_param: &A::VerifyParam, + vdaf: &A, + role: Role, task_id: TaskId, aggregation_job_id: AggregationJobId, - ) -> Result>, Error> + ) -> Result>, Error> where - A::PrepareStep: ParameterizedDecode, + for<'a> A::PrepareState: ParameterizedDecode<(&'a A, usize)>, A::OutputShare: for<'a> TryFrom<&'a [u8]>, for<'a> &'a A::AggregateShare: Into>, { @@ -1014,18 +1019,18 @@ impl Transaction<'_, C> { ) .await? .into_iter() - .map(|row| report_aggregation_from_row(verify_param, task_id, aggregation_job_id, row)) + .map(|row| report_aggregation_from_row(vdaf, role, task_id, aggregation_job_id, row)) .collect() } /// put_report_aggregation stores aggregation data for a single report. #[tracing::instrument(skip(self), err)] - pub async fn put_report_aggregation( + pub async fn put_report_aggregation>( &self, - report_aggregation: &ReportAggregation, + report_aggregation: &ReportAggregation, ) -> Result<(), Error> where - A::PrepareStep: Encode, + A::PrepareState: Encode, for<'a> &'a A::OutputShare: Into>, for<'a> &'a A::AggregateShare: Into>, { @@ -1065,12 +1070,12 @@ impl Transaction<'_, C> { } #[tracing::instrument(skip(self), err)] - pub async fn update_report_aggregation( + pub async fn update_report_aggregation>( &self, - report_aggregation: &ReportAggregation, + report_aggregation: &ReportAggregation, ) -> Result<(), Error> where - A::PrepareStep: Encode, + A::PrepareState: Encode, for<'a> &'a A::OutputShare: Into>, for<'a> &'a A::AggregateShare: Into>, { @@ -1134,12 +1139,11 @@ impl Transaction<'_, C> { /// Returns the collect job for the provided UUID, or `None` if no such collect job exists. #[tracing::instrument(skip(self), err)] - pub(crate) async fn get_collect_job( + pub(crate) async fn get_collect_job>( &self, collect_job_id: Uuid, - ) -> Result>, Error> + ) -> Result>, Error> where - A: vdaf::Aggregator, for<'a> >::Error: std::fmt::Display, for<'a> &'a A::AggregateShare: Into>, { @@ -1406,14 +1410,13 @@ ORDER BY id DESC /// Updates an existing collect job with the provided aggregate shares. // TODO(#242): update this function to take a CollectJob. #[tracing::instrument(skip(self), err)] - pub(crate) async fn update_collect_job( + pub(crate) async fn update_collect_job>( &self, collect_job_id: Uuid, leader_aggregate_share: &A::AggregateShare, helper_aggregate_share: &HpkeCiphertext, ) -> Result<(), Error> where - A: vdaf::Aggregator, for<'a> >::Error: std::fmt::Display, for<'a> &'a A::AggregateShare: Into>, { @@ -1465,12 +1468,11 @@ ORDER BY id DESC /// Store a new `batch_unit_aggregations` row in the datastore. #[tracing::instrument(skip(self), err)] - pub(crate) async fn put_batch_unit_aggregation( + pub(crate) async fn put_batch_unit_aggregation>( &self, - batch_unit_aggregation: &BatchUnitAggregation, + batch_unit_aggregation: &BatchUnitAggregation, ) -> Result<(), Error> where - A: vdaf::Aggregator, A::AggregationParam: Encode + std::fmt::Debug, A::AggregateShare: std::fmt::Debug, for<'a> &'a A::AggregateShare: Into>, @@ -1510,12 +1512,11 @@ ORDER BY id DESC /// Update an existing `batch_unit_aggregations` row with the `aggregate_share`, `checksum` and /// `report_count` values in `batch_unit_aggregation`. #[tracing::instrument(skip(self), err)] - pub(crate) async fn update_batch_unit_aggregation( + pub(crate) async fn update_batch_unit_aggregation>( &self, - batch_unit_aggregation: &BatchUnitAggregation, + batch_unit_aggregation: &BatchUnitAggregation, ) -> Result<(), Error> where - A: vdaf::Aggregator, A::AggregationParam: Encode + std::fmt::Debug, A::AggregateShare: std::fmt::Debug, for<'a> &'a A::AggregateShare: Into>, @@ -1561,14 +1562,16 @@ ORDER BY id DESC /// Fetch all the `batch_unit_aggregations` rows whose `unit_interval_start` describes an /// interval that falls within the provided `interval` and whose `aggregation_param` matches. #[tracing::instrument(skip(self, aggregation_param), err)] - pub(crate) async fn get_batch_unit_aggregations_for_task_in_interval( + pub(crate) async fn get_batch_unit_aggregations_for_task_in_interval< + const L: usize, + A: vdaf::Aggregator, + >( &self, task_id: TaskId, interval: Interval, aggregation_param: &A::AggregationParam, - ) -> Result>, Error> + ) -> Result>, Error> where - A: vdaf::Aggregator, A::AggregationParam: Encode + Clone, for<'a> >::Error: std::fmt::Display, for<'a> &'a A::AggregateShare: Into>, @@ -1627,12 +1630,11 @@ ORDER BY id DESC /// Fetch an `aggregate_share_jobs` row from the datastore corresponding to the provided /// [`AggregateShareRequest`], or `None` if no such job exists. #[tracing::instrument(skip(self), err)] - pub(crate) async fn get_aggregate_share_job_by_request( + pub(crate) async fn get_aggregate_share_job_by_request>( &self, request: &AggregateShareReq, - ) -> Result>, Error> + ) -> Result>, Error> where - A: vdaf::Aggregator, for<'a> >::Error: std::fmt::Display, for<'a> &'a A::AggregateShare: Into>, { @@ -1749,12 +1751,11 @@ ORDER BY id DESC /// Put an `aggregate_share_job` row into the datastore. #[tracing::instrument(skip(self), err)] - pub(crate) async fn put_aggregate_share_job( + pub(crate) async fn put_aggregate_share_job>( &self, - job: &AggregateShareJob, + job: &AggregateShareJob, ) -> Result<(), Error> where - A: vdaf::Aggregator, for<'a> &'a A::AggregateShare: Into>, for<'a> >::Error: std::fmt::Display, { @@ -1804,14 +1805,15 @@ fn check_update(row_count: u64) -> Result<(), Error> { } } -fn report_aggregation_from_row( - verify_param: &A::VerifyParam, +fn report_aggregation_from_row>( + vdaf: &A, + role: Role, task_id: TaskId, aggregation_job_id: AggregationJobId, row: Row, -) -> Result, Error> +) -> Result, Error> where - A::PrepareStep: ParameterizedDecode, + for<'a> A::PrepareState: ParameterizedDecode<(&'a A, usize)>, A::OutputShare: for<'a> TryFrom<&'a [u8]>, for<'a> &'a A::AggregateShare: Into>, { @@ -1844,8 +1846,11 @@ where let agg_state = match state { ReportAggregationStateCode::Start => ReportAggregationState::Start, ReportAggregationStateCode::Waiting => { - let prep_state = A::PrepareStep::get_decoded_with_param( - verify_param, + let agg_index = role + .index() + .ok_or_else(|| Error::User(anyhow!("unexpected role: {}", role.as_str()).into()))?; + let prep_state = A::PrepareState::get_decoded_with_param( + &(vdaf, agg_index), &prep_state_bytes.ok_or_else(|| { Error::DbState( "report aggregation in state WAITING but prep_state is NULL".to_string(), @@ -2177,7 +2182,7 @@ pub mod models { /// AggregationJob represents an aggregation job from the PPM specification. #[derive(Clone, Debug)] - pub struct AggregationJob + pub struct AggregationJob> where for<'a> &'a A::AggregateShare: Into>, { @@ -2187,7 +2192,7 @@ pub mod models { pub state: AggregationJobState, } - impl PartialEq for AggregationJob + impl> PartialEq for AggregationJob where A::AggregationParam: PartialEq, for<'a> &'a A::AggregateShare: Into>, @@ -2200,7 +2205,7 @@ pub mod models { } } - impl Eq for AggregationJob + impl> Eq for AggregationJob where A::AggregationParam: Eq, for<'a> &'a A::AggregateShare: Into>, @@ -2297,7 +2302,7 @@ pub mod models { /// ReportAggregation represents a the state of a single client report's ongoing aggregation. #[derive(Clone, Debug)] - pub struct ReportAggregation + pub struct ReportAggregation> where for<'a> &'a A::AggregateShare: Into>, { @@ -2305,12 +2310,12 @@ pub mod models { pub task_id: TaskId, pub nonce: Nonce, pub ord: i64, - pub state: ReportAggregationState, + pub state: ReportAggregationState, } - impl PartialEq for ReportAggregation + impl> PartialEq for ReportAggregation where - A::PrepareStep: PartialEq, + A::PrepareState: PartialEq, A::PrepareMessage: PartialEq, A::OutputShare: PartialEq, for<'a> &'a A::AggregateShare: Into>, @@ -2324,9 +2329,9 @@ pub mod models { } } - impl Eq for ReportAggregation + impl> Eq for ReportAggregation where - A::PrepareStep: Eq, + A::PrepareState: Eq, A::PrepareMessage: Eq, A::OutputShare: Eq, for<'a> &'a A::AggregateShare: Into>, @@ -2336,18 +2341,18 @@ pub mod models { /// ReportAggregationState represents the state of a single report aggregation. It corresponds /// to the REPORT_AGGREGATION_STATE enum in the schema, along with the state-specific data. #[derive(Clone, Debug)] - pub enum ReportAggregationState + pub enum ReportAggregationState> where for<'a> &'a A::AggregateShare: Into>, { Start, - Waiting(A::PrepareStep, Option), + Waiting(A::PrepareState, Option), Finished(A::OutputShare), Failed(ReportShareError), Invalid, } - impl ReportAggregationState + impl> ReportAggregationState where for<'a> &'a A::AggregateShare: Into>, { @@ -2366,14 +2371,14 @@ pub mod models { /// message, output share, transition error. pub(super) fn encoded_values_from_state(&self) -> EncodedReportAggregationStateValues where - A::PrepareStep: Encode, + A::PrepareState: Encode, for<'a> &'a A::OutputShare: Into>, for<'a> &'a A::AggregateShare: Into>, { let (prep_state, prep_msg, output_share, report_share_err) = match self { ReportAggregationState::Start => (None, None, None, None), - ReportAggregationState::Waiting(prep_step, prep_msg) => ( - Some(prep_step.get_encoded()), + ReportAggregationState::Waiting(prep_state, prep_msg) => ( + Some(prep_state.get_encoded()), prep_msg.as_ref().map(|msg| msg.get_encoded()), None, None, @@ -2421,9 +2426,9 @@ pub mod models { Invalid, } - impl PartialEq for ReportAggregationState + impl> PartialEq for ReportAggregationState where - A::PrepareStep: PartialEq, + A::PrepareState: PartialEq, A::PrepareMessage: PartialEq, A::OutputShare: PartialEq, for<'a> &'a A::AggregateShare: Into>, @@ -2431,9 +2436,9 @@ pub mod models { fn eq(&self, other: &Self) -> bool { match (self, other) { ( - Self::Waiting(lhs_prep_step, lhs_prep_msg), - Self::Waiting(rhs_prep_step, rhs_prep_msg), - ) => lhs_prep_step == rhs_prep_step && lhs_prep_msg == rhs_prep_msg, + Self::Waiting(lhs_prep_state, lhs_prep_msg), + Self::Waiting(rhs_prep_state, rhs_prep_msg), + ) => lhs_prep_state == rhs_prep_state && lhs_prep_msg == rhs_prep_msg, (Self::Finished(lhs_out_share), Self::Finished(rhs_out_share)) => { lhs_out_share == rhs_out_share } @@ -2445,9 +2450,9 @@ pub mod models { } } - impl Eq for ReportAggregationState + impl> Eq for ReportAggregationState where - A::PrepareStep: Eq, + A::PrepareState: Eq, A::PrepareMessage: Eq, A::OutputShare: Eq, for<'a> &'a A::AggregateShare: Into>, @@ -2462,7 +2467,7 @@ pub mod models { /// consists of one or more `BatchUnitAggregation`s merged together. #[derive(Clone, Derivative)] #[derivative(Debug)] - pub(crate) struct BatchUnitAggregation + pub(crate) struct BatchUnitAggregation> where for<'a> &'a A::AggregateShare: Into>, { @@ -2486,7 +2491,7 @@ pub mod models { pub(crate) checksum: NonceChecksum, } - impl PartialEq for BatchUnitAggregation + impl> PartialEq for BatchUnitAggregation where A::AggregationParam: PartialEq, A::AggregateShare: PartialEq, @@ -2502,7 +2507,7 @@ pub mod models { } } - impl Eq for BatchUnitAggregation + impl> Eq for BatchUnitAggregation where A::AggregationParam: Eq, A::AggregateShare: Eq, @@ -2514,7 +2519,7 @@ pub mod models { /// running collect jobs and store the results of completed ones. #[derive(Clone, Derivative)] #[derivative(Debug)] - pub(crate) struct CollectJob + pub(crate) struct CollectJob> where for<'a> &'a A::AggregateShare: Into>, { @@ -2528,14 +2533,14 @@ pub mod models { #[derivative(Debug = "ignore")] pub(crate) aggregation_param: A::AggregationParam, /// The current state of the collect job. - pub(crate) state: CollectJobState, + pub(crate) state: CollectJobState, } - impl PartialEq for CollectJob + impl> PartialEq for CollectJob where for<'a> &'a A::AggregateShare: Into>, A::AggregationParam: PartialEq, - CollectJobState: PartialEq, + CollectJobState: PartialEq, { fn eq(&self, other: &Self) -> bool { self.collect_job_id == other.collect_job_id @@ -2546,17 +2551,17 @@ pub mod models { } } - impl Eq for CollectJob + impl> Eq for CollectJob where for<'a> &'a A::AggregateShare: Into>, A::AggregationParam: Eq, - CollectJobState: Eq, + CollectJobState: Eq, { } #[derive(Clone, Derivative)] #[derivative(Debug)] - pub(crate) enum CollectJobState + pub(crate) enum CollectJobState> where for<'a> &'a A::AggregateShare: Into>, { @@ -2572,7 +2577,7 @@ pub mod models { Abandoned, } - impl PartialEq for CollectJobState + impl> PartialEq for CollectJobState where for<'a> &'a A::AggregateShare: Into>, A::AggregateShare: PartialEq, @@ -2597,7 +2602,7 @@ pub mod models { } } - impl Eq for CollectJobState + impl> Eq for CollectJobState where for<'a> &'a A::AggregateShare: Into>, A::AggregateShare: Eq, @@ -2619,7 +2624,7 @@ pub mod models { /// store the results of handling an AggregateShareReq from the leader. #[derive(Clone, Derivative)] #[derivative(Debug)] - pub(crate) struct AggregateShareJob + pub(crate) struct AggregateShareJob> where for<'a> &'a A::AggregateShare: Into>, { @@ -2640,7 +2645,7 @@ pub mod models { pub(crate) checksum: NonceChecksum, } - impl PartialEq for AggregateShareJob + impl> PartialEq for AggregateShareJob where A::AggregationParam: PartialEq, A::AggregateShare: PartialEq, @@ -2656,7 +2661,7 @@ pub mod models { } } - impl Eq for AggregateShareJob + impl> Eq for AggregateShareJob where A::AggregationParam: Eq, A::AggregateShare: Eq, @@ -2682,7 +2687,7 @@ mod tests { test_util::ephemeral_datastore, }, message::{test_util::new_dummy_report, ReportShareError}, - task::{test_util::new_dummy_task, VdafInstance}, + task::{test_util::new_dummy_task, VdafInstance, PRIO3_AES128_VERIFY_KEY_LENGTH}, trace::test_util::install_test_trace_subscriber, }; use ::janus_test_util::{dummy_vdaf, generate_aead_key, MockClock}; @@ -2697,13 +2702,14 @@ mod tests { vdaf::{ poplar1::{IdpfInput, Poplar1, ToyIdpf}, prg::PrgAes128, - prio3::Prio3Aes128Count, + prio3::{Prio3, Prio3Aes128Count}, AggregateShare, PrepareTransition, }, }; use std::{ collections::{BTreeSet, HashMap, HashSet}, iter, mem, + sync::Arc, }; #[tokio::test] @@ -2942,20 +2948,28 @@ mod tests { tx.put_client_report(&unrelated_report).await?; let aggregation_job_id = AggregationJobId::random(); - tx.put_aggregation_job(&AggregationJob:: { + tx.put_aggregation_job(&AggregationJob::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id, task_id: unrelated_task_id, aggregation_param: (), state: AggregationJobState::InProgress, }) .await?; - tx.put_report_aggregation(&ReportAggregation { - aggregation_job_id, - task_id, - nonce: aggregated_report.nonce(), - ord: 0, - state: ReportAggregationState::::Start, - }) + tx.put_report_aggregation( + &ReportAggregation { + aggregation_job_id, + task_id, + nonce: aggregated_report.nonce(), + ord: 0, + state: ReportAggregationState::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + >::Start, + }, + ) .await }) }) @@ -3058,8 +3072,9 @@ mod tests { // We use Poplar1 for this test as it has a non-trivial aggregation parameter, to allow // better exercising the serialization/deserialization roundtrip of the aggregation_param. - type ToyPoplar1 = Poplar1, PrgAes128, 16>; - let aggregation_job = AggregationJob:: { + const PRG_SEED_SIZE: usize = 16; + type ToyPoplar1 = Poplar1, PrgAes128, PRG_SEED_SIZE>; + let aggregation_job = AggregationJob:: { aggregation_job_id: AggregationJobId::random(), task_id: TaskId::random(), aggregation_param: BTreeSet::from([ @@ -3148,7 +3163,10 @@ mod tests { )) .await?; for aggregation_job_id in aggregation_job_ids { - tx.put_aggregation_job(&AggregationJob:: { + tx.put_aggregation_job(&AggregationJob::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id, task_id, aggregation_param: (), @@ -3158,7 +3176,10 @@ mod tests { } // Write an aggregation job that is finished. We don't want to retrieve this one. - tx.put_aggregation_job(&AggregationJob:: { + tx.put_aggregation_job(&AggregationJob::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id: AggregationJobId::random(), task_id, aggregation_param: (), @@ -3175,7 +3196,10 @@ mod tests { Role::Helper, )) .await?; - tx.put_aggregation_job(&AggregationJob:: { + tx.put_aggregation_job(&AggregationJob::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id: AggregationJobId::random(), task_id: helper_task_id, aggregation_param: (), @@ -3385,7 +3409,7 @@ mod tests { let rslt = ds .run_tx(|tx| { Box::pin(async move { - tx.get_aggregation_job::( + tx.get_aggregation_job::( TaskId::random(), AggregationJobId::random(), ) @@ -3399,12 +3423,14 @@ mod tests { let rslt = ds .run_tx(|tx| { Box::pin(async move { - tx.update_aggregation_job::(&AggregationJob { - aggregation_job_id: AggregationJobId::random(), - task_id: TaskId::random(), - aggregation_param: (), - state: AggregationJobState::InProgress, - }) + tx.update_aggregation_job::( + &AggregationJob { + aggregation_job_id: AggregationJobId::random(), + task_id: TaskId::random(), + aggregation_param: (), + state: AggregationJobState::InProgress, + }, + ) .await }) }) @@ -3420,9 +3446,10 @@ mod tests { // We use Poplar1 for this test as it has a non-trivial aggregation parameter, to allow // better exercising the serialization/deserialization roundtrip of the aggregation_param. - type ToyPoplar1 = Poplar1, PrgAes128, 16>; + const PRG_SEED_SIZE: usize = 16; + type ToyPoplar1 = Poplar1, PrgAes128, PRG_SEED_SIZE>; let task_id = TaskId::random(); - let first_aggregation_job = AggregationJob:: { + let first_aggregation_job = AggregationJob:: { aggregation_job_id: AggregationJobId::random(), task_id, aggregation_param: BTreeSet::from([ @@ -3431,7 +3458,7 @@ mod tests { ]), state: AggregationJobState::InProgress, }; - let second_aggregation_job = AggregationJob:: { + let second_aggregation_job = AggregationJob:: { aggregation_job_id: AggregationJobId::random(), task_id, aggregation_param: BTreeSet::from([ @@ -3465,7 +3492,7 @@ mod tests { Role::Leader, )) .await?; - tx.put_aggregation_job(&AggregationJob:: { + tx.put_aggregation_job(&AggregationJob:: { aggregation_job_id: AggregationJobId::random(), task_id: unrelated_task_id, aggregation_param: BTreeSet::from([ @@ -3500,13 +3527,13 @@ mod tests { install_test_trace_subscriber(); let (ds, _db_handle) = ephemeral_datastore(MockClock::default()).await; - let vdaf = Prio3Aes128Count::new(2).unwrap(); - let (verify_param, prep_step, prep_msg, output_share) = generate_vdaf_values(vdaf, (), 0); + let vdaf = Arc::new(Prio3::new_aes128_count(2).unwrap()); + let (prep_state, prep_msg, output_share) = generate_vdaf_values(vdaf.as_ref(), (), 0); for (ord, state) in [ - ReportAggregationState::::Start, - ReportAggregationState::Waiting(prep_step.clone(), None), - ReportAggregationState::Waiting(prep_step, Some(prep_msg)), + ReportAggregationState::::Start, + ReportAggregationState::Waiting(prep_state.clone(), None), + ReportAggregationState::Waiting(prep_state, Some(prep_msg)), ReportAggregationState::Finished(output_share), ReportAggregationState::Failed(ReportShareError::VdafPrepError), ReportAggregationState::Invalid, @@ -3531,7 +3558,10 @@ mod tests { Role::Leader, )) .await?; - tx.put_aggregation_job(&AggregationJob:: { + tx.put_aggregation_job(&AggregationJob::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id, task_id, aggregation_param: (), @@ -3568,10 +3598,11 @@ mod tests { let got_report_aggregation = ds .run_tx(|tx| { - let verify_param = verify_param.clone(); + let vdaf = Arc::clone(&vdaf); Box::pin(async move { - tx.get_report_aggregation::( - &verify_param, + tx.get_report_aggregation( + vdaf.as_ref(), + Role::Leader, task_id, aggregation_job_id, nonce, @@ -3594,10 +3625,11 @@ mod tests { let got_report_aggregation = ds .run_tx(|tx| { - let verify_param = verify_param.clone(); + let vdaf = Arc::clone(&vdaf); Box::pin(async move { - tx.get_report_aggregation::( - &verify_param, + tx.get_report_aggregation( + vdaf.as_ref(), + Role::Leader, task_id, aggregation_job_id, nonce, @@ -3616,11 +3648,16 @@ mod tests { install_test_trace_subscriber(); let (ds, _db_handle) = ephemeral_datastore(MockClock::default()).await; + const VERIFY_KEY_LENGTH: usize = FakeVdaf::VERIFY_KEY_LENGTH; + let vdaf = Arc::new(FakeVdaf::default()); + let rslt = ds .run_tx(|tx| { + let vdaf = Arc::clone(&vdaf); Box::pin(async move { - tx.get_report_aggregation::( - &(), + tx.get_report_aggregation( + vdaf.as_ref(), + Role::Leader, TaskId::random(), AggregationJobId::random(), Nonce::new( @@ -3638,16 +3675,18 @@ mod tests { let rslt = ds .run_tx(|tx| { Box::pin(async move { - tx.update_report_aggregation::(&ReportAggregation { - aggregation_job_id: AggregationJobId::random(), - task_id: TaskId::random(), - nonce: Nonce::new( - Time::from_seconds_since_epoch(12345), - [1, 2, 3, 4, 5, 6, 7, 8], - ), - ord: 0, - state: ReportAggregationState::Invalid, - }) + tx.update_report_aggregation::( + &ReportAggregation { + aggregation_job_id: AggregationJobId::random(), + task_id: TaskId::random(), + nonce: Nonce::new( + Time::from_seconds_since_epoch(12345), + [1, 2, 3, 4, 5, 6, 7, 8], + ), + ord: 0, + state: ReportAggregationState::Invalid, + }, + ) .await }) }) @@ -3660,8 +3699,8 @@ mod tests { install_test_trace_subscriber(); let (ds, _db_handle) = ephemeral_datastore(MockClock::default()).await; - let vdaf = Prio3Aes128Count::new(2).unwrap(); - let (verify_param, prep_step, prep_msg, output_share) = generate_vdaf_values(vdaf, (), 0); + let vdaf = Arc::new(Prio3::new_aes128_count(2).unwrap()); + let (prep_state, prep_msg, output_share) = generate_vdaf_values(vdaf.as_ref(), (), 0); let task_id = TaskId::random(); let aggregation_job_id = AggregationJobId::random(); @@ -3669,7 +3708,7 @@ mod tests { let report_aggregations = ds .run_tx(|tx| { let prep_msg = prep_msg.clone(); - let prep_step = prep_step.clone(); + let prep_state = prep_state.clone(); let output_share = output_share.clone(); Box::pin(async move { @@ -3679,7 +3718,10 @@ mod tests { Role::Leader, )) .await?; - tx.put_aggregation_job(&AggregationJob:: { + tx.put_aggregation_job(&AggregationJob::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + > { aggregation_job_id, task_id, aggregation_param: (), @@ -3688,16 +3730,20 @@ mod tests { .await?; let mut report_aggregations = Vec::new(); - for (ord, state) in [ - ReportAggregationState::::Start, - ReportAggregationState::Waiting(prep_step.clone(), None), - ReportAggregationState::Waiting(prep_step, Some(prep_msg)), - ReportAggregationState::Finished(output_share), - ReportAggregationState::Failed(ReportShareError::VdafPrepError), - ReportAggregationState::Invalid, - ] - .iter() - .enumerate() + for (ord, state) in + [ + ReportAggregationState::< + PRIO3_AES128_VERIFY_KEY_LENGTH, + Prio3Aes128Count, + >::Start, + ReportAggregationState::Waiting(prep_state.clone(), None), + ReportAggregationState::Waiting(prep_state, Some(prep_msg)), + ReportAggregationState::Finished(output_share), + ReportAggregationState::Failed(ReportShareError::VdafPrepError), + ReportAggregationState::Invalid, + ] + .iter() + .enumerate() { let nonce = Nonce::new( Time::from_seconds_since_epoch(12345), @@ -3735,11 +3781,11 @@ mod tests { let got_report_aggregations = ds .run_tx(|tx| { - let verify_param = verify_param.clone(); - + let vdaf = Arc::clone(&vdaf); Box::pin(async move { tx.get_report_aggregations_for_aggregation_job( - &verify_param, + vdaf.as_ref(), + Role::Leader, task_id, aggregation_job_id, ) @@ -4006,7 +4052,9 @@ mod tests { .unwrap(); let first_collect_job = tx - .get_collect_job::(first_collect_job_id) + .get_collect_job::( + first_collect_job_id, + ) .await .unwrap() .unwrap(); @@ -4016,7 +4064,9 @@ mod tests { assert_eq!(first_collect_job.state, CollectJobState::Start); let second_collect_job = tx - .get_collect_job::(second_collect_job_id) + .get_collect_job::( + second_collect_job_id, + ) .await .unwrap() .unwrap(); @@ -4035,7 +4085,7 @@ mod tests { ) .unwrap(); - tx.update_collect_job::( + tx.update_collect_job::( first_collect_job_id, &leader_aggregate_share, &encrypted_helper_aggregate_share, @@ -4044,7 +4094,9 @@ mod tests { .unwrap(); let first_collect_job = tx - .get_collect_job::(first_collect_job_id) + .get_collect_job::( + first_collect_job_id, + ) .await .unwrap() .unwrap(); @@ -4093,7 +4145,9 @@ mod tests { let collect_job_id = tx.put_collect_job(task_id, batch_interval, &[]).await?; let collect_job = tx - .get_collect_job::(collect_job_id) + .get_collect_job::( + collect_job_id, + ) .await? .unwrap(); @@ -4121,7 +4175,9 @@ mod tests { Box::pin(async move { tx.cancel_collect_job(collect_job_id).await?; let collect_job = tx - .get_collect_job::(collect_job_id) + .get_collect_job::( + collect_job_id, + ) .await? .unwrap(); Ok(collect_job) @@ -4159,8 +4215,8 @@ mod tests { struct CollectJobAcquireTestCase { task_ids: Vec, reports: Vec, - aggregation_jobs: Vec>, - report_aggregations: Vec>, + aggregation_jobs: Vec>, + report_aggregations: Vec>, collect_job_test_cases: Vec, } @@ -4168,6 +4224,7 @@ mod tests { ds: &Datastore, test_case: CollectJobAcquireTestCase, ) -> CollectJobAcquireTestCase { + const VERIFY_KEY_LENGTH: usize = FakeVdaf::VERIFY_KEY_LENGTH; ds.run_tx(|tx| { let mut test_case = test_case.clone(); Box::pin(async move { @@ -4198,7 +4255,7 @@ mod tests { .await?; if test_case.set_aggregate_shares { - tx.update_collect_job::( + tx.update_collect_job::( collect_job_id, &dummy_vdaf::AggregateShare(), &HpkeCiphertext::new(HpkeConfigId::from(0), vec![], vec![]), @@ -4269,16 +4326,18 @@ mod tests { let clock = MockClock::default(); let (ds, _db_handle) = ephemeral_datastore(clock.clone()).await; + const VERIFY_KEY_LENGTH: usize = FakeVdaf::VERIFY_KEY_LENGTH; + let task_id = TaskId::random(); let reports = vec![new_dummy_report(task_id, Time::from_seconds_since_epoch(0))]; let aggregation_job_id = AggregationJobId::random(); - let aggregation_jobs = vec![AggregationJob:: { + let aggregation_jobs = vec![AggregationJob:: { aggregation_job_id, aggregation_param: 0u8, task_id, state: AggregationJobState::Finished, }]; - let report_aggregations = vec![ReportAggregation:: { + let report_aggregations = vec![ReportAggregation:: { aggregation_job_id, task_id, nonce: reports[0].nonce(), @@ -4388,10 +4447,12 @@ mod tests { let clock = MockClock::default(); let (ds, _db_handle) = ephemeral_datastore(clock.clone()).await; + const VERIFY_KEY_LENGTH: usize = FakeVdaf::VERIFY_KEY_LENGTH; + let task_id = TaskId::random(); let other_task_id = TaskId::random(); - let aggregation_jobs = vec![AggregationJob:: { + let aggregation_jobs = vec![AggregationJob:: { aggregation_job_id: AggregationJobId::random(), aggregation_param: 0u8, // Aggregation job task ID does not match collect job task ID @@ -4431,10 +4492,12 @@ mod tests { let clock = MockClock::default(); let (ds, _db_handle) = ephemeral_datastore(clock.clone()).await; + const VERIFY_KEY_LENGTH: usize = FakeVdaf::VERIFY_KEY_LENGTH; + let task_id = TaskId::random(); let reports = vec![new_dummy_report(task_id, Time::from_seconds_since_epoch(0))]; - let aggregation_jobs = vec![AggregationJob:: { + let aggregation_jobs = vec![AggregationJob:: { aggregation_job_id: AggregationJobId::random(), // Aggregation job agg param does not match collect job agg param aggregation_param: 1u8, @@ -4474,6 +4537,8 @@ mod tests { let clock = MockClock::default(); let (ds, _db_handle) = ephemeral_datastore(clock.clone()).await; + const VERIFY_KEY_LENGTH: usize = FakeVdaf::VERIFY_KEY_LENGTH; + let task_id = TaskId::random(); let reports = vec![new_dummy_report( task_id, @@ -4482,13 +4547,13 @@ mod tests { Time::from_seconds_since_epoch(200), )]; let aggregation_job_id = AggregationJobId::random(); - let aggregation_jobs = vec![AggregationJob:: { + let aggregation_jobs = vec![AggregationJob:: { aggregation_job_id, aggregation_param: 0u8, task_id, state: AggregationJobState::Finished, }]; - let report_aggregations = vec![ReportAggregation:: { + let report_aggregations = vec![ReportAggregation:: { aggregation_job_id, task_id, nonce: reports[0].nonce(), @@ -4529,17 +4594,19 @@ mod tests { let clock = MockClock::default(); let (ds, _db_handle) = ephemeral_datastore(clock.clone()).await; + const VERIFY_KEY_LENGTH: usize = FakeVdaf::VERIFY_KEY_LENGTH; + let task_id = TaskId::random(); let reports = vec![new_dummy_report(task_id, Time::from_seconds_since_epoch(0))]; let aggregation_job_id = AggregationJobId::random(); - let aggregation_jobs = vec![AggregationJob:: { + let aggregation_jobs = vec![AggregationJob:: { aggregation_job_id, aggregation_param: 0u8, task_id, state: AggregationJobState::Finished, }]; - let report_aggregations = vec![ReportAggregation:: { + let report_aggregations = vec![ReportAggregation:: { aggregation_job_id, task_id, nonce: reports[0].nonce(), @@ -4580,6 +4647,8 @@ mod tests { let clock = MockClock::default(); let (ds, _db_handle) = ephemeral_datastore(clock.clone()).await; + const VERIFY_KEY_LENGTH: usize = FakeVdaf::VERIFY_KEY_LENGTH; + let task_id = TaskId::random(); let reports = vec![ new_dummy_report(task_id, Time::from_seconds_since_epoch(0)), @@ -4588,13 +4657,13 @@ mod tests { let aggregation_job_ids = [AggregationJobId::random(), AggregationJobId::random()]; let aggregation_jobs = vec![ - AggregationJob:: { + AggregationJob:: { aggregation_job_id: aggregation_job_ids[0], aggregation_param: 0u8, task_id, state: AggregationJobState::Finished, }, - AggregationJob:: { + AggregationJob:: { aggregation_job_id: aggregation_job_ids[1], aggregation_param: 0u8, task_id, @@ -4604,14 +4673,14 @@ mod tests { ]; let report_aggregations = vec![ - ReportAggregation:: { + ReportAggregation:: { aggregation_job_id: aggregation_job_ids[0], task_id, nonce: reports[0].nonce(), ord: 0, state: ReportAggregationState::Start, }, - ReportAggregation:: { + ReportAggregation:: { aggregation_job_id: aggregation_job_ids[1], task_id, nonce: reports[1].nonce(), @@ -4652,17 +4721,19 @@ mod tests { let clock = MockClock::default(); let (ds, _db_handle) = ephemeral_datastore(clock.clone()).await; + const VERIFY_KEY_LENGTH: usize = FakeVdaf::VERIFY_KEY_LENGTH; + let task_id = TaskId::random(); let reports = vec![new_dummy_report(task_id, Time::from_seconds_since_epoch(0))]; let aggregation_job_ids = [AggregationJobId::random(), AggregationJobId::random()]; let aggregation_jobs = vec![ - AggregationJob:: { + AggregationJob:: { aggregation_job_id: aggregation_job_ids[0], aggregation_param: 0u8, task_id, state: AggregationJobState::Finished, }, - AggregationJob:: { + AggregationJob:: { aggregation_job_id: aggregation_job_ids[1], aggregation_param: 1u8, task_id, @@ -4670,14 +4741,14 @@ mod tests { }, ]; let report_aggregations = vec![ - ReportAggregation:: { + ReportAggregation:: { aggregation_job_id: aggregation_job_ids[0], task_id, nonce: reports[0].nonce(), ord: 0, state: ReportAggregationState::Start, }, - ReportAggregation:: { + ReportAggregation:: { aggregation_job_id: aggregation_job_ids[1], task_id, nonce: reports[0].nonce(), @@ -4779,7 +4850,8 @@ mod tests { async fn roundtrip_batch_unit_aggregation() { install_test_trace_subscriber(); - type ToyPoplar1 = Poplar1, PrgAes128, 16>; + const PRG_SEED_SIZE: usize = 16; + type ToyPoplar1 = Poplar1, PrgAes128, PRG_SEED_SIZE>; let task_id = TaskId::random(); let other_task_id = TaskId::random(); @@ -4810,35 +4882,38 @@ mod tests { )) .await?; - let first_batch_unit_aggregation = BatchUnitAggregation:: { - task_id, - unit_interval_start: Time::from_seconds_since_epoch(100), - aggregation_param: aggregation_param.clone(), - aggregate_share: aggregate_share.clone(), - report_count: 0, - checksum: NonceChecksum::default(), - }; - - let second_batch_unit_aggregation = BatchUnitAggregation:: { - task_id, - unit_interval_start: Time::from_seconds_since_epoch(150), - aggregation_param: aggregation_param.clone(), - aggregate_share: aggregate_share.clone(), - report_count: 0, - checksum: NonceChecksum::default(), - }; - - let third_batch_unit_aggregation = BatchUnitAggregation:: { - task_id, - unit_interval_start: Time::from_seconds_since_epoch(200), - aggregation_param: aggregation_param.clone(), - aggregate_share: aggregate_share.clone(), - report_count: 0, - checksum: NonceChecksum::default(), - }; + let first_batch_unit_aggregation = + BatchUnitAggregation:: { + task_id, + unit_interval_start: Time::from_seconds_since_epoch(100), + aggregation_param: aggregation_param.clone(), + aggregate_share: aggregate_share.clone(), + report_count: 0, + checksum: NonceChecksum::default(), + }; + + let second_batch_unit_aggregation = + BatchUnitAggregation:: { + task_id, + unit_interval_start: Time::from_seconds_since_epoch(150), + aggregation_param: aggregation_param.clone(), + aggregate_share: aggregate_share.clone(), + report_count: 0, + checksum: NonceChecksum::default(), + }; + + let third_batch_unit_aggregation = + BatchUnitAggregation:: { + task_id, + unit_interval_start: Time::from_seconds_since_epoch(200), + aggregation_param: aggregation_param.clone(), + aggregate_share: aggregate_share.clone(), + report_count: 0, + checksum: NonceChecksum::default(), + }; // Start of this aggregation's interval is before the interval queried below. - tx.put_batch_unit_aggregation(&BatchUnitAggregation:: { + tx.put_batch_unit_aggregation(&BatchUnitAggregation:: { task_id, unit_interval_start: Time::from_seconds_since_epoch(25), aggregation_param: aggregation_param.clone(), @@ -4858,7 +4933,7 @@ mod tests { tx.put_batch_unit_aggregation(&third_batch_unit_aggregation) .await?; // Aggregation parameter differs from the one queried below. - tx.put_batch_unit_aggregation(&BatchUnitAggregation:: { + tx.put_batch_unit_aggregation(&BatchUnitAggregation:: { task_id, unit_interval_start: Time::from_seconds_since_epoch(100), aggregation_param: BTreeSet::from([ @@ -4872,7 +4947,7 @@ mod tests { .await?; // End of this aggregation's interval is after the interval queried below. - tx.put_batch_unit_aggregation(&BatchUnitAggregation:: { + tx.put_batch_unit_aggregation(&BatchUnitAggregation:: { task_id, unit_interval_start: Time::from_seconds_since_epoch(250), aggregation_param: aggregation_param.clone(), @@ -4883,7 +4958,7 @@ mod tests { .await?; // Start of this aggregation's interval is after the interval queried below. - tx.put_batch_unit_aggregation(&BatchUnitAggregation:: { + tx.put_batch_unit_aggregation(&BatchUnitAggregation:: { task_id, unit_interval_start: Time::from_seconds_since_epoch(400), aggregation_param: aggregation_param.clone(), @@ -4894,7 +4969,7 @@ mod tests { .await?; // Task ID differs from that queried below. - tx.put_batch_unit_aggregation(&BatchUnitAggregation:: { + tx.put_batch_unit_aggregation(&BatchUnitAggregation:: { task_id: other_task_id, unit_interval_start: Time::from_seconds_since_epoch(200), aggregation_param: aggregation_param.clone(), @@ -4905,7 +4980,7 @@ mod tests { .await?; let batch_unit_aggregations = tx - .get_batch_unit_aggregations_for_task_in_interval::( + .get_batch_unit_aggregations_for_task_in_interval::( task_id, Interval::new( Time::from_seconds_since_epoch(50), @@ -4934,18 +5009,19 @@ mod tests { ); assert!(batch_unit_aggregations.contains(&third_batch_unit_aggregation)); - let updated_first_batch_unit_aggregation = BatchUnitAggregation:: { - aggregate_share: AggregateShare::from(vec![Field64::from(25)]), - report_count: 1, - checksum: NonceChecksum::get_decoded(&[1; 32]).unwrap(), - ..first_batch_unit_aggregation - }; + let updated_first_batch_unit_aggregation = + BatchUnitAggregation:: { + aggregate_share: AggregateShare::from(vec![Field64::from(25)]), + report_count: 1, + checksum: NonceChecksum::get_decoded(&[1; 32]).unwrap(), + ..first_batch_unit_aggregation + }; tx.update_batch_unit_aggregation(&updated_first_batch_unit_aggregation) .await?; let batch_unit_aggregations = tx - .get_batch_unit_aggregations_for_task_in_interval::( + .get_batch_unit_aggregations_for_task_in_interval::( task_id, Interval::new( Time::from_seconds_since_epoch(50), @@ -5020,18 +5096,22 @@ mod tests { checksum, }; - tx.put_aggregate_share_job::(&aggregate_share_job) - .await - .unwrap(); + tx.put_aggregate_share_job::( + &aggregate_share_job, + ) + .await + .unwrap(); let aggregate_share_job_again = tx - .get_aggregate_share_job_by_request::(&AggregateShareReq { - task_id, - batch_interval, - aggregation_param: ().get_encoded(), - report_count, - checksum, - }) + .get_aggregate_share_job_by_request::( + &AggregateShareReq { + task_id, + batch_interval, + aggregation_param: ().get_encoded(), + report_count, + checksum, + }, + ) .await .unwrap() .unwrap(); @@ -5039,13 +5119,15 @@ mod tests { assert_eq!(aggregate_share_job, aggregate_share_job_again); assert!(tx - .get_aggregate_share_job_by_request::(&AggregateShareReq { - task_id, - batch_interval: other_batch_interval, - aggregation_param: ().get_encoded(), - report_count, - checksum, - },) + .get_aggregate_share_job_by_request::( + &AggregateShareReq { + task_id, + batch_interval: other_batch_interval, + aggregation_param: ().get_encoded(), + report_count, + checksum, + }, + ) .await .unwrap() .is_none()); @@ -5118,14 +5200,16 @@ mod tests { ]; for (task_id, interval) in aggregate_share_jobs { - tx.put_aggregate_share_job::(&AggregateShareJob { - task_id, - batch_interval: interval, - aggregation_param: (), - helper_aggregate_share: aggregate_share.clone(), - report_count: 10, - checksum: NonceChecksum::get_decoded(&[1; 32]).unwrap(), - }) + tx.put_aggregate_share_job::( + &AggregateShareJob { + task_id, + batch_interval: interval, + aggregation_param: (), + helper_aggregate_share: aggregate_share.clone(), + report_count: 10, + checksum: NonceChecksum::get_decoded(&[1; 32]).unwrap(), + }, + ) .await .unwrap(); } @@ -5234,49 +5318,33 @@ mod tests { /// with the same aggregator. /// /// generate_vdaf_values assumes that the VDAF in use is one-round. - fn generate_vdaf_values( - vdaf: A, + fn generate_vdaf_values + vdaf::Client>( + vdaf: &A, agg_param: A::AggregationParam, measurement: A::Measurement, - ) -> ( - A::VerifyParam, - A::PrepareStep, - A::PrepareMessage, - A::OutputShare, - ) + ) -> (A::PrepareState, A::PrepareMessage, A::OutputShare) where for<'a> &'a A::AggregateShare: Into>, { - let (public_param, mut verify_params) = vdaf.setup().unwrap(); + let input_shares = vdaf.shard(&measurement).unwrap(); + let mut verify_key = [0u8; L]; + thread_rng().fill(&mut verify_key[..]); - let input_shares = vdaf.shard(&public_param, &measurement).unwrap(); - let prep_states: Vec = verify_params + let (mut prep_states, prep_shares): (Vec<_>, Vec<_>) = input_shares .iter() - .zip(input_shares) - .map(|(verify_param, input_share)| { - vdaf.prepare_init(verify_param, &agg_param, b"nonce", &input_share) + .enumerate() + .map(|(agg_id, input_share)| { + vdaf.prepare_init(&verify_key, agg_id, &agg_param, b"nonce", input_share) .unwrap() }) - .collect(); - let (mut prep_states, prep_msgs): (Vec, Vec) = - prep_states - .iter() - .map(|prep_state| { - if let PrepareTransition::Continue(prep_state, prep_msg) = - vdaf.prepare_step(prep_state.clone(), None) - { - (prep_state, prep_msg) - } else { - panic!("generate_vdaf_values: VDAF returned something other than Continue") - } - }) - .unzip(); - let prep_msg = vdaf.prepare_preprocess(prep_msgs).unwrap(); + .unzip(); + let prep_msg = vdaf.prepare_preprocess(prep_shares).unwrap(); let mut output_shares: Vec = prep_states .iter() .map(|prep_state| { - if let PrepareTransition::Finish(output_share) = - vdaf.prepare_step(prep_state.clone(), Some(prep_msg.clone())) + if let PrepareTransition::Finish(output_share) = vdaf + .prepare_step(prep_state.clone(), prep_msg.clone()) + .unwrap() { output_share } else { @@ -5285,11 +5353,6 @@ mod tests { }) .collect(); - ( - verify_params.remove(0), - prep_states.remove(0), - prep_msg, - output_shares.remove(0), - ) + (prep_states.remove(0), prep_msg, output_shares.remove(0)) } } diff --git a/janus_server/src/message.rs b/janus_server/src/message.rs index c5d0ef07c..dd98ac474 100644 --- a/janus_server/src/message.rs +++ b/janus_server/src/message.rs @@ -88,9 +88,9 @@ impl Encode for PrepareStepResult { // The encoding includes an implicit discriminator byte, called PrepareStepResult in the // DAP spec. match self { - Self::Continued(prep_msg) => { + Self::Continued(vdaf_msg) => { 0u8.encode(bytes); - encode_u16_items(bytes, &(), prep_msg); + encode_u16_items(bytes, &(), vdaf_msg); } Self::Finished => 1u8.encode(bytes), Self::Failed(error) => { @@ -522,7 +522,7 @@ mod tests { ), "00", // prepare_step_result concat!( - // prep_msg + // vdaf_msg "0006", // length "303132333435", // opaque data ), diff --git a/janus_server/src/task.rs b/janus_server/src/task.rs index 20a1eb945..982526009 100644 --- a/janus_server/src/task.rs +++ b/janus_server/src/task.rs @@ -20,6 +20,8 @@ pub enum Error { InvalidParameter(&'static str), #[error("URL parse error")] Url(#[from] url::ParseError), + #[error("aggregator auth key size out of range")] + AggregatorAuthKeySize, } /// Identifiers for VDAFs supported by this aggregator, corresponding to @@ -45,6 +47,25 @@ impl From for VdafInstance { } } +impl VdafInstance { + /// Returns the expected length of a VDAF verification key for a VDAF of this type. + fn verify_key_length(&self) -> usize { + match self { + // All "real" VDAFs use a verify key of length 16 currently. (Poplar1 may not, but it's + // not yet done being specified, so choosing 16 bytes is fine for testing.) + VdafInstance::Real(_) => PRIO3_AES128_VERIFY_KEY_LENGTH, + + #[cfg(test)] + VdafInstance::Fake + | VdafInstance::FakeFailsPrepInit + | VdafInstance::FakeFailsPrepStep => 0, + } + } +} + +/// The length of the verify key parameter for Prio3 AES-128 VDAF instantiations. +pub const PRIO3_AES128_VERIFY_KEY_LENGTH: usize = 16; + impl Serialize for VdafInstance { fn serialize(&self, serializer: S) -> Result where @@ -173,9 +194,9 @@ pub struct Task { pub vdaf: VdafInstance, /// The role performed by the aggregator. pub role: Role, - /// Secret verification parameters shared by the aggregators. + /// Secret verification keys shared by the aggregators. #[derivative(Debug = "ignore")] - pub vdaf_verify_parameters: Vec>, + pub vdaf_verify_keys: Vec>, /// The maximum number of times a given batch may be collected. pub(crate) max_batch_lifetime: u64, /// The minimum number of reports in a batch to allow it to be collected. @@ -202,7 +223,7 @@ impl Task { aggregator_endpoints: Vec, vdaf: VdafInstance, role: Role, - vdaf_verify_parameters: Vec>, + vdaf_verify_keys: Vec>, max_batch_lifetime: u64, min_batch_size: u64, min_batch_duration: Duration, @@ -221,8 +242,8 @@ impl Task { if agg_auth_tokens.is_empty() { return Err(Error::InvalidParameter("agg_auth_tokens")); } - if vdaf_verify_parameters.is_empty() { - return Err(Error::InvalidParameter("vdaf_verify_parameters")); + if vdaf_verify_keys.is_empty() { + return Err(Error::InvalidParameter("vdaf_verify_keys")); } // Compute hpke_configs mapping cfg.id -> (cfg, key). @@ -239,7 +260,7 @@ impl Task { aggregator_endpoints, vdaf, role, - vdaf_verify_parameters, + vdaf_verify_keys, max_batch_lifetime, min_batch_size, min_batch_duration, @@ -285,21 +306,13 @@ impl Task { // This is public to allow use in integration tests. #[doc(hidden)] pub mod test_util { + use std::iter; + use super::{AggregatorAuthenticationToken, Task, VdafInstance}; use janus::{ hpke::test_util::generate_hpke_config_and_private_key, message::{Duration, HpkeConfig, HpkeConfigId, Role, TaskId}, }; - use prio::{ - codec::Encode, - field::Field128, - vdaf::{ - self, - poplar1::{Poplar1, ToyIdpf}, - prg::PrgAes128, - prio3::{Prio3Aes128Count, Prio3Aes128Histogram, Prio3Aes128Sum}, - }, - }; use rand::{thread_rng, Rng}; /// Create a dummy [`Task`] from the provided [`TaskId`], with @@ -319,7 +332,9 @@ pub mod test_util { aggregator_config_1.public_key().clone(), ); - let vdaf_verify_parameter = verify_param_dispatch(&vdaf, role); + let vdaf_verify_key = iter::repeat_with(|| thread_rng().gen()) + .take(vdaf.verify_key_length()) + .collect(); Task::new( task_id, @@ -329,7 +344,7 @@ pub mod test_util { ], vdaf, role, - vec![vdaf_verify_parameter], + vec![vdaf_verify_key], 0, 0, Duration::from_hours(8).unwrap(), @@ -354,41 +369,6 @@ pub mod test_util { .into_bytes() .into() } - - fn verify_param_dispatch(vdaf: &VdafInstance, role: Role) -> Vec { - match &vdaf { - VdafInstance::Real(janus::task::VdafInstance::Prio3Aes128Count) => { - verify_param(Prio3Aes128Count::new(2).unwrap(), role) - } - VdafInstance::Real(janus::task::VdafInstance::Prio3Aes128Sum { bits }) => { - verify_param(Prio3Aes128Sum::new(2, *bits).unwrap(), role) - } - VdafInstance::Real(janus::task::VdafInstance::Prio3Aes128Histogram { buckets }) => { - verify_param(Prio3Aes128Histogram::new(2, &*buckets).unwrap(), role) - } - VdafInstance::Real(janus::task::VdafInstance::Poplar1 { bits }) => verify_param( - Poplar1::, PrgAes128, 16>::new(*bits), - role, - ), - - #[cfg(test)] - VdafInstance::Fake - | VdafInstance::FakeFailsPrepInit - | VdafInstance::FakeFailsPrepStep => Vec::new(), - } - } - - fn verify_param(vdaf: V, role: Role) -> Vec - where - for<'a> &'a V::AggregateShare: Into>, - V::VerifyParam: Encode, - { - let (_, verify_params) = vdaf.setup().unwrap(); - verify_params - .get(role.index().unwrap()) - .unwrap() - .get_encoded() - } } #[cfg(test)] diff --git a/monolithic_integration_test/Cargo.toml b/monolithic_integration_test/Cargo.toml index 85d0aacca..13726eea7 100644 --- a/monolithic_integration_test/Cargo.toml +++ b/monolithic_integration_test/Cargo.toml @@ -13,7 +13,8 @@ futures = "0.3.21" janus = { path = "../janus" } janus_server = { path = "../janus_server" } lazy_static = "1" -prio = "0.7.1" +prio = "0.8.0" +rand = "0.8" ring = "0.16.20" testcontainers = "0.14.0" janus_test_util = { path = "../test_util" } diff --git a/monolithic_integration_test/tests/integration_test.rs b/monolithic_integration_test/tests/integration_test.rs index ba7bbb8d5..5472140ad 100644 --- a/monolithic_integration_test/tests/integration_test.rs +++ b/monolithic_integration_test/tests/integration_test.rs @@ -9,13 +9,11 @@ use janus_server::{ aggregator::aggregator_server, client::{self, Client, ClientParameters}, datastore::{Crypter, Datastore}, - task::{test_util::generate_aggregator_auth_token, Task}, + task::{test_util::generate_aggregator_auth_token, Task, PRIO3_AES128_VERIFY_KEY_LENGTH}, trace::{install_trace_subscriber, TraceConfiguration}, }; -use prio::{ - codec::Encode, - vdaf::{prio3::Prio3Aes128Count, Vdaf as VdafTrait}, -}; +use prio::vdaf::prio3::{Prio3, Prio3Aes128Count}; +use rand::{thread_rng, Rng}; use std::{ net::{IpAddr, Ipv4Addr, SocketAddr}, sync::Arc, @@ -52,10 +50,8 @@ async fn setup_test() -> TestCase { let task_id = TaskId::random(); - let vdaf = Prio3Aes128Count::new(2).unwrap(); - let mut verify_params_iter = vdaf.setup().unwrap().1.into_iter(); - let leader_verify_param = verify_params_iter.next().unwrap(); - let helper_verify_param = verify_params_iter.next().unwrap(); + let mut verify_key = [0u8; PRIO3_AES128_VERIFY_KEY_LENGTH]; + thread_rng().fill(&mut verify_key[..]); let (collector_hpke_config, _) = generate_hpke_config_and_private_key(); let agg_auth_token = generate_aggregator_auth_token(); @@ -85,7 +81,7 @@ async fn setup_test() -> TestCase { ], VdafInstance::Prio3Aes128Count.into(), Role::Leader, - vec![leader_verify_param.get_encoded()], + vec![Vec::from(verify_key)], 1, 0, Duration::from_hours(8).unwrap(), @@ -118,7 +114,7 @@ async fn setup_test() -> TestCase { ], VdafInstance::Prio3Aes128Count.into(), Role::Helper, - vec![helper_verify_param.get_encoded()], + vec![Vec::from(verify_key)], 1, 0, Duration::from_hours(8).unwrap(), @@ -165,12 +161,11 @@ async fn setup_test() -> TestCase { .await .unwrap(); - let vdaf = Prio3Aes128Count::new(2).unwrap(); + let vdaf = Prio3::new_aes128_count(2).unwrap(); let client = Client::new( client_parameters, vdaf, - (), // no public parameter for prio3 RealClock::default(), &http_client, leader_report_config, diff --git a/test_util/Cargo.toml b/test_util/Cargo.toml index f822453b5..4aae2d461 100644 --- a/test_util/Cargo.toml +++ b/test_util/Cargo.toml @@ -8,7 +8,7 @@ publish = false [dependencies] assert_matches = "1" -prio = "0.7.1" +prio = "0.8.0" rand = "0.8" ring = "0.16.20" janus = { path = "../janus" } diff --git a/test_util/src/dummy_vdaf.rs b/test_util/src/dummy_vdaf.rs index 2ecf1e1e3..80d0265ae 100644 --- a/test_util/src/dummy_vdaf.rs +++ b/test_util/src/dummy_vdaf.rs @@ -11,12 +11,13 @@ use std::sync::Arc; pub type Vdaf = VdafWithAggregationParameter<()>; #[derive(Clone)] -pub struct VdafWithAggregationParameter { +pub struct VdafWithAggregationParameter { prep_init_fn: Arc Result<(), VdafError> + 'static + Send + Sync>, - prep_step_fn: Arc PrepareTransition<(), (), OutputShare> + 'static + Send + Sync>, + prep_step_fn: + Arc Result, VdafError> + 'static + Send + Sync>, } -impl Debug for VdafWithAggregationParameter { +impl Debug for VdafWithAggregationParameter { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Vdaf") .field("prep_init_result", &"[omitted]") @@ -25,12 +26,15 @@ impl Debug for VdafWithAggregationParameter { } } -impl VdafWithAggregationParameter { +impl VdafWithAggregationParameter { + /// The length of the verify key parameter for fake VDAF instantiations. + pub const VERIFY_KEY_LENGTH: usize = 0; + pub fn new() -> Self { Self { prep_init_fn: Arc::new(|_| -> Result<(), VdafError> { Ok(()) }), - prep_step_fn: Arc::new(|| -> PrepareTransition<(), (), OutputShare> { - PrepareTransition::Finish(OutputShare()) + prep_step_fn: Arc::new(|| -> Result, VdafError> { + Ok(PrepareTransition::Finish(OutputShare())) }), } } @@ -43,7 +47,7 @@ impl VdafWithAggregationParameter { self } - pub fn with_prep_step_fn PrepareTransition<(), (), OutputShare>>( + pub fn with_prep_step_fn Result, VdafError>>( mut self, f: F, ) -> Self @@ -55,7 +59,7 @@ impl VdafWithAggregationParameter { } } -impl Default for VdafWithAggregationParameter { +impl Default for VdafWithAggregationParameter { fn default() -> Self { Self::new() } @@ -65,33 +69,30 @@ impl vdaf::Vdaf for VdafWithAggregationParam type Measurement = (); type AggregateResult = (); type AggregationParam = A; - type PublicParam = (); - type VerifyParam = (); type InputShare = (); type OutputShare = OutputShare; type AggregateShare = AggregateShare; - fn setup(&self) -> Result<(Self::PublicParam, Vec), VdafError> { - Ok(((), vec![(), ()])) - } - fn num_aggregators(&self) -> usize { 2 } } -impl vdaf::Aggregator for VdafWithAggregationParameter { - type PrepareStep = (); +impl vdaf::Aggregator<0> for VdafWithAggregationParameter { + type PrepareState = (); + type PrepareShare = (); type PrepareMessage = (); fn prepare_init( &self, - _: &Self::VerifyParam, + _: &[u8; 0], + _: usize, aggregation_param: &Self::AggregationParam, _: &[u8], _: &Self::InputShare, - ) -> Result { - (self.prep_init_fn)(aggregation_param) + ) -> Result<(Self::PrepareState, Self::PrepareShare), VdafError> { + (self.prep_init_fn)(aggregation_param)?; + Ok(((), ())) } fn prepare_preprocess>( @@ -103,9 +104,9 @@ impl vdaf::Aggregator for VdafWithAggregatio fn prepare_step( &self, - _: Self::PrepareStep, - _: Option, - ) -> PrepareTransition { + _: Self::PrepareState, + _: Self::PrepareMessage, + ) -> Result, VdafError> { (self.prep_step_fn)() } @@ -119,11 +120,7 @@ impl vdaf::Aggregator for VdafWithAggregatio } impl vdaf::Client for VdafWithAggregationParameter { - fn shard( - &self, - _: &Self::PublicParam, - _: &Self::Measurement, - ) -> Result, VdafError> { + fn shard(&self, _: &Self::Measurement) -> Result, VdafError> { Ok(vec![(), ()]) } } diff --git a/test_util/src/lib.rs b/test_util/src/lib.rs index 9ea5c6cdf..1c46da73a 100644 --- a/test_util/src/lib.rs +++ b/test_util/src/lib.rs @@ -5,7 +5,7 @@ use janus::{ }; use prio::{ codec::Encode, - vdaf::{self, VdafError}, + vdaf::{self, PrepareTransition, VdafError}, }; use rand::{thread_rng, Rng}; use ring::aead::{LessSafeKey, UnboundKey, AES_128_GCM}; @@ -144,84 +144,75 @@ impl Default for MockClock { } } -/// A type alias for [`prio::vdaf::PrepareTransition`] that derives the appropriate generic types -/// based on a single aggregator parameter. -// TODO(https://github.com/divviup/libprio-rs/issues/231): change libprio-rs' PrepareTransition to be generic only on a vdaf::Aggregator. -pub type PrepareTransition = vdaf::PrepareTransition< - ::PrepareStep, - ::PrepareMessage, - ::OutputShare, ->; - /// A transcript of a VDAF run. All fields are indexed by natural role index (i.e., index 0 = /// leader, index 1 = helper). -pub struct VdafTranscript +pub struct VdafTranscript> where for<'a> &'a V::AggregateShare: Into>, { pub input_shares: Vec, - pub transitions: Vec>>, - pub combined_messages: Vec, + pub prepare_transitions: Vec>>, + pub prepare_messages: Vec, } /// run_vdaf runs a VDAF state machine from sharding through to generating an output share, /// returning a "transcript" of all states & messages. -pub fn run_vdaf( +pub fn run_vdaf + vdaf::Client>( vdaf: &V, - public_param: &V::PublicParam, - verify_params: &[V::VerifyParam], + verify_key: &[u8; L], aggregation_param: &V::AggregationParam, nonce: Nonce, measurement: &V::Measurement, -) -> VdafTranscript +) -> VdafTranscript where for<'a> &'a V::AggregateShare: Into>, { - assert_eq!(vdaf.num_aggregators(), verify_params.len()); - // Shard inputs into input shares, and initialize the initial PrepareTransitions. - let input_shares = vdaf.shard(public_param, measurement).unwrap(); - let mut prep_trans: Vec>> = input_shares + let input_shares = vdaf.shard(measurement).unwrap(); + let encoded_nonce = nonce.get_encoded(); + let mut prep_trans: Vec>> = input_shares .iter() - .zip(verify_params) - .map(|(input_share, verify_param)| { - let prep_step = vdaf.prepare_init( - verify_param, + .enumerate() + .map(|(agg_id, input_share)| { + let (prep_state, prep_share) = vdaf.prepare_init( + verify_key, + agg_id, aggregation_param, - &nonce.get_encoded(), + &encoded_nonce, input_share, )?; - let prep_trans = vdaf.prepare_step(prep_step, None); - Ok(vec![prep_trans]) + Ok(vec![PrepareTransition::Continue(prep_state, prep_share)]) }) - .collect::>>, VdafError>>() + .collect::>>, VdafError>>() .unwrap(); - let mut combined_prep_msgs = Vec::new(); + let mut prep_msgs = Vec::new(); // Repeatedly step the VDAF until we reach a terminal state. loop { // Gather messages from last round & combine them into next round's message; if any // participants have reached a terminal state (Finish or Fail), we are done. - let mut prep_msgs = Vec::new(); + let mut prep_shares = Vec::new(); for pts in &prep_trans { match pts.last().unwrap() { - PrepareTransition::::Continue(_, prep_msg) => prep_msgs.push(prep_msg.clone()), + PrepareTransition::::Continue(_, prep_share) => { + prep_shares.push(prep_share.clone()) + } _ => { return VdafTranscript { input_shares, - transitions: prep_trans, - combined_messages: combined_prep_msgs, + prepare_transitions: prep_trans, + prepare_messages: prep_msgs, } } } } - let combined_prep_msg = vdaf.prepare_preprocess(prep_msgs).unwrap(); - combined_prep_msgs.push(combined_prep_msg.clone()); + let prep_msg = vdaf.prepare_preprocess(prep_shares).unwrap(); + prep_msgs.push(prep_msg.clone()); // Compute each participant's next transition. for pts in &mut prep_trans { - let prep_step = assert_matches!(pts.last().unwrap(), PrepareTransition::::Continue(prep_step, _) => prep_step).clone(); - pts.push(vdaf.prepare_step(prep_step, Some(combined_prep_msg.clone()))); + let prep_state = assert_matches!(pts.last().unwrap(), PrepareTransition::::Continue(prep_state, _) => prep_state).clone(); + pts.push(vdaf.prepare_step(prep_state, prep_msg.clone()).unwrap()); } } }