diff --git a/crates/dapf/src/acceptance/load_testing.rs b/crates/dapf/src/acceptance/load_testing.rs index e77fd5c8c..cf0d9966c 100644 --- a/crates/dapf/src/acceptance/load_testing.rs +++ b/crates/dapf/src/acceptance/load_testing.rs @@ -59,16 +59,16 @@ fn jobs_per_batch() -> impl Iterator<Item = usize> { fn vdaf_config_params() -> impl Iterator<Item = VdafConfig> { [ VdafConfig::Prio2 { dimension: 99_992 }, - VdafConfig::Prio3Draft09( - daphne::vdaf::Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { + VdafConfig::Prio3( + daphne::vdaf::Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 { bits: 1, length: 100_000, chunk_length: 320, num_proofs: 2, }, ), - VdafConfig::Prio3Draft09( - daphne::vdaf::Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { + VdafConfig::Prio3( + daphne::vdaf::Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 { bits: 1, length: 100_000, chunk_length: 320, diff --git a/crates/dapf/src/acceptance/mod.rs b/crates/dapf/src/acceptance/mod.rs index dc82028ab..d781da6fe 100644 --- a/crates/dapf/src/acceptance/mod.rs +++ b/crates/dapf/src/acceptance/mod.rs @@ -532,6 +532,7 @@ impl Test { agg_job_state, agg_job_resp, self.metrics(), + task_config.version, )?; let aggregated_report_count = agg_share_span diff --git a/crates/dapf/src/cli_parsers.rs b/crates/dapf/src/cli_parsers.rs index 80cbc7e82..4b3cbf2c6 100644 --- a/crates/dapf/src/cli_parsers.rs +++ b/crates/dapf/src/cli_parsers.rs @@ -58,22 +58,22 @@ impl DefaultVdafConfigs { fn into_vdaf_config(self) -> VdafConfig { match self { Self::Prio2Dimension99k => VdafConfig::Prio2 { dimension: 99_992 }, - Self::Prio3Draft09NumProofs2 => { - VdafConfig::Prio3Draft09(Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { + Self::Prio3Draft09NumProofs2 => VdafConfig::Prio3( + Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 { bits: 1, length: 100_000, chunk_length: 320, num_proofs: 2, - }) - } - Self::Prio3Draft09NumProofs3 => { - VdafConfig::Prio3Draft09(Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { + }, + ), + Self::Prio3Draft09NumProofs3 => VdafConfig::Prio3( + Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 { bits: 1, length: 100_000, chunk_length: 320, num_proofs: 3, - }) - } + }, + ), } } } diff --git a/crates/daphne-server/src/roles/mod.rs b/crates/daphne-server/src/roles/mod.rs index 62a5de962..ec96287a8 100644 --- a/crates/daphne-server/src/roles/mod.rs +++ b/crates/daphne-server/src/roles/mod.rs @@ -186,32 +186,36 @@ mod test_utils { cmd.vdaf.bits, cmd.vdaf.length, cmd.vdaf.chunk_length, + cmd.vdaf.dimension, ) { - ("Prio3Count", None, None, None) => VdafConfig::Prio3Draft09(Prio3Config::Count), - ("Prio3Sum", Some(bits), None, None) => VdafConfig::Prio3Draft09(Prio3Config::Sum { - bits: bits.parse().map_err(|e| fatal_error!(err = ?e, "failed to parse bits for Prio3Config::Sum"))?, + ("Prio3Count", None, None, None, None) => VdafConfig::Prio3(Prio3Config::Count), + ("Prio3Sum", Some(max_measurement), None, None, None) => VdafConfig::Prio3(Prio3Config::Sum { + max_measurement: max_measurement.parse().map_err(|e| fatal_error!(err = ?e, "failed to parse bits for Prio3Config::Sum"))?, }), - ("Prio3SumVec", Some(bits), Some(length), Some(chunk_length)) => { - VdafConfig::Prio3Draft09(Prio3Config::SumVec { + ("Prio3SumVec", Some(bits), Some(length), Some(chunk_length), None) => { + VdafConfig::Prio3(Prio3Config::SumVec { bits: bits.parse().map_err(|e| fatal_error!(err = ?e, "failed to parse bits for Prio3Config::SumVec"))?, length: length.parse().map_err(|e| fatal_error!(err = ?e, "failed to parse length for Prio3Config::SumVec"))?, chunk_length: chunk_length.parse().map_err(|e| fatal_error!(err = ?e, "failed to parse chunk_length for Prio3Config::SumVec"))?, }) } - ("Prio3Histogram", None, Some(length), Some(chunk_length)) => { - VdafConfig::Prio3Draft09(Prio3Config::Histogram { + ("Prio3Histogram", None, Some(length), Some(chunk_length), None) => { + VdafConfig::Prio3(Prio3Config::Histogram { length: length.parse().map_err(|e| fatal_error!(err = ?e, "failed to parse length for Prio3Config::Histogram"))?, chunk_length: chunk_length.parse().map_err(|e| fatal_error!(err = ?e, "failed to parse chunk_length for Prio3Config::Histogram"))?, }) } - ("Prio3SumVecField64MultiproofHmacSha256Aes128", Some(bits), Some(length), Some(chunk_length)) => { - VdafConfig::Prio3Draft09(Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { + ("Prio3SumVecField64MultiproofHmacSha256Aes128", Some(bits), Some(length), Some(chunk_length), None) => { + VdafConfig::Prio3(Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 { bits: bits.parse().map_err(|e| fatal_error!(err = ?e, "failed to parse bits for Prio3Config::SumVecField64MultiproofHmacSha256Aes128"))?, length: length.parse().map_err(|e| fatal_error!(err = ?e, "failed to parse length for Prio3Config::SumVecField64MultiproofHmacSha256Aes128"))?, chunk_length: chunk_length.parse().map_err(|e| fatal_error!(err = ?e, "failed to parse chunk_length for Prio3Config::SumVecField64MultiproofHmacSha256Aes128"))?, num_proofs: 2, }) } + ("Prio2", None, None, None, Some(dimension)) => VdafConfig::Prio2 { + dimension: dimension.parse().map_err(|e| fatal_error!(err = ?e, "failed to parse dimension for Prio2"))?, + }, _ => return Err(fatal_error!(err = "command failed: unrecognized VDAF")), }; diff --git a/crates/daphne-server/tests/e2e/e2e.rs b/crates/daphne-server/tests/e2e/e2e.rs index 93a4ca009..309bc6d7f 100644 --- a/crates/daphne-server/tests/e2e/e2e.rs +++ b/crates/daphne-server/tests/e2e/e2e.rs @@ -151,7 +151,6 @@ async fn hpke_configs_are_cached(version: DapVersion) { async_test_versions! { hpke_configs_are_cached } -// TODO draft02 cleanup: In draft09, the client is meant to PUT its report, not POST it. async fn leader_upload(version: DapVersion) { let t = TestRunner::default_with_version(version).await; let mut rng = thread_rng(); @@ -171,7 +170,7 @@ async fn leader_upload(version: DapVersion) { &hpke_config_list, t.now, &t.task_id, - DapMeasurement::U64(23), + DapMeasurement::U32Vec(vec![1; 10]), version, ) .unwrap(); @@ -202,7 +201,7 @@ async fn leader_upload(version: DapVersion) { &hpke_config_list, t.now, &bad_id, - DapMeasurement::U64(999), + DapMeasurement::U32Vec(vec![1; 10]), version, ) .unwrap() @@ -222,7 +221,7 @@ async fn leader_upload(version: DapVersion) { &hpke_config_list, t.now, &t.task_id, - DapMeasurement::U64(999), + DapMeasurement::U32Vec(vec![1; 10]), version, ) .unwrap(); @@ -264,7 +263,7 @@ async fn leader_upload(version: DapVersion) { &hpke_config_list, t.task_config.not_after, // past the expiration date &t.task_id, - DapMeasurement::U64(23), + DapMeasurement::U32Vec(vec![1; 10]), version, ) .unwrap(); @@ -352,7 +351,7 @@ async fn leader_back_compat_upload(version: DapVersion) { &hpke_config_list, t.now, &t.task_id, - DapMeasurement::U64(23), + DapMeasurement::U32Vec(vec![0; 10]), version, ) .unwrap(); @@ -561,7 +560,7 @@ async fn internal_leader_process(version: DapVersion) { &hpke_config_list, now, &t.task_id, - DapMeasurement::U64(1), + DapMeasurement::U32Vec(vec![1; 10]), version, ) .unwrap() @@ -641,7 +640,7 @@ async fn leader_collect_ok(version: DapVersion) { &hpke_config_list, now, &t.task_id, - DapMeasurement::U64(1), + DapMeasurement::U32Vec(vec![1; 10]), version, ) .unwrap() @@ -707,7 +706,10 @@ async fn leader_collect_ok(version: DapVersion) { .unwrap(); assert_eq!( agg_res, - DapAggregateResult::U128(u128::from(t.task_config.min_batch_size)) + DapAggregateResult::U32Vec(vec![ + u32::try_from(t.task_config.min_batch_size).unwrap(); + 10 + ]) ); // Check that the time interval for the reports is correct. @@ -739,7 +741,7 @@ async fn leader_collect_ok(version: DapVersion) { // "upload", // DapMediaType::Report, // t.task_config.vdaf - // .produce_report(&hpke_config_list, now, &t.task_id, DapMeasurement::U64(1)) + // .produce_report(&hpke_config_list, now, &t.task_id, DapMeasurement::U32Vec(vec![1; 10])) // .unwrap() // .get_encoded(), // 400, @@ -779,7 +781,7 @@ async fn leader_collect_ok_interleaved(version: DapVersion) { &hpke_config_list, now, &t.task_id, - DapMeasurement::U64(1), + DapMeasurement::U32Vec(vec![1; 10]), version, ) .unwrap() @@ -843,7 +845,7 @@ async fn leader_collect_not_ready_min_batch_size(version: DapVersion) { &hpke_config_list, now, &t.task_id, - DapMeasurement::U64(1), + DapMeasurement::U32Vec(vec![1; 10]), version, ) .unwrap() @@ -936,7 +938,7 @@ async fn leader_collect_back_compat(version: DapVersion) { &hpke_config_list, now, &t.task_id, - DapMeasurement::U64(1), + DapMeasurement::U32Vec(vec![1; 10]), version, ) .unwrap() @@ -1091,7 +1093,7 @@ async fn leader_collect_abort_overlapping_batch_interval(version: DapVersion) { &hpke_config_list, now, &t.task_id, - DapMeasurement::U64(1), + DapMeasurement::U32Vec(vec![1; 10]), version, ) .unwrap() @@ -1187,7 +1189,7 @@ async fn leader_selected() { &hpke_config_list, t.now, &t.task_id, - DapMeasurement::U64(1), + DapMeasurement::U32Vec(vec![1; 10]), version, ) .unwrap() @@ -1256,7 +1258,10 @@ async fn leader_selected() { .unwrap(); assert_eq!( agg_res, - DapAggregateResult::U128(u128::from(t.task_config.min_batch_size)) + DapAggregateResult::U32Vec(vec![ + u32::try_from(t.task_config.min_batch_size).unwrap(); + 10 + ]) ); // Collector: Poll the collect URI once more. Expect the response to be the same as the first, @@ -1282,7 +1287,7 @@ async fn leader_selected() { &hpke_config_list, t.now, &t.task_id, - DapMeasurement::U64(1), + DapMeasurement::U32Vec(vec![1; 10]), version, ) .unwrap() diff --git a/crates/daphne-server/tests/e2e/test_runner.rs b/crates/daphne-server/tests/e2e/test_runner.rs index be798032d..e39c33a85 100644 --- a/crates/daphne-server/tests/e2e/test_runner.rs +++ b/crates/daphne-server/tests/e2e/test_runner.rs @@ -10,7 +10,7 @@ use daphne::{ encode_base64url, taskprov::TaskprovAdvertisement, Base64Encode, BatchId, CollectionJobId, Duration, HpkeConfigList, Interval, TaskId, }, - vdaf::{Prio3Config, VdafConfig}, + vdaf::VdafConfig, DapBatchMode, DapGlobalConfig, DapLeaderProcessTelemetry, DapTaskConfig, DapVersion, }; use daphne_service_utils::http_headers; @@ -29,7 +29,8 @@ use std::{ use tokio::time::timeout; use url::Url; -const VDAF_CONFIG: &VdafConfig = &VdafConfig::Prio3Draft09(Prio3Config::Sum { bits: 10 }); +// Use a VDAF that is supported in all versions of DAP. +const VDAF_CONFIG: &VdafConfig = &VdafConfig::Prio2 { dimension: 10 }; pub(crate) const MIN_BATCH_SIZE: u64 = 10; pub(crate) const MAX_BATCH_SIZE: u32 = 12; pub(crate) const TIME_PRECISION: Duration = 3600; // seconds @@ -167,10 +168,10 @@ impl TestRunner { encode_base64url(t.collector_hpke_receiver.config.get_encoded().unwrap()); let vdaf = json!({ - "type": "Prio3Sum", - "bits": assert_matches!( + "type": "Prio2", + "dimension": assert_matches!( t.task_config.vdaf, - VdafConfig::Prio3Draft09(Prio3Config::Sum{ bits }) => format!("{bits}") + VdafConfig::Prio2{ dimension } => format!("{dimension}") ), }); diff --git a/crates/daphne-service-utils/src/test_route_types.rs b/crates/daphne-service-utils/src/test_route_types.rs index 30c22f1b1..3fba74fde 100644 --- a/crates/daphne-service-utils/src/test_route_types.rs +++ b/crates/daphne-service-utils/src/test_route_types.rs @@ -30,34 +30,44 @@ pub struct InternalTestVdaf { pub length: Option<String>, #[serde(skip_serializing_if = "Option::is_none")] pub chunk_length: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub dimension: Option<String>, } impl From<VdafConfig> for InternalTestVdaf { fn from(vdaf: VdafConfig) -> Self { - let (typ, bits, length, chunk_length) = match vdaf { - VdafConfig::Prio3Draft09(prio3) => match prio3 { - Prio3Config::Count => ("Prio3Draft09Count", None, None, None), - Prio3Config::Sum { bits } => ("Prio3Draft09Sum", Some(bits), None, None), + let (typ, bits, length, chunk_length, dimension) = match vdaf { + VdafConfig::Prio3(prio3) => match prio3 { + Prio3Config::Count => ("Prio3Count", None, None, None, None), + Prio3Config::Sum { max_measurement } => ( + "Prio3Sum", + Some(usize::try_from(max_measurement).unwrap()), + None, + None, + None, + ), Prio3Config::Histogram { length, chunk_length, } => ( - "Prio3Draft09Histogram", + "Prio3Histogram", None, Some(length), Some(chunk_length), + None, ), Prio3Config::SumVec { bits, length, chunk_length, } => ( - "Prio3Draft09SumVec", + "Prio3SumVec", Some(bits), Some(length), Some(chunk_length), + None, ), - Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { + Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 { bits, length, chunk_length, @@ -67,34 +77,11 @@ impl From<VdafConfig> for InternalTestVdaf { Some(bits), Some(length), Some(chunk_length), + None, ), }, - VdafConfig::Prio3(prio3) => match prio3 { - Prio3Config::Count => ("Prio3Count", None, None, None), - Prio3Config::Sum { bits } => ("Prio3Sum", Some(bits), None, None), - Prio3Config::Histogram { - length, - chunk_length, - } => ("Prio3Histogram", None, Some(length), Some(chunk_length)), - Prio3Config::SumVec { - bits, - length, - chunk_length, - } => ("Prio3SumVec", Some(bits), Some(length), Some(chunk_length)), - Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { - bits, - length, - chunk_length, - num_proofs: _unimplemented, - } => ( - "Prio3SumVecField64MultiproofHmacSha256Aes128", - Some(bits), - Some(length), - Some(chunk_length), - ), - }, - VdafConfig::Prio2 { .. } => ("Prio2", None, None, None), - VdafConfig::Pine(_) => ("Pine", None, None, None), + VdafConfig::Prio2 { dimension } => ("Prio2", None, None, None, Some(dimension)), + VdafConfig::Pine(_) => ("Pine", None, None, None, None), #[cfg(feature = "experimental")] VdafConfig::Mastic { .. } => todo!(), }; @@ -103,6 +90,7 @@ impl From<VdafConfig> for InternalTestVdaf { bits: bits.map(|a| a.to_string()), length: length.map(|a| a.to_string()), chunk_length: chunk_length.map(|a| a.to_string()), + dimension: dimension.map(|a| a.to_string()), } } } diff --git a/crates/daphne/benches/aggregation.rs b/crates/daphne/benches/aggregation.rs index 2ae70556d..318e15add 100644 --- a/crates/daphne/benches/aggregation.rs +++ b/crates/daphne/benches/aggregation.rs @@ -37,18 +37,20 @@ fn consume_reports_vary_vdaf_dimension(c: &mut Criterion) { let mut test = AggregationJobTest::new( &VdafConfig::Prio2 { dimension: 0 }, HpkeKemId::P256HkdfSha256, - DapVersion::Latest, + DapVersion::Draft09, ); test.disable_replay_protection(); let mut g = c.benchmark_group(function!()); for vdaf_length in vdaf_lengths { - let vdaf = VdafConfig::Prio3Draft09(Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { - bits: 1, - length: vdaf_length, - chunk_length: 320, - num_proofs: 2, - }); + let vdaf = VdafConfig::Prio3( + Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 { + bits: 1, + length: vdaf_length, + chunk_length: 320, + num_proofs: 2, + }, + ); test.change_vdaf(vdaf); let reports = test .produce_repeated_reports(vdaf.gen_measurement().unwrap()) @@ -66,15 +68,16 @@ fn consume_reports_vary_vdaf_dimension(c: &mut Criterion) { } fn consume_reports_vary_num_reports(c: &mut Criterion) { - const VDAF: VdafConfig = - VdafConfig::Prio3Draft09(Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { + const VDAF: VdafConfig = VdafConfig::Prio3( + Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 { bits: 1, length: 1000, chunk_length: 320, num_proofs: 2, - }); + }, + ); - let mut test = AggregationJobTest::new(&VDAF, HpkeKemId::P256HkdfSha256, DapVersion::Latest); + let mut test = AggregationJobTest::new(&VDAF, HpkeKemId::P256HkdfSha256, DapVersion::Draft09); test.disable_replay_protection(); let mut g = c.benchmark_group(function!()); diff --git a/crates/daphne/src/protocol/aggregator.rs b/crates/daphne/src/protocol/aggregator.rs index dc21fc315..124863ddf 100644 --- a/crates/daphne/src/protocol/aggregator.rs +++ b/crates/daphne/src/protocol/aggregator.rs @@ -21,8 +21,6 @@ use crate::{ protocol::{decode_ping_pong_framed, PingPongMessageType}, vdaf::{ prio2::{prio2_prep_finish, prio2_prep_finish_from_shares}, - prio3::{prio3_prep_finish, prio3_prep_finish_from_shares}, - prio3_draft09::{prio3_draft09_prep_finish, prio3_draft09_prep_finish_from_shares}, VdafError, }, AggregationJobReportState, DapAggregateShare, DapAggregateSpan, DapAggregationJobState, @@ -286,6 +284,7 @@ impl DapTaskConfig { report_status: &HashMap<ReportId, ReportProcessedStatus>, part_batch_sel: &PartialBatchSelector, initialized_reports: &[InitializedReport<WithPeerPrepShare>], + version: DapVersion, ) -> Result<(DapAggregateSpan<DapAggregateShare>, AggregationJobResp), DapError> { let num_reports = initialized_reports.len(); let mut agg_span = DapAggregateSpan::default(); @@ -304,23 +303,15 @@ impl DapTaskConfig { prep_state: helper_prep_state, } => { let res = match &self.vdaf { - VdafConfig::Prio3Draft09(prio3_config) => { - prio3_draft09_prep_finish_from_shares( - prio3_config, + VdafConfig::Prio3(prio3_config) => prio3_config + .prep_finish_from_shares( + version, 1, + task_id, helper_prep_state.clone(), helper_prep_share.clone(), leader_prep_share, - ) - } - VdafConfig::Prio3(prio3_config) => prio3_prep_finish_from_shares( - prio3_config, - 1, - task_id, - helper_prep_state.clone(), - helper_prep_share.clone(), - leader_prep_share, - ), + ), VdafConfig::Prio2 { dimension } => prio2_prep_finish_from_shares( *dimension, helper_prep_state.clone(), @@ -406,6 +397,7 @@ impl DapTaskConfig { state: DapAggregationJobState, agg_job_resp: AggregationJobResp, metrics: &dyn DaphneMetrics, + version: DapVersion, ) -> Result<DapAggregateSpan<DapAggregateShare>, DapError> { if agg_job_resp.transitions.len() != state.seq.len() { return Err(DapAbort::InvalidMessage { @@ -459,11 +451,8 @@ impl DapTaskConfig { }; let res = match &self.vdaf { - VdafConfig::Prio3Draft09(prio3_config) => { - prio3_draft09_prep_finish(prio3_config, leader.prep_state, prep_msg) - } VdafConfig::Prio3(prio3_config) => { - prio3_prep_finish(prio3_config, leader.prep_state, prep_msg, *task_id) + prio3_config.prep_finish(leader.prep_state, prep_msg, *task_id, version) } VdafConfig::Prio2 { dimension } => { prio2_prep_finish(*dimension, leader.prep_state, prep_msg) diff --git a/crates/daphne/src/protocol/client.rs b/crates/daphne/src/protocol/client.rs index 97a5b6604..349c380f1 100644 --- a/crates/daphne/src/protocol/client.rs +++ b/crates/daphne/src/protocol/client.rs @@ -7,7 +7,7 @@ use crate::{ constants::DapAggregatorRole, hpke::{info_and_aad, HpkeConfig}, messages::{Extension, PlaintextInputShare, Report, ReportId, ReportMetadata, TaskId, Time}, - vdaf::{prio2::prio2_shard, prio3::prio3_shard, prio3_draft09::prio3_draft09_shard, VdafError}, + vdaf::{prio2::prio2_shard, VdafError}, DapError, DapMeasurement, DapVersion, VdafConfig, }; use prio::codec::ParameterizedEncode; @@ -45,7 +45,7 @@ impl VdafConfig { let mut rng = thread_rng(); let report_id = ReportId(rng.gen()); let (public_share, input_shares) = self - .produce_input_shares(measurement, &report_id.0, task_id) + .produce_input_shares(measurement, &report_id.0, task_id, version) .map_err(DapError::from_vdaf)?; Self::produce_report_with_extensions_for_shares( public_share, @@ -122,13 +122,11 @@ impl VdafConfig { measurement: DapMeasurement, nonce: &[u8; 16], task_id: &TaskId, + version: DapVersion, ) -> Result<(Vec<u8>, [Vec<u8>; 2]), VdafError> { match self { - Self::Prio3Draft09(prio3_config) => { - Ok(prio3_draft09_shard(prio3_config, measurement, nonce)?) - } Self::Prio3(prio3_config) => { - Ok(prio3_shard(prio3_config, measurement, nonce, *task_id)?) + Ok(prio3_config.shard(version, measurement, nonce, *task_id)?) } Self::Prio2 { dimension } => Ok(prio2_shard(*dimension, measurement, nonce)?), #[cfg(feature = "experimental")] diff --git a/crates/daphne/src/protocol/collector.rs b/crates/daphne/src/protocol/collector.rs index 3dfca90c9..f6fc6a88f 100644 --- a/crates/daphne/src/protocol/collector.rs +++ b/crates/daphne/src/protocol/collector.rs @@ -8,7 +8,7 @@ use crate::{ fatal_error, hpke::{info_and_aad, HpkeDecrypter}, messages::{BatchSelector, HpkeCiphertext, TaskId}, - vdaf::{prio2::prio2_unshard, prio3::prio3_unshard, prio3_draft09::prio3_draft09_unshard}, + vdaf::prio2::prio2_unshard, DapAggregateResult, DapAggregationParam, DapError, DapVersion, VdafConfig, }; @@ -73,10 +73,9 @@ impl VdafConfig { let num_measurements = usize::try_from(report_count).unwrap(); match self { - Self::Prio3Draft09(prio3_config) => { - prio3_draft09_unshard(prio3_config, num_measurements, agg_shares) + Self::Prio3(prio3_config) => { + prio3_config.unshard(version, num_measurements, agg_shares) } - Self::Prio3(prio3_config) => prio3_unshard(prio3_config, num_measurements, agg_shares), Self::Prio2 { dimension } => prio2_unshard(*dimension, num_measurements, agg_shares), #[cfg(feature = "experimental")] Self::Mastic { diff --git a/crates/daphne/src/protocol/mod.rs b/crates/daphne/src/protocol/mod.rs index 64eadd55d..47627e7bc 100644 --- a/crates/daphne/src/protocol/mod.rs +++ b/crates/daphne/src/protocol/mod.rs @@ -72,7 +72,7 @@ mod test { }, test_versions, testing::AggregationJobTest, - vdaf::{Prio3Config, VdafConfig}, + vdaf::VdafConfig, DapAggregateResult, DapAggregateShare, DapAggregationParam, DapError, DapMeasurement, DapVersion, VdafAggregateShare, VdafPrepShare, VdafPrepState, }; @@ -80,24 +80,17 @@ mod test { use hpke_rs::HpkePublicKey; use prio::{ codec::encode_u32_items, + field::FieldPrio2, vdaf::{ - prio3::Prio3, Aggregator as VdafAggregator, Collector as VdafCollector, - PrepareTransition, - }, - }; - use prio_draft09::{ - field::Field64 as Field64Draft09, - vdaf::{ - prio3::Prio3 as Prio3Draft09, AggregateShare as AggregateShareDraft09, - Aggregator as VdafAggregatorDraft09, Collector as VdafCollectorDraft09, - OutputShare as OutputShareDraft09, PrepareTransition as PrepareTransitionDraft09, + prio2::Prio2, AggregateShare, Aggregator as VdafAggregator, Collector as VdafCollector, + OutputShare, PrepareTransition, }, }; use rand::prelude::*; use std::iter::zip; - const TEST_VDAF: &VdafConfig = &VdafConfig::Prio3(Prio3Config::Count); - const TEST_VDAF_DRAFT09: &VdafConfig = &VdafConfig::Prio3Draft09(Prio3Config::Count); + // Choose a VDAF for testing that is supported by all versions of DAP. + const TEST_VDAF: &VdafConfig = &VdafConfig::Prio2 { dimension: 10 }; fn roundtrip_report(version: DapVersion) { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); @@ -108,7 +101,7 @@ mod test { &t.client_hpke_config_list, t.now, &t.task_id, - DapMeasurement::U64(1), + DapMeasurement::U32Vec(vec![1; 10]), t.task_config.version, ) .unwrap(); @@ -170,26 +163,23 @@ mod test { helper_prep_share, ) { ( - VdafPrepState::Prio3Field64(leader_step), - VdafPrepState::Prio3Field64(helper_step), - VdafPrepShare::Prio3Field64(leader_share), - VdafPrepShare::Prio3Field64(helper_share), + VdafPrepState::Prio2(leader_step), + VdafPrepState::Prio2(helper_step), + VdafPrepShare::Prio2(leader_share), + VdafPrepShare::Prio2(helper_share), ) => { - let ctx = &["dap-13".as_bytes(), &t.task_id.0].concat(); - //let ctx = binding.as_slice(); - let vdaf = Prio3::new_count(2).unwrap(); - let message = vdaf - .prepare_shares_to_prepare_message(ctx, &(), [leader_share, helper_share]) + let vdaf = Prio2::new(10).unwrap(); + vdaf.prepare_shares_to_prepare_message(&[], &(), [leader_share, helper_share]) .unwrap(); let leader_out_share = assert_matches!( - vdaf.prepare_next(ctx, leader_step, message.clone()).unwrap(), + vdaf.prepare_next(&[], leader_step, ()).unwrap(), PrepareTransition::Finish(out_share) => out_share ); let leader_agg_share = vdaf.aggregate(&(), [leader_out_share]).unwrap(); let helper_out_share = assert_matches!( - vdaf.prepare_next(ctx, helper_step, message).unwrap(), + vdaf.prepare_next(&[], helper_step, ()).unwrap(), PrepareTransition::Finish(out_share) => out_share ); let helper_agg_share = vdaf.aggregate(&(), [helper_out_share]).unwrap(); @@ -197,7 +187,7 @@ mod test { assert_eq!( vdaf.unshard(&(), vec![leader_agg_share, helper_agg_share], 1) .unwrap(), - 1, + vec![1; 10], ); } _ => { @@ -208,113 +198,6 @@ mod test { test_versions! { roundtrip_report } - #[test] - fn roundtrip_report_vdaf_draft09() { - let version = DapVersion::Draft09; - let t = AggregationJobTest::new(TEST_VDAF_DRAFT09, HpkeKemId::X25519HkdfSha256, version); - let report = t - .task_config - .vdaf - .produce_report( - &t.client_hpke_config_list, - t.now, - &t.task_id, - DapMeasurement::U64(1), - t.task_config.version, - ) - .unwrap(); - - let [leader_share, helper_share] = report.encrypted_input_shares; - - let InitializedReport::Ready { - prep_share: leader_prep_share, - prep_state: leader_prep_state, - .. - } = InitializedReport::from_client( - &t.leader_hpke_receiver_config, - t.valid_report_time_range(), - &t.task_id, - &t.task_config, - ReportShare { - report_metadata: report.report_metadata.clone(), - public_share: report.public_share.clone(), - encrypted_input_share: leader_share, - }, - &DapAggregationParam::Empty, - ) - .unwrap() - else { - panic!("rejected unexpectedly"); - }; - - let InitializedReport::Ready { - prep_share: helper_prep_share, - prep_state: helper_prep_state, - .. - } = InitializedReport::from_leader( - &t.helper_hpke_receiver_config, - t.valid_report_time_range(), - &t.task_id, - &t.task_config, - ReportShare { - report_metadata: report.report_metadata, - public_share: report.public_share, - encrypted_input_share: helper_share, - }, - { - let mut outbound = Vec::new(); - outbound.push(PingPongMessageType::Initialize as u8); - encode_u32_items(&mut outbound, &version, &[leader_prep_share.clone()]).unwrap(); - outbound - }, - &DapAggregationParam::Empty, - ) - .unwrap() - else { - panic!("rejected unexpectedly"); - }; - - match ( - leader_prep_state, - helper_prep_state, - leader_prep_share, - helper_prep_share, - ) { - ( - VdafPrepState::Prio3Draft09Field64(leader_step), - VdafPrepState::Prio3Draft09Field64(helper_step), - VdafPrepShare::Prio3Draft09Field64(leader_share), - VdafPrepShare::Prio3Draft09Field64(helper_share), - ) => { - let vdaf = Prio3Draft09::new_count(2).unwrap(); - let message = vdaf - .prepare_shares_to_prepare_message(&(), [leader_share, helper_share]) - .unwrap(); - - let leader_out_share = assert_matches!( - vdaf.prepare_next(leader_step, message.clone()).unwrap(), - PrepareTransitionDraft09::Finish(out_share) => out_share - ); - let leader_agg_share = vdaf.aggregate(&(), [leader_out_share]).unwrap(); - - let helper_out_share = assert_matches!( - vdaf.prepare_next(helper_step, message).unwrap(), - PrepareTransitionDraft09::Finish(out_share) => out_share - ); - let helper_agg_share = vdaf.aggregate(&(), [helper_out_share]).unwrap(); - - assert_eq!( - vdaf.unshard(&(), vec![leader_agg_share, helper_agg_share], 1) - .unwrap(), - 1, - ); - } - _ => { - panic!("unexpected output from leader or helper"); - } - } - } - fn roundtrip_report_unsupported_hpke_suite(version: DapVersion) { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); @@ -334,7 +217,7 @@ mod test { &unsupported_hpke_config_list, t.now, &t.task_id, - DapMeasurement::U64(1), + DapMeasurement::U32Vec(vec![1; 10]), t.task_config.version, ); assert_matches!( @@ -348,9 +231,9 @@ mod test { fn produce_agg_job_req(version: DapVersion) { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); let reports = t.produce_reports(vec![ - DapMeasurement::U64(1), - DapMeasurement::U64(0), - DapMeasurement::U64(0), + DapMeasurement::U32Vec(vec![1; 10]), + DapMeasurement::U32Vec(vec![0; 10]), + DapMeasurement::U32Vec(vec![0; 10]), ]); let (agg_job_state, agg_job_init_req) = @@ -374,7 +257,7 @@ mod test { fn produce_agg_job_req_skip_hpke_decrypt_err(version: DapVersion) { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); - let mut reports = t.produce_reports(vec![DapMeasurement::U64(1)]); + let mut reports = t.produce_reports(vec![DapMeasurement::U32Vec(vec![0; 10])]); // Simulate HPKE decryption error of leader's report share. reports[0].encrypted_input_shares[0].payload[0] ^= 1; @@ -399,7 +282,7 @@ mod test { &t.client_hpke_config_list, t.valid_report_time_range().start - 1, &t.task_id, - DapMeasurement::U64(1), + DapMeasurement::U32Vec(vec![1; 10]), t.task_config.version, ) .unwrap()]; @@ -424,7 +307,7 @@ mod test { &t.client_hpke_config_list, t.valid_report_time_range().end + 1, &t.task_id, - DapMeasurement::U64(1), + DapMeasurement::U32Vec(vec![1; 10]), t.task_config.version, ) .unwrap()]; @@ -442,7 +325,7 @@ mod test { fn produce_agg_job_req_skip_hpke_unknown_config_id(version: DapVersion) { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); - let mut reports = t.produce_reports(vec![DapMeasurement::U64(1)]); + let mut reports = t.produce_reports(vec![DapMeasurement::U32Vec(vec![1; 10])]); // Client tries to send Leader encrypted input with incorrect config ID. reports[0].encrypted_input_shares[0].config_id ^= 1; @@ -461,8 +344,14 @@ mod test { fn produce_agg_job_req_skip_vdaf_prep_error(version: DapVersion) { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); let reports = vec![ - t.produce_invalid_report_public_share_decode_failure(DapMeasurement::U64(1), version), - t.produce_invalid_report_input_share_decode_failure(DapMeasurement::U64(1), version), + t.produce_invalid_report_public_share_decode_failure( + DapMeasurement::U32Vec(vec![1; 10]), + version, + ), + t.produce_invalid_report_input_share_decode_failure( + DapMeasurement::U32Vec(vec![1; 10]), + version, + ), ]; let (agg_job_state, _agg_job_init_req) = @@ -478,7 +367,7 @@ mod test { fn handle_agg_job_req_hpke_decrypt_err(version: DapVersion) { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); - let mut reports = t.produce_reports(vec![DapMeasurement::U64(1)]); + let mut reports = t.produce_reports(vec![DapMeasurement::U32Vec(vec![1; 10])]); // Simulate HPKE decryption error of helper's report share. reports[0].encrypted_input_shares[1].payload[0] ^= 1; @@ -505,7 +394,7 @@ mod test { &t.client_hpke_config_list, t.valid_report_time_range().start - 1, &t.task_id, - DapMeasurement::U64(1), + DapMeasurement::U32Vec(vec![1; 10]), t.task_config.version, ) .unwrap()]; @@ -540,7 +429,7 @@ mod test { &t.client_hpke_config_list, t.valid_report_time_range().end + 1, &t.task_id, - DapMeasurement::U64(1), + DapMeasurement::U32Vec(vec![1; 10]), t.task_config.version, ) .unwrap()]; @@ -568,7 +457,7 @@ mod test { fn handle_agg_job_req_hpke_unknown_config_id(version: DapVersion) { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); - let mut reports = t.produce_reports(vec![DapMeasurement::U64(1)]); + let mut reports = t.produce_reports(vec![DapMeasurement::U32Vec(vec![1; 10])]); // Client tries to send Helper encrypted input with incorrect config ID. reports[0].encrypted_input_shares[1].config_id ^= 1; @@ -588,10 +477,14 @@ mod test { fn handle_agg_job_req_vdaf_prep_error(version: DapVersion) { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); - let report0 = - t.produce_invalid_report_public_share_decode_failure(DapMeasurement::U64(1), version); - let report1 = - t.produce_invalid_report_input_share_decode_failure(DapMeasurement::U64(1), version); + let report0 = t.produce_invalid_report_public_share_decode_failure( + DapMeasurement::U32Vec(vec![1; 10]), + version, + ); + let report1 = t.produce_invalid_report_input_share_decode_failure( + DapMeasurement::U32Vec(vec![1; 10]), + version, + ); let agg_job_init_req = AggregationJobInitReq { agg_param: Vec::new(), @@ -633,7 +526,10 @@ mod test { fn agg_job_resp_abort_transition_out_of_order(version: DapVersion) { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); - let reports = t.produce_reports(vec![DapMeasurement::U64(1), DapMeasurement::U64(1)]); + let reports = t.produce_reports(vec![ + DapMeasurement::U32Vec(vec![1; 10]), + DapMeasurement::U32Vec(vec![1; 10]), + ]); let (leader_state, agg_job_init_req) = t.produce_agg_job_req(&DapAggregationParam::Empty, reports); let (_agg_span, mut agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); @@ -653,7 +549,10 @@ mod test { fn agg_job_resp_abort_report_id_repeated(version: DapVersion) { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); - let reports = t.produce_reports(vec![DapMeasurement::U64(1), DapMeasurement::U64(1)]); + let reports = t.produce_reports(vec![ + DapMeasurement::U32Vec(vec![1; 10]), + DapMeasurement::U32Vec(vec![1; 10]), + ]); let (leader_state, agg_job_init_req) = t.produce_agg_job_req(&DapAggregationParam::Empty, reports); let (_agg_span, mut agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); @@ -673,7 +572,10 @@ mod test { fn agg_job_resp_abort_unrecognized_report_id(version: DapVersion) { let mut rng = thread_rng(); let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); - let reports = t.produce_reports(vec![DapMeasurement::U64(1), DapMeasurement::U64(1)]); + let reports = t.produce_reports(vec![ + DapMeasurement::U32Vec(vec![1; 10]), + DapMeasurement::U32Vec(vec![1; 10]), + ]); let (leader_state, agg_job_init_req) = t.produce_agg_job_req(&DapAggregationParam::Empty, reports); let (_agg_span, mut agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); @@ -694,7 +596,7 @@ mod test { fn agg_job_resp_abort_invalid_transition(version: DapVersion) { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); - let reports = t.produce_reports(vec![DapMeasurement::U64(1)]); + let reports = t.produce_reports(vec![DapMeasurement::U32Vec(vec![1; 10])]); let (leader_state, agg_job_init_req) = t.produce_agg_job_req(&DapAggregationParam::Empty, reports); let (_helper_agg_span, mut agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); @@ -711,59 +613,14 @@ mod test { test_versions! { agg_job_resp_abort_invalid_transition } - #[test] - fn finish_agg_job_vdaf_draft09() { - let version = DapVersion::Draft09; - let t = AggregationJobTest::new(TEST_VDAF_DRAFT09, HpkeKemId::X25519HkdfSha256, version); - let reports = t.produce_reports(vec![ - DapMeasurement::U64(1), - DapMeasurement::U64(1), - DapMeasurement::U64(0), - DapMeasurement::U64(0), - DapMeasurement::U64(1), - ]); - - let (leader_state, agg_job_init_req) = - t.produce_agg_job_req(&DapAggregationParam::Empty, reports); - - let (leader_agg_span, helper_agg_span) = { - let (helper_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); - let leader_agg_span = t.consume_agg_job_resp(leader_state, agg_job_resp); - - (leader_agg_span, helper_agg_span) - }; - - assert_eq!(leader_agg_span.report_count(), 5); - let num_measurements = leader_agg_span.report_count(); - - let VdafAggregateShare::Field64Draft09(leader_agg_share) = - leader_agg_span.collapsed().data.unwrap() - else { - panic!("unexpected VdafAggregateShare variant") - }; - - let VdafAggregateShare::Field64Draft09(helper_agg_share) = - helper_agg_span.collapsed().data.unwrap() - else { - panic!("unexpected VdafAggregateShare variant") - }; - - let vdaf = Prio3Draft09::new_count(2).unwrap(); - assert_eq!( - vdaf.unshard(&(), [leader_agg_share, helper_agg_share], num_measurements,) - .unwrap(), - 3, - ); - } - fn finish_agg_job(version: DapVersion) { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); let reports = t.produce_reports(vec![ - DapMeasurement::U64(1), - DapMeasurement::U64(1), - DapMeasurement::U64(0), - DapMeasurement::U64(0), - DapMeasurement::U64(1), + DapMeasurement::U32Vec(vec![1; 10]), + DapMeasurement::U32Vec(vec![1; 10]), + DapMeasurement::U32Vec(vec![0; 10]), + DapMeasurement::U32Vec(vec![0; 10]), + DapMeasurement::U32Vec(vec![1; 10]), ]); let (leader_state, agg_job_init_req) = @@ -779,36 +636,40 @@ mod test { assert_eq!(leader_agg_span.report_count(), 5); let num_measurements = leader_agg_span.report_count(); - let VdafAggregateShare::Field64(leader_agg_share) = + let VdafAggregateShare::Field32(leader_agg_share) = leader_agg_span.collapsed().data.unwrap() else { panic!("unexpected VdafAggregateShare variant") }; - let VdafAggregateShare::Field64(helper_agg_share) = + let VdafAggregateShare::Field32(helper_agg_share) = helper_agg_span.collapsed().data.unwrap() else { panic!("unexpected VdafAggregateShare variant") }; - let vdaf = Prio3::new_count(2).unwrap(); + let vdaf = Prio2::new(10).unwrap(); assert_eq!( vdaf.unshard(&(), [leader_agg_share, helper_agg_share], num_measurements,) .unwrap(), - 3, + vec![3; 10], ); } test_versions! { finish_agg_job } - #[tokio::test] - async fn agg_job_init_req_skip_vdaf_prep_error_draft09() { - let t = - AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, DapVersion::Draft09); - let mut reports = t.produce_reports(vec![DapMeasurement::U64(1), DapMeasurement::U64(1)]); + fn agg_job_init_req_skip_vdaf_prep_error(version: DapVersion) { + let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); + let mut reports = t.produce_reports(vec![ + DapMeasurement::U32Vec(vec![1; 10]), + DapMeasurement::U32Vec(vec![1; 10]), + ]); reports.insert( 1, - t.produce_invalid_report_vdaf_prep_failure(DapMeasurement::U64(1), DapVersion::Draft09), + t.produce_invalid_report_vdaf_prep_failure( + DapMeasurement::U32Vec(vec![1; 10]), + version, + ), ); let (leader_state, agg_job_init_req) = @@ -830,6 +691,8 @@ mod test { assert_eq!(leader_agg_span.report_count(), 2); } + test_versions! { agg_job_init_req_skip_vdaf_prep_error } + fn encrypted_agg_share(version: DapVersion) { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); let leader_agg_share = DapAggregateShare { @@ -837,22 +700,18 @@ mod test { min_time: 1_637_359_200, max_time: 1_637_359_200, checksum: [0; 32], - data: Some(VdafAggregateShare::Field64Draft09( - AggregateShareDraft09::from(OutputShareDraft09::from(vec![Field64Draft09::from( - 23, - )])), - )), + data: Some(VdafAggregateShare::Field32(AggregateShare::from( + OutputShare::from(vec![FieldPrio2::from(23); 10]), + ))), }; let helper_agg_share = DapAggregateShare { report_count: 50, min_time: 1_637_359_200, max_time: 1_637_359_200, checksum: [0; 32], - data: Some(VdafAggregateShare::Field64Draft09( - AggregateShareDraft09::from(OutputShareDraft09::from(vec![Field64Draft09::from( - 9, - )])), - )), + data: Some(VdafAggregateShare::Field32(AggregateShare::from( + OutputShare::from(vec![FieldPrio2::from(9); 10]), + ))), }; let batch_selector = BatchSelector::TimeInterval { @@ -878,7 +737,7 @@ mod test { vec![leader_encrypted_agg_share, helper_encrypted_agg_share], ); - assert_eq!(agg_res, DapAggregateResult::U64(32)); + assert_eq!(agg_res, DapAggregateResult::U32Vec(vec![32; 10])); } test_versions! { encrypted_agg_share } @@ -892,7 +751,7 @@ mod test { &t.client_hpke_config_list, t.now, &t.task_id, - DapMeasurement::U64(1), + DapMeasurement::U32Vec(vec![1; 10]), vec![Extension::NotImplemented { typ: 0xffff, payload: b"some extension data".to_vec(), @@ -934,7 +793,7 @@ mod test { &t.client_hpke_config_list, t.now, &t.task_id, - DapMeasurement::U64(1), + DapMeasurement::U32Vec(vec![1; 10]), vec![ Extension::NotImplemented { typ: 23, @@ -982,7 +841,12 @@ mod test { let (invalid_public_share, mut invalid_input_shares) = self .task_config .vdaf - .produce_input_shares(measurement, &report_id.0, &self.task_id) + .produce_input_shares( + measurement, + &report_id.0, + &self.task_id, + self.task_config.version, + ) .unwrap(); invalid_input_shares[1][0] ^= 1; // The first bit is incorrect! VdafConfig::produce_report_with_extensions_for_shares( @@ -1008,7 +872,12 @@ mod test { let (mut invalid_public_share, invalid_input_shares) = self .task_config .vdaf - .produce_input_shares(measurement, &report_id.0, &self.task_id) + .produce_input_shares( + measurement, + &report_id.0, + &self.task_id, + self.task_config.version, + ) .unwrap(); invalid_public_share.push(1); // Add spurious byte at the end VdafConfig::produce_report_with_extensions_for_shares( @@ -1034,7 +903,12 @@ mod test { let (invalid_public_share, mut invalid_input_shares) = self .task_config .vdaf - .produce_input_shares(measurement, &report_id.0, &self.task_id) + .produce_input_shares( + measurement, + &report_id.0, + &self.task_id, + self.task_config.version, + ) .unwrap(); invalid_input_shares[0].push(1); // Add a spurious byte to the Leader's share invalid_input_shares[1].push(1); // Add a spurious byte to the Helper's share diff --git a/crates/daphne/src/protocol/report_init.rs b/crates/daphne/src/protocol/report_init.rs index 2749840be..0d1b88001 100644 --- a/crates/daphne/src/protocol/report_init.rs +++ b/crates/daphne/src/protocol/report_init.rs @@ -10,10 +10,7 @@ use crate::{ self, Extension, PlaintextInputShare, ReportError, ReportMetadata, ReportShare, TaskId, }, protocol::{decode_ping_pong_framed, no_duplicates, PingPongMessageType}, - vdaf::{ - prio2::prio2_prep_init, prio3::prio3_prep_init, prio3_draft09::prio3_draft09_prep_init, - VdafConfig, VdafPrepShare, VdafPrepState, - }, + vdaf::{prio2::prio2_prep_init, VdafConfig, VdafPrepShare, VdafPrepState}, DapAggregationParam, DapError, DapTaskConfig, }; use prio::codec::{CodecError, ParameterizedDecode as _}; @@ -197,16 +194,8 @@ impl<P> InitializedReport<P> { DapAggregatorRole::Helper => 1, }; let res = match &task_config.vdaf { - VdafConfig::Prio3Draft09(ref prio3_config) => prio3_draft09_prep_init( - prio3_config, - &task_config.vdaf_verify_key, - agg_id, - &report_share.report_metadata.id.0, - &report_share.public_share, - &input_share, - ), - VdafConfig::Prio3(ref prio3_config) => prio3_prep_init( - prio3_config, + VdafConfig::Prio3(prio3_config) => prio3_config.prep_init( + task_config.version, &task_config.vdaf_verify_key, *task_id, agg_id, diff --git a/crates/daphne/src/roles/helper.rs b/crates/daphne/src/roles/helper.rs index 39f9998f1..2d4e3b13a 100644 --- a/crates/daphne/src/roles/helper.rs +++ b/crates/daphne/src/roles/helper.rs @@ -232,6 +232,7 @@ async fn finish_agg_job_and_aggregate( &report_status, part_batch_sel, initialized_reports, + task_config.version, )?; let put_shares_result = helper diff --git a/crates/daphne/src/roles/leader/mod.rs b/crates/daphne/src/roles/leader/mod.rs index 503461805..a1fb09378 100644 --- a/crates/daphne/src/roles/leader/mod.rs +++ b/crates/daphne/src/roles/leader/mod.rs @@ -341,8 +341,13 @@ async fn run_agg_job<A: DapLeader>( .map_err(|e| DapAbort::from_codec_error(e, *task_id))?; // Handle AggregationJobResp. - let agg_span = - task_config.consume_agg_job_resp(task_id, agg_job_state, agg_job_resp, metrics)?; + let agg_span = task_config.consume_agg_job_resp( + task_id, + agg_job_state, + agg_job_resp, + metrics, + task_config.version, + )?; let out_shares_count = agg_span.report_count() as u64; if out_shares_count == 0 { diff --git a/crates/daphne/src/roles/mod.rs b/crates/daphne/src/roles/mod.rs index 1b038be73..0754a9a61 100644 --- a/crates/daphne/src/roles/mod.rs +++ b/crates/daphne/src/roles/mod.rs @@ -198,7 +198,9 @@ mod test { }; // Task Parameters that the Leader and Helper must agree on. - let vdaf_config = VdafConfig::Prio3(Prio3Config::Count); + // + // We need to use a VDAF that is compatible with all versions of DAP. + let vdaf_config = VdafConfig::Prio2 { dimension: 10 }; let leader_url = Url::parse("https://leader.com/v02/").unwrap(); let helper_url = Url::parse("http://helper.org:8788/v02/").unwrap(); let collector_hpke_receiver_config = @@ -483,12 +485,12 @@ mod test { } pub async fn gen_test_report(&self, task_id: &TaskId) -> Report { - // Construct report. We expect the VDAF to be Prio3Count so that we know what type of - // measurement to generate. However, we could extend the code to support more VDAFs. + // Construct report. We expect the VDAF to be Prio2 because it's supported in all + // versions of DAP. In the future we might want to test multiple different VDAFs. let task_config = self.leader.unchecked_get_task_config(task_id).await; - assert_matches!(task_config.vdaf, VdafConfig::Prio3(Prio3Config::Count)); + assert_matches!(task_config.vdaf, VdafConfig::Prio2 { dimension: 10 }); - self.gen_test_report_for_measurement(task_id, DapMeasurement::U64(1)) + self.gen_test_report_for_measurement(task_id, DapMeasurement::U32Vec(vec![1; 10])) .await } @@ -1524,28 +1526,28 @@ mod test { async_test_versions! { e2e_taskprov_prio2 } - async fn e2e_taskprov_prio3_draft09_sum_vec_field64_multiproof_hmac_sha256_aes128( - version: DapVersion, - ) { + #[tokio::test] + async fn e2e_taskprov_prio3_draft09_sum_vec_field64_multiproof_hmac_sha256_aes128_draft09() { e2e_taskprov( - version, - VdafConfig::Prio3Draft09(Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { - bits: 1, - length: 10, - chunk_length: 2, - num_proofs: 3, - }), + DapVersion::Draft09, + VdafConfig::Prio3( + Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 { + bits: 1, + length: 10, + chunk_length: 2, + num_proofs: 3, + }, + ), DapMeasurement::U64Vec(vec![1; 10]), ) .await; } - async_test_versions! { e2e_taskprov_prio3_draft09_sum_vec_field64_multiproof_hmac_sha256_aes128 } - - async fn e2e_taskprov_pine32_hmac_sha256_aes128(version: DapVersion) { + #[tokio::test] + async fn e2e_taskprov_pine32_hmac_sha256_aes128_draft09() { use crate::{pine::PineParam, vdaf::pine::PineConfig}; e2e_taskprov( - version, + DapVersion::Draft09, VdafConfig::Pine(PineConfig::Field32HmacSha256Aes128 { param: PineParam { norm_bound: 16, @@ -1564,12 +1566,11 @@ mod test { .await; } - async_test_versions! { e2e_taskprov_pine32_hmac_sha256_aes128 } - - async fn e2e_taskprov_pine64_hmac_sha256_aes128(version: DapVersion) { + #[tokio::test] + async fn e2e_taskprov_pine64_hmac_sha256_aes128_draft09() { use crate::{pine::PineParam, vdaf::pine::PineConfig}; e2e_taskprov( - version, + DapVersion::Draft09, VdafConfig::Pine(PineConfig::Field64HmacSha256Aes128 { param: PineParam { norm_bound: 16, @@ -1588,8 +1589,6 @@ mod test { .await; } - async_test_versions! { e2e_taskprov_pine64_hmac_sha256_aes128 } - // Test multiple tasks in flight at once. async fn multi_task(version: DapVersion) { let t = Test::new(version); diff --git a/crates/daphne/src/taskprov.rs b/crates/daphne/src/taskprov.rs index 666dab673..ab96273c5 100644 --- a/crates/daphne/src/taskprov.rs +++ b/crates/daphne/src/taskprov.rs @@ -181,7 +181,7 @@ impl VdafConfig { })?, }), ( - DapVersion::Draft09 | DapVersion::Latest, + DapVersion::Draft09, VdafTypeVar::Prio3SumVecField64MultiproofHmacSha256Aes128 { bits, length, @@ -197,8 +197,8 @@ impl VdafConfig { task_id: *task_id, }); } - Ok(VdafConfig::Prio3Draft09( - Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { + Ok(VdafConfig::Prio3( + Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 { bits: bits.into(), length: length.try_into().map_err(|_| DapAbort::InvalidTask { detail: "length is larger than the system's word size".to_string(), @@ -215,7 +215,7 @@ impl VdafConfig { }, )) } - (_, VdafTypeVar::Pine32HmacSha256Aes128 { param }) => { + (DapVersion::Draft09, VdafTypeVar::Pine32HmacSha256Aes128 { param }) => { if let Err(e) = pine32_hmac_sha256_aes128(¶m) { Err(DapAbort::InvalidTask { detail: format!("invalid parameters for Pine32: {e}"), @@ -232,7 +232,7 @@ impl VdafConfig { })) } } - (_, VdafTypeVar::Pine64HmacSha256Aes128 { param }) => { + (DapVersion::Draft09, VdafTypeVar::Pine64HmacSha256Aes128 { param }) => { if let Err(e) = pine64_hmac_sha256_aes128(¶m) { Err(DapAbort::InvalidTask { detail: format!("invalid parameters for Pine64: {e}"), @@ -253,6 +253,10 @@ impl VdafConfig { detail: format!("unimplemented VDAF type ({typ})"), task_id: *task_id, }), + (_, _) => Err(DapAbort::InvalidTask { + detail: format!("VDAF not supported in {version}"), + task_id: *task_id, + }), } } } @@ -380,7 +384,7 @@ impl TryFrom<&VdafConfig> for messages::taskprov::VdafTypeVar { fatal_error!(err = "{vdaf_config}: dimension is too large for taskprov") })?, }), - VdafConfig::Prio3Draft09(Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { + VdafConfig::Prio3(Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 { bits, length, chunk_length, @@ -398,9 +402,6 @@ impl TryFrom<&VdafConfig> for messages::taskprov::VdafTypeVar { })?, num_proofs: *num_proofs, }), - VdafConfig::Prio3Draft09(..) => Err(fatal_error!( - err = format!("{vdaf_config} is not currently supported for taskprov") - )), VdafConfig::Prio3(..) => Err(fatal_error!( err = format!("{vdaf_config} is not currently supported for taskprov") )), diff --git a/crates/daphne/src/testing/mod.rs b/crates/daphne/src/testing/mod.rs index 16e49e337..5b32c3044 100644 --- a/crates/daphne/src/testing/mod.rs +++ b/crates/daphne/src/testing/mod.rs @@ -223,6 +223,7 @@ impl AggregationJobTest { self.replay_protection, ) .unwrap(), + self.task_config.version, ) .unwrap() } @@ -241,6 +242,7 @@ impl AggregationJobTest { leader_state, agg_job_resp, &self.leader_metrics, + self.task_config.version, ) .unwrap() } @@ -253,7 +255,13 @@ impl AggregationJobTest { ) -> DapError { let metrics = &self.leader_metrics; self.task_config - .consume_agg_job_resp(&self.task_id, leader_state, agg_job_resp, metrics) + .consume_agg_job_resp( + &self.task_id, + leader_state, + agg_job_resp, + metrics, + self.task_config.version, + ) .expect_err("consume_agg_job_resp() succeeded; expected failure") } @@ -1113,9 +1121,10 @@ impl VdafConfig { pub fn gen_measurement(&self) -> Result<DapMeasurement, DapError> { match self { Self::Prio2 { dimension } => Ok(DapMeasurement::U32Vec(vec![1; *dimension])), - Self::Prio3Draft09( - crate::vdaf::Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { - length, .. + Self::Prio3( + crate::vdaf::Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 { + length, + .. }, ) => Ok(DapMeasurement::U64Vec(vec![0; *length])), _ => Err(fatal_error!( diff --git a/crates/daphne/src/vdaf/mod.rs b/crates/daphne/src/vdaf/mod.rs index d3eabf1e5..9c95f53d7 100644 --- a/crates/daphne/src/vdaf/mod.rs +++ b/crates/daphne/src/vdaf/mod.rs @@ -74,7 +74,6 @@ pub(crate) enum VdafError { #[serde(rename_all = "snake_case")] #[cfg_attr(any(test, feature = "test-utils"), derive(deepsize::DeepSizeOf))] pub enum VdafConfig { - Prio3Draft09(Prio3Config), Prio3(Prio3Config), Prio2 { dimension: usize, @@ -112,7 +111,6 @@ pub(crate) fn from_codec_error(c: CodecErrorDraft09) -> CodecError { impl std::fmt::Display for VdafConfig { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - VdafConfig::Prio3Draft09(prio3_config) => write!(f, "Prio3Draft09({prio3_config})"), VdafConfig::Prio3(prio3_config) => write!(f, "Prio3({prio3_config})"), VdafConfig::Prio2 { dimension } => write!(f, "Prio2({dimension})"), #[cfg(feature = "experimental")] @@ -135,8 +133,8 @@ pub enum Prio3Config { Count, /// The sum of 64-bit, unsigned integers. Each measurement is an integer in range `[0, - /// 2^bits)`. - Sum { bits: usize }, + /// max_measurement]`. + Sum { max_measurement: u64 }, /// A histogram for estimating the distribution of 64-bit, unsigned integers where each /// measurement is a bucket index in range `[0, len)`. @@ -151,8 +149,11 @@ pub enum Prio3Config { }, /// A variant of `SumVec` that uses a smaller field (`Field64`), multiple proofs, and a custom - /// XOF (`XofHmacSha256Aes128`). - SumVecField64MultiproofHmacSha256Aes128 { + /// XOF (`XofHmacSha256Aes128`). This VDAF is only supported in DAP-09. + // + // Ensure the serialization of this type is backwards compatible. + #[serde(rename = "sum_vec_field64_multiproof_hmac_sha256_aes128")] + Draft09SumVecField64MultiproofHmacSha256Aes128 { bits: usize, length: usize, chunk_length: usize, @@ -168,18 +169,18 @@ impl std::fmt::Display for Prio3Config { length, chunk_length, } => write!(f, "Histogram({length},{chunk_length})"), - Prio3Config::Sum { bits } => write!(f, "Sum({bits})"), + Prio3Config::Sum { max_measurement } => write!(f, "Sum({max_measurement})"), Prio3Config::SumVec { bits, length, chunk_length, } => write!(f, "SumVec({bits},{length},{chunk_length})"), - Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { + Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 { bits, length, chunk_length, num_proofs, - } => write!(f, "SumVecField64MultiproofHmacSha256Aes128({bits},{length},{chunk_length},{num_proofs})"), + } => write!(f, "Draft09SumVecField64MultiproofHmacSha256Aes128({bits},{length},{chunk_length},{num_proofs})"), } } } @@ -231,9 +232,7 @@ impl AsMut<[u8]> for VdafVerifyKey { #[cfg_attr(any(test, feature = "test-utils"), derive(Debug, Eq, PartialEq))] pub enum VdafPrepState { Prio2(Prio2PrepareState), - Prio3Draft09Field64(Prio3Draft09PrepareState<Field64Draft09, 16>), Prio3Draft09Field64HmacSha256Aes128(Prio3Draft09PrepareState<Field64Draft09, 32>), - Prio3Draft09Field128(Prio3Draft09PrepareState<Field128Draft09, 16>), Prio3Field64(Prio3PrepareState<Field64, 32>), Prio3Field64HmacSha256Aes128(Prio3PrepareState<Field64, 32>), Prio3Field128(Prio3PrepareState<Field128, 32>), @@ -254,9 +253,7 @@ impl deepsize::DeepSizeOf for VdafPrepState { // This happens to be correct for helpers but not for leaders match self { Self::Prio2(_) - | Self::Prio3Draft09Field64(_) | Self::Prio3Draft09Field64HmacSha256Aes128(_) - | Self::Prio3Draft09Field128(_) | Self::Prio3Field64(_) | Self::Prio3Field64HmacSha256Aes128(_) | Self::Prio3Field128(_) @@ -273,10 +270,7 @@ impl deepsize::DeepSizeOf for VdafPrepState { #[cfg_attr(any(test, feature = "test-utils"), derive(Debug))] pub enum VdafPrepShare { Prio2(Prio2PrepareShare), - Prio3Draft09Field64(Prio3Draft09PrepareShare<Field64Draft09, 16>), Prio3Draft09Field64HmacSha256Aes128(Prio3Draft09PrepareShare<Field64Draft09, 32>), - Prio3Draft09Field128(Prio3Draft09PrepareShare<Field128Draft09, 16>), - Prio3Field64(Prio3PrepareShare<Field64, 32>), Prio3Field64HmacSha256Aes128(Prio3PrepareShare<Field64, 32>), Prio3Field128(Prio3PrepareShare<Field128, 32>), @@ -297,9 +291,7 @@ impl deepsize::DeepSizeOf for VdafPrepShare { // share. The length of the verifier share depends on the Prio3 type, which we don't // know at this point. Likewise, whether the XOF seed is present depends on the Prio3 // type. - Self::Prio3Draft09Field64(..) - | Self::Prio3Draft09Field64HmacSha256Aes128(..) - | Self::Prio3Draft09Field128(..) + Self::Prio3Draft09Field64HmacSha256Aes128(..) | Self::Prio3Field64(..) | Self::Prio3Field64HmacSha256Aes128(..) | Self::Prio3Field128(..) @@ -314,11 +306,9 @@ impl deepsize::DeepSizeOf for VdafPrepShare { impl Encode for VdafPrepShare { fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> { match self { - Self::Prio3Draft09Field64(share) => share.encode(bytes).map_err(from_codec_error), Self::Prio3Draft09Field64HmacSha256Aes128(share) => { share.encode(bytes).map_err(from_codec_error) } - Self::Prio3Draft09Field128(share) => share.encode(bytes).map_err(from_codec_error), Self::Prio3Field64(share) => share.encode(bytes), Self::Prio3Field64HmacSha256Aes128(share) => share.encode(bytes), Self::Prio3Field128(share) => share.encode(bytes), @@ -337,20 +327,12 @@ impl ParameterizedDecode<VdafPrepState> for VdafPrepShare { bytes: &mut std::io::Cursor<&[u8]>, ) -> Result<Self, CodecError> { match state { - VdafPrepState::Prio3Draft09Field64(state) => Ok(VdafPrepShare::Prio3Draft09Field64( - Prio3Draft09PrepareShare::decode_with_param(state, bytes) - .map_err(from_codec_error)?, - )), VdafPrepState::Prio3Draft09Field64HmacSha256Aes128(state) => { Ok(VdafPrepShare::Prio3Draft09Field64HmacSha256Aes128( Prio3Draft09PrepareShare::decode_with_param(state, bytes) .map_err(from_codec_error)?, )) } - VdafPrepState::Prio3Draft09Field128(state) => Ok(VdafPrepShare::Prio3Draft09Field128( - Prio3Draft09PrepareShare::decode_with_param(state, bytes) - .map_err(from_codec_error)?, - )), VdafPrepState::Prio3Field64(state) => Ok(VdafPrepShare::Prio3Field64( Prio3PrepareShare::decode_with_param(state, bytes)?, )), @@ -432,10 +414,7 @@ impl Encode for VdafAggregateShare { impl VdafConfig { pub(crate) fn uninitialized_verify_key(&self) -> VdafVerifyKey { match self { - Self::Prio3Draft09(Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { .. }) - | Self::Prio2 { .. } - | Self::Prio3(..) => VdafVerifyKey::L32([0; 32]), - Self::Prio3Draft09(..) => VdafVerifyKey::L16([0; 16]), + Self::Prio2 { .. } | Self::Prio3(..) => VdafVerifyKey::L32([0; 32]), #[cfg(feature = "experimental")] Self::Mastic { .. } => VdafVerifyKey::L16([0; 16]), Self::Pine(..) => VdafVerifyKey::L32([0; 32]), @@ -445,18 +424,11 @@ impl VdafConfig { /// Parse a verification key from raw bytes. pub fn get_decoded_verify_key(&self, bytes: &[u8]) -> Result<VdafVerifyKey, CodecError> { match self { - Self::Prio3Draft09(Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { .. }) - | Self::Prio2 { .. } - | Self::Prio3(..) => Ok(VdafVerifyKey::L32( + Self::Prio2 { .. } | Self::Prio3(..) => Ok(VdafVerifyKey::L32( <[u8; 32]>::try_from(bytes) .map_err(|e| CodecErrorDraft09::Other(Box::new(e))) .map_err(from_codec_error)?, )), - Self::Prio3Draft09(..) => Ok(VdafVerifyKey::L16( - <[u8; 16]>::try_from(bytes) - .map_err(|e| CodecErrorDraft09::Other(Box::new(e))) - .map_err(from_codec_error)?, - )), #[cfg(feature = "experimental")] Self::Mastic { .. } => Ok(VdafVerifyKey::L16( <[u8; 16]>::try_from(bytes).map_err(|e| CodecError::Other(Box::new(e)))?, @@ -481,7 +453,7 @@ impl VdafConfig { /// executed. pub fn is_valid_agg_param(&self, agg_param: &[u8]) -> bool { match self { - Self::Prio3Draft09(..) | Self::Prio3(..) | Self::Prio2 { .. } => agg_param.is_empty(), + Self::Prio3(..) | Self::Prio2 { .. } => agg_param.is_empty(), #[cfg(feature = "experimental")] Self::Mastic { .. } => true, Self::Pine(..) => agg_param.is_empty(), @@ -612,7 +584,7 @@ where Ok(vdaf.unshard(&(), agg_shares_vec, num_measurements)?) } -fn shard_then_encode_draft09<V: VdafDraft09 + ClientDraft09<16>>( +pub(crate) fn shard_then_encode_draft09<V: VdafDraft09 + ClientDraft09<16>>( vdaf: &V, measurement: &V::Measurement, nonce: &[u8; 16], diff --git a/crates/daphne/src/vdaf/prio3.rs b/crates/daphne/src/vdaf/prio3.rs index 949ca61e3..35307d3a8 100644 --- a/crates/daphne/src/vdaf/prio3.rs +++ b/crates/daphne/src/vdaf/prio3.rs @@ -6,8 +6,8 @@ use crate::{ fatal_error, messages::TaskId, - vdaf::{VdafError, VdafVerifyKey}, - DapAggregateResult, DapMeasurement, Prio3Config, VdafAggregateShare, VdafPrepShare, + vdaf::{prio3_draft09, VdafError, VdafVerifyKey}, + DapAggregateResult, DapMeasurement, DapVersion, Prio3Config, VdafAggregateShare, VdafPrepShare, VdafPrepState, }; @@ -25,385 +25,542 @@ use prio::{ const CTX_STRING_PREFIX: &[u8] = b"dap-13"; -/// Split the given measurement into a sequence of encoded input shares. -pub(crate) fn prio3_shard( - config: &Prio3Config, - measurement: DapMeasurement, - nonce: &[u8; 16], - task_id: TaskId, -) -> Result<(Vec<u8>, [Vec<u8>; 2]), VdafError> { - match (config, measurement) { - (Prio3Config::Count, DapMeasurement::U64(measurement)) if measurement < 2 => { - let vdaf = Prio3::new_count(2).map_err(|e| { - VdafError::Dap( - fatal_error!(err = ?e, "failed to create prio3 count from num_aggregators(2)"), - ) - })?; - shard_then_encode(&vdaf, task_id, &(measurement != 0), nonce) - } - ( - Prio3Config::Histogram { - length, - chunk_length, - }, - DapMeasurement::U64(measurement), - ) => { - let vdaf = Prio3::new_histogram(2, *length, *chunk_length) - .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 Histogram from num_aggregators(2), length({length}), chunk_length({chunk_length})")))?; - let m: usize = measurement.try_into().unwrap(); - shard_then_encode(&vdaf, task_id, &m, nonce) - } - (Prio3Config::Sum { .. }, DapMeasurement::U64(_)) => { - Err(VdafError::Dap(fatal_error!(err = "Sum unimplemented"))) - } - ( - Prio3Config::SumVec { - bits, - length, - chunk_length, - }, - DapMeasurement::U128Vec(measurement), - ) => { - let vdaf = Prio3::new_sum_vec(2, *bits, *length, *chunk_length) - .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 sum vec from num_aggregators(2), bits({bits}), length({length}), chunk_length({chunk_length})")))?; - shard_then_encode(&vdaf, task_id, &measurement, nonce) +impl Prio3Config { + pub(crate) fn shard( + &self, + version: DapVersion, + measurement: DapMeasurement, + nonce: &[u8; 16], + task_id: TaskId, + ) -> Result<(Vec<u8>, [Vec<u8>; 2]), VdafError> { + match (version, self, measurement) { + (DapVersion::Latest, Prio3Config::Count, DapMeasurement::U64(measurement)) + if measurement < 2 => + { + let vdaf = Prio3::new_count(2).map_err(|e| { + VdafError::Dap(fatal_error!(err = ?e, "initializing {self:?} failed")) + })?; + shard_then_encode(&vdaf, task_id, &(measurement != 0), nonce) + } + ( + DapVersion::Latest, + Prio3Config::Sum { max_measurement }, + DapMeasurement::U64(measurement), + ) => { + let vdaf = Prio3::new_sum(2, *max_measurement).map_err(|e| { + VdafError::Dap(fatal_error!(err = ?e, "initializing {self:?} failed")) + })?; + shard_then_encode(&vdaf, task_id, &measurement, nonce) + } + ( + DapVersion::Latest, + Prio3Config::Histogram { + length, + chunk_length, + }, + DapMeasurement::U64(measurement), + ) => { + let vdaf = Prio3::new_histogram(2, *length, *chunk_length).map_err(|e| { + VdafError::Dap(fatal_error!(err = ?e, "initializing {self:?} failed")) + })?; + let m: usize = measurement.try_into().unwrap(); + shard_then_encode(&vdaf, task_id, &m, nonce) + } + ( + DapVersion::Latest, + Prio3Config::SumVec { + bits, + length, + chunk_length, + }, + DapMeasurement::U128Vec(measurement), + ) => { + let vdaf = Prio3::new_sum_vec(2, *bits, *length, *chunk_length).map_err(|e| { + VdafError::Dap(fatal_error!(err = ?e, "initializing {self:?} failed")) + })?; + shard_then_encode(&vdaf, task_id, &measurement, nonce) + } + ( + DapVersion::Draft09, + Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 { + bits, + length, + chunk_length, + num_proofs, + }, + DapMeasurement::U64Vec(measurement), + ) => { + let vdaf = prio3_draft09::new_prio3_sum_vec_field64_multiproof_hmac_sha256_aes128( + *bits, + *length, + *chunk_length, + *num_proofs, + )?; + super::shard_then_encode_draft09(&vdaf, &measurement, nonce) + } + _ => Err(VdafError::Dap(fatal_error!( + err = + format!("unexpected measurement or {self:?} is not supported in DAP {version}") + ))), } - ( - Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { .. }, - DapMeasurement::U64Vec(_), - ) => Err(VdafError::Dap(fatal_error!( - err = format!( - "prio3_shard: SumVecField64MultiproofHmacSha256Aes128 is not defined for VDAF-13" - ) - ))), - _ => Err(VdafError::Dap(fatal_error!( - err = format!("prio3_shard: unexpected VDAF config {config:?}") - ))), } -} -/// Consume an input share and return the corresponding VDAF step and message. -pub(crate) fn prio3_prep_init( - config: &Prio3Config, - verify_key: &VdafVerifyKey, - task_id: TaskId, - agg_id: usize, - nonce: &[u8; 16], - public_share_data: &[u8], - input_share_data: &[u8], -) -> Result<(VdafPrepState, VdafPrepShare), VdafError> { - return match (&config, verify_key) { - (Prio3Config::Count, VdafVerifyKey::L32(verify_key)) => { - let vdaf = Prio3::new_count(2).map_err(|e| { - VdafError::Dap( - fatal_error!(err = ?e, "failed to create prio3 from num_aggregators(2)"), - ) - })?; - let (state, share) = prep_init( - vdaf, - task_id, + #[allow(clippy::too_many_arguments)] + pub(crate) fn prep_init( + &self, + version: DapVersion, + verify_key: &VdafVerifyKey, + task_id: TaskId, + agg_id: usize, + nonce: &[u8; 16], + public_share_data: &[u8], + input_share_data: &[u8], + ) -> Result<(VdafPrepState, VdafPrepShare), VdafError> { + return match (version, self, verify_key) { + (DapVersion::Latest, Prio3Config::Count, VdafVerifyKey::L32(verify_key)) => { + let vdaf = Prio3::new_count(2).map_err(|e| { + VdafError::Dap(fatal_error!(err = ?e, "initializing {self:?} failed")) + })?; + let (state, share) = prep_init( + vdaf, + task_id, + verify_key, + agg_id, + nonce, + public_share_data, + input_share_data, + )?; + Ok(( + VdafPrepState::Prio3Field64(state), + VdafPrepShare::Prio3Field64(share), + )) + } + ( + DapVersion::Latest, + Prio3Config::Sum { max_measurement }, + VdafVerifyKey::L32(verify_key), + ) => { + let vdaf = Prio3::new_sum(2, *max_measurement).map_err(|e| { + VdafError::Dap(fatal_error!(err = ?e, "initializing {self:?} failed")) + })?; + let (state, share) = prep_init( + vdaf, + task_id, + verify_key, + agg_id, + nonce, + public_share_data, + input_share_data, + )?; + Ok(( + VdafPrepState::Prio3Field64(state), + VdafPrepShare::Prio3Field64(share), + )) + } + ( + DapVersion::Latest, + Prio3Config::Histogram { + length, + chunk_length, + }, + VdafVerifyKey::L32(verify_key), + ) => { + let vdaf = Prio3::new_histogram(2, *length, *chunk_length).map_err(|e| { + VdafError::Dap(fatal_error!(err = ?e, "initializing {self:?} failed")) + })?; + let (state, share) = prep_init( + vdaf, + task_id, + verify_key, + agg_id, + nonce, + public_share_data, + input_share_data, + )?; + Ok(( + VdafPrepState::Prio3Field128(state), + VdafPrepShare::Prio3Field128(share), + )) + } + ( + DapVersion::Latest, + Prio3Config::SumVec { + bits, + length, + chunk_length, + }, + VdafVerifyKey::L32(verify_key), + ) => { + let vdaf = Prio3::new_sum_vec(2, *bits, *length, *chunk_length).map_err(|e| { + VdafError::Dap(fatal_error!(err = ?e, "initializing {self:?} failed")) + })?; + let (state, share) = prep_init( + vdaf, + task_id, + verify_key, + agg_id, + nonce, + public_share_data, + input_share_data, + )?; + Ok(( + VdafPrepState::Prio3Field128(state), + VdafPrepShare::Prio3Field128(share), + )) + } + ( + DapVersion::Draft09, + Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 { + bits, + length, + chunk_length, + num_proofs, + }, + VdafVerifyKey::L32(verify_key), + ) => { + let vdaf = prio3_draft09::new_prio3_sum_vec_field64_multiproof_hmac_sha256_aes128( + *bits, + *length, + *chunk_length, + *num_proofs, + )?; + let (state, share) = prio3_draft09::prep_init_draft09( + vdaf, + verify_key, + agg_id, + nonce, + public_share_data, + input_share_data, + )?; + Ok(( + VdafPrepState::Prio3Draft09Field64HmacSha256Aes128(state), + VdafPrepShare::Prio3Draft09Field64HmacSha256Aes128(share), + )) + } + _ => Err(VdafError::Dap(fatal_error!( + err = + format!("unexpected verify key or {self:?} is not supported in DAP {version}") + ))), + }; + + type Prio3Prepared<T, const SEED_SIZE: usize> = ( + Prio3PrepareState<<T as Type>::Field, SEED_SIZE>, + Prio3PrepareShare<<T as Type>::Field, SEED_SIZE>, + ); + + fn prep_init<T, P, const SEED_SIZE: usize>( + vdaf: Prio3<T, P, SEED_SIZE>, + task_id: TaskId, + verify_key: &[u8; SEED_SIZE], + agg_id: usize, + nonce: &[u8; 16], + public_share_data: &[u8], + input_share_data: &[u8], + ) -> Result<Prio3Prepared<T, SEED_SIZE>, VdafError> + where + T: Type, + P: Xof<SEED_SIZE>, + { + // Parse the public share. + let public_share = Prio3PublicShare::get_decoded_with_param(&vdaf, public_share_data)?; + + // Parse the input share. + let input_share = + Prio3InputShare::get_decoded_with_param(&(&vdaf, agg_id), input_share_data)?; + + let mut ctx = [0; CTX_STRING_PREFIX.len() + 32]; + ctx[..CTX_STRING_PREFIX.len()].copy_from_slice(CTX_STRING_PREFIX); + ctx[CTX_STRING_PREFIX.len()..].copy_from_slice(&task_id.0); + // Run the prepare-init algorithm, returning the initial state. + Ok(vdaf.prepare_init( verify_key, + &ctx, agg_id, + &(), nonce, - public_share_data, - input_share_data, - )?; - Ok(( + &public_share, + &input_share, + )?) + } + } + + pub(crate) fn prep_finish_from_shares( + &self, + version: DapVersion, + agg_id: usize, + task_id: TaskId, + host_state: VdafPrepState, + host_share: VdafPrepShare, + peer_share_data: &[u8], + ) -> Result<(VdafAggregateShare, Vec<u8>), VdafError> { + let (agg_share, outbound) = match (version, self, host_state, host_share) { + ( + DapVersion::Latest, + Prio3Config::Count, VdafPrepState::Prio3Field64(state), VdafPrepShare::Prio3Field64(share), - )) - } - ( - Prio3Config::Histogram { - length, - chunk_length, - }, - VdafVerifyKey::L32(verify_key), - ) => { - let vdaf = Prio3::new_histogram(2, *length, *chunk_length) - .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 histogram from num_aggregators(2), length({length}), chunk_length({chunk_length})")))?; - let (state, share) = prep_init( - vdaf, - task_id, - verify_key, - agg_id, - nonce, - public_share_data, - input_share_data, - )?; - Ok(( + ) => { + let vdaf = Prio3::new_count(2).map_err(|e| { + VdafError::Dap(fatal_error!(err = ?e, "initializing {self:?} failed")) + })?; + let (out_share, outbound) = + prep_finish_from_shares(&vdaf, task_id, agg_id, state, share, peer_share_data)?; + let agg_share = VdafAggregateShare::Field64(vdaf.aggregate(&(), [out_share])?); + (agg_share, outbound) + } + ( + DapVersion::Latest, + Prio3Config::Sum { max_measurement }, + VdafPrepState::Prio3Field64(state), + VdafPrepShare::Prio3Field64(share), + ) => { + let vdaf = Prio3::new_sum(2, *max_measurement).map_err(|e| { + VdafError::Dap(fatal_error!(err = ?e, "initializing {self:?} failed")) + })?; + let (out_share, outbound) = + prep_finish_from_shares(&vdaf, task_id, agg_id, state, share, peer_share_data)?; + let agg_share = VdafAggregateShare::Field64(vdaf.aggregate(&(), [out_share])?); + (agg_share, outbound) + } + ( + DapVersion::Latest, + Prio3Config::Histogram { + length, + chunk_length, + }, VdafPrepState::Prio3Field128(state), VdafPrepShare::Prio3Field128(share), - )) - } - (Prio3Config::Sum { .. }, VdafVerifyKey::L32(_)) => { - Err(VdafError::Dap(fatal_error!(err = "sum unimplemented"))) - } - ( - Prio3Config::SumVec { - bits, - length, - chunk_length, - }, - VdafVerifyKey::L32(verify_key), - ) => { - let vdaf = Prio3::new_sum_vec(2, *bits, *length, *chunk_length) - .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 sum vec from num_aggregators(2), bits({bits}), length({length}), chunk_length({chunk_length})")))?; - let (state, share) = prep_init( - vdaf, - task_id, - verify_key, - agg_id, - nonce, - public_share_data, - input_share_data, - )?; - Ok(( + ) => { + let vdaf = Prio3::new_histogram(2, *length, *chunk_length).map_err(|e| { + VdafError::Dap(fatal_error!(err = ?e, "initializing {self:?} failed")) + })?; + let (out_share, outbound) = + prep_finish_from_shares(&vdaf, task_id, agg_id, state, share, peer_share_data)?; + let agg_share = VdafAggregateShare::Field128(vdaf.aggregate(&(), [out_share])?); + (agg_share, outbound) + } + ( + DapVersion::Latest, + Prio3Config::SumVec { + bits, + length, + chunk_length, + }, VdafPrepState::Prio3Field128(state), VdafPrepShare::Prio3Field128(share), - )) - } - ( - Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { - ..}, - VdafVerifyKey::L32(_), - ) => { - Err(VdafError::Dap(fatal_error!(err = format!("prio3_shard: SumVecField64MultiproofHmacSha256Aes128 is not defined for VDAF-13")))) - }, - _ => { - return Err(VdafError::Dap(fatal_error!( - err = "unhandled config and verify key combination", - ))) - } - - }; + ) => { + let vdaf = Prio3::new_sum_vec(2, *bits, *length, *chunk_length).map_err(|e| { + VdafError::Dap(fatal_error!(err = ?e, "initializing {self:?} failed")) + })?; + let (out_share, outbound) = + prep_finish_from_shares(&vdaf, task_id, agg_id, state, share, peer_share_data)?; + let agg_share = VdafAggregateShare::Field128(vdaf.aggregate(&(), [out_share])?); + (agg_share, outbound) + } + ( + DapVersion::Draft09, + Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 { + bits, + length, + chunk_length, + num_proofs, + }, + VdafPrepState::Prio3Draft09Field64HmacSha256Aes128(state), + VdafPrepShare::Prio3Draft09Field64HmacSha256Aes128(share), + ) => { + let vdaf = prio3_draft09::new_prio3_sum_vec_field64_multiproof_hmac_sha256_aes128( + *bits, + *length, + *chunk_length, + *num_proofs, + )?; + let (out_share, outbound) = super::prep_finish_from_shares_draft09( + &vdaf, + agg_id, + state, + share, + peer_share_data, + )?; + let agg_share = VdafAggregateShare::Field64Draft09( + prio_draft09::vdaf::Aggregator::aggregate(&vdaf, &(), [out_share])?, + ); + (agg_share, outbound) + } + _ => { + return Err(VdafError::Dap(fatal_error!( + err = format!( + "unexpected prep state or share or {self:?} is not supported in DAP {version}" + ) + ))) + } + }; - type Prio3Prepared<T, const SEED_SIZE: usize> = ( - Prio3PrepareState<<T as Type>::Field, SEED_SIZE>, - Prio3PrepareShare<<T as Type>::Field, SEED_SIZE>, - ); + Ok((agg_share, outbound)) + } - fn prep_init<T, P, const SEED_SIZE: usize>( - vdaf: Prio3<T, P, SEED_SIZE>, + pub(crate) fn prep_finish( + &self, + host_state: VdafPrepState, + peer_message_data: &[u8], task_id: TaskId, - verify_key: &[u8; SEED_SIZE], - agg_id: usize, - nonce: &[u8; 16], - public_share_data: &[u8], - input_share_data: &[u8], - ) -> Result<Prio3Prepared<T, SEED_SIZE>, VdafError> - where - T: Type, - P: Xof<SEED_SIZE>, - { - // Parse the public share. - let public_share = Prio3PublicShare::get_decoded_with_param(&vdaf, public_share_data)?; - - // Parse the input share. - let input_share = - Prio3InputShare::get_decoded_with_param(&(&vdaf, agg_id), input_share_data)?; + version: DapVersion, + ) -> Result<VdafAggregateShare, VdafError> { + let agg_share = match (version, self, host_state) { + (DapVersion::Latest, Prio3Config::Count, VdafPrepState::Prio3Field64(state)) => { + let vdaf = Prio3::new_count(2).map_err(|e| { + VdafError::Dap(fatal_error!(err = ?e, "initializing {self:?} failed")) + })?; + let out_share = prep_finish(&vdaf, task_id, state, peer_message_data)?; + VdafAggregateShare::Field64(vdaf.aggregate(&(), [out_share])?) + } + ( + DapVersion::Latest, + Prio3Config::Sum { max_measurement }, + VdafPrepState::Prio3Field64(state), + ) => { + let vdaf = Prio3::new_sum(2, *max_measurement).map_err(|e| { + VdafError::Dap(fatal_error!(err = ?e, "initializing {self:?} failed")) + })?; + let out_share = prep_finish(&vdaf, task_id, state, peer_message_data)?; + VdafAggregateShare::Field64(vdaf.aggregate(&(), [out_share])?) + } + ( + DapVersion::Latest, + Prio3Config::Histogram { + length, + chunk_length, + }, + VdafPrepState::Prio3Field128(state), + ) => { + let vdaf = Prio3::new_histogram(2, *length, *chunk_length).map_err(|e| { + VdafError::Dap(fatal_error!(err = ?e, "initializing {self:?} failed")) + })?; + let out_share = prep_finish(&vdaf, task_id, state, peer_message_data)?; + VdafAggregateShare::Field128(vdaf.aggregate(&(), [out_share])?) + } + ( + DapVersion::Latest, + Prio3Config::SumVec { + bits, + length, + chunk_length, + }, + VdafPrepState::Prio3Field128(state), + ) => { + let vdaf = Prio3::new_sum_vec(2, *bits, *length, *chunk_length).map_err(|e| { + VdafError::Dap(fatal_error!(err = ?e, "initializing {self:?} failed")) + })?; + let out_share = prep_finish(&vdaf, task_id, state, peer_message_data)?; + VdafAggregateShare::Field128(vdaf.aggregate(&(), [out_share])?) + } + ( + DapVersion::Draft09, + Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 { + bits, + length, + chunk_length, + num_proofs, + }, + VdafPrepState::Prio3Draft09Field64HmacSha256Aes128(state), + ) => { + let vdaf = prio3_draft09::new_prio3_sum_vec_field64_multiproof_hmac_sha256_aes128( + *bits, + *length, + *chunk_length, + *num_proofs, + )?; + let out_share = super::prep_finish_draft09(&vdaf, state, peer_message_data)?; + VdafAggregateShare::Field64Draft09(prio_draft09::vdaf::Aggregator::aggregate( + &vdaf, + &(), + [out_share], + )?) + } + _ => { + return Err(VdafError::Dap(fatal_error!( + err = format!( + "unexpected prep state or {self:?} is not supported in DAP {version}" + ) + ))) + } + }; - let mut ctx = [0; CTX_STRING_PREFIX.len() + 32]; - ctx[..CTX_STRING_PREFIX.len()].copy_from_slice(CTX_STRING_PREFIX); - ctx[CTX_STRING_PREFIX.len()..].copy_from_slice(&task_id.0); - // Run the prepare-init algorithm, returning the initial state. - Ok(vdaf.prepare_init( - verify_key, - &ctx, - agg_id, - &(), - nonce, - &public_share, - &input_share, - )?) + Ok(agg_share) } -} -/// Consume the prep shares and return our output share and the prep message. -pub(crate) fn prio3_prep_finish_from_shares( - config: &Prio3Config, - agg_id: usize, - task_id: TaskId, - host_state: VdafPrepState, - host_share: VdafPrepShare, - peer_share_data: &[u8], -) -> Result<(VdafAggregateShare, Vec<u8>), VdafError> { - let (agg_share, outbound) = match (&config, host_state, host_share) { - ( - Prio3Config::Count, - VdafPrepState::Prio3Field64(state), - VdafPrepShare::Prio3Field64(share), - ) => { - let vdaf = Prio3::new_count(2).map_err(|e| { - VdafError::Dap( - fatal_error!(err = ?e, "failed to create prio3 count num_aggregators(2)"), - ) - })?; - let (out_share, outbound) = - prep_finish_from_shares(&vdaf, task_id, agg_id, state, share, peer_share_data)?; - let agg_share = VdafAggregateShare::Field64(vdaf.aggregate(&(), [out_share])?); - (agg_share, outbound) - } - ( - Prio3Config::Histogram { - length, - chunk_length, - }, - VdafPrepState::Prio3Field128(state), - VdafPrepShare::Prio3Field128(share), - ) => { - let vdaf = Prio3::new_histogram(2, *length, *chunk_length) - .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 histogram from num_aggregators(2), length({length}), chunk_length({chunk_length})")))?; - let (out_share, outbound) = - prep_finish_from_shares(&vdaf, task_id, agg_id, state, share, peer_share_data)?; - let agg_share = VdafAggregateShare::Field128(vdaf.aggregate(&(), [out_share])?); - (agg_share, outbound) - } - ( - Prio3Config::Sum { .. }, - VdafPrepState::Prio3Field128(_), - VdafPrepShare::Prio3Field128(_), - ) => Err(VdafError::Dap(fatal_error!(err = "Prio3Sum is not supported in VDAF-13")))?, - ( - Prio3Config::SumVec { - bits, - length, - chunk_length, - }, - VdafPrepState::Prio3Field128(state), - VdafPrepShare::Prio3Field128(share), - ) => { - let vdaf = Prio3::new_sum_vec(2, *bits, *length, *chunk_length) - .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 sum vec from num_aggregators(2), bits({bits}), length({length}), chunk_length({chunk_length})")))?; - let (out_share, outbound) = - prep_finish_from_shares(&vdaf, task_id, agg_id, state, share, peer_share_data)?; - let agg_share = VdafAggregateShare::Field128(vdaf.aggregate(&(), [out_share])?); - (agg_share, outbound) - } - ( - Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { - .. - }, - VdafPrepState::Prio3Field64HmacSha256Aes128(_), - VdafPrepShare::Prio3Field64HmacSha256Aes128(_), - ) => { - return Err(VdafError::Dap(fatal_error!(err = format!("prio3_prep_finish_from_shares: SumVecField64MultiproofHmacSha256Aes128 is not defined for VDAF-13")))) - } - _ => { - return Err(VdafError::Dap(fatal_error!( - err = format!("prio3_prep_finish_from_shares: unexpected field type for step or message") - ))) - } - }; + pub(crate) fn unshard<M: IntoIterator<Item = Vec<u8>>>( + &self, + version: DapVersion, + num_measurements: usize, + agg_shares: M, + ) -> Result<DapAggregateResult, VdafError> { + match (version, self) { + (DapVersion::Latest, Prio3Config::Count) => { + let vdaf = Prio3::new_count(2).map_err(|e| { + VdafError::Dap(fatal_error!(err = ?e, "initializing {self:?} failed")) + })?; - Ok((agg_share, outbound)) -} + let agg_res = unshard(&vdaf, num_measurements, agg_shares)?; + Ok(DapAggregateResult::U64(agg_res)) + } + (DapVersion::Latest, Prio3Config::Sum { max_measurement }) => { + let vdaf = Prio3::new_sum(2, *max_measurement).map_err(|e| { + VdafError::Dap(fatal_error!(err = ?e, "initializing {self:?} failed")) + })?; -/// Consume the prep message and output our output share. -pub(crate) fn prio3_prep_finish( - config: &Prio3Config, - host_state: VdafPrepState, - peer_message_data: &[u8], - task_id: TaskId, -) -> Result<VdafAggregateShare, VdafError> { - let agg_share = match (&config, host_state) { - (Prio3Config::Count, VdafPrepState::Prio3Field64(state)) => { - let vdaf = Prio3::new_count(2).map_err(|e| { - VdafError::Dap( - fatal_error!(err = ?e, "failed to create prio3 count from num_aggregators(2)"), - ) - })?; - let out_share = prep_finish(&vdaf, task_id, state, peer_message_data)?; - VdafAggregateShare::Field64(vdaf.aggregate(&(), [out_share])?) - } - ( - Prio3Config::Histogram { - length, - chunk_length, - }, - VdafPrepState::Prio3Field128(state), - ) => { - let vdaf = Prio3::new_histogram(2, *length, *chunk_length) - .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 histogram from num_aggregators(2), length({length}), chunk_length({chunk_length})")))?; - let out_share = prep_finish(&vdaf, task_id, state, peer_message_data)?; - VdafAggregateShare::Field128(vdaf.aggregate(&(), [out_share])?) - } - (Prio3Config::Sum { .. }, VdafPrepState::Prio3Field128(_)) => { - Err(VdafError::Dap(fatal_error!(err = "sum unimplemented")))? - } - ( - Prio3Config::SumVec { - bits, - length, - chunk_length, - }, - VdafPrepState::Prio3Field128(state), - ) => { - let vdaf = Prio3::new_sum_vec(2, *bits, *length, *chunk_length) - .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 sum vec from num_aggregators(2), bits({bits}), length({length}), chunk_length({chunk_length})")))?; - let out_share = prep_finish(&vdaf, task_id, state, peer_message_data)?; - VdafAggregateShare::Field128(vdaf.aggregate(&(), [out_share])?) - } - ( - Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { - .. - }, - VdafPrepState::Prio3Field64HmacSha256Aes128(_), - ) => { + let agg_res = unshard(&vdaf, num_measurements, agg_shares)?; + Ok(DapAggregateResult::U64(agg_res)) + } + ( + DapVersion::Latest, + Prio3Config::Histogram { + length, + chunk_length, + }, + ) => { + let vdaf = Prio3::new_histogram(2, *length, *chunk_length).map_err(|e| { + VdafError::Dap(fatal_error!(err = ?e, "initializing {self:?} failed")) + })?; - return Err(VdafError::Dap(fatal_error!(err = format!("prio3_prep_finish: SumVecField64MultiproofHmacSha256Aes128 is not defined for VDAF-13")))) - } + let agg_res = unshard(&vdaf, num_measurements, agg_shares)?; + Ok(DapAggregateResult::U128Vec(agg_res)) + } + ( + DapVersion::Latest, + Prio3Config::SumVec { + bits, + length, + chunk_length, + }, + ) => { + let vdaf = Prio3::new_sum_vec(2, *bits, *length, *chunk_length).map_err(|e| { + VdafError::Dap(fatal_error!(err = ?e, "initializing {self:?} failed")) + })?; - _ => { - return Err(VdafError::Dap(fatal_error!( - err = format!("prio3_prep_finish: unexpected field type for step or message") - ))) - } - }; - - Ok(agg_share) -} - -/// Interpret `agg_shares` as a sequence of encoded aggregate shares and unshard them. -pub(crate) fn prio3_unshard<M: IntoIterator<Item = Vec<u8>>>( - config: &Prio3Config, - num_measurements: usize, - agg_shares: M, -) -> Result<DapAggregateResult, VdafError> { - match config { - Prio3Config::Count => { - let vdaf = Prio3::new_count(2).map_err(|e| { - VdafError::Dap( - fatal_error!(err = ?e, "failed to create prio3 count from num_aggregators(2)"), - ) - })?; - let agg_res = unshard(&vdaf, num_measurements, agg_shares)?; - Ok(DapAggregateResult::U64(agg_res)) - } - Prio3Config::Histogram { - length, - chunk_length, - } => { - let vdaf = Prio3::new_histogram(2, *length, *chunk_length) - .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 histogram from num_aggregators(2), length({length}), chunk_length({chunk_length})")))?; - let agg_res = unshard(&vdaf, num_measurements, agg_shares)?; - Ok(DapAggregateResult::U128Vec(agg_res)) - } - Prio3Config::Sum { .. } => Err(VdafError::Dap(fatal_error!(err = "sum unimplemented"))), - Prio3Config::SumVec { - bits, - length, - chunk_length, - } => { - let vdaf = Prio3::new_sum_vec(2, *bits, *length, *chunk_length) - .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 sum vec from num_aggregators(2), bits({bits}), length({length}), chunk_length({chunk_length})")))?; - let agg_res = unshard(&vdaf, num_measurements, agg_shares)?; - Ok(DapAggregateResult::U128Vec(agg_res)) - } - Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { - .. - } => { - Err(VdafError::Dap(fatal_error!(err = format!("prio3_prep_finish: SumVecField64MultiproofHmacSha256Aes128 is not defined for VDAF-13")))) + let agg_res = unshard(&vdaf, num_measurements, agg_shares)?; + Ok(DapAggregateResult::U128Vec(agg_res)) + } + ( + DapVersion::Draft09, + Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 { + bits, + length, + chunk_length, + num_proofs, + }, + ) => { + let vdaf = prio3_draft09::new_prio3_sum_vec_field64_multiproof_hmac_sha256_aes128( + *bits, + *length, + *chunk_length, + *num_proofs, + )?; + let agg_res = super::unshard_draft09(&vdaf, num_measurements, agg_shares)?; + Ok(DapAggregateResult::U64Vec(agg_res)) + } + _ => Err(VdafError::Dap(fatal_error!( + err = format!("{version} does not support {self:?}") + ))), } } } @@ -413,17 +570,17 @@ mod test { use crate::{ hpke::HpkeKemId, - test_versions, testing::AggregationJobTest, vdaf::{Prio3Config, VdafConfig}, DapAggregateResult, DapAggregationParam, DapMeasurement, DapVersion, }; - fn roundtrip_count(version: DapVersion) { + #[test] + fn roundtrip_count() { let mut t = AggregationJobTest::new( &VdafConfig::Prio3(Prio3Config::Count), HpkeKemId::X25519HkdfSha256, - version, + DapVersion::Latest, ); let got = t.roundtrip( DapAggregationParam::Empty, @@ -438,9 +595,30 @@ mod test { assert_eq!(got, DapAggregateResult::U64(3)); } - test_versions! { roundtrip_count } + #[test] + fn roundtrip_sum() { + let mut t = AggregationJobTest::new( + &VdafConfig::Prio3(Prio3Config::Sum { + max_measurement: 1337, + }), + HpkeKemId::X25519HkdfSha256, + DapVersion::Latest, + ); + let got = t.roundtrip( + DapAggregationParam::Empty, + vec![ + DapMeasurement::U64(0), + DapMeasurement::U64(1), + DapMeasurement::U64(1337), + DapMeasurement::U64(4), + DapMeasurement::U64(0), + ], + ); + assert_eq!(got, DapAggregateResult::U64(1342)); + } - fn roundtrip_sum_vec(version: DapVersion) { + #[test] + fn roundtrip_sum_vec() { let mut t = AggregationJobTest::new( &VdafConfig::Prio3(Prio3Config::SumVec { bits: 23, @@ -448,7 +626,7 @@ mod test { chunk_length: 1, }), HpkeKemId::X25519HkdfSha256, - version, + DapVersion::Latest, ); let got = t.roundtrip( DapAggregationParam::Empty, @@ -461,16 +639,15 @@ mod test { assert_eq!(got, DapAggregateResult::U128Vec(vec![1338, 1338])); } - test_versions! { roundtrip_sum_vec } - - fn roundtrip_histogram(version: DapVersion) { + #[test] + fn roundtrip_histogram() { let mut t = AggregationJobTest::new( &VdafConfig::Prio3(Prio3Config::Histogram { length: 3, chunk_length: 1, }), HpkeKemId::X25519HkdfSha256, - version, + DapVersion::Latest, ); let got = t.roundtrip( DapAggregationParam::Empty, @@ -484,6 +661,4 @@ mod test { ); assert_eq!(got, DapAggregateResult::U128Vec(vec![1, 1, 3])); } - - test_versions! { roundtrip_histogram } } diff --git a/crates/daphne/src/vdaf/prio3_draft09.rs b/crates/daphne/src/vdaf/prio3_draft09.rs index bd13f30bc..93800e783 100644 --- a/crates/daphne/src/vdaf/prio3_draft09.rs +++ b/crates/daphne/src/vdaf/prio3_draft09.rs @@ -4,16 +4,8 @@ //! Parameters for the [Prio3 VDAF](https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/09/). use crate::{ - fatal_error, - messages::taskprov::VDAF_TYPE_PRIO3_SUM_VEC_FIELD64_MULTIPROOF_HMAC_SHA256_AES128, - vdaf::{VdafError, VdafVerifyKey}, - DapAggregateResult, DapMeasurement, Prio3Config, VdafAggregateShare, VdafPrepShare, - VdafPrepState, -}; - -use super::{ - prep_finish_draft09, prep_finish_from_shares_draft09, shard_then_encode_draft09, - unshard_draft09, + fatal_error, messages::taskprov::VDAF_TYPE_PRIO3_SUM_VEC_FIELD64_MULTIPROOF_HMAC_SHA256_AES128, + vdaf::VdafError, }; use prio_draft09::{ @@ -31,12 +23,10 @@ use prio_draft09::{ }, }; -const ERR_FIELD_TYPE: &str = "unexpected field type for step or message"; - type Prio3SumVecField64MultiproofHmacSha256Aes128 = Prio3<SumVec<Field64, ParallelSum<Field64, Mul<Field64>>>, XofHmacSha256Aes128, 32>; -fn new_prio3_sum_vec_field64_multiproof_hmac_sha256_aes128( +pub(crate) fn new_prio3_sum_vec_field64_multiproof_hmac_sha256_aes128( bits: usize, length: usize, chunk_length: usize, @@ -56,458 +46,31 @@ fn new_prio3_sum_vec_field64_multiproof_hmac_sha256_aes128( .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3"))) } -/// Split the given measurement into a sequence of encoded input shares. -pub(crate) fn prio3_draft09_shard( - config: &Prio3Config, - measurement: DapMeasurement, - nonce: &[u8; 16], -) -> Result<(Vec<u8>, [Vec<u8>; 2]), VdafError> { - match (config, measurement) { - (Prio3Config::Count, DapMeasurement::U64(measurement)) if measurement < 2 => { - let vdaf = Prio3::new_count(2).map_err(|e| { - VdafError::Dap( - fatal_error!(err = ?e, "failed to create prio3 count from num_aggregators(2)"), - ) - })?; - shard_then_encode_draft09(&vdaf, &(measurement != 0), nonce) - } - ( - Prio3Config::Histogram { - length, - chunk_length, - }, - DapMeasurement::U64(measurement), - ) => { - let vdaf = Prio3::new_histogram(2, *length, *chunk_length) - .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 Histogram from num_aggregators(2), length({length}), chunk_length({chunk_length})")))?; - let m: usize = measurement.try_into().unwrap(); - shard_then_encode_draft09(&vdaf, &m, nonce) - } - (Prio3Config::Sum { bits }, DapMeasurement::U64(measurement)) => { - let vdaf = - Prio3::new_sum(2, *bits).map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 sum from num_aggregators(2), bits({bits})")))?; - shard_then_encode_draft09(&vdaf, &u128::from(measurement), nonce) - } - ( - Prio3Config::SumVec { - bits, - length, - chunk_length, - }, - DapMeasurement::U128Vec(measurement), - ) => { - let vdaf = Prio3::new_sum_vec(2, *bits, *length, *chunk_length) - .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 sum vec from num_aggregators(2), bits({bits}), length({length}), chunk_length({chunk_length})")))?; - shard_then_encode_draft09(&vdaf, &measurement, nonce) - } - ( - Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { - bits, - length, - chunk_length, - num_proofs, - }, - DapMeasurement::U64Vec(measurement), - ) => { - let vdaf = new_prio3_sum_vec_field64_multiproof_hmac_sha256_aes128( - *bits, - *length, - *chunk_length, - *num_proofs, - )?; - shard_then_encode_draft09(&vdaf, &measurement, nonce) - } - _ => Err(VdafError::Dap(fatal_error!( - err = format!("prio3_shard: unexpected VDAF config {config:?}") - ))), - } -} +type Prio3Draft09Prepared<T, const SEED_SIZE: usize> = ( + Prio3PrepareState<<T as Type>::Field, SEED_SIZE>, + Prio3PrepareShare<<T as Type>::Field, SEED_SIZE>, +); -/// Consume an input share and return the corresponding VDAF step and message. -pub(crate) fn prio3_draft09_prep_init( - config: &Prio3Config, - verify_key: &VdafVerifyKey, +pub(crate) fn prep_init_draft09<T, P, const SEED_SIZE: usize>( + vdaf: Prio3<T, P, SEED_SIZE>, + verify_key: &[u8; SEED_SIZE], agg_id: usize, nonce: &[u8; 16], public_share_data: &[u8], input_share_data: &[u8], -) -> Result<(VdafPrepState, VdafPrepShare), VdafError> { - return match (&config, verify_key) { - (Prio3Config::Count, VdafVerifyKey::L16(verify_key)) => { - let vdaf = Prio3::new_count(2).map_err(|e| { - VdafError::Dap( - fatal_error!(err = ?e, "failed to create prio3 from num_aggregators(2)"), - ) - })?; - let (state, share) = prep_init_draft09( - vdaf, - verify_key, - agg_id, - nonce, - public_share_data, - input_share_data, - )?; - Ok(( - VdafPrepState::Prio3Draft09Field64(state), - VdafPrepShare::Prio3Draft09Field64(share), - )) - } - ( - Prio3Config::Histogram { - length, - chunk_length, - }, - VdafVerifyKey::L16(verify_key), - ) => { - let vdaf = Prio3::new_histogram(2, *length, *chunk_length) - .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 histogram from num_aggregators(2), length({length}), chunk_length({chunk_length})")))?; - let (state, share) = prep_init_draft09( - vdaf, - verify_key, - agg_id, - nonce, - public_share_data, - input_share_data, - )?; - Ok(( - VdafPrepState::Prio3Draft09Field128(state), - VdafPrepShare::Prio3Draft09Field128(share), - )) - } - (Prio3Config::Sum { bits }, VdafVerifyKey::L16(verify_key)) => { - let vdaf = - Prio3::new_sum(2, *bits).map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 sum from num_aggregators(2), bits({bits})")))?; - let (state, share) = prep_init_draft09( - vdaf, - verify_key, - agg_id, - nonce, - public_share_data, - input_share_data, - )?; - Ok(( - VdafPrepState::Prio3Draft09Field128(state), - VdafPrepShare::Prio3Draft09Field128(share), - )) - } - ( - Prio3Config::SumVec { - bits, - length, - chunk_length, - }, - VdafVerifyKey::L16(verify_key), - ) => { - let vdaf = Prio3::new_sum_vec(2, *bits, *length, *chunk_length) - .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 sum vec from num_aggregators(2), bits({bits}), length({length}), chunk_length({chunk_length})")))?; - let (state, share) = prep_init_draft09( - vdaf, - verify_key, - agg_id, - nonce, - public_share_data, - input_share_data, - )?; - Ok(( - VdafPrepState::Prio3Draft09Field128(state), - VdafPrepShare::Prio3Draft09Field128(share), - )) - } - ( - Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { - bits, - length, - chunk_length, - num_proofs, - }, - VdafVerifyKey::L32(verify_key), - ) => { - let vdaf = new_prio3_sum_vec_field64_multiproof_hmac_sha256_aes128( - *bits, - *length, - *chunk_length, - *num_proofs, - )?; - let (state, share) = prep_init_draft09( - vdaf, - verify_key, - agg_id, - nonce, - public_share_data, - input_share_data, - )?; - Ok(( - VdafPrepState::Prio3Draft09Field64HmacSha256Aes128(state), - VdafPrepShare::Prio3Draft09Field64HmacSha256Aes128(share), - )) - } - _ => { - return Err(VdafError::Dap(fatal_error!( - err = "unhandled config and verify key combination", - ))) - } - }; - - type Prio3Draft09Prepared<T, const SEED_SIZE: usize> = ( - Prio3PrepareState<<T as Type>::Field, SEED_SIZE>, - Prio3PrepareShare<<T as Type>::Field, SEED_SIZE>, - ); - - fn prep_init_draft09<T, P, const SEED_SIZE: usize>( - vdaf: Prio3<T, P, SEED_SIZE>, - verify_key: &[u8; SEED_SIZE], - agg_id: usize, - nonce: &[u8; 16], - public_share_data: &[u8], - input_share_data: &[u8], - ) -> Result<Prio3Draft09Prepared<T, SEED_SIZE>, VdafError> - where - T: Type, - P: Xof<SEED_SIZE>, - { - // Parse the public share. - let public_share = Prio3PublicShare::get_decoded_with_param(&vdaf, public_share_data)?; - - // Parse the input share. - let input_share = - Prio3InputShare::get_decoded_with_param(&(&vdaf, agg_id), input_share_data)?; - - // Run the prepare-init algorithm, returning the initial state. - Ok(vdaf.prepare_init(verify_key, agg_id, &(), nonce, &public_share, &input_share)?) - } -} - -/// Consume the prep shares and return our output share and the prep message. -pub(crate) fn prio3_draft09_prep_finish_from_shares( - config: &Prio3Config, - agg_id: usize, - host_state: VdafPrepState, - host_share: VdafPrepShare, - peer_share_data: &[u8], -) -> Result<(VdafAggregateShare, Vec<u8>), VdafError> { - let (agg_share, outbound) = match (&config, host_state, host_share) { - ( - Prio3Config::Count, - VdafPrepState::Prio3Draft09Field64(state), - VdafPrepShare::Prio3Draft09Field64(share), - ) => { - let vdaf = Prio3::new_count(2).map_err(|e| { - VdafError::Dap( - fatal_error!(err = ?e, "failed to create prio3 count num_aggregators(2)"), - ) - })?; - let (out_share, outbound) = - prep_finish_from_shares_draft09(&vdaf, agg_id, state, share, peer_share_data)?; - let agg_share = VdafAggregateShare::Field64Draft09(vdaf.aggregate(&(), [out_share])?); - (agg_share, outbound) - } - ( - Prio3Config::Histogram { - length, - chunk_length, - }, - VdafPrepState::Prio3Draft09Field128(state), - VdafPrepShare::Prio3Draft09Field128(share), - ) => { - let vdaf = Prio3::new_histogram(2, *length, *chunk_length) - .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 histogram from num_aggregators(2), length({length}), chunk_length({chunk_length})")))?; - let (out_share, outbound) = - prep_finish_from_shares_draft09(&vdaf, agg_id, state, share, peer_share_data)?; - let agg_share = VdafAggregateShare::Field128Draft09(vdaf.aggregate(&(), [out_share])?); - (agg_share, outbound) - } - ( - Prio3Config::Sum { bits }, - VdafPrepState::Prio3Draft09Field128(state), - VdafPrepShare::Prio3Draft09Field128(share), - ) => { - let vdaf = - Prio3::new_sum(2, *bits).map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 sum from num_aggregators(2), bits({bits})")))?; - let (out_share, outbound) = - prep_finish_from_shares_draft09(&vdaf, agg_id, state, share, peer_share_data)?; - let agg_share = VdafAggregateShare::Field128Draft09(vdaf.aggregate(&(), [out_share])?); - (agg_share, outbound) - } - ( - Prio3Config::SumVec { - bits, - length, - chunk_length, - }, - VdafPrepState::Prio3Draft09Field128(state), - VdafPrepShare::Prio3Draft09Field128(share), - ) => { - let vdaf = Prio3::new_sum_vec(2, *bits, *length, *chunk_length) - .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 sum vec from num_aggregators(2), bits({bits}), length({length}), chunk_length({chunk_length})")))?; - let (out_share, outbound) = - prep_finish_from_shares_draft09(&vdaf, agg_id, state, share, peer_share_data)?; - let agg_share = VdafAggregateShare::Field128Draft09(vdaf.aggregate(&(), [out_share])?); - (agg_share, outbound) - } - ( - Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { - bits, - length, - chunk_length, - num_proofs, - }, - VdafPrepState::Prio3Draft09Field64HmacSha256Aes128(state), - VdafPrepShare::Prio3Draft09Field64HmacSha256Aes128(share), - ) => { - let vdaf = new_prio3_sum_vec_field64_multiproof_hmac_sha256_aes128( - *bits, - *length, - *chunk_length, - *num_proofs, - )?; - let (out_share, outbound) = - prep_finish_from_shares_draft09(&vdaf, agg_id, state, share, peer_share_data)?; - let agg_share = VdafAggregateShare::Field64Draft09(vdaf.aggregate(&(), [out_share])?); - (agg_share, outbound) - } - _ => { - return Err(VdafError::Dap(fatal_error!( - err = format!("prio3_prep_finish_from_shares: {ERR_FIELD_TYPE}") - ))) - } - }; - - Ok((agg_share, outbound)) -} - -/// Consume the prep message and output our output share. -pub(crate) fn prio3_draft09_prep_finish( - config: &Prio3Config, - host_state: VdafPrepState, - peer_message_data: &[u8], -) -> Result<VdafAggregateShare, VdafError> { - let agg_share = match (&config, host_state) { - (Prio3Config::Count, VdafPrepState::Prio3Draft09Field64(state)) => { - let vdaf = Prio3::new_count(2).map_err(|e| { - VdafError::Dap( - fatal_error!(err = ?e, "failed to create prio3 count from num_aggregators(2)"), - ) - })?; - let out_share = prep_finish_draft09(&vdaf, state, peer_message_data)?; - VdafAggregateShare::Field64Draft09(vdaf.aggregate(&(), [out_share])?) - } - ( - Prio3Config::Histogram { - length, - chunk_length, - }, - VdafPrepState::Prio3Draft09Field128(state), - ) => { - let vdaf = Prio3::new_histogram(2, *length, *chunk_length) - .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 histogram from num_aggregators(2), length({length}), chunk_length({chunk_length})")))?; - let out_share = prep_finish_draft09(&vdaf, state, peer_message_data)?; - VdafAggregateShare::Field128Draft09(vdaf.aggregate(&(), [out_share])?) - } - (Prio3Config::Sum { bits }, VdafPrepState::Prio3Draft09Field128(state)) => { - let vdaf = - Prio3::new_sum(2, *bits).map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 sum from num_aggregators(2), bits({bits})")))?; - let out_share = prep_finish_draft09(&vdaf, state, peer_message_data)?; - VdafAggregateShare::Field128Draft09(vdaf.aggregate(&(), [out_share])?) - } - ( - Prio3Config::SumVec { - bits, - length, - chunk_length, - }, - VdafPrepState::Prio3Draft09Field128(state), - ) => { - let vdaf = Prio3::new_sum_vec(2, *bits, *length, *chunk_length) - .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 sum vec from num_aggregators(2), bits({bits}), length({length}), chunk_length({chunk_length})")))?; - let out_share = prep_finish_draft09(&vdaf, state, peer_message_data)?; - VdafAggregateShare::Field128Draft09(vdaf.aggregate(&(), [out_share])?) - } - ( - Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { - bits, - length, - chunk_length, - num_proofs, - }, - VdafPrepState::Prio3Draft09Field64HmacSha256Aes128(state), - ) => { - let vdaf = new_prio3_sum_vec_field64_multiproof_hmac_sha256_aes128( - *bits, - *length, - *chunk_length, - *num_proofs, - )?; - let out_share = prep_finish_draft09(&vdaf, state, peer_message_data)?; - VdafAggregateShare::Field64Draft09(vdaf.aggregate(&(), [out_share])?) - } - - _ => { - return Err(VdafError::Dap(fatal_error!( - err = format!("prio3_prep_finish: {ERR_FIELD_TYPE}") - ))) - } - }; - - Ok(agg_share) -} - -/// Interpret `agg_shares` as a sequence of encoded aggregate shares and unshard them. -pub(crate) fn prio3_draft09_unshard<M: IntoIterator<Item = Vec<u8>>>( - config: &Prio3Config, - num_measurements: usize, - agg_shares: M, -) -> Result<DapAggregateResult, VdafError> { - match config { - Prio3Config::Count => { - let vdaf = Prio3::new_count(2).map_err(|e| { - VdafError::Dap( - fatal_error!(err = ?e, "failed to create prio3 count from num_aggregators(2)"), - ) - })?; - let agg_res = unshard_draft09(&vdaf, num_measurements, agg_shares)?; - Ok(DapAggregateResult::U64(agg_res)) - } - Prio3Config::Histogram { - length, - chunk_length, - } => { - let vdaf = Prio3::new_histogram(2, *length, *chunk_length) - .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 histogram from num_aggregators(2), length({length}), chunk_length({chunk_length})")))?; - let agg_res = unshard_draft09(&vdaf, num_measurements, agg_shares)?; - Ok(DapAggregateResult::U128Vec(agg_res)) - } - Prio3Config::Sum { bits } => { - let vdaf = - Prio3::new_sum(2, *bits).map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 sum from num_aggregators(2), bits({bits})")))?; - let agg_res = unshard_draft09(&vdaf, num_measurements, agg_shares)?; - Ok(DapAggregateResult::U128(agg_res)) - } - Prio3Config::SumVec { - bits, - length, - chunk_length, - } => { - let vdaf = Prio3::new_sum_vec(2, *bits, *length, *chunk_length) - .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 sum vec from num_aggregators(2), bits({bits}), length({length}), chunk_length({chunk_length})")))?; - let agg_res = unshard_draft09(&vdaf, num_measurements, agg_shares)?; - Ok(DapAggregateResult::U128Vec(agg_res)) - } - Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { - bits, - length, - chunk_length, - num_proofs, - } => { - let vdaf = new_prio3_sum_vec_field64_multiproof_hmac_sha256_aes128( - *bits, - *length, - *chunk_length, - *num_proofs, - )?; - let agg_res = unshard_draft09(&vdaf, num_measurements, agg_shares)?; - Ok(DapAggregateResult::U64Vec(agg_res)) - } - } +) -> Result<Prio3Draft09Prepared<T, SEED_SIZE>, VdafError> +where + T: Type, + P: Xof<SEED_SIZE>, +{ + // Parse the public share. + let public_share = Prio3PublicShare::get_decoded_with_param(&vdaf, public_share_data)?; + + // Parse the input share. + let input_share = Prio3InputShare::get_decoded_with_param(&(&vdaf, agg_id), input_share_data)?; + + // Run the prepare-init algorithm, returning the initial state. + Ok(vdaf.prepare_init(verify_key, agg_id, &(), nonce, &public_share, &input_share)?) } #[cfg(test)] @@ -517,7 +80,6 @@ mod test { use crate::{ hpke::HpkeKemId, - test_versions, testing::AggregationJobTest, vdaf::{ prio3_draft09::new_prio3_sum_vec_field64_multiproof_hmac_sha256_aes128, Prio3Config, @@ -526,105 +88,19 @@ mod test { DapAggregateResult, DapAggregationParam, DapMeasurement, DapVersion, }; - fn roundtrip_count(version: DapVersion) { - let mut t = AggregationJobTest::new( - &VdafConfig::Prio3Draft09(Prio3Config::Count), - HpkeKemId::X25519HkdfSha256, - version, - ); - let got = t.roundtrip( - DapAggregationParam::Empty, - vec![ - DapMeasurement::U64(0), - DapMeasurement::U64(1), - DapMeasurement::U64(1), - DapMeasurement::U64(1), - DapMeasurement::U64(0), - ], - ); - assert_eq!(got, DapAggregateResult::U64(3)); - } - - test_versions! { roundtrip_count } - - fn roundtrip_sum(version: DapVersion) { - let mut t = AggregationJobTest::new( - &VdafConfig::Prio3Draft09(Prio3Config::Sum { bits: 23 }), - HpkeKemId::X25519HkdfSha256, - version, - ); - let got = t.roundtrip( - DapAggregationParam::Empty, - vec![ - DapMeasurement::U64(0), - DapMeasurement::U64(1), - DapMeasurement::U64(1337), - DapMeasurement::U64(4), - DapMeasurement::U64(0), - ], - ); - assert_eq!(got, DapAggregateResult::U128(1342)); - } - - test_versions! { roundtrip_sum } - - fn roundtrip_sum_vec(version: DapVersion) { - let mut t = AggregationJobTest::new( - &VdafConfig::Prio3Draft09(Prio3Config::SumVec { - bits: 23, - length: 2, - chunk_length: 1, - }), - HpkeKemId::X25519HkdfSha256, - version, - ); - let got = t.roundtrip( - DapAggregationParam::Empty, - vec![ - DapMeasurement::U128Vec(vec![1337, 0]), - DapMeasurement::U128Vec(vec![0, 1337]), - DapMeasurement::U128Vec(vec![1, 1]), - ], - ); - assert_eq!(got, DapAggregateResult::U128Vec(vec![1338, 1338])); - } - - test_versions! { roundtrip_sum_vec } - - fn roundtrip_histogram(version: DapVersion) { - let mut t = AggregationJobTest::new( - &VdafConfig::Prio3Draft09(Prio3Config::Histogram { - length: 3, - chunk_length: 1, - }), - HpkeKemId::X25519HkdfSha256, - version, - ); - let got = t.roundtrip( - DapAggregationParam::Empty, - vec![ - DapMeasurement::U64(0), - DapMeasurement::U64(1), - DapMeasurement::U64(2), - DapMeasurement::U64(2), - DapMeasurement::U64(2), - ], - ); - assert_eq!(got, DapAggregateResult::U128Vec(vec![1, 1, 3])); - } - - test_versions! { roundtrip_histogram } - - fn roundtrip_sum_vec_field64_multiproof_hmac_sha256_aes128(version: DapVersion) { + #[test] + fn roundtrip_sum_vec_field64_multiproof_hmac_sha256_aes128_draft09() { let mut t = AggregationJobTest::new( - &VdafConfig::Prio3Draft09(Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { - bits: 23, - length: 2, - chunk_length: 1, - num_proofs: 4, - }), + &VdafConfig::Prio3( + Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 { + bits: 23, + length: 2, + chunk_length: 1, + num_proofs: 4, + }, + ), HpkeKemId::X25519HkdfSha256, - version, + DapVersion::Draft09, ); let got = t.roundtrip( DapAggregationParam::Empty, @@ -637,8 +113,6 @@ mod test { assert_eq!(got, DapAggregateResult::U64Vec(vec![1338, 1338])); } - test_versions! { roundtrip_sum_vec_field64_multiproof_hmac_sha256_aes128 } - #[test] fn test_vec_sum_vec_field64_multiproof_hmac_sha256_aes128() { for test_vec_json_str in [