Skip to content
This repository was archived by the owner on Jul 22, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ fn invoke_parser(
)?;
let mut transactional_state = cached_state.create_transactional();
let _tx_info = internal_invoke.apply(&mut transactional_state, &BlockContext::default(), 0)?;
cached_state.apply_state_update(&StateDiff::from_cached_state(transactional_state)?)?;
cached_state.apply_state_update(&StateDiff::from_cached_state(transactional_state.cache())?)?;

let tx_hash = calculate_transaction_hash_common(
TransactionHashPrefix::Invoke,
Expand Down
113 changes: 2 additions & 111 deletions src/state/cached_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,9 @@ impl<T: StateReader> CachedState<T> {

/// Creates a copy of this state with an empty cache for saving changes and applying them
/// later.
pub fn create_transactional(&self) -> TransactionalCachedState<T> {
let state_reader = Arc::new(TransactionalCachedStateReader::new(self));
pub fn create_transactional(&self) -> CachedState<T> {
CachedState {
state_reader,
state_reader: self.state_reader.clone(),
cache: self.cache.clone(),
contract_classes: self.contract_classes.clone(),
cache_hits: 0,
Expand Down Expand Up @@ -445,114 +444,6 @@ impl<T: StateReader> State for CachedState<T> {
}
}

/// A CachedState which has access to another, "parent" state, used for executing transactions
/// without commiting changes to the parent.
pub type TransactionalCachedState<'a, T> = CachedState<TransactionalCachedStateReader<'a, T>>;

/// State reader used for transactional states which allows to check the parent state's cache and
/// state reader if a transactional cache miss happens.
///
/// In practice this will act as a way to access the parent state's cache and other fields,
/// without referencing the whole parent state, so there's no need to adapt state-modifying
/// functions in the case that a transactional state is needed.
#[derive(Debug, MutGetters, Getters, PartialEq, Clone)]
pub struct TransactionalCachedStateReader<'a, T: StateReader> {
/// The parent state's state_reader
#[get(get = "pub")]
pub(crate) state_reader: Arc<T>,
/// The parent state's cache
#[get(get = "pub")]
pub(crate) cache: &'a StateCache,
/// The parent state's contract_classes
#[get(get = "pub")]
pub(crate) contract_classes: ContractClassCache,
}

impl<'a, T: StateReader> TransactionalCachedStateReader<'a, T> {
fn new(state: &'a CachedState<T>) -> Self {
Self {
state_reader: state.state_reader.clone(),
cache: &state.cache,
contract_classes: state.contract_classes.clone(),
}
}
}

impl<'a, T: StateReader> StateReader for TransactionalCachedStateReader<'a, T> {
/// Returns the class hash for a given contract address.
/// Returns zero as default value if missing
fn get_class_hash_at(&self, contract_address: &Address) -> Result<ClassHash, StateError> {
self.cache
.get_class_hash(contract_address)
.map(|a| Ok(*a))
.unwrap_or_else(|| self.state_reader.get_class_hash_at(contract_address))
}

/// Returns the nonce for a given contract address.
fn get_nonce_at(&self, contract_address: &Address) -> Result<Felt252, StateError> {
if self.cache.get_nonce(contract_address).is_none() {
return self.state_reader.get_nonce_at(contract_address);
}
self.cache
.get_nonce(contract_address)
.ok_or_else(|| StateError::NoneNonce(contract_address.clone()))
.cloned()
}

/// Returns storage data for a given storage entry.
/// Returns zero as default value if missing
fn get_storage_at(&self, storage_entry: &StorageEntry) -> Result<Felt252, StateError> {
self.cache
.get_storage(storage_entry)
.map(|v| Ok(v.clone()))
.unwrap_or_else(|| self.state_reader.get_storage_at(storage_entry))
}

// TODO: check if that the proper way to store it (converting hash to address)
/// Returned the compiled class hash for a given class hash.
fn get_compiled_class_hash(&self, class_hash: &ClassHash) -> Result<ClassHash, StateError> {
if self
.cache
.class_hash_to_compiled_class_hash
.get(class_hash)
.is_none()
{
return self.state_reader.get_compiled_class_hash(class_hash);
}
self.cache
.class_hash_to_compiled_class_hash
.get(class_hash)
.ok_or_else(|| StateError::NoneCompiledClass(*class_hash))
.cloned()
}

/// Returns the contract class for a given class hash.
fn get_contract_class(&self, class_hash: &ClassHash) -> Result<CompiledClass, StateError> {
// This method can receive both compiled_class_hash & class_hash and return both casm and deprecated contract classes
//, which can be on the cache or on the state_reader, different cases will be described below:
if class_hash == UNINITIALIZED_CLASS_HASH {
return Err(StateError::UninitiaizedClassHash);
}

// I: FETCHING FROM CACHE
if let Some(compiled_class) = self.contract_classes.get(class_hash) {
return Ok(compiled_class.clone());
}

// I: CASM CONTRACT CLASS : CLASS_HASH
if let Some(compiled_class_hash) =
self.cache.class_hash_to_compiled_class_hash.get(class_hash)
{
if let Some(casm_class) = self.contract_classes.get(compiled_class_hash) {
return Ok(casm_class.clone());
}
}

// II: FETCHING FROM STATE_READER
self.state_reader.get_contract_class(class_hash)
}
}

impl<T: StateReader> CachedState<T> {
// Updates the cache's storage_initial_values according to those in storage_writes
// If a key is present in the storage_writes but not in storage_initial_values,
Expand Down
23 changes: 9 additions & 14 deletions src/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::{
utils::{Address, ClassHash},
};

use self::{cached_state::CachedState, state_api::StateReader};
use self::{cached_state::CachedState, state_api::StateReader, state_cache::StateCache};

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct BlockInfo {
Expand Down Expand Up @@ -125,18 +125,13 @@ impl StateDiff {
}
}

pub fn from_cached_state<T>(cached_state: CachedState<T>) -> Result<Self, StateError>
where
T: StateReader,
{
let state_cache = cached_state.cache().to_owned();

let substracted_maps = state_cache.storage_writes;
pub fn from_cached_state(state_cache: &StateCache) -> Result<Self, StateError> {
let substracted_maps = &state_cache.storage_writes;
let storage_updates = to_state_diff_storage_mapping(substracted_maps);

let address_to_nonce = state_cache.nonce_writes;
let class_hash_to_compiled_class = state_cache.compiled_class_hash_writes;
let address_to_class_hash = state_cache.class_hash_writes;
let address_to_nonce = state_cache.nonce_writes.clone();
let class_hash_to_compiled_class = state_cache.compiled_class_hash_writes.clone();
let address_to_class_hash = state_cache.class_hash_writes.clone();

Ok(StateDiff {
address_to_class_hash,
Expand Down Expand Up @@ -248,7 +243,7 @@ mod test {

let cached_state = CachedState::new(Arc::new(state_reader), HashMap::new());

let diff = StateDiff::from_cached_state(cached_state).unwrap();
let diff = StateDiff::from_cached_state(&cached_state.cache).unwrap();

assert_eq!(0, diff.storage_updates.len());
}
Expand Down Expand Up @@ -319,7 +314,7 @@ mod test {
let cached_state_original =
CachedState::new(Arc::new(state_reader.clone()), HashMap::new());

let diff = StateDiff::from_cached_state(cached_state_original.clone()).unwrap();
let diff = StateDiff::from_cached_state(cached_state_original.cache()).unwrap();

let cached_state = diff.to_cached_state(Arc::new(state_reader)).unwrap();

Expand Down Expand Up @@ -367,7 +362,7 @@ mod test {
let cached_state =
CachedState::new_for_testing(Arc::new(state_reader), cache, HashMap::new());

let mut diff = StateDiff::from_cached_state(cached_state).unwrap();
let mut diff = StateDiff::from_cached_state(cached_state.cache()).unwrap();

let diff_squashed = diff.squash(diff.clone());

Expand Down
2 changes: 1 addition & 1 deletion src/syscalls/deprecated_syscall_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1203,7 +1203,7 @@ mod tests {
.unwrap();

state
.apply_state_update(&StateDiff::from_cached_state(transactional).unwrap())
.apply_state_update(&StateDiff::from_cached_state(transactional.cache()).unwrap())
.unwrap();

let result_call_info = result.call_info.unwrap();
Expand Down
3 changes: 2 additions & 1 deletion src/transaction/deploy_account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ impl DeployAccount {
.as_str(),
);
} else {
state.apply_state_update(&StateDiff::from_cached_state(transactional_state)?)?;
state
.apply_state_update(&StateDiff::from_cached_state(transactional_state.cache())?)?;
}

let mut tx_execution_context =
Expand Down
11 changes: 6 additions & 5 deletions src/transaction/invoke_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::{
},
services::api::contract_classes::deprecated_contract_class::EntryPointType,
state::{
cached_state::{CachedState, TransactionalCachedState},
cached_state::CachedState,
state_api::{State, StateReader},
ExecutionResourcesManager, StateDiff,
},
Expand Down Expand Up @@ -239,7 +239,7 @@ impl InvokeFunction {
/// - remaining_gas: The amount of gas that the transaction disposes.
pub fn apply<S: StateReader>(
&self,
state: &mut TransactionalCachedState<S>,
state: &mut CachedState<S>,
block_context: &BlockContext,
remaining_gas: u128,
) -> Result<TransactionExecutionInfo, TransactionError> {
Expand Down Expand Up @@ -334,7 +334,8 @@ impl InvokeFunction {
.as_str(),
);
} else {
state.apply_state_update(&StateDiff::from_cached_state(transactional_state)?)?;
state
.apply_state_update(&StateDiff::from_cached_state(transactional_state.cache())?)?;
}

let mut tx_execution_context =
Expand Down Expand Up @@ -674,7 +675,7 @@ mod tests {
.apply(&mut transactional, &BlockContext::default(), 0)
.unwrap();
state
.apply_state_update(&StateDiff::from_cached_state(transactional).unwrap())
.apply_state_update(&StateDiff::from_cached_state(transactional.cache()).unwrap())
.unwrap();

assert_eq!(result.tx_type, Some(TransactionType::InvokeFunction));
Expand Down Expand Up @@ -882,7 +883,7 @@ mod tests {
.apply(&mut transactional, &BlockContext::default(), 0)
.unwrap();
state
.apply_state_update(&StateDiff::from_cached_state(transactional).unwrap())
.apply_state_update(&StateDiff::from_cached_state(transactional.cache()).unwrap())
.unwrap();

assert_eq!(result.tx_type, Some(TransactionType::InvokeFunction));
Expand Down
14 changes: 7 additions & 7 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,17 @@ pub fn string_to_hash(class_string: &String) -> ClassHash {

/// Converts CachedState storage mapping to StateDiff storage mapping.
pub fn to_state_diff_storage_mapping(
storage_writes: HashMap<StorageEntry, Felt252>,
storage_writes: &HashMap<StorageEntry, Felt252>,
) -> HashMap<Address, HashMap<Felt252, Felt252>> {
let mut storage_updates: HashMap<Address, HashMap<Felt252, Felt252>> = HashMap::new();
for ((address, key), value) in storage_writes.into_iter() {
for ((address, key), value) in storage_writes.iter() {
storage_updates
.entry(address)
.entry(address.clone())
.and_modify(|updates_for_address: &mut HashMap<Felt252, Felt252>| {
let key_fe = Felt252::from_bytes_be(&key);
let key_fe = Felt252::from_bytes_be(key);
updates_for_address.insert(key_fe, value.clone());
})
.or_insert_with(|| HashMap::from([(Felt252::from_bytes_be(&key), value)]));
.or_insert_with(|| HashMap::from([(Felt252::from_bytes_be(key), value.clone())]));
}
storage_updates
}
Expand Down Expand Up @@ -805,7 +805,7 @@ mod test {
storage.insert((address1.clone(), key1), value1.clone());
storage.insert((address2.clone(), key2), value2.clone());

let map = to_state_diff_storage_mapping(storage);
let map = to_state_diff_storage_mapping(&storage);

let key1_fe = Felt252::from_bytes_be(key1.as_slice());
let key2_fe = Felt252::from_bytes_be(key2.as_slice());
Expand Down Expand Up @@ -876,7 +876,7 @@ mod test {
storage.insert((address1.clone(), key1), value1.clone());
storage.insert((address2.clone(), key2), value2.clone());

let state_dff = to_state_diff_storage_mapping(storage);
let state_dff = to_state_diff_storage_mapping(&storage);
let cache_storage = to_cache_state_storage_mapping(&state_dff);

let mut expected_res = HashMap::new();
Expand Down