Skip to content

Commit

Permalink
vdaf: Isolate prio_draft09 usage as much as possible
Browse files Browse the repository at this point in the history
Move some methods from the base of the module to where they're used.
This reduces the number of trait collisions there.
  • Loading branch information
cjpatton committed Dec 20, 2024
1 parent 658902e commit b2741f4
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 105 deletions.
96 changes: 1 addition & 95 deletions crates/daphne/src/vdaf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ use prio_draft09::{
Prio3PrepareShare as Prio3Draft09PrepareShare,
Prio3PrepareState as Prio3Draft09PrepareState,
},
AggregateShare as AggregateShareDraft09, Aggregator as AggregatorDraft09,
Client as ClientDraft09, Collector as CollectorDraft09,
PrepareTransition as PrepareTransitionDraft09, Vdaf as VdafDraft09,
AggregateShare as AggregateShareDraft09,
},
};
use rand::prelude::*;
Expand Down Expand Up @@ -802,95 +800,3 @@ where
}
Ok(vdaf.unshard(&(), agg_shares_vec, num_measurements)?)
}

pub(crate) fn shard_then_encode_draft09<V: VdafDraft09 + ClientDraft09<16>>(
vdaf: &V,
measurement: &V::Measurement,
nonce: &[u8; 16],
) -> Result<(Vec<u8>, [Vec<u8>; 2]), VdafError> {
let (public_share, input_shares) = vdaf.shard(measurement, nonce)?;

Ok((
public_share.get_encoded()?,
input_shares
.iter()
.map(|input_share| input_share.get_encoded())
.collect::<Result<Vec<_>, _>>()?
.try_into()
.map_err(|e: Vec<_>| {
VdafError::Dap(fatal_error!(
err = format!("expected 2 input shares got {}", e.len())
))
})?,
))
}

fn prep_finish_from_shares_draft09<V, const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>(
vdaf: &V,
agg_id: usize,
host_state: V::PrepareState,
host_share: V::PrepareShare,
peer_share_data: &[u8],
) -> Result<(V::OutputShare, Vec<u8>), VdafError>
where
V: VdafDraft09<AggregationParam = ()> + AggregatorDraft09<VERIFY_KEY_SIZE, NONCE_SIZE>,
{
// Decode the Helper's inbound message.
let peer_share = V::PrepareShare::get_decoded_with_param(&host_state, peer_share_data)?;

// Preprocess the inbound messages.
let message = vdaf.prepare_shares_to_prepare_message(
&(),
if agg_id == 0 {
[host_share, peer_share]
} else {
[peer_share, host_share]
},
)?;
let message_data = message.get_encoded()?;

// Compute the host's output share.
match vdaf.prepare_next(host_state, message)? {
PrepareTransitionDraft09::Continue(..) => Err(VdafError::Dap(fatal_error!(
err = format!("prep_finish_from_shares: unexpected transition")
))),
PrepareTransitionDraft09::Finish(out_share) => Ok((out_share, message_data)),
}
}

fn prep_finish_draft09<V, const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>(
vdaf: &V,
host_state: V::PrepareState,
peer_message_data: &[u8],
) -> Result<V::OutputShare, VdafError>
where
V: VdafDraft09 + AggregatorDraft09<VERIFY_KEY_SIZE, NONCE_SIZE>,
{
// Decode the inbound message from the peer, which contains the preprocessed prepare message.
let peer_message = V::PrepareMessage::get_decoded_with_param(&host_state, peer_message_data)?;

// Compute the host's output share.
match vdaf.prepare_next(host_state, peer_message)? {
PrepareTransitionDraft09::Continue(..) => Err(VdafError::Dap(fatal_error!(
err = format!("prep_finish: unexpected transition"),
))),
PrepareTransitionDraft09::Finish(out_share) => Ok(out_share),
}
}

