diff --git a/src/kakarot/instructions/stop_and_math_operations.cairo b/src/kakarot/instructions/stop_and_math_operations.cairo index 9f9e513b9..7156383ab 100644 --- a/src/kakarot/instructions/stop_and_math_operations.cairo +++ b/src/kakarot/instructions/stop_and_math_operations.cairo @@ -31,7 +31,7 @@ from kakarot.model import model from kakarot.execution_context import ExecutionContext from kakarot.stack import Stack from kakarot.errors import Errors -from utils.uint256 import uint256_exp, uint256_signextend +from utils.uint256 import uint256_fast_exp, uint256_signextend // @title Stop and Math operations opcodes. // @notice Math operations gathers Arithmetic and Comparison operations @@ -241,7 +241,7 @@ namespace StopAndMathOperations { let range_check_ptr = [ap - 2]; let popped = cast([ap - 1], Uint256*); - let result = uint256_exp(popped[0], popped[1]); + let result = uint256_fast_exp(popped[0], popped[1]); tempvar bitwise_ptr = cast([fp - 4], BitwiseBuiltin*); tempvar range_check_ptr = range_check_ptr; diff --git a/src/utils/uint256.cairo b/src/utils/uint256.cairo index a4b4b184c..8f33a1a6f 100644 --- a/src/utils/uint256.cairo +++ b/src/utils/uint256.cairo @@ -1,31 +1,31 @@ from starkware.cairo.common.uint256 import ( Uint256, uint256_eq, - uint256_le, uint256_sub, uint256_mul, - uint256_add, - uint256_pow2, uint256_unsigned_div_rem, + uint256_le, + uint256_pow2, + uint256_add, ) from starkware.cairo.common.bool import FALSE // @notice Internal exponentiation of two 256-bit integers. // @dev The result is modulo 2^256. -// @param a The base. -// @param b The exponent. +// @param value - The base. +// @param exponent - The exponent. // @return The result of the exponentiation. -func uint256_exp{range_check_ptr}(a: Uint256, b: Uint256) -> Uint256 { - let one_uint = Uint256(1, 0); - let zero_uint = Uint256(0, 0); +func uint256_exp{range_check_ptr}(value: Uint256, exponent: Uint256) -> Uint256 { + let one = Uint256(1, 0); + let zero = Uint256(0, 0); - let (is_b_zero) = uint256_eq(b, zero_uint); - if (is_b_zero != FALSE) { - return one_uint; + let (exponent_is_zero) = uint256_eq(exponent, zero); + if (exponent_is_zero != FALSE) { + return one; } - let (b_minus_one) = uint256_sub(b, one_uint); - let pow = uint256_exp(a, b_minus_one); - let (res, _) = uint256_mul(a, pow); + let (exponent_minus_one) = uint256_sub(exponent, one); + let pow = uint256_exp(value, exponent_minus_one); + let (res, _) = uint256_mul(value, pow); return res; } @@ -56,3 +56,38 @@ func uint256_signextend{range_check_ptr}(x: Uint256, byte_num: Uint256) -> Uint2 let (value, _) = uint256_add(value, padding); return value; } + +// @notice Internal fast exponentiation of two 256-bit integers. +// @dev The result is modulo 2^256. +// @param value - The base. +// @param exponent - The exponent. +// @return The result of the exponentiation. +func uint256_fast_exp{range_check_ptr}(value: Uint256, exponent: Uint256) -> Uint256 { + alloc_locals; + + let one = Uint256(1, 0); + let zero = Uint256(0, 0); + + let (exponent_is_zero) = uint256_eq(exponent, zero); + if (exponent_is_zero != FALSE) { + return one; + } + + let (exponent_is_one) = uint256_eq(exponent, one); + if (exponent_is_one != FALSE) { + return value; + } + + let (half_exponent, is_odd) = uint256_unsigned_div_rem(exponent, Uint256(2, 0)); + let pow = uint256_fast_exp(value, half_exponent); + + if (is_odd.low != FALSE) { + let (res, _) = uint256_mul(pow, pow); + let (res, _) = uint256_mul(res, value); + return res; + } + + let pow = uint256_fast_exp(value, half_exponent); + let (res, _) = uint256_mul(pow, pow); + return res; +} diff --git a/tests/src/kakarot/instructions/test_stop_and_math_operations.py b/tests/src/kakarot/instructions/test_stop_and_math_operations.py index 33f8b85ec..94c174aa9 100644 --- a/tests/src/kakarot/instructions/test_stop_and_math_operations.py +++ b/tests/src/kakarot/instructions/test_stop_and_math_operations.py @@ -37,6 +37,12 @@ class TestMathOperations: (Opcodes.MULMOD, [3, 2, 2], (3 * 2) % 2), (Opcodes.EXP, [3, 2], (3**2)), (Opcodes.EXP, [3, 1], (3**1)), + (Opcodes.EXP, [3, 0], (3**0)), + ( + Opcodes.EXP, + [0xFF, 0x11], + (0xFF**0x11), + ), ( Opcodes.SIGNEXTEND, [