Skip to content

Commit

Permalink
feat(zk): implement faster pke proof
Browse files Browse the repository at this point in the history
- original work by Sarah El kazdadi

co-authored-by: sarah el kazdadi <sarah.elkazdadi@zama.ai>
  • Loading branch information
IceTDrinker and sarah el kazdadi committed Sep 6, 2024
1 parent 32b45ac commit ce9da12
Show file tree
Hide file tree
Showing 9 changed files with 2,552 additions and 216 deletions.
1 change: 1 addition & 0 deletions tfhe-zk-pok/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ rayon = "1.8.0"
sha3 = "0.10.8"
serde = { version = "~1.0", features = ["derive"] }
zeroize = "1.7.0"
num-bigint = "0.4.5"

[dev-dependencies]
serde_json = "~1.0"
46 changes: 38 additions & 8 deletions tfhe-zk-pok/src/curve_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,14 @@ impl CurveGroupOps<bls12_381::Zp> for bls12_381::G1 {
}

fn mul_scalar(self, scalar: bls12_381::Zp) -> Self {
self.mul_scalar(scalar)
if scalar.inner == MontFp!("2") {
self.double()
} else {
self.mul_scalar(scalar)
}
}

#[track_caller]
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_381::Zp]) -> Self {
Self::Affine::multi_mul_scalar(bases, scalars)
}
Expand Down Expand Up @@ -245,9 +250,14 @@ impl CurveGroupOps<bls12_381::Zp> for bls12_381::G2 {
}

fn mul_scalar(self, scalar: bls12_381::Zp) -> Self {
self.mul_scalar(scalar)
if scalar.inner == MontFp!("2") {
self.double()
} else {
self.mul_scalar(scalar)
}
}

#[track_caller]
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_381::Zp]) -> Self {
Self::Affine::multi_mul_scalar(bases, scalars)
}
Expand All @@ -273,6 +283,9 @@ impl PairingGroupOps<bls12_381::Zp, bls12_381::G1, bls12_381::G2> for bls12_381:
}

fn pairing(x: bls12_381::G1, y: bls12_381::G2) -> Self {
if x == bls12_381::G1::ZERO || y == bls12_381::G2::ZERO {
return Self::pairing(bls12_381::G1::ZERO, bls12_381::G2::GENERATOR);
}
Self::pairing(x, y)
}
}
Expand Down Expand Up @@ -329,12 +342,21 @@ impl CurveGroupOps<bls12_446::Zp> for bls12_446::G1 {
}

fn mul_scalar(self, scalar: bls12_446::Zp) -> Self {
self.mul_scalar(scalar)
if scalar.inner == MontFp!("2") {
self.double()
} else {
self.mul_scalar(scalar)
}
}

#[track_caller]
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_446::Zp]) -> Self {
msm::msm_wnaf_g1_446(bases, scalars)
// Self::Affine::multi_mul_scalar(bases, scalars)
// overhead seems to not be worth it outside of wasm
if cfg!(target_family = "wasm") {
msm::msm_wnaf_g1_446(bases, scalars)
} else {
Self::Affine::multi_mul_scalar(bases, scalars)
}
}

fn to_bytes(self) -> impl AsRef<[u8]> {
Expand Down Expand Up @@ -365,9 +387,14 @@ impl CurveGroupOps<bls12_446::Zp> for bls12_446::G2 {
}

fn mul_scalar(self, scalar: bls12_446::Zp) -> Self {
self.mul_scalar(scalar)
if scalar.inner == MontFp!("2") {
self.double()
} else {
self.mul_scalar(scalar)
}
}

#[track_caller]
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_446::Zp]) -> Self {
Self::Affine::multi_mul_scalar(bases, scalars)
}
Expand All @@ -393,13 +420,16 @@ impl PairingGroupOps<bls12_446::Zp, bls12_446::G1, bls12_446::G2> for bls12_446:
}

fn pairing(x: bls12_446::G1, y: bls12_446::G2) -> Self {
if x == bls12_446::G1::ZERO || y == bls12_446::G2::ZERO {
return Self::pairing(bls12_446::G1::ZERO, bls12_446::G2::GENERATOR);
}
Self::pairing(x, y)
}
}

