Skip to content

Commit

Permalink
refactor: Refactor and test Fp6 multiplication for Zkvm OS
Browse files Browse the repository at this point in the history
- Refactored the `mul_interleaved` function in the `Fp6` struct, moving code for `target_os != "zkvm"` to a new `mul_interleaved_default` function,
- Added a test module for the `mul_interleaved` function, including a new function `mul_interleaved_zkvm_test` and a test case `fuzz_mul_interleaved` for function validation.
  • Loading branch information
huitseeker committed Aug 21, 2024
1 parent 652ebc1 commit fda530b
Showing 1 changed file with 152 additions and 75 deletions.
227 changes: 152 additions & 75 deletions src/fp6.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ impl Fp6 {
}
}

#[cfg(feature = "pairings")]
#[cfg(any(test, feature = "pairings"))]
pub fn random(mut rng: impl RngCore) -> Self {
Fp6 {
c0: Fp2::random(&mut rng),
Expand Down Expand Up @@ -264,12 +264,90 @@ impl Fp6 {
self.c0.is_zero() & self.c1.is_zero() & self.c2.is_zero()
}

#[inline]
fn mul_interleaved_default(&self, b: &Self) -> Self {
let a = self;
// Implements the full-tower interleaving strategy from
// [ePrint 2022-376](https://eprint.iacr.org/2022/367).
// The intuition for this algorithm is that we can look at F_p^6 as a direct
// extension of F_p^2, and express the overall operations down to the base field
// F_p instead of only over F_p^2. This enables us to interleave multiplications
// and reductions, ensuring that we don't require double-width intermediate
// representations (with around twice as many limbs as F_p elements).

// We want to express the multiplication c = a x b, where a = (a_0, a_1, a_2) is
// an element of F_p^6, and a_i = (a_i,0, a_i,1) is an element of F_p^2. The fully
// expanded multiplication is given by (2022-376 §5):
//
// c_0,0 = a_0,0 b_0,0 - a_0,1 b_0,1 + a_1,0 b_2,0 - a_1,1 b_2,1 + a_2,0 b_1,0 - a_2,1 b_1,1
// - a_1,0 b_2,1 - a_1,1 b_2,0 - a_2,0 b_1,1 - a_2,1 b_1,0.
// = a_0,0 b_0,0 - a_0,1 b_0,1 + a_1,0 (b_2,0 - b_2,1) - a_1,1 (b_2,0 + b_2,1)
// + a_2,0 (b_1,0 - b_1,1) - a_2,1 (b_1,0 + b_1,1).
//
// c_0,1 = a_0,0 b_0,1 + a_0,1 b_0,0 + a_1,0 b_2,1 + a_1,1 b_2,0 + a_2,0 b_1,1 + a_2,1 b_1,0
// + a_1,0 b_2,0 - a_1,1 b_2,1 + a_2,0 b_1,0 - a_2,1 b_1,1.
// = a_0,0 b_0,1 + a_0,1 b_0,0 + a_1,0(b_2,0 + b_2,1) + a_1,1(b_2,0 - b_2,1)
// + a_2,0(b_1,0 + b_1,1) + a_2,1(b_1,0 - b_1,1).
//
// c_1,0 = a_0,0 b_1,0 - a_0,1 b_1,1 + a_1,0 b_0,0 - a_1,1 b_0,1 + a_2,0 b_2,0 - a_2,1 b_2,1
// - a_2,0 b_2,1 - a_2,1 b_2,0.
// = a_0,0 b_1,0 - a_0,1 b_1,1 + a_1,0 b_0,0 - a_1,1 b_0,1 + a_2,0(b_2,0 - b_2,1)
// - a_2,1(b_2,0 + b_2,1).
//
// c_1,1 = a_0,0 b_1,1 + a_0,1 b_1,0 + a_1,0 b_0,1 + a_1,1 b_0,0 + a_2,0 b_2,1 + a_2,1 b_2,0
// + a_2,0 b_2,0 - a_2,1 b_2,1
// = a_0,0 b_1,1 + a_0,1 b_1,0 + a_1,0 b_0,1 + a_1,1 b_0,0 + a_2,0(b_2,0 + b_2,1)
// + a_2,1(b_2,0 - b_2,1).
//
// c_2,0 = a_0,0 b_2,0 - a_0,1 b_2,1 + a_1,0 b_1,0 - a_1,1 b_1,1 + a_2,0 b_0,0 - a_2,1 b_0,1.
// c_2,1 = a_0,0 b_2,1 + a_0,1 b_2,0 + a_1,0 b_1,1 + a_1,1 b_1,0 + a_2,0 b_0,1 + a_2,1 b_0,0.
//
// Each of these is a "sum of products", which we can compute efficiently.
let b10_p_b11 = b.c1.c0 + b.c1.c1;
let b10_m_b11 = b.c1.c0 - b.c1.c1;
let b20_p_b21 = b.c2.c0 + b.c2.c1;
let b20_m_b21 = b.c2.c0 - b.c2.c1;

Fp6 {
c0: Fp2 {
c0: Fp::sum_of_products(
[a.c0.c0, -a.c0.c1, a.c1.c0, -a.c1.c1, a.c2.c0, -a.c2.c1],
[b.c0.c0, b.c0.c1, b20_m_b21, b20_p_b21, b10_m_b11, b10_p_b11],
),
c1: Fp::sum_of_products(
[a.c0.c0, a.c0.c1, a.c1.c0, a.c1.c1, a.c2.c0, a.c2.c1],
[b.c0.c1, b.c0.c0, b20_p_b21, b20_m_b21, b10_p_b11, b10_m_b11],
),
},
c1: Fp2 {
c0: Fp::sum_of_products(
[a.c0.c0, -a.c0.c1, a.c1.c0, -a.c1.c1, a.c2.c0, -a.c2.c1],
[b.c1.c0, b.c1.c1, b.c0.c0, b.c0.c1, b20_m_b21, b20_p_b21],
),
c1: Fp::sum_of_products(
[a.c0.c0, a.c0.c1, a.c1.c0, a.c1.c1, a.c2.c0, a.c2.c1],
[b.c1.c1, b.c1.c0, b.c0.c1, b.c0.c0, b20_p_b21, b20_m_b21],
),
},
c2: Fp2 {
c0: Fp::sum_of_products(
[a.c0.c0, -a.c0.c1, a.c1.c0, -a.c1.c1, a.c2.c0, -a.c2.c1],
[b.c2.c0, b.c2.c1, b.c1.c0, b.c1.c1, b.c0.c0, b.c0.c1],
),
c1: Fp::sum_of_products(
[a.c0.c0, a.c0.c1, a.c1.c0, a.c1.c1, a.c2.c0, a.c2.c1],
[b.c2.c1, b.c2.c0, b.c1.c1, b.c1.c0, b.c0.c1, b.c0.c0],
),
},
}
}

/// Returns `c = self * b`.
#[inline]
fn mul_interleaved(&self, b: &Self) -> Self {
let a = self;
cfg_if::cfg_if! {
if #[cfg(target_os = "zkvm")] {
let a = self;
// Implements Algorithm 13 from https://eprint.iacr.org/2010/354.pdf
let mut t0 = self.c0;
t0.mul_inp(&b.c0);
Expand Down Expand Up @@ -308,79 +386,7 @@ impl Fp6 {
c2.add_inp(&t1);
Fp6 { c0, c1, c2 }
} else {
// Implements the full-tower interleaving strategy from
// [ePrint 2022-376](https://eprint.iacr.org/2022/367).
// The intuition for this algorithm is that we can look at F_p^6 as a direct
// extension of F_p^2, and express the overall operations down to the base field
// F_p instead of only over F_p^2. This enables us to interleave multiplications
// and reductions, ensuring that we don't require double-width intermediate
// representations (with around twice as many limbs as F_p elements).

// We want to express the multiplication c = a x b, where a = (a_0, a_1, a_2) is
// an element of F_p^6, and a_i = (a_i,0, a_i,1) is an element of F_p^2. The fully
// expanded multiplication is given by (2022-376 §5):
//
// c_0,0 = a_0,0 b_0,0 - a_0,1 b_0,1 + a_1,0 b_2,0 - a_1,1 b_2,1 + a_2,0 b_1,0 - a_2,1 b_1,1
// - a_1,0 b_2,1 - a_1,1 b_2,0 - a_2,0 b_1,1 - a_2,1 b_1,0.
// = a_0,0 b_0,0 - a_0,1 b_0,1 + a_1,0 (b_2,0 - b_2,1) - a_1,1 (b_2,0 + b_2,1)
// + a_2,0 (b_1,0 - b_1,1) - a_2,1 (b_1,0 + b_1,1).
//
// c_0,1 = a_0,0 b_0,1 + a_0,1 b_0,0 + a_1,0 b_2,1 + a_1,1 b_2,0 + a_2,0 b_1,1 + a_2,1 b_1,0
// + a_1,0 b_2,0 - a_1,1 b_2,1 + a_2,0 b_1,0 - a_2,1 b_1,1.
// = a_0,0 b_0,1 + a_0,1 b_0,0 + a_1,0(b_2,0 + b_2,1) + a_1,1(b_2,0 - b_2,1)
// + a_2,0(b_1,0 + b_1,1) + a_2,1(b_1,0 - b_1,1).
//
// c_1,0 = a_0,0 b_1,0 - a_0,1 b_1,1 + a_1,0 b_0,0 - a_1,1 b_0,1 + a_2,0 b_2,0 - a_2,1 b_2,1
// - a_2,0 b_2,1 - a_2,1 b_2,0.
// = a_0,0 b_1,0 - a_0,1 b_1,1 + a_1,0 b_0,0 - a_1,1 b_0,1 + a_2,0(b_2,0 - b_2,1)
// - a_2,1(b_2,0 + b_2,1).
//
// c_1,1 = a_0,0 b_1,1 + a_0,1 b_1,0 + a_1,0 b_0,1 + a_1,1 b_0,0 + a_2,0 b_2,1 + a_2,1 b_2,0
// + a_2,0 b_2,0 - a_2,1 b_2,1
// = a_0,0 b_1,1 + a_0,1 b_1,0 + a_1,0 b_0,1 + a_1,1 b_0,0 + a_2,0(b_2,0 + b_2,1)
// + a_2,1(b_2,0 - b_2,1).
//
// c_2,0 = a_0,0 b_2,0 - a_0,1 b_2,1 + a_1,0 b_1,0 - a_1,1 b_1,1 + a_2,0 b_0,0 - a_2,1 b_0,1.
// c_2,1 = a_0,0 b_2,1 + a_0,1 b_2,0 + a_1,0 b_1,1 + a_1,1 b_1,0 + a_2,0 b_0,1 + a_2,1 b_0,0.
//
// Each of these is a "sum of products", which we can compute efficiently.
let b10_p_b11 = b.c1.c0 + b.c1.c1;
let b10_m_b11 = b.c1.c0 - b.c1.c1;
let b20_p_b21 = b.c2.c0 + b.c2.c1;
let b20_m_b21 = b.c2.c0 - b.c2.c1;

Fp6 {
c0: Fp2 {
c0: Fp::sum_of_products(
[a.c0.c0, -a.c0.c1, a.c1.c0, -a.c1.c1, a.c2.c0, -a.c2.c1],
[b.c0.c0, b.c0.c1, b20_m_b21, b20_p_b21, b10_m_b11, b10_p_b11],
),
c1: Fp::sum_of_products(
[a.c0.c0, a.c0.c1, a.c1.c0, a.c1.c1, a.c2.c0, a.c2.c1],
[b.c0.c1, b.c0.c0, b20_p_b21, b20_m_b21, b10_p_b11, b10_m_b11],
),
},
c1: Fp2 {
c0: Fp::sum_of_products(
[a.c0.c0, -a.c0.c1, a.c1.c0, -a.c1.c1, a.c2.c0, -a.c2.c1],
[b.c1.c0, b.c1.c1, b.c0.c0, b.c0.c1, b20_m_b21, b20_p_b21],
),
c1: Fp::sum_of_products(
[a.c0.c0, a.c0.c1, a.c1.c0, a.c1.c1, a.c2.c0, a.c2.c1],
[b.c1.c1, b.c1.c0, b.c0.c1, b.c0.c0, b20_p_b21, b20_m_b21],
),
},
c2: Fp2 {
c0: Fp::sum_of_products(
[a.c0.c0, -a.c0.c1, a.c1.c0, -a.c1.c1, a.c2.c0, -a.c2.c1],
[b.c2.c0, b.c2.c1, b.c1.c0, b.c1.c1, b.c0.c0, b.c0.c1],
),
c1: Fp::sum_of_products(
[a.c0.c0, a.c0.c1, a.c1.c0, a.c1.c1, a.c2.c0, a.c2.c1],
[b.c2.c1, b.c2.c0, b.c1.c1, b.c1.c0, b.c0.c1, b.c0.c0],
),
},
}
self.mul_interleaved_default(b)
}
}
}
Expand Down Expand Up @@ -681,3 +687,74 @@ fn test_zeroize() {
a.zeroize();
assert!(bool::from(a.is_zero()));
}

#[cfg(test)]
mod tests {
use rand_core::SeedableRng;

use super::*;
const TEST_ITER: usize = 1000;
const SEED: [u8; 16] = [
0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, 0xbc,
0xe5,
];

// this is meant to mirror the zkvm branch of mul_interleaved
fn mul_interleaved_zkvm_test(a: Fp6, b: &Fp6) -> Fp6 {
// Implements Algorithm 13 from https://eprint.iacr.org/2010/354.pdf
let mut t0 = a.c0;
t0 = t0.mul(&b.c0);
let mut t1 = a.c1;
t1 = t1.mul(&b.c1);
let mut t2 = a.c2;
t2 = t2.mul(&b.c2);
let mut c0 = a.c1;
c0 = c0.add(&a.c2);
let mut tmp = b.c1;
tmp = tmp.add(&b.c2);
c0 = c0.mul(&tmp);
tmp = t2;
tmp = tmp.add(&t1);
c0 = c0.sub(&tmp);
c0 = c0.mul_by_nonresidue();
c0 = c0.add(&t0);
let mut c1 = a.c0;
c1 = c1.add(&a.c1);
tmp = b.c0;
tmp = tmp.add(&b.c1);
c1 = c1.mul(&tmp);
tmp = t0;
tmp = tmp.add(&t1);
c1 = c1.sub(&tmp);
tmp = t2.mul_by_nonresidue();
c1 = c1.add(&tmp);
tmp = a.c0;
tmp = tmp.add(&a.c2);
let mut c2 = b.c0;
c2 = c2.add(&b.c2);
c2 = c2.mul(&tmp);
tmp = t0;
tmp = tmp.add(&t2);
c2 = c2.sub(&tmp);
c2 = c2.add(&t1);
Fp6 { c0, c1, c2 }
}

#[test]
fn fuzz_mul_interleaved() {
for _i in 0..TEST_ITER {
let mut rng = rand_xorshift::XorShiftRng::from_seed(SEED);
let a = Fp6::random(&mut rng);
let b = Fp6::random(&mut rng);

let result_zkvm = mul_interleaved_zkvm_test(a, &b);
let result_default = a.mul_interleaved_default(&b);

assert_eq!(
result_zkvm, result_default,
"Mismatch in mul_interleaved results for a={:?}, b={:?}",
a, b
);
}
}
}

0 comments on commit fda530b

Please sign in to comment.