diff --git a/fastcrypto-tbls/Cargo.toml b/fastcrypto-tbls/Cargo.toml index 6b46edbcfd..b2a106683d 100644 --- a/fastcrypto-tbls/Cargo.toml +++ b/fastcrypto-tbls/Cargo.toml @@ -41,6 +41,11 @@ name = "nidkg" harness = false required-features = ["experimental"] +[[bench]] +name = "tbls" +harness = false +required-features = ["experimental"] + [features] default = [] experimental = [] diff --git a/fastcrypto-tbls/benches/dkg.rs b/fastcrypto-tbls/benches/dkg.rs index 8c030967a3..a50f5a86b2 100644 --- a/fastcrypto-tbls/benches/dkg.rs +++ b/fastcrypto-tbls/benches/dkg.rs @@ -51,39 +51,51 @@ mod dkg_benches { use super::*; fn dkg(c: &mut Criterion) { - const SIZES: [u16; 1] = [100]; - const WEIGHTS: [u16; 3] = [10, 20, 33]; + const SIZES: [u16; 2] = [100, 200]; + const TOTAL_WEIGHTS: [u16; 3] = [2000, 3000, 5000]; { let mut create: BenchmarkGroup<_> = c.benchmark_group("DKG create"); - for (n, w) in iproduct!(SIZES.iter(), WEIGHTS.iter()) { - let t = (n * w / 2) as u32; + for (n, total_w) in iproduct!(SIZES.iter(), TOTAL_WEIGHTS.iter()) { + let w = total_w / n; + let t = (total_w / 3) as u32; let keys = gen_ecies_keys(*n); - let d0 = setup_party(0, t, *w, &keys); + let d0 = setup_party(0, t, w, &keys); - create.bench_function(format!("n={}, w={}, t={}", n, w, t).as_str(), |b| { - b.iter(|| d0.create_message(&mut thread_rng())) - }); + create.bench_function( + format!("n={}, total_weight={}, t={}, w={}", n, total_w, t, w).as_str(), + |b| b.iter(|| d0.create_message(&mut thread_rng())), + ); + + let message = d0.create_message(&mut thread_rng()); + println!( + "Message size for n={}, t={}: {}", + n, + t, + bcs::to_bytes(&message).unwrap().len(), + ); } } { let mut verify: BenchmarkGroup<_> = c.benchmark_group("DKG message processing"); - for (n, w) in iproduct!(SIZES.iter(), WEIGHTS.iter()) { - let t = (n * w / 2) as u32; + for (n, total_w) in iproduct!(SIZES.iter(), TOTAL_WEIGHTS.iter()) { + let w = total_w / n; + let t = (total_w / 3) as u32; let keys = gen_ecies_keys(*n); - let d0 = setup_party(0, t, *w, &keys); - let d1 = setup_party(1, t, *w, &keys); + let d0 = setup_party(0, t, w, &keys); + let d1 = setup_party(1, t, w, &keys); let message = d0.create_message(&mut thread_rng()); - println!("Message size: {}", bcs::to_bytes(&message).unwrap().len()); - - verify.bench_function(format!("n={}, w={}, t={}", n, w, t).as_str(), |b| { - b.iter(|| { - d1.process_message(message.clone(), &mut thread_rng()) - .unwrap() - }) - }); + verify.bench_function( + format!("n={}, total_weight={}, t={}, w={}", n, total_w, t, w).as_str(), + |b| { + b.iter(|| { + d1.process_message(message.clone(), &mut thread_rng()) + .unwrap() + }) + }, + ); } } } diff --git a/fastcrypto-tbls/benches/tbls.rs b/fastcrypto-tbls/benches/tbls.rs new file mode 100644 index 0000000000..a808cd0d2f --- /dev/null +++ b/fastcrypto-tbls/benches/tbls.rs @@ -0,0 +1,60 @@ +// Copyright (c) 2022, Mysten Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use criterion::{criterion_group, criterion_main, BenchmarkGroup, Criterion}; +use fastcrypto::groups::bls12381; +use rand::thread_rng; +use std::num::NonZeroU32; + +mod tbls_benches { + use super::*; + use fastcrypto_tbls::polynomial::Poly; + use fastcrypto_tbls::tbls::ThresholdBls; + use fastcrypto_tbls::types::ThresholdBls12381MinSig; + + fn tbls(c: &mut Criterion) { + let msg = b"test"; + + { + let mut create: BenchmarkGroup<_> = c.benchmark_group("Batch signing"); + let private_poly = Poly::::rand(500, &mut thread_rng()); + const WEIGHTS: [usize; 5] = [10, 20, 30, 40, 50]; + for w in WEIGHTS { + let shares = (1..=w) + .into_iter() + .map(|i| private_poly.eval(NonZeroU32::new(i as u32).unwrap())) + .collect::>(); + + create.bench_function(format!("w={}", w).as_str(), |b| { + b.iter(|| ThresholdBls12381MinSig::partial_sign_batch(&shares, msg)) + }); + } + } + + { + let mut create: BenchmarkGroup<_> = c.benchmark_group("Recover full signature"); + const TOTAL_WEIGHTS: [usize; 4] = [666, 833, 1111, 1666]; + for w in TOTAL_WEIGHTS { + let private_poly = Poly::::rand(w as u32, &mut thread_rng()); + let shares = (1..=w) + .into_iter() + .map(|i| private_poly.eval(NonZeroU32::new(i as u32).unwrap())) + .collect::>(); + + let sigs = ThresholdBls12381MinSig::partial_sign_batch(&shares, msg); + + create.bench_function(format!("w={}", w).as_str(), |b| { + b.iter(|| ThresholdBls12381MinSig::aggregate(w as u32, &sigs).unwrap()) + }); + } + } + } + + criterion_group! { + name = tbls_benches; + config = Criterion::default(); + targets = tbls, + } +} + +criterion_main!(tbls_benches::tbls_benches); diff --git a/fastcrypto-tbls/src/polynomial.rs b/fastcrypto-tbls/src/polynomial.rs index be8b61c299..710bbc815c 100644 --- a/fastcrypto-tbls/src/polynomial.rs +++ b/fastcrypto-tbls/src/polynomial.rs @@ -6,8 +6,8 @@ // use crate::types::{IndexedValue, ShareIndex}; -use fastcrypto::error::FastCryptoError; -use fastcrypto::groups::{GroupElement, Scalar}; +use fastcrypto::error::{FastCryptoError, FastCryptoResult}; +use fastcrypto::groups::{GroupElement, MultiScalarMul, Scalar}; use fastcrypto::traits::AllowedRng; use serde::{Deserialize, Serialize}; use std::collections::HashSet; @@ -82,49 +82,52 @@ impl Poly { } } - /// Given at least `t` polynomial evaluations, it will recover the polynomial's - /// constant term - pub fn recover_c0(t: u32, shares: &[Eval]) -> Result { + fn get_lagrange_coefficients( + t: u32, + shares: &[Eval], + ) -> FastCryptoResult> { if shares.len() < t.try_into().unwrap() { return Err(FastCryptoError::InvalidInput); } - // Check for duplicates. let mut ids_set = HashSet::new(); shares.iter().map(|s| &s.index).for_each(|id| { ids_set.insert(id); }); - if ids_set.len() != t as usize { + if ids_set.len() != shares.len() { return Err(FastCryptoError::InvalidInput); } - // Iterate over all indices and for each multiply the lagrange basis - // with the value of the share. - let mut acc = C::zero(); - for IndexedValue { - index: i, - value: share_i, - } in shares - { - let mut num = C::ScalarType::generator(); - let mut den = C::ScalarType::generator(); - - for IndexedValue { index: j, value: _ } in shares { - if i == j { - continue; - }; - // j - 0 - num = num * C::ScalarType::from(j.get() as u64); - // 1 / (j - i) - den = den - * (C::ScalarType::from(j.get() as u64) - C::ScalarType::from(i.get() as u64)); - } - // Next line is safe since i != j. - let inv = (C::ScalarType::generator() / den).unwrap(); - acc += *share_i * num * inv; + let indices = shares + .iter() + .map(|s| C::ScalarType::from(s.index.get() as u64)) + .collect::>(); + + let full_numerator = indices + .iter() + .fold(C::ScalarType::generator(), |acc, i| acc * i); + let mut coeffs = Vec::new(); + for i in &indices { + let denominator = indices + .iter() + .filter(|j| *j != i) + .fold(*i, |acc, j| acc * (*j - i)); + let coeff = full_numerator / denominator; + coeffs.push(coeff.expect("safe since i != j")); } + Ok(coeffs) + } - Ok(acc) + /// Given at least `t` polynomial evaluations, it will recover the polynomial's + /// constant term + pub fn recover_c0(t: u32, shares: &[Eval]) -> Result { + let coeffs = Self::get_lagrange_coefficients(t, shares)?; + let plain_shares = shares.iter().map(|s| s.value).collect::>(); + let res = coeffs + .iter() + .zip(plain_shares.iter()) + .fold(C::zero(), |acc, (c, s)| acc + (*s * *c)); + Ok(res) } /// Checks if a given share is valid. @@ -172,3 +175,14 @@ impl Poly { Poly::

::from(commits) } } + +impl Poly { + /// Given at least `t` polynomial evaluations, it will recover the polynomial's + /// constant term + pub fn recover_c0_msm(t: u32, shares: &[Eval]) -> Result { + let coeffs = Self::get_lagrange_coefficients(t, shares)?; + let plain_shares = shares.iter().map(|s| s.value).collect::>(); + let res = C::multi_scalar_mul(&coeffs, &plain_shares).expect("sizes match"); + Ok(res) + } +} diff --git a/fastcrypto-tbls/src/tbls.rs b/fastcrypto-tbls/src/tbls.rs index 8955ab8b98..307e0ab89a 100644 --- a/fastcrypto-tbls/src/tbls.rs +++ b/fastcrypto-tbls/src/tbls.rs @@ -83,6 +83,7 @@ pub trait ThresholdBls { .iter() .map(|e| Self::Private::from(e.index.get().into())) .collect::>(); + // TODO: should we cache it instead? let coeffs = batch_coefficients(&rs, &evals_as_scalars, vss_pk.degree()); let pk = Self::Public::multi_scalar_mul(&coeffs, vss_pk.as_vec()).expect("sizes match"); let aggregated_sig = Self::Signature::multi_scalar_mul( @@ -101,6 +102,6 @@ pub trait ThresholdBls { ) -> Result { // No conversion is required since PartialSignature and Eval are different aliases to // IndexedValue. - Poly::::recover_c0(threshold, partials) + Poly::::recover_c0_msm(threshold, partials) } } diff --git a/fastcrypto-tbls/src/tests/polynomial_tests.rs b/fastcrypto-tbls/src/tests/polynomial_tests.rs index f056466fc7..7ca7424365 100644 --- a/fastcrypto-tbls/src/tests/polynomial_tests.rs +++ b/fastcrypto-tbls/src/tests/polynomial_tests.rs @@ -7,6 +7,7 @@ use crate::polynomial::*; use crate::types::ShareIndex; +use fastcrypto::groups::bls12381::G1Element; use fastcrypto::groups::ristretto255::{RistrettoPoint, RistrettoScalar}; use fastcrypto::groups::*; use rand::prelude::*; @@ -79,3 +80,16 @@ fn interpolation_insufficient_shares() { Poly::::recover_c0(threshold, &shares).unwrap_err(); } + +#[test] +fn eval_regression_msm() { + let one = G1Element::generator(); + let coeff = vec![one, one, one]; + let p = Poly::::from(coeff); + assert_eq!(p.degree(), 2); + let s1 = p.eval(NonZeroU32::new(10).unwrap()); + let s2 = p.eval(NonZeroU32::new(20).unwrap()); + let s3 = p.eval(NonZeroU32::new(30).unwrap()); + let shares = vec![s1, s2, s3]; + assert_eq!(Poly::::recover_c0_msm(3, &shares).unwrap(), one); +}