diff --git a/daphne/src/messages/mod.rs b/daphne/src/messages/mod.rs index ba4a671ce..157350243 100644 --- a/daphne/src/messages/mod.rs +++ b/daphne/src/messages/mod.rs @@ -125,8 +125,14 @@ pub type Time = u64; #[serde(rename_all = "snake_case")] #[cfg_attr(any(test, feature = "test-utils"), derive(deepsize::DeepSizeOf))] pub enum Extension { - Taskprov { payload: Vec }, // Not a TaskConfig to make computing the expected task id more efficient - Unhandled { typ: u16, payload: Vec }, + Taskprov { + // Not a TaskConfig to make computing the expected task id more efficient + draft02_payload: Option>, + }, + NotImplemented { + typ: u16, + payload: Vec, + }, } impl Extension { @@ -134,19 +140,23 @@ impl Extension { pub(crate) fn type_code(&self) -> u16 { match self { Self::Taskprov { .. } => EXTENSION_TASKPROV, - Self::Unhandled { typ, .. } => *typ, + Self::NotImplemented { typ, .. } => *typ, } } } -impl Encode for Extension { - fn encode(&self, bytes: &mut Vec) { +impl ParameterizedEncode for Extension { + fn encode_with_param(&self, version: &DapVersion, bytes: &mut Vec) { match self { - Self::Taskprov { payload } => { + Self::Taskprov { draft02_payload } => { EXTENSION_TASKPROV.encode(bytes); - encode_u16_bytes(bytes, payload); + match (version, draft02_payload) { + (DapVersion::Draft07, None) => encode_u16_item(bytes, *version, &()), + (DapVersion::Draft02, Some(payload)) => encode_u16_bytes(bytes, payload), + _ => unreachable!("unhandled version {version:?}"), + } } - Self::Unhandled { typ, payload } => { + Self::NotImplemented { typ, payload } => { typ.encode(bytes); encode_u16_bytes(bytes, payload); } @@ -154,13 +164,26 @@ impl Encode for Extension { } } -impl Decode for Extension { - fn decode(bytes: &mut Cursor<&[u8]>) -> Result { +impl ParameterizedDecode for Extension { + fn decode_with_param( + version: &DapVersion, + bytes: &mut Cursor<&[u8]>, + ) -> Result { let typ = u16::decode(bytes)?; - let payload = decode_u16_bytes(bytes)?; - match typ { - EXTENSION_TASKPROV => Ok(Self::Taskprov { payload }), - _ => Ok(Self::Unhandled { typ, payload }), + match (version, typ) { + (DapVersion::Draft07, EXTENSION_TASKPROV) => { + decode_u16_item::<()>(*version, bytes)?; + Ok(Self::Taskprov { + draft02_payload: None, + }) + } + (DapVersion::Draft02, EXTENSION_TASKPROV) => Ok(Self::Taskprov { + draft02_payload: Some(decode_u16_bytes(bytes)?), + }), + _ => Ok(Self::NotImplemented { + typ, + payload: decode_u16_bytes(bytes)?, + }), } } } @@ -182,7 +205,7 @@ impl ParameterizedEncode for ReportMetadata { self.time.encode(bytes); match (version, &self.draft02_extensions) { (DapVersion::Draft07, None) => (), - (DapVersion::Draft02, Some(extensions)) => encode_u16_items(bytes, &(), extensions), + (DapVersion::Draft02, Some(extensions)) => encode_u16_items(bytes, version, extensions), _ => unreachable!("extensions should be set in (and only in) draft02"), } } @@ -197,7 +220,7 @@ impl ParameterizedDecode for ReportMetadata { id: ReportId::decode(bytes)?, time: Time::decode(bytes)?, draft02_extensions: match version { - DapVersion::Draft02 => Some(decode_u16_items(&(), bytes)?), + DapVersion::Draft02 => Some(decode_u16_items(version, bytes)?), DapVersion::Draft07 => None, }, }; @@ -1130,17 +1153,20 @@ pub struct PlaintextInputShare { pub payload: Vec, } -impl Encode for PlaintextInputShare { - fn encode(&self, bytes: &mut Vec) { - encode_u16_items(bytes, &(), &self.extensions); +impl ParameterizedEncode for PlaintextInputShare { + fn encode_with_param(&self, version: &DapVersion, bytes: &mut Vec) { + encode_u16_items(bytes, version, &self.extensions); encode_u32_bytes(bytes, &self.payload); } } -impl Decode for PlaintextInputShare { - fn decode(bytes: &mut Cursor<&[u8]>) -> Result { +impl ParameterizedDecode for PlaintextInputShare { + fn decode_with_param( + version: &DapVersion, + bytes: &mut Cursor<&[u8]>, + ) -> Result { Ok(Self { - extensions: decode_u16_items(&(), bytes)?, + extensions: decode_u16_items(version, bytes)?, payload: decode_u32_bytes(bytes)?, }) } @@ -1214,58 +1240,64 @@ pub fn decode_base64url_vec>(input: T) -> Option> { URL_SAFE_NO_PAD.decode(input).ok() } -fn encode_u16_item_for_version>( +// Cribbed from `decode_u16_items()` from libprio. +fn encode_u16_item>( bytes: &mut Vec, version: DapVersion, item: &E, ) { - match version { - DapVersion::Draft07 => { - // Cribbed from `decode_u16_items()` from libprio. - // - // Reserve space for the length prefix. - let len_offset = bytes.len(); - 0_u16.encode(bytes); - - item.encode_with_param(&version, bytes); - let len = bytes.len() - len_offset - 2; - for (offset, byte) in u16::to_be_bytes(len.try_into().unwrap()).iter().enumerate() { - bytes[len_offset + offset] = *byte; - } - } + // Reserve space for the length prefix. + let len_offset = bytes.len(); + 0_u16.encode(bytes); - DapVersion::Draft02 => item.encode_with_param(&version, bytes), + item.encode_with_param(&version, bytes); + let len = bytes.len() - len_offset - 2; + for (offset, byte) in u16::to_be_bytes(len.try_into().unwrap()).iter().enumerate() { + bytes[len_offset + offset] = *byte; } } -pub fn decode_u16_item_for_version>( - version: &DapVersion, +// Cribbed from `decode_u16_items()` from libprio. +fn decode_u16_item>( + version: DapVersion, bytes: &mut Cursor<&[u8]>, ) -> Result { - match version { - DapVersion::Draft07 => { - // Cribbed from `decode_u16_items()` from libprio. - // - // Read the length prefix. - let len = usize::from(u16::decode(bytes)?); + // Read the length prefix. + let len = usize::from(u16::decode(bytes)?); - let item_start = usize::try_from(bytes.position()).unwrap(); + let item_start = usize::try_from(bytes.position()).unwrap(); - // Make sure encoded length doesn't overflow usize or go past the end of provided byte buffer. - let item_end = item_start - .checked_add(len) - .ok_or_else(|| CodecError::LengthPrefixTooBig(len))?; + // Make sure encoded length doesn't overflow usize or go past the end of provided byte buffer. + let item_end = item_start + .checked_add(len) + .ok_or_else(|| CodecError::LengthPrefixTooBig(len))?; - let decoded = - D::get_decoded_with_param(version, &bytes.get_ref()[item_start..item_end])?; + let decoded = D::get_decoded_with_param(&version, &bytes.get_ref()[item_start..item_end])?; - // Advance outer cursor by the amount read in the inner cursor. - bytes.set_position(item_end.try_into().unwrap()); + // Advance outer cursor by the amount read in the inner cursor. + bytes.set_position(item_end.try_into().unwrap()); - Ok(decoded) - } + Ok(decoded) +} - DapVersion::Draft02 => D::decode_with_param(version, bytes), +fn encode_u16_item_for_version>( + bytes: &mut Vec, + version: DapVersion, + item: &E, +) { + match version { + DapVersion::Draft07 => encode_u16_item(bytes, version, item), + DapVersion::Draft02 => item.encode_with_param(&version, bytes), + } +} + +fn decode_u16_item_for_version>( + version: DapVersion, + bytes: &mut Cursor<&[u8]>, +) -> Result { + match version { + DapVersion::Draft07 => decode_u16_item(version, bytes), + DapVersion::Draft02 => D::decode_with_param(&version, bytes), } } diff --git a/daphne/src/messages/taskprov.rs b/daphne/src/messages/taskprov.rs index 188c01b81..19f264c88 100644 --- a/daphne/src/messages/taskprov.rs +++ b/daphne/src/messages/taskprov.rs @@ -55,7 +55,7 @@ impl ParameterizedDecode for VdafTypeVar { let vdaf_type = u32::decode(bytes)?; match (version, vdaf_type) { (.., VDAF_TYPE_PRIO2) => Ok(Self::Prio2 { - dimension: decode_u16_item_for_version(version, bytes)?, + dimension: decode_u16_item_for_version(*version, bytes)?, }), (DapVersion::Draft07, ..) => Ok(Self::NotImplemented { typ: vdaf_type, @@ -102,7 +102,7 @@ impl ParameterizedDecode for DpConfig { let dp_mechanism = u8::decode(bytes)?; match (version, dp_mechanism) { (.., DP_MECHANISM_NONE) => { - decode_u16_item_for_version::<()>(version, bytes)?; + decode_u16_item_for_version::<()>(*version, bytes)?; Ok(Self::None) } (DapVersion::Draft07, ..) => Ok(Self::NotImplemented { @@ -240,11 +240,11 @@ impl ParameterizedDecode for QueryConfig { let query_type = query_type.unwrap_or(u8::decode(bytes)?); let var = match (version, query_type) { (.., QUERY_TYPE_TIME_INTERVAL) => { - decode_u16_item_for_version::<()>(version, bytes)?; + decode_u16_item_for_version::<()>(*version, bytes)?; QueryConfigVar::TimeInterval } (.., QUERY_TYPE_FIXED_SIZE) => QueryConfigVar::FixedSize { - max_batch_size: decode_u16_item_for_version(version, bytes)?, + max_batch_size: decode_u16_item_for_version(*version, bytes)?, }, (DapVersion::Draft07, ..) => QueryConfigVar::NotImplemented { typ: query_type, diff --git a/daphne/src/roles/helper.rs b/daphne/src/roles/helper.rs index 797e6740a..d0a26c581 100644 --- a/daphne/src/roles/helper.rs +++ b/daphne/src/roles/helper.rs @@ -57,40 +57,45 @@ pub trait DapHelper: DapAggregator { metrics.agg_job_observe_batch_size(agg_job_init_req.prep_inits.len()); - // taskprov: Resolve the task config to use for the request. We also need to ensure - // that all of the reports include the task config in the report extensions. (See - // section 6 of draft-wang-ppm-dap-taskprov-02.) + // taskprov: Resolve the task config to use for the request. if self.get_global_config().allow_taskprov { - let using_taskprov = agg_job_init_req - .prep_inits - .iter() - .filter(|prep_init| { - prep_init - .report_share - .report_metadata - .is_taskprov(req.version, task_id) - }) - .count(); - - let first_metadata = match using_taskprov { - 0 => None, - c if c == agg_job_init_req.prep_inits.len() => { - // All the extensions use taskprov and look ok, so compute first_metadata. - // Note this will always be Some(). - agg_job_init_req - .prep_inits - .first() - .map(|prep_init| &prep_init.report_share.report_metadata) - } - _ => { - // It's not all taskprov or no taskprov, so it's an error. - return Err(DapAbort::InvalidMessage { - detail: "some reports include the taskprov extensions and some do not" - .to_string(), - task_id: Some(*task_id), - }); + // draft02 compatibility: We also need to ensure that all of the reports include the task + // config in the report extensions. (See section 6 of draft-wang-ppm-dap-taskprov-02.) + let first_metadata = if req.version == DapVersion::default() { + let using_taskprov = agg_job_init_req + .prep_inits + .iter() + .filter(|prep_init| { + prep_init + .report_share + .report_metadata + .is_taskprov(req.version, task_id) + }) + .count(); + + match using_taskprov { + 0 => None, + c if c == agg_job_init_req.prep_inits.len() => { + // All the extensions use taskprov and look ok, so compute first_metadata. + // Note this will always be Some(). + agg_job_init_req + .prep_inits + .first() + .map(|prep_init| &prep_init.report_share.report_metadata) + } + _ => { + // It's not all taskprov or no taskprov, so it's an error. + return Err(DapAbort::InvalidMessage { + detail: "some reports include the taskprov extensions and some do not" + .to_string(), + task_id: Some(*task_id), + }); + } } + } else { + None }; + resolve_taskprov(self, task_id, req, first_metadata).await?; } diff --git a/daphne/src/roles/mod.rs b/daphne/src/roles/mod.rs index dcad0cb3d..1f3d2728b 100644 --- a/daphne/src/roles/mod.rs +++ b/daphne/src/roles/mod.rs @@ -1795,7 +1795,10 @@ mod test { &task_id, DapMeasurement::U32Vec(vec![1; 10]), vec![Extension::Taskprov { - payload: taskprov_report_extension_payload.clone(), + draft02_payload: match version { + DapVersion::Draft07 => None, + DapVersion::Draft02 => Some(taskprov_report_extension_payload.clone()), + }, }], version, ) diff --git a/daphne/src/taskprov.rs b/daphne/src/taskprov.rs index e7df87e1e..d00e606ad 100644 --- a/daphne/src/taskprov.rs +++ b/daphne/src/taskprov.rs @@ -164,8 +164,10 @@ fn get_taskprov_task_config( match taskprovs.len() { 0 => return Ok(None), 1 => match &taskprovs[0] { - Extension::Taskprov { payload } => Cow::Borrowed(payload), - Extension::Unhandled { .. } => panic!("cannot happen"), + Extension::Taskprov { + draft02_payload: Some(payload), + } => Cow::Borrowed(payload), + _ => panic!("cannot happen"), }, _ => { // The decoder already returns an error if an extension of a give type occurs more @@ -359,11 +361,13 @@ impl TryFrom<&DapTaskConfig> for messages::taskprov::TaskConfig { impl ReportMetadata { /// Does this metatdata have a taskprov extension and does it match the specified id? - pub fn is_taskprov(&self, version: DapVersion, task_id: &TaskId) -> bool { + pub(crate) fn is_taskprov(&self, version: DapVersion, task_id: &TaskId) -> bool { return self.draft02_extensions.as_ref().is_some_and(|extensions| { extensions.iter().any(|x| match x { - Extension::Taskprov { payload } => *task_id == compute_task_id(version, payload), - Extension::Unhandled { .. } => false, + Extension::Taskprov { + draft02_payload: Some(payload), + } => *task_id == compute_task_id(version, payload), + _ => false, }) }); } @@ -521,7 +525,7 @@ mod test { id: ReportId([0; 16]), time: 0, draft02_extensions: Some(vec![Extension::Taskprov { - payload: taskprov_task_config_data, + draft02_payload: Some(taskprov_task_config_data), }]), }), ) diff --git a/daphne/src/vdaf/mod.rs b/daphne/src/vdaf/mod.rs index 4c5ce290d..a0386e625 100644 --- a/daphne/src/vdaf/mod.rs +++ b/daphne/src/vdaf/mod.rs @@ -220,15 +220,20 @@ impl<'req> EarlyReportStateConsumed<'req> { // draft, the plaintext also encodes the report extensions. let (input_share, draft07_extensions) = match task_config.version { DapVersion::Draft02 => (encoded_input_share, None), - DapVersion::Draft07 => match PlaintextInputShare::get_decoded(&encoded_input_share) { - Ok(input_share) => (input_share.payload, Some(input_share.extensions)), - Err(..) => { - return Ok(Self::Rejected { - metadata, - failure: TransitionFailure::InvalidMessage, - }) + DapVersion::Draft07 => { + match PlaintextInputShare::get_decoded_with_param( + &task_config.version, + &encoded_input_share, + ) { + Ok(input_share) => (input_share.payload, Some(input_share.extensions)), + Err(..) => { + return Ok(Self::Rejected { + metadata, + failure: TransitionFailure::InvalidMessage, + }) + } } - }, + } }; // Handle report extensions. @@ -238,26 +243,43 @@ impl<'req> EarlyReportStateConsumed<'req> { DapVersion::Draft02 => metadata.as_ref().draft02_extensions.as_ref().unwrap(), }; + let mut taskprov_indicated = false; let mut seen: HashSet = HashSet::with_capacity(extensions.len()); for extension in extensions { + // Reject reports with duplicated extensions. if !seen.insert(extension.type_code()) { return Ok(Self::Rejected { metadata, failure: TransitionFailure::InvalidMessage, }); } - // draft02 compatibility: In the latest version, reports with unrecognized - // extensions are rejected; in draft02, the Aggregator is just supposed to ignore - // unrecognized extensions. - if task_config.version != DapVersion::Draft02 - && matches!(extension, Extension::Unhandled { .. }) - { - return Ok(Self::Rejected { - metadata, - failure: TransitionFailure::InvalidMessage, - }); + + match (task_config.version, extension) { + (.., Extension::Taskprov { .. }) if task_config.method_is_taskprov() => { + taskprov_indicated = true; + } + + // Reject reports with unrecognized extensions. + (DapVersion::Draft07, ..) => { + return Ok(Self::Rejected { + metadata, + failure: TransitionFailure::InvalidMessage, + }) + } + + // draft02 compatibility: Ignore unrecognized extensions. + (DapVersion::Draft02, ..) => (), } } + + if task_config.method_is_taskprov() && !taskprov_indicated { + // taskprov: If the task configuration method is taskprov, then we expect each + // report to indicate support. + return Ok(Self::Rejected { + metadata, + failure: TransitionFailure::InvalidMessage, + }); + } } Ok(Self::Ready { @@ -691,7 +713,7 @@ impl VdafConfig { let encoded_input_shares = input_shares.into_iter().map(|input_share| { if let Some(ref mut plaintext_input_share) = draft07_plaintext_input_share { plaintext_input_share.payload = input_share; - plaintext_input_share.get_encoded() + plaintext_input_share.get_encoded_with_param(&version) } else { input_share } @@ -2592,7 +2614,7 @@ mod test { assert!(DapAggregationJobState::get_decoded(TEST_VDAF, b"invalid helper state").is_err()); } - async fn heandle_unrecognized_report_extensions(version: DapVersion) { + async fn handle_unrecognized_report_extensions(version: DapVersion) { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); let report = t .task_config @@ -2602,7 +2624,7 @@ mod test { t.now, &t.task_id, DapMeasurement::U64(1), - vec![Extension::Unhandled { + vec![Extension::NotImplemented { typ: 0xffff, payload: b"some extension data".to_vec(), }], @@ -2634,9 +2656,9 @@ mod test { assert_eq!(consumed_report.is_ready(), expect_ready); } - async_test_versions! { heandle_unrecognized_report_extensions } + async_test_versions! { handle_unrecognized_report_extensions } - async fn heandle_repeated_report_extensions(version: DapVersion) { + async fn handle_repeated_report_extensions(version: DapVersion) { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); let report = t .task_config @@ -2647,10 +2669,12 @@ mod test { &t.task_id, DapMeasurement::U64(1), vec![ - Extension::Taskprov { + Extension::NotImplemented { + typ: 23, payload: b"this payload shouldn't be interpretd yet".to_vec(), }, - Extension::Taskprov { + Extension::NotImplemented { + typ: 23, payload: b"nor should this payload".to_vec(), }, ], @@ -2674,7 +2698,7 @@ mod test { assert!(!consumed_report.is_ready()); } - async_test_versions! { heandle_repeated_report_extensions } + async_test_versions! { handle_repeated_report_extensions } impl AggregationJobTest { // Tweak the Helper's share so that decoding succeeds but preparation fails. diff --git a/daphne_worker_test/tests/e2e/e2e.rs b/daphne_worker_test/tests/e2e/e2e.rs index 42797ae9c..940ad096a 100644 --- a/daphne_worker_test/tests/e2e/e2e.rs +++ b/daphne_worker_test/tests/e2e/e2e.rs @@ -406,7 +406,7 @@ async fn leader_upload_taskprov() { &task_id, DapMeasurement::U32Vec(vec![1; 10]), vec![Extension::Taskprov { - payload: taskprov_task_config.get_encoded_with_param(&version), + draft02_payload: Some(taskprov_task_config.get_encoded_with_param(&version)), }], version, ) @@ -424,7 +424,9 @@ async fn leader_upload_taskprov() { let mut bad_payload = payload.clone(); bad_payload[0] = u8::wrapping_add(bad_payload[0], 1); let task_id = compute_task_id(DapVersion::Draft02, &bad_payload); - let extensions = vec![Extension::Taskprov { payload }]; + let extensions = vec![Extension::Taskprov { + draft02_payload: Some(payload), + }]; let report = task_config .vdaf .produce_report_with_extensions( @@ -1310,7 +1312,10 @@ async fn leader_collect_taskprov_ok(version: DapVersion) { let mut rng = thread_rng(); for _ in 0..t.task_config.min_batch_size { let extensions = vec![Extension::Taskprov { - payload: taskprov_report_extension_payload.clone(), + draft02_payload: match version { + DapVersion::Draft07 => None, + DapVersion::Draft02 => Some(taskprov_report_extension_payload.clone()), + }, }]; let now = rng.gen_range(t.report_interval(&batch_interval)); t.leader_put_expect_ok(