From 456f4147908ac00af72738639d4a4b7eead98064 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s?= <47506558+MegaRedHand@users.noreply.github.com> Date: Thu, 20 Apr 2023 19:38:20 -0300 Subject: [PATCH] feat(hints): add NewHint#45 (#1024) * Add NewHint#45 * Update changelog --- CHANGELOG.md | 9 ++ cairo_programs/uint256_improvements.cairo | 15 +++ .../builtin_hint_processor_definition.rs | 6 +- .../builtin_hint_processor/hint_code.rs | 3 + .../builtin_hint_processor/uint256_utils.rs | 92 +++++++++++++------ 5 files changed, 94 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 79e66243f3..8f9d0d985f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,15 @@ #### Upcoming Changes +* Add missing hint on uint256_improvements lib [#1024](https://github.com/lambdaclass/cairo-rs/pull/1024): + + `BuiltinHintProcessor` now supports the following hint: + + ```python + res = ids.a + ids.b + ids.carry = 1 if res >= ids.SHIFT else 0 + ``` + * BREAKING CHANGE: move `Program::identifiers` to `SharedProgramData::identifiers` [#1023](https://github.com/lambdaclass/cairo-rs/pull/1023) * Optimizes `CairoRunner::new`, needed for sequencers and other workflows reusing the same `Program` instance across `CairoRunner`s * Breaking change: make all fields in `Program` and `SharedProgramData` `pub(crate)`, since we break by moving the field let's make it the last break for this struct diff --git a/cairo_programs/uint256_improvements.cairo b/cairo_programs/uint256_improvements.cairo index 79fc6d7261..a5f3223de7 100644 --- a/cairo_programs/uint256_improvements.cairo +++ b/cairo_programs/uint256_improvements.cairo @@ -338,9 +338,24 @@ func test_uint256_sub{range_check_ptr}() { return (); } +func test_uint128_add{range_check_ptr}() { + let (res) = uint128_add(5, 66); + + assert res = Uint256(71, 0); + + let (res) = uint128_add( + 340282366920938463463374607431768211455, 340282366920938463463374607431768211455 + ); + + assert res = Uint256(340282366920938463463374607431768211454, 1); + + return (); +} + func main{range_check_ptr}() { test_udiv_expanded(); test_uint256_sub(); + test_uint128_add(); return (); } diff --git a/src/hint_processor/builtin_hint_processor/builtin_hint_processor_definition.rs b/src/hint_processor/builtin_hint_processor/builtin_hint_processor_definition.rs index df555e2b98..4fba459770 100644 --- a/src/hint_processor/builtin_hint_processor/builtin_hint_processor_definition.rs +++ b/src/hint_processor/builtin_hint_processor/builtin_hint_processor_definition.rs @@ -57,8 +57,9 @@ use crate::{ squash_dict_inner_used_accesses_assert, }, uint256_utils::{ - split_64, uint256_add, uint256_expanded_unsigned_div_rem, uint256_mul_div_mod, - uint256_signed_nn, uint256_sqrt, uint256_sub, uint256_unsigned_div_rem, + split_64, uint128_add, uint256_add, uint256_expanded_unsigned_div_rem, + uint256_mul_div_mod, uint256_signed_nn, uint256_sqrt, uint256_sub, + uint256_unsigned_div_rem, }, uint384::{ add_no_uint384_check, uint384_signed_nn, uint384_split_128, uint384_sqrt, @@ -336,6 +337,7 @@ impl HintProcessor for BuiltinHintProcessor { dict_squash_update_ptr(vm, exec_scopes, &hint_data.ids_data, &hint_data.ap_tracking) } hint_code::UINT256_ADD => uint256_add(vm, &hint_data.ids_data, &hint_data.ap_tracking), + hint_code::UINT128_ADD => uint128_add(vm, &hint_data.ids_data, &hint_data.ap_tracking), hint_code::UINT256_SUB => uint256_sub(vm, &hint_data.ids_data, &hint_data.ap_tracking), hint_code::SPLIT_64 => split_64(vm, &hint_data.ids_data, &hint_data.ap_tracking), hint_code::UINT256_SQRT => { diff --git a/src/hint_processor/builtin_hint_processor/hint_code.rs b/src/hint_processor/builtin_hint_processor/hint_code.rs index fb5af02b06..95bfea0f7c 100644 --- a/src/hint_processor/builtin_hint_processor/hint_code.rs +++ b/src/hint_processor/builtin_hint_processor/hint_code.rs @@ -283,6 +283,9 @@ ids.carry_low = 1 if sum_low >= ids.SHIFT else 0 sum_high = ids.a.high + ids.b.high + ids.carry_low ids.carry_high = 1 if sum_high >= ids.SHIFT else 0"#; +pub const UINT128_ADD: &str = r#"res = ids.a + ids.b +ids.carry = 1 if res >= ids.SHIFT else 0"#; + pub const UINT256_SUB: &str = r#"def split(num: int, num_bits_shift: int = 128, length: int = 2): a = [] for _ in range(length): diff --git a/src/hint_processor/builtin_hint_processor/uint256_utils.rs b/src/hint_processor/builtin_hint_processor/uint256_utils.rs index da65a60f7c..a8d9b2d518 100644 --- a/src/hint_processor/builtin_hint_processor/uint256_utils.rs +++ b/src/hint_processor/builtin_hint_processor/uint256_utils.rs @@ -108,38 +108,53 @@ pub fn uint256_add( ap_tracking: &ApTracking, ) -> Result<(), HintError> { let shift = Felt252::new(1_u32) << 128_u32; - let a_relocatable = get_relocatable_from_var_name("a", vm, ids_data, ap_tracking)?; - let b_relocatable = get_relocatable_from_var_name("b", vm, ids_data, ap_tracking)?; - let a_low = vm.get_integer(a_relocatable)?; - let a_high = vm.get_integer((a_relocatable + 1_usize)?)?; - let b_low = vm.get_integer(b_relocatable)?; - let b_high = vm.get_integer((b_relocatable + 1_usize)?)?; - let a_low = a_low.as_ref(); - let a_high = a_high.as_ref(); - let b_low = b_low.as_ref(); - let b_high = b_high.as_ref(); - //Main logic - //sum_low = ids.a.low + ids.b.low - //ids.carry_low = 1 if sum_low >= ids.SHIFT else 0 - //sum_high = ids.a.high + ids.b.high + ids.carry_low - //ids.carry_high = 1 if sum_high >= ids.SHIFT else 0 + let a = Uint256::from_var_name("a", vm, ids_data, ap_tracking)?; + let b = Uint256::from_var_name("b", vm, ids_data, ap_tracking)?; + let a_low = a.low.as_ref(); + let a_high = a.high.as_ref(); + let b_low = b.low.as_ref(); + let b_high = b.high.as_ref(); - let carry_low = if a_low + b_low >= shift { - Felt252::one() - } else { - Felt252::zero() - }; + // Main logic + // sum_low = ids.a.low + ids.b.low + // ids.carry_low = 1 if sum_low >= ids.SHIFT else 0 + // sum_high = ids.a.high + ids.b.high + ids.carry_low + // ids.carry_high = 1 if sum_high >= ids.SHIFT else 0 + + let carry_low = Felt252::from((a_low + b_low >= shift) as u8); + let carry_high = Felt252::from((a_high + b_high + &carry_low >= shift) as u8); - let carry_high = if a_high + b_high + &carry_low >= shift { - Felt252::one() - } else { - Felt252::zero() - }; insert_value_from_var_name("carry_high", carry_high, vm, ids_data, ap_tracking)?; insert_value_from_var_name("carry_low", carry_low, vm, ids_data, ap_tracking) } +/* +Implements hint: +%{ + res = ids.a + ids.b + ids.carry = 1 if res >= ids.SHIFT else 0 +%} +*/ +pub fn uint128_add( + vm: &mut VirtualMachine, + ids_data: &HashMap, + ap_tracking: &ApTracking, +) -> Result<(), HintError> { + let shift = Felt252::new(1_u32) << 128_u32; + let a = get_integer_from_var_name("a", vm, ids_data, ap_tracking)?; + let b = get_integer_from_var_name("b", vm, ids_data, ap_tracking)?; + let a = a.as_ref(); + let b = b.as_ref(); + + // Main logic + // res = ids.a + ids.b + // ids.carry = 1 if res >= ids.SHIFT else 0 + let carry = Felt252::from((a + b >= shift) as u8); + + insert_value_from_var_name("carry", carry, vm, ids_data, ap_tracking) +} + /* Implements hint: %{ @@ -477,18 +492,18 @@ mod tests { #[test] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] fn run_uint256_add_ok() { - let hint_code = "sum_low = ids.a.low + ids.b.low\nids.carry_low = 1 if sum_low >= ids.SHIFT else 0\nsum_high = ids.a.high + ids.b.high + ids.carry_low\nids.carry_high = 1 if sum_high >= ids.SHIFT else 0"; + let hint_code = hint_code::UINT256_ADD; let mut vm = vm_with_range_check!(); //Initialize fp vm.run_context.fp = 10; //Create hint_data let ids_data = - non_continuous_ids_data![("a", -6), ("b", -4), ("carry_high", 3), ("carry_low", 2)]; + non_continuous_ids_data![("a", -6), ("b", -4), ("carry_low", 2), ("carry_high", 3)]; vm.segments = segments![ ((1, 4), 2), ((1, 5), 3), ((1, 6), 4), - ((1, 7), ("340282366920938463463374607431768211456", 10)) + ((1, 7), ("340282366920938463463374607431768211455", 10)) ]; //Execute the hint assert_matches!(run_hint!(vm, ids_data, hint_code), Ok(())); @@ -496,10 +511,29 @@ mod tests { check_memory![vm.segments.memory, ((1, 12), 0), ((1, 13), 1)]; } + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn run_uint128_add_ok() { + let hint_code = hint_code::UINT128_ADD; + let mut vm = vm_with_range_check!(); + // Initialize fp + vm.run_context.fp = 0; + // Create hint_data + let ids_data = non_continuous_ids_data![("a", 0), ("b", 1), ("carry", 2)]; + vm.segments = segments![ + ((1, 0), 180141183460469231731687303715884105727_u128), + ((1, 1), 180141183460469231731687303715884105727_u128), + ]; + // Execute the hint + assert_matches!(run_hint!(vm, ids_data, hint_code), Ok(())); + // Check hint memory inserts + check_memory![vm.segments.memory, ((1, 2), 1)]; + } + #[test] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] fn run_uint256_add_fail_inserts() { - let hint_code = "sum_low = ids.a.low + ids.b.low\nids.carry_low = 1 if sum_low >= ids.SHIFT else 0\nsum_high = ids.a.high + ids.b.high + ids.carry_low\nids.carry_high = 1 if sum_high >= ids.SHIFT else 0"; + let hint_code = hint_code::UINT256_ADD; let mut vm = vm_with_range_check!(); //Initialize fp vm.run_context.fp = 10;