Skip to content

Commit

Permalink
Eliminate unnecessary wraps & use alloc_infallible (argumentcomputer#246
Browse files Browse the repository at this point in the history
)

* feat: Refactor unnecessary wraps in error handling

Revised operations unnecessarily wrapping their return in a Result.

* refactor: Refactor to use infallible allocation across application

- Switched the allocation method to the infallible version, `AllocatedNum::alloc_infallible`, in multiple units (`poseidon.rs`, `bellpepper/mod.rs`, `utils.rs`, `circuit.rs`, `lib.rs`, `nifs.rs`).
  This removed the need for multiple error checks and reduced usage of `Result` return types.
  • Loading branch information
huitseeker committed Oct 30, 2023
1 parent 4222d83 commit bf74e94
Show file tree
Hide file tree
Showing 11 changed files with 33 additions and 47 deletions.
15 changes: 6 additions & 9 deletions src/bellpepper/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,27 @@ mod tests {
},
traits::Group,
};
use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError};
use bellpepper_core::{num::AllocatedNum, ConstraintSystem};
use ff::PrimeField;

fn synthesize_alloc_bit<Fr: PrimeField, CS: ConstraintSystem<Fr>>(
cs: &mut CS,
) -> Result<(), SynthesisError> {
fn synthesize_alloc_bit<Fr: PrimeField, CS: ConstraintSystem<Fr>>(cs: &mut CS) {
// get two bits as input and check that they are indeed bits
let a = AllocatedNum::alloc(cs.namespace(|| "a"), || Ok(Fr::ONE))?;
let a = AllocatedNum::alloc_infallible(cs.namespace(|| "a"), || Fr::ONE);
let _ = a.inputize(cs.namespace(|| "a is input"));
cs.enforce(
|| "check a is 0 or 1",
|lc| lc + CS::one() - a.get_variable(),
|lc| lc + a.get_variable(),
|lc| lc,
);
let b = AllocatedNum::alloc(cs.namespace(|| "b"), || Ok(Fr::ONE))?;
let b = AllocatedNum::alloc_infallible(cs.namespace(|| "b"), || Fr::ONE);
let _ = b.inputize(cs.namespace(|| "b is input"));
cs.enforce(
|| "check b is 0 or 1",
|lc| lc + CS::one() - b.get_variable(),
|lc| lc + b.get_variable(),
|lc| lc,
);
Ok(())
}

fn test_alloc_bit_with<G>()
Expand All @@ -49,12 +46,12 @@ mod tests {
{
// First create the shape
let mut cs: ShapeCS<G> = ShapeCS::new();
let _ = synthesize_alloc_bit(&mut cs);
synthesize_alloc_bit(&mut cs);
let (shape, ck) = cs.r1cs_shape();

// Now get the assignment
let mut cs: SatisfyingAssignment<G> = SatisfyingAssignment::new();
let _ = synthesize_alloc_bit(&mut cs);
synthesize_alloc_bit(&mut cs);
let (inst, witness) = cs.r1cs_instance_and_witness(&shape, &ck).unwrap();

// Make sure that this is satisfiable
Expand Down
2 changes: 1 addition & 1 deletion src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ impl<'a, G: Group, SC: StepCircuit<G::Base>> NovaAugmentedCircuit<'a, G, SC> {
self.alloc_witness(cs.namespace(|| "allocate the circuit witness"), arity)?;

// Compute variable indicating if this is the base case
let zero = alloc_zero(cs.namespace(|| "zero"))?;
let zero = alloc_zero(cs.namespace(|| "zero"));
let is_base_case = alloc_num_equals(cs.namespace(|| "Check if base case"), &i.clone(), &zero)?;

// Synthesize the circuit for the base case and get the new running instance
Expand Down
19 changes: 9 additions & 10 deletions src/gadgets/ecc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ where
where
CS: ConstraintSystem<G::Base>,
{
let zero = alloc_zero(cs.namespace(|| "zero"))?;
let one = alloc_one(cs.namespace(|| "one"))?;
let zero = alloc_zero(cs.namespace(|| "zero"));
let one = alloc_one(cs.namespace(|| "one"));

Ok(AllocatedPoint {
x: zero.clone(),
Expand Down Expand Up @@ -884,13 +884,12 @@ mod tests {
pub fn inputize_allocted_point<G: Group, CS: ConstraintSystem<G::Base>>(
p: &AllocatedPoint<G>,
mut cs: CS,
) -> Result<(), SynthesisError> {
) {
let _ = p.x.inputize(cs.namespace(|| "Input point.x"));
let _ = p.y.inputize(cs.namespace(|| "Input point.y"));
let _ = p
.is_infinity
.inputize(cs.namespace(|| "Input point.is_infinity"));
Ok(())
}

#[test]
Expand Down Expand Up @@ -964,7 +963,7 @@ mod tests {
CS: ConstraintSystem<G::Base>,
{
let a = alloc_random_point(cs.namespace(|| "a")).unwrap();
inputize_allocted_point(&a, cs.namespace(|| "inputize a")).unwrap();
inputize_allocted_point(&a, cs.namespace(|| "inputize a"));

let s = G::Scalar::random(&mut OsRng);
// Allocate bits for s
Expand All @@ -976,7 +975,7 @@ mod tests {
.collect::<Result<Vec<AllocatedBit>, SynthesisError>>()
.unwrap();
let e = a.scalar_mul(cs.namespace(|| "Scalar Mul"), &bits).unwrap();
inputize_allocted_point(&e, cs.namespace(|| "inputize e")).unwrap();
inputize_allocted_point(&e, cs.namespace(|| "inputize e"));
(a, e, s)
}

Expand Down Expand Up @@ -1030,9 +1029,9 @@ mod tests {
CS: ConstraintSystem<G::Base>,
{
let a = alloc_random_point(cs.namespace(|| "a")).unwrap();
inputize_allocted_point(&a, cs.namespace(|| "inputize a")).unwrap();
inputize_allocted_point(&a, cs.namespace(|| "inputize a"));
let e = a.add(cs.namespace(|| "add a to a"), &a).unwrap();
inputize_allocted_point(&e, cs.namespace(|| "inputize e")).unwrap();
inputize_allocted_point(&e, cs.namespace(|| "inputize e"));
(a, e)
}

Expand Down Expand Up @@ -1085,13 +1084,13 @@ mod tests {
CS: ConstraintSystem<G::Base>,
{
let a = alloc_random_point(cs.namespace(|| "a")).unwrap();
inputize_allocted_point(&a, cs.namespace(|| "inputize a")).unwrap();
inputize_allocted_point(&a, cs.namespace(|| "inputize a"));
let b = &mut a.clone();
b.y = AllocatedNum::alloc(cs.namespace(|| "allocate negation of a"), || {
Ok(G::Base::ZERO)
})
.unwrap();
inputize_allocted_point(b, cs.namespace(|| "inputize b")).unwrap();
inputize_allocted_point(b, cs.namespace(|| "inputize b"));
let e = a.add(cs.namespace(|| "add a to b"), b).unwrap();
e
}
Expand Down
2 changes: 1 addition & 1 deletion src/gadgets/nonnative/bignat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ impl<Scalar: PrimeField> BigNat<Scalar> {
// (1) decompose `bignat` into a bitvector `bv`
let bv = bignat.decompose(cs.namespace(|| "bv"))?;
// (2) recompose bits and check if it equals n
n.is_equal(cs.namespace(|| "n"), &bv)?;
n.is_equal(cs.namespace(|| "n"), &bv);

Ok(bignat)
}
Expand Down
7 changes: 1 addition & 6 deletions src/gadgets/nonnative/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,7 @@ impl<Scalar: PrimeField> Num<Scalar> {

/// Computes the natural number represented by an array of bits.
/// Checks if the natural number equals `self`
pub fn is_equal<CS: ConstraintSystem<Scalar>>(
&self,
mut cs: CS,
other: &Bitvector<Scalar>,
) -> Result<(), SynthesisError> {
pub fn is_equal<CS: ConstraintSystem<Scalar>>(&self, mut cs: CS, other: &Bitvector<Scalar>) {
let allocations = other.allocations.clone();
let mut f = Scalar::ONE;
let sum = allocations
Expand All @@ -173,7 +169,6 @@ impl<Scalar: PrimeField> Num<Scalar> {
});
let sum_lc = LinearCombination::zero() + &self.num - &sum;
cs.enforce(|| "sum", |lc| lc + &sum_lc, |lc| lc + CS::one(), |lc| lc);
Ok(())
}

/// Compute the natural number represented by an array of limbs.
Expand Down
2 changes: 1 addition & 1 deletion src/gadgets/r1cs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ impl<G: Group> AllocatedRelaxedR1CSInstance<G> {
) -> Result<Self, SynthesisError> {
let E = AllocatedPoint::default(cs.namespace(|| "allocate W"))?;

let u = alloc_one(cs.namespace(|| "one"))?;
let u = alloc_one(cs.namespace(|| "one"));

let X0 = BigNat::from_num(
cs.namespace(|| "allocate X0 from relaxed r1cs"),
Expand Down
16 changes: 6 additions & 10 deletions src/gadgets/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,32 +44,28 @@ where
}

/// Allocate a variable that is set to zero
pub fn alloc_zero<F: PrimeField, CS: ConstraintSystem<F>>(
mut cs: CS,
) -> Result<AllocatedNum<F>, SynthesisError> {
let zero = AllocatedNum::alloc(cs.namespace(|| "alloc"), || Ok(F::ZERO))?;
pub fn alloc_zero<F: PrimeField, CS: ConstraintSystem<F>>(mut cs: CS) -> AllocatedNum<F> {
let zero = AllocatedNum::alloc_infallible(cs.namespace(|| "alloc"), || F::ZERO);
cs.enforce(
|| "check zero is valid",
|lc| lc,
|lc| lc,
|lc| lc + zero.get_variable(),
);
Ok(zero)
zero
}

/// Allocate a variable that is set to one
pub fn alloc_one<F: PrimeField, CS: ConstraintSystem<F>>(
mut cs: CS,
) -> Result<AllocatedNum<F>, SynthesisError> {
let one = AllocatedNum::alloc(cs.namespace(|| "alloc"), || Ok(F::ONE))?;
pub fn alloc_one<F: PrimeField, CS: ConstraintSystem<F>>(mut cs: CS) -> AllocatedNum<F> {
let one = AllocatedNum::alloc_infallible(cs.namespace(|| "alloc"), || F::ONE);
cs.enforce(
|| "check one is valid",
|lc| lc + CS::one(),
|lc| lc + CS::one(),
|lc| lc + one.get_variable(),
);

Ok(one)
one
}

/// Allocate a scalar as a base. Only to be used is the scalar fits in base!
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1324,7 +1324,7 @@ mod tests {
let x = &z[0];

// we allocate a variable and set it to the provided non-deterministic advice.
let y = AllocatedNum::alloc(cs.namespace(|| "y"), || Ok(self.y))?;
let y = AllocatedNum::alloc_infallible(cs.namespace(|| "y"), || self.y);

// We now check if y = x^{1/5} by checking if y^5 = x
let y_sq = y.square(cs.namespace(|| "y_sq"))?;
Expand Down
6 changes: 3 additions & 3 deletions src/nifs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl<G: Group> NIFS<G> {
let r = ro.squeeze(NUM_CHALLENGE_BITS);

// fold the instance using `r` and `comm_T`
let U = U1.fold(U2, &comm_T, &r)?;
let U = U1.fold(U2, &comm_T, &r);

// fold the witness using `r` and `T`
let W = W1.fold(W2, &T, &r)?;
Expand Down Expand Up @@ -103,7 +103,7 @@ impl<G: Group> NIFS<G> {
let r = ro.squeeze(NUM_CHALLENGE_BITS);

// fold the instance using `r` and `comm_T`
let U = U1.fold(U2, &comm_T, &r)?;
let U = U1.fold(U2, &comm_T, &r);

// return the folded instance
Ok(U)
Expand All @@ -125,7 +125,7 @@ mod tests {
x_val: Option<Scalar>,
) -> Result<(), SynthesisError> {
// Consider a cubic equation: `x^3 + x + 5 = y`, where `x` and `y` are respectively the input and output.
let x = AllocatedNum::alloc(cs.namespace(|| "x"), || Ok(x_val.unwrap()))?;
let x = AllocatedNum::alloc_infallible(cs.namespace(|| "x"), || x_val.unwrap());
let _ = x.inputize(cs.namespace(|| "x is input"));

let x_sq = x.square(cs.namespace(|| "x_sq"))?;
Expand Down
3 changes: 1 addition & 2 deletions src/provider/poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,7 @@ mod tests {
for i in 0..num_absorbs {
let num = G::Scalar::random(&mut csprng);
ro.absorb(num);
let num_gadget =
AllocatedNum::alloc(cs.namespace(|| format!("data {i}")), || Ok(num)).unwrap();
let num_gadget = AllocatedNum::alloc_infallible(cs.namespace(|| format!("data {i}")), || num);
num_gadget
.inputize(&mut cs.namespace(|| format!("input {i}")))
.unwrap();
Expand Down
6 changes: 3 additions & 3 deletions src/r1cs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ impl<G: Group> RelaxedR1CSInstance<G> {
U2: &R1CSInstance<G>,
comm_T: &Commitment<G>,
r: &G::Scalar,
) -> Result<RelaxedR1CSInstance<G>, NovaError> {
) -> RelaxedR1CSInstance<G> {
let (X1, u1, comm_W_1, comm_E_1) =
(&self.X, &self.u, &self.comm_W.clone(), &self.comm_E.clone());
let (X2, comm_W_2) = (&U2.X, &U2.comm_W);
Expand All @@ -515,12 +515,12 @@ impl<G: Group> RelaxedR1CSInstance<G> {
let comm_E = *comm_E_1 + *comm_T * *r;
let u = *u1 + *r;

Ok(RelaxedR1CSInstance {
RelaxedR1CSInstance {
comm_W,
comm_E,
X,
u,
})
}
}
}

Expand Down

0 comments on commit bf74e94

Please sign in to comment.