Skip to content
This repository was archived by the owner on Jul 22, 2024. It is now read-only.

Commit 64f6e70

Browse files
fguthmannfannyguthmannjuanbono
committed
change Box to Arc for CompiledClass (#926)
Co-authored-by: fannyguthmann <fanny.guthmann@post.idc.ac.il> Co-authored-by: Juan Bono <juanbono94@gmail.com>
1 parent 2af401c commit 64f6e70

File tree

10 files changed

+68
-37
lines changed

10 files changed

+68
-37
lines changed

src/execution/execution_entry_point.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::Arc;
2+
13
use crate::services::api::contract_classes::deprecated_contract_class::{
24
ContractEntryPoint, EntryPointType,
35
};
@@ -331,7 +333,7 @@ impl ExecutionEntryPoint {
331333
resources_manager: &mut ExecutionResourcesManager,
332334
block_context: &BlockContext,
333335
tx_execution_context: &mut TransactionExecutionContext,
334-
contract_class: Box<ContractClass>,
336+
contract_class: Arc<ContractClass>,
335337
class_hash: [u8; 32],
336338
) -> Result<CallInfo, TransactionError> {
337339
let previous_cairo_usage = resources_manager.cairo_usage.clone();
@@ -436,7 +438,7 @@ impl ExecutionEntryPoint {
436438
resources_manager: &mut ExecutionResourcesManager,
437439
block_context: &BlockContext,
438440
tx_execution_context: &mut TransactionExecutionContext,
439-
contract_class: Box<CasmContractClass>,
441+
contract_class: Arc<CasmContractClass>,
440442
class_hash: [u8; 32],
441443
support_reverted: bool,
442444
) -> Result<CallInfo, TransactionError> {

src/services/api/contract_classes/compiled_class.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::collections::HashMap;
22
use std::io::{self, Read};
3+
use std::sync::Arc;
34

45
use crate::core::contract_address::{compute_hinted_class_hash, CairoProgramToHash};
56
use crate::services::api::contract_class_errors::ContractClassError;
@@ -21,15 +22,15 @@ use starknet::core::types::ContractClass::{Legacy, Sierra};
2122

2223
#[derive(Clone, PartialEq, Eq, Debug)]
2324
pub enum CompiledClass {
24-
Deprecated(Box<ContractClass>),
25-
Casm(Box<CasmContractClass>),
25+
Deprecated(Arc<ContractClass>),
26+
Casm(Arc<CasmContractClass>),
2627
}
2728

2829
impl TryInto<CasmContractClass> for CompiledClass {
2930
type Error = ContractClassError;
3031
fn try_into(self) -> Result<CasmContractClass, ContractClassError> {
3132
match self {
32-
CompiledClass::Casm(boxed) => Ok(*boxed),
33+
CompiledClass::Casm(arc) => Ok((*arc).clone()),
3334
_ => Err(ContractClassError::NotACasmContractClass),
3435
}
3536
}
@@ -39,7 +40,7 @@ impl TryInto<ContractClass> for CompiledClass {
3940
type Error = ContractClassError;
4041
fn try_into(self) -> Result<ContractClass, ContractClassError> {
4142
match self {
42-
CompiledClass::Deprecated(boxed) => Ok(*boxed),
43+
CompiledClass::Deprecated(arc) => Ok((*arc).clone()),
4344
_ => Err(ContractClassError::NotADeprecatedContractClass),
4445
}
4546
}
@@ -77,7 +78,7 @@ impl From<StarknetRsContractClass> for CompiledClass {
7778

7879
let casm_cc = CasmContractClass::from_contract_class(sierra_cc, true).unwrap();
7980

80-
CompiledClass::Casm(Box::new(casm_cc))
81+
CompiledClass::Casm(Arc::new(casm_cc))
8182
}
8283
Legacy(_deprecated_contract_class) => {
8384
let as_str = decode_reader(_deprecated_contract_class.program).unwrap();
@@ -148,7 +149,7 @@ impl From<StarknetRsContractClass> for CompiledClass {
148149
let v = serde_json::to_value(serialized_cc).unwrap();
149150
let hinted_class_hash = compute_hinted_class_hash(&v).unwrap();
150151

151-
CompiledClass::Deprecated(Box::new(ContractClass {
152+
CompiledClass::Deprecated(Arc::new(ContractClass {
152153
program,
153154
entry_points_by_type,
154155
abi,

src/state/cached_state.rs

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ pub type CasmClassCache = HashMap<ClassHash, CasmContractClass>;
2525

2626
pub const UNINITIALIZED_CLASS_HASH: &ClassHash = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
2727

28+
/// Represents a cached state of contract classes with optional caches.
2829
#[derive(Default, Clone, Debug, Eq, Getters, MutGetters, PartialEq)]
2930
pub struct CachedState<T: StateReader> {
3031
pub state_reader: Arc<T>,
@@ -37,6 +38,7 @@ pub struct CachedState<T: StateReader> {
3738
}
3839

3940
impl<T: StateReader> CachedState<T> {
41+
/// Constructor, creates a new cached state.
4042
pub fn new(
4143
state_reader: Arc<T>,
4244
contract_class_cache: Option<ContractClassCache>,
@@ -50,6 +52,7 @@ impl<T: StateReader> CachedState<T> {
5052
}
5153
}
5254

55+
/// Creates a CachedState for testing purposes.
5356
pub fn new_for_testing(
5457
state_reader: Arc<T>,
5558
contract_classes: Option<ContractClassCache>,
@@ -64,6 +67,7 @@ impl<T: StateReader> CachedState<T> {
6467
}
6568
}
6669

70+
/// Sets the contract classes cache.
6771
pub fn set_contract_classes(
6872
&mut self,
6973
contract_classes: ContractClassCache,
@@ -75,6 +79,7 @@ impl<T: StateReader> CachedState<T> {
7579
Ok(())
7680
}
7781

82+
/// Returns the casm classes.
7883
#[allow(dead_code)]
7984
pub(crate) fn get_casm_classes(&mut self) -> Result<&CasmClassCache, StateError> {
8085
self.casm_contract_classes
@@ -84,6 +89,7 @@ impl<T: StateReader> CachedState<T> {
8489
}
8590

8691
impl<T: StateReader> StateReader for CachedState<T> {
92+
/// Returns the class hash for a given contract address.
8793
fn get_class_hash_at(&self, contract_address: &Address) -> Result<ClassHash, StateError> {
8894
if self.cache.get_class_hash(contract_address).is_none() {
8995
match self.state_reader.get_class_hash_at(contract_address) {
@@ -105,6 +111,7 @@ impl<T: StateReader> StateReader for CachedState<T> {
105111
.cloned()
106112
}
107113

114+
/// Returns the nonce for a given contract address.
108115
fn get_nonce_at(&self, contract_address: &Address) -> Result<Felt252, StateError> {
109116
if self.cache.get_nonce(contract_address).is_none() {
110117
return self.state_reader.get_nonce_at(contract_address);
@@ -115,6 +122,7 @@ impl<T: StateReader> StateReader for CachedState<T> {
115122
.cloned()
116123
}
117124

125+
/// Returns storage data for a given storage entry.
118126
fn get_storage_at(&self, storage_entry: &StorageEntry) -> Result<Felt252, StateError> {
119127
if self.cache.get_storage(storage_entry).is_none() {
120128
match self.state_reader.get_storage_at(storage_entry) {
@@ -140,6 +148,7 @@ impl<T: StateReader> StateReader for CachedState<T> {
140148
}
141149

142150
// TODO: check if that the proper way to store it (converting hash to address)
151+
/// Returned the compiled class hash for a given class hash.
143152
fn get_compiled_class_hash(&self, class_hash: &ClassHash) -> Result<ClassHash, StateError> {
144153
if self
145154
.cache
@@ -156,6 +165,7 @@ impl<T: StateReader> StateReader for CachedState<T> {
156165
.cloned()
157166
}
158167

168+
/// Returns the contract class for a given class hash.
159169
fn get_contract_class(&self, class_hash: &ClassHash) -> Result<CompiledClass, StateError> {
160170
// This method can receive both compiled_class_hash & class_hash and return both casm and deprecated contract classes
161171
//, which can be on the cache or on the state_reader, different cases will be described below:
@@ -170,15 +180,15 @@ impl<T: StateReader> StateReader for CachedState<T> {
170180
.as_ref()
171181
.and_then(|x| x.get(class_hash))
172182
{
173-
return Ok(CompiledClass::Deprecated(Box::new(compiled_class.clone())));
183+
return Ok(CompiledClass::Deprecated(Arc::new(compiled_class.clone())));
174184
}
175185
// I: CASM CONTRACT CLASS : COMPILED_CLASS_HASH
176186
if let Some(compiled_class) = self
177187
.casm_contract_classes
178188
.as_ref()
179189
.and_then(|x| x.get(class_hash))
180190
{
181-
return Ok(CompiledClass::Casm(Box::new(compiled_class.clone())));
191+
return Ok(CompiledClass::Casm(Arc::new(compiled_class.clone())));
182192
}
183193
// I: CASM CONTRACT CLASS : CLASS_HASH
184194
if let Some(compiled_class_hash) =
@@ -189,7 +199,7 @@ impl<T: StateReader> StateReader for CachedState<T> {
189199
.as_ref()
190200
.and_then(|m| m.get(compiled_class_hash))
191201
{
192-
return Ok(CompiledClass::Casm(Box::new(casm_class.clone())));
202+
return Ok(CompiledClass::Casm(Arc::new(casm_class.clone())));
193203
}
194204
}
195205
// II: FETCHING FROM STATE_READER
@@ -198,6 +208,7 @@ impl<T: StateReader> StateReader for CachedState<T> {
198208
}
199209

200210
impl<T: StateReader> State for CachedState<T> {
211+
/// Stores a contract class in the cache.
201212
fn set_contract_class(
202213
&mut self,
203214
class_hash: &ClassHash,
@@ -215,6 +226,7 @@ impl<T: StateReader> State for CachedState<T> {
215226
Ok(())
216227
}
217228

229+
/// Deploys a new contract and updates the cache.
218230
fn deploy_contract(
219231
&mut self,
220232
deploy_contract_address: Address,
@@ -433,15 +445,15 @@ impl<T: StateReader> State for CachedState<T> {
433445
.as_ref()
434446
.and_then(|x| x.get(class_hash))
435447
{
436-
return Ok(CompiledClass::Deprecated(Box::new(compiled_class.clone())));
448+
return Ok(CompiledClass::Deprecated(Arc::new(compiled_class.clone())));
437449
}
438450
// I: CASM CONTRACT CLASS : COMPILED_CLASS_HASH
439451
if let Some(compiled_class) = self
440452
.casm_contract_classes
441453
.as_ref()
442454
.and_then(|x| x.get(class_hash))
443455
{
444-
return Ok(CompiledClass::Casm(Box::new(compiled_class.clone())));
456+
return Ok(CompiledClass::Casm(Arc::new(compiled_class.clone())));
445457
}
446458
// I: CASM CONTRACT CLASS : CLASS_HASH
447459
if let Some(compiled_class_hash) =
@@ -452,7 +464,7 @@ impl<T: StateReader> State for CachedState<T> {
452464
.as_ref()
453465
.and_then(|m| m.get(compiled_class_hash))
454466
{
455-
return Ok(CompiledClass::Casm(Box::new(casm_class.clone())));
467+
return Ok(CompiledClass::Casm(Arc::new(casm_class.clone())));
456468
}
457469
}
458470
// II: FETCHING FROM STATE_READER
@@ -463,7 +475,7 @@ impl<T: StateReader> State for CachedState<T> {
463475
let compiled_class_hash = self.get_compiled_class_hash(class_hash)?;
464476
self.casm_contract_classes
465477
.as_mut()
466-
.and_then(|m| m.insert(compiled_class_hash, *class.clone()));
478+
.and_then(|m| m.insert(compiled_class_hash, class.as_ref().clone()));
467479
}
468480
CompiledClass::Deprecated(ref contract) => {
469481
self.set_contract_class(class_hash, &contract.clone())?
@@ -481,6 +493,8 @@ mod tests {
481493

482494
use num_traits::One;
483495

496+
/// Test checks if class hashes and nonces are correctly fetched from the state reader.
497+
/// It also tests the increment_nonce method.
484498
#[test]
485499
fn get_class_hash_and_nonce_from_state_reader() {
486500
let mut state_reader = InMemoryStateReader::new(
@@ -522,6 +536,7 @@ mod tests {
522536
);
523537
}
524538

539+
/// This test checks if the contract class is correctly fetched from the state reader.
525540
#[test]
526541
fn get_contract_class_from_state_reader() {
527542
let mut state_reader = InMemoryStateReader::new(
@@ -554,6 +569,7 @@ mod tests {
554569
);
555570
}
556571

572+
/// This test verifies the correct handling of storage in the cached state.
557573
#[test]
558574
fn cached_state_storage_test() {
559575
let mut cached_state =
@@ -572,6 +588,7 @@ mod tests {
572588
.is_zero());
573589
}
574590

591+
/// This test checks if deploying a contract works as expected.
575592
#[test]
576593
fn cached_state_deploy_contract_test() {
577594
let state_reader = Arc::new(InMemoryStateReader::default());
@@ -585,6 +602,7 @@ mod tests {
585602
.is_ok());
586603
}
587604

605+
/// This test verifies the set and get storage values in the cached state.
588606
#[test]
589607
fn get_and_set_storage() {
590608
let state_reader = Arc::new(InMemoryStateReader::default());
@@ -611,6 +629,7 @@ mod tests {
611629
assert_eq!(new_result.unwrap(), new_value);
612630
}
613631

632+
/// This test ensures that an error is thrown when trying to set contract classes twice.
614633
#[test]
615634
fn set_contract_classes_twice_error_test() {
616635
let state_reader = InMemoryStateReader::new(
@@ -631,6 +650,7 @@ mod tests {
631650
assert_matches!(result, StateError::AssignedContractClassCache);
632651
}
633652

653+
/// This test ensures that an error is thrown if a contract address is out of range.
634654
#[test]
635655
fn deploy_contract_address_out_of_range_error_test() {
636656
let state_reader = InMemoryStateReader::new(
@@ -656,6 +676,7 @@ mod tests {
656676
);
657677
}
658678

679+
/// This test ensures that an error is thrown if a contract address is already in use.
659680
#[test]
660681
fn deploy_contract_address_in_use_error_test() {
661682
let state_reader = InMemoryStateReader::new(
@@ -684,6 +705,7 @@ mod tests {
684705
);
685706
}
686707

708+
/// This test checks if replacing a contract in the cached state works correctly.
687709
#[test]
688710
fn cached_state_replace_contract_test() {
689711
let state_reader = InMemoryStateReader::new(
@@ -713,6 +735,7 @@ mod tests {
713735
);
714736
}
715737

738+
/// This test verifies if the cached state's internal structures are correctly updated after applying a state update.
716739
#[test]
717740
fn cached_state_apply_state_update() {
718741
let state_reader = InMemoryStateReader::new(
@@ -757,6 +780,7 @@ mod tests {
757780
assert!(cached_state.cache.class_hash_initial_values.is_empty());
758781
}
759782

783+
/// This test calculate the number of actual storage changes.
760784
#[test]
761785
fn count_actual_storage_changes_test() {
762786
let state_reader = InMemoryStateReader::default();

src/state/in_memory_state_reader.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use crate::{
1313
use cairo_vm::felt::Felt252;
1414
use getset::{Getters, MutGetters};
1515
use num_traits::Zero;
16-
use std::collections::HashMap;
16+
use std::{collections::HashMap, sync::Arc};
1717

1818
/// A [StateReader] that holds all the data in memory.
1919
///
@@ -80,10 +80,10 @@ impl InMemoryStateReader {
8080
compiled_class_hash: &CompiledClassHash,
8181
) -> Result<CompiledClass, StateError> {
8282
if let Some(compiled_class) = self.casm_contract_classes.get(compiled_class_hash) {
83-
return Ok(CompiledClass::Casm(Box::new(compiled_class.clone())));
83+
return Ok(CompiledClass::Casm(Arc::new(compiled_class.clone())));
8484
}
8585
if let Some(compiled_class) = self.class_hash_to_contract_class.get(compiled_class_hash) {
86-
return Ok(CompiledClass::Deprecated(Box::new(compiled_class.clone())));
86+
return Ok(CompiledClass::Deprecated(Arc::new(compiled_class.clone())));
8787
}
8888
Err(StateError::NoneCompiledClass(*compiled_class_hash))
8989
}
@@ -128,7 +128,7 @@ impl StateReader for InMemoryStateReader {
128128
fn get_contract_class(&self, class_hash: &ClassHash) -> Result<CompiledClass, StateError> {
129129
// Deprecated contract classes dont have a compiled_class_hash, we dont need to fetch it
130130
if let Some(compiled_class) = self.class_hash_to_contract_class.get(class_hash) {
131-
return Ok(CompiledClass::Deprecated(Box::new(compiled_class.clone())));
131+
return Ok(CompiledClass::Deprecated(Arc::new(compiled_class.clone())));
132132
}
133133
let compiled_class_hash = self.get_compiled_class_hash(class_hash)?;
134134
if compiled_class_hash != *UNINITIALIZED_CLASS_HASH {

src/state/state_cache.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,8 @@ impl StateCache {
219219
/// Unit tests for StateCache
220220
#[cfg(test)]
221221
mod tests {
222+
use std::sync::Arc;
223+
222224
use crate::services::api::contract_classes::deprecated_contract_class::ContractClass;
223225

224226
use super::*;
@@ -230,7 +232,7 @@ mod tests {
230232
let contract_class =
231233
ContractClass::from_path("starknet_programs/raw_contract_classes/class_with_abi.json")
232234
.unwrap();
233-
let compiled_class = CompiledClass::Deprecated(Box::new(contract_class));
235+
let compiled_class = CompiledClass::Deprecated(Arc::new(contract_class));
234236
let class_hash_to_compiled_class_hash = HashMap::from([([8; 32], compiled_class)]);
235237
let address_to_nonce = HashMap::from([(Address(9.into()), 12.into())]);
236238
let storage_updates = HashMap::from([((Address(4.into()), [1; 32]), 18.into())]);

0 commit comments

Comments
 (0)