diff --git a/crates/dapf/src/acceptance/mod.rs b/crates/dapf/src/acceptance/mod.rs index a2ecd0ad..f54d339c 100644 --- a/crates/dapf/src/acceptance/mod.rs +++ b/crates/dapf/src/acceptance/mod.rs @@ -687,7 +687,7 @@ impl Test { .. } = &test_task_config; - info!("task id: {}", task_id.to_hex()); + info!("task id: {task_id}"); // Generate enough reports to complete a batch. let measurement = match &opt.measurement { diff --git a/crates/daphne-server/src/roles/aggregator.rs b/crates/daphne-server/src/roles/aggregator.rs index 78a53e00..a71ceb1d 100644 --- a/crates/daphne-server/src/roles/aggregator.rs +++ b/crates/daphne-server/src/roles/aggregator.rs @@ -37,7 +37,6 @@ impl DapAggregator for crate::App { task_config: &DapTaskConfig, agg_share_span: DapAggregateSpan, ) -> DapAggregateSpan> { - let task_id_hex = task_id.to_hex(); let durable = self.durable(); let replay_protection = fetch_replay_protection_override(self.kv()).await; @@ -47,7 +46,7 @@ impl DapAggregator for crate::App { let result = durable .request( bindings::AggregateStore::Merge, - (task_config.version, &task_id_hex, &bucket), + (task_config.version, task_id, &bucket), ) .encode(&AggregateStoreMergeReq { contained_reports: report_metadatas.iter().map(|(id, _)| *id).collect(), @@ -96,7 +95,7 @@ impl DapAggregator for crate::App { durable .request( bindings::AggregateStore::Get, - (task_config.as_ref().version, &task_id.to_hex(), &bucket), + (task_config.as_ref().version, task_id, &bucket), ) .send(), ); @@ -132,7 +131,7 @@ impl DapAggregator for crate::App { durable .request( bindings::AggregateStore::MarkCollected, - (task_config.as_ref().version, &task_id.to_hex(), &bucket), + (task_config.as_ref().version, task_id, &bucket), ) .send::<()>(), ); @@ -289,7 +288,7 @@ impl DapAggregator for crate::App { durable .request( bindings::AggregateStore::CheckCollected, - (task_config.as_ref().version, &task_id.to_hex(), &bucket), + (task_config.as_ref().version, task_id, &bucket), ) .send() }) @@ -320,10 +319,7 @@ impl DapAggregator for crate::App { Ok::( !self .durable() - .request( - bindings::AggregateStore::Get, - (version, &task_id.to_hex(), &bucket), - ) + .request(bindings::AggregateStore::Get, (version, task_id, &bucket)) .send::() .await .map_err(|e| fatal_error!(err = ?e, "failed to get an agg share"))? diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs b/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs index 7cccf9f3..2a27b97a 100644 --- a/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs +++ b/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs @@ -4,7 +4,9 @@ use std::collections::HashSet; use daphne::{ - messages::ReportId, vdaf::VdafAggregateShare, DapAggregateShare, DapBatchBucket, DapVersion, + messages::{ReportId, TaskId}, + vdaf::VdafAggregateShare, + DapAggregateShare, DapBatchBucket, DapVersion, }; use serde::{Deserialize, Serialize}; @@ -25,20 +27,20 @@ super::define_do_binding! { CheckCollected = "/internal/do/aggregate_store/check_collected", } - fn name((version, task_id_hex, bucket): (DapVersion, &'n str, &'n DapBatchBucket)) -> ObjectIdFrom { + fn name((version, task_id, bucket): (DapVersion, &'n TaskId, &'n DapBatchBucket)) -> ObjectIdFrom { fn durable_name_bucket(bucket: &DapBatchBucket) -> String { format!("{bucket}") } ObjectIdFrom::Name(format!( "{}/{}", - durable_name_task(version, task_id_hex), + durable_name_task(version, task_id), durable_name_bucket(bucket), )) } } -fn durable_name_task(version: DapVersion, task_id_hex: &str) -> String { - format!("{}/task/{}", version.as_ref(), task_id_hex) +fn durable_name_task(version: DapVersion, task_id: &TaskId) -> String { + format!("{}/task/{task_id}", version.as_ref()) } #[derive(Debug, PartialEq, Eq)] diff --git a/crates/daphne-service-utils/src/durable_requests/mod.rs b/crates/daphne-service-utils/src/durable_requests/mod.rs index a4f220fc..a5a007ff 100644 --- a/crates/daphne-service-utils/src/durable_requests/mod.rs +++ b/crates/daphne-service-utils/src/durable_requests/mod.rs @@ -295,7 +295,8 @@ where #[cfg(test)] mod test { - use daphne::{DapBatchBucket, DapVersion}; + use daphne::{messages::TaskId, DapBatchBucket, DapVersion}; + use rand::{thread_rng, Rng}; use crate::durable_requests::bindings::AggregateStore; @@ -307,7 +308,7 @@ mod test { AggregateStore::Merge, ( DapVersion::Draft09, - "some-task-id-hex", + &TaskId(thread_rng().gen()), &DapBatchBucket::TimeInterval { batch_window: 0, shard: 17, @@ -327,7 +328,7 @@ mod test { bindings::AggregateStore::Merge, ( DapVersion::Draft09, - "some-task-id-hex", + &TaskId(thread_rng().gen()), &DapBatchBucket::TimeInterval { batch_window: 0, shard: 16, diff --git a/crates/daphne/src/error/aborts.rs b/crates/daphne/src/error/aborts.rs index 99ca7913..cb3a9e6c 100644 --- a/crates/daphne/src/error/aborts.rs +++ b/crates/daphne/src/error/aborts.rs @@ -9,7 +9,6 @@ use crate::{ messages::{AggregationJobId, ReportId, TaskId, TransitionFailure}, DapError, DapRequestMeta, DapVersion, }; -use hex::FromHexError; use prio::codec::CodecError; use serde::{Deserialize, Serialize}; @@ -319,13 +318,6 @@ impl DapAbort { task_id, } } - - pub fn from_hex_error(e: FromHexError, task_id: TaskId) -> Self { - Self::InvalidMessage { - detail: format!("invalid hexadecimal string {e:?}"), - task_id, - } - } } /// A problem details document compatible with RFC 7807. diff --git a/crates/daphne/src/lib.rs b/crates/daphne/src/lib.rs index 33b38e99..a3c7f42e 100644 --- a/crates/daphne/src/lib.rs +++ b/crates/daphne/src/lib.rs @@ -314,7 +314,7 @@ impl std::fmt::Display for DapBatchBucket { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::TimeInterval { batch_window, .. } => write!(f, "window/{batch_window}")?, - Self::FixedSize { batch_id, .. } => write!(f, "batch/{}", batch_id.to_hex())?, + Self::FixedSize { batch_id, .. } => write!(f, "batch/{batch_id}")?, }; let shard = self.shard(); diff --git a/crates/daphne/src/messages/mod.rs b/crates/daphne/src/messages/mod.rs index 15328f41..1a051351 100644 --- a/crates/daphne/src/messages/mod.rs +++ b/crates/daphne/src/messages/mod.rs @@ -52,14 +52,7 @@ macro_rules! id_struct { Copy, Clone, Default, Deserialize, Hash, PartialEq, Eq, Serialize, PartialOrd, Ord, )] #[cfg_attr(any(test, feature = "test-utils"), derive(deepsize::DeepSizeOf))] - pub struct $sname(#[serde(with = "hex")] pub [u8; $len]); - - impl $sname { - /// Return the ID encoded as a hex string. - pub fn to_hex(&self) -> String { - hex::encode(self.0) - } - } + pub struct $sname(#[serde(with = "base64url_bytes")] pub [u8; $len]); impl $crate::messages::Base64Encode for $sname { /// Return the URL-safe, base64 encoding of the ID. @@ -102,13 +95,13 @@ macro_rules! id_struct { impl fmt::Display for $sname { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.to_hex()) + write!(f, "{}", self.to_base64url()) } } impl fmt::Debug for $sname { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}({})", ::std::stringify!($sname), self.to_hex()) + write!(f, "{}({self})", ::std::stringify!($sname)) } } }; @@ -120,6 +113,57 @@ id_struct!(CollectionJobId, 16, "Collection Job ID"); id_struct!(ReportId, 16, "Report ID"); id_struct!(TaskId, 32, "Task ID"); +/// module to serialize and deserialize types ids into base64 +mod base64url_bytes { + use serde::{de, ser}; + + use crate::messages::decode_base64url; + + use super::encode_base64url; + + pub fn serialize(id: &I, serializer: S) -> Result + where + I: AsRef<[u8]>, + S: ser::Serializer, + { + serializer.serialize_str(&encode_base64url(id)) + } + + pub fn deserialize<'de, const N: usize, O, D>(deserializer: D) -> Result + where + D: de::Deserializer<'de>, + D::Error: de::Error, + O: From<[u8; N]>, + { + struct Visitor(std::marker::PhantomData<[O; N]>); + impl<'de, const N: usize, O> de::Visitor<'de> for Visitor + where + O: From<[u8; N]>, + { + type Value = O; + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a base64 encoded value") + } + fn visit_str(self, v: &str) -> Result + where + E: de::Error, + { + decode_base64url(v) + .map(|v| O::from(v)) + .ok_or_else(|| E::custom("invalid base64")) + } + + fn visit_string(self, v: String) -> Result + where + E: de::Error, + { + self.visit_str(&v) + } + } + deserializer.deserialize_str(Visitor::(std::marker::PhantomData)) + } +} + /// serde module for base64url-encoded serialization of ids pub mod base64url { use serde::{de, ser};