Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(ssa refactor): Add more documentation for truncation #1607

Merged
merged 3 commits into from
Jun 9, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,57 +95,97 @@ impl GeneratedAcir {
}

impl GeneratedAcir {
/// Computes lhs mod 2^rhs
/// Computes lhs = 2^{rhs_bit_size} * q + r
///
/// `max_bits` is the upper-bound on the bit_size of the object that `lhs` is representing.

/// An example; max_bits would be 32, if lhs was representing a u32 at a higher level.
/// For example, if we had a u32:
/// - `rhs` would be `32`
/// - `max_bits` would be the size of `lhs`
///
/// Take the following code:
/// ``
/// fn main(x : u32) -> u32 {
/// let a = x + x; (L1)
/// let b = a * a; (L2)
/// b + b (L3)
/// }
/// ``
///
/// Call truncate only on L1:
/// - `rhs` would be `32`
/// - `max_bits` would be `33` due to the addition of two u32s
/// Call truncate only on L2:
/// - `rhs` would be `32`
/// - `max_bits` would be `66` due to the multiplication of two u33s `a`
/// Call truncate only on L3:
/// - `rhs` would be `32`
/// - `max_bits` would be `67` due to the addition of two u66s `b`
kevaundray marked this conversation as resolved.
Show resolved Hide resolved
///
/// Truncation is done via the euclidean division formula:
///
/// a = b * q + r
///
/// where:
/// - a = `lhs`
/// - b = 2^{max_bits}
/// The prover will supply the quotient and the remainder, where the remainder
/// is the truncated value that we will return since it is enforced to be
/// in the range: 0 <= r < 2^{rhs}
kevaundray marked this conversation as resolved.
Show resolved Hide resolved
pub(crate) fn truncate(
&mut self,
lhs: &Expression,
rhs: u32,
rhs_bit_size: u32,
max_bits: u32,
) -> Result<Expression, AcirGenError> {
assert!(max_bits > rhs, "max_bits = {max_bits}, rhs = {rhs}");
let exp_big = BigUint::from(2_u32).pow(rhs);
assert!(max_bits > rhs_bit_size, "max_bits = {max_bits}, rhs = {rhs_bit_size} -- The caller should ensure that truncation is only called when the value needs to be truncated");
let exp_big = BigUint::from(2_u32).pow(rhs_bit_size);

// 0. Check for constant expression.
if let Some(a_c) = lhs.to_const() {
let mut a_big = BigUint::from_bytes_be(&a_c.to_be_bytes());
a_big %= exp_big;
return Ok(Expression::from(FieldElement::from_be_bytes_reduce(&a_big.to_bytes_be())));
}
// Note: This is doing a reduction. However, since the compiler will call
// `max_bits` before it overflows the modulus, this line should never do a reduction.
//
// For example, if the modulus is a 254 bit number.
// `max_bits` will never be 255 since `exp` will be 2^255, which will cause a reduction in the following line.
// TODO: We should change this from `from_be_bytes_reduce` to `from_be_bytes`
// TODO: the latter will return an option that we can unwrap in the compiler
let exp = FieldElement::from_be_bytes_reduce(&exp_big.to_bytes_be());

// 1. Generate witnesses a,b,c
let b_witness = self.next_witness_index();
let c_witness = self.next_witness_index();
let remainder_witness = self.next_witness_index();
let quotient_witness = self.next_witness_index();
self.push_opcode(AcirOpcode::Directive(Directive::Quotient(QuotientDirective {
a: lhs.clone(),
b: Expression::from_field(exp),
q: c_witness,
r: b_witness,
q: quotient_witness,
r: remainder_witness,
predicate: None,
})));

self.range_constraint(b_witness, rhs)?;
self.range_constraint(c_witness, max_bits - rhs)?;
// According to the division theorem, the remainder needs to be 0 <= r < 2^{rhs_bit_size}
self.range_constraint(remainder_witness, rhs_bit_size)?;

// According to the formula above, the quotient should be within the range 0 <= q < 2^{max_bits - rhs}
self.range_constraint(quotient_witness, max_bits - rhs_bit_size)?;

// 2. Add the constraint a = b + 2^{rhs} * c
// 2. Add the constraint a == r + (q * 2^{rhs})
//
// 2^{rhs}
let mut two_pow_rhs_bits = FieldElement::from(2_i128);
two_pow_rhs_bits = two_pow_rhs_bits.pow(&FieldElement::from(rhs as i128));
two_pow_rhs_bits = two_pow_rhs_bits.pow(&FieldElement::from(rhs_bit_size as i128));

let b_arith = Expression::from(b_witness);
let c_arith = Expression::from(c_witness);
let remainder_expr = Expression::from(remainder_witness);
let quotient_expr = Expression::from(quotient_witness);

let res = &b_arith + &(two_pow_rhs_bits * &c_arith);
let my_constraint = &res - lhs;
let res = &remainder_expr + &(two_pow_rhs_bits * &quotient_expr);
let euclidean_division = &res - lhs;

self.push_opcode(AcirOpcode::Arithmetic(my_constraint));
self.push_opcode(AcirOpcode::Arithmetic(euclidean_division));

Ok(Expression::from(b_witness))
Ok(Expression::from(remainder_witness))
}

/// Calls a black box function and returns the output
Expand Down Expand Up @@ -264,13 +304,15 @@ impl GeneratedAcir {
&mut self,
lhs: &Expression,
rhs: &Expression,
bit_size: u32,
max_bit_size: u32,
predicate: &Expression,
) -> Result<(Witness, Witness), AcirGenError> {
let q_witness = self.next_witness_index();
let r_witness = self.next_witness_index();

let pa = lhs * predicate;
// lhs = rhs * q + r
//
// If predicate is zero, `q_witness` and `r_witness` will be 0
self.push_opcode(AcirOpcode::Directive(Directive::Quotient(QuotientDirective {
a: lhs.clone(),
b: rhs.clone(),
Expand All @@ -279,17 +321,25 @@ impl GeneratedAcir {
predicate: Some(predicate.clone()),
})));

//r<b
// Constrain r to be 0 <= r < 2^{max_bit_size}
let r_expr = Expression::from(r_witness);
self.range_constraint(r_witness, bit_size)?;
self.bound_constraint_with_offset(&r_expr, rhs, predicate, bit_size)?;
//range check q<=a
self.range_constraint(q_witness, bit_size)?;
// a-b*q-r = 0
let mut d = rhs * &Expression::from(q_witness);
d = &d + r_witness;
d = &d * predicate;
let div_euclidean = &pa - &d;
self.range_constraint(r_witness, max_bit_size)?;
// Constrain r < rhs
self.bound_constraint_with_offset(&r_expr, rhs, predicate, max_bit_size)?;

// Constrain q to be 0 <= q < 2^{max_bit_size}
self.range_constraint(q_witness, max_bit_size)?;

// a * predicate == (b * q + r) * predicate
// => predicate * ( a - b * q - r) == 0
// When the predicate is 0, the equation always passes.
// When the predicate is 1, the euclidean division needs to be
// true.
let mut rhs_constraint = rhs * &Expression::from(q_witness);
rhs_constraint = &rhs_constraint + r_witness;
rhs_constraint = &rhs_constraint * predicate;
let lhs_constraint = lhs * predicate;
let div_euclidean = &lhs_constraint - &rhs_constraint;

self.push_opcode(AcirOpcode::Arithmetic(div_euclidean));

Expand Down Expand Up @@ -576,7 +626,7 @@ impl GeneratedAcir {
// TODO: perhaps this should be a user error, instead of an assert
assert!(max_bits + 1 < FieldElement::max_num_bits());

// Compute : 2^max_bits + a - b
// Compute : 2^{max_bits} + a - b
let mut comparison_evaluation = a - b;
let two = FieldElement::from(2_i128);
let two_max_bits = two.pow(&FieldElement::from(max_bits as i128));
Expand All @@ -586,6 +636,25 @@ impl GeneratedAcir {
let r_witness = self.next_witness_index();

// Add constraint : 2^{max_bits} + a - b = q * 2^{max_bits} + r
//
// case: a == b
//
// let k = 0;
// - 2^{max_bits} == q * 2^{max_bits} + r
// - This is only the case when q == 1 and r == 0 (assuming r is bounded to be less than 2^{max_bits})
//
// case: a > b
//
// let k = a - b;
// - k + 2^{max_bits} == q * 2^{max_bits} + r
// - This is the case when q == 1 and r = k
//
// case: a < b
//
// let k = b - a
// - 2^{max_bits} - k == q * 2^{max_bits} + r
// - This is only the case when q == 0 and r == 2^{max_bits} - k
//
let mut expr = Expression::default();
expr.push_addition_term(two_max_bits, q_witness);
expr.push_addition_term(FieldElement::one(), r_witness);
Expand Down