diff --git a/starknet_programs/cairo1/square_root_recursive.cairo b/starknet_programs/cairo1/square_root_recursive.cairo new file mode 100644 index 000000000..a38b7e1fe --- /dev/null +++ b/starknet_programs/cairo1/square_root_recursive.cairo @@ -0,0 +1,24 @@ +#[abi] +trait Math { + #[external] + fn square_root(n: felt252) -> felt252; +} + +#[contract] +mod SquareRoot { + use super::MathDispatcherTrait; + use super::MathLibraryDispatcher; + use starknet::ClassHash; + + #[external] + fn square_root_recursive(n: felt252, math_class_hash: ClassHash, n_iterations: u32) -> felt252 { + square_root_recursive_inner(n, math_class_hash, n_iterations) + } + + fn square_root_recursive_inner(n: felt252, math_class_hash: ClassHash, n_iterations: u32) -> felt252 { + if n_iterations == 0 { + return n; + } + square_root_recursive_inner(MathLibraryDispatcher {class_hash: math_class_hash}.square_root(n), math_class_hash, n_iterations - 1) + } +} diff --git a/starknet_programs/cairo1/wallet_wrapper.cairo b/starknet_programs/cairo1/wallet_wrapper.cairo index d3557a309..ca65cef90 100644 --- a/starknet_programs/cairo1/wallet_wrapper.cairo +++ b/starknet_programs/cairo1/wallet_wrapper.cairo @@ -22,4 +22,13 @@ mod WalletWrapper { fn increase_balance(amount: felt252, simple_wallet_contract_address: ContractAddress) { SimpleWalletDispatcher {contract_address: simple_wallet_contract_address}.increase_balance(amount) } + + #[external] + fn increase_balance_recursive(amount: felt252, simple_wallet_contract_address: ContractAddress) { + if amount == 0 { + return(); + } + SimpleWalletDispatcher {contract_address: simple_wallet_contract_address}.increase_balance(1); + increase_balance_recursive(amount - 1, simple_wallet_contract_address) + } } diff --git a/starknet_programs/cairo2/square_root_recursive.cairo b/starknet_programs/cairo2/square_root_recursive.cairo new file mode 100644 index 000000000..a960e572e --- /dev/null +++ b/starknet_programs/cairo2/square_root_recursive.cairo @@ -0,0 +1,37 @@ +use starknet::ClassHash; + +#[starknet::interface] +trait Math { + fn square_root(self: @TContractState, n: felt252) -> felt252; +} + +#[starknet::interface] +trait ISquareRoot { + fn square_root(self: @TContractState, n: felt252, math_class_hash: ClassHash, n_iterations: u32) -> felt252; +} + + +#[starknet::contract] +mod SquareRoot { + use super::MathDispatcherTrait; + use super::MathLibraryDispatcher; + use starknet::ClassHash; + + #[storage] + struct Storage{ + } + + #[external(v0)] + impl SquareRoot of super::ISquareRoot { + fn square_root(self: @ContractState, n: felt252, math_class_hash: ClassHash, n_iterations: u32) -> felt252 { + square_root_recursive_inner(n, math_class_hash, n_iterations) + } + } + + fn square_root_recursive_inner(n: felt252, math_class_hash: ClassHash, n_iterations: u32) -> felt252 { + if n_iterations == 0 { + return n; + } + square_root_recursive_inner(MathLibraryDispatcher {class_hash: math_class_hash}.square_root(n), math_class_hash, n_iterations - 1) + } +} diff --git a/starknet_programs/cairo2/wallet_wrapper.cairo b/starknet_programs/cairo2/wallet_wrapper.cairo index 5208f8f73..d4a4a241b 100644 --- a/starknet_programs/cairo2/wallet_wrapper.cairo +++ b/starknet_programs/cairo2/wallet_wrapper.cairo @@ -7,7 +7,8 @@ trait SimpleWallet { #[starknet::interface] trait IWalletWrapper { fn get_balance(self: @TContractState, simple_wallet_contract_address: starknet::ContractAddress) -> felt252; - fn increase_balance(ref self: TContractState, amount: felt252, simple_wallet_contract_address: starknet::ContractAddress); + fn increase_balance(ref self: TContractState, amount: felt252, simple_wallet_contract_address: starknet::ContractAddress); + fn increase_balance_recursive(ref self: TContractState, amount: felt252, simple_wallet_contract_address: starknet::ContractAddress); } #[starknet::contract] @@ -28,5 +29,16 @@ mod WalletWrapper { fn increase_balance(ref self: ContractState, amount: felt252, simple_wallet_contract_address: ContractAddress) { SimpleWalletDispatcher {contract_address: simple_wallet_contract_address}.increase_balance(amount) } + fn increase_balance_recursive(ref self: ContractState, amount: felt252, simple_wallet_contract_address: ContractAddress) { + increase_balance_recursive_inner(amount, simple_wallet_contract_address) + } + } + + fn increase_balance_recursive_inner(amount: felt252, simple_wallet_contract_address: ContractAddress) { + if amount == 0 { + return(); + } + SimpleWalletDispatcher {contract_address: simple_wallet_contract_address}.increase_balance(1); + increase_balance_recursive_inner(amount - 1, simple_wallet_contract_address) } } diff --git a/tests/cairo_1_syscalls.rs b/tests/cairo_1_syscalls.rs index 4cbe0db4e..7a7888335 100644 --- a/tests/cairo_1_syscalls.rs +++ b/tests/cairo_1_syscalls.rs @@ -11,7 +11,6 @@ use cairo_vm::{ use num_bigint::BigUint; use num_traits::{Num, One, Zero}; use pretty_assertions_sorted::{assert_eq, assert_eq_sorted}; -use starknet_in_rust::EntryPointType; use starknet_in_rust::{ definitions::{block_context::BlockContext, constants::TRANSACTION_VERSION}, execution::{ @@ -25,6 +24,7 @@ use starknet_in_rust::{ state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager}, utils::{Address, ClassHash}, }; +use starknet_in_rust::{utils::calculate_sn_keccak, EntryPointType}; fn create_execute_extrypoint( address: Address, @@ -354,9 +354,10 @@ fn call_contract_storage_write_read() { let program_data = include_bytes!("../starknet_programs/cairo1/wallet_wrapper.casm"); let contract_class: CasmContractClass = serde_json::from_slice(program_data).unwrap(); - let entrypoints = contract_class.clone().entry_points_by_type; - let get_balance_entrypoint_selector = &entrypoints.external.get(1).unwrap().selector; - let increase_balance_entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; + let get_balance_entrypoint_selector = + &BigUint::from_bytes_be(&calculate_sn_keccak("get_balance".as_bytes())); + let increase_balance_entrypoint_selector = + &BigUint::from_bytes_be(&calculate_sn_keccak("increase_balance".as_bytes())); // Create state reader with class hash data let mut contract_class_cache = HashMap::new(); @@ -3044,3 +3045,536 @@ fn keccak_syscall() { assert_eq!(retdata[0], Felt252::one()); } + +#[test] +fn library_call_recursive_50_calls() { + // Create program and entry point types for contract class + #[cfg(not(feature = "cairo_1_tests"))] + let program_data = include_bytes!("../starknet_programs/cairo2/square_root_recursive.casm"); + #[cfg(feature = "cairo_1_tests")] + let program_data = include_bytes!("../starknet_programs/cairo1/square_root_recursive.casm"); + + let contract_class: CasmContractClass = serde_json::from_slice(program_data).unwrap(); + let entrypoints = contract_class.clone().entry_points_by_type; + let entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; + + // Create state reader with class hash data + let mut contract_class_cache = HashMap::new(); + + let address = Address(1111.into()); + let class_hash: ClassHash = [1; 32]; + let nonce = Felt252::zero(); + + contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + let mut state_reader = InMemoryStateReader::default(); + state_reader + .address_to_class_hash_mut() + .insert(address.clone(), class_hash); + state_reader + .address_to_nonce_mut() + .insert(address.clone(), nonce); + + // Add lib contract to the state + + #[cfg(not(feature = "cairo_1_tests"))] + let lib_program_data = include_bytes!("../starknet_programs/cairo2/math_lib.casm"); + #[cfg(feature = "cairo_1_tests")] + let lib_program_data = include_bytes!("../starknet_programs/cairo1/math_lib.casm"); + + let lib_contract_class: CasmContractClass = serde_json::from_slice(lib_program_data).unwrap(); + + let lib_address = Address(1112.into()); + let lib_class_hash: ClassHash = [2; 32]; + let lib_nonce = Felt252::zero(); + + contract_class_cache.insert( + lib_class_hash, + CompiledClass::Casm(Arc::new(lib_contract_class)), + ); + state_reader + .address_to_class_hash_mut() + .insert(lib_address.clone(), lib_class_hash); + state_reader + .address_to_nonce_mut() + .insert(lib_address, lib_nonce); + + // Create state from the state_reader and contract cache. + let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + + // Create an execution entry point + let calldata = [ + felt_str!("1125899906842624"), + Felt252::from_bytes_be(&lib_class_hash), + Felt252::from(50), + ] + .to_vec(); + let caller_address = Address(0000.into()); + let entry_point_type = EntryPointType::External; + + let exec_entry_point = ExecutionEntryPoint::new( + address, + calldata, + Felt252::new(entrypoint_selector.clone()), + caller_address, + entry_point_type, + Some(CallType::Delegate), + Some(class_hash), + u128::MAX, + ); + + // Execute the entrypoint + let block_context = BlockContext::default(); + let mut tx_execution_context = TransactionExecutionContext::new( + Address(0.into()), + Felt252::zero(), + Vec::new(), + 0, + 10.into(), + block_context.invoke_tx_max_n_steps(), + TRANSACTION_VERSION.clone(), + ); + let mut resources_manager = ExecutionResourcesManager::default(); + let expected_execution_resources_internal_call = ExecutionResources { + #[cfg(not(feature = "cairo_1_tests"))] + n_steps: 80, + #[cfg(feature = "cairo_1_tests")] + n_steps: 85, + n_memory_holes: 5, + builtin_instance_counter: HashMap::from([(RANGE_CHECK_BUILTIN_NAME.to_string(), 7)]), + }; + + let call_info = exec_entry_point + .execute( + &mut state, + &block_context, + &mut resources_manager, + &mut tx_execution_context, + false, + block_context.invoke_tx_max_n_steps(), + ) + .unwrap() + .call_info + .unwrap(); + + assert_eq!(call_info.internal_calls.len(), 50); + assert_eq!( + call_info.internal_calls[0], + CallInfo { + caller_address: Address(0.into()), + call_type: Some(CallType::Delegate), + contract_address: Address(1111.into()), + entry_point_selector: Some( + Felt252::from_str_radix( + "544923964202674311881044083303061611121949089655923191939299897061511784662", + 10, + ) + .unwrap(), + ), + entry_point_type: Some(EntryPointType::External), + calldata: vec![felt_str!("1125899906842624")], + retdata: [felt_str!("33554432")].to_vec(), + execution_resources: Some(expected_execution_resources_internal_call), + class_hash: Some(lib_class_hash), + gas_consumed: 0, + ..Default::default() + } + ); + assert_eq!(call_info.retdata, [1.into()].to_vec()); + assert!(!call_info.failure_flag); +} + +#[test] +fn call_contract_storage_write_read_recursive_50_calls() { + // Create program and entry point types for contract class + #[cfg(not(feature = "cairo_1_tests"))] + let program_data = include_bytes!("../starknet_programs/cairo2/wallet_wrapper.casm"); + #[cfg(feature = "cairo_1_tests")] + let program_data = include_bytes!("../starknet_programs/cairo1/wallet_wrapper.casm"); + + let contract_class: CasmContractClass = serde_json::from_slice(program_data).unwrap(); + let get_balance_entrypoint_selector = + &BigUint::from_bytes_be(&calculate_sn_keccak("get_balance".as_bytes())); + let increase_balance_entrypoint_selector = &BigUint::from_bytes_be(&calculate_sn_keccak( + "increase_balance_recursive".as_bytes(), + )); + + // Create state reader with class hash data + let mut contract_class_cache = HashMap::new(); + + let address = Address(1111.into()); + let class_hash: ClassHash = [1; 32]; + let nonce = Felt252::zero(); + + contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + let mut state_reader = InMemoryStateReader::default(); + state_reader + .address_to_class_hash_mut() + .insert(address.clone(), class_hash); + state_reader + .address_to_nonce_mut() + .insert(address.clone(), nonce); + + // Add simple_wallet contract to the state + #[cfg(not(feature = "cairo_1_tests"))] + let simple_wallet_program_data = + include_bytes!("../starknet_programs/cairo2/simple_wallet.casm"); + #[cfg(feature = "cairo_1_tests")] + let simple_wallet_program_data = + include_bytes!("../starknet_programs/cairo1/simple_wallet.casm"); + + let simple_wallet_contract_class: CasmContractClass = + serde_json::from_slice(simple_wallet_program_data).unwrap(); + let simple_wallet_constructor_entrypoint_selector = simple_wallet_contract_class + .entry_points_by_type + .constructor + .get(0) + .unwrap() + .selector + .clone(); + + let simple_wallet_address = Address(1112.into()); + let simple_wallet_class_hash: ClassHash = [2; 32]; + let simple_wallet_nonce = Felt252::zero(); + + contract_class_cache.insert( + simple_wallet_class_hash, + CompiledClass::Casm(Arc::new(simple_wallet_contract_class)), + ); + state_reader + .address_to_class_hash_mut() + .insert(simple_wallet_address.clone(), simple_wallet_class_hash); + state_reader + .address_to_nonce_mut() + .insert(simple_wallet_address.clone(), simple_wallet_nonce); + + // Create state from the state_reader and contract cache. + let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + + let block_context = BlockContext::default(); + let mut tx_execution_context = TransactionExecutionContext::new( + Address(0.into()), + Felt252::zero(), + Vec::new(), + 0, + 10.into(), + block_context.invoke_tx_max_n_steps(), + TRANSACTION_VERSION.clone(), + ); + + let mut resources_manager = ExecutionResourcesManager::default(); + + let create_execute_extrypoint = |selector: &BigUint, + calldata: Vec, + entry_point_type: EntryPointType, + class_hash: [u8; 32], + address: Address| + -> ExecutionEntryPoint { + ExecutionEntryPoint::new( + address, + calldata, + Felt252::new(selector.clone()), + Address(0000.into()), + entry_point_type, + Some(CallType::Delegate), + Some(class_hash), + u64::MAX.into(), + ) + }; + + // RUN SIMPLE_WALLET CONSTRUCTOR + // Create an execution entry point + let calldata = [25.into()].to_vec(); + let constructor_exec_entry_point = create_execute_extrypoint( + &simple_wallet_constructor_entrypoint_selector, + calldata, + EntryPointType::Constructor, + simple_wallet_class_hash, + simple_wallet_address.clone(), + ); + + // Run constructor entrypoint + constructor_exec_entry_point + .execute( + &mut state, + &block_context, + &mut resources_manager, + &mut tx_execution_context, + false, + block_context.invoke_tx_max_n_steps(), + ) + .unwrap(); + + // RUN GET_BALANCE + // Create an execution entry point + let calldata = [simple_wallet_address.0.clone()].to_vec(); + let get_balance_exec_entry_point = create_execute_extrypoint( + get_balance_entrypoint_selector, + calldata, + EntryPointType::External, + class_hash, + address.clone(), + ); + + // Run get_balance entrypoint + let call_info = get_balance_exec_entry_point + .execute( + &mut state, + &block_context, + &mut resources_manager, + &mut tx_execution_context, + false, + block_context.invoke_tx_max_n_steps(), + ) + .unwrap(); + assert_eq!(call_info.call_info.unwrap().retdata, [25.into()]); + + // RUN INCREASE_BALANCE + // Create an execution entry point + let calldata = [50.into(), simple_wallet_address.0.clone()].to_vec(); + let increase_balance_entry_point = create_execute_extrypoint( + increase_balance_entrypoint_selector, + calldata, + EntryPointType::External, + class_hash, + address.clone(), + ); + + // Run increase_balance entrypoint + let call_info = increase_balance_entry_point + .execute( + &mut state, + &block_context, + &mut resources_manager, + &mut tx_execution_context, + false, + block_context.invoke_tx_max_n_steps(), + ) + .unwrap() + .call_info + .unwrap(); + // Check that the recursive function did in fact call the simple_wallet contract 50 times + assert_eq!(call_info.internal_calls.len(), 50); + assert!(!call_info.failure_flag); + + // RUN GET_BALANCE + // Create an execution entry point + let calldata = [simple_wallet_address.0].to_vec(); + let get_balance_exec_entry_point = create_execute_extrypoint( + get_balance_entrypoint_selector, + calldata, + EntryPointType::External, + class_hash, + address, + ); + + // Run get_balance entrypoint + let call_info = get_balance_exec_entry_point + .execute( + &mut state, + &block_context, + &mut resources_manager, + &mut tx_execution_context, + false, + block_context.invoke_tx_max_n_steps(), + ) + .unwrap(); + assert_eq!(call_info.call_info.unwrap().retdata, [75.into()]) +} + +#[test] +fn call_contract_storage_write_read_recursive_100_calls() { + // Create program and entry point types for contract class + #[cfg(not(feature = "cairo_1_tests"))] + let program_data = include_bytes!("../starknet_programs/cairo2/wallet_wrapper.casm"); + #[cfg(feature = "cairo_1_tests")] + let program_data = include_bytes!("../starknet_programs/cairo1/wallet_wrapper.casm"); + + let contract_class: CasmContractClass = serde_json::from_slice(program_data).unwrap(); + let get_balance_entrypoint_selector = + &BigUint::from_bytes_be(&calculate_sn_keccak("get_balance".as_bytes())); + let increase_balance_entrypoint_selector = &BigUint::from_bytes_be(&calculate_sn_keccak( + "increase_balance_recursive".as_bytes(), + )); + + // Create state reader with class hash data + let mut contract_class_cache = HashMap::new(); + + let address = Address(1111.into()); + let class_hash: ClassHash = [1; 32]; + let nonce = Felt252::zero(); + + contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + let mut state_reader = InMemoryStateReader::default(); + state_reader + .address_to_class_hash_mut() + .insert(address.clone(), class_hash); + state_reader + .address_to_nonce_mut() + .insert(address.clone(), nonce); + + // Add simple_wallet contract to the state + #[cfg(not(feature = "cairo_1_tests"))] + let simple_wallet_program_data = + include_bytes!("../starknet_programs/cairo2/simple_wallet.casm"); + #[cfg(feature = "cairo_1_tests")] + let simple_wallet_program_data = + include_bytes!("../starknet_programs/cairo1/simple_wallet.casm"); + + let simple_wallet_contract_class: CasmContractClass = + serde_json::from_slice(simple_wallet_program_data).unwrap(); + let simple_wallet_constructor_entrypoint_selector = simple_wallet_contract_class + .entry_points_by_type + .constructor + .get(0) + .unwrap() + .selector + .clone(); + + let simple_wallet_address = Address(1112.into()); + let simple_wallet_class_hash: ClassHash = [2; 32]; + let simple_wallet_nonce = Felt252::zero(); + + contract_class_cache.insert( + simple_wallet_class_hash, + CompiledClass::Casm(Arc::new(simple_wallet_contract_class)), + ); + state_reader + .address_to_class_hash_mut() + .insert(simple_wallet_address.clone(), simple_wallet_class_hash); + state_reader + .address_to_nonce_mut() + .insert(simple_wallet_address.clone(), simple_wallet_nonce); + + // Create state from the state_reader and contract cache. + let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + + let block_context = BlockContext::default(); + let mut tx_execution_context = TransactionExecutionContext::new( + Address(0.into()), + Felt252::zero(), + Vec::new(), + 0, + 10.into(), + block_context.invoke_tx_max_n_steps(), + TRANSACTION_VERSION.clone(), + ); + + let mut resources_manager = ExecutionResourcesManager::default(); + + let create_execute_extrypoint = |selector: &BigUint, + calldata: Vec, + entry_point_type: EntryPointType, + class_hash: [u8; 32], + address: Address| + -> ExecutionEntryPoint { + ExecutionEntryPoint::new( + address, + calldata, + Felt252::new(selector.clone()), + Address(0000.into()), + entry_point_type, + Some(CallType::Delegate), + Some(class_hash), + u64::MAX.into(), + ) + }; + + // RUN SIMPLE_WALLET CONSTRUCTOR + // Create an execution entry point + let calldata = [25.into()].to_vec(); + let constructor_exec_entry_point = create_execute_extrypoint( + &simple_wallet_constructor_entrypoint_selector, + calldata, + EntryPointType::Constructor, + simple_wallet_class_hash, + simple_wallet_address.clone(), + ); + + // Run constructor entrypoint + constructor_exec_entry_point + .execute( + &mut state, + &block_context, + &mut resources_manager, + &mut tx_execution_context, + false, + block_context.invoke_tx_max_n_steps(), + ) + .unwrap(); + + // RUN GET_BALANCE + // Create an execution entry point + let calldata = [simple_wallet_address.0.clone()].to_vec(); + let get_balance_exec_entry_point = create_execute_extrypoint( + get_balance_entrypoint_selector, + calldata, + EntryPointType::External, + class_hash, + address.clone(), + ); + + // Run get_balance entrypoint + let call_info = get_balance_exec_entry_point + .execute( + &mut state, + &block_context, + &mut resources_manager, + &mut tx_execution_context, + false, + block_context.invoke_tx_max_n_steps(), + ) + .unwrap(); + assert_eq!(call_info.call_info.unwrap().retdata, [25.into()]); + + // RUN INCREASE_BALANCE + // Create an execution entry point + let calldata = [100.into(), simple_wallet_address.0.clone()].to_vec(); + let increase_balance_entry_point = create_execute_extrypoint( + increase_balance_entrypoint_selector, + calldata, + EntryPointType::External, + class_hash, + address.clone(), + ); + + // Run increase_balance entrypoint + let call_info = increase_balance_entry_point + .execute( + &mut state, + &block_context, + &mut resources_manager, + &mut tx_execution_context, + false, + block_context.invoke_tx_max_n_steps(), + ) + .unwrap() + .call_info + .unwrap(); + // Check that the recursive function did in fact call the simple_wallet contract 50 times + assert_eq!(call_info.internal_calls.len(), 100); + assert!(!call_info.failure_flag); + + // RUN GET_BALANCE + // Create an execution entry point + let calldata = [simple_wallet_address.0].to_vec(); + let get_balance_exec_entry_point = create_execute_extrypoint( + get_balance_entrypoint_selector, + calldata, + EntryPointType::External, + class_hash, + address, + ); + + // Run get_balance entrypoint + let call_info = get_balance_exec_entry_point + .execute( + &mut state, + &block_context, + &mut resources_manager, + &mut tx_execution_context, + false, + block_context.invoke_tx_max_n_steps(), + ) + .unwrap(); + assert_eq!(call_info.call_info.unwrap().retdata, [125.into()]) +}