diff --git a/Cargo.lock b/Cargo.lock index 738d67151..4a0a42c76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2825,6 +2825,15 @@ dependencies = [ "hashbrown 0.12.3", ] +[[package]] +name = "lru" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a83fb7698b3643a0e34f9ae6f2e8f0178c0fd42f8b59d493aa271ff3a5bf21" +dependencies = [ + "hashbrown 0.14.1", +] + [[package]] name = "matchers" version = "0.1.0" @@ -3049,7 +3058,7 @@ checksum = "5f4e3bc495f6e95bc15a6c0c55ac00421504a5a43d09e3cc455d1fea7015581d" dependencies = [ "bitvec", "either", - "lru", + "lru 0.7.8", "num-bigint", "num-integer", "num-modular", @@ -4455,6 +4464,7 @@ dependencies = [ "hex", "keccak", "lazy_static", + "lru 0.11.1", "mimalloc", "num-bigint", "num-integer", diff --git a/Cargo.toml b/Cargo.toml index 89ff5782d..d14f87c95 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,6 +63,7 @@ tracing = "0.1.37" [dev-dependencies] assert_matches = "1.5.0" coverage-helper = "0.2.0" +lru = "0.11.0" pretty_assertions_sorted = "1.2.3" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } diff --git a/README.md b/README.md index 874fa1e9c..f1d176bd7 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,54 @@ You can find an example on how to use the CLI [here](/docs/CLI_USAGE_EXAMPLE.md) ### Customization +#### Contract class cache behavior + +`starknet_in_rust` supports caching contracts in memory. Caching the contracts is useful for +avoiding excessive RPC API usage and keeping the contract class deserialization overhead to the +minimum. The project provides two builtin cache policies: null and permanent. The null cache behaves +as if there was no cache at all. The permanent cache caches everything in memory forever. + +In addition to those two, an example is provided that implements and uses an LRU cache policy. +Long-running applications should ideally implement a cache algorithm suited to their needs or +alternatively use our example's implementation to avoid spamming the API when using the null cache +or blowing the memory usage when running with the permanent cache. + +Customized cache policies may be used by implementing the `ContractClassCache` trait. Check out our +[LRU cache example](examples/lru_cache/main.rs) for more details. Updating the cache requires +manually merging the local state cache into the shared cache manually. This can be done by calling +the `drain_private_contract_class_cache` on the `CachedState` instance. + +```rs +// To use the null cache (aka. no cache at all), create the state as follows: +let cache = Arc::new(NullContractClassCache::default()); +let state1 = CachedState::new(state_reader.clone(), cache.clone()); +let state2 = CachedState::new(state_reader.clone(), cache.clone()); // Cache is reused. + +// Insert state usage here. + +// The null cache doesn't have any method to extend it since it has no data. +``` + +```rs +// If the permanent cache is preferred, then use `PermanentContractClassCache` instead: +let cache = Arc::new(PermanentContractClassCache::default()); +let state1 = CachedState::new(state_reader.clone(), cache.clone()); +let state2 = CachedState::new(state_reader.clone(), cache.clone()); // Cache is reused. + +// Insert state usage here. + +// Extend the shared cache with the states' contracts after using them. +cache.extend(state1.state.drain_private_contract_class_cache()); +cache.extend(state2.state.drain_private_contract_class_cache()); +``` + +#### Logging configuration + +This project uses the [`tracing`](https://crates.io/crates/tracing) crate as a library. Check out +its documentation for more information. + +### Testing + #### Logging configuration This project uses the [`tracing`](https://crates.io/crates/tracing) crate as a library. Check out diff --git a/bench/internals.rs b/bench/internals.rs index 3d89107ae..40bcb8211 100644 --- a/bench/internals.rs +++ b/bench/internals.rs @@ -15,12 +15,15 @@ use starknet_in_rust::{ services::api::contract_classes::{ compiled_class::CompiledClass, deprecated_contract_class::ContractClass, }, - state::in_memory_state_reader::InMemoryStateReader, state::{cached_state::CachedState, state_api::State}, + state::{ + contract_class_cache::PermanentContractClassCache, + in_memory_state_reader::InMemoryStateReader, + }, transaction::{declare::Declare, Deploy, DeployAccount, InvokeFunction}, utils::{Address, ClassHash}, }; -use std::{collections::HashMap, hint::black_box, sync::Arc}; +use std::{hint::black_box, sync::Arc}; #[cfg(feature = "cairo-native")] use std::{cell::RefCell, rc::Rc}; @@ -74,7 +77,10 @@ fn deploy_account( const RUNS: usize = 500; let state_reader = Arc::new(InMemoryStateReader::default()); - let mut state = CachedState::new(state_reader, HashMap::new()); + let mut state = CachedState::new( + state_reader, + Arc::new(PermanentContractClassCache::default()), + ); state .set_contract_class( @@ -86,7 +92,7 @@ fn deploy_account( let block_context = &Default::default(); for _ in 0..RUNS { - let mut state_copy = state.clone(); + let mut state_copy = state.clone_for_testing(); let class_hash = *CLASS_HASH; let signature = SIGNATURE.clone(); scope(|| { @@ -118,12 +124,15 @@ fn declare(#[cfg(feature = "cairo-native")] program_cache: Rc) = match native { + let (erc20_address, mut state): ( + Address, + CachedState, + ) = match native { true => { let erc20_sierra_class = include_bytes!("../starknet_programs/cairo2/erc20.sierra"); let sierra_contract_class: cairo_lang_starknet::contract_class::ContractClass = @@ -259,11 +263,11 @@ fn bench_erc20(executions: usize, native: bool) { let deploy_entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; // insert deployer and erc20 classes into the cache. - contract_class_cache.insert( + contract_class_cache.set_contract_class( *DEPLOYER_CLASS_HASH, CompiledClass::Casm(Arc::new(erc20_deployer_class)), ); - contract_class_cache.insert(*ERC20_CLASS_HASH, erc20_contract_class); + contract_class_cache.set_contract_class(*ERC20_CLASS_HASH, erc20_contract_class); let mut state_reader = InMemoryStateReader::default(); // setup deployer nonce and address into the state reader @@ -275,7 +279,8 @@ fn bench_erc20(executions: usize, native: bool) { .insert(DEPLOYER_ADDRESS.clone(), Felt252::zero()); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = + CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // deploy the erc20 contract by calling the deployer contract. @@ -338,11 +343,11 @@ fn bench_erc20(executions: usize, native: bool) { let deploy_entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; // insert deployer and erc20 classes into the cache. - contract_class_cache.insert( + contract_class_cache.set_contract_class( *DEPLOYER_CLASS_HASH, CompiledClass::Casm(Arc::new(erc20_deployer_class)), ); - contract_class_cache.insert(*ERC20_CLASS_HASH, erc20_contract_class); + contract_class_cache.set_contract_class(*ERC20_CLASS_HASH, erc20_contract_class); let mut state_reader = InMemoryStateReader::default(); // setup deployer nonce and address into the state reader @@ -354,7 +359,8 @@ fn bench_erc20(executions: usize, native: bool) { .insert(DEPLOYER_ADDRESS.clone(), Felt252::zero()); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = + CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // deploy the erc20 contract by calling the deployer contract. @@ -493,7 +499,7 @@ fn bench_erc20(executions: usize, native: bool) { for _ in 0..executions { let result = execute( - &mut state.clone(), + &mut state.clone_for_testing(), &account1_address, &erc20_address, &transfer_entrypoint_selector.clone(), @@ -510,7 +516,7 @@ fn bench_erc20(executions: usize, native: bool) { #[inline(never)] #[allow(clippy::too_many_arguments)] fn execute( - state: &mut CachedState, + state: &mut CachedState, caller_address: &Address, callee_address: &Address, selector: &Felt252, diff --git a/cli/src/main.rs b/cli/src/main.rs index 8ba639082..5d25952b8 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -27,12 +27,14 @@ use starknet_in_rust::{ compiled_class::CompiledClass, deprecated_contract_class::ContractClass, }, state::{cached_state::CachedState, state_api::State}, - state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager, StateDiff}, + state::{ + contract_class_cache::PermanentContractClassCache, + in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager, StateDiff, + }, transaction::{error::TransactionError, InvokeFunction}, utils::{felt_to_hash, string_to_hash, Address}, }; use std::{ - collections::HashMap, path::PathBuf, sync::{Arc, Mutex}, }; @@ -102,11 +104,11 @@ struct DevnetArgs { } struct AppState { - cached_state: Mutex>, + cached_state: Mutex>, } fn declare_parser( - cached_state: &mut CachedState, + cached_state: &mut CachedState, args: &DeclareArgs, ) -> Result<(Felt252, Felt252), ParserError> { let contract_class = @@ -129,7 +131,7 @@ fn declare_parser( } fn deploy_parser( - cached_state: &mut CachedState, + cached_state: &mut CachedState, args: &DeployArgs, ) -> Result<(Felt252, Felt252), ParserError> { let constructor_calldata = match &args.inputs { @@ -155,7 +157,7 @@ fn deploy_parser( } fn invoke_parser( - cached_state: &mut CachedState, + cached_state: &mut CachedState, args: &InvokeArgs, ) -> Result<(Felt252, Felt252), ParserError> { let contract_address = Address( @@ -200,7 +202,7 @@ fn invoke_parser( Some(Felt252::zero()), transaction_hash.unwrap(), )?; - let mut transactional_state = cached_state.create_transactional(); + let mut transactional_state = cached_state.create_transactional()?; let _tx_info = internal_invoke.apply( &mut transactional_state, &BlockContext::default(), @@ -225,7 +227,7 @@ fn invoke_parser( } fn call_parser( - cached_state: &mut CachedState, + cached_state: &mut CachedState, args: &CallArgs, ) -> Result, ParserError> { let contract_address = Address( @@ -326,9 +328,9 @@ async fn call_req(data: web::Data, args: web::Json) -> HttpR pub async fn start_devnet(port: u16) -> Result<(), std::io::Error> { let cached_state = web::Data::new(AppState { - cached_state: Mutex::new(CachedState::::new( + cached_state: Mutex::new(CachedState::new( Arc::new(InMemoryStateReader::default()), - HashMap::new(), + Arc::new(PermanentContractClassCache::default()), )), }); diff --git a/examples/contract_execution/src/main.rs b/examples/contract_execution/src/main.rs index f92ec1f94..2b238d133 100644 --- a/examples/contract_execution/src/main.rs +++ b/examples/contract_execution/src/main.rs @@ -17,13 +17,14 @@ use starknet_in_rust::{ compiled_class::CompiledClass, deprecated_contract_class::ContractClass, }, state::{ - cached_state::CachedState, in_memory_state_reader::InMemoryStateReader, state_api::State, + cached_state::CachedState, contract_class_cache::PermanentContractClassCache, + in_memory_state_reader::InMemoryStateReader, state_api::State, }, transaction::{DeclareV2, DeployAccount, InvokeFunction}, utils::{calculate_sn_keccak, felt_to_hash, Address}, CasmContractClass, SierraContractClass, }; -use std::{collections::HashMap, fs::File, io::BufReader, path::Path, str::FromStr, sync::Arc}; +use std::{fs::File, io::BufReader, path::Path, str::FromStr, sync::Arc}; fn main() { // replace this with the path to your compiled contract @@ -67,7 +68,10 @@ fn test_contract( //* Initialize state //* -------------------------------------------- let state_reader = Arc::new(InMemoryStateReader::default()); - let mut state = CachedState::new(state_reader, HashMap::new()); + let mut state = CachedState::new( + state_reader, + Arc::new(PermanentContractClassCache::default()), + ); //* -------------------------------------------- //* Deploy deployer contract diff --git a/examples/lru_cache/main.rs b/examples/lru_cache/main.rs new file mode 100644 index 000000000..68615a85d --- /dev/null +++ b/examples/lru_cache/main.rs @@ -0,0 +1,155 @@ +// #![deny(warnings)] + +use cairo_vm::felt::Felt252; +use lru::LruCache; +use starknet_in_rust::{ + definitions::{block_context::BlockContext, constants::TRANSACTION_VERSION}, + services::api::contract_classes::{ + compiled_class::CompiledClass, deprecated_contract_class::ContractClass, + }, + state::{ + cached_state::CachedState, + contract_class_cache::{ContractClassCache, PermanentContractClassCache}, + in_memory_state_reader::InMemoryStateReader, + }, + transaction::{Declare, Deploy, InvokeFunction}, + utils::{calculate_sn_keccak, Address, ClassHash}, +}; +use std::{ + num::NonZeroUsize, + path::Path, + sync::{Arc, Mutex}, +}; + +fn main() { + let shared_cache = Arc::new(LruContractCache::new(NonZeroUsize::new(64).unwrap())); + + let ret_data = run_contract( + "starknet_programs/factorial.json", + "factorial", + [10.into()], + shared_cache, + ); + + println!("{ret_data:?}"); +} + +fn run_contract( + contract_path: impl AsRef, + entry_point: impl AsRef, + calldata: impl Into>, + contract_cache: Arc, +) -> Vec { + let block_context = BlockContext::default(); + let chain_id = block_context.starknet_os_config().chain_id().clone(); + let sender_address = Address(1.into()); + let signature = vec![]; + + let state_reader = Arc::new(InMemoryStateReader::default()); + let mut state = CachedState::new( + state_reader, + Arc::new(PermanentContractClassCache::default()), + ); + + let contract_class = ContractClass::from_path(contract_path.as_ref()).unwrap(); + + let declare_tx = Declare::new( + contract_class.clone(), + chain_id.clone(), + sender_address, + 0, + 0.into(), + signature.clone(), + 0.into(), + ) + .unwrap(); + + declare_tx + .execute( + &mut state, + &block_context, + #[cfg(feature = "cairo-native")] + None, + ) + .unwrap(); + + let deploy_tx = Deploy::new( + Default::default(), + contract_class, + Vec::new(), + block_context.starknet_os_config().chain_id().clone(), + TRANSACTION_VERSION.clone(), + ) + .unwrap(); + + deploy_tx + .execute( + &mut state, + &block_context, + #[cfg(feature = "cairo-native")] + None, + ) + .unwrap(); + + let entry_point_selector = + Felt252::from_bytes_be(&calculate_sn_keccak(entry_point.as_ref().as_bytes())); + + let invoke_tx = InvokeFunction::new( + deploy_tx.contract_address.clone(), + entry_point_selector, + 0, + TRANSACTION_VERSION.clone(), + calldata.into(), + signature, + chain_id, + Some(0.into()), + ) + .unwrap(); + + let invoke_tx_execution_info = invoke_tx + .execute( + &mut state, + &block_context, + 0, + #[cfg(feature = "cairo-native")] + None, + ) + .unwrap(); + + // Store the local cache changes into the shared cache. This updates the shared cache with all + // the contracts used on this state. + contract_cache.extend(state.drain_private_contract_class_cache().unwrap()); + + invoke_tx_execution_info.call_info.unwrap().retdata +} + +pub struct LruContractCache { + storage: Mutex>, +} + +impl LruContractCache { + pub fn new(cap: NonZeroUsize) -> Self { + Self { + storage: Mutex::new(LruCache::new(cap)), + } + } + + pub fn extend(&self, other: I) + where + I: IntoIterator, + { + other.into_iter().for_each(|(k, v)| { + self.storage.lock().unwrap().put(k, v); + }); + } +} + +impl ContractClassCache for LruContractCache { + fn get_contract_class(&self, class_hash: ClassHash) -> Option { + self.storage.lock().unwrap().get(&class_hash).cloned() + } + + fn set_contract_class(&self, class_hash: ClassHash, compiled_class: CompiledClass) { + self.storage.lock().unwrap().put(class_hash, compiled_class); + } +} diff --git a/fuzzer/src/main.rs b/fuzzer/src/main.rs index 486fdc098..75e06042b 100644 --- a/fuzzer/src/main.rs +++ b/fuzzer/src/main.rs @@ -3,8 +3,7 @@ #[macro_use] extern crate honggfuzz; -use cairo_vm::felt::Felt252; -use cairo_vm::vm::runners::cairo_runner::ExecutionResources; +use cairo_vm::{felt::Felt252, vm::runners::cairo_runner::ExecutionResources}; use num_traits::Zero; use starknet_in_rust::execution::execution_entry_point::ExecutionResult; use starknet_in_rust::utils::ClassHash; @@ -14,24 +13,21 @@ use starknet_in_rust::{ execution::{ execution_entry_point::ExecutionEntryPoint, CallInfo, CallType, TransactionExecutionContext, }, - services::api::contract_classes::deprecated_contract_class::ContractClass, - state::cached_state::CachedState, - state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager}, + services::api::contract_classes::{ + compiled_class::CompiledClass, deprecated_contract_class::ContractClass, + }, + state::{ + cached_state::CachedState, + contract_class_cache::{ContractClassCache, PermanentContractClassCache}, + in_memory_state_reader::InMemoryStateReader, + ExecutionResourcesManager, + }, utils::{calculate_sn_keccak, Address}, }; - -use std::sync::Arc; use std::{ - collections::{HashMap, HashSet}, - path::PathBuf, + collections::HashSet, fs, path::PathBuf, process::Command, sync::Arc, thread, time::Duration, }; -use starknet_in_rust::services::api::contract_classes::compiled_class::CompiledClass; -use std::fs; -use std::process::Command; -use std::thread; -use std::time::Duration; - fn main() { println!("Starting fuzzer"); let mut iteration = 0; @@ -111,14 +107,14 @@ fn main() { //* Create state reader with class hash data //* -------------------------------------------- - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); // ------------ contract data -------------------- let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class)), ); @@ -131,7 +127,8 @@ fn main() { //* Create state with previous data //* --------------------------------------- - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = + CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); //* ------------------------------------ //* Create execution entry point diff --git a/rpc_state_reader/tests/sir_tests.rs b/rpc_state_reader/tests/sir_tests.rs index 2cf08f75d..b90eb2324 100644 --- a/rpc_state_reader/tests/sir_tests.rs +++ b/rpc_state_reader/tests/sir_tests.rs @@ -23,10 +23,8 @@ use starknet_in_rust::{ execution::{CallInfo, TransactionExecutionInfo}, services::api::contract_classes::compiled_class::CompiledClass, state::{ - cached_state::{CachedState, ContractClassCache}, - state_api::StateReader, - state_cache::StorageEntry, - BlockInfo, + cached_state::CachedState, contract_class_cache::PermanentContractClassCache, + state_api::StateReader, state_cache::StorageEntry, BlockInfo, }, transaction::{Declare, DeclareV2, DeployAccount, InvokeFunction, L1Handler}, utils::{Address, ClassHash}, @@ -222,7 +220,7 @@ pub fn execute_tx_configurable( let trace = rpc_reader.0.get_transaction_trace(&tx_hash).unwrap(); let receipt = rpc_reader.0.get_transaction_receipt(&tx_hash).unwrap(); - let class_cache = ContractClassCache::default(); + let class_cache = Arc::new(PermanentContractClassCache::default()); let mut state = CachedState::new(Arc::new(rpc_reader), class_cache); let block_context = BlockContext::new( diff --git a/src/bin/deploy.rs b/src/bin/deploy.rs index a8b70c69d..ef05c388c 100644 --- a/src/bin/deploy.rs +++ b/src/bin/deploy.rs @@ -1,16 +1,13 @@ -use std::{collections::HashMap, sync::Arc}; - use lazy_static::lazy_static; use starknet_in_rust::{ definitions::{block_context::BlockContext, constants::TRANSACTION_VERSION}, - services::api::contract_classes::{ - compiled_class::CompiledClass, deprecated_contract_class::ContractClass, - }, - state::{ - cached_state::CachedState, in_memory_state_reader::InMemoryStateReader, state_api::State, - }, + services::api::contract_classes::compiled_class::CompiledClass, + services::api::contract_classes::deprecated_contract_class::ContractClass, + state::{cached_state::CachedState, in_memory_state_reader::InMemoryStateReader}, + state::{contract_class_cache::PermanentContractClassCache, state_api::State}, transaction::{Deploy, Transaction}, }; +use std::sync::Arc; #[cfg(feature = "with_mimalloc")] use mimalloc::MiMalloc; @@ -32,7 +29,10 @@ fn main() { let block_context = BlockContext::default(); let state_reader = Arc::new(InMemoryStateReader::default()); - let mut state = CachedState::new(state_reader, HashMap::new()); + let mut state = CachedState::new( + state_reader, + Arc::new(PermanentContractClassCache::default()), + ); let call_data = vec![]; for n in 0..RUNS { diff --git a/src/bin/deploy_invoke.rs b/src/bin/deploy_invoke.rs index e84003f33..33724963b 100644 --- a/src/bin/deploy_invoke.rs +++ b/src/bin/deploy_invoke.rs @@ -1,21 +1,19 @@ -use std::{collections::HashMap, path::PathBuf, sync::Arc}; - use cairo_vm::felt::{felt_str, Felt252}; +use lazy_static::lazy_static; use num_traits::Zero; - use starknet_in_rust::{ definitions::{block_context::BlockContext, constants::TRANSACTION_VERSION}, services::api::contract_classes::{ compiled_class::CompiledClass, deprecated_contract_class::ContractClass, }, state::{ - cached_state::CachedState, in_memory_state_reader::InMemoryStateReader, state_api::State, + cached_state::CachedState, contract_class_cache::PermanentContractClassCache, + in_memory_state_reader::InMemoryStateReader, state_api::State, }, transaction::{Deploy, InvokeFunction, Transaction}, utils::Address, }; - -use lazy_static::lazy_static; +use std::{path::PathBuf, sync::Arc}; #[cfg(feature = "with_mimalloc")] use mimalloc::MiMalloc; @@ -46,7 +44,10 @@ fn main() { let block_context = BlockContext::default(); let state_reader = Arc::new(InMemoryStateReader::default()); - let mut state = CachedState::new(state_reader, HashMap::new()); + let mut state = CachedState::new( + state_reader, + Arc::new(PermanentContractClassCache::default()), + ); let call_data = vec![]; let contract_address_salt = 1.into(); diff --git a/src/bin/fibonacci.rs b/src/bin/fibonacci.rs index c4d3da456..670bee497 100644 --- a/src/bin/fibonacci.rs +++ b/src/bin/fibonacci.rs @@ -1,9 +1,7 @@ -use std::{collections::HashMap, path::PathBuf, sync::Arc}; - use cairo_vm::felt::{felt_str, Felt252}; -use num_traits::Zero; - use lazy_static::lazy_static; +use num_traits::Zero; +use starknet_in_rust::utils::ClassHash; use starknet_in_rust::{ definitions::{block_context::BlockContext, constants::TRANSACTION_VERSION}, execution::{ @@ -13,10 +11,14 @@ use starknet_in_rust::{ compiled_class::CompiledClass, deprecated_contract_class::ContractClass, }, state::cached_state::CachedState, - state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager}, - utils::{Address, ClassHash}, + state::{ + contract_class_cache::PermanentContractClassCache, + in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager, + }, + utils::Address, EntryPointType, }; +use std::{path::PathBuf, sync::Arc}; #[cfg(feature = "with_mimalloc")] use mimalloc::MiMalloc; @@ -60,7 +62,7 @@ fn main() { //* Create state reader with class hash data //* -------------------------------------------- - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = Arc::new(PermanentContractClassCache::default()); // ------------ contract data -------------------- @@ -68,10 +70,10 @@ fn main() { let class_hash = *CONTRACT_CLASS_HASH; let nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.extend([( class_hash, CompiledClass::Deprecated(Arc::new(contract_class)), - ); + )]); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() diff --git a/src/bin/invoke.rs b/src/bin/invoke.rs index ff4efc4c2..a4b61c76c 100644 --- a/src/bin/invoke.rs +++ b/src/bin/invoke.rs @@ -1,20 +1,19 @@ -use std::{collections::HashMap, path::PathBuf, sync::Arc}; - use cairo_vm::felt::{felt_str, Felt252}; +use lazy_static::lazy_static; use num_traits::Zero; - use starknet_in_rust::{ definitions::{block_context::BlockContext, constants::TRANSACTION_VERSION}, services::api::contract_classes::{ compiled_class::CompiledClass, deprecated_contract_class::ContractClass, }, - state::cached_state::CachedState, - state::in_memory_state_reader::InMemoryStateReader, + state::{ + cached_state::CachedState, contract_class_cache::PermanentContractClassCache, + in_memory_state_reader::InMemoryStateReader, + }, transaction::{InvokeFunction, Transaction}, utils::{Address, ClassHash}, }; - -use lazy_static::lazy_static; +use std::{path::PathBuf, sync::Arc}; #[cfg(feature = "with_mimalloc")] use mimalloc::MiMalloc; @@ -65,7 +64,7 @@ fn main() { .insert((CONTRACT_ADDRESS.clone(), [0; 32]), Felt252::zero()); Arc::new(state_reader) }, - HashMap::new(), + Arc::new(PermanentContractClassCache::default()), ); let chain_id = block_context.starknet_os_config().chain_id().clone(); let signature = Vec::new(); diff --git a/src/bin/invoke_with_cachedstate.rs b/src/bin/invoke_with_cachedstate.rs index c5632f0f6..c70a10caa 100644 --- a/src/bin/invoke_with_cachedstate.rs +++ b/src/bin/invoke_with_cachedstate.rs @@ -1,8 +1,6 @@ -use std::{collections::HashMap, path::PathBuf, sync::Arc}; - use cairo_vm::felt::{felt_str, Felt252}; +use lazy_static::lazy_static; use num_traits::Zero; - use starknet_in_rust::{ definitions::{ block_context::{BlockContext, StarknetChainId, StarknetOsConfig}, @@ -12,12 +10,13 @@ use starknet_in_rust::{ compiled_class::CompiledClass, deprecated_contract_class::ContractClass, }, state::in_memory_state_reader::InMemoryStateReader, - state::{cached_state::CachedState, BlockInfo}, + state::{ + cached_state::CachedState, contract_class_cache::PermanentContractClassCache, BlockInfo, + }, transaction::InvokeFunction, utils::{Address, ClassHash}, }; - -use lazy_static::lazy_static; +use std::{collections::HashMap, path::PathBuf, sync::Arc}; #[cfg(feature = "with_mimalloc")] use mimalloc::MiMalloc; @@ -102,7 +101,7 @@ fn main() { } } -fn create_initial_state() -> CachedState { +fn create_initial_state() -> CachedState { let cached_state = CachedState::new( { let mut state_reader = InMemoryStateReader::default(); @@ -123,7 +122,7 @@ fn create_initial_state() -> CachedState { .insert((CONTRACT_ADDRESS.clone(), [0; 32]), Felt252::zero()); Arc::new(state_reader) }, - HashMap::new(), + Arc::new(PermanentContractClassCache::default()), ); cached_state diff --git a/src/core/errors/state_errors.rs b/src/core/errors/state_errors.rs index e5ab48713..ba09e3e51 100644 --- a/src/core/errors/state_errors.rs +++ b/src/core/errors/state_errors.rs @@ -8,21 +8,17 @@ use thiserror::Error; #[derive(Debug, Error)] pub enum StateError { - #[error("Missing ContractClassCache")] - MissingContractClassCache, - #[error("ContractClassCache must be None")] - AssignedContractClassCache, #[error("Missing key in StorageUpdate Map")] EmptyKeyInStorage, #[error("Try to create a CarriedState from a None parent")] ParentCarriedStateIsNone, #[error("Cache already initialized")] StateCacheAlreadyInitialized, - #[error("No contract state assigned for contact address: {0:?}")] + #[error("No contract state assigned for contract address: {0:?}")] NoneContractState(Address), - #[error("No class hash assigned for contact address: {0:?}")] + #[error("No class hash assigned for contract address: {0:?}")] NoneClassHash(Address), - #[error("No nonce assigned for contact address: {0:?}")] + #[error("No nonce assigned for contract address: {0:?}")] NoneNonce(Address), #[error("No storage value assigned for entry: {0:?}")] NoneStorage(StorageEntry), @@ -34,20 +30,16 @@ pub enum StateError { ContractAddressUnavailable(Address), #[error(transparent)] ContractClass(#[from] ContractClassError), - #[error("Missing CasmClassCache")] - MissingCasmClassCache, #[error("Constructor calldata is empty")] - ConstructorCalldataEmpty(), + ConstructorCalldataEmpty, #[error("Error in ExecutionEntryPoint")] - ExecutionEntryPoint(), + ExecutionEntryPoint, #[error("No compiled class found for compiled_class_hash {0:?}")] NoneCompiledClass(ClassHash), #[error("No compiled class hash found for class_hash {0:?}")] NoneCompiledHash(ClassHash), #[error("Missing casm class for hash {0:?}")] MissingCasmClass(ClassHash), - #[error("No class hash declared in class_hash_to_contract_class")] - MissingClassHash(), #[error("Uninitializes class_hash")] UninitiaizedClassHash, #[error(transparent)] @@ -56,4 +48,6 @@ pub enum StateError { CustomError(String), #[error(transparent)] ByteArray(#[from] FromByteArrayError), + #[error("Failed to read contract class cache")] + FailedToReadContractClassCache, } diff --git a/src/execution/execution_entry_point.rs b/src/execution/execution_entry_point.rs index 4f78ceffc..c98e2351d 100644 --- a/src/execution/execution_entry_point.rs +++ b/src/execution/execution_entry_point.rs @@ -12,6 +12,7 @@ use crate::{ }, state::{ cached_state::CachedState, + contract_class_cache::ContractClassCache, contract_storage_state::ContractStorageState, state_api::{State, StateReader}, ExecutionResourcesManager, @@ -46,8 +47,6 @@ use cairo_vm::{ }, }; use std::sync::Arc; -#[cfg(feature = "cairo-native")] -use std::{cell::RefCell, rc::Rc}; #[cfg(feature = "cairo-native")] use { @@ -56,7 +55,9 @@ use { execution_result::NativeExecutionResult, metadata::syscall_handler::SyscallHandlerMeta, utils::felt252_bigint, }, + core::cell::RefCell, serde_json::Value, + std::rc::Rc, }; #[derive(Debug, Default)] @@ -110,9 +111,9 @@ impl ExecutionEntryPoint { /// The information collected from this run (number of steps required, modifications to the /// contract storage, etc.) is saved on the resources manager. /// Returns a CallInfo object that represents the execution. - pub fn execute( + pub fn execute( &self, - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, resources_manager: &mut ExecutionResourcesManager, tx_execution_context: &mut TransactionExecutionContext, @@ -124,6 +125,7 @@ impl ExecutionEntryPoint { ) -> Result where T: StateReader, + C: ContractClassCache, { // lookup the compiled class from the state. let class_hash = self.get_class_hash(state)?; @@ -182,7 +184,7 @@ impl ExecutionEntryPoint { } #[cfg(feature = "cairo-native")] CompiledClass::Sierra(sierra_program_and_entrypoints) => { - let mut transactional_state = state.create_transactional(); + let mut transactional_state = state.create_transactional()?; let program_cache = program_cache.unwrap_or_else(|| { Rc::new(RefCell::new(ProgramCache::new( @@ -295,11 +297,11 @@ impl ExecutionEntryPoint { .ok_or(TransactionError::EntryPointNotFound) } - fn build_call_info_deprecated( + fn build_call_info_deprecated( &self, previous_cairo_usage: ExecutionResources, resources_manager: &ExecutionResourcesManager, - starknet_storage_state: ContractStorageState, + starknet_storage_state: ContractStorageState, events: Vec, l2_to_l1_messages: Vec, internal_calls: Vec, @@ -328,11 +330,11 @@ impl ExecutionEntryPoint { }) } - fn build_call_info( + fn build_call_info( &self, previous_cairo_usage: ExecutionResources, resources_manager: &ExecutionResourcesManager, - starknet_storage_state: ContractStorageState, + starknet_storage_state: ContractStorageState, events: Vec, l2_to_l1_messages: Vec, internal_calls: Vec, @@ -387,9 +389,9 @@ impl ExecutionEntryPoint { get_deployed_address_class_hash_at_address(state, code_address) } - fn _execute_version0_class( + fn _execute_version0_class( &self, - state: &mut CachedState, + state: &mut CachedState, resources_manager: &mut ExecutionResourcesManager, block_context: &BlockContext, tx_execution_context: &mut TransactionExecutionContext, @@ -410,7 +412,7 @@ impl ExecutionEntryPoint { // prepare OS context //let os_context = runner.prepare_os_context(); let os_context = - StarknetRunner::>::prepare_os_context_cairo0( + StarknetRunner::>::prepare_os_context_cairo0( &cairo_runner, &mut vm, ); @@ -481,7 +483,7 @@ impl ExecutionEntryPoint { let retdata = runner.get_return_values()?; - self.build_call_info_deprecated::( + self.build_call_info_deprecated::( previous_cairo_usage, resources_manager, runner.hint_processor.syscall_handler.starknet_storage_state, @@ -492,9 +494,9 @@ impl ExecutionEntryPoint { ) } - fn _execute( + fn _execute( &self, - state: &mut CachedState, + state: &mut CachedState, resources_manager: &mut ExecutionResourcesManager, block_context: &BlockContext, tx_execution_context: &mut TransactionExecutionContext, @@ -520,7 +522,7 @@ impl ExecutionEntryPoint { )?; validate_contract_deployed(state, &self.contract_address)?; // prepare OS context - let os_context = StarknetRunner::>::prepare_os_context_cairo1( + let os_context = StarknetRunner::>::prepare_os_context_cairo1( &cairo_runner, &mut vm, self.initial_gas.into(), @@ -632,7 +634,7 @@ impl ExecutionEntryPoint { resources_manager.cairo_usage += &runner.get_execution_resources()?; let call_result = runner.get_call_result(self.initial_gas)?; - self.build_call_info::( + self.build_call_info::( previous_cairo_usage, resources_manager, runner.hint_processor.syscall_handler.starknet_storage_state, @@ -646,9 +648,9 @@ impl ExecutionEntryPoint { #[cfg(not(feature = "cairo-native"))] #[inline(always)] #[allow(dead_code)] - fn native_execute( + fn native_execute( &self, - _state: &mut CachedState, + _state: &mut CachedState, _sierra_program_and_entrypoints: Arc<(SierraProgram, ContractEntryPoints)>, _tx_execution_context: &mut TransactionExecutionContext, _block_context: &BlockContext, @@ -660,9 +662,9 @@ impl ExecutionEntryPoint { #[cfg(feature = "cairo-native")] #[inline(always)] - fn native_execute( + fn native_execute( &self, - state: &mut CachedState, + state: &mut CachedState, sierra_program_and_entrypoints: Arc<(SierraProgram, ContractEntryPoints)>, tx_execution_context: &TransactionExecutionContext, block_context: &BlockContext, diff --git a/src/lib.rs b/src/lib.rs index 6e3500195..628c99969 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ use crate::{ }, state::{ cached_state::CachedState, + contract_class_cache::ContractClassCache, state_api::{State, StateReader}, ExecutionResourcesManager, }, @@ -17,7 +18,7 @@ use crate::{ utils::Address, }; use cairo_vm::felt::Felt252; -use std::{collections::HashMap, sync::Arc}; +use std::sync::Arc; #[cfg(test)] #[macro_use] @@ -27,9 +28,10 @@ extern crate assert_matches; pub use crate::services::api::contract_classes::deprecated_contract_class::{ ContractEntryPoint, EntryPointType, }; -pub use cairo_lang_starknet::casm_contract_class::CasmContractClass; -pub use cairo_lang_starknet::contract_class::ContractClass; -pub use cairo_lang_starknet::contract_class::ContractClass as SierraContractClass; +pub use cairo_lang_starknet::{ + casm_contract_class::CasmContractClass, contract_class::ContractClass, + contract_class::ContractClass as SierraContractClass, +}; pub use cairo_vm::felt; #[cfg(feature = "cairo-native")] @@ -53,9 +55,10 @@ pub mod transaction; pub mod utils; #[allow(clippy::too_many_arguments)] -pub fn simulate_transaction( +pub fn simulate_transaction( transactions: &[&Transaction], state: S, + contract_class_cache: Arc, block_context: &BlockContext, remaining_gas: u128, skip_validate: bool, @@ -67,7 +70,7 @@ pub fn simulate_transaction( Rc>>, >, ) -> Result, TransactionError> { - let mut cache_state = CachedState::new(Arc::new(state), HashMap::new()); + let mut cache_state = CachedState::new(Arc::new(state), contract_class_cache); let mut result = Vec::with_capacity(transactions.len()); for transaction in transactions { let tx_for_simulation = transaction.create_for_simulation( @@ -91,9 +94,9 @@ pub fn simulate_transaction( } /// Estimate the fee associated with transaction -pub fn estimate_fee( +pub fn estimate_fee( transactions: &[Transaction], - mut cached_state: CachedState, + mut cached_state: CachedState, block_context: &BlockContext, #[cfg(feature = "cairo-native")] program_cache: Option< Rc>>, @@ -101,6 +104,7 @@ pub fn estimate_fee( ) -> Result, TransactionError> where T: StateReader, + C: ContractClassCache, { let mut result = Vec::with_capacity(transactions.len()); for transaction in transactions { @@ -130,11 +134,11 @@ where Ok(result) } -pub fn call_contract( +pub fn call_contract( contract_address: Felt252, entrypoint_selector: Felt252, calldata: Vec, - state: &mut CachedState, + state: &mut CachedState, block_context: BlockContext, caller_address: Address, #[cfg(feature = "cairo-native")] program_cache: Option< @@ -189,9 +193,9 @@ pub fn call_contract( } /// Estimate the fee associated with L1Handler -pub fn estimate_message_fee( +pub fn estimate_message_fee( l1_handler: &L1Handler, - state: T, + mut cached_state: CachedState, block_context: &BlockContext, #[cfg(feature = "cairo-native")] program_cache: Option< Rc>>, @@ -199,10 +203,8 @@ pub fn estimate_message_fee( ) -> Result<(u128, usize), TransactionError> where T: StateReader, + C: ContractClassCache, { - // This is used as a copy of the original state, we can update this cached state freely. - let mut cached_state = CachedState::::new(Arc::new(state), HashMap::new()); - // Check if the contract is deployed. cached_state.get_class_hash_at(l1_handler.contract_address())?; @@ -226,9 +228,9 @@ where } } -pub fn execute_transaction( +pub fn execute_transaction( tx: Transaction, - state: &mut CachedState, + state: &mut CachedState, block_context: BlockContext, remaining_gas: u128, #[cfg(feature = "cairo-native")] program_cache: Option< @@ -264,8 +266,11 @@ mod test { }, simulate_transaction, state::{ - cached_state::CachedState, in_memory_state_reader::InMemoryStateReader, - state_api::State, ExecutionResourcesManager, + cached_state::CachedState, + contract_class_cache::{ContractClassCache, PermanentContractClassCache}, + in_memory_state_reader::InMemoryStateReader, + state_api::State, + ExecutionResourcesManager, }, transaction::{ Declare, DeclareV2, Deploy, DeployAccount, InvokeFunction, L1Handler, Transaction, @@ -287,7 +292,7 @@ mod test { use lazy_static::lazy_static; use num_traits::{Num, One, Zero}; use pretty_assertions_sorted::assert_eq; - use std::{collections::HashMap, path::PathBuf, sync::Arc}; + use std::{path::PathBuf, sync::Arc}; lazy_static! { // include_str! doesn't seem to work in CI @@ -355,13 +360,14 @@ mod test { let entrypoints = contract_class.clone().entry_points_by_type; let entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() @@ -370,7 +376,7 @@ mod test { .address_to_nonce_mut() .insert(address.clone(), nonce); - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); let calldata = [1.into(), 1.into(), 10.into()].to_vec(); let retdata = call_contract( @@ -413,7 +419,7 @@ mod test { // Set contract_class let class_hash = ClassHash([1; 32]); let contract_class = ContractClass::from_path("starknet_programs/l1l2.json").unwrap(); - // Set contact_state + // Set contract_state let contract_address = Address(0.into()); let nonce = Felt252::zero(); @@ -424,14 +430,16 @@ mod test { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(Arc::new(state_reader), HashMap::new()); + let state = CachedState::new( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ); // Initialize state.contract_classes - let contract_classes = HashMap::from([( + state.contract_class_cache().set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class)), - )]); - state.set_contract_classes(contract_classes).unwrap(); + ); let mut block_context = BlockContext::default(); block_context.starknet_os_config.gas_price = 1; @@ -457,13 +465,14 @@ mod test { let entrypoints = contract_class.clone().entry_points_by_type; let entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader @@ -473,7 +482,7 @@ mod test { .address_to_nonce_mut() .insert(address.clone(), nonce); - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); let calldata = [1.into(), 1.into(), 10.into()].to_vec(); let invoke = InvokeFunction::new( @@ -616,6 +625,7 @@ mod test { let context = simulate_transaction( &[&invoke_1, &invoke_2, &invoke_3], state_reader, + Arc::new(PermanentContractClassCache::default()), &block_context, 1000, false, @@ -720,6 +730,7 @@ mod test { let context = simulate_transaction( &[&invoke], state_reader, + Arc::new(PermanentContractClassCache::default()), &block_context, 1000, true, @@ -740,7 +751,11 @@ mod test { #[test] fn test_simulate_deploy() { let state_reader = Arc::new(InMemoryStateReader::default()); - let mut state = CachedState::new(state_reader, HashMap::new()); + let mut state = CachedState::new( + state_reader, + Arc::new(PermanentContractClassCache::default()), + ); + state .set_contract_class( &CLASS_HASH, @@ -766,7 +781,8 @@ mod test { simulate_transaction( &[&internal_deploy], - state, + state.clone_for_testing(), + state.contract_class_cache().clone(), block_context, 100_000_000, false, @@ -783,7 +799,10 @@ mod test { #[test] fn test_simulate_declare() { let state_reader = Arc::new(InMemoryStateReader::default()); - let state = CachedState::new(state_reader, HashMap::new()); + let state = CachedState::new( + state_reader, + Arc::new(PermanentContractClassCache::default()), + ); let block_context = &Default::default(); @@ -805,7 +824,8 @@ mod test { simulate_transaction( &[&declare_tx], - state, + state.clone_for_testing(), + state.contract_class_cache().clone(), block_context, 100_000_000, false, @@ -822,7 +842,11 @@ mod test { #[test] fn test_simulate_invoke() { let state_reader = Arc::new(InMemoryStateReader::default()); - let mut state = CachedState::new(state_reader, HashMap::new()); + let mut state = CachedState::new( + state_reader, + Arc::new(PermanentContractClassCache::default()), + ); + state .set_contract_class( &CLASS_HASH, @@ -877,7 +901,8 @@ mod test { simulate_transaction( &[&invoke_tx], - state, + state.clone_for_testing(), + state.contract_class_cache().clone(), &block_context, 100_000_000, false, @@ -894,7 +919,10 @@ mod test { #[test] fn test_simulate_deploy_account() { let state_reader = Arc::new(InMemoryStateReader::default()); - let mut state = CachedState::new(state_reader, HashMap::new()); + let mut state = CachedState::new( + state_reader, + Arc::new(PermanentContractClassCache::default()), + ); state .set_contract_class( @@ -922,7 +950,8 @@ mod test { simulate_transaction( &[&deploy_account_tx], - state, + state.clone_for_testing(), + state.contract_class_cache().clone(), block_context, 100_000_000, false, @@ -969,7 +998,8 @@ mod test { simulate_transaction( &[&declare_tx], - state, + state.clone_for_testing(), + state.contract_class_cache().clone(), &block_context, 100_000_000, false, @@ -1011,7 +1041,7 @@ mod test { // Set contract_class let class_hash = ClassHash([1; 32]); let contract_class = ContractClass::from_path("starknet_programs/l1l2.json").unwrap(); - // Set contact_state + // Set contract_state let contract_address = Address(0.into()); let nonce = Felt252::zero(); @@ -1022,10 +1052,10 @@ mod test { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(Arc::new(state_reader), HashMap::new()); - - // Initialize state.contract_classes - state.set_contract_classes(HashMap::new()).unwrap(); + let mut state = CachedState::new( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ); state .set_contract_class( @@ -1039,7 +1069,8 @@ mod test { simulate_transaction( &[&l1_handler_tx], - state, + state.clone_for_testing(), + state.contract_class_cache().clone(), &block_context, 100_000_000, false, @@ -1056,7 +1087,10 @@ mod test { #[test] fn test_deploy_and_invoke_simulation() { let state_reader = Arc::new(InMemoryStateReader::default()); - let state = CachedState::new(state_reader, HashMap::new()); + let state = CachedState::new( + state_reader, + Arc::new(PermanentContractClassCache::default()), + ); let block_context = &Default::default(); @@ -1098,7 +1132,8 @@ mod test { simulate_transaction( &[&deploy, &invoke_tx], - state.clone(), + state.clone_for_testing(), + state.contract_class_cache().clone(), block_context, 100_000_000, false, @@ -1158,13 +1193,13 @@ mod test { let contract_class = ContractClass::from_path("starknet_programs/Account.json").unwrap(); // Instantiate CachedState - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); // ------------ contract data -------------------- let class_hash_felt = compute_deprecated_class_hash(&contract_class).unwrap(); let class_hash = ClassHash::from(class_hash_felt); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class)), ); @@ -1183,7 +1218,7 @@ mod test { .address_to_nonce_mut() .insert(sender_address.clone(), Felt252::new(1)); - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // Insert pubkey storage var to pass validation let storage_entry = &( sender_address, @@ -1236,7 +1271,8 @@ mod test { let without_validate_fee = simulate_transaction( &[&declare_tx], - state.clone(), + state.clone_for_testing(), + state.clone_for_testing().contract_class_cache().clone(), &block_context, 100_000_000, true, @@ -1252,7 +1288,8 @@ mod test { let with_validate_fee = simulate_transaction( &[&declare_tx], - state, + state.clone_for_testing(), + state.contract_class_cache().clone(), &block_context, 100_000_000, false, diff --git a/src/runner/mod.rs b/src/runner/mod.rs index c5e86702c..e0033331c 100644 --- a/src/runner/mod.rs +++ b/src/runner/mod.rs @@ -426,7 +426,10 @@ mod test { use super::StarknetRunner; use crate::{ state::cached_state::CachedState, - state::in_memory_state_reader::InMemoryStateReader, + state::{ + contract_class_cache::PermanentContractClassCache, + in_memory_state_reader::InMemoryStateReader, + }, syscalls::{ deprecated_business_logic_syscall_handler::DeprecatedBLSyscallHandler, deprecated_syscall_handler::DeprecatedSyscallHintProcessor, @@ -448,7 +451,12 @@ mod test { let cairo_runner = CairoRunner::new(&program, "starknet", false).unwrap(); let mut vm = VirtualMachine::new(true); - let os_context = StarknetRunner::>>::prepare_os_context_cairo0(&cairo_runner, &mut vm); + let os_context = StarknetRunner::< + SyscallHintProcessor< + CachedState, + PermanentContractClassCache, + >, + >::prepare_os_context_cairo0(&cairo_runner, &mut vm); // is expected to return a pointer to the first segment as there is nothing more in the vm let expected = Vec::from([MaybeRelocatable::from((0, 0))]); @@ -462,7 +470,7 @@ mod test { let cairo_runner = CairoRunner::new(&program, "starknet", false).unwrap(); let vm = VirtualMachine::new(true); - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let hint_processor = DeprecatedSyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -478,7 +486,7 @@ mod test { let cairo_runner = CairoRunner::new(&program, "starknet", false).unwrap(); let vm = VirtualMachine::new(true); - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let hint_processor = DeprecatedSyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -497,7 +505,7 @@ mod test { let cairo_runner = CairoRunner::new(&program, "starknet", false).unwrap(); let vm = VirtualMachine::new(true); - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let hint_processor = DeprecatedSyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -519,7 +527,7 @@ mod test { let cairo_runner = CairoRunner::new(&program, "starknet", false).unwrap(); let vm = VirtualMachine::new(true); - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let hint_processor = DeprecatedSyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -541,7 +549,7 @@ mod test { let cairo_runner = CairoRunner::new(&program, "starknet", false).unwrap(); let vm = VirtualMachine::new(true); - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let hint_processor = DeprecatedSyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -563,7 +571,7 @@ mod test { vm.add_memory_segment(); vm.compute_segments_effective_sizes(); - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let hint_processor = DeprecatedSyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -586,7 +594,7 @@ mod test { vm.add_memory_segment(); vm.compute_segments_effective_sizes(); - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let hint_processor = DeprecatedSyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), diff --git a/src/state/cached_state.rs b/src/state/cached_state.rs index f9911d14b..823f6362c 100644 --- a/src/state/cached_state.rs +++ b/src/state/cached_state.rs @@ -1,4 +1,5 @@ use super::{ + contract_class_cache::ContractClassCache, state_api::{State, StateChangesCount, StateReader}, state_cache::{StateCache, StorageEntry}, }; @@ -17,27 +18,30 @@ use getset::{Getters, MutGetters}; use num_traits::Zero; use std::{ collections::{HashMap, HashSet}, - sync::Arc, + sync::{Arc, RwLock}, }; -pub type ContractClassCache = HashMap; - pub const UNINITIALIZED_CLASS_HASH: &ClassHash = &ClassHash([0u8; 32]); /// Represents a cached state of contract classes with optional caches. -#[derive(Default, Clone, Debug, Eq, Getters, MutGetters, PartialEq)] -pub struct CachedState { +#[derive(Default, Debug, Getters, MutGetters)] +pub struct CachedState { pub state_reader: Arc, #[getset(get = "pub", get_mut = "pub")] pub(crate) cache: StateCache, - #[get = "pub"] - pub(crate) contract_classes: ContractClassCache, + + #[getset(get = "pub", get_mut = "pub")] + pub(crate) contract_class_cache: Arc, + pub(crate) contract_class_cache_private: Arc>>, + + #[cfg(feature = "metrics")] cache_hits: usize, + #[cfg(feature = "metrics")] cache_misses: usize, } #[cfg(feature = "metrics")] -impl CachedState { +impl CachedState { #[inline(always)] pub fn add_hit(&mut self) { self.cache_hits += 1; @@ -50,7 +54,7 @@ impl CachedState { } #[cfg(not(feature = "metrics"))] -impl CachedState { +impl CachedState { #[inline(always)] pub fn add_hit(&mut self) { // does nothing @@ -62,14 +66,18 @@ impl CachedState { } } -impl CachedState { +impl CachedState { /// Constructor, creates a new cached state. - pub fn new(state_reader: Arc, contract_classes: ContractClassCache) -> Self { + pub fn new(state_reader: Arc, contract_classes: Arc) -> Self { Self { cache: StateCache::default(), state_reader, - contract_classes, + contract_class_cache: contract_classes, + contract_class_cache_private: Arc::new(RwLock::new(HashMap::new())), + + #[cfg(feature = "metrics")] cache_hits: 0, + #[cfg(feature = "metrics")] cache_misses: 0, } } @@ -78,43 +86,63 @@ impl CachedState { pub fn new_for_testing( state_reader: Arc, cache: StateCache, - _contract_classes: ContractClassCache, + contract_classes: Arc, ) -> Self { Self { cache, - contract_classes: HashMap::new(), state_reader, + contract_class_cache: contract_classes, + contract_class_cache_private: Arc::new(RwLock::new(HashMap::new())), + + #[cfg(feature = "metrics")] cache_hits: 0, + #[cfg(feature = "metrics")] cache_misses: 0, } } - /// Sets the contract classes cache. - pub fn set_contract_classes( - &mut self, - contract_classes: ContractClassCache, - ) -> Result<(), StateError> { - if !self.contract_classes.is_empty() { - return Err(StateError::AssignedContractClassCache); + /// Clones a CachedState for testing purposes. + pub fn clone_for_testing(&self) -> Self { + Self { + state_reader: self.state_reader.clone(), + cache: self.cache.clone(), + contract_class_cache: self.contract_class_cache.clone(), + contract_class_cache_private: self.contract_class_cache_private.clone(), + #[cfg(feature = "metrics")] + cache_hits: self.cache_hits, + #[cfg(feature = "metrics")] + cache_misses: self.cache_misses, } - self.contract_classes = contract_classes; - Ok(()) + } + + pub fn drain_private_contract_class_cache( + &self, + ) -> Result, StateError> { + Ok(self + .contract_class_cache_private + .read() + .map_err(|_| StateError::FailedToReadContractClassCache)? + .clone() + .into_iter()) } /// Creates a copy of this state with an empty cache for saving changes and applying them /// later. - pub fn create_transactional(&self) -> CachedState { - CachedState { + pub fn create_transactional(&self) -> Result, StateError> { + Ok(CachedState { state_reader: self.state_reader.clone(), cache: self.cache.clone(), - contract_classes: self.contract_classes.clone(), + contract_class_cache: self.contract_class_cache.clone(), + contract_class_cache_private: self.contract_class_cache_private.clone(), + #[cfg(feature = "metrics")] cache_hits: 0, + #[cfg(feature = "metrics")] cache_misses: 0, - } + }) } } -impl StateReader for CachedState { +impl StateReader for CachedState { /// Returns the class hash for a given contract address. /// Returns zero as default value if missing fn get_class_hash_at(&self, contract_address: &Address) -> Result { @@ -165,32 +193,52 @@ impl StateReader for CachedState { } // I: FETCHING FROM CACHE - if let Some(compiled_class) = self.contract_classes.get(class_hash) { + let mut private_cache = self + .contract_class_cache_private + .write() + .map_err(|_| StateError::FailedToReadContractClassCache)?; + if let Some(compiled_class) = private_cache.get(class_hash) { return Ok(compiled_class.clone()); + } else if let Some(compiled_class) = + self.contract_class_cache().get_contract_class(*class_hash) + { + private_cache.insert(*class_hash, compiled_class.clone()); + return Ok(compiled_class); } // I: CASM CONTRACT CLASS : CLASS_HASH if let Some(compiled_class_hash) = self.cache.class_hash_to_compiled_class_hash.get(class_hash) { - if let Some(casm_class) = self.contract_classes.get(compiled_class_hash) { + if let Some(casm_class) = private_cache.get(compiled_class_hash) { return Ok(casm_class.clone()); + } else if let Some(casm_class) = self + .contract_class_cache() + .get_contract_class(*compiled_class_hash) + { + private_cache.insert(*class_hash, casm_class.clone()); + return Ok(casm_class); } } // II: FETCHING FROM STATE_READER - self.state_reader.get_contract_class(class_hash) + let contract_class = self.state_reader.get_contract_class(class_hash)?; + private_cache.insert(*class_hash, contract_class.clone()); + + Ok(contract_class) } } -impl State for CachedState { +impl State for CachedState { /// Stores a contract class in the cache. fn set_contract_class( &mut self, class_hash: &ClassHash, contract_class: &CompiledClass, ) -> Result<(), StateError> { - self.contract_classes + self.contract_class_cache_private + .write() + .map_err(|_| StateError::FailedToReadContractClassCache)? .insert(*class_hash, contract_class.clone()); Ok(()) @@ -423,17 +471,48 @@ impl State for CachedState { // I: FETCHING FROM CACHE // deprecated contract classes dont have compiled class hashes, so we only have one case - if let Some(compiled_class) = self.contract_classes.get(class_hash).cloned() { + let compiled_class_op = self + .contract_class_cache_private + .read() + .map_err(|_| StateError::FailedToReadContractClassCache)? + .get(class_hash) + .cloned(); + if let Some(compiled_class) = compiled_class_op { self.add_hit(); return Ok(compiled_class); + } else if let Some(compiled_class) = + self.contract_class_cache().get_contract_class(*class_hash) + { + self.add_hit(); + self.contract_class_cache_private + .write() + .map_err(|_| StateError::FailedToReadContractClassCache)? + .insert(*class_hash, compiled_class.clone()); + return Ok(compiled_class); } // I: CASM CONTRACT CLASS : CLASS_HASH if let Some(compiled_class_hash) = self.cache.class_hash_to_compiled_class_hash.get(class_hash) { - if let Some(casm_class) = self.contract_classes.get(compiled_class_hash).cloned() { + let casm_class_op = self + .contract_class_cache_private + .read() + .map_err(|_| StateError::FailedToReadContractClassCache)? + .get(compiled_class_hash) + .cloned(); + if let Some(casm_class) = casm_class_op { + self.add_hit(); + return Ok(casm_class); + } else if let Some(casm_class) = self + .contract_class_cache() + .get_contract_class(*compiled_class_hash) + { self.add_hit(); + self.contract_class_cache_private + .write() + .map_err(|_| StateError::FailedToReadContractClassCache)? + .insert(*class_hash, casm_class.clone()); return Ok(casm_class); } } @@ -492,7 +571,7 @@ impl State for CachedState { } } -impl CachedState { +impl CachedState { // Updates the cache's storage_initial_values according to those in storage_writes // If a key is present in the storage_writes but not in storage_initial_values, // the initial value for that key will be fetched from the state_reader and inserted into the cache's storage_initial_values @@ -541,13 +620,15 @@ impl CachedState { #[cfg(test)] mod tests { use super::*; - use crate::{ services::api::contract_classes::deprecated_contract_class::ContractClass, - state::in_memory_state_reader::InMemoryStateReader, + state::{ + contract_class_cache::PermanentContractClassCache, + in_memory_state_reader::InMemoryStateReader, + }, }; - use num_traits::One; + use std::collections::HashMap; /// Test checks if class hashes and nonces are correctly fetched from the state reader. /// It also tests the increment_nonce method. @@ -577,7 +658,10 @@ mod tests { .address_to_storage_mut() .insert(storage_entry, storage_value); - let mut cached_state = CachedState::new(Arc::new(state_reader), HashMap::new()); + let mut cached_state = CachedState::new( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ); assert_eq!( cached_state.get_class_hash_at(&contract_address).unwrap(), @@ -610,9 +694,10 @@ mod tests { CompiledClass::Deprecated(Arc::new(contract_class)), ); - let mut cached_state = CachedState::new(Arc::new(state_reader), HashMap::new()); - - cached_state.set_contract_classes(HashMap::new()).unwrap(); + let cached_state = CachedState::new( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ); assert_eq!( cached_state @@ -628,8 +713,10 @@ mod tests { /// This test verifies the correct handling of storage in the cached state. #[test] fn cached_state_storage_test() { - let mut cached_state = - CachedState::new(Arc::new(InMemoryStateReader::default()), HashMap::new()); + let mut cached_state = CachedState::new( + Arc::new(InMemoryStateReader::default()), + Arc::new(PermanentContractClassCache::default()), + ); let storage_entry: StorageEntry = (Address(31.into()), [0; 32]); let value = Felt252::new(10); @@ -651,7 +738,10 @@ mod tests { let contract_address = Address(32123.into()); - let mut cached_state = CachedState::new(state_reader, HashMap::new()); + let mut cached_state = CachedState::new( + state_reader, + Arc::new(PermanentContractClassCache::default()), + ); assert!(cached_state .deploy_contract(contract_address, ClassHash([10; 32])) @@ -667,7 +757,10 @@ mod tests { let storage_key = [18; 32]; let value = Felt252::new(912); - let mut cached_state = CachedState::new(state_reader, HashMap::new()); + let mut cached_state = CachedState::new( + state_reader, + Arc::new(PermanentContractClassCache::default()), + ); // set storage_key cached_state.set_storage_at(&(contract_address.clone(), storage_key), value.clone()); @@ -698,7 +791,10 @@ mod tests { let contract_address = Address(0.into()); - let mut cached_state = CachedState::new(Arc::new(state_reader), HashMap::new()); + let mut cached_state = CachedState::new( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ); let result = cached_state .deploy_contract(contract_address.clone(), ClassHash([10; 32])) @@ -723,7 +819,10 @@ mod tests { let contract_address = Address(42.into()); - let mut cached_state = CachedState::new(Arc::new(state_reader), HashMap::new()); + let mut cached_state = CachedState::new( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ); cached_state .deploy_contract(contract_address.clone(), ClassHash([10; 32])) @@ -751,7 +850,10 @@ mod tests { let contract_address = Address(32123.into()); - let mut cached_state = CachedState::new(Arc::new(state_reader), HashMap::new()); + let mut cached_state = CachedState::new( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ); cached_state .deploy_contract(contract_address.clone(), ClassHash([10; 32])) @@ -780,7 +882,10 @@ mod tests { let address_one = Address(Felt252::one()); - let mut cached_state = CachedState::new(Arc::new(state_reader), HashMap::new()); + let mut cached_state = CachedState::new( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ); let state_diff = StateDiff { address_to_class_hash: HashMap::from([( @@ -816,8 +921,10 @@ mod tests { #[test] fn count_actual_state_changes_test() { let state_reader = InMemoryStateReader::default(); - - let mut cached_state = CachedState::new(Arc::new(state_reader), HashMap::new()); + let mut cached_state = CachedState::new( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ); let address_one = Address(1.into()); let address_two = Address(2.into()); @@ -865,12 +972,15 @@ mod tests { #[test] fn test_cache_hit_miss_counter() { let state_reader = Arc::new(InMemoryStateReader::default()); - let mut cached_state = CachedState::new(state_reader, HashMap::default()); + let mut cached_state = CachedState::new( + state_reader, + Arc::new(PermanentContractClassCache::default()), + ); let address = Address(1.into()); // Simulate a cache miss by querying an address not in the cache. - let _ = as State>::get_class_hash_at(&mut cached_state, &address); + let _ = as State>::get_class_hash_at(&mut cached_state, &address); assert_eq!(cached_state.cache_misses, 1); assert_eq!(cached_state.cache_hits, 0); @@ -879,18 +989,18 @@ mod tests { .cache .class_hash_writes .insert(address.clone(), ClassHash([0; 32])); - let _ = as State>::get_class_hash_at(&mut cached_state, &address); + let _ = as State>::get_class_hash_at(&mut cached_state, &address); assert_eq!(cached_state.cache_misses, 1); assert_eq!(cached_state.cache_hits, 1); // Simulate another cache hit. - let _ = as State>::get_class_hash_at(&mut cached_state, &address); + let _ = as State>::get_class_hash_at(&mut cached_state, &address); assert_eq!(cached_state.cache_misses, 1); assert_eq!(cached_state.cache_hits, 2); // Simulate another cache miss. let other_address = Address(2.into()); - let _ = as State>::get_class_hash_at(&mut cached_state, &other_address); + let _ = as State>::get_class_hash_at(&mut cached_state, &other_address); assert_eq!(cached_state.cache_misses, 2); assert_eq!(cached_state.cache_hits, 2); } diff --git a/src/state/contract_class_cache.rs b/src/state/contract_class_cache.rs new file mode 100644 index 000000000..d7269ce00 --- /dev/null +++ b/src/state/contract_class_cache.rs @@ -0,0 +1,90 @@ +//! # Contract cache system +//! +//! The contract caches allow the application to keep some contracts within itself, providing them +//! efficiently when they are needed. +//! +//! The trait `ContractClassCache` provides methods for retrieving and inserting elements into the +//! cache. It also contains a method to extend the shared cache from an iterator so that it can be +//! used with the private caches. + +use crate::{services::api::contract_classes::compiled_class::CompiledClass, utils::ClassHash}; +use std::{collections::HashMap, sync::RwLock}; + +/// The contract class cache trait, which must be implemented by all caches. +pub trait ContractClassCache { + /// Provides the stored contract class associated with a specific class hash, or `None` if not + /// present. + fn get_contract_class(&self, class_hash: ClassHash) -> Option; + /// Inserts or replaces a contract class associated with a specific class hash. + fn set_contract_class(&self, class_hash: ClassHash, compiled_class: CompiledClass); +} + +/// A contract class cache which stores nothing. In other words, using this as a cache means there's +/// effectively no cache. +#[derive(Clone, Copy, Debug, Default, Hash)] +pub struct NullContractClassCache; + +impl ContractClassCache for NullContractClassCache { + fn get_contract_class(&self, _class_hash: ClassHash) -> Option { + None + } + + fn set_contract_class(&self, _class_hash: ClassHash, _compiled_class: CompiledClass) { + // Nothing needs to be done here. + } +} + +/// A contract class cache which stores everything. This cache is useful for testing but will +/// probably end up taking all the memory available if the application is long running. +#[derive(Debug, Default)] +pub struct PermanentContractClassCache { + storage: RwLock>, +} + +impl PermanentContractClassCache { + pub fn extend(&self, other: I) + where + I: IntoIterator, + { + self.storage.write().unwrap().extend(other); + } +} + +impl ContractClassCache for PermanentContractClassCache { + fn get_contract_class(&self, class_hash: ClassHash) -> Option { + self.storage.read().unwrap().get(&class_hash).cloned() + } + + fn set_contract_class(&self, class_hash: ClassHash, compiled_class: CompiledClass) { + self.storage + .write() + .unwrap() + .insert(class_hash, compiled_class); + } +} + +impl Clone for PermanentContractClassCache { + fn clone(&self) -> Self { + Self { + storage: RwLock::new(self.storage.read().unwrap().clone()), + } + } +} + +impl IntoIterator for PermanentContractClassCache { + type Item = (ClassHash, CompiledClass); + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.storage.into_inner().unwrap().into_iter() + } +} + +impl IntoIterator for &PermanentContractClassCache { + type Item = (ClassHash, CompiledClass); + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.storage.read().unwrap().clone().into_iter() + } +} diff --git a/src/state/contract_storage_state.rs b/src/state/contract_storage_state.rs index fd23ecfa3..5064d2757 100644 --- a/src/state/contract_storage_state.rs +++ b/src/state/contract_storage_state.rs @@ -1,5 +1,6 @@ use super::{ cached_state::CachedState, + contract_class_cache::ContractClassCache, state_api::{State, StateReader}, }; use crate::{ @@ -10,16 +11,16 @@ use cairo_vm::felt::Felt252; use std::collections::HashSet; #[derive(Debug)] -pub(crate) struct ContractStorageState<'a, S: StateReader> { - pub(crate) state: &'a mut CachedState, +pub(crate) struct ContractStorageState<'a, S: StateReader, C: ContractClassCache> { + pub(crate) state: &'a mut CachedState, pub(crate) contract_address: Address, /// Maintain all read request values in chronological order pub(crate) read_values: Vec, pub(crate) accessed_keys: HashSet, } -impl<'a, S: StateReader> ContractStorageState<'a, S> { - pub(crate) fn new(state: &'a mut CachedState, contract_address: Address) -> Self { +impl<'a, S: StateReader, C: ContractClassCache> ContractStorageState<'a, S, C> { + pub(crate) fn new(state: &'a mut CachedState, contract_address: Address) -> Self { Self { state, contract_address, diff --git a/src/state/mod.rs b/src/state/mod.rs index c021862a6..63b88ecc7 100644 --- a/src/state/mod.rs +++ b/src/state/mod.rs @@ -1,25 +1,25 @@ -pub mod cached_state; -pub(crate) mod contract_storage_state; -pub mod in_memory_state_reader; -pub mod state_api; -pub mod state_cache; - +use self::{ + cached_state::CachedState, contract_class_cache::ContractClassCache, state_api::StateReader, + state_cache::StateCache, +}; use crate::{ core::errors::state_errors::StateError, + transaction::error::TransactionError, utils::{ - get_keys, to_cache_state_storage_mapping, to_state_diff_storage_mapping, CompiledClassHash, + get_keys, to_cache_state_storage_mapping, to_state_diff_storage_mapping, Address, + ClassHash, CompiledClassHash, }, }; use cairo_vm::{felt::Felt252, vm::runners::cairo_runner::ExecutionResources}; use getset::Getters; use std::{collections::HashMap, sync::Arc}; -use crate::{ - transaction::error::TransactionError, - utils::{Address, ClassHash}, -}; - -use self::{cached_state::CachedState, state_api::StateReader, state_cache::StateCache}; +pub mod cached_state; +pub mod contract_class_cache; +pub(crate) mod contract_storage_state; +pub mod in_memory_state_reader; +pub mod state_api; +pub mod state_cache; #[derive(Clone, Debug, PartialEq, Eq)] pub struct BlockInfo { @@ -142,11 +142,16 @@ impl StateDiff { }) } - pub fn to_cached_state(&self, state_reader: Arc) -> Result, StateError> + pub fn to_cached_state( + &self, + state_reader: Arc, + contract_class_cache: Arc, + ) -> Result, StateError> where T: StateReader + Clone, + C: ContractClassCache, { - let mut cache_state = CachedState::new(state_reader, HashMap::new()); + let mut cache_state = CachedState::new(state_reader, contract_class_cache); let cache_storage_mapping = to_cache_state_storage_mapping(&self.storage_updates); cache_state.cache_mut().set_initial_values( @@ -213,19 +218,19 @@ fn test_validate_legal_progress() { #[cfg(test)] mod test { - use std::{collections::HashMap, sync::Arc}; - use super::StateDiff; use crate::{ - state::in_memory_state_reader::InMemoryStateReader, state::{ cached_state::CachedState, + contract_class_cache::PermanentContractClassCache, + in_memory_state_reader::InMemoryStateReader, state_api::StateReader, state_cache::{StateCache, StorageEntry}, }, utils::{Address, ClassHash}, }; use cairo_vm::felt::Felt252; + use std::{collections::HashMap, sync::Arc}; #[test] fn test_from_cached_state_without_updates() { @@ -242,7 +247,10 @@ mod test { .address_to_nonce .insert(contract_address, nonce); - let cached_state = CachedState::new(Arc::new(state_reader), HashMap::new()); + let cached_state = CachedState::new( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ); let diff = StateDiff::from_cached_state(&cached_state.cache).unwrap(); @@ -312,16 +320,27 @@ mod test { .address_to_nonce .insert(contract_address.clone(), nonce); - let cached_state_original = - CachedState::new(Arc::new(state_reader.clone()), HashMap::new()); + let cached_state_original = CachedState::new( + Arc::new(state_reader.clone()), + Arc::new(PermanentContractClassCache::default()), + ); let diff = StateDiff::from_cached_state(cached_state_original.cache()).unwrap(); - let cached_state = diff.to_cached_state(Arc::new(state_reader)).unwrap(); + let cached_state = diff + .to_cached_state::<_, PermanentContractClassCache>( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ) + .unwrap(); assert_eq!( - cached_state_original.contract_classes(), - cached_state.contract_classes() + (&*cached_state_original.contract_class_cache().clone()) + .into_iter() + .collect::>(), + (&*cached_state.contract_class_cache().clone()) + .into_iter() + .collect::>() ); assert_eq!( cached_state_original @@ -360,8 +379,11 @@ mod test { storage_writes, HashMap::new(), ); - let cached_state = - CachedState::new_for_testing(Arc::new(state_reader), cache, HashMap::new()); + let cached_state = CachedState::new_for_testing( + Arc::new(state_reader), + cache, + Arc::new(PermanentContractClassCache::default()), + ); let mut diff = StateDiff::from_cached_state(cached_state.cache()).unwrap(); diff --git a/src/syscalls/business_logic_syscall_handler.rs b/src/syscalls/business_logic_syscall_handler.rs index dd45fe33a..33c54a7c9 100644 --- a/src/syscalls/business_logic_syscall_handler.rs +++ b/src/syscalls/business_logic_syscall_handler.rs @@ -55,6 +55,7 @@ use cairo_vm::{ use lazy_static::lazy_static; use crate::services::api::contract_classes::deprecated_contract_class::EntryPointType; +use crate::state::contract_class_cache::ContractClassCache; use num_traits::{One, ToPrimitive, Zero}; #[cfg(feature = "cairo-native")] @@ -127,7 +128,7 @@ lazy_static! { } #[derive(Debug)] -pub struct BusinessLogicSyscallHandler<'a, S: StateReader> { +pub struct BusinessLogicSyscallHandler<'a, S: StateReader, C: ContractClassCache> { pub(crate) events: Vec, pub(crate) expected_syscall_ptr: Relocatable, pub(crate) resources_manager: ExecutionResourcesManager, @@ -138,7 +139,7 @@ pub struct BusinessLogicSyscallHandler<'a, S: StateReader> { pub(crate) read_only_segments: Vec<(Relocatable, MaybeRelocatable)>, pub(crate) internal_calls: Vec, pub(crate) block_context: BlockContext, - pub(crate) starknet_storage_state: ContractStorageState<'a, S>, + pub(crate) starknet_storage_state: ContractStorageState<'a, S, C>, pub(crate) support_reverted: bool, pub(crate) entry_point_selector: Felt252, pub(crate) selector_to_syscall: &'a HashMap, @@ -147,11 +148,11 @@ pub struct BusinessLogicSyscallHandler<'a, S: StateReader> { // TODO: execution entry point may no be a parameter field, but there is no way to generate a default for now -impl<'a, S: StateReader> BusinessLogicSyscallHandler<'a, S> { +impl<'a, S: StateReader, C: ContractClassCache> BusinessLogicSyscallHandler<'a, S, C> { #[allow(clippy::too_many_arguments)] pub fn new( tx_execution_context: TransactionExecutionContext, - state: &'a mut CachedState, + state: &'a mut CachedState, resources_manager: ExecutionResourcesManager, caller_address: Address, contract_address: Address, @@ -184,7 +185,8 @@ impl<'a, S: StateReader> BusinessLogicSyscallHandler<'a, S> { execution_info_ptr: None, } } - pub fn default_with_state(state: &'a mut CachedState) -> Self { + + pub fn default_with_state(state: &'a mut CachedState) -> Self { BusinessLogicSyscallHandler::new_for_testing( BlockInfo::default(), Default::default(), @@ -195,7 +197,7 @@ impl<'a, S: StateReader> BusinessLogicSyscallHandler<'a, S> { pub fn new_for_testing( block_info: BlockInfo, _contract_address: Address, - state: &'a mut CachedState, + state: &'a mut CachedState, ) -> Self { let syscalls = Vec::from([ "emit_event".to_string(), @@ -359,7 +361,7 @@ impl<'a, S: StateReader> BusinessLogicSyscallHandler<'a, S> { if self.constructor_entry_points_empty(compiled_class)? { if !constructor_calldata.is_empty() { - return Err(StateError::ConstructorCalldataEmpty()); + return Err(StateError::ConstructorCalldataEmpty); } let call_info = CallInfo::empty_constructor_call( @@ -398,7 +400,7 @@ impl<'a, S: StateReader> BusinessLogicSyscallHandler<'a, S> { #[cfg(feature = "cairo-native")] program_cache, ) - .map_err(|_| StateError::ExecutionEntryPoint())?; + .map_err(|_| StateError::ExecutionEntryPoint)?; let call_info = call_info.ok_or(StateError::CustomError( revert_error.unwrap_or_else(|| "Execution error".to_string()), @@ -608,7 +610,7 @@ impl<'a, S: StateReader> BusinessLogicSyscallHandler<'a, S> { } } -impl<'a, S: StateReader> BusinessLogicSyscallHandler<'a, S> { +impl<'a, S: StateReader, C: ContractClassCache> BusinessLogicSyscallHandler<'a, S, C> { fn emit_event( &mut self, vm: &VirtualMachine, diff --git a/src/syscalls/deprecated_business_logic_syscall_handler.rs b/src/syscalls/deprecated_business_logic_syscall_handler.rs index 92d3ac8f4..9537e1890 100644 --- a/src/syscalls/deprecated_business_logic_syscall_handler.rs +++ b/src/syscalls/deprecated_business_logic_syscall_handler.rs @@ -26,6 +26,7 @@ use crate::{ }, state::ExecutionResourcesManager, state::{ + contract_class_cache::ContractClassCache, contract_storage_state::ContractStorageState, state_api::{State, StateReader}, BlockInfo, @@ -55,7 +56,7 @@ use { //* ----------------------------------- /// Deprecated version of BusinessLogicSyscallHandler. #[derive(Debug)] -pub struct DeprecatedBLSyscallHandler<'a, S: StateReader> { +pub struct DeprecatedBLSyscallHandler<'a, S: StateReader, C: ContractClassCache> { pub(crate) tx_execution_context: TransactionExecutionContext, /// Events emitted by the current contract call. pub(crate) events: Vec, @@ -67,15 +68,15 @@ pub struct DeprecatedBLSyscallHandler<'a, S: StateReader> { pub(crate) l2_to_l1_messages: Vec, pub(crate) block_context: BlockContext, pub(crate) tx_info_ptr: Option, - pub(crate) starknet_storage_state: ContractStorageState<'a, S>, + pub(crate) starknet_storage_state: ContractStorageState<'a, S, C>, pub(crate) internal_calls: Vec, pub(crate) expected_syscall_ptr: Relocatable, } -impl<'a, S: StateReader> DeprecatedBLSyscallHandler<'a, S> { +impl<'a, S: StateReader, C: ContractClassCache> DeprecatedBLSyscallHandler<'a, S, C> { pub fn new( tx_execution_context: TransactionExecutionContext, - state: &'a mut CachedState, + state: &'a mut CachedState, resources_manager: ExecutionResourcesManager, caller_address: Address, contract_address: Address, @@ -106,7 +107,7 @@ impl<'a, S: StateReader> DeprecatedBLSyscallHandler<'a, S> { } } - pub fn default_with(state: &'a mut CachedState) -> Self { + pub fn default_with(state: &'a mut CachedState) -> Self { DeprecatedBLSyscallHandler::new_for_testing(BlockInfo::default(), Default::default(), state) } @@ -119,7 +120,7 @@ impl<'a, S: StateReader> DeprecatedBLSyscallHandler<'a, S> { pub fn new_for_testing( block_info: BlockInfo, _contract_address: Address, - state: &'a mut CachedState, + state: &'a mut CachedState, ) -> Self { let syscalls = Vec::from([ "emit_event".to_string(), @@ -218,7 +219,7 @@ impl<'a, S: StateReader> DeprecatedBLSyscallHandler<'a, S> { if self.constructor_entry_points_empty(contract_class)? { if !constructor_calldata.is_empty() { - return Err(StateError::ConstructorCalldataEmpty()); + return Err(StateError::ConstructorCalldataEmpty); } let call_info = CallInfo::empty_constructor_call( @@ -255,7 +256,7 @@ impl<'a, S: StateReader> DeprecatedBLSyscallHandler<'a, S> { #[cfg(feature = "cairo-native")] program_cache, ) - .map_err(|_| StateError::ExecutionEntryPoint())?; + .map_err(|_| StateError::ExecutionEntryPoint)?; if let Some(call_info) = call_info.call_info { self.internal_calls.push(call_info); @@ -265,7 +266,7 @@ impl<'a, S: StateReader> DeprecatedBLSyscallHandler<'a, S> { } } -impl<'a, S: StateReader> DeprecatedBLSyscallHandler<'a, S> { +impl<'a, S: StateReader, C: ContractClassCache> DeprecatedBLSyscallHandler<'a, S, C> { pub(crate) fn emit_event( &mut self, vm: &VirtualMachine, @@ -1007,7 +1008,10 @@ impl<'a, S: StateReader> DeprecatedBLSyscallHandler<'a, S> { mod tests { use crate::{ state::cached_state::CachedState, - state::in_memory_state_reader::InMemoryStateReader, + state::{ + contract_class_cache::PermanentContractClassCache, + in_memory_state_reader::InMemoryStateReader, + }, syscalls::syscall_handler_errors::SyscallHandlerError, utils::{test_utils::*, Address}, }; @@ -1028,7 +1032,7 @@ mod tests { use std::{any::Any, borrow::Cow, collections::HashMap, sync::Arc}; type DeprecatedBLSyscallHandler<'a> = - super::DeprecatedBLSyscallHandler<'a, InMemoryStateReader>; + super::DeprecatedBLSyscallHandler<'a, InMemoryStateReader, PermanentContractClassCache>; #[test] fn run_alloc_hint_ap_is_not_empty() { @@ -1049,7 +1053,7 @@ mod tests { #[test] fn deploy_from_zero_error() { - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let mut syscall = DeprecatedBLSyscallHandler::default_with(&mut state); let mut vm = vm!(); @@ -1080,7 +1084,7 @@ mod tests { #[test] fn can_allocate_segment() { - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let mut syscall_handler = DeprecatedBLSyscallHandler::default_with(&mut state); let mut vm = vm!(); let data = vec![MaybeRelocatable::Int(7.into())]; @@ -1096,7 +1100,7 @@ mod tests { #[test] fn test_get_block_number() { - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let mut syscall = DeprecatedBLSyscallHandler::default_with(&mut state); let mut vm = vm!(); @@ -1116,7 +1120,7 @@ mod tests { #[test] fn test_get_contract_address_ok() { - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let mut syscall = DeprecatedBLSyscallHandler::default_with(&mut state); let mut vm = vm!(); @@ -1133,7 +1137,7 @@ mod tests { #[test] fn test_storage_read_empty() { - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let mut syscall_handler = DeprecatedBLSyscallHandler::default_with(&mut state); assert_matches!( @@ -1151,7 +1155,10 @@ mod tests { Felt252::zero(), ); // Create empty-cached state - let mut state = CachedState::new(Arc::new(state_reader), HashMap::new()); + let mut state = CachedState::new( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ); let mut syscall_handler = DeprecatedBLSyscallHandler::default_with(&mut state); // Perform write assert!(syscall_handler diff --git a/src/syscalls/deprecated_syscall_handler.rs b/src/syscalls/deprecated_syscall_handler.rs index 5300b3332..9dc3c1bbc 100644 --- a/src/syscalls/deprecated_syscall_handler.rs +++ b/src/syscalls/deprecated_syscall_handler.rs @@ -2,38 +2,41 @@ use super::{ deprecated_business_logic_syscall_handler::DeprecatedBLSyscallHandler, hint_code::*, other_syscalls, syscall_handler::HintProcessorPostRun, }; -use crate::{state::state_api::StateReader, syscalls::syscall_handler_errors::SyscallHandlerError}; -use cairo_vm::{ - felt::Felt252, - hint_processor::hint_processor_definition::HintProcessorLogic, - vm::runners::cairo_runner::{ResourceTracker, RunResources}, +use crate::{ + state::{contract_class_cache::ContractClassCache, state_api::StateReader}, + syscalls::syscall_handler_errors::SyscallHandlerError, }; use cairo_vm::{ + felt::Felt252, hint_processor::{ builtin_hint_processor::{ builtin_hint_processor_definition::{BuiltinHintProcessor, HintProcessorData}, hint_utils::get_relocatable_from_var_name, }, - hint_processor_definition::HintReference, + hint_processor_definition::{HintProcessorLogic, HintReference}, }, serde::deserialize_program::ApTracking, types::{exec_scope::ExecutionScopes, relocatable::Relocatable}, - vm::{errors::hint_errors::HintError, vm_core::VirtualMachine}, + vm::{ + errors::hint_errors::HintError, + runners::cairo_runner::{ResourceTracker, RunResources}, + vm_core::VirtualMachine, + }, }; use std::{any::Any, collections::HashMap}; /// Definition of the deprecated syscall hint processor with associated structs -pub(crate) struct DeprecatedSyscallHintProcessor<'a, S: StateReader> { +pub(crate) struct DeprecatedSyscallHintProcessor<'a, S: StateReader, C: ContractClassCache> { pub(crate) builtin_hint_processor: BuiltinHintProcessor, - pub(crate) syscall_handler: DeprecatedBLSyscallHandler<'a, S>, + pub(crate) syscall_handler: DeprecatedBLSyscallHandler<'a, S, C>, run_resources: RunResources, } /// Implementations and methods for DeprecatedSyscallHintProcessor -impl<'a, S: StateReader> DeprecatedSyscallHintProcessor<'a, S> { +impl<'a, S: StateReader, C: ContractClassCache> DeprecatedSyscallHintProcessor<'a, S, C> { /// Constructor for DeprecatedSyscallHintProcessor pub fn new( - syscall_handler: DeprecatedBLSyscallHandler<'a, S>, + syscall_handler: DeprecatedBLSyscallHandler<'a, S, C>, run_resources: RunResources, ) -> Self { DeprecatedSyscallHintProcessor { @@ -191,7 +194,9 @@ impl<'a, S: StateReader> DeprecatedSyscallHintProcessor<'a, S> { } /// Implement the HintProcessorLogic trait for DeprecatedSyscallHintProcessor -impl<'a, S: StateReader> HintProcessorLogic for DeprecatedSyscallHintProcessor<'a, S> { +impl<'a, S: StateReader, C: ContractClassCache> HintProcessorLogic + for DeprecatedSyscallHintProcessor<'a, S, C> +{ /// Executes the received hint fn execute_hint( &mut self, @@ -215,7 +220,9 @@ impl<'a, S: StateReader> HintProcessorLogic for DeprecatedSyscallHintProcessor<' } /// Implement the ResourceTracker trait for DeprecatedSyscallHintProcessor -impl<'a, S: StateReader> ResourceTracker for DeprecatedSyscallHintProcessor<'a, S> { +impl<'a, S: StateReader, C: ContractClassCache> ResourceTracker + for DeprecatedSyscallHintProcessor<'a, S, C> +{ fn consumed(&self) -> bool { self.run_resources.consumed() } @@ -234,7 +241,9 @@ impl<'a, S: StateReader> ResourceTracker for DeprecatedSyscallHintProcessor<'a, } /// Implement the HintProcessorPostRun trait for DeprecatedSyscallHintProcessor -impl<'a, S: StateReader> HintProcessorPostRun for DeprecatedSyscallHintProcessor<'a, S> { +impl<'a, S: StateReader, C: ContractClassCache> HintProcessorPostRun + for DeprecatedSyscallHintProcessor<'a, S, C> +{ /// Validates the execution post run fn post_run( &self, @@ -259,8 +268,6 @@ fn get_syscall_ptr( /// Unit tests for this module #[cfg(test)] mod tests { - use std::sync::Arc; - use super::*; use crate::services::api::contract_classes::compiled_class::CompiledClass; use crate::services::api::contract_classes::deprecated_contract_class::EntryPointType; @@ -275,8 +282,10 @@ mod tests { execution::{OrderedEvent, OrderedL2ToL1Message, TransactionExecutionContext}, memory_insert, services::api::contract_classes::deprecated_contract_class::ContractClass, - state::in_memory_state_reader::InMemoryStateReader, - state::{cached_state::CachedState, state_api::State}, + state::{ + cached_state::CachedState, contract_class_cache::PermanentContractClassCache, + in_memory_state_reader::InMemoryStateReader, state_api::State, + }, syscalls::deprecated_syscall_request::{ DeprecatedDeployRequest, DeprecatedSendMessageToL1SysCallRequest, DeprecatedSyscallRequest, @@ -290,18 +299,20 @@ mod tests { }; use cairo_vm::relocatable; use num_traits::Num; + use std::sync::Arc; type DeprecatedBLSyscallHandler<'a> = crate::syscalls::deprecated_business_logic_syscall_handler::DeprecatedBLSyscallHandler< 'a, InMemoryStateReader, + PermanentContractClassCache, >; - type SyscallHintProcessor<'a, T> = super::DeprecatedSyscallHintProcessor<'a, T>; + type SyscallHintProcessor<'a, T, C> = super::DeprecatedSyscallHintProcessor<'a, T, C>; /// Test checks if the send_message_to_l1 syscall is read correctly. #[test] fn read_send_message_to_l1_request() { - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let syscall = DeprecatedBLSyscallHandler::default_with(&mut state); let mut vm = vm!(); add_segments!(vm, 3); @@ -324,7 +335,7 @@ mod tests { /// Test verifies if the read syscall can correctly read a deploy request. #[test] fn read_deploy_syscall_request() { - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let syscall = DeprecatedBLSyscallHandler::default_with(&mut state); let mut vm = vm!(); add_segments!(vm, 2); @@ -357,7 +368,7 @@ mod tests { /// Test checks the get block timestamp for business logic. #[test] fn get_block_timestamp_for_business_logic() { - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let syscall = DeprecatedBLSyscallHandler::default_with(&mut state); let mut vm = vm!(); add_segments!(vm, 2); @@ -375,7 +386,7 @@ mod tests { let hint_data = HintProcessorData::new_default(GET_BLOCK_TIMESTAMP.to_string(), ids_data); // invoke syscall - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let mut syscall_handler = SyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -409,7 +420,7 @@ mod tests { let hint_data = HintProcessorData::new_default(GET_SEQUENCER_ADDRESS.to_string(), ids_data); // invoke syscall - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let mut syscall_handler = SyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -463,7 +474,7 @@ mod tests { let hint_data = HintProcessorData::new_default(EMIT_EVENT_CODE.to_string(), ids_data); // invoke syscall - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let mut syscall_handler = SyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -522,7 +533,7 @@ mod tests { let hint_data = HintProcessorData::new_default(GET_TX_INFO.to_string(), ids_data); // invoke syscall - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let mut syscall_handler_hint_processor = SyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -630,7 +641,7 @@ mod tests { let hint_data = HintProcessorData::new_default(GET_TX_INFO.to_string(), ids_data); // invoke syscall - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let mut syscall_handler_hint_processor = SyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -671,7 +682,7 @@ mod tests { let hint_data = HintProcessorData::new_default(GET_CALLER_ADDRESS.to_string(), ids_data); // invoke syscall - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let mut hint_processor = SyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -719,7 +730,7 @@ mod tests { let hint_data = HintProcessorData::new_default(SEND_MESSAGE_TO_L1.to_string(), ids_data); // invoke syscall - let mut state = CachedState::::default(); + let mut state = CachedState::::default(); let mut hint_processor = SyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -768,7 +779,10 @@ mod tests { ] ); - let mut state = CachedState::new(Arc::new(InMemoryStateReader::default()), HashMap::new()); + let mut state = CachedState::new( + Arc::new(InMemoryStateReader::default()), + Arc::new(PermanentContractClassCache::default()), + ); let mut hint_processor = SyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -804,7 +818,10 @@ mod tests { let hint_data = HintProcessorData::new_default(GET_CONTRACT_ADDRESS.to_string(), ids_data); // invoke syscall - let mut state = CachedState::new(Arc::new(InMemoryStateReader::default()), HashMap::new()); + let mut state = CachedState::new( + Arc::new(InMemoryStateReader::default()), + Arc::new(PermanentContractClassCache::default()), + ); let mut hint_processor = SyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -846,7 +863,10 @@ mod tests { let hint_data = HintProcessorData::new_default(GET_TX_SIGNATURE.to_string(), ids_data); // invoke syscall - let mut state = CachedState::new(Arc::new(InMemoryStateReader::default()), HashMap::new()); + let mut state = CachedState::new( + Arc::new(InMemoryStateReader::default()), + Arc::new(PermanentContractClassCache::default()), + ); let mut syscall_handler_hint_processor = SyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -914,7 +934,10 @@ mod tests { let hint_data = HintProcessorData::new_default(STORAGE_READ.to_string(), ids_data); - let mut state = CachedState::new(Arc::new(InMemoryStateReader::default()), HashMap::new()); + let mut state = CachedState::new( + Arc::new(InMemoryStateReader::default()), + Arc::new(PermanentContractClassCache::default()), + ); let mut syscall_handler_hint_processor = SyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -979,7 +1002,10 @@ mod tests { let hint_data = HintProcessorData::new_default(STORAGE_WRITE.to_string(), ids_data); - let mut state = CachedState::new(Arc::new(InMemoryStateReader::default()), HashMap::new()); + let mut state = CachedState::new( + Arc::new(InMemoryStateReader::default()), + Arc::new(PermanentContractClassCache::default()), + ); let mut syscall_handler_hint_processor = SyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -1053,7 +1079,10 @@ mod tests { let hint_data = HintProcessorData::new_default(DEPLOY.to_string(), ids_data); // Create SyscallHintProcessor - let mut state = CachedState::new(Arc::new(InMemoryStateReader::default()), HashMap::new()); + let mut state = CachedState::new( + Arc::new(InMemoryStateReader::default()), + Arc::new(PermanentContractClassCache::default()), + ); let mut syscall_handler_hint_processor = SyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -1147,7 +1176,10 @@ mod tests { ); // Create SyscallHintProcessor - let mut state = CachedState::new(Arc::new(InMemoryStateReader::default()), HashMap::new()); + let mut state = CachedState::new( + Arc::new(InMemoryStateReader::default()), + Arc::new(PermanentContractClassCache::default()), + ); let mut syscall_handler_hint_processor = SyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -1218,7 +1250,7 @@ mod tests { ) .unwrap(); - let mut transactional = state.create_transactional(); + let mut transactional = state.create_transactional().unwrap(); // Invoke result let result = internal_invoke_function .apply( diff --git a/src/syscalls/deprecated_syscall_response.rs b/src/syscalls/deprecated_syscall_response.rs index 79fbe4be2..a6baf3767 100644 --- a/src/syscalls/deprecated_syscall_response.rs +++ b/src/syscalls/deprecated_syscall_response.rs @@ -309,28 +309,34 @@ impl DeprecatedWriteSyscallResponse for DeprecatedStorageReadResponse { #[cfg(test)] mod tests { - use std::{collections::HashMap, sync::Arc}; - use super::*; use crate::{ add_segments, state::cached_state::CachedState, - state::in_memory_state_reader::InMemoryStateReader, + state::{ + contract_class_cache::PermanentContractClassCache, + in_memory_state_reader::InMemoryStateReader, + }, utils::{get_integer, test_utils::vm}, }; use cairo_vm::relocatable; + use std::sync::Arc; type DeprecatedBLSyscallHandler<'a> = crate::syscalls::deprecated_business_logic_syscall_handler::DeprecatedBLSyscallHandler< 'a, InMemoryStateReader, + PermanentContractClassCache, >; /// Unit test to check the write_get_caller_address_response function #[test] fn write_get_caller_address_response() { // Initialize a VM and syscall handler - let mut state = CachedState::new(Arc::new(InMemoryStateReader::default()), HashMap::new()); + let mut state = CachedState::new( + Arc::new(InMemoryStateReader::default()), + Arc::new(PermanentContractClassCache::default()), + ); let syscall = DeprecatedBLSyscallHandler::default_with(&mut state); let mut vm = vm!(); diff --git a/src/syscalls/native_syscall_handler.rs b/src/syscalls/native_syscall_handler.rs index c4a24b7c1..d68028888 100644 --- a/src/syscalls/native_syscall_handler.rs +++ b/src/syscalls/native_syscall_handler.rs @@ -1,3 +1,4 @@ +use crate::ContractClassCache; use std::{cell::RefCell, rc::Rc}; use cairo_native::{ @@ -36,11 +37,12 @@ use crate::{ }; #[derive(Debug)] -pub struct NativeSyscallHandler<'a, 'cache, S> +pub struct NativeSyscallHandler<'a, 'cache, S, C> where S: StateReader, + C: ContractClassCache, { - pub(crate) starknet_storage_state: ContractStorageState<'a, S>, + pub(crate) starknet_storage_state: ContractStorageState<'a, S, C>, pub(crate) contract_address: Address, pub(crate) caller_address: Address, pub(crate) entry_point_selector: Felt252, @@ -53,7 +55,7 @@ where pub(crate) program_cache: Rc>>, } -impl<'a, 'cache, S: StateReader> NativeSyscallHandler<'a, 'cache, S> { +impl<'a, 'cache, S: StateReader, C: ContractClassCache> NativeSyscallHandler<'a, 'cache, S, C> { /// Generic code that needs to be run on all syscalls. fn handle_syscall_request(&mut self, gas: &mut u128, syscall_name: &str) -> SyscallResult<()> { let required_gas = SYSCALL_GAS_COST @@ -76,7 +78,9 @@ impl<'a, 'cache, S: StateReader> NativeSyscallHandler<'a, 'cache, S> { } } -impl<'a, 'cache, S: StateReader> StarkNetSyscallHandler for NativeSyscallHandler<'a, 'cache, S> { +impl<'a, 'cache, S: StateReader, C: ContractClassCache> StarkNetSyscallHandler + for NativeSyscallHandler<'a, 'cache, S, C> +{ fn get_block_hash( &mut self, block_number: u64, @@ -591,9 +595,10 @@ impl<'a, 'cache, S: StateReader> StarkNetSyscallHandler for NativeSyscallHandler } } -impl<'a, 'cache, S> NativeSyscallHandler<'a, 'cache, S> +impl<'a, 'cache, S, C> NativeSyscallHandler<'a, 'cache, S, C> where S: StateReader, + C: ContractClassCache, { fn execute_constructor_entry_point( &mut self, @@ -618,7 +623,7 @@ where if self.constructor_entry_points_empty(compiled_class)? { if !constructor_calldata.is_empty() { - return Err(StateError::ConstructorCalldataEmpty()); + return Err(StateError::ConstructorCalldataEmpty); } let call_info = CallInfo::empty_constructor_call( @@ -652,7 +657,7 @@ where u64::MAX, Some(self.program_cache.clone()), ) - .map_err(|_| StateError::ExecutionEntryPoint())?; + .map_err(|_| StateError::ExecutionEntryPoint)?; let call_info = call_info.ok_or(StateError::CustomError("Execution error".to_string()))?; diff --git a/src/syscalls/syscall_handler.rs b/src/syscalls/syscall_handler.rs index 7936f5a11..dfef8051d 100644 --- a/src/syscalls/syscall_handler.rs +++ b/src/syscalls/syscall_handler.rs @@ -1,5 +1,5 @@ use super::business_logic_syscall_handler::BusinessLogicSyscallHandler; -use crate::state::state_api::StateReader; +use crate::state::{contract_class_cache::ContractClassCache, state_api::StateReader}; use crate::transaction::error::TransactionError; use cairo_lang_casm::{ hints::{Hint, StarknetHint}, @@ -32,15 +32,15 @@ pub(crate) trait HintProcessorPostRun { } #[allow(unused)] -pub(crate) struct SyscallHintProcessor<'a, S: StateReader> { +pub(crate) struct SyscallHintProcessor<'a, S: StateReader, C: ContractClassCache> { pub(crate) cairo1_hint_processor: Cairo1HintProcessor, - pub(crate) syscall_handler: BusinessLogicSyscallHandler<'a, S>, + pub(crate) syscall_handler: BusinessLogicSyscallHandler<'a, S, C>, pub(crate) run_resources: RunResources, } -impl<'a, S: StateReader> SyscallHintProcessor<'a, S> { +impl<'a, S: StateReader, C: ContractClassCache> SyscallHintProcessor<'a, S, C> { pub fn new( - syscall_handler: BusinessLogicSyscallHandler<'a, S>, + syscall_handler: BusinessLogicSyscallHandler<'a, S, C>, hints: &[(usize, Vec)], run_resources: RunResources, ) -> Self { @@ -52,7 +52,9 @@ impl<'a, S: StateReader> SyscallHintProcessor<'a, S> { } } -impl<'a, S: StateReader> HintProcessorLogic for SyscallHintProcessor<'a, S> { +impl<'a, S: StateReader, C: ContractClassCache> HintProcessorLogic + for SyscallHintProcessor<'a, S, C> +{ fn execute_hint( &mut self, vm: &mut VirtualMachine, @@ -117,7 +119,7 @@ impl<'a, S: StateReader> HintProcessorLogic for SyscallHintProcessor<'a, S> { } } -impl<'a, S: StateReader> ResourceTracker for SyscallHintProcessor<'a, S> { +impl<'a, S: StateReader, C: ContractClassCache> ResourceTracker for SyscallHintProcessor<'a, S, C> { fn consumed(&self) -> bool { self.run_resources.consumed() } @@ -135,7 +137,9 @@ impl<'a, S: StateReader> ResourceTracker for SyscallHintProcessor<'a, S> { } } -impl<'a, S: StateReader> HintProcessorPostRun for SyscallHintProcessor<'a, S> { +impl<'a, S: StateReader, C: ContractClassCache> HintProcessorPostRun + for SyscallHintProcessor<'a, S, C> +{ fn post_run( &self, runner: &mut VirtualMachine, diff --git a/src/transaction/declare.rs b/src/transaction/declare.rs index 5f02bc4d1..51c2f61f4 100644 --- a/src/transaction/declare.rs +++ b/src/transaction/declare.rs @@ -1,25 +1,23 @@ -use crate::execution::execution_entry_point::ExecutionResult; +use crate::core::contract_address::compute_deprecated_class_hash; +use crate::core::transaction_hash::calculate_declare_transaction_hash; +use crate::definitions::block_context::BlockContext; +use crate::definitions::constants::VALIDATE_DECLARE_ENTRY_POINT_SELECTOR; +use crate::definitions::transaction_type::TransactionType; use crate::execution::gas_usage::get_onchain_data_segment_length; use crate::execution::os_usage::ESTIMATED_DECLARE_STEPS; use crate::services::api::contract_classes::deprecated_contract_class::EntryPointType; use crate::services::eth_definitions::eth_gas_constants::SHARP_GAS_PER_MEMORY_WORD; use crate::state::cached_state::CachedState; -use crate::state::state_api::StateChangesCount; +use crate::state::contract_class_cache::ContractClassCache; +use crate::state::state_api::{State, StateChangesCount, StateReader}; use crate::{ - core::{ - contract_address::compute_deprecated_class_hash, - transaction_hash::calculate_declare_transaction_hash, - }, - definitions::{ - block_context::BlockContext, constants::VALIDATE_DECLARE_ENTRY_POINT_SELECTOR, - transaction_type::TransactionType, - }, execution::{ - execution_entry_point::ExecutionEntryPoint, CallInfo, TransactionExecutionContext, - TransactionExecutionInfo, + execution_entry_point::{ExecutionEntryPoint, ExecutionResult}, + CallInfo, TransactionExecutionContext, TransactionExecutionInfo, + }, + services::api::contract_classes::{ + compiled_class::CompiledClass, deprecated_contract_class::ContractClass, }, - services::api::contract_classes::deprecated_contract_class::ContractClass, - state::state_api::{State, StateReader}, state::ExecutionResourcesManager, transaction::error::TransactionError, utils::{ @@ -32,7 +30,6 @@ use num_traits::{One, Zero}; use super::fee::{calculate_tx_fee, charge_fee}; use super::{get_tx_version, Transaction}; -use crate::services::api::contract_classes::compiled_class::CompiledClass; use std::collections::HashMap; use std::fmt::Debug; use std::sync::Arc; @@ -187,9 +184,9 @@ impl Declare { /// Executes a call to the cairo-vm using the accounts_validation.cairo contract to validate /// the contract that is being declared. Then it returns the transaction execution info of the run. - pub fn apply( + pub fn apply( &self, - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, #[cfg(feature = "cairo-native")] program_cache: Option< Rc>>, @@ -246,9 +243,9 @@ impl Declare { ) } - pub fn run_validate_entrypoint( + pub fn run_validate_entrypoint( &self, - state: &mut CachedState, + state: &mut CachedState, resources_manager: &mut ExecutionResourcesManager, block_context: &BlockContext, #[cfg(feature = "cairo-native")] program_cache: Option< @@ -369,9 +366,9 @@ impl Declare { self.sender_address = ?self.sender_address, self.nonce = ?self.nonce, ))] - pub fn execute( + pub fn execute( &self, - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, #[cfg(feature = "cairo-native")] program_cache: Option< Rc>>, @@ -453,13 +450,6 @@ impl Declare { #[cfg(test)] mod tests { use super::*; - use cairo_vm::{ - felt::{felt_str, Felt252}, - vm::runners::cairo_runner::ExecutionResources, - }; - use num_traits::{One, Zero}; - use std::{collections::HashMap, path::PathBuf, sync::Arc}; - use crate::{ definitions::{ block_context::{BlockContext, StarknetChainId}, @@ -470,12 +460,16 @@ mod tests { services::api::contract_classes::{ compiled_class::CompiledClass, deprecated_contract_class::ContractClass, }, - state::cached_state::CachedState, state::in_memory_state_reader::InMemoryStateReader, + state::{cached_state::CachedState, contract_class_cache::PermanentContractClassCache}, utils::{felt_to_hash, Address}, }; - - use super::Declare; + use cairo_vm::{ + felt::{felt_str, Felt252}, + vm::runners::cairo_runner::ExecutionResources, + }; + use num_traits::{One, Zero}; + use std::{collections::HashMap, path::PathBuf, sync::Arc}; #[test] fn declare_fibonacci() { @@ -484,13 +478,13 @@ mod tests { ContractClass::from_path("starknet_programs/account_without_validation.json").unwrap(); // Instantiate CachedState - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); // ------------ contract data -------------------- let hash = compute_deprecated_class_hash(&contract_class).unwrap(); let class_hash = ClassHash::from(hash); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class.clone())), ); @@ -509,7 +503,7 @@ mod tests { .address_to_nonce_mut() .insert(sender_address, Felt252::new(1)); - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); //* --------------------------------------- //* Test declare with previous data @@ -602,13 +596,13 @@ mod tests { let contract_class = ContractClass::from_path(path).unwrap(); // Instantiate CachedState - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); // ------------ contract data -------------------- let hash = compute_deprecated_class_hash(&contract_class).unwrap(); let class_hash = felt_to_hash(&hash); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class)), ); @@ -627,7 +621,7 @@ mod tests { .address_to_nonce_mut() .insert(sender_address, Felt252::zero()); - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); //* --------------------------------------- //* Test declare with previous data @@ -691,13 +685,13 @@ mod tests { let contract_class = ContractClass::from_path(path).unwrap(); // Instantiate CachedState - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); // ------------ contract data -------------------- let hash = compute_deprecated_class_hash(&contract_class).unwrap(); let class_hash = felt_to_hash(&hash); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class)), ); @@ -716,7 +710,7 @@ mod tests { .address_to_nonce_mut() .insert(sender_address, Felt252::zero()); - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); //* --------------------------------------- //* Test declare with previous data @@ -769,11 +763,11 @@ mod tests { #[test] fn validate_transaction_should_fail() { // Instantiate CachedState - let contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let state_reader = Arc::new(InMemoryStateReader::default()); - let mut state = CachedState::new(state_reader, contract_class_cache); + let mut state = CachedState::new(state_reader, Arc::new(contract_class_cache)); // There are no account contracts in the state, so the transaction should fail let fib_contract_class = @@ -813,13 +807,13 @@ mod tests { let contract_class = ContractClass::from_path(path).unwrap(); // Instantiate CachedState - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); // ------------ contract data -------------------- let hash = compute_deprecated_class_hash(&contract_class).unwrap(); let class_hash = felt_to_hash(&hash); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class)), ); @@ -838,7 +832,7 @@ mod tests { .address_to_nonce_mut() .insert(sender_address, Felt252::zero()); - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); //* --------------------------------------- //* Test declare with previous data @@ -879,13 +873,13 @@ mod tests { let contract_class = ContractClass::from_path("starknet_programs/Account.json").unwrap(); // Instantiate CachedState - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); // ------------ contract data -------------------- let hash = compute_deprecated_class_hash(&contract_class).unwrap(); let class_hash = ClassHash::from(hash); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class)), ); @@ -904,7 +898,7 @@ mod tests { .address_to_nonce_mut() .insert(sender_address.clone(), Felt252::new(1)); - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // Insert pubkey storage var to pass validation let storage_entry = &( sender_address, @@ -958,7 +952,7 @@ mod tests { // --------------------- // Comparison // --------------------- - let mut state_copy = state.clone(); + let mut state_copy = state.clone_for_testing(); let mut bock_context = BlockContext::default(); bock_context.starknet_os_config.gas_price = 12; assert!( @@ -1002,8 +996,8 @@ mod tests { Felt252::zero(), ) .unwrap(); - let result = internal_declare.execute::>( - &mut CachedState::default(), + let result = internal_declare.execute( + &mut CachedState::::default(), &BlockContext::default(), #[cfg(feature = "cairo-native")] None, diff --git a/src/transaction/declare_v2.rs b/src/transaction/declare_v2.rs index 30d9b7c84..109eaed5b 100644 --- a/src/transaction/declare_v2.rs +++ b/src/transaction/declare_v2.rs @@ -10,6 +10,7 @@ use crate::services::api::contract_classes::deprecated_contract_class::EntryPoin use crate::services::api::contract_classes::compiled_class::CompiledClass; use crate::services::eth_definitions::eth_gas_constants::SHARP_GAS_PER_MEMORY_WORD; use crate::state::cached_state::CachedState; +use crate::state::contract_class_cache::ContractClassCache; use crate::state::state_api::StateChangesCount; use crate::utils::ClassHash; use crate::{ @@ -357,9 +358,9 @@ impl DeclareV2 { self.sender_address = ?self.sender_address, self.nonce = ?self.nonce, ))] - pub fn execute( + pub fn execute( &self, - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, #[cfg(feature = "cairo-native")] program_cache: Option< Rc>>, @@ -470,10 +471,10 @@ impl DeclareV2 { Ok(()) } - fn run_validate_entrypoint( + fn run_validate_entrypoint( &self, mut remaining_gas: u128, - state: &mut CachedState, + state: &mut CachedState, resources_manager: &mut ExecutionResourcesManager, block_context: &BlockContext, #[cfg(feature = "cairo-native")] program_cache: Option< @@ -568,9 +569,6 @@ impl DeclareV2 { #[cfg(test)] mod tests { - use std::sync::Arc; - use std::{collections::HashMap, fs::File, io::BufReader, path::PathBuf}; - use super::DeclareV2; use crate::core::contract_address::{compute_casm_class_hash, compute_sierra_class_hash}; use crate::definitions::block_context::{BlockContext, StarknetChainId}; @@ -580,12 +578,16 @@ mod tests { use crate::transaction::error::TransactionError; use crate::utils::ClassHash; use crate::{ - state::cached_state::CachedState, state::in_memory_state_reader::InMemoryStateReader, + state::{ + cached_state::CachedState, contract_class_cache::PermanentContractClassCache, + in_memory_state_reader::InMemoryStateReader, + }, utils::Address, }; use cairo_lang_starknet::casm_contract_class::CasmContractClass; use cairo_vm::felt::Felt252; use num_traits::{One, Zero}; + use std::{fs::File, io::BufReader, path::PathBuf, sync::Arc}; #[test] fn create_declare_v2_without_casm_contract_class_test() { @@ -629,9 +631,9 @@ mod tests { .unwrap(); // crate state to store casm contract class - let casm_contract_class_cache = HashMap::new(); + let casm_contract_class_cache = PermanentContractClassCache::default(); let state_reader = Arc::new(InMemoryStateReader::default()); - let mut state = CachedState::new(state_reader, casm_contract_class_cache); + let mut state = CachedState::new(state_reader, Arc::new(casm_contract_class_cache)); // call compile and store assert!(internal_declare @@ -699,9 +701,9 @@ mod tests { .unwrap(); // crate state to store casm contract class - let casm_contract_class_cache = HashMap::new(); + let casm_contract_class_cache = PermanentContractClassCache::default(); let state_reader = Arc::new(InMemoryStateReader::default()); - let mut state = CachedState::new(state_reader, casm_contract_class_cache); + let mut state = CachedState::new(state_reader, Arc::new(casm_contract_class_cache)); // call compile and store assert!(internal_declare @@ -771,9 +773,9 @@ mod tests { .unwrap(); // crate state to store casm contract class - let casm_contract_class_cache = HashMap::new(); + let casm_contract_class_cache = PermanentContractClassCache::default(); let state_reader = Arc::new(InMemoryStateReader::default()); - let mut state = CachedState::new(state_reader, casm_contract_class_cache); + let mut state = CachedState::new(state_reader, Arc::new(casm_contract_class_cache)); // call compile and store assert!(internal_declare @@ -841,9 +843,9 @@ mod tests { .unwrap(); // crate state to store casm contract class - let casm_contract_class_cache = HashMap::new(); + let casm_contract_class_cache = PermanentContractClassCache::default(); let state_reader = Arc::new(InMemoryStateReader::default()); - let mut state = CachedState::new(state_reader, casm_contract_class_cache); + let mut state = CachedState::new(state_reader, Arc::new(casm_contract_class_cache)); // call compile and store assert!(internal_declare @@ -912,9 +914,9 @@ mod tests { .unwrap(); // crate state to store casm contract class - let casm_contract_class_cache = HashMap::new(); + let casm_contract_class_cache = PermanentContractClassCache::default(); let state_reader = Arc::new(InMemoryStateReader::default()); - let mut state = CachedState::new(state_reader, casm_contract_class_cache); + let mut state = CachedState::new(state_reader, Arc::new(casm_contract_class_cache)); let expected_err = format!( "Invalid compiled class, expected class hash: {}, but received: {}", @@ -962,8 +964,8 @@ mod tests { Felt252::zero(), ) .unwrap(); - let result = internal_declare.execute::>( - &mut CachedState::default(), + let result = internal_declare.execute( + &mut CachedState::::default(), &BlockContext::default(), #[cfg(feature = "cairo-native")] None, diff --git a/src/transaction/deploy.rs b/src/transaction/deploy.rs index 846f043fd..eb01ccd42 100644 --- a/src/transaction/deploy.rs +++ b/src/transaction/deploy.rs @@ -1,11 +1,4 @@ -use std::sync::Arc; - -use crate::execution::execution_entry_point::ExecutionResult; -use crate::services::api::contract_classes::deprecated_contract_class::{ - ContractClass, EntryPointType, -}; -use crate::state::cached_state::CachedState; -use crate::syscalls::syscall_handler_errors::SyscallHandlerError; +use super::Transaction; use crate::{ core::{ contract_address::compute_deprecated_class_hash, errors::hash_errors::HashError, @@ -16,22 +9,31 @@ use crate::{ transaction_type::TransactionType, }, execution::{ - execution_entry_point::ExecutionEntryPoint, CallInfo, TransactionExecutionContext, - TransactionExecutionInfo, + execution_entry_point::{ExecutionEntryPoint, ExecutionResult}, + CallInfo, TransactionExecutionContext, TransactionExecutionInfo, }, hash_utils::calculate_contract_address, services::api::{ - contract_class_errors::ContractClassError, contract_classes::compiled_class::CompiledClass, + contract_class_errors::ContractClassError, + contract_classes::{ + compiled_class::CompiledClass, + deprecated_contract_class::{ContractClass, EntryPointType}, + }, }, - state::state_api::{State, StateReader}, - state::ExecutionResourcesManager, + state::{ + cached_state::CachedState, + contract_class_cache::ContractClassCache, + state_api::{State, StateReader}, + ExecutionResourcesManager, + }, + syscalls::syscall_handler_errors::SyscallHandlerError, transaction::error::TransactionError, utils::{calculate_tx_resources, felt_to_hash, Address, ClassHash}, }; use cairo_vm::felt::Felt252; use num_traits::Zero; +use std::sync::Arc; -use super::Transaction; use std::fmt::Debug; #[cfg(feature = "cairo-native")] @@ -151,9 +153,9 @@ impl Deploy { /// ## Parameters /// - state: A state that implements the [`State`] and [`StateReader`] traits. /// - block_context: The block's execution context. - pub fn apply( + pub fn apply( &self, - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, #[cfg(feature = "cairo-native")] program_cache: Option< Rc>>, @@ -224,9 +226,9 @@ impl Deploy { /// ## Parameters /// - state: A state that implements the [`State`] and [`StateReader`] traits. /// - block_context: The block's execution context. - pub fn invoke_constructor( + pub fn invoke_constructor( &self, - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, #[cfg(feature = "cairo-native")] program_cache: Option< Rc>>, @@ -301,9 +303,9 @@ impl Deploy { self.contract_address = ?self.contract_address, self.contract_address_salt = ?self.contract_address_salt, ))] - pub fn execute( + pub fn execute( &self, - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, #[cfg(feature = "cairo-native")] program_cache: Option< Rc>>, @@ -345,19 +347,24 @@ impl Deploy { #[cfg(test)] mod tests { - use std::{collections::HashMap, sync::Arc}; - use super::*; use crate::{ - state::cached_state::CachedState, state::in_memory_state_reader::InMemoryStateReader, + state::{ + cached_state::CachedState, contract_class_cache::PermanentContractClassCache, + in_memory_state_reader::InMemoryStateReader, + }, utils::calculate_sn_keccak, }; + use std::{collections::HashMap, sync::Arc}; #[test] fn invoke_constructor_test() { // Instantiate CachedState let state_reader = Arc::new(InMemoryStateReader::default()); - let mut state = CachedState::new(state_reader, HashMap::new()); + let mut state = CachedState::new( + state_reader, + Arc::new(PermanentContractClassCache::default()), + ); // Set contract_class let contract_class = @@ -411,7 +418,10 @@ mod tests { fn invoke_constructor_no_calldata_should_fail() { // Instantiate CachedState let state_reader = Arc::new(InMemoryStateReader::default()); - let mut state = CachedState::new(state_reader, HashMap::new()); + let mut state = CachedState::new( + state_reader, + Arc::new(PermanentContractClassCache::default()), + ); let contract_class = ContractClass::from_path("starknet_programs/constructor.json").unwrap(); @@ -444,7 +454,10 @@ mod tests { fn deploy_contract_without_constructor_should_fail() { // Instantiate CachedState let state_reader = Arc::new(InMemoryStateReader::default()); - let mut state = CachedState::new(state_reader, HashMap::new()); + let mut state = CachedState::new( + state_reader, + Arc::new(PermanentContractClassCache::default()), + ); let contract_path = "starknet_programs/amm.json"; let contract_class = ContractClass::from_path(contract_path).unwrap(); diff --git a/src/transaction/deploy_account.rs b/src/transaction/deploy_account.rs index d9406a983..6ba421865 100644 --- a/src/transaction/deploy_account.rs +++ b/src/transaction/deploy_account.rs @@ -31,8 +31,11 @@ use crate::{ services::api::{ contract_class_errors::ContractClassError, contract_classes::compiled_class::CompiledClass, }, - state::state_api::{State, StateReader}, - state::ExecutionResourcesManager, + state::{ + contract_class_cache::ContractClassCache, + state_api::{State, StateReader}, + ExecutionResourcesManager, + }, syscalls::syscall_handler_errors::SyscallHandlerError, transaction::error::TransactionError, utils::{calculate_tx_resources, Address, ClassHash}, @@ -178,9 +181,9 @@ impl DeployAccount { self.contract_address_salt = ?self.contract_address_salt, self.nonce = ?self.nonce, ))] - pub fn execute( + pub fn execute( &self, - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, #[cfg(feature = "cairo-native")] program_cache: Option< Rc>>, @@ -200,7 +203,7 @@ impl DeployAccount { self.handle_nonce(state)?; - let mut transactional_state = state.create_transactional(); + let mut transactional_state = state.create_transactional()?; let mut tx_exec_info = self.apply( &mut transactional_state, block_context, @@ -266,9 +269,9 @@ impl DeployAccount { /// Execute a call to the cairo-vm using the accounts_validation.cairo contract to validate /// the contract that is being declared. Then it returns the transaction execution info of the run. - fn apply( + fn apply( &self, - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, #[cfg(feature = "cairo-native")] program_cache: Option< Rc>>, @@ -322,10 +325,10 @@ impl DeployAccount { )) } - pub fn handle_constructor( + pub fn handle_constructor( &self, contract_class: CompiledClass, - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, resources_manager: &mut ExecutionResourcesManager, #[cfg(feature = "cairo-native")] program_cache: Option< @@ -419,9 +422,9 @@ impl DeployAccount { ) } - pub fn run_constructor_entrypoint( + pub fn run_constructor_entrypoint( &self, - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, resources_manager: &mut ExecutionResourcesManager, #[cfg(feature = "cairo-native")] program_cache: Option< @@ -471,9 +474,9 @@ impl DeployAccount { ) } - pub fn run_validate_entrypoint( + pub fn run_validate_entrypoint( &self, - state: &mut CachedState, + state: &mut CachedState, resources_manager: &mut ExecutionResourcesManager, block_context: &BlockContext, #[cfg(feature = "cairo-native")] program_cache: Option< @@ -643,17 +646,16 @@ impl TryFrom for DeployAcco #[cfg(test)] mod tests { - use std::{collections::HashMap, path::PathBuf, sync::Arc}; - use super::*; use crate::{ core::{contract_address::compute_deprecated_class_hash, errors::state_errors::StateError}, definitions::block_context::StarknetChainId, services::api::contract_classes::deprecated_contract_class::ContractClass, - state::cached_state::CachedState, state::in_memory_state_reader::InMemoryStateReader, + state::{cached_state::CachedState, contract_class_cache::PermanentContractClassCache}, utils::felt_to_hash, }; + use std::{path::PathBuf, sync::Arc}; #[test] fn get_state_selector() { @@ -664,7 +666,10 @@ mod tests { let class_hash = felt_to_hash(&hash); let block_context = BlockContext::default(); - let mut _state = CachedState::new(Arc::new(InMemoryStateReader::default()), HashMap::new()); + let mut _state = CachedState::new( + Arc::new(InMemoryStateReader::default()), + Arc::new(PermanentContractClassCache::default()), + ); let internal_deploy = DeployAccount::new( class_hash, @@ -696,7 +701,10 @@ mod tests { let class_hash = felt_to_hash(&hash); let block_context = BlockContext::default(); - let mut state = CachedState::new(Arc::new(InMemoryStateReader::default()), HashMap::new()); + let mut state = CachedState::new( + Arc::new(InMemoryStateReader::default()), + Arc::new(PermanentContractClassCache::default()), + ); let internal_deploy = DeployAccount::new( class_hash, @@ -758,7 +766,10 @@ mod tests { let class_hash = felt_to_hash(&hash); let block_context = BlockContext::default(); - let mut state = CachedState::new(Arc::new(InMemoryStateReader::default()), HashMap::new()); + let mut state = CachedState::new( + Arc::new(InMemoryStateReader::default()), + Arc::new(PermanentContractClassCache::default()), + ); let internal_deploy = DeployAccount::new( class_hash, @@ -802,8 +813,8 @@ mod tests { chain_id, ) .unwrap(); - let result = internal_declare.execute::>( - &mut CachedState::default(), + let result = internal_declare.execute( + &mut CachedState::::default(), &BlockContext::default(), #[cfg(feature = "cairo-native")] None, diff --git a/src/transaction/fee.rs b/src/transaction/fee.rs index 8b2c44bec..b8d68b005 100644 --- a/src/transaction/fee.rs +++ b/src/transaction/fee.rs @@ -1,19 +1,18 @@ use super::error::TransactionError; -use crate::definitions::constants::FEE_FACTOR; -use crate::execution::execution_entry_point::ExecutionResult; -use crate::execution::CallType; -use crate::services::api::contract_classes::deprecated_contract_class::EntryPointType; -use crate::state::cached_state::CachedState; use crate::{ definitions::{ block_context::BlockContext, - constants::{INITIAL_GAS_COST, TRANSFER_ENTRY_POINT_SELECTOR}, + constants::{FEE_FACTOR, INITIAL_GAS_COST, TRANSFER_ENTRY_POINT_SELECTOR}, }, execution::{ - execution_entry_point::ExecutionEntryPoint, CallInfo, TransactionExecutionContext, + execution_entry_point::{ExecutionEntryPoint, ExecutionResult}, + CallInfo, CallType, TransactionExecutionContext, + }, + services::api::contract_classes::deprecated_contract_class::EntryPointType, + state::{ + cached_state::CachedState, contract_class_cache::ContractClassCache, + state_api::StateReader, ExecutionResourcesManager, }, - state::state_api::StateReader, - state::ExecutionResourcesManager, }; use cairo_vm::felt::Felt252; use num_traits::{ToPrimitive, Zero}; @@ -31,8 +30,8 @@ pub type FeeInfo = (Option, u128); /// Transfers the amount actual_fee from the caller account to the sequencer. /// Returns the resulting CallInfo of the transfer call. -pub(crate) fn execute_fee_transfer( - state: &mut CachedState, +pub(crate) fn execute_fee_transfer( + state: &mut CachedState, block_context: &BlockContext, tx_execution_context: &mut TransactionExecutionContext, actual_fee: u128, @@ -149,8 +148,8 @@ fn max_of_keys(cairo_rsc: &HashMap, weights: &HashMap( - state: &mut CachedState, +pub fn charge_fee( + state: &mut CachedState, resources: &HashMap, block_context: &BlockContext, max_fee: u128, @@ -201,17 +200,16 @@ pub fn charge_fee( #[cfg(test)] mod tests { - use std::{collections::HashMap, sync::Arc}; - use crate::{ definitions::block_context::BlockContext, execution::TransactionExecutionContext, state::{ - cached_state::{CachedState, ContractClassCache}, + cached_state::CachedState, contract_class_cache::PermanentContractClassCache, in_memory_state_reader::InMemoryStateReader, }, transaction::fee::charge_fee, }; + use std::{collections::HashMap, sync::Arc}; /// Tests the behavior of the charge_fee function when the actual fee exceeds the maximum fee /// for version 0. It expects to return an ActualFeeExceedsMaxFee error. @@ -219,7 +217,7 @@ mod tests { fn charge_fee_v0_max_fee_exceeded_should_charge_nothing() { let mut state = CachedState::new( Arc::new(InMemoryStateReader::default()), - ContractClassCache::default(), + Arc::new(PermanentContractClassCache::default()), ); let mut tx_execution_context = TransactionExecutionContext::default(); let mut block_context = BlockContext::default(); @@ -252,7 +250,7 @@ mod tests { fn charge_fee_v1_max_fee_exceeded_should_charge_max_fee() { let mut state = CachedState::new( Arc::new(InMemoryStateReader::default()), - ContractClassCache::default(), + Arc::new(PermanentContractClassCache::default()), ); let mut tx_execution_context = TransactionExecutionContext { version: 1.into(), diff --git a/src/transaction/invoke_function.rs b/src/transaction/invoke_function.rs index da289f41f..04a71f7f2 100644 --- a/src/transaction/invoke_function.rs +++ b/src/transaction/invoke_function.rs @@ -25,6 +25,7 @@ use crate::{ }, state::{ cached_state::CachedState, + contract_class_cache::ContractClassCache, state_api::{State, StateChangesCount, StateReader}, ExecutionResourcesManager, StateDiff, }, @@ -176,9 +177,9 @@ impl InvokeFunction { /// - state: A state that implements the [`State`] and [`StateReader`] traits. /// - resources_manager: the resources that are in use by the contract /// - block_context: The block's execution context - pub(crate) fn run_validate_entrypoint( + pub(crate) fn run_validate_entrypoint( &self, - state: &mut CachedState, + state: &mut CachedState, resources_manager: &mut ExecutionResourcesManager, block_context: &BlockContext, #[cfg(feature = "cairo-native")] program_cache: Option< @@ -239,9 +240,9 @@ impl InvokeFunction { /// Builds the transaction execution context and executes the entry point. /// Returns the CallInfo. - fn run_execute_entrypoint( + fn run_execute_entrypoint( &self, - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, resources_manager: &mut ExecutionResourcesManager, remaining_gas: u128, @@ -277,9 +278,9 @@ impl InvokeFunction { /// - state: A state that implements the [`State`] and [`StateReader`] traits. /// - block_context: The block's execution context. /// - remaining_gas: The amount of gas that the transaction disposes. - pub fn apply( + pub fn apply( &self, - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, remaining_gas: u128, #[cfg(feature = "cairo-native")] program_cache: Option< @@ -352,9 +353,9 @@ impl InvokeFunction { self.entry_point_selector = ?self.entry_point_selector, self.entry_point_type = ?self.entry_point_type, ))] - pub fn execute( + pub fn execute( &self, - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, remaining_gas: u128, #[cfg(feature = "cairo-native")] program_cache: Option< @@ -375,7 +376,7 @@ impl InvokeFunction { self.handle_nonce(state)?; - let mut transactional_state = state.create_transactional(); + let mut transactional_state = state.create_transactional()?; let mut tx_exec_info = self.apply( &mut transactional_state, block_context, @@ -654,11 +655,15 @@ fn convert_invoke_v1( mod tests { use super::*; use crate::{ + definitions::constants::QUERY_VERSION_1, services::api::contract_classes::{ compiled_class::CompiledClass, deprecated_contract_class::ContractClass, }, state::cached_state::CachedState, - state::in_memory_state_reader::InMemoryStateReader, + state::{ + contract_class_cache::PermanentContractClassCache, + in_memory_state_reader::InMemoryStateReader, + }, utils::{calculate_sn_keccak, ClassHash}, }; use cairo_lang_starknet::casm_contract_class::CasmContractClass; @@ -669,7 +674,7 @@ mod tests { hash::{StarkFelt, StarkHash}, transaction::{Fee, InvokeTransaction, InvokeTransactionV1, TransactionSignature}, }; - use std::{collections::HashMap, sync::Arc}; + use std::sync::Arc; #[test] fn test_from_invoke_transaction() { @@ -766,7 +771,7 @@ mod tests { // Set contract_class let class_hash = ClassHash([1; 32]); let contract_class = ContractClass::from_path("starknet_programs/fibonacci.json").unwrap(); - // Set contact_state + // Set contract_state let contract_address = Address(0.into()); let nonce = Felt252::zero(); @@ -777,10 +782,10 @@ mod tests { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(Arc::new(state_reader), HashMap::new()); - - // Initialize state.contract_classes - state.set_contract_classes(HashMap::new()).unwrap(); + let mut state = CachedState::new( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ); state .set_contract_class( @@ -789,7 +794,7 @@ mod tests { ) .unwrap(); - let mut transactional = state.create_transactional(); + let mut transactional = state.create_transactional().unwrap(); // Invoke result let result = internal_invoke_function .apply( @@ -849,7 +854,7 @@ mod tests { // Set contract_class let class_hash: ClassHash = ClassHash([1; 32]); let contract_class = ContractClass::from_path("starknet_programs/fibonacci.json").unwrap(); - // Set contact_state + // Set contract_state let contract_address = Address(0.into()); let nonce = Felt252::zero(); @@ -860,10 +865,10 @@ mod tests { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(Arc::new(state_reader), HashMap::new()); - - // Initialize state.contract_classes - state.set_contract_classes(HashMap::new()).unwrap(); + let mut state = CachedState::new( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ); state .set_contract_class( @@ -923,7 +928,7 @@ mod tests { // Set contract_class let class_hash: ClassHash = ClassHash([1; 32]); let contract_class = ContractClass::from_path("starknet_programs/amm.json").unwrap(); - // Set contact_state + // Set contract_state let contract_address = Address(0.into()); let nonce = Felt252::zero(); @@ -934,10 +939,10 @@ mod tests { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(Arc::new(state_reader), HashMap::new()); - - // Initialize state.contract_classes - state.set_contract_classes(HashMap::new()).unwrap(); + let mut state = CachedState::new( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ); state .set_contract_class( @@ -946,7 +951,7 @@ mod tests { ) .unwrap(); - let mut transactional = state.create_transactional(); + let mut transactional = state.create_transactional().unwrap(); let expected_error = internal_invoke_function.apply( &mut transactional, &BlockContext::default(), @@ -991,7 +996,7 @@ mod tests { // Set contract_class let class_hash: ClassHash = ClassHash([1; 32]); let contract_class = ContractClass::from_path("starknet_programs/fibonacci.json").unwrap(); - // Set contact_state + // Set contract_state let contract_address = Address(0.into()); let nonce = Felt252::zero(); @@ -1002,10 +1007,10 @@ mod tests { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(Arc::new(state_reader), HashMap::new()); - - // Initialize state.contract_classes - state.set_contract_classes(HashMap::new()).unwrap(); + let mut state = CachedState::new( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ); state .set_contract_class( @@ -1014,7 +1019,7 @@ mod tests { ) .unwrap(); - let mut transactional = state.create_transactional(); + let mut transactional = state.create_transactional().unwrap(); // Invoke result let result = internal_invoke_function .apply( @@ -1070,7 +1075,7 @@ mod tests { // Set contract_class let class_hash: ClassHash = ClassHash([1; 32]); let contract_class = ContractClass::from_path("starknet_programs/amm.json").unwrap(); - // Set contact_state + // Set contract_state let contract_address = Address(0.into()); let nonce = Felt252::zero(); @@ -1081,10 +1086,10 @@ mod tests { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(Arc::new(state_reader), HashMap::new()); - - // Initialize state.contract_classes - state.set_contract_classes(HashMap::new()).unwrap(); + let mut state = CachedState::new( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ); state .set_contract_class( @@ -1093,7 +1098,7 @@ mod tests { ) .unwrap(); - let mut transactional = state.create_transactional(); + let mut transactional = state.create_transactional().unwrap(); // Invoke result let expected_error = internal_invoke_function.apply( &mut transactional, @@ -1118,7 +1123,7 @@ mod tests { // Set contract_class let class_hash: ClassHash = ClassHash([1; 32]); let contract_class = ContractClass::from_path("starknet_programs/fibonacci.json").unwrap(); - // Set contact_state + // Set contract_state let nonce = Felt252::zero(); state_reader @@ -1150,10 +1155,10 @@ mod tests { skip_nonce_check: false, }; - let mut state = CachedState::new(Arc::new(state_reader), HashMap::new()); - - // Initialize state.contract_classes - state.set_contract_classes(HashMap::new()).unwrap(); + let mut state = CachedState::new( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ); state .set_contract_class( @@ -1208,7 +1213,7 @@ mod tests { // Set contract_class let class_hash: ClassHash = ClassHash([1; 32]); let contract_class = ContractClass::from_path("starknet_programs/fibonacci.json").unwrap(); - // Set contact_state + // Set contract_state let contract_address = Address(0.into()); let nonce = Felt252::zero(); @@ -1219,10 +1224,10 @@ mod tests { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(Arc::new(state_reader), HashMap::new()); - - // Initialize state.contract_classes - state.set_contract_classes(HashMap::new()).unwrap(); + let mut state = CachedState::new( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ); state .set_contract_class( @@ -1284,7 +1289,7 @@ mod tests { // Set contract_class let class_hash: ClassHash = ClassHash([1; 32]); let contract_class = ContractClass::from_path("starknet_programs/fibonacci.json").unwrap(); - // Set contact_state + // Set contract_state let contract_address = Address(0.into()); let nonce = Felt252::zero(); @@ -1295,10 +1300,10 @@ mod tests { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(Arc::new(state_reader), HashMap::new()); - - // Initialize state.contract_classes - state.set_contract_classes(HashMap::new()).unwrap(); + let mut state = CachedState::new( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ); state .set_contract_class( @@ -1361,7 +1366,7 @@ mod tests { // Set contract_class let class_hash: ClassHash = ClassHash([1; 32]); let contract_class = ContractClass::from_path("starknet_programs/fibonacci.json").unwrap(); - // Set contact_state + // Set contract_state let contract_address = Address(0.into()); let nonce = Felt252::zero(); @@ -1372,10 +1377,10 @@ mod tests { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(Arc::new(state_reader), HashMap::new()); - - // Initialize state.contract_classes - state.set_contract_classes(HashMap::new()).unwrap(); + let mut state = CachedState::new( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ); state .set_contract_class( @@ -1480,7 +1485,7 @@ mod tests { ) .unwrap(), None, - Into::::into(1), + QUERY_VERSION_1.clone(), ); assert!(expected_error.is_err()); } @@ -1523,13 +1528,15 @@ mod tests { .insert(class_hash, class_hash); // last is necessary so the transactional state can cache the class - let mut casm_contract_class_cache = HashMap::new(); + let casm_contract_class_cache = PermanentContractClassCache::default(); - casm_contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + casm_contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); - let mut state = CachedState::new(Arc::new(state_reader), casm_contract_class_cache); + let mut state = + CachedState::new(Arc::new(state_reader), Arc::new(casm_contract_class_cache)); - let state_before_execution = state.clone(); + let state_before_execution = state.clone_for_testing(); let result = internal_invoke_function .execute( @@ -1586,8 +1593,8 @@ mod tests { Some(Felt252::zero()), ) .unwrap(); - let result = internal_declare.execute::>( - &mut CachedState::default(), + let result = internal_declare.execute( + &mut CachedState::::default(), &BlockContext::default(), u128::MAX, #[cfg(feature = "cairo-native")] diff --git a/src/transaction/l1_handler.rs b/src/transaction/l1_handler.rs index d582a6570..a082c2a85 100644 --- a/src/transaction/l1_handler.rs +++ b/src/transaction/l1_handler.rs @@ -12,6 +12,7 @@ use crate::{ services::api::contract_classes::deprecated_contract_class::EntryPointType, state::{ cached_state::CachedState, + contract_class_cache::ContractClassCache, state_api::{State, StateReader}, ExecutionResourcesManager, }, @@ -109,9 +110,9 @@ impl L1Handler { self.entry_point_selector = ?self.entry_point_selector, self.nonce = ?self.nonce, ))] - pub fn execute( + pub fn execute( &self, - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, remaining_gas: u128, #[cfg(feature = "cairo-native")] program_cache: Option< @@ -246,35 +247,31 @@ impl L1Handler { #[cfg(test)] mod test { + use crate::{ + definitions::{block_context::BlockContext, transaction_type::TransactionType}, + execution::{CallInfo, TransactionExecutionInfo}, + services::api::contract_classes::{ + compiled_class::CompiledClass, + deprecated_contract_class::{ContractClass, EntryPointType}, + }, + state::{ + cached_state::CachedState, contract_class_cache::PermanentContractClassCache, + in_memory_state_reader::InMemoryStateReader, state_api::State, + }, + transaction::l1_handler::L1Handler, + utils::{Address, ClassHash}, + }; use std::{ collections::{HashMap, HashSet}, sync::Arc, }; - use crate::{ - services::api::contract_classes::{ - compiled_class::CompiledClass, deprecated_contract_class::EntryPointType, - }, - utils::ClassHash, - }; use cairo_vm::{ felt::{felt_str, Felt252}, vm::runners::cairo_runner::ExecutionResources, }; use num_traits::{Num, Zero}; - use crate::{ - definitions::{block_context::BlockContext, transaction_type::TransactionType}, - execution::{CallInfo, TransactionExecutionInfo}, - services::api::contract_classes::deprecated_contract_class::ContractClass, - state::{ - cached_state::CachedState, in_memory_state_reader::InMemoryStateReader, - state_api::State, - }, - transaction::l1_handler::L1Handler, - utils::Address, - }; - /// Test the correct execution of the L1Handler. #[test] fn test_execute_l1_handler() { @@ -301,7 +298,7 @@ mod test { // Set contract_class let class_hash: ClassHash = ClassHash([1; 32]); let contract_class = ContractClass::from_path("starknet_programs/l1l2.json").unwrap(); - // Set contact_state + // Set contract_state let contract_address = Address(0.into()); let nonce = Felt252::zero(); @@ -312,10 +309,10 @@ mod test { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(Arc::new(state_reader), HashMap::new()); - - // Initialize state.contract_classes - state.set_contract_classes(HashMap::new()).unwrap(); + let mut state = CachedState::new( + Arc::new(state_reader), + Arc::new(PermanentContractClassCache::default()), + ); state .set_contract_class( diff --git a/src/transaction/mod.rs b/src/transaction/mod.rs index 4f8d212a2..81a4634ed 100644 --- a/src/transaction/mod.rs +++ b/src/transaction/mod.rs @@ -1,3 +1,20 @@ +use crate::{ + definitions::block_context::BlockContext, + definitions::constants::{QUERY_VERSION_0, QUERY_VERSION_1, QUERY_VERSION_2}, + execution::TransactionExecutionInfo, + state::{ + cached_state::CachedState, contract_class_cache::ContractClassCache, state_api::StateReader, + }, + utils::Address, +}; +pub use declare::Declare; +pub use declare_v2::DeclareV2; +pub use deploy::Deploy; +pub use deploy_account::DeployAccount; +use error::TransactionError; +pub use invoke_function::InvokeFunction; +pub use l1_handler::L1Handler; + pub mod declare; pub mod declare_v2; pub mod deploy; @@ -8,25 +25,8 @@ pub mod invoke_function; pub mod l1_handler; use cairo_vm::felt::Felt252; -pub use declare::Declare; -pub use declare_v2::DeclareV2; -pub use deploy::Deploy; -pub use deploy_account::DeployAccount; -pub use invoke_function::InvokeFunction; -pub use l1_handler::L1Handler; use num_traits::{One, Zero}; -use crate::{ - definitions::{ - block_context::BlockContext, - constants::{QUERY_VERSION_0, QUERY_VERSION_1, QUERY_VERSION_2}, - }, - execution::TransactionExecutionInfo, - state::{cached_state::CachedState, state_api::StateReader}, - utils::Address, -}; -use error::TransactionError; - #[cfg(feature = "cairo-native")] use { crate::utils::ClassHash, @@ -76,9 +76,9 @@ impl Transaction { ///- state: a structure that implements State and StateReader traits. ///- block_context: The block context of the transaction that is about to be executed. ///- remaining_gas: The gas supplied to execute the transaction. - pub fn execute( + pub fn execute( &self, - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, remaining_gas: u128, #[cfg(feature = "cairo-native")] program_cache: Option< diff --git a/src/utils.rs b/src/utils.rs index b9f0e474d..0d2e6c0d7 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -527,8 +527,8 @@ pub mod test_utils { compiled_class::CompiledClass, deprecated_contract_class::ContractClass, }, state::{ - cached_state::CachedState, in_memory_state_reader::InMemoryStateReader, - state_cache::StorageEntry, BlockInfo, + cached_state::CachedState, contract_class_cache::PermanentContractClassCache, + in_memory_state_reader::InMemoryStateReader, state_cache::StorageEntry, BlockInfo, }, utils::Address, }; @@ -781,8 +781,13 @@ pub mod test_utils { ) } - pub(crate) fn create_account_tx_test_state( - ) -> Result<(BlockContext, CachedState), Box> { + pub(crate) fn create_account_tx_test_state() -> Result< + ( + BlockContext, + CachedState, + ), + Box, + > { let block_context = new_starknet_block_context_for_testing(); let test_contract_class_hash = felt_to_hash(&TEST_CLASS_HASH.clone()); @@ -859,7 +864,7 @@ pub mod test_utils { } Arc::new(state_reader) }, - HashMap::new(), + Arc::new(PermanentContractClassCache::default()), ); Ok((block_context, cached_state)) diff --git a/starknet_programs/syscalls.cairo b/starknet_programs/syscalls.cairo index 010f8374c..930092240 100644 --- a/starknet_programs/syscalls.cairo +++ b/starknet_programs/syscalls.cairo @@ -72,10 +72,10 @@ func test_call_contract{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_ch let (value) = lib_state.read(); assert value = 10; - let (call_contact_address) = ISyscallsLib.stateful_get_contract_address( + let (call_contract_address) = ISyscallsLib.stateful_get_contract_address( contract_address=contract_address ); - assert call_contact_address = contract_address; + assert call_contract_address = contract_address; return (); } @@ -181,11 +181,11 @@ func test_library_call{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_che let (value) = lib_state.read(); assert value = 11; - let self_contact_address = get_contract_address(); - let call_contact_address = ISyscallsLib.library_call_stateful_get_contract_address( + let self_contract_address = get_contract_address(); + let call_contract_address = ISyscallsLib.library_call_stateful_get_contract_address( class_hash=0x0202020202020202020202020202020202020202020202020202020202020202 ); - assert self_contact_address = call_contact_address; + assert self_contract_address = call_contract_address; return (); } @@ -255,7 +255,7 @@ func test_deploy_with_constructor{syscall_ptr: felt*}( // Set constructor. let (ptr) = alloc(); assert [ptr] = constructor; - + let contract_address = deploy( class_hash, contract_address_salt, @@ -276,7 +276,7 @@ func test_deploy_and_call_contract{syscall_ptr: felt*, pedersen_ptr: HashBuiltin // Set constructor. let (ptr) = alloc(); assert [ptr] = constructor; - + // Deploy contract let (contract_address) = deploy( class_hash, diff --git a/tests/account_panic.rs b/tests/account_panic.rs index 5ee3807e6..fac02121b 100644 --- a/tests/account_panic.rs +++ b/tests/account_panic.rs @@ -1,11 +1,15 @@ -use std::{collections::HashMap, sync::Arc}; +use std::sync::Arc; use cairo_vm::felt::Felt252; use starknet_in_rust::{ core::contract_address::compute_casm_class_hash, definitions::{block_context::BlockContext, constants::TRANSACTION_VERSION}, services::api::contract_classes::compiled_class::CompiledClass, - state::{cached_state::CachedState, in_memory_state_reader::InMemoryStateReader}, + state::{ + cached_state::CachedState, + contract_class_cache::{ContractClassCache, PermanentContractClassCache}, + in_memory_state_reader::InMemoryStateReader, + }, transaction::{InvokeFunction, Transaction}, utils::{calculate_sn_keccak, Address, ClassHash}, CasmContractClass, @@ -33,13 +37,13 @@ fn account_panic() { let block_context = BlockContext::default(); - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( account_class_hash, CompiledClass::Casm(Arc::new(account_contract_class)), ); - contract_class_cache.insert( + contract_class_cache.set_contract_class( contract_class_hash, CompiledClass::Casm(Arc::new(contract_class.clone())), ); @@ -57,7 +61,7 @@ fn account_panic() { state_reader .address_to_nonce_mut() .insert(contract_address, 1.into()); - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); let selector = Felt252::from_bytes_be(&calculate_sn_keccak(b"__execute__")); diff --git a/tests/cairo_1_syscalls.rs b/tests/cairo_1_syscalls.rs index 67954b8d9..af3ffa823 100644 --- a/tests/cairo_1_syscalls.rs +++ b/tests/cairo_1_syscalls.rs @@ -1,8 +1,3 @@ -use std::{ - collections::{HashMap, HashSet}, - sync::Arc, -}; - use cairo_lang_starknet::casm_contract_class::CasmContractClass; use cairo_vm::{ felt::{felt_str, Felt252}, @@ -11,6 +6,7 @@ use cairo_vm::{ use num_bigint::BigUint; use num_traits::{Num, One, Zero}; use pretty_assertions_sorted::{assert_eq, assert_eq_sorted}; +use starknet_in_rust::utils::calculate_sn_keccak; use starknet_in_rust::{ definitions::{block_context::BlockContext, constants::TRANSACTION_VERSION}, execution::{ @@ -20,11 +16,20 @@ use starknet_in_rust::{ services::api::contract_classes::{ compiled_class::CompiledClass, deprecated_contract_class::ContractClass, }, - state::{cached_state::CachedState, state_api::StateReader}, - state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager}, + state::{ + cached_state::CachedState, + contract_class_cache::{ContractClassCache, PermanentContractClassCache}, + in_memory_state_reader::InMemoryStateReader, + state_api::StateReader, + ExecutionResourcesManager, + }, utils::{Address, ClassHash}, + EntryPointType, +}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, }; -use starknet_in_rust::{utils::calculate_sn_keccak, EntryPointType}; fn create_execute_extrypoint( address: Address, @@ -60,13 +65,14 @@ fn storage_write_read() { let increase_balance_entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() @@ -76,7 +82,7 @@ fn storage_write_read() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); let block_context = BlockContext::default(); let mut tx_execution_context = TransactionExecutionContext::new( @@ -207,13 +213,14 @@ fn library_call() { let entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() @@ -235,7 +242,7 @@ fn library_call() { let lib_class_hash: ClassHash = ClassHash([2; 32]); let lib_nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( lib_class_hash, CompiledClass::Casm(Arc::new(lib_contract_class)), ); @@ -247,7 +254,7 @@ fn library_call() { .insert(lib_address, lib_nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // Create an execution entry point let calldata = [ @@ -374,13 +381,14 @@ fn call_contract_storage_write_read() { &BigUint::from_bytes_be(&calculate_sn_keccak("increase_balance".as_bytes())); // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() @@ -411,7 +419,7 @@ fn call_contract_storage_write_read() { let simple_wallet_class_hash: ClassHash = ClassHash([2; 32]); let simple_wallet_nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( simple_wallet_class_hash, CompiledClass::Casm(Arc::new(simple_wallet_contract_class)), ); @@ -423,7 +431,7 @@ fn call_contract_storage_write_read() { .insert(simple_wallet_address.clone(), simple_wallet_nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); let block_context = BlockContext::default(); let mut tx_execution_context = TransactionExecutionContext::new( @@ -571,13 +579,14 @@ fn emit_event() { let entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() @@ -587,7 +596,7 @@ fn emit_event() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // Create an execution entry point let calldata = [].to_vec(); @@ -686,14 +695,15 @@ fn deploy_cairo1_from_cairo1() { let entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); - contract_class_cache.insert( + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache.set_contract_class( test_class_hash, CompiledClass::Casm(Arc::new(test_contract_class.clone())), ); @@ -707,7 +717,7 @@ fn deploy_cairo1_from_cairo1() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // arguments of deploy contract let calldata: Vec<_> = [test_felt_hash, salt].to_vec(); @@ -789,14 +799,15 @@ fn deploy_cairo0_from_cairo1_without_constructor() { let entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); - contract_class_cache.insert( + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache.set_contract_class( test_class_hash, CompiledClass::Deprecated(Arc::new(test_contract_class.clone())), ); @@ -810,7 +821,7 @@ fn deploy_cairo0_from_cairo1_without_constructor() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // arguments of deploy contract let calldata: Vec<_> = [test_felt_hash, salt].to_vec(); @@ -890,15 +901,16 @@ fn deploy_cairo0_from_cairo1_with_constructor() { let entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); // simulate contract declare - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); - contract_class_cache.insert( + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache.set_contract_class( test_class_hash, CompiledClass::Deprecated(Arc::new(test_contract_class.clone())), ); @@ -912,7 +924,7 @@ fn deploy_cairo0_from_cairo1_with_constructor() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // arguments of deploy contract let calldata: Vec<_> = [test_felt_hash, salt, address.0.clone(), Felt252::zero()].to_vec(); @@ -994,14 +1006,15 @@ fn deploy_cairo0_and_invoke() { let entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); - contract_class_cache.insert( + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache.set_contract_class( test_class_hash, CompiledClass::Deprecated(Arc::new(test_contract_class.clone())), ); @@ -1015,7 +1028,8 @@ fn deploy_cairo0_and_invoke() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state: CachedState<_> = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state: CachedState<_, _> = + CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // arguments of deploy contract let calldata: Vec<_> = [test_felt_hash, salt].to_vec(); @@ -1123,13 +1137,14 @@ fn test_send_message_to_l1_syscall() { let external_entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader @@ -1140,7 +1155,7 @@ fn test_send_message_to_l1_syscall() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // RUN SEND_MSG // Create an execution entry point @@ -1229,13 +1244,14 @@ fn test_get_execution_info() { let external_entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() @@ -1245,7 +1261,7 @@ fn test_get_execution_info() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); let block_context = BlockContext::default(); let mut tx_execution_context = TransactionExecutionContext::new( @@ -1336,13 +1352,13 @@ fn replace_class_internal() { let upgrade_selector = &entrypoints_a.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash_a: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash_a, CompiledClass::Casm(Arc::new(contract_class_a)), ); @@ -1363,13 +1379,13 @@ fn replace_class_internal() { let class_hash_b: ClassHash = ClassHash([2; 32]); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash_b, CompiledClass::Casm(Arc::new(contract_class_b.clone())), ); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // Run upgrade entrypoint and check that the storage was updated with the new contract class // Create an execution entry point @@ -1440,13 +1456,13 @@ fn replace_class_contract_call() { let contract_class_a: CasmContractClass = serde_json::from_slice(program_data).unwrap(); // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(Felt252::one()); let class_hash_a: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash_a, CompiledClass::Casm(Arc::new(contract_class_a)), ); @@ -1470,7 +1486,7 @@ fn replace_class_contract_call() { let class_hash_b: ClassHash = ClassHash([2; 32]); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash_b, CompiledClass::Casm(Arc::new(contract_class_b)), ); @@ -1490,7 +1506,7 @@ fn replace_class_contract_call() { let wrapper_address = Address(Felt252::from(2)); let wrapper_class_hash: ClassHash = ClassHash([3; 32]); - contract_class_cache.insert( + contract_class_cache.set_contract_class( wrapper_class_hash, CompiledClass::Casm(Arc::new(wrapper_contract_class)), ); @@ -1502,7 +1518,7 @@ fn replace_class_contract_call() { .insert(wrapper_address, nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // INITIALIZE STARKNET CONFIG let block_context = BlockContext::default(); @@ -1623,13 +1639,13 @@ fn replace_class_contract_call_same_transaction() { let contract_class_a: CasmContractClass = serde_json::from_slice(program_data).unwrap(); // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(Felt252::one()); let class_hash_a: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash_a, CompiledClass::Casm(Arc::new(contract_class_a)), ); @@ -1653,7 +1669,7 @@ fn replace_class_contract_call_same_transaction() { let class_hash_b: ClassHash = ClassHash([2; 32]); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash_b, CompiledClass::Casm(Arc::new(contract_class_b)), ); @@ -1672,7 +1688,7 @@ fn replace_class_contract_call_same_transaction() { let wrapper_address = Address(Felt252::from(2)); let wrapper_class_hash: ClassHash = ClassHash([3; 32]); - contract_class_cache.insert( + contract_class_cache.set_contract_class( wrapper_class_hash, CompiledClass::Casm(Arc::new(wrapper_contract_class)), ); @@ -1684,7 +1700,7 @@ fn replace_class_contract_call_same_transaction() { .insert(wrapper_address, nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // INITIALIZE STARKNET CONFIG let block_context = BlockContext::default(); @@ -1748,13 +1764,13 @@ fn call_contract_upgrade_cairo_0_to_cairo_1_same_transaction() { let contract_class_c = ContractClass::from_path("starknet_programs/get_number_c.json").unwrap(); // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(Felt252::one()); let class_hash_c: ClassHash = ClassHash::from(Felt252::one()); let nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash_c, CompiledClass::Deprecated(Arc::new(contract_class_c)), ); @@ -1778,7 +1794,7 @@ fn call_contract_upgrade_cairo_0_to_cairo_1_same_transaction() { let class_hash_b: ClassHash = ClassHash::from(Felt252::from(2)); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash_b, CompiledClass::Casm(Arc::new(contract_class_b)), ); @@ -1797,7 +1813,7 @@ fn call_contract_upgrade_cairo_0_to_cairo_1_same_transaction() { let wrapper_address = Address(Felt252::from(2)); let wrapper_class_hash: ClassHash = ClassHash([3; 32]); - contract_class_cache.insert( + contract_class_cache.set_contract_class( wrapper_class_hash, CompiledClass::Casm(Arc::new(wrapper_contract_class)), ); @@ -1809,7 +1825,7 @@ fn call_contract_upgrade_cairo_0_to_cairo_1_same_transaction() { .insert(wrapper_address, nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // INITIALIZE STARKNET CONFIG let block_context = BlockContext::default(); @@ -1871,13 +1887,13 @@ fn call_contract_downgrade_cairo_1_to_cairo_0_same_transaction() { let contract_class_c = ContractClass::from_path("starknet_programs/get_number_c.json").unwrap(); // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(Felt252::one()); let class_hash_c: ClassHash = ClassHash::from(Felt252::one()); let nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash_c, CompiledClass::Deprecated(Arc::new(contract_class_c)), ); @@ -1894,7 +1910,7 @@ fn call_contract_downgrade_cairo_1_to_cairo_0_same_transaction() { let class_hash_b: ClassHash = ClassHash::from(Felt252::from(2)); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash_b, CompiledClass::Casm(Arc::new(contract_class_b)), ); @@ -1920,7 +1936,7 @@ fn call_contract_downgrade_cairo_1_to_cairo_0_same_transaction() { let wrapper_address = Address(Felt252::from(2)); let wrapper_class_hash: ClassHash = ClassHash([3; 32]); - contract_class_cache.insert( + contract_class_cache.set_contract_class( wrapper_class_hash, CompiledClass::Casm(Arc::new(wrapper_contract_class)), ); @@ -1932,7 +1948,7 @@ fn call_contract_downgrade_cairo_1_to_cairo_0_same_transaction() { .insert(wrapper_address, nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // INITIALIZE STARKNET CONFIG let block_context = BlockContext::default(); @@ -1994,13 +2010,13 @@ fn call_contract_replace_class_cairo_0() { let contract_class_c = ContractClass::from_path("starknet_programs/get_number_c.json").unwrap(); // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(Felt252::one()); let class_hash_c: ClassHash = ClassHash::from(Felt252::one()); let nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash_c, CompiledClass::Deprecated(Arc::new(contract_class_c)), ); @@ -2013,7 +2029,7 @@ fn call_contract_replace_class_cairo_0() { let class_hash_d: ClassHash = ClassHash::from(Felt252::from(2)); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash_d, CompiledClass::Deprecated(Arc::new(contract_class_d)), ); @@ -2039,7 +2055,7 @@ fn call_contract_replace_class_cairo_0() { let wrapper_address = Address(Felt252::from(2)); let wrapper_class_hash: ClassHash = ClassHash([3; 32]); - contract_class_cache.insert( + contract_class_cache.set_contract_class( wrapper_class_hash, CompiledClass::Casm(Arc::new(wrapper_contract_class)), ); @@ -2051,7 +2067,7 @@ fn call_contract_replace_class_cairo_0() { .insert(wrapper_address, nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // INITIALIZE STARKNET CONFIG let block_context = BlockContext::default(); @@ -2113,13 +2129,14 @@ fn test_out_of_gas_failure() { let entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() @@ -2129,7 +2146,7 @@ fn test_out_of_gas_failure() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // Create an execution entry point let calldata = [].to_vec(); @@ -2192,13 +2209,14 @@ fn deploy_syscall_failure_uninitialized_class_hash() { let entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() @@ -2208,7 +2226,7 @@ fn deploy_syscall_failure_uninitialized_class_hash() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // Create an execution entry point let calldata = [Felt252::zero()].to_vec(); @@ -2270,13 +2288,14 @@ fn deploy_syscall_failure_in_constructor() { let entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() @@ -2292,13 +2311,13 @@ fn deploy_syscall_failure_in_constructor() { let f_c_program_data = include_bytes!("../starknet_programs/cairo1/failing_constructor.casm"); let f_c_contract_class: CasmContractClass = serde_json::from_slice(f_c_program_data).unwrap(); let f_c_class_hash = Felt252::one(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( ClassHash::from(f_c_class_hash.clone()), CompiledClass::Casm(Arc::new(f_c_contract_class)), ); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // Create an execution entry point let calldata = [f_c_class_hash].to_vec(); @@ -2362,13 +2381,14 @@ fn storage_read_no_value() { let get_balance_entrypoint_selector = &entrypoints.external.get(1).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() @@ -2378,7 +2398,7 @@ fn storage_read_no_value() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); let block_context = BlockContext::default(); let mut tx_execution_context = TransactionExecutionContext::new( @@ -2435,13 +2455,14 @@ fn storage_read_unavailable_address_domain() { let read_storage_entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() @@ -2451,7 +2472,7 @@ fn storage_read_unavailable_address_domain() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); let block_context = BlockContext::default(); let mut tx_execution_context = TransactionExecutionContext::new( @@ -2511,13 +2532,14 @@ fn storage_write_unavailable_address_domain() { let read_storage_entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() @@ -2527,7 +2549,7 @@ fn storage_write_unavailable_address_domain() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); let block_context = BlockContext::default(); let mut tx_execution_context = TransactionExecutionContext::new( @@ -2585,13 +2607,14 @@ fn library_call_failure() { let entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() @@ -2612,7 +2635,7 @@ fn library_call_failure() { let lib_class_hash: ClassHash = ClassHash([2; 32]); let lib_nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( lib_class_hash, CompiledClass::Casm(Arc::new(lib_contract_class)), ); @@ -2624,7 +2647,7 @@ fn library_call_failure() { .insert(lib_address, lib_nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // Create an execution entry point let calldata = [25.into(), Felt252::from_bytes_be(&lib_class_hash.0)].to_vec(); @@ -2699,13 +2722,14 @@ fn send_messages_to_l1_different_contract_calls() { .clone(); // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() @@ -2726,7 +2750,7 @@ fn send_messages_to_l1_different_contract_calls() { let send_msg_class_hash: ClassHash = ClassHash([2; 32]); let send_msg_nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( send_msg_class_hash, CompiledClass::Casm(Arc::new(send_msg_contract_class)), ); @@ -2738,7 +2762,7 @@ fn send_messages_to_l1_different_contract_calls() { .insert(send_msg_address, send_msg_nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // Create an execution entry point let calldata = [25.into(), 50.into(), 75.into()].to_vec(); @@ -2824,13 +2848,14 @@ fn send_messages_to_l1_different_contract_calls_cairo1_to_cairo0() { .clone(); // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() @@ -2848,7 +2873,7 @@ fn send_messages_to_l1_different_contract_calls_cairo1_to_cairo0() { let send_msg_class_hash: ClassHash = ClassHash([2; 32]); let send_msg_nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( send_msg_class_hash, CompiledClass::Deprecated(Arc::new(send_msg_contract_class)), ); @@ -2860,7 +2885,7 @@ fn send_messages_to_l1_different_contract_calls_cairo1_to_cairo0() { .insert(send_msg_address, send_msg_nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // Create an execution entry point let calldata = [25.into(), 50.into(), 75.into()].to_vec(); @@ -2941,13 +2966,13 @@ fn send_messages_to_l1_different_contract_calls_cairo0_to_cairo1() { .to_owned(); // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class)), ); @@ -2971,7 +2996,7 @@ fn send_messages_to_l1_different_contract_calls_cairo0_to_cairo1() { let send_msg_class_hash: ClassHash = ClassHash([2; 32]); let send_msg_nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( send_msg_class_hash, CompiledClass::Casm(Arc::new(send_msg_contract_class)), ); @@ -2983,7 +3008,7 @@ fn send_messages_to_l1_different_contract_calls_cairo0_to_cairo1() { .insert(send_msg_address, send_msg_nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // Create an execution entry point let calldata = [25.into(), 50.into(), 75.into()].to_vec(); @@ -3063,13 +3088,14 @@ fn keccak_syscall() { let read_storage_entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() @@ -3079,7 +3105,7 @@ fn keccak_syscall() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); let block_context = BlockContext::default(); let mut tx_execution_context = TransactionExecutionContext::new( @@ -3138,13 +3164,14 @@ fn library_call_recursive_50_calls() { let entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() @@ -3166,7 +3193,7 @@ fn library_call_recursive_50_calls() { let lib_class_hash: ClassHash = ClassHash([2; 32]); let lib_nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( lib_class_hash, CompiledClass::Casm(Arc::new(lib_contract_class)), ); @@ -3178,7 +3205,7 @@ fn library_call_recursive_50_calls() { .insert(lib_address, lib_nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // Create an execution entry point let calldata = [ @@ -3280,13 +3307,14 @@ fn call_contract_storage_write_read_recursive_50_calls() { )); // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() @@ -3317,7 +3345,7 @@ fn call_contract_storage_write_read_recursive_50_calls() { let simple_wallet_class_hash: ClassHash = ClassHash([2; 32]); let simple_wallet_nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( simple_wallet_class_hash, CompiledClass::Casm(Arc::new(simple_wallet_contract_class)), ); @@ -3329,7 +3357,7 @@ fn call_contract_storage_write_read_recursive_50_calls() { .insert(simple_wallet_address.clone(), simple_wallet_nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); let block_context = BlockContext::default(); let mut tx_execution_context = TransactionExecutionContext::new( @@ -3486,13 +3514,14 @@ fn call_contract_storage_write_read_recursive_100_calls() { )); // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() @@ -3523,7 +3552,7 @@ fn call_contract_storage_write_read_recursive_100_calls() { let simple_wallet_class_hash: ClassHash = ClassHash([2; 32]); let simple_wallet_nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( simple_wallet_class_hash, CompiledClass::Casm(Arc::new(simple_wallet_contract_class)), ); @@ -3535,7 +3564,7 @@ fn call_contract_storage_write_read_recursive_100_calls() { .insert(simple_wallet_address.clone(), simple_wallet_nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); let block_context = BlockContext::default(); let mut tx_execution_context = TransactionExecutionContext::new( diff --git a/tests/cairo_native.rs b/tests/cairo_native.rs index 2f7c0b809..504e15cc7 100644 --- a/tests/cairo_native.rs +++ b/tests/cairo_native.rs @@ -5,6 +5,7 @@ use cairo_lang_starknet::casm_contract_class::CasmContractEntryPoints; use cairo_lang_starknet::contract_class::ContractClass; use cairo_lang_starknet::contract_class::ContractEntryPoints; use cairo_vm::felt::Felt252; +use cairo_vm::with_std::collections::HashSet; use num_bigint::BigUint; use num_traits::{Num, One, Zero}; use pretty_assertions_sorted::{assert_eq, assert_eq_sorted}; @@ -14,6 +15,8 @@ use starknet_in_rust::definitions::block_context::BlockContext; use starknet_in_rust::execution::{Event, OrderedEvent}; use starknet_in_rust::hash_utils::calculate_contract_address; use starknet_in_rust::services::api::contract_classes::compiled_class::CompiledClass; +use starknet_in_rust::state::contract_class_cache::ContractClassCache; +use starknet_in_rust::state::contract_class_cache::PermanentContractClassCache; use starknet_in_rust::state::state_api::State; use starknet_in_rust::CasmContractClass; use starknet_in_rust::EntryPointType::{self, External}; @@ -26,18 +29,16 @@ use starknet_in_rust::{ state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager}, utils::{Address, ClassHash}, }; -use std::collections::HashMap; -use std::collections::HashSet; use std::sync::Arc; fn insert_sierra_class_into_cache( - contract_class_cache: &mut HashMap, + contract_class_cache: &PermanentContractClassCache, class_hash: ClassHash, sierra_class: ContractClass, ) { let sierra_program = sierra_class.extract_sierra_program().unwrap(); let entry_points = sierra_class.entry_points_by_type; - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Sierra(Arc::new((sierra_program, entry_points))), ); @@ -46,7 +47,9 @@ fn insert_sierra_class_into_cache( #[test] #[cfg(feature = "cairo-native")] fn get_block_hash_test() { - use starknet_in_rust::utils::felt_to_hash; + use starknet_in_rust::{ + state::contract_class_cache::PermanentContractClassCache, utils::felt_to_hash, + }; let sierra_contract_class: cairo_lang_starknet::contract_class::ContractClass = serde_json::from_str( @@ -66,19 +69,19 @@ fn get_block_hash_test() { let casm_external_selector = &casm_entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let native_class_hash: ClassHash = ClassHash([1; 32]); let casm_class_hash: ClassHash = ClassHash([2; 32]); let caller_address = Address(1.into()); insert_sierra_class_into_cache( - &mut contract_class_cache, + &contract_class_cache, native_class_hash, sierra_contract_class, ); - contract_class_cache.insert( + contract_class_cache.set_contract_class( casm_class_hash, CompiledClass::Casm(Arc::new(casm_contract_class)), ); @@ -95,13 +98,14 @@ fn get_block_hash_test() { // Create state from the state_reader and contract cache. let state_reader = Arc::new(state_reader); - let mut state_vm = CachedState::new(state_reader.clone(), contract_class_cache.clone()); + let mut state_vm = + CachedState::new(state_reader.clone(), Arc::new(contract_class_cache.clone())); state_vm.cache_mut().storage_initial_values_mut().insert( (Address(1.into()), felt_to_hash(&Felt252::from(10)).0), Felt252::from_bytes_be(StarkHash::new([5; 32]).unwrap().bytes()), ); - let mut state_native = CachedState::new(state_reader, contract_class_cache); + let mut state_native = CachedState::new(state_reader, Arc::new(contract_class_cache)); state_native .cache_mut() .storage_initial_values_mut() @@ -204,7 +208,7 @@ fn integration_test_erc20() { let casm_constructor_selector = &casm_entrypoints.constructor.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); static NATIVE_CLASS_HASH: ClassHash = ClassHash([1; 32]); static CASM_CLASS_HASH: ClassHash = ClassHash([2; 32]); @@ -212,11 +216,11 @@ fn integration_test_erc20() { let caller_address = Address(123456789.into()); insert_sierra_class_into_cache( - &mut contract_class_cache, + &contract_class_cache, NATIVE_CLASS_HASH, sierra_contract_class, ); - contract_class_cache.insert( + contract_class_cache.set_contract_class( CASM_CLASS_HASH, CompiledClass::Casm(Arc::new(casm_contract_class)), ); @@ -232,9 +236,9 @@ fn integration_test_erc20() { // Create state from the state_reader and contract cache. let state_reader = Arc::new(state_reader); - let mut state_vm = CachedState::new(state_reader.clone(), contract_class_cache.clone()); - - let mut state_native = CachedState::new(state_reader, contract_class_cache); + let mut state_vm = + CachedState::new(state_reader.clone(), Arc::new(contract_class_cache.clone())); + let mut state_native = CachedState::new(state_reader, Arc::new(contract_class_cache)); /* 1 recipient @@ -316,8 +320,8 @@ fn integration_test_erc20() { #[allow(clippy::too_many_arguments)] fn compare_results( - state_vm: &mut CachedState, - state_native: &mut CachedState, + state_vm: &mut CachedState, + state_native: &mut CachedState, selector_idx: usize, native_entrypoints: &ContractEntryPoints, casm_entrypoints: &CasmContractEntryPoints, @@ -597,7 +601,7 @@ fn call_contract_test() { let fn_selector = &callee_entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); // Caller contract data let caller_address = Address(1111.into()); @@ -610,13 +614,13 @@ fn call_contract_test() { let callee_nonce = Felt252::zero(); insert_sierra_class_into_cache( - &mut contract_class_cache, + &contract_class_cache, caller_class_hash, caller_contract_class, ); insert_sierra_class_into_cache( - &mut contract_class_cache, + &contract_class_cache, callee_class_hash, callee_contract_class, ); @@ -640,10 +644,9 @@ fn call_contract_test() { .insert(callee_address.clone(), callee_nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); let calldata = [fn_selector.into()].to_vec(); - let result = execute( &mut state, &caller_address, @@ -686,7 +689,7 @@ fn call_echo_contract_test() { let fn_selector = &callee_entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); // Caller contract data let caller_address = Address(1111.into()); @@ -699,13 +702,13 @@ fn call_echo_contract_test() { let callee_nonce = Felt252::zero(); insert_sierra_class_into_cache( - &mut contract_class_cache, + &contract_class_cache, caller_class_hash, caller_contract_class, ); insert_sierra_class_into_cache( - &mut contract_class_cache, + &contract_class_cache, callee_class_hash, callee_contract_class, ); @@ -729,7 +732,7 @@ fn call_echo_contract_test() { .insert(callee_address.clone(), callee_nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); let calldata = [fn_selector.into(), 99999999.into()].to_vec(); let result = execute( @@ -776,7 +779,7 @@ fn call_events_contract_test() { let fn_selector = &callee_entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); // Caller contract data let caller_address = Address(1111.into()); @@ -789,13 +792,13 @@ fn call_events_contract_test() { let callee_nonce = Felt252::zero(); insert_sierra_class_into_cache( - &mut contract_class_cache, + &contract_class_cache, caller_class_hash, caller_contract_class, ); insert_sierra_class_into_cache( - &mut contract_class_cache, + &contract_class_cache, callee_class_hash, callee_contract_class, ); @@ -819,9 +822,9 @@ fn call_events_contract_test() { .insert(callee_address.clone(), callee_nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); - let calldata: Vec = [fn_selector.into()].to_vec(); + let calldata = [fn_selector.into()].to_vec(); let result = execute( &mut state, &caller_address, @@ -893,7 +896,7 @@ fn replace_class_test() { let casm_replace_selector = &casm_entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let mut contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let casm_address = Address(2222.into()); @@ -905,7 +908,7 @@ fn replace_class_test() { insert_sierra_class_into_cache(&mut contract_class_cache, CLASS_HASH_A, contract_class_a); - contract_class_cache.insert( + contract_class_cache.set_contract_class( CASM_CLASS_HASH_A, CompiledClass::Casm(Arc::new(casm_contract_class)), ); @@ -935,19 +938,22 @@ fn replace_class_test() { static CASM_CLASS_HASH_B: ClassHash = ClassHash([4; 32]); insert_sierra_class_into_cache( - &mut contract_class_cache, + &contract_class_cache, CLASS_HASH_B, contract_class_b.clone(), ); - contract_class_cache.insert( + contract_class_cache.set_contract_class( CASM_CLASS_HASH_B, CompiledClass::Casm(Arc::new(casm_contract_class_b.clone())), ); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader.clone()), contract_class_cache.clone()); - let mut vm_state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new( + Arc::new(state_reader.clone()), + Arc::new(contract_class_cache.clone()), + ); + let mut vm_state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // Run upgrade entrypoint and check that the storage was updated with the new contract class // Create an execution entry point @@ -1059,14 +1065,14 @@ fn replace_class_contract_call() { .unwrap(); // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); - let mut native_contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); + let mut native_contract_class_cache = PermanentContractClassCache::default(); let address = Address(Felt252::one()); let class_hash_a: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash_a, CompiledClass::Casm(Arc::new(casm_contract_class_a)), ); @@ -1104,7 +1110,7 @@ fn replace_class_contract_call() { .unwrap(); let class_hash_b: ClassHash = ClassHash([2; 32]); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash_b, CompiledClass::Casm(Arc::new(contract_class_b)), ); @@ -1140,7 +1146,7 @@ fn replace_class_contract_call() { let wrapper_address = Address(Felt252::from(2)); let wrapper_class_hash: ClassHash = ClassHash([3; 32]); - contract_class_cache.insert( + contract_class_cache.set_contract_class( wrapper_class_hash, CompiledClass::Casm(Arc::new(wrapper_contract_class)), ); @@ -1162,8 +1168,11 @@ fn replace_class_contract_call() { .insert(wrapper_address, wrapper_class_hash); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader.clone()), contract_class_cache.clone()); - let mut native_state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new( + Arc::new(state_reader.clone()), + Arc::new(contract_class_cache.clone()), + ); + let mut native_state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // CALL GET_NUMBER BEFORE REPLACE_CLASS let calldata = [].to_vec(); @@ -1258,10 +1267,10 @@ fn keccak_syscall_test() { let native_class_hash: ClassHash = ClassHash([1; 32]); let caller_address = Address(123456789.into()); - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); insert_sierra_class_into_cache( - &mut contract_class_cache, + &contract_class_cache, native_class_hash, sierra_contract_class, ); @@ -1274,7 +1283,7 @@ fn keccak_syscall_test() { .insert(caller_address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); let native_result = execute( &mut state, @@ -1292,7 +1301,7 @@ fn keccak_syscall_test() { #[allow(clippy::too_many_arguments)] fn execute( - state: &mut CachedState, + state: &mut CachedState, caller_address: &Address, callee_address: &Address, selector: &BigUint, @@ -1356,13 +1365,13 @@ fn library_call() { let entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Sierra(Arc::new(( contract_class.extract_sierra_program().unwrap(), @@ -1387,11 +1396,7 @@ fn library_call() { let lib_class_hash: ClassHash = ClassHash([2; 32]); let lib_nonce = Felt252::zero(); - insert_sierra_class_into_cache( - &mut contract_class_cache, - lib_class_hash, - lib_contract_class, - ); + insert_sierra_class_into_cache(&contract_class_cache, lib_class_hash, lib_contract_class); state_reader .address_to_class_hash_mut() @@ -1401,7 +1406,7 @@ fn library_call() { .insert(lib_address, lib_nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // Create an execution entry point let calldata = [ @@ -1495,7 +1500,7 @@ fn library_call() { } fn execute_deploy( - state: &mut CachedState, + state: &mut CachedState, caller_address: &Address, selector: &BigUint, calldata: &[Felt252], @@ -1573,7 +1578,7 @@ fn deploy_syscall_test() { let _fn_selector = &deployee_entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); // Deployer contract data let deployer_address = Address(1111.into()); @@ -1585,13 +1590,13 @@ fn deploy_syscall_test() { let _deployee_nonce = Felt252::zero(); insert_sierra_class_into_cache( - &mut contract_class_cache, + &contract_class_cache, deployer_class_hash, deployer_contract_class, ); insert_sierra_class_into_cache( - &mut contract_class_cache, + &contract_class_cache, deployee_class_hash, deployee_contract_class, ); @@ -1607,7 +1612,7 @@ fn deploy_syscall_test() { .insert(deployer_address.clone(), deployer_nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); let calldata = [ Felt252::from_bytes_be(deployee_class_hash.to_bytes_be()), @@ -1673,7 +1678,7 @@ fn deploy_syscall_address_unavailable_test() { let _fn_selector = &deployee_entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); // Deployer contract data let deployer_address = Address(1111.into()); @@ -1696,13 +1701,13 @@ fn deploy_syscall_address_unavailable_test() { let deployee_address = expected_deployed_contract_address; insert_sierra_class_into_cache( - &mut contract_class_cache, + &contract_class_cache, deployer_class_hash, deployer_contract_class, ); insert_sierra_class_into_cache( - &mut contract_class_cache, + &contract_class_cache, deployee_class_hash, deployee_contract_class, ); @@ -1726,7 +1731,7 @@ fn deploy_syscall_address_unavailable_test() { .insert(deployee_address.clone(), deployee_nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); let calldata = [ Felt252::from_bytes_be(deployee_class_hash.to_bytes_be()), @@ -1771,7 +1776,7 @@ fn get_execution_info_test() { let selector = &entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let mut contract_class_cache = PermanentContractClassCache::default(); // Contract data let address = Address(1111.into()); @@ -1791,7 +1796,7 @@ fn get_execution_info_test() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); let calldata = [].to_vec(); diff --git a/tests/complex_contracts/amm_contracts/amm.rs b/tests/complex_contracts/amm_contracts/amm.rs index 0d7979c08..98653cd58 100644 --- a/tests/complex_contracts/amm_contracts/amm.rs +++ b/tests/complex_contracts/amm_contracts/amm.rs @@ -1,22 +1,27 @@ -use std::collections::{HashMap, HashSet}; -use std::sync::Arc; - -use cairo_vm::vm::runners::builtin_runner::HASH_BUILTIN_NAME; -use cairo_vm::vm::runners::cairo_runner::ExecutionResources; -use cairo_vm::{felt::Felt252, vm::runners::builtin_runner::RANGE_CHECK_BUILTIN_NAME}; +use crate::complex_contracts::utils::*; +use cairo_vm::{ + felt::Felt252, + vm::runners::{ + builtin_runner::{HASH_BUILTIN_NAME, RANGE_CHECK_BUILTIN_NAME}, + cairo_runner::ExecutionResources, + }, +}; use num_traits::Zero; -use starknet_in_rust::definitions::block_context::BlockContext; -use starknet_in_rust::EntryPointType; use starknet_in_rust::{ + definitions::block_context::BlockContext, execution::{CallInfo, CallType}, services::api::contract_classes::deprecated_contract_class::ContractClass, - state::{cached_state::CachedState, state_api::StateReader}, - state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager}, + state::{ + cached_state::CachedState, contract_class_cache::PermanentContractClassCache, + in_memory_state_reader::InMemoryStateReader, state_api::StateReader, + ExecutionResourcesManager, + }, transaction::error::TransactionError, utils::{calculate_sn_keccak, Address}, + EntryPointType, }; - -use crate::complex_contracts::utils::*; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; fn init_pool( calldata: &[Felt252], @@ -54,7 +59,10 @@ fn swap(calldata: &[Felt252], call_config: &mut CallConfig) -> Result { - pub state: &'a mut CachedState, + pub state: &'a mut CachedState, pub caller_address: &'a Address, pub address: &'a Address, pub class_hash: &'a ClassHash, @@ -141,7 +144,7 @@ pub fn execute_entry_point( } pub fn deploy( - state: &mut CachedState, + state: &mut CachedState, path: &str, calldata: &[Felt252], block_context: &BlockContext, diff --git a/tests/delegate_call.rs b/tests/delegate_call.rs index f72feea4b..976a5af74 100644 --- a/tests/delegate_call.rs +++ b/tests/delegate_call.rs @@ -11,12 +11,15 @@ use starknet_in_rust::{ execution_entry_point::ExecutionEntryPoint, CallType, TransactionExecutionContext, }, services::api::contract_classes::deprecated_contract_class::ContractClass, - state::cached_state::CachedState, - state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager}, + state::{ + cached_state::CachedState, + contract_class_cache::{ContractClassCache, PermanentContractClassCache}, + in_memory_state_reader::InMemoryStateReader, + ExecutionResourcesManager, + }, utils::Address, }; -use std::sync::Arc; -use std::{collections::HashMap, path::PathBuf}; +use std::{path::PathBuf, sync::Arc}; #[test] fn delegate_call() { @@ -24,7 +27,7 @@ fn delegate_call() { //* Create state reader with class hash data //* -------------------------------------------- - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let nonce = Felt252::zero(); // Add get_number.cairo contract to the state @@ -35,7 +38,7 @@ fn delegate_call() { let address = Address(Felt252::one()); // const CONTRACT_ADDRESS = 1; let class_hash = ClassHash([2; 32]); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class)), ); @@ -69,7 +72,7 @@ fn delegate_call() { let address = Address(1111.into()); let class_hash = ClassHash([1; 32]); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class)), ); @@ -84,7 +87,7 @@ fn delegate_call() { //* Create state with previous data //* --------------------------------------- - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); //* ------------------------------------ //* Create execution entry point diff --git a/tests/delegate_l1_handler.rs b/tests/delegate_l1_handler.rs index 7508944e1..e4176aa8f 100644 --- a/tests/delegate_l1_handler.rs +++ b/tests/delegate_l1_handler.rs @@ -12,20 +12,21 @@ use starknet_in_rust::{ }, services::api::contract_classes::deprecated_contract_class::ContractClass, state::{ - cached_state::CachedState, in_memory_state_reader::InMemoryStateReader, + cached_state::CachedState, + contract_class_cache::{ContractClassCache, PermanentContractClassCache}, + in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager, }, utils::Address, }; -use std::sync::Arc; -use std::{collections::HashMap, path::PathBuf}; +use std::{path::PathBuf, sync::Arc}; #[test] fn delegate_l1_handler() { //* -------------------------------------------- //* Create state reader with class hash data //* -------------------------------------------- - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let nonce = Felt252::zero(); // Add get_number.cairo contract to the state @@ -36,7 +37,7 @@ fn delegate_l1_handler() { let address = Address(Felt252::one()); // const CONTRACT_ADDRESS = 1; let class_hash: ClassHash = ClassHash([2; 32]); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class)), ); @@ -64,7 +65,7 @@ fn delegate_l1_handler() { let address = Address(1111.into()); let class_hash = ClassHash([1; 32]); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class)), ); @@ -79,7 +80,7 @@ fn delegate_l1_handler() { //* Create state with previous data //* --------------------------------------- - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); //* ------------------------------------ //* Create execution entry point diff --git a/tests/deploy_account.rs b/tests/deploy_account.rs index d283db043..d9bb6fed7 100644 --- a/tests/deploy_account.rs +++ b/tests/deploy_account.rs @@ -4,6 +4,7 @@ use cairo_vm::{ }; use lazy_static::lazy_static; use num_traits::Zero; +use starknet_in_rust::EntryPointType; use starknet_in_rust::{ core::contract_address::compute_deprecated_class_hash, definitions::{ @@ -13,20 +14,18 @@ use starknet_in_rust::{ }, execution::{CallInfo, CallType, TransactionExecutionInfo}, hash_utils::calculate_contract_address, - services::api::contract_classes::deprecated_contract_class::ContractClass, - state::in_memory_state_reader::InMemoryStateReader, - state::{cached_state::CachedState, state_api::State}, + services::api::contract_classes::{ + compiled_class::CompiledClass, deprecated_contract_class::ContractClass, + }, + state::{ + cached_state::CachedState, contract_class_cache::PermanentContractClassCache, + in_memory_state_reader::InMemoryStateReader, state_api::State, + }, transaction::DeployAccount, utils::{Address, ClassHash}, CasmContractClass, }; -use starknet_in_rust::{ - services::api::contract_classes::compiled_class::CompiledClass, EntryPointType, -}; -use std::{ - collections::{HashMap, HashSet}, - sync::Arc, -}; +use std::{collections::HashSet, sync::Arc}; lazy_static! { static ref TEST_ACCOUNT_COMPILED_CONTRACT_CLASS_HASH: Felt252 = felt_str!("1"); @@ -35,9 +34,10 @@ lazy_static! { #[test] fn internal_deploy_account() { let state_reader = Arc::new(InMemoryStateReader::default()); - let mut state = CachedState::new(state_reader, HashMap::new()); - - state.set_contract_classes(Default::default()).unwrap(); + let mut state = CachedState::new( + state_reader, + Arc::new(PermanentContractClassCache::default()), + ); let contract_class = ContractClass::from_path("starknet_programs/account_without_validation.json").unwrap(); @@ -139,9 +139,10 @@ fn internal_deploy_account() { #[test] fn internal_deploy_account_cairo1() { let state_reader = Arc::new(InMemoryStateReader::default()); - let mut state = CachedState::new(state_reader, HashMap::default()); - - state.set_contract_classes(Default::default()).unwrap(); + let mut state = CachedState::new( + state_reader, + Arc::new(PermanentContractClassCache::default()), + ); #[cfg(not(feature = "cairo_1_tests"))] let program_data = include_bytes!("../starknet_programs/cairo2/hello_world_account.casm"); diff --git a/tests/fibonacci.rs b/tests/fibonacci.rs index 56d428c1c..796a5d65d 100644 --- a/tests/fibonacci.rs +++ b/tests/fibonacci.rs @@ -2,24 +2,29 @@ #![deny(warnings)] use cairo_lang_starknet::casm_contract_class::CasmContractClass; -use cairo_vm::vm::runners::cairo_runner::ExecutionResources; -use cairo_vm::{felt::Felt252, vm::runners::builtin_runner::RANGE_CHECK_BUILTIN_NAME}; +use cairo_vm::{ + felt::Felt252, + vm::runners::{builtin_runner::RANGE_CHECK_BUILTIN_NAME, cairo_runner::ExecutionResources}, +}; use num_traits::Zero; -use starknet_in_rust::definitions::block_context::BlockContext; -use starknet_in_rust::services::api::contract_classes::compiled_class::CompiledClass; -use starknet_in_rust::EntryPointType; use starknet_in_rust::{ - definitions::constants::TRANSACTION_VERSION, + definitions::{block_context::BlockContext, constants::TRANSACTION_VERSION}, execution::{ execution_entry_point::ExecutionEntryPoint, CallInfo, CallType, TransactionExecutionContext, }, - services::api::contract_classes::deprecated_contract_class::ContractClass, - state::cached_state::CachedState, - state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager}, + services::api::contract_classes::{ + compiled_class::CompiledClass, deprecated_contract_class::ContractClass, + }, + state::{ + cached_state::CachedState, + contract_class_cache::{ContractClassCache, PermanentContractClassCache}, + in_memory_state_reader::InMemoryStateReader, + ExecutionResourcesManager, + }, utils::{Address, ClassHash}, + EntryPointType, }; -use std::sync::Arc; -use std::{collections::HashMap, path::PathBuf}; +use std::{collections::HashMap, path::PathBuf, sync::Arc}; #[test] fn integration_test() { @@ -43,7 +48,7 @@ fn integration_test() { //* Create state reader with class hash data //* -------------------------------------------- - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); // ------------ contract data -------------------- @@ -51,7 +56,7 @@ fn integration_test() { let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class)), ); @@ -67,7 +72,7 @@ fn integration_test() { //* Create state with previous data //* --------------------------------------- - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); //* ------------------------------------ //* Create execution entry point @@ -151,13 +156,14 @@ fn integration_test_cairo1() { let fib_entrypoint_selector = &entrypoints.external.get(0).unwrap().selector; // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() @@ -167,7 +173,7 @@ fn integration_test_cairo1() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // Create an execution entry point let calldata = [0.into(), 1.into(), 12.into()].to_vec(); diff --git a/tests/increase_balance.rs b/tests/increase_balance.rs index 22235036a..df55bfaa5 100644 --- a/tests/increase_balance.rs +++ b/tests/increase_balance.rs @@ -1,7 +1,6 @@ #![deny(warnings)] -use cairo_vm::felt::Felt252; -use cairo_vm::vm::runners::cairo_runner::ExecutionResources; +use cairo_vm::{felt::Felt252, vm::runners::cairo_runner::ExecutionResources}; use num_traits::Zero; use starknet_in_rust::services::api::contract_classes::compiled_class::CompiledClass; use starknet_in_rust::utils::ClassHash; @@ -12,15 +11,16 @@ use starknet_in_rust::{ execution_entry_point::ExecutionEntryPoint, CallInfo, CallType, TransactionExecutionContext, }, services::api::contract_classes::deprecated_contract_class::ContractClass, - state::{cached_state::CachedState, state_cache::StorageEntry}, - state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager}, + state::{ + cached_state::CachedState, + contract_class_cache::{ContractClassCache, PermanentContractClassCache}, + in_memory_state_reader::InMemoryStateReader, + state_cache::StorageEntry, + ExecutionResourcesManager, + }, utils::{calculate_sn_keccak, Address}, }; -use std::sync::Arc; -use std::{ - collections::{HashMap, HashSet}, - path::PathBuf, -}; +use std::{collections::HashSet, path::PathBuf, sync::Arc}; #[test] fn hello_starknet_increase_balance() { @@ -45,7 +45,7 @@ fn hello_starknet_increase_balance() { //* Create state reader with class hash data //* -------------------------------------------- - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); // ------------ contract data -------------------- @@ -55,7 +55,7 @@ fn hello_starknet_increase_balance() { let storage_entry: StorageEntry = (address.clone(), [1; 32]); let storage = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class)), ); @@ -74,7 +74,7 @@ fn hello_starknet_increase_balance() { //* Create state with previous data //* --------------------------------------- - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); //* ------------------------------------ //* Create execution entry point diff --git a/tests/internal_calls.rs b/tests/internal_calls.rs index d8c22d918..e0ba38b8a 100644 --- a/tests/internal_calls.rs +++ b/tests/internal_calls.rs @@ -1,22 +1,26 @@ #![deny(warnings)] -use std::sync::Arc; - use cairo_vm::felt::Felt252; use num_traits::Zero; -use starknet_in_rust::services::api::contract_classes::compiled_class::CompiledClass; -use starknet_in_rust::EntryPointType; use starknet_in_rust::{ definitions::{block_context::BlockContext, constants::TRANSACTION_VERSION}, execution::{ execution_entry_point::ExecutionEntryPoint, CallType, TransactionExecutionContext, }, - services::api::contract_classes::deprecated_contract_class::ContractClass, - state::{cached_state::CachedState, state_cache::StorageEntry}, - state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager}, + services::api::contract_classes::{ + compiled_class::CompiledClass, deprecated_contract_class::ContractClass, + }, + state::{ + cached_state::CachedState, + contract_class_cache::{ContractClassCache, PermanentContractClassCache}, + in_memory_state_reader::InMemoryStateReader, + state_cache::StorageEntry, + ExecutionResourcesManager, + }, utils::{calculate_sn_keccak, Address, ClassHash}, + EntryPointType, }; -use std::collections::HashMap; +use std::sync::Arc; #[test] fn test_internal_calls() { @@ -49,10 +53,14 @@ fn test_internal_calls() { let mut state = CachedState::new( Arc::new(state_reader), - HashMap::from([( - ClassHash([1; 32]), - CompiledClass::Deprecated(Arc::new(contract_class)), - )]), + Arc::new({ + let cache = PermanentContractClassCache::default(); + cache.set_contract_class( + ClassHash([0x01; 32]), + CompiledClass::Deprecated(Arc::new(contract_class)), + ); + cache + }), ); let entry_point_selector = Felt252::from_bytes_be(&calculate_sn_keccak(b"a")); diff --git a/tests/internals.rs b/tests/internals.rs index 7ffe8dfae..8dcb22897 100644 --- a/tests/internals.rs +++ b/tests/internals.rs @@ -1,21 +1,21 @@ // This module tests our code against the blockifier to ensure they work in the same way. use assert_matches::assert_matches; use cairo_lang_starknet::contract_class::ContractClass as SierraContractClass; -use cairo_vm::felt::{felt_str, Felt252}; -use cairo_vm::vm::runners::builtin_runner::{HASH_BUILTIN_NAME, RANGE_CHECK_BUILTIN_NAME}; -use cairo_vm::vm::{ - errors::{ - cairo_run_errors::CairoRunError, vm_errors::VirtualMachineError, vm_exception::VmException, +use cairo_vm::{ + felt::{felt_str, Felt252}, + vm::runners::builtin_runner::{HASH_BUILTIN_NAME, RANGE_CHECK_BUILTIN_NAME}, + vm::{ + errors::{ + cairo_run_errors::CairoRunError, vm_errors::VirtualMachineError, + vm_exception::VmException, + }, + runners::cairo_runner::ExecutionResources, }, - runners::cairo_runner::ExecutionResources, }; use lazy_static::lazy_static; use num_bigint::BigUint; use num_traits::{Num, One, Zero}; use pretty_assertions_sorted::{assert_eq, assert_eq_sorted}; -use starknet_in_rust::core::contract_address::{ - compute_casm_class_hash, compute_sierra_class_hash, -}; use starknet_in_rust::core::errors::state_errors::StateError; use starknet_in_rust::definitions::constants::{ DEFAULT_CAIRO_RESOURCE_FEE_WEIGHTS, VALIDATE_ENTRY_POINT_SELECTOR, @@ -30,34 +30,37 @@ use starknet_in_rust::transaction::{DeclareV2, Deploy}; use starknet_in_rust::utils::CompiledClassHash; use starknet_in_rust::CasmContractClass; use starknet_in_rust::EntryPointType; +use starknet_in_rust::{ + core::contract_address::{compute_casm_class_hash, compute_sierra_class_hash}, + definitions::constants::{ + CONSTRUCTOR_ENTRY_POINT_SELECTOR, EXECUTE_ENTRY_POINT_SELECTOR, TRANSACTION_VERSION, + TRANSFER_ENTRY_POINT_SELECTOR, TRANSFER_EVENT_SELECTOR, + VALIDATE_DECLARE_ENTRY_POINT_SELECTOR, VALIDATE_DEPLOY_ENTRY_POINT_SELECTOR, + }, +}; use starknet_in_rust::{ definitions::{ block_context::{BlockContext, StarknetChainId, StarknetOsConfig}, - constants::{ - CONSTRUCTOR_ENTRY_POINT_SELECTOR, EXECUTE_ENTRY_POINT_SELECTOR, TRANSACTION_VERSION, - TRANSFER_ENTRY_POINT_SELECTOR, TRANSFER_EVENT_SELECTOR, - VALIDATE_DECLARE_ENTRY_POINT_SELECTOR, VALIDATE_DEPLOY_ENTRY_POINT_SELECTOR, - }, transaction_type::TransactionType, }, execution::{CallInfo, CallType, OrderedEvent, TransactionExecutionInfo}, - state::in_memory_state_reader::InMemoryStateReader, state::{ cached_state::CachedState, + contract_class_cache::{ContractClassCache, PermanentContractClassCache}, + in_memory_state_reader::InMemoryStateReader, state_api::{State, StateReader}, - state_cache::StateCache, - state_cache::StorageEntry, + state_cache::{StateCache, StorageEntry}, BlockInfo, }, transaction::{ - error::TransactionError, - DeployAccount, - {invoke_function::InvokeFunction, Declare}, + error::TransactionError, invoke_function::InvokeFunction, Declare, DeployAccount, }, utils::{calculate_sn_keccak, felt_to_hash, Address, ClassHash}, }; -use std::collections::{HashMap, HashSet}; -use std::sync::Arc; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; const ACCOUNT_CONTRACT_PATH: &str = "starknet_programs/account_without_validation.json"; const ERC20_CONTRACT_PATH: &str = "starknet_programs/ERC20.json"; @@ -120,8 +123,13 @@ pub fn new_starknet_block_context_for_testing() -> BlockContext { ) } -fn create_account_tx_test_state( -) -> Result<(BlockContext, CachedState), Box> { +fn create_account_tx_test_state() -> Result< + ( + BlockContext, + CachedState, + ), + Box, +> { let block_context = new_starknet_block_context_for_testing(); let test_contract_class_hash = *TEST_CLASS_HASH; @@ -195,14 +203,19 @@ fn create_account_tx_test_state( } Arc::new(state_reader) }, - HashMap::new(), + Arc::new(PermanentContractClassCache::default()), ); Ok((block_context, cached_state)) } -fn create_account_tx_test_state_revert_test( -) -> Result<(BlockContext, CachedState), Box> { +fn create_account_tx_test_state_revert_test() -> Result< + ( + BlockContext, + CachedState, + ), + Box, +> { let block_context = new_starknet_block_context_for_testing(); let test_contract_class_hash = *TEST_CLASS_HASH; @@ -280,46 +293,50 @@ fn create_account_tx_test_state_revert_test( } Arc::new(state_reader) }, - HashMap::new(), + Arc::new(PermanentContractClassCache::default()), ); Ok((block_context, cached_state)) } -fn expected_state_before_tx() -> CachedState { +fn expected_state_before_tx() -> CachedState { let in_memory_state_reader = initial_in_memory_state_reader(); - CachedState::new(Arc::new(in_memory_state_reader), HashMap::new()) + CachedState::new( + Arc::new(in_memory_state_reader), + Arc::new(PermanentContractClassCache::default()), + ) } -fn expected_state_after_tx(fee: u128) -> CachedState { +fn expected_state_after_tx( + fee: u128, +) -> CachedState { let in_memory_state_reader = initial_in_memory_state_reader(); - let contract_classes_cache = HashMap::from([ - ( - *TEST_CLASS_HASH, - CompiledClass::Deprecated(Arc::new( - ContractClass::from_path(TEST_CONTRACT_PATH).unwrap(), - )), - ), - ( - *TEST_ACCOUNT_CONTRACT_CLASS_HASH, - CompiledClass::Deprecated(Arc::new( - ContractClass::from_path(ACCOUNT_CONTRACT_PATH).unwrap(), - )), - ), - ( - *TEST_ERC20_CONTRACT_CLASS_HASH, - CompiledClass::Deprecated(Arc::new( - ContractClass::from_path(ERC20_CONTRACT_PATH).unwrap(), - )), - ), - ]); + let contract_classes_cache = PermanentContractClassCache::default(); + contract_classes_cache.set_contract_class( + *TEST_CLASS_HASH, + CompiledClass::Deprecated(Arc::new( + ContractClass::from_path(TEST_CONTRACT_PATH).unwrap(), + )), + ); + contract_classes_cache.set_contract_class( + *TEST_ACCOUNT_CONTRACT_CLASS_HASH, + CompiledClass::Deprecated(Arc::new( + ContractClass::from_path(ACCOUNT_CONTRACT_PATH).unwrap(), + )), + ); + contract_classes_cache.set_contract_class( + *TEST_ERC20_CONTRACT_CLASS_HASH, + CompiledClass::Deprecated(Arc::new( + ContractClass::from_path(ERC20_CONTRACT_PATH).unwrap(), + )), + ); CachedState::new_for_testing( Arc::new(in_memory_state_reader), state_cache_after_invoke_tx(fee), - contract_classes_cache, + Arc::new(contract_classes_cache), ) } @@ -601,8 +618,12 @@ fn test_create_account_tx_test_state() { let expected_initial_state = expected_state_before_tx(); assert_eq!(&state.cache(), &expected_initial_state.cache()); assert_eq!( - &state.contract_classes(), - &expected_initial_state.contract_classes() + (&*state.contract_class_cache().clone()) + .into_iter() + .collect::>(), + (&*expected_initial_state.contract_class_cache().clone()) + .into_iter() + .collect::>() ); assert_eq!( &state.state_reader.address_to_class_hash, @@ -973,8 +994,12 @@ fn test_declare_tx() { let expected_initial_state = expected_state_before_tx(); assert_eq!(&state.cache(), &expected_initial_state.cache()); assert_eq!( - &state.contract_classes(), - &expected_initial_state.contract_classes() + (&*state.contract_class_cache().clone()) + .into_iter() + .collect::>(), + (&*expected_initial_state.contract_class_cache().clone()) + .into_iter() + .collect::>() ); assert_eq!( &state.state_reader.address_to_class_hash, @@ -1066,8 +1091,12 @@ fn test_declarev2_tx() { let expected_initial_state = expected_state_before_tx(); assert_eq!(&state.cache(), &expected_initial_state.cache()); assert_eq!( - &state.contract_classes(), - &expected_initial_state.contract_classes() + (&*state.contract_class_cache().clone()) + .into_iter() + .collect::>(), + (&*expected_initial_state.contract_class_cache().clone()) + .into_iter() + .collect::>() ); assert_eq!( &state.state_reader.address_to_class_hash, @@ -1477,8 +1506,12 @@ fn test_invoke_tx_state() { let expected_initial_state = expected_state_before_tx(); assert_eq!(&state.cache(), &expected_initial_state.cache()); assert_eq!( - &state.contract_classes(), - &expected_initial_state.contract_classes() + (&*state.contract_class_cache().clone()) + .into_iter() + .collect::>(), + (&*expected_initial_state.contract_class_cache().clone()) + .into_iter() + .collect::>() ); assert_eq!( &state.state_reader.address_to_class_hash, @@ -1556,8 +1589,12 @@ fn test_invoke_with_declarev2_tx() { let expected_initial_state = expected_state_before_tx(); assert_eq!(&state.cache(), &expected_initial_state.cache()); assert_eq!( - &state.contract_classes(), - &expected_initial_state.contract_classes() + (&*state.contract_class_cache().clone()) + .into_iter() + .collect::>(), + (&*expected_initial_state.contract_class_cache().clone()) + .into_iter() + .collect::>() ); assert_eq!( &state.state_reader.address_to_class_hash, @@ -1675,7 +1712,14 @@ fn test_deploy_account() { let (state_before, state_after) = expected_deploy_account_states(); assert_eq!(&state.cache(), &state_before.cache()); - assert_eq!(&state.contract_classes(), &state_before.contract_classes()); + assert_eq!( + (&*state.contract_class_cache().clone()) + .into_iter() + .collect::>(), + (&*state_before.contract_class_cache().clone()) + .into_iter() + .collect::>() + ); let tx_info = deploy_account_tx .execute( @@ -1793,8 +1837,15 @@ fn test_deploy_account_revert() { let (state_before, mut state_after) = expected_deploy_account_states(); assert_eq_sorted!(&state.cache(), &state_before.cache()); - assert_eq_sorted!(&state.contract_classes(), &state_before.contract_classes()); - assert!(&state.contract_classes().is_empty()); + assert_eq!( + (&*state.contract_class_cache().clone()) + .into_iter() + .collect::>(), + (&*state_before.contract_class_cache().clone()) + .into_iter() + .collect::>() + ); + assert_eq!(state.contract_class_cache().as_ref().into_iter().count(), 0); let tx_info = deploy_account_tx .execute( @@ -1833,11 +1884,6 @@ fn test_deploy_account_revert() { .storage_initial_values_mut() .extend(state_after.cache_mut().storage_initial_values_mut().clone()); - // Set contract class cache - state_reverted - .set_contract_classes(state_after.contract_classes().clone()) - .unwrap(); - // Set storage writes related to the fee transfer state_reverted .cache_mut() @@ -1918,8 +1964,8 @@ fn test_deploy_account_revert() { } fn expected_deploy_account_states() -> ( - CachedState, - CachedState, + CachedState, + CachedState, ) { let fee = Felt252::from(3097); let mut state_before = CachedState::new( @@ -1962,7 +2008,7 @@ fn expected_deploy_account_states() -> ( ]), HashMap::new(), )), - HashMap::new(), + Arc::new(PermanentContractClassCache::default()), ); state_before.set_storage_at( &( @@ -1974,7 +2020,12 @@ fn expected_deploy_account_states() -> ( INITIAL_BALANCE.clone(), ); - let mut state_after = state_before.clone(); + let mut state_after = state_before.clone_for_testing(); + + // Make the contract cache independent (otherwise tests will fail because the initial state's + // cache will not be empty anymore). + *state_after.contract_class_cache_mut() = Arc::new(PermanentContractClassCache::default()); + state_after.cache_mut().nonce_initial_values_mut().insert( Address(felt_str!( "386181506763903095743576862849245034886954647214831045800703908858571591162" @@ -2445,6 +2496,7 @@ fn test_library_call_with_declare_v2() { { casm_contract_hash = *TEST_FIB_COMPILED_CONTRACT_CLASS_HASH_CAIRO1 } + // Create an execution entry point let calldata = vec![ Felt252::from_bytes_be(casm_contract_hash.to_bytes_be()), diff --git a/tests/multi_syscall_test.rs b/tests/multi_syscall_test.rs index bcea1728a..afafdfc9a 100644 --- a/tests/multi_syscall_test.rs +++ b/tests/multi_syscall_test.rs @@ -1,20 +1,23 @@ use cairo_lang_starknet::casm_contract_class::CasmContractClass; use cairo_vm::felt::Felt252; use num_traits::{Num, Zero}; -use starknet_in_rust::services::api::contract_classes::compiled_class::CompiledClass; -use starknet_in_rust::utils::calculate_sn_keccak; -use starknet_in_rust::EntryPointType; use starknet_in_rust::{ definitions::{block_context::BlockContext, constants::TRANSACTION_VERSION}, execution::{ execution_entry_point::ExecutionEntryPoint, CallInfo, CallType, OrderedEvent, OrderedL2ToL1Message, TransactionExecutionContext, }, - state::cached_state::CachedState, - state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager}, - utils::{Address, ClassHash}, + services::api::contract_classes::compiled_class::CompiledClass, + state::{ + cached_state::CachedState, + contract_class_cache::{ContractClassCache, PermanentContractClassCache}, + in_memory_state_reader::InMemoryStateReader, + ExecutionResourcesManager, + }, + utils::{calculate_sn_keccak, Address, ClassHash}, + EntryPointType, }; -use std::{collections::HashMap, sync::Arc, vec}; +use std::{sync::Arc, vec}; #[test] fn test_multiple_syscall() { @@ -23,13 +26,14 @@ fn test_multiple_syscall() { let contract_class: CasmContractClass = serde_json::from_slice(program_data).unwrap(); // Create state reader with class hash data - let mut contract_class_cache: HashMap = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert(class_hash, CompiledClass::Casm(Arc::new(contract_class))); + contract_class_cache + .set_contract_class(class_hash, CompiledClass::Casm(Arc::new(contract_class))); let mut state_reader = InMemoryStateReader::default(); state_reader .address_to_class_hash_mut() @@ -39,7 +43,7 @@ fn test_multiple_syscall() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache.clone()); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // Create an execution entry point let calldata = [].to_vec(); @@ -60,7 +64,7 @@ fn test_multiple_syscall() { assert_eq!(call_info.retdata, vec![caller_address.clone().0]) } - // Block for get_contact_address. + // Block for get_contract_address. { let call_info = test_syscall( "contract_address", @@ -217,7 +221,7 @@ fn test_syscall( caller_address: Address, entry_point_type: EntryPointType, class_hash: ClassHash, - state: &mut CachedState, + state: &mut CachedState, ) -> CallInfo { let entrypoint_selector = Felt252::from_bytes_be(&calculate_sn_keccak(entrypoint_selector.as_bytes())); diff --git a/tests/storage.rs b/tests/storage.rs index 186919256..f6940fa8e 100644 --- a/tests/storage.rs +++ b/tests/storage.rs @@ -10,15 +10,15 @@ use starknet_in_rust::{ execution_entry_point::ExecutionEntryPoint, CallInfo, CallType, TransactionExecutionContext, }, services::api::contract_classes::deprecated_contract_class::ContractClass, - state::cached_state::CachedState, - state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager}, + state::{ + cached_state::CachedState, + contract_class_cache::{ContractClassCache, PermanentContractClassCache}, + in_memory_state_reader::InMemoryStateReader, + ExecutionResourcesManager, + }, utils::{calculate_sn_keccak, Address}, }; -use std::sync::Arc; -use std::{ - collections::{HashMap, HashSet}, - path::PathBuf, -}; +use std::{collections::HashSet, path::PathBuf, sync::Arc}; #[test] fn integration_storage_test() { @@ -42,7 +42,7 @@ fn integration_storage_test() { //* Create state reader with class hash data //* -------------------------------------------- - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); // ------------ contract data -------------------- @@ -52,7 +52,7 @@ fn integration_storage_test() { let storage_entry = (address.clone(), [90; 32]); let storage_value = Felt252::new(10902); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class)), ); @@ -71,7 +71,7 @@ fn integration_storage_test() { //* Create state with previous data //* --------------------------------------- - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); //* ------------------------------------ //* Create execution entry point diff --git a/tests/syscalls.rs b/tests/syscalls.rs index 704ab757d..d8b9ec685 100644 --- a/tests/syscalls.rs +++ b/tests/syscalls.rs @@ -20,16 +20,18 @@ use starknet_in_rust::{ execution_entry_point::ExecutionEntryPoint, CallInfo, CallType, L2toL1MessageInfo, OrderedEvent, OrderedL2ToL1Message, TransactionExecutionContext, }, - services::api::contract_classes::deprecated_contract_class::ContractClass, + services::api::contract_classes::{ + compiled_class::CompiledClass, deprecated_contract_class::ContractClass, + }, state::{ cached_state::CachedState, - state_api::{State, StateReader}, + contract_class_cache::{ContractClassCache, PermanentContractClassCache}, + in_memory_state_reader::InMemoryStateReader, + state_api::State, + ExecutionResourcesManager, }, - state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager}, utils::{calculate_sn_keccak, felt_to_hash, Address, ClassHash}, -}; -use starknet_in_rust::{ - services::api::contract_classes::compiled_class::CompiledClass, EntryPointType, + EntryPointType, }; use std::{ collections::{HashMap, HashSet}, @@ -93,13 +95,13 @@ fn test_contract<'a>( let mut storage_entries = Vec::new(); let contract_class_cache = { - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); for (class_hash, contract_path, contract_address) in extra_contracts { let contract_class = ContractClass::from_path(contract_path) .expect("Could not load extra contract from JSON"); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class.clone())), ); @@ -122,7 +124,7 @@ fn test_contract<'a>( contract_class_cache }; - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); storage_entries .into_iter() .for_each(|(a, b, c)| state.set_storage_at(&(a, b), c)); @@ -1144,18 +1146,18 @@ fn deploy_cairo1_from_cairo0_with_constructor() { let test_contract_class: CasmContractClass = serde_json::from_slice(program_data).unwrap(); // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); // simulate contract declare - contract_class_cache.insert( + contract_class_cache.set_contract_class( test_class_hash, CompiledClass::Casm(Arc::new(test_contract_class.clone())), ); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class)), ); @@ -1169,7 +1171,7 @@ fn deploy_cairo1_from_cairo0_with_constructor() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // arguments of deploy contract let calldata: Vec<_> = [test_felt_hash, salt, Felt252::one()].to_vec(); @@ -1249,18 +1251,18 @@ fn deploy_cairo1_from_cairo0_without_constructor() { let test_contract_class: CasmContractClass = serde_json::from_slice(program_data).unwrap(); // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); // simulate contract declare - contract_class_cache.insert( + contract_class_cache.set_contract_class( test_class_hash, CompiledClass::Casm(Arc::new(test_contract_class.clone())), ); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class)), ); @@ -1274,7 +1276,7 @@ fn deploy_cairo1_from_cairo0_without_constructor() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // arguments of deploy contract let calldata: Vec<_> = [test_felt_hash, salt].to_vec(); @@ -1356,18 +1358,18 @@ fn deploy_cairo1_and_invoke() { let test_contract_class: CasmContractClass = serde_json::from_slice(program_data).unwrap(); // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); // simulate contract declare - contract_class_cache.insert( + contract_class_cache.set_contract_class( test_class_hash, CompiledClass::Casm(Arc::new(test_contract_class.clone())), ); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class)), ); @@ -1381,7 +1383,7 @@ fn deploy_cairo1_and_invoke() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // arguments of deploy contract let calldata: Vec<_> = [test_felt_hash, salt].to_vec(); @@ -1487,13 +1489,13 @@ fn send_messages_to_l1_different_contract_calls() { .to_owned(); // Create state reader with class hash data - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); let address = Address(1111.into()); let class_hash: ClassHash = ClassHash([1; 32]); let nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class)), ); @@ -1514,7 +1516,7 @@ fn send_messages_to_l1_different_contract_calls() { let send_msg_class_hash: ClassHash = ClassHash([2; 32]); let send_msg_nonce = Felt252::zero(); - contract_class_cache.insert( + contract_class_cache.set_contract_class( send_msg_class_hash, CompiledClass::Deprecated(Arc::new(send_msg_contract_class)), ); @@ -1526,7 +1528,7 @@ fn send_messages_to_l1_different_contract_calls() { .insert(send_msg_address, send_msg_nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); // Create an execution entry point let calldata = [25.into(), 50.into(), 75.into()].to_vec(); diff --git a/tests/syscalls_errors.rs b/tests/syscalls_errors.rs index 5f6937fe5..9c2175ea8 100644 --- a/tests/syscalls_errors.rs +++ b/tests/syscalls_errors.rs @@ -1,25 +1,27 @@ #![deny(warnings)] +use assert_matches::assert_matches; use cairo_vm::felt::Felt252; -use starknet_in_rust::utils::felt_to_hash; -use starknet_in_rust::EntryPointType; use starknet_in_rust::{ core::errors::state_errors::StateError, definitions::{block_context::BlockContext, constants::TRANSACTION_VERSION}, execution::{ execution_entry_point::ExecutionEntryPoint, CallType, TransactionExecutionContext, }, - services::api::contract_classes::deprecated_contract_class::ContractClass, - state::{cached_state::CachedState, state_api::State}, - state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager}, - utils::{calculate_sn_keccak, Address, ClassHash}, + services::api::contract_classes::{ + compiled_class::CompiledClass, deprecated_contract_class::ContractClass, + }, + state::{ + cached_state::CachedState, + contract_class_cache::{ContractClassCache, PermanentContractClassCache}, + in_memory_state_reader::InMemoryStateReader, + state_api::State, + ExecutionResourcesManager, + }, + utils::{calculate_sn_keccak, felt_to_hash, Address, ClassHash}, + EntryPointType, }; -use std::path::Path; -use std::sync::Arc; - -use assert_matches::assert_matches; -use starknet_in_rust::services::api::contract_classes::compiled_class::CompiledClass; -use std::collections::HashMap; +use std::{path::Path, sync::Arc}; #[allow(clippy::too_many_arguments)] fn test_contract<'a>( @@ -69,13 +71,13 @@ fn test_contract<'a>( let mut storage_entries = Vec::new(); let contract_class_cache = { - let mut contract_class_cache = HashMap::new(); + let contract_class_cache = PermanentContractClassCache::default(); for (class_hash, contract_path, contract_address) in extra_contracts { let contract_class = ContractClass::from_path(contract_path) .expect("Could not load extra contract from JSON"); - contract_class_cache.insert( + contract_class_cache.set_contract_class( class_hash, CompiledClass::Deprecated(Arc::new(contract_class.clone())), ); @@ -101,7 +103,7 @@ fn test_contract<'a>( contract_class_cache }; - let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache); + let mut state = CachedState::new(Arc::new(state_reader), Arc::new(contract_class_cache)); storage_entries .into_iter() .for_each(|(a, b, c)| state.set_storage_at(&(a, b), c));