@@ -6,7 +6,10 @@ use crate::{
66 core:: errors:: state_errors:: StateError ,
77 services:: api:: contract_classes:: compiled_class:: CompiledClass ,
88 state:: StateDiff ,
9- utils:: { subtract_mappings, to_cache_state_storage_mapping, Address , ClassHash } ,
9+ utils:: {
10+ get_erc20_balance_var_addresses, subtract_mappings, to_cache_state_storage_mapping,
11+ Address , ClassHash ,
12+ } ,
1013} ;
1114use cairo_vm:: felt:: Felt252 ;
1215use getset:: { Getters , MutGetters } ;
@@ -268,8 +271,11 @@ impl<T: StateReader> State for CachedState<T> {
268271 Ok ( ( ) )
269272 }
270273
271- fn count_actual_storage_changes ( & mut self ) -> ( usize , usize ) {
272- let storage_updates = subtract_mappings (
274+ fn count_actual_storage_changes (
275+ & mut self ,
276+ fee_token_and_sender_address : Option < ( & Address , & Address ) > ,
277+ ) -> ( usize , usize ) {
278+ let mut storage_updates = subtract_mappings (
273279 self . cache . storage_writes . clone ( ) ,
274280 self . cache . storage_initial_values . clone ( ) ,
275281 ) ;
@@ -301,6 +307,16 @@ impl<T: StateReader> State for CachedState<T> {
301307 modified_contracts. len ( )
302308 } ;
303309
310+ // Add fee transfer storage update before actually charging it, as it needs to be included in the
311+ // calculation of the final fee.
312+ if let Some ( ( fee_token_address, sender_address) ) = fee_token_and_sender_address {
313+ let ( sender_low_key, _) = get_erc20_balance_var_addresses ( sender_address) . unwrap ( ) ;
314+ storage_updates. insert (
315+ ( fee_token_address. clone ( ) , sender_low_key) ,
316+ Felt252 :: default ( ) ,
317+ ) ;
318+ }
319+
304320 ( n_modified_contracts, storage_updates. len ( ) )
305321 }
306322
@@ -705,13 +721,17 @@ mod tests {
705721 ( ( address_two, storage_key_two) , Felt252 :: from ( 1 ) ) ,
706722 ] ) ;
707723
724+ let fee_token_address = Address ( 123 . into ( ) ) ;
725+ let sender_address = Address ( 321 . into ( ) ) ;
726+
708727 let expected_changes = {
709- let n_storage_updates = 3 ;
728+ let n_storage_updates = 3 + 1 ; // + 1 fee transfer balance update
710729 let n_modified_contracts = 2 ;
711730
712731 ( n_modified_contracts, n_storage_updates)
713732 } ;
714- let changes = cached_state. count_actual_storage_changes ( ) ;
733+ let changes =
734+ cached_state. count_actual_storage_changes ( Some ( ( & fee_token_address, & sender_address) ) ) ;
715735
716736 assert_eq ! ( changes, expected_changes) ;
717737 }
0 commit comments