diff --git a/core/state_processor.go b/core/state_processor.go index a79443ebeec9..d9eb80ed40ea 100644 --- a/core/state_processor.go +++ b/core/state_processor.go @@ -102,11 +102,19 @@ func applyTransaction(msg types.Message, config *params.ChainConfig, gp *GasPool txContext := NewEVMTxContext(msg) evm.Reset(txContext, statedb) + snapshot := statedb.Snapshot() // Apply the transaction to the current state (included in the env). result, err := ApplyMessage(evm, msg, gp) if err != nil { return nil, err } + if preFinalizeHook != nil { + err = preFinalizeHook() + if err != nil { + statedb.RevertToSnapshot(snapshot) + return nil, err + } + } // Update the state with pending changes. var root []byte @@ -191,7 +199,7 @@ func applyTransactionWithResult(msg types.Message, config *params.ChainConfig, b // and uses the input parameters for its environment. It returns the receipt // for the transaction, gas used and an error if the transaction failed, // indicating the block was invalid. -func ApplyTransaction(config *params.ChainConfig, bc ChainContext, author *common.Address, gp *GasPool, statedb *state.StateDB, header *types.Header, tx *types.Transaction, usedGas *uint64, cfg vm.Config) (*types.Receipt, error) { +func ApplyTransaction(config *params.ChainConfig, bc ChainContext, author *common.Address, gp *GasPool, statedb *state.StateDB, header *types.Header, tx *types.Transaction, usedGas *uint64, cfg vm.Config, preFinalizeHook func() error) (*types.Receipt, error) { msg, err := tx.AsMessage(types.MakeSigner(config, header.Number), header.BaseFee) if err != nil { return nil, err diff --git a/miner/algo_common.go b/miner/algo_common.go index 90f4b2a79eb4..e1e8503d6a86 100644 --- a/miner/algo_common.go +++ b/miner/algo_common.go @@ -88,7 +88,7 @@ func applyTransactionWithBlacklist(signer types.Signer, config *params.ChainConf // short circuit if blacklist is empty if len(blacklist) == 0 { snap := statedb.Snapshot() - receipt, err := core.ApplyTransaction(config, bc, author, gp, statedb, header, tx, usedGas, cfg) + receipt, err := core.ApplyTransaction(config, bc, author, gp, statedb, header, tx, usedGas, cfg, nil) if err != nil { statedb.RevertToSnapshot(snap) } @@ -114,29 +114,28 @@ func applyTransactionWithBlacklist(signer types.Signer, config *params.ChainConf cfg.Tracer = touchTracer cfg.Debug = true + hook := func() error { + for _, address := range touchTracer.TouchedAddresses() { + if _, in := blacklist[address]; in { + return errors.New("blacklist violation, tx trace") + } + } + return nil + } + usedGasTmp := *usedGas gasPoolTmp := new(core.GasPool).AddGas(gp.Gas()) - stateCopy := statedb.Copy() - snap := stateCopy.Snapshot() + snap := statedb.Snapshot() - stateCopy.Prepare(tx.Hash(), statedb.TxIndex()) - receipt, err := core.ApplyTransaction(config, bc, author, gasPoolTmp, stateCopy, header, tx, &usedGasTmp, cfg) + receipt, err := core.ApplyTransaction(config, bc, author, gasPoolTmp, statedb, header, tx, &usedGasTmp, cfg, hook) if err != nil { - stateCopy.RevertToSnapshot(snap) - *usedGas = usedGasTmp - *gp = *gasPoolTmp - return receipt, stateCopy, err - } - - for _, address := range touchTracer.TouchedAddresses() { - if _, in := blacklist[address]; in { - return nil, statedb, errors.New("blacklist violation, tx trace") - } + statedb.RevertToSnapshot(snap) + return receipt, statedb, err } *usedGas = usedGasTmp *gp = *gasPoolTmp - return receipt, stateCopy, nil + return receipt, statedb, err } // commit tx to envDiff diff --git a/miner/algo_common_test.go b/miner/algo_common_test.go index 84fd4a935762..2dd43e62ef12 100644 --- a/miner/algo_common_test.go +++ b/miner/algo_common_test.go @@ -71,7 +71,7 @@ func simulateBundle(env *environment, bundle types.MevBundle, chData chainData, coinbaseBalanceBefore := stateDB.GetBalance(env.coinbase) var tempGasUsed uint64 - receipt, err := core.ApplyTransaction(chData.chainConfig, chData.chain, &env.coinbase, gasPool, stateDB, env.header, tx, &tempGasUsed, *chData.chain.GetVMConfig()) + receipt, err := core.ApplyTransaction(chData.chainConfig, chData.chain, &env.coinbase, gasPool, stateDB, env.header, tx, &tempGasUsed, *chData.chain.GetVMConfig(), nil) if err != nil { return types.SimulatedBundle{}, err } @@ -439,6 +439,8 @@ func TestBlacklist(t *testing.T) { env := newEnvironment(chData, statedb, signers.addresses[0], GasLimit, big.NewInt(1)) envDiff := newEnvironmentDiff(env) + beforeRoot := statedb.IntermediateRoot(true) + blacklist := map[common.Address]struct{}{ signers.addresses[3]: {}, } @@ -494,6 +496,11 @@ func TestBlacklist(t *testing.T) { if len(envDiff.newReceipts) != 0 { t.Fatal("newReceipts changed") } + + afterRoot := statedb.IntermediateRoot(true) + if beforeRoot != afterRoot { + t.Fatal("statedb root changed") + } } func TestGetSealingWorkAlgos(t *testing.T) { diff --git a/miner/worker.go b/miner/worker.go index 16ec5dd2035c..a93826e8932c 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -953,12 +953,7 @@ func (w *worker) updateSnapshot(env *environment) { func (w *worker) commitTransaction(env *environment, tx *types.Transaction) ([]*types.Log, error) { gasPool := *env.gasPool envGasUsed := env.header.GasUsed - var stateDB *state.StateDB - if len(w.blockList) != 0 { - stateDB = env.state.Copy() - } else { - stateDB = env.state - } + stateDB := env.state // It's important to copy then .Prepare() - don't reorder. stateDB.Prepare(tx.Hash(), env.tcount) @@ -971,25 +966,27 @@ func (w *worker) commitTransaction(env *environment, tx *types.Transaction) ([]* } var tracer *logger.AccountTouchTracer + var hook func() error config := *w.chain.GetVMConfig() if len(w.blockList) != 0 { tracer = logger.NewAccountTouchTracer() config.Tracer = tracer config.Debug = true + hook = func() error { + for _, address := range tracer.TouchedAddresses() { + if _, in := w.blockList[address]; in { + return errBlocklistViolation + } + } + return nil + } } - receipt, err := core.ApplyTransaction(w.chainConfig, w.chain, &env.coinbase, &gasPool, stateDB, env.header, tx, &envGasUsed, config) + receipt, err := core.ApplyTransaction(w.chainConfig, w.chain, &env.coinbase, &gasPool, stateDB, env.header, tx, &envGasUsed, config, hook) if err != nil { stateDB.RevertToSnapshot(snapshot) return nil, err } - if len(w.blockList) != 0 { - for _, address := range tracer.TouchedAddresses() { - if _, in := w.blockList[address]; in { - return nil, errBlocklistViolation - } - } - } *env.gasPool = gasPool env.header.GasUsed = envGasUsed @@ -1727,7 +1724,7 @@ func (w *worker) computeBundleGas(env *environment, bundle types.MevBundle, stat config.Tracer = tracer config.Debug = true } - receipt, err := core.ApplyTransaction(w.chainConfig, w.chain, &env.coinbase, gasPool, state, env.header, tx, &tempGasUsed, config) + receipt, err := core.ApplyTransaction(w.chainConfig, w.chain, &env.coinbase, gasPool, state, env.header, tx, &tempGasUsed, config, nil) if err != nil { return simulatedBundle{}, err }