Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(recursion): num2bits fixes #732

Merged
merged 5 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 42 additions & 7 deletions recursion/compiler/src/ir/bits.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use p3_field::AbstractField;
use p3_field::{AbstractField, Field};
use sp1_recursion_core::runtime::NUM_BITS;

use super::{Array, Builder, Config, DslIr, Felt, Usize, Var};

impl<C: Config> Builder<C> {
/// Converts a variable to bits.
/// Converts a variable to LE bits.
pub fn num2bits_v(&mut self, num: Var<C::N>) -> Array<C, Var<C::N>> {
// This function is only used when the native field is Babybear.
assert!(C::N::bits() == NUM_BITS);

let output = self.dyn_array::<Var<_>>(NUM_BITS);
self.push(DslIr::HintBitsV(output.clone(), num));

Expand All @@ -16,10 +19,10 @@ impl<C: Config> Builder<C> {
self.assign(sum, sum + bit * C::N::from_canonical_u32(1 << i));
}

// TODO: There is an edge case where the witnessed bits may slightly overflow and cause
// the output to be incorrect. This is a known issue and will be fixed in the future.
self.assert_var_eq(sum, num);

self.less_than_bb_modulus(output.clone());

output
}

Expand Down Expand Up @@ -49,22 +52,25 @@ impl<C: Config> Builder<C> {
});
}

// TODO: There is an edge case where the witnessed bits may slightly overflow and cause
// the output to be incorrect. This is a known issue and will be fixed in the future.
self.assert_felt_eq(sum, num);

self.less_than_bb_modulus(output.clone());

output
}

/// Converts a felt to bits inside a circuit.
pub fn num2bits_f_circuit(&mut self, num: Felt<C::F>) -> Vec<Var<C::N>> {
let mut output = Vec::new();
for _ in 0..32 {
for _ in 0..NUM_BITS {
output.push(self.uninit());
}

self.push(DslIr::CircuitNum2BitsF(num, output.clone()));

let output_array = self.vec(output.clone());
self.less_than_bb_modulus(output_array);

output
}

Expand Down Expand Up @@ -149,4 +155,33 @@ impl<C: Config> Builder<C> {
}
result_bits
}

/// Checks that the LE bit decomposition of a number is less than the babybear modulus.
///
/// SAFETY: This function assumes that the num_bits values are already verified to be boolean.
///
/// The babybear modulus in LE bits is: 100_000_000_000_000_000_000_000_000_111_1.
/// To check that the num_bits array is less than that value, we first check if the most significant
/// bits are all 1. If it is, then we assert that the other bits are all 0.
fn less_than_bb_modulus(&mut self, num_bits: Array<C, Var<C::N>>) {
let one: Var<_> = self.eval(C::N::one());
let zero: Var<_> = self.eval(C::N::zero());

let mut most_sig_4_bits = one;
for i in (NUM_BITS - 4)..NUM_BITS {
let bit = self.get(&num_bits, i);
most_sig_4_bits = self.eval(bit * most_sig_4_bits);
}

let mut sum_least_sig_bits = zero;
for i in 0..(NUM_BITS - 4) {
let bit = self.get(&num_bits, i);
sum_least_sig_bits = self.eval(bit + sum_least_sig_bits);
}

// If the most significant 4 bits are all 1, then check the sum of the least significant bits, else return zero.
let check: Var<_> =
self.eval(most_sig_4_bits * sum_least_sig_bits + (one - most_sig_4_bits) * zero);
self.assert_var_eq(check, zero);
}
}
2 changes: 1 addition & 1 deletion recursion/core/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ where
// Get the src value.
let num = b_val[0].as_canonical_u32();

// Decompose the num into bits.
// Decompose the num into LE bits.
let bits = (0..NUM_BITS).map(|i| (num >> i) & 1).collect::<Vec<_>>();
// Write the bits to the array at dst.
for (i, bit) in bits.iter().enumerate() {
Expand Down
Loading