From 7a3d2a90d3032ee0da10d33cdfbcf4b077e00615 Mon Sep 17 00:00:00 2001 From: ashWhiteHat Date: Tue, 5 Dec 2023 15:54:09 +0900 Subject: [PATCH 1/3] feat: base field should be fft field --- zkstd/src/circuit.rs | 7 +- zkstd/src/circuit/gadget/binary.rs | 4 +- zkstd/src/circuit/gadget/curve.rs | 37 +++++------ zkstd/src/circuit/gadget/field.rs | 27 ++++---- zkstd/src/r1cs.rs | 65 +++++++++---------- zkstd/tests/gadget.rs | 57 +++++++--------- zkstd/tests/grumpkin.rs | 101 +++++++++++++++++++++++++++-- 7 files changed, 188 insertions(+), 110 deletions(-) diff --git a/zkstd/src/circuit.rs b/zkstd/src/circuit.rs index c891083d..46f75fca 100644 --- a/zkstd/src/circuit.rs +++ b/zkstd/src/circuit.rs @@ -1,7 +1,7 @@ mod gadget; pub mod prelude; -use crate::common::{BNAffine, Deserialize, PrimeField, Serialize}; +use crate::common::{BNAffine, Deserialize, FftField, PrimeField, Serialize}; pub trait CircuitDriver: Clone { const NUM_BITS: u16; @@ -9,10 +9,11 @@ pub trait CircuitDriver: Clone { type Affine: BNAffine; // curve base field - type Base: PrimeField + From + Serialize + for<'de> Deserialize<'de>; + type Base: FftField + From + Serialize + for<'de> Deserialize<'de>; // curve scalar field type Scalar: PrimeField + From + Serialize + for<'de> Deserialize<'de>; + // bn curve 3b param - fn b3() -> Self::Scalar; + fn b3() -> Self::Base; } diff --git a/zkstd/src/circuit/gadget/binary.rs b/zkstd/src/circuit/gadget/binary.rs index e035b32c..a6262f80 100644 --- a/zkstd/src/circuit/gadget/binary.rs +++ b/zkstd/src/circuit/gadget/binary.rs @@ -8,14 +8,14 @@ pub struct BinaryAssignment(Wire, PhantomData); impl BinaryAssignment { pub fn instance(cs: &mut R1cs, bit: u8) -> Self { let wire = cs.public_wire(); - cs.x.push(C::Scalar::from(bit as u64)); + cs.x.push(C::Base::from(bit as u64)); Self(wire, PhantomData::default()) } pub fn witness(cs: &mut R1cs, bit: u8) -> Self { let wire = cs.private_wire(); - cs.w.push(C::Scalar::from(bit as u64)); + cs.w.push(C::Base::from(bit as u64)); Self(wire, PhantomData::default()) } diff --git a/zkstd/src/circuit/gadget/curve.rs b/zkstd/src/circuit/gadget/curve.rs index ac3ea55c..3a920f8f 100644 --- a/zkstd/src/circuit/gadget/curve.rs +++ b/zkstd/src/circuit/gadget/curve.rs @@ -4,6 +4,7 @@ use super::field::FieldAssignment; use crate::circuit::CircuitDriver; use crate::common::{BNProjective, CurveGroup, Group, IntGroup, Ring}; use crate::r1cs::R1cs; +use crate::traits::BNAffine; #[derive(Clone)] pub struct PointAssignment { @@ -13,30 +14,30 @@ pub struct PointAssignment { } impl PointAssignment { - pub fn instance(cs: &mut R1cs, x: C::Scalar, y: C::Scalar, is_infinity: bool) -> Self { - let x = FieldAssignment::instance(cs, x); - let y = FieldAssignment::instance(cs, y); + pub fn instance(cs: &mut R1cs, point: C::Affine) -> Self { + let x = FieldAssignment::instance(cs, point.get_x().into()); + let y = FieldAssignment::instance(cs, point.get_y().into()); let z = FieldAssignment::instance( cs, - if is_infinity { - C::Scalar::zero() + if point.is_identity() { + C::Base::zero() } else { - C::Scalar::one() + C::Base::one() }, ); Self { x, y, z } } - pub fn witness(cs: &mut R1cs, x: C::Scalar, y: C::Scalar, is_infinity: bool) -> Self { + pub fn witness(cs: &mut R1cs, x: C::Base, y: C::Base, is_infinity: bool) -> Self { let x = FieldAssignment::witness(cs, x); let y = FieldAssignment::witness(cs, y); let z = FieldAssignment::witness( cs, if is_infinity { - C::Scalar::zero() + C::Base::zero() } else { - C::Scalar::one() + C::Base::one() }, ); @@ -46,11 +47,11 @@ impl PointAssignment { pub fn assert_equal_public_point( &self, cs: &mut R1cs, - point: impl BNProjective, + point: ::Extended, ) { - let point_x = FieldAssignment::constant(&point.get_x()); - let point_y = FieldAssignment::constant(&point.get_y()); - let point_z = FieldAssignment::constant(&point.get_z()); + let point_x = FieldAssignment::constant(&C::Base::from(point.get_x())); + let point_y = FieldAssignment::constant(&C::Base::from(point.get_y())); + let point_z = FieldAssignment::constant(&C::Base::from(point.get_z())); let xz1 = FieldAssignment::mul(cs, &self.x, &point_z); let xz2 = FieldAssignment::mul(cs, &point_x, &self.z); @@ -64,7 +65,7 @@ impl PointAssignment { } pub fn add(&self, rhs: &Self, cs: &mut R1cs) -> Self { - let b3 = FieldAssignment::::constant(&C::b3()); + let b3 = FieldAssignment::::constant(&C::Base::from(C::b3())); let t0 = FieldAssignment::mul(cs, &self.x, &rhs.x); let t1 = FieldAssignment::mul(cs, &self.y, &rhs.y); let t2 = FieldAssignment::mul(cs, &self.z, &rhs.z); @@ -107,7 +108,7 @@ impl PointAssignment { } pub fn double(&self, cs: &mut R1cs) -> Self { - let b3 = FieldAssignment::::constant(&C::b3()); + let b3 = FieldAssignment::::constant(&C::Base::from(C::b3())); let t0 = FieldAssignment::mul(cs, &self.y, &self.y); let z3 = &t0 + &t0; let z3 = &z3 + &z3; @@ -137,8 +138,7 @@ impl PointAssignment { /// coordinate scalar pub fn scalar_point(&self, cs: &mut R1cs, scalar: &FieldAssignment) -> Self { let i = C::Affine::ADDITIVE_IDENTITY; - let mut res = - PointAssignment::instance(cs, i.get_x().into(), i.get_y().into(), i.is_identity()); + let mut res = PointAssignment::instance(cs, i); for bit in FieldAssignment::to_bits(cs, scalar).iter() { res = res.double(cs); let point_to_add = self.select_identity(cs, bit); @@ -153,8 +153,7 @@ impl PointAssignment { let bit = FieldAssignment::from(bit); Self { x: FieldAssignment::mul(cs, &x, &bit), - y: &(&FieldAssignment::mul(cs, &y, &bit) - + &FieldAssignment::constant(&C::Scalar::one())) + y: &(&FieldAssignment::mul(cs, &y, &bit) + &FieldAssignment::constant(&C::Base::one())) - &bit, z: FieldAssignment::mul(cs, &z, &bit), } diff --git a/zkstd/src/circuit/gadget/field.rs b/zkstd/src/circuit/gadget/field.rs index 22504819..08c30108 100644 --- a/zkstd/src/circuit/gadget/field.rs +++ b/zkstd/src/circuit/gadget/field.rs @@ -5,27 +5,28 @@ use crate::matrix::SparseRow; use crate::r1cs::{R1cs, Wire}; #[derive(Clone)] -pub struct FieldAssignment(SparseRow); +pub struct FieldAssignment(SparseRow); impl FieldAssignment { - pub fn inner(&self) -> &SparseRow { + pub fn inner(&self) -> &SparseRow { &self.0 } - pub fn instance(cs: &mut R1cs, instance: C::Scalar) -> Self { + + pub fn instance(cs: &mut R1cs, instance: C::Base) -> Self { let wire = cs.public_wire(); cs.x.push(instance); Self(SparseRow::from(wire)) } - pub fn witness(cs: &mut R1cs, witness: C::Scalar) -> Self { + pub fn witness(cs: &mut R1cs, witness: C::Base) -> Self { let wire = cs.private_wire(); cs.w.push(witness); Self(SparseRow::from(wire)) } - pub fn constant(constant: &C::Scalar) -> Self { + pub fn constant(constant: &C::Base) -> Self { Self(SparseRow(vec![(Wire::ONE, *constant)])) } @@ -63,7 +64,7 @@ impl FieldAssignment { z } - pub fn range_check(cs: &mut R1cs, a_bits: &[BinaryAssignment], c: C::Scalar) { + pub fn range_check(cs: &mut R1cs, a_bits: &[BinaryAssignment], c: C::Base) { let c_bits = c .to_bits() .into_iter() @@ -74,7 +75,7 @@ impl FieldAssignment { assert!(a_bits .iter() .take(a_bits.len() - c_bits.len()) - .all(|b| cs[*b.inner()] == C::Scalar::zero())); + .all(|b| cs[*b.inner()] == C::Base::zero())); let a_bits = a_bits .iter() @@ -104,24 +105,24 @@ impl FieldAssignment { if c == 1 { let bool_constr = FieldAssignment::mul( cs, - &(&bit_field - &FieldAssignment::constant(&C::Scalar::one())), + &(&bit_field - &FieldAssignment::constant(&C::Base::one())), &bit_field, ); FieldAssignment::eq( cs, &bool_constr, - &FieldAssignment::constant(&C::Scalar::zero()), + &FieldAssignment::constant(&C::Base::zero()), ); } else if c == 0 { let bool_constr = FieldAssignment::mul( cs, - &(&(&FieldAssignment::constant(&C::Scalar::one()) - &bit_field) - &p[i - 1]), + &(&(&FieldAssignment::constant(&C::Base::one()) - &bit_field) - &p[i - 1]), &bit_field, ); FieldAssignment::eq( cs, &bool_constr, - &FieldAssignment::constant(&C::Scalar::zero()), + &FieldAssignment::constant(&C::Base::zero()), ); } } @@ -129,7 +130,7 @@ impl FieldAssignment { /// To bit representation in Big-endian pub fn to_bits(cs: &mut R1cs, x: &Self) -> Vec> { - let bound = C::Scalar::MODULUS - C::Scalar::one(); + let bound = C::Base::MODULUS - C::Base::one(); let bit_repr: Vec> = x .inner() @@ -146,7 +147,7 @@ impl FieldAssignment { cs.mul_gate(&x.0, &SparseRow::one(), &y.0) } - pub fn eq_constant(cs: &mut R1cs, x: &Self, c: &C::Scalar) { + pub fn eq_constant(cs: &mut R1cs, x: &Self, c: &C::Base) { cs.mul_gate( &x.0, &SparseRow::one(), diff --git a/zkstd/src/r1cs.rs b/zkstd/src/r1cs.rs index ebbcc885..3b54586a 100644 --- a/zkstd/src/r1cs.rs +++ b/zkstd/src/r1cs.rs @@ -13,17 +13,17 @@ pub struct R1cs { // 1. Structure S // a, b and c matrices and matrix size m: usize, - a: SparseMatrix, - b: SparseMatrix, - c: SparseMatrix, + a: SparseMatrix, + b: SparseMatrix, + c: SparseMatrix, // 2. Instance // r1cs instance includes one constant and public inputs and outputs - pub(crate) x: DenseVectors, + pub(crate) x: DenseVectors, // 3. Witness // r1cs witness includes private inputs and intermediate value - pub(crate) w: DenseVectors, + pub(crate) w: DenseVectors, } impl R1cs { @@ -39,11 +39,11 @@ impl R1cs { self.w.len() } - pub fn x(&self) -> Vec { + pub fn x(&self) -> Vec { self.x.get() } - pub fn w(&self) -> Vec { + pub fn w(&self) -> Vec { self.w.get() } @@ -51,9 +51,9 @@ impl R1cs { pub fn matrices( &self, ) -> ( - SparseMatrix, - SparseMatrix, - SparseMatrix, + SparseMatrix, + SparseMatrix, + SparseMatrix, ) { (self.a.clone(), self.b.clone(), self.c.clone()) } @@ -75,12 +75,7 @@ impl R1cs { .all(|(left, right)| left == right) } - fn append( - &mut self, - a: SparseRow, - b: SparseRow, - c: SparseRow, - ) { + fn append(&mut self, a: SparseRow, b: SparseRow, c: SparseRow) { self.a.0.push(a); self.b.0.push(b); self.c.0.push(c); @@ -100,9 +95,9 @@ impl R1cs { /// constrain x * y = z pub fn mul_gate( &mut self, - x: &SparseRow, - y: &SparseRow, - z: &SparseRow, + x: &SparseRow, + y: &SparseRow, + z: &SparseRow, ) { self.append(x.clone(), y.clone(), z.clone()); } @@ -110,9 +105,9 @@ impl R1cs { /// constrain x + y = z pub fn add_gate( &mut self, - x: &SparseRow, - y: &SparseRow, - z: &SparseRow, + x: &SparseRow, + y: &SparseRow, + z: &SparseRow, ) { self.append(x + y, SparseRow::from(Wire::ONE), z.clone()); } @@ -120,20 +115,20 @@ impl R1cs { /// constrain x - y = z pub fn sub_gate( &mut self, - x: &SparseRow, - y: &SparseRow, - z: &SparseRow, + x: &SparseRow, + y: &SparseRow, + z: &SparseRow, ) { self.append(x - y, SparseRow::from(Wire::ONE), z.clone()); } /// constrain x == y - pub fn equal_gate(&mut self, x: &SparseRow, y: &SparseRow) { + pub fn equal_gate(&mut self, x: &SparseRow, y: &SparseRow) { self.mul_gate(x, &SparseRow::one(), y); } #[allow(clippy::type_complexity)] - pub fn evaluate(&self) -> (Vec, Vec, Vec) { + pub fn evaluate(&self) -> (Vec, Vec, Vec) { let a_evals = self.a.evaluate_with_z(&self.x, &self.w); let b_evals = self.b.evaluate_with_z(&self.x, &self.w); let c_evals = self.c.evaluate_with_z(&self.x, &self.w); @@ -147,14 +142,14 @@ impl R1cs { m_l_1: usize, ) -> ( ( - Vec>, - Vec>, - Vec>, + Vec>, + Vec>, + Vec>, ), ( - Vec>, - Vec>, - Vec>, + Vec>, + Vec>, + Vec>, ), ) { let (a_x, a_w) = self.a.x_and_w(l, m_l_1); @@ -172,14 +167,14 @@ impl Default for R1cs { a: SparseMatrix::default(), b: SparseMatrix::default(), c: SparseMatrix::default(), - x: DenseVectors::new(vec![C::Scalar::one()]), + x: DenseVectors::new(vec![C::Base::one()]), w: DenseVectors::default(), } } } impl Index for R1cs { - type Output = C::Scalar; + type Output = C::Base; fn index(&self, w: Wire) -> &Self::Output { match w { diff --git a/zkstd/tests/gadget.rs b/zkstd/tests/gadget.rs index 6db758f9..1d925a49 100644 --- a/zkstd/tests/gadget.rs +++ b/zkstd/tests/gadget.rs @@ -2,25 +2,25 @@ mod grumpkin; #[cfg(test)] mod grumpkin_gadget_tests { - use crate::grumpkin::{Affine, Fq as Base, Fr as Scalar, GrumpkinDriver}; + use crate::grumpkin::{Affine, Fq as Scalar, Fr as Base, GrumpkinDriver}; use rand_core::OsRng; use zkstd::circuit::prelude::{FieldAssignment, PointAssignment, R1cs}; - use zkstd::common::{BNAffine, BNProjective, CurveGroup, Group, PrimeField}; + use zkstd::common::{BNAffine, BNProjective, Group, PrimeField}; #[test] fn range_proof_test() { for _ in 0..100 { let mut cs: R1cs = R1cs::default(); let mut ncs = cs.clone(); - let bound = Scalar::from(10); + let bound = Base::from(10); let x_ass = FieldAssignment::instance(&mut cs, bound); let x_bits = FieldAssignment::to_bits(&mut cs, &x_ass); FieldAssignment::range_check(&mut cs, &x_bits, bound); assert!(cs.is_sat()); - let x_ass = FieldAssignment::instance(&mut ncs, bound + Scalar::one()); + let x_ass = FieldAssignment::instance(&mut ncs, bound + Base::one()); let x_bits = FieldAssignment::to_bits(&mut ncs, &x_ass); FieldAssignment::range_check(&mut ncs, &x_bits, bound); assert!(!ncs.is_sat()); @@ -31,8 +31,8 @@ mod grumpkin_gadget_tests { fn field_add_test() { let mut cs: R1cs = R1cs::default(); let mut ncs = cs.clone(); - let a = Scalar::random(OsRng); - let b = Scalar::random(OsRng); + let a = Base::random(OsRng); + let b = Base::random(OsRng); let mut c = a + b; // a + b == c @@ -45,7 +45,7 @@ mod grumpkin_gadget_tests { assert!(cs.is_sat()); // a + b != c - c += Scalar::one(); + c += Base::one(); let x = FieldAssignment::instance(&mut ncs, a); let y = FieldAssignment::witness(&mut ncs, b); let z = FieldAssignment::instance(&mut ncs, c); @@ -59,8 +59,8 @@ mod grumpkin_gadget_tests { fn field_mul_test() { let mut cs: R1cs = R1cs::default(); let mut ncs = cs.clone(); - let a = Scalar::random(OsRng); - let b = Scalar::random(OsRng); + let a = Base::random(OsRng); + let b = Base::random(OsRng); let mut c = a * b; // a * b == c @@ -73,7 +73,7 @@ mod grumpkin_gadget_tests { assert!(cs.is_sat()); // a * b != c - c += Scalar::one(); + c += Base::one(); let x = FieldAssignment::instance(&mut ncs, a); let y = FieldAssignment::witness(&mut ncs, b); let z = FieldAssignment::instance(&mut ncs, c); @@ -87,9 +87,9 @@ mod grumpkin_gadget_tests { fn field_ops_test() { let mut cs: R1cs = R1cs::default(); let mut ncs = cs.clone(); - let input = Scalar::from(3); - let c = Scalar::from(5); - let out = Scalar::from(35); + let input = Base::from(3); + let c = Base::from(5); + let out = Base::from(35); // x^3 + x + 5 == 35 let x = FieldAssignment::witness(&mut cs, input); @@ -103,8 +103,8 @@ mod grumpkin_gadget_tests { assert!(cs.is_sat()); // x^3 + x + 5 != 36 - let c = Scalar::from(5); - let out = Scalar::from(36); + let c = Base::from(5); + let out = Base::from(36); let x = FieldAssignment::witness(&mut ncs, input); let c = FieldAssignment::constant(&c); let z = FieldAssignment::instance(&mut ncs, out); @@ -122,13 +122,7 @@ mod grumpkin_gadget_tests { let mut cs: R1cs = R1cs::default(); let point = Affine::random(OsRng); - let circuit_double = PointAssignment::instance( - &mut cs, - point.get_x(), - point.get_y(), - point.is_identity(), - ) - .double(&mut cs); + let circuit_double = PointAssignment::instance(&mut cs, point).double(&mut cs); let expected = point.to_extended().double(); @@ -146,10 +140,8 @@ mod grumpkin_gadget_tests { let a = Affine::random(OsRng); let b = Affine::ADDITIVE_IDENTITY; - let a_assignment = - PointAssignment::instance(&mut cs, a.get_x(), a.get_y(), a.is_identity()); - let b_assignment = - PointAssignment::instance(&mut cs, b.get_x(), b.get_y(), b.is_identity()); + let a_assignment = PointAssignment::instance(&mut cs, a); + let b_assignment = PointAssignment::instance(&mut cs, b); let expected = a + b; @@ -165,10 +157,8 @@ mod grumpkin_gadget_tests { let a = Affine::random(OsRng); let b = Affine::random(OsRng); - let a_assignment = - PointAssignment::instance(&mut cs, a.get_x(), a.get_y(), a.is_identity()); - let b_assignment = - PointAssignment::instance(&mut cs, b.get_x(), b.get_y(), b.is_identity()); + let a_assignment = PointAssignment::instance(&mut cs, a); + let b_assignment = PointAssignment::instance(&mut cs, b); let expected = a.to_extended() + b.to_extended(); @@ -184,13 +174,12 @@ mod grumpkin_gadget_tests { fn curve_scalar_mul_test() { for _ in 0..100 { let mut cs: R1cs = R1cs::default(); - let x = Scalar::random(OsRng); + let x = Base::random(OsRng); let p = Affine::random(OsRng); let x_assignment = FieldAssignment::instance(&mut cs, x); // Fr - let p_assignment = - PointAssignment::instance(&mut cs, p.get_x(), p.get_y(), p.is_identity()); - let expected = p * Base::from(x); + let p_assignment = PointAssignment::instance(&mut cs, p); + let expected = p * Scalar::from(x); assert_eq!(x.to_bits(), Base::from(x).to_bits()); diff --git a/zkstd/tests/grumpkin.rs b/zkstd/tests/grumpkin.rs index 9ad7c038..2ce9fc31 100644 --- a/zkstd/tests/grumpkin.rs +++ b/zkstd/tests/grumpkin.rs @@ -147,6 +147,99 @@ macro_rules! cycle_pair_field { cycle_pair_field!(Fr, FR_GENERATOR, FR_MODULUS, FR_R, FR_R2, FR_R3, FR_INV); cycle_pair_field!(Fq, FQ_GENERATOR, FQ_MODULUS, FQ_R, FQ_R2, FQ_R3, FQ_INV); +impl FftField for Fr { + const S: usize = 28; + + const ROOT_OF_UNITY: Self = Fr::to_mont_form([ + 0xd34f1ed960c37c9c, + 0x3215cf6dd39329c8, + 0x98865ea93dd31f74, + 0x03ddb9f5166d18b7, + ]); + + const MULTIPLICATIVE_GENERATOR: Self = Fr::to_mont_form(FR_GENERATOR); + + fn pow(self, val: u64) -> Self { + Self(pow(self.0, [val, 0, 0, 0], FR_R, FR_MODULUS, FR_INV)) + } + + fn divn(&mut self, mut n: u32) { + if n >= 256 { + *self = Self::from(0u64); + return; + } + + while n >= 64 { + let mut t = 0; + for i in self.0.iter_mut().rev() { + core::mem::swap(&mut t, i); + } + n -= 64; + } + + if n > 0 { + let mut t = 0; + for i in self.0.iter_mut().rev() { + let t2 = *i << (64 - n); + *i >>= n; + *i |= t; + t = t2; + } + } + } + + fn from_hash(hash: &[u8; 64]) -> Self { + let d0 = Self([ + u64::from_le_bytes(hash[0..8].try_into().unwrap()), + u64::from_le_bytes(hash[8..16].try_into().unwrap()), + u64::from_le_bytes(hash[16..24].try_into().unwrap()), + u64::from_le_bytes(hash[24..32].try_into().unwrap()), + ]); + let d1 = Self([ + u64::from_le_bytes(hash[32..40].try_into().unwrap()), + u64::from_le_bytes(hash[40..48].try_into().unwrap()), + u64::from_le_bytes(hash[48..56].try_into().unwrap()), + u64::from_le_bytes(hash[56..64].try_into().unwrap()), + ]); + d0 * Self(FR_R2) + d1 * Self(FR_R3) + } + + fn reduce(&self) -> Self { + Self(mont( + [self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0], + FR_MODULUS, + FR_INV, + )) + } + + fn is_even(&self) -> bool { + self.0[0] % 2 == 0 + } + + fn mod_2_pow_k(&self, k: u8) -> u8 { + (self.0[0] & ((1 << k) - 1)) as u8 + } + + fn mods_2_pow_k(&self, w: u8) -> i8 { + assert!(w < 32u8); + let modulus = self.mod_2_pow_k(w) as i8; + let two_pow_w_minus_one = 1i8 << (w - 1); + + match modulus >= two_pow_w_minus_one { + false => modulus, + true => modulus - ((1u8 << w) as i8), + } + } +} + +impl From<[u64; 4]> for Fr { + fn from(val: [u64; 4]) -> Fr { + Fr(val) + } +} + +impl ParallelCmp for Fr {} + pub(crate) const FR_PARAM_B: Fr = Fr::new_unchecked([ 0xdd7056026000005a, 0x223fa97acb319311, @@ -429,13 +522,13 @@ pub struct GrumpkinDriver; impl CircuitDriver for GrumpkinDriver { const NUM_BITS: u16 = 254; - type Affine = G1Affine; + type Affine = Affine; - type Base = Fq; + type Base = Fr; - type Scalar = Fr; + type Scalar = Fq; - fn b3() -> Self::Scalar { + fn b3() -> Self::Base { FR_PARAM_B3 } } From 1e9a270bf8983883ae882fdc2c4e39aa56407bdf Mon Sep 17 00:00:00 2001 From: ashWhiteHat Date: Tue, 5 Dec 2023 15:54:45 +0900 Subject: [PATCH 2/3] feat: grumpkin driver modification --- groth16/src/lib.rs | 14 +++++++------- grumpkin/src/driver.rs | 12 +++++++----- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/groth16/src/lib.rs b/groth16/src/lib.rs index 78a92e65..b4aca25f 100644 --- a/groth16/src/lib.rs +++ b/groth16/src/lib.rs @@ -22,7 +22,7 @@ mod tests { use crate::error::Error; use crate::zksnark::ZkSnark; - use bn_254::Fr as BnScalar; + use bn_254::Fr as GrumpkinBase; use grumpkin::driver::GrumpkinDriver; use zkstd::circuit::prelude::{FieldAssignment, R1cs}; use zkstd::common::OsRng; @@ -31,12 +31,12 @@ mod tests { fn arithmetic_test() { #[derive(Debug)] pub struct DummyCircuit { - x: BnScalar, - o: BnScalar, + x: GrumpkinBase, + o: GrumpkinBase, } impl DummyCircuit { - pub fn new(x: BnScalar, o: BnScalar) -> Self { + pub fn new(x: GrumpkinBase, o: GrumpkinBase) -> Self { Self { x, o } } } @@ -51,7 +51,7 @@ mod tests { fn synthesize(&self, composer: &mut R1cs) -> Result<(), Error> { let x = FieldAssignment::instance(composer, self.x); let o = FieldAssignment::instance(composer, self.o); - let c = FieldAssignment::constant(&BnScalar::from(5)); + let c = FieldAssignment::constant(&GrumpkinBase::from(5)); let sym1 = FieldAssignment::mul(composer, &x, &x); let y = FieldAssignment::mul(composer, &sym1, &x); @@ -64,8 +64,8 @@ mod tests { } } - let x = BnScalar::from(3); - let o = BnScalar::from(35); + let x = GrumpkinBase::from(3); + let o = GrumpkinBase::from(35); let circuit = DummyCircuit::new(x, o); let (mut prover, verifier) = diff --git a/grumpkin/src/driver.rs b/grumpkin/src/driver.rs index 5296eb13..8d40fde4 100644 --- a/grumpkin/src/driver.rs +++ b/grumpkin/src/driver.rs @@ -1,5 +1,6 @@ +use crate::curve::Affine; use crate::params::PARAM_B3; -use bn_254::{Fq, Fr, G1Affine}; +use bn_254::{Fq, Fr}; use zkstd::circuit::CircuitDriver; #[derive(Clone, Debug, Default, PartialEq, Eq)] @@ -7,13 +8,14 @@ pub struct GrumpkinDriver; impl CircuitDriver for GrumpkinDriver { const NUM_BITS: u16 = 254; - type Affine = G1Affine; - type Base = Fq; + type Affine = Affine; - type Scalar = Fr; + type Base = Fr; - fn b3() -> Self::Scalar { + type Scalar = Fq; + + fn b3() -> Self::Base { PARAM_B3 } } From ededf49f5dfe2924939f27f5d2489bbb29ba6fb7 Mon Sep 17 00:00:00 2001 From: ashWhiteHat Date: Thu, 7 Dec 2023 16:42:54 +0900 Subject: [PATCH 3/3] feat: circuit driver wip --- nova/src/circuit/transcript.rs | 9 ++++---- nova/src/driver.rs | 37 +++++++++++++++++++++++++++++++ nova/src/gadget/mimc.rs | 4 ++-- nova/src/ivc.rs | 3 ++- nova/src/lib.rs | 1 + nova/src/prover.rs | 10 ++++----- nova/src/relaxed_r1cs/instance.rs | 2 +- nova/src/relaxed_r1cs/witness.rs | 2 +- zkstd/src/circuit.rs | 4 ++-- 9 files changed, 55 insertions(+), 17 deletions(-) create mode 100644 nova/src/driver.rs diff --git a/nova/src/circuit/transcript.rs b/nova/src/circuit/transcript.rs index 5c55cc70..c9b62b16 100644 --- a/nova/src/circuit/transcript.rs +++ b/nova/src/circuit/transcript.rs @@ -14,7 +14,7 @@ impl Default for MimcROCircuit { Self { hasher: MimcAssignment::default(), state: Vec::default(), - key: FieldAssignment::constant(&C::Scalar::zero()), + key: FieldAssignment::constant(&C::Base::zero()), } } } @@ -47,7 +47,7 @@ mod tests { use grumpkin::{driver::GrumpkinDriver, Affine}; use rand_core::OsRng; use zkstd::circuit::prelude::{FieldAssignment, PointAssignment, R1cs}; - use zkstd::common::{CurveGroup, Group}; + use zkstd::common::Group; #[test] fn mimc_circuit() { @@ -57,9 +57,8 @@ mod tests { let point = Affine::random(OsRng); let scalar = Fr::random(OsRng); - let point_assignment = - PointAssignment::instance(&mut cs, point.get_x(), point.get_y(), point.is_identity()); - let scalar_assignment = FieldAssignment::instance(&mut cs, scalar); + let point_assignment = PointAssignment::instance(&mut cs, point); + let scalar_assignment = FieldAssignment::instance(&mut cs, scalar.into()); mimc.append(scalar); mimc.append_point(point); mimc_circuit.append(scalar_assignment); diff --git a/nova/src/driver.rs b/nova/src/driver.rs new file mode 100644 index 00000000..898940c2 --- /dev/null +++ b/nova/src/driver.rs @@ -0,0 +1,37 @@ +use bn_254::{Fq, Fr, G1Affine as BN254Affine, params::PARAM_B3 as BN254_B3}; +use grumpkin::{Affine as GrumpkinAffine, params::PARAM_B3 as Grumpkin_B3}; +use zkstd::circuit::CircuitDriver; + +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct GrumpkinDriver; + +impl CircuitDriver for GrumpkinDriver { + const NUM_BITS: u16 = 254; + + type Affine = GrumpkinAffine; + + type Base = Fr; + + type Scalar = Fq; + + fn b3() -> Self::Base { + Grumpkin_B3 + } +} + +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct BN254Driver; + +impl CircuitDriver for BN254Driver { + const NUM_BITS: u16 = 254; + + type Affine = BN254Affine; + + type Base = Fq; + + type Scalar = Fr; + + fn b3() -> Self::Base { + BN254_B3 + } +} diff --git a/nova/src/gadget/mimc.rs b/nova/src/gadget/mimc.rs index 9f9aef0e..d62ede56 100644 --- a/nova/src/gadget/mimc.rs +++ b/nova/src/gadget/mimc.rs @@ -3,13 +3,13 @@ use crate::hash::Mimc; use zkstd::circuit::prelude::{CircuitDriver, FieldAssignment, R1cs}; pub(crate) struct MimcAssignment { - constants: [C::Scalar; ROUND], + constants: [C::Base; ROUND], } impl Default for MimcAssignment { fn default() -> Self { Self { - constants: Mimc::::default().constants, + constants: Mimc::::default().constants, } } } diff --git a/nova/src/ivc.rs b/nova/src/ivc.rs index 4ffb863b..e6e6c09e 100644 --- a/nova/src/ivc.rs +++ b/nova/src/ivc.rs @@ -74,6 +74,7 @@ mod tests { use crate::test::ExampleFunction; use grumpkin::driver::GrumpkinDriver; + use bn_254::Fq; use rand_core::OsRng; use zkstd::circuit::prelude::R1cs; use zkstd::matrix::DenseVectors; @@ -82,7 +83,7 @@ mod tests { #[test] fn ivc_test() { let r1cs: R1cs = example_r1cs(1); - let z0 = DenseVectors::new(r1cs.x()); + let z0 = DenseVectors::new(r1cs.x().iter().map(|x| Fq::from(*x)).collect()); let mut ivc = Ivc::new(r1cs, OsRng, z0); ivc.recurse::>(); let proof = ivc.prove(); diff --git a/nova/src/lib.rs b/nova/src/lib.rs index 8513174f..3bb4f8e4 100644 --- a/nova/src/lib.rs +++ b/nova/src/lib.rs @@ -2,6 +2,7 @@ #![allow(unused_variables, dead_code)] mod circuit; +mod driver; mod function; mod gadget; mod hash; diff --git a/nova/src/prover.rs b/nova/src/prover.rs index bc225d3b..43bc8ad1 100644 --- a/nova/src/prover.rs +++ b/nova/src/prover.rs @@ -8,15 +8,15 @@ use zkstd::circuit::prelude::{CircuitDriver, R1cs}; use zkstd::common::{Ring, RngCore}; use zkstd::matrix::DenseVectors; -pub struct Prover { +pub struct Prover { // public parameters pp: PedersenCommitment, // r1cs structure - f: R1cs, + f: R1cs, } -impl Prover { +impl Prover { pub fn new(f: R1cs, rng: impl RngCore) -> Self { let m = f.m(); let n = m.next_power_of_two() as u64; @@ -61,8 +61,8 @@ impl Prover { let u2 = relaxed_r1cs.u(); let m = self.f.m(); let (a, b, c) = self.f.matrices(); - let (w0, w1) = (DenseVectors::new(r1cs.w()), relaxed_r1cs.w()); - let (x0, x1) = (DenseVectors::new(r1cs.x()), relaxed_r1cs.x()); + let (w0, w1) = (DenseVectors::new(r1cs.w().iter().map(|c| C::Scalar::from(*c)).collect()), relaxed_r1cs.w()); + let (x0, x1) = (DenseVectors::new(r1cs.x().iter().map(|c| C::Scalar::from(*c)).collect()), relaxed_r1cs.x()); // matrices and z vector matrix multiplication let az2 = a.prod(&m, &x1, &w1); diff --git a/nova/src/relaxed_r1cs/instance.rs b/nova/src/relaxed_r1cs/instance.rs index aa36dca9..1c5aa3c4 100644 --- a/nova/src/relaxed_r1cs/instance.rs +++ b/nova/src/relaxed_r1cs/instance.rs @@ -31,7 +31,7 @@ impl RelaxedR1csInstance { C::Affine::ADDITIVE_IDENTITY, C::Scalar::one(), C::Affine::ADDITIVE_IDENTITY, - DenseVectors::new(r1cs.x()), + DenseVectors::new(r1cs.x().iter().map(|x| C::Scalar::from(*x)).collect()), ); let (e2, u2, w2, x2) = (self.commit_e, self.u, self.commit_w, self.x.clone()); diff --git a/nova/src/relaxed_r1cs/witness.rs b/nova/src/relaxed_r1cs/witness.rs index 8b746eb0..b60f6d2d 100644 --- a/nova/src/relaxed_r1cs/witness.rs +++ b/nova/src/relaxed_r1cs/witness.rs @@ -21,7 +21,7 @@ impl RelaxedR1csWitness { pub(crate) fn fold(&self, r1cs: &R1cs, r: C::Scalar, t: DenseVectors) -> Self { let r2 = r.square(); let e2 = self.e.clone(); - let w1 = DenseVectors::new(r1cs.w()); + let w1 = DenseVectors::new(r1cs.w().iter().map(|x| C::Scalar::from(*x)).collect()); let w2 = self.w.clone(); let e = t * r + e2 * r2; diff --git a/zkstd/src/circuit.rs b/zkstd/src/circuit.rs index 46f75fca..4bbdd7da 100644 --- a/zkstd/src/circuit.rs +++ b/zkstd/src/circuit.rs @@ -1,7 +1,7 @@ mod gadget; pub mod prelude; -use crate::common::{BNAffine, Deserialize, FftField, PrimeField, Serialize}; +use crate::common::{BNAffine, Deserialize, PrimeField, Serialize}; pub trait CircuitDriver: Clone { const NUM_BITS: u16; @@ -9,7 +9,7 @@ pub trait CircuitDriver: Clone { type Affine: BNAffine; // curve base field - type Base: FftField + From + Serialize + for<'de> Deserialize<'de>; + type Base: PrimeField + From + Serialize + for<'de> Deserialize<'de>; // curve scalar field type Scalar: PrimeField + From + Serialize + for<'de> Deserialize<'de>;