diff --git a/src/lib.rs b/src/lib.rs index 222c125a7..ba1aa607c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,8 @@ #![deny(warnings)] #![forbid(unsafe_code)] #![cfg_attr(coverage_nightly, feature(no_coverage))] +use std::collections::HashMap; + use crate::{ execution::{ execution_entry_point::ExecutionEntryPoint, CallType, TransactionExecutionContext, @@ -15,7 +17,6 @@ use crate::{ use definitions::block_context::BlockContext; use state::cached_state::CachedState; -use transaction::InvokeFunction; use transaction::L1Handler; use utils::Address; @@ -46,16 +47,18 @@ pub mod transaction; pub mod utils; pub fn simulate_transaction( - transaction: &InvokeFunction, + transaction: &Transaction, state: S, block_context: BlockContext, remaining_gas: u128, skip_validate: bool, skip_execute: bool, + skip_fee_transfer: bool, ) -> Result { + let mut cached_state = CachedState::new(state, None, Some(HashMap::new())); let tx_for_simulation = - transaction.create_for_simulation(transaction.clone(), skip_validate, skip_execute); - tx_for_simulation.simulate_transaction(state, block_context, remaining_gas) + transaction.create_for_simulation(skip_validate, skip_execute, skip_fee_transfer); + tx_for_simulation.execute(&mut cached_state, &block_context, remaining_gas) } /// Estimate the fee associated with transaction @@ -171,18 +174,34 @@ pub fn execute_transaction( #[cfg(test)] mod test { use std::collections::HashMap; + use std::fs::File; + use std::io::BufReader; use std::path::PathBuf; use crate::core::contract_address::compute_deprecated_class_hash; - use crate::definitions::block_context::StarknetChainId; - use crate::definitions::constants::EXECUTE_ENTRY_POINT_SELECTOR; + use crate::definitions::{ + block_context::StarknetChainId, + constants::{ + EXECUTE_ENTRY_POINT_SELECTOR, VALIDATE_DECLARE_ENTRY_POINT_SELECTOR, + VALIDATE_ENTRY_POINT_SELECTOR, + }, + transaction_type::TransactionType, + }; use crate::estimate_fee; use crate::estimate_message_fee; + use crate::hash_utils::calculate_contract_address; use crate::services::api::contract_classes::deprecated_contract_class::ContractClass; - use crate::testing::{create_account_tx_test_state, TEST_CONTRACT_ADDRESS, TEST_CONTRACT_PATH}; - use crate::transaction::{InvokeFunction, L1Handler, Transaction}; + use crate::state::state_api::State; + use crate::testing::{ + create_account_tx_test_state, TEST_ACCOUNT_CONTRACT_ADDRESS, TEST_CONTRACT_ADDRESS, + TEST_CONTRACT_PATH, TEST_FIB_COMPILED_CONTRACT_CLASS_HASH, + }; + use crate::transaction::{ + Declare, DeclareV2, Deploy, DeployAccount, InvokeFunction, L1Handler, Transaction, + }; use crate::utils::felt_to_hash; use cairo_lang_starknet::casm_contract_class::CasmContractClass; + use cairo_lang_starknet::contract_class::ContractClass as SierraContractClass; use cairo_vm::felt::{felt_str, Felt252}; use num_traits::{Num, One, Zero}; use starknet_contract_class::EntryPointType; @@ -198,6 +217,27 @@ mod test { utils::{Address, ClassHash}, }; + use lazy_static::lazy_static; + + lazy_static! { + // include_str! doesn't seem to work in CI + static ref CONTRACT_CLASS: ContractClass = ContractClass::try_from(BufReader::new(File::open( + "starknet_programs/account_without_validation.json", + ).unwrap())) + .unwrap(); + static ref CLASS_HASH: Felt252 = compute_deprecated_class_hash(&CONTRACT_CLASS).unwrap(); + static ref CLASS_HASH_BYTES: [u8; 32] = CLASS_HASH.clone().to_be_bytes(); + static ref SALT: Felt252 = felt_str!( + "2669425616857739096022668060305620640217901643963991674344872184515580705509" + ); + static ref CONTRACT_ADDRESS: Address = Address(calculate_contract_address(&SALT.clone(), &CLASS_HASH.clone(), &[], Address(0.into())).unwrap()); + static ref SIGNATURE: Vec = vec![ + felt_str!("3233776396904427614006684968846859029149676045084089832563834729503047027074"), + felt_str!("707039245213420890976709143988743108543645298941971188668773816813012281203"), + ]; + pub static ref TRANSACTION_VERSION: Felt252 = 1.into(); + } + #[test] fn estimate_fee_test() { let contract_class: ContractClass = @@ -365,7 +405,10 @@ mod test { .unwrap(); let block_context = BlockContext::default(); - let simul_invoke = invoke.create_for_simulation(invoke.clone(), true, false); + let Transaction::InvokeFunction(simul_invoke) = + invoke.create_for_simulation(true, false, false) else { + unreachable!() + }; let call_info = simul_invoke .run_validate_entrypoint( @@ -453,13 +496,22 @@ mod test { let block_context = BlockContext::default(); - let context = - simulate_transaction(&invoke, state_reader, block_context, 1000, false, true).unwrap(); + let context = simulate_transaction( + &Transaction::InvokeFunction(invoke), + state_reader, + block_context, + 1000, + false, + true, + true, + ) + .unwrap(); assert!(context.validate_info.is_some()); assert!(context.call_info.is_none()); assert!(context.fee_transfer_info.is_none()); } + #[test] fn test_skip_execute_and_validate_flags() { let path = PathBuf::from("starknet_programs/account_without_validation.json"); @@ -535,11 +587,302 @@ mod test { let block_context = BlockContext::default(); - let context = - simulate_transaction(&invoke, state_reader, block_context, 1000, true, true).unwrap(); + let context = simulate_transaction( + &Transaction::InvokeFunction(invoke), + state_reader, + block_context, + 1000, + true, + true, + true, + ) + .unwrap(); assert!(context.validate_info.is_none()); assert!(context.call_info.is_none()); assert!(context.fee_transfer_info.is_none()); } + + #[test] + fn test_simulate_deploy() { + let state_reader = InMemoryStateReader::default(); + let mut state = CachedState::new(state_reader, Some(Default::default()), None); + + state + .set_contract_class(&CLASS_HASH_BYTES, &CONTRACT_CLASS) + .unwrap(); + + let block_context = Default::default(); + let salt = felt_str!( + "2669425616857739096022668060305620640217901643963991674344872184515580705509" + ); + // new consumes more execution time than raw struct instantiation + let internal_deploy = Transaction::Deploy( + Deploy::new( + salt, + CONTRACT_CLASS.clone(), + vec![], + StarknetChainId::TestNet.to_felt(), + 0.into(), + None, + ) + .unwrap(), + ); + + simulate_transaction( + &internal_deploy, + state, + block_context, + 100_000_000, + false, + false, + false, + ) + .unwrap(); + } + + #[test] + fn test_simulate_declare() { + let state_reader = InMemoryStateReader::default(); + let state = CachedState::new(state_reader, Some(Default::default()), None); + + let block_context = Default::default(); + + let class = CONTRACT_CLASS.clone(); + let address = CONTRACT_ADDRESS.clone(); + // new consumes more execution time than raw struct instantiation + let declare_tx = Transaction::Declare( + Declare::new( + class, + StarknetChainId::TestNet.to_felt(), + address, + 0, + 0.into(), + vec![], + Felt252::zero(), + None, + ) + .expect("couldn't create transaction"), + ); + + simulate_transaction( + &declare_tx, + state, + block_context, + 100_000_000, + false, + false, + false, + ) + .unwrap(); + } + + #[test] + fn test_simulate_invoke() { + let state_reader = InMemoryStateReader::default(); + let mut state = CachedState::new(state_reader, Some(Default::default()), None); + + state + .set_contract_class(&CLASS_HASH_BYTES, &CONTRACT_CLASS) + .unwrap(); + + let block_context = Default::default(); + + let salt = felt_str!( + "2669425616857739096022668060305620640217901643963991674344872184515580705509" + ); + let class = CONTRACT_CLASS.clone(); + let deploy = Deploy::new( + salt, + class, + vec![], + StarknetChainId::TestNet.to_felt(), + 0.into(), + None, + ) + .unwrap(); + + let _deploy_exec_info = deploy.execute(&mut state, &block_context).unwrap(); + + let selector = VALIDATE_ENTRY_POINT_SELECTOR.clone(); + let calldata = vec![ + CONTRACT_ADDRESS.0.clone(), + selector.clone(), + Felt252::zero(), + ]; + // new consumes more execution time than raw struct instantiation + let invoke_tx = Transaction::InvokeFunction( + InvokeFunction::new( + CONTRACT_ADDRESS.clone(), + selector, + 0, + TRANSACTION_VERSION.clone(), + calldata, + SIGNATURE.clone(), + StarknetChainId::TestNet.to_felt(), + Some(Felt252::zero()), + None, + ) + .unwrap(), + ); + + simulate_transaction( + &invoke_tx, + state, + block_context, + 100_000_000, + false, + false, + false, + ) + .unwrap(); + } + + #[test] + fn test_simulate_deploy_account() { + let state_reader = InMemoryStateReader::default(); + let mut state = CachedState::new(state_reader, Some(Default::default()), None); + + state + .set_contract_class(&CLASS_HASH_BYTES, &CONTRACT_CLASS) + .unwrap(); + + let block_context = Default::default(); + + // new consumes more execution time than raw struct instantiation + let deploy_account_tx = &Transaction::DeployAccount( + DeployAccount::new( + *CLASS_HASH_BYTES, + 0, + 0.into(), + Felt252::zero(), + vec![], + SIGNATURE.clone(), + SALT.clone(), + StarknetChainId::TestNet.to_felt(), + None, + ) + .unwrap(), + ); + + simulate_transaction( + deploy_account_tx, + state, + block_context, + 100_000_000, + false, + false, + false, + ) + .unwrap(); + } + + fn declarev2_tx() -> DeclareV2 { + let program_data = include_bytes!("../starknet_programs/cairo1/fibonacci.sierra"); + let sierra_contract_class: SierraContractClass = + serde_json::from_slice(program_data).unwrap(); + + DeclareV2 { + sender_address: TEST_ACCOUNT_CONTRACT_ADDRESS.clone(), + tx_type: TransactionType::Declare, + validate_entry_point_selector: VALIDATE_DECLARE_ENTRY_POINT_SELECTOR.clone(), + version: 1.into(), + max_fee: 2, + signature: vec![], + nonce: 0.into(), + hash_value: 0.into(), + compiled_class_hash: TEST_FIB_COMPILED_CONTRACT_CLASS_HASH.clone(), + sierra_contract_class, + casm_class: Default::default(), + skip_execute: false, + skip_fee_transfer: false, + skip_validate: false, + } + } + + #[test] + fn test_simulate_declare_v2() { + let (block_context, state) = create_account_tx_test_state().unwrap(); + let declare_tx = Transaction::DeclareV2(Box::new(declarev2_tx())); + + simulate_transaction( + &declare_tx, + state, + block_context, + 100_000_000, + false, + false, + false, + ) + .unwrap(); + } + + #[test] + fn test_simulate_l1_handler() { + let l1_handler_tx = Transaction::L1Handler( + L1Handler::new( + Address(0.into()), + Felt252::from_str_radix( + "c73f681176fc7b3f9693986fd7b14581e8d540519e27400e88b8713932be01", + 16, + ) + .unwrap(), + vec![ + Felt252::from_str_radix("8359E4B0152ed5A731162D3c7B0D8D56edB165A0", 16) + .unwrap(), + 1.into(), + 10.into(), + ], + 0.into(), + 0.into(), + Some(10000.into()), + ) + .unwrap(), + ); + + // Instantiate CachedState + let mut state_reader = InMemoryStateReader::default(); + // Set contract_class + let class_hash = [1; 32]; + let contract_class = + ContractClass::try_from(PathBuf::from("starknet_programs/l1l2.json")).unwrap(); + // Set contact_state + let contract_address = Address(0.into()); + let nonce = Felt252::zero(); + + state_reader + .address_to_class_hash_mut() + .insert(contract_address.clone(), class_hash); + state_reader + .address_to_nonce + .insert(contract_address, nonce); + + let mut state = CachedState::new(state_reader.clone(), None, None); + + // Initialize state.contract_classes + state.set_contract_classes(HashMap::new()).unwrap(); + + state + .set_contract_class(&class_hash, &contract_class) + .unwrap(); + + let mut block_context = BlockContext::default(); + block_context.cairo_resource_fee_weights = HashMap::from([ + (String::from("l1_gas_usage"), 0.into()), + (String::from("pedersen_builtin"), 16.into()), + (String::from("range_check_builtin"), 70.into()), + ]); + block_context.starknet_os_config.gas_price = 1; + + simulate_transaction( + &l1_handler_tx, + state, + block_context, + 100_000_000, + false, + false, + false, + ) + .unwrap(); + } } diff --git a/src/transaction/declare.rs b/src/transaction/declare.rs index 4123161d1..363100104 100644 --- a/src/transaction/declare.rs +++ b/src/transaction/declare.rs @@ -28,10 +28,12 @@ use num_traits::Zero; use starknet_contract_class::EntryPointType; use std::collections::HashMap; +use super::Transaction; + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Represents an internal transaction in the StarkNet network that is a declaration of a Cairo /// contract class. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Declare { pub class_hash: ClassHash, pub sender_address: Address, @@ -43,6 +45,9 @@ pub struct Declare { pub nonce: Felt252, pub hash_value: Felt252, pub contract_class: ContractClass, + pub skip_validate: bool, + pub skip_execute: bool, + pub skip_fee_transfer: bool, } // ------------------------------------------------------------ @@ -88,6 +93,9 @@ impl Declare { nonce, hash_value, contract_class, + skip_execute: false, + skip_validate: false, + skip_fee_transfer: false, }; internal_declare.verify_version()?; @@ -128,9 +136,11 @@ impl Declare { // validate transaction let mut resources_manager = ExecutionResourcesManager::default(); - let validate_info = - self.run_validate_entrypoint(state, &mut resources_manager, block_context)?; - + let validate_info = if self.skip_validate { + None + } else { + self.run_validate_entrypoint(state, &mut resources_manager, block_context)? + }; let changes = state.count_actual_storage_changes(); let actual_resources = calculate_tx_resources( resources_manager, @@ -222,10 +232,17 @@ impl Declare { let mut tx_execution_context = self.get_execution_context(block_context.invoke_tx_max_n_steps); - let fee_transfer_info = - execute_fee_transfer(state, block_context, &mut tx_execution_context, actual_fee)?; - - Ok((Some(fee_transfer_info), actual_fee)) + let fee_transfer_info = if self.skip_fee_transfer { + None + } else { + Some(execute_fee_transfer( + state, + block_context, + &mut tx_execution_context, + actual_fee, + )?) + }; + Ok((fee_transfer_info, actual_fee)) } fn handle_nonce(&self, state: &mut S) -> Result<(), TransactionError> { @@ -280,6 +297,22 @@ impl Declare { ), ) } + + pub(crate) fn create_for_simulation( + &self, + skip_validate: bool, + skip_execute: bool, + skip_fee_transfer: bool, + ) -> Transaction { + let tx = Declare { + skip_validate, + skip_execute, + skip_fee_transfer, + ..self.clone() + }; + + Transaction::Declare(tx) + } } // --------------- diff --git a/src/transaction/declare_v2.rs b/src/transaction/declare_v2.rs index 55b515d85..9416dcca0 100644 --- a/src/transaction/declare_v2.rs +++ b/src/transaction/declare_v2.rs @@ -24,7 +24,9 @@ use cairo_vm::felt::Felt252; use num_traits::Zero; use starknet_contract_class::EntryPointType; use std::collections::HashMap; -#[derive(Debug)] + +use super::Transaction; +#[derive(Debug, Clone)] pub struct DeclareV2 { pub sender_address: Address, pub tx_type: TransactionType, @@ -37,6 +39,9 @@ pub struct DeclareV2 { pub sierra_contract_class: SierraContractClass, pub hash_value: Felt252, pub casm_class: once_cell::unsync::OnceCell, + pub skip_validate: bool, + pub skip_execute: bool, + pub skip_fee_transfer: bool, } impl DeclareV2 { @@ -79,6 +84,9 @@ impl DeclareV2 { compiled_class_hash, hash_value, casm_class: Default::default(), + skip_execute: false, + skip_validate: false, + skip_fee_transfer: false, }; internal_declare.verify_version()?; @@ -147,10 +155,18 @@ impl DeclareV2 { let mut tx_execution_context = self.get_execution_context(block_context.invoke_tx_max_n_steps); - let fee_transfer_info = - execute_fee_transfer(state, block_context, &mut tx_execution_context, actual_fee)?; + let fee_transfer_info = if self.skip_fee_transfer { + None + } else { + Some(execute_fee_transfer( + state, + block_context, + &mut tx_execution_context, + actual_fee, + )?) + }; - Ok((Some(fee_transfer_info), actual_fee)) + Ok((fee_transfer_info, actual_fee)) } // TODO: delete once used @@ -185,17 +201,22 @@ impl DeclareV2 { let mut resources_manager = ExecutionResourcesManager::default(); - let (validate_info, _remaining_gas) = self.run_validate_entrypoint( - initial_gas, - state, - &mut resources_manager, - block_context, - )?; + let (validate_info, _remaining_gas) = if self.skip_validate { + (None, 0) + } else { + let (info, gas) = self.run_validate_entrypoint( + initial_gas, + state, + &mut resources_manager, + block_context, + )?; + (Some(info), gas) + }; let storage_changes = state.count_actual_storage_changes(); let actual_resources = calculate_tx_resources( resources_manager, - &[Some(validate_info.clone())], + &[validate_info.clone()], self.tx_type, storage_changes, None, @@ -207,7 +228,7 @@ impl DeclareV2 { self.charge_fee(state, &actual_resources, block_context)?; let concurrent_exec_info = TransactionExecutionInfo::create_concurrent_stage_execution_info( - Some(validate_info), + validate_info, None, actual_resources, Some(self.tx_type), @@ -263,19 +284,41 @@ impl DeclareV2 { let mut tx_execution_context = self.get_execution_context(block_context.validate_max_n_steps); - let call_info = entry_point.execute( - state, - block_context, - resources_manager, - &mut tx_execution_context, - false, - )?; - + let call_info = if self.skip_execute { + None + } else { + Some(entry_point.execute( + state, + block_context, + resources_manager, + &mut tx_execution_context, + false, + )?) + }; + let call_info = verify_no_calls_to_other_contracts(&call_info)?; remaining_gas -= call_info.gas_consumed; - verify_no_calls_to_other_contracts(&call_info)?; Ok((call_info, remaining_gas)) } + + // --------------- + // Simulation + // --------------- + pub(crate) fn create_for_simulation( + &self, + skip_validate: bool, + skip_execute: bool, + skip_fee_transfer: bool, + ) -> Transaction { + let tx = DeclareV2 { + skip_validate, + skip_execute, + skip_fee_transfer, + ..self.clone() + }; + + Transaction::DeclareV2(Box::new(tx)) + } } #[cfg(test)] diff --git a/src/transaction/deploy.rs b/src/transaction/deploy.rs index 49ece6d72..dafb7bdeb 100644 --- a/src/transaction/deploy.rs +++ b/src/transaction/deploy.rs @@ -28,7 +28,9 @@ use cairo_vm::felt::Felt252; use num_traits::Zero; use starknet_contract_class::EntryPointType; -#[derive(Debug)] +use super::Transaction; + +#[derive(Debug, Clone)] pub struct Deploy { pub hash_value: Felt252, pub version: Felt252, @@ -37,6 +39,9 @@ pub struct Deploy { pub contract_hash: ClassHash, pub constructor_calldata: Vec, pub tx_type: TransactionType, + pub skip_validate: bool, + pub skip_execute: bool, + pub skip_fee_transfer: bool, } impl Deploy { @@ -77,6 +82,9 @@ impl Deploy { contract_hash, constructor_calldata, tx_type: TransactionType::Deploy, + skip_validate: false, + skip_execute: false, + skip_fee_transfer: false, }) } @@ -223,6 +231,25 @@ impl Deploy { ), ) } + + // --------------- + // Simulation + // --------------- + pub(crate) fn create_for_simulation( + &self, + skip_validate: bool, + skip_execute: bool, + skip_fee_transfer: bool, + ) -> Transaction { + let tx = Deploy { + skip_validate, + skip_execute, + skip_fee_transfer, + ..self.clone() + }; + + Transaction::Deploy(tx) + } } #[cfg(test)] diff --git a/src/transaction/deploy_account.rs b/src/transaction/deploy_account.rs index 1e9a4973e..6179e14f0 100644 --- a/src/transaction/deploy_account.rs +++ b/src/transaction/deploy_account.rs @@ -1,4 +1,4 @@ -use super::invoke_function::verify_no_calls_to_other_contracts; +use super::{invoke_function::verify_no_calls_to_other_contracts, Transaction}; use crate::{ core::{ errors::state_errors::StateError, @@ -58,6 +58,9 @@ pub struct DeployAccount { hash_value: Felt252, #[getset(get = "pub")] signature: Vec, + skip_validate: bool, + skip_execute: bool, + skip_fee_transfer: bool, } impl DeployAccount { @@ -104,6 +107,9 @@ impl DeployAccount { max_fee, hash_value, signature, + skip_execute: false, + skip_validate: false, + skip_fee_transfer: false, }) } @@ -169,8 +175,11 @@ impl DeployAccount { let constructor_call_info = self.handle_constructor(contract_class, state, block_context, &mut resources_manager)?; - let validate_info = - self.run_validate_entrypoint(state, &mut resources_manager, block_context)?; + let validate_info = if self.skip_validate { + None + } else { + self.run_validate_entrypoint(state, &mut resources_manager, block_context)? + }; let actual_resources = calculate_tx_resources( resources_manager, @@ -254,15 +263,19 @@ impl DeployAccount { INITIAL_GAS_COST, ); - let call_info = entry_point.execute( - state, - block_context, - resources_manager, - &mut self.get_execution_context(block_context.validate_max_n_steps), - false, - )?; + let call_info = if self.skip_execute { + None + } else { + Some(entry_point.execute( + state, + block_context, + resources_manager, + &mut self.get_execution_context(block_context.validate_max_n_steps), + false, + )?) + }; - verify_no_calls_to_other_contracts(&call_info) + let call_info = verify_no_calls_to_other_contracts(&call_info) .map_err(|_| TransactionError::InvalidContractCall)?; Ok(call_info) } @@ -309,18 +322,22 @@ impl DeployAccount { INITIAL_GAS_COST, ); - let call_info = call.execute( - state, - block_context, - resources_manager, - &mut self.get_execution_context(block_context.validate_max_n_steps), - false, - )?; + let call_info = if self.skip_execute { + None + } else { + Some(call.execute( + state, + block_context, + resources_manager, + &mut self.get_execution_context(block_context.validate_max_n_steps), + false, + )?) + }; verify_no_calls_to_other_contracts(&call_info) .map_err(|_| TransactionError::InvalidContractCall)?; - Ok(Some(call_info)) + Ok(call_info) } fn charge_fee( @@ -344,10 +361,34 @@ impl DeployAccount { let mut tx_execution_context = self.get_execution_context(block_context.invoke_tx_max_n_steps); - let fee_transfer_info = - execute_fee_transfer(state, block_context, &mut tx_execution_context, actual_fee)?; + let fee_transfer_info = if self.skip_fee_transfer { + None + } else { + Some(execute_fee_transfer( + state, + block_context, + &mut tx_execution_context, + actual_fee, + )?) + }; + + Ok((fee_transfer_info, actual_fee)) + } + + pub(crate) fn create_for_simulation( + &self, + skip_validate: bool, + skip_execute: bool, + skip_fee_transfer: bool, + ) -> Transaction { + let tx = DeployAccount { + skip_validate, + skip_execute, + skip_fee_transfer, + ..self.clone() + }; - Ok((Some(fee_transfer_info), actual_fee)) + Transaction::DeployAccount(tx) } } diff --git a/src/transaction/error.rs b/src/transaction/error.rs index 7243a730a..c330e3878 100644 --- a/src/transaction/error.rs +++ b/src/transaction/error.rs @@ -127,4 +127,6 @@ pub enum TransactionError { InvalidBlockTimestamp, #[error("{0:?}")] CustomError(String), + #[error("call info is None")] + CallInfoIsNone, } diff --git a/src/transaction/invoke_function.rs b/src/transaction/invoke_function.rs index bf9fec939..96ab2bf4d 100644 --- a/src/transaction/invoke_function.rs +++ b/src/transaction/invoke_function.rs @@ -11,11 +11,8 @@ use crate::{ execution_entry_point::ExecutionEntryPoint, CallInfo, TransactionExecutionContext, TransactionExecutionInfo, }, + state::state_api::{State, StateReader}, state::ExecutionResourcesManager, - state::{ - cached_state::CachedState, - state_api::{State, StateReader}, - }, transaction::{ error::TransactionError, fee::{calculate_tx_fee, execute_fee_transfer, FeeInfo}, @@ -28,6 +25,8 @@ use getset::Getters; use num_traits::Zero; use starknet_contract_class::EntryPointType; +use super::Transaction; + #[derive(Debug, Getters, Clone)] pub struct InvokeFunction { #[getset(get = "pub")] @@ -47,6 +46,7 @@ pub struct InvokeFunction { nonce: Option, skip_validation: bool, skip_execute: bool, + skip_fee_transfer: bool, } impl InvokeFunction { @@ -96,6 +96,7 @@ impl InvokeFunction { hash_value, skip_validation: false, skip_execute: false, + skip_fee_transfer: false, }) } @@ -148,15 +149,15 @@ impl InvokeFunction { 0, ); - let call_info = call.execute( + let call_info = Some(call.execute( state, block_context, resources_manager, &mut self.get_execution_context(block_context.validate_max_n_steps)?, false, - )?; + )?); - verify_no_calls_to_other_contracts(&call_info) + let call_info = verify_no_calls_to_other_contracts(&call_info) .map_err(|_| TransactionError::InvalidContractCall)?; Ok(Some(call_info)) @@ -256,7 +257,7 @@ impl InvokeFunction { let mut tx_execution_context = self.get_execution_context(block_context.invoke_tx_max_n_steps)?; - let fee_transfer_info = if self.skip_execute { + let fee_transfer_info = if self.skip_fee_transfer { None } else { Some(execute_fee_transfer( @@ -323,26 +324,18 @@ impl InvokeFunction { pub(crate) fn create_for_simulation( &self, - tx: InvokeFunction, skip_validation: bool, skip_execute: bool, - ) -> InvokeFunction { - InvokeFunction { + skip_fee_transfer: bool, + ) -> Transaction { + let tx = InvokeFunction { skip_validation, skip_execute, - ..tx - } - } + skip_fee_transfer, + ..self.clone() + }; - pub(crate) fn simulate_transaction( - &self, - state: S, - block_context: BlockContext, - remaining_gas: u128, - ) -> Result { - let mut cache_state = CachedState::new(state, None, None); - // init simulation - self.execute(&mut cache_state, &block_context, remaining_gas) + Transaction::InvokeFunction(tx) } } @@ -350,14 +343,17 @@ impl InvokeFunction { // Invoke internal functions utils // ------------------------------------ -pub fn verify_no_calls_to_other_contracts(call_info: &CallInfo) -> Result<(), TransactionError> { +pub fn verify_no_calls_to_other_contracts( + call_info: &Option, +) -> Result { + let call_info = call_info.clone().ok_or(TransactionError::CallInfoIsNone)?; let invoked_contract_address = call_info.contract_address.clone(); for internal_call in call_info.gen_call_topology() { if internal_call.contract_address != invoked_contract_address { return Err(TransactionError::UnauthorizedActionOnValidate); } } - Ok(()) + Ok(call_info) } // Performs validation on fields related to function invocation transaction. @@ -420,6 +416,7 @@ mod tests { nonce: Some(0.into()), skip_validation: false, skip_execute: false, + skip_fee_transfer: false, }; // Instantiate CachedState @@ -488,6 +485,7 @@ mod tests { nonce: Some(0.into()), skip_validation: false, skip_execute: false, + skip_fee_transfer: false, }; // Instantiate CachedState @@ -552,6 +550,7 @@ mod tests { nonce: Some(0.into()), skip_validation: false, skip_execute: false, + skip_fee_transfer: false, }; // Instantiate CachedState @@ -610,6 +609,7 @@ mod tests { nonce: None, skip_validation: false, skip_execute: false, + skip_fee_transfer: false, }; // Instantiate CachedState @@ -674,6 +674,7 @@ mod tests { nonce: None, skip_validation: false, skip_execute: false, + skip_fee_transfer: false, }; // Instantiate CachedState @@ -730,6 +731,7 @@ mod tests { nonce: Some(0.into()), skip_validation: false, skip_execute: false, + skip_fee_transfer: false, }; // Instantiate CachedState @@ -791,6 +793,7 @@ mod tests { nonce: Some(0.into()), skip_validation: false, skip_execute: false, + skip_fee_transfer: false, }; // Instantiate CachedState @@ -853,6 +856,7 @@ mod tests { nonce: Some(0.into()), skip_validation: false, skip_execute: false, + skip_fee_transfer: false, }; // Instantiate CachedState @@ -915,6 +919,7 @@ mod tests { nonce: None, skip_validation: false, skip_execute: false, + skip_fee_transfer: false, }; // Instantiate CachedState @@ -977,7 +982,7 @@ mod tests { internal_calls.push(internal_call); call_info.internal_calls = internal_calls; - let expected_error = verify_no_calls_to_other_contracts(&call_info); + let expected_error = verify_no_calls_to_other_contracts(&Some(call_info)); assert!(expected_error.is_err()); assert_matches!( diff --git a/src/transaction/l1_handler.rs b/src/transaction/l1_handler.rs index 71fa12702..fa2c37045 100644 --- a/src/transaction/l1_handler.rs +++ b/src/transaction/l1_handler.rs @@ -21,8 +21,10 @@ use crate::{ utils::{calculate_tx_resources, Address}, }; +use super::Transaction; + #[allow(dead_code)] -#[derive(Debug, Getters)] +#[derive(Debug, Getters, Clone)] pub struct L1Handler { #[getset(get = "pub")] hash_value: Felt252, @@ -32,6 +34,8 @@ pub struct L1Handler { calldata: Vec, nonce: Option, paid_fee_on_l1: Option, + skip_validate: bool, + skip_execute: bool, } impl L1Handler { @@ -61,6 +65,8 @@ impl L1Handler { calldata, nonce: Some(nonce), paid_fee_on_l1, + skip_execute: false, + skip_validate: false, }) } @@ -86,18 +92,22 @@ impl L1Handler { remaining_gas, ); - let call_info = entrypoint.execute( - state, - block_context, - &mut resources_manager, - &mut self.get_execution_context(block_context.invoke_tx_max_n_steps)?, - false, - )?; + let call_info = if self.skip_execute { + None + } else { + Some(entrypoint.execute( + state, + block_context, + &mut resources_manager, + &mut self.get_execution_context(block_context.invoke_tx_max_n_steps)?, + false, + )?) + }; let changes = state.count_actual_storage_changes(); let actual_resources = calculate_tx_resources( resources_manager, - &[Some(call_info.clone())], + &[call_info.clone()], TransactionType::L1Handler, changes, Some(self.get_payload_size()), @@ -126,7 +136,7 @@ impl L1Handler { Ok( TransactionExecutionInfo::create_concurrent_stage_execution_info( None, - Some(call_info), + call_info, actual_resources, Some(TransactionType::L1Handler), ), @@ -155,6 +165,19 @@ impl L1Handler { L1_HANDLER_VERSION.into(), )) } + pub(crate) fn create_for_simulation( + &self, + skip_validate: bool, + skip_execute: bool, + ) -> Transaction { + let tx = L1Handler { + skip_validate, + skip_execute, + ..self.clone() + }; + + Transaction::L1Handler(tx) + } } #[cfg(test)] diff --git a/src/transaction/mod.rs b/src/transaction/mod.rs index bdc850531..827b72621 100644 --- a/src/transaction/mod.rs +++ b/src/transaction/mod.rs @@ -64,4 +64,30 @@ impl Transaction { Transaction::L1Handler(tx) => tx.execute(state, block_context, remaining_gas), } } + + pub fn create_for_simulation( + &self, + skip_validate: bool, + skip_execute: bool, + skip_fee_transfer: bool, + ) -> Self { + match self { + Transaction::Declare(tx) => { + tx.create_for_simulation(skip_validate, skip_execute, skip_fee_transfer) + } + Transaction::DeclareV2(tx) => { + tx.create_for_simulation(skip_validate, skip_execute, skip_fee_transfer) + } + Transaction::Deploy(tx) => { + tx.create_for_simulation(skip_validate, skip_execute, skip_fee_transfer) + } + Transaction::DeployAccount(tx) => { + tx.create_for_simulation(skip_validate, skip_execute, skip_fee_transfer) + } + Transaction::InvokeFunction(tx) => { + tx.create_for_simulation(skip_validate, skip_execute, skip_fee_transfer) + } + Transaction::L1Handler(tx) => tx.create_for_simulation(skip_validate, skip_execute), + } + } } diff --git a/tests/internals.rs b/tests/internals.rs index af0262d8c..000ec53f8 100644 --- a/tests/internals.rs +++ b/tests/internals.rs @@ -691,6 +691,9 @@ fn declare_tx() -> Declare { signature: vec![], nonce: 0.into(), hash_value: 0.into(), + skip_execute: false, + skip_fee_transfer: false, + skip_validate: false, } } @@ -713,6 +716,9 @@ fn declarev2_tx() -> DeclareV2 { compiled_class_hash: TEST_FIB_COMPILED_CONTRACT_CLASS_HASH.clone(), sierra_contract_class, casm_class: Default::default(), + skip_execute: false, + skip_fee_transfer: false, + skip_validate: false, } } @@ -725,6 +731,9 @@ fn deploy_fib_syscall() -> Deploy { contract_hash: felt_to_hash(&TEST_FIB_COMPILED_CONTRACT_CLASS_HASH.clone()), constructor_calldata: Vec::new(), tx_type: TransactionType::Deploy, + skip_execute: false, + skip_fee_transfer: false, + skip_validate: false, } }