From ce1546742dba79d9146f378fd62061e87f449d8c Mon Sep 17 00:00:00 2001 From: David Nevado Date: Wed, 3 Jul 2024 22:27:20 +0200 Subject: [PATCH 1/8] feat: add edge case handling for batch_add --- src/msm.rs | 119 +++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 112 insertions(+), 7 deletions(-) diff --git a/src/msm.rs b/src/msm.rs index 97bc426d..e2297d51 100644 --- a/src/msm.rs +++ b/src/msm.rs @@ -14,8 +14,7 @@ fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 { // Booth encoding: // * step by `window` size // * slice by size of `window + 1`` - // * each window overlap by 1 bit - // * append a zero bit to the least significant end + // * each window overlap by 1 bit * append a zero bit to the least significant end // Indexing rule for example window size 3 where we slice by 4 bits: // `[0, +1, +1, +2, +2, +3, +3, +4, -4, -3, -3 -2, -2, -1, -1, 0]`` // So we can reduce the bucket size without preprocessing scalars @@ -54,7 +53,9 @@ fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 { } } -fn batch_add( +// Batch addition without edge case handling: +// Will panic if a point is the identity or if two points share the x coordinate. +fn batch_add_nonexceptional( size: usize, buckets: &mut [BucketAffine], points: &[SchedulePoint], @@ -85,7 +86,9 @@ fn batch_add( acc *= *z; } - acc = acc.invert().unwrap(); + acc = acc + .invert() + .expect("Attempted to invert 0 at batch_add_nmonexceptional"); for ( ( @@ -112,6 +115,94 @@ fn batch_add( } } +/// Batch addition with edge case handling. +fn batch_add_exceptional( + size: usize, + buckets: &mut [BucketAffine], + points: &[SchedulePoint], + bases: &[Affine], +) { + let mut t = vec![C::Base::ZERO; size]; // Stores x2 - x1 + let mut z = vec![C::Base::ZERO; size]; // Stores y2 - y1 + let mut acc = C::Base::ONE; + + for ( + ( + SchedulePoint { + base_idx, + buck_idx, + sign, + }, + t, + ), + z, + ) in points.iter().zip(t.iter_mut()).zip(z.iter_mut()) + { + if buckets[*buck_idx].is_inf() { + // We assume bases[*base_idx] != infinity always. + continue; + } + + if buckets[*buck_idx].x() == bases[*base_idx].x { + // y-coordinate matches: + // 1. y1 == y2 and sign = false or + // 2. y1 != y2 and sign = true + // => ( y1 == y2) xor !sign + // (This uses the fact that x1 == x2 and both points satisfy the curve eq.) + if (buckets[*buck_idx].y() == bases[*base_idx].y) ^ !*sign { + // Doubling + let x_squared = bases[*base_idx].x.square(); + *z = buckets[*buck_idx].y() + buckets[*buck_idx].y(); // 2y + *t = acc * (x_squared + x_squared + x_squared); // acc * 3x^2 + acc *= *z; + continue; + } + // P + (-P) + buckets[*buck_idx].set_inf(); + continue; + } + // Addition + *z = buckets[*buck_idx].x() - bases[*base_idx].x; // x2 - x1 + if *sign { + *t = acc * (buckets[*buck_idx].y() - bases[*base_idx].y); + } else { + *t = acc * (buckets[*buck_idx].y() + bases[*base_idx].y); + } // y2 - y1 + acc *= *z; + } + + acc = acc + .invert() + .expect("Some edge case has not been handled properly"); + + for ( + ( + SchedulePoint { + base_idx, + buck_idx, + sign, + }, + t, + ), + z, + ) in points.iter().zip(t.iter()).zip(z.iter()).rev() + { + if buckets[*buck_idx].is_inf() { + // We assume bases[*base_idx] != infinity always. + continue; + } + let lambda = acc * t; + acc *= z; // update acc + let x = lambda.square() - (buckets[*buck_idx].x() + bases[*base_idx].x); // x_result + if *sign { + buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) - bases[*base_idx].y)); + } else { + buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) + bases[*base_idx].y)); + } // y_result = lambda * (x1 - x_result) - y1 + buckets[*buck_idx].set_x(&x); + } +} + #[derive(Debug, Clone, Copy)] struct Affine { x: C::Base, @@ -207,6 +298,13 @@ impl BucketAffine { } } + fn is_inf(&self) -> bool { + match self { + Self::None => true, + Self::Point(_) => false, + } + } + fn set_x(&mut self, x: &C::Base) { match self { Self::None => panic!("::set_x None"), @@ -220,6 +318,13 @@ impl BucketAffine { Self::Point(ref mut a) => a.y = *y, } } + + fn set_inf(&mut self) { + match self { + Self::None => {} + Self::Point(_) => *self = Self::None, + } + } } struct Schedule { @@ -266,7 +371,7 @@ impl Schedule { fn execute(&mut self, bases: &[Affine]) { if self.ptr != 0 { - batch_add(self.ptr, &mut self.buckets, &self.set, bases); + batch_add_nonexceptional(self.ptr, &mut self.buckets, &self.set, bases); self.ptr = 0; self.set .iter_mut() @@ -491,7 +596,6 @@ pub fn best_multiexp_independent_points( #[cfg(test)] mod test { - use std::ops::Neg; use crate::bn256::{Fr, G1Affine, G1}; @@ -547,6 +651,7 @@ mod test { } } + #[cfg(test)] fn run_msm_cross(min_k: usize, max_k: usize) { let points = (0..1 << max_k) .map(|_| C::Curve::random(OsRng)) @@ -563,7 +668,7 @@ mod test { let points = &points[..1 << k]; let scalars = &scalars[..1 << k]; - let t0 = start_timer!(|| format!("cyclone k={}", k)); + let t0 = start_timer!(|| format!("cyclone indep k={}", k)); let e0 = super::best_multiexp_independent_points(scalars, points); end_timer!(t0); From 791680a6f8b4183509d3346cd7c4473cf1737aa1 Mon Sep 17 00:00:00 2001 From: David Nevado Date: Thu, 11 Jul 2024 20:20:19 +0200 Subject: [PATCH 2/8] feat: handle edge cases in msm + rename functions --- src/msm.rs | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/msm.rs b/src/msm.rs index e2297d51..d36ce46a 100644 --- a/src/msm.rs +++ b/src/msm.rs @@ -371,7 +371,7 @@ impl Schedule { fn execute(&mut self, bases: &[Affine]) { if self.ptr != 0 { - batch_add_nonexceptional(self.ptr, &mut self.buckets, &self.set, bases); + batch_add_exceptional(self.ptr, &mut self.buckets, &self.set, bases); self.ptr = 0; self.set .iter_mut() @@ -391,7 +391,7 @@ impl Schedule { } } -pub fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { +pub fn serial_multiexp(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); let c = if bases.len() < 4 { @@ -487,7 +487,7 @@ pub fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: & /// This function will panic if coeffs and bases have a different length. /// /// This will use multithreading if beneficial. -pub fn best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { +pub fn parallel_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { assert_eq!(coeffs.len(), bases.len()); let num_threads = rayon::current_num_threads(); @@ -504,14 +504,14 @@ pub fn best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Cu .zip(results.iter_mut()) { scope.spawn(move |_| { - multiexp_serial(coeffs, bases, acc); + serial_multiexp(coeffs, bases, acc); }); } }); results.iter().fold(C::Curve::identity(), |a, b| a + b) } else { let mut acc = C::Curve::identity(); - multiexp_serial(coeffs, bases, &mut acc); + serial_multiexp(coeffs, bases, &mut acc); acc } } @@ -519,10 +519,7 @@ pub fn best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Cu /// This function will panic if coeffs and bases have a different length. /// /// This will use multithreading if beneficial. -pub fn best_multiexp_independent_points( - coeffs: &[C::Scalar], - bases: &[C], -) -> C::Curve { +pub fn best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { assert_eq!(coeffs.len(), bases.len()); // TODO: consider adjusting it with emprical data? @@ -535,7 +532,7 @@ pub fn best_multiexp_independent_points( }; if c < 10 { - return best_multiexp(coeffs, bases); + return parallel_multiexp(coeffs, bases); } // coeffs to byte representation @@ -669,11 +666,11 @@ mod test { let scalars = &scalars[..1 << k]; let t0 = start_timer!(|| format!("cyclone indep k={}", k)); - let e0 = super::best_multiexp_independent_points(scalars, points); + let e0 = super::best_multiexp(scalars, points); end_timer!(t0); let t1 = start_timer!(|| format!("older k={}", k)); - let e1 = super::best_multiexp(scalars, points); + let e1 = super::parallel_multiexp(scalars, points); end_timer!(t1); assert_eq!(e0, e1); } From e2aa6cfadd2f2197d3ca9688b874af48fa2a5aae Mon Sep 17 00:00:00 2001 From: David Nevado Date: Thu, 11 Jul 2024 20:21:02 +0200 Subject: [PATCH 3/8] chore: generate test points in parallel --- src/msm.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/msm.rs b/src/msm.rs index d36ce46a..882653ce 100644 --- a/src/msm.rs +++ b/src/msm.rs @@ -650,7 +650,10 @@ mod test { #[cfg(test)] fn run_msm_cross(min_k: usize, max_k: usize) { + use rayon::iter::{IntoParallelIterator, ParallelIterator}; + let points = (0..1 << max_k) + .into_par_iter() .map(|_| C::Curve::random(OsRng)) .collect::>(); let mut affine_points = vec![C::identity(); 1 << max_k]; @@ -658,6 +661,7 @@ mod test { let points = affine_points; let scalars = (0..1 << max_k) + .into_par_iter() .map(|_| C::Scalar::random(OsRng)) .collect::>(); From 549f32712b297cc02d13874e67b1fbf01132d548 Mon Sep 17 00:00:00 2001 From: David Nevado Date: Fri, 19 Jul 2024 17:34:47 +0200 Subject: [PATCH 4/8] chore: remove redundant cfg --- src/msm.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/msm.rs b/src/msm.rs index 882653ce..3e41d74c 100644 --- a/src/msm.rs +++ b/src/msm.rs @@ -648,7 +648,6 @@ mod test { } } - #[cfg(test)] fn run_msm_cross(min_k: usize, max_k: usize) { use rayon::iter::{IntoParallelIterator, ParallelIterator}; From 6efd1303844a371a0171b0d4ebd4ffc8e1f8bd8e Mon Sep 17 00:00:00 2001 From: David Nevado Date: Fri, 19 Jul 2024 17:39:16 +0200 Subject: [PATCH 5/8] refactor: rename msm functions --- benches/msm.rs | 11 ++++++++--- src/msm.rs | 23 +++++++++++++---------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/benches/msm.rs b/benches/msm.rs index b16afbab..9042d955 100644 --- a/benches/msm.rs +++ b/benches/msm.rs @@ -16,8 +16,13 @@ use criterion::{BenchmarkId, Criterion}; use ff::{Field, PrimeField}; use group::prime::PrimeCurveAffine; use halo2curves::bn256::{Fr as Scalar, G1Affine as Point}; -use halo2curves::msm::{best_multiexp, multiexp_serial}; +// <<<<<<< HEAD +// use halo2curves::msm::{best_multiexp, msm_serial}; use rand_core::{RngCore, SeedableRng}; +// ======= +use halo2curves::msm::{msm_best, msm_serial}; +// use rand_core::SeedableRng; +// >>>>>>> defcdc2 (refactor: rename msm functions) use rand_xorshift::XorShiftRng; use rayon::current_thread_index; use rayon::prelude::{IntoParallelIterator, ParallelIterator}; @@ -136,7 +141,7 @@ fn msm(c: &mut Criterion) { assert!(k < 64); let n: usize = 1 << k; let mut acc = Point::identity().into(); - b.iter(|| multiexp_serial(&coeffs[b_index][..n], &bases[..n], &mut acc)); + b.iter(|| msm_serial(&coeffs[b_index][..n], &bases[..n], &mut acc)); }) .sample_size(10); } @@ -147,7 +152,7 @@ fn msm(c: &mut Criterion) { assert!(k < 64); let n: usize = 1 << k; b.iter(|| { - best_multiexp(&coeffs[b_index][..n], &bases[..n]); + msm_best(&coeffs[b_index][..n], &bases[..n]); }) }) .sample_size(SAMPLE_SIZE); diff --git a/src/msm.rs b/src/msm.rs index 3e41d74c..090b4c58 100644 --- a/src/msm.rs +++ b/src/msm.rs @@ -391,7 +391,10 @@ impl Schedule { } } -pub fn serial_multiexp(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { +/// Performs a multi-scalar multiplication operation. +/// +/// This function will panic if coeffs and bases have a different length. +pub fn msm_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); let c = if bases.len() < 4 { @@ -482,12 +485,12 @@ pub fn serial_multiexp(coeffs: &[C::Scalar], bases: &[C], acc: & } } -/// Performs a multi-exponentiation operation. +/// Performs a multi-scalar multiplication operation. /// /// This function will panic if coeffs and bases have a different length. /// /// This will use multithreading if beneficial. -pub fn parallel_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { +pub fn msm_parallel(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { assert_eq!(coeffs.len(), bases.len()); let num_threads = rayon::current_num_threads(); @@ -504,22 +507,22 @@ pub fn parallel_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C .zip(results.iter_mut()) { scope.spawn(move |_| { - serial_multiexp(coeffs, bases, acc); + msm_serial(coeffs, bases, acc); }); } }); results.iter().fold(C::Curve::identity(), |a, b| a + b) } else { let mut acc = C::Curve::identity(); - serial_multiexp(coeffs, bases, &mut acc); + msm_serial(coeffs, bases, &mut acc); acc } } -/// + /// This function will panic if coeffs and bases have a different length. /// /// This will use multithreading if beneficial. -pub fn best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { +pub fn msm_best(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { assert_eq!(coeffs.len(), bases.len()); // TODO: consider adjusting it with emprical data? @@ -532,7 +535,7 @@ pub fn best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Cu }; if c < 10 { - return parallel_multiexp(coeffs, bases); + return msm_parallel(coeffs, bases); } // coeffs to byte representation @@ -669,11 +672,11 @@ mod test { let scalars = &scalars[..1 << k]; let t0 = start_timer!(|| format!("cyclone indep k={}", k)); - let e0 = super::best_multiexp(scalars, points); + let e0 = super::msm_best(scalars, points); end_timer!(t0); let t1 = start_timer!(|| format!("older k={}", k)); - let e1 = super::parallel_multiexp(scalars, points); + let e1 = super::msm_parallel(scalars, points); end_timer!(t1); assert_eq!(e0, e1); } From af39bcd67dc164575ed27f4930663dc6f641e2f5 Mon Sep 17 00:00:00 2001 From: David Nevado Date: Fri, 19 Jul 2024 17:41:24 +0200 Subject: [PATCH 6/8] chore: remove batch_add w/o edge case handling --- src/msm.rs | 68 +++--------------------------------------------------- 1 file changed, 3 insertions(+), 65 deletions(-) diff --git a/src/msm.rs b/src/msm.rs index 090b4c58..2b0c59ea 100644 --- a/src/msm.rs +++ b/src/msm.rs @@ -53,70 +53,8 @@ fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 { } } -// Batch addition without edge case handling: -// Will panic if a point is the identity or if two points share the x coordinate. -fn batch_add_nonexceptional( - size: usize, - buckets: &mut [BucketAffine], - points: &[SchedulePoint], - bases: &[Affine], -) { - let mut t = vec![C::Base::ZERO; size]; - let mut z = vec![C::Base::ZERO; size]; - let mut acc = C::Base::ONE; - - for ( - ( - SchedulePoint { - base_idx, - buck_idx, - sign, - }, - t, - ), - z, - ) in points.iter().zip(t.iter_mut()).zip(z.iter_mut()) - { - *z = buckets[*buck_idx].x() - bases[*base_idx].x; - if *sign { - *t = acc * (buckets[*buck_idx].y() - bases[*base_idx].y); - } else { - *t = acc * (buckets[*buck_idx].y() + bases[*base_idx].y); - } - acc *= *z; - } - - acc = acc - .invert() - .expect("Attempted to invert 0 at batch_add_nmonexceptional"); - - for ( - ( - SchedulePoint { - base_idx, - buck_idx, - sign, - }, - t, - ), - z, - ) in points.iter().zip(t.iter()).zip(z.iter()).rev() - { - let lambda = acc * t; - acc *= z; - - let x = lambda.square() - (buckets[*buck_idx].x() + bases[*base_idx].x); - if *sign { - buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) - bases[*base_idx].y)); - } else { - buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) + bases[*base_idx].y)); - } - buckets[*buck_idx].set_x(&x); - } -} - -/// Batch addition with edge case handling. -fn batch_add_exceptional( +/// Batch addition. +fn batch_add( size: usize, buckets: &mut [BucketAffine], points: &[SchedulePoint], @@ -371,7 +309,7 @@ impl Schedule { fn execute(&mut self, bases: &[Affine]) { if self.ptr != 0 { - batch_add_exceptional(self.ptr, &mut self.buckets, &self.set, bases); + batch_add(self.ptr, &mut self.buckets, &self.set, bases); self.ptr = 0; self.set .iter_mut() From 6ede410cd67696a27691fcd020d89e29a45a5d39 Mon Sep 17 00:00:00 2001 From: David Nevado Date: Tue, 23 Jul 2024 19:03:10 +0200 Subject: [PATCH 7/8] fix: clippy --- src/msm.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/msm.rs b/src/msm.rs index 2b0c59ea..615bcb5f 100644 --- a/src/msm.rs +++ b/src/msm.rs @@ -349,7 +349,7 @@ pub fn msm_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut C let mut acc_or = vec![0; field_byte_size]; for coeff in &coeffs { for (acc_limb, limb) in acc_or.iter_mut().zip(coeff.as_ref().iter()) { - *acc_limb = *acc_limb | *limb; + *acc_limb |= *limb; } } let max_byte_size = field_byte_size @@ -361,7 +361,7 @@ pub fn msm_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut C if max_byte_size == 0 { return; } - let number_of_windows = max_byte_size * 8 as usize / c + 1; + let number_of_windows = max_byte_size * 8_usize / c + 1; for current_window in (0..number_of_windows).rev() { for _ in 0..c { From 4e457f54ff14b8b1fe9d7ac6aadb0feec5594b42 Mon Sep 17 00:00:00 2001 From: David Nevado Date: Thu, 25 Jul 2024 11:36:06 +0200 Subject: [PATCH 8/8] fix: leftover comment --- benches/msm.rs | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/benches/msm.rs b/benches/msm.rs index 9042d955..cf806b76 100644 --- a/benches/msm.rs +++ b/benches/msm.rs @@ -16,13 +16,8 @@ use criterion::{BenchmarkId, Criterion}; use ff::{Field, PrimeField}; use group::prime::PrimeCurveAffine; use halo2curves::bn256::{Fr as Scalar, G1Affine as Point}; -// <<<<<<< HEAD -// use halo2curves::msm::{best_multiexp, msm_serial}; -use rand_core::{RngCore, SeedableRng}; -// ======= use halo2curves::msm::{msm_best, msm_serial}; -// use rand_core::SeedableRng; -// >>>>>>> defcdc2 (refactor: rename msm functions) +use rand_core::{RngCore, SeedableRng}; use rand_xorshift::XorShiftRng; use rayon::current_thread_index; use rayon::prelude::{IntoParallelIterator, ParallelIterator};