Skip to content

Commit

Permalink
taskprov: Improve extension payload handling
Browse files Browse the repository at this point in the history
* If taskprov is disabled for the task, then handle the extension as
  unrecognized.

* If taskprov is enabled, then assert that the extension payload is
  empty (in draft07).

* Fence the payload consistency check to draft02.

While at it, align unimplemented-extension serialization with other
unimplemented fields.
  • Loading branch information
cjpatton committed Nov 29, 2023
1 parent ef17992 commit 2bdfc8c
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 129 deletions.
152 changes: 94 additions & 58 deletions daphne/src/messages/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,42 +125,69 @@ pub type Time = u64;
#[serde(rename_all = "snake_case")]
#[cfg_attr(any(test, feature = "test-utils"), derive(deepsize::DeepSizeOf))]
pub enum Extension {
Taskprov { payload: Vec<u8> }, // Not a TaskConfig to make computing the expected task id more efficient
Unhandled { typ: u16, payload: Vec<u8> },
Taskprov {
// draft02 compatibility: The payload is the serialized `TaskConfig` advertised by each
// Client. We treat it as an opaque byte string here to save time during the aggregation
// sub-protocol. Before we deserialize it, we need to check (1) each Client has the same
// extension paylaod and (2) the task ID matches the hash of the extension payload. After
// we do this check, we need only to deserialize it once.
draft02_payload: Option<Vec<u8>>,
},
NotImplemented {
typ: u16,
payload: Vec<u8>,
},
}

impl Extension {
/// Return the type code associated with the extension
pub(crate) fn type_code(&self) -> u16 {
match self {
Self::Taskprov { .. } => EXTENSION_TASKPROV,
Self::Unhandled { typ, .. } => *typ,
Self::NotImplemented { typ, .. } => *typ,
}
}
}

