diff --git a/src/state/cached_state.rs b/src/state/cached_state.rs index c453e1a96..894dc7f9b 100644 --- a/src/state/cached_state.rs +++ b/src/state/cached_state.rs @@ -16,15 +16,14 @@ use cairo_vm::felt::Felt252; use getset::{Getters, MutGetters}; use num_traits::Zero; use std::{ - cell::RefCell, collections::{HashMap, HashSet}, - sync::Arc, + sync::{Arc, RwLock}, }; pub const UNINITIALIZED_CLASS_HASH: &ClassHash = &[0u8; 32]; /// Represents a cached state of contract classes with optional caches. -#[derive(Default, Clone, Debug, Getters, MutGetters)] +#[derive(Default, Debug, Getters, MutGetters)] pub struct CachedState { pub state_reader: Arc, #[getset(get = "pub", get_mut = "pub")] @@ -32,7 +31,7 @@ pub struct CachedState { #[getset(get = "pub", get_mut = "pub")] pub(crate) contract_class_cache: Arc, - pub(crate) contract_class_cache_private: RefCell>, + pub(crate) contract_class_cache_private: RwLock>, #[cfg(feature = "metrics")] cache_hits: usize, @@ -73,7 +72,7 @@ impl CachedState { cache: StateCache::default(), state_reader, contract_class_cache: contract_classes, - contract_class_cache_private: RefCell::new(HashMap::new()), + contract_class_cache_private: RwLock::new(HashMap::new()), #[cfg(feature = "metrics")] cache_hits: 0, @@ -92,7 +91,7 @@ impl CachedState { cache, state_reader, contract_class_cache: contract_classes, - contract_class_cache_private: RefCell::new(HashMap::new()), + contract_class_cache_private: RwLock::new(HashMap::new()), #[cfg(feature = "metrics")] cache_hits: 0, @@ -104,7 +103,11 @@ impl CachedState { pub fn drain_private_contract_class_cache( &self, ) -> impl Iterator { - self.contract_class_cache_private.take().into_iter() + self.contract_class_cache_private + .read() + .unwrap() + .clone() + .into_iter() } /// Creates a copy of this state with an empty cache for saving changes and applying them @@ -115,8 +118,8 @@ impl CachedState { state_reader, cache: self.cache.clone(), contract_class_cache: self.contract_class_cache.clone(), - contract_class_cache_private: RefCell::new( - self.contract_class_cache_private.borrow().clone(), + contract_class_cache_private: RwLock::new( + self.contract_class_cache_private.read().unwrap().clone(), ), #[cfg(feature = "metrics")] cache_hits: 0, @@ -177,7 +180,7 @@ impl StateReader for CachedState { } // I: FETCHING FROM CACHE - let mut private_cache = self.contract_class_cache_private.borrow_mut(); + let mut private_cache = self.contract_class_cache_private.write().unwrap(); if let Some(compiled_class) = private_cache.get(class_hash) { return Ok(compiled_class.clone()); } else if let Some(compiled_class) = @@ -221,6 +224,7 @@ impl State for CachedState { // have a mutable reference to the `RefCell` available. self.contract_class_cache_private .get_mut() + .unwrap() .insert(*class_hash, contract_class.clone()); Ok(()) @@ -446,6 +450,7 @@ impl State for CachedState { if let Some(compiled_class) = self .contract_class_cache_private .get_mut() + .unwrap() .get(class_hash) .cloned() { @@ -457,6 +462,7 @@ impl State for CachedState { self.add_hit(); self.contract_class_cache_private .get_mut() + .unwrap() .insert(*class_hash, compiled_class.clone()); return Ok(compiled_class); } @@ -465,14 +471,11 @@ impl State for CachedState { if let Some(compiled_class_hash) = self.cache.class_hash_to_compiled_class_hash.get(class_hash) { + let write_guard = self.contract_class_cache_private.get_mut().unwrap(); + // `RefCell::get_mut()` provides a mutable reference without the borrowing overhead when // we have a mutable reference to the `RefCell` available. - if let Some(casm_class) = self - .contract_class_cache_private - .get_mut() - .get(compiled_class_hash) - .cloned() - { + if let Some(casm_class) = write_guard.get(compiled_class_hash).cloned() { self.add_hit(); return Ok(casm_class); } else if let Some(casm_class) = self @@ -482,6 +485,7 @@ impl State for CachedState { self.add_hit(); self.contract_class_cache_private .get_mut() + .unwrap() .insert(*class_hash, casm_class.clone()); return Ok(casm_class); } @@ -517,7 +521,7 @@ pub type TransactionalCachedState<'a, T, C> = /// In practice this will act as a way to access the parent state's cache and other fields, /// without referencing the whole parent state, so there's no need to adapt state-modifying /// functions in the case that a transactional state is needed. -#[derive(Debug, MutGetters, Getters, PartialEq, Clone)] +#[derive(Debug, MutGetters, Getters)] pub struct TransactionalCachedStateReader<'a, T: StateReader, C: ContractClassCache> { /// The parent state's state_reader #[get(get = "pub")] @@ -529,7 +533,7 @@ pub struct TransactionalCachedStateReader<'a, T: StateReader, C: ContractClassCa /// The parent state's contract_classes #[get(get = "pub")] pub(crate) contract_class_cache: Arc, - pub(crate) contract_class_cache_private: &'a RefCell>, + pub(crate) contract_class_cache_private: &'a RwLock>, } impl<'a, T: StateReader, C: ContractClassCache> TransactionalCachedStateReader<'a, T, C> { @@ -602,7 +606,7 @@ impl<'a, T: StateReader, C: ContractClassCache> StateReader } // I: FETCHING FROM CACHE - let mut private_cache = self.contract_class_cache_private.borrow_mut(); + let mut private_cache = self.contract_class_cache_private.write().unwrap(); if let Some(compiled_class) = private_cache.get(class_hash) { return Ok(compiled_class.clone()); } else if let Some(compiled_class) =