Skip to content

Commit

Permalink
Make HpkeDecrypter not async
Browse files Browse the repository at this point in the history
  • Loading branch information
mendess committed Oct 9, 2024
1 parent 5441b0e commit c1cb9c1
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 102 deletions.
60 changes: 27 additions & 33 deletions crates/daphne-server/src/roles/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use daphne::{
audit_log::AuditLog,
error::DapAbort,
fatal_error,
hpke::{HpkeConfig, HpkeDecrypter, HpkeProvider},
messages::{self, BatchId, BatchSelector, HpkeCiphertext, TaskId, Time, TransitionFailure},
hpke::{self, HpkeConfig, HpkeProvider, HpkeReceiverConfig},
messages::{self, BatchId, BatchSelector, HpkeCiphertext, TaskId, Time},
metrics::DaphneMetrics,
roles::{aggregator::MergeAggShareError, DapAggregator, DapReportInitializer},
taskprov, DapAggregateShare, DapAggregateSpan, DapAggregationParam, DapError, DapGlobalConfig,
Expand Down Expand Up @@ -393,9 +393,23 @@ impl DapReportInitializer for crate::App {
}
}

pub struct HpkeDecrypter(Marc<Vec<HpkeReceiverConfig>>);

impl hpke::HpkeDecrypter for HpkeDecrypter {
fn hpke_decrypt(
&self,
info: &[u8],
aad: &[u8],
ciphertext: &HpkeCiphertext,
) -> Result<Vec<u8>, DapError> {
self.0.hpke_decrypt(info, aad, ciphertext)
}
}

#[async_trait]
impl HpkeProvider for crate::App {
type WrappedHpkeConfig<'s> = Marc<HpkeConfig>;
type ReceiverConfigs<'s> = HpkeDecrypter;

async fn get_hpke_config_for<'s>(
&'s self,
Expand Down Expand Up @@ -439,37 +453,17 @@ impl HpkeProvider for crate::App {
.map_err(|e| fatal_error!(err = ?e, "failed to get at the hpke config"))?
.unwrap_or(false))
}
}

#[async_trait]
impl HpkeDecrypter for crate::App {
async fn hpke_decrypt(
&self,
task_id: &TaskId,
info: &[u8],
aad: &[u8],
ciphertext: &HpkeCiphertext,
) -> Result<Vec<u8>, DapError> {
let version = self
.get_task_config_for(task_id)
.await?
.as_ref()
.ok_or(DapAbort::UnrecognizedTask { task_id: *task_id })?
.version;
self.kv()
.peek::<kv::prefix::HpkeReceiverConfigSet, _, _>(
&version,
&KvGetOptions::default(),
|config_list| {
config_list
.iter()
.find(|receiver| receiver.config.id == ciphertext.config_id)
.map(|receiver| receiver.decrypt(info, aad, ciphertext))
},
)
.await
.map_err(|e| fatal_error!(err = ?e, "failed to get the hpke config"))?
.flatten()
.ok_or(DapError::Transition(TransitionFailure::HpkeUnknownConfigId))?
async fn get_receiver_configs<'s>(
&'s self,
version: DapVersion,
) -> Result<Self::ReceiverConfigs<'s>, DapError> {
Ok(HpkeDecrypter(
self.kv()
.get::<kv::prefix::HpkeReceiverConfigSet>(&version, &KvGetOptions::default())
.await
.map_err(|e| fatal_error!(err= ?e,"failed to get the hpke config"))?
.ok_or_else(|| fatal_error!(err="there are no hpke configs in kv!!", %version))?,
))
}
}
108 changes: 82 additions & 26 deletions crates/daphne/src/hpke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::{
};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::ops::Deref;
use std::{borrow::Borrow, ops::Deref};

// Various algorithm constants
const KEM_ID_X25519_HKDF_SHA256: u16 = 0x0020;
Expand Down Expand Up @@ -171,12 +171,6 @@ impl deepsize::DeepSizeOf for HpkeConfig {
}
}

impl AsRef<HpkeConfig> for HpkeConfig {
fn as_ref(&self) -> &Self {
self
}
}

