Skip to content
This repository was archived by the owner on Jul 22, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/execution/execution_entry_point.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use crate::services::api::contract_classes::deprecated_contract_class::{
ContractEntryPoint, EntryPointType,
};
Expand Down Expand Up @@ -331,7 +333,7 @@ impl ExecutionEntryPoint {
resources_manager: &mut ExecutionResourcesManager,
block_context: &BlockContext,
tx_execution_context: &mut TransactionExecutionContext,
contract_class: Box<ContractClass>,
contract_class: Arc<ContractClass>,
class_hash: [u8; 32],
) -> Result<CallInfo, TransactionError> {
let previous_cairo_usage = resources_manager.cairo_usage.clone();
Expand Down Expand Up @@ -436,7 +438,7 @@ impl ExecutionEntryPoint {
resources_manager: &mut ExecutionResourcesManager,
block_context: &BlockContext,
tx_execution_context: &mut TransactionExecutionContext,
contract_class: Box<CasmContractClass>,
contract_class: Arc<CasmContractClass>,
class_hash: [u8; 32],
support_reverted: bool,
) -> Result<CallInfo, TransactionError> {
Expand Down
13 changes: 7 additions & 6 deletions src/services/api/contract_classes/compiled_class.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashMap;
use std::io::{self, Read};
use std::sync::Arc;

use crate::core::contract_address::{compute_hinted_class_hash, CairoProgramToHash};
use crate::services::api::contract_class_errors::ContractClassError;
Expand All @@ -21,15 +22,15 @@ use starknet::core::types::ContractClass::{Legacy, Sierra};

#[derive(Clone, PartialEq, Eq, Debug)]
pub enum CompiledClass {
Deprecated(Box<ContractClass>),
Casm(Box<CasmContractClass>),
Deprecated(Arc<ContractClass>),
Casm(Arc<CasmContractClass>),
}

impl TryInto<CasmContractClass> for CompiledClass {
type Error = ContractClassError;
fn try_into(self) -> Result<CasmContractClass, ContractClassError> {
match self {
CompiledClass::Casm(boxed) => Ok(*boxed),
CompiledClass::Casm(arc) => Ok((*arc).clone()),
_ => Err(ContractClassError::NotACasmContractClass),
}
}
Expand All @@ -39,7 +40,7 @@ impl TryInto<ContractClass> for CompiledClass {
type Error = ContractClassError;
fn try_into(self) -> Result<ContractClass, ContractClassError> {
match self {
CompiledClass::Deprecated(boxed) => Ok(*boxed),
CompiledClass::Deprecated(arc) => Ok((*arc).clone()),
_ => Err(ContractClassError::NotADeprecatedContractClass),
}
}
Expand Down Expand Up @@ -77,7 +78,7 @@ impl From<StarknetRsContractClass> for CompiledClass {

let casm_cc = CasmContractClass::from_contract_class(sierra_cc, true).unwrap();

CompiledClass::Casm(Box::new(casm_cc))
CompiledClass::Casm(Arc::new(casm_cc))
}
Legacy(_deprecated_contract_class) => {
let as_str = decode_reader(_deprecated_contract_class.program).unwrap();
Expand Down Expand Up @@ -148,7 +149,7 @@ impl From<StarknetRsContractClass> for CompiledClass {
let v = serde_json::to_value(serialized_cc).unwrap();
let hinted_class_hash = compute_hinted_class_hash(&v).unwrap();

CompiledClass::Deprecated(Box::new(ContractClass {
CompiledClass::Deprecated(Arc::new(ContractClass {
program,
entry_points_by_type,
abi,
Expand Down
38 changes: 31 additions & 7 deletions src/state/cached_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub type CasmClassCache = HashMap<ClassHash, CasmContractClass>;

pub const UNINITIALIZED_CLASS_HASH: &ClassHash = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";

/// Represents a cached state of contract classes with optional caches.
#[derive(Default, Clone, Debug, Eq, Getters, MutGetters, PartialEq)]
pub struct CachedState<T: StateReader> {
pub state_reader: Arc<T>,
Expand All @@ -37,6 +38,7 @@ pub struct CachedState<T: StateReader> {
}

impl<T: StateReader> CachedState<T> {
/// Constructor, creates a new cached state.
pub fn new(
state_reader: Arc<T>,
contract_class_cache: Option<ContractClassCache>,
Expand All @@ -50,6 +52,7 @@ impl<T: StateReader> CachedState<T> {
}
}

/// Creates a CachedState for testing purposes.
pub fn new_for_testing(
state_reader: Arc<T>,
contract_classes: Option<ContractClassCache>,
Expand All @@ -64,6 +67,7 @@ impl<T: StateReader> CachedState<T> {
}
}

/// Sets the contract classes cache.
pub fn set_contract_classes(
&mut self,
contract_classes: ContractClassCache,
Expand All @@ -75,6 +79,7 @@ impl<T: StateReader> CachedState<T> {
Ok(())
}

/// Returns the casm classes.
#[allow(dead_code)]
pub(crate) fn get_casm_classes(&mut self) -> Result<&CasmClassCache, StateError> {
self.casm_contract_classes
Expand All @@ -84,6 +89,7 @@ impl<T: StateReader> CachedState<T> {
}

impl<T: StateReader> StateReader for CachedState<T> {
/// Returns the class hash for a given contract address.
fn get_class_hash_at(&self, contract_address: &Address) -> Result<ClassHash, StateError> {
if self.cache.get_class_hash(contract_address).is_none() {
match self.state_reader.get_class_hash_at(contract_address) {
Expand All @@ -105,6 +111,7 @@ impl<T: StateReader> StateReader for CachedState<T> {
.cloned()
}

/// Returns the nonce for a given contract address.
fn get_nonce_at(&self, contract_address: &Address) -> Result<Felt252, StateError> {
if self.cache.get_nonce(contract_address).is_none() {
return self.state_reader.get_nonce_at(contract_address);
Expand All @@ -115,6 +122,7 @@ impl<T: StateReader> StateReader for CachedState<T> {
.cloned()
}

/// Returns storage data for a given storage entry.
fn get_storage_at(&self, storage_entry: &StorageEntry) -> Result<Felt252, StateError> {
if self.cache.get_storage(storage_entry).is_none() {
match self.state_reader.get_storage_at(storage_entry) {
Expand All @@ -140,6 +148,7 @@ impl<T: StateReader> StateReader for CachedState<T> {
}

// TODO: check if that the proper way to store it (converting hash to address)
/// Returned the compiled class hash for a given class hash.
fn get_compiled_class_hash(&self, class_hash: &ClassHash) -> Result<ClassHash, StateError> {
if self
.cache
Expand All @@ -156,6 +165,7 @@ impl<T: StateReader> StateReader for CachedState<T> {
.cloned()
}

/// Returns the contract class for a given class hash.
fn get_contract_class(&self, class_hash: &ClassHash) -> Result<CompiledClass, StateError> {
// This method can receive both compiled_class_hash & class_hash and return both casm and deprecated contract classes
//, which can be on the cache or on the state_reader, different cases will be described below:
Expand All @@ -170,15 +180,15 @@ impl<T: StateReader> StateReader for CachedState<T> {
.as_ref()
.and_then(|x| x.get(class_hash))
{
return Ok(CompiledClass::Deprecated(Box::new(compiled_class.clone())));
return Ok(CompiledClass::Deprecated(Arc::new(compiled_class.clone())));
}
// I: CASM CONTRACT CLASS : COMPILED_CLASS_HASH
if let Some(compiled_class) = self
.casm_contract_classes
.as_ref()
.and_then(|x| x.get(class_hash))
{
return Ok(CompiledClass::Casm(Box::new(compiled_class.clone())));
return Ok(CompiledClass::Casm(Arc::new(compiled_class.clone())));
}
// I: CASM CONTRACT CLASS : CLASS_HASH
if let Some(compiled_class_hash) =
Expand All @@ -189,7 +199,7 @@ impl<T: StateReader> StateReader for CachedState<T> {
.as_ref()
.and_then(|m| m.get(compiled_class_hash))
{
return Ok(CompiledClass::Casm(Box::new(casm_class.clone())));
return Ok(CompiledClass::Casm(Arc::new(casm_class.clone())));
}
}
// II: FETCHING FROM STATE_READER
Expand All @@ -198,6 +208,7 @@ impl<T: StateReader> StateReader for CachedState<T> {
}

impl<T: StateReader> State for CachedState<T> {
/// Stores a contract class in the cache.
fn set_contract_class(
&mut self,
class_hash: &ClassHash,
Expand All @@ -215,6 +226,7 @@ impl<T: StateReader> State for CachedState<T> {
Ok(())
}

/// Deploys a new contract and updates the cache.
fn deploy_contract(
&mut self,
deploy_contract_address: Address,
Expand Down Expand Up @@ -433,15 +445,15 @@ impl<T: StateReader> State for CachedState<T> {
.as_ref()
.and_then(|x| x.get(class_hash))
{
return Ok(CompiledClass::Deprecated(Box::new(compiled_class.clone())));
return Ok(CompiledClass::Deprecated(Arc::new(compiled_class.clone())));
}
// I: CASM CONTRACT CLASS : COMPILED_CLASS_HASH
if let Some(compiled_class) = self
.casm_contract_classes
.as_ref()
.and_then(|x| x.get(class_hash))
{
return Ok(CompiledClass::Casm(Box::new(compiled_class.clone())));
return Ok(CompiledClass::Casm(Arc::new(compiled_class.clone())));
}
// I: CASM CONTRACT CLASS : CLASS_HASH
if let Some(compiled_class_hash) =
Expand All @@ -452,7 +464,7 @@ impl<T: StateReader> State for CachedState<T> {
.as_ref()
.and_then(|m| m.get(compiled_class_hash))
{
return Ok(CompiledClass::Casm(Box::new(casm_class.clone())));
return Ok(CompiledClass::Casm(Arc::new(casm_class.clone())));
}
}
// II: FETCHING FROM STATE_READER
Expand All @@ -463,7 +475,7 @@ impl<T: StateReader> State for CachedState<T> {
let compiled_class_hash = self.get_compiled_class_hash(class_hash)?;
self.casm_contract_classes
.as_mut()
.and_then(|m| m.insert(compiled_class_hash, *class.clone()));
.and_then(|m| m.insert(compiled_class_hash, class.as_ref().clone()));
}
CompiledClass::Deprecated(ref contract) => {
self.set_contract_class(class_hash, &contract.clone())?
Expand All @@ -481,6 +493,8 @@ mod tests {

use num_traits::One;

/// Test checks if class hashes and nonces are correctly fetched from the state reader.
/// It also tests the increment_nonce method.
#[test]
fn get_class_hash_and_nonce_from_state_reader() {
let mut state_reader = InMemoryStateReader::new(
Expand Down Expand Up @@ -522,6 +536,7 @@ mod tests {
);
}

/// This test checks if the contract class is correctly fetched from the state reader.
#[test]
fn get_contract_class_from_state_reader() {
let mut state_reader = InMemoryStateReader::new(
Expand Down Expand Up @@ -554,6 +569,7 @@ mod tests {
);
}

/// This test verifies the correct handling of storage in the cached state.
#[test]
fn cached_state_storage_test() {
let mut cached_state =
Expand All @@ -572,6 +588,7 @@ mod tests {
.is_zero());
}

/// This test checks if deploying a contract works as expected.
#[test]
fn cached_state_deploy_contract_test() {
let state_reader = Arc::new(InMemoryStateReader::default());
Expand All @@ -585,6 +602,7 @@ mod tests {
.is_ok());
}

/// This test verifies the set and get storage values in the cached state.
#[test]
fn get_and_set_storage() {
let state_reader = Arc::new(InMemoryStateReader::default());
Expand All @@ -611,6 +629,7 @@ mod tests {
assert_eq!(new_result.unwrap(), new_value);
}

/// This test ensures that an error is thrown when trying to set contract classes twice.
#[test]
fn set_contract_classes_twice_error_test() {
let state_reader = InMemoryStateReader::new(
Expand All @@ -631,6 +650,7 @@ mod tests {
assert_matches!(result, StateError::AssignedContractClassCache);
}

/// This test ensures that an error is thrown if a contract address is out of range.
#[test]
fn deploy_contract_address_out_of_range_error_test() {
let state_reader = InMemoryStateReader::new(
Expand All @@ -656,6 +676,7 @@ mod tests {
);
}

/// This test ensures that an error is thrown if a contract address is already in use.
#[test]
fn deploy_contract_address_in_use_error_test() {
let state_reader = InMemoryStateReader::new(
Expand Down Expand Up @@ -684,6 +705,7 @@ mod tests {
);
}

/// This test checks if replacing a contract in the cached state works correctly.
#[test]
fn cached_state_replace_contract_test() {
let state_reader = InMemoryStateReader::new(
Expand Down Expand Up @@ -713,6 +735,7 @@ mod tests {
);
}

/// This test verifies if the cached state's internal structures are correctly updated after applying a state update.
#[test]
fn cached_state_apply_state_update() {
let state_reader = InMemoryStateReader::new(
Expand Down Expand Up @@ -757,6 +780,7 @@ mod tests {
assert!(cached_state.cache.class_hash_initial_values.is_empty());
}

/// This test calculate the number of actual storage changes.
#[test]
fn count_actual_storage_changes_test() {
let state_reader = InMemoryStateReader::default();
Expand Down
8 changes: 4 additions & 4 deletions src/state/in_memory_state_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
use cairo_vm::felt::Felt252;
use getset::{Getters, MutGetters};
use num_traits::Zero;
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};

/// A [StateReader] that holds all the data in memory.
///
Expand Down Expand Up @@ -80,10 +80,10 @@ impl InMemoryStateReader {
compiled_class_hash: &CompiledClassHash,
) -> Result<CompiledClass, StateError> {
if let Some(compiled_class) = self.casm_contract_classes.get(compiled_class_hash) {
return Ok(CompiledClass::Casm(Box::new(compiled_class.clone())));
return Ok(CompiledClass::Casm(Arc::new(compiled_class.clone())));
}
if let Some(compiled_class) = self.class_hash_to_contract_class.get(compiled_class_hash) {
return Ok(CompiledClass::Deprecated(Box::new(compiled_class.clone())));
return Ok(CompiledClass::Deprecated(Arc::new(compiled_class.clone())));
}
Err(StateError::NoneCompiledClass(*compiled_class_hash))
}
Expand Down Expand Up @@ -128,7 +128,7 @@ impl StateReader for InMemoryStateReader {
fn get_contract_class(&self, class_hash: &ClassHash) -> Result<CompiledClass, StateError> {
// Deprecated contract classes dont have a compiled_class_hash, we dont need to fetch it
if let Some(compiled_class) = self.class_hash_to_contract_class.get(class_hash) {
return Ok(CompiledClass::Deprecated(Box::new(compiled_class.clone())));
return Ok(CompiledClass::Deprecated(Arc::new(compiled_class.clone())));
}
let compiled_class_hash = self.get_compiled_class_hash(class_hash)?;
if compiled_class_hash != *UNINITIALIZED_CLASS_HASH {
Expand Down
4 changes: 3 additions & 1 deletion src/state/state_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ impl StateCache {
/// Unit tests for StateCache
#[cfg(test)]
mod tests {
use std::sync::Arc;

use crate::services::api::contract_classes::deprecated_contract_class::ContractClass;

use super::*;
Expand All @@ -230,7 +232,7 @@ mod tests {
let contract_class =
ContractClass::from_path("starknet_programs/raw_contract_classes/class_with_abi.json")
.unwrap();
let compiled_class = CompiledClass::Deprecated(Box::new(contract_class));
let compiled_class = CompiledClass::Deprecated(Arc::new(contract_class));
let class_hash_to_compiled_class_hash = HashMap::from([([8; 32], compiled_class)]);
let address_to_nonce = HashMap::from([(Address(9.into()), 12.into())]);
let storage_updates = HashMap::from([((Address(4.into()), [1; 32]), 18.into())]);
Expand Down
Loading