Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ThresholdBls: accept Iterator directly where possible #709

Merged
merged 2 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions fastcrypto-tbls/benches/tbls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ mod tbls_benches {
.collect::<Vec<_>>();

create.bench_function(format!("w={}", w).as_str(), |b| {
b.iter(|| ThresholdBls12381MinSig::partial_sign_batch(&shares, msg))
b.iter(|| ThresholdBls12381MinSig::partial_sign_batch(shares.iter(), msg))
});
}
}
Expand All @@ -39,10 +39,10 @@ mod tbls_benches {
.map(|i| private_poly.eval(NonZeroU32::new(i as u32).unwrap()))
.collect::<Vec<_>>();

let sigs = ThresholdBls12381MinSig::partial_sign_batch(&shares, msg);
let sigs = ThresholdBls12381MinSig::partial_sign_batch(shares.iter(), msg);

create.bench_function(format!("w={}", w).as_str(), |b| {
b.iter(|| ThresholdBls12381MinSig::aggregate(w as u32, &sigs).unwrap())
b.iter(|| ThresholdBls12381MinSig::aggregate(w as u32, sigs.iter()).unwrap())
});
}
}
Expand Down
5 changes: 2 additions & 3 deletions fastcrypto-tbls/src/nidkg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,8 @@ where
.map(|(i, pk)| Eval {
index: NonZeroU32::new((i + 1) as u32).expect("non zero"),
value: *pk,
})
.collect::<Vec<Eval<G>>>();
let pk = Poly::<G>::recover_c0(self.t, &evals).expect("enough shares");
});
let pk = Poly::<G>::recover_c0(self.t, evals).expect("enough shares");

(pk, partial_pks)
}
Expand Down
50 changes: 30 additions & 20 deletions fastcrypto-tbls/src/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use fastcrypto::error::{FastCryptoError, FastCryptoResult};
use fastcrypto::groups::{GroupElement, MultiScalarMul, Scalar};
use fastcrypto::traits::AllowedRng;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashSet;

