@@ -99,6 +99,19 @@ impl<T: StateReader> CachedState<T> {
99
99
self . contract_classes = contract_classes;
100
100
Ok ( ( ) )
101
101
}
102
+
103
+ /// Creates a copy of this state with an empty cache for saving changes and applying them
104
+ /// later.
105
+ pub fn create_transactional ( & self ) -> TransactionalCachedState < T > {
106
+ let state_reader = Arc :: new ( TransactionalCachedStateReader :: new ( self ) ) ;
107
+ CachedState {
108
+ state_reader,
109
+ cache : self . cache . clone ( ) ,
110
+ contract_classes : self . contract_classes . clone ( ) ,
111
+ cache_hits : 0 ,
112
+ cache_misses : 0 ,
113
+ }
114
+ }
102
115
}
103
116
104
117
impl < T : StateReader > StateReader for CachedState < T > {
@@ -134,19 +147,13 @@ impl<T: StateReader> StateReader for CachedState<T> {
134
147
// TODO: check if that the proper way to store it (converting hash to address)
135
148
/// Returned the compiled class hash for a given class hash.
136
149
fn get_compiled_class_hash ( & self , class_hash : & ClassHash ) -> Result < ClassHash , StateError > {
137
- if self
138
- . cache
139
- . class_hash_to_compiled_class_hash
140
- . get ( class_hash)
141
- . is_none ( )
150
+ if let Some ( compiled_class_hash) =
151
+ self . cache . class_hash_to_compiled_class_hash . get ( class_hash)
142
152
{
143
- return self . state_reader . get_compiled_class_hash ( class_hash) ;
153
+ Ok ( * compiled_class_hash)
154
+ } else {
155
+ self . state_reader . get_compiled_class_hash ( class_hash)
144
156
}
145
- self . cache
146
- . class_hash_to_compiled_class_hash
147
- . get ( class_hash)
148
- . ok_or_else ( || StateError :: NoneCompiledClass ( * class_hash) )
149
- . cloned ( )
150
157
}
151
158
152
159
/// Returns the contract class for a given class hash.
@@ -438,6 +445,114 @@ impl<T: StateReader> State for CachedState<T> {
438
445
}
439
446
}
440
447
448
+ /// A CachedState which has access to another, "parent" state, used for executing transactions
449
+ /// without commiting changes to the parent.
450
+ pub type TransactionalCachedState < ' a , T > = CachedState < TransactionalCachedStateReader < ' a , T > > ;
451
+
452
+ /// State reader used for transactional states which allows to check the parent state's cache and
453
+ /// state reader if a transactional cache miss happens.
454
+ ///
455
+ /// In practice this will act as a way to access the parent state's cache and other fields,
456
+ /// without referencing the whole parent state, so there's no need to adapt state-modifying
457
+ /// functions in the case that a transactional state is needed.
458
+ #[ derive( Debug , MutGetters , Getters , PartialEq , Clone ) ]
459
+ pub struct TransactionalCachedStateReader < ' a , T : StateReader > {
460
+ /// The parent state's state_reader
461
+ #[ get( get = "pub" ) ]
462
+ pub ( crate ) state_reader : Arc < T > ,
463
+ /// The parent state's cache
464
+ #[ get( get = "pub" ) ]
465
+ pub ( crate ) cache : & ' a StateCache ,
466
+ /// The parent state's contract_classes
467
+ #[ get( get = "pub" ) ]
468
+ pub ( crate ) contract_classes : ContractClassCache ,
469
+ }
470
+
471
+ impl < ' a , T : StateReader > TransactionalCachedStateReader < ' a , T > {
472
+ fn new ( state : & ' a CachedState < T > ) -> Self {
473
+ Self {
474
+ state_reader : state. state_reader . clone ( ) ,
475
+ cache : & state. cache ,
476
+ contract_classes : state. contract_classes . clone ( ) ,
477
+ }
478
+ }
479
+ }
480
+
481
+ impl < ' a , T : StateReader > StateReader for TransactionalCachedStateReader < ' a , T > {
482
+ /// Returns the class hash for a given contract address.
483
+ /// Returns zero as default value if missing
484
+ fn get_class_hash_at ( & self , contract_address : & Address ) -> Result < ClassHash , StateError > {
485
+ self . cache
486
+ . get_class_hash ( contract_address)
487
+ . map ( |a| Ok ( * a) )
488
+ . unwrap_or_else ( || self . state_reader . get_class_hash_at ( contract_address) )
489
+ }
490
+
491
+ /// Returns the nonce for a given contract address.
492
+ fn get_nonce_at ( & self , contract_address : & Address ) -> Result < Felt252 , StateError > {
493
+ if self . cache . get_nonce ( contract_address) . is_none ( ) {
494
+ return self . state_reader . get_nonce_at ( contract_address) ;
495
+ }
496
+ self . cache
497
+ . get_nonce ( contract_address)
498
+ . ok_or_else ( || StateError :: NoneNonce ( contract_address. clone ( ) ) )
499
+ . cloned ( )
500
+ }
501
+
502
+ /// Returns storage data for a given storage entry.
503
+ /// Returns zero as default value if missing
504
+ fn get_storage_at ( & self , storage_entry : & StorageEntry ) -> Result < Felt252 , StateError > {
505
+ self . cache
506
+ . get_storage ( storage_entry)
507
+ . map ( |v| Ok ( v. clone ( ) ) )
508
+ . unwrap_or_else ( || self . state_reader . get_storage_at ( storage_entry) )
509
+ }
510
+
511
+ // TODO: check if that the proper way to store it (converting hash to address)
512
+ /// Returned the compiled class hash for a given class hash.
513
+ fn get_compiled_class_hash ( & self , class_hash : & ClassHash ) -> Result < ClassHash , StateError > {
514
+ if self
515
+ . cache
516
+ . class_hash_to_compiled_class_hash
517
+ . get ( class_hash)
518
+ . is_none ( )
519
+ {
520
+ return self . state_reader . get_compiled_class_hash ( class_hash) ;
521
+ }
522
+ self . cache
523
+ . class_hash_to_compiled_class_hash
524
+ . get ( class_hash)
525
+ . ok_or_else ( || StateError :: NoneCompiledClass ( * class_hash) )
526
+ . cloned ( )
527
+ }
528
+
529
+ /// Returns the contract class for a given class hash.
530
+ fn get_contract_class ( & self , class_hash : & ClassHash ) -> Result < CompiledClass , StateError > {
531
+ // This method can receive both compiled_class_hash & class_hash and return both casm and deprecated contract classes
532
+ //, which can be on the cache or on the state_reader, different cases will be described below:
533
+ if class_hash == UNINITIALIZED_CLASS_HASH {
534
+ return Err ( StateError :: UninitiaizedClassHash ) ;
535
+ }
536
+
537
+ // I: FETCHING FROM CACHE
538
+ if let Some ( compiled_class) = self . contract_classes . get ( class_hash) {
539
+ return Ok ( compiled_class. clone ( ) ) ;
540
+ }
541
+
542
+ // I: CASM CONTRACT CLASS : CLASS_HASH
543
+ if let Some ( compiled_class_hash) =
544
+ self . cache . class_hash_to_compiled_class_hash . get ( class_hash)
545
+ {
546
+ if let Some ( casm_class) = self . contract_classes . get ( compiled_class_hash) {
547
+ return Ok ( casm_class. clone ( ) ) ;
548
+ }
549
+ }
550
+
551
+ // II: FETCHING FROM STATE_READER
552
+ self . state_reader . get_contract_class ( class_hash)
553
+ }
554
+ }
555
+
441
556
impl < T : StateReader > CachedState < T > {
442
557
// Updates the cache's storage_initial_values according to those in storage_writes
443
558
// If a key is present in the storage_writes but not in storage_initial_values,
0 commit comments