From bf74e94cd7cc3402281283685c1404a5add531b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= <4142+huitseeker@users.noreply.github.com> Date: Mon, 30 Oct 2023 18:53:39 -0400 Subject: [PATCH] Eliminate unnecessary wraps & use alloc_infallible (#246) * feat: Refactor unnecessary wraps in error handling Revised operations unnecessarily wrapping their return in a Result. * refactor: Refactor to use infallible allocation across application - Switched the allocation method to the infallible version, `AllocatedNum::alloc_infallible`, in multiple units (`poseidon.rs`, `bellpepper/mod.rs`, `utils.rs`, `circuit.rs`, `lib.rs`, `nifs.rs`). This removed the need for multiple error checks and reduced usage of `Result` return types. --- src/bellpepper/mod.rs | 15 ++++++--------- src/circuit.rs | 2 +- src/gadgets/ecc.rs | 19 +++++++++---------- src/gadgets/nonnative/bignat.rs | 2 +- src/gadgets/nonnative/util.rs | 7 +------ src/gadgets/r1cs.rs | 2 +- src/gadgets/utils.rs | 16 ++++++---------- src/lib.rs | 2 +- src/nifs.rs | 6 +++--- src/provider/poseidon.rs | 3 +-- src/r1cs/mod.rs | 6 +++--- 11 files changed, 33 insertions(+), 47 deletions(-) diff --git a/src/bellpepper/mod.rs b/src/bellpepper/mod.rs index f889d12c4..6677f3b0a 100644 --- a/src/bellpepper/mod.rs +++ b/src/bellpepper/mod.rs @@ -17,14 +17,12 @@ mod tests { }, traits::Group, }; - use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; + use bellpepper_core::{num::AllocatedNum, ConstraintSystem}; use ff::PrimeField; - fn synthesize_alloc_bit>( - cs: &mut CS, - ) -> Result<(), SynthesisError> { + fn synthesize_alloc_bit>(cs: &mut CS) { // get two bits as input and check that they are indeed bits - let a = AllocatedNum::alloc(cs.namespace(|| "a"), || Ok(Fr::ONE))?; + let a = AllocatedNum::alloc_infallible(cs.namespace(|| "a"), || Fr::ONE); let _ = a.inputize(cs.namespace(|| "a is input")); cs.enforce( || "check a is 0 or 1", @@ -32,7 +30,7 @@ mod tests { |lc| lc + a.get_variable(), |lc| lc, ); - let b = AllocatedNum::alloc(cs.namespace(|| "b"), || Ok(Fr::ONE))?; + let b = AllocatedNum::alloc_infallible(cs.namespace(|| "b"), || Fr::ONE); let _ = b.inputize(cs.namespace(|| "b is input")); cs.enforce( || "check b is 0 or 1", @@ -40,7 +38,6 @@ mod tests { |lc| lc + b.get_variable(), |lc| lc, ); - Ok(()) } fn test_alloc_bit_with() @@ -49,12 +46,12 @@ mod tests { { // First create the shape let mut cs: ShapeCS = ShapeCS::new(); - let _ = synthesize_alloc_bit(&mut cs); + synthesize_alloc_bit(&mut cs); let (shape, ck) = cs.r1cs_shape(); // Now get the assignment let mut cs: SatisfyingAssignment = SatisfyingAssignment::new(); - let _ = synthesize_alloc_bit(&mut cs); + synthesize_alloc_bit(&mut cs); let (inst, witness) = cs.r1cs_instance_and_witness(&shape, &ck).unwrap(); // Make sure that this is satisfiable diff --git a/src/circuit.rs b/src/circuit.rs index 013210946..252955fc9 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -266,7 +266,7 @@ impl<'a, G: Group, SC: StepCircuit> NovaAugmentedCircuit<'a, G, SC> { self.alloc_witness(cs.namespace(|| "allocate the circuit witness"), arity)?; // Compute variable indicating if this is the base case - let zero = alloc_zero(cs.namespace(|| "zero"))?; + let zero = alloc_zero(cs.namespace(|| "zero")); let is_base_case = alloc_num_equals(cs.namespace(|| "Check if base case"), &i.clone(), &zero)?; // Synthesize the circuit for the base case and get the new running instance diff --git a/src/gadgets/ecc.rs b/src/gadgets/ecc.rs index 9e0535f81..ac117e99b 100644 --- a/src/gadgets/ecc.rs +++ b/src/gadgets/ecc.rs @@ -68,8 +68,8 @@ where where CS: ConstraintSystem, { - let zero = alloc_zero(cs.namespace(|| "zero"))?; - let one = alloc_one(cs.namespace(|| "one"))?; + let zero = alloc_zero(cs.namespace(|| "zero")); + let one = alloc_one(cs.namespace(|| "one")); Ok(AllocatedPoint { x: zero.clone(), @@ -884,13 +884,12 @@ mod tests { pub fn inputize_allocted_point>( p: &AllocatedPoint, mut cs: CS, - ) -> Result<(), SynthesisError> { + ) { let _ = p.x.inputize(cs.namespace(|| "Input point.x")); let _ = p.y.inputize(cs.namespace(|| "Input point.y")); let _ = p .is_infinity .inputize(cs.namespace(|| "Input point.is_infinity")); - Ok(()) } #[test] @@ -964,7 +963,7 @@ mod tests { CS: ConstraintSystem, { let a = alloc_random_point(cs.namespace(|| "a")).unwrap(); - inputize_allocted_point(&a, cs.namespace(|| "inputize a")).unwrap(); + inputize_allocted_point(&a, cs.namespace(|| "inputize a")); let s = G::Scalar::random(&mut OsRng); // Allocate bits for s @@ -976,7 +975,7 @@ mod tests { .collect::, SynthesisError>>() .unwrap(); let e = a.scalar_mul(cs.namespace(|| "Scalar Mul"), &bits).unwrap(); - inputize_allocted_point(&e, cs.namespace(|| "inputize e")).unwrap(); + inputize_allocted_point(&e, cs.namespace(|| "inputize e")); (a, e, s) } @@ -1030,9 +1029,9 @@ mod tests { CS: ConstraintSystem, { let a = alloc_random_point(cs.namespace(|| "a")).unwrap(); - inputize_allocted_point(&a, cs.namespace(|| "inputize a")).unwrap(); + inputize_allocted_point(&a, cs.namespace(|| "inputize a")); let e = a.add(cs.namespace(|| "add a to a"), &a).unwrap(); - inputize_allocted_point(&e, cs.namespace(|| "inputize e")).unwrap(); + inputize_allocted_point(&e, cs.namespace(|| "inputize e")); (a, e) } @@ -1085,13 +1084,13 @@ mod tests { CS: ConstraintSystem, { let a = alloc_random_point(cs.namespace(|| "a")).unwrap(); - inputize_allocted_point(&a, cs.namespace(|| "inputize a")).unwrap(); + inputize_allocted_point(&a, cs.namespace(|| "inputize a")); let b = &mut a.clone(); b.y = AllocatedNum::alloc(cs.namespace(|| "allocate negation of a"), || { Ok(G::Base::ZERO) }) .unwrap(); - inputize_allocted_point(b, cs.namespace(|| "inputize b")).unwrap(); + inputize_allocted_point(b, cs.namespace(|| "inputize b")); let e = a.add(cs.namespace(|| "add a to b"), b).unwrap(); e } diff --git a/src/gadgets/nonnative/bignat.rs b/src/gadgets/nonnative/bignat.rs index 348af1b37..66bfb7411 100644 --- a/src/gadgets/nonnative/bignat.rs +++ b/src/gadgets/nonnative/bignat.rs @@ -244,7 +244,7 @@ impl BigNat { // (1) decompose `bignat` into a bitvector `bv` let bv = bignat.decompose(cs.namespace(|| "bv"))?; // (2) recompose bits and check if it equals n - n.is_equal(cs.namespace(|| "n"), &bv)?; + n.is_equal(cs.namespace(|| "n"), &bv); Ok(bignat) } diff --git a/src/gadgets/nonnative/util.rs b/src/gadgets/nonnative/util.rs index 5307882a0..4e7860f1b 100644 --- a/src/gadgets/nonnative/util.rs +++ b/src/gadgets/nonnative/util.rs @@ -157,11 +157,7 @@ impl Num { /// Computes the natural number represented by an array of bits. /// Checks if the natural number equals `self` - pub fn is_equal>( - &self, - mut cs: CS, - other: &Bitvector, - ) -> Result<(), SynthesisError> { + pub fn is_equal>(&self, mut cs: CS, other: &Bitvector) { let allocations = other.allocations.clone(); let mut f = Scalar::ONE; let sum = allocations @@ -173,7 +169,6 @@ impl Num { }); let sum_lc = LinearCombination::zero() + &self.num - ∑ cs.enforce(|| "sum", |lc| lc + &sum_lc, |lc| lc + CS::one(), |lc| lc); - Ok(()) } /// Compute the natural number represented by an array of limbs. diff --git a/src/gadgets/r1cs.rs b/src/gadgets/r1cs.rs index 6f9d444d4..29278339a 100644 --- a/src/gadgets/r1cs.rs +++ b/src/gadgets/r1cs.rs @@ -143,7 +143,7 @@ impl AllocatedRelaxedR1CSInstance { ) -> Result { let E = AllocatedPoint::default(cs.namespace(|| "allocate W"))?; - let u = alloc_one(cs.namespace(|| "one"))?; + let u = alloc_one(cs.namespace(|| "one")); let X0 = BigNat::from_num( cs.namespace(|| "allocate X0 from relaxed r1cs"), diff --git a/src/gadgets/utils.rs b/src/gadgets/utils.rs index d641c27ad..7478a1378 100644 --- a/src/gadgets/utils.rs +++ b/src/gadgets/utils.rs @@ -44,24 +44,20 @@ where } /// Allocate a variable that is set to zero -pub fn alloc_zero>( - mut cs: CS, -) -> Result, SynthesisError> { - let zero = AllocatedNum::alloc(cs.namespace(|| "alloc"), || Ok(F::ZERO))?; +pub fn alloc_zero>(mut cs: CS) -> AllocatedNum { + let zero = AllocatedNum::alloc_infallible(cs.namespace(|| "alloc"), || F::ZERO); cs.enforce( || "check zero is valid", |lc| lc, |lc| lc, |lc| lc + zero.get_variable(), ); - Ok(zero) + zero } /// Allocate a variable that is set to one -pub fn alloc_one>( - mut cs: CS, -) -> Result, SynthesisError> { - let one = AllocatedNum::alloc(cs.namespace(|| "alloc"), || Ok(F::ONE))?; +pub fn alloc_one>(mut cs: CS) -> AllocatedNum { + let one = AllocatedNum::alloc_infallible(cs.namespace(|| "alloc"), || F::ONE); cs.enforce( || "check one is valid", |lc| lc + CS::one(), @@ -69,7 +65,7 @@ pub fn alloc_one>( |lc| lc + one.get_variable(), ); - Ok(one) + one } /// Allocate a scalar as a base. Only to be used is the scalar fits in base! diff --git a/src/lib.rs b/src/lib.rs index af4cc8dd5..ed27c74d4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1324,7 +1324,7 @@ mod tests { let x = &z[0]; // we allocate a variable and set it to the provided non-deterministic advice. - let y = AllocatedNum::alloc(cs.namespace(|| "y"), || Ok(self.y))?; + let y = AllocatedNum::alloc_infallible(cs.namespace(|| "y"), || self.y); // We now check if y = x^{1/5} by checking if y^5 = x let y_sq = y.square(cs.namespace(|| "y_sq"))?; diff --git a/src/nifs.rs b/src/nifs.rs index 288055dea..442df98a0 100644 --- a/src/nifs.rs +++ b/src/nifs.rs @@ -59,7 +59,7 @@ impl NIFS { let r = ro.squeeze(NUM_CHALLENGE_BITS); // fold the instance using `r` and `comm_T` - let U = U1.fold(U2, &comm_T, &r)?; + let U = U1.fold(U2, &comm_T, &r); // fold the witness using `r` and `T` let W = W1.fold(W2, &T, &r)?; @@ -103,7 +103,7 @@ impl NIFS { let r = ro.squeeze(NUM_CHALLENGE_BITS); // fold the instance using `r` and `comm_T` - let U = U1.fold(U2, &comm_T, &r)?; + let U = U1.fold(U2, &comm_T, &r); // return the folded instance Ok(U) @@ -125,7 +125,7 @@ mod tests { x_val: Option, ) -> Result<(), SynthesisError> { // Consider a cubic equation: `x^3 + x + 5 = y`, where `x` and `y` are respectively the input and output. - let x = AllocatedNum::alloc(cs.namespace(|| "x"), || Ok(x_val.unwrap()))?; + let x = AllocatedNum::alloc_infallible(cs.namespace(|| "x"), || x_val.unwrap()); let _ = x.inputize(cs.namespace(|| "x is input")); let x_sq = x.square(cs.namespace(|| "x_sq"))?; diff --git a/src/provider/poseidon.rs b/src/provider/poseidon.rs index 917da91c3..76f20953e 100644 --- a/src/provider/poseidon.rs +++ b/src/provider/poseidon.rs @@ -224,8 +224,7 @@ mod tests { for i in 0..num_absorbs { let num = G::Scalar::random(&mut csprng); ro.absorb(num); - let num_gadget = - AllocatedNum::alloc(cs.namespace(|| format!("data {i}")), || Ok(num)).unwrap(); + let num_gadget = AllocatedNum::alloc_infallible(cs.namespace(|| format!("data {i}")), || num); num_gadget .inputize(&mut cs.namespace(|| format!("input {i}"))) .unwrap(); diff --git a/src/r1cs/mod.rs b/src/r1cs/mod.rs index 26adade81..5c20aa537 100644 --- a/src/r1cs/mod.rs +++ b/src/r1cs/mod.rs @@ -500,7 +500,7 @@ impl RelaxedR1CSInstance { U2: &R1CSInstance, comm_T: &Commitment, r: &G::Scalar, - ) -> Result, NovaError> { + ) -> RelaxedR1CSInstance { let (X1, u1, comm_W_1, comm_E_1) = (&self.X, &self.u, &self.comm_W.clone(), &self.comm_E.clone()); let (X2, comm_W_2) = (&U2.X, &U2.comm_W); @@ -515,12 +515,12 @@ impl RelaxedR1CSInstance { let comm_E = *comm_E_1 + *comm_T * *r; let u = *u1 + *r; - Ok(RelaxedR1CSInstance { + RelaxedR1CSInstance { comm_W, comm_E, X, u, - }) + } } }