Skip to content

Commit

Permalink
vdaf: Isolate prio_draft09 usage as much as possible
Browse files Browse the repository at this point in the history
Move some methods from the base of the module to where they're used.
This reduces the number of trait collisions there.

Also, rename `prio3_draft09` to `draft09`, as the same code is shared by
`pine`.
  • Loading branch information
cjpatton committed Dec 20, 2024
1 parent 5ac2daf commit 9c79654
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 137 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
// Copyright (c) 2022 Cloudflare, Inc. All rights reserved.
// Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

//! Parameters for the [Prio3 VDAF](https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/09/).
//! DAP-09 compatibility.
use crate::{
fatal_error, messages::taskprov::VDAF_TYPE_PRIO3_SUM_VEC_FIELD64_MULTIPROOF_HMAC_SHA256_AES128,
vdaf::VdafError,
};

use prio_draft09::{
codec::ParameterizedDecode,
codec::{Encode, ParameterizedDecode},
field::Field64,
flp::{
gadgets::{Mul, ParallelSum},
Expand All @@ -19,7 +19,7 @@ use prio_draft09::{
vdaf::{
prio3::{Prio3, Prio3InputShare, Prio3PrepareShare, Prio3PrepareState, Prio3PublicShare},
xof::{Xof, XofHmacSha256Aes128},
Aggregator,
Aggregator, Client, Collector, PrepareTransition, Vdaf,
},
};

Expand Down Expand Up @@ -51,7 +51,29 @@ type Prio3Draft09Prepared<T, const SEED_SIZE: usize> = (
Prio3PrepareShare<<T as Type>::Field, SEED_SIZE>,
);

pub(crate) fn prep_init_draft09<T, P, const SEED_SIZE: usize>(
pub(crate) fn shard_then_encode<V: Vdaf + Client<16>>(
vdaf: &V,
measurement: &V::Measurement,
nonce: &[u8; 16],
) -> Result<(Vec<u8>, [Vec<u8>; 2]), VdafError> {
let (public_share, input_shares) = vdaf.shard(measurement, nonce)?;

Ok((
public_share.get_encoded()?,
input_shares
.iter()
.map(|input_share| input_share.get_encoded())
.collect::<Result<Vec<_>, _>>()?
.try_into()
.map_err(|e: Vec<_>| {
VdafError::Dap(fatal_error!(
err = format!("expected 2 input shares got {}", e.len())
))
})?,
))
}

pub(crate) fn prep_init<T, P, const SEED_SIZE: usize>(
vdaf: Prio3<T, P, SEED_SIZE>,
verify_key: &[u8; SEED_SIZE],
agg_id: usize,
Expand All @@ -73,18 +95,87 @@ where
Ok(vdaf.prepare_init(verify_key, agg_id, &(), nonce, &public_share, &input_share)?)
}

pub(crate) fn prep_finish_from_shares<V, const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>(
vdaf: &V,
agg_id: usize,
host_state: V::PrepareState,
host_share: V::PrepareShare,
peer_share_data: &[u8],
) -> Result<(V::OutputShare, Vec<u8>), VdafError>
where
V: Vdaf<AggregationParam = ()> + Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
{
// Decode the Helper's inbound message.
let peer_share = V::PrepareShare::get_decoded_with_param(&host_state, peer_share_data)?;

// Preprocess the inbound messages.
let message = vdaf.prepare_shares_to_prepare_message(
&(),
if agg_id == 0 {
[host_share, peer_share]
} else {
[peer_share, host_share]
},
)?;
let message_data = message.get_encoded()?;

// Compute the host's output share.
match vdaf.prepare_next(host_state, message)? {
PrepareTransition::Continue(..) => Err(VdafError::Dap(fatal_error!(
err = format!("prep_finish_from_shares: unexpected transition")
))),
PrepareTransition::Finish(out_share) => Ok((out_share, message_data)),
}
}

pub(crate) fn prep_finish<V, const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>(
vdaf: &V,
host_state: V::PrepareState,
peer_message_data: &[u8],
) -> Result<V::OutputShare, VdafError>
where
V: Vdaf + Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
{
// Decode the inbound message from the peer, which contains the preprocessed prepare message.
let peer_message = V::PrepareMessage::get_decoded_with_param(&host_state, peer_message_data)?;

// Compute the host's output share.
match vdaf.prepare_next(host_state, peer_message)? {
PrepareTransition::Continue(..) => Err(VdafError::Dap(fatal_error!(
err = format!("prep_finish: unexpected transition"),
))),
PrepareTransition::Finish(out_share) => Ok(out_share),
}
}

pub(crate) fn unshard<V, M>(
vdaf: &V,
num_measurements: usize,
agg_shares: M,
) -> Result<V::AggregateResult, VdafError>
where
V: Vdaf<AggregationParam = ()> + Collector,
M: IntoIterator<Item = Vec<u8>>,
{
let mut agg_shares_vec = Vec::with_capacity(vdaf.num_aggregators());
for data in agg_shares {
let agg_share = V::AggregateShare::get_decoded_with_param(&(vdaf, &()), data.as_ref())?;
agg_shares_vec.push(agg_share);
}
Ok(vdaf.unshard(&(), agg_shares_vec, num_measurements)?)
}

#[cfg(test)]
mod test {

use prio_draft09::vdaf::prio3_test::check_test_vec;

use super::*;

use crate::{
hpke::HpkeKemId,
testing::AggregationJobTest,
vdaf::{
prio3_draft09::new_prio3_sum_vec_field64_multiproof_hmac_sha256_aes128, Prio3Config,
VdafConfig,
},
vdaf::{Prio3Config, VdafConfig},
DapAggregateResult, DapAggregationParam, DapMeasurement, DapVersion,
};

Expand Down Expand Up @@ -114,7 +205,7 @@ mod test {
}

#[test]
fn test_vec_sum_vec_field64_multiproof_hmac_sha256_aes128() {
fn test_vec_sum_vec_field64_multiproof_hmac_sha256_aes128_draft09() {
for test_vec_json_str in [
include_str!("test_vec/Prio3SumVecField64MultiproofHmacSha256Aes128_0.json"),
include_str!("test_vec/Prio3SumVecField64MultiproofHmacSha256Aes128_1.json"),
Expand Down
98 changes: 2 additions & 96 deletions crates/daphne/src/vdaf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
//! Verifiable, Distributed Aggregation Functions
//! ([VDAFs](https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/)).
pub(crate) mod draft09;
#[cfg(feature = "experimental")]
pub(crate) mod mastic;
pub(crate) mod pine;
pub(crate) mod prio2;
pub(crate) mod prio3;
pub(crate) mod prio3_draft09;

use crate::pine::vdaf::PinePrepState;
use crate::{fatal_error, messages::TaskId, DapError};
Expand Down Expand Up @@ -40,9 +40,7 @@ use prio_draft09::{
Prio3PrepareShare as Prio3Draft09PrepareShare,
Prio3PrepareState as Prio3Draft09PrepareState,
},
AggregateShare as AggregateShareDraft09, Aggregator as AggregatorDraft09,
Client as ClientDraft09, Collector as CollectorDraft09,
PrepareTransition as PrepareTransitionDraft09, Vdaf as VdafDraft09,
AggregateShare as AggregateShareDraft09,
},
};
use rand::prelude::*;
Expand Down Expand Up @@ -783,95 +781,3 @@ where
}
Ok(vdaf.unshard(&(), agg_shares_vec, num_measurements)?)
}

pub(crate) fn shard_then_encode_draft09<V: VdafDraft09 + ClientDraft09<16>>(
vdaf: &V,
measurement: &V::Measurement,
nonce: &[u8; 16],
) -> Result<(Vec<u8>, [Vec<u8>; 2]), VdafError> {
let (public_share, input_shares) = vdaf.shard(measurement, nonce)?;

Ok((
public_share.get_encoded()?,
input_shares
.iter()
.map(|input_share| input_share.get_encoded())
.collect::<Result<Vec<_>, _>>()?
.try_into()
.map_err(|e: Vec<_>| {
VdafError::Dap(fatal_error!(
err = format!("expected 2 input shares got {}", e.len())
))
})?,
))
}

fn prep_finish_from_shares_draft09<V, const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>(
vdaf: &V,
agg_id: usize,
host_state: V::PrepareState,
host_share: V::PrepareShare,
peer_share_data: &[u8],
) -> Result<(V::OutputShare, Vec<u8>), VdafError>
where
V: VdafDraft09<AggregationParam = ()> + AggregatorDraft09<VERIFY_KEY_SIZE, NONCE_SIZE>,
{
// Decode the Helper's inbound message.
let peer_share = V::PrepareShare::get_decoded_with_param(&host_state, peer_share_data)?;

// Preprocess the inbound messages.
let message = vdaf.prepare_shares_to_prepare_message(
&(),
if agg_id == 0 {
[host_share, peer_share]
} else {
[peer_share, host_share]
},
)?;
let message_data = message.get_encoded()?;

// Compute the host's output share.
match vdaf.prepare_next(host_state, message)? {
PrepareTransitionDraft09::Continue(..) => Err(VdafError::Dap(fatal_error!(
err = format!("prep_finish_from_shares: unexpected transition")
))),
PrepareTransitionDraft09::Finish(out_share) => Ok((out_share, message_data)),
}
}

fn prep_finish_draft09<V, const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>(
vdaf: &V,
host_state: V::PrepareState,
peer_message_data: &[u8],
) -> Result<V::OutputShare, VdafError>
where
V: VdafDraft09 + AggregatorDraft09<VERIFY_KEY_SIZE, NONCE_SIZE>,
{
// Decode the inbound message from the peer, which contains the preprocessed prepare message.
let peer_message = V::PrepareMessage::get_decoded_with_param(&host_state, peer_message_data)?;

// Compute the host's output share.
match vdaf.prepare_next(host_state, peer_message)? {
PrepareTransitionDraft09::Continue(..) => Err(VdafError::Dap(fatal_error!(
err = format!("prep_finish: unexpected transition"),
))),
PrepareTransitionDraft09::Finish(out_share) => Ok(out_share),
}
}

fn unshard_draft09<V, M>(
vdaf: &V,
num_measurements: usize,
agg_shares: M,
) -> Result<V::AggregateResult, VdafError>
where
V: VdafDraft09<AggregationParam = ()> + CollectorDraft09,
M: IntoIterator<Item = Vec<u8>>,
{
let mut agg_shares_vec = Vec::with_capacity(vdaf.num_aggregators());
for data in agg_shares {
let agg_share = V::AggregateShare::get_decoded_with_param(&(vdaf, &()), data.as_ref())?;
agg_shares_vec.push(agg_share);
}
Ok(vdaf.unshard(&(), agg_shares_vec, num_measurements)?)
}
21 changes: 9 additions & 12 deletions crates/daphne/src/vdaf/pine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ use crate::{
DapAggregateResult, DapMeasurement,
};

use super::{
prep_finish_draft09, prep_finish_from_shares_draft09, shard_then_encode_draft09,
unshard_draft09, VdafAggregateShare, VdafError, VdafPrepShare, VdafPrepState, VdafVerifyKey,
};
use super::{draft09, VdafAggregateShare, VdafError, VdafPrepShare, VdafPrepState, VdafVerifyKey};
use prio_draft09::{
codec::ParameterizedDecode,
field::{FftFriendlyFieldElement, Field64, FieldPrio2},
Expand Down Expand Up @@ -71,11 +68,11 @@ impl PineConfig {
match self {
PineConfig::Field32HmacSha256Aes128 { param } => {
let vdaf = pine32_hmac_sha256_aes128(param)?;
shard_then_encode_draft09(&vdaf, gradient, nonce)
draft09::shard_then_encode(&vdaf, gradient, nonce)
}
PineConfig::Field64HmacSha256Aes128 { param } => {
let vdaf = pine64_hmac_sha256_aes128(param)?;
shard_then_encode_draft09(&vdaf, gradient, nonce)
draft09::shard_then_encode(&vdaf, gradient, nonce)
}
}
}
Expand Down Expand Up @@ -140,7 +137,7 @@ impl PineConfig {
) => {
let vdaf = pine32_hmac_sha256_aes128(param)?;
let (out_share, outbound) =
prep_finish_from_shares_draft09(&vdaf, agg_id, state, share, peer_share_data)?;
draft09::prep_finish_from_shares(&vdaf, agg_id, state, share, peer_share_data)?;
let agg_share = VdafAggregateShare::Field32Draft09(AggregateShare::from(
OutputShare::from(out_share.0),
));
Expand All @@ -153,7 +150,7 @@ impl PineConfig {
) => {
let vdaf = pine64_hmac_sha256_aes128(param)?;
let (out_share, outbound) =
prep_finish_from_shares_draft09(&vdaf, agg_id, state, share, peer_share_data)?;
draft09::prep_finish_from_shares(&vdaf, agg_id, state, share, peer_share_data)?;
let agg_share = VdafAggregateShare::Field64Draft09(AggregateShare::from(
OutputShare::from(out_share.0),
));
Expand All @@ -176,7 +173,7 @@ impl PineConfig {
VdafPrepState::Pine32HmacSha256Aes128(state),
) => {
let vdaf = pine32_hmac_sha256_aes128(param)?;
let out_share = prep_finish_draft09(&vdaf, state, peer_message_data)?;
let out_share = draft09::prep_finish(&vdaf, state, peer_message_data)?;
let agg_share = VdafAggregateShare::Field32Draft09(AggregateShare::from(
OutputShare::from(out_share.0),
));
Expand All @@ -187,7 +184,7 @@ impl PineConfig {
VdafPrepState::Pine64HmacSha256Aes128(state),
) => {
let vdaf = pine64_hmac_sha256_aes128(param)?;
let out_share = prep_finish_draft09(&vdaf, state, peer_message_data)?;
let out_share = draft09::prep_finish(&vdaf, state, peer_message_data)?;
let agg_share = VdafAggregateShare::Field64Draft09(AggregateShare::from(
OutputShare::from(out_share.0),
));
Expand All @@ -207,12 +204,12 @@ impl PineConfig {
match self {
PineConfig::Field32HmacSha256Aes128 { param } => {
let vdaf = pine32_hmac_sha256_aes128(param)?;
let agg_res = unshard_draft09(&vdaf, num_measurements, agg_shares)?;
let agg_res = draft09::unshard(&vdaf, num_measurements, agg_shares)?;
Ok(DapAggregateResult::F64Vec(agg_res))
}
PineConfig::Field64HmacSha256Aes128 { param } => {
let vdaf = pine64_hmac_sha256_aes128(param)?;
let agg_res = unshard_draft09(&vdaf, num_measurements, agg_shares)?;
let agg_res = draft09::unshard(&vdaf, num_measurements, agg_shares)?;
Ok(DapAggregateResult::F64Vec(agg_res))
}
}
Expand Down
Loading

0 comments on commit 9c79654

Please sign in to comment.