diff --git a/CHANGELOG.md b/CHANGELOG.md index cdded309f1..12df582fb4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -380,6 +380,48 @@ ids.root = root ``` +* Add missing hint on vrf.json lib [#1045](https://github.com/lambdaclass/cairo-rs/pull/1045): + + `BuiltinHintProcessor` now supports the following hint: + + ```python + from starkware.python.math_utils import is_quad_residue, sqrt + + def split(a: int): + return (a & ((1 << 128) - 1), a >> 128) + + def pack(z) -> int: + return z.low + (z.high << 128) + + generator = pack(ids.generator) + x = pack(ids.x) + p = pack(ids.p) + + success_x = is_quad_residue(x, p) + root_x = sqrt(x, p) if success_x else None + success_gx = is_quad_residue(generator*x, p) + root_gx = sqrt(generator*x, p) if success_gx else None + + # Check that one is 0 and the other is 1 + if x != 0: + assert success_x + success_gx == 1 + + # `None` means that no root was found, but we need to transform these into a felt no matter what + if root_x == None: + root_x = 0 + if root_gx == None: + root_gx = 0 + ids.success_x = int(success_x) + ids.success_gx = int(success_gx) + split_root_x = split(root_x) + # print('split root x', split_root_x) + split_root_gx = split(root_gx) + ids.sqrt_x.low = split_root_x[0] + ids.sqrt_x.high = split_root_x[1] + ids.sqrt_gx.low = split_root_gx[0] + ids.sqrt_gx.high = split_root_gx[1] + ``` + * Add missing hint on uint256_improvements lib [#1024](https://github.com/lambdaclass/cairo-rs/pull/1024): `BuiltinHintProcessor` now supports the following hint: diff --git a/cairo_programs/field_arithmetic.cairo b/cairo_programs/field_arithmetic.cairo index dd2586cefd..a6174666bc 100644 --- a/cairo_programs/field_arithmetic.cairo +++ b/cairo_programs/field_arithmetic.cairo @@ -7,6 +7,7 @@ from starkware.cairo.common.math import assert_in_range, assert_le, assert_nn_le from starkware.cairo.common.math_cmp import is_le from starkware.cairo.common.pow import pow from starkware.cairo.common.registers import get_ap, get_fp_and_pc +from starkware.cairo.common.uint256 import Uint256 from cairo_programs.uint384 import u384, Uint384, Uint384_expand, SHIFT, HALF_SHIFT from cairo_programs.uint384_extension import u384_ext, Uint768 @@ -44,7 +45,6 @@ namespace field_arithmetic { ) -> (success: felt, res: Uint384) { alloc_locals; - // TODO: Create an equality function within field_arithmetic to avoid overflow bugs let (is_zero) = u384.eq(x, Uint384(0, 0, 0)); if (is_zero == 1) { return (1, Uint384(0, 0, 0)); @@ -108,7 +108,6 @@ namespace field_arithmetic { assert is_valid = 1; let (sqrt_x_squared: Uint384) = mul(sqrt_x, sqrt_x, p); // Note these checks may fail if the input x does not satisfy 0<= x < p - // TODO: Create a equality function within field_arithmetic to avoid overflow bugs let (check_x) = u384.eq(x, sqrt_x_squared); assert check_x = 1; return (1, sqrt_x); @@ -126,6 +125,103 @@ namespace field_arithmetic { } } + // Equivalent of get_square_root but for Uint256 + func u256_get_square_root{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}( + x: Uint256, p: Uint256, generator: Uint256 + ) -> (success: felt, res: Uint256) { + alloc_locals; + + let (is_zero) = u384.eq(Uint384(x.low, x.high, 0), Uint384(0, 0, 0)); + if (is_zero == 1) { + return (1, Uint256(0, 0)); + } + + local success_x: felt; + local success_gx: felt; + local sqrt_x: Uint256; + local sqrt_gx: Uint256; + + // Compute square roots in a hint + %{ + from starkware.python.math_utils import is_quad_residue, sqrt + + def split(a: int): + return (a & ((1 << 128) - 1), a >> 128) + + def pack(z) -> int: + return z.low + (z.high << 128) + + generator = pack(ids.generator) + x = pack(ids.x) + p = pack(ids.p) + + success_x = is_quad_residue(x, p) + root_x = sqrt(x, p) if success_x else None + success_gx = is_quad_residue(generator*x, p) + root_gx = sqrt(generator*x, p) if success_gx else None + + # Check that one is 0 and the other is 1 + if x != 0: + assert success_x + success_gx == 1 + + # `None` means that no root was found, but we need to transform these into a felt no matter what + if root_x == None: + root_x = 0 + if root_gx == None: + root_gx = 0 + ids.success_x = int(success_x) + ids.success_gx = int(success_gx) + split_root_x = split(root_x) + # print('split root x', split_root_x) + split_root_gx = split(root_gx) + ids.sqrt_x.low = split_root_x[0] + ids.sqrt_x.high = split_root_x[1] + ids.sqrt_gx.low = split_root_gx[0] + ids.sqrt_gx.high = split_root_gx[1] + %} + + // Verify that the values computed in the hint are what they are supposed to be + let (gx_384: Uint384) = mul( + Uint384(generator.low, generator.high, 0), + Uint384(x.low, x.high, 0), + Uint384(p.low, p.high, 0), + ); + let gx: Uint256 = Uint256(gx_384.d0, gx_384.d1); + if (success_x == 1) { + // u384.check(sqrt_x); + let (is_valid) = u384.lt( + Uint384(sqrt_x.low, sqrt_x.high, 0), Uint384(p.low, p.high, 0) + ); + assert is_valid = 1; + let (sqrt_x_squared: Uint384) = mul( + Uint384(sqrt_x.low, sqrt_x.high, 0), + Uint384(sqrt_x.low, sqrt_x.high, 0), + Uint384(p.low, p.high, 0), + ); + // Note these checks may fail if the input x does not satisfy 0<= x < p + let (check_x) = u384.eq(Uint384(x.low, x.high, 0), sqrt_x_squared); + assert check_x = 1; + return (1, sqrt_x); + } else { + // In this case success_gx = 1 + // u384.check(sqrt_gx); + let (is_valid) = u384.lt( + Uint384(sqrt_gx.low, sqrt_gx.high, 0), Uint384(p.low, p.high, 0) + ); + assert is_valid = 1; + let (sqrt_gx_squared: Uint384) = mul( + Uint384(sqrt_gx.low, sqrt_gx.high, 0), + Uint384(sqrt_gx.low, sqrt_gx.high, 0), + Uint384(p.low, p.high, 0), + ); + let (check_gx) = u384.eq(Uint384(gx.low, gx.high, 0), sqrt_gx_squared); + assert check_gx = 1; + // No square roots were found + // Note that Uint384(0, 0, 0) is not a square root here, but something needs to be returned + return (0, Uint256(0, 0)); + } + } + // Computes a * b^{-1} modulo p // NOTE: The modular inverse of b modulo p is computed in a hint and verified outside the hint with a multiplicaiton func div{range_check_ptr}(a: Uint384, b: Uint384, p: Uint384) -> (res: Uint384) { @@ -228,7 +324,46 @@ func test_field_arithmetics_extension_operations{range_check_ptr, bitwise_ptr: B return (); } +func test_u256_get_square_root{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}() { + alloc_locals; + // Test get_square + + // Small prime + let p_a = Uint256(7, 0); + let x_a = Uint256(2, 0); + let generator_a = Uint256(3, 0); + let (s_a, r_a) = field_arithmetic.u256_get_square_root(x_a, p_a, generator_a); + assert s_a = 1; + + assert r_a.low = 3; + assert r_a.high = 0; + + // Goldilocks Prime + let p_b = Uint256(18446744069414584321, 0); // Goldilocks Prime + let x_b = Uint256(25, 0); + let generator_b = Uint256(7, 0); + let (s_b, r_b) = field_arithmetic.u256_get_square_root(x_b, p_b, generator_b); + assert s_b = 1; + + assert r_b.low = 5; + assert r_b.high = 0; + + // Prime 2**101-99 + let p_c = Uint256(77371252455336267181195165, 32767); + let x_c = Uint256(96059601, 0); + let generator_c = Uint256(3, 0); + let (s_c, r_c) = field_arithmetic.u256_get_square_root(x_c, p_c, generator_c); + assert s_c = 1; + + assert r_c.low = 9801; + assert r_c.high = 0; + + return (); +} + func main{range_check_ptr: felt, bitwise_ptr: BitwiseBuiltin*}() { test_field_arithmetics_extension_operations(); + test_u256_get_square_root(); + return (); } diff --git a/src/hint_processor/builtin_hint_processor/builtin_hint_processor_definition.rs b/src/hint_processor/builtin_hint_processor/builtin_hint_processor_definition.rs index 3452036cd0..a23fb91925 100644 --- a/src/hint_processor/builtin_hint_processor/builtin_hint_processor_definition.rs +++ b/src/hint_processor/builtin_hint_processor/builtin_hint_processor_definition.rs @@ -3,7 +3,7 @@ use super::{ ec_recover_divmod_n_packed, ec_recover_product_div_m, ec_recover_product_mod, ec_recover_sub_a_b, }, - field_arithmetic::uint384_div, + field_arithmetic::{u256_get_square_root, u384_get_square_root, uint384_div}, secp::{ ec_utils::{ compute_slope_and_assing_secp_p, ec_double_assign_new_y, ec_negate_embedded_secp_p, @@ -34,7 +34,6 @@ use crate::{ dict_squash_update_ptr, dict_update, dict_write, }, ec_utils::{chained_ec_op_random_ec_point_hint, random_ec_point_hint, recover_y_hint}, - field_arithmetic::get_square_root, find_element_hint::{find_element, search_sorted_lower}, garaga::get_felt_bitlenght, hint_code, @@ -641,8 +640,11 @@ impl HintProcessor for BuiltinHintProcessor { | hint_code::UNSIGNED_DIV_REM_UINT768_BY_UINT384_STRIPPED => { unsigned_div_rem_uint768_by_uint384(vm, &hint_data.ids_data, &hint_data.ap_tracking) } - hint_code::GET_SQUARE_ROOT => { - get_square_root(vm, &hint_data.ids_data, &hint_data.ap_tracking) + hint_code::UINT384_GET_SQUARE_ROOT => { + u384_get_square_root(vm, &hint_data.ids_data, &hint_data.ap_tracking) + } + hint_code::UINT256_GET_SQUARE_ROOT => { + u256_get_square_root(vm, &hint_data.ids_data, &hint_data.ap_tracking) } hint_code::UINT384_SIGNED_NN => { uint384_signed_nn(vm, &hint_data.ids_data, &hint_data.ap_tracking) diff --git a/src/hint_processor/builtin_hint_processor/field_arithmetic.rs b/src/hint_processor/builtin_hint_processor/field_arithmetic.rs index c103724304..276812014d 100644 --- a/src/hint_processor/builtin_hint_processor/field_arithmetic.rs +++ b/src/hint_processor/builtin_hint_processor/field_arithmetic.rs @@ -5,6 +5,7 @@ use num_traits::Zero; use super::hint_utils::insert_value_from_var_name; use super::secp::bigint_utils::Uint384; +use super::uint256_utils::Uint256; use crate::math_utils::{is_quad_residue, mul_inv, sqrt_prime_power}; use crate::serde::deserialize_program::ApTracking; use crate::stdlib::{collections::HashMap, prelude::*}; @@ -15,52 +16,52 @@ use crate::{ }; /* Implements Hint: - %{ - from starkware.python.math_utils import is_quad_residue, sqrt - - def split(num: int, num_bits_shift: int = 128, length: int = 3): - a = [] - for _ in range(length): - a.append( num & ((1 << num_bits_shift) - 1) ) - num = num >> num_bits_shift - return tuple(a) - - def pack(z, num_bits_shift: int = 128) -> int: - limbs = (z.d0, z.d1, z.d2) - return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) - - - generator = pack(ids.generator) - x = pack(ids.x) - p = pack(ids.p) - - success_x = is_quad_residue(x, p) - root_x = sqrt(x, p) if success_x else None - - success_gx = is_quad_residue(generator*x, p) - root_gx = sqrt(generator*x, p) if success_gx else None - - # Check that one is 0 and the other is 1 - if x != 0: - assert success_x + success_gx ==1 - - # `None` means that no root was found, but we need to transform these into a felt no matter what - if root_x == None: - root_x = 0 - if root_gx == None: - root_gx = 0 - ids.success_x = int(success_x) - split_root_x = split(root_x) - split_root_gx = split(root_gx) - ids.sqrt_x.d0 = split_root_x[0] - ids.sqrt_x.d1 = split_root_x[1] - ids.sqrt_x.d2 = split_root_x[2] - ids.sqrt_gx.d0 = split_root_gx[0] - ids.sqrt_gx.d1 = split_root_gx[1] - ids.sqrt_gx.d2 = split_root_gx[2] - %} +%{ + from starkware.python.math_utils import is_quad_residue, sqrt + + def split(num: int, num_bits_shift: int = 128, length: int = 3): + a = [] + for _ in range(length): + a.append( num & ((1 << num_bits_shift) - 1) ) + num = num >> num_bits_shift + return tuple(a) + + def pack(z, num_bits_shift: int = 128) -> int: + limbs = (z.d0, z.d1, z.d2) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + + + generator = pack(ids.generator) + x = pack(ids.x) + p = pack(ids.p) + + success_x = is_quad_residue(x, p) + root_x = sqrt(x, p) if success_x else None + + success_gx = is_quad_residue(generator*x, p) + root_gx = sqrt(generator*x, p) if success_gx else None + + # Check that one is 0 and the other is 1 + if x != 0: + assert success_x + success_gx ==1 + + # `None` means that no root was found, but we need to transform these into a felt no matter what + if root_x == None: + root_x = 0 + if root_gx == None: + root_gx = 0 + ids.success_x = int(success_x) + split_root_x = split(root_x) + split_root_gx = split(root_gx) + ids.sqrt_x.d0 = split_root_x[0] + ids.sqrt_x.d1 = split_root_x[1] + ids.sqrt_x.d2 = split_root_x[2] + ids.sqrt_gx.d0 = split_root_gx[0] + ids.sqrt_gx.d1 = split_root_gx[1] + ids.sqrt_gx.d2 = split_root_gx[2] +%} */ -pub fn get_square_root( +pub fn u384_get_square_root( vm: &mut VirtualMachine, ids_data: &HashMap, ap_tracking: &ApTracking, @@ -102,6 +103,96 @@ pub fn get_square_root( Ok(()) } +/* Implements Hint: +%{ + from starkware.python.math_utils import is_quad_residue, sqrt + + def split(a: int): + return (a & ((1 << 128) - 1), a >> 128) + + def pack(z) -> int: + return z.low + (z.high << 128) + + generator = pack(ids.generator) + x = pack(ids.x) + p = pack(ids.p) + + success_x = is_quad_residue(x, p) + root_x = sqrt(x, p) if success_x else None + success_gx = is_quad_residue(generator*x, p) + root_gx = sqrt(generator*x, p) if success_gx else None + + # Check that one is 0 and the other is 1 + if x != 0: + assert success_x + success_gx == 1 + + # `None` means that no root was found, but we need to transform these into a felt no matter what + if root_x == None: + root_x = 0 + if root_gx == None: + root_gx = 0 + ids.success_x = int(success_x) + ids.success_gx = int(success_gx) + split_root_x = split(root_x) + # print('split root x', split_root_x) + split_root_gx = split(root_gx) + ids.sqrt_x.low = split_root_x[0] + ids.sqrt_x.high = split_root_x[1] + ids.sqrt_gx.low = split_root_gx[0] + ids.sqrt_gx.high = split_root_gx[1] +%} +*/ +// TODO: extract UintNNN methods to a trait, and use generics +// to merge this with u384_get_square_root +pub fn u256_get_square_root( + vm: &mut VirtualMachine, + ids_data: &HashMap, + ap_tracking: &ApTracking, +) -> Result<(), HintError> { + let generator = Uint256::from_var_name("generator", vm, ids_data, ap_tracking)?.pack(); + let x = Uint256::from_var_name("x", vm, ids_data, ap_tracking)?.pack(); + let p = Uint256::from_var_name("p", vm, ids_data, ap_tracking)?.pack(); + let success_x = is_quad_residue(&x, &p)?; + + let root_x = if success_x { + sqrt_prime_power(&x, &p).unwrap_or_default() + } else { + BigUint::zero() + }; + + let gx = generator * &x; + let success_gx = is_quad_residue(&gx, &p)?; + + let root_gx = if success_gx { + sqrt_prime_power(&gx, &p).unwrap_or_default() + } else { + BigUint::zero() + }; + + if !&x.is_zero() && !(success_x ^ success_gx) { + return Err(HintError::AssertionFailed(String::from( + "assert success_x + success_gx ==1", + ))); + } + insert_value_from_var_name( + "success_x", + Felt252::from(success_x as u8), + vm, + ids_data, + ap_tracking, + )?; + insert_value_from_var_name( + "success_gx", + Felt252::from(success_gx as u8), + vm, + ids_data, + ap_tracking, + )?; + Uint256::split(&root_x).insert_from_var_name("sqrt_x", vm, ids_data, ap_tracking)?; + Uint256::split(&root_gx).insert_from_var_name("sqrt_gx", vm, ids_data, ap_tracking)?; + Ok(()) +} + /* Implements Hint: %{ from starkware.python.math_utils import div_mod @@ -182,7 +273,7 @@ mod tests { #[test] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] - fn run_get_square_ok_goldilocks_prime() { + fn run_u384_get_square_ok_goldilocks_prime() { let mut vm = vm_with_range_check!(); //Initialize fp vm.run_context.fp = 14; @@ -211,7 +302,10 @@ mod tests { ((1, 8), 0) ]; //Execute the hint - assert_matches!(run_hint!(vm, ids_data, hint_code::GET_SQUARE_ROOT), Ok(())); + assert_matches!( + run_hint!(vm, ids_data, hint_code::UINT384_GET_SQUARE_ROOT), + Ok(()) + ); //Check hint memory inserts check_memory![ vm.segments.memory, @@ -230,7 +324,7 @@ mod tests { #[test] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] - fn run_get_square_no_successes() { + fn run_u384_get_square_no_successes() { let mut vm = vm_with_range_check!(); //Initialize fp vm.run_context.fp = 14; @@ -259,14 +353,14 @@ mod tests { ((1, 8), 0) ]; //Execute the hint - assert_matches!(run_hint!(vm, ids_data, hint_code::GET_SQUARE_ROOT), + assert_matches!(run_hint!(vm, ids_data, hint_code::UINT384_GET_SQUARE_ROOT), Err(HintError::AssertionFailed(s)) if s == "assert success_x + success_gx ==1" ); } #[test] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] - fn run_get_square_ok_success_gx() { + fn run_u384_get_square_ok_success_gx() { let mut vm = vm_with_range_check!(); //Initialize fp vm.run_context.fp = 14; @@ -277,7 +371,7 @@ mod tests { ("generator", -8), ("sqrt_x", -5), ("sqrt_gx", -2), - ("success_x", 1) + ("success_x", 1), ]; //Insert ids into memory vm.segments = segments![ @@ -292,10 +386,13 @@ mod tests { //generator ((1, 6), 71), ((1, 7), 0), - ((1, 8), 0) + ((1, 8), 0), ]; //Execute the hint - assert_matches!(run_hint!(vm, ids_data, hint_code::GET_SQUARE_ROOT), Ok(())); + assert_matches!( + run_hint!(vm, ids_data, hint_code::UINT384_GET_SQUARE_ROOT), + Ok(()) + ); //Check hint memory inserts check_memory![ vm.segments.memory, @@ -308,7 +405,133 @@ mod tests { ((1, 13), 0), ((1, 14), 0), // success_x - ((1, 15), 0) + ((1, 15), 0), + ]; + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn run_u256_get_square_ok_goldilocks_prime() { + let mut vm = vm_with_range_check!(); + //Initialize fp + vm.run_context.fp = 14; + //Create hint_data + let ids_data = non_continuous_ids_data![ + ("p", -14), + ("x", -11), + ("generator", -8), + ("sqrt_x", -5), + ("sqrt_gx", -2), + ("success_x", 1), + ("success_gx", 2), + ]; + //Insert ids into memory + vm.segments = segments![ + //p + ((1, 0), 18446744069414584321), + ((1, 1), 0), + //x + ((1, 3), 25), + ((1, 4), 0), + //generator + ((1, 6), 7), + ((1, 7), 0), + ]; + //Execute the hint + assert!(run_hint!(vm, ids_data, hint_code::UINT256_GET_SQUARE_ROOT).is_ok()); + //Check hint memory inserts + check_memory![ + vm.segments.memory, + // sqrt_x + ((1, 9), 5), + ((1, 10), 0), + // sqrt_gx + ((1, 12), 0), + ((1, 13), 0), + // success_x + ((1, 15), 1), + // success_gx + ((1, 16), 0), + ]; + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn run_u256_get_square_no_successes() { + let mut vm = vm_with_range_check!(); + //Initialize fp + vm.run_context.fp = 14; + //Create hint_data + let ids_data = non_continuous_ids_data![ + ("p", -14), + ("x", -11), + ("generator", -8), + ("sqrt_x", -5), + ("sqrt_gx", -2), + ("success_x", 1), + ("success_gx", 2), + ]; + //Insert ids into memory + vm.segments = segments![ + //p + ((1, 0), 3), + ((1, 1), 0), + //x + ((1, 3), 17), + ((1, 4), 0), + //generator + ((1, 6), 1), + ((1, 7), 0), + ]; + //Execute the hint + assert_matches!(run_hint!(vm, ids_data, hint_code::UINT256_GET_SQUARE_ROOT), + Err(HintError::AssertionFailed(s)) if s == "assert success_x + success_gx ==1" + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn run_u256_get_square_ok_success_gx() { + let mut vm = vm_with_range_check!(); + //Initialize fp + vm.run_context.fp = 14; + //Create hint_data + let ids_data = non_continuous_ids_data![ + ("p", -14), + ("x", -11), + ("generator", -8), + ("sqrt_x", -5), + ("sqrt_gx", -2), + ("success_x", 1), + ("success_gx", 2), + ]; + //Insert ids into memory + vm.segments = segments![ + //p + ((1, 0), 3), + ((1, 1), 0), + //x + ((1, 3), 17), + ((1, 4), 0), + //generator + ((1, 6), 71), + ((1, 7), 0), + ]; + //Execute the hint + assert!(run_hint!(vm, ids_data, hint_code::UINT256_GET_SQUARE_ROOT).is_ok()); + //Check hint memory inserts + check_memory![ + vm.segments.memory, + // sqrt_x + ((1, 9), 0), + ((1, 10), 0), + // sqrt_gx + ((1, 12), 1), + ((1, 13), 0), + // success_x + ((1, 15), 0), + // success_gx + ((1, 16), 1), ]; } diff --git a/src/hint_processor/builtin_hint_processor/hint_code.rs b/src/hint_processor/builtin_hint_processor/hint_code.rs index 93b52bc753..00ddb675b7 100644 --- a/src/hint_processor/builtin_hint_processor/hint_code.rs +++ b/src/hint_processor/builtin_hint_processor/hint_code.rs @@ -861,15 +861,14 @@ from starkware.python.math_utils import recover_y ids.p.x = ids.x # This raises an exception if `x` is not on the curve. ids.p.y = recover_y(ids.x, ALPHA, BETA, FIELD_PRIME)"; -pub(crate) const PACK_MODN_DIV_MODN: &str = - "from starkware.cairo.common.cairo_secp.secp_utils import pack +pub const PACK_MODN_DIV_MODN: &str = "from starkware.cairo.common.cairo_secp.secp_utils import pack from starkware.python.math_utils import div_mod, safe_div N = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141 x = pack(ids.x, PRIME) % N s = pack(ids.s, PRIME) % N value = res = div_mod(x, s, N)"; -pub(crate) const XS_SAFE_DIV: &str = "value = k = safe_div(res * s - x, N)"; +pub const XS_SAFE_DIV: &str = "value = k = safe_div(res * s - x, N)"; // The following hints support the lib https://github.com/NethermindEth/research-basic-Cairo-operations-big-integers/blob/main/lib pub const UINT384_UNSIGNED_DIV_REM: &str = "def split(num: int, num_bits_shift: int, length: int): @@ -1031,7 +1030,7 @@ ids.remainder.d2 = remainder_split[2]"#; pub const UINT384_SIGNED_NN: &str = "memory[ap] = 1 if 0 <= (ids.a.d2 % PRIME) < 2 ** 127 else 0"; -pub(crate) const GET_SQUARE_ROOT: &str = +pub const UINT384_GET_SQUARE_ROOT: &str = "from starkware.python.math_utils import is_quad_residue, sqrt def split(num: int, num_bits_shift: int = 128, length: int = 3): @@ -1075,6 +1074,42 @@ ids.sqrt_gx.d0 = split_root_gx[0] ids.sqrt_gx.d1 = split_root_gx[1] ids.sqrt_gx.d2 = split_root_gx[2]"; +pub const UINT256_GET_SQUARE_ROOT: &str = r#"from starkware.python.math_utils import is_quad_residue, sqrt + +def split(a: int): + return (a & ((1 << 128) - 1), a >> 128) + +def pack(z) -> int: + return z.low + (z.high << 128) + +generator = pack(ids.generator) +x = pack(ids.x) +p = pack(ids.p) + +success_x = is_quad_residue(x, p) +root_x = sqrt(x, p) if success_x else None +success_gx = is_quad_residue(generator*x, p) +root_gx = sqrt(generator*x, p) if success_gx else None + +# Check that one is 0 and the other is 1 +if x != 0: + assert success_x + success_gx == 1 + +# `None` means that no root was found, but we need to transform these into a felt no matter what +if root_x == None: + root_x = 0 +if root_gx == None: + root_gx = 0 +ids.success_x = int(success_x) +ids.success_gx = int(success_gx) +split_root_x = split(root_x) +# print('split root x', split_root_x) +split_root_gx = split(root_gx) +ids.sqrt_x.low = split_root_x[0] +ids.sqrt_x.high = split_root_x[1] +ids.sqrt_gx.low = split_root_gx[0] +ids.sqrt_gx.high = split_root_gx[1]"#; + pub const UINT384_DIV: &str = "from starkware.python.math_utils import div_mod def split(num: int, num_bits_shift: int, length: int): diff --git a/src/tests/cairo_run_test.rs b/src/tests/cairo_run_test.rs index 0a3f2e410d..65fae7c704 100644 --- a/src/tests/cairo_run_test.rs +++ b/src/tests/cairo_run_test.rs @@ -734,7 +734,7 @@ fn uint384_extension() { #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] fn field_arithmetic() { let program_data = include_bytes!("../../cairo_programs/field_arithmetic.json"); - run_program_simple_with_memory_holes(program_data.as_slice(), 272); + run_program_simple_with_memory_holes(program_data.as_slice(), 464); } #[test] diff --git a/src/utils.rs b/src/utils.rs index 777add36f9..0501c463bf 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -342,7 +342,7 @@ pub mod test_utils { pub(crate) use ids_data; macro_rules! non_continuous_ids_data { - ( $( ($name: expr, $offset:expr) ),* ) => { + ( $( ($name: expr, $offset:expr) ),* $(,)? ) => { { let mut ids_data = crate::stdlib::collections::HashMap::::new(); $(