Skip to content

Commit

Permalink
feat(hints): add NewHint#44 (#1025)
Browse files Browse the repository at this point in the history
* Add NewHint#44

* Update changelog

* Use `bits` instead of `shl` for comparison

Co-authored-by: Mario Rugiero <mario.rugiero@lambdaclass.com>

---------

Co-authored-by: Mario Rugiero <mario.rugiero@lambdaclass.com>
  • Loading branch information
MegaRedHand and Oppen authored Apr 21, 2023
1 parent 375373d commit cdc28b0
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 20 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@

#### Upcoming Changes

* Add missing hint on uint256_improvements lib [#1025](https://github.com/lambdaclass/cairo-rs/pull/1025):

`BuiltinHintProcessor` now supports the following hint:

```python
from starkware.python.math_utils import isqrt
n = (ids.n.high << 128) + ids.n.low
root = isqrt(n)
assert 0 <= root < 2 ** 128
ids.root = root
```

* Add missing hint on uint256_improvements lib [#1024](https://github.com/lambdaclass/cairo-rs/pull/1024):

`BuiltinHintProcessor` now supports the following hint:
Expand Down
19 changes: 19 additions & 0 deletions cairo_programs/uint256_improvements.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -352,10 +352,29 @@ func test_uint128_add{range_check_ptr}() {
return ();
}

func test_uint256_sqrt{range_check_ptr}() {
let n = Uint256(8, 0);

let (res) = uint256_sqrt(n);

assert res = Uint256(2, 0);

let n = Uint256(
340282366920938463463374607431768211455, 21267647932558653966460912964485513215
);

let (res) = uint256_sqrt(n);

assert res = Uint256(85070591730234615865843651857942052863, 0);

return ();
}

func main{range_check_ptr}() {
test_udiv_expanded();
test_uint256_sub();
test_uint128_add();
test_uint256_sqrt();

return ();
}
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,10 @@ impl HintProcessor for BuiltinHintProcessor {
hint_code::UINT256_SUB => uint256_sub(vm, &hint_data.ids_data, &hint_data.ap_tracking),
hint_code::SPLIT_64 => split_64(vm, &hint_data.ids_data, &hint_data.ap_tracking),
hint_code::UINT256_SQRT => {
uint256_sqrt(vm, &hint_data.ids_data, &hint_data.ap_tracking)
uint256_sqrt(vm, &hint_data.ids_data, &hint_data.ap_tracking, false)
}
hint_code::UINT256_SQRT_FELT => {
uint256_sqrt(vm, &hint_data.ids_data, &hint_data.ap_tracking, true)
}
hint_code::UINT256_SIGNED_NN => {
uint256_signed_nn(vm, &hint_data.ids_data, &hint_data.ap_tracking)
Expand Down
6 changes: 6 additions & 0 deletions src/hint_processor/builtin_hint_processor/hint_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,12 @@ assert 0 <= root < 2 ** 128
ids.root.low = root
ids.root.high = 0"#;

pub const UINT256_SQRT_FELT: &str = r#"from starkware.python.math_utils import isqrt
n = (ids.n.high << 128) + ids.n.low
root = isqrt(n)
assert 0 <= root < 2 ** 128
ids.root = root"#;

pub const UINT256_SIGNED_NN: &str = "memory[ap] = 1 if 0 <= (ids.a.high % PRIME) < 2 ** 127 else 0";

pub const UINT256_UNSIGNED_DIV_REM: &str = r#"a = (ids.a.high << 128) + ids.a.low
Expand Down
59 changes: 40 additions & 19 deletions src/hint_processor/builtin_hint_processor/uint256_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,33 +245,37 @@ pub fn uint256_sqrt(
vm: &mut VirtualMachine,
ids_data: &HashMap<String, HintReference>,
ap_tracking: &ApTracking,
only_low: bool,
) -> Result<(), HintError> {
let n_addr = get_relocatable_from_var_name("n", vm, ids_data, ap_tracking)?;
let root_addr = get_relocatable_from_var_name("root", vm, ids_data, ap_tracking)?;
let n_low = vm.get_integer(n_addr)?;
let n_high = vm.get_integer((n_addr + 1_usize)?)?;
let n_low = n_low.as_ref();
let n_high = n_high.as_ref();
let n = Uint256::from_var_name("n", vm, ids_data, ap_tracking)?;
let n = pack(n);

//Main logic
//from starkware.python.math_utils import isqrt
//n = (ids.n.high << 128) + ids.n.low
//root = isqrt(n)
//assert 0 <= root < 2 ** 128
//ids.root.low = root
//ids.root.high = 0
// Main logic
// from starkware.python.math_utils import isqrt
// n = (ids.n.high << 128) + ids.n.low
// root = isqrt(n)
// assert 0 <= root < 2 ** 128
// ids.root.low = root
// ids.root.high = 0

let root = isqrt(&(&n_high.to_biguint().shl(128_u32) + n_low.to_biguint()))?;
let root = isqrt(&n)?;

if root >= num_bigint::BigUint::one().shl(128_u32) {
if root.bits() > 128 {
return Err(HintError::AssertionFailed(format!(
"assert 0 <= {} < 2 ** 128",
&root
)));
}
vm.insert_value(root_addr, Felt252::new(root))?;
vm.insert_value((root_addr + 1_i32)?, Felt252::zero())
.map_err(HintError::Memory)

let root = Felt252::new(root);

if only_low {
insert_value_from_var_name("root", root, vm, ids_data, ap_tracking)?;
} else {
let root_u256 = Uint256::from_values(root, Felt252::zero());
root_u256.insert_from_var_name("root", vm, ids_data, ap_tracking)?;
}
Ok(())
}

/*
Expand Down Expand Up @@ -706,7 +710,7 @@ mod tests {
#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn run_uint256_sqrt_ok() {
let hint_code = "from starkware.python.math_utils import isqrt\nn = (ids.n.high << 128) + ids.n.low\nroot = isqrt(n)\nassert 0 <= root < 2 ** 128\nids.root.low = root\nids.root.high = 0";
let hint_code = hint_code::UINT256_SQRT;
let mut vm = vm_with_range_check!();
//Initialize fp
vm.run_context.fp = 5;
Expand All @@ -724,6 +728,23 @@ mod tests {
];
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn run_uint256_sqrt_felt_ok() {
let hint_code = "from starkware.python.math_utils import isqrt\nn = (ids.n.high << 128) + ids.n.low\nroot = isqrt(n)\nassert 0 <= root < 2 ** 128\nids.root = root";
let mut vm = vm_with_range_check!();
//Initialize fp
vm.run_context.fp = 0;
//Create hint_data
let ids_data = non_continuous_ids_data![("n", 0), ("root", 2)];
vm.segments = segments![((1, 0), 879232), ((1, 1), 135906)];
//Execute the hint
assert_matches!(run_hint!(vm, ids_data, hint_code), Ok(()));
//Check hint memory inserts
//ids.root
check_memory![vm.segments.memory, ((1, 2), 6800471701195223914689)];
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn run_uint256_sqrt_assert_error() {
Expand Down

1 comment on commit cdc28b0

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 1.30.

Benchmark suite Current: cdc28b0 Previous: 375373d Ratio
build runner 3421900 ns/iter (± 230757) 2551290 ns/iter (± 831) 1.34

This comment was automatically generated by workflow using github-action-benchmark.

CC: @unbalancedparentheses

Please sign in to comment.