impl HpkeConfig {
/// Encrypt `plaintext` with info string `info` and associated data `aad` using this HPKE
/// configuration. The return values are the encapsulated key and the ciphertext.
Expand Down Expand Up @@ -215,35 +209,113 @@ impl HpkeConfig {
}

#[async_trait]
pub trait HpkeProvider: HpkeDecrypter {
pub trait HpkeProvider {
/// Return type of `get_hpke_config_for()`, wraps a reference to an HPKE config.
type WrappedHpkeConfig<'a>: Deref<Target = HpkeConfig> + Send
where
Self: 'a;

type ReceiverConfigs<'a>: HpkeDecrypter
where
Self: 'a;

/// Look up the HPKE configuration to use for the given task ID (if specified).
async fn get_hpke_config_for<'s>(
&'s self,
version: DapVersion,
task_id: Option<&TaskId>,
) -> Result<Self::WrappedHpkeConfig<'s>, DapError>;

async fn get_receiver_configs<'s>(
&'s self,
version: DapVersion,
) -> Result<Self::ReceiverConfigs<'s>, DapError>;

/// Returns `true` if a ciphertext with the HPKE config ID can be consumed in the current task.
async fn can_hpke_decrypt(&self, task_id: &TaskId, config_id: u8) -> Result<bool, DapError>;
}

