Skip to content
This repository was archived by the owner on Jul 22, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use starknet_in_rust::{
serde_structs::read_abi,
services::api::contract_classes::deprecated_contract_class::ContractClass,
state::{cached_state::CachedState, state_api::State},
state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager},
state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager, StateDiff},
transaction::{error::TransactionError, InvokeFunction},
utils::{felt_to_hash, string_to_hash, Address},
};
Expand Down Expand Up @@ -195,7 +195,9 @@ fn invoke_parser(
Some(Felt252::zero()),
transaction_hash.unwrap(),
)?;
let _tx_info = internal_invoke.apply(cached_state, &BlockContext::default(), 0)?;
let mut transactional_state = cached_state.create_transactional();
let _tx_info = internal_invoke.apply(&mut transactional_state, &BlockContext::default(), 0)?;
cached_state.apply_state_update(&StateDiff::from_cached_state(transactional_state)?)?;

let tx_hash = calculate_transaction_hash_common(
TransactionHashPrefix::Invoke,
Expand Down
1 change: 1 addition & 0 deletions rpc_state_reader/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ mod tests {
rpc_state.get_transaction(tx_hash);
}

#[ignore]
#[test]
fn test_get_block_info() {
let rpc_state = RpcState::new(
Expand Down
24 changes: 6 additions & 18 deletions src/execution/execution_entry_point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::services::api::contract_classes::deprecated_contract_class::{
ContractEntryPoint, EntryPointType,
};
use crate::state::cached_state::CachedState;
use crate::state::StateDiff;
use crate::{
definitions::{block_context::BlockContext, constants::DEFAULT_ENTRY_POINT_SELECTOR},
runner::StarknetRunner,
Expand Down Expand Up @@ -126,31 +125,20 @@ impl ExecutionEntryPoint {
})
}
CompiledClass::Casm(contract_class) => {
let mut tmp_state = CachedState::new(
state.state_reader.clone(),
state.contract_classes.clone(),
state.casm_contract_classes.clone(),
);
tmp_state.cache = state.cache.clone();

match self._execute(
&mut tmp_state,
state,
resources_manager,
block_context,
tx_execution_context,
contract_class,
class_hash,
support_reverted,
) {
Ok(call_info) => {
let state_diff = StateDiff::from_cached_state(tmp_state)?;
state.apply_state_update(&state_diff)?;
Ok(ExecutionResult {
call_info: Some(call_info),
revert_error: None,
n_reverted_steps: 0,
})
}
Ok(call_info) => Ok(ExecutionResult {
call_info: Some(call_info),
revert_error: None,
n_reverted_steps: 0,
}),
Err(e) => {
if !support_reverted {
return Err(e);
Expand Down
4 changes: 2 additions & 2 deletions src/execution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -610,11 +610,11 @@ impl TransactionExecutionInfo {
Ok(sorted_messages)
}

pub fn to_revert_error(self, revert_error: String) -> Self {
pub fn to_revert_error(self, revert_error: &str) -> Self {
TransactionExecutionInfo {
validate_info: None,
call_info: None,
revert_error: Some(revert_error),
revert_error: Some(revert_error.to_string()),
fee_transfer_info: None,
..self
}
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ mod test {
let transaction = Transaction::InvokeFunction(invoke_function);

let estimated_fee = estimate_fee(&[transaction], state, &block_context).unwrap();
assert_eq!(estimated_fee[0], (2483, 2448));
assert_eq!(estimated_fee[0], (1259, 1224));
}

#[test]
Expand Down Expand Up @@ -1014,7 +1014,7 @@ mod test {

assert_eq!(
estimate_fee(&[deploy, invoke_tx], state, block_context,).unwrap(),
[(0, 3672), (0, 2448)]
[(0, 3672), (0, 1224)]
);
}

Expand Down
213 changes: 202 additions & 11 deletions src/state/cached_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,16 @@ impl<T: StateReader> CachedState<T> {
.ok_or(StateError::MissingCasmClassCache)
}

pub fn create_copy(&self) -> Self {
let mut state = CachedState::new(
self.state_reader.clone(),
self.contract_classes.clone(),
self.casm_contract_classes.clone(),
);
state.cache = self.cache.clone();

state
/// Creates a copy of this state with an empty cache for saving changes and applying them
/// later.
pub fn create_transactional(&self) -> TransactionalCachedState<T> {
let state_reader = Arc::new(TransactionalCachedStateReader::new(self));
CachedState {
state_reader,
cache: Default::default(),
contract_classes: Default::default(),
casm_contract_classes: Default::default(),
}
}
}

Expand Down Expand Up @@ -471,10 +472,10 @@ impl<T: StateReader> State for CachedState<T> {
match contract {
CompiledClass::Casm(ref class) => {
// We call this method instead of state_reader's in order to update the cache's class_hash_initial_values map
let compiled_class_hash = self.get_compiled_class_hash(class_hash)?;
//let compiled_class_hash = self.get_compiled_class_hash(class_hash)?;
self.casm_contract_classes
.as_mut()
.and_then(|m| m.insert(compiled_class_hash, *class.clone()));
.and_then(|m| m.insert(*class_hash, *class.clone()));
}
CompiledClass::Deprecated(ref contract) => {
self.set_contract_class(class_hash, &contract.clone())?
Expand All @@ -484,6 +485,196 @@ impl<T: StateReader> State for CachedState<T> {
}
}

/// A CachedState which has access to another, "parent" state, used for executing transactions
/// without commiting changes to the parent.
pub type TransactionalCachedState<'a, T> = CachedState<TransactionalCachedStateReader<'a, T>>;

impl<'a, T: StateReader> TransactionalCachedState<'a, T> {
pub fn count_actual_storage_changes(&mut self) -> Result<(usize, usize), StateError> {
let storage_updates = subtract_mappings(
self.cache.storage_writes.clone(),
self.cache.storage_initial_values.clone(),
);

let n_modified_contracts = {
let storage_unique_updates = storage_updates.keys().map(|k| k.0.clone());

let class_hash_updates: Vec<_> = subtract_mappings(
self.cache.class_hash_writes.clone(),
self.cache.class_hash_initial_values.clone(),
)
.keys()
.cloned()
.collect();

let nonce_updates: Vec<_> = subtract_mappings(
self.cache.nonce_writes.clone(),
self.cache.nonce_initial_values.clone(),
)
.keys()
.cloned()
.collect();

let mut modified_contracts: HashSet<Address> = HashSet::new();
modified_contracts.extend(storage_unique_updates);
modified_contracts.extend(class_hash_updates);
modified_contracts.extend(nonce_updates);

modified_contracts.len()
};

Ok((n_modified_contracts, storage_updates.len()))
}
}

/// State reader used for transactional states which allows to check the parent state's cache and
/// state reader if a transactional cache miss happens.
///
/// In practice this will act as a way to access the parent state's cache and other fields,
/// without referencing the whole parent state, so there's no need to adapt state-modifying
/// functions in the case that a transactional state is needed.
#[derive(Debug, MutGetters, Getters, PartialEq, Clone)]
pub struct TransactionalCachedStateReader<'a, T: StateReader> {
/// The parent state's state_reader
#[get(get = "pub")]
pub(crate) state_reader: Arc<T>,
/// The parent state's cache
#[get(get = "pub")]
pub(crate) cache: &'a StateCache,
/// The parent state's contract_classes
#[get(get = "pub")]
pub(crate) contract_classes: Option<ContractClassCache>,
/// The parent state's casm_contract_classes
#[get(get = "pub")]
pub(crate) casm_contract_classes: Option<CasmClassCache>,
}

impl<'a, T: StateReader> TransactionalCachedStateReader<'a, T> {
fn new(state: &'a CachedState<T>) -> Self {
Self {
state_reader: state.state_reader.clone(),
cache: &state.cache,
contract_classes: state.contract_classes.clone(),
casm_contract_classes: state.casm_contract_classes.clone(),
}
}
}

impl<'a, T: StateReader> StateReader for TransactionalCachedStateReader<'a, T> {
fn get_class_hash_at(&self, contract_address: &Address) -> Result<ClassHash, StateError> {
if self.cache.get_class_hash(contract_address).is_none() {
match self.state_reader.get_class_hash_at(contract_address) {
Ok(class_hash) => {
return Ok(class_hash);
}
Err(StateError::NoneContractState(_)) => {
return Ok([0; 32]);
}
Err(e) => {
return Err(e);
}
}
}

self.cache
.get_class_hash(contract_address)
.ok_or_else(|| StateError::NoneClassHash(contract_address.clone()))
.cloned()
}

fn get_nonce_at(&self, contract_address: &Address) -> Result<Felt252, StateError> {
if self.cache.get_nonce(contract_address).is_none() {
return self.state_reader.get_nonce_at(contract_address);
}
self.cache
.get_nonce(contract_address)
.ok_or_else(|| StateError::NoneNonce(contract_address.clone()))
.cloned()
}

fn get_storage_at(&self, storage_entry: &StorageEntry) -> Result<Felt252, StateError> {
if self.cache.get_storage(storage_entry).is_none() {
match self.state_reader.get_storage_at(storage_entry) {
Ok(storage) => {
return Ok(storage);
}
Err(
StateError::EmptyKeyInStorage
| StateError::NoneStoragLeaf(_)
| StateError::NoneStorage(_)
| StateError::NoneContractState(_),
) => return Ok(Felt252::zero()),
Err(e) => {
return Err(e);
}
}
}

self.cache
.get_storage(storage_entry)
.ok_or_else(|| StateError::NoneStorage(storage_entry.clone()))
.cloned()
}

// TODO: check if that the proper way to store it (converting hash to address)
fn get_compiled_class_hash(&self, class_hash: &ClassHash) -> Result<ClassHash, StateError> {
if self
.cache
.class_hash_to_compiled_class_hash
.get(class_hash)
.is_none()
{
return self.state_reader.get_compiled_class_hash(class_hash);
}
self.cache
.class_hash_to_compiled_class_hash
.get(class_hash)
.ok_or_else(|| StateError::NoneCompiledClass(*class_hash))
.cloned()
}

fn get_contract_class(&self, class_hash: &ClassHash) -> Result<CompiledClass, StateError> {
// This method can receive both compiled_class_hash & class_hash and return both casm and deprecated contract classes
//, which can be on the cache or on the state_reader, different cases will be described below:
if class_hash == UNINITIALIZED_CLASS_HASH {
return Err(StateError::UninitiaizedClassHash);
}
// I: FETCHING FROM CACHE
// I: DEPRECATED CONTRACT CLASS
// deprecated contract classes dont have compiled class hashes, so we only have one case
if let Some(compiled_class) = self
.contract_classes
.as_ref()
.and_then(|x| x.get(class_hash))
{
return Ok(CompiledClass::Deprecated(Box::new(compiled_class.clone())));
}
// I: CASM CONTRACT CLASS : COMPILED_CLASS_HASH
if let Some(compiled_class) = self
.casm_contract_classes
.as_ref()
.and_then(|x| x.get(class_hash))
{
return Ok(CompiledClass::Casm(Box::new(compiled_class.clone())));
}
// I: CASM CONTRACT CLASS : CLASS_HASH
if let Some(compiled_class_hash) =
self.cache.class_hash_to_compiled_class_hash.get(class_hash)
{
if let Some(casm_class) = &mut self
.casm_contract_classes
.as_ref()
.and_then(|m| m.get(compiled_class_hash))
{
return Ok(CompiledClass::Casm(Box::new(casm_class.clone())));
}
}
// II: FETCHING FROM STATE_READER
let contract = self.state_reader.get_contract_class(class_hash)?;
Ok(contract)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
2 changes: 1 addition & 1 deletion src/state/in_memory_state_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub struct InMemoryStateReader {
#[getset(get_mut = "pub")]
pub(crate) casm_contract_classes: CasmClassCache,
#[getset(get_mut = "pub")]
pub(crate) class_hash_to_compiled_class_hash: HashMap<ClassHash, CompiledClassHash>,
pub class_hash_to_compiled_class_hash: HashMap<ClassHash, CompiledClassHash>,
}

impl InMemoryStateReader {
Expand Down
8 changes: 7 additions & 1 deletion src/syscalls/deprecated_syscall_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ mod tests {

use super::*;
use crate::services::api::contract_classes::deprecated_contract_class::EntryPointType;
use crate::state::StateDiff;
use crate::{
add_segments, allocate_selector, any_box,
definitions::{
Expand Down Expand Up @@ -1188,9 +1189,14 @@ mod tests {
)
.unwrap();

let mut transactional = state.create_transactional();
// Invoke result
let result = internal_invoke_function
.apply(&mut state, &BlockContext::default(), 0)
.apply(&mut transactional, &BlockContext::default(), 0)
.unwrap();

state
.apply_state_update(&StateDiff::from_cached_state(transactional).unwrap())
.unwrap();

let result_call_info = result.call_info.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion src/testing/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ mod tests {
.unwrap();
let actual_resources = HashMap::from([
("n_steps".to_string(), 3457),
("l1_gas_usage".to_string(), 2448),
("l1_gas_usage".to_string(), 1224),
("range_check_builtin".to_string(), 80),
("pedersen_builtin".to_string(), 16),
]);
Expand Down
Loading