#[derive(Copy, Clone, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Copy, Clone, serde::Serialize, serde::Deserialize)]
pub struct Bls12_381;
#[derive(Copy, Clone, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Copy, Clone, serde::Serialize, serde::Deserialize)]
pub struct Bls12_446;

impl Curve for Bls12_381 {
Expand Down
7 changes: 5 additions & 2 deletions tfhe-zk-pok/src/curve_api/bls12_446.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ mod g1 {
}

impl G1Affine {
#[track_caller]
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G1 {
// SAFETY: interpreting a `repr(transparent)` pointer as its contents.
G1 {
Expand Down Expand Up @@ -124,6 +125,7 @@ mod g1 {
}
}

#[track_caller]
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self {
use rayon::prelude::*;
let bases = bases
Expand Down Expand Up @@ -230,6 +232,7 @@ mod g2 {
}

impl G2Affine {
#[track_caller]
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G2 {
// SAFETY: interpreting a `repr(transparent)` pointer as its contents.
G2 {
Expand All @@ -247,10 +250,10 @@ mod g2 {
// functions. we cache it since it requires a Zp division
// https://hackmd.io/@tazAymRSQCGXTUKkbh1BAg/Sk27liTW9#Math-Formula-for-Point-Addition
pub(crate) fn compute_m(self, other: G2Affine) -> Option<crate::curve_446::Fq2> {
let zero = crate::curve_446::Fq2::ZERO;

// in the context of elliptic curves, the point at infinity is the zero element of the
// group
let zero = crate::curve_446::Fq2::ZERO;

if self.inner.infinity || other.inner.infinity {
return None;
}
Expand Down
207 changes: 2 additions & 205 deletions tfhe-zk-pok/src/curve_api/msm.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use ark_ec::short_weierstrass::Affine;
use ark_ec::AffineRepr;
use ark_ff::{AdditiveGroup, BigInt, BigInteger, Field, Fp, PrimeField};
use ark_ff::{AdditiveGroup, BigInteger, Field, Fp, PrimeField};
use rayon::prelude::*;

fn make_digits(a: &impl BigInteger, w: usize, num_bits: usize) -> impl Iterator<Item = i64> + '_ {
Expand Down Expand Up @@ -46,6 +46,7 @@ fn make_digits(a: &impl BigInteger, w: usize, num_bits: usize) -> impl Iterator<
}

// Compute msm using windowed non-adjacent form
#[track_caller]
pub fn msm_wnaf_g1_446(
bases: &[super::bls12_446::G1Affine],
scalars: &[super::bls12_446::Zp],
Expand Down Expand Up @@ -236,207 +237,3 @@ pub fn msm_wnaf_g1_446(
total
})
}

// Compute msm using windowed non-adjacent form
pub fn msm_wnaf_g1_446_extended(
bases: &[super::bls12_446::G1Affine],
scalars: &[super::bls12_446::Zp],
) -> super::bls12_446::G1 {
use super::bls12_446::*;
type BaseField = Fp<ark_ff::MontBackend<crate::curve_446::FqConfig, 7>, 7>;

// let num_bits = 75usize;
// let mask = BigInt([!0, (1 << 11) - 1, 0, 0, 0]);
// let scalars = &*scalars
// .par_iter()
// .map(|x| x.inner.into_bigint())
// .flat_map_iter(|x| (0..4).map(move |i| (x >> (75 * i)) & mask))
// .collect::<Vec<_>>();

let num_bits = 150usize;
let mask = BigInt([!0, !0, (1 << 22) - 1, 0, 0]);
let scalars = &*scalars
.par_iter()
.map(|x| x.inner.into_bigint())
.flat_map_iter(|x| (0..2).map(move |i| (x >> (150 * i)) & mask))
.collect::<Vec<_>>();

assert_eq!(bases.len(), scalars.len());

let size = bases.len();

let c = if size < 32 {
3
} else {
// natural log approx
(size.ilog2() as usize * 69 / 100) + 2
};
let c = c - 3;

let digits_count = (num_bits + c - 1) / c;
let scalar_digits = scalars
.into_par_iter()
.flat_map_iter(|s| make_digits(s, c, num_bits))
.collect::<Vec<_>>();

let zero = G1Affine {
inner: Affine::zero(),
};

let window_sums: Vec<_> = (0..digits_count)
.into_par_iter()
.map(|i| {
let n = 1 << c;
let mut indices = vec![vec![]; n];
let mut d = vec![BaseField::ZERO; n + 1];
let mut e = vec![BaseField::ZERO; n + 1];

for (idx, digits) in scalar_digits.chunks(digits_count).enumerate() {
use core::cmp::Ordering;
// digits is the digits thing of the first scalar?
let scalar = digits[i];
match 0.cmp(&scalar) {
Ordering::Less => indices[(scalar - 1) as usize].push(idx),
Ordering::Greater => indices[(-scalar - 1) as usize].push(!idx),
Ordering::Equal => (),
}
}

let mut buckets = vec![zero; 1 << c];

loop {
d[0] = BaseField::ONE;
for (k, (bucket, idx)) in core::iter::zip(&mut buckets, &mut indices).enumerate() {
if let Some(idx) = idx.last().copied() {
let value = if idx >> (usize::BITS - 1) == 1 {
let mut val = bases[!idx];
val.inner.y = -val.inner.y;
val
} else {
bases[idx]
};

if !bucket.inner.infinity {
let a = value.inner.x - bucket.inner.x;
if a != BaseField::ZERO {
d[k + 1] = d[k] * a;
} else if value.inner.y == bucket.inner.y {
d[k + 1] = d[k] * value.inner.y.double();
} else {
d[k + 1] = d[k];
}
continue;
}
}
d[k + 1] = d[k];
}
e[n] = d[n].inverse().unwrap();

for (k, (bucket, idx)) in core::iter::zip(&mut buckets, &mut indices)
.enumerate()
.rev()
{
if let Some(idx) = idx.last().copied() {
let value = if idx >> (usize::BITS - 1) == 1 {
let mut val = bases[!idx];
val.inner.y = -val.inner.y;
val
} else {
bases[idx]
};

if !bucket.inner.infinity {
let a = value.inner.x - bucket.inner.x;
if a != BaseField::ZERO {
e[k] = e[k + 1] * a;
} else if value.inner.y == bucket.inner.y {
e[k] = e[k + 1] * value.inner.y.double();
} else {
e[k] = e[k + 1];
}
continue;
}
}
e[k] = e[k + 1];
}

let d = &d[..n];
let e = &e[1..];

let mut empty = true;
for ((&d, &e), (bucket, idx)) in core::iter::zip(
core::iter::zip(d, e),
core::iter::zip(&mut buckets, &mut indices),
) {
empty &= idx.len() <= 1;
if let Some(idx) = idx.pop() {
let value = if idx >> (usize::BITS - 1) == 1 {
let mut val = bases[!idx];
val.inner.y = -val.inner.y;
val
} else {
bases[idx]
};

if !bucket.inner.infinity {
let x1 = bucket.inner.x;
let x2 = value.inner.x;
let y1 = bucket.inner.y;
let y2 = value.inner.y;

let eq_x = x1 == x2;

if eq_x && y1 != y2 {
bucket.inner.infinity = true;
} else {
let r = d * e;
let m = if eq_x {
let x1 = x1.square();
x1 + x1.double()
} else {
y2 - y1
};
let m = m * r;

let x3 = m.square() - x1 - x2;
let y3 = m * (x1 - x3) - y1;
bucket.inner.x = x3;
bucket.inner.y = y3;
}
} else {
*bucket = value;
}
}
}

if empty {
break;
}
}

let mut running_sum = G1::ZERO;
let mut res = G1::ZERO;
buckets.into_iter().rev().for_each(|b| {
running_sum.inner += b.inner;
res += running_sum;
});
res
})
.collect();

// We store the sum for the lowest window.
let lowest = *window_sums.first().unwrap();

// We're traversing windows from high to low.
lowest
+ window_sums[1..]
.iter()
.rev()
.fold(G1::ZERO, |mut total, &sum_i| {
total += sum_i;
for _ in 0..c {
total = total.double();
}
total
})
}
Loading

0 comments on commit ce9da12

Please sign in to comment.