From 0c85eb3ba3671272e010992cbd6f504aa5cadd36 Mon Sep 17 00:00:00 2001 From: Christopher Patton Date: Tue, 28 Nov 2023 14:29:32 -0800 Subject: [PATCH] taskprov: Enforce param length prefixes Structs `VdafConfig`, `DpConfig`, and `QueryConfig` all include a length prefix for the type-specific parameters so that the struct can be decoded even if the type is not recognized. If the parameters don't fill the space indicated by the length prefix, then decoding should fail. This change ensures the length is enforced for recognized types. It unifies the decoding logic across draft02 and draft07 (the former does not have the length prefix). --- daphne/src/messages/mod.rs | 55 ++++++++++ daphne/src/messages/taskprov.rs | 177 ++++++++++++-------------------- 2 files changed, 123 insertions(+), 109 deletions(-) diff --git a/daphne/src/messages/mod.rs b/daphne/src/messages/mod.rs index 564e7f6ee..70d571cf8 100644 --- a/daphne/src/messages/mod.rs +++ b/daphne/src/messages/mod.rs @@ -1214,6 +1214,61 @@ pub fn decode_base64url_vec>(input: T) -> Option> { URL_SAFE_NO_PAD.decode(input).ok() } +fn encode_u16_item_for_version>( + bytes: &mut Vec, + version: DapVersion, + item: &E, +) { + match version { + DapVersion::Draft07 => { + // Cribbed from `decode_u16_items()` from libprio. + // + // Reserve space for the length prefix. + let len_offset = bytes.len(); + 0_u16.encode(bytes); + + item.encode_with_param(&version, bytes); + let len_bytes = std::mem::size_of::(); + let len = bytes.len() - len_offset - len_bytes; + bytes[len_offset..len_offset + len_bytes] + .copy_from_slice(&u16::to_be_bytes(len.try_into().unwrap())); + } + + DapVersion::Draft02 => item.encode_with_param(&version, bytes), + } +} + +pub fn decode_u16_item_for_version>( + version: &DapVersion, + bytes: &mut Cursor<&[u8]>, +) -> Result { + match version { + DapVersion::Draft07 => { + // Cribbed from `decode_u16_items()` from libprio. + // + // Read the length prefix. + let len = usize::from(u16::decode(bytes)?); + + let item_start = usize::try_from(bytes.position()).unwrap(); + + // Make sure encoded length doesn't overflow usize or go past the end of provided byte buffer. + let item_end = item_start + .checked_add(len) + .ok_or_else(|| CodecError::LengthPrefixTooBig(len))?; + + let decoded = + D::get_decoded_with_param(version, &bytes.get_ref()[item_start..item_end])?; + + // Advance outer cursor by the amount read in the inner cursor. + bytes.set_position(item_end.try_into().unwrap()); + + Ok(decoded) + } + + DapVersion::Draft02 => D::decode_with_param(version, bytes), + } +} + #[cfg(test)] mod test { use super::*; diff --git a/daphne/src/messages/taskprov.rs b/daphne/src/messages/taskprov.rs index c40f70688..188c01b81 100644 --- a/daphne/src/messages/taskprov.rs +++ b/daphne/src/messages/taskprov.rs @@ -5,8 +5,8 @@ //! defined in draft-wang-ppm-dap-taskprov-02. use crate::messages::{ - decode_u16_bytes, encode_u16_bytes, Duration, Time, QUERY_TYPE_FIXED_SIZE, - QUERY_TYPE_TIME_INTERVAL, + decode_u16_bytes, decode_u16_item_for_version, encode_u16_bytes, encode_u16_item_for_version, + Duration, Time, QUERY_TYPE_FIXED_SIZE, QUERY_TYPE_TIME_INTERVAL, }; use crate::DapVersion; use prio::codec::{ @@ -14,7 +14,7 @@ use prio::codec::{ Encode, ParameterizedDecode, ParameterizedEncode, }; use serde::{Deserialize, Serialize}; -use std::io::{Cursor, Read}; +use std::io::Cursor; // VDAF type codes. const VDAF_TYPE_PRIO2: u32 = 0xFFFF_0000; @@ -31,20 +31,17 @@ pub enum VdafTypeVar { impl ParameterizedEncode for VdafTypeVar { fn encode_with_param(&self, version: &DapVersion, bytes: &mut Vec) { - match &self { + match self { Self::Prio2 { dimension } => { VDAF_TYPE_PRIO2.encode(bytes); - if *version != DapVersion::Draft02 { - 4_u16.encode(bytes); - } - dimension.encode(bytes); + encode_u16_item_for_version(bytes, *version, dimension); } Self::NotImplemented { typ, param } => { typ.encode(bytes); - if *version != DapVersion::Draft02 { - u16::try_from(param.len()).unwrap().encode(bytes); + match version { + DapVersion::Draft07 => encode_u16_bytes(bytes, param), + DapVersion::Draft02 => bytes.extend_from_slice(param), } - bytes.extend_from_slice(param); } } } @@ -56,25 +53,17 @@ impl ParameterizedDecode for VdafTypeVar { bytes: &mut Cursor<&[u8]>, ) -> Result { let vdaf_type = u32::decode(bytes)?; - let vdaf_type_param_len = match version { - DapVersion::Draft07 => Some(u16::decode(bytes)?.try_into().unwrap()), - DapVersion::Draft02 => None, - }; - match (vdaf_type, vdaf_type_param_len) { - (VDAF_TYPE_PRIO2, _) => Ok(Self::Prio2 { - dimension: u32::decode(bytes)?, + match (version, vdaf_type) { + (.., VDAF_TYPE_PRIO2) => Ok(Self::Prio2 { + dimension: decode_u16_item_for_version(version, bytes)?, + }), + (DapVersion::Draft07, ..) => Ok(Self::NotImplemented { + typ: vdaf_type, + param: decode_u16_bytes(bytes)?, }), - (_, Some(len)) => { - let mut param = vec![0; len]; - bytes.read_exact(&mut param)?; - Ok(Self::NotImplemented { - typ: vdaf_type, - param, - }) - } // draft02 compatibility: We don't recognize the VDAF type, which means the rest of // this message is not decodable. We must abort. - _ => Err(CodecError::UnexpectedValue), + (DapVersion::Draft02, ..) => Err(CodecError::UnexpectedValue), } } } @@ -91,16 +80,15 @@ impl ParameterizedEncode for DpConfig { match self { Self::None => { DP_MECHANISM_NONE.encode(bytes); - if *version != DapVersion::Draft02 { - 0_u16.encode(bytes); - } + encode_u16_item_for_version(bytes, *version, &()); } + Self::NotImplemented { typ, param } => { typ.encode(bytes); - if *version != DapVersion::Draft02 { - u16::try_from(param.len()).unwrap().encode(bytes); + match version { + DapVersion::Draft07 => encode_u16_bytes(bytes, param), + DapVersion::Draft02 => bytes.extend_from_slice(param), } - bytes.extend_from_slice(param); } } } @@ -112,23 +100,18 @@ impl ParameterizedDecode for DpConfig { bytes: &mut Cursor<&[u8]>, ) -> Result { let dp_mechanism = u8::decode(bytes)?; - let dp_mechanism_param_len = match version { - DapVersion::Draft07 => Some(u16::decode(bytes)?.try_into().unwrap()), - DapVersion::Draft02 => None, - }; - match (dp_mechanism, dp_mechanism_param_len) { - (DP_MECHANISM_NONE, _) => Ok(Self::None), - (_, Some(len)) => { - let mut param = vec![0; len]; - bytes.read_exact(&mut param)?; - Ok(Self::NotImplemented { - typ: dp_mechanism, - param, - }) + match (version, dp_mechanism) { + (.., DP_MECHANISM_NONE) => { + decode_u16_item_for_version::<()>(version, bytes)?; + Ok(Self::None) } + (DapVersion::Draft07, ..) => Ok(Self::NotImplemented { + typ: dp_mechanism, + param: decode_u16_bytes(bytes)?, + }), // draft02 compatibility: We must abort because unimplemented DP mechansims can't be // decoded. - _ => Err(CodecError::UnexpectedValue), + (DapVersion::Draft02, ..) => Err(CodecError::UnexpectedValue), } } } @@ -224,24 +207,19 @@ impl ParameterizedEncode for QueryConfig { self.min_batch_size.encode(bytes); match &self.var { QueryConfigVar::TimeInterval => { - if *version == DapVersion::Draft07 { - QUERY_TYPE_TIME_INTERVAL.encode(bytes); - 0_u16.encode(bytes); - } + QUERY_TYPE_TIME_INTERVAL.encode(bytes); + encode_u16_item_for_version(bytes, *version, &()); } QueryConfigVar::FixedSize { max_batch_size } => { - if *version == DapVersion::Draft07 { - QUERY_TYPE_FIXED_SIZE.encode(bytes); - 4_u16.encode(bytes); - } - max_batch_size.encode(bytes); + QUERY_TYPE_FIXED_SIZE.encode(bytes); + encode_u16_item_for_version(bytes, *version, max_batch_size); } QueryConfigVar::NotImplemented { typ, param } => { - if *version == DapVersion::Draft07 { - typ.encode(bytes); - u16::try_from(param.len()).unwrap().encode(bytes); + typ.encode(bytes); + match version { + DapVersion::Draft07 => encode_u16_bytes(bytes, param), + DapVersion::Draft02 => bytes.extend_from_slice(param), } - bytes.extend_from_slice(param); } } } @@ -252,56 +230,37 @@ impl ParameterizedDecode for QueryConfig { version: &DapVersion, bytes: &mut Cursor<&[u8]>, ) -> Result { - match version { - DapVersion::Draft07 => { - let time_precision = Duration::decode(bytes)?; - let max_batch_query_count = u16::decode(bytes)?; - let min_batch_size = u32::decode(bytes)?; - let query_type = u8::decode(bytes)?; - let query_type_param_len = u16::decode(bytes)?.try_into().unwrap(); - let var = match query_type { - QUERY_TYPE_TIME_INTERVAL => QueryConfigVar::TimeInterval, - QUERY_TYPE_FIXED_SIZE => QueryConfigVar::FixedSize { - max_batch_size: u32::decode(bytes)?, - }, - _ => { - let mut query_type_param = vec![0; query_type_param_len]; - bytes.read_exact(&mut query_type_param)?; - QueryConfigVar::NotImplemented { - typ: query_type, - param: query_type_param, - } - } - }; - Ok(Self { - time_precision, - max_batch_query_count, - min_batch_size, - var, - }) - } - DapVersion::Draft02 => { - let query_type = u8::decode(bytes)?; - let time_precision = Duration::decode(bytes)?; - let max_batch_query_count = u16::decode(bytes)?; - let min_batch_size = u32::decode(bytes)?; - let var = match query_type { - QUERY_TYPE_TIME_INTERVAL => QueryConfigVar::TimeInterval, - QUERY_TYPE_FIXED_SIZE => QueryConfigVar::FixedSize { - max_batch_size: u32::decode(bytes)?, - }, - // draft02 compatibility: Unrecognized query types are not decodable, so we're - // forced to abort at this point. - _ => return Err(CodecError::UnexpectedValue), - }; - Ok(Self { - time_precision, - max_batch_query_count, - min_batch_size, - var, - }) + let query_type = match version { + DapVersion::Draft07 => None, + DapVersion::Draft02 => Some(u8::decode(bytes)?), + }; + let time_precision = Duration::decode(bytes)?; + let max_batch_query_count = u16::decode(bytes)?; + let min_batch_size = u32::decode(bytes)?; + let query_type = query_type.unwrap_or(u8::decode(bytes)?); + let var = match (version, query_type) { + (.., QUERY_TYPE_TIME_INTERVAL) => { + decode_u16_item_for_version::<()>(version, bytes)?; + QueryConfigVar::TimeInterval } - } + (.., QUERY_TYPE_FIXED_SIZE) => QueryConfigVar::FixedSize { + max_batch_size: decode_u16_item_for_version(version, bytes)?, + }, + (DapVersion::Draft07, ..) => QueryConfigVar::NotImplemented { + typ: query_type, + param: decode_u16_bytes(bytes)?, + }, + // draft02 compatibility: We must abort because unimplemented query configurations + // can't be decoded. + (DapVersion::Draft02, ..) => return Err(CodecError::UnexpectedValue), + }; + + Ok(Self { + time_precision, + max_batch_query_count, + min_batch_size, + var, + }) } }