Skip to content

Commit

Permalink
chore: use euclidian division for truncate and comparison (#1757)
Browse files Browse the repository at this point in the history
* use euclidian division for truncate and comparison

* chore: fix merge issue

* code review

---------

Co-authored-by: TomAFrench <tom@tomfren.ch>
  • Loading branch information
guipublic and TomAFrench authored Jul 25, 2023
1 parent 0444b52 commit befe549
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 124 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -588,9 +588,17 @@ impl AcirContext {
let lhs_data = &self.vars[&lhs];
let lhs_expr = lhs_data.to_expression();

let result_expr = self.acir_ir.truncate(&lhs_expr, rhs, max_bit_size)?;

Ok(self.add_data(AcirVarData::Expr(result_expr)))
// 2^{rhs}
let divisor = FieldElement::from(2_i128).pow(&FieldElement::from(rhs as i128));
// Computes lhs = 2^{rhs} * q + r
let (_, remainder) = self.acir_ir.euclidean_division(
&lhs_expr,
&Expression::from_field(divisor),
max_bit_size,
&Expression::one(),
)?;

Ok(self.add_data(AcirVarData::Expr(Expression::from(remainder))))
}

/// Returns an `AcirVar` which will be `1` if lhs >= rhs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,92 +112,6 @@ impl GeneratedAcir {
}

