From befe549acb0b76b6c6a76d1adac8362a0614bc64 Mon Sep 17 00:00:00 2001 From: guipublic <47281315+guipublic@users.noreply.github.com> Date: Tue, 25 Jul 2023 08:31:28 +0200 Subject: [PATCH] chore: use euclidian division for truncate and comparison (#1757) * use euclidian division for truncate and comparison * chore: fix merge issue * code review --------- Co-authored-by: TomAFrench --- .../acir_gen/acir_ir/acir_variable.rs | 14 +- .../acir_gen/acir_ir/generated_acir.rs | 148 ++++-------------- 2 files changed, 38 insertions(+), 124 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/acir_variable.rs b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/acir_variable.rs index 3580639893f..5ef4707f9a9 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/acir_variable.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/acir_variable.rs @@ -588,9 +588,17 @@ impl AcirContext { let lhs_data = &self.vars[&lhs]; let lhs_expr = lhs_data.to_expression(); - let result_expr = self.acir_ir.truncate(&lhs_expr, rhs, max_bit_size)?; - - Ok(self.add_data(AcirVarData::Expr(result_expr))) + // 2^{rhs} + let divisor = FieldElement::from(2_i128).pow(&FieldElement::from(rhs as i128)); + // Computes lhs = 2^{rhs} * q + r + let (_, remainder) = self.acir_ir.euclidean_division( + &lhs_expr, + &Expression::from_field(divisor), + max_bit_size, + &Expression::one(), + )?; + + Ok(self.add_data(AcirVarData::Expr(Expression::from(remainder)))) } /// Returns an `AcirVar` which will be `1` if lhs >= rhs diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/generated_acir.rs b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/generated_acir.rs index ff1fc680191..18c7216a6fa 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/generated_acir.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/generated_acir.rs @@ -112,92 +112,6 @@ impl GeneratedAcir { } impl GeneratedAcir { - /// Computes lhs = 2^{rhs_bit_size} * q + r - /// - /// For example, if we had a u32: - /// - `rhs` would be `32` - /// - `max_bits` would be the size of `lhs` - /// - /// Take the following code: - /// `` - /// fn main(x : u32) -> u32 { - /// let a = x + x; (L1) - /// let b = a * a; (L2) - /// b + b (L3) - /// } - /// `` - /// - /// Call truncate only on L1: - /// - `rhs` would be `32` - /// - `max_bits` would be `33` due to the addition of two u32s - /// Call truncate only on L2: - /// - `rhs` would be `32` - /// - `max_bits` would be `66` due to the multiplication of two u33s `a` - /// Call truncate only on L3: - /// - `rhs` would be `32` - /// - `max_bits` would be `67` due to the addition of two u66s `b` - /// - /// Truncation is done via the euclidean division formula: - /// - /// a = b * q + r - /// - /// where: - /// - a = `lhs` - /// - b = 2^{max_bits} - /// The prover will supply the quotient and the remainder, where the remainder - /// is the truncated value that we will return since it is enforced to be - /// in the range: 0 <= r < 2^{rhs_bit_size} - pub(crate) fn truncate( - &mut self, - lhs: &Expression, - rhs_bit_size: u32, - max_bits: u32, - ) -> Result { - assert!(max_bits > rhs_bit_size, "max_bits = {max_bits}, rhs = {rhs_bit_size} -- The caller should ensure that truncation is only called when the value needs to be truncated"); - let exp_big = BigUint::from(2_u32).pow(rhs_bit_size); - - // 0. Check for constant expression. - if let Some(a_c) = lhs.to_const() { - let mut a_big = BigUint::from_bytes_be(&a_c.to_be_bytes()); - a_big %= exp_big; - return Ok(Expression::from(FieldElement::from_be_bytes_reduce(&a_big.to_bytes_be()))); - } - // Note: This is doing a reduction. However, since the compiler will call - // `max_bits` before it overflows the modulus, this line should never do a reduction. - // - // For example, if the modulus is a 254 bit number. - // `max_bits` will never be 255 since `exp` will be 2^255, which will cause a reduction in the following line. - // TODO: We should change this from `from_be_bytes_reduce` to `from_be_bytes` - // TODO: the latter will return an option that we can unwrap in the compiler - let exp = FieldElement::from_be_bytes_reduce(&exp_big.to_bytes_be()); - - // 1. Generate witnesses a,b,c - - // According to the division theorem, the remainder needs to be 0 <= r < 2^{rhs_bit_size} - let r_max_bits = rhs_bit_size; - // According to the formula above, the quotient should be within the range 0 <= q < 2^{max_bits - rhs} - let q_max_bits = max_bits - rhs_bit_size; - - let (quotient_witness, remainder_witness) = - self.quotient_directive(lhs.clone(), exp.into(), None, q_max_bits, r_max_bits)?; - - // 2. Add the constraint a == r + (q * 2^{rhs}) - // - // 2^{rhs} - let mut two_pow_rhs_bits = FieldElement::from(2_i128); - two_pow_rhs_bits = two_pow_rhs_bits.pow(&FieldElement::from(rhs_bit_size as i128)); - - let remainder_expr = Expression::from(remainder_witness); - let quotient_expr = Expression::from(quotient_witness); - - let res = &remainder_expr + &(two_pow_rhs_bits * "ient_expr); - let euclidean_division = &res - lhs; - - self.push_opcode(AcirOpcode::Arithmetic(euclidean_division)); - - Ok(Expression::from(remainder_witness)) - } - /// Calls a black box function and returns the output /// of said blackbox function. pub(crate) fn call_black_box( @@ -457,16 +371,28 @@ impl GeneratedAcir { // lhs = rhs * q + r // // If predicate is zero, `q_witness` and `r_witness` will be 0 + + // maximum bit size for q and for [r and rhs] + let mut max_q_bits = max_bit_size; + let mut max_rhs_bits = max_bit_size; + // when rhs is constant, we can better estimate the maximum bit sizes + if let Some(rhs_const) = rhs.to_const() { + max_rhs_bits = rhs_const.num_bits(); + if max_rhs_bits != 0 { + max_q_bits = max_bit_size - max_rhs_bits + 1; + } + } + let (q_witness, r_witness) = self.quotient_directive( lhs.clone(), rhs.clone(), Some(predicate.clone()), - max_bit_size, - max_bit_size, + max_q_bits, + max_rhs_bits, )?; // Constrain r < rhs - self.bound_constraint_with_offset(&r_witness.into(), rhs, predicate, max_bit_size)?; + self.bound_constraint_with_offset(&r_witness.into(), rhs, predicate, max_rhs_bits)?; // a * predicate == (b * q + r) * predicate // => predicate * (a - b * q - r) == 0 @@ -766,26 +692,12 @@ impl GeneratedAcir { let two_max_bits: FieldElement = two.pow(&FieldElement::from(max_bits as i128)); let comparison_evaluation = (a - b) + two_max_bits; - // We want to enforce that `q` is a boolean value. - // In particular it should be the `n` bit of the `comparison_evaluation` - // which will indicate whether a >= b. + // Euclidian division by 2^{max_bits} : 2^{max_bits} + a - b = q * 2^{max_bits} + r // - // In the document linked above, they mention negating the value of `q` - // which would tell us whether a < b. Since we do not negate `q` - // what we get is a boolean indicating whether a >= b. - let q_max_bits = 1; - // `r` can take any value up to `two_max_bits`. - let r_max_bits = max_bits; - - let (q_witness, r_witness) = self.quotient_directive( - comparison_evaluation.clone(), - two_max_bits.into(), - Some(predicate.clone()), - q_max_bits, - r_max_bits, - )?; - - // Add constraint : 2^{max_bits} + a - b = q * 2^{max_bits} + r + // 2^{max_bits} is of max_bits+1 bit size + // If a>b, then a-b is less than 2^{max_bits} - 1, so 2^{max_bits} + a - b is less than 2^{max_bits} + 2^{max_bits} - 1 = 2^{max_bits+1} - 1 + // If a <= b, then 2^{max_bits} + a - b is less than 2^{max_bits} <= 2^{max_bits+1} - 1 + // This means that both operands of the division have at most max_bits+1 bit size. // // case: a == b // @@ -805,19 +717,13 @@ impl GeneratedAcir { // - 2^{max_bits} - k == q * 2^{max_bits} + r // - This is only the case when q == 0 and r == 2^{max_bits} - k // - // case: predicate is zero - // The values for q and r will be zero for a honest prover and - // can be garbage for a dishonest prover. The below constraint will - // will be switched off. - let mut expr = Expression::default(); - expr.push_addition_term(two_max_bits, q_witness); - expr.push_addition_term(FieldElement::one(), r_witness); - - let equation = &comparison_evaluation - &expr; - let predicated_equation = self.mul_with_witness(&equation, &predicate); - self.push_opcode(AcirOpcode::Arithmetic(predicated_equation)); - - Ok(q_witness) + let (q, _) = self.euclidean_division( + &comparison_evaluation, + &Expression::from(two_max_bits), + max_bits + 1, + &predicate, + )?; + Ok(q) } pub(crate) fn brillig(