@@ -99,6 +99,19 @@ impl<T: StateReader> CachedState<T> {
9999 self . contract_classes = contract_classes;
100100 Ok ( ( ) )
101101 }
102+
103+ /// Creates a copy of this state with an empty cache for saving changes and applying them
104+ /// later.
105+ pub fn create_transactional ( & self ) -> TransactionalCachedState < T > {
106+ let state_reader = Arc :: new ( TransactionalCachedStateReader :: new ( self ) ) ;
107+ CachedState {
108+ state_reader,
109+ cache : self . cache . clone ( ) ,
110+ contract_classes : self . contract_classes . clone ( ) ,
111+ cache_hits : 0 ,
112+ cache_misses : 0 ,
113+ }
114+ }
102115}
103116
104117impl < T : StateReader > StateReader for CachedState < T > {
@@ -134,19 +147,13 @@ impl<T: StateReader> StateReader for CachedState<T> {
134147 // TODO: check if that the proper way to store it (converting hash to address)
135148 /// Returned the compiled class hash for a given class hash.
136149 fn get_compiled_class_hash ( & self , class_hash : & ClassHash ) -> Result < ClassHash , StateError > {
137- if self
138- . cache
139- . class_hash_to_compiled_class_hash
140- . get ( class_hash)
141- . is_none ( )
150+ if let Some ( compiled_class_hash) =
151+ self . cache . class_hash_to_compiled_class_hash . get ( class_hash)
142152 {
143- return self . state_reader . get_compiled_class_hash ( class_hash) ;
153+ Ok ( * compiled_class_hash)
154+ } else {
155+ self . state_reader . get_compiled_class_hash ( class_hash)
144156 }
145- self . cache
146- . class_hash_to_compiled_class_hash
147- . get ( class_hash)
148- . ok_or_else ( || StateError :: NoneCompiledClass ( * class_hash) )
149- . cloned ( )
150157 }
151158
152159 /// Returns the contract class for a given class hash.
@@ -438,6 +445,114 @@ impl<T: StateReader> State for CachedState<T> {
438445 }
439446}
440447
448+ /// A CachedState which has access to another, "parent" state, used for executing transactions
449+ /// without commiting changes to the parent.
450+ pub type TransactionalCachedState < ' a , T > = CachedState < TransactionalCachedStateReader < ' a , T > > ;
451+
452+ /// State reader used for transactional states which allows to check the parent state's cache and
453+ /// state reader if a transactional cache miss happens.
454+ ///
455+ /// In practice this will act as a way to access the parent state's cache and other fields,
456+ /// without referencing the whole parent state, so there's no need to adapt state-modifying
457+ /// functions in the case that a transactional state is needed.
458+ #[ derive( Debug , MutGetters , Getters , PartialEq , Clone ) ]
459+ pub struct TransactionalCachedStateReader < ' a , T : StateReader > {
460+ /// The parent state's state_reader
461+ #[ get( get = "pub" ) ]
462+ pub ( crate ) state_reader : Arc < T > ,
463+ /// The parent state's cache
464+ #[ get( get = "pub" ) ]
465+ pub ( crate ) cache : & ' a StateCache ,
466+ /// The parent state's contract_classes
467+ #[ get( get = "pub" ) ]
468+ pub ( crate ) contract_classes : ContractClassCache ,
469+ }
470+
471+ impl < ' a , T : StateReader > TransactionalCachedStateReader < ' a , T > {
472+ fn new ( state : & ' a CachedState < T > ) -> Self {
473+ Self {
474+ state_reader : state. state_reader . clone ( ) ,
475+ cache : & state. cache ,
476+ contract_classes : state. contract_classes . clone ( ) ,
477+ }
478+ }
479+ }
480+
481+ impl < ' a , T : StateReader > StateReader for TransactionalCachedStateReader < ' a , T > {
482+ /// Returns the class hash for a given contract address.
483+ /// Returns zero as default value if missing
484+ fn get_class_hash_at ( & self , contract_address : & Address ) -> Result < ClassHash , StateError > {
485+ self . cache
486+ . get_class_hash ( contract_address)
487+ . map ( |a| Ok ( * a) )
488+ . unwrap_or_else ( || self . state_reader . get_class_hash_at ( contract_address) )
489+ }
490+
491+ /// Returns the nonce for a given contract address.
492+ fn get_nonce_at ( & self , contract_address : & Address ) -> Result < Felt252 , StateError > {
493+ if self . cache . get_nonce ( contract_address) . is_none ( ) {
494+ return self . state_reader . get_nonce_at ( contract_address) ;
495+ }
496+ self . cache
497+ . get_nonce ( contract_address)
498+ . ok_or_else ( || StateError :: NoneNonce ( contract_address. clone ( ) ) )
499+ . cloned ( )
500+ }
501+
502+ /// Returns storage data for a given storage entry.
503+ /// Returns zero as default value if missing
504+ fn get_storage_at ( & self , storage_entry : & StorageEntry ) -> Result < Felt252 , StateError > {
505+ self . cache
506+ . get_storage ( storage_entry)
507+ . map ( |v| Ok ( v. clone ( ) ) )
508+ . unwrap_or_else ( || self . state_reader . get_storage_at ( storage_entry) )
509+ }
510+
511+ // TODO: check if that the proper way to store it (converting hash to address)
512+ /// Returned the compiled class hash for a given class hash.
513+ fn get_compiled_class_hash ( & self , class_hash : & ClassHash ) -> Result < ClassHash , StateError > {
514+ if self
515+ . cache
516+ . class_hash_to_compiled_class_hash
517+ . get ( class_hash)
518+ . is_none ( )
519+ {
520+ return self . state_reader . get_compiled_class_hash ( class_hash) ;
521+ }
522+ self . cache
523+ . class_hash_to_compiled_class_hash
524+ . get ( class_hash)
525+ . ok_or_else ( || StateError :: NoneCompiledClass ( * class_hash) )
526+ . cloned ( )
527+ }
528+
529+ /// Returns the contract class for a given class hash.
530+ fn get_contract_class ( & self , class_hash : & ClassHash ) -> Result < CompiledClass , StateError > {
531+ // This method can receive both compiled_class_hash & class_hash and return both casm and deprecated contract classes
532+ //, which can be on the cache or on the state_reader, different cases will be described below:
533+ if class_hash == UNINITIALIZED_CLASS_HASH {
534+ return Err ( StateError :: UninitiaizedClassHash ) ;
535+ }
536+
537+ // I: FETCHING FROM CACHE
538+ if let Some ( compiled_class) = self . contract_classes . get ( class_hash) {
539+ return Ok ( compiled_class. clone ( ) ) ;
540+ }
541+
542+ // I: CASM CONTRACT CLASS : CLASS_HASH
543+ if let Some ( compiled_class_hash) =
544+ self . cache . class_hash_to_compiled_class_hash . get ( class_hash)
545+ {
546+ if let Some ( casm_class) = self . contract_classes . get ( compiled_class_hash) {
547+ return Ok ( casm_class. clone ( ) ) ;
548+ }
549+ }
550+
551+ // II: FETCHING FROM STATE_READER
552+ self . state_reader . get_contract_class ( class_hash)
553+ }
554+ }
555+
441556impl < T : StateReader > CachedState < T > {
442557 // Updates the cache's storage_initial_values according to those in storage_writes
443558 // If a key is present in the storage_writes but not in storage_initial_values,
0 commit comments