diff --git a/daphne/src/messages/mod.rs b/daphne/src/messages/mod.rs index f36924e5d..db977ef4c 100644 --- a/daphne/src/messages/mod.rs +++ b/daphne/src/messages/mod.rs @@ -1634,82 +1634,6 @@ mod test { ); } - // NOTE: these test vectors are no longer valid, TaskProv doesn't match the VDAF-06 spec. - // Tracking the issue here: https://github.com/wangshan/draft-wang-ppm-dap-taskprov/issues/33. - // #[test] - // fn read_vdaf_config() { - // let data = [ - // 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x18, 0x01, 0x02, 0x03, 0x04, 0x04, 0x03, 0x02, - // 0x01, 0x02, 0x02, 0x03, 0x04, 0x04, 0x03, 0x02, 0x02, 0x03, 0x02, 0x03, 0x04, 0x04, 0x03, - // 0x02, 0x03, - // ]; - - // // let buckets = vec![0x0102030404030201, 0x0202030404030202, 0x0302030404030203]; - // let len_length: u8 = 1; - // let vdaf_config = VdafConfig::get_decoded(&data).unwrap(); - // assert_eq!( - // vdaf_config, - // VdafConfig { - // dp_config: DpConfig::None, - // // var: VdafTypeVar::Prio3Aes128Histogram { buckets }, - // var: VdafTypeVar::Prio3Aes128Histogram { len_length }, - // } - // ); - // } - - // #[test] - // fn read_task_config_taskprov_draft02() { - // let data = [ - // 0x02, 0x48, 0x69, 0x00, 0x0e, 0x00, 0x0c, 0x68, 0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, - // 0x74, 0x65, 0x73, 0x74, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x80, - // 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x63, 0x52, 0xf9, - // 0xa5, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x18, 0x01, 0x02, 0x03, 0x04, 0x04, 0x03, - // 0x02, 0x01, 0x02, 0x02, 0x03, 0x04, 0x04, 0x03, 0x02, 0x02, 0x03, 0x02, 0x03, 0x04, 0x04, - // 0x03, 0x02, 0x03, - // ]; - - // // let buckets = vec![0x0102030404030201, 0x0202030404030202, 0x0302030404030203]; - // let len_length: u8 = 1; - // let task_config = TaskConfig::get_decoded_with_param(&TaskprovVersion::Draft02, &data).unwrap(); - // assert_eq!( - // task_config, - // TaskConfig { - // task_info: "Hi".as_bytes().to_vec(), - // aggregator_endpoints: vec![UrlBytes { - // bytes: "https://test".as_bytes().to_vec() - // }], - // query_config: QueryConfig { - // time_precision: 0x01, - // max_batch_query_count: 128, - // min_batch_size: 1024, - // var: QueryConfigVar::FixedSize { - // max_batch_size: 2048 - // }, - // }, - // task_expiration: 0x6352f9a5, - // vdaf_config: VdafConfig { - // dp_config: DpConfig::None, - // var: VdafTypeVar::Prio3Aes128Histogram { len_length }, - // }, - // } - // ); - - // assert_eq!( - // compute_task_id( - // TaskprovVersion::Draft02, - // &task_config.get_encoded_with_param(&TaskprovVersion::Draft02) - // ) - // .unwrap() - // .to_hex(), - // "2b585fcbb48c21fb5f05221a241fdd8cb9ebe99bd183d66326fcecd85fe06fd5", - // ); - - // assert_eq!( - // task_config.get_encoded_with_param(&TaskprovVersion::Draft02), - // &data - // ); - // } - #[test] fn test_base64url() { let mut rng = thread_rng(); @@ -1747,4 +1671,81 @@ mod test { let id = TaskId([7; 32]); assert_eq!(TaskId::try_from_base64url(id.to_base64url()).unwrap(), id); } + + fn roundtrip_taskprov_query_config(version: DapVersion) { + let query_config = taskprov::QueryConfig { + time_precision: 12_345_678, + max_batch_query_count: 1337, + min_batch_size: 12_345_678, + var: taskprov::QueryConfigVar::TimeInterval, + }; + assert_eq!( + taskprov::QueryConfig::get_decoded_with_param( + &version, + &query_config.get_encoded_with_param(&version) + ) + .unwrap(), + query_config + ); + + let query_config = taskprov::QueryConfig { + time_precision: 12_345_678, + max_batch_query_count: 1337, + min_batch_size: 12_345_678, + var: taskprov::QueryConfigVar::FixedSize { + max_batch_size: 12_345_678, + }, + }; + assert_eq!( + taskprov::QueryConfig::get_decoded_with_param( + &version, + &query_config.get_encoded_with_param(&version) + ) + .unwrap(), + query_config + ); + } + + test_versions! { roundtrip_taskprov_query_config } + + #[test] + fn roundtrip_taskprov_query_config_not_implemented_draft07() { + let query_config = taskprov::QueryConfig { + time_precision: 12_345_678, + max_batch_query_count: 1337, + min_batch_size: 12_345_678, + var: taskprov::QueryConfigVar::NotImplemented { + typ: 0, + param: b"query config param".to_vec(), + }, + }; + assert_eq!( + taskprov::QueryConfig::get_decoded_with_param( + &DapVersion::Draft07, + &query_config.get_encoded_with_param(&DapVersion::Draft07) + ) + .unwrap(), + query_config + ); + } + + #[test] + fn roundtrip_taskprov_query_config_not_implemented_draft02() { + let query_config = taskprov::QueryConfig { + time_precision: 12_345_678, + max_batch_query_count: 1337, + min_batch_size: 12_345_678, + var: taskprov::QueryConfigVar::NotImplemented { + typ: 0, + param: b"query config param".to_vec(), + }, + }; + + // Expect error because unimplemented query types aren't decodable. + assert!(taskprov::QueryConfig::get_decoded_with_param( + &DapVersion::Draft02, + &query_config.get_encoded_with_param(&DapVersion::Draft02) + ) + .is_err()); + } } diff --git a/daphne/src/messages/taskprov.rs b/daphne/src/messages/taskprov.rs index c033bbab8..d4f24b7b8 100644 --- a/daphne/src/messages/taskprov.rs +++ b/daphne/src/messages/taskprov.rs @@ -16,7 +16,7 @@ use prio::codec::{ }; use ring::hkdf::KeyType; use serde::{Deserialize, Serialize}; -use std::io::Cursor; +use std::io::{Cursor, Read}; // VDAF type codes. const VDAF_TYPE_PRIO2: u32 = 0xFFFF_0000; @@ -163,13 +163,9 @@ impl Decode for UrlBytes { pub enum QueryConfigVar { TimeInterval, FixedSize { max_batch_size: u32 }, + NotImplemented { typ: u8, param: Vec }, } -// There is no Encode or Decode for QueryConfigVar as we have to split the query type and -// the associated configuration data in the message format, so we must do all of the work -// in QueryConfig's Encode and Decode. If the spec is revised to allow these fields -// to be encoded and decoded contiguously, then we will revise this code. - /// A query configuration. #[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)] pub struct QueryConfig { @@ -188,47 +184,101 @@ impl QueryConfig { QueryConfigVar::FixedSize { .. } => { QUERY_TYPE_FIXED_SIZE.encode(bytes); } + QueryConfigVar::NotImplemented { typ, .. } => { + typ.encode(bytes); + } } } } impl ParameterizedEncode for QueryConfig { - fn encode_with_param(&self, _version: &DapVersion, bytes: &mut Vec) { - self.encode_query_type(bytes); + fn encode_with_param(&self, version: &DapVersion, bytes: &mut Vec) { + if *version == DapVersion::Draft02 { + self.encode_query_type(bytes); + } self.time_precision.encode(bytes); self.max_batch_query_count.encode(bytes); self.min_batch_size.encode(bytes); match &self.var { - QueryConfigVar::TimeInterval => (), + QueryConfigVar::TimeInterval => { + if *version == DapVersion::Draft07 { + QUERY_TYPE_TIME_INTERVAL.encode(bytes); + 0_u16.encode(bytes); + } + } QueryConfigVar::FixedSize { max_batch_size } => { + if *version == DapVersion::Draft07 { + QUERY_TYPE_FIXED_SIZE.encode(bytes); + 4_u16.encode(bytes); + } max_batch_size.encode(bytes); } + QueryConfigVar::NotImplemented { typ, param } => { + if *version == DapVersion::Draft07 { + typ.encode(bytes); + u16::try_from(param.len()).unwrap().encode(bytes); + } + bytes.extend_from_slice(param); + } } } } impl ParameterizedDecode for QueryConfig { fn decode_with_param( - _version: &DapVersion, + version: &DapVersion, bytes: &mut Cursor<&[u8]>, ) -> Result { - let query_type = u8::decode(bytes)?; - let time_precision = Duration::decode(bytes)?; - let max_batch_query_count = u16::decode(bytes)?; - let min_batch_size = u32::decode(bytes)?; - let var = match query_type { - QUERY_TYPE_TIME_INTERVAL => Ok(QueryConfigVar::TimeInterval), - QUERY_TYPE_FIXED_SIZE => Ok(QueryConfigVar::FixedSize { - max_batch_size: u32::decode(bytes)?, - }), - _ => Err(CodecError::UnexpectedValue), - }?; - Ok(Self { - time_precision, - max_batch_query_count, - min_batch_size, - var, - }) + match version { + DapVersion::Draft07 => { + let time_precision = Duration::decode(bytes)?; + let max_batch_query_count = u16::decode(bytes)?; + let min_batch_size = u32::decode(bytes)?; + let query_type = u8::decode(bytes)?; + let query_type_param_len = u16::decode(bytes)?.try_into().unwrap(); + let var = match query_type { + QUERY_TYPE_TIME_INTERVAL => QueryConfigVar::TimeInterval, + QUERY_TYPE_FIXED_SIZE => QueryConfigVar::FixedSize { + max_batch_size: u32::decode(bytes)?, + }, + _ => { + let mut query_type_param = vec![0; query_type_param_len]; + bytes.read_exact(&mut query_type_param)?; + QueryConfigVar::NotImplemented { + typ: query_type, + param: query_type_param, + } + } + }; + Ok(Self { + time_precision, + max_batch_query_count, + min_batch_size, + var, + }) + } + DapVersion::Draft02 => { + let query_type = u8::decode(bytes)?; + let time_precision = Duration::decode(bytes)?; + let max_batch_query_count = u16::decode(bytes)?; + let min_batch_size = u32::decode(bytes)?; + let var = match query_type { + QUERY_TYPE_TIME_INTERVAL => QueryConfigVar::TimeInterval, + QUERY_TYPE_FIXED_SIZE => QueryConfigVar::FixedSize { + max_batch_size: u32::decode(bytes)?, + }, + // draft02 compatibility: Unrecognized query types are not decodable, so we're + // forced to abort at this point. + _ => return Err(CodecError::UnexpectedValue), + }; + Ok(Self { + time_precision, + max_batch_query_count, + min_batch_size, + var, + }) + } + } } } diff --git a/daphne/src/taskprov.rs b/daphne/src/taskprov.rs index 700320f67..3d40b4636 100644 --- a/daphne/src/taskprov.rs +++ b/daphne/src/taskprov.rs @@ -214,13 +214,17 @@ fn url_from_bytes(task_id: &TaskId, url_bytes: &[u8]) -> Result { }) } -impl From for DapQueryConfig { - fn from(var: QueryConfigVar) -> Self { +impl DapQueryConfig { + fn try_from_taskprov(task_id: &TaskId, var: QueryConfigVar) -> Result { match var { - QueryConfigVar::FixedSize { max_batch_size } => DapQueryConfig::FixedSize { + QueryConfigVar::FixedSize { max_batch_size } => Ok(DapQueryConfig::FixedSize { max_batch_size: max_batch_size.into(), - }, - QueryConfigVar::TimeInterval => DapQueryConfig::TimeInterval, + }), + QueryConfigVar::TimeInterval => Ok(DapQueryConfig::TimeInterval), + QueryConfigVar::NotImplemented { typ, .. } => Err(DapAbort::InvalidTask { + detail: format!("unimplemented query type ({typ})"), + task_id: *task_id, + }), } } } @@ -277,7 +281,7 @@ impl DapTaskConfig { time_precision: task_config.query_config.time_precision, expiration: task_config.task_expiration, min_batch_size: task_config.query_config.min_batch_size.into(), - query: DapQueryConfig::from(task_config.query_config.var), + query: DapQueryConfig::try_from_taskprov(task_id, task_config.query_config.var)?, vdaf: VdafConfig::from(task_config.vdaf_config.var), vdaf_verify_key: compute_vdaf_verify_key( version,