From 4d8041fdf5ddd946b8495c5c9742d590e8c595e2 Mon Sep 17 00:00:00 2001 From: Andrew Schran Date: Wed, 13 Dec 2023 13:59:16 -0500 Subject: [PATCH] ThresholdBls: accept Iterator directly where possible (#709) Instead of requiring a slice that we immediately and only call `iter()` on, accept the Iterator. This can enable clients to avoid extra copies. --- fastcrypto-tbls/benches/tbls.rs | 6 +-- fastcrypto-tbls/src/nidkg.rs | 5 +- fastcrypto-tbls/src/polynomial.rs | 50 ++++++++++-------- fastcrypto-tbls/src/tbls.rs | 51 +++++++++---------- fastcrypto-tbls/src/tests/dkg_tests.rs | 2 +- fastcrypto-tbls/src/tests/polynomial_tests.rs | 37 +++++++------- fastcrypto-tbls/src/tests/tbls_tests.rs | 32 ++++++------ 7 files changed, 94 insertions(+), 89 deletions(-) diff --git a/fastcrypto-tbls/benches/tbls.rs b/fastcrypto-tbls/benches/tbls.rs index 5a0d356028..170f272fea 100644 --- a/fastcrypto-tbls/benches/tbls.rs +++ b/fastcrypto-tbls/benches/tbls.rs @@ -25,7 +25,7 @@ mod tbls_benches { .collect::>(); create.bench_function(format!("w={}", w).as_str(), |b| { - b.iter(|| ThresholdBls12381MinSig::partial_sign_batch(&shares, msg)) + b.iter(|| ThresholdBls12381MinSig::partial_sign_batch(shares.iter(), msg)) }); } } @@ -39,10 +39,10 @@ mod tbls_benches { .map(|i| private_poly.eval(NonZeroU32::new(i as u32).unwrap())) .collect::>(); - let sigs = ThresholdBls12381MinSig::partial_sign_batch(&shares, msg); + let sigs = ThresholdBls12381MinSig::partial_sign_batch(shares.iter(), msg); create.bench_function(format!("w={}", w).as_str(), |b| { - b.iter(|| ThresholdBls12381MinSig::aggregate(w as u32, &sigs).unwrap()) + b.iter(|| ThresholdBls12381MinSig::aggregate(w as u32, sigs.iter()).unwrap()) }); } } diff --git a/fastcrypto-tbls/src/nidkg.rs b/fastcrypto-tbls/src/nidkg.rs index 2239e15b83..e6cdfb878d 100644 --- a/fastcrypto-tbls/src/nidkg.rs +++ b/fastcrypto-tbls/src/nidkg.rs @@ -332,9 +332,8 @@ where .map(|(i, pk)| Eval { index: NonZeroU32::new((i + 1) as u32).expect("non zero"), value: *pk, - }) - .collect::>>(); - let pk = Poly::::recover_c0(self.t, &evals).expect("enough shares"); + }); + let pk = Poly::::recover_c0(self.t, evals).expect("enough shares"); (pk, partial_pks) } diff --git a/fastcrypto-tbls/src/polynomial.rs b/fastcrypto-tbls/src/polynomial.rs index bb0818daec..4093796cb7 100644 --- a/fastcrypto-tbls/src/polynomial.rs +++ b/fastcrypto-tbls/src/polynomial.rs @@ -10,6 +10,7 @@ use fastcrypto::error::{FastCryptoError, FastCryptoResult}; use fastcrypto::groups::{GroupElement, MultiScalarMul, Scalar}; use fastcrypto::traits::AllowedRng; use serde::{Deserialize, Serialize}; +use std::borrow::Borrow; use std::collections::HashSet; /// Types @@ -81,22 +82,25 @@ impl Poly { // Expects exactly t unique shares. fn get_lagrange_coefficients_for_c0( t: u32, - shares: &[Eval], + mut shares: impl Iterator>>, ) -> FastCryptoResult> { - if shares.len() != t as usize { - return Err(FastCryptoError::InvalidInput); - } - // Check for duplicates. let mut ids_set = HashSet::new(); - if !shares.iter().map(|s| &s.index).all(|id| ids_set.insert(id)) { - return Err(FastCryptoError::InvalidInput); // expected unique ids + let (shares_size_lower, shares_size_upper) = shares.size_hint(); + let indices = shares.try_fold( + Vec::with_capacity(shares_size_upper.unwrap_or(shares_size_lower)), + |mut vec, s| { + // Check for duplicates. + if !ids_set.insert(s.borrow().index) { + return Err(FastCryptoError::InvalidInput); // expected unique ids + } + vec.push(C::ScalarType::from(s.borrow().index.get() as u64)); + Ok(vec) + }, + )?; + if indices.len() != t as usize { + return Err(FastCryptoError::InvalidInput); } - let indices = shares - .iter() - .map(|s| C::ScalarType::from(s.index.get() as u64)) - .collect::>(); - let full_numerator = indices .iter() .fold(C::ScalarType::generator(), |acc, i| acc * i); @@ -113,13 +117,16 @@ impl Poly { } /// Given exactly `t` polynomial evaluations, it will recover the polynomial's constant term. - pub fn recover_c0(t: u32, shares: &[Eval]) -> Result { - let coeffs = Self::get_lagrange_coefficients_for_c0(t, shares)?; - let plain_shares = shares.iter().map(|s| s.value).collect::>(); + pub fn recover_c0( + t: u32, + shares: impl Iterator>> + Clone, + ) -> Result { + let coeffs = Self::get_lagrange_coefficients_for_c0(t, shares.clone())?; + let plain_shares = shares.map(|s| s.borrow().value); let res = coeffs .iter() - .zip(plain_shares.iter()) - .fold(C::zero(), |acc, (c, s)| acc + (*s * *c)); + .zip(plain_shares) + .fold(C::zero(), |acc, (c, s)| acc + (s * *c)); Ok(res) } @@ -172,9 +179,12 @@ impl Poly { impl Poly { /// Given exactly `t` polynomial evaluations, it will recover the polynomial's /// constant term. - pub fn recover_c0_msm(t: u32, shares: &[Eval]) -> Result { - let coeffs = Self::get_lagrange_coefficients_for_c0(t, shares)?; - let plain_shares = shares.iter().map(|s| s.value).collect::>(); + pub fn recover_c0_msm( + t: u32, + shares: impl Iterator>> + Clone, + ) -> Result { + let coeffs = Self::get_lagrange_coefficients_for_c0(t, shares.clone())?; + let plain_shares = shares.map(|s| s.borrow().value).collect::>(); let res = C::multi_scalar_mul(&coeffs, &plain_shares).expect("sizes match"); Ok(res) } diff --git a/fastcrypto-tbls/src/tbls.rs b/fastcrypto-tbls/src/tbls.rs index 93e96b6251..4df61e3d89 100644 --- a/fastcrypto-tbls/src/tbls.rs +++ b/fastcrypto-tbls/src/tbls.rs @@ -4,6 +4,8 @@ // Some of the code below is based on code from https://github.com/celo-org/celo-threshold-bls-rs, // modified for our needs. +use std::borrow::Borrow; + use crate::dl_verification::{batch_coefficients, get_random_scalars}; use crate::polynomial::Poly; use crate::types::IndexedValue; @@ -29,20 +31,22 @@ pub trait ThresholdBls { /// Sign a message using the private share/partial key. fn partial_sign(share: &Share, msg: &[u8]) -> PartialSignature { - Self::partial_sign_batch(&[share.clone()], msg)[0].clone() + Self::partial_sign_batch(std::iter::once(share), msg)[0].clone() } /// Sign a message using one of more private share/partial keys. fn partial_sign_batch( - shares: &[Share], + shares: impl Iterator>>, msg: &[u8], ) -> Vec> { let h = Self::Signature::hash_to_group_element(msg); shares - .iter() - .map(|share| PartialSignature { - index: share.index, - value: h * share.value, + .map(|share| { + let share = share.borrow(); + PartialSignature { + index: share.index, + value: h * share.value, + } }) .collect() } @@ -63,26 +67,24 @@ pub trait ThresholdBls { fn partial_verify_batch( vss_pk: &Poly, msg: &[u8], - partial_sigs: &[PartialSignature], + partial_sigs: impl Iterator>>, rng: &mut R, ) -> FastCryptoResult<()> { assert!(vss_pk.degree() > 0 || !msg.is_empty()); - if partial_sigs.is_empty() { + let (evals_as_scalars, points): (Vec<_>, Vec<_>) = partial_sigs + .map(|sig| { + let sig = sig.borrow(); + (Self::Private::from(sig.index.get().into()), sig.value) + }) + .unzip(); + if points.is_empty() { return Ok(()); } - let rs = get_random_scalars::(partial_sigs.len() as u32, rng); - let evals_as_scalars = partial_sigs - .iter() - .map(|e| Self::Private::from(e.index.get().into())) - .collect::>(); + let rs = get_random_scalars::(points.len() as u32, rng); // TODO: should we cache it instead? that would replace t-wide msm with w-wide msm. let coeffs = batch_coefficients(&rs, &evals_as_scalars, vss_pk.degree()); let pk = Self::Public::multi_scalar_mul(&coeffs, vss_pk.as_vec()).expect("sizes match"); - let aggregated_sig = Self::Signature::multi_scalar_mul( - &rs, - &partial_sigs.iter().map(|s| s.value).collect::>(), - ) - .expect("sizes match"); + let aggregated_sig = Self::Signature::multi_scalar_mul(&rs, &points).expect("sizes match"); Self::verify(&pk, msg, &aggregated_sig) } @@ -90,19 +92,16 @@ pub trait ThresholdBls { /// Interpolate partial signatures to recover the full signature. fn aggregate( threshold: u32, - partials: &[PartialSignature], + partials: impl Iterator>> + Clone, ) -> FastCryptoResult { let unique_partials = partials - .iter() - .unique_by(|p| p.index) - .take(threshold as usize) - .cloned() - .collect::>(); - if unique_partials.len() != threshold as usize { + .unique_by(|p| p.borrow().index) + .take(threshold as usize); + if unique_partials.clone().count() != threshold as usize { return Err(FastCryptoError::NotEnoughInputs); } // No conversion is required since PartialSignature and Eval are different aliases to // IndexedValue. - Poly::::recover_c0_msm(threshold, &unique_partials) + Poly::::recover_c0_msm(threshold, unique_partials) } } diff --git a/fastcrypto-tbls/src/tests/dkg_tests.rs b/fastcrypto-tbls/src/tests/dkg_tests.rs index 92f4efc7f4..d79a619fcc 100644 --- a/fastcrypto-tbls/src/tests/dkg_tests.rs +++ b/fastcrypto-tbls/src/tests/dkg_tests.rs @@ -261,7 +261,7 @@ fn test_dkg_e2e_5_parties_min_weight_2_threshold_4() { S::partial_verify(&o3.vss_pk, &MSG, &sig31).unwrap(); let sigs = vec![sig00, sig30, sig31]; - let sig = S::aggregate(d0.t(), &sigs).unwrap(); + let sig = S::aggregate(d0.t(), sigs.iter()).unwrap(); S::verify(o0.vss_pk.c0(), &MSG, &sig).unwrap(); } diff --git a/fastcrypto-tbls/src/tests/polynomial_tests.rs b/fastcrypto-tbls/src/tests/polynomial_tests.rs index 7258bd9d42..3768c8bda3 100644 --- a/fastcrypto-tbls/src/tests/polynomial_tests.rs +++ b/fastcrypto-tbls/src/tests/polynomial_tests.rs @@ -53,16 +53,13 @@ mod scalar_tests { let threshold = degree + 1; let poly = Poly::::rand(4, &mut thread_rng()); // insufficient shares gathered - let shares = (1..threshold) - .map(|i| poly.eval(ShareIndex::new(i).unwrap())) - .collect::>(); - Poly::::recover_c0(threshold, &shares).unwrap_err(); + let shares = (1..threshold).map(|i| poly.eval(ShareIndex::new(i).unwrap())); + Poly::::recover_c0(threshold, shares).unwrap_err(); // duplications - let mut shares = (1..=threshold) + let shares = (1..=threshold) .map(|i| poly.eval(ShareIndex::new(i).unwrap())) - .collect::>(); - shares.push(shares[0].clone()); - Poly::::recover_c0(threshold, &shares).unwrap_err(); + .chain(std::iter::once(poly.eval(ShareIndex::new(1).unwrap()))); // duplicate value 1 + Poly::::recover_c0(threshold, shares).unwrap_err(); } #[test] @@ -75,7 +72,7 @@ mod scalar_tests { let c0 = poly.c0(); for _ in 0..10 { shares.shuffle(&mut thread_rng()); - let used_shares = &shares[..124]; + let used_shares = shares.iter().take(124); assert_eq!(c0, &Poly::::recover_c0(124, used_shares).unwrap()); } } @@ -112,7 +109,10 @@ mod points_tests { let s2 = p.eval(NonZeroU32::new(20).unwrap()); let s3 = p.eval(NonZeroU32::new(30).unwrap()); let shares = vec![s1, s2, s3]; - assert_eq!(Poly::::recover_c0(3, &shares).unwrap(), one); + assert_eq!( + Poly::::recover_c0(3, shares.iter()).unwrap(), + one + ); } #[test] @@ -122,16 +122,13 @@ mod points_tests { let poly = Poly::::rand(4, &mut thread_rng()); let poly_g = poly.commit(); // insufficient shares gathered - let shares = (1..threshold) - .map(|i| poly_g.eval(ShareIndex::new(i).unwrap())) - .collect::>(); - Poly::::recover_c0_msm(threshold, &shares).unwrap_err(); + let shares = (1..threshold).map(|i| poly_g.eval(ShareIndex::new(i).unwrap())); + Poly::::recover_c0_msm(threshold, shares).unwrap_err(); // duplications - let mut shares = (1..threshold) + let shares = (1..threshold) .map(|i| poly_g.eval(ShareIndex::new(i).unwrap())) - .collect::>(); - shares.push(shares[0].clone()); - Poly::::recover_c0_msm(threshold, &shares).unwrap_err(); + .chain(std::iter::once(poly_g.eval(ShareIndex::new(1).unwrap()))); // duplicate value 1 + Poly::::recover_c0_msm(threshold, shares).unwrap_err(); } #[test] @@ -144,7 +141,7 @@ mod points_tests { let s2 = p.eval(NonZeroU32::new(20).unwrap()); let s3 = p.eval(NonZeroU32::new(30).unwrap()); let shares = vec![s1, s2, s3]; - assert_eq!(Poly::::recover_c0_msm(3, &shares).unwrap(), one); + assert_eq!(Poly::::recover_c0_msm(3, shares.iter()).unwrap(), one); // and random tests let poly = Poly::::rand(123, &mut thread_rng()); @@ -156,7 +153,7 @@ mod points_tests { let c0 = poly_g.c0(); for _ in 0..10 { shares.shuffle(&mut thread_rng()); - let used_shares = &shares[..124]; + let used_shares = shares.iter().take(124); assert_eq!(c0, &Poly::::recover_c0_msm(124, used_shares).unwrap()); } } diff --git a/fastcrypto-tbls/src/tests/tbls_tests.rs b/fastcrypto-tbls/src/tests/tbls_tests.rs index 2f057e4205..08ff1a0393 100644 --- a/fastcrypto-tbls/src/tests/tbls_tests.rs +++ b/fastcrypto-tbls/src/tests/tbls_tests.rs @@ -34,11 +34,11 @@ fn test_tbls_e2e() { ThresholdBls12381MinSig::partial_verify(&public_poly, b"other message", &sig1).is_err() ); // Aggregate should fail if we don't have enough signatures. - assert!(ThresholdBls12381MinSig::aggregate(t, &[sig1.clone(), sig2.clone()]).is_err()); + assert!(ThresholdBls12381MinSig::aggregate(t, [sig1.clone(), sig2.clone()].iter()).is_err()); // Signatures should be the same no matter if calculated with the private key or from a // threshold of partial signatures. - let full_sig = ThresholdBls12381MinSig::aggregate(t, &[sig1, sig2, sig3]).unwrap(); + let full_sig = ThresholdBls12381MinSig::aggregate(t, [sig1, sig2, sig3].iter()).unwrap(); assert!(ThresholdBls12381MinSig::verify(public_poly.c0(), msg, &full_sig).is_ok()); assert_eq!( full_sig, @@ -69,26 +69,26 @@ fn test_partial_verify_batch() { assert!(ThresholdBls12381MinSig::partial_verify_batch( &public_poly, msg, - &[], + [].iter(), &mut thread_rng() ) .is_ok()); // standard sigs should pass - let sigs = ThresholdBls12381MinSig::partial_sign_batch(&shares, msg); + let sigs = ThresholdBls12381MinSig::partial_sign_batch(shares.iter(), msg); assert!(ThresholdBls12381MinSig::partial_verify_batch( &public_poly, msg, - &sigs, + sigs.iter(), &mut thread_rng() ) .is_ok()); // even if repeated - let mut sigs = ThresholdBls12381MinSig::partial_sign_batch(&shares, msg); + let mut sigs = ThresholdBls12381MinSig::partial_sign_batch(shares.iter(), msg); sigs[0] = sigs[2].clone(); assert!(ThresholdBls12381MinSig::partial_verify_batch( &public_poly, msg, - &sigs, + sigs.iter(), &mut thread_rng() ) .is_ok()); @@ -96,48 +96,48 @@ fn test_partial_verify_batch() { assert!(ThresholdBls12381MinSig::partial_verify_batch( &public_poly, b"other message", - &sigs, + sigs.iter(), &mut thread_rng() ) .is_err()); // invalid signatures according to the polynomial should fail - let mut sigs = ThresholdBls12381MinSig::partial_sign_batch(&shares, msg); + let mut sigs = ThresholdBls12381MinSig::partial_sign_batch(shares.iter(), msg); (sigs[0].index, sigs[1].index) = (sigs[1].index, sigs[0].index); assert!(ThresholdBls12381MinSig::partial_verify_batch( &public_poly, msg, - &sigs, + sigs.iter(), &mut thread_rng() ) .is_err()); // identity as the signature should fail - let mut sigs = ThresholdBls12381MinSig::partial_sign_batch(&shares, msg); + let mut sigs = ThresholdBls12381MinSig::partial_sign_batch(shares.iter(), msg); sigs[1].value = G1Element::zero(); assert!(ThresholdBls12381MinSig::partial_verify_batch( &public_poly, msg, - &sigs, + sigs.iter(), &mut thread_rng() ) .is_err()); // generator as the signature should fail - let mut sigs = ThresholdBls12381MinSig::partial_sign_batch(&shares, msg); + let mut sigs = ThresholdBls12381MinSig::partial_sign_batch(shares.iter(), msg); sigs[1].value = G1Element::generator(); assert!(ThresholdBls12381MinSig::partial_verify_batch( &public_poly, msg, - &sigs, + sigs.iter(), &mut thread_rng() ) .is_err()); // even if the sum of sigs is ok, should fail since not consistent with the polynomial - let mut sigs = ThresholdBls12381MinSig::partial_sign_batch(&shares, msg); + let mut sigs = ThresholdBls12381MinSig::partial_sign_batch(shares.iter(), msg); sigs[0].value -= G1Element::generator(); sigs[1].value += G1Element::generator(); assert!(ThresholdBls12381MinSig::partial_verify_batch( &public_poly, msg, - &sigs, + sigs.iter(), &mut thread_rng() ) .is_err());