diff --git a/cli/src/main.rs b/cli/src/main.rs index f37bdf544..675c9a41a 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -25,7 +25,7 @@ use starknet_in_rust::{ serde_structs::read_abi, services::api::contract_classes::deprecated_contract_class::ContractClass, state::{cached_state::CachedState, state_api::State}, - state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager}, + state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager, StateDiff}, transaction::{error::TransactionError, InvokeFunction}, utils::{felt_to_hash, string_to_hash, Address}, }; @@ -195,7 +195,9 @@ fn invoke_parser( Some(Felt252::zero()), transaction_hash.unwrap(), )?; - let _tx_info = internal_invoke.apply(cached_state, &BlockContext::default(), 0)?; + let mut transactional_state = cached_state.create_transactional(); + let _tx_info = internal_invoke.apply(&mut transactional_state, &BlockContext::default(), 0)?; + cached_state.apply_state_update(&StateDiff::from_cached_state(transactional_state)?)?; let tx_hash = calculate_transaction_hash_common( TransactionHashPrefix::Invoke, diff --git a/rpc_state_reader/src/lib.rs b/rpc_state_reader/src/lib.rs index 2281aae61..6f50716f9 100644 --- a/rpc_state_reader/src/lib.rs +++ b/rpc_state_reader/src/lib.rs @@ -530,6 +530,7 @@ mod tests { rpc_state.get_transaction(tx_hash); } + #[ignore] #[test] fn test_get_block_info() { let rpc_state = RpcState::new( diff --git a/src/execution/execution_entry_point.rs b/src/execution/execution_entry_point.rs index b7116d1f1..7559690ba 100644 --- a/src/execution/execution_entry_point.rs +++ b/src/execution/execution_entry_point.rs @@ -2,7 +2,6 @@ use crate::services::api::contract_classes::deprecated_contract_class::{ ContractEntryPoint, EntryPointType, }; use crate::state::cached_state::CachedState; -use crate::state::StateDiff; use crate::{ definitions::{block_context::BlockContext, constants::DEFAULT_ENTRY_POINT_SELECTOR}, runner::StarknetRunner, @@ -126,15 +125,8 @@ impl ExecutionEntryPoint { }) } CompiledClass::Casm(contract_class) => { - let mut tmp_state = CachedState::new( - state.state_reader.clone(), - state.contract_classes.clone(), - state.casm_contract_classes.clone(), - ); - tmp_state.cache = state.cache.clone(); - match self._execute( - &mut tmp_state, + state, resources_manager, block_context, tx_execution_context, @@ -142,15 +134,11 @@ impl ExecutionEntryPoint { class_hash, support_reverted, ) { - Ok(call_info) => { - let state_diff = StateDiff::from_cached_state(tmp_state)?; - state.apply_state_update(&state_diff)?; - Ok(ExecutionResult { - call_info: Some(call_info), - revert_error: None, - n_reverted_steps: 0, - }) - } + Ok(call_info) => Ok(ExecutionResult { + call_info: Some(call_info), + revert_error: None, + n_reverted_steps: 0, + }), Err(e) => { if !support_reverted { return Err(e); diff --git a/src/execution/mod.rs b/src/execution/mod.rs index 3ff535150..f8afed7af 100644 --- a/src/execution/mod.rs +++ b/src/execution/mod.rs @@ -610,11 +610,11 @@ impl TransactionExecutionInfo { Ok(sorted_messages) } - pub fn to_revert_error(self, revert_error: String) -> Self { + pub fn to_revert_error(self, revert_error: &str) -> Self { TransactionExecutionInfo { validate_info: None, call_info: None, - revert_error: Some(revert_error), + revert_error: Some(revert_error.to_string()), fee_transfer_info: None, ..self } diff --git a/src/lib.rs b/src/lib.rs index 6dc8c24fa..e32069441 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -297,7 +297,7 @@ mod test { let transaction = Transaction::InvokeFunction(invoke_function); let estimated_fee = estimate_fee(&[transaction], state, &block_context).unwrap(); - assert_eq!(estimated_fee[0], (2483, 2448)); + assert_eq!(estimated_fee[0], (1259, 1224)); } #[test] @@ -1014,7 +1014,7 @@ mod test { assert_eq!( estimate_fee(&[deploy, invoke_tx], state, block_context,).unwrap(), - [(0, 3672), (0, 2448)] + [(0, 3672), (0, 1224)] ); } diff --git a/src/state/cached_state.rs b/src/state/cached_state.rs index a17f99089..f9828b6c0 100644 --- a/src/state/cached_state.rs +++ b/src/state/cached_state.rs @@ -82,15 +82,16 @@ impl CachedState { .ok_or(StateError::MissingCasmClassCache) } - pub fn create_copy(&self) -> Self { - let mut state = CachedState::new( - self.state_reader.clone(), - self.contract_classes.clone(), - self.casm_contract_classes.clone(), - ); - state.cache = self.cache.clone(); - - state + /// Creates a copy of this state with an empty cache for saving changes and applying them + /// later. + pub fn create_transactional(&self) -> TransactionalCachedState { + let state_reader = Arc::new(TransactionalCachedStateReader::new(self)); + CachedState { + state_reader, + cache: Default::default(), + contract_classes: Default::default(), + casm_contract_classes: Default::default(), + } } } @@ -471,10 +472,10 @@ impl State for CachedState { match contract { CompiledClass::Casm(ref class) => { // We call this method instead of state_reader's in order to update the cache's class_hash_initial_values map - let compiled_class_hash = self.get_compiled_class_hash(class_hash)?; + //let compiled_class_hash = self.get_compiled_class_hash(class_hash)?; self.casm_contract_classes .as_mut() - .and_then(|m| m.insert(compiled_class_hash, *class.clone())); + .and_then(|m| m.insert(*class_hash, *class.clone())); } CompiledClass::Deprecated(ref contract) => { self.set_contract_class(class_hash, &contract.clone())? @@ -484,6 +485,196 @@ impl State for CachedState { } } +/// A CachedState which has access to another, "parent" state, used for executing transactions +/// without commiting changes to the parent. +pub type TransactionalCachedState<'a, T> = CachedState>; + +impl<'a, T: StateReader> TransactionalCachedState<'a, T> { + pub fn count_actual_storage_changes(&mut self) -> Result<(usize, usize), StateError> { + let storage_updates = subtract_mappings( + self.cache.storage_writes.clone(), + self.cache.storage_initial_values.clone(), + ); + + let n_modified_contracts = { + let storage_unique_updates = storage_updates.keys().map(|k| k.0.clone()); + + let class_hash_updates: Vec<_> = subtract_mappings( + self.cache.class_hash_writes.clone(), + self.cache.class_hash_initial_values.clone(), + ) + .keys() + .cloned() + .collect(); + + let nonce_updates: Vec<_> = subtract_mappings( + self.cache.nonce_writes.clone(), + self.cache.nonce_initial_values.clone(), + ) + .keys() + .cloned() + .collect(); + + let mut modified_contracts: HashSet
= HashSet::new(); + modified_contracts.extend(storage_unique_updates); + modified_contracts.extend(class_hash_updates); + modified_contracts.extend(nonce_updates); + + modified_contracts.len() + }; + + Ok((n_modified_contracts, storage_updates.len())) + } +} + +/// State reader used for transactional states which allows to check the parent state's cache and +/// state reader if a transactional cache miss happens. +/// +/// In practice this will act as a way to access the parent state's cache and other fields, +/// without referencing the whole parent state, so there's no need to adapt state-modifying +/// functions in the case that a transactional state is needed. +#[derive(Debug, MutGetters, Getters, PartialEq, Clone)] +pub struct TransactionalCachedStateReader<'a, T: StateReader> { + /// The parent state's state_reader + #[get(get = "pub")] + pub(crate) state_reader: Arc, + /// The parent state's cache + #[get(get = "pub")] + pub(crate) cache: &'a StateCache, + /// The parent state's contract_classes + #[get(get = "pub")] + pub(crate) contract_classes: Option, + /// The parent state's casm_contract_classes + #[get(get = "pub")] + pub(crate) casm_contract_classes: Option, +} + +impl<'a, T: StateReader> TransactionalCachedStateReader<'a, T> { + fn new(state: &'a CachedState) -> Self { + Self { + state_reader: state.state_reader.clone(), + cache: &state.cache, + contract_classes: state.contract_classes.clone(), + casm_contract_classes: state.casm_contract_classes.clone(), + } + } +} + +impl<'a, T: StateReader> StateReader for TransactionalCachedStateReader<'a, T> { + fn get_class_hash_at(&self, contract_address: &Address) -> Result { + if self.cache.get_class_hash(contract_address).is_none() { + match self.state_reader.get_class_hash_at(contract_address) { + Ok(class_hash) => { + return Ok(class_hash); + } + Err(StateError::NoneContractState(_)) => { + return Ok([0; 32]); + } + Err(e) => { + return Err(e); + } + } + } + + self.cache + .get_class_hash(contract_address) + .ok_or_else(|| StateError::NoneClassHash(contract_address.clone())) + .cloned() + } + + fn get_nonce_at(&self, contract_address: &Address) -> Result { + if self.cache.get_nonce(contract_address).is_none() { + return self.state_reader.get_nonce_at(contract_address); + } + self.cache + .get_nonce(contract_address) + .ok_or_else(|| StateError::NoneNonce(contract_address.clone())) + .cloned() + } + + fn get_storage_at(&self, storage_entry: &StorageEntry) -> Result { + if self.cache.get_storage(storage_entry).is_none() { + match self.state_reader.get_storage_at(storage_entry) { + Ok(storage) => { + return Ok(storage); + } + Err( + StateError::EmptyKeyInStorage + | StateError::NoneStoragLeaf(_) + | StateError::NoneStorage(_) + | StateError::NoneContractState(_), + ) => return Ok(Felt252::zero()), + Err(e) => { + return Err(e); + } + } + } + + self.cache + .get_storage(storage_entry) + .ok_or_else(|| StateError::NoneStorage(storage_entry.clone())) + .cloned() + } + + // TODO: check if that the proper way to store it (converting hash to address) + fn get_compiled_class_hash(&self, class_hash: &ClassHash) -> Result { + if self + .cache + .class_hash_to_compiled_class_hash + .get(class_hash) + .is_none() + { + return self.state_reader.get_compiled_class_hash(class_hash); + } + self.cache + .class_hash_to_compiled_class_hash + .get(class_hash) + .ok_or_else(|| StateError::NoneCompiledClass(*class_hash)) + .cloned() + } + + fn get_contract_class(&self, class_hash: &ClassHash) -> Result { + // This method can receive both compiled_class_hash & class_hash and return both casm and deprecated contract classes + //, which can be on the cache or on the state_reader, different cases will be described below: + if class_hash == UNINITIALIZED_CLASS_HASH { + return Err(StateError::UninitiaizedClassHash); + } + // I: FETCHING FROM CACHE + // I: DEPRECATED CONTRACT CLASS + // deprecated contract classes dont have compiled class hashes, so we only have one case + if let Some(compiled_class) = self + .contract_classes + .as_ref() + .and_then(|x| x.get(class_hash)) + { + return Ok(CompiledClass::Deprecated(Box::new(compiled_class.clone()))); + } + // I: CASM CONTRACT CLASS : COMPILED_CLASS_HASH + if let Some(compiled_class) = self + .casm_contract_classes + .as_ref() + .and_then(|x| x.get(class_hash)) + { + return Ok(CompiledClass::Casm(Box::new(compiled_class.clone()))); + } + // I: CASM CONTRACT CLASS : CLASS_HASH + if let Some(compiled_class_hash) = + self.cache.class_hash_to_compiled_class_hash.get(class_hash) + { + if let Some(casm_class) = &mut self + .casm_contract_classes + .as_ref() + .and_then(|m| m.get(compiled_class_hash)) + { + return Ok(CompiledClass::Casm(Box::new(casm_class.clone()))); + } + } + // II: FETCHING FROM STATE_READER + let contract = self.state_reader.get_contract_class(class_hash)?; + Ok(contract) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/state/in_memory_state_reader.rs b/src/state/in_memory_state_reader.rs index bad842c38..e71445cbb 100644 --- a/src/state/in_memory_state_reader.rs +++ b/src/state/in_memory_state_reader.rs @@ -32,7 +32,7 @@ pub struct InMemoryStateReader { #[getset(get_mut = "pub")] pub(crate) casm_contract_classes: CasmClassCache, #[getset(get_mut = "pub")] - pub(crate) class_hash_to_compiled_class_hash: HashMap, + pub class_hash_to_compiled_class_hash: HashMap, } impl InMemoryStateReader { diff --git a/src/syscalls/deprecated_syscall_handler.rs b/src/syscalls/deprecated_syscall_handler.rs index d131d5ff1..e14f3a83b 100644 --- a/src/syscalls/deprecated_syscall_handler.rs +++ b/src/syscalls/deprecated_syscall_handler.rs @@ -228,6 +228,7 @@ mod tests { use super::*; use crate::services::api::contract_classes::deprecated_contract_class::EntryPointType; + use crate::state::StateDiff; use crate::{ add_segments, allocate_selector, any_box, definitions::{ @@ -1188,9 +1189,14 @@ mod tests { ) .unwrap(); + let mut transactional = state.create_transactional(); // Invoke result let result = internal_invoke_function - .apply(&mut state, &BlockContext::default(), 0) + .apply(&mut transactional, &BlockContext::default(), 0) + .unwrap(); + + state + .apply_state_update(&StateDiff::from_cached_state(transactional).unwrap()) .unwrap(); let result_call_info = result.call_info.unwrap(); diff --git a/src/testing/state.rs b/src/testing/state.rs index 28c1ef84a..b0415e74e 100644 --- a/src/testing/state.rs +++ b/src/testing/state.rs @@ -574,7 +574,7 @@ mod tests { .unwrap(); let actual_resources = HashMap::from([ ("n_steps".to_string(), 3457), - ("l1_gas_usage".to_string(), 2448), + ("l1_gas_usage".to_string(), 1224), ("range_check_builtin".to_string(), 80), ("pedersen_builtin".to_string(), 16), ]); diff --git a/src/transaction/deploy_account.rs b/src/transaction/deploy_account.rs index 2833004e1..bf7bf8b10 100644 --- a/src/transaction/deploy_account.rs +++ b/src/transaction/deploy_account.rs @@ -1,9 +1,10 @@ -use super::fee::charge_fee; +use super::fee::{calculate_tx_fee, charge_fee}; use super::{invoke_function::verify_no_calls_to_other_contracts, Transaction}; use crate::definitions::constants::QUERY_VERSION_BASE; use crate::execution::execution_entry_point::ExecutionResult; use crate::services::api::contract_classes::deprecated_contract_class::EntryPointType; use crate::state::cached_state::CachedState; +use crate::state::StateDiff; use crate::{ core::{ errors::state_errors::StateError, @@ -156,22 +157,46 @@ impl DeployAccount { block_context: &BlockContext, ) -> Result { self.handle_nonce(state)?; - let mut tx_info = self.apply(state, block_context)?; + + let mut transactional_state = state.create_transactional(); + let mut tx_exec_info = self.apply(&mut transactional_state, block_context)?; + + let actual_fee = calculate_tx_fee( + &tx_exec_info.actual_resources, + block_context.starknet_os_config.gas_price, + block_context, + )?; + + if let Some(revert_error) = tx_exec_info.revert_error.clone() { + // execution error + tx_exec_info = tx_exec_info.to_revert_error(&revert_error); + } else if actual_fee > self.max_fee { + // max_fee exceeded + tx_exec_info = tx_exec_info.to_revert_error( + format!( + "Calculated fee ({}) exceeds max fee ({})", + actual_fee, self.max_fee + ) + .as_str(), + ); + } else { + state.apply_state_update(&StateDiff::from_cached_state(transactional_state)?)?; + } let mut tx_execution_context = self.get_execution_context(block_context.invoke_tx_max_n_steps); let (fee_transfer_info, actual_fee) = charge_fee( state, - &tx_info.actual_resources, + &tx_exec_info.actual_resources, block_context, self.max_fee, &mut tx_execution_context, self.skip_fee_transfer, )?; - tx_info.set_fee_info(actual_fee, fee_transfer_info); + tx_exec_info.set_fee_info(actual_fee, fee_transfer_info); - Ok(tx_info) + Ok(tx_exec_info) } fn constructor_entry_points_empty( diff --git a/src/transaction/invoke_function.rs b/src/transaction/invoke_function.rs index 49fcc17ec..3f376648c 100644 --- a/src/transaction/invoke_function.rs +++ b/src/transaction/invoke_function.rs @@ -11,7 +11,10 @@ use crate::{ execution_entry_point::{ExecutionEntryPoint, ExecutionResult}, CallInfo, TransactionExecutionContext, TransactionExecutionInfo, }, - state::{cached_state::CachedState, ExecutionResourcesManager}, + state::{ + cached_state::{CachedState, TransactionalCachedState}, + ExecutionResourcesManager, + }, state::{ state_api::{State, StateReader}, StateDiff, @@ -229,7 +232,7 @@ impl InvokeFunction { /// - remaining_gas: The amount of gas that the transaction disposes. pub fn apply( &self, - state: &mut CachedState, + state: &mut TransactionalCachedState, block_context: &BlockContext, remaining_gas: u128, ) -> Result { @@ -251,7 +254,7 @@ impl InvokeFunction { remaining_gas, )? }; - let changes = state.count_actual_storage_changes(); + let changes = state.count_actual_storage_changes()?; let actual_resources = calculate_tx_resources( resources_manager, &vec![call_info.clone(), validate_info.clone()], @@ -286,7 +289,7 @@ impl InvokeFunction { self.handle_nonce(state)?; } - let mut transactional_state = state.create_copy(); + let mut transactional_state = state.create_transactional(); let mut tx_exec_info = self.apply(&mut transactional_state, block_context, remaining_gas)?; @@ -296,13 +299,20 @@ impl InvokeFunction { block_context, )?; - if actual_fee <= self.max_fee { - state.apply_state_update(&StateDiff::from_cached_state(transactional_state)?)?; + if let Some(revert_error) = tx_exec_info.revert_error.clone() { + // execution error + tx_exec_info = tx_exec_info.to_revert_error(&revert_error); + } else if actual_fee > self.max_fee { + // max_fee exceeded + tx_exec_info = tx_exec_info.to_revert_error( + format!( + "Calculated fee ({}) exceeds max fee ({})", + actual_fee, self.max_fee + ) + .as_str(), + ); } else { - tx_exec_info = tx_exec_info.to_revert_error(format!( - "Calculated fee ({}) exceeds max fee ({})", - actual_fee, self.max_fee - )); + state.apply_state_update(&StateDiff::from_cached_state(transactional_state)?)?; } let mut tx_execution_context = @@ -482,8 +492,13 @@ mod tests { .set_contract_class(&class_hash, &contract_class) .unwrap(); + let mut transactional = state.create_transactional(); + // Invoke result let result = internal_invoke_function - .apply(&mut state, &BlockContext::default(), 0) + .apply(&mut transactional, &BlockContext::default(), 0) + .unwrap(); + state + .apply_state_update(&StateDiff::from_cached_state(transactional).unwrap()) .unwrap(); assert_eq!(result.tx_type, Some(TransactionType::InvokeFunction)); @@ -616,8 +631,9 @@ mod tests { .set_contract_class(&class_hash, &contract_class) .unwrap(); + let mut transactional = state.create_transactional(); let expected_error = - internal_invoke_function.apply(&mut state, &BlockContext::default(), 0); + internal_invoke_function.apply(&mut transactional, &BlockContext::default(), 0); assert!(expected_error.is_err()); assert_matches!( @@ -675,8 +691,13 @@ mod tests { .set_contract_class(&class_hash, &contract_class) .unwrap(); + let mut transactional = state.create_transactional(); + // Invoke result let result = internal_invoke_function - .apply(&mut state, &BlockContext::default(), 0) + .apply(&mut transactional, &BlockContext::default(), 0) + .unwrap(); + state + .apply_state_update(&StateDiff::from_cached_state(transactional).unwrap()) .unwrap(); assert_eq!(result.tx_type, Some(TransactionType::InvokeFunction)); @@ -740,8 +761,10 @@ mod tests { .set_contract_class(&class_hash, &contract_class) .unwrap(); + let mut transactional = state.create_transactional(); + // Invoke result let expected_error = - internal_invoke_function.apply(&mut state, &BlockContext::default(), 0); + internal_invoke_function.apply(&mut transactional, &BlockContext::default(), 0); assert!(expected_error.is_err()); assert_matches!(expected_error.unwrap_err(), TransactionError::MissingNonce); @@ -862,11 +885,14 @@ mod tests { let tx_info = internal_invoke_function .execute(&mut state, &block_context, 0) .unwrap(); - let expected_actual_fee = 2483; - let expected_tx_info = tx_info.clone().to_revert_error(format!( - "Calculated fee ({}) exceeds max fee ({})", - expected_actual_fee, max_fee - )); + let expected_actual_fee = 1259; + let expected_tx_info = tx_info.clone().to_revert_error( + format!( + "Calculated fee ({}) exceeds max fee ({})", + expected_actual_fee, max_fee + ) + .as_str(), + ); assert_eq!(tx_info, expected_tx_info); } @@ -1114,6 +1140,10 @@ mod tests { state_reader .address_to_nonce .insert(contract_address, nonce); + state_reader + .class_hash_to_compiled_class_hash + .insert(class_hash, class_hash); + // last is necessary so the transactional state can cache the class let mut casm_contract_class_cache = HashMap::new(); diff --git a/tests/internals.rs b/tests/internals.rs index 4328faed7..8308282f8 100644 --- a/tests/internals.rs +++ b/tests/internals.rs @@ -1224,7 +1224,7 @@ fn expected_transaction_execution_info(block_context: &BlockContext) -> Transact let resources = HashMap::from([ ("n_steps".to_string(), 3445), ("pedersen_builtin".to_string(), 16), - ("l1_gas_usage".to_string(), 2448), + ("l1_gas_usage".to_string(), 1224), ("range_check_builtin".to_string(), 82), ]); let fee = calculate_tx_fee(&resources, *GAS_PRICE, block_context).unwrap(); @@ -1253,7 +1253,7 @@ fn expected_fib_transaction_execution_info( } let resources = HashMap::from([ ("n_steps".to_string(), n_steps), - ("l1_gas_usage".to_string(), 7344), + ("l1_gas_usage".to_string(), 1224), ("pedersen_builtin".to_string(), 16), ("range_check_builtin".to_string(), 85), ]); @@ -1300,17 +1300,19 @@ fn test_invoke_tx_exceeded_max_fee() { Felt252::from(2), // CONTRACT_CALLDATA ]; let max_fee = 3; - let actual_fee = 2483; + let actual_fee = 1259; let invoke_tx = invoke_tx(calldata, max_fee); // Extract invoke transaction fields for testing, as it is consumed when creating an account // transaction. let result = invoke_tx.execute(state, block_context, 0).unwrap(); - let mut expected_result = - expected_transaction_execution_info(block_context).to_revert_error(format!( + let mut expected_result = expected_transaction_execution_info(block_context).to_revert_error( + format!( "Calculated fee ({}) exceeds max fee ({})", actual_fee, max_fee - )); + ) + .as_str(), + ); expected_result.set_fee_info(max_fee, Some(expected_fee_transfer_info(max_fee))); assert_eq!(result, expected_result); @@ -1493,7 +1495,7 @@ fn test_invoke_with_declarev2_tx() { fn test_deploy_account() { let (block_context, mut state) = create_account_tx_test_state().unwrap(); - let expected_fee = 6157; + let expected_fee = 3709; let deploy_account_tx = DeployAccount::new( felt_to_hash(&TEST_ACCOUNT_CONTRACT_CLASS_HASH), @@ -1571,12 +1573,12 @@ fn test_deploy_account() { ("n_steps".to_string(), 3625), ("range_check_builtin".to_string(), 83), ("pedersen_builtin".to_string(), 23), - ("l1_gas_usage".to_string(), 6120), + ("l1_gas_usage".to_string(), 3672), ]); let fee = calculate_tx_fee(&resources, *GAS_PRICE, &block_context).unwrap(); - assert_eq!(fee, 6157); + assert_eq!(fee, 3709); let expected_execution_info = TransactionExecutionInfo::new( expected_validate_call_info.into(), @@ -1606,11 +1608,161 @@ fn test_deploy_account() { assert_eq!(class_hash_from_state, *deploy_account_tx.class_hash()); } +#[test] +fn test_deploy_account_revert() { + let (block_context, mut state) = create_account_tx_test_state().unwrap(); + + let expected_fee = 1; + + let deploy_account_tx = DeployAccount::new( + felt_to_hash(&TEST_ACCOUNT_CONTRACT_CLASS_HASH), + 1, + TRANSACTION_VERSION.clone(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + StarknetChainId::TestNet.to_felt(), + ) + .unwrap(); + + state.set_storage_at( + &( + block_context + .starknet_os_config() + .fee_token_address() + .clone(), + felt_to_hash(&TEST_ERC20_DEPLOYED_ACCOUNT_BALANCE_KEY), + ), + INITIAL_BALANCE.clone(), + ); + + let (mut state_before, mut state_after) = expected_deploy_account_states(); + + assert_eq!(&state.cache(), &state_before.cache()); + assert_eq!(&state.contract_classes(), &state_before.contract_classes()); + assert_eq!( + &state.casm_contract_classes(), + &state_before.casm_contract_classes() + ); + + let tx_info = deploy_account_tx + .execute(&mut state, &block_context) + .unwrap(); + + assert_eq!( + state.casm_contract_classes(), + state_before.casm_contract_classes() + ); + + let mut state_reverted = state_before.clone(); + + // Add initial writes (these 'bypass' the transactional state because it's a state reader and + // it will cache initial values when looking for them). + state_reverted + .cache_mut() + .class_hash_initial_values_mut() + .extend( + state_after + .cache_mut() + .class_hash_initial_values_mut() + .clone(), + ); + state_reverted + .cache_mut() + .nonce_initial_values_mut() + .extend(state_after.cache_mut().nonce_initial_values_mut().clone()); + state_reverted + .cache_mut() + .storage_initial_values_mut() + .extend(state_after.cache_mut().storage_initial_values_mut().clone()); + state_reverted + .cache_mut() + .storage_initial_values_mut() + .extend(state_after.cache_mut().storage_initial_values_mut().clone()); + + // Set storage writes related to the fee transfer + state_reverted + .cache_mut() + .storage_writes_mut() + .extend(state_after.cache_mut().storage_writes().clone()); + state_reverted.set_storage_at( + &( + Address(0x1001.into()), + felt_to_hash(&TEST_ERC20_DEPLOYED_ACCOUNT_BALANCE_KEY), + ), + INITIAL_BALANCE.clone() - Felt252::one(), // minus the max fee that will be transfered + ); + state_reverted.cache_mut().storage_writes_mut().insert( + ( + Address(0x1001.into()), + felt_to_hash(&TEST_ERC20_SEQUENCER_BALANCE_KEY), + ), + Felt252::one(), // the max fee received by the sequencer + ); + + // Set nonce + state_reverted + .cache_mut() + .nonce_writes_mut() + .extend(state_after.cache_mut().nonce_writes_mut().clone()); + + assert_eq!(state.cache(), state_reverted.cache()); + + let expected_fee_transfer_call_info = expected_fee_transfer_call_info( + &block_context, + deploy_account_tx.contract_address(), + expected_fee as u64, + ); + + let resources = HashMap::from([ + ("n_steps".to_string(), 3625), + ("range_check_builtin".to_string(), 83), + ("pedersen_builtin".to_string(), 23), + ("l1_gas_usage".to_string(), 3672), + ]); + + let fee = calculate_tx_fee(&resources, *GAS_PRICE, &block_context).unwrap(); + + assert_eq!(fee, 3709); + + let mut expected_execution_info = TransactionExecutionInfo::new( + None, + None, + None, + None, + expected_fee, + // Entry **not** in blockifier. + // Default::default(), + resources, + TransactionType::DeployAccount.into(), + ) + .to_revert_error(format!("Calculated fee ({}) exceeds max fee ({})", 3709, 1).as_str()); + + expected_execution_info.set_fee_info(expected_fee, expected_fee_transfer_call_info.into()); + + assert_eq!(tx_info, expected_execution_info); + + let nonce_from_state = state + .get_nonce_at(deploy_account_tx.contract_address()) + .unwrap(); + assert_eq!(nonce_from_state, Felt252::one()); + + let hash = TEST_ERC20_DEPLOYED_ACCOUNT_BALANCE_KEY.to_be_bytes(); + + validate_final_balances(&mut state, &block_context, &hash, expected_fee); + + let class_hash_from_state = state + .get_class_hash_at(deploy_account_tx.contract_address()) + .unwrap(); + assert_eq!(class_hash_from_state, [0; 32]); +} + fn expected_deploy_account_states() -> ( CachedState, CachedState, ) { - let fee = Felt252::from(6157); + let fee = Felt252::from(3709); let mut state_before = CachedState::new( Arc::new(InMemoryStateReader::new( HashMap::from([ @@ -1669,15 +1821,6 @@ fn expected_deploy_account_states() -> ( .cache_mut() .class_hash_initial_values_mut() .insert(Address(0x1001.into()), felt_to_hash(&0x1010.into())); - state_after - .cache_mut() - .class_hash_initial_values_mut() - .insert( - Address(felt_str!( - "386181506763903095743576862849245034886954647214831045800703908858571591162" - )), - [0; 32], - ); state_after.cache_mut().storage_initial_values_mut().insert( ( Address(0x1001.into()),