impl Encode for Extension {
fn encode(&self, bytes: &mut Vec<u8>) {
impl ParameterizedEncode<DapVersion> for Extension {
fn encode_with_param(&self, version: &DapVersion, bytes: &mut Vec<u8>) {
match self {
Self::Taskprov { payload } => {
Self::Taskprov { draft02_payload } => {
EXTENSION_TASKPROV.encode(bytes);
encode_u16_bytes(bytes, payload);
match (version, draft02_payload) {
(DapVersion::Draft07, None) => encode_u16_item(bytes, *version, &()),
(DapVersion::Draft02, Some(payload)) => encode_u16_bytes(bytes, payload),
_ => unreachable!("unhandled version {version:?}"),
}
}
Self::Unhandled { typ, payload } => {
Self::NotImplemented { typ, payload } => {
typ.encode(bytes);
encode_u16_bytes(bytes, payload);
}
}
}
}

impl Decode for Extension {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
impl ParameterizedDecode<DapVersion> for Extension {
fn decode_with_param(
version: &DapVersion,
bytes: &mut Cursor<&[u8]>,
) -> Result<Self, CodecError> {
let typ = u16::decode(bytes)?;
let payload = decode_u16_bytes(bytes)?;
match typ {
EXTENSION_TASKPROV => Ok(Self::Taskprov { payload }),
_ => Ok(Self::Unhandled { typ, payload }),
match (version, typ) {
(DapVersion::Draft07, EXTENSION_TASKPROV) => {
decode_u16_item::<()>(*version, bytes)?;
Ok(Self::Taskprov {
draft02_payload: None,
})
}
(DapVersion::Draft02, EXTENSION_TASKPROV) => Ok(Self::Taskprov {
draft02_payload: Some(decode_u16_bytes(bytes)?),
}),
_ => Ok(Self::NotImplemented {
typ,
payload: decode_u16_bytes(bytes)?,
}),
}
}
}
Expand All @@ -182,7 +209,7 @@ impl ParameterizedEncode<DapVersion> for ReportMetadata {
self.time.encode(bytes);
match (version, &self.draft02_extensions) {
(DapVersion::Draft07, None) => (),
(DapVersion::Draft02, Some(extensions)) => encode_u16_items(bytes, &(), extensions),
(DapVersion::Draft02, Some(extensions)) => encode_u16_items(bytes, version, extensions),
_ => unreachable!("extensions should be set in (and only in) draft02"),
}
}
Expand All @@ -197,7 +224,7 @@ impl ParameterizedDecode<DapVersion> for ReportMetadata {
id: ReportId::decode(bytes)?,
time: Time::decode(bytes)?,
draft02_extensions: match version {
DapVersion::Draft02 => Some(decode_u16_items(&(), bytes)?),
DapVersion::Draft02 => Some(decode_u16_items(version, bytes)?),
DapVersion::Draft07 => None,
},
};
Expand Down Expand Up @@ -1130,17 +1157,20 @@ pub struct PlaintextInputShare {
pub payload: Vec<u8>,
}

impl Encode for PlaintextInputShare {
fn encode(&self, bytes: &mut Vec<u8>) {
encode_u16_items(bytes, &(), &self.extensions);
impl ParameterizedEncode<DapVersion> for PlaintextInputShare {
fn encode_with_param(&self, version: &DapVersion, bytes: &mut Vec<u8>) {
encode_u16_items(bytes, version, &self.extensions);
encode_u32_bytes(bytes, &self.payload);
}
}

impl Decode for PlaintextInputShare {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
impl ParameterizedDecode<DapVersion> for PlaintextInputShare {
fn decode_with_param(
version: &DapVersion,
bytes: &mut Cursor<&[u8]>,
) -> Result<Self, CodecError> {
Ok(Self {
extensions: decode_u16_items(&(), bytes)?,
extensions: decode_u16_items(version, bytes)?,
payload: decode_u32_bytes(bytes)?,
})
}
Expand Down Expand Up @@ -1214,58 +1244,64 @@ pub fn decode_base64url_vec<T: AsRef<[u8]>>(input: T) -> Option<Vec<u8>> {
URL_SAFE_NO_PAD.decode(input).ok()
}

fn encode_u16_item_for_version<E: ParameterizedEncode<DapVersion>>(
// Cribbed from `decode_u16_items()` from libprio.
fn encode_u16_item<E: ParameterizedEncode<DapVersion>>(
bytes: &mut Vec<u8>,
version: DapVersion,
item: &E,
) {
match version {
DapVersion::Draft07 => {
// Cribbed from `decode_u16_items()` from libprio.
//
// Reserve space for the length prefix.
let len_offset = bytes.len();
0_u16.encode(bytes);

item.encode_with_param(&version, bytes);
let len = bytes.len() - len_offset - 2;
for (offset, byte) in u16::to_be_bytes(len.try_into().unwrap()).iter().enumerate() {
bytes[len_offset + offset] = *byte;
}
}
// Reserve space for the length prefix.
let len_offset = bytes.len();
0_u16.encode(bytes);

DapVersion::Draft02 => item.encode_with_param(&version, bytes),
item.encode_with_param(&version, bytes);
let len = bytes.len() - len_offset - 2;
for (offset, byte) in u16::to_be_bytes(len.try_into().unwrap()).iter().enumerate() {
bytes[len_offset + offset] = *byte;
}
}

pub fn decode_u16_item_for_version<D: ParameterizedDecode<DapVersion>>(
version: &DapVersion,
// Cribbed from `decode_u16_items()` from libprio.
fn decode_u16_item<D: ParameterizedDecode<DapVersion>>(
version: DapVersion,
bytes: &mut Cursor<&[u8]>,
) -> Result<D, CodecError> {
match version {
DapVersion::Draft07 => {
// Cribbed from `decode_u16_items()` from libprio.
//
// Read the length prefix.
let len = usize::from(u16::decode(bytes)?);
// Read the length prefix.
let len = usize::from(u16::decode(bytes)?);

let item_start = usize::try_from(bytes.position()).unwrap();
let item_start = usize::try_from(bytes.position()).unwrap();

// Make sure encoded length doesn't overflow usize or go past the end of provided byte buffer.
let item_end = item_start
.checked_add(len)
.ok_or_else(|| CodecError::LengthPrefixTooBig(len))?;
// Make sure encoded length doesn't overflow usize or go past the end of provided byte buffer.
let item_end = item_start
.checked_add(len)
.ok_or_else(|| CodecError::LengthPrefixTooBig(len))?;

let decoded =
D::get_decoded_with_param(version, &bytes.get_ref()[item_start..item_end])?;
let decoded = D::get_decoded_with_param(&version, &bytes.get_ref()[item_start..item_end])?;

// Advance outer cursor by the amount read in the inner cursor.
bytes.set_position(item_end.try_into().unwrap());
// Advance outer cursor by the amount read in the inner cursor.
bytes.set_position(item_end.try_into().unwrap());

Ok(decoded)
}
Ok(decoded)
}

DapVersion::Draft02 => D::decode_with_param(version, bytes),
fn encode_u16_item_for_version<E: ParameterizedEncode<DapVersion>>(
bytes: &mut Vec<u8>,
version: DapVersion,
item: &E,
) {
match version {
DapVersion::Draft07 => encode_u16_item(bytes, version, item),
DapVersion::Draft02 => item.encode_with_param(&version, bytes),
}
}

fn decode_u16_item_for_version<D: ParameterizedDecode<DapVersion>>(
version: DapVersion,
bytes: &mut Cursor<&[u8]>,
) -> Result<D, CodecError> {
match version {
DapVersion::Draft07 => decode_u16_item(version, bytes),
DapVersion::Draft02 => D::decode_with_param(&version, bytes),
}
}

Expand Down
8 changes: 4 additions & 4 deletions daphne/src/messages/taskprov.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl ParameterizedDecode<DapVersion> for VdafTypeVar {
let vdaf_type = u32::decode(bytes)?;
match (version, vdaf_type) {
(.., VDAF_TYPE_PRIO2) => Ok(Self::Prio2 {
dimension: decode_u16_item_for_version(version, bytes)?,
dimension: decode_u16_item_for_version(*version, bytes)?,
}),
(DapVersion::Draft07, ..) => Ok(Self::NotImplemented {
typ: vdaf_type,
Expand Down Expand Up @@ -102,7 +102,7 @@ impl ParameterizedDecode<DapVersion> for DpConfig {
let dp_mechanism = u8::decode(bytes)?;
match (version, dp_mechanism) {
(.., DP_MECHANISM_NONE) => {
decode_u16_item_for_version::<()>(version, bytes)?;
decode_u16_item_for_version::<()>(*version, bytes)?;
Ok(Self::None)
}
(DapVersion::Draft07, ..) => Ok(Self::NotImplemented {
Expand Down Expand Up @@ -240,11 +240,11 @@ impl ParameterizedDecode<DapVersion> for QueryConfig {
let query_type = query_type.unwrap_or(u8::decode(bytes)?);
let var = match (version, query_type) {
(.., QUERY_TYPE_TIME_INTERVAL) => {
decode_u16_item_for_version::<()>(version, bytes)?;
decode_u16_item_for_version::<()>(*version, bytes)?;
QueryConfigVar::TimeInterval
}
(.., QUERY_TYPE_FIXED_SIZE) => QueryConfigVar::FixedSize {
max_batch_size: decode_u16_item_for_version(version, bytes)?,
max_batch_size: decode_u16_item_for_version(*version, bytes)?,
},
(DapVersion::Draft07, ..) => QueryConfigVar::NotImplemented {
typ: query_type,
Expand Down
67 changes: 36 additions & 31 deletions daphne/src/roles/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,40 +57,45 @@ pub trait DapHelper<S>: DapAggregator<S> {

metrics.agg_job_observe_batch_size(agg_job_init_req.prep_inits.len());

// taskprov: Resolve the task config to use for the request. We also need to ensure
// that all of the reports include the task config in the report extensions. (See
// section 6 of draft-wang-ppm-dap-taskprov-02.)
// taskprov: Resolve the task config to use for the request.
if self.get_global_config().allow_taskprov {
let using_taskprov = agg_job_init_req
.prep_inits
.iter()
.filter(|prep_init| {
prep_init
.report_share
.report_metadata
.is_taskprov(req.version, task_id)
})
.count();

let first_metadata = match using_taskprov {
0 => None,
c if c == agg_job_init_req.prep_inits.len() => {
// All the extensions use taskprov and look ok, so compute first_metadata.
// Note this will always be Some().
agg_job_init_req
.prep_inits
.first()
.map(|prep_init| &prep_init.report_share.report_metadata)
}
_ => {
// It's not all taskprov or no taskprov, so it's an error.
return Err(DapAbort::InvalidMessage {
detail: "some reports include the taskprov extensions and some do not"
.to_string(),
task_id: Some(*task_id),
});
// draft02 compatibility: We also need to ensure that all of the reports include the task
// config in the report extensions. (See section 6 of draft-wang-ppm-dap-taskprov-02.)
let first_metadata = if req.version == DapVersion::default() {
let using_taskprov = agg_job_init_req
.prep_inits
.iter()
.filter(|prep_init| {
prep_init
.report_share
.report_metadata
.is_taskprov(req.version, task_id)
})
.count();

match using_taskprov {
0 => None,
c if c == agg_job_init_req.prep_inits.len() => {
// All the extensions use taskprov and look ok, so compute first_metadata.
// Note this will always be Some().
agg_job_init_req
.prep_inits
.first()
.map(|prep_init| &prep_init.report_share.report_metadata)
}
_ => {
// It's not all taskprov or no taskprov, so it's an error.
return Err(DapAbort::InvalidMessage {
detail: "some reports include the taskprov extensions and some do not"
.to_string(),
task_id: Some(*task_id),
});
}
}
} else {
None
};

resolve_taskprov(self, task_id, req, first_metadata).await?;
}

Expand Down
5 changes: 4 additions & 1 deletion daphne/src/roles/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1795,7 +1795,10 @@ mod test {
&task_id,
DapMeasurement::U32Vec(vec![1; 10]),
vec![Extension::Taskprov {
payload: taskprov_report_extension_payload.clone(),
draft02_payload: match version {
DapVersion::Draft07 => None,
DapVersion::Draft02 => Some(taskprov_report_extension_payload.clone()),
},
}],
version,
)
Expand Down
16 changes: 10 additions & 6 deletions daphne/src/taskprov.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,10 @@ fn get_taskprov_task_config<S>(
match taskprovs.len() {
0 => return Ok(None),
1 => match &taskprovs[0] {
Extension::Taskprov { payload } => Cow::Borrowed(payload),
Extension::Unhandled { .. } => panic!("cannot happen"),
Extension::Taskprov {
draft02_payload: Some(payload),
} => Cow::Borrowed(payload),
_ => panic!("cannot happen"),
},
_ => {
// The decoder already returns an error if an extension of a give type occurs more
Expand Down Expand Up @@ -359,11 +361,13 @@ impl TryFrom<&DapTaskConfig> for messages::taskprov::TaskConfig {

impl ReportMetadata {
/// Does this metatdata have a taskprov extension and does it match the specified id?
pub fn is_taskprov(&self, version: DapVersion, task_id: &TaskId) -> bool {
pub(crate) fn is_taskprov(&self, version: DapVersion, task_id: &TaskId) -> bool {
return self.draft02_extensions.as_ref().is_some_and(|extensions| {
extensions.iter().any(|x| match x {
Extension::Taskprov { payload } => *task_id == compute_task_id(version, payload),
Extension::Unhandled { .. } => false,
Extension::Taskprov {
draft02_payload: Some(payload),
} => *task_id == compute_task_id(version, payload),
_ => false,
})
});
}
Expand Down Expand Up @@ -521,7 +525,7 @@ mod test {
id: ReportId([0; 16]),
time: 0,
draft02_extensions: Some(vec![Extension::Taskprov {
payload: taskprov_task_config_data,
draft02_payload: Some(taskprov_task_config_data),
}]),
}),
)
Expand Down
Loading

0 comments on commit 2bdfc8c

Please sign in to comment.