Skip to content

Commit

Permalink
Merge branch 'mod-builtin-fix' into dynamic-layout-2
Browse files Browse the repository at this point in the history
  • Loading branch information
JulianGCalderon committed Sep 19, 2024
2 parents 6e9e09a + dd6d6ff commit 63e606f
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 49 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
* feat: [#1838](https://github.com/lambdaclass/cairo-vm/pull/1824):
* Add support for missing dynamic layout features

* fix: [#1841](https://github.com/lambdaclass/cairo-vm/pull/1841):
* Fix modulo builtin to comply with prover constraints

* feat(BREAKING): [#1824](https://github.com/lambdaclass/cairo-vm/pull/1824):
* Add support for dynamic layout
* CLI change(BREAKING): The flag `cairo_layout_params_file` must be specified when using dynamic layout.
Expand Down
25 changes: 25 additions & 0 deletions vm/src/air_private_input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ pub struct ModInputInstance {
pub values_ptr: usize,
pub offsets_ptr: usize,
pub n: usize,
#[serde(deserialize_with = "mod_input_instance_batch_serde::deserialize")]
#[serde(serialize_with = "mod_input_instance_batch_serde::serialize")]
pub batch: BTreeMap<usize, ModInputMemoryVars>,
}

Expand Down Expand Up @@ -205,6 +207,29 @@ impl AirPrivateInputSerializable {
}
}

mod mod_input_instance_batch_serde {
use super::*;

use serde::{Deserializer, Serializer};

pub(crate) fn serialize<S: Serializer>(
value: &BTreeMap<usize, ModInputMemoryVars>,
s: S,
) -> Result<S::Ok, S::Error> {
let value = value.iter().map(|v| v.1).collect::<Vec<_>>();

value.serialize(s)
}

pub(crate) fn deserialize<'de, D: Deserializer<'de>>(
d: D,
) -> Result<BTreeMap<usize, ModInputMemoryVars>, D::Error> {
let value = Vec::<ModInputMemoryVars>::deserialize(d)?;

Ok(value.into_iter().enumerate().collect())
}
}

#[cfg(test)]
mod tests {
use crate::types::layout_name::LayoutName;
Expand Down
3 changes: 3 additions & 0 deletions vm/src/tests/compare_outputs_dynamic_layouts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ CASES=(
"cairo_programs/proof_programs/fibonacci.json;all_cairo"
"cairo_programs/proof_programs/factorial.json;double_all_cairo"
"cairo_programs/proof_programs/fibonacci.json;double_all_cairo"
"cairo_programs/mod_builtin_feature/proof/mod_builtin.json;all_cairo"
"cairo_programs/mod_builtin_feature/proof/mod_builtin_failure.json;all_cairo"
"cairo_programs/mod_builtin_feature/proof/apply_poly.json;all_cairo"
)

passed_tests=0
Expand Down
139 changes: 90 additions & 49 deletions vm/src/vm/runners/builtin_runner/modulo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub struct ModBuiltinRunner {
// Precomputed powers used for reading and writing values that are represented as n_words words of word_bit_len bits each.
shift: BigUint,
shift_powers: [BigUint; N_WORDS],
k_bound: BigUint,
}

#[derive(Debug, Clone)]
Expand All @@ -60,7 +61,7 @@ pub enum Operation {
Mul,
Add,
Sub,
DivMod(BigUint),
DivMod,
}

impl Display for Operation {
Expand All @@ -69,7 +70,7 @@ impl Display for Operation {
Operation::Mul => "*".fmt(f),
Operation::Add => "+".fmt(f),
Operation::Sub => "-".fmt(f),
Operation::DivMod(_) => "/".fmt(f),
Operation::DivMod => "/".fmt(f),
}
}
}
Expand All @@ -85,17 +86,28 @@ struct Inputs {

impl ModBuiltinRunner {
pub(crate) fn new_add_mod(instance_def: &ModInstanceDef, included: bool) -> Self {
Self::new(instance_def.clone(), included, ModBuiltinType::Add)
Self::new(
instance_def.clone(),
included,
ModBuiltinType::Add,
Some(2u32.into()),
)
}

pub(crate) fn new_mul_mod(instance_def: &ModInstanceDef, included: bool) -> Self {
Self::new(instance_def.clone(), included, ModBuiltinType::Mul)
Self::new(instance_def.clone(), included, ModBuiltinType::Mul, None)
}

fn new(instance_def: ModInstanceDef, included: bool, builtin_type: ModBuiltinType) -> Self {
fn new(
instance_def: ModInstanceDef,
included: bool,
builtin_type: ModBuiltinType,
k_bound: Option<BigUint>,
) -> Self {
let shift = BigUint::one().shl(instance_def.word_bit_len);
let shift_powers = core::array::from_fn(|i| shift.pow(i as u32));
let zero_segment_size = core::cmp::max(N_WORDS, instance_def.batch_size * 3);
let int_lim = BigUint::from(2_u32).pow(N_WORDS as u32 * instance_def.word_bit_len);
Self {
builtin_type,
base: 0,
Expand All @@ -106,6 +118,7 @@ impl ModBuiltinRunner {
zero_segment_size,
shift,
shift_powers,
k_bound: k_bound.unwrap_or(int_lim),
}
}

Expand Down Expand Up @@ -462,19 +475,19 @@ impl ModBuiltinRunner {
match (a, b, c) {
// Deduce c from a and b and write it to memory.
(Some(a), Some(b), None) => {
let value = apply_op(a, b, op)?.mod_floor(&inputs.p);
let value = apply_op(op, a, b, &inputs.p, &self.k_bound)?;
self.write_n_words_value(memory, addresses[2], value)?;
Ok(true)
}
// Deduce b from a and c and write it to memory.
(Some(a), None, Some(c)) => {
let value = apply_op(c, a, inv_op)?.mod_floor(&inputs.p);
let value = apply_op(inv_op, c, a, &inputs.p, &self.k_bound)?;
self.write_n_words_value(memory, addresses[1], value)?;
Ok(true)
}
// Deduce a from b and c and write it to memory.
(None, Some(b), Some(c)) => {
let value = apply_op(c, b, inv_op)?.mod_floor(&inputs.p);
let value = apply_op(inv_op, c, b, &inputs.p, &self.k_bound)?;
self.write_n_words_value(memory, addresses[0], value)?;
Ok(true)
}
Expand Down Expand Up @@ -543,44 +556,45 @@ impl ModBuiltinRunner {
Default::default()
};

// Get one of the builtin runners - the rest of this function doesn't depend on batch_size.
let mod_runner = if let Some((_, add_mod, _)) = add_mod {
add_mod
} else {
mul_mod.unwrap().1
};
// Fill the values table.
let mut add_mod_index = 0;
let mut mul_mod_index = 0;
// Create operation here to avoid cloning p in the loop
let div_operation = Operation::DivMod(mul_mod_inputs.p.clone());

while add_mod_index < add_mod_n || mul_mod_index < mul_mod_n {
if add_mod_index < add_mod_n
&& mod_runner.fill_value(
memory,
&add_mod_inputs,
add_mod_index,
&Operation::Add,
&Operation::Sub,
)?
{
add_mod_index += 1;
} else if mul_mod_index < mul_mod_n
&& mod_runner.fill_value(
memory,
&mul_mod_inputs,
mul_mod_index,
&Operation::Mul,
&div_operation,
)?
{
mul_mod_index += 1;
} else {
return Err(RunnerError::FillMemoryCoudNotFillTable(
add_mod_index,
mul_mod_index,
));
if add_mod_index < add_mod_n {
if let Some((_, add_mod_runner, _)) = add_mod {
if add_mod_runner.fill_value(
memory,
&add_mod_inputs,
add_mod_index,
&Operation::Add,
&Operation::Sub,
)? {
add_mod_index += 1;
continue;
}
}
}

if mul_mod_index < mul_mod_n {
if let Some((_, mul_mod_runner, _)) = mul_mod {
if mul_mod_runner.fill_value(
memory,
&mul_mod_inputs,
mul_mod_index,
&Operation::Mul,
&Operation::DivMod,
)? {
mul_mod_index += 1;
}
continue;
}
}

return Err(RunnerError::FillMemoryCoudNotFillTable(
add_mod_index,
mul_mod_index,
));
}
Ok(())
}
Expand Down Expand Up @@ -633,7 +647,7 @@ impl ModBuiltinRunner {
ModBuiltinType::Add => Operation::Add,
ModBuiltinType::Mul => Operation::Mul,
};
let a_op_b = apply_op(&a, &b, &op)?.mod_floor(&inputs.p);
let a_op_b = apply_op(&op, &a, &b, &inputs.p, &self.k_bound)?;
if a_op_b != c.mod_floor(&inputs.p) {
// Build error string
let p = inputs.p;
Expand Down Expand Up @@ -673,13 +687,40 @@ impl ModBuiltinRunner {
}
}

fn apply_op(lhs: &BigUint, rhs: &BigUint, op: &Operation) -> Result<BigUint, MathError> {
Ok(match op {
Operation::Mul => lhs * rhs,
Operation::Add => lhs + rhs,
Operation::Sub => lhs - rhs,
Operation::DivMod(ref p) => div_mod_unsigned(lhs, rhs, p)?,
})
fn apply_op(
op: &Operation,
lhs: &BigUint,
rhs: &BigUint,
p: &BigUint,
k_bound: &BigUint,
) -> Result<BigUint, MathError> {
let value = match op {
Operation::Mul => {
let value = lhs * rhs;
if value < k_bound * p {
value.mod_floor(p)
} else {
value - (k_bound - 1u32) * p
}
}
Operation::Add => {
let value = lhs + rhs;
if value < k_bound * p {
value.mod_floor(p)
} else {
value - (k_bound - 1u32) * p
}
}
Operation::Sub => {
if rhs <= lhs {
lhs - rhs
} else {
lhs + p - rhs
}
}
Operation::DivMod => div_mod_unsigned(lhs, rhs, p)?,
};
Ok(value)
}

#[cfg(test)]
Expand Down

0 comments on commit 63e606f

Please sign in to comment.