#[async_trait]
pub trait HpkeDecrypter {
/// Decrypt the given HPKE ciphertext using the given info and AAD string.
async fn hpke_decrypt(
fn hpke_decrypt(
&self,
task_id: &TaskId,
info: &[u8],
aad: &[u8],
ciphertext: &HpkeCiphertext,
) -> Result<Vec<u8>, DapError>;
}

impl<T> HpkeDecrypter for &T
where
T: HpkeDecrypter,
{
fn hpke_decrypt(
&self,
info: &[u8],
aad: &[u8],
ciphertext: &HpkeCiphertext,
) -> Result<Vec<u8>, DapError> {
<T as HpkeDecrypter>::hpke_decrypt(self, info, aad, ciphertext)
}
}

macro_rules! impl_decrypter_for_slice_types {
($($(const $n:ident : usize)? $ty:ty),*$(,)?) => {
$(
impl$(<const $n: usize>)* HpkeDecrypter for $ty {
fn hpke_decrypt(
&self,
info: &[u8],
aad: &[u8],
ciphertext: &HpkeCiphertext,
) -> Result<Vec<u8>, DapError> {
self.iter().hpke_decrypt(info, aad, ciphertext)
}
}
)*
}
}

impl_decrypter_for_slice_types!(
Vec<HpkeReceiverConfig>,
&'_ [HpkeReceiverConfig],
&'_ [&'_ HpkeReceiverConfig],
const N: usize [HpkeReceiverConfig; N],
const N: usize [&'_ HpkeReceiverConfig; N],
);

impl<R> HpkeDecrypter for std::slice::Iter<'_, R>
where
R: Borrow<HpkeReceiverConfig>,
{
fn hpke_decrypt(
&self,
info: &[u8],
aad: &[u8],
ciphertext: &HpkeCiphertext,
) -> Result<Vec<u8>, DapError> {
self.clone()
.map(|c| c.borrow())
.find(|c| c.config.id == ciphertext.config_id)
.ok_or(DapError::Transition(TransitionFailure::HpkeUnknownConfigId))?
.decrypt(info, aad, ciphertext)
}
}

// This let's us use a single config during tests to simplify test code.
#[cfg(any(test, feature = "test-utils"))]
impl HpkeDecrypter for HpkeReceiverConfig {
fn hpke_decrypt(
&self,
info: &[u8],
aad: &[u8],
ciphertext: &HpkeCiphertext,
) -> Result<Vec<u8>, DapError> {
[self].hpke_decrypt(info, aad, ciphertext)
}
}

/// Struct that combines `HpkeConfig` and `HpkeSecretKey`
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct HpkeReceiverConfig {
Expand Down Expand Up @@ -339,22 +411,6 @@ impl TryFrom<(HpkeConfig, HpkePrivateKey)> for HpkeReceiverConfig {
}
}

// This let's us use a single config during tests to simplify test code.
#[cfg(any(test, feature = "test-utils"))]
#[async_trait]
impl HpkeDecrypter for HpkeReceiverConfig {
async fn hpke_decrypt(
&self,
_task_id: &TaskId,
info: &[u8],
aad: &[u8],
ciphertext: &HpkeCiphertext,
) -> Result<Vec<u8>, DapError> {
self.config
.decrypt(&self.private_key, info, aad, ciphertext)
}
}

impl std::str::FromStr for HpkeReceiverConfig {
type Err = serde_json::Error;

Expand Down
32 changes: 15 additions & 17 deletions crates/daphne/src/protocol/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,19 +151,17 @@ impl EarlyReportStateConsumed {
.map_err(DapError::encoding)?;
encode_u32_bytes(&mut aad, &report_share.public_share).map_err(DapError::encoding)?;

let encoded_input_share = match decrypter
.hpke_decrypt(task_id, &info, &aad, &report_share.encrypted_input_share)
.await
{
Ok(encoded_input_share) => encoded_input_share,
Err(DapError::Transition(failure)) => {
return Ok(Self::Rejected {
metadata: report_share.report_metadata,
failure,
})
}
Err(e) => return Err(e),
};
let encoded_input_share =
match decrypter.hpke_decrypt(&info, &aad, &report_share.encrypted_input_share) {
Ok(encoded_input_share) => encoded_input_share,
Err(DapError::Transition(failure)) => {
return Ok(Self::Rejected {
metadata: report_share.report_metadata,
failure,
})
}
Err(e) => return Err(e),
};

let (input_share, extensions) = {
match PlaintextInputShare::get_decoded_with_param(
Expand Down Expand Up @@ -405,7 +403,7 @@ impl DapTaskConfig {
#[allow(clippy::too_many_arguments)]
pub async fn produce_agg_job_req<S>(
&self,
decrypter: &impl HpkeDecrypter,
decrypter: impl HpkeDecrypter,
initializer: &impl DapReportInitializer,
task_id: &TaskId,
part_batch_sel: &PartialBatchSelector,
Expand All @@ -432,7 +430,7 @@ impl DapTaskConfig {
#[allow(clippy::too_many_arguments)]
async fn produce_agg_job_req_impl<S>(
&self,
decrypter: &impl HpkeDecrypter,
decrypter: impl HpkeDecrypter,
initializer: &impl DapReportInitializer,
task_id: &TaskId,
part_batch_sel: &PartialBatchSelector,
Expand Down Expand Up @@ -467,7 +465,7 @@ impl DapTaskConfig {

consumed_reports.push(
EarlyReportStateConsumed::consume(
decrypter,
&decrypter,
initializer,
true,
task_id,
Expand Down Expand Up @@ -556,7 +554,7 @@ impl DapTaskConfig {
#[cfg(any(test, feature = "test-utils"))]
pub async fn test_produce_agg_job_req<S>(
&self,
decrypter: &impl HpkeDecrypter,
decrypter: impl HpkeDecrypter,
initializer: &impl DapReportInitializer,
task_id: &TaskId,
part_batch_sel: &PartialBatchSelector,
Expand Down
4 changes: 1 addition & 3 deletions crates/daphne/src/protocol/collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ impl VdafConfig {
CTX_ROLE_HELPER
};

let agg_share_data = decrypter
.hpke_decrypt(task_id, &info, &aad, agg_share_ciphertext)
.await?;
let agg_share_data = decrypter.hpke_decrypt(&info, &aad, agg_share_ciphertext)?;
agg_shares.push(agg_share_data);
}

Expand Down
2 changes: 1 addition & 1 deletion crates/daphne/src/roles/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub async fn handle_agg_job_init_req<A: DapHelper>(
let version = req.version;
let initialized_reports = task_config
.consume_agg_job_req(
aggregator,
&aggregator.get_receiver_configs(task_config.version).await?,
aggregator,
&task_id,
req.payload,
Expand Down
3 changes: 2 additions & 1 deletion crates/daphne/src/roles/leader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,10 @@ async fn run_agg_job<A: DapLeader>(

// Prepare AggregationJobInitReq.
let agg_job_id = AggregationJobId(thread_rng().gen());
let decrypter = aggregator.get_receiver_configs(task_config.version).await?;
let (agg_job_state, agg_job_init_req) = task_config
.produce_agg_job_req(
aggregator,
decrypter,
aggregator,
task_id,
part_batch_sel,
Expand Down
2 changes: 1 addition & 1 deletion crates/daphne/src/roles/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ mod test {

let (leader_state, agg_job_init_req) = task_config
.produce_agg_job_req(
&*self.leader,
&*self.leader.hpke_receiver_config_list,
&*self.leader,
task_id,
&part_batch_sel,
Expand Down
Loading

0 comments on commit c1cb9c1

Please sign in to comment.