diff --git a/src/state/cached_state.rs b/src/state/cached_state.rs index 98b37deb6..2a63a88cf 100644 --- a/src/state/cached_state.rs +++ b/src/state/cached_state.rs @@ -105,28 +105,12 @@ impl StateReader for CachedState { } /// Returns storage data for a given storage entry. + /// Returns zero as default value if missing fn get_storage_at(&self, storage_entry: &StorageEntry) -> Result { - if self.cache.get_storage(storage_entry).is_none() { - match self.state_reader.get_storage_at(storage_entry) { - Ok(storage) => { - return Ok(storage); - } - Err( - StateError::EmptyKeyInStorage - | StateError::NoneStoragLeaf(_) - | StateError::NoneStorage(_) - | StateError::NoneContractState(_), - ) => return Ok(Felt252::zero()), - Err(e) => { - return Err(e); - } - } - } - self.cache .get_storage(storage_entry) - .ok_or_else(|| StateError::NoneStorage(storage_entry.clone())) - .cloned() + .map(|v| Ok(v.clone())) + .unwrap_or_else(|| self.state_reader.get_storage_at(storage_entry)) } // TODO: check if that the proper way to store it (converting hash to address) @@ -353,27 +337,20 @@ impl State for CachedState { .clone()) } + /// Returns storage data for a given storage entry. + /// Returns zero as default value if missing + /// Adds the value to the cache's inital_values if not present fn get_storage_at(&mut self, storage_entry: &StorageEntry) -> Result { - if self.cache.get_storage(storage_entry).is_none() { - let value = match self.state_reader.get_storage_at(storage_entry) { - Ok(value) => value, - Err( - StateError::EmptyKeyInStorage - | StateError::NoneStoragLeaf(_) - | StateError::NoneStorage(_) - | StateError::NoneContractState(_), - ) => Felt252::zero(), - Err(e) => return Err(e), - }; - self.cache - .storage_initial_values - .insert(storage_entry.clone(), value); + match self.cache.get_storage(storage_entry) { + Some(value) => Ok(value.clone()), + None => { + let value = self.state_reader.get_storage_at(storage_entry)?; + self.cache + .storage_initial_values + .insert(storage_entry.clone(), value.clone()); + Ok(value) + } } - - self.cache - .get_storage(storage_entry) - .ok_or_else(|| StateError::NoneStorage(storage_entry.clone())) - .cloned() } // TODO: check if that the proper way to store it (converting hash to address) diff --git a/src/state/in_memory_state_reader.rs b/src/state/in_memory_state_reader.rs index 2d6a001f7..82c0dc734 100644 --- a/src/state/in_memory_state_reader.rs +++ b/src/state/in_memory_state_reader.rs @@ -97,11 +97,11 @@ impl StateReader for InMemoryStateReader { } fn get_storage_at(&self, storage_entry: &StorageEntry) -> Result { - let storage = self + Ok(self .address_to_storage .get(storage_entry) - .ok_or_else(|| StateError::NoneStorage(storage_entry.clone())); - storage.cloned() + .cloned() + .unwrap_or_default()) } fn get_compiled_class_hash( @@ -132,10 +132,21 @@ impl StateReader for InMemoryStateReader { #[cfg(test)] mod tests { + use num_traits::One; + use super::*; use crate::services::api::contract_classes::deprecated_contract_class::ContractClass; use std::sync::Arc; + #[test] + fn get_storage_returns_zero_if_missing() { + let state_reader = InMemoryStateReader::default(); + assert!(state_reader + .get_storage_at(&(Address(Felt252::one()), Felt252::one().to_be_bytes())) + .unwrap() + .is_zero()) + } + #[test] fn get_contract_state_test() { let mut state_reader = InMemoryStateReader::new( diff --git a/src/state/state_api.rs b/src/state/state_api.rs index edeab04de..663ba5825 100644 --- a/src/state/state_api.rs +++ b/src/state/state_api.rs @@ -16,6 +16,7 @@ pub trait StateReader { /// Returns the nonce of the given contract instance. fn get_nonce_at(&self, contract_address: &Address) -> Result; /// Returns the storage value under the given key in the given contract instance. + /// Returns zero by default if the value is not present fn get_storage_at(&self, storage_entry: &StorageEntry) -> Result; /// Return the class hash of the given casm contract class fn get_compiled_class_hash( @@ -65,6 +66,8 @@ pub trait State { /// Default: 0 for an uninitialized contract address. fn get_nonce_at(&mut self, contract_address: &Address) -> Result; + /// Returns storage data for a given storage entry. + /// Returns zero as default value if missing fn get_storage_at(&mut self, storage_entry: &StorageEntry) -> Result; fn get_compiled_class_hash(&mut self, class_hash: &ClassHash) -> Result;