diff --git a/crates/dapf/src/acceptance/load_testing.rs b/crates/dapf/src/acceptance/load_testing.rs index cf0d9966c..c5aea384e 100644 --- a/crates/dapf/src/acceptance/load_testing.rs +++ b/crates/dapf/src/acceptance/load_testing.rs @@ -448,6 +448,7 @@ pub async fn execute_single_combination_from_env( &measurment, VERSION, system_now.0, + Some(vec![]), vec![messages::Extension::Taskprov], t.replay_reports, ) diff --git a/crates/dapf/src/acceptance/mod.rs b/crates/dapf/src/acceptance/mod.rs index dc82028ab..91ef43073 100644 --- a/crates/dapf/src/acceptance/mod.rs +++ b/crates/dapf/src/acceptance/mod.rs @@ -655,6 +655,10 @@ impl Test { measurement.as_ref(), version, now.0, + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, vec![messages::Extension::Taskprov], self.replay_reports, ) diff --git a/crates/daphne-server/tests/e2e/e2e.rs b/crates/daphne-server/tests/e2e/e2e.rs index 7fa867d2a..c7c782471 100644 --- a/crates/daphne-server/tests/e2e/e2e.rs +++ b/crates/daphne-server/tests/e2e/e2e.rs @@ -420,6 +420,10 @@ async fn leader_upload_taskprov() { t.now, &task_id, DapMeasurement::U32Vec(vec![1; 10]), + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, vec![Extension::Taskprov], version, ) @@ -447,6 +451,10 @@ async fn leader_upload_taskprov() { t.now, &task_id, DapMeasurement::U32Vec(vec![1; 10]), + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, vec![Extension::Taskprov], version, ) @@ -512,6 +520,10 @@ async fn leader_upload_taskprov_wrong_version(version: DapVersion) { t.now, &task_id, DapMeasurement::U32Vec(vec![1; 10]), + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, vec![Extension::Taskprov], version, ) @@ -1542,18 +1554,18 @@ async fn leader_collect_taskprov_repeated_abort() { .unwrap(), ), { - let mut report = task_config + let report = task_config .vdaf .produce_report_with_extensions( &hpke_config_list, now, &task_id, DapMeasurement::U32Vec(vec![1; 10]), + Some(vec![Extension::Taskprov]), extensions, version, ) .unwrap(); - report.report_metadata.public_extensions = Some(vec![Extension::Taskprov]); report.get_encoded_with_param(&version).unwrap() }, ) @@ -1660,6 +1672,10 @@ async fn leader_collect_taskprov_ok(version: DapVersion) { now, &task_id, DapMeasurement::U32Vec(vec![1; 10]), + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, extensions, version, ) diff --git a/crates/daphne/src/error/aborts.rs b/crates/daphne/src/error/aborts.rs index d77aac4aa..53dda2522 100644 --- a/crates/daphne/src/error/aborts.rs +++ b/crates/daphne/src/error/aborts.rs @@ -268,14 +268,10 @@ impl DapAbort { task_id: &TaskId, unknown_extensions: &[u16], ) -> Result { - let detail = serde_json::to_string(&unknown_extensions); - match detail { - Ok(s) => Ok(Self::UnsupportedExtension { - detail: s, - task_id: *task_id, - }), - Err(x) => Err(fatal_error!(err = %x,)), - } + Ok(Self::UnsupportedExtension { + detail: format!("{unknown_extensions:?}"), + task_id: *task_id, + }) } fn title_and_type(&self) -> (&'static str, Option) { diff --git a/crates/daphne/src/messages/mod.rs b/crates/daphne/src/messages/mod.rs index 800dcba53..796bfbd9f 100644 --- a/crates/daphne/src/messages/mod.rs +++ b/crates/daphne/src/messages/mod.rs @@ -1418,7 +1418,7 @@ impl Decode for HpkeCiphertext { /// A plaintext input share. #[derive(Clone, Debug, PartialEq, Eq)] pub struct PlaintextInputShare { - pub extensions: Vec, + pub private_extensions: Vec, pub payload: Vec, } @@ -1428,7 +1428,7 @@ impl ParameterizedEncode for PlaintextInputShare { version: &DapVersion, bytes: &mut Vec, ) -> Result<(), CodecError> { - encode_u16_items(bytes, version, &self.extensions)?; + encode_u16_items(bytes, version, &self.private_extensions)?; encode_u32_bytes(bytes, &self.payload)?; Ok(()) } @@ -1440,7 +1440,7 @@ impl ParameterizedDecode for PlaintextInputShare { bytes: &mut Cursor<&[u8]>, ) -> Result { Ok(Self { - extensions: decode_u16_items(version, bytes)?, + private_extensions: decode_u16_items(version, bytes)?, payload: decode_u32_bytes(bytes)?, }) } @@ -1648,6 +1648,24 @@ mod test { test_versions! {report_metadata_encode_decode} + #[test] + fn report_metadata_encode_latest_decode_draft09() { + let ext_rm = ReportMetadata { + id: ReportId([15; 16]), + time: 123_456, + public_extensions: Some(vec![Extension::NotImplemented { + typ: 0x10, + payload: vec![0x11, 0x12], + }]), + }; + let bytes = ext_rm.get_encoded_with_param(&DapVersion::Latest).unwrap(); + assert!(matches!( + ReportMetadata::get_decoded_with_param(&DapVersion::Draft09, bytes.as_slice()) + .unwrap_err(), + CodecError::BytesLeftOver(..) + )); + } + fn partial_batch_selector_encode_decode(version: DapVersion) { const TEST_DATA_DRAFT09: &[u8] = &[1]; const TEST_DATA_LATEST: &[u8] = &[1, 0, 0]; diff --git a/crates/daphne/src/protocol/aggregator.rs b/crates/daphne/src/protocol/aggregator.rs index 479d42312..da19b6328 100644 --- a/crates/daphne/src/protocol/aggregator.rs +++ b/crates/daphne/src/protocol/aggregator.rs @@ -2,7 +2,7 @@ // SPDX-License-Identifier: BSD-3-Clause use super::{ - no_duplicates, + check_no_duplicates, report_init::{InitializedReport, WithPeerPrepShare}, }; use crate::{ @@ -241,7 +241,7 @@ impl DapTaskConfig { DapAggregationParam::get_decoded_with_param(&self.vdaf, &agg_job_init_req.agg_param) .map_err(|e| DapAbort::from_codec_error(e, *task_id))?; if replay_protection.enabled() { - no_duplicates( + check_no_duplicates( agg_job_init_req .prep_inits .iter() diff --git a/crates/daphne/src/protocol/client.rs b/crates/daphne/src/protocol/client.rs index cbfdc21f3..dd3216363 100644 --- a/crates/daphne/src/protocol/client.rs +++ b/crates/daphne/src/protocol/client.rs @@ -30,13 +30,15 @@ impl VdafConfig { /// * `extensions` are the extensions. /// /// * `version` is the `DapVersion` to use. + #[allow(clippy::too_many_arguments)] pub fn produce_report_with_extensions( &self, hpke_config_list: &[HpkeConfig; 2], time: Time, task_id: &TaskId, measurement: DapMeasurement, - extensions: Vec, + public_extensions: Option>, + private_extensions: Vec, version: DapVersion, ) -> Result { let mut rng = thread_rng(); @@ -51,7 +53,8 @@ impl VdafConfig { time, task_id, &report_id, - extensions, + public_extensions, + private_extensions, version, ) } @@ -65,21 +68,28 @@ impl VdafConfig { time: Time, task_id: &TaskId, report_id: &ReportId, - extensions: Vec, + public_extensions: Option>, + private_extensions: Vec, version: DapVersion, ) -> Result { + match (&public_extensions, version) { + (Some(_), DapVersion::Draft09) | (None, DapVersion::Latest) => { + return Err(DapError::ReportError( + crate::messages::ReportError::InvalidMessage, + )) + } + _ => (), + } + let mut plaintext_input_share = PlaintextInputShare { - extensions, + private_extensions, payload: Vec::default(), }; let metadata = ReportMetadata { id: *report_id, time, - public_extensions: match version { - DapVersion::Draft09 => None, - DapVersion::Latest => Some(Vec::new()), - }, + public_extensions, }; let encoded_input_shares = input_shares.into_iter().map(|input_share| { @@ -147,6 +157,10 @@ impl VdafConfig { time, task_id, measurement, + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(Vec::new()), + }, Vec::new(), version, ) diff --git a/crates/daphne/src/protocol/mod.rs b/crates/daphne/src/protocol/mod.rs index dac8837f1..2f39c3efa 100644 --- a/crates/daphne/src/protocol/mod.rs +++ b/crates/daphne/src/protocol/mod.rs @@ -11,7 +11,7 @@ pub(crate) mod report_init; /// checks if an iterator has no duplicate items, returns the ok if there are no dups or an error /// with the first offending item. -pub(crate) fn no_duplicates(iterator: I) -> Result<(), I::Item> +pub(crate) fn check_no_duplicates(iterator: I) -> Result<(), I::Item> where I: Iterator, I::Item: Eq + std::hash::Hash, @@ -752,6 +752,10 @@ mod test { t.now, &t.task_id, DapMeasurement::U32Vec(vec![1; 10]), + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, vec![Extension::NotImplemented { typ: 0xffff, payload: b"some extension data".to_vec(), @@ -784,29 +788,32 @@ mod test { test_versions! { handle_unrecognized_report_extensions } - fn handle_unknown_public_extensions_in_report(version: DapVersion) { + #[test] + fn handle_unknown_public_extensions_in_report() { + let version = DapVersion::Latest; let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); - let mut report = t + let report = t .task_config .vdaf - .produce_report( + .produce_report_with_extensions( &t.client_hpke_config_list, t.now, &t.task_id, DapMeasurement::U32Vec(vec![1; 10]), + Some(vec![ + Extension::NotImplemented { + typ: 0x01, + payload: b"This is ignored".to_vec(), + }, + Extension::NotImplemented { + typ: 0x02, + payload: b"This is ignored too".to_vec(), + }, + ]), + vec![], version, ) .unwrap(); - report.report_metadata.public_extensions = Some(vec![ - Extension::NotImplemented { - typ: 0x01, - payload: b"This is ignored".to_vec(), - }, - Extension::NotImplemented { - typ: 0x02, - payload: b"This is ignored too".to_vec(), - }, - ]); let report_metadata = report.report_metadata.clone(); let [leader_share, _] = report.encrypted_input_shares; let initialized_report = InitializedReport::from_client( @@ -832,7 +839,6 @@ mod test { } ); } - test_versions! {handle_unknown_public_extensions_in_report} fn handle_repeated_report_extensions(version: DapVersion) { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); @@ -844,6 +850,10 @@ mod test { t.now, &t.task_id, DapMeasurement::U32Vec(vec![1; 10]), + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, vec![ Extension::NotImplemented { typ: 23, @@ -906,6 +916,10 @@ mod test { self.now, &self.task_id, &report_id, + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, Vec::new(), // extensions version, ) @@ -937,6 +951,10 @@ mod test { self.now, &self.task_id, &report_id, + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, Vec::new(), // extensions version, ) @@ -969,6 +987,10 @@ mod test { self.now, &self.task_id, &report_id, + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, Vec::new(), // extensions version, ) diff --git a/crates/daphne/src/protocol/report_init.rs b/crates/daphne/src/protocol/report_init.rs index abc1c1f4b..6d8ed075b 100644 --- a/crates/daphne/src/protocol/report_init.rs +++ b/crates/daphne/src/protocol/report_init.rs @@ -7,7 +7,7 @@ use crate::{ messages::{ self, Extension, PlaintextInputShare, ReportError, ReportMetadata, ReportShare, TaskId, }, - protocol::{decode_ping_pong_framed, no_duplicates, PingPongMessageType}, + protocol::{check_no_duplicates, decode_ping_pong_framed, PingPongMessageType}, vdaf::{VdafConfig, VdafPrepShare, VdafPrepState, VdafVerifyKey}, DapAggregationParam, DapError, DapTaskConfig, DapVersion, }; @@ -158,31 +158,16 @@ impl

InitializedReport

{ _ => {} } - // We don't check for duplicates here, because we check for them later - // on, and the taskprov extension, the only one we support, has no - // side-effects if processed when the report should have been rejected. - - let mut taskprov_indicated = false; match ( &report_share.report_metadata.public_extensions, task_config.version, ) { - (Some(extensions), crate::DapVersion::Latest) => { - for extension in extensions { - match extension { - Extension::Taskprov { .. } => { - taskprov_indicated |= task_config.method_is_taskprov; - } - Extension::NotImplemented { .. } => reject!(InvalidMessage), - } - } - } - (None, crate::DapVersion::Draft09) => (), + (Some(..), crate::DapVersion::Latest) | (None, crate::DapVersion::Draft09) => (), (_, _) => reject!(InvalidMessage), } // decrypt input share let PlaintextInputShare { - extensions, + private_extensions, payload: input_share, } = { let info = info_and_aad::InputShare { @@ -217,8 +202,8 @@ impl

InitializedReport

{ // Handle report extensions. { // Check for duplicates in public and private extensions - if no_duplicates( - extensions + if check_no_duplicates( + private_extensions .iter() .chain( report_share @@ -234,10 +219,17 @@ impl

InitializedReport

{ reject!(InvalidMessage) } - for extension in extensions { + let mut taskprov_indicated = false; + for extension in private_extensions.iter().chain( + report_share + .report_metadata + .public_extensions + .as_deref() + .unwrap_or_default(), + ) { match extension { Extension::Taskprov { .. } => { - taskprov_indicated |= task_config.method_is_taskprov; + taskprov_indicated = task_config.method_is_taskprov; } // Reject reports with unrecognized extensions. Extension::NotImplemented { .. } => reject!(InvalidMessage), diff --git a/crates/daphne/src/roles/helper/handle_agg_job.rs b/crates/daphne/src/roles/helper/handle_agg_job.rs index 58d852b25..63b9efcb9 100644 --- a/crates/daphne/src/roles/helper/handle_agg_job.rs +++ b/crates/daphne/src/roles/helper/handle_agg_job.rs @@ -155,7 +155,7 @@ impl HandleAggJob { > { let task_id = self.state.request.task_id; if replay_protection.enabled() { - crate::protocol::no_duplicates( + crate::protocol::check_no_duplicates( self.state .request .payload diff --git a/crates/daphne/src/roles/leader/mod.rs b/crates/daphne/src/roles/leader/mod.rs index ad0bcd042..9a61413c4 100644 --- a/crates/daphne/src/roles/leader/mod.rs +++ b/crates/daphne/src/roles/leader/mod.rs @@ -23,6 +23,7 @@ use crate::{ CollectionReq, Extension, Interval, PartialBatchSelector, Query, Report, TaskId, }, metrics::{DaphneRequestType, ReportStatus}, + protocol, roles::resolve_task_config, DapAggregationParam, DapCollectionJob, DapError, DapLeaderProcessTelemetry, DapRequest, DapRequestMeta, DapResponse, DapTaskConfig, DapVersion, @@ -214,6 +215,8 @@ pub async fn handle_upload_req( } .into()); } + + // Check that the report was generated after the task's `not_before` time. if report.report_metadata.time < task_config.as_ref().not_before - task_config.as_ref().time_precision { @@ -223,45 +226,32 @@ pub async fn handle_upload_req( .into()); } - match ( - &report.report_metadata.public_extensions, - task_config.version, - ) { - (Some(extensions), DapVersion::Latest) => { - let mut unknown_extensions = Vec::::new(); - if crate::protocol::no_duplicates(extensions.iter()).is_err() { - return Err(DapError::Abort(DapAbort::InvalidMessage { - detail: "Repeated public extension".into(), - task_id, - })); - }; - for extension in extensions { - match extension { - Extension::Taskprov => (), - Extension::NotImplemented { typ, .. } => unknown_extensions.push(*typ), - } - } - - if !unknown_extensions.is_empty() { - return match DapAbort::unsupported_extension(&task_id, &unknown_extensions) { - Ok(abort) => Err::<(), DapError>(abort.into()), - Err(fatal) => Err(fatal), - }; + if let Some(public_extensions) = &report.report_metadata.public_extensions { + // We can be sure at this point that the ReportMetadata is well formed (as + // the decoding / error checking happens in the extractor). + assert_eq!(DapVersion::Latest, task_config.version); + let mut unknown_extensions = Vec::::new(); + if protocol::check_no_duplicates(public_extensions.iter()).is_err() { + return Err(DapError::Abort(DapAbort::InvalidMessage { + detail: "Repeated public extension".into(), + task_id, + })); + }; + for extension in public_extensions { + match extension { + Extension::Taskprov => (), + Extension::NotImplemented { typ, .. } => unknown_extensions.push(*typ), } } - (None, DapVersion::Draft09) => (), - (Some(_), DapVersion::Draft09) => { - return Err(DapError::Abort(DapAbort::version_mismatch( - DapVersion::Draft09, - DapVersion::Latest, - ))) - } - (None, DapVersion::Latest) => { - return Err(DapError::Abort(DapAbort::version_mismatch( - DapVersion::Latest, - DapVersion::Draft09, - ))) + + if !unknown_extensions.is_empty() { + return match DapAbort::unsupported_extension(&task_id, &unknown_extensions) { + Ok(abort) => Err::<(), DapError>(abort.into()), + Err(fatal) => Err(fatal), + }; } + } else { + assert_eq!(DapVersion::Draft09, task_config.version); } // Store the report for future processing. At this point, the report may be rejected if diff --git a/crates/daphne/src/roles/mod.rs b/crates/daphne/src/roles/mod.rs index 8a982b59e..0e056e6b3 100644 --- a/crates/daphne/src/roles/mod.rs +++ b/crates/daphne/src/roles/mod.rs @@ -765,7 +765,9 @@ mod test { async_test_versions! { handle_agg_job_req_failure_hpke_decrypt_error } - async fn handle_unknown_public_extensions(version: DapVersion) { + #[tokio::test] + async fn handle_unknown_public_extensions() { + let version = DapVersion::Latest; let t = Test::new(version); let task_id = &t.time_interval_task_id; let task_config = t.leader.unchecked_get_task_config(task_id).await; @@ -785,25 +787,37 @@ mod test { resource_id: (), payload: report, }; - match version { - DapVersion::Draft09 => assert_eq!( - leader::handle_upload_req(&*t.leader, req).await, - Err(DapError::Abort(DapAbort::version_mismatch( - DapVersion::Draft09, - DapVersion::Latest - ))) - ), - DapVersion::Latest => assert_eq!( - leader::handle_upload_req(&*t.leader, req).await, - Err(DapError::Abort(DapAbort::UnsupportedExtension { - detail: "[1]".into(), - task_id: *task_id - })) - ), - } + assert_eq!( + leader::handle_upload_req(&*t.leader, req).await, + Err(DapError::Abort(DapAbort::UnsupportedExtension { + detail: "[1]".into(), + task_id: *task_id + })) + ); } - async_test_versions! { handle_unknown_public_extensions } + #[tokio::test] + #[should_panic(expected = "assertion `left == right` failed\n left: Latest\n right: Draft09")] + async fn handle_public_extensions_draft09() { + let version = DapVersion::Draft09; + let t = Test::new(version); + let task_id = &t.time_interval_task_id; + let task_config = t.leader.unchecked_get_task_config(task_id).await; + let mut report = t.gen_test_report(task_id).await; + report.report_metadata.public_extensions = Some(vec![]); + + let req = DapRequest { + meta: DapRequestMeta { + version: task_config.version, + media_type: Some(DapMediaType::Report), + task_id: *task_id, + ..Default::default() + }, + resource_id: (), + payload: report, + }; + _ = leader::handle_upload_req(&*t.leader, req).await; + } async fn handle_agg_job_req_transition_continue(version: DapVersion) { let t = Test::new(version); @@ -1513,6 +1527,10 @@ mod test { t.now, &task_id, test_measurement.clone(), + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, vec![Extension::Taskprov], task_config.version, ) @@ -1629,6 +1647,271 @@ mod test { .await; } + #[tokio::test] + async fn leader_upload_taskprov_public() { + let version = DapVersion::Latest; + let t = Test::new(DapVersion::Latest); + + let (task_config, task_id, taskprov_advertisement) = DapTaskParameters { + version, + min_batch_size: 1, + query: DapBatchMode::LeaderSelected { + max_batch_size: Some(NonZeroU32::new(2).unwrap()), + }, + ..Default::default() + } + .to_config_with_taskprov( + b"cool task".to_vec(), + t.now, + t.leader.get_taskprov_config().unwrap(), + ) + .unwrap(); + + // Clients: Send upload request to Leader. + let hpke_config_list = [ + t.leader + .get_hpke_config_for(version, Some(&task_id)) + .await + .unwrap() + .clone(), + t.helper + .get_hpke_config_for(version, Some(&task_id)) + .await + .unwrap() + .clone(), + ]; + + for _ in 0..task_config.min_batch_size { + let report = task_config + .vdaf + .produce_report_with_extensions( + &hpke_config_list, + t.now + 1, + &task_id, + DapMeasurement::U32Vec(vec![1; 10]), + Some(vec![Extension::Taskprov]), + vec![], + version, + ) + .unwrap(); + let req = DapRequest { + meta: DapRequestMeta { + version, + media_type: Some(DapMediaType::Report), + task_id, + taskprov_advertisement: Some(taskprov_advertisement.clone()), + }, + resource_id: (), + payload: report, + }; + leader::handle_upload_req(&*t.leader, req).await.unwrap(); + } + // Collector: Request result from the Leader. + let query = Query::LeaderSelectedCurrentBatch; + leader::handle_coll_job_req(&*t.leader, &t.gen_test_coll_job_req(query, &task_id).await) + .await + .unwrap(); + + leader::process(&*t.leader, "leader.com", 100) + .await + .unwrap(); + + assert_metrics_include!(t.helper_registry, { + r#"inbound_request_counter{env="test_helper",host="helper.org",type="aggregate"}"#: 1, + r#"inbound_request_counter{env="test_helper",host="helper.org",type="collect"}"#: 1, + r#"report_counter{env="test_helper",host="helper.org",status="aggregated"}"#: 1, + r#"report_counter{env="test_helper",host="helper.org",status="collected"}"#: 1, + r#"aggregation_job_counter{env="test_helper",host="helper.org",status="started"}"#: 1, + r#"aggregation_job_counter{env="test_helper",host="helper.org",status="completed"}"#: 1, + }); + assert_metrics_include!(t.leader_registry, { + r#"report_counter{env="test_leader",host="leader.com",status="aggregated"}"#: 1, + r#"report_counter{env="test_leader",host="leader.com",status="collected"}"#: 1, + }); + } + + #[tokio::test] + async fn leader_upload_taskprov_public_extension_errors() { + let version = DapVersion::Latest; + let t = Test::new(DapVersion::Latest); + + let (task_config, task_id, taskprov_advertisement) = DapTaskParameters { + version, + min_batch_size: 1, + query: DapBatchMode::LeaderSelected { + max_batch_size: Some(NonZeroU32::new(2).unwrap()), + }, + ..Default::default() + } + .to_config_with_taskprov( + b"cool task".to_vec(), + t.now, + t.leader.get_taskprov_config().unwrap(), + ) + .unwrap(); + + // Clients: Send upload request to Leader. + let hpke_config_list = [ + t.leader + .get_hpke_config_for(version, Some(&task_id)) + .await + .unwrap() + .clone(), + t.helper + .get_hpke_config_for(version, Some(&task_id)) + .await + .unwrap() + .clone(), + ]; + + let report = task_config + .vdaf + .produce_report_with_extensions( + &hpke_config_list, + t.now + 1, + &task_id, + DapMeasurement::U32Vec(vec![1; 10]), + Some(vec![Extension::Taskprov, Extension::Taskprov]), + vec![], + version, + ) + .unwrap(); + let req = DapRequest { + meta: DapRequestMeta { + version, + media_type: Some(DapMediaType::Report), + task_id, + taskprov_advertisement: Some(taskprov_advertisement.clone()), + }, + resource_id: (), + payload: report, + }; + assert_eq!( + DapError::Abort(DapAbort::InvalidMessage { + detail: "Repeated public extension".into(), + task_id, + }), + leader::handle_upload_req(&*t.leader, req) + .await + .unwrap_err() + ); + + let report = task_config + .vdaf + .produce_report_with_extensions( + &hpke_config_list, + t.now + 1, + &task_id, + DapMeasurement::U32Vec(vec![1; 10]), + Some(vec![ + Extension::Taskprov, + Extension::NotImplemented { + typ: 14, + payload: b"Ignore".into(), + }, + Extension::NotImplemented { + typ: 15, + payload: b"Ignore".into(), + }, + ]), + vec![], + version, + ) + .unwrap(); + let req = DapRequest { + meta: DapRequestMeta { + version, + media_type: Some(DapMediaType::Report), + task_id, + taskprov_advertisement: Some(taskprov_advertisement.clone()), + }, + resource_id: (), + payload: report, + }; + + assert_eq!( + DapError::Abort(DapAbort::unsupported_extension(&task_id, &[14, 15]).unwrap()), + leader::handle_upload_req(&*t.leader, req) + .await + .unwrap_err() + ); + } + + #[tokio::test] + async fn leader_upload_taskprov_in_public_and_private_extensions() { + let version = DapVersion::Latest; + let t = Test::new(DapVersion::Latest); + + let (task_config, task_id, taskprov_advertisement) = DapTaskParameters { + version, + min_batch_size: 1, + query: DapBatchMode::LeaderSelected { + max_batch_size: Some(NonZeroU32::new(2).unwrap()), + }, + ..Default::default() + } + .to_config_with_taskprov( + b"cool task".to_vec(), + t.now, + t.leader.get_taskprov_config().unwrap(), + ) + .unwrap(); + + // Clients: Send upload request to Leader. + let hpke_config_list = [ + t.leader + .get_hpke_config_for(version, Some(&task_id)) + .await + .unwrap() + .clone(), + t.helper + .get_hpke_config_for(version, Some(&task_id)) + .await + .unwrap() + .clone(), + ]; + + for _ in 0..task_config.min_batch_size { + let report = task_config + .vdaf + .produce_report_with_extensions( + &hpke_config_list, + t.now + 1, + &task_id, + DapMeasurement::U32Vec(vec![1; 10]), + Some(vec![Extension::Taskprov]), + vec![Extension::Taskprov], + version, + ) + .unwrap(); + let req = DapRequest { + meta: DapRequestMeta { + version, + media_type: Some(DapMediaType::Report), + task_id, + taskprov_advertisement: Some(taskprov_advertisement.clone()), + }, + resource_id: (), + payload: report, + }; + leader::handle_upload_req(&*t.leader, req).await.unwrap(); + } + // Collector: Request result from the Leader. + let query = Query::LeaderSelectedCurrentBatch; + leader::handle_coll_job_req(&*t.leader, &t.gen_test_coll_job_req(query, &task_id).await) + .await + .unwrap(); + + leader::process(&*t.leader, "leader.com", 100) + .await + .unwrap(); + + assert_metrics_include!(t.leader_registry, { + r#"report_counter{env="test_leader",host="leader.com",status="rejected_invalid_message"}"#: 1, + r#"inbound_request_counter{env="test_leader",host="leader.com",type="upload"}"#: 1, + }); + } + // Test multiple tasks in flight at once. async fn multi_task(version: DapVersion) { let t = Test::new(version); diff --git a/crates/daphne/src/testing/report_generator.rs b/crates/daphne/src/testing/report_generator.rs index 07332c41d..ce171e8bf 100644 --- a/crates/daphne/src/testing/report_generator.rs +++ b/crates/daphne/src/testing/report_generator.rs @@ -45,7 +45,8 @@ impl ReportGenerator { measurement: &DapMeasurement, version: DapVersion, now: Time, - extensions: Vec, + public_extensions: Option>, + private_extensions: Vec, replay_reports: bool, ) -> Self { let (tx, rx) = mpsc::channel(); @@ -78,7 +79,8 @@ impl ReportGenerator { report_time_dist.sample(&mut thread_rng()), &task_id, measurement.clone(), - extensions.clone(), + public_extensions.clone(), + private_extensions.clone(), version, ) .expect("we have to panic here since we can't return the error") @@ -90,7 +92,8 @@ impl ReportGenerator { report_time_dist.sample(&mut thread_rng()), &task_id, measurement.clone(), - extensions.clone(), + public_extensions.clone(), + private_extensions.clone(), version, )? };