/// Types
Expand Down Expand Up @@ -81,22 +82,25 @@ impl<C: GroupElement> Poly<C> {
// Expects exactly t unique shares.
fn get_lagrange_coefficients_for_c0(
t: u32,
shares: &[Eval<C>],
mut shares: impl Iterator<Item = impl Borrow<Eval<C>>>,
) -> FastCryptoResult<Vec<C::ScalarType>> {
if shares.len() != t as usize {
return Err(FastCryptoError::InvalidInput);
}
// Check for duplicates.
let mut ids_set = HashSet::new();
if !shares.iter().map(|s| &s.index).all(|id| ids_set.insert(id)) {
return Err(FastCryptoError::InvalidInput); // expected unique ids
let (shares_size_lower, shares_size_upper) = shares.size_hint();
let indices = shares.try_fold(
Vec::with_capacity(shares_size_upper.unwrap_or(shares_size_lower)),
|mut vec, s| {
// Check for duplicates.
if !ids_set.insert(s.borrow().index) {
return Err(FastCryptoError::InvalidInput); // expected unique ids
}
vec.push(C::ScalarType::from(s.borrow().index.get() as u64));
Ok(vec)
},
)?;
if indices.len() != t as usize {
return Err(FastCryptoError::InvalidInput);
}

let indices = shares
.iter()
.map(|s| C::ScalarType::from(s.index.get() as u64))
.collect::<Vec<_>>();

let full_numerator = indices
.iter()
.fold(C::ScalarType::generator(), |acc, i| acc * i);
Expand All @@ -113,13 +117,16 @@ impl<C: GroupElement> Poly<C> {
}

/// Given exactly `t` polynomial evaluations, it will recover the polynomial's constant term.
pub fn recover_c0(t: u32, shares: &[Eval<C>]) -> Result<C, FastCryptoError> {
let coeffs = Self::get_lagrange_coefficients_for_c0(t, shares)?;
let plain_shares = shares.iter().map(|s| s.value).collect::<Vec<_>>();
pub fn recover_c0(
t: u32,
shares: impl Iterator<Item = impl Borrow<Eval<C>>> + Clone,
) -> Result<C, FastCryptoError> {
let coeffs = Self::get_lagrange_coefficients_for_c0(t, shares.clone())?;
let plain_shares = shares.map(|s| s.borrow().value);
let res = coeffs
.iter()
.zip(plain_shares.iter())
.fold(C::zero(), |acc, (c, s)| acc + (*s * *c));
.zip(plain_shares)
.fold(C::zero(), |acc, (c, s)| acc + (s * *c));
Ok(res)
}

Expand Down Expand Up @@ -172,9 +179,12 @@ impl<C: Scalar> Poly<C> {
impl<C: GroupElement + MultiScalarMul> Poly<C> {
/// Given exactly `t` polynomial evaluations, it will recover the polynomial's
/// constant term.
pub fn recover_c0_msm(t: u32, shares: &[Eval<C>]) -> Result<C, FastCryptoError> {
let coeffs = Self::get_lagrange_coefficients_for_c0(t, shares)?;
let plain_shares = shares.iter().map(|s| s.value).collect::<Vec<_>>();
pub fn recover_c0_msm(
t: u32,
shares: impl Iterator<Item = impl Borrow<Eval<C>>> + Clone,
) -> Result<C, FastCryptoError> {
let coeffs = Self::get_lagrange_coefficients_for_c0(t, shares.clone())?;
let plain_shares = shares.map(|s| s.borrow().value).collect::<Vec<_>>();
let res = C::multi_scalar_mul(&coeffs, &plain_shares).expect("sizes match");
Ok(res)
}
Expand Down
51 changes: 25 additions & 26 deletions fastcrypto-tbls/src/tbls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
// Some of the code below is based on code from https://github.com/celo-org/celo-threshold-bls-rs,
// modified for our needs.

use std::borrow::Borrow;

use crate::dl_verification::{batch_coefficients, get_random_scalars};
use crate::polynomial::Poly;
use crate::types::IndexedValue;
Expand All @@ -29,20 +31,22 @@ pub trait ThresholdBls {

/// Sign a message using the private share/partial key.
fn partial_sign(share: &Share<Self::Private>, msg: &[u8]) -> PartialSignature<Self::Signature> {
Self::partial_sign_batch(&[share.clone()], msg)[0].clone()
Self::partial_sign_batch(std::iter::once(share), msg)[0].clone()
}

/// Sign a message using one of more private share/partial keys.
fn partial_sign_batch(
shares: &[Share<Self::Private>],
shares: impl Iterator<Item = impl Borrow<Share<Self::Private>>>,
msg: &[u8],
) -> Vec<PartialSignature<Self::Signature>> {
let h = Self::Signature::hash_to_group_element(msg);
shares
.iter()
.map(|share| PartialSignature {
index: share.index,
value: h * share.value,
.map(|share| {
let share = share.borrow();
PartialSignature {
index: share.index,
value: h * share.value,
}
})
.collect()
}
Expand All @@ -63,46 +67,41 @@ pub trait ThresholdBls {
fn partial_verify_batch<R: AllowedRng>(
vss_pk: &Poly<Self::Public>,
msg: &[u8],
partial_sigs: &[PartialSignature<Self::Signature>],
partial_sigs: impl Iterator<Item = impl Borrow<PartialSignature<Self::Signature>>>,
rng: &mut R,
) -> FastCryptoResult<()> {
assert!(vss_pk.degree() > 0 || !msg.is_empty());
if partial_sigs.is_empty() {
let (evals_as_scalars, points): (Vec<_>, Vec<_>) = partial_sigs
.map(|sig| {
let sig = sig.borrow();
(Self::Private::from(sig.index.get().into()), sig.value)
})
.unzip();
if points.is_empty() {
return Ok(());
}
let rs = get_random_scalars::<Self::Private, R>(partial_sigs.len() as u32, rng);
let evals_as_scalars = partial_sigs
.iter()
.map(|e| Self::Private::from(e.index.get().into()))
.collect::<Vec<_>>();
let rs = get_random_scalars::<Self::Private, R>(points.len() as u32, rng);
// TODO: should we cache it instead? that would replace t-wide msm with w-wide msm.
let coeffs = batch_coefficients(&rs, &evals_as_scalars, vss_pk.degree());
let pk = Self::Public::multi_scalar_mul(&coeffs, vss_pk.as_vec()).expect("sizes match");
let aggregated_sig = Self::Signature::multi_scalar_mul(
&rs,
&partial_sigs.iter().map(|s| s.value).collect::<Vec<_>>(),
)
.expect("sizes match");
let aggregated_sig = Self::Signature::multi_scalar_mul(&rs, &points).expect("sizes match");

Self::verify(&pk, msg, &aggregated_sig)
}

/// Interpolate partial signatures to recover the full signature.
fn aggregate(
threshold: u32,
partials: &[PartialSignature<Self::Signature>],
partials: impl Iterator<Item = impl Borrow<PartialSignature<Self::Signature>>> + Clone,
) -> FastCryptoResult<Self::Signature> {
let unique_partials = partials
.iter()
.unique_by(|p| p.index)
.take(threshold as usize)
.cloned()
.collect::<Vec<_>>();
if unique_partials.len() != threshold as usize {
.unique_by(|p| p.borrow().index)
.take(threshold as usize);
if unique_partials.clone().count() != threshold as usize {
return Err(FastCryptoError::NotEnoughInputs);
}
// No conversion is required since PartialSignature<S> and Eval<S> are different aliases to
// IndexedValue<S>.
Poly::<Self::Signature>::recover_c0_msm(threshold, &unique_partials)
Poly::<Self::Signature>::recover_c0_msm(threshold, unique_partials)
}
}
2 changes: 1 addition & 1 deletion fastcrypto-tbls/src/tests/dkg_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ fn test_dkg_e2e_5_parties_min_weight_2_threshold_4() {
S::partial_verify(&o3.vss_pk, &MSG, &sig31).unwrap();

let sigs = vec![sig00, sig30, sig31];
let sig = S::aggregate(d0.t(), &sigs).unwrap();
let sig = S::aggregate(d0.t(), sigs.iter()).unwrap();
S::verify(o0.vss_pk.c0(), &MSG, &sig).unwrap();
}

Expand Down
37 changes: 17 additions & 20 deletions fastcrypto-tbls/src/tests/polynomial_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,13 @@ mod scalar_tests {
let threshold = degree + 1;
let poly = Poly::<S>::rand(4, &mut thread_rng());
// insufficient shares gathered
let shares = (1..threshold)
.map(|i| poly.eval(ShareIndex::new(i).unwrap()))
.collect::<Vec<_>>();
Poly::<S>::recover_c0(threshold, &shares).unwrap_err();
let shares = (1..threshold).map(|i| poly.eval(ShareIndex::new(i).unwrap()));
Poly::<S>::recover_c0(threshold, shares).unwrap_err();
// duplications
let mut shares = (1..=threshold)
let shares = (1..=threshold)
.map(|i| poly.eval(ShareIndex::new(i).unwrap()))
.collect::<Vec<_>>();
shares.push(shares[0].clone());
Poly::<S>::recover_c0(threshold, &shares).unwrap_err();
.chain(std::iter::once(poly.eval(ShareIndex::new(1).unwrap()))); // duplicate value 1
Poly::<S>::recover_c0(threshold, shares).unwrap_err();
}

#[test]
Expand All @@ -75,7 +72,7 @@ mod scalar_tests {
let c0 = poly.c0();
for _ in 0..10 {
shares.shuffle(&mut thread_rng());
let used_shares = &shares[..124];
let used_shares = shares.iter().take(124);
assert_eq!(c0, &Poly::<S>::recover_c0(124, used_shares).unwrap());
}
}
Expand Down Expand Up @@ -112,7 +109,10 @@ mod points_tests {
let s2 = p.eval(NonZeroU32::new(20).unwrap());
let s3 = p.eval(NonZeroU32::new(30).unwrap());
let shares = vec![s1, s2, s3];
assert_eq!(Poly::<G::ScalarType>::recover_c0(3, &shares).unwrap(), one);
assert_eq!(
Poly::<G::ScalarType>::recover_c0(3, shares.iter()).unwrap(),
one
);
}

#[test]
Expand All @@ -122,16 +122,13 @@ mod points_tests {
let poly = Poly::<G::ScalarType>::rand(4, &mut thread_rng());
let poly_g = poly.commit();
// insufficient shares gathered
let shares = (1..threshold)
.map(|i| poly_g.eval(ShareIndex::new(i).unwrap()))
.collect::<Vec<_>>();
Poly::<G>::recover_c0_msm(threshold, &shares).unwrap_err();
let shares = (1..threshold).map(|i| poly_g.eval(ShareIndex::new(i).unwrap()));
Poly::<G>::recover_c0_msm(threshold, shares).unwrap_err();
// duplications
let mut shares = (1..threshold)
let shares = (1..threshold)
.map(|i| poly_g.eval(ShareIndex::new(i).unwrap()))
.collect::<Vec<_>>();
shares.push(shares[0].clone());
Poly::<G>::recover_c0_msm(threshold, &shares).unwrap_err();
.chain(std::iter::once(poly_g.eval(ShareIndex::new(1).unwrap()))); // duplicate value 1
Poly::<G>::recover_c0_msm(threshold, shares).unwrap_err();
}

#[test]
Expand All @@ -144,7 +141,7 @@ mod points_tests {
let s2 = p.eval(NonZeroU32::new(20).unwrap());
let s3 = p.eval(NonZeroU32::new(30).unwrap());
let shares = vec![s1, s2, s3];
assert_eq!(Poly::<G>::recover_c0_msm(3, &shares).unwrap(), one);
assert_eq!(Poly::<G>::recover_c0_msm(3, shares.iter()).unwrap(), one);

// and random tests
let poly = Poly::<G::ScalarType>::rand(123, &mut thread_rng());
Expand All @@ -156,7 +153,7 @@ mod points_tests {
let c0 = poly_g.c0();
for _ in 0..10 {
shares.shuffle(&mut thread_rng());
let used_shares = &shares[..124];
let used_shares = shares.iter().take(124);
assert_eq!(c0, &Poly::<G>::recover_c0_msm(124, used_shares).unwrap());
}
}
Expand Down
Loading
Loading