Skip to content

Commit

Permalink
Display Id types as base64 instead of hex
Browse files Browse the repository at this point in the history
  • Loading branch information
mendess committed Oct 10, 2024
1 parent 5441b0e commit 272fa7e
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 37 deletions.
2 changes: 1 addition & 1 deletion crates/dapf/src/acceptance/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ impl Test {
..
} = &test_task_config;

info!("task id: {}", task_id.to_hex());
info!("task id: {task_id}");

// Generate enough reports to complete a batch.
let measurement = match &opt.measurement {
Expand Down
14 changes: 5 additions & 9 deletions crates/daphne-server/src/roles/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ impl DapAggregator for crate::App {
task_config: &DapTaskConfig,
agg_share_span: DapAggregateSpan<DapAggregateShare>,
) -> DapAggregateSpan<Result<(), MergeAggShareError>> {
let task_id_hex = task_id.to_hex();
let durable = self.durable();

let replay_protection = fetch_replay_protection_override(self.kv()).await;
Expand All @@ -47,7 +46,7 @@ impl DapAggregator for crate::App {
let result = durable
.request(
bindings::AggregateStore::Merge,
(task_config.version, &task_id_hex, &bucket),
(task_config.version, task_id, &bucket),
)
.encode(&AggregateStoreMergeReq {
contained_reports: report_metadatas.iter().map(|(id, _)| *id).collect(),
Expand Down Expand Up @@ -96,7 +95,7 @@ impl DapAggregator for crate::App {
durable
.request(
bindings::AggregateStore::Get,
(task_config.as_ref().version, &task_id.to_hex(), &bucket),
(task_config.as_ref().version, task_id, &bucket),
)
.send(),
);
Expand Down Expand Up @@ -132,7 +131,7 @@ impl DapAggregator for crate::App {
durable
.request(
bindings::AggregateStore::MarkCollected,
(task_config.as_ref().version, &task_id.to_hex(), &bucket),
(task_config.as_ref().version, task_id, &bucket),
)
.send::<()>(),
);
Expand Down Expand Up @@ -289,7 +288,7 @@ impl DapAggregator for crate::App {
durable
.request(
bindings::AggregateStore::CheckCollected,
(task_config.as_ref().version, &task_id.to_hex(), &bucket),
(task_config.as_ref().version, task_id, &bucket),
)
.send()
})
Expand Down Expand Up @@ -320,10 +319,7 @@ impl DapAggregator for crate::App {
Ok::<bool, DapError>(
!self
.durable()
.request(
bindings::AggregateStore::Get,
(version, &task_id.to_hex(), &bucket),
)
.request(bindings::AggregateStore::Get, (version, task_id, &bucket))
.send::<DapAggregateShare>()
.await
.map_err(|e| fatal_error!(err = ?e, "failed to get an agg share"))?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
use std::collections::HashSet;

use daphne::{
messages::ReportId, vdaf::VdafAggregateShare, DapAggregateShare, DapBatchBucket, DapVersion,
messages::{ReportId, TaskId},
vdaf::VdafAggregateShare,
DapAggregateShare, DapBatchBucket, DapVersion,
};
use serde::{Deserialize, Serialize};

Expand All @@ -25,20 +27,20 @@ super::define_do_binding! {
CheckCollected = "/internal/do/aggregate_store/check_collected",
}

fn name((version, task_id_hex, bucket): (DapVersion, &'n str, &'n DapBatchBucket)) -> ObjectIdFrom {
fn name((version, task_id, bucket): (DapVersion, &'n TaskId, &'n DapBatchBucket)) -> ObjectIdFrom {
fn durable_name_bucket(bucket: &DapBatchBucket) -> String {
format!("{bucket}")
}
ObjectIdFrom::Name(format!(
"{}/{}",
durable_name_task(version, task_id_hex),
durable_name_task(version, task_id),
durable_name_bucket(bucket),
))
}
}

fn durable_name_task(version: DapVersion, task_id_hex: &str) -> String {
format!("{}/task/{}", version.as_ref(), task_id_hex)
fn durable_name_task(version: DapVersion, task_id: &TaskId) -> String {
format!("{}/task/{task_id}", version.as_ref())
}

#[derive(Debug, PartialEq, Eq)]
Expand Down
7 changes: 4 additions & 3 deletions crates/daphne-service-utils/src/durable_requests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ where

#[cfg(test)]
mod test {
use daphne::{DapBatchBucket, DapVersion};
use daphne::{messages::TaskId, DapBatchBucket, DapVersion};
use rand::{thread_rng, Rng};

use crate::durable_requests::bindings::AggregateStore;

Expand All @@ -307,7 +308,7 @@ mod test {
AggregateStore::Merge,
(
DapVersion::Draft09,
"some-task-id-hex",
&TaskId(thread_rng().gen()),
&DapBatchBucket::TimeInterval {
batch_window: 0,
shard: 17,
Expand All @@ -327,7 +328,7 @@ mod test {
bindings::AggregateStore::Merge,
(
DapVersion::Draft09,
"some-task-id-hex",
&TaskId(thread_rng().gen()),
&DapBatchBucket::TimeInterval {
batch_window: 0,
shard: 16,
Expand Down
8 changes: 0 additions & 8 deletions crates/daphne/src/error/aborts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use crate::{
messages::{AggregationJobId, ReportId, TaskId, TransitionFailure},
DapError, DapRequestMeta, DapVersion,
};
use hex::FromHexError;
use prio::codec::CodecError;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -319,13 +318,6 @@ impl DapAbort {
task_id,
}
}

pub fn from_hex_error(e: FromHexError, task_id: TaskId) -> Self {
Self::InvalidMessage {
detail: format!("invalid hexadecimal string {e:?}"),
task_id,
}
}
}

/// A problem details document compatible with RFC 7807.
Expand Down
2 changes: 1 addition & 1 deletion crates/daphne/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ impl std::fmt::Display for DapBatchBucket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TimeInterval { batch_window, .. } => write!(f, "window/{batch_window}")?,
Self::FixedSize { batch_id, .. } => write!(f, "batch/{}", batch_id.to_hex())?,
Self::FixedSize { batch_id, .. } => write!(f, "batch/{batch_id}")?,
};

let shard = self.shard();
Expand Down
64 changes: 54 additions & 10 deletions crates/daphne/src/messages/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,7 @@ macro_rules! id_struct {
Copy, Clone, Default, Deserialize, Hash, PartialEq, Eq, Serialize, PartialOrd, Ord,
)]
#[cfg_attr(any(test, feature = "test-utils"), derive(deepsize::DeepSizeOf))]
pub struct $sname(#[serde(with = "hex")] pub [u8; $len]);

impl $sname {
/// Return the ID encoded as a hex string.
pub fn to_hex(&self) -> String {
hex::encode(self.0)
}
}
pub struct $sname(#[serde(with = "base64url_bytes")] pub [u8; $len]);

impl $crate::messages::Base64Encode for $sname {
/// Return the URL-safe, base64 encoding of the ID.
Expand Down Expand Up @@ -102,13 +95,13 @@ macro_rules! id_struct {

impl fmt::Display for $sname {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_hex())
write!(f, "{}", self.to_base64url())
}
}

impl fmt::Debug for $sname {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}({})", ::std::stringify!($sname), self.to_hex())
write!(f, "{}({self})", ::std::stringify!($sname))
}
}
};
Expand All @@ -120,6 +113,57 @@ id_struct!(CollectionJobId, 16, "Collection Job ID");
id_struct!(ReportId, 16, "Report ID");
id_struct!(TaskId, 32, "Task ID");

/// module to serialize and deserialize types ids into base64
mod base64url_bytes {
use serde::{de, ser};

use crate::messages::decode_base64url;

use super::encode_base64url;

pub fn serialize<I, S>(id: &I, serializer: S) -> Result<S::Ok, S::Error>
where
I: AsRef<[u8]>,
S: ser::Serializer,
{
serializer.serialize_str(&encode_base64url(id))
}

pub fn deserialize<'de, const N: usize, O, D>(deserializer: D) -> Result<O, D::Error>
where
D: de::Deserializer<'de>,
D::Error: de::Error,
O: From<[u8; N]>,
{
struct Visitor<const N: usize, O>(std::marker::PhantomData<[O; N]>);
impl<'de, const N: usize, O> de::Visitor<'de> for Visitor<N, O>
where
O: From<[u8; N]>,
{
type Value = O;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a base64 encoded value")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
decode_base64url(v)
.map(|v| O::from(v))
.ok_or_else(|| E::custom("invalid base64"))
}

fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: de::Error,
{
self.visit_str(&v)
}
}
deserializer.deserialize_str(Visitor::<N, O>(std::marker::PhantomData))
}
}

/// serde module for base64url-encoded serialization of ids
pub mod base64url {
use serde::{de, ser};
Expand Down

0 comments on commit 272fa7e

Please sign in to comment.