fn unshard_draft09<V, M>(
vdaf: &V,
num_measurements: usize,
agg_shares: M,
) -> Result<V::AggregateResult, VdafError>
where
V: VdafDraft09<AggregationParam = ()> + CollectorDraft09,
M: IntoIterator<Item = Vec<u8>>,
{
let mut agg_shares_vec = Vec::with_capacity(vdaf.num_aggregators());
for data in agg_shares {
let agg_share = V::AggregateShare::get_decoded_with_param(&(vdaf, &()), data.as_ref())?;
agg_shares_vec.push(agg_share);
}
Ok(vdaf.unshard(&(), agg_shares_vec, num_measurements)?)
}
7 changes: 5 additions & 2 deletions crates/daphne/src/vdaf/pine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ use crate::{
};

use super::{
prep_finish_draft09, prep_finish_from_shares_draft09, shard_then_encode_draft09,
unshard_draft09, VdafAggregateShare, VdafError, VdafPrepShare, VdafPrepState, VdafVerifyKey,
prio3_draft09::{
prep_finish_draft09, prep_finish_from_shares_draft09, shard_then_encode_draft09,
unshard_draft09,
},
VdafAggregateShare, VdafError, VdafPrepShare, VdafPrepState, VdafVerifyKey,
};
use prio_draft09::{
codec::ParameterizedDecode,
Expand Down
13 changes: 7 additions & 6 deletions crates/daphne/src/vdaf/prio3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ use crate::{
VdafPrepState,
};

use super::{prep_finish, prep_finish_from_shares, shard_then_encode, unshard};

use prio::{
codec::ParameterizedDecode,
flp::Type,
Expand All @@ -23,6 +21,8 @@ use prio::{
},
};

use super::{prep_finish, prep_finish_from_shares, shard_then_encode, unshard};

const CTX_STRING_PREFIX: &[u8] = b"dap-13";

impl Prio3Config {
Expand Down Expand Up @@ -96,7 +96,7 @@ impl Prio3Config {
*chunk_length,
*num_proofs,
)?;
super::shard_then_encode_draft09(&vdaf, &measurement, nonce)
prio3_draft09::shard_then_encode_draft09(&vdaf, &measurement, nonce)
}
_ => Err(VdafError::Dap(fatal_error!(
err =
Expand Down Expand Up @@ -374,7 +374,7 @@ impl Prio3Config {
*chunk_length,
*num_proofs,
)?;
let (out_share, outbound) = super::prep_finish_from_shares_draft09(
let (out_share, outbound) = prio3_draft09::prep_finish_from_shares_draft09(
&vdaf,
agg_id,
state,
Expand Down Expand Up @@ -469,7 +469,8 @@ impl Prio3Config {
*chunk_length,
*num_proofs,
)?;
let out_share = super::prep_finish_draft09(&vdaf, state, peer_message_data)?;
let out_share =
prio3_draft09::prep_finish_draft09(&vdaf, state, peer_message_data)?;
VdafAggregateShare::Field64Draft09(prio_draft09::vdaf::Aggregator::aggregate(
&vdaf,
&(),
Expand Down Expand Up @@ -555,7 +556,7 @@ impl Prio3Config {
*chunk_length,
*num_proofs,
)?;
let agg_res = super::unshard_draft09(&vdaf, num_measurements, agg_shares)?;
let agg_res = prio3_draft09::unshard_draft09(&vdaf, num_measurements, agg_shares)?;
Ok(DapAggregateResult::U64Vec(agg_res))
}
_ => Err(VdafError::Dap(fatal_error!(
Expand Down
100 changes: 98 additions & 2 deletions crates/daphne/src/vdaf/prio3_draft09.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
};

use prio_draft09::{
codec::ParameterizedDecode,
codec::{Encode, ParameterizedDecode},
field::Field64,
flp::{
gadgets::{Mul, ParallelSum},
Expand All @@ -19,7 +19,7 @@ use prio_draft09::{
vdaf::{
prio3::{Prio3, Prio3InputShare, Prio3PrepareShare, Prio3PrepareState, Prio3PublicShare},
xof::{Xof, XofHmacSha256Aes128},
Aggregator,
Aggregator, Client, Collector, PrepareTransition, Vdaf,
},
};

Expand Down Expand Up @@ -51,6 +51,28 @@ type Prio3Draft09Prepared<T, const SEED_SIZE: usize> = (
Prio3PrepareShare<<T as Type>::Field, SEED_SIZE>,
);

pub(crate) fn shard_then_encode_draft09<V: Vdaf + Client<16>>(
vdaf: &V,
measurement: &V::Measurement,
nonce: &[u8; 16],
) -> Result<(Vec<u8>, [Vec<u8>; 2]), VdafError> {
let (public_share, input_shares) = vdaf.shard(measurement, nonce)?;

Ok((
public_share.get_encoded()?,
input_shares
.iter()
.map(|input_share| input_share.get_encoded())
.collect::<Result<Vec<_>, _>>()?
.try_into()
.map_err(|e: Vec<_>| {
VdafError::Dap(fatal_error!(
err = format!("expected 2 input shares got {}", e.len())
))
})?,
))
}

pub(crate) fn prep_init_draft09<T, P, const SEED_SIZE: usize>(
vdaf: Prio3<T, P, SEED_SIZE>,
verify_key: &[u8; SEED_SIZE],
Expand All @@ -73,6 +95,80 @@ where
Ok(vdaf.prepare_init(verify_key, agg_id, &(), nonce, &public_share, &input_share)?)
}

pub(crate) fn prep_finish_from_shares_draft09<
V,
const VERIFY_KEY_SIZE: usize,
const NONCE_SIZE: usize,
>(
vdaf: &V,
agg_id: usize,
host_state: V::PrepareState,
host_share: V::PrepareShare,
peer_share_data: &[u8],
) -> Result<(V::OutputShare, Vec<u8>), VdafError>
where
V: Vdaf<AggregationParam = ()> + Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
{
// Decode the Helper's inbound message.
let peer_share = V::PrepareShare::get_decoded_with_param(&host_state, peer_share_data)?;

// Preprocess the inbound messages.
let message = vdaf.prepare_shares_to_prepare_message(
&(),
if agg_id == 0 {
[host_share, peer_share]
} else {
[peer_share, host_share]
},
)?;
let message_data = message.get_encoded()?;

// Compute the host's output share.
match vdaf.prepare_next(host_state, message)? {
PrepareTransition::Continue(..) => Err(VdafError::Dap(fatal_error!(
err = format!("prep_finish_from_shares: unexpected transition")
))),
PrepareTransition::Finish(out_share) => Ok((out_share, message_data)),
}
}

pub(crate) fn prep_finish_draft09<V, const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>(
vdaf: &V,
host_state: V::PrepareState,
peer_message_data: &[u8],
) -> Result<V::OutputShare, VdafError>
where
V: Vdaf + Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
{
// Decode the inbound message from the peer, which contains the preprocessed prepare message.
let peer_message = V::PrepareMessage::get_decoded_with_param(&host_state, peer_message_data)?;

// Compute the host's output share.
match vdaf.prepare_next(host_state, peer_message)? {
PrepareTransition::Continue(..) => Err(VdafError::Dap(fatal_error!(
err = format!("prep_finish: unexpected transition"),
))),
PrepareTransition::Finish(out_share) => Ok(out_share),
}
}

pub(crate) fn unshard_draft09<V, M>(
vdaf: &V,
num_measurements: usize,
agg_shares: M,
) -> Result<V::AggregateResult, VdafError>
where
V: Vdaf<AggregationParam = ()> + Collector,
M: IntoIterator<Item = Vec<u8>>,
{
let mut agg_shares_vec = Vec::with_capacity(vdaf.num_aggregators());
for data in agg_shares {
let agg_share = V::AggregateShare::get_decoded_with_param(&(vdaf, &()), data.as_ref())?;
agg_shares_vec.push(agg_share);
}
Ok(vdaf.unshard(&(), agg_shares_vec, num_measurements)?)
}

#[cfg(test)]
mod test {

Expand Down

0 comments on commit b2741f4

Please sign in to comment.