impl GeneratedAcir {
/// Computes lhs = 2^{rhs_bit_size} * q + r
///
/// For example, if we had a u32:
/// - `rhs` would be `32`
/// - `max_bits` would be the size of `lhs`
///
/// Take the following code:
/// ``
/// fn main(x : u32) -> u32 {
/// let a = x + x; (L1)
/// let b = a * a; (L2)
/// b + b (L3)
/// }
/// ``
///
/// Call truncate only on L1:
/// - `rhs` would be `32`
/// - `max_bits` would be `33` due to the addition of two u32s
/// Call truncate only on L2:
/// - `rhs` would be `32`
/// - `max_bits` would be `66` due to the multiplication of two u33s `a`
/// Call truncate only on L3:
/// - `rhs` would be `32`
/// - `max_bits` would be `67` due to the addition of two u66s `b`
///
/// Truncation is done via the euclidean division formula:
///
/// a = b * q + r
///
/// where:
/// - a = `lhs`
/// - b = 2^{max_bits}
/// The prover will supply the quotient and the remainder, where the remainder
/// is the truncated value that we will return since it is enforced to be
/// in the range: 0 <= r < 2^{rhs_bit_size}
pub(crate) fn truncate(
&mut self,
lhs: &Expression,
rhs_bit_size: u32,
max_bits: u32,
) -> Result<Expression, AcirGenError> {
assert!(max_bits > rhs_bit_size, "max_bits = {max_bits}, rhs = {rhs_bit_size} -- The caller should ensure that truncation is only called when the value needs to be truncated");
let exp_big = BigUint::from(2_u32).pow(rhs_bit_size);

// 0. Check for constant expression.
if let Some(a_c) = lhs.to_const() {
let mut a_big = BigUint::from_bytes_be(&a_c.to_be_bytes());
a_big %= exp_big;
return Ok(Expression::from(FieldElement::from_be_bytes_reduce(&a_big.to_bytes_be())));
}
// Note: This is doing a reduction. However, since the compiler will call
// `max_bits` before it overflows the modulus, this line should never do a reduction.
//
// For example, if the modulus is a 254 bit number.
// `max_bits` will never be 255 since `exp` will be 2^255, which will cause a reduction in the following line.
// TODO: We should change this from `from_be_bytes_reduce` to `from_be_bytes`
// TODO: the latter will return an option that we can unwrap in the compiler
let exp = FieldElement::from_be_bytes_reduce(&exp_big.to_bytes_be());

// 1. Generate witnesses a,b,c

// According to the division theorem, the remainder needs to be 0 <= r < 2^{rhs_bit_size}
let r_max_bits = rhs_bit_size;
// According to the formula above, the quotient should be within the range 0 <= q < 2^{max_bits - rhs}
let q_max_bits = max_bits - rhs_bit_size;

let (quotient_witness, remainder_witness) =
self.quotient_directive(lhs.clone(), exp.into(), None, q_max_bits, r_max_bits)?;

// 2. Add the constraint a == r + (q * 2^{rhs})
//
// 2^{rhs}
let mut two_pow_rhs_bits = FieldElement::from(2_i128);
two_pow_rhs_bits = two_pow_rhs_bits.pow(&FieldElement::from(rhs_bit_size as i128));

let remainder_expr = Expression::from(remainder_witness);
let quotient_expr = Expression::from(quotient_witness);

let res = &remainder_expr + &(two_pow_rhs_bits * &quotient_expr);
let euclidean_division = &res - lhs;

self.push_opcode(AcirOpcode::Arithmetic(euclidean_division));

Ok(Expression::from(remainder_witness))
}

/// Calls a black box function and returns the output
/// of said blackbox function.
pub(crate) fn call_black_box(
Expand Down Expand Up @@ -457,16 +371,28 @@ impl GeneratedAcir {
// lhs = rhs * q + r
//
// If predicate is zero, `q_witness` and `r_witness` will be 0

// maximum bit size for q and for [r and rhs]
let mut max_q_bits = max_bit_size;
let mut max_rhs_bits = max_bit_size;
// when rhs is constant, we can better estimate the maximum bit sizes
if let Some(rhs_const) = rhs.to_const() {
max_rhs_bits = rhs_const.num_bits();
if max_rhs_bits != 0 {
max_q_bits = max_bit_size - max_rhs_bits + 1;
}
}

let (q_witness, r_witness) = self.quotient_directive(
lhs.clone(),
rhs.clone(),
Some(predicate.clone()),
max_bit_size,
max_bit_size,
max_q_bits,
max_rhs_bits,
)?;

// Constrain r < rhs
self.bound_constraint_with_offset(&r_witness.into(), rhs, predicate, max_bit_size)?;
self.bound_constraint_with_offset(&r_witness.into(), rhs, predicate, max_rhs_bits)?;

// a * predicate == (b * q + r) * predicate
// => predicate * (a - b * q - r) == 0
Expand Down Expand Up @@ -766,26 +692,12 @@ impl GeneratedAcir {
let two_max_bits: FieldElement = two.pow(&FieldElement::from(max_bits as i128));
let comparison_evaluation = (a - b) + two_max_bits;

// We want to enforce that `q` is a boolean value.
// In particular it should be the `n` bit of the `comparison_evaluation`
// which will indicate whether a >= b.
// Euclidian division by 2^{max_bits} : 2^{max_bits} + a - b = q * 2^{max_bits} + r
//
// In the document linked above, they mention negating the value of `q`
// which would tell us whether a < b. Since we do not negate `q`
// what we get is a boolean indicating whether a >= b.
let q_max_bits = 1;
// `r` can take any value up to `two_max_bits`.
let r_max_bits = max_bits;

let (q_witness, r_witness) = self.quotient_directive(
comparison_evaluation.clone(),
two_max_bits.into(),
Some(predicate.clone()),
q_max_bits,
r_max_bits,
)?;

// Add constraint : 2^{max_bits} + a - b = q * 2^{max_bits} + r
// 2^{max_bits} is of max_bits+1 bit size
// If a>b, then a-b is less than 2^{max_bits} - 1, so 2^{max_bits} + a - b is less than 2^{max_bits} + 2^{max_bits} - 1 = 2^{max_bits+1} - 1
// If a <= b, then 2^{max_bits} + a - b is less than 2^{max_bits} <= 2^{max_bits+1} - 1
// This means that both operands of the division have at most max_bits+1 bit size.
//
// case: a == b
//
Expand All @@ -805,19 +717,13 @@ impl GeneratedAcir {
// - 2^{max_bits} - k == q * 2^{max_bits} + r
// - This is only the case when q == 0 and r == 2^{max_bits} - k
//
// case: predicate is zero
// The values for q and r will be zero for a honest prover and
// can be garbage for a dishonest prover. The below constraint will
// will be switched off.
let mut expr = Expression::default();
expr.push_addition_term(two_max_bits, q_witness);
expr.push_addition_term(FieldElement::one(), r_witness);

let equation = &comparison_evaluation - &expr;
let predicated_equation = self.mul_with_witness(&equation, &predicate);
self.push_opcode(AcirOpcode::Arithmetic(predicated_equation));

Ok(q_witness)
let (q, _) = self.euclidean_division(
&comparison_evaluation,
&Expression::from(two_max_bits),
max_bits + 1,
&predicate,
)?;
Ok(q)
}

pub(crate) fn brillig(
Expand Down

0 comments on commit befe549

Please sign in to comment.