From 47bb7da7bf575c9e5bd0b9a41c7424c4632d40d0 Mon Sep 17 00:00:00 2001 From: Gary Rong Date: Thu, 18 Jul 2024 12:11:34 +0800 Subject: [PATCH] core, internal: fix storage override --- core/state/statedb.go | 28 ++++++++++++++++++---------- internal/ethapi/api_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/core/state/statedb.go b/core/state/statedb.go index 641775b0bdfc..a731f9a8d643 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -471,20 +471,28 @@ func (s *StateDB) SetState(addr common.Address, key, value common.Hash) { // storage. This function should only be used for debugging and the mutations // must be discarded afterwards. func (s *StateDB) SetStorage(addr common.Address, storage map[common.Hash]common.Hash) { - // SetStorage needs to wipe existing storage. We achieve this by pretending - // that the account self-destructed earlier in this block, by flagging - // it in stateObjectsDestruct. The effect of doing so is that storage lookups - // will not hit disk, since it is assumed that the disk-data is belonging + // SetStorage needs to wipe the existing storage. We achieve this by marking + // the account as self-destructed in this block. The effect is that storage + // lookups will not hit the disk, as it is assumed that the disk data belongs // to a previous incarnation of the object. // - // TODO(rjl493456442) this function should only be supported by 'unwritable' - // state and all mutations made should all be discarded afterwards. - if _, ok := s.stateObjectsDestruct[addr]; !ok { - s.stateObjectsDestruct[addr] = nil + // TODO (rjl493456442): This function should only be supported by 'unwritable' + // state, and all mutations made should be discarded afterward. + obj := s.getStateObject(addr) + if obj != nil { + if _, ok := s.stateObjectsDestruct[addr]; !ok { + s.stateObjectsDestruct[addr] = obj + } } - stateObject := s.getOrNewStateObject(addr) + newObj := s.createObject(addr) for k, v := range storage { - stateObject.SetState(k, v) + newObj.SetState(k, v) + } + // Inherit the metadata of original object if it was existent + if obj != nil { + newObj.SetCode(common.BytesToHash(obj.CodeHash()), obj.code) + newObj.SetNonce(obj.Nonce()) + newObj.SetBalance(obj.Balance(), tracing.BalanceChangeUnspecified) } } diff --git a/internal/ethapi/api_test.go b/internal/ethapi/api_test.go index cf5160caf778..b4041fc84ce3 100644 --- a/internal/ethapi/api_test.go +++ b/internal/ethapi/api_test.go @@ -781,15 +781,24 @@ func TestEstimateGas(t *testing.T) { func TestCall(t *testing.T) { t.Parallel() + // Initialize test accounts var ( accounts = newAccounts(3) + dad = common.HexToAddress("0x0000000000000000000000000000000000000dad") genesis = &core.Genesis{ Config: params.MergedTestChainConfig, Alloc: types.GenesisAlloc{ accounts[0].addr: {Balance: big.NewInt(params.Ether)}, accounts[1].addr: {Balance: big.NewInt(params.Ether)}, accounts[2].addr: {Balance: big.NewInt(params.Ether)}, + dad: { + Balance: big.NewInt(params.Ether), + Nonce: 1, + Storage: map[common.Hash]common.Hash{ + common.Hash{}: common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000001"), + }, + }, }, } genBlocks = 10 @@ -949,6 +958,32 @@ func TestCall(t *testing.T) { }, want: "0x0122000000000000000000000000000000000000000000000000000000000000", }, + // Clear the entire storage set + { + blockNumber: rpc.LatestBlockNumber, + call: TransactionArgs{ + From: &accounts[1].addr, + // Yul: + // object "Test" { + // code { + // let dad := 0x0000000000000000000000000000000000000dad + // if eq(balance(dad), 0) { + // revert(0, 0) + // } + // let slot := sload(0) + // mstore(0, slot) + // return(0, 32) + // } + // } + Input: hex2Bytes("610dad6000813103600f57600080fd5b6000548060005260206000f3"), + }, + overrides: StateOverride{ + dad: OverrideAccount{ + State: &map[common.Hash]common.Hash{}, + }, + }, + want: "0x0000000000000000000000000000000000000000000000000000000000000000", + }, } for i, tc := range testSuite { result, err := api.Call(context.Background(), tc.call, &rpc.BlockNumberOrHash{BlockNumber: &tc.blockNumber}, &tc.overrides, &tc.blockOverrides)