Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More benchmarks, faster computation of lagrange coeffs, and msm #655

Merged
merged 6 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions fastcrypto-tbls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ name = "nidkg"
harness = false
required-features = ["experimental"]

[[bench]]
name = "tbls"
harness = false
required-features = ["experimental"]

[features]
default = []
experimental = []
52 changes: 32 additions & 20 deletions fastcrypto-tbls/benches/dkg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,39 +51,51 @@ mod dkg_benches {
use super::*;

fn dkg(c: &mut Criterion) {
const SIZES: [u16; 1] = [100];
const WEIGHTS: [u16; 3] = [10, 20, 33];
const SIZES: [u16; 2] = [100, 200];
const TOTAL_WEIGHTS: [u16; 3] = [2000, 3000, 5000];

{
let mut create: BenchmarkGroup<_> = c.benchmark_group("DKG create");
for (n, w) in iproduct!(SIZES.iter(), WEIGHTS.iter()) {
let t = (n * w / 2) as u32;
for (n, total_w) in iproduct!(SIZES.iter(), TOTAL_WEIGHTS.iter()) {
let w = total_w / n;
let t = (total_w / 3) as u32;
let keys = gen_ecies_keys(*n);
let d0 = setup_party(0, t, *w, &keys);
let d0 = setup_party(0, t, w, &keys);

create.bench_function(format!("n={}, w={}, t={}", n, w, t).as_str(), |b| {
b.iter(|| d0.create_message(&mut thread_rng()))
});
create.bench_function(
format!("n={}, total_weight={}, t={}, w={}", n, total_w, t, w).as_str(),
|b| b.iter(|| d0.create_message(&mut thread_rng())),
);

let message = d0.create_message(&mut thread_rng());
println!(
"Message size for n={}, t={}: {}",
n,
t,
bcs::to_bytes(&message).unwrap().len(),
);
}
}

{
let mut verify: BenchmarkGroup<_> = c.benchmark_group("DKG message processing");
for (n, w) in iproduct!(SIZES.iter(), WEIGHTS.iter()) {
let t = (n * w / 2) as u32;
for (n, total_w) in iproduct!(SIZES.iter(), TOTAL_WEIGHTS.iter()) {
let w = total_w / n;
let t = (total_w / 3) as u32;
let keys = gen_ecies_keys(*n);
let d0 = setup_party(0, t, *w, &keys);
let d1 = setup_party(1, t, *w, &keys);
let d0 = setup_party(0, t, w, &keys);
let d1 = setup_party(1, t, w, &keys);
let message = d0.create_message(&mut thread_rng());

println!("Message size: {}", bcs::to_bytes(&message).unwrap().len());

verify.bench_function(format!("n={}, w={}, t={}", n, w, t).as_str(), |b| {
b.iter(|| {
d1.process_message(message.clone(), &mut thread_rng())
.unwrap()
})
});
verify.bench_function(
format!("n={}, total_weight={}, t={}, w={}", n, total_w, t, w).as_str(),
|b| {
b.iter(|| {
d1.process_message(message.clone(), &mut thread_rng())
.unwrap()
})
},
);
}
}
}
Expand Down
60 changes: 60 additions & 0 deletions fastcrypto-tbls/benches/tbls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (c) 2022, Mysten Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

use criterion::{criterion_group, criterion_main, BenchmarkGroup, Criterion};
use fastcrypto::groups::bls12381;
use rand::thread_rng;
use std::num::NonZeroU32;

mod tbls_benches {
use super::*;
use fastcrypto_tbls::polynomial::Poly;
use fastcrypto_tbls::tbls::ThresholdBls;
use fastcrypto_tbls::types::ThresholdBls12381MinSig;

fn tbls(c: &mut Criterion) {
let msg = b"test";

{
let mut create: BenchmarkGroup<_> = c.benchmark_group("Batch signing");
let private_poly = Poly::<bls12381::Scalar>::rand(500, &mut thread_rng());
const WEIGHTS: [usize; 5] = [10, 20, 30, 40, 50];
for w in WEIGHTS {
let shares = (1..=w)
.into_iter()
.map(|i| private_poly.eval(NonZeroU32::new(i as u32).unwrap()))
.collect::<Vec<_>>();

create.bench_function(format!("w={}", w).as_str(), |b| {
b.iter(|| ThresholdBls12381MinSig::partial_sign_batch(&shares, msg))
});
}
}

{
let mut create: BenchmarkGroup<_> = c.benchmark_group("Recover full signature");
const TOTAL_WEIGHTS: [usize; 4] = [666, 833, 1111, 1666];
for w in TOTAL_WEIGHTS {
let private_poly = Poly::<bls12381::Scalar>::rand(w as u32, &mut thread_rng());
let shares = (1..=w)
.into_iter()
.map(|i| private_poly.eval(NonZeroU32::new(i as u32).unwrap()))
.collect::<Vec<_>>();

let sigs = ThresholdBls12381MinSig::partial_sign_batch(&shares, msg);

create.bench_function(format!("w={}", w).as_str(), |b| {
b.iter(|| ThresholdBls12381MinSig::aggregate(w as u32, &sigs).unwrap())
});
}
}
}

criterion_group! {
name = tbls_benches;
config = Criterion::default();
targets = tbls,
}
}

criterion_main!(tbls_benches::tbls_benches);
78 changes: 46 additions & 32 deletions fastcrypto-tbls/src/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
//

use crate::types::{IndexedValue, ShareIndex};
use fastcrypto::error::FastCryptoError;
use fastcrypto::groups::{GroupElement, Scalar};
use fastcrypto::error::{FastCryptoError, FastCryptoResult};
use fastcrypto::groups::{GroupElement, MultiScalarMul, Scalar};
use fastcrypto::traits::AllowedRng;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
Expand Down Expand Up @@ -82,49 +82,52 @@ impl<C: GroupElement> Poly<C> {
}
}

/// Given at least `t` polynomial evaluations, it will recover the polynomial's
/// constant term
pub fn recover_c0(t: u32, shares: &[Eval<C>]) -> Result<C, FastCryptoError> {
fn get_lagrange_coefficients(
t: u32,
shares: &[Eval<C>],
) -> FastCryptoResult<Vec<C::ScalarType>> {
if shares.len() < t.try_into().unwrap() {
return Err(FastCryptoError::InvalidInput);
}

// Check for duplicates.
let mut ids_set = HashSet::new();
shares.iter().map(|s| &s.index).for_each(|id| {
ids_set.insert(id);
});
if ids_set.len() != t as usize {
if ids_set.len() != shares.len() {
return Err(FastCryptoError::InvalidInput);
}

// Iterate over all indices and for each multiply the lagrange basis
// with the value of the share.
let mut acc = C::zero();
for IndexedValue {
index: i,
value: share_i,
} in shares
{
let mut num = C::ScalarType::generator();
let mut den = C::ScalarType::generator();

for IndexedValue { index: j, value: _ } in shares {
if i == j {
continue;
};
// j - 0
num = num * C::ScalarType::from(j.get() as u64);
// 1 / (j - i)
den = den
* (C::ScalarType::from(j.get() as u64) - C::ScalarType::from(i.get() as u64));
}
// Next line is safe since i != j.
let inv = (C::ScalarType::generator() / den).unwrap();
acc += *share_i * num * inv;
let indices = shares
.iter()
.map(|s| C::ScalarType::from(s.index.get() as u64))
.collect::<Vec<_>>();

let full_numerator = indices
.iter()
.fold(C::ScalarType::generator(), |acc, i| acc * i);
let mut coeffs = Vec::new();
for i in &indices {
let denominator = indices
.iter()
.filter(|j| *j != i)
.fold(*i, |acc, j| acc * (*j - i));
let coeff = full_numerator / denominator;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small optimisation idea: In some cases, this will be implemented as full_numerator * invert(denominator), and if that is the case here, the multiplication by full_numerator may be done on the sum in recover_c0 instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with msm that'd require another exponentiation, will test the difference

coeffs.push(coeff.expect("safe since i != j"));
}
Ok(coeffs)
}

Ok(acc)
/// Given at least `t` polynomial evaluations, it will recover the polynomial's
/// constant term
pub fn recover_c0(t: u32, shares: &[Eval<C>]) -> Result<C, FastCryptoError> {
let coeffs = Self::get_lagrange_coefficients(t, shares)?;
let plain_shares = shares.iter().map(|s| s.value).collect::<Vec<_>>();
let res = coeffs
.iter()
.zip(plain_shares.iter())
.fold(C::zero(), |acc, (c, s)| acc + (*s * *c));
Ok(res)
}

/// Checks if a given share is valid.
Expand Down Expand Up @@ -172,3 +175,14 @@ impl<C: Scalar> Poly<C> {
Poly::<P>::from(commits)
}
}

impl<C: GroupElement + MultiScalarMul> Poly<C> {
/// Given at least `t` polynomial evaluations, it will recover the polynomial's
/// constant term
pub fn recover_c0_msm(t: u32, shares: &[Eval<C>]) -> Result<C, FastCryptoError> {
let coeffs = Self::get_lagrange_coefficients(t, shares)?;
let plain_shares = shares.iter().map(|s| s.value).collect::<Vec<_>>();
let res = C::multi_scalar_mul(&coeffs, &plain_shares).expect("sizes match");
Ok(res)
}
}
3 changes: 2 additions & 1 deletion fastcrypto-tbls/src/tbls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ pub trait ThresholdBls {
.iter()
.map(|e| Self::Private::from(e.index.get().into()))
.collect::<Vec<_>>();
// TODO: should we cache it instead?
let coeffs = batch_coefficients(&rs, &evals_as_scalars, vss_pk.degree());
let pk = Self::Public::multi_scalar_mul(&coeffs, vss_pk.as_vec()).expect("sizes match");
let aggregated_sig = Self::Signature::multi_scalar_mul(
Expand All @@ -101,6 +102,6 @@ pub trait ThresholdBls {
) -> Result<Self::Signature, FastCryptoError> {
// No conversion is required since PartialSignature<S> and Eval<S> are different aliases to
// IndexedValue<S>.
Poly::<Self::Signature>::recover_c0(threshold, partials)
Poly::<Self::Signature>::recover_c0_msm(threshold, partials)
}
}
14 changes: 14 additions & 0 deletions fastcrypto-tbls/src/tests/polynomial_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

use crate::polynomial::*;
use crate::types::ShareIndex;
use fastcrypto::groups::bls12381::G1Element;
use fastcrypto::groups::ristretto255::{RistrettoPoint, RistrettoScalar};
use fastcrypto::groups::*;
use rand::prelude::*;
Expand Down Expand Up @@ -79,3 +80,16 @@ fn interpolation_insufficient_shares() {

Poly::<RistrettoScalar>::recover_c0(threshold, &shares).unwrap_err();
}

#[test]
fn eval_regression_msm() {
let one = G1Element::generator();
let coeff = vec![one, one, one];
let p = Poly::<G1Element>::from(coeff);
assert_eq!(p.degree(), 2);
let s1 = p.eval(NonZeroU32::new(10).unwrap());
let s2 = p.eval(NonZeroU32::new(20).unwrap());
let s3 = p.eval(NonZeroU32::new(30).unwrap());
let shares = vec![s1, s2, s3];
assert_eq!(Poly::<G1Element>::recover_c0_msm(3, &shares).unwrap(), one);
}