From b26af2b5f392ae9af585facce5c69176ba7f9b85 Mon Sep 17 00:00:00 2001 From: David Cook Date: Thu, 19 Dec 2024 15:36:30 -0600 Subject: [PATCH 1/2] Add method to construct an empty aggregate share --- src/vdaf.rs | 3 +++ src/vdaf/dummy.rs | 8 ++++++-- src/vdaf/mastic.rs | 16 ++++++++++------ src/vdaf/poplar1.rs | 7 +++++++ src/vdaf/prio2.rs | 6 +++++- src/vdaf/prio3.rs | 6 +++++- 6 files changed, 36 insertions(+), 10 deletions(-) diff --git a/src/vdaf.rs b/src/vdaf.rs index 035ef425..26b667be 100644 --- a/src/vdaf.rs +++ b/src/vdaf.rs @@ -308,6 +308,9 @@ pub trait Aggregator: Vda output_shares: M, ) -> Result; + /// Create an empty aggregate share. + fn aggregate_init(&self, agg_param: &Self::AggregationParam) -> Self::AggregateShare; + /// Validates an aggregation parameter with respect to all previous aggregaiton parameters used /// for the same input share. `prev` MUST be sorted from least to most recently used. #[must_use] diff --git a/src/vdaf/dummy.rs b/src/vdaf/dummy.rs index 1a78e3ee..e45dac4b 100644 --- a/src/vdaf/dummy.rs +++ b/src/vdaf/dummy.rs @@ -160,16 +160,20 @@ impl vdaf::Aggregator<0, 16> for Vdaf { fn aggregate>( &self, - _aggregation_param: &Self::AggregationParam, + aggregation_param: &Self::AggregationParam, output_shares: M, ) -> Result { - let mut aggregate_share = AggregateShare(0); + let mut aggregate_share = self.aggregate_init(aggregation_param); for output_share in output_shares { aggregate_share.accumulate(&output_share)?; } Ok(aggregate_share) } + fn aggregate_init(&self, _agg_param: &Self::AggregationParam) -> Self::AggregateShare { + AggregateShare(0) + } + fn is_agg_param_valid(_cur: &Self::AggregationParam, _prev: &[Self::AggregationParam]) -> bool { true } diff --git a/src/vdaf/mastic.rs b/src/vdaf/mastic.rs index 7fb7ba48..f48c5370 100644 --- a/src/vdaf/mastic.rs +++ b/src/vdaf/mastic.rs @@ -646,18 +646,22 @@ where agg_param: &MasticAggregationParam, output_shares: M, ) -> Result, VdafError> { - let mut agg_share = MasticAggregateShare::::from(vec![ + let mut agg_share = self.aggregate_init(agg_param); + for output_share in output_shares.into_iter() { + agg_share.accumulate(&output_share)?; + } + Ok(agg_share) + } + + fn aggregate_init(&self, agg_param: &Self::AggregationParam) -> Self::AggregateShare { + MasticAggregateShare::::from(vec![ T::Field::zero(); self.vidpf.weight_parameter * agg_param .level_and_prefixes .prefixes() .len() - ]); - for output_share in output_shares.into_iter() { - agg_share.accumulate(&output_share)?; - } - Ok(agg_share) + ]) } } diff --git a/src/vdaf/poplar1.rs b/src/vdaf/poplar1.rs index b67c850c..cb9bae18 100644 --- a/src/vdaf/poplar1.rs +++ b/src/vdaf/poplar1.rs @@ -1263,6 +1263,13 @@ impl, const SEED_SIZE: usize> Aggregator ) } + fn aggregate_init(&self, agg_param: &Self::AggregationParam) -> Self::AggregateShare { + Poplar1FieldVec::zero( + usize::from(agg_param.level) == self.bits - 1, + agg_param.prefixes.len(), + ) + } + /// Validates that no aggregation parameter with the same level as `cur` has been used with the /// same input share before. `prev` contains the aggregation parameters used for the same input. /// `prev` MUST be sorted from least to most recently used. diff --git a/src/vdaf/prio2.rs b/src/vdaf/prio2.rs index 96a8f5a3..5b25643c 100644 --- a/src/vdaf/prio2.rs +++ b/src/vdaf/prio2.rs @@ -322,7 +322,7 @@ impl Aggregator<32, 16> for Prio2 { _agg_param: &Self::AggregationParam, out_shares: M, ) -> Result, VdafError> { - let mut agg_share = AggregateShare(vec![FieldPrio2::zero(); self.input_len]); + let mut agg_share = self.aggregate_init(&()); for out_share in out_shares.into_iter() { agg_share.accumulate(&out_share)?; } @@ -330,6 +330,10 @@ impl Aggregator<32, 16> for Prio2 { Ok(agg_share) } + fn aggregate_init(&self, _agg_param: &Self::AggregationParam) -> Self::AggregateShare { + AggregateShare(vec![FieldPrio2::zero(); self.input_len]) + } + /// Returns `true` iff `prev.is_empty()` fn is_agg_param_valid(_cur: &Self::AggregationParam, prev: &[Self::AggregationParam]) -> bool { prev.is_empty() diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index 9d0d65b8..4dd18351 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -1446,7 +1446,7 @@ where _agg_param: &(), output_shares: It, ) -> Result, VdafError> { - let mut agg_share = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]); + let mut agg_share = self.aggregate_init(&()); for output_share in output_shares.into_iter() { agg_share.accumulate(&output_share)?; } @@ -1454,6 +1454,10 @@ where Ok(agg_share) } + fn aggregate_init(&self, _agg_param: &Self::AggregationParam) -> Self::AggregateShare { + AggregateShare(vec![T::Field::zero(); self.typ.output_len()]) + } + /// Returns `true` iff `prev.is_empty()` fn is_agg_param_valid(_cur: &Self::AggregationParam, prev: &[Self::AggregationParam]) -> bool { prev.is_empty() From 7b955611d9feb1f3a7d20bc48988ec40d6aa5d55 Mon Sep 17 00:00:00 2001 From: David Cook Date: Thu, 19 Dec 2024 15:44:24 -0600 Subject: [PATCH 2/2] Add default implementation of aggregate() --- src/vdaf.rs | 8 +++++++- src/vdaf/dummy.rs | 12 ------------ src/vdaf/mastic.rs | 12 ------------ src/vdaf/poplar1.rs | 12 ------------ src/vdaf/prio2.rs | 13 ------------- src/vdaf/prio3.rs | 14 -------------- 6 files changed, 7 insertions(+), 64 deletions(-) diff --git a/src/vdaf.rs b/src/vdaf.rs index 26b667be..2e18b130 100644 --- a/src/vdaf.rs +++ b/src/vdaf.rs @@ -306,7 +306,13 @@ pub trait Aggregator: Vda &self, agg_param: &Self::AggregationParam, output_shares: M, - ) -> Result; + ) -> Result { + let mut share = self.aggregate_init(agg_param); + for output_share in output_shares { + share.accumulate(&output_share)?; + } + Ok(share) + } /// Create an empty aggregate share. fn aggregate_init(&self, agg_param: &Self::AggregationParam) -> Self::AggregateShare; diff --git a/src/vdaf/dummy.rs b/src/vdaf/dummy.rs index e45dac4b..9fec3a68 100644 --- a/src/vdaf/dummy.rs +++ b/src/vdaf/dummy.rs @@ -158,18 +158,6 @@ impl vdaf::Aggregator<0, 16> for Vdaf { (self.prep_step_fn)(&state) } - fn aggregate>( - &self, - aggregation_param: &Self::AggregationParam, - output_shares: M, - ) -> Result { - let mut aggregate_share = self.aggregate_init(aggregation_param); - for output_share in output_shares { - aggregate_share.accumulate(&output_share)?; - } - Ok(aggregate_share) - } - fn aggregate_init(&self, _agg_param: &Self::AggregationParam) -> Self::AggregateShare { AggregateShare(0) } diff --git a/src/vdaf/mastic.rs b/src/vdaf/mastic.rs index f48c5370..9ee7bf31 100644 --- a/src/vdaf/mastic.rs +++ b/src/vdaf/mastic.rs @@ -641,18 +641,6 @@ where Ok(PrepareTransition::Finish(output_shares)) } - fn aggregate>>( - &self, - agg_param: &MasticAggregationParam, - output_shares: M, - ) -> Result, VdafError> { - let mut agg_share = self.aggregate_init(agg_param); - for output_share in output_shares.into_iter() { - agg_share.accumulate(&output_share)?; - } - Ok(agg_share) - } - fn aggregate_init(&self, agg_param: &Self::AggregationParam) -> Self::AggregateShare { MasticAggregateShare::::from(vec![ T::Field::zero(); diff --git a/src/vdaf/poplar1.rs b/src/vdaf/poplar1.rs index cb9bae18..27ed6aa2 100644 --- a/src/vdaf/poplar1.rs +++ b/src/vdaf/poplar1.rs @@ -1251,18 +1251,6 @@ impl, const SEED_SIZE: usize> Aggregator } } - fn aggregate>( - &self, - agg_param: &Poplar1AggregationParam, - output_shares: M, - ) -> Result { - aggregate( - usize::from(agg_param.level) == self.bits - 1, - agg_param.prefixes.len(), - output_shares, - ) - } - fn aggregate_init(&self, agg_param: &Self::AggregationParam) -> Self::AggregateShare { Poplar1FieldVec::zero( usize::from(agg_param.level) == self.bits - 1, diff --git a/src/vdaf/prio2.rs b/src/vdaf/prio2.rs index 5b25643c..dd35e1e3 100644 --- a/src/vdaf/prio2.rs +++ b/src/vdaf/prio2.rs @@ -317,19 +317,6 @@ impl Aggregator<32, 16> for Prio2 { Ok(PrepareTransition::Finish(OutputShare::from(data))) } - fn aggregate>>( - &self, - _agg_param: &Self::AggregationParam, - out_shares: M, - ) -> Result, VdafError> { - let mut agg_share = self.aggregate_init(&()); - for out_share in out_shares.into_iter() { - agg_share.accumulate(&out_share)?; - } - - Ok(agg_share) - } - fn aggregate_init(&self, _agg_param: &Self::AggregationParam) -> Self::AggregateShare { AggregateShare(vec![FieldPrio2::zero(); self.input_len]) } diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index 4dd18351..0e2d3c5e 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -1440,20 +1440,6 @@ where Ok(PrepareTransition::Finish(output_share)) } - /// Aggregates a sequence of output shares into an aggregate share. - fn aggregate>>( - &self, - _agg_param: &(), - output_shares: It, - ) -> Result, VdafError> { - let mut agg_share = self.aggregate_init(&()); - for output_share in output_shares.into_iter() { - agg_share.accumulate(&output_share)?; - } - - Ok(agg_share) - } - fn aggregate_init(&self, _agg_param: &Self::AggregationParam) -> Self::AggregateShare { AggregateShare(vec![T::Field::zero(); self.typ.output_len()]) }