diff --git a/CHANGELOG.md b/CHANGELOG.md index def8da9ca9..e059ff9e60 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,108 @@ #### Upcoming Changes +* Implement hints on uint384 lib (Part 1) [#960](https://github.com/lambdaclass/cairo-rs/pull/960) + + `BuiltinHintProcessor` now supports the following hints: + + ```python + def split(num: int, num_bits_shift: int, length: int): + a = [] + for _ in range(length): + a.append( num & ((1 << num_bits_shift) - 1) ) + num = num >> num_bits_shift + return tuple(a) + + def pack(z, num_bits_shift: int) -> int: + limbs = (z.d0, z.d1, z.d2) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + + a = pack(ids.a, num_bits_shift = 128) + div = pack(ids.div, num_bits_shift = 128) + quotient, remainder = divmod(a, div) + + quotient_split = split(quotient, num_bits_shift=128, length=3) + assert len(quotient_split) == 3 + + ids.quotient.d0 = quotient_split[0] + ids.quotient.d1 = quotient_split[1] + ids.quotient.d2 = quotient_split[2] + + remainder_split = split(remainder, num_bits_shift=128, length=3) + ids.remainder.d0 = remainder_split[0] + ids.remainder.d1 = remainder_split[1] + ids.remainder.d2 = remainder_split[2] + ``` + + ```python + ids.low = ids.a & ((1<<128) - 1) + ids.high = ids.a >> 128 + ``` + + ```python + sum_d0 = ids.a.d0 + ids.b.d0 + ids.carry_d0 = 1 if sum_d0 >= ids.SHIFT else 0 + sum_d1 = ids.a.d1 + ids.b.d1 + ids.carry_d0 + ids.carry_d1 = 1 if sum_d1 >= ids.SHIFT else 0 + sum_d2 = ids.a.d2 + ids.b.d2 + ids.carry_d1 + ids.carry_d2 = 1 if sum_d2 >= ids.SHIFT else 0 + ``` + + ```python + def split(num: int, num_bits_shift: int, length: int): + a = [] + for _ in range(length): + a.append( num & ((1 << num_bits_shift) - 1) ) + num = num >> num_bits_shift + return tuple(a) + + def pack(z, num_bits_shift: int) -> int: + limbs = (z.d0, z.d1, z.d2) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + + def pack2(z, num_bits_shift: int) -> int: + limbs = (z.b01, z.b23, z.b45) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + + a = pack(ids.a, num_bits_shift = 128) + div = pack2(ids.div, num_bits_shift = 128) + quotient, remainder = divmod(a, div) + + quotient_split = split(quotient, num_bits_shift=128, length=3) + assert len(quotient_split) == 3 + + ids.quotient.d0 = quotient_split[0] + ids.quotient.d1 = quotient_split[1] + ids.quotient.d2 = quotient_split[2] + + remainder_split = split(remainder, num_bits_shift=128, length=3) + ids.remainder.d0 = remainder_split[0] + ids.remainder.d1 = remainder_split[1] + ids.remainder.d2 = remainder_split[2] + ``` + + ```python + from starkware.python.math_utils import isqrt + + def split(num: int, num_bits_shift: int, length: int): + a = [] + for _ in range(length): + a.append( num & ((1 << num_bits_shift) - 1) ) + num = num >> num_bits_shift + return tuple(a) + + def pack(z, num_bits_shift: int) -> int: + limbs = (z.d0, z.d1, z.d2) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + + a = pack(ids.a, num_bits_shift=128) + root = isqrt(a) + assert 0 <= root < 2 ** 192 + root_split = split(root, num_bits_shift=128, length=3) + ids.root.d0 = root_split[0] + ids.root.d1 = root_split[1] + ids.root.d2 = root_split[2] + ``` * Implement hint on `uint256_mul_div_mod`[#957](https://github.com/lambdaclass/cairo-rs/pull/957) `BuiltinHintProcessor` now supports the following hint: diff --git a/cairo_programs/is_quad_residue_test.cairo b/cairo_programs/is_quad_residue_test.cairo index 4b20e3a2e9..7b78ec6e65 100644 --- a/cairo_programs/is_quad_residue_test.cairo +++ b/cairo_programs/is_quad_residue_test.cairo @@ -39,5 +39,5 @@ func main{output_ptr: felt*}() { check_quad_res(inputs, expected, 0); - return(); + return (); } diff --git a/cairo_programs/uint384.cairo b/cairo_programs/uint384.cairo new file mode 100644 index 0000000000..ecb135dae8 --- /dev/null +++ b/cairo_programs/uint384.cairo @@ -0,0 +1,532 @@ +%builtins range_check +// Code taken from https://github.com/NethermindEth/research-basic-Cairo-operations-big-integers/blob/main/lib/uint384.cairo +from starkware.cairo.common.bitwise import bitwise_and, bitwise_or, bitwise_xor +from starkware.cairo.common.cairo_builtins import BitwiseBuiltin +from starkware.cairo.common.math import assert_in_range, assert_le, assert_nn_le, assert_not_zero +from starkware.cairo.common.math import unsigned_div_rem as frem +from starkware.cairo.common.math_cmp import is_le +from starkware.cairo.common.uint256 import Uint256, uint256_add, word_reverse_endian +from starkware.cairo.common.pow import pow +from starkware.cairo.common.registers import get_ap, get_fp_and_pc + +// This library is adapted from Cairo's common library Uint256 and it follows it as closely as possible. +// The library implements basic operations between 384-bit integers. +// Most operations use unsigned integers. Only a few operations are implemented for signed integers + +// Represents an integer in the range [0, 2^384). +struct Uint384 { + // The low 128 bits of the value. + d0: felt, + // The middle 128 bits of the value. + d1: felt, + // The # 128 bits of the value. + d2: felt, +} + +struct Uint384_expand { + B0: felt, + b01: felt, + b12: felt, + b23: felt, + b34: felt, + b45: felt, + b5: felt, +} + +const SHIFT = 2 ** 128; +const ALL_ONES = 2 ** 128 - 1; +const HALF_SHIFT = 2 ** 64; + +namespace uint384_lib { + // Verifies that the given integer is valid. + func check{range_check_ptr}(a: Uint384) { + [range_check_ptr] = a.d0; + [range_check_ptr + 1] = a.d1; + [range_check_ptr + 2] = a.d2; + let range_check_ptr = range_check_ptr + 3; + return (); + } + + // Adds two integers. Returns the result as a 384-bit integer and the (1-bit) carry. + func add{range_check_ptr}(a: Uint384, b: Uint384) -> (res: Uint384, carry: felt) { + alloc_locals; + local res: Uint384; + local carry_d0: felt; + local carry_d1: felt; + local carry_d2: felt; + %{ + sum_d0 = ids.a.d0 + ids.b.d0 + ids.carry_d0 = 1 if sum_d0 >= ids.SHIFT else 0 + sum_d1 = ids.a.d1 + ids.b.d1 + ids.carry_d0 + ids.carry_d1 = 1 if sum_d1 >= ids.SHIFT else 0 + sum_d2 = ids.a.d2 + ids.b.d2 + ids.carry_d1 + ids.carry_d2 = 1 if sum_d2 >= ids.SHIFT else 0 + %} + + // Either 0 or 1 + assert carry_d0 * carry_d0 = carry_d0; + assert carry_d1 * carry_d1 = carry_d1; + assert carry_d2 * carry_d2 = carry_d2; + + assert res.d0 = a.d0 + b.d0 - carry_d0 * SHIFT; + assert res.d1 = a.d1 + b.d1 + carry_d0 - carry_d1 * SHIFT; + assert res.d2 = a.d2 + b.d2 + carry_d1 - carry_d2 * SHIFT; + + check(res); + + return (res, carry_d2); + } + + // Subtracts two integers. Returns the result as a 384-bit integer. + func sub{range_check_ptr}(a: Uint384, b: Uint384) -> (res: Uint384) { + let (b_neg) = neg(b); + let (res, _) = add(a, b_neg); + return (res,); + } + + // Returns the bitwise NOT of an integer. + func not(a: Uint384) -> (res: Uint384) { + return (Uint384(d0=ALL_ONES - a.d0, d1=ALL_ONES - a.d1, d2=ALL_ONES - a.d2),); + } + + // Returns the negation of an integer. + // Note that the negation of -2**383 is -2**383. + func neg{range_check_ptr}(a: Uint384) -> (res: Uint384) { + let (not_num) = not(a); + let (res, _) = add(not_num, Uint384(d0=1, d1=0, d2=0)); + return (res,); + } + + // Adds two integers. Returns the result as a 384-bit integer and the (1-bit) carry. + // Doesn't verify that the result is a proper Uint384, that's now the responsibility of the calling function + func _add_no_uint384_check{range_check_ptr}(a: Uint384, b: Uint384) -> ( + res: Uint384, carry: felt + ) { + alloc_locals; + local res: Uint384; + local carry_d0: felt; + local carry_d1: felt; + local carry_d2: felt; + %{ + sum_d0 = ids.a.d0 + ids.b.d0 + ids.carry_d0 = 1 if sum_d0 >= ids.SHIFT else 0 + sum_d1 = ids.a.d1 + ids.b.d1 + ids.carry_d0 + ids.carry_d1 = 1 if sum_d1 >= ids.SHIFT else 0 + sum_d2 = ids.a.d2 + ids.b.d2 + ids.carry_d1 + ids.carry_d2 = 1 if sum_d2 >= ids.SHIFT else 0 + %} + + // Either 0 or 1 + assert carry_d0 * carry_d0 = carry_d0; + assert carry_d1 * carry_d1 = carry_d1; + assert carry_d2 * carry_d2 = carry_d2; + + assert res.d0 = a.d0 + b.d0 - carry_d0 * SHIFT; + assert res.d1 = a.d1 + b.d1 + carry_d0 - carry_d1 * SHIFT; + assert res.d2 = a.d2 + b.d2 + carry_d1 - carry_d2 * SHIFT; + + return (res, carry_d2); + } + + // Splits a field element in the range [0, 2^192) to its low 64-bit and high 128-bit parts. + func split_64{range_check_ptr}(a: felt) -> (low: felt, high: felt) { + alloc_locals; + local low: felt; + local high: felt; + + %{ + ids.low = ids.a & ((1<<64) - 1) + ids.high = ids.a >> 64 + %} + assert a = low + high * HALF_SHIFT; + assert [range_check_ptr + 0] = low; + assert [range_check_ptr + 1] = HALF_SHIFT - 1 - low; + assert [range_check_ptr + 2] = high; + let range_check_ptr = range_check_ptr + 3; + return (low, high); + } + + // Splits a field element in the range [0, 2^224) to its low 128-bit and high 96-bit parts. + func split_128{range_check_ptr}(a: felt) -> (low: felt, high: felt) { + alloc_locals; + const UPPER_BOUND = 2 ** 224; + const HIGH_BOUND = UPPER_BOUND / SHIFT; + local low: felt; + local high: felt; + + %{ + ids.low = ids.a & ((1<<128) - 1) + ids.high = ids.a >> 128 + %} + assert a = low + high * SHIFT; + assert [range_check_ptr + 0] = high; + assert [range_check_ptr + 1] = HIGH_BOUND - 1 - high; + assert [range_check_ptr + 2] = low; + let range_check_ptr = range_check_ptr + 3; + return (low, high); + } + + // Multiplies two integers. Returns the result as two 384-bit integers: the result has 2*384 bits, + // the returned integers represent the lower 384-bits and the higher 384-bits, respectively. + func mul{range_check_ptr}(a: Uint384, b: Uint384) -> (low: Uint384, high: Uint384) { + let (a0, a1) = split_64(a.d0); + let (a2, a3) = split_64(a.d1); + let (a4, a5) = split_64(a.d2); + let (b0, b1) = split_64(b.d0); + let (b2, b3) = split_64(b.d1); + let (b4, b5) = split_64(b.d2); + + let (res0, carry) = split_64(a0 * b0); + let (res1, carry) = split_64(a1 * b0 + a0 * b1 + carry); + let (res2, carry) = split_64(a2 * b0 + a1 * b1 + a0 * b2 + carry); + let (res3, carry) = split_64(a3 * b0 + a2 * b1 + a1 * b2 + a0 * b3 + carry); + let (res4, carry) = split_64(a4 * b0 + a3 * b1 + a2 * b2 + a1 * b3 + a0 * b4 + carry); + let (res5, carry) = split_64( + a5 * b0 + a4 * b1 + a3 * b2 + a2 * b3 + a1 * b4 + a0 * b5 + carry + ); + let (res6, carry) = split_64(a5 * b1 + a4 * b2 + a3 * b3 + a2 * b4 + a1 * b5 + carry); + let (res7, carry) = split_64(a5 * b2 + a4 * b3 + a3 * b4 + a2 * b5 + carry); + let (res8, carry) = split_64(a5 * b3 + a4 * b4 + a3 * b5 + carry); + let (res9, carry) = split_64(a5 * b4 + a4 * b5 + carry); + let (res10, carry) = split_64(a5 * b5 + carry); + + return ( + low=Uint384( + d0=res0 + HALF_SHIFT * res1, + d1=res2 + HALF_SHIFT * res3, + d2=res4 + HALF_SHIFT * res5, + ), + high=Uint384( + d0=res6 + HALF_SHIFT * res7, + d1=res8 + HALF_SHIFT * res9, + d2=res10 + HALF_SHIFT * carry, + ), + ); + } + func mul_expanded{range_check_ptr}(a: Uint384, b: Uint384_expand) -> ( + low: Uint384, high: Uint384 + ) { + let (a0, a1) = split_64(a.d0); + let (a2, a3) = split_64(a.d1); + let (a4, a5) = split_64(a.d2); + + let (res0, carry) = split_128(a1 * b.B0 + a0 * b.b01); + let (res2, carry) = split_128(a3 * b.B0 + a2 * b.b01 + a1 * b.b12 + a0 * b.b23 + carry); + let (res4, carry) = split_128( + a5 * b.B0 + a4 * b.b01 + a3 * b.b12 + a2 * b.b23 + a1 * b.b34 + a0 * b.b45 + carry + ); + let (res6, carry) = split_128( + a5 * b.b12 + a4 * b.b23 + a3 * b.b34 + a2 * b.b45 + a1 * b.b5 + carry + ); + let (res8, carry) = split_128(a5 * b.b34 + a4 * b.b45 + a3 * b.b5 + carry); + // let (res10, carry) = split_64(a5 * b.b5 + carry) + + return ( + low=Uint384(d0=res0, d1=res2, d2=res4), + high=Uint384(d0=res6, d1=res8, d2=a5 * b.b5 + carry), + ); + } + + func mul_d{range_check_ptr}(a: Uint384, b: Uint384) -> (low: Uint384, high: Uint384) { + alloc_locals; + let (a0, a1) = split_64(a.d0); + let (a2, a3) = split_64(a.d1); + let (a4, a5) = split_64(a.d2); + let (b0, b1) = split_64(b.d0); + let (b2, b3) = split_64(b.d1); + let (b4, b5) = split_64(b.d2); + + local B0 = b0 * HALF_SHIFT; + local b12 = b1 + b2 * HALF_SHIFT; + local b34 = b3 + b4 * HALF_SHIFT; + + let (res0, carry) = split_128(a1 * B0 + a0 * b.d0); + let (res2, carry) = split_128(a3 * B0 + a2 * b.d0 + a1 * b12 + a0 * b.d1 + carry); + let (res4, carry) = split_128( + a5 * B0 + a4 * b.d0 + a3 * b12 + a2 * b.d1 + a1 * b34 + a0 * b.d2 + carry + ); + let (res6, carry) = split_128( + a5 * b12 + a4 * b.d1 + a3 * b34 + a2 * b.d2 + a1 * b5 + carry + ); + let (res8, carry) = split_128(a5 * b34 + a4 * b.d2 + a3 * b5 + carry); + // let (res10, carry) = split_64(a5 * b5 + carry) + + return ( + low=Uint384(d0=res0, d1=res2, d2=res4), + high=Uint384(d0=res6, d1=res8, d2=a5 * b5 + carry), + ); + } + + func lt{range_check_ptr}(a: Uint384, b: Uint384) -> (res: felt) { + if (a.d2 == b.d2) { + if (a.d1 == b.d1) { + return (is_le(a.d0 + 1, b.d0),); + } + return (is_le(a.d1 + 1, b.d1),); + } + return (is_le(a.d2 + 1, b.d2),); + } + + // Returns 1 if the first unsigned integer is less than or equal to the second unsigned integer. + func le{range_check_ptr}(a: Uint384, b: Uint384) -> (res: felt) { + let (not_le) = lt(a=b, b=a); + return (1 - not_le,); + } + + // Unsigned integer division between two integers. Returns the quotient and the remainder. + // Conforms to EVM specifications: division by 0 yields 0. + func unsigned_div_rem{range_check_ptr}(a: Uint384, div: Uint384) -> ( + quotient: Uint384, remainder: Uint384 + ) { + alloc_locals; + local quotient: Uint384; + local remainder: Uint384; + + // If div == 0, return (0, 0, 0). + if (div.d0 + div.d1 + div.d2 == 0) { + return (quotient=Uint384(0, 0, 0), remainder=Uint384(0, 0, 0)); + } + + %{ + def split(num: int, num_bits_shift: int, length: int): + a = [] + for _ in range(length): + a.append( num & ((1 << num_bits_shift) - 1) ) + num = num >> num_bits_shift + return tuple(a) + + def pack(z, num_bits_shift: int) -> int: + limbs = (z.d0, z.d1, z.d2) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + + a = pack(ids.a, num_bits_shift = 128) + div = pack(ids.div, num_bits_shift = 128) + quotient, remainder = divmod(a, div) + + quotient_split = split(quotient, num_bits_shift=128, length=3) + assert len(quotient_split) == 3 + + ids.quotient.d0 = quotient_split[0] + ids.quotient.d1 = quotient_split[1] + ids.quotient.d2 = quotient_split[2] + + remainder_split = split(remainder, num_bits_shift=128, length=3) + ids.remainder.d0 = remainder_split[0] + ids.remainder.d1 = remainder_split[1] + ids.remainder.d2 = remainder_split[2] + %} + check(quotient); + check(remainder); + let (res_mul: Uint384, carry: Uint384) = mul_d(quotient, div); + assert carry = Uint384(0, 0, 0); + + let (check_val: Uint384, add_carry: felt) = _add_no_uint384_check(res_mul, remainder); + assert check_val = a; + assert add_carry = 0; + + let (is_valid) = lt(remainder, div); + assert is_valid = 1; + return (quotient=quotient, remainder=remainder); + } + + // Unsigned integer division between two integers. Returns the quotient and the remainder. + func unsigned_div_rem_expanded{range_check_ptr}(a: Uint384, div: Uint384_expand) -> ( + quotient: Uint384, remainder: Uint384 + ) { + alloc_locals; + local quotient: Uint384; + local remainder: Uint384; + + let div2 = Uint384(div.b01, div.b23, div.b45); + + %{ + def split(num: int, num_bits_shift: int, length: int): + a = [] + for _ in range(length): + a.append( num & ((1 << num_bits_shift) - 1) ) + num = num >> num_bits_shift + return tuple(a) + + def pack(z, num_bits_shift: int) -> int: + limbs = (z.d0, z.d1, z.d2) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + + def pack2(z, num_bits_shift: int) -> int: + limbs = (z.b01, z.b23, z.b45) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + + a = pack(ids.a, num_bits_shift = 128) + div = pack2(ids.div, num_bits_shift = 128) + quotient, remainder = divmod(a, div) + + quotient_split = split(quotient, num_bits_shift=128, length=3) + assert len(quotient_split) == 3 + + ids.quotient.d0 = quotient_split[0] + ids.quotient.d1 = quotient_split[1] + ids.quotient.d2 = quotient_split[2] + + remainder_split = split(remainder, num_bits_shift=128, length=3) + ids.remainder.d0 = remainder_split[0] + ids.remainder.d1 = remainder_split[1] + ids.remainder.d2 = remainder_split[2] + %} + check(quotient); + check(remainder); + let (res_mul: Uint384, carry: Uint384) = mul_expanded(quotient, div); + assert carry = Uint384(0, 0, 0); + + let (check_val: Uint384, add_carry: felt) = _add_no_uint384_check(res_mul, remainder); + assert check_val = a; + assert add_carry = 0; + + let (is_valid) = lt(remainder, div2); + assert is_valid = 1; + return (quotient=quotient, remainder=remainder); + } + + func square_e{range_check_ptr}(a: Uint384) -> (low: Uint384, high: Uint384) { + alloc_locals; + let (a0, a1) = split_64(a.d0); + let (a2, a3) = split_64(a.d1); + let (a4, a5) = split_64(a.d2); + + const HALF_SHIFT2 = 2 * HALF_SHIFT; + local a0_2 = a0 * 2; + local a34 = a3 + a4 * HALF_SHIFT2; + + let (res0, carry) = split_128(a0 * (a0 + a1 * HALF_SHIFT2)); + let (res2, carry) = split_128(a.d1 * a0_2 + a1 * (a1 + a2 * HALF_SHIFT2) + carry); + let (res4, carry) = split_128( + a.d2 * a0_2 + (a3 + a34) * a1 + a2 * (a2 + a3 * HALF_SHIFT2) + carry + ); + let (res6, carry) = split_128((a5 * a1 + a.d2 * a2) * 2 + a3 * a34 + carry); + let (res8, carry) = split_128(a5 * (a3 + a34) + a4 * a4 + carry); + // let (res10, carry) = split_64(a5*a5 + carry) + + return ( + low=Uint384(d0=res0, d1=res2, d2=res4), + high=Uint384(d0=res6, d1=res8, d2=a5 * a5 + carry), + ); + } + + // Returns the floor value of the square root of a Uint384 integer. + func sqrt{range_check_ptr}(a: Uint384) -> (res: Uint384) { + alloc_locals; + let (__fp__, _) = get_fp_and_pc(); + local root: Uint384; + + %{ + from starkware.python.math_utils import isqrt + + def split(num: int, num_bits_shift: int, length: int): + a = [] + for _ in range(length): + a.append( num & ((1 << num_bits_shift) - 1) ) + num = num >> num_bits_shift + return tuple(a) + + def pack(z, num_bits_shift: int) -> int: + limbs = (z.d0, z.d1, z.d2) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + + a = pack(ids.a, num_bits_shift=128) + root = isqrt(a) + assert 0 <= root < 2 ** 192 + root_split = split(root, num_bits_shift=128, length=3) + ids.root.d0 = root_split[0] + ids.root.d1 = root_split[1] + ids.root.d2 = root_split[2] + %} + + // Verify that 0 <= root < 2**192. + assert root.d2 = 0; + [range_check_ptr] = root.d0; + + // We don't need to check that 0 <= d1 < 2**64, since this gets checked + // when we check that carry==0 later + assert [range_check_ptr + 1] = root.d1; + let range_check_ptr = range_check_ptr + 2; + + // Verify that n >= root**2. + let (root_squared, carry) = square_e(root); + assert carry = Uint384(0, 0, 0); + let (check_lower_bound) = le(root_squared, a); + assert check_lower_bound = 1; + + // Verify that n <= (root+1)**2 - 1. + // In the case where root = 2**192 - 1, we will have next_root_squared=0, since + // (root+1)**2 = 2**384. Therefore next_root_squared - 1 = 2**384 - 1, as desired. + let (next_root, add_carry) = add(root, Uint384(1, 0, 0)); + assert add_carry = 0; + let (next_root_squared, _) = square_e(next_root); + let (next_root_squared_minus_one) = sub(next_root_squared, Uint384(1, 0, 0)); + let (check_upper_bound) = le(a, next_root_squared_minus_one); + assert check_upper_bound = 1; + + return (res=root); + } +} + +func test_uint384_operations{range_check_ptr}() { + // Test unsigned_div_rem + let a = Uint384(83434123481193248, 82349321849739284, 839243219401320423); + let div = Uint384(9283430921839492319493, 313248123482483248, 3790328402913840); + let (quotient: Uint384, remainder: Uint384) = uint384_lib.unsigned_div_rem{ + range_check_ptr=range_check_ptr + }(a, div); + assert quotient.d0 = 221; + assert quotient.d1 = 0; + assert quotient.d2 = 0; + + assert remainder.d0 = 340282366920936411825224315027446796751; + assert remainder.d1 = 340282366920938463394229121463989152931; + assert remainder.d2 = 1580642357361782; + + // Test split_128 + let b = 6805647338418769269267492148635364229100; + let (low, high) = uint384_lib.split_128{range_check_ptr=range_check_ptr}(b); + assert high = 19; + assert low = 340282366920938463463374607431768211436; + + // Test _add_no_uint384_test + + let c = Uint384(3789423292314891293, 21894, 340282366920938463463374607431768211455); + let d = Uint384(32838232, 17, 8); + let (sum_res, carry) = uint384_lib._add_no_uint384_check(c, d); + + assert sum_res.d0 = 3789423292347729525; + assert sum_res.d1 = 21911; + assert sum_res.d2 = 7; + assert carry = 1; + + // Test unsigned_div_rem_expanded + let e = Uint384(83434123481193248, 82349321849739284, 839243219401320423); + let div_expand = Uint384_expand( + 9283430921839492319493, 313248123482483248, 3790328402913840, 13, 78990, 109, 7 + ); + let (quotient: Uint384, remainder: Uint384) = uint384_lib.unsigned_div_rem_expanded{ + range_check_ptr=range_check_ptr + }(a, div_expand); + assert quotient.d0 = 7699479077076334; + assert quotient.d1 = 0; + assert quotient.d2 = 0; + + assert remainder.d0 = 340279955073565776659831804641277151872; + assert remainder.d1 = 340282366920938463463356863525615958397; + assert remainder.d2 = 16; + + // Test sqrt + let f = Uint384(83434123481193248, 82349321849739284, 839243219401320423); + let (root) = uint384_lib.sqrt(f); + assert root.d0 = 100835122758113432298839930225328621183; + assert root.d1 = 916102188; + assert root.d2 = 0; + + return (); +} + +func main{range_check_ptr: felt}() { + test_uint384_operations(); + return (); +} diff --git a/src/hint_processor/builtin_hint_processor/builtin_hint_processor_definition.rs b/src/hint_processor/builtin_hint_processor/builtin_hint_processor_definition.rs index f85fac3fe2..9c0e96eada 100644 --- a/src/hint_processor/builtin_hint_processor/builtin_hint_processor_definition.rs +++ b/src/hint_processor/builtin_hint_processor/builtin_hint_processor_definition.rs @@ -72,6 +72,10 @@ use felt::Felt252; use crate::hint_processor::builtin_hint_processor::skip_next_instruction::skip_next_instruction; use super::ec_utils::{chained_ec_op_random_ec_point_hint, random_ec_point_hint, recover_y_hint}; +use super::uint384::{ + add_no_uint384_check, uint384_split_128, uint384_sqrt, uint384_unsigned_div_rem, + uint384_unsigned_div_rem_expanded, +}; pub struct HintProcessorData { pub code: String, @@ -453,6 +457,21 @@ impl HintProcessor for BuiltinHintProcessor { chained_ec_op_random_ec_point_hint(vm, &hint_data.ids_data, &hint_data.ap_tracking) } hint_code::RECOVER_Y => recover_y_hint(vm, &hint_data.ids_data, &hint_data.ap_tracking), + hint_code::UINT384_UNSIGNED_DIV_REM => { + uint384_unsigned_div_rem(vm, &hint_data.ids_data, &hint_data.ap_tracking) + } + hint_code::UINT384_SPLIT_128 => { + uint384_split_128(vm, &hint_data.ids_data, &hint_data.ap_tracking) + } + hint_code::ADD_NO_UINT384_CHECK => { + add_no_uint384_check(vm, &hint_data.ids_data, &hint_data.ap_tracking, constants) + } + hint_code::UINT384_UNSIGNED_DIV_REM_EXPANDED => { + uint384_unsigned_div_rem_expanded(vm, &hint_data.ids_data, &hint_data.ap_tracking) + } + hint_code::UINT384_SQRT => { + uint384_sqrt(vm, &hint_data.ids_data, &hint_data.ap_tracking) + } hint_code::UINT256_MUL_DIV_MOD => { uint256_mul_div_mod(vm, &hint_data.ids_data, &hint_data.ap_tracking) } diff --git a/src/hint_processor/builtin_hint_processor/hint_code.rs b/src/hint_processor/builtin_hint_processor/hint_code.rs index 1574beedee..9d3dd75402 100644 --- a/src/hint_processor/builtin_hint_processor/hint_code.rs +++ b/src/hint_processor/builtin_hint_processor/hint_code.rs @@ -629,5 +629,94 @@ from starkware.python.math_utils import recover_y ids.p.x = ids.x # This raises an exception if `x` is not on the curve. ids.p.y = recover_y(ids.x, ALPHA, BETA, FIELD_PRIME)"; + +// The following hints support the lib https://github.com/NethermindEth/research-basic-Cairo-operations-big-integers/blob/main/lib/uint384.cairo +pub(crate) const UINT384_UNSIGNED_DIV_REM: &str = + "def split(num: int, num_bits_shift: int, length: int): + a = [] + for _ in range(length): + a.append( num & ((1 << num_bits_shift) - 1) ) + num = num >> num_bits_shift + return tuple(a) + +def pack(z, num_bits_shift: int) -> int: + limbs = (z.d0, z.d1, z.d2) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + +a = pack(ids.a, num_bits_shift = 128) +div = pack(ids.div, num_bits_shift = 128) +quotient, remainder = divmod(a, div) + +quotient_split = split(quotient, num_bits_shift=128, length=3) +assert len(quotient_split) == 3 + +ids.quotient.d0 = quotient_split[0] +ids.quotient.d1 = quotient_split[1] +ids.quotient.d2 = quotient_split[2] + +remainder_split = split(remainder, num_bits_shift=128, length=3) +ids.remainder.d0 = remainder_split[0] +ids.remainder.d1 = remainder_split[1] +ids.remainder.d2 = remainder_split[2]"; +pub(crate) const UINT384_SPLIT_128: &str = "ids.low = ids.a & ((1<<128) - 1) +ids.high = ids.a >> 128"; +pub(crate) const ADD_NO_UINT384_CHECK: &str = "sum_d0 = ids.a.d0 + ids.b.d0 +ids.carry_d0 = 1 if sum_d0 >= ids.SHIFT else 0 +sum_d1 = ids.a.d1 + ids.b.d1 + ids.carry_d0 +ids.carry_d1 = 1 if sum_d1 >= ids.SHIFT else 0 +sum_d2 = ids.a.d2 + ids.b.d2 + ids.carry_d1 +ids.carry_d2 = 1 if sum_d2 >= ids.SHIFT else 0"; +pub(crate) const UINT384_UNSIGNED_DIV_REM_EXPANDED: &str = + "def split(num: int, num_bits_shift: int, length: int): + a = [] + for _ in range(length): + a.append( num & ((1 << num_bits_shift) - 1) ) + num = num >> num_bits_shift + return tuple(a) + +def pack(z, num_bits_shift: int) -> int: + limbs = (z.d0, z.d1, z.d2) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + +def pack2(z, num_bits_shift: int) -> int: + limbs = (z.b01, z.b23, z.b45) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + +a = pack(ids.a, num_bits_shift = 128) +div = pack2(ids.div, num_bits_shift = 128) +quotient, remainder = divmod(a, div) + +quotient_split = split(quotient, num_bits_shift=128, length=3) +assert len(quotient_split) == 3 + +ids.quotient.d0 = quotient_split[0] +ids.quotient.d1 = quotient_split[1] +ids.quotient.d2 = quotient_split[2] + +remainder_split = split(remainder, num_bits_shift=128, length=3) +ids.remainder.d0 = remainder_split[0] +ids.remainder.d1 = remainder_split[1] +ids.remainder.d2 = remainder_split[2]"; +pub(crate) const UINT384_SQRT: &str = "from starkware.python.math_utils import isqrt + +def split(num: int, num_bits_shift: int, length: int): + a = [] + for _ in range(length): + a.append( num & ((1 << num_bits_shift) - 1) ) + num = num >> num_bits_shift + return tuple(a) + +def pack(z, num_bits_shift: int) -> int: + limbs = (z.d0, z.d1, z.d2) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + +a = pack(ids.a, num_bits_shift=128) +root = isqrt(a) +assert 0 <= root < 2 ** 192 +root_split = split(root, num_bits_shift=128, length=3) +ids.root.d0 = root_split[0] +ids.root.d1 = root_split[1] +ids.root.d2 = root_split[2]"; + #[cfg(feature = "skip_next_instruction_hint")] pub(crate) const SKIP_NEXT_INSTRUCTION: &str = "skip_next_instruction()"; diff --git a/src/hint_processor/builtin_hint_processor/mod.rs b/src/hint_processor/builtin_hint_processor/mod.rs index 75abd3a6e9..19348b601c 100644 --- a/src/hint_processor/builtin_hint_processor/mod.rs +++ b/src/hint_processor/builtin_hint_processor/mod.rs @@ -24,4 +24,5 @@ pub mod signature; pub mod skip_next_instruction; pub mod squash_dict_utils; pub mod uint256_utils; +pub mod uint384; pub mod usort; diff --git a/src/hint_processor/builtin_hint_processor/uint384.rs b/src/hint_processor/builtin_hint_processor/uint384.rs new file mode 100644 index 0000000000..055d907dd7 --- /dev/null +++ b/src/hint_processor/builtin_hint_processor/uint384.rs @@ -0,0 +1,708 @@ +use core::ops::Shl; +use felt::Felt252; +use num_bigint::BigUint; +use num_integer::Integer; +use num_traits::{One, Zero}; + +use crate::math_utils::isqrt; +use crate::stdlib::{borrow::Cow, collections::HashMap, prelude::*}; +use crate::types::relocatable::Relocatable; +use crate::{ + hint_processor::hint_processor_definition::HintReference, + serde::deserialize_program::ApTracking, + vm::{errors::hint_errors::HintError, vm_core::VirtualMachine}, +}; + +use super::hint_utils::{ + get_integer_from_var_name, get_relocatable_from_var_name, insert_value_from_var_name, +}; +use super::secp::bigint_utils::BigInt3; +// Notes: Hints in this lib use the type Uint384, which is equal to common lib's BigInt3 + +/* Reduced version of Uint384_expand +The full version has 7 limbs (B0, b01, b12, b23, b34, b45, b5), but only 3 are used by the pack2 fn (b01, b23, b45) +As there are no other uses of Uint384_expand outside of these in the lib, we can use a reduced version with just 3 limbs +*/ +#[derive(Debug, PartialEq)] +#[allow(non_snake_case)] +pub(crate) struct Uint384ExpandReduced<'a> { + pub b01: Cow<'a, Felt252>, + pub b23: Cow<'a, Felt252>, + pub b45: Cow<'a, Felt252>, +} + +impl Uint384ExpandReduced<'_> { + pub(crate) fn from_base_addr<'a>( + addr: Relocatable, + name: &str, + vm: &'a VirtualMachine, + ) -> Result, HintError> { + Ok(Uint384ExpandReduced { + b01: vm.get_integer((addr + 1)?).map_err(|_| { + HintError::IdentifierHasNoMember(name.to_string(), "b01".to_string()) + })?, + b23: vm.get_integer((addr + 3)?).map_err(|_| { + HintError::IdentifierHasNoMember(name.to_string(), "b23".to_string()) + })?, + b45: vm.get_integer((addr + 5)?).map_err(|_| { + HintError::IdentifierHasNoMember(name.to_string(), "b45".to_string()) + })?, + }) + } + pub(crate) fn from_var_name<'a>( + name: &str, + vm: &'a VirtualMachine, + ids_data: &HashMap, + ap_tracking: &ApTracking, + ) -> Result, HintError> { + let base_addr = get_relocatable_from_var_name(name, vm, ids_data, ap_tracking)?; + Uint384ExpandReduced::from_base_addr(base_addr, name, vm) + } +} + +fn split(num: &BigUint, num_bits_shift: u32) -> [BigUint; T] { + let mut num = num.clone(); + [0; T].map(|_| { + let a = &num & &((BigUint::one() << num_bits_shift) - 1_u32); + num = &num >> num_bits_shift; + a + }) +} + +fn pack(num: BigInt3, num_bits_shift: usize) -> BigUint { + let limbs = [num.d0, num.d1, num.d2]; + #[allow(deprecated)] + limbs + .into_iter() + .enumerate() + .map(|(idx, value)| value.to_biguint().shl(idx * num_bits_shift)) + .sum() +} + +fn pack2(num: Uint384ExpandReduced, num_bits_shift: usize) -> BigUint { + let limbs = [num.b01, num.b23, num.b45]; + #[allow(deprecated)] + limbs + .into_iter() + .enumerate() + .map(|(idx, value)| value.to_biguint().shl(idx * num_bits_shift)) + .sum() +} +/* Implements Hint: +%{ + def split(num: int, num_bits_shift: int, length: int): + a = [] + for _ in range(length): + a.append( num & ((1 << num_bits_shift) - 1) ) + num = num >> num_bits_shift + return tuple(a) + + def pack(z, num_bits_shift: int) -> int: + limbs = (z.d0, z.d1, z.d2) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + + a = pack(ids.a, num_bits_shift = 128) + div = pack(ids.div, num_bits_shift = 128) + quotient, remainder = divmod(a, div) + + quotient_split = split(quotient, num_bits_shift=128, length=3) + assert len(quotient_split) == 3 + + ids.quotient.d0 = quotient_split[0] + ids.quotient.d1 = quotient_split[1] + ids.quotient.d2 = quotient_split[2] + + remainder_split = split(remainder, num_bits_shift=128, length=3) + ids.remainder.d0 = remainder_split[0] + ids.remainder.d1 = remainder_split[1] + ids.remainder.d2 = remainder_split[2] +%} +*/ +pub fn uint384_unsigned_div_rem( + vm: &mut VirtualMachine, + ids_data: &HashMap, + ap_tracking: &ApTracking, +) -> Result<(), HintError> { + let a = pack(BigInt3::from_var_name("a", vm, ids_data, ap_tracking)?, 128); + let div = pack( + BigInt3::from_var_name("div", vm, ids_data, ap_tracking)?, + 128, + ); + let quotient_addr = get_relocatable_from_var_name("quotient", vm, ids_data, ap_tracking)?; + let remainder_addr = get_relocatable_from_var_name("remainder", vm, ids_data, ap_tracking)?; + let (quotient, remainder) = a.div_mod_floor(&div); + let quotient_split = split::<3>("ient, 128); + for (i, quotient_split) in quotient_split.iter().enumerate() { + vm.insert_value((quotient_addr + i)?, Felt252::from(quotient_split))?; + } + let remainder_split = split::<3>(&remainder, 128); + for (i, remainder_split) in remainder_split.iter().enumerate() { + vm.insert_value((remainder_addr + i)?, Felt252::from(remainder_split))?; + } + Ok(()) +} + +/* Implements Hint: + %{ + ids.low = ids.a & ((1<<128) - 1) + ids.high = ids.a >> 128 + %} +*/ +pub fn uint384_split_128( + vm: &mut VirtualMachine, + ids_data: &HashMap, + ap_tracking: &ApTracking, +) -> Result<(), HintError> { + let a = get_integer_from_var_name("a", vm, ids_data, ap_tracking)?.into_owned(); + insert_value_from_var_name( + "low", + &a & &Felt252::from(u128::MAX), + vm, + ids_data, + ap_tracking, + )?; + insert_value_from_var_name("high", a >> 128_u32, vm, ids_data, ap_tracking) +} + +/* Implements Hint: +%{ + sum_d0 = ids.a.d0 + ids.b.d0 + ids.carry_d0 = 1 if sum_d0 >= ids.SHIFT else 0 + sum_d1 = ids.a.d1 + ids.b.d1 + ids.carry_d0 + ids.carry_d1 = 1 if sum_d1 >= ids.SHIFT else 0 + sum_d2 = ids.a.d2 + ids.b.d2 + ids.carry_d1 + ids.carry_d2 = 1 if sum_d2 >= ids.SHIFT else 0 +%} + */ +pub fn add_no_uint384_check( + vm: &mut VirtualMachine, + ids_data: &HashMap, + ap_tracking: &ApTracking, + constants: &HashMap, +) -> Result<(), HintError> { + let a = BigInt3::from_var_name("a", vm, ids_data, ap_tracking)?; + let b = BigInt3::from_var_name("b", vm, ids_data, ap_tracking)?; + // This hint is not from the cairo commonlib, and its lib can be found under different paths, so we cant rely on a full path name + let shift = constants + .iter() + .find(|(k, _)| k.rsplit('.').next() == Some("SHIFT")) + .map(|(_, n)| n.to_biguint()) + .ok_or(HintError::MissingConstant("SHIFT"))?; + + let sum_d0 = a.d0.to_biguint() + b.d0.to_biguint(); + let carry_d0 = Felt252::from((sum_d0 >= shift) as usize); + let sum_d1 = a.d1.to_biguint() + b.d1.to_biguint(); + let carry_d1 = Felt252::from((sum_d1 >= shift) as usize); + let sum_d2 = a.d2.to_biguint() + b.d2.to_biguint(); + let carry_d2 = Felt252::from((sum_d2 >= shift) as usize); + + insert_value_from_var_name("carry_d0", carry_d0, vm, ids_data, ap_tracking)?; + insert_value_from_var_name("carry_d1", carry_d1, vm, ids_data, ap_tracking)?; + insert_value_from_var_name("carry_d2", carry_d2, vm, ids_data, ap_tracking) +} + +/* Implements Hint: +%{ + def split(num: int, num_bits_shift: int, length: int): + a = [] + for _ in range(length): + a.append( num & ((1 << num_bits_shift) - 1) ) + num = num >> num_bits_shift + return tuple(a) + + def pack(z, num_bits_shift: int) -> int: + limbs = (z.d0, z.d1, z.d2) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + + def pack2(z, num_bits_shift: int) -> int: + limbs = (z.b01, z.b23, z.b45) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + + a = pack(ids.a, num_bits_shift = 128) + div = pack2(ids.div, num_bits_shift = 128) + quotient, remainder = divmod(a, div) + + quotient_split = split(quotient, num_bits_shift=128, length=3) + assert len(quotient_split) == 3 + + ids.quotient.d0 = quotient_split[0] + ids.quotient.d1 = quotient_split[1] + ids.quotient.d2 = quotient_split[2] + + remainder_split = split(remainder, num_bits_shift=128, length=3) + ids.remainder.d0 = remainder_split[0] + ids.remainder.d1 = remainder_split[1] + ids.remainder.d2 = remainder_split[2] +%} +*/ +pub fn uint384_unsigned_div_rem_expanded( + vm: &mut VirtualMachine, + ids_data: &HashMap, + ap_tracking: &ApTracking, +) -> Result<(), HintError> { + let a = pack(BigInt3::from_var_name("a", vm, ids_data, ap_tracking)?, 128); + let div = pack2( + Uint384ExpandReduced::from_var_name("div", vm, ids_data, ap_tracking)?, + 128, + ); + let quotient_addr = get_relocatable_from_var_name("quotient", vm, ids_data, ap_tracking)?; + let remainder_addr = get_relocatable_from_var_name("remainder", vm, ids_data, ap_tracking)?; + let (quotient, remainder) = a.div_mod_floor(&div); + let quotient_split = split::<3>("ient, 128); + for (i, quotient_split) in quotient_split.iter().enumerate() { + vm.insert_value((quotient_addr + i)?, Felt252::from(quotient_split))?; + } + let remainder_split = split::<3>(&remainder, 128); + for (i, remainder_split) in remainder_split.iter().enumerate() { + vm.insert_value((remainder_addr + i)?, Felt252::from(remainder_split))?; + } + Ok(()) +} + +/* Implements Hint +%{ + from starkware.python.math_utils import isqrt + + def split(num: int, num_bits_shift: int, length: int): + a = [] + for _ in range(length): + a.append( num & ((1 << num_bits_shift) - 1) ) + num = num >> num_bits_shift + return tuple(a) + + def pack(z, num_bits_shift: int) -> int: + limbs = (z.d0, z.d1, z.d2) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + + a = pack(ids.a, num_bits_shift=128) + root = isqrt(a) + assert 0 <= root < 2 ** 192 + root_split = split(root, num_bits_shift=128, length=3) + ids.root.d0 = root_split[0] + ids.root.d1 = root_split[1] + ids.root.d2 = root_split[2] +%} + */ +pub fn uint384_sqrt( + vm: &mut VirtualMachine, + ids_data: &HashMap, + ap_tracking: &ApTracking, +) -> Result<(), HintError> { + let a = pack(BigInt3::from_var_name("a", vm, ids_data, ap_tracking)?, 128); + let root_addr = get_relocatable_from_var_name("root", vm, ids_data, ap_tracking)?; + let root = isqrt(&a)?; + if root.is_zero() || root.bits() > 192 { + return Err(HintError::AssertionFailed(String::from( + "assert 0 <= root < 2 ** 192", + ))); + } + let root_split = split::<3>(&root, 128); + for (i, root_split) in root_split.iter().enumerate() { + vm.insert_value((root_addr + i)?, Felt252::from(root_split))?; + } + Ok(()) +} +#[cfg(test)] +mod tests { + use super::*; + use crate::hint_processor::builtin_hint_processor::hint_code; + use crate::vm::vm_memory::memory_segments::MemorySegmentManager; + use crate::{ + any_box, + hint_processor::{ + builtin_hint_processor::builtin_hint_processor_definition::{ + BuiltinHintProcessor, HintProcessorData, + }, + hint_processor_definition::HintProcessor, + }, + types::{ + exec_scope::ExecutionScopes, + relocatable::{MaybeRelocatable, Relocatable}, + }, + utils::test_utils::*, + vm::{ + errors::memory_errors::MemoryError, runners::builtin_runner::RangeCheckBuiltinRunner, + vm_core::VirtualMachine, vm_memory::memory::Memory, + }, + }; + use assert_matches::assert_matches; + use felt::felt_str; + + #[cfg(target_arch = "wasm32")] + use wasm_bindgen_test::*; + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn run_unsigned_div_rem_ok() { + let mut vm = vm_with_range_check!(); + //Initialize fp + vm.run_context.fp = 10; + //Create hint_data + let ids_data = + non_continuous_ids_data![("a", -9), ("div", -6), ("quotient", -3), ("remainder", 0)]; + //Insert ids into memory + vm.segments = segments![ + //a + ((1, 1), 83434123481193248), + ((1, 2), 82349321849739284), + ((1, 3), 839243219401320423), + //div + ((1, 4), 9283430921839492319493), + ((1, 5), 313248123482483248), + ((1, 6), 3790328402913840) + ]; + //Execute the hint + assert_matches!( + run_hint!(vm, ids_data, hint_code::UINT384_UNSIGNED_DIV_REM), + Ok(()) + ); + //Check hint memory inserts + check_memory![ + vm.segments.memory, + // quotient + ((1, 7), 221), + ((1, 8), 0), + ((1, 9), 0), + // remainder + //((1, 10), 340282366920936411825224315027446796751), + //((1, 11), 340282366920938463394229121463989152931), + ((1, 12), 1580642357361782) + ]; + assert_eq!( + vm.segments + .memory + .get_integer((1, 10).into()) + .unwrap() + .as_ref(), + &felt_str!("340282366920936411825224315027446796751") + ); + assert_eq!( + vm.segments + .memory + .get_integer((1, 11).into()) + .unwrap() + .as_ref(), + &felt_str!("340282366920938463394229121463989152931") + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn run_unsigned_div_rem_invalid_memory_insert() { + let mut vm = vm_with_range_check!(); + //Initialize fp + vm.run_context.fp = 10; + //Create hint_data + let ids_data = + non_continuous_ids_data![("a", -9), ("div", -6), ("quotient", -3), ("remainder", 0)]; + //Insert ids into memory + vm.segments = segments![ + //a + ((1, 1), 83434123481193248), + ((1, 2), 82349321849739284), + ((1, 3), 839243219401320423), + //div + ((1, 4), 9283430921839492319493), + ((1, 5), 313248123482483248), + ((1, 6), 3790328402913840), + //quotient + ((1, 7), 2) + ]; + //Execute the hint + assert_matches!( + run_hint!(vm, ids_data, hint_code::UINT384_UNSIGNED_DIV_REM), + Err(HintError::Memory( + MemoryError::InconsistentMemory( + x, + y, + z, + ) + )) if x == Relocatable::from((1, 7)) && + y == MaybeRelocatable::from(Felt252::new(2)) && + z == MaybeRelocatable::from(Felt252::new(221)) + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn run_split_128_ok() { + let mut vm = vm_with_range_check!(); + //Initialize fp + vm.run_context.fp = 3; + //Create hint_data + let ids_data = ids_data!["a", "low", "high"]; + //Insert ids into memory + vm.segments = segments![((1, 0), 34895349583295832495320945304)]; + //Execute the hint + assert_matches!( + run_hint!(vm, ids_data, hint_code::UINT384_SPLIT_128), + Ok(()) + ); + //Check hint memory inserts + check_memory![ + vm.segments.memory, + // low + ((1, 1), 34895349583295832495320945304), + // high + ((1, 2), 0) + ]; + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn run_split_128_ok_big_number() { + let mut vm = vm_with_range_check!(); + //Initialize fp + vm.run_context.fp = 3; + //Create hint_data + let ids_data = ids_data!["a", "low", "high"]; + //Insert ids into memory + vm.segments.add(); + vm.segments.add(); + vm.segments + .memory + .insert( + (1, 0).into(), + Felt252::from(u128::MAX) * Felt252::from(20_u32), + ) + .unwrap(); + //Execute the hint + assert_matches!( + run_hint!(vm, ids_data, hint_code::UINT384_SPLIT_128), + Ok(()) + ); + //Check hint memory inserts + check_memory![ + vm.segments.memory, + // low + //((1, 1), 340282366920938463463374607431768211454) + // high + ((1, 2), 19) + ]; + assert_eq!( + vm.segments + .memory + .get_integer((1, 1).into()) + .unwrap() + .as_ref(), + &felt_str!("340282366920938463463374607431768211436") + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn run_split_128_invalid_memory_insert() { + let mut vm = vm_with_range_check!(); + //Initialize fp + vm.run_context.fp = 3; + //Create hint_data + let ids_data = ids_data!["a", "low", "high"]; + //Insert ids into memory + vm.segments = segments![((1, 0), 34895349583295832495320945304), ((1, 1), 2)]; + //Execute the hint + assert_matches!( + run_hint!(vm, ids_data, hint_code::UINT384_SPLIT_128), + Err(HintError::Memory( + MemoryError::InconsistentMemory( + x, + y, + z, + ) + )) if x == Relocatable::from((1, 1)) && + y == MaybeRelocatable::from(Felt252::new(2)) && + z == MaybeRelocatable::from(Felt252::new(34895349583295832495320945304_i128)) + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn run_add_no_check_ok() { + let mut vm = vm_with_range_check!(); + //Initialize fp + vm.run_context.fp = 10; + //Create hint_data + let ids_data = non_continuous_ids_data![ + ("a", -10), + ("b", -7), + ("carry_d0", -4), + ("carry_d1", -3), + ("carry_d2", -2) + ]; + //Insert ids into memory + vm.segments = segments![ + // a + ((1, 0), 3789423292314891293), + ((1, 1), 21894), + ((1, 2), 340282366920938463463374607431768211455_u128), + // b + ((1, 3), 32838232), + ((1, 4), 17), + ((1, 5), 8) + ]; + //Execute the hint + assert_matches!( + run_hint!( + vm, + ids_data, + hint_code::ADD_NO_UINT384_CHECK, + &mut exec_scopes_ref!(), + &[("path.path.path.SHIFT", Felt252::one().shl(128_u32))] + .into_iter() + .map(|(k, v)| (k.to_string(), v)) + .collect() + ), + Ok(()) + ); + //Check hint memory inserts + check_memory![ + vm.segments.memory, + // carry_d0 + ((1, 6), 0), + // carry_d1 + ((1, 7), 0), + // carry_d2 + ((1, 8), 1) + ]; + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn run_add_no_check_missing_constant() { + let mut vm = vm_with_range_check!(); + //Initialize fp + vm.run_context.fp = 10; + //Create hint_data + let ids_data = non_continuous_ids_data![ + ("a", -10), + ("b", -7), + ("carry_d0", -3), + ("carry_d1", -2), + ("carry_d2", -1) + ]; + //Insert ids into memory + vm.segments = segments![ + // a + ((1, 0), 3789423292314891293), + ((1, 1), 21894), + ((1, 2), 340282366920938463463374607431768211455_u128), + // b + ((1, 3), 32838232), + ((1, 4), 17), + ((1, 5), 8) + ]; + //Execute the hint + assert_matches!( + run_hint!(vm, ids_data, hint_code::ADD_NO_UINT384_CHECK), + Err(HintError::MissingConstant(s)) if s == "SHIFT" + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn run_unsigned_div_rem_expand_ok() { + let mut vm = vm_with_range_check!(); + //Initialize fp + vm.run_context.fp = 13; + //Create hint_data + let ids_data = + non_continuous_ids_data![("a", -13), ("div", -10), ("quotient", -3), ("remainder", 0)]; + //Insert ids into memory + vm.segments = segments![ + //a + ((1, 0), 83434123481193248), + ((1, 1), 82349321849739284), + ((1, 2), 839243219401320423), + //div + ((1, 3), 9283430921839492319493), + ((1, 4), 313248123482483248), + ((1, 5), 3790328402913840), + ((1, 6), 13), + ((1, 7), 78990), + ((1, 8), 109), + ((1, 9), 7) + ]; + //Execute the hint + assert_matches!( + run_hint!(vm, ids_data, hint_code::UINT384_UNSIGNED_DIV_REM_EXPANDED), + Ok(()) + ); + //Check hint memory inserts + check_memory![ + vm.segments.memory, + // quotient + ((1, 10), 7699479077076334), + ((1, 11), 0), + ((1, 12), 0), + // remainder + //((1, 13), 340279955073565776659831804641277151872), + //((1, 14), 340282366920938463463356863525615958397), + ((1, 15), 16) + ]; + assert_eq!( + vm.segments + .memory + .get_integer((1, 13).into()) + .unwrap() + .as_ref(), + &felt_str!("340279955073565776659831804641277151872") + ); + assert_eq!( + vm.segments + .memory + .get_integer((1, 14).into()) + .unwrap() + .as_ref(), + &felt_str!("340282366920938463463356863525615958397") + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn run_sqrt_ok() { + let mut vm = vm_with_range_check!(); + //Initialize fp + vm.run_context.fp = 5; + //Create hint_data + let ids_data = non_continuous_ids_data![("a", -5), ("root", -2)]; + //Insert ids into memory + vm.segments = segments![ + //a + ((1, 0), 83434123481193248), + ((1, 1), 82349321849739284), + ((1, 2), 839243219401320423) + ]; + //Execute the hint + assert_matches!(run_hint!(vm, ids_data, hint_code::UINT384_SQRT), Ok(())); + //Check hint memory inserts + check_memory![ + vm.segments.memory, + // root + ((1, 3), 100835122758113432298839930225328621183), + ((1, 4), 916102188), + ((1, 5), 0) + ]; + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn run_sqrt_assertion_fail() { + let mut vm = vm_with_range_check!(); + //Initialize fp + vm.run_context.fp = 5; + //Create hint_data + let ids_data = non_continuous_ids_data![("a", -5), ("root", -2)]; + //Insert ids into memory + //Insert ids into memory + vm.segments = segments![ + //a + ((1, 0), (-1)), + ((1, 1), (-1)), + ((1, 2), (-1)) + ]; + //Execute the hint + assert_matches!( + run_hint!(vm, ids_data, hint_code::UINT384_SQRT), + Err(HintError::AssertionFailed(s)) if s == "assert 0 <= root < 2 ** 192" + ); + } +} diff --git a/src/tests/cairo_run_test.rs b/src/tests/cairo_run_test.rs index 9c12f8b0e3..b0b7b131b4 100644 --- a/src/tests/cairo_run_test.rs +++ b/src/tests/cairo_run_test.rs @@ -1280,3 +1280,10 @@ fn cairo_run_is_quad_residue_test() { let program_data = include_bytes!("../../cairo_programs/is_quad_residue_test.json"); run_program_simple(program_data.as_slice()); } + +#[test] +#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] +fn cairo_run_uint384() { + let program_data = include_bytes!("../../cairo_programs/uint384.json"); + run_program_simple(program_data.as_slice()); +}