From f8080891a6b737fbb2a2004822d582342268ff31 Mon Sep 17 00:00:00 2001 From: fannyguthmann Date: Thu, 17 Aug 2023 13:32:25 +0300 Subject: [PATCH] change Box to Arc for CompiledClass --- src/execution/execution_entry_point.rs | 6 ++- .../api/contract_classes/compiled_class.rs | 13 ++++--- src/state/cached_state.rs | 38 +++++++++++++++---- src/state/in_memory_state_reader.rs | 8 ++-- src/state/state_cache.rs | 4 +- src/transaction/declare_v2.rs | 8 ++-- src/transaction/deploy.rs | 10 +++-- tests/cairo_1_syscalls.rs | 10 ++--- tests/internals.rs | 2 +- tests/syscalls.rs | 6 +-- 10 files changed, 68 insertions(+), 37 deletions(-) diff --git a/src/execution/execution_entry_point.rs b/src/execution/execution_entry_point.rs index b7116d1f1..554463471 100644 --- a/src/execution/execution_entry_point.rs +++ b/src/execution/execution_entry_point.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::services::api::contract_classes::deprecated_contract_class::{ ContractEntryPoint, EntryPointType, }; @@ -331,7 +333,7 @@ impl ExecutionEntryPoint { resources_manager: &mut ExecutionResourcesManager, block_context: &BlockContext, tx_execution_context: &mut TransactionExecutionContext, - contract_class: Box, + contract_class: Arc, class_hash: [u8; 32], ) -> Result { let previous_cairo_usage = resources_manager.cairo_usage.clone(); @@ -436,7 +438,7 @@ impl ExecutionEntryPoint { resources_manager: &mut ExecutionResourcesManager, block_context: &BlockContext, tx_execution_context: &mut TransactionExecutionContext, - contract_class: Box, + contract_class: Arc, class_hash: [u8; 32], support_reverted: bool, ) -> Result { diff --git a/src/services/api/contract_classes/compiled_class.rs b/src/services/api/contract_classes/compiled_class.rs index bc58c091e..035a96fbb 100644 --- a/src/services/api/contract_classes/compiled_class.rs +++ b/src/services/api/contract_classes/compiled_class.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::io::{self, Read}; +use std::sync::Arc; use crate::core::contract_address::{compute_hinted_class_hash, CairoProgramToHash}; use crate::services::api::contract_class_errors::ContractClassError; @@ -21,15 +22,15 @@ use starknet::core::types::ContractClass::{Legacy, Sierra}; #[derive(Clone, PartialEq, Eq, Debug)] pub enum CompiledClass { - Deprecated(Box), - Casm(Box), + Deprecated(Arc), + Casm(Arc), } impl TryInto for CompiledClass { type Error = ContractClassError; fn try_into(self) -> Result { match self { - CompiledClass::Casm(boxed) => Ok(*boxed), + CompiledClass::Casm(arc) => Ok((*arc).clone()), _ => Err(ContractClassError::NotACasmContractClass), } } @@ -39,7 +40,7 @@ impl TryInto for CompiledClass { type Error = ContractClassError; fn try_into(self) -> Result { match self { - CompiledClass::Deprecated(boxed) => Ok(*boxed), + CompiledClass::Deprecated(arc) => Ok((*arc).clone()), _ => Err(ContractClassError::NotADeprecatedContractClass), } } @@ -77,7 +78,7 @@ impl From for CompiledClass { let casm_cc = CasmContractClass::from_contract_class(sierra_cc, true).unwrap(); - CompiledClass::Casm(Box::new(casm_cc)) + CompiledClass::Casm(Arc::new(casm_cc)) } Legacy(_deprecated_contract_class) => { let as_str = decode_reader(_deprecated_contract_class.program).unwrap(); @@ -148,7 +149,7 @@ impl From for CompiledClass { let v = serde_json::to_value(serialized_cc).unwrap(); let hinted_class_hash = compute_hinted_class_hash(&v).unwrap(); - CompiledClass::Deprecated(Box::new(ContractClass { + CompiledClass::Deprecated(Arc::new(ContractClass { program, entry_points_by_type, abi, diff --git a/src/state/cached_state.rs b/src/state/cached_state.rs index 5366e983d..8b2298de8 100644 --- a/src/state/cached_state.rs +++ b/src/state/cached_state.rs @@ -25,6 +25,7 @@ pub type CasmClassCache = HashMap; pub const UNINITIALIZED_CLASS_HASH: &ClassHash = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"; +/// Represents a cached state of contract classes with optional caches. #[derive(Default, Clone, Debug, Eq, Getters, MutGetters, PartialEq)] pub struct CachedState { pub state_reader: Arc, @@ -37,6 +38,7 @@ pub struct CachedState { } impl CachedState { + /// Constructor, creates a new cached state. pub fn new( state_reader: Arc, contract_class_cache: Option, @@ -50,6 +52,7 @@ impl CachedState { } } + /// Creates a CachedState for testing purposes. pub fn new_for_testing( state_reader: Arc, contract_classes: Option, @@ -64,6 +67,7 @@ impl CachedState { } } + /// Sets the contract classes cache. pub fn set_contract_classes( &mut self, contract_classes: ContractClassCache, @@ -75,6 +79,7 @@ impl CachedState { Ok(()) } + /// Returns the casm classes. #[allow(dead_code)] pub(crate) fn get_casm_classes(&mut self) -> Result<&CasmClassCache, StateError> { self.casm_contract_classes @@ -84,6 +89,7 @@ impl CachedState { } impl StateReader for CachedState { + /// Returns the class hash for a given contract address. fn get_class_hash_at(&self, contract_address: &Address) -> Result { if self.cache.get_class_hash(contract_address).is_none() { match self.state_reader.get_class_hash_at(contract_address) { @@ -105,6 +111,7 @@ impl StateReader for CachedState { .cloned() } + /// Returns the nonce for a given contract address. fn get_nonce_at(&self, contract_address: &Address) -> Result { if self.cache.get_nonce(contract_address).is_none() { return self.state_reader.get_nonce_at(contract_address); @@ -115,6 +122,7 @@ impl StateReader for CachedState { .cloned() } + /// Returns storage data for a given storage entry. fn get_storage_at(&self, storage_entry: &StorageEntry) -> Result { if self.cache.get_storage(storage_entry).is_none() { match self.state_reader.get_storage_at(storage_entry) { @@ -140,6 +148,7 @@ impl StateReader for CachedState { } // TODO: check if that the proper way to store it (converting hash to address) + /// Returned the compiled class hash for a given class hash. fn get_compiled_class_hash(&self, class_hash: &ClassHash) -> Result { if self .cache @@ -156,6 +165,7 @@ impl StateReader for CachedState { .cloned() } + /// Returns the contract class for a given class hash. fn get_contract_class(&self, class_hash: &ClassHash) -> Result { // This method can receive both compiled_class_hash & class_hash and return both casm and deprecated contract classes //, which can be on the cache or on the state_reader, different cases will be described below: @@ -170,7 +180,7 @@ impl StateReader for CachedState { .as_ref() .and_then(|x| x.get(class_hash)) { - return Ok(CompiledClass::Deprecated(Box::new(compiled_class.clone()))); + return Ok(CompiledClass::Deprecated(Arc::new(compiled_class.clone()))); } // I: CASM CONTRACT CLASS : COMPILED_CLASS_HASH if let Some(compiled_class) = self @@ -178,7 +188,7 @@ impl StateReader for CachedState { .as_ref() .and_then(|x| x.get(class_hash)) { - return Ok(CompiledClass::Casm(Box::new(compiled_class.clone()))); + return Ok(CompiledClass::Casm(Arc::new(compiled_class.clone()))); } // I: CASM CONTRACT CLASS : CLASS_HASH if let Some(compiled_class_hash) = @@ -189,7 +199,7 @@ impl StateReader for CachedState { .as_ref() .and_then(|m| m.get(compiled_class_hash)) { - return Ok(CompiledClass::Casm(Box::new(casm_class.clone()))); + return Ok(CompiledClass::Casm(Arc::new(casm_class.clone()))); } } // II: FETCHING FROM STATE_READER @@ -198,6 +208,7 @@ impl StateReader for CachedState { } impl State for CachedState { + /// Stores a contract class in the cache. fn set_contract_class( &mut self, class_hash: &ClassHash, @@ -215,6 +226,7 @@ impl State for CachedState { Ok(()) } + /// Deploys a new contract and updates the cache. fn deploy_contract( &mut self, deploy_contract_address: Address, @@ -433,7 +445,7 @@ impl State for CachedState { .as_ref() .and_then(|x| x.get(class_hash)) { - return Ok(CompiledClass::Deprecated(Box::new(compiled_class.clone()))); + return Ok(CompiledClass::Deprecated(Arc::new(compiled_class.clone()))); } // I: CASM CONTRACT CLASS : COMPILED_CLASS_HASH if let Some(compiled_class) = self @@ -441,7 +453,7 @@ impl State for CachedState { .as_ref() .and_then(|x| x.get(class_hash)) { - return Ok(CompiledClass::Casm(Box::new(compiled_class.clone()))); + return Ok(CompiledClass::Casm(Arc::new(compiled_class.clone()))); } // I: CASM CONTRACT CLASS : CLASS_HASH if let Some(compiled_class_hash) = @@ -452,7 +464,7 @@ impl State for CachedState { .as_ref() .and_then(|m| m.get(compiled_class_hash)) { - return Ok(CompiledClass::Casm(Box::new(casm_class.clone()))); + return Ok(CompiledClass::Casm(Arc::new(casm_class.clone()))); } } // II: FETCHING FROM STATE_READER @@ -463,7 +475,7 @@ impl State for CachedState { let compiled_class_hash = self.get_compiled_class_hash(class_hash)?; self.casm_contract_classes .as_mut() - .and_then(|m| m.insert(compiled_class_hash, *class.clone())); + .and_then(|m| m.insert(compiled_class_hash, class.as_ref().clone())); } CompiledClass::Deprecated(ref contract) => { self.set_contract_class(class_hash, &contract.clone())? @@ -481,6 +493,8 @@ mod tests { use num_traits::One; + /// Test checks if class hashes and nonces are correctly fetched from the state reader. + /// It also tests the increment_nonce method. #[test] fn get_class_hash_and_nonce_from_state_reader() { let mut state_reader = InMemoryStateReader::new( @@ -522,6 +536,7 @@ mod tests { ); } + /// This test checks if the contract class is correctly fetched from the state reader. #[test] fn get_contract_class_from_state_reader() { let mut state_reader = InMemoryStateReader::new( @@ -554,6 +569,7 @@ mod tests { ); } + /// This test verifies the correct handling of storage in the cached state. #[test] fn cached_state_storage_test() { let mut cached_state = @@ -572,6 +588,7 @@ mod tests { .is_zero()); } + /// This test checks if deploying a contract works as expected. #[test] fn cached_state_deploy_contract_test() { let state_reader = Arc::new(InMemoryStateReader::default()); @@ -585,6 +602,7 @@ mod tests { .is_ok()); } + /// This test verifies the set and get storage values in the cached state. #[test] fn get_and_set_storage() { let state_reader = Arc::new(InMemoryStateReader::default()); @@ -611,6 +629,7 @@ mod tests { assert_eq!(new_result.unwrap(), new_value); } + /// This test ensures that an error is thrown when trying to set contract classes twice. #[test] fn set_contract_classes_twice_error_test() { let state_reader = InMemoryStateReader::new( @@ -631,6 +650,7 @@ mod tests { assert_matches!(result, StateError::AssignedContractClassCache); } + /// This test ensures that an error is thrown if a contract address is out of range. #[test] fn deploy_contract_address_out_of_range_error_test() { let state_reader = InMemoryStateReader::new( @@ -656,6 +676,7 @@ mod tests { ); } + /// This test ensures that an error is thrown if a contract address is already in use. #[test] fn deploy_contract_address_in_use_error_test() { let state_reader = InMemoryStateReader::new( @@ -684,6 +705,7 @@ mod tests { ); } + /// This test checks if replacing a contract in the cached state works correctly. #[test] fn cached_state_replace_contract_test() { let state_reader = InMemoryStateReader::new( @@ -713,6 +735,7 @@ mod tests { ); } + /// This test verifies if the cached state's internal structures are correctly updated after applying a state update. #[test] fn cached_state_apply_state_update() { let state_reader = InMemoryStateReader::new( @@ -757,6 +780,7 @@ mod tests { assert!(cached_state.cache.class_hash_initial_values.is_empty()); } + /// This test calculate the number of actual storage changes. #[test] fn count_actual_storage_changes_test() { let state_reader = InMemoryStateReader::default(); diff --git a/src/state/in_memory_state_reader.rs b/src/state/in_memory_state_reader.rs index bad842c38..de9b3e827 100644 --- a/src/state/in_memory_state_reader.rs +++ b/src/state/in_memory_state_reader.rs @@ -13,7 +13,7 @@ use crate::{ use cairo_vm::felt::Felt252; use getset::{Getters, MutGetters}; use num_traits::Zero; -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; /// A [StateReader] that holds all the data in memory. /// @@ -80,10 +80,10 @@ impl InMemoryStateReader { compiled_class_hash: &CompiledClassHash, ) -> Result { if let Some(compiled_class) = self.casm_contract_classes.get(compiled_class_hash) { - return Ok(CompiledClass::Casm(Box::new(compiled_class.clone()))); + return Ok(CompiledClass::Casm(Arc::new(compiled_class.clone()))); } if let Some(compiled_class) = self.class_hash_to_contract_class.get(compiled_class_hash) { - return Ok(CompiledClass::Deprecated(Box::new(compiled_class.clone()))); + return Ok(CompiledClass::Deprecated(Arc::new(compiled_class.clone()))); } Err(StateError::NoneCompiledClass(*compiled_class_hash)) } @@ -128,7 +128,7 @@ impl StateReader for InMemoryStateReader { fn get_contract_class(&self, class_hash: &ClassHash) -> Result { // Deprecated contract classes dont have a compiled_class_hash, we dont need to fetch it if let Some(compiled_class) = self.class_hash_to_contract_class.get(class_hash) { - return Ok(CompiledClass::Deprecated(Box::new(compiled_class.clone()))); + return Ok(CompiledClass::Deprecated(Arc::new(compiled_class.clone()))); } let compiled_class_hash = self.get_compiled_class_hash(class_hash)?; if compiled_class_hash != *UNINITIALIZED_CLASS_HASH { diff --git a/src/state/state_cache.rs b/src/state/state_cache.rs index b17165c44..03efa1774 100644 --- a/src/state/state_cache.rs +++ b/src/state/state_cache.rs @@ -219,6 +219,8 @@ impl StateCache { /// Unit tests for StateCache #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::services::api::contract_classes::deprecated_contract_class::ContractClass; use super::*; @@ -230,7 +232,7 @@ mod tests { let contract_class = ContractClass::from_path("starknet_programs/raw_contract_classes/class_with_abi.json") .unwrap(); - let compiled_class = CompiledClass::Deprecated(Box::new(contract_class)); + let compiled_class = CompiledClass::Deprecated(Arc::new(contract_class)); let class_hash_to_compiled_class_hash = HashMap::from([([8; 32], compiled_class)]); let address_to_nonce = HashMap::from([(Address(9.into()), 12.into())]); let storage_updates = HashMap::from([((Address(4.into()), [1; 32]), 18.into())]); diff --git a/src/transaction/declare_v2.rs b/src/transaction/declare_v2.rs index 688329a09..0b80174c5 100644 --- a/src/transaction/declare_v2.rs +++ b/src/transaction/declare_v2.rs @@ -528,7 +528,7 @@ mod tests { .get_contract_class(&internal_declare.compiled_class_hash.to_be_bytes()) .unwrap() { - CompiledClass::Casm(casm) => *casm, + CompiledClass::Casm(casm) => casm.as_ref().clone(), _ => unreachable!(), }; @@ -597,7 +597,7 @@ mod tests { .get_contract_class(&internal_declare.compiled_class_hash.to_be_bytes()) .unwrap() { - CompiledClass::Casm(casm) => *casm, + CompiledClass::Casm(casm) => casm.as_ref().clone(), _ => unreachable!(), }; @@ -668,7 +668,7 @@ mod tests { .get_contract_class(&internal_declare.compiled_class_hash.to_be_bytes()) .unwrap() { - CompiledClass::Casm(casm) => *casm, + CompiledClass::Casm(casm) => casm.as_ref().clone(), _ => unreachable!(), }; @@ -737,7 +737,7 @@ mod tests { .get_contract_class(&internal_declare.compiled_class_hash.to_be_bytes()) .unwrap() { - CompiledClass::Casm(casm) => *casm, + CompiledClass::Casm(casm) => casm.as_ref().clone(), _ => unreachable!(), }; diff --git a/src/transaction/deploy.rs b/src/transaction/deploy.rs index 52112ead4..23606ca65 100644 --- a/src/transaction/deploy.rs +++ b/src/transaction/deploy.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::execution::execution_entry_point::ExecutionResult; use crate::services::api::contract_classes::deprecated_contract_class::{ ContractClass, EntryPointType, @@ -80,7 +82,7 @@ impl Deploy { contract_address, contract_address_salt, contract_hash, - contract_class: CompiledClass::Deprecated(Box::new(contract_class)), + contract_class: CompiledClass::Deprecated(Arc::new(contract_class)), constructor_calldata, tx_type: TransactionType::Deploy, skip_validate: false, @@ -114,7 +116,7 @@ impl Deploy { contract_address_salt, contract_hash, constructor_calldata, - contract_class: CompiledClass::Deprecated(Box::new(contract_class)), + contract_class: CompiledClass::Deprecated(Arc::new(contract_class)), tx_type: TransactionType::Deploy, skip_validate: false, skip_execute: false, @@ -153,7 +155,7 @@ impl Deploy { CompiledClass::Casm(contract_class) => { state.set_compiled_class( &Felt252::from_bytes_be(&self.contract_hash), - *contract_class, + contract_class.as_ref().clone(), )?; } CompiledClass::Deprecated(contract_class) => { @@ -350,7 +352,7 @@ mod tests { assert_eq!( state.get_contract_class(&class_hash_bytes).unwrap(), - CompiledClass::Deprecated(Box::new(contract_class)) + CompiledClass::Deprecated(Arc::new(contract_class)) ); assert_eq!( diff --git a/tests/cairo_1_syscalls.rs b/tests/cairo_1_syscalls.rs index ed75e392d..1f32b20e6 100644 --- a/tests/cairo_1_syscalls.rs +++ b/tests/cairo_1_syscalls.rs @@ -722,7 +722,7 @@ fn deploy_cairo1_from_cairo1() { let ret_class_hash = state.get_class_hash_at(&ret_address).unwrap(); let ret_casm_class = match state.get_contract_class(&ret_class_hash).unwrap() { - CompiledClass::Casm(class) => *class, + CompiledClass::Casm(class) => class.as_ref().clone(), CompiledClass::Deprecated(_) => unreachable!(), }; @@ -824,7 +824,7 @@ fn deploy_cairo0_from_cairo1_without_constructor() { let ret_class_hash = state.get_class_hash_at(&ret_address).unwrap(); let ret_casm_class = match state.get_contract_class(&ret_class_hash).unwrap() { - CompiledClass::Deprecated(class) => *class, + CompiledClass::Deprecated(class) => class.as_ref().clone(), CompiledClass::Casm(_) => unreachable!(), }; @@ -925,7 +925,7 @@ fn deploy_cairo0_from_cairo1_with_constructor() { let ret_class_hash = state.get_class_hash_at(&ret_address).unwrap(); let ret_casm_class = match state.get_contract_class(&ret_class_hash).unwrap() { - CompiledClass::Deprecated(class) => *class, + CompiledClass::Deprecated(class) => class.as_ref().clone(), CompiledClass::Casm(_) => unreachable!(), }; @@ -1027,7 +1027,7 @@ fn deploy_cairo0_and_invoke() { let ret_class_hash = state.get_class_hash_at(&ret_address).unwrap(); let ret_casm_class = match state.get_contract_class(&ret_class_hash).unwrap() { - CompiledClass::Deprecated(class) => *class, + CompiledClass::Deprecated(class) => class.as_ref().clone(), CompiledClass::Casm(_) => unreachable!(), }; @@ -1342,7 +1342,7 @@ fn replace_class_internal() { // Check that the class_hash_b leads to contract_class_b for soundness assert_eq!( state.get_contract_class(&class_hash_b).unwrap(), - CompiledClass::Casm(Box::new(contract_class_b)) + CompiledClass::Casm(Arc::new(contract_class_b)) ); } diff --git a/tests/internals.rs b/tests/internals.rs index fd77a0f16..26595c01b 100644 --- a/tests/internals.rs +++ b/tests/internals.rs @@ -777,7 +777,7 @@ fn deploy_fib_syscall() -> Deploy { let program_data = include_bytes!("../starknet_programs/cairo1/fibonacci.sierra"); let sierra_contract_class: SierraContractClass = serde_json::from_slice(program_data).unwrap(); let casm_class = CasmContractClass::from_contract_class(sierra_contract_class, true).unwrap(); - let contract_class = CompiledClass::Casm(Box::new(casm_class)); + let contract_class = CompiledClass::Casm(Arc::new(casm_class)); let contract_hash; #[cfg(not(feature = "cairo_1_tests"))] diff --git a/tests/syscalls.rs b/tests/syscalls.rs index 5eab7d63d..23ee0b2e0 100644 --- a/tests/syscalls.rs +++ b/tests/syscalls.rs @@ -1154,7 +1154,7 @@ fn deploy_cairo1_from_cairo0_with_constructor() { let ret_class_hash = state.get_class_hash_at(&ret_address).unwrap(); let ret_casm_class = match state.get_contract_class(&ret_class_hash).unwrap() { - CompiledClass::Casm(class) => *class, + CompiledClass::Casm(class) => class.as_ref().clone(), CompiledClass::Deprecated(_) => unreachable!(), }; @@ -1257,7 +1257,7 @@ fn deploy_cairo1_from_cairo0_without_constructor() { let ret_class_hash = state.get_class_hash_at(&ret_address).unwrap(); let ret_casm_class = match state.get_contract_class(&ret_class_hash).unwrap() { - CompiledClass::Casm(class) => *class, + CompiledClass::Casm(class) => class.as_ref().clone(), CompiledClass::Deprecated(_) => unreachable!(), }; @@ -1358,7 +1358,7 @@ fn deploy_cairo1_and_invoke() { let ret_class_hash = state.get_class_hash_at(&ret_address).unwrap(); let ret_casm_class = match state.get_contract_class(&ret_class_hash).unwrap() { - CompiledClass::Casm(class) => *class, + CompiledClass::Casm(class) => class.as_ref().clone(), CompiledClass::Deprecated(_) => unreachable!(), };