From 523d4f345d6081d01813d08dbfcf1c4e013b37ca Mon Sep 17 00:00:00 2001 From: Thiago Coimbra Lemos Date: Wed, 29 May 2024 09:57:55 -0300 Subject: [PATCH] remove db tx from RPC (#3648) --- jsonrpc/dbtxmanager.go | 41 -- jsonrpc/dbtxmanager_test.go | 99 --- jsonrpc/endpoints_debug.go | 249 ++++--- jsonrpc/endpoints_eth.go | 818 +++++++++++------------ jsonrpc/endpoints_eth_test.go | 1106 ++++--------------------------- jsonrpc/endpoints_zkevm.go | 713 ++++++++++---------- jsonrpc/endpoints_zkevm_test.go | 707 +++----------------- jsonrpc/mocks/mock_dbtx.go | 350 ---------- jsonrpc/server_test.go | 31 +- jsonrpc/types/codec_test.go | 66 +- test/Makefile | 1 - 11 files changed, 1133 insertions(+), 3048 deletions(-) delete mode 100644 jsonrpc/dbtxmanager.go delete mode 100644 jsonrpc/dbtxmanager_test.go delete mode 100644 jsonrpc/mocks/mock_dbtx.go diff --git a/jsonrpc/dbtxmanager.go b/jsonrpc/dbtxmanager.go deleted file mode 100644 index bb073d0369..0000000000 --- a/jsonrpc/dbtxmanager.go +++ /dev/null @@ -1,41 +0,0 @@ -package jsonrpc - -import ( - "context" - - "github.com/0xPolygonHermez/zkevm-node/jsonrpc/types" - "github.com/jackc/pgx/v4" -) - -// DBTxManager allows to do scopped DB txs -type DBTxManager struct{} - -// DBTxScopedFn function to do scopped DB txs -type DBTxScopedFn func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) - -// DBTxer interface to begin DB txs -type DBTxer interface { - BeginStateTransaction(ctx context.Context) (pgx.Tx, error) -} - -// NewDbTxScope function to initiate DB scopped txs -func (f *DBTxManager) NewDbTxScope(db DBTxer, scopedFn DBTxScopedFn) (interface{}, types.Error) { - ctx := context.Background() - dbTx, err := db.BeginStateTransaction(ctx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to connect to the state", err, true) - } - - v, rpcErr := scopedFn(ctx, dbTx) - if rpcErr != nil { - if txErr := dbTx.Rollback(context.Background()); txErr != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to rollback db transaction", txErr, true) - } - return v, rpcErr - } - - if txErr := dbTx.Commit(context.Background()); txErr != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to commit db transaction", txErr, true) - } - return v, rpcErr -} diff --git a/jsonrpc/dbtxmanager_test.go b/jsonrpc/dbtxmanager_test.go deleted file mode 100644 index b3dba72625..0000000000 --- a/jsonrpc/dbtxmanager_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package jsonrpc - -import ( - "context" - "errors" - "testing" - - "github.com/0xPolygonHermez/zkevm-node/jsonrpc/mocks" - "github.com/0xPolygonHermez/zkevm-node/jsonrpc/types" - "github.com/jackc/pgx/v4" - "github.com/stretchr/testify/assert" -) - -func TestNewDbTxScope(t *testing.T) { - type testCase struct { - Name string - Fn DBTxScopedFn - ExpectedResult interface{} - ExpectedError types.Error - SetupMocks func(s *mocks.StateMock, d *mocks.DBTxMock) - } - - testCases := []testCase{ - { - Name: "Run scoped func commits DB tx", - Fn: func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - return 1, nil - }, - ExpectedResult: 1, - ExpectedError: nil, - SetupMocks: func(s *mocks.StateMock, d *mocks.DBTxMock) { - d.On("Commit", context.Background()).Return(nil).Once() - s.On("BeginStateTransaction", context.Background()).Return(d, nil).Once() - }, - }, - { - Name: "Run scoped func rollbacks DB tx", - Fn: func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - return nil, types.NewRPCError(types.DefaultErrorCode, "func returned an error") - }, - ExpectedResult: nil, - ExpectedError: types.NewRPCError(types.DefaultErrorCode, "func returned an error"), - SetupMocks: func(s *mocks.StateMock, d *mocks.DBTxMock) { - d.On("Rollback", context.Background()).Return(nil).Once() - s.On("BeginStateTransaction", context.Background()).Return(d, nil).Once() - }, - }, - { - Name: "Run scoped func but fails create a db tx", - Fn: func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - return nil, nil - }, - ExpectedResult: nil, - ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to connect to the state"), - SetupMocks: func(s *mocks.StateMock, d *mocks.DBTxMock) { - s.On("BeginStateTransaction", context.Background()).Return(nil, errors.New("failed to create db tx")).Once() - }, - }, - { - Name: "Run scoped func but fails to commit DB tx", - Fn: func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - return 1, nil - }, - ExpectedResult: nil, - ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to commit db transaction"), - SetupMocks: func(s *mocks.StateMock, d *mocks.DBTxMock) { - d.On("Commit", context.Background()).Return(errors.New("failed to commit db tx")).Once() - s.On("BeginStateTransaction", context.Background()).Return(d, nil).Once() - }, - }, - { - Name: "Run scoped func but fails to rollbacks DB tx", - Fn: func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - return nil, types.NewRPCError(types.DefaultErrorCode, "func returned an error") - }, - ExpectedResult: nil, - ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to rollback db transaction"), - SetupMocks: func(s *mocks.StateMock, d *mocks.DBTxMock) { - d.On("Rollback", context.Background()).Return(errors.New("failed to rollback db tx")).Once() - s.On("BeginStateTransaction", context.Background()).Return(d, nil).Once() - }, - }, - } - - dbTxManager := DBTxManager{} - s := mocks.NewStateMock(t) - d := mocks.NewDBTxMock(t) - - for _, testCase := range testCases { - t.Run(testCase.Name, func(t *testing.T) { - tc := testCase - tc.SetupMocks(s, d) - - result, err := dbTxManager.NewDbTxScope(s, tc.Fn) - assert.Equal(t, tc.ExpectedResult, result) - assert.Equal(t, tc.ExpectedError, err) - }) - } -} diff --git a/jsonrpc/endpoints_debug.go b/jsonrpc/endpoints_debug.go index a91cd924da..2309db7c5c 100644 --- a/jsonrpc/endpoints_debug.go +++ b/jsonrpc/endpoints_debug.go @@ -31,7 +31,6 @@ type DebugEndpoints struct { cfg Config state types.StateInterface etherman types.EthermanInterface - txMan DBTxManager } // NewDebugEndpoints returns DebugEndpoints @@ -64,54 +63,51 @@ type traceBatchTransactionResponse struct { // TraceTransaction creates a response for debug_traceTransaction request. // See https://geth.ethereum.org/docs/interacting-with-geth/rpc/ns-debug#debugtracetransaction func (d *DebugEndpoints) TraceTransaction(hash types.ArgHash, cfg *traceConfig) (interface{}, types.Error) { - return d.txMan.NewDbTxScope(d.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - return d.buildTraceTransaction(ctx, hash.Hash(), cfg, dbTx) - }) + ctx := context.Background() + return d.buildTraceTransaction(ctx, hash.Hash(), cfg, nil) } // TraceBlockByNumber creates a response for debug_traceBlockByNumber request. // See https://geth.ethereum.org/docs/interacting-with-geth/rpc/ns-debug#debugtraceblockbynumber func (d *DebugEndpoints) TraceBlockByNumber(number types.BlockNumber, cfg *traceConfig) (interface{}, types.Error) { - return d.txMan.NewDbTxScope(d.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - blockNumber, rpcErr := number.GetNumericBlockNumber(ctx, d.state, d.etherman, dbTx) - if rpcErr != nil { - return nil, rpcErr - } + ctx := context.Background() + blockNumber, rpcErr := number.GetNumericBlockNumber(ctx, d.state, d.etherman, nil) + if rpcErr != nil { + return nil, rpcErr + } - block, err := d.state.GetL2BlockByNumber(ctx, blockNumber, dbTx) - if errors.Is(err, state.ErrNotFound) { - return nil, types.NewRPCError(types.DefaultErrorCode, fmt.Sprintf("block #%d not found", blockNumber)) - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get block by number", err, true) - } + block, err := d.state.GetL2BlockByNumber(ctx, blockNumber, nil) + if errors.Is(err, state.ErrNotFound) { + return nil, types.NewRPCError(types.DefaultErrorCode, fmt.Sprintf("block #%d not found", blockNumber)) + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get block by number", err, true) + } - traces, rpcErr := d.buildTraceBlock(ctx, block.Transactions(), cfg, dbTx) - if rpcErr != nil { - return nil, rpcErr - } + traces, rpcErr := d.buildTraceBlock(ctx, block.Transactions(), cfg, nil) + if rpcErr != nil { + return nil, rpcErr + } - return traces, nil - }) + return traces, nil } // TraceBlockByHash creates a response for debug_traceBlockByHash request. // See https://geth.ethereum.org/docs/interacting-with-geth/rpc/ns-debug#debugtraceblockbyhash func (d *DebugEndpoints) TraceBlockByHash(hash types.ArgHash, cfg *traceConfig) (interface{}, types.Error) { - return d.txMan.NewDbTxScope(d.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - block, err := d.state.GetL2BlockByHash(ctx, hash.Hash(), dbTx) - if errors.Is(err, state.ErrNotFound) { - return nil, types.NewRPCError(types.DefaultErrorCode, fmt.Sprintf("block %s not found", hash.Hash().String())) - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get block by hash", err, true) - } + ctx := context.Background() + block, err := d.state.GetL2BlockByHash(ctx, hash.Hash(), nil) + if errors.Is(err, state.ErrNotFound) { + return nil, types.NewRPCError(types.DefaultErrorCode, fmt.Sprintf("block %s not found", hash.Hash().String())) + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get block by hash", err, true) + } - traces, rpcErr := d.buildTraceBlock(ctx, block.Transactions(), cfg, dbTx) - if rpcErr != nil { - return nil, rpcErr - } + traces, rpcErr := d.buildTraceBlock(ctx, block.Transactions(), cfg, nil) + if rpcErr != nil { + return nil, rpcErr + } - return traces, nil - }) + return traces, nil } // TraceBatchByNumber creates a response for debug_traceBatchByNumber request. @@ -144,113 +140,112 @@ func (d *DebugEndpoints) TraceBatchByNumber(httpRequest *http.Request, number ty // how many txs it will process in parallel. const bufferSize = 10 - return d.txMan.NewDbTxScope(d.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - batchNumber, rpcErr := number.GetNumericBatchNumber(ctx, d.state, d.etherman, dbTx) - if rpcErr != nil { - return nil, rpcErr - } + ctx := context.Background() + batchNumber, rpcErr := number.GetNumericBatchNumber(ctx, d.state, d.etherman, nil) + if rpcErr != nil { + return nil, rpcErr + } - batch, err := d.state.GetBatchByNumber(ctx, batchNumber, dbTx) - if errors.Is(err, state.ErrNotFound) { - return nil, types.NewRPCError(types.DefaultErrorCode, fmt.Sprintf("batch #%d not found", batchNumber)) - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get batch by number", err, true) - } + batch, err := d.state.GetBatchByNumber(ctx, batchNumber, nil) + if errors.Is(err, state.ErrNotFound) { + return nil, types.NewRPCError(types.DefaultErrorCode, fmt.Sprintf("batch #%d not found", batchNumber)) + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get batch by number", err, true) + } - txs, _, err := d.state.GetTransactionsByBatchNumber(ctx, batch.BatchNumber, dbTx) - if !errors.Is(err, state.ErrNotFound) && err != nil { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load batch txs from state by number %v to create the traces", batchNumber), err, true) - } + txs, _, err := d.state.GetTransactionsByBatchNumber(ctx, batch.BatchNumber, nil) + if !errors.Is(err, state.ErrNotFound) && err != nil { + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load batch txs from state by number %v to create the traces", batchNumber), err, true) + } - receipts := make([]ethTypes.Receipt, 0, len(txs)) - for _, tx := range txs { - receipt, err := d.state.GetTransactionReceipt(ctx, tx.Hash(), dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load receipt for tx %v to get trace", tx.Hash().String()), err, true) - } - receipts = append(receipts, *receipt) + receipts := make([]ethTypes.Receipt, 0, len(txs)) + for _, tx := range txs { + receipt, err := d.state.GetTransactionReceipt(ctx, tx.Hash(), nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load receipt for tx %v to get trace", tx.Hash().String()), err, true) } + receipts = append(receipts, *receipt) + } - requests := make(chan (ethTypes.Receipt), bufferSize) - - mu := &sync.Mutex{} - wg := sync.WaitGroup{} - wg.Add(len(receipts)) - responses := make([]traceResponse, 0, len(receipts)) - - // gets the trace from the jRPC and adds it to the responses - loadTraceByTxHash := func(d *DebugEndpoints, receipt ethTypes.Receipt, cfg *traceConfig) { - response := traceResponse{ - blockNumber: receipt.BlockNumber.Uint64(), - txIndex: uint64(receipt.TransactionIndex), - txHash: receipt.TxHash, - } - - defer wg.Done() - trace, err := d.TraceTransaction(types.ArgHash(receipt.TxHash), cfg) - if err != nil { - err := fmt.Errorf("failed to get tx trace for tx %v, err: %w", receipt.TxHash.String(), err) - log.Errorf(err.Error()) - response.err = err - } else { - response.trace = trace - } - - // add to the responses - mu.Lock() - defer mu.Unlock() - responses = append(responses, response) + requests := make(chan (ethTypes.Receipt), bufferSize) + + mu := &sync.Mutex{} + wg := sync.WaitGroup{} + wg.Add(len(receipts)) + responses := make([]traceResponse, 0, len(receipts)) + + // gets the trace from the jRPC and adds it to the responses + loadTraceByTxHash := func(d *DebugEndpoints, receipt ethTypes.Receipt, cfg *traceConfig) { + response := traceResponse{ + blockNumber: receipt.BlockNumber.Uint64(), + txIndex: uint64(receipt.TransactionIndex), + txHash: receipt.TxHash, } - // goes through the buffer and loads the trace - // by all the transactions added in the buffer - // then add the results to the responses map - go func() { - index := uint(0) - for req := range requests { - go loadTraceByTxHash(d, req, cfg) - index++ - } - }() - - // add receipts to the buffer - for _, receipt := range receipts { - requests <- receipt + defer wg.Done() + trace, err := d.TraceTransaction(types.ArgHash(receipt.TxHash), cfg) + if err != nil { + err := fmt.Errorf("failed to get tx trace for tx %v, err: %w", receipt.TxHash.String(), err) + log.Errorf(err.Error()) + response.err = err + } else { + response.trace = trace } - // wait the traces to be loaded - if waitTimeout(&wg, d.cfg.ReadTimeout.Duration) { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("failed to get traces for batch %v: timeout reached", batchNumber), nil, true) + // add to the responses + mu.Lock() + defer mu.Unlock() + responses = append(responses, response) + } + + // goes through the buffer and loads the trace + // by all the transactions added in the buffer + // then add the results to the responses map + go func() { + index := uint(0) + for req := range requests { + go loadTraceByTxHash(d, req, cfg) + index++ } + }() - close(requests) - - // since the txs are attached to a L2 Block and the L2 Block is - // the struct attached to the Batch, in order to always respond - // the traces in the same order, we need to order the transactions - // first by block number and then by tx index, so we can have something - // close to the txs being sorted by a tx index related to the batch - sort.Slice(responses, func(i, j int) bool { - if responses[i].txIndex != responses[j].txIndex { - return responses[i].txIndex < responses[j].txIndex - } - return responses[i].blockNumber < responses[j].blockNumber - }) + // add receipts to the buffer + for _, receipt := range receipts { + requests <- receipt + } + + // wait the traces to be loaded + if waitTimeout(&wg, d.cfg.ReadTimeout.Duration) { + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("failed to get traces for batch %v: timeout reached", batchNumber), nil, true) + } + + close(requests) - // build the batch trace response array - traces := make([]traceBatchTransactionResponse, 0, len(receipts)) - for _, response := range responses { - if response.err != nil { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("failed to get traces for batch %v: failed to get trace for tx: %v, err: %v", batchNumber, response.txHash.String(), response.err.Error()), nil, true) - } - - traces = append(traces, traceBatchTransactionResponse{ - TxHash: response.txHash, - Result: response.trace, - }) + // since the txs are attached to a L2 Block and the L2 Block is + // the struct attached to the Batch, in order to always respond + // the traces in the same order, we need to order the transactions + // first by block number and then by tx index, so we can have something + // close to the txs being sorted by a tx index related to the batch + sort.Slice(responses, func(i, j int) bool { + if responses[i].txIndex != responses[j].txIndex { + return responses[i].txIndex < responses[j].txIndex } - return traces, nil + return responses[i].blockNumber < responses[j].blockNumber }) + + // build the batch trace response array + traces := make([]traceBatchTransactionResponse, 0, len(receipts)) + for _, response := range responses { + if response.err != nil { + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("failed to get traces for batch %v: failed to get trace for tx: %v, err: %v", batchNumber, response.txHash.String(), response.err.Error()), nil, true) + } + + traces = append(traces, traceBatchTransactionResponse{ + TxHash: response.txHash, + Result: response.trace, + }) + } + return traces, nil } func (d *DebugEndpoints) buildTraceBlock(ctx context.Context, txs []*ethTypes.Transaction, cfg *traceConfig, dbTx pgx.Tx) (interface{}, types.Error) { diff --git a/jsonrpc/endpoints_eth.go b/jsonrpc/endpoints_eth.go index 8231cab78b..bcfb21e770 100644 --- a/jsonrpc/endpoints_eth.go +++ b/jsonrpc/endpoints_eth.go @@ -37,7 +37,6 @@ type EthEndpoints struct { state types.StateInterface etherman types.EthermanInterface storage storageInterface - txMan DBTxManager } // NewEthEndpoints creates an new instance of Eth @@ -50,14 +49,13 @@ func NewEthEndpoints(cfg Config, chainID uint64, p types.PoolInterface, s types. // BlockNumber returns current block number func (e *EthEndpoints) BlockNumber() (interface{}, types.Error) { - return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - lastBlockNumber, err := e.state.GetLastL2BlockNumber(ctx, dbTx) - if err != nil { - return "0x0", types.NewRPCError(types.DefaultErrorCode, "failed to get the last block number from state") - } + ctx := context.Background() + lastBlockNumber, err := e.state.GetLastL2BlockNumber(ctx, nil) + if err != nil { + return "0x0", types.NewRPCError(types.DefaultErrorCode, "failed to get the last block number from state") + } - return hex.EncodeUint64(lastBlockNumber), nil - }) + return hex.EncodeUint64(lastBlockNumber), nil } // Call executes a new message call immediately and returns the value of @@ -65,62 +63,61 @@ func (e *EthEndpoints) BlockNumber() (interface{}, types.Error) { // Note, this function doesn't make any changes in the state/blockchain and is // useful to execute view/pure methods and retrieve values. func (e *EthEndpoints) Call(arg *types.TxArgs, blockArg *types.BlockNumberOrHash) (interface{}, types.Error) { - return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - if arg == nil { - return RPCErrorResponse(types.InvalidParamsErrorCode, "missing value for required argument 0", nil, false) - } - block, respErr := e.getBlockByArg(ctx, blockArg, dbTx) - if respErr != nil { - return nil, respErr - } - var blockToProcess *uint64 - if blockArg != nil { - blockNumArg := blockArg.Number() - if blockNumArg != nil && (*blockArg.Number() == types.LatestBlockNumber || *blockArg.Number() == types.PendingBlockNumber) { - blockToProcess = nil - } else { - n := block.NumberU64() - blockToProcess = &n - } - } - - // If the caller didn't supply the gas limit in the message, then we set it to maximum possible => block gas limit - if arg.Gas == nil || uint64(*arg.Gas) <= 0 { - header, err := e.state.GetL2BlockHeaderByNumber(ctx, block.NumberU64(), dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get block header", err, true) - } - - gas := types.ArgUint64(header.GasLimit) - arg.Gas = &gas + ctx := context.Background() + if arg == nil { + return RPCErrorResponse(types.InvalidParamsErrorCode, "missing value for required argument 0", nil, false) + } + block, respErr := e.getBlockByArg(ctx, blockArg, nil) + if respErr != nil { + return nil, respErr + } + var blockToProcess *uint64 + if blockArg != nil { + blockNumArg := blockArg.Number() + if blockNumArg != nil && (*blockArg.Number() == types.LatestBlockNumber || *blockArg.Number() == types.PendingBlockNumber) { + blockToProcess = nil + } else { + n := block.NumberU64() + blockToProcess = &n } + } - defaultSenderAddress := common.HexToAddress(state.DefaultSenderAddress) - sender, tx, err := arg.ToTransaction(ctx, e.state, state.MaxTxGasLimit, block.Root(), defaultSenderAddress, dbTx) + // If the caller didn't supply the gas limit in the message, then we set it to maximum possible => block gas limit + if arg.Gas == nil || uint64(*arg.Gas) <= 0 { + header, err := e.state.GetL2BlockHeaderByNumber(ctx, block.NumberU64(), nil) if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to convert arguments into an unsigned transaction", err, false) + return RPCErrorResponse(types.DefaultErrorCode, "failed to get block header", err, true) } - result, err := e.state.ProcessUnsignedTransaction(ctx, tx, sender, blockToProcess, true, dbTx) - if err != nil { - errMsg := fmt.Sprintf("failed to execute the unsigned transaction: %v", err.Error()) - logError := !executor.IsROMOutOfCountersError(executor.RomErrorCode(err)) && !errors.Is(err, runtime.ErrOutOfGas) - return RPCErrorResponse(types.DefaultErrorCode, errMsg, nil, logError) - } + gas := types.ArgUint64(header.GasLimit) + arg.Gas = &gas + } - if result.Reverted() { - data := make([]byte, len(result.ReturnValue)) - copy(data, result.ReturnValue) - if len(data) == 0 { - return nil, types.NewRPCError(types.DefaultErrorCode, result.Err.Error()) - } - return nil, types.NewRPCErrorWithData(types.RevertedErrorCode, result.Err.Error(), data) - } else if result.Failed() { + defaultSenderAddress := common.HexToAddress(state.DefaultSenderAddress) + sender, tx, err := arg.ToTransaction(ctx, e.state, state.MaxTxGasLimit, block.Root(), defaultSenderAddress, nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to convert arguments into an unsigned transaction", err, false) + } + + result, err := e.state.ProcessUnsignedTransaction(ctx, tx, sender, blockToProcess, true, nil) + if err != nil { + errMsg := fmt.Sprintf("failed to execute the unsigned transaction: %v", err.Error()) + logError := !executor.IsROMOutOfCountersError(executor.RomErrorCode(err)) && !errors.Is(err, runtime.ErrOutOfGas) + return RPCErrorResponse(types.DefaultErrorCode, errMsg, nil, logError) + } + + if result.Reverted() { + data := make([]byte, len(result.ReturnValue)) + copy(data, result.ReturnValue) + if len(data) == 0 { return nil, types.NewRPCError(types.DefaultErrorCode, result.Err.Error()) } + return nil, types.NewRPCErrorWithData(types.RevertedErrorCode, result.Err.Error(), data) + } else if result.Failed() { + return nil, types.NewRPCError(types.DefaultErrorCode, result.Err.Error()) + } - return types.ArgBytesPtr(result.ReturnValue), nil - }) + return types.ArgBytesPtr(result.ReturnValue), nil } // ChainId returns the chain id of the client @@ -161,46 +158,45 @@ func (e *EthEndpoints) getCoinbaseFromSequencerNode() (interface{}, types.Error) // used by the transaction, for a variety of reasons including EVM mechanics and // node performance. func (e *EthEndpoints) EstimateGas(arg *types.TxArgs, blockArg *types.BlockNumberOrHash) (interface{}, types.Error) { - return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - if arg == nil { - return RPCErrorResponse(types.InvalidParamsErrorCode, "missing value for required argument 0", nil, false) - } + ctx := context.Background() + if arg == nil { + return RPCErrorResponse(types.InvalidParamsErrorCode, "missing value for required argument 0", nil, false) + } - block, respErr := e.getBlockByArg(ctx, blockArg, dbTx) - if respErr != nil { - return nil, respErr - } + block, respErr := e.getBlockByArg(ctx, blockArg, nil) + if respErr != nil { + return nil, respErr + } - var blockToProcess *uint64 - if blockArg != nil { - blockNumArg := blockArg.Number() - if blockNumArg != nil && (*blockArg.Number() == types.LatestBlockNumber || *blockArg.Number() == types.PendingBlockNumber) { - blockToProcess = nil - } else { - n := block.NumberU64() - blockToProcess = &n - } + var blockToProcess *uint64 + if blockArg != nil { + blockNumArg := blockArg.Number() + if blockNumArg != nil && (*blockArg.Number() == types.LatestBlockNumber || *blockArg.Number() == types.PendingBlockNumber) { + blockToProcess = nil + } else { + n := block.NumberU64() + blockToProcess = &n } + } - defaultSenderAddress := common.HexToAddress(state.DefaultSenderAddress) - sender, tx, err := arg.ToTransaction(ctx, e.state, state.MaxTxGasLimit, block.Root(), defaultSenderAddress, dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to convert arguments into an unsigned transaction", err, false) - } + defaultSenderAddress := common.HexToAddress(state.DefaultSenderAddress) + sender, tx, err := arg.ToTransaction(ctx, e.state, state.MaxTxGasLimit, block.Root(), defaultSenderAddress, nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to convert arguments into an unsigned transaction", err, false) + } - gasEstimation, returnValue, err := e.state.EstimateGas(tx, sender, blockToProcess, dbTx) - if errors.Is(err, runtime.ErrExecutionReverted) { - data := make([]byte, len(returnValue)) - copy(data, returnValue) - if len(data) == 0 { - return nil, types.NewRPCError(types.DefaultErrorCode, err.Error()) - } - return nil, types.NewRPCErrorWithData(types.RevertedErrorCode, err.Error(), data) - } else if err != nil { + gasEstimation, returnValue, err := e.state.EstimateGas(tx, sender, blockToProcess, nil) + if errors.Is(err, runtime.ErrExecutionReverted) { + data := make([]byte, len(returnValue)) + copy(data, returnValue) + if len(data) == 0 { return nil, types.NewRPCError(types.DefaultErrorCode, err.Error()) } - return hex.EncodeUint64(gasEstimation), nil - }) + return nil, types.NewRPCErrorWithData(types.RevertedErrorCode, err.Error(), data) + } else if err != nil { + return nil, types.NewRPCError(types.DefaultErrorCode, err.Error()) + } + return hex.EncodeUint64(gasEstimation), nil } // GasPrice returns the average gas price based on the last x blocks @@ -253,21 +249,20 @@ func (e *EthEndpoints) getHighestL2BlockFromTrustedNode() (interface{}, types.Er // GetBalance returns the account's balance at the referenced block func (e *EthEndpoints) GetBalance(address types.ArgAddress, blockArg *types.BlockNumberOrHash) (interface{}, types.Error) { - return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - block, rpcErr := e.getBlockByArg(ctx, blockArg, dbTx) - if rpcErr != nil { - return nil, rpcErr - } + ctx := context.Background() + block, rpcErr := e.getBlockByArg(ctx, blockArg, nil) + if rpcErr != nil { + return nil, rpcErr + } - balance, err := e.state.GetBalance(ctx, address.Address(), block.Root()) - if errors.Is(err, state.ErrNotFound) { - return hex.EncodeUint64(0), nil - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get balance from state", err, true) - } + balance, err := e.state.GetBalance(ctx, address.Address(), block.Root()) + if errors.Is(err, state.ErrNotFound) { + return hex.EncodeUint64(0), nil + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get balance from state", err, true) + } - return hex.EncodeBig(balance), nil - }) + return hex.EncodeBig(balance), nil } func (e *EthEndpoints) getBlockByArg(ctx context.Context, blockArg *types.BlockNumberOrHash, dbTx pgx.Tx) (*state.L2Block, types.Error) { @@ -308,111 +303,108 @@ func (e *EthEndpoints) getBlockByArg(ctx context.Context, blockArg *types.BlockN // GetBlockByHash returns information about a block by hash func (e *EthEndpoints) GetBlockByHash(hash types.ArgHash, fullTx bool, includeExtraInfo *bool) (interface{}, types.Error) { - return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - l2Block, err := e.state.GetL2BlockByHash(ctx, hash.Hash(), dbTx) - if errors.Is(err, state.ErrNotFound) { - return nil, nil - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get block by hash from state", err, true) - } - - txs := l2Block.Transactions() - receipts := make([]ethTypes.Receipt, 0, len(txs)) - for _, tx := range txs { - receipt, err := e.state.GetTransactionReceipt(ctx, tx.Hash(), dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load receipt for tx %v", tx.Hash().String()), err, true) - } - receipts = append(receipts, *receipt) - } + ctx := context.Background() + l2Block, err := e.state.GetL2BlockByHash(ctx, hash.Hash(), nil) + if errors.Is(err, state.ErrNotFound) { + return nil, nil + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get block by hash from state", err, true) + } - rpcBlock, err := types.NewBlock(ctx, e.state, state.Ptr(l2Block.Hash()), l2Block, receipts, fullTx, false, includeExtraInfo, dbTx) + txs := l2Block.Transactions() + receipts := make([]ethTypes.Receipt, 0, len(txs)) + for _, tx := range txs { + receipt, err := e.state.GetTransactionReceipt(ctx, tx.Hash(), nil) if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't build block response for block by hash %v", hash.Hash()), err, true) + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load receipt for tx %v", tx.Hash().String()), err, true) } + receipts = append(receipts, *receipt) + } - return rpcBlock, nil - }) + rpcBlock, err := types.NewBlock(ctx, e.state, state.Ptr(l2Block.Hash()), l2Block, receipts, fullTx, false, includeExtraInfo, nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't build block response for block by hash %v", hash.Hash()), err, true) + } + + return rpcBlock, nil } // GetBlockByNumber returns information about a block by block number func (e *EthEndpoints) GetBlockByNumber(number types.BlockNumber, fullTx bool, includeExtraInfo *bool) (interface{}, types.Error) { - return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - if number == types.PendingBlockNumber { - lastBlock, err := e.state.GetLastL2Block(ctx, dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "couldn't load last block from state to compute the pending block", err, true) - } - l2Header := state.NewL2Header(ðTypes.Header{ - ParentHash: lastBlock.Hash(), - Number: big.NewInt(0).SetUint64(lastBlock.Number().Uint64() + 1), - TxHash: ethTypes.EmptyRootHash, - UncleHash: ethTypes.EmptyUncleHash, - }) - l2Block := state.NewL2BlockWithHeader(l2Header) - rpcBlock, err := types.NewBlock(ctx, e.state, nil, l2Block, nil, fullTx, false, includeExtraInfo, dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "couldn't build the pending block response", err, true) - } - - // clean fields that are not available for pending block - rpcBlock.Hash = nil - rpcBlock.Miner = nil - rpcBlock.Nonce = nil - rpcBlock.TotalDifficulty = nil - - return rpcBlock, nil + ctx := context.Background() + if number == types.PendingBlockNumber { + lastBlock, err := e.state.GetLastL2Block(ctx, nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "couldn't load last block from state to compute the pending block", err, true) } - var err error - blockNumber, rpcErr := number.GetNumericBlockNumber(ctx, e.state, e.etherman, dbTx) - if rpcErr != nil { - return nil, rpcErr + l2Header := state.NewL2Header(ðTypes.Header{ + ParentHash: lastBlock.Hash(), + Number: big.NewInt(0).SetUint64(lastBlock.Number().Uint64() + 1), + TxHash: ethTypes.EmptyRootHash, + UncleHash: ethTypes.EmptyUncleHash, + }) + l2Block := state.NewL2BlockWithHeader(l2Header) + rpcBlock, err := types.NewBlock(ctx, e.state, nil, l2Block, nil, fullTx, false, includeExtraInfo, nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "couldn't build the pending block response", err, true) } - l2Block, err := e.state.GetL2BlockByNumber(ctx, blockNumber, dbTx) - if errors.Is(err, state.ErrNotFound) { - return nil, nil - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load block from state by number %v", blockNumber), err, true) - } + // clean fields that are not available for pending block + rpcBlock.Hash = nil + rpcBlock.Miner = nil + rpcBlock.Nonce = nil + rpcBlock.TotalDifficulty = nil - txs := l2Block.Transactions() - receipts := make([]ethTypes.Receipt, 0, len(txs)) - for _, tx := range txs { - receipt, err := e.state.GetTransactionReceipt(ctx, tx.Hash(), dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load receipt for tx %v", tx.Hash().String()), err, true) - } - receipts = append(receipts, *receipt) - } + return rpcBlock, nil + } + var err error + blockNumber, rpcErr := number.GetNumericBlockNumber(ctx, e.state, e.etherman, nil) + if rpcErr != nil { + return nil, rpcErr + } - rpcBlock, err := types.NewBlock(ctx, e.state, state.Ptr(l2Block.Hash()), l2Block, receipts, fullTx, false, includeExtraInfo, dbTx) + l2Block, err := e.state.GetL2BlockByNumber(ctx, blockNumber, nil) + if errors.Is(err, state.ErrNotFound) { + return nil, nil + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load block from state by number %v", blockNumber), err, true) + } + + txs := l2Block.Transactions() + receipts := make([]ethTypes.Receipt, 0, len(txs)) + for _, tx := range txs { + receipt, err := e.state.GetTransactionReceipt(ctx, tx.Hash(), nil) if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't build block response for block by number %v", blockNumber), err, true) + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load receipt for tx %v", tx.Hash().String()), err, true) } + receipts = append(receipts, *receipt) + } - return rpcBlock, nil - }) + rpcBlock, err := types.NewBlock(ctx, e.state, state.Ptr(l2Block.Hash()), l2Block, receipts, fullTx, false, includeExtraInfo, nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't build block response for block by number %v", blockNumber), err, true) + } + + return rpcBlock, nil } // GetCode returns account code at given block number func (e *EthEndpoints) GetCode(address types.ArgAddress, blockArg *types.BlockNumberOrHash) (interface{}, types.Error) { - return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - var err error - block, rpcErr := e.getBlockByArg(ctx, blockArg, dbTx) - if rpcErr != nil { - return nil, rpcErr - } + ctx := context.Background() + var err error + block, rpcErr := e.getBlockByArg(ctx, blockArg, nil) + if rpcErr != nil { + return nil, rpcErr + } - code, err := e.state.GetCode(ctx, address.Address(), block.Root()) - if errors.Is(err, state.ErrNotFound) { - return "0x", nil - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get code", err, true) - } + code, err := e.state.GetCode(ctx, address.Address(), block.Root()) + if errors.Is(err, state.ErrNotFound) { + return "0x", nil + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get code", err, true) + } - return types.ArgBytes(code), nil - }) + return types.ArgBytes(code), nil } // GetCompilers eth_getCompilers @@ -511,9 +503,8 @@ func (e *EthEndpoints) GetFilterLogs(filterID string) (interface{}, types.Error) // GetLogs returns a list of logs accordingly to the provided filter func (e *EthEndpoints) GetLogs(filter LogFilter) (interface{}, types.Error) { - return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - return e.internalGetLogs(ctx, dbTx, filter) - }) + ctx := context.Background() + return e.internalGetLogs(ctx, nil, filter) } func (e *EthEndpoints) internalGetLogs(ctx context.Context, dbTx pgx.Tx, filter LogFilter) (interface{}, types.Error) { @@ -549,92 +540,123 @@ func (e *EthEndpoints) internalGetLogs(ctx context.Context, dbTx pgx.Tx, filter // GetStorageAt gets the value stored for an specific address and position func (e *EthEndpoints) GetStorageAt(address types.ArgAddress, storageKeyStr string, blockArg *types.BlockNumberOrHash) (interface{}, types.Error) { + ctx := context.Background() storageKey := types.ArgHash{} err := storageKey.UnmarshalText([]byte(storageKeyStr)) if err != nil { return RPCErrorResponse(types.DefaultErrorCode, "unable to decode storage key: hex string invalid", nil, false) } - return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - block, respErr := e.getBlockByArg(ctx, blockArg, dbTx) - if respErr != nil { - return nil, respErr - } + block, respErr := e.getBlockByArg(ctx, blockArg, nil) + if respErr != nil { + return nil, respErr + } - value, err := e.state.GetStorageAt(ctx, address.Address(), storageKey.Hash().Big(), block.Root()) - if errors.Is(err, state.ErrNotFound) { - return types.ArgBytesPtr(common.Hash{}.Bytes()), nil - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get storage value from state", err, true) - } + value, err := e.state.GetStorageAt(ctx, address.Address(), storageKey.Hash().Big(), block.Root()) + if errors.Is(err, state.ErrNotFound) { + return types.ArgBytesPtr(common.Hash{}.Bytes()), nil + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get storage value from state", err, true) + } - return types.ArgBytesPtr(common.BigToHash(value).Bytes()), nil - }) + return types.ArgBytesPtr(common.BigToHash(value).Bytes()), nil } // GetTransactionByBlockHashAndIndex returns information about a transaction by // block hash and transaction index position. func (e *EthEndpoints) GetTransactionByBlockHashAndIndex(hash types.ArgHash, index types.Index, includeExtraInfo *bool) (interface{}, types.Error) { - return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - tx, err := e.state.GetTransactionByL2BlockHashAndIndex(ctx, hash.Hash(), uint64(index), dbTx) - if errors.Is(err, state.ErrNotFound) { - return nil, nil - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get transaction", err, true) - } - - receipt, err := e.state.GetTransactionReceipt(ctx, tx.Hash(), dbTx) - if errors.Is(err, state.ErrNotFound) { - return nil, nil - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get transaction receipt", err, true) - } + ctx := context.Background() + tx, err := e.state.GetTransactionByL2BlockHashAndIndex(ctx, hash.Hash(), uint64(index), nil) + if errors.Is(err, state.ErrNotFound) { + return nil, nil + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get transaction", err, true) + } - var l2Hash *common.Hash - if includeExtraInfo != nil && *includeExtraInfo { - l2h, err := e.state.GetL2TxHashByTxHash(ctx, tx.Hash(), dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get l2 transaction hash", err, true) - } - l2Hash = l2h - } + receipt, err := e.state.GetTransactionReceipt(ctx, tx.Hash(), nil) + if errors.Is(err, state.ErrNotFound) { + return nil, nil + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get transaction receipt", err, true) + } - res, err := types.NewTransaction(*tx, receipt, false, l2Hash) + var l2Hash *common.Hash + if includeExtraInfo != nil && *includeExtraInfo { + l2h, err := e.state.GetL2TxHashByTxHash(ctx, tx.Hash(), nil) if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to build transaction response", err, true) + return RPCErrorResponse(types.DefaultErrorCode, "failed to get l2 transaction hash", err, true) } + l2Hash = l2h + } - return res, nil - }) + res, err := types.NewTransaction(*tx, receipt, false, l2Hash) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to build transaction response", err, true) + } + + return res, nil } // GetTransactionByBlockNumberAndIndex returns information about a transaction by // block number and transaction index position. func (e *EthEndpoints) GetTransactionByBlockNumberAndIndex(number *types.BlockNumber, index types.Index, includeExtraInfo *bool) (interface{}, types.Error) { - return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - var err error - blockNumber, rpcErr := number.GetNumericBlockNumber(ctx, e.state, e.etherman, dbTx) - if rpcErr != nil { - return nil, rpcErr - } + ctx := context.Background() + var err error + blockNumber, rpcErr := number.GetNumericBlockNumber(ctx, e.state, e.etherman, nil) + if rpcErr != nil { + return nil, rpcErr + } - tx, err := e.state.GetTransactionByL2BlockNumberAndIndex(ctx, blockNumber, uint64(index), dbTx) - if errors.Is(err, state.ErrNotFound) { - return nil, nil - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get transaction", err, true) + tx, err := e.state.GetTransactionByL2BlockNumberAndIndex(ctx, blockNumber, uint64(index), nil) + if errors.Is(err, state.ErrNotFound) { + return nil, nil + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get transaction", err, true) + } + + receipt, err := e.state.GetTransactionReceipt(ctx, tx.Hash(), nil) + if errors.Is(err, state.ErrNotFound) { + return nil, nil + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get transaction receipt", err, true) + } + + var l2Hash *common.Hash + if includeExtraInfo != nil && *includeExtraInfo { + l2h, err := e.state.GetL2TxHashByTxHash(ctx, tx.Hash(), nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get l2 transaction hash", err, true) } + l2Hash = l2h + } + + res, err := types.NewTransaction(*tx, receipt, false, l2Hash) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to build transaction response", err, true) + } + + return res, nil +} - receipt, err := e.state.GetTransactionReceipt(ctx, tx.Hash(), dbTx) +// GetTransactionByHash returns a transaction by his hash +func (e *EthEndpoints) GetTransactionByHash(hash types.ArgHash, includeExtraInfo *bool) (interface{}, types.Error) { + ctx := context.Background() + // try to get tx from state + tx, err := e.state.GetTransactionByHash(ctx, hash.Hash(), nil) + if err != nil && !errors.Is(err, state.ErrNotFound) { + return RPCErrorResponse(types.DefaultErrorCode, "failed to load transaction by hash from state", err, true) + } + if tx != nil { + receipt, err := e.state.GetTransactionReceipt(ctx, hash.Hash(), nil) if errors.Is(err, state.ErrNotFound) { - return nil, nil + return RPCErrorResponse(types.DefaultErrorCode, "transaction receipt not found", err, false) } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get transaction receipt", err, true) + return RPCErrorResponse(types.DefaultErrorCode, "failed to load transaction receipt from state", err, true) } var l2Hash *common.Hash if includeExtraInfo != nil && *includeExtraInfo { - l2h, err := e.state.GetL2TxHashByTxHash(ctx, tx.Hash(), dbTx) + l2h, err := e.state.GetL2TxHashByTxHash(ctx, hash.Hash(), nil) if err != nil { return RPCErrorResponse(types.DefaultErrorCode, "failed to get l2 transaction hash", err, true) } @@ -647,62 +669,27 @@ func (e *EthEndpoints) GetTransactionByBlockNumberAndIndex(number *types.BlockNu } return res, nil - }) -} - -// GetTransactionByHash returns a transaction by his hash -func (e *EthEndpoints) GetTransactionByHash(hash types.ArgHash, includeExtraInfo *bool) (interface{}, types.Error) { - return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - // try to get tx from state - tx, err := e.state.GetTransactionByHash(ctx, hash.Hash(), dbTx) - if err != nil && !errors.Is(err, state.ErrNotFound) { - return RPCErrorResponse(types.DefaultErrorCode, "failed to load transaction by hash from state", err, true) - } - if tx != nil { - receipt, err := e.state.GetTransactionReceipt(ctx, hash.Hash(), dbTx) - if errors.Is(err, state.ErrNotFound) { - return RPCErrorResponse(types.DefaultErrorCode, "transaction receipt not found", err, false) - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to load transaction receipt from state", err, true) - } - - var l2Hash *common.Hash - if includeExtraInfo != nil && *includeExtraInfo { - l2h, err := e.state.GetL2TxHashByTxHash(ctx, hash.Hash(), dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get l2 transaction hash", err, true) - } - l2Hash = l2h - } - - res, err := types.NewTransaction(*tx, receipt, false, l2Hash) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to build transaction response", err, true) - } - - return res, nil - } + } - // if the tx does not exist in the state, look for it in the pool - if e.cfg.SequencerNodeURI != "" { - return e.getTransactionByHashFromSequencerNode(hash.Hash(), includeExtraInfo) - } - poolTx, err := e.pool.GetTransactionByHash(ctx, hash.Hash()) - if errors.Is(err, pool.ErrNotFound) { - return nil, nil - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to load transaction by hash from pool", err, true) - } - if poolTx.Status == pool.TxStatusPending { - tx = &poolTx.Transaction - res, err := types.NewTransaction(*tx, nil, false, nil) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to build transaction response", err, true) - } - return res, nil - } + // if the tx does not exist in the state, look for it in the pool + if e.cfg.SequencerNodeURI != "" { + return e.getTransactionByHashFromSequencerNode(hash.Hash(), includeExtraInfo) + } + poolTx, err := e.pool.GetTransactionByHash(ctx, hash.Hash()) + if errors.Is(err, pool.ErrNotFound) { return nil, nil - }) + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to load transaction by hash from pool", err, true) + } + if poolTx.Status == pool.TxStatusPending { + tx = &poolTx.Transaction + res, err := types.NewTransaction(*tx, nil, false, nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to build transaction response", err, true) + } + return res, nil + } + return nil, nil } func (e *EthEndpoints) getTransactionByHashFromSequencerNode(hash common.Hash, includeExtraInfo *bool) (interface{}, types.Error) { @@ -729,45 +716,44 @@ func (e *EthEndpoints) getTransactionByHashFromSequencerNode(hash common.Hash, i // GetTransactionCount returns account nonce func (e *EthEndpoints) GetTransactionCount(address types.ArgAddress, blockArg *types.BlockNumberOrHash) (interface{}, types.Error) { - return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - var ( - pendingNonce uint64 - nonce uint64 - err error - ) - - block, respErr := e.getBlockByArg(ctx, blockArg, dbTx) - if respErr != nil { - return nil, respErr - } + ctx := context.Background() + var ( + pendingNonce uint64 + nonce uint64 + err error + ) + + block, respErr := e.getBlockByArg(ctx, blockArg, nil) + if respErr != nil { + return nil, respErr + } - if blockArg != nil { - blockNumArg := blockArg.Number() - if blockNumArg != nil && *blockNumArg == types.PendingBlockNumber { - if e.cfg.SequencerNodeURI != "" { - return e.getTransactionCountFromSequencerNode(address.Address(), blockArg.Number()) - } - pendingNonce, err = e.pool.GetNonce(ctx, address.Address()) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to count pending transactions", err, true) - } + if blockArg != nil { + blockNumArg := blockArg.Number() + if blockNumArg != nil && *blockNumArg == types.PendingBlockNumber { + if e.cfg.SequencerNodeURI != "" { + return e.getTransactionCountFromSequencerNode(address.Address(), blockArg.Number()) + } + pendingNonce, err = e.pool.GetNonce(ctx, address.Address()) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to count pending transactions", err, true) } } + } - nonce, err = e.state.GetNonce(ctx, address.Address(), block.Root()) + nonce, err = e.state.GetNonce(ctx, address.Address(), block.Root()) - if errors.Is(err, state.ErrNotFound) { - return hex.EncodeUint64(0), nil - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to count transactions", err, true) - } + if errors.Is(err, state.ErrNotFound) { + return hex.EncodeUint64(0), nil + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to count transactions", err, true) + } - if pendingNonce > nonce { - nonce = pendingNonce - } + if pendingNonce > nonce { + nonce = pendingNonce + } - return hex.EncodeUint64(nonce), nil - }) + return hex.EncodeUint64(nonce), nil } func (e *EthEndpoints) getTransactionCountFromSequencerNode(address common.Address, number *types.BlockNumber) (interface{}, types.Error) { @@ -791,44 +777,42 @@ func (e *EthEndpoints) getTransactionCountFromSequencerNode(address common.Addre // GetBlockTransactionCountByHash returns the number of transactions in a // block from a block matching the given block hash. func (e *EthEndpoints) GetBlockTransactionCountByHash(hash types.ArgHash) (interface{}, types.Error) { - return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - c, err := e.state.GetL2BlockTransactionCountByHash(ctx, hash.Hash(), dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to count transactions", err, true) - } + ctx := context.Background() + c, err := e.state.GetL2BlockTransactionCountByHash(ctx, hash.Hash(), nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to count transactions", err, true) + } - return types.ArgUint64(c), nil - }) + return types.ArgUint64(c), nil } // GetBlockTransactionCountByNumber returns the number of transactions in a // block from a block matching the given block number. func (e *EthEndpoints) GetBlockTransactionCountByNumber(number *types.BlockNumber) (interface{}, types.Error) { - return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - if number != nil && *number == types.PendingBlockNumber { - if e.cfg.SequencerNodeURI != "" { - return e.getBlockTransactionCountByNumberFromSequencerNode(number) - } - c, err := e.pool.CountPendingTransactions(ctx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to count pending transactions", err, true) - } - return types.ArgUint64(c), nil - } - - var err error - blockNumber, rpcErr := number.GetNumericBlockNumber(ctx, e.state, e.etherman, dbTx) - if rpcErr != nil { - return nil, rpcErr + ctx := context.Background() + if number != nil && *number == types.PendingBlockNumber { + if e.cfg.SequencerNodeURI != "" { + return e.getBlockTransactionCountByNumberFromSequencerNode(number) } - - c, err := e.state.GetL2BlockTransactionCountByNumber(ctx, blockNumber, dbTx) + c, err := e.pool.CountPendingTransactions(ctx) if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to count transactions", err, true) + return RPCErrorResponse(types.DefaultErrorCode, "failed to count pending transactions", err, true) } - return types.ArgUint64(c), nil - }) + } + + var err error + blockNumber, rpcErr := number.GetNumericBlockNumber(ctx, e.state, e.etherman, nil) + if rpcErr != nil { + return nil, rpcErr + } + + c, err := e.state.GetL2BlockTransactionCountByNumber(ctx, blockNumber, nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to count transactions", err, true) + } + + return types.ArgUint64(c), nil } func (e *EthEndpoints) getBlockTransactionCountByNumberFromSequencerNode(number *types.BlockNumber) (interface{}, types.Error) { @@ -851,28 +835,27 @@ func (e *EthEndpoints) getBlockTransactionCountByNumberFromSequencerNode(number // GetTransactionReceipt returns a transaction receipt by his hash func (e *EthEndpoints) GetTransactionReceipt(hash types.ArgHash) (interface{}, types.Error) { - return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - tx, err := e.state.GetTransactionByHash(ctx, hash.Hash(), dbTx) - if errors.Is(err, state.ErrNotFound) { - return nil, nil - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get tx from state", err, true) - } + ctx := context.Background() + tx, err := e.state.GetTransactionByHash(ctx, hash.Hash(), nil) + if errors.Is(err, state.ErrNotFound) { + return nil, nil + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get tx from state", err, true) + } - r, err := e.state.GetTransactionReceipt(ctx, hash.Hash(), dbTx) - if errors.Is(err, state.ErrNotFound) { - return nil, nil - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get tx receipt from state", err, true) - } + r, err := e.state.GetTransactionReceipt(ctx, hash.Hash(), nil) + if errors.Is(err, state.ErrNotFound) { + return nil, nil + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get tx receipt from state", err, true) + } - receipt, err := types.NewReceipt(*tx, r, nil) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to build the receipt response", err, true) - } + receipt, err := types.NewReceipt(*tx, r, nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to build the receipt response", err, true) + } - return receipt, nil - }) + return receipt, nil } // NewBlockFilter creates a filter in the node, to notify when @@ -896,9 +879,8 @@ func (e *EthEndpoints) newBlockFilter(wsConn *concurrentWsConn) (interface{}, ty // to notify when the state changes (logs). To check if the state // has changed, call eth_getFilterChanges. func (e *EthEndpoints) NewFilter(filter LogFilter) (interface{}, types.Error) { - return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - return e.newFilter(ctx, nil, filter, dbTx) - }) + ctx := context.Background() + return e.newFilter(ctx, nil, filter, nil) } // internal @@ -1006,46 +988,45 @@ func (e *EthEndpoints) UninstallFilter(filterID string) (interface{}, types.Erro // Syncing returns an object with data about the sync status or false. // https://eth.wiki/json-rpc/API#eth_syncing func (e *EthEndpoints) Syncing() (interface{}, types.Error) { - return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - _, err := e.state.GetLastL2BlockNumber(ctx, dbTx) - if errors.Is(err, state.ErrStateNotSynchronized) { - return nil, types.NewRPCError(types.DefaultErrorCode, state.ErrStateNotSynchronized.Error()) - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get last block number from state", err, true) - } + ctx := context.Background() + _, err := e.state.GetLastL2BlockNumber(ctx, nil) + if errors.Is(err, state.ErrStateNotSynchronized) { + return nil, types.NewRPCError(types.DefaultErrorCode, state.ErrStateNotSynchronized.Error()) + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get last block number from state", err, true) + } - syncInfo, err := e.state.GetSyncingInfo(ctx, dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get syncing info from state", err, true) - } + syncInfo, err := e.state.GetSyncingInfo(ctx, nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get syncing info from state", err, true) + } - if !syncInfo.IsSynchronizing { - return false, nil - } - if e.cfg.SequencerNodeURI != "" { - // If we have a trusted node we ask it for the highest l2 block - res, err := e.getHighestL2BlockFromTrustedNode() - if err != nil { - log.Warnf("failed to get highest l2 block from trusted node: %v", err) + if !syncInfo.IsSynchronizing { + return false, nil + } + if e.cfg.SequencerNodeURI != "" { + // If we have a trusted node we ask it for the highest l2 block + res, err := e.getHighestL2BlockFromTrustedNode() + if err != nil { + log.Warnf("failed to get highest l2 block from trusted node: %v", err) + } else { + highestL2BlockInTrusted := res.(uint64) + if highestL2BlockInTrusted > syncInfo.CurrentBlockNumber { + syncInfo.EstimatedHighestBlock = highestL2BlockInTrusted } else { - highestL2BlockInTrusted := res.(uint64) - if highestL2BlockInTrusted > syncInfo.CurrentBlockNumber { - syncInfo.EstimatedHighestBlock = highestL2BlockInTrusted - } else { - log.Warnf("highest l2 block in trusted node (%d) is lower than the current block number in the state (%d)", highestL2BlockInTrusted, syncInfo.CurrentBlockNumber) - } + log.Warnf("highest l2 block in trusted node (%d) is lower than the current block number in the state (%d)", highestL2BlockInTrusted, syncInfo.CurrentBlockNumber) } } - return struct { - S types.ArgUint64 `json:"startingBlock"` - C types.ArgUint64 `json:"currentBlock"` - H types.ArgUint64 `json:"highestBlock"` - }{ - S: types.ArgUint64(syncInfo.InitialSyncingBlock), - C: types.ArgUint64(syncInfo.CurrentBlockNumber), - H: types.ArgUint64(syncInfo.EstimatedHighestBlock), - }, nil - }) + } + return struct { + S types.ArgUint64 `json:"startingBlock"` + C types.ArgUint64 `json:"currentBlock"` + H types.ArgUint64 `json:"highestBlock"` + }{ + S: types.ArgUint64(syncInfo.InitialSyncingBlock), + C: types.ArgUint64(syncInfo.CurrentBlockNumber), + H: types.ArgUint64(syncInfo.EstimatedHighestBlock), + }, nil } // GetUncleByBlockHashAndIndex returns information about a uncle of a @@ -1109,13 +1090,12 @@ func (e *EthEndpoints) Subscribe(wsConn *concurrentWsConn, name string, logFilte case "newHeads": return e.newBlockFilter(wsConn) case "logs": - return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - var lf LogFilter - if logFilter != nil { - lf = *logFilter - } - return e.newFilter(ctx, wsConn, lf, dbTx) - }) + ctx := context.Background() + var lf LogFilter + if logFilter != nil { + lf = *logFilter + } + return e.newFilter(ctx, wsConn, lf, nil) case "pendingTransactions", "newPendingTransactions": return e.newPendingTransactionFilter(wsConn) case "syncing": diff --git a/jsonrpc/endpoints_eth_test.go b/jsonrpc/endpoints_eth_test.go index 69291429d0..a60ce9418d 100644 --- a/jsonrpc/endpoints_eth_test.go +++ b/jsonrpc/endpoints_eth_test.go @@ -63,18 +63,8 @@ func TestBlockNumber(t *testing.T) { ExpectedError: nil, ExpectedResult: blockNumTen.Uint64(), SetupMocks: func(m *mocksWrapper) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(blockNumTen.Uint64(), nil). Once() }, @@ -84,18 +74,8 @@ func TestBlockNumber(t *testing.T) { ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get the last block number from state"), ExpectedResult: 0, SetupMocks: func(m *mocksWrapper) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(uint64(0), errors.New("failed to get last block number")). Once() }, @@ -155,8 +135,6 @@ func TestCall(t *testing.T) { expectedError: nil, setupMocks: func(c Config, m *mocksWrapper, testCase *testCase) { nonce := uint64(7) - m.DbTx.On("Commit", context.Background()).Return(nil).Once() - m.State.On("BeginStateTransaction", context.Background()).Return(m.DbTx, nil).Once() txArgs := testCase.params[0].(types.TxArgs) txMatchBy := mock.MatchedBy(func(tx *ethTypes.Transaction) bool { gasPrice := big.NewInt(0).SetBytes(*txArgs.GasPrice) @@ -171,10 +149,10 @@ func TestCall(t *testing.T) { return match }) block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumOne, Root: blockRoot})) - m.State.On("GetL2BlockByNumber", context.Background(), blockNumOneUint64, m.DbTx).Return(block, nil).Once() + m.State.On("GetL2BlockByNumber", context.Background(), blockNumOneUint64, nil).Return(block, nil).Once() m.State.On("GetNonce", context.Background(), *txArgs.From, blockRoot).Return(nonce, nil).Once() m.State. - On("ProcessUnsignedTransaction", context.Background(), txMatchBy, *txArgs.From, &blockNumOneUint64, true, m.DbTx). + On("ProcessUnsignedTransaction", context.Background(), txMatchBy, *txArgs.From, &blockNumOneUint64, true, nil). Return(&runtime.ExecutionResult{ReturnValue: testCase.expectedResult}, nil). Once() }, @@ -198,11 +176,9 @@ func TestCall(t *testing.T) { expectedError: nil, setupMocks: func(c Config, m *mocksWrapper, testCase *testCase) { nonce := uint64(7) - m.DbTx.On("Commit", context.Background()).Return(nil).Once() - m.State.On("BeginStateTransaction", context.Background()).Return(m.DbTx, nil).Once() block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumOne, Root: blockRoot})) m.State. - On("GetL2BlockByHash", context.Background(), blockHash, m.DbTx). + On("GetL2BlockByHash", context.Background(), blockHash, nil). Return(block, nil).Once() txArgs := testCase.params[0].(types.TxArgs) txMatchBy := mock.MatchedBy(func(tx *ethTypes.Transaction) bool { @@ -218,7 +194,7 @@ func TestCall(t *testing.T) { }) m.State.On("GetNonce", context.Background(), *txArgs.From, blockRoot).Return(nonce, nil).Once() m.State. - On("ProcessUnsignedTransaction", context.Background(), txMatchBy, *txArgs.From, &blockNumOneUint64, true, m.DbTx). + On("ProcessUnsignedTransaction", context.Background(), txMatchBy, *txArgs.From, &blockNumOneUint64, true, nil). Return(&runtime.ExecutionResult{ReturnValue: testCase.expectedResult}, nil). Once() }, @@ -240,9 +216,7 @@ func TestCall(t *testing.T) { expectedError: nil, setupMocks: func(c Config, m *mocksWrapper, testCase *testCase) { nonce := uint64(7) - m.DbTx.On("Commit", context.Background()).Return(nil).Once() - m.State.On("BeginStateTransaction", context.Background()).Return(m.DbTx, nil).Once() - m.State.On("GetLastL2BlockNumber", context.Background(), m.DbTx).Return(blockNumOne.Uint64(), nil).Once() + m.State.On("GetLastL2BlockNumber", context.Background(), nil).Return(blockNumOne.Uint64(), nil).Once() txArgs := testCase.params[0].(types.TxArgs) txMatchBy := mock.MatchedBy(func(tx *ethTypes.Transaction) bool { gasPrice := big.NewInt(0).SetBytes(*txArgs.GasPrice) @@ -257,10 +231,10 @@ func TestCall(t *testing.T) { return match }) block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumOne, Root: blockRoot})) - m.State.On("GetL2BlockByNumber", context.Background(), blockNumOneUint64, m.DbTx).Return(block, nil).Once() + m.State.On("GetL2BlockByNumber", context.Background(), blockNumOneUint64, nil).Return(block, nil).Once() m.State.On("GetNonce", context.Background(), *txArgs.From, blockRoot).Return(nonce, nil).Once() m.State. - On("ProcessUnsignedTransaction", context.Background(), txMatchBy, *txArgs.From, nilUint64, true, m.DbTx). + On("ProcessUnsignedTransaction", context.Background(), txMatchBy, *txArgs.From, nilUint64, true, nil). Return(&runtime.ExecutionResult{ReturnValue: testCase.expectedResult}, nil). Once() }, @@ -282,11 +256,9 @@ func TestCall(t *testing.T) { expectedError: nil, setupMocks: func(c Config, m *mocksWrapper, testCase *testCase) { nonce := uint64(7) - m.DbTx.On("Commit", context.Background()).Return(nil).Once() - m.State.On("BeginStateTransaction", context.Background()).Return(m.DbTx, nil).Once() block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumTen, Root: blockRoot})) m.State. - On("GetL2BlockByHash", context.Background(), blockHash, m.DbTx). + On("GetL2BlockByHash", context.Background(), blockHash, nil). Return(block, nil).Once() txArgs := testCase.params[0].(types.TxArgs) txMatchBy := mock.MatchedBy(func(tx *ethTypes.Transaction) bool { @@ -302,7 +274,7 @@ func TestCall(t *testing.T) { }) m.State.On("GetNonce", context.Background(), *txArgs.From, blockRoot).Return(nonce, nil).Once() m.State. - On("ProcessUnsignedTransaction", context.Background(), txMatchBy, *txArgs.From, &blockNumTenUint64, true, m.DbTx). + On("ProcessUnsignedTransaction", context.Background(), txMatchBy, *txArgs.From, &blockNumTenUint64, true, nil). Return(&runtime.ExecutionResult{ReturnValue: testCase.expectedResult}, nil). Once() }, @@ -324,8 +296,6 @@ func TestCall(t *testing.T) { expectedError: nil, setupMocks: func(c Config, m *mocksWrapper, testCase *testCase) { nonce := uint64(7) - m.DbTx.On("Commit", context.Background()).Return(nil).Once() - m.State.On("BeginStateTransaction", context.Background()).Return(m.DbTx, nil).Once() txArgs := testCase.params[0].(types.TxArgs) txMatchBy := mock.MatchedBy(func(tx *ethTypes.Transaction) bool { gasPrice := big.NewInt(0).SetBytes(*txArgs.GasPrice) @@ -339,10 +309,10 @@ func TestCall(t *testing.T) { tx.Nonce() == nonce }) block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumTen, Root: blockRoot})) - m.State.On("GetL2BlockByNumber", context.Background(), blockNumTenUint64, m.DbTx).Return(block, nil).Once() + m.State.On("GetL2BlockByNumber", context.Background(), blockNumTenUint64, nil).Return(block, nil).Once() m.State.On("GetNonce", context.Background(), *txArgs.From, blockRoot).Return(nonce, nil).Once() m.State. - On("ProcessUnsignedTransaction", context.Background(), txMatchBy, *txArgs.From, &blockNumTenUint64, true, m.DbTx). + On("ProcessUnsignedTransaction", context.Background(), txMatchBy, *txArgs.From, &blockNumTenUint64, true, nil). Return(&runtime.ExecutionResult{ReturnValue: testCase.expectedResult}, nil). Once() }, @@ -362,10 +332,8 @@ func TestCall(t *testing.T) { expectedError: nil, setupMocks: func(c Config, m *mocksWrapper, testCase *testCase) { blockHeader := state.NewL2Header(ðTypes.Header{GasLimit: s.Config.MaxCumulativeGasUsed}) - m.DbTx.On("Commit", context.Background()).Return(nil).Once() - m.State.On("BeginStateTransaction", context.Background()).Return(m.DbTx, nil).Once() - m.State.On("GetLastL2BlockNumber", context.Background(), m.DbTx).Return(blockNumOne.Uint64(), nil).Once() - m.State.On("GetL2BlockHeaderByNumber", context.Background(), blockNumOne.Uint64(), m.DbTx).Return(blockHeader, nil).Once() + m.State.On("GetLastL2BlockNumber", context.Background(), nil).Return(blockNumOne.Uint64(), nil).Once() + m.State.On("GetL2BlockHeaderByNumber", context.Background(), blockNumOne.Uint64(), nil).Return(blockHeader, nil).Once() txArgs := testCase.params[0].(types.TxArgs) txMatchBy := mock.MatchedBy(func(tx *ethTypes.Transaction) bool { gasPrice := big.NewInt(0).SetBytes(*txArgs.GasPrice) @@ -379,9 +347,9 @@ func TestCall(t *testing.T) { return hasTx && gasMatch && toMatch && gasPriceMatch && valueMatch && dataMatch }) block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumOne, Root: blockRoot})) - m.State.On("GetL2BlockByNumber", context.Background(), blockNumOneUint64, m.DbTx).Return(block, nil).Once() + m.State.On("GetL2BlockByNumber", context.Background(), blockNumOneUint64, nil).Return(block, nil).Once() m.State. - On("ProcessUnsignedTransaction", context.Background(), txMatchBy, common.HexToAddress(state.DefaultSenderAddress), nilUint64, true, m.DbTx). + On("ProcessUnsignedTransaction", context.Background(), txMatchBy, common.HexToAddress(state.DefaultSenderAddress), nilUint64, true, nil). Return(&runtime.ExecutionResult{ReturnValue: testCase.expectedResult}, nil). Once() }, @@ -401,10 +369,8 @@ func TestCall(t *testing.T) { expectedError: nil, setupMocks: func(c Config, m *mocksWrapper, testCase *testCase) { blockHeader := state.NewL2Header(ðTypes.Header{GasLimit: s.Config.MaxCumulativeGasUsed}) - m.DbTx.On("Commit", context.Background()).Return(nil).Once() - m.State.On("BeginStateTransaction", context.Background()).Return(m.DbTx, nil).Once() - m.State.On("GetLastL2BlockNumber", context.Background(), m.DbTx).Return(blockNumOne.Uint64(), nil).Once() - m.State.On("GetL2BlockHeaderByNumber", context.Background(), blockNumOne.Uint64(), m.DbTx).Return(blockHeader, nil).Once() + m.State.On("GetLastL2BlockNumber", context.Background(), nil).Return(blockNumOne.Uint64(), nil).Once() + m.State.On("GetL2BlockHeaderByNumber", context.Background(), blockNumOne.Uint64(), nil).Return(blockHeader, nil).Once() txArgs := testCase.params[0].(types.TxArgs) txMatchBy := mock.MatchedBy(func(tx *ethTypes.Transaction) bool { gasPrice := big.NewInt(0).SetBytes(*txArgs.GasPrice) @@ -418,9 +384,9 @@ func TestCall(t *testing.T) { return hasTx && gasMatch && toMatch && gasPriceMatch && valueMatch && dataMatch }) block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumOne, Root: blockRoot})) - m.State.On("GetL2BlockByNumber", context.Background(), blockNumOneUint64, m.DbTx).Return(block, nil).Once() + m.State.On("GetL2BlockByNumber", context.Background(), blockNumOneUint64, nil).Return(block, nil).Once() m.State. - On("ProcessUnsignedTransaction", context.Background(), txMatchBy, common.HexToAddress(state.DefaultSenderAddress), nilUint64, true, m.DbTx). + On("ProcessUnsignedTransaction", context.Background(), txMatchBy, common.HexToAddress(state.DefaultSenderAddress), nilUint64, true, nil). Return(&runtime.ExecutionResult{ReturnValue: testCase.expectedResult}, nil). Once() }, @@ -439,12 +405,10 @@ func TestCall(t *testing.T) { expectedResult: nil, expectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get block header"), setupMocks: func(c Config, m *mocksWrapper, testCase *testCase) { - m.DbTx.On("Rollback", context.Background()).Return(nil).Once() - m.State.On("BeginStateTransaction", context.Background()).Return(m.DbTx, nil).Once() - m.State.On("GetLastL2BlockNumber", context.Background(), m.DbTx).Return(blockNumOne.Uint64(), nil).Once() + m.State.On("GetLastL2BlockNumber", context.Background(), nil).Return(blockNumOne.Uint64(), nil).Once() block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumOne, Root: blockRoot})) - m.State.On("GetL2BlockByNumber", context.Background(), blockNumOneUint64, m.DbTx).Return(block, nil).Once() - m.State.On("GetL2BlockHeaderByNumber", context.Background(), blockNumOne.Uint64(), m.DbTx).Return(nil, errors.New("failed to get block header")).Once() + m.State.On("GetL2BlockByNumber", context.Background(), blockNumOneUint64, nil).Return(block, nil).Once() + m.State.On("GetL2BlockHeaderByNumber", context.Background(), blockNumOne.Uint64(), nil).Return(nil, errors.New("failed to get block header")).Once() }, }, { @@ -464,9 +428,7 @@ func TestCall(t *testing.T) { expectedError: types.NewRPCError(types.DefaultErrorCode, "failed to process unsigned transaction"), setupMocks: func(c Config, m *mocksWrapper, testCase *testCase) { nonce := uint64(7) - m.DbTx.On("Rollback", context.Background()).Return(nil).Once() - m.State.On("BeginStateTransaction", context.Background()).Return(m.DbTx, nil).Once() - m.State.On("GetLastL2BlockNumber", context.Background(), m.DbTx).Return(blockNumOne.Uint64(), nil).Once() + m.State.On("GetLastL2BlockNumber", context.Background(), nil).Return(blockNumOne.Uint64(), nil).Once() txArgs := testCase.params[0].(types.TxArgs) txMatchBy := mock.MatchedBy(func(tx *ethTypes.Transaction) bool { gasPrice := big.NewInt(0).SetBytes(*txArgs.GasPrice) @@ -481,10 +443,10 @@ func TestCall(t *testing.T) { return hasTx && gasMatch && toMatch && gasPriceMatch && valueMatch && dataMatch && nonceMatch }) block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumOne, Root: blockRoot})) - m.State.On("GetL2BlockByNumber", context.Background(), blockNumOneUint64, m.DbTx).Return(block, nil).Once() + m.State.On("GetL2BlockByNumber", context.Background(), blockNumOneUint64, nil).Return(block, nil).Once() m.State.On("GetNonce", context.Background(), *txArgs.From, blockRoot).Return(nonce, nil).Once() m.State. - On("ProcessUnsignedTransaction", context.Background(), txMatchBy, *txArgs.From, nilUint64, true, m.DbTx). + On("ProcessUnsignedTransaction", context.Background(), txMatchBy, *txArgs.From, nilUint64, true, nil). Return(&runtime.ExecutionResult{Err: errors.New("failed to process unsigned transaction")}, nil). Once() }, @@ -506,9 +468,7 @@ func TestCall(t *testing.T) { expectedError: types.NewRPCError(types.DefaultErrorCode, "execution reverted"), setupMocks: func(c Config, m *mocksWrapper, testCase *testCase) { nonce := uint64(7) - m.DbTx.On("Rollback", context.Background()).Return(nil).Once() - m.State.On("BeginStateTransaction", context.Background()).Return(m.DbTx, nil).Once() - m.State.On("GetLastL2BlockNumber", context.Background(), m.DbTx).Return(blockNumOne.Uint64(), nil).Once() + m.State.On("GetLastL2BlockNumber", context.Background(), nil).Return(blockNumOne.Uint64(), nil).Once() txArgs := testCase.params[0].(types.TxArgs) txMatchBy := mock.MatchedBy(func(tx *ethTypes.Transaction) bool { gasPrice := big.NewInt(0).SetBytes(*txArgs.GasPrice) @@ -523,10 +483,10 @@ func TestCall(t *testing.T) { return hasTx && gasMatch && toMatch && gasPriceMatch && valueMatch && dataMatch && nonceMatch }) block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumOne, Root: blockRoot})) - m.State.On("GetL2BlockByNumber", context.Background(), blockNumOneUint64, m.DbTx).Return(block, nil).Once() + m.State.On("GetL2BlockByNumber", context.Background(), blockNumOneUint64, nil).Return(block, nil).Once() m.State.On("GetNonce", context.Background(), *txArgs.From, blockRoot).Return(nonce, nil).Once() m.State. - On("ProcessUnsignedTransaction", context.Background(), txMatchBy, *txArgs.From, nilUint64, true, m.DbTx). + On("ProcessUnsignedTransaction", context.Background(), txMatchBy, *txArgs.From, nilUint64, true, nil). Return(&runtime.ExecutionResult{Err: runtime.ErrExecutionReverted}, nil). Once() }, @@ -677,18 +637,15 @@ func TestEstimateGas(t *testing.T) { return matchTo && matchGasPrice && matchValue && matchData && matchNonce }) - m.DbTx.On("Commit", context.Background()).Return(nil).Once() - m.State.On("BeginStateTransaction", context.Background()).Return(m.DbTx, nil).Once() - block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumTen, Root: blockRoot})) - m.State.On("GetLastL2Block", context.Background(), m.DbTx).Return(block, nil).Once() + m.State.On("GetLastL2Block", context.Background(), nil).Return(block, nil).Once() m.State. On("GetNonce", context.Background(), *txArgs.From, blockRoot). Return(nonce, nil). Once() m.State. - On("EstimateGas", txMatchBy, *txArgs.From, nilUint64, m.DbTx). + On("EstimateGas", txMatchBy, *txArgs.From, nilUint64, nil). Return(*testCase.expectedResult, nil, nil). Once() }, @@ -722,14 +679,11 @@ func TestEstimateGas(t *testing.T) { return matchTo && matchGasPrice && matchValue && matchData && matchNonce }) - m.DbTx.On("Commit", context.Background()).Return(nil).Once() - m.State.On("BeginStateTransaction", context.Background()).Return(m.DbTx, nil).Once() - block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumTen, Root: blockRoot})) - m.State.On("GetLastL2Block", context.Background(), m.DbTx).Return(block, nil).Once() + m.State.On("GetLastL2Block", context.Background(), nil).Return(block, nil).Once() m.State. - On("EstimateGas", txMatchBy, common.HexToAddress(state.DefaultSenderAddress), nilUint64, m.DbTx). + On("EstimateGas", txMatchBy, common.HexToAddress(state.DefaultSenderAddress), nilUint64, nil). Return(*testCase.expectedResult, nil, nil). Once() }, @@ -822,18 +776,8 @@ func TestGetBalance(t *testing.T) { expectedBalance: 0, expectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get the last block number from state"), setupMocks: func(m *mocksWrapper, t *testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastL2Block", context.Background(), m.DbTx). + On("GetLastL2Block", context.Background(), nil). Return(nil, errors.New("failed to get last block number")).Once() }, }, @@ -846,19 +790,9 @@ func TestGetBalance(t *testing.T) { expectedBalance: 1000, expectedError: nil, setupMocks: func(m *mocksWrapper, t *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumTen, Root: blockRoot})) m.State. - On("GetLastL2Block", context.Background(), m.DbTx). + On("GetLastL2Block", context.Background(), nil). Return(block, nil).Once() m.State. @@ -879,19 +813,9 @@ func TestGetBalance(t *testing.T) { expectedBalance: 1000, expectedError: nil, setupMocks: func(m *mocksWrapper, t *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumTen, Root: blockRoot})) m.State. - On("GetL2BlockByHash", context.Background(), blockHash, m.DbTx). + On("GetL2BlockByHash", context.Background(), blockHash, nil). Return(block, nil). Once() @@ -910,18 +834,8 @@ func TestGetBalance(t *testing.T) { expectedBalance: 0, expectedError: nil, setupMocks: func(m *mocksWrapper, t *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumTen, Root: blockRoot})) - m.State.On("GetLastL2Block", context.Background(), m.DbTx).Return(block, nil).Once() + m.State.On("GetLastL2Block", context.Background(), nil).Return(block, nil).Once() m.State. On("GetBalance", context.Background(), addressArg, blockRoot). @@ -938,18 +852,8 @@ func TestGetBalance(t *testing.T) { expectedBalance: 0, expectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get balance from state"), setupMocks: func(m *mocksWrapper, t *testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumTen, Root: blockRoot})) - m.State.On("GetLastL2Block", context.Background(), m.DbTx).Return(block, nil).Once() + m.State.On("GetLastL2Block", context.Background(), nil).Return(block, nil).Once() m.State. On("GetBalance", context.Background(), addressArg, blockRoot). @@ -1004,18 +908,8 @@ func TestGetL2BlockByHash(t *testing.T) { ExpectedResult: nil, ExpectedError: ethereum.NotFound, SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetL2BlockByHash", context.Background(), tc.Hash, m.DbTx). + On("GetL2BlockByHash", context.Background(), tc.Hash, nil). Return(nil, state.ErrNotFound) }, }, @@ -1025,18 +919,8 @@ func TestGetL2BlockByHash(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get block by hash from state"), SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetL2BlockByHash", context.Background(), tc.Hash, m.DbTx). + On("GetL2BlockByHash", context.Background(), tc.Hash, nil). Return(nil, errors.New("failed to get block from state")). Once() }, @@ -1059,24 +943,14 @@ func TestGetL2BlockByHash(t *testing.T) { } block := state.NewL2Block(state.NewL2Header(tc.ExpectedResult.Header()), tc.ExpectedResult.Transactions(), uncles, []*ethTypes.Receipt{ethTypes.NewReceipt([]byte{}, false, uint64(0))}, st) - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetL2BlockByHash", context.Background(), tc.Hash, m.DbTx). + On("GetL2BlockByHash", context.Background(), tc.Hash, nil). Return(block, nil). Once() for _, tx := range tc.ExpectedResult.Transactions() { m.State. - On("GetTransactionReceipt", context.Background(), tx.Hash(), m.DbTx). + On("GetTransactionReceipt", context.Background(), tx.Hash(), nil). Return(ethTypes.NewReceipt([]byte{}, false, uint64(0)), nil). Once() } @@ -1264,18 +1138,8 @@ func TestGetL2BlockByNumber(t *testing.T) { ExpectedResult: nil, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetL2BlockByNumber", context.Background(), tc.Number.Uint64(), m.DbTx). + On("GetL2BlockByNumber", context.Background(), tc.Number.Uint64(), nil). Return(nil, state.ErrNotFound) }, }, @@ -1285,30 +1149,20 @@ func TestGetL2BlockByNumber(t *testing.T) { ExpectedResult: rpcBlock, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetL2BlockByNumber", context.Background(), tc.Number.Uint64(), m.DbTx). + On("GetL2BlockByNumber", context.Background(), tc.Number.Uint64(), nil). Return(l2Block, nil). Once() for _, receipt := range receipts { m.State. - On("GetTransactionReceipt", context.Background(), receipt.TxHash, m.DbTx). + On("GetTransactionReceipt", context.Background(), receipt.TxHash, nil). Return(receipt, nil). Once() } for _, signedTx := range signedTransactions { m.State. - On("GetL2TxHashByTxHash", context.Background(), signedTx.Hash(), m.DbTx). + On("GetL2TxHashByTxHash", context.Background(), signedTx.Hash(), nil). Return(state.Ptr(signedTx.Hash()), nil). Once() } @@ -1320,35 +1174,25 @@ func TestGetL2BlockByNumber(t *testing.T) { ExpectedResult: rpcBlock, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(uint64(tc.ExpectedResult.Number), nil). Once() m.State. - On("GetL2BlockByNumber", context.Background(), uint64(tc.ExpectedResult.Number), m.DbTx). + On("GetL2BlockByNumber", context.Background(), uint64(tc.ExpectedResult.Number), nil). Return(l2Block, nil). Once() for _, receipt := range receipts { m.State. - On("GetTransactionReceipt", context.Background(), receipt.TxHash, m.DbTx). + On("GetTransactionReceipt", context.Background(), receipt.TxHash, nil). Return(receipt, nil). Once() } for _, signedTx := range signedTransactions { m.State. - On("GetL2TxHashByTxHash", context.Background(), signedTx.Hash(), m.DbTx). + On("GetL2TxHashByTxHash", context.Background(), signedTx.Hash(), nil). Return(state.Ptr(signedTx.Hash()), nil). Once() } @@ -1360,18 +1204,8 @@ func TestGetL2BlockByNumber(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get the last block number from state"), SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(uint64(0), errors.New("failed to get last block number")). Once() }, @@ -1382,23 +1216,13 @@ func TestGetL2BlockByNumber(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "couldn't load block from state by number 1"), SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(uint64(1), nil). Once() m.State. - On("GetL2BlockByNumber", context.Background(), uint64(1), m.DbTx). + On("GetL2BlockByNumber", context.Background(), uint64(1), nil). Return(nil, errors.New("failed to load block by number")). Once() }, @@ -1427,18 +1251,8 @@ func TestGetL2BlockByNumber(t *testing.T) { tc.ExpectedResult.Nonce = nil tc.ExpectedResult.TotalDifficulty = nil - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastL2Block", context.Background(), m.DbTx). + On("GetLastL2Block", context.Background(), nil). Return(lastBlock, nil). Once() }, @@ -1449,18 +1263,8 @@ func TestGetL2BlockByNumber(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "couldn't load last block from state to compute the pending block"), SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastL2Block", context.Background(), m.DbTx). + On("GetLastL2Block", context.Background(), nil). Return(nil, errors.New("failed to load last block")). Once() }, @@ -1626,18 +1430,8 @@ func TestGetCode(t *testing.T) { ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get the last block number from state"), SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetLastL2Block", context.Background(), m.DbTx). + On("GetLastL2Block", context.Background(), nil). Return(nil, errors.New("failed to get last block number")). Once() }, @@ -1652,18 +1446,8 @@ func TestGetCode(t *testing.T) { ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get code"), SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumOne, Root: blockRoot})) - m.State.On("GetL2BlockByNumber", context.Background(), blockNumOne.Uint64(), m.DbTx).Return(block, nil).Once() + m.State.On("GetL2BlockByNumber", context.Background(), blockNumOne.Uint64(), nil).Return(block, nil).Once() m.State. On("GetCode", context.Background(), addressArg, blockRoot). @@ -1681,18 +1465,8 @@ func TestGetCode(t *testing.T) { ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumOne, Root: blockRoot})) - m.State.On("GetL2BlockByNumber", context.Background(), blockNumOne.Uint64(), m.DbTx).Return(block, nil).Once() + m.State.On("GetL2BlockByNumber", context.Background(), blockNumOne.Uint64(), nil).Return(block, nil).Once() m.State. On("GetCode", context.Background(), addressArg, blockRoot). @@ -1710,18 +1484,8 @@ func TestGetCode(t *testing.T) { ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumOne, Root: blockRoot})) - m.State.On("GetL2BlockByNumber", context.Background(), blockNumOne.Uint64(), m.DbTx).Return(block, nil).Once() + m.State.On("GetL2BlockByNumber", context.Background(), blockNumOne.Uint64(), nil).Return(block, nil).Once() m.State. On("GetCode", context.Background(), addressArg, blockRoot). @@ -1738,19 +1502,9 @@ func TestGetCode(t *testing.T) { ExpectedResult: []byte{1, 2, 3}, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumTen, Root: blockRoot})) m.State. - On("GetL2BlockByHash", context.Background(), blockHash, m.DbTx). + On("GetL2BlockByHash", context.Background(), blockHash, nil). Return(block, nil). Once() @@ -1819,18 +1573,8 @@ func TestGetStorageAt(t *testing.T) { ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get the last block number from state"), SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetLastL2Block", context.Background(), m.DbTx). + On("GetLastL2Block", context.Background(), nil). Return(nil, errors.New("failed to get last block number")). Once() }, @@ -1848,19 +1592,9 @@ func TestGetStorageAt(t *testing.T) { ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get storage value from state"), SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - blockNumber := big.NewInt(1) block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumber, Root: blockRoot})) - m.State.On("GetL2BlockByNumber", context.Background(), blockNumber.Uint64(), m.DbTx).Return(block, nil).Once() + m.State.On("GetL2BlockByNumber", context.Background(), blockNumber.Uint64(), nil).Return(block, nil).Once() m.State. On("GetStorageAt", context.Background(), addressArg, keyArg.Big(), blockRoot). @@ -1881,19 +1615,9 @@ func TestGetStorageAt(t *testing.T) { ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - blockNumber := big.NewInt(1) block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumber, Root: blockRoot})) - m.State.On("GetL2BlockByNumber", context.Background(), blockNumber.Uint64(), m.DbTx).Return(block, nil).Once() + m.State.On("GetL2BlockByNumber", context.Background(), blockNumber.Uint64(), nil).Return(block, nil).Once() m.State. On("GetStorageAt", context.Background(), addressArg, keyArg.Big(), blockRoot). @@ -1914,19 +1638,9 @@ func TestGetStorageAt(t *testing.T) { ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - blockNumber := big.NewInt(1) block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumber, Root: blockRoot})) - m.State.On("GetL2BlockByNumber", context.Background(), blockNumber.Uint64(), m.DbTx).Return(block, nil).Once() + m.State.On("GetL2BlockByNumber", context.Background(), blockNumber.Uint64(), nil).Return(block, nil).Once() m.State. On("GetStorageAt", context.Background(), addressArg, keyArg.Big(), blockRoot). @@ -1947,19 +1661,9 @@ func TestGetStorageAt(t *testing.T) { ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumTen, Root: blockRoot})) m.State. - On("GetL2BlockByHash", context.Background(), blockHash, m.DbTx). + On("GetL2BlockByHash", context.Background(), blockHash, nil). Return(block, nil). Once() @@ -2030,18 +1734,8 @@ func TestSyncing(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get last block number from state"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(uint64(0), errors.New("failed to get last l2 block number from state")). Once() }, @@ -2051,23 +1745,13 @@ func TestSyncing(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get syncing info from state"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(uint64(10), nil). Once() m.State. - On("GetSyncingInfo", context.Background(), m.DbTx). + On("GetSyncingInfo", context.Background(), nil). Return(state.SyncingInfo{}, errors.New("failed to get syncing info from state")). Once() }, @@ -2077,23 +1761,13 @@ func TestSyncing(t *testing.T) { ExpectedResult: ðereum.SyncProgress{StartingBlock: 1, CurrentBlock: 2, HighestBlock: 3}, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(uint64(10), nil). Once() m.State. - On("GetSyncingInfo", context.Background(), m.DbTx). + On("GetSyncingInfo", context.Background(), nil). Return(state.SyncingInfo{InitialSyncingBlock: 1, CurrentBlockNumber: 2, EstimatedHighestBlock: 3, IsSynchronizing: true}, nil). Once() }, @@ -2103,23 +1777,13 @@ func TestSyncing(t *testing.T) { ExpectedResult: nil, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(uint64(10), nil). Once() m.State. - On("GetSyncingInfo", context.Background(), m.DbTx). + On("GetSyncingInfo", context.Background(), nil). Return(state.SyncingInfo{InitialSyncingBlock: 1, CurrentBlockNumber: 1, EstimatedHighestBlock: 3, IsSynchronizing: false}, nil). Once() }, @@ -2129,23 +1793,13 @@ func TestSyncing(t *testing.T) { ExpectedResult: nil, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(uint64(10), nil). Once() m.State. - On("GetSyncingInfo", context.Background(), m.DbTx). + On("GetSyncingInfo", context.Background(), nil). Return(state.SyncingInfo{InitialSyncingBlock: 1, CurrentBlockNumber: 2, EstimatedHighestBlock: 3, IsSynchronizing: false}, nil). Once() }, @@ -2207,18 +1861,9 @@ func TestGetTransactionL2onByBlockHashAndIndex(t *testing.T) { ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { tx := tc.ExpectedResult - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() m.State. - On("GetTransactionByL2BlockHashAndIndex", context.Background(), tc.Hash, uint64(tc.Index), m.DbTx). + On("GetTransactionByL2BlockHashAndIndex", context.Background(), tc.Hash, uint64(tc.Index), nil). Return(tx, nil). Once() @@ -2228,7 +1873,7 @@ func TestGetTransactionL2onByBlockHashAndIndex(t *testing.T) { receipt.TransactionIndex = tc.Index m.State. - On("GetTransactionReceipt", context.Background(), tx.Hash(), m.DbTx). + On("GetTransactionReceipt", context.Background(), tx.Hash(), nil). Return(receipt, nil). Once() }, @@ -2240,18 +1885,8 @@ func TestGetTransactionL2onByBlockHashAndIndex(t *testing.T) { ExpectedResult: nil, ExpectedError: ethereum.NotFound, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetTransactionByL2BlockHashAndIndex", context.Background(), tc.Hash, uint64(tc.Index), m.DbTx). + On("GetTransactionByL2BlockHashAndIndex", context.Background(), tc.Hash, uint64(tc.Index), nil). Return(nil, state.ErrNotFound). Once() }, @@ -2263,18 +1898,8 @@ func TestGetTransactionL2onByBlockHashAndIndex(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get transaction"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetTransactionByL2BlockHashAndIndex", context.Background(), tc.Hash, uint64(tc.Index), m.DbTx). + On("GetTransactionByL2BlockHashAndIndex", context.Background(), tc.Hash, uint64(tc.Index), nil). Return(nil, errors.New("failed to get transaction by block and index from state")). Once() }, @@ -2287,23 +1912,14 @@ func TestGetTransactionL2onByBlockHashAndIndex(t *testing.T) { ExpectedError: ethereum.NotFound, SetupMocks: func(m *mocksWrapper, tc testCase) { tx := ethTypes.NewTransaction(0, common.Address{}, big.NewInt(0), 0, big.NewInt(0), []byte{}) - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetTransactionByL2BlockHashAndIndex", context.Background(), tc.Hash, uint64(tc.Index), m.DbTx). + On("GetTransactionByL2BlockHashAndIndex", context.Background(), tc.Hash, uint64(tc.Index), nil). Return(tx, nil). Once() m.State. - On("GetTransactionReceipt", context.Background(), tx.Hash(), m.DbTx). + On("GetTransactionReceipt", context.Background(), tx.Hash(), nil). Return(nil, state.ErrNotFound). Once() }, @@ -2316,23 +1932,14 @@ func TestGetTransactionL2onByBlockHashAndIndex(t *testing.T) { ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get transaction receipt"), SetupMocks: func(m *mocksWrapper, tc testCase) { tx := ethTypes.NewTransaction(0, common.Address{}, big.NewInt(0), 0, big.NewInt(0), []byte{}) - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() m.State. - On("GetTransactionByL2BlockHashAndIndex", context.Background(), tc.Hash, uint64(tc.Index), m.DbTx). + On("GetTransactionByL2BlockHashAndIndex", context.Background(), tc.Hash, uint64(tc.Index), nil). Return(tx, nil). Once() m.State. - On("GetTransactionReceipt", context.Background(), tx.Hash(), m.DbTx). + On("GetTransactionReceipt", context.Background(), tx.Hash(), nil). Return(nil, errors.New("failed to get transaction receipt from state")). Once() }, @@ -2395,18 +2002,9 @@ func TestGetTransactionByBlockNumberAndIndex(t *testing.T) { SetupMocks: func(m *mocksWrapper, tc testCase) { tx := tc.ExpectedResult blockNumber, _ := encoding.DecodeUint64orHex(&tc.BlockNumber) - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetTransactionByL2BlockNumberAndIndex", context.Background(), blockNumber, uint64(tc.Index), m.DbTx). + On("GetTransactionByL2BlockNumberAndIndex", context.Background(), blockNumber, uint64(tc.Index), nil). Return(tx, nil). Once() @@ -2415,7 +2013,7 @@ func TestGetTransactionByBlockNumberAndIndex(t *testing.T) { receipt.BlockNumber = big.NewInt(1) receipt.TransactionIndex = tc.Index m.State. - On("GetTransactionReceipt", context.Background(), tx.Hash(), m.DbTx). + On("GetTransactionReceipt", context.Background(), tx.Hash(), nil). Return(receipt, nil). Once() }, @@ -2427,18 +2025,8 @@ func TestGetTransactionByBlockNumberAndIndex(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get the last block number from state"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(uint64(0), errors.New("failed to get last block number")). Once() }, @@ -2451,18 +2039,8 @@ func TestGetTransactionByBlockNumberAndIndex(t *testing.T) { ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { blockNumber, _ := encoding.DecodeUint64orHex(&tc.BlockNumber) - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetTransactionByL2BlockNumberAndIndex", context.Background(), blockNumber, uint64(tc.Index), m.DbTx). + On("GetTransactionByL2BlockNumberAndIndex", context.Background(), blockNumber, uint64(tc.Index), nil). Return(nil, state.ErrNotFound). Once() }, @@ -2475,18 +2053,8 @@ func TestGetTransactionByBlockNumberAndIndex(t *testing.T) { ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get transaction"), SetupMocks: func(m *mocksWrapper, tc testCase) { blockNumber, _ := encoding.DecodeUint64orHex(&tc.BlockNumber) - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetTransactionByL2BlockNumberAndIndex", context.Background(), blockNumber, uint64(tc.Index), m.DbTx). + On("GetTransactionByL2BlockNumberAndIndex", context.Background(), blockNumber, uint64(tc.Index), nil). Return(nil, errors.New("failed to get transaction by block and index from state")). Once() }, @@ -2501,23 +2069,13 @@ func TestGetTransactionByBlockNumberAndIndex(t *testing.T) { tx := ethTypes.NewTransaction(0, common.Address{}, big.NewInt(0), 0, big.NewInt(0), []byte{}) blockNumber, _ := encoding.DecodeUint64orHex(&tc.BlockNumber) - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetTransactionByL2BlockNumberAndIndex", context.Background(), blockNumber, uint64(tc.Index), m.DbTx). + On("GetTransactionByL2BlockNumberAndIndex", context.Background(), blockNumber, uint64(tc.Index), nil). Return(tx, nil). Once() m.State. - On("GetTransactionReceipt", context.Background(), tx.Hash(), m.DbTx). + On("GetTransactionReceipt", context.Background(), tx.Hash(), nil). Return(nil, state.ErrNotFound). Once() }, @@ -2532,23 +2090,13 @@ func TestGetTransactionByBlockNumberAndIndex(t *testing.T) { tx := ethTypes.NewTransaction(0, common.Address{}, big.NewInt(0), 0, big.NewInt(0), []byte{}) blockNumber, _ := encoding.DecodeUint64orHex(&tc.BlockNumber) - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetTransactionByL2BlockNumberAndIndex", context.Background(), blockNumber, uint64(tc.Index), m.DbTx). + On("GetTransactionByL2BlockNumberAndIndex", context.Background(), blockNumber, uint64(tc.Index), nil). Return(tx, nil). Once() m.State. - On("GetTransactionReceipt", context.Background(), tx.Hash(), m.DbTx). + On("GetTransactionReceipt", context.Background(), tx.Hash(), nil). Return(nil, errors.New("failed to get transaction receipt from state")). Once() }, @@ -2615,18 +2163,8 @@ func TestGetTransactionByHash(t *testing.T) { ExpectedResult: signedTx, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetTransactionByHash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByHash", context.Background(), tc.Hash, nil). Return(tc.ExpectedResult, nil). Once() @@ -2635,7 +2173,7 @@ func TestGetTransactionByHash(t *testing.T) { receipt.BlockNumber = big.NewInt(1) m.State. - On("GetTransactionReceipt", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionReceipt", context.Background(), tc.Hash, nil). Return(receipt, nil). Once() }, @@ -2647,18 +2185,8 @@ func TestGetTransactionByHash(t *testing.T) { ExpectedResult: ethTypes.NewTransaction(1, common.Address{}, big.NewInt(1), 1, big.NewInt(1), []byte{}), ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetTransactionByHash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByHash", context.Background(), tc.Hash, nil). Return(nil, state.ErrNotFound). Once() @@ -2675,18 +2203,8 @@ func TestGetTransactionByHash(t *testing.T) { ExpectedResult: nil, ExpectedError: ethereum.NotFound, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetTransactionByHash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByHash", context.Background(), tc.Hash, nil). Return(nil, state.ErrNotFound). Once() @@ -2703,18 +2221,8 @@ func TestGetTransactionByHash(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to load transaction by hash from state"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetTransactionByHash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByHash", context.Background(), tc.Hash, nil). Return(nil, errors.New("failed to load transaction by hash from state")). Once() }, @@ -2726,18 +2234,8 @@ func TestGetTransactionByHash(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to load transaction by hash from pool"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetTransactionByHash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByHash", context.Background(), tc.Hash, nil). Return(nil, state.ErrNotFound). Once() @@ -2755,23 +2253,14 @@ func TestGetTransactionByHash(t *testing.T) { ExpectedError: types.NewRPCError(types.DefaultErrorCode, "transaction receipt not found"), SetupMocks: func(m *mocksWrapper, tc testCase) { tx := ðTypes.Transaction{} - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() m.State. - On("GetTransactionByHash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByHash", context.Background(), tc.Hash, nil). Return(tx, nil). Once() m.State. - On("GetTransactionReceipt", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionReceipt", context.Background(), tc.Hash, nil). Return(nil, state.ErrNotFound). Once() }, @@ -2784,23 +2273,14 @@ func TestGetTransactionByHash(t *testing.T) { ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to load transaction receipt from state"), SetupMocks: func(m *mocksWrapper, tc testCase) { tx := ðTypes.Transaction{} - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetTransactionByHash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByHash", context.Background(), tc.Hash, nil). Return(tx, nil). Once() m.State. - On("GetTransactionReceipt", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionReceipt", context.Background(), tc.Hash, nil). Return(nil, errors.New("failed to load transaction receipt from state")). Once() }, @@ -2851,18 +2331,8 @@ func TestGetBlockTransactionCountByHash(t *testing.T) { ExpectedResult: uint(10), ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetL2BlockTransactionCountByHash", context.Background(), tc.BlockHash, m.DbTx). + On("GetL2BlockTransactionCountByHash", context.Background(), tc.BlockHash, nil). Return(uint64(10), nil). Once() }, @@ -2873,18 +2343,8 @@ func TestGetBlockTransactionCountByHash(t *testing.T) { ExpectedResult: 0, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to count transactions"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetL2BlockTransactionCountByHash", context.Background(), tc.BlockHash, m.DbTx). + On("GetL2BlockTransactionCountByHash", context.Background(), tc.BlockHash, nil). Return(uint64(0), errors.New("failed to count txs")). Once() }, @@ -2932,23 +2392,14 @@ func TestGetBlockTransactionCountByNumber(t *testing.T) { ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { blockNumber := uint64(10) - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(blockNumber, nil). Once() m.State. - On("GetL2BlockTransactionCountByNumber", context.Background(), blockNumber, m.DbTx). + On("GetL2BlockTransactionCountByNumber", context.Background(), blockNumber, nil). Return(uint64(10), nil). Once() }, @@ -2959,16 +2410,6 @@ func TestGetBlockTransactionCountByNumber(t *testing.T) { ExpectedResult: uint(10), ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.Pool. On("CountPendingTransactions", context.Background()). Return(uint64(10), nil). @@ -2981,18 +2422,8 @@ func TestGetBlockTransactionCountByNumber(t *testing.T) { ExpectedResult: 0, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get the last block number from state"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(uint64(0), errors.New("failed to get last block number")). Once() }, @@ -3004,23 +2435,14 @@ func TestGetBlockTransactionCountByNumber(t *testing.T) { ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to count transactions"), SetupMocks: func(m *mocksWrapper, tc testCase) { blockNumber := uint64(10) - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(blockNumber, nil). Once() m.State. - On("GetL2BlockTransactionCountByNumber", context.Background(), blockNumber, m.DbTx). + On("GetL2BlockTransactionCountByNumber", context.Background(), blockNumber, nil). Return(uint64(0), errors.New("failed to count")). Once() }, @@ -3031,16 +2453,6 @@ func TestGetBlockTransactionCountByNumber(t *testing.T) { ExpectedResult: 0, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to count pending transactions"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.Pool. On("CountPendingTransactions", context.Background()). Return(uint64(0), errors.New("failed to count")). @@ -3093,18 +2505,8 @@ func TestGetTransactionCount(t *testing.T) { ExpectedResult: uint(10), ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumTen, Root: blockRoot})) - m.State.On("GetLastL2Block", context.Background(), m.DbTx).Return(block, nil).Once() + m.State.On("GetLastL2Block", context.Background(), nil).Return(block, nil).Once() m.State. On("GetNonce", context.Background(), addressArg, blockRoot). @@ -3121,19 +2523,9 @@ func TestGetTransactionCount(t *testing.T) { ExpectedResult: uint(10), ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumTen, Root: blockRoot})) m.State. - On("GetL2BlockByHash", context.Background(), blockHash, m.DbTx). + On("GetL2BlockByHash", context.Background(), blockHash, nil). Return(block, nil). Once() @@ -3152,23 +2544,13 @@ func TestGetTransactionCount(t *testing.T) { ExpectedResult: 0, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(blockNumTen.Uint64(), nil). Once() block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumTen, Root: blockRoot})) - m.State.On("GetL2BlockByNumber", context.Background(), blockNumTenUint64, m.DbTx).Return(block, nil).Once() + m.State.On("GetL2BlockByNumber", context.Background(), blockNumTenUint64, nil).Return(block, nil).Once() m.State. On("GetNonce", context.Background(), addressArg, blockRoot). @@ -3185,18 +2567,8 @@ func TestGetTransactionCount(t *testing.T) { ExpectedResult: 0, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get the last block number from state"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(uint64(0), errors.New("failed to get last block number")). Once() }, @@ -3210,23 +2582,13 @@ func TestGetTransactionCount(t *testing.T) { ExpectedResult: 0, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to count transactions"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(blockNumTen.Uint64(), nil). Once() block := state.NewL2BlockWithHeader(state.NewL2Header(ðTypes.Header{Number: blockNumTen, Root: blockRoot})) - m.State.On("GetL2BlockByNumber", context.Background(), blockNumTenUint64, m.DbTx).Return(block, nil).Once() + m.State.On("GetL2BlockByNumber", context.Background(), blockNumTenUint64, nil).Return(block, nil).Once() m.State. On("GetNonce", context.Background(), addressArg, blockRoot). @@ -3336,23 +2698,13 @@ func TestGetTransactionReceipt(t *testing.T) { ExpectedResult: &rpcReceipt, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetTransactionByHash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByHash", context.Background(), tc.Hash, nil). Return(signedTx, nil). Once() m.State. - On("GetTransactionReceipt", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionReceipt", context.Background(), tc.Hash, nil). Return(receipt, nil). Once() }, @@ -3363,18 +2715,8 @@ func TestGetTransactionReceipt(t *testing.T) { ExpectedResult: nil, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetTransactionByHash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByHash", context.Background(), tc.Hash, nil). Return(nil, state.ErrNotFound). Once() }, @@ -3385,18 +2727,8 @@ func TestGetTransactionReceipt(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get tx from state"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetTransactionByHash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByHash", context.Background(), tc.Hash, nil). Return(nil, errors.New("failed to get tx")). Once() }, @@ -3407,23 +2739,13 @@ func TestGetTransactionReceipt(t *testing.T) { ExpectedResult: nil, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetTransactionByHash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByHash", context.Background(), tc.Hash, nil). Return(signedTx, nil). Once() m.State. - On("GetTransactionReceipt", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionReceipt", context.Background(), tc.Hash, nil). Return(nil, state.ErrNotFound). Once() }, @@ -3434,23 +2756,13 @@ func TestGetTransactionReceipt(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get tx receipt from state"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetTransactionByHash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByHash", context.Background(), tc.Hash, nil). Return(signedTx, nil). Once() m.State. - On("GetTransactionReceipt", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionReceipt", context.Background(), tc.Hash, nil). Return(nil, errors.New("failed to get tx receipt from state")). Once() }, @@ -3461,23 +2773,13 @@ func TestGetTransactionReceipt(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to build the receipt response"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetTransactionByHash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByHash", context.Background(), tc.Hash, nil). Return(tx, nil). Once() m.State. - On("GetTransactionReceipt", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionReceipt", context.Background(), tc.Hash, nil). Return(receipt, nil). Once() }, @@ -3852,16 +3154,6 @@ func TestNewFilter(t *testing.T) { ExpectedResult: "1", ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.Storage. On("NewLogFilter", mock.IsType(&concurrentWsConn{}), mock.IsType(LogFilter{})). Return("1", nil). @@ -3876,16 +3168,6 @@ func TestNewFilter(t *testing.T) { ExpectedResult: "1", ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.Storage. On("NewLogFilter", mock.IsType(&concurrentWsConn{}), mock.IsType(LogFilter{})). Return("1", nil). @@ -3901,15 +3183,6 @@ func TestNewFilter(t *testing.T) { ExpectedResult: "", ExpectedError: types.NewRPCError(types.InvalidParamsErrorCode, "invalid block range"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() }, }, { @@ -3921,15 +3194,6 @@ func TestNewFilter(t *testing.T) { ExpectedResult: "", ExpectedError: types.NewRPCError(types.InvalidParamsErrorCode, "logs are limited to a 10000 block range"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() }, }, { @@ -3940,15 +3204,6 @@ func TestNewFilter(t *testing.T) { ExpectedResult: "", ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to create new log filter"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() m.Storage. On("NewLogFilter", mock.IsType(&concurrentWsConn{}), mock.IsType(LogFilter{})). Return("", errors.New("failed to add new filter")). @@ -4215,18 +3470,8 @@ func TestGetLogs(t *testing.T) { logs = append(logs, &l) } - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetLogs", context.Background(), tc.Filter.FromBlock.Uint64(), tc.Filter.ToBlock.Uint64(), tc.Filter.Addresses, tc.Filter.Topics, tc.Filter.BlockHash, since, m.DbTx). + On("GetLogs", context.Background(), tc.Filter.FromBlock.Uint64(), tc.Filter.ToBlock.Uint64(), tc.Filter.Addresses, tc.Filter.Topics, tc.Filter.BlockHash, since, nil). Return(logs, nil). Once() }, @@ -4244,18 +3489,8 @@ func TestGetLogs(t *testing.T) { }, SetupMocks: func(m *mocksWrapper, tc testCase) { var since *time.Time - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLogs", context.Background(), tc.Filter.FromBlock.Uint64(), tc.Filter.ToBlock.Uint64(), tc.Filter.Addresses, tc.Filter.Topics, tc.Filter.BlockHash, since, m.DbTx). + On("GetLogs", context.Background(), tc.Filter.FromBlock.Uint64(), tc.Filter.ToBlock.Uint64(), tc.Filter.Addresses, tc.Filter.Topics, tc.Filter.BlockHash, since, nil). Return(nil, errors.New("failed to get logs from state")). Once() }, @@ -4272,18 +3507,8 @@ func TestGetLogs(t *testing.T) { tc.ExpectedError = types.NewRPCError(types.DefaultErrorCode, "failed to get the last block number from state") }, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(uint64(0), errors.New("failed to get last block number from state")). Once() }, @@ -4300,18 +3525,8 @@ func TestGetLogs(t *testing.T) { tc.ExpectedError = types.NewRPCError(types.DefaultErrorCode, "failed to get the last block number from state") }, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(uint64(0), errors.New("failed to get last block number from state")). Once() }, @@ -4328,15 +3543,6 @@ func TestGetLogs(t *testing.T) { tc.ExpectedError = types.NewRPCError(types.InvalidParamsErrorCode, "logs are limited to a 10000 block range") }, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() }, }, { @@ -4352,18 +3558,9 @@ func TestGetLogs(t *testing.T) { }, SetupMocks: func(m *mocksWrapper, tc testCase) { var since *time.Time - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetLogs", context.Background(), tc.Filter.FromBlock.Uint64(), tc.Filter.ToBlock.Uint64(), tc.Filter.Addresses, tc.Filter.Topics, tc.Filter.BlockHash, since, m.DbTx). + On("GetLogs", context.Background(), tc.Filter.FromBlock.Uint64(), tc.Filter.ToBlock.Uint64(), tc.Filter.Addresses, tc.Filter.Topics, tc.Filter.BlockHash, since, nil). Return(nil, state.ErrMaxLogsCountLimitExceeded). Once() }, @@ -4444,23 +3641,13 @@ func TestGetFilterLogs(t *testing.T) { Parameters: logFilter, } - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.Storage. On("GetFilter", tc.FilterID). Return(filter, nil). Once() m.State. - On("GetLogs", context.Background(), uint64(*logFilter.FromBlock), uint64(*logFilter.ToBlock), logFilter.Addresses, logFilter.Topics, logFilter.BlockHash, since, m.DbTx). + On("GetLogs", context.Background(), uint64(*logFilter.FromBlock), uint64(*logFilter.ToBlock), logFilter.Addresses, logFilter.Topics, logFilter.BlockHash, since, nil). Return(logs, nil). Once() }, @@ -5198,16 +4385,6 @@ func TestSubscribeNewLogs(t *testing.T) { } }, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.Storage. On("NewLogFilter", mock.IsType(&concurrentWsConn{}), mock.IsType(LogFilter{})). Return("0x1", nil). @@ -5223,16 +4400,6 @@ func TestSubscribeNewLogs(t *testing.T) { } }, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.Storage. On("NewLogFilter", mock.IsType(&concurrentWsConn{}), mock.IsType(LogFilter{})). Return("", fmt.Errorf("failed to add filter to storage")). @@ -5248,15 +4415,6 @@ func TestSubscribeNewLogs(t *testing.T) { } }, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() }, }, } diff --git a/jsonrpc/endpoints_zkevm.go b/jsonrpc/endpoints_zkevm.go index cb56dade8c..f159885cb3 100644 --- a/jsonrpc/endpoints_zkevm.go +++ b/jsonrpc/endpoints_zkevm.go @@ -27,7 +27,6 @@ type ZKEVMEndpoints struct { pool types.PoolInterface state types.StateInterface etherman types.EthermanInterface - txMan DBTxManager } // NewZKEVMEndpoints returns ZKEVMEndpoints @@ -42,359 +41,346 @@ func NewZKEVMEndpoints(cfg Config, pool types.PoolInterface, state types.StateIn // ConsolidatedBlockNumber returns last block number related to the last verified batch func (z *ZKEVMEndpoints) ConsolidatedBlockNumber() (interface{}, types.Error) { - return z.txMan.NewDbTxScope(z.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - lastBlockNumber, err := z.state.GetLastConsolidatedL2BlockNumber(ctx, dbTx) - if err != nil { - const errorMessage = "failed to get last consolidated block number from state" - log.Errorf("%v:%v", errorMessage, err) - return nil, types.NewRPCError(types.DefaultErrorCode, errorMessage) - } + ctx := context.Background() + lastBlockNumber, err := z.state.GetLastConsolidatedL2BlockNumber(ctx, nil) + if err != nil { + const errorMessage = "failed to get last consolidated block number from state" + log.Errorf("%v:%v", errorMessage, err) + return nil, types.NewRPCError(types.DefaultErrorCode, errorMessage) + } - return hex.EncodeUint64(lastBlockNumber), nil - }) + return hex.EncodeUint64(lastBlockNumber), nil } // IsBlockConsolidated returns the consolidation status of a provided block number func (z *ZKEVMEndpoints) IsBlockConsolidated(blockNumber types.ArgUint64) (interface{}, types.Error) { - return z.txMan.NewDbTxScope(z.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - IsL2BlockConsolidated, err := z.state.IsL2BlockConsolidated(ctx, uint64(blockNumber), dbTx) - if err != nil { - const errorMessage = "failed to check if the block is consolidated" - log.Errorf("%v: %v", errorMessage, err) - return nil, types.NewRPCError(types.DefaultErrorCode, errorMessage) - } + ctx := context.Background() + IsL2BlockConsolidated, err := z.state.IsL2BlockConsolidated(ctx, uint64(blockNumber), nil) + if err != nil { + const errorMessage = "failed to check if the block is consolidated" + log.Errorf("%v: %v", errorMessage, err) + return nil, types.NewRPCError(types.DefaultErrorCode, errorMessage) + } - return IsL2BlockConsolidated, nil - }) + return IsL2BlockConsolidated, nil } // IsBlockVirtualized returns the virtualization status of a provided block number func (z *ZKEVMEndpoints) IsBlockVirtualized(blockNumber types.ArgUint64) (interface{}, types.Error) { - return z.txMan.NewDbTxScope(z.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - IsL2BlockVirtualized, err := z.state.IsL2BlockVirtualized(ctx, uint64(blockNumber), dbTx) - if err != nil { - const errorMessage = "failed to check if the block is virtualized" - log.Errorf("%v: %v", errorMessage, err) - return nil, types.NewRPCError(types.DefaultErrorCode, errorMessage) - } + ctx := context.Background() + IsL2BlockVirtualized, err := z.state.IsL2BlockVirtualized(ctx, uint64(blockNumber), nil) + if err != nil { + const errorMessage = "failed to check if the block is virtualized" + log.Errorf("%v: %v", errorMessage, err) + return nil, types.NewRPCError(types.DefaultErrorCode, errorMessage) + } - return IsL2BlockVirtualized, nil - }) + return IsL2BlockVirtualized, nil } // BatchNumberByBlockNumber returns the batch number from which the passed block number is created func (z *ZKEVMEndpoints) BatchNumberByBlockNumber(blockNumber types.ArgUint64) (interface{}, types.Error) { - return z.txMan.NewDbTxScope(z.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - batchNum, err := z.state.BatchNumberByL2BlockNumber(ctx, uint64(blockNumber), dbTx) - if errors.Is(err, state.ErrNotFound) { - return nil, nil - } else if err != nil { - const errorMessage = "failed to get batch number from block number" - log.Errorf("%v: %v", errorMessage, err.Error()) - return nil, types.NewRPCError(types.DefaultErrorCode, errorMessage) - } + ctx := context.Background() + batchNum, err := z.state.BatchNumberByL2BlockNumber(ctx, uint64(blockNumber), nil) + if errors.Is(err, state.ErrNotFound) { + return nil, nil + } else if err != nil { + const errorMessage = "failed to get batch number from block number" + log.Errorf("%v: %v", errorMessage, err.Error()) + return nil, types.NewRPCError(types.DefaultErrorCode, errorMessage) + } - return hex.EncodeUint64(batchNum), nil - }) + return hex.EncodeUint64(batchNum), nil } // BatchNumber returns the latest trusted batch number func (z *ZKEVMEndpoints) BatchNumber() (interface{}, types.Error) { - return z.txMan.NewDbTxScope(z.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - lastBatchNumber, err := z.state.GetLastBatchNumber(ctx, dbTx) - if err != nil { - return "0x0", types.NewRPCError(types.DefaultErrorCode, "failed to get the last batch number from state") - } + ctx := context.Background() + lastBatchNumber, err := z.state.GetLastBatchNumber(ctx, nil) + if err != nil { + return "0x0", types.NewRPCError(types.DefaultErrorCode, "failed to get the last batch number from state") + } - return hex.EncodeUint64(lastBatchNumber), nil - }) + return hex.EncodeUint64(lastBatchNumber), nil } // VirtualBatchNumber returns the latest virtualized batch number func (z *ZKEVMEndpoints) VirtualBatchNumber() (interface{}, types.Error) { - return z.txMan.NewDbTxScope(z.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - lastBatchNumber, err := z.state.GetLastVirtualBatchNum(ctx, dbTx) - if err != nil { - return "0x0", types.NewRPCError(types.DefaultErrorCode, "failed to get the last virtual batch number from state") - } + ctx := context.Background() + lastBatchNumber, err := z.state.GetLastVirtualBatchNum(ctx, nil) + if err != nil { + return "0x0", types.NewRPCError(types.DefaultErrorCode, "failed to get the last virtual batch number from state") + } - return hex.EncodeUint64(lastBatchNumber), nil - }) + return hex.EncodeUint64(lastBatchNumber), nil } // VerifiedBatchNumber returns the latest verified batch number func (z *ZKEVMEndpoints) VerifiedBatchNumber() (interface{}, types.Error) { - return z.txMan.NewDbTxScope(z.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - lastBatch, err := z.state.GetLastVerifiedBatch(ctx, dbTx) - if err != nil { - return "0x0", types.NewRPCError(types.DefaultErrorCode, "failed to get the last verified batch number from state") - } - return hex.EncodeUint64(lastBatch.BatchNumber), nil - }) + ctx := context.Background() + lastBatch, err := z.state.GetLastVerifiedBatch(ctx, nil) + if err != nil { + return "0x0", types.NewRPCError(types.DefaultErrorCode, "failed to get the last verified batch number from state") + } + return hex.EncodeUint64(lastBatch.BatchNumber), nil } // GetBatchByNumber returns information about a batch by batch number func (z *ZKEVMEndpoints) GetBatchByNumber(batchNumber types.BatchNumber, fullTx bool) (interface{}, types.Error) { - return z.txMan.NewDbTxScope(z.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - var err error - batchNumber, rpcErr := batchNumber.GetNumericBatchNumber(ctx, z.state, z.etherman, dbTx) - if rpcErr != nil { - return nil, rpcErr - } + ctx := context.Background() + var err error + numericBatchNumber, rpcErr := batchNumber.GetNumericBatchNumber(ctx, z.state, z.etherman, nil) + if rpcErr != nil { + return nil, rpcErr + } - batch, err := z.state.GetBatchByNumber(ctx, batchNumber, dbTx) - if errors.Is(err, state.ErrNotFound) { - return nil, nil - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load batch from state by number %v", batchNumber), err, true) - } - batchTimestamp, err := z.state.GetBatchTimestamp(ctx, batchNumber, nil, dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load batch timestamp from state by number %v", batchNumber), err, true) - } + batch, err := z.state.GetBatchByNumber(ctx, numericBatchNumber, nil) + if errors.Is(err, state.ErrNotFound) { + return nil, nil + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load batch from state by number %v", numericBatchNumber), err, true) + } + batchTimestamp, err := z.state.GetBatchTimestamp(ctx, numericBatchNumber, nil, nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load batch timestamp from state by number %v", numericBatchNumber), err, true) + } - if batchTimestamp == nil { - batch.Timestamp = time.Time{} - } else { - batch.Timestamp = *batchTimestamp - } + if batchTimestamp == nil { + batch.Timestamp = time.Time{} + } else { + batch.Timestamp = *batchTimestamp + } - txs, _, err := z.state.GetTransactionsByBatchNumber(ctx, batchNumber, dbTx) - if !errors.Is(err, state.ErrNotFound) && err != nil { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load batch txs from state by number %v", batchNumber), err, true) - } + txs, _, err := z.state.GetTransactionsByBatchNumber(ctx, numericBatchNumber, nil) + if !errors.Is(err, state.ErrNotFound) && err != nil { + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load batch txs from state by number %v", numericBatchNumber), err, true) + } - receipts := make([]ethTypes.Receipt, 0, len(txs)) - for _, tx := range txs { - receipt, err := z.state.GetTransactionReceipt(ctx, tx.Hash(), dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load receipt for tx %v", tx.Hash().String()), err, true) - } - receipts = append(receipts, *receipt) + receipts := make([]ethTypes.Receipt, 0, len(txs)) + for _, tx := range txs { + receipt, err := z.state.GetTransactionReceipt(ctx, tx.Hash(), nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load receipt for tx %v", tx.Hash().String()), err, true) } + receipts = append(receipts, *receipt) + } - virtualBatch, err := z.state.GetVirtualBatch(ctx, batchNumber, dbTx) - if err != nil && !errors.Is(err, state.ErrNotFound) { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load virtual batch from state by number %v", batchNumber), err, true) - } + virtualBatch, err := z.state.GetVirtualBatch(ctx, numericBatchNumber, nil) + if err != nil && !errors.Is(err, state.ErrNotFound) { + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load virtual batch from state by number %v", numericBatchNumber), err, true) + } - verifiedBatch, err := z.state.GetVerifiedBatch(ctx, batchNumber, dbTx) - if err != nil && !errors.Is(err, state.ErrNotFound) { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load virtual batch from state by number %v", batchNumber), err, true) - } + verifiedBatch, err := z.state.GetVerifiedBatch(ctx, numericBatchNumber, nil) + if err != nil && !errors.Is(err, state.ErrNotFound) { + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load virtual batch from state by number %v", numericBatchNumber), err, true) + } - ger, err := z.state.GetExitRootByGlobalExitRoot(ctx, batch.GlobalExitRoot, dbTx) - if err != nil && !errors.Is(err, state.ErrNotFound) { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load full GER from state by number %v", batchNumber), err, true) - } else if errors.Is(err, state.ErrNotFound) { - ger = &state.GlobalExitRoot{} - } + ger, err := z.state.GetExitRootByGlobalExitRoot(ctx, batch.GlobalExitRoot, nil) + if err != nil && !errors.Is(err, state.ErrNotFound) { + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load full GER from state by number %v", numericBatchNumber), err, true) + } else if errors.Is(err, state.ErrNotFound) { + ger = &state.GlobalExitRoot{} + } - blocks, err := z.state.GetL2BlocksByBatchNumber(ctx, batchNumber, dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load blocks associated to the batch %v", batchNumber), err, true) - } + blocks, err := z.state.GetL2BlocksByBatchNumber(ctx, numericBatchNumber, nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load blocks associated to the batch %v", numericBatchNumber), err, true) + } - batch.Transactions = txs - rpcBatch, err := types.NewBatch(ctx, z.state, batch, virtualBatch, verifiedBatch, blocks, receipts, fullTx, true, ger, dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't build the batch %v response", batchNumber), err, true) - } - return rpcBatch, nil - }) + batch.Transactions = txs + rpcBatch, err := types.NewBatch(ctx, z.state, batch, virtualBatch, verifiedBatch, blocks, receipts, fullTx, true, ger, nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't build the batch %v response", numericBatchNumber), err, true) + } + return rpcBatch, nil } // GetFullBlockByNumber returns information about a block by block number func (z *ZKEVMEndpoints) GetFullBlockByNumber(number types.BlockNumber, fullTx bool) (interface{}, types.Error) { - return z.txMan.NewDbTxScope(z.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - if number == types.PendingBlockNumber { - lastBlock, err := z.state.GetLastL2Block(ctx, dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "couldn't load last block from state to compute the pending block", err, true) - } - l2Header := state.NewL2Header(ðTypes.Header{ - ParentHash: lastBlock.Hash(), - Number: big.NewInt(0).SetUint64(lastBlock.Number().Uint64() + 1), - TxHash: ethTypes.EmptyRootHash, - UncleHash: ethTypes.EmptyUncleHash, - }) - l2Block := state.NewL2BlockWithHeader(l2Header) - rpcBlock, err := types.NewBlock(ctx, z.state, nil, l2Block, nil, fullTx, false, state.Ptr(true), dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "couldn't build the pending block response", err, true) - } - - // clean fields that are not available for pending block - rpcBlock.Hash = nil - rpcBlock.Miner = nil - rpcBlock.Nonce = nil - rpcBlock.TotalDifficulty = nil - - return rpcBlock, nil - } - var err error - blockNumber, rpcErr := number.GetNumericBlockNumber(ctx, z.state, z.etherman, dbTx) - if rpcErr != nil { - return nil, rpcErr + ctx := context.Background() + if number == types.PendingBlockNumber { + lastBlock, err := z.state.GetLastL2Block(ctx, nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "couldn't load last block from state to compute the pending block", err, true) + } + l2Header := state.NewL2Header(ðTypes.Header{ + ParentHash: lastBlock.Hash(), + Number: big.NewInt(0).SetUint64(lastBlock.Number().Uint64() + 1), + TxHash: ethTypes.EmptyRootHash, + UncleHash: ethTypes.EmptyUncleHash, + }) + l2Block := state.NewL2BlockWithHeader(l2Header) + rpcBlock, err := types.NewBlock(ctx, z.state, nil, l2Block, nil, fullTx, false, state.Ptr(true), nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "couldn't build the pending block response", err, true) } - l2Block, err := z.state.GetL2BlockByNumber(ctx, blockNumber, dbTx) - if errors.Is(err, state.ErrNotFound) { - return nil, nil - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load block from state by number %v", blockNumber), err, true) - } + // clean fields that are not available for pending block + rpcBlock.Hash = nil + rpcBlock.Miner = nil + rpcBlock.Nonce = nil + rpcBlock.TotalDifficulty = nil - txs := l2Block.Transactions() - receipts := make([]ethTypes.Receipt, 0, len(txs)) - for _, tx := range txs { - receipt, err := z.state.GetTransactionReceipt(ctx, tx.Hash(), dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load receipt for tx %v", tx.Hash().String()), err, true) - } - receipts = append(receipts, *receipt) - } + return rpcBlock, nil + } + var err error + blockNumber, rpcErr := number.GetNumericBlockNumber(ctx, z.state, z.etherman, nil) + if rpcErr != nil { + return nil, rpcErr + } + + l2Block, err := z.state.GetL2BlockByNumber(ctx, blockNumber, nil) + if errors.Is(err, state.ErrNotFound) { + return nil, nil + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load block from state by number %v", blockNumber), err, true) + } - rpcBlock, err := types.NewBlock(ctx, z.state, state.Ptr(l2Block.Hash()), l2Block, receipts, fullTx, true, state.Ptr(true), dbTx) + txs := l2Block.Transactions() + receipts := make([]ethTypes.Receipt, 0, len(txs)) + for _, tx := range txs { + receipt, err := z.state.GetTransactionReceipt(ctx, tx.Hash(), nil) if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't build block response for block by number %v", blockNumber), err, true) + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load receipt for tx %v", tx.Hash().String()), err, true) } + receipts = append(receipts, *receipt) + } - return rpcBlock, nil - }) + rpcBlock, err := types.NewBlock(ctx, z.state, state.Ptr(l2Block.Hash()), l2Block, receipts, fullTx, true, state.Ptr(true), nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't build block response for block by number %v", blockNumber), err, true) + } + + return rpcBlock, nil } // GetFullBlockByHash returns information about a block by hash func (z *ZKEVMEndpoints) GetFullBlockByHash(hash types.ArgHash, fullTx bool) (interface{}, types.Error) { - return z.txMan.NewDbTxScope(z.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - l2Block, err := z.state.GetL2BlockByHash(ctx, hash.Hash(), dbTx) - if errors.Is(err, state.ErrNotFound) { - return nil, nil - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get block by hash from state", err, true) - } - - txs := l2Block.Transactions() - receipts := make([]ethTypes.Receipt, 0, len(txs)) - for _, tx := range txs { - receipt, err := z.state.GetTransactionReceipt(ctx, tx.Hash(), dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load receipt for tx %v", tx.Hash().String()), err, true) - } - receipts = append(receipts, *receipt) - } + ctx := context.Background() + l2Block, err := z.state.GetL2BlockByHash(ctx, hash.Hash(), nil) + if errors.Is(err, state.ErrNotFound) { + return nil, nil + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get block by hash from state", err, true) + } - rpcBlock, err := types.NewBlock(ctx, z.state, state.Ptr(l2Block.Hash()), l2Block, receipts, fullTx, true, state.Ptr(true), dbTx) + txs := l2Block.Transactions() + receipts := make([]ethTypes.Receipt, 0, len(txs)) + for _, tx := range txs { + receipt, err := z.state.GetTransactionReceipt(ctx, tx.Hash(), nil) if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't build block response for block by hash %v", hash.Hash()), err, true) + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't load receipt for tx %v", tx.Hash().String()), err, true) } + receipts = append(receipts, *receipt) + } - return rpcBlock, nil - }) + rpcBlock, err := types.NewBlock(ctx, z.state, state.Ptr(l2Block.Hash()), l2Block, receipts, fullTx, true, state.Ptr(true), nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, fmt.Sprintf("couldn't build block response for block by hash %v", hash.Hash()), err, true) + } + + return rpcBlock, nil } // GetNativeBlockHashesInRange return the state root for the blocks in range func (z *ZKEVMEndpoints) GetNativeBlockHashesInRange(filter NativeBlockHashBlockRangeFilter) (interface{}, types.Error) { - return z.txMan.NewDbTxScope(z.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - fromBlockNumber, toBlockNumber, rpcErr := filter.GetNumericBlockNumbers(ctx, z.cfg, z.state, z.etherman, dbTx) - if rpcErr != nil { - return nil, rpcErr - } + ctx := context.Background() + fromBlockNumber, toBlockNumber, rpcErr := filter.GetNumericBlockNumbers(ctx, z.cfg, z.state, z.etherman, nil) + if rpcErr != nil { + return nil, rpcErr + } - nativeBlockHashes, err := z.state.GetNativeBlockHashesInRange(ctx, fromBlockNumber, toBlockNumber, dbTx) - if errors.Is(err, state.ErrNotFound) { - return nil, nil - } else if errors.Is(err, state.ErrMaxNativeBlockHashBlockRangeLimitExceeded) { - errMsg := fmt.Sprintf(state.ErrMaxNativeBlockHashBlockRangeLimitExceeded.Error(), z.cfg.MaxNativeBlockHashBlockRange) - return RPCErrorResponse(types.InvalidParamsErrorCode, errMsg, nil, false) - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get block by hash from state", err, true) - } + nativeBlockHashes, err := z.state.GetNativeBlockHashesInRange(ctx, fromBlockNumber, toBlockNumber, nil) + if errors.Is(err, state.ErrNotFound) { + return nil, nil + } else if errors.Is(err, state.ErrMaxNativeBlockHashBlockRangeLimitExceeded) { + errMsg := fmt.Sprintf(state.ErrMaxNativeBlockHashBlockRangeLimitExceeded.Error(), z.cfg.MaxNativeBlockHashBlockRange) + return RPCErrorResponse(types.InvalidParamsErrorCode, errMsg, nil, false) + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get block by hash from state", err, true) + } - return nativeBlockHashes, nil - }) + return nativeBlockHashes, nil } // GetTransactionByL2Hash returns a transaction by his l2 hash func (z *ZKEVMEndpoints) GetTransactionByL2Hash(hash types.ArgHash) (interface{}, types.Error) { - return z.txMan.NewDbTxScope(z.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - // try to get tx from state - tx, err := z.state.GetTransactionByL2Hash(ctx, hash.Hash(), dbTx) - if err != nil && !errors.Is(err, state.ErrNotFound) { - return RPCErrorResponse(types.DefaultErrorCode, "failed to load transaction by l2 hash from state", err, true) + ctx := context.Background() + // try to get tx from state + tx, err := z.state.GetTransactionByL2Hash(ctx, hash.Hash(), nil) + if err != nil && !errors.Is(err, state.ErrNotFound) { + return RPCErrorResponse(types.DefaultErrorCode, "failed to load transaction by l2 hash from state", err, true) + } + if tx != nil { + receipt, err := z.state.GetTransactionReceipt(ctx, hash.Hash(), nil) + if errors.Is(err, state.ErrNotFound) { + return RPCErrorResponse(types.DefaultErrorCode, "transaction receipt not found", err, false) + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to load transaction receipt from state", err, true) } - if tx != nil { - receipt, err := z.state.GetTransactionReceipt(ctx, hash.Hash(), dbTx) - if errors.Is(err, state.ErrNotFound) { - return RPCErrorResponse(types.DefaultErrorCode, "transaction receipt not found", err, false) - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to load transaction receipt from state", err, true) - } - l2Hash, err := z.state.GetL2TxHashByTxHash(ctx, tx.Hash(), dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get l2 transaction hash", err, true) - } - - res, err := types.NewTransaction(*tx, receipt, false, l2Hash) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to build transaction response", err, true) - } - - return res, nil + l2Hash, err := z.state.GetL2TxHashByTxHash(ctx, tx.Hash(), nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get l2 transaction hash", err, true) } - // if the tx does not exist in the state, look for it in the pool - if z.cfg.SequencerNodeURI != "" { - return z.getTransactionByL2HashFromSequencerNode(hash.Hash()) - } - poolTx, err := z.pool.GetTransactionByL2Hash(ctx, hash.Hash()) - if errors.Is(err, pool.ErrNotFound) { - return nil, nil - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to load transaction by l2 hash from pool", err, true) - } - if poolTx.Status == pool.TxStatusPending { - tx = &poolTx.Transaction - res, err := types.NewTransaction(*tx, nil, false, nil) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to build transaction response", err, true) - } - return res, nil + res, err := types.NewTransaction(*tx, receipt, false, l2Hash) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to build transaction response", err, true) } + + return res, nil + } + + // if the tx does not exist in the state, look for it in the pool + if z.cfg.SequencerNodeURI != "" { + return z.getTransactionByL2HashFromSequencerNode(hash.Hash()) + } + poolTx, err := z.pool.GetTransactionByL2Hash(ctx, hash.Hash()) + if errors.Is(err, pool.ErrNotFound) { return nil, nil - }) + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to load transaction by l2 hash from pool", err, true) + } + if poolTx.Status == pool.TxStatusPending { + tx = &poolTx.Transaction + res, err := types.NewTransaction(*tx, nil, false, nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to build transaction response", err, true) + } + return res, nil + } + return nil, nil } // GetTransactionReceiptByL2Hash returns a transaction receipt by his hash func (z *ZKEVMEndpoints) GetTransactionReceiptByL2Hash(hash types.ArgHash) (interface{}, types.Error) { - return z.txMan.NewDbTxScope(z.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - tx, err := z.state.GetTransactionByL2Hash(ctx, hash.Hash(), dbTx) - if errors.Is(err, state.ErrNotFound) { - return nil, nil - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get tx from state", err, true) - } + ctx := context.Background() + tx, err := z.state.GetTransactionByL2Hash(ctx, hash.Hash(), nil) + if errors.Is(err, state.ErrNotFound) { + return nil, nil + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get tx from state", err, true) + } - r, err := z.state.GetTransactionReceipt(ctx, hash.Hash(), dbTx) - if errors.Is(err, state.ErrNotFound) { - return nil, nil - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get tx receipt from state", err, true) - } + r, err := z.state.GetTransactionReceipt(ctx, hash.Hash(), nil) + if errors.Is(err, state.ErrNotFound) { + return nil, nil + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get tx receipt from state", err, true) + } - l2Hash, err := z.state.GetL2TxHashByTxHash(ctx, tx.Hash(), dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get l2 transaction hash", err, true) - } + l2Hash, err := z.state.GetL2TxHashByTxHash(ctx, tx.Hash(), nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get l2 transaction hash", err, true) + } - receipt, err := types.NewReceipt(*tx, r, l2Hash) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to build the receipt response", err, true) - } + receipt, err := types.NewReceipt(*tx, r, l2Hash) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to build the receipt response", err, true) + } - return receipt, nil - }) + return receipt, nil } func (z *ZKEVMEndpoints) getTransactionByL2HashFromSequencerNode(hash common.Hash) (interface{}, types.Error) { @@ -417,43 +403,40 @@ func (z *ZKEVMEndpoints) getTransactionByL2HashFromSequencerNode(hash common.Has // GetExitRootsByGER returns the exit roots accordingly to the provided Global Exit Root func (z *ZKEVMEndpoints) GetExitRootsByGER(globalExitRoot common.Hash) (interface{}, types.Error) { - return z.txMan.NewDbTxScope(z.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - exitRoots, err := z.state.GetExitRootByGlobalExitRoot(ctx, globalExitRoot, dbTx) - if errors.Is(err, state.ErrNotFound) { - return nil, nil - } else if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to get exit roots by global exit root from state", err, true) - } + ctx := context.Background() + exitRoots, err := z.state.GetExitRootByGlobalExitRoot(ctx, globalExitRoot, nil) + if errors.Is(err, state.ErrNotFound) { + return nil, nil + } else if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to get exit roots by global exit root from state", err, true) + } - return types.ExitRoots{ - BlockNumber: types.ArgUint64(exitRoots.BlockNumber), - Timestamp: types.ArgUint64(exitRoots.Timestamp.Unix()), - MainnetExitRoot: exitRoots.MainnetExitRoot, - RollupExitRoot: exitRoots.RollupExitRoot, - }, nil - }) + return types.ExitRoots{ + BlockNumber: types.ArgUint64(exitRoots.BlockNumber), + Timestamp: types.ArgUint64(exitRoots.Timestamp.Unix()), + MainnetExitRoot: exitRoots.MainnetExitRoot, + RollupExitRoot: exitRoots.RollupExitRoot, + }, nil } // EstimateGasPrice returns an estimate gas price for the transaction. func (z *ZKEVMEndpoints) EstimateGasPrice(arg *types.TxArgs, blockArg *types.BlockNumberOrHash) (interface{}, types.Error) { - return z.txMan.NewDbTxScope(z.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - gasPrice, _, err := z.internalEstimateGasPriceAndFee(ctx, arg, blockArg, dbTx) - if err != nil { - return nil, err - } - return hex.EncodeBig(gasPrice), nil - }) + ctx := context.Background() + gasPrice, _, err := z.internalEstimateGasPriceAndFee(ctx, arg, blockArg, nil) + if err != nil { + return nil, err + } + return hex.EncodeBig(gasPrice), nil } // EstimateFee returns an estimate fee for the transaction. func (z *ZKEVMEndpoints) EstimateFee(arg *types.TxArgs, blockArg *types.BlockNumberOrHash) (interface{}, types.Error) { - return z.txMan.NewDbTxScope(z.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - _, fee, err := z.internalEstimateGasPriceAndFee(ctx, arg, blockArg, dbTx) - if err != nil { - return nil, err - } - return hex.EncodeBig(fee), nil - }) + ctx := context.Background() + _, fee, err := z.internalEstimateGasPriceAndFee(ctx, arg, blockArg, nil) + if err != nil { + return nil, err + } + return hex.EncodeBig(fee), nil } // internalEstimateGasPriceAndFee computes the estimated gas price and the estimated fee for the transaction @@ -462,7 +445,7 @@ func (z *ZKEVMEndpoints) internalEstimateGasPriceAndFee(ctx context.Context, arg return nil, nil, types.NewRPCError(types.InvalidParamsErrorCode, "missing value for required argument 0") } - block, respErr := z.getBlockByArg(ctx, blockArg, dbTx) + block, respErr := z.getBlockByArg(ctx, blockArg, nil) if respErr != nil { return nil, nil, respErr } @@ -479,12 +462,12 @@ func (z *ZKEVMEndpoints) internalEstimateGasPriceAndFee(ctx context.Context, arg } defaultSenderAddress := common.HexToAddress(state.DefaultSenderAddress) - sender, tx, err := arg.ToTransaction(ctx, z.state, z.cfg.MaxCumulativeGasUsed, block.Root(), defaultSenderAddress, dbTx) + sender, tx, err := arg.ToTransaction(ctx, z.state, z.cfg.MaxCumulativeGasUsed, block.Root(), defaultSenderAddress, nil) if err != nil { return nil, nil, types.NewRPCError(types.DefaultErrorCode, "failed to convert arguments into an unsigned transaction") } - gasEstimation, returnValue, err := z.state.EstimateGas(tx, sender, blockToProcess, dbTx) + gasEstimation, returnValue, err := z.state.EstimateGas(tx, sender, blockToProcess, nil) if errors.Is(err, runtime.ErrExecutionReverted) { data := make([]byte, len(returnValue)) copy(data, returnValue) @@ -540,78 +523,77 @@ func (z *ZKEVMEndpoints) internalEstimateGasPriceAndFee(ctx context.Context, arg // EstimateCounters returns an estimation of the counters that are going to be used while executing // this transaction. func (z *ZKEVMEndpoints) EstimateCounters(arg *types.TxArgs, blockArg *types.BlockNumberOrHash) (interface{}, types.Error) { - return z.txMan.NewDbTxScope(z.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - if arg == nil { - return RPCErrorResponse(types.InvalidParamsErrorCode, "missing value for required argument 0", nil, false) - } + ctx := context.Background() + if arg == nil { + return RPCErrorResponse(types.InvalidParamsErrorCode, "missing value for required argument 0", nil, false) + } - block, respErr := z.getBlockByArg(ctx, blockArg, dbTx) - if respErr != nil { - return nil, respErr - } + block, respErr := z.getBlockByArg(ctx, blockArg, nil) + if respErr != nil { + return nil, respErr + } - var blockToProcess *uint64 - if blockArg != nil { - blockNumArg := blockArg.Number() - if blockNumArg != nil && (*blockArg.Number() == types.LatestBlockNumber || *blockArg.Number() == types.PendingBlockNumber) { - blockToProcess = nil - } else { - n := block.NumberU64() - blockToProcess = &n - } + var blockToProcess *uint64 + if blockArg != nil { + blockNumArg := blockArg.Number() + if blockNumArg != nil && (*blockArg.Number() == types.LatestBlockNumber || *blockArg.Number() == types.PendingBlockNumber) { + blockToProcess = nil + } else { + n := block.NumberU64() + blockToProcess = &n } + } - defaultSenderAddress := common.HexToAddress(state.DefaultSenderAddress) - sender, tx, err := arg.ToTransaction(ctx, z.state, z.cfg.MaxCumulativeGasUsed, block.Root(), defaultSenderAddress, dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "failed to convert arguments into an unsigned transaction", err, false) - } + defaultSenderAddress := common.HexToAddress(state.DefaultSenderAddress) + sender, tx, err := arg.ToTransaction(ctx, z.state, z.cfg.MaxCumulativeGasUsed, block.Root(), defaultSenderAddress, nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "failed to convert arguments into an unsigned transaction", err, false) + } - var oocErr error - processBatchResponse, err := z.state.PreProcessUnsignedTransaction(ctx, tx, sender, blockToProcess, dbTx) - if err != nil { - if executor.IsROMOutOfCountersError(executor.RomErrorCode(err)) { - oocErr = err - } else { - errMsg := fmt.Sprintf("failed to estimate counters: %v", err.Error()) - return nil, types.NewRPCError(types.DefaultErrorCode, errMsg) - } + var oocErr error + processBatchResponse, err := z.state.PreProcessUnsignedTransaction(ctx, tx, sender, blockToProcess, nil) + if err != nil { + if executor.IsROMOutOfCountersError(executor.RomErrorCode(err)) { + oocErr = err + } else { + errMsg := fmt.Sprintf("failed to estimate counters: %v", err.Error()) + return nil, types.NewRPCError(types.DefaultErrorCode, errMsg) } + } - var revert *types.RevertInfo - if len(processBatchResponse.BlockResponses) > 0 && len(processBatchResponse.BlockResponses[0].TransactionResponses) > 0 { - txResponse := processBatchResponse.BlockResponses[0].TransactionResponses[0] - err = txResponse.RomError - if errors.Is(err, runtime.ErrExecutionReverted) { - returnValue := make([]byte, len(txResponse.ReturnValue)) - copy(returnValue, txResponse.ReturnValue) - err := state.ConstructErrorFromRevert(err, returnValue) - revert = &types.RevertInfo{ - Message: err.Error(), - Data: state.Ptr(types.ArgBytes(returnValue)), - } + var revert *types.RevertInfo + if len(processBatchResponse.BlockResponses) > 0 && len(processBatchResponse.BlockResponses[0].TransactionResponses) > 0 { + txResponse := processBatchResponse.BlockResponses[0].TransactionResponses[0] + err = txResponse.RomError + if errors.Is(err, runtime.ErrExecutionReverted) { + returnValue := make([]byte, len(txResponse.ReturnValue)) + copy(returnValue, txResponse.ReturnValue) + err := state.ConstructErrorFromRevert(err, returnValue) + revert = &types.RevertInfo{ + Message: err.Error(), + Data: state.Ptr(types.ArgBytes(returnValue)), } } + } - limits := types.ZKCountersLimits{ - MaxGasUsed: types.ArgUint64(state.MaxTxGasLimit), - MaxKeccakHashes: types.ArgUint64(z.cfg.ZKCountersLimits.MaxKeccakHashes), - MaxPoseidonHashes: types.ArgUint64(z.cfg.ZKCountersLimits.MaxPoseidonHashes), - MaxPoseidonPaddings: types.ArgUint64(z.cfg.ZKCountersLimits.MaxPoseidonPaddings), - MaxMemAligns: types.ArgUint64(z.cfg.ZKCountersLimits.MaxMemAligns), - MaxArithmetics: types.ArgUint64(z.cfg.ZKCountersLimits.MaxArithmetics), - MaxBinaries: types.ArgUint64(z.cfg.ZKCountersLimits.MaxBinaries), - MaxSteps: types.ArgUint64(z.cfg.ZKCountersLimits.MaxSteps), - MaxSHA256Hashes: types.ArgUint64(z.cfg.ZKCountersLimits.MaxSHA256Hashes), - } - return types.NewZKCountersResponse(processBatchResponse.UsedZkCounters, limits, revert, oocErr), nil - }) + limits := types.ZKCountersLimits{ + MaxGasUsed: types.ArgUint64(state.MaxTxGasLimit), + MaxKeccakHashes: types.ArgUint64(z.cfg.ZKCountersLimits.MaxKeccakHashes), + MaxPoseidonHashes: types.ArgUint64(z.cfg.ZKCountersLimits.MaxPoseidonHashes), + MaxPoseidonPaddings: types.ArgUint64(z.cfg.ZKCountersLimits.MaxPoseidonPaddings), + MaxMemAligns: types.ArgUint64(z.cfg.ZKCountersLimits.MaxMemAligns), + MaxArithmetics: types.ArgUint64(z.cfg.ZKCountersLimits.MaxArithmetics), + MaxBinaries: types.ArgUint64(z.cfg.ZKCountersLimits.MaxBinaries), + MaxSteps: types.ArgUint64(z.cfg.ZKCountersLimits.MaxSteps), + MaxSHA256Hashes: types.ArgUint64(z.cfg.ZKCountersLimits.MaxSHA256Hashes), + } + return types.NewZKCountersResponse(processBatchResponse.UsedZkCounters, limits, revert, oocErr), nil } func (z *ZKEVMEndpoints) getBlockByArg(ctx context.Context, blockArg *types.BlockNumberOrHash, dbTx pgx.Tx) (*state.L2Block, types.Error) { // If no block argument is provided, return the latest block if blockArg == nil { - block, err := z.state.GetLastL2Block(ctx, dbTx) + block, err := z.state.GetLastL2Block(ctx, nil) if err != nil { return nil, types.NewRPCError(types.DefaultErrorCode, "failed to get the last block number from state") } @@ -620,7 +602,7 @@ func (z *ZKEVMEndpoints) getBlockByArg(ctx context.Context, blockArg *types.Bloc // If we have a block hash, try to get the block by hash if blockArg.IsHash() { - block, err := z.state.GetL2BlockByHash(ctx, blockArg.Hash().Hash(), dbTx) + block, err := z.state.GetL2BlockByHash(ctx, blockArg.Hash().Hash(), nil) if errors.Is(err, state.ErrNotFound) { return nil, types.NewRPCError(types.DefaultErrorCode, "header for hash not found") } else if err != nil { @@ -630,11 +612,11 @@ func (z *ZKEVMEndpoints) getBlockByArg(ctx context.Context, blockArg *types.Bloc } // Otherwise, try to get the block by number - blockNum, rpcErr := blockArg.Number().GetNumericBlockNumber(ctx, z.state, z.etherman, dbTx) + blockNum, rpcErr := blockArg.Number().GetNumericBlockNumber(ctx, z.state, z.etherman, nil) if rpcErr != nil { return nil, rpcErr } - block, err := z.state.GetL2BlockByNumber(context.Background(), blockNum, dbTx) + block, err := z.state.GetL2BlockByNumber(context.Background(), blockNum, nil) if errors.Is(err, state.ErrNotFound) || block == nil { return nil, types.NewRPCError(types.DefaultErrorCode, "header not found") } else if err != nil { @@ -646,14 +628,13 @@ func (z *ZKEVMEndpoints) getBlockByArg(ctx context.Context, blockArg *types.Bloc // GetLatestGlobalExitRoot returns the last global exit root used by l2 func (z *ZKEVMEndpoints) GetLatestGlobalExitRoot() (interface{}, types.Error) { - return z.txMan.NewDbTxScope(z.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, types.Error) { - var err error + ctx := context.Background() + var err error - ger, err := z.state.GetLatestBatchGlobalExitRoot(ctx, dbTx) - if err != nil { - return RPCErrorResponse(types.DefaultErrorCode, "couldn't load the last global exit root", err, true) - } + ger, err := z.state.GetLatestBatchGlobalExitRoot(ctx, nil) + if err != nil { + return RPCErrorResponse(types.DefaultErrorCode, "couldn't load the last global exit root", err, true) + } - return ger.String(), nil - }) + return ger.String(), nil } diff --git a/jsonrpc/endpoints_zkevm_test.go b/jsonrpc/endpoints_zkevm_test.go index 8c0090d1e2..ff2761158f 100644 --- a/jsonrpc/endpoints_zkevm_test.go +++ b/jsonrpc/endpoints_zkevm_test.go @@ -46,18 +46,8 @@ func TestConsolidatedBlockNumber(t *testing.T) { Name: "Get consolidated block number successfully", ExpectedResult: state.Ptr(uint64(10)), SetupMocks: func(m *mocksWrapper) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastConsolidatedL2BlockNumber", context.Background(), m.DbTx). + On("GetLastConsolidatedL2BlockNumber", context.Background(), nil). Return(uint64(10), nil). Once() }, @@ -67,18 +57,8 @@ func TestConsolidatedBlockNumber(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get last consolidated block number from state"), SetupMocks: func(m *mocksWrapper) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetLastConsolidatedL2BlockNumber", context.Background(), m.DbTx). + On("GetLastConsolidatedL2BlockNumber", context.Background(), nil). Return(uint64(0), errors.New("failed to get last consolidated block number")). Once() }, @@ -124,18 +104,8 @@ func TestIsBlockConsolidated(t *testing.T) { Name: "Query status of block number successfully", ExpectedResult: true, SetupMocks: func(m *mocksWrapper) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("IsL2BlockConsolidated", context.Background(), uint64(1), m.DbTx). + On("IsL2BlockConsolidated", context.Background(), uint64(1), nil). Return(true, nil). Once() }, @@ -145,18 +115,8 @@ func TestIsBlockConsolidated(t *testing.T) { ExpectedResult: false, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to check if the block is consolidated"), SetupMocks: func(m *mocksWrapper) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("IsL2BlockConsolidated", context.Background(), uint64(1), m.DbTx). + On("IsL2BlockConsolidated", context.Background(), uint64(1), nil). Return(false, errors.New("failed to check if the block is consolidated")). Once() }, @@ -202,18 +162,8 @@ func TestIsBlockVirtualized(t *testing.T) { Name: "Query status of block number successfully", ExpectedResult: true, SetupMocks: func(m *mocksWrapper) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("IsL2BlockVirtualized", context.Background(), uint64(1), m.DbTx). + On("IsL2BlockVirtualized", context.Background(), uint64(1), nil). Return(true, nil). Once() }, @@ -223,18 +173,8 @@ func TestIsBlockVirtualized(t *testing.T) { ExpectedResult: false, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to check if the block is virtualized"), SetupMocks: func(m *mocksWrapper) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("IsL2BlockVirtualized", context.Background(), uint64(1), m.DbTx). + On("IsL2BlockVirtualized", context.Background(), uint64(1), nil). Return(false, errors.New("failed to check if the block is virtualized")). Once() }, @@ -282,18 +222,8 @@ func TestBatchNumberByBlockNumber(t *testing.T) { Name: "get batch number by block number successfully", ExpectedResult: &batchNumber, SetupMocks: func(m *mocksWrapper) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("BatchNumberByL2BlockNumber", context.Background(), blockNumber, m.DbTx). + On("BatchNumberByL2BlockNumber", context.Background(), blockNumber, nil). Return(batchNumber, nil). Once() }, @@ -303,18 +233,8 @@ func TestBatchNumberByBlockNumber(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get batch number from block number"), SetupMocks: func(m *mocksWrapper) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("BatchNumberByL2BlockNumber", context.Background(), blockNumber, m.DbTx). + On("BatchNumberByL2BlockNumber", context.Background(), blockNumber, nil). Return(uint64(0), errors.New("failed to get batch number of l2 batchNum")). Once() }, @@ -324,18 +244,8 @@ func TestBatchNumberByBlockNumber(t *testing.T) { ExpectedResult: nil, ExpectedError: nil, SetupMocks: func(m *mocksWrapper) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("BatchNumberByL2BlockNumber", context.Background(), blockNumber, m.DbTx). + On("BatchNumberByL2BlockNumber", context.Background(), blockNumber, nil). Return(uint64(0), state.ErrNotFound). Once() }, @@ -393,18 +303,8 @@ func TestBatchNumber(t *testing.T) { ExpectedError: nil, ExpectedResult: 10, SetupMocks: func(m *mocksWrapper) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastBatchNumber", context.Background(), m.DbTx). + On("GetLastBatchNumber", context.Background(), nil). Return(uint64(10), nil). Once() }, @@ -414,18 +314,8 @@ func TestBatchNumber(t *testing.T) { ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get the last batch number from state"), ExpectedResult: 0, SetupMocks: func(m *mocksWrapper) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastBatchNumber", context.Background(), m.DbTx). + On("GetLastBatchNumber", context.Background(), nil). Return(uint64(0), errors.New("failed to get last batch number")). Once() }, @@ -472,18 +362,8 @@ func TestVirtualBatchNumber(t *testing.T) { ExpectedError: nil, ExpectedResult: 10, SetupMocks: func(m *mocksWrapper) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetLastVirtualBatchNum", context.Background(), m.DbTx). + On("GetLastVirtualBatchNum", context.Background(), nil). Return(uint64(10), nil). Once() }, @@ -493,18 +373,8 @@ func TestVirtualBatchNumber(t *testing.T) { ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get the last virtual batch number from state"), ExpectedResult: 0, SetupMocks: func(m *mocksWrapper) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastVirtualBatchNum", context.Background(), m.DbTx). + On("GetLastVirtualBatchNum", context.Background(), nil). Return(uint64(0), errors.New("failed to get last batch number")). Once() }, @@ -551,18 +421,8 @@ func TestVerifiedBatchNumber(t *testing.T) { ExpectedError: nil, ExpectedResult: 10, SetupMocks: func(m *mocksWrapper) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastVerifiedBatch", context.Background(), m.DbTx). + On("GetLastVerifiedBatch", context.Background(), nil). Return(&state.VerifiedBatch{BatchNumber: uint64(10)}, nil). Once() }, @@ -572,18 +432,8 @@ func TestVerifiedBatchNumber(t *testing.T) { ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get the last verified batch number from state"), ExpectedResult: 0, SetupMocks: func(m *mocksWrapper) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetLastVerifiedBatch", context.Background(), m.DbTx). + On("GetLastVerifiedBatch", context.Background(), nil). Return(nil, errors.New("failed to get last batch number")). Once() }, @@ -630,18 +480,8 @@ func TestGetBatchByNumber(t *testing.T) { ExpectedResult: nil, ExpectedError: nil, SetupMocks: func(s *mockedServer, m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetBatchByNumber", context.Background(), hex.DecodeBig(tc.Number).Uint64(), m.DbTx). + On("GetBatchByNumber", context.Background(), hex.DecodeBig(tc.Number).Uint64(), nil). Return(nil, state.ErrNotFound) }, }, @@ -661,16 +501,6 @@ func TestGetBatchByNumber(t *testing.T) { }, ExpectedError: nil, SetupMocks: func(s *mockedServer, m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - txs := []*ethTypes.Transaction{ signTx(ethTypes.NewTransaction(1001, common.HexToAddress("0x1000"), big.NewInt(1000), 1001, big.NewInt(1002), []byte("1003")), s.ChainID()), signTx(ethTypes.NewTransaction(1002, common.HexToAddress("0x1000"), big.NewInt(1000), 1001, big.NewInt(1002), []byte("1003")), s.ChainID()), @@ -733,12 +563,12 @@ func TestGetBatchByNumber(t *testing.T) { } m.State. - On("GetBatchByNumber", context.Background(), hex.DecodeBig(tc.Number).Uint64(), m.DbTx). + On("GetBatchByNumber", context.Background(), hex.DecodeBig(tc.Number).Uint64(), nil). Return(batch, nil). Once() m.State. - On("GetBatchTimestamp", mock.Anything, mock.Anything, (*uint64)(nil), m.DbTx). + On("GetBatchTimestamp", mock.Anything, mock.Anything, (*uint64)(nil), nil). Return(&batch.Timestamp, nil). Once() @@ -747,7 +577,7 @@ func TestGetBatchByNumber(t *testing.T) { } m.State. - On("GetVirtualBatch", context.Background(), hex.DecodeBig(tc.Number).Uint64(), m.DbTx). + On("GetVirtualBatch", context.Background(), hex.DecodeBig(tc.Number).Uint64(), nil). Return(virtualBatch, nil). Once() @@ -756,7 +586,7 @@ func TestGetBatchByNumber(t *testing.T) { } m.State. - On("GetVerifiedBatch", context.Background(), hex.DecodeBig(tc.Number).Uint64(), m.DbTx). + On("GetVerifiedBatch", context.Background(), hex.DecodeBig(tc.Number).Uint64(), nil). Return(verifiedBatch, nil). Once() @@ -766,27 +596,27 @@ func TestGetBatchByNumber(t *testing.T) { GlobalExitRoot: common.HexToHash("0x4"), } m.State. - On("GetExitRootByGlobalExitRoot", context.Background(), batch.GlobalExitRoot, m.DbTx). + On("GetExitRootByGlobalExitRoot", context.Background(), batch.GlobalExitRoot, nil). Return(&ger, nil). Once() for i, tx := range txs { m.State. - On("GetTransactionReceipt", context.Background(), tx.Hash(), m.DbTx). + On("GetTransactionReceipt", context.Background(), tx.Hash(), nil). Return(receipts[i], nil). Once() m.State. - On("GetL2TxHashByTxHash", context.Background(), tx.Hash(), m.DbTx). + On("GetL2TxHashByTxHash", context.Background(), tx.Hash(), nil). Return(state.Ptr(tx.Hash()), nil). Once() } m.State. - On("GetTransactionsByBatchNumber", context.Background(), hex.DecodeBig(tc.Number).Uint64(), m.DbTx). + On("GetTransactionsByBatchNumber", context.Background(), hex.DecodeBig(tc.Number).Uint64(), nil). Return(batchTxs, effectivePercentages, nil). Once() m.State. - On("GetL2BlocksByBatchNumber", context.Background(), hex.DecodeBig(tc.Number).Uint64(), m.DbTx). + On("GetL2BlocksByBatchNumber", context.Background(), hex.DecodeBig(tc.Number).Uint64(), nil). Return(blocks, nil). Once() }, @@ -807,16 +637,6 @@ func TestGetBatchByNumber(t *testing.T) { }, ExpectedError: nil, SetupMocks: func(s *mockedServer, m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - txs := []*ethTypes.Transaction{ signTx(ethTypes.NewTransaction(1001, common.HexToAddress("0x1000"), big.NewInt(1000), 1001, big.NewInt(1002), []byte("1003")), s.ChainID()), signTx(ethTypes.NewTransaction(1002, common.HexToAddress("0x1000"), big.NewInt(1000), 1001, big.NewInt(1002), []byte("1003")), s.ChainID()), @@ -861,12 +681,12 @@ func TestGetBatchByNumber(t *testing.T) { } m.State. - On("GetBatchByNumber", context.Background(), hex.DecodeBig(tc.Number).Uint64(), m.DbTx). + On("GetBatchByNumber", context.Background(), hex.DecodeBig(tc.Number).Uint64(), nil). Return(batch, nil). Once() m.State. - On("GetBatchTimestamp", mock.Anything, mock.Anything, (*uint64)(nil), m.DbTx). + On("GetBatchTimestamp", mock.Anything, mock.Anything, (*uint64)(nil), nil). Return(&batch.Timestamp, nil). Once() @@ -875,7 +695,7 @@ func TestGetBatchByNumber(t *testing.T) { } m.State. - On("GetVirtualBatch", context.Background(), hex.DecodeBig(tc.Number).Uint64(), m.DbTx). + On("GetVirtualBatch", context.Background(), hex.DecodeBig(tc.Number).Uint64(), nil). Return(virtualBatch, nil). Once() @@ -884,7 +704,7 @@ func TestGetBatchByNumber(t *testing.T) { } m.State. - On("GetVerifiedBatch", context.Background(), hex.DecodeBig(tc.Number).Uint64(), m.DbTx). + On("GetVerifiedBatch", context.Background(), hex.DecodeBig(tc.Number).Uint64(), nil). Return(verifiedBatch, nil). Once() @@ -894,22 +714,22 @@ func TestGetBatchByNumber(t *testing.T) { GlobalExitRoot: common.HexToHash("0x4"), } m.State. - On("GetExitRootByGlobalExitRoot", context.Background(), batch.GlobalExitRoot, m.DbTx). + On("GetExitRootByGlobalExitRoot", context.Background(), batch.GlobalExitRoot, nil). Return(&ger, nil). Once() for i, tx := range txs { m.State. - On("GetTransactionReceipt", context.Background(), tx.Hash(), m.DbTx). + On("GetTransactionReceipt", context.Background(), tx.Hash(), nil). Return(receipts[i], nil). Once() } m.State. - On("GetTransactionsByBatchNumber", context.Background(), hex.DecodeBig(tc.Number).Uint64(), m.DbTx). + On("GetTransactionsByBatchNumber", context.Background(), hex.DecodeBig(tc.Number).Uint64(), nil). Return(batchTxs, effectivePercentages, nil). Once() m.State. - On("GetL2BlocksByBatchNumber", context.Background(), hex.DecodeBig(tc.Number).Uint64(), m.DbTx). + On("GetL2BlocksByBatchNumber", context.Background(), hex.DecodeBig(tc.Number).Uint64(), nil). Return(blocks, nil). Once() @@ -933,18 +753,8 @@ func TestGetBatchByNumber(t *testing.T) { }, ExpectedError: nil, SetupMocks: func(s *mockedServer, m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastClosedBatchNumber", context.Background(), m.DbTx). + On("GetLastClosedBatchNumber", context.Background(), nil). Return(uint64(tc.ExpectedResult.Number), nil). Once() @@ -1018,12 +828,12 @@ func TestGetBatchByNumber(t *testing.T) { } m.State. - On("GetBatchByNumber", context.Background(), uint64(tc.ExpectedResult.Number), m.DbTx). + On("GetBatchByNumber", context.Background(), uint64(tc.ExpectedResult.Number), nil). Return(batch, nil). Once() m.State. - On("GetBatchTimestamp", mock.Anything, mock.Anything, (*uint64)(nil), m.DbTx). + On("GetBatchTimestamp", mock.Anything, mock.Anything, (*uint64)(nil), nil). Return(&batch.Timestamp, nil). Once() @@ -1032,7 +842,7 @@ func TestGetBatchByNumber(t *testing.T) { } m.State. - On("GetVirtualBatch", context.Background(), uint64(tc.ExpectedResult.Number), m.DbTx). + On("GetVirtualBatch", context.Background(), uint64(tc.ExpectedResult.Number), nil). Return(virtualBatch, nil). Once() @@ -1041,7 +851,7 @@ func TestGetBatchByNumber(t *testing.T) { } m.State. - On("GetVerifiedBatch", context.Background(), uint64(tc.ExpectedResult.Number), m.DbTx). + On("GetVerifiedBatch", context.Background(), uint64(tc.ExpectedResult.Number), nil). Return(verifiedBatch, nil). Once() @@ -1051,28 +861,28 @@ func TestGetBatchByNumber(t *testing.T) { GlobalExitRoot: common.HexToHash("0x4"), } m.State. - On("GetExitRootByGlobalExitRoot", context.Background(), batch.GlobalExitRoot, m.DbTx). + On("GetExitRootByGlobalExitRoot", context.Background(), batch.GlobalExitRoot, nil). Return(&ger, nil). Once() for i, tx := range txs { m.State. - On("GetTransactionReceipt", context.Background(), tx.Hash(), m.DbTx). + On("GetTransactionReceipt", context.Background(), tx.Hash(), nil). Return(receipts[i], nil). Once() m.State. - On("GetL2TxHashByTxHash", context.Background(), tx.Hash(), m.DbTx). + On("GetL2TxHashByTxHash", context.Background(), tx.Hash(), nil). Return(state.Ptr(tx.Hash()), nil). Once() } m.State. - On("GetTransactionsByBatchNumber", context.Background(), uint64(tc.ExpectedResult.Number), m.DbTx). + On("GetTransactionsByBatchNumber", context.Background(), uint64(tc.ExpectedResult.Number), nil). Return(batchTxs, effectivePercentages, nil). Once() m.State. - On("GetL2BlocksByBatchNumber", context.Background(), uint64(tc.ExpectedResult.Number), m.DbTx). + On("GetL2BlocksByBatchNumber", context.Background(), uint64(tc.ExpectedResult.Number), nil). Return(blocks, nil). Once() tc.ExpectedResult.BatchL2Data = batchL2Data @@ -1084,18 +894,8 @@ func TestGetBatchByNumber(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get the last batch number from state"), SetupMocks: func(s *mockedServer, m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetLastClosedBatchNumber", context.Background(), m.DbTx). + On("GetLastClosedBatchNumber", context.Background(), nil). Return(uint64(0), errors.New("failed to get last batch number")). Once() }, @@ -1106,23 +906,13 @@ func TestGetBatchByNumber(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "couldn't load batch from state by number 1"), SetupMocks: func(s *mockedServer, m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetLastClosedBatchNumber", context.Background(), m.DbTx). + On("GetLastClosedBatchNumber", context.Background(), nil). Return(uint64(1), nil). Once() m.State. - On("GetBatchByNumber", context.Background(), uint64(1), m.DbTx). + On("GetBatchByNumber", context.Background(), uint64(1), nil). Return(nil, errors.New("failed to load batch by number")). Once() }, @@ -1218,18 +1008,8 @@ func TestGetL2FullBlockByHash(t *testing.T) { ExpectedResult: nil, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetL2BlockByHash", context.Background(), tc.Hash, m.DbTx). + On("GetL2BlockByHash", context.Background(), tc.Hash, nil). Return(nil, state.ErrNotFound) }, }, @@ -1239,18 +1019,8 @@ func TestGetL2FullBlockByHash(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get block by hash from state"), SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetL2BlockByHash", context.Background(), tc.Hash, m.DbTx). + On("GetL2BlockByHash", context.Background(), tc.Hash, nil). Return(nil, errors.New("failed to get block from state")). Once() }, @@ -1274,24 +1044,14 @@ func TestGetL2FullBlockByHash(t *testing.T) { st := trie.NewStackTrie(nil) block := state.NewL2Block(state.NewL2Header(tc.ExpectedResult.Header()), tc.ExpectedResult.Transactions(), uncles, []*ethTypes.Receipt{ethTypes.NewReceipt([]byte{}, false, uint64(0))}, st) - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetL2BlockByHash", context.Background(), tc.Hash, m.DbTx). + On("GetL2BlockByHash", context.Background(), tc.Hash, nil). Return(block, nil). Once() for _, tx := range tc.ExpectedResult.Transactions() { m.State. - On("GetTransactionReceipt", context.Background(), tx.Hash(), m.DbTx). + On("GetTransactionReceipt", context.Background(), tx.Hash(), nil). Return(ethTypes.NewReceipt([]byte{}, false, uint64(0)), nil). Once() } @@ -1486,18 +1246,8 @@ func TestGetL2FullBlockByNumber(t *testing.T) { ExpectedResult: nil, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetL2BlockByNumber", context.Background(), hex.DecodeUint64(tc.Number), m.DbTx). + On("GetL2BlockByNumber", context.Background(), hex.DecodeUint64(tc.Number), nil). Return(nil, state.ErrNotFound). Once() }, @@ -1508,24 +1258,14 @@ func TestGetL2FullBlockByNumber(t *testing.T) { ExpectedResult: rpcBlock, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetL2BlockByNumber", context.Background(), hex.DecodeUint64(tc.Number), m.DbTx). + On("GetL2BlockByNumber", context.Background(), hex.DecodeUint64(tc.Number), nil). Return(l2Block, nil). Once() for _, receipt := range receipts { m.State. - On("GetTransactionReceipt", context.Background(), receipt.TxHash, m.DbTx). + On("GetTransactionReceipt", context.Background(), receipt.TxHash, nil). Return(receipt, nil). Once() } @@ -1537,31 +1277,21 @@ func TestGetL2FullBlockByNumber(t *testing.T) { ExpectedResult: rpcBlock, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - blockNumber := uint64(1) m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(blockNumber, nil). Once() m.State. - On("GetL2BlockByNumber", context.Background(), blockNumber, m.DbTx). + On("GetL2BlockByNumber", context.Background(), blockNumber, nil). Return(l2Block, nil). Once() for _, receipt := range receipts { m.State. - On("GetTransactionReceipt", context.Background(), receipt.TxHash, m.DbTx). + On("GetTransactionReceipt", context.Background(), receipt.TxHash, nil). Return(receipt, nil). Once() } @@ -1573,18 +1303,8 @@ func TestGetL2FullBlockByNumber(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get the last block number from state"), SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(uint64(0), errors.New("failed to get last block number")). Once() }, @@ -1595,23 +1315,13 @@ func TestGetL2FullBlockByNumber(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "couldn't load block from state by number 1"), SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastL2BlockNumber", context.Background(), m.DbTx). + On("GetLastL2BlockNumber", context.Background(), nil). Return(uint64(1), nil). Once() m.State. - On("GetL2BlockByNumber", context.Background(), uint64(1), m.DbTx). + On("GetL2BlockByNumber", context.Background(), uint64(1), nil). Return(nil, errors.New("failed to load block by number")). Once() }, @@ -1640,18 +1350,8 @@ func TestGetL2FullBlockByNumber(t *testing.T) { tc.ExpectedResult.Nonce = nil tc.ExpectedResult.TotalDifficulty = nil - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetLastL2Block", context.Background(), m.DbTx). + On("GetLastL2Block", context.Background(), nil). Return(lastBlock, nil). Once() }, @@ -1662,18 +1362,8 @@ func TestGetL2FullBlockByNumber(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "couldn't load last block from state to compute the pending block"), SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLastL2Block", context.Background(), m.DbTx). + On("GetLastL2Block", context.Background(), nil). Return(nil, errors.New("failed to load last block")). Once() }, @@ -1769,21 +1459,11 @@ func TestGetNativeBlockHashesInRange(t *testing.T) { ExpectedResult: state.Ptr([]string{}), ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - fromBlock, _ := tc.Filter.FromBlock.GetNumericBlockNumber(context.Background(), nil, nil, nil) toBlock, _ := tc.Filter.ToBlock.GetNumericBlockNumber(context.Background(), nil, nil, nil) m.State. - On("GetNativeBlockHashesInRange", context.Background(), fromBlock, toBlock, m.DbTx). + On("GetNativeBlockHashesInRange", context.Background(), fromBlock, toBlock, nil). Return([]common.Hash{}, nil). Once() }, @@ -1797,16 +1477,6 @@ func TestGetNativeBlockHashesInRange(t *testing.T) { ExpectedResult: state.Ptr([]string{}), ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - fromBlock, _ := tc.Filter.FromBlock.GetNumericBlockNumber(context.Background(), nil, nil, nil) toBlock, _ := tc.Filter.ToBlock.GetNumericBlockNumber(context.Background(), nil, nil, nil) hashes := []common.Hash{} @@ -1820,7 +1490,7 @@ func TestGetNativeBlockHashesInRange(t *testing.T) { tc.ExpectedResult = &expectedResult m.State. - On("GetNativeBlockHashesInRange", context.Background(), fromBlock, toBlock, m.DbTx). + On("GetNativeBlockHashesInRange", context.Background(), fromBlock, toBlock, nil). Return(hashes, nil). Once() }, @@ -1834,15 +1504,7 @@ func TestGetNativeBlockHashesInRange(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.InvalidParamsErrorCode, "invalid block range"), SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() }, }, { @@ -1854,15 +1516,7 @@ func TestGetNativeBlockHashesInRange(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.InvalidParamsErrorCode, "native block hashes are limited to a 60000 block range"), SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() }, }, } @@ -1969,28 +1623,18 @@ func TestGetTransactionByL2Hash(t *testing.T) { ExpectedResult: &rpcTransaction, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetTransactionByL2Hash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByL2Hash", context.Background(), tc.Hash, nil). Return(signedTx, nil). Once() m.State. - On("GetTransactionReceipt", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionReceipt", context.Background(), tc.Hash, nil). Return(receipt, nil). Once() m.State. - On("GetL2TxHashByTxHash", context.Background(), signedTx.Hash(), m.DbTx). + On("GetL2TxHashByTxHash", context.Background(), signedTx.Hash(), nil). Return(&l2Hash, nil). Once() }, @@ -2007,18 +1651,8 @@ func TestGetTransactionByL2Hash(t *testing.T) { tc.ExpectedResult.TxIndex = nil tc.ExpectedResult.L2Hash = nil - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetTransactionByL2Hash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByL2Hash", context.Background(), tc.Hash, nil). Return(nil, state.ErrNotFound). Once() @@ -2035,18 +1669,8 @@ func TestGetTransactionByL2Hash(t *testing.T) { ExpectedResult: nil, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetTransactionByL2Hash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByL2Hash", context.Background(), tc.Hash, nil). Return(nil, state.ErrNotFound). Once() @@ -2063,18 +1687,8 @@ func TestGetTransactionByL2Hash(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to load transaction by l2 hash from state"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetTransactionByL2Hash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByL2Hash", context.Background(), tc.Hash, nil). Return(nil, errors.New("failed to load transaction by l2 hash from state")). Once() }, @@ -2086,18 +1700,8 @@ func TestGetTransactionByL2Hash(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to load transaction by l2 hash from pool"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetTransactionByL2Hash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByL2Hash", context.Background(), tc.Hash, nil). Return(nil, state.ErrNotFound). Once() @@ -2114,23 +1718,13 @@ func TestGetTransactionByL2Hash(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "transaction receipt not found"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetTransactionByL2Hash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByL2Hash", context.Background(), tc.Hash, nil). Return(signedTx, nil). Once() m.State. - On("GetTransactionReceipt", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionReceipt", context.Background(), tc.Hash, nil). Return(nil, state.ErrNotFound). Once() }, @@ -2142,23 +1736,13 @@ func TestGetTransactionByL2Hash(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to load transaction receipt from state"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetTransactionByL2Hash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByL2Hash", context.Background(), tc.Hash, nil). Return(signedTx, nil). Once() m.State. - On("GetTransactionReceipt", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionReceipt", context.Background(), tc.Hash, nil). Return(nil, errors.New("failed to load transaction receipt from state")). Once() }, @@ -2285,28 +1869,18 @@ func TestGetTransactionReceiptByL2Hash(t *testing.T) { ExpectedResult: &rpcReceipt, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetTransactionByL2Hash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByL2Hash", context.Background(), tc.Hash, nil). Return(signedTx, nil). Once() m.State. - On("GetTransactionReceipt", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionReceipt", context.Background(), tc.Hash, nil). Return(receipt, nil). Once() m.State. - On("GetL2TxHashByTxHash", context.Background(), signedTx.Hash(), m.DbTx). + On("GetL2TxHashByTxHash", context.Background(), signedTx.Hash(), nil). Return(&l2Hash, nil). Once() }, @@ -2317,18 +1891,8 @@ func TestGetTransactionReceiptByL2Hash(t *testing.T) { ExpectedResult: nil, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetTransactionByL2Hash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByL2Hash", context.Background(), tc.Hash, nil). Return(nil, state.ErrNotFound). Once() }, @@ -2339,18 +1903,8 @@ func TestGetTransactionReceiptByL2Hash(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get tx from state"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetTransactionByL2Hash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByL2Hash", context.Background(), tc.Hash, nil). Return(nil, errors.New("failed to get tx")). Once() }, @@ -2361,23 +1915,13 @@ func TestGetTransactionReceiptByL2Hash(t *testing.T) { ExpectedResult: nil, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetTransactionByL2Hash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByL2Hash", context.Background(), tc.Hash, nil). Return(signedTx, nil). Once() m.State. - On("GetTransactionReceipt", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionReceipt", context.Background(), tc.Hash, nil). Return(nil, state.ErrNotFound). Once() }, @@ -2388,23 +1932,13 @@ func TestGetTransactionReceiptByL2Hash(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get tx receipt from state"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetTransactionByL2Hash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByL2Hash", context.Background(), tc.Hash, nil). Return(signedTx, nil). Once() m.State. - On("GetTransactionReceipt", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionReceipt", context.Background(), tc.Hash, nil). Return(nil, errors.New("failed to get tx receipt from state")). Once() }, @@ -2415,28 +1949,18 @@ func TestGetTransactionReceiptByL2Hash(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to build the receipt response"), SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetTransactionByL2Hash", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionByL2Hash", context.Background(), tc.Hash, nil). Return(tx, nil). Once() m.State. - On("GetTransactionReceipt", context.Background(), tc.Hash, m.DbTx). + On("GetTransactionReceipt", context.Background(), tc.Hash, nil). Return(ethTypes.NewReceipt([]byte{}, false, 0), nil). Once() m.State. - On("GetL2TxHashByTxHash", context.Background(), tx.Hash(), m.DbTx). + On("GetL2TxHashByTxHash", context.Background(), tx.Hash(), nil). Return(&l2Hash, nil). Once() }, @@ -2530,18 +2054,8 @@ func TestGetExitRootsByGER(t *testing.T) { ExpectedResult: nil, ExpectedError: nil, SetupMocks: func(s *mockedServer, m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetExitRootByGlobalExitRoot", context.Background(), tc.GER, m.DbTx). + On("GetExitRootByGlobalExitRoot", context.Background(), tc.GER, nil). Return(nil, state.ErrNotFound) }, }, @@ -2551,18 +2065,8 @@ func TestGetExitRootsByGER(t *testing.T) { ExpectedResult: nil, ExpectedError: nil, SetupMocks: func(s *mockedServer, m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetExitRootByGlobalExitRoot", context.Background(), tc.GER, m.DbTx). + On("GetExitRootByGlobalExitRoot", context.Background(), tc.GER, nil). Return(nil, fmt.Errorf("failed to load exit roots from state")) }, }, @@ -2577,15 +2081,6 @@ func TestGetExitRootsByGER(t *testing.T) { }, ExpectedError: nil, SetupMocks: func(s *mockedServer, m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() er := &state.GlobalExitRoot{ BlockNumber: uint64(tc.ExpectedResult.BlockNumber), Timestamp: time.Unix(int64(tc.ExpectedResult.Timestamp), 0), @@ -2594,7 +2089,7 @@ func TestGetExitRootsByGER(t *testing.T) { } m.State. - On("GetExitRootByGlobalExitRoot", context.Background(), tc.GER, m.DbTx). + On("GetExitRootByGlobalExitRoot", context.Background(), tc.GER, nil). Return(er, nil) }, }, @@ -2642,18 +2137,8 @@ func TestGetLatestGlobalExitRoot(t *testing.T) { ExpectedResult: nil, ExpectedError: types.NewRPCError(types.DefaultErrorCode, "couldn't load the last global exit root"), SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Rollback", context.Background()). - Return(nil). - Once() - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - - m.State. - On("GetLatestBatchGlobalExitRoot", context.Background(), m.DbTx). + On("GetLatestBatchGlobalExitRoot", context.Background(), nil). Return(nil, fmt.Errorf("failed to load GER from state")). Once() }, @@ -2663,18 +2148,8 @@ func TestGetLatestGlobalExitRoot(t *testing.T) { ExpectedResult: state.Ptr(common.HexToHash("0x1")), ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc *testCase) { - m.DbTx. - On("Commit", context.Background()). - Return(nil). - Once() - - m.State. - On("BeginStateTransaction", context.Background()). - Return(m.DbTx, nil). - Once() - m.State. - On("GetLatestBatchGlobalExitRoot", context.Background(), m.DbTx). + On("GetLatestBatchGlobalExitRoot", context.Background(), nil). Return(common.HexToHash("0x1"), nil). Once() }, diff --git a/jsonrpc/mocks/mock_dbtx.go b/jsonrpc/mocks/mock_dbtx.go deleted file mode 100644 index c17e30a024..0000000000 --- a/jsonrpc/mocks/mock_dbtx.go +++ /dev/null @@ -1,350 +0,0 @@ -// Code generated by mockery v2.39.0. DO NOT EDIT. - -package mocks - -import ( - context "context" - - pgconn "github.com/jackc/pgconn" - mock "github.com/stretchr/testify/mock" - - pgx "github.com/jackc/pgx/v4" -) - -// DBTxMock is an autogenerated mock type for the Tx type -type DBTxMock struct { - mock.Mock -} - -// Begin provides a mock function with given fields: ctx -func (_m *DBTxMock) Begin(ctx context.Context) (pgx.Tx, error) { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for Begin") - } - - var r0 pgx.Tx - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (pgx.Tx, error)); ok { - return rf(ctx) - } - if rf, ok := ret.Get(0).(func(context.Context) pgx.Tx); ok { - r0 = rf(ctx) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(pgx.Tx) - } - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// BeginFunc provides a mock function with given fields: ctx, f -func (_m *DBTxMock) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error { - ret := _m.Called(ctx, f) - - if len(ret) == 0 { - panic("no return value specified for BeginFunc") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, func(pgx.Tx) error) error); ok { - r0 = rf(ctx, f) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// Commit provides a mock function with given fields: ctx -func (_m *DBTxMock) Commit(ctx context.Context) error { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for Commit") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context) error); ok { - r0 = rf(ctx) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// Conn provides a mock function with given fields: -func (_m *DBTxMock) Conn() *pgx.Conn { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for Conn") - } - - var r0 *pgx.Conn - if rf, ok := ret.Get(0).(func() *pgx.Conn); ok { - r0 = rf() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*pgx.Conn) - } - } - - return r0 -} - -// CopyFrom provides a mock function with given fields: ctx, tableName, columnNames, rowSrc -func (_m *DBTxMock) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { - ret := _m.Called(ctx, tableName, columnNames, rowSrc) - - if len(ret) == 0 { - panic("no return value specified for CopyFrom") - } - - var r0 int64 - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) (int64, error)); ok { - return rf(ctx, tableName, columnNames, rowSrc) - } - if rf, ok := ret.Get(0).(func(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) int64); ok { - r0 = rf(ctx, tableName, columnNames, rowSrc) - } else { - r0 = ret.Get(0).(int64) - } - - if rf, ok := ret.Get(1).(func(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) error); ok { - r1 = rf(ctx, tableName, columnNames, rowSrc) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// Exec provides a mock function with given fields: ctx, sql, arguments -func (_m *DBTxMock) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { - var _ca []interface{} - _ca = append(_ca, ctx, sql) - _ca = append(_ca, arguments...) - ret := _m.Called(_ca...) - - if len(ret) == 0 { - panic("no return value specified for Exec") - } - - var r0 pgconn.CommandTag - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) (pgconn.CommandTag, error)); ok { - return rf(ctx, sql, arguments...) - } - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) pgconn.CommandTag); ok { - r0 = rf(ctx, sql, arguments...) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(pgconn.CommandTag) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string, ...interface{}) error); ok { - r1 = rf(ctx, sql, arguments...) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// LargeObjects provides a mock function with given fields: -func (_m *DBTxMock) LargeObjects() pgx.LargeObjects { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for LargeObjects") - } - - var r0 pgx.LargeObjects - if rf, ok := ret.Get(0).(func() pgx.LargeObjects); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(pgx.LargeObjects) - } - - return r0 -} - -// Prepare provides a mock function with given fields: ctx, name, sql -func (_m *DBTxMock) Prepare(ctx context.Context, name string, sql string) (*pgconn.StatementDescription, error) { - ret := _m.Called(ctx, name, sql) - - if len(ret) == 0 { - panic("no return value specified for Prepare") - } - - var r0 *pgconn.StatementDescription - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) (*pgconn.StatementDescription, error)); ok { - return rf(ctx, name, sql) - } - if rf, ok := ret.Get(0).(func(context.Context, string, string) *pgconn.StatementDescription); ok { - r0 = rf(ctx, name, sql) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*pgconn.StatementDescription) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, name, sql) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// Query provides a mock function with given fields: ctx, sql, args -func (_m *DBTxMock) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { - var _ca []interface{} - _ca = append(_ca, ctx, sql) - _ca = append(_ca, args...) - ret := _m.Called(_ca...) - - if len(ret) == 0 { - panic("no return value specified for Query") - } - - var r0 pgx.Rows - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) (pgx.Rows, error)); ok { - return rf(ctx, sql, args...) - } - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) pgx.Rows); ok { - r0 = rf(ctx, sql, args...) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(pgx.Rows) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string, ...interface{}) error); ok { - r1 = rf(ctx, sql, args...) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// QueryFunc provides a mock function with given fields: ctx, sql, args, scans, f -func (_m *DBTxMock) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { - ret := _m.Called(ctx, sql, args, scans, f) - - if len(ret) == 0 { - panic("no return value specified for QueryFunc") - } - - var r0 pgconn.CommandTag - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, []interface{}, []interface{}, func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error)); ok { - return rf(ctx, sql, args, scans, f) - } - if rf, ok := ret.Get(0).(func(context.Context, string, []interface{}, []interface{}, func(pgx.QueryFuncRow) error) pgconn.CommandTag); ok { - r0 = rf(ctx, sql, args, scans, f) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(pgconn.CommandTag) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string, []interface{}, []interface{}, func(pgx.QueryFuncRow) error) error); ok { - r1 = rf(ctx, sql, args, scans, f) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// QueryRow provides a mock function with given fields: ctx, sql, args -func (_m *DBTxMock) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { - var _ca []interface{} - _ca = append(_ca, ctx, sql) - _ca = append(_ca, args...) - ret := _m.Called(_ca...) - - if len(ret) == 0 { - panic("no return value specified for QueryRow") - } - - var r0 pgx.Row - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) pgx.Row); ok { - r0 = rf(ctx, sql, args...) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(pgx.Row) - } - } - - return r0 -} - -// Rollback provides a mock function with given fields: ctx -func (_m *DBTxMock) Rollback(ctx context.Context) error { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for Rollback") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context) error); ok { - r0 = rf(ctx) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// SendBatch provides a mock function with given fields: ctx, b -func (_m *DBTxMock) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { - ret := _m.Called(ctx, b) - - if len(ret) == 0 { - panic("no return value specified for SendBatch") - } - - var r0 pgx.BatchResults - if rf, ok := ret.Get(0).(func(context.Context, *pgx.Batch) pgx.BatchResults); ok { - r0 = rf(ctx, b) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(pgx.BatchResults) - } - } - - return r0 -} - -// NewDBTxMock creates a new instance of DBTxMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewDBTxMock(t interface { - mock.TestingT - Cleanup(func()) -}) *DBTxMock { - mock := &DBTxMock{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/jsonrpc/server_test.go b/jsonrpc/server_test.go index ffb60f8717..ae30947e13 100644 --- a/jsonrpc/server_test.go +++ b/jsonrpc/server_test.go @@ -42,7 +42,6 @@ type mocksWrapper struct { State *mocks.StateMock Etherman *mocks.EthermanMock Storage *storageMock - DbTx *mocks.DBTxMock } func newMockedServer(t *testing.T, cfg Config) (*mockedServer, *mocksWrapper, *ethclient.Client) { @@ -50,7 +49,6 @@ func newMockedServer(t *testing.T, cfg Config) (*mockedServer, *mocksWrapper, *e st := mocks.NewStateMock(t) etherman := mocks.NewEthermanMock(t) storage := newStorageMock(t) - dbTx := mocks.NewDBTxMock(t) apis := map[string]bool{ APIEth: true, APINet: true, @@ -143,7 +141,6 @@ func newMockedServer(t *testing.T, cfg Config) (*mockedServer, *mocksWrapper, *e State: st, Etherman: etherman, Storage: storage, - DbTx: dbTx, } return msv, mks, ethClient @@ -262,11 +259,9 @@ func TestBatchRequests(t *testing.T) { NumberOfRequests: 100, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx.On("Commit", context.Background()).Return(nil).Times(tc.NumberOfRequests) - m.State.On("BeginStateTransaction", context.Background()).Return(m.DbTx, nil).Times(tc.NumberOfRequests) - m.State.On("GetLastL2BlockNumber", context.Background(), m.DbTx).Return(block.Number().Uint64(), nil).Times(tc.NumberOfRequests) - m.State.On("GetL2BlockByNumber", context.Background(), block.Number().Uint64(), m.DbTx).Return(block, nil).Times(tc.NumberOfRequests) - m.State.On("GetTransactionReceipt", context.Background(), mock.Anything, m.DbTx).Return(ethTypes.NewReceipt([]byte{}, false, uint64(0)), nil) + m.State.On("GetLastL2BlockNumber", context.Background(), nil).Return(block.Number().Uint64(), nil).Times(tc.NumberOfRequests) + m.State.On("GetL2BlockByNumber", context.Background(), block.Number().Uint64(), nil).Return(block, nil).Times(tc.NumberOfRequests) + m.State.On("GetTransactionReceipt", context.Background(), mock.Anything, nil).Return(ethTypes.NewReceipt([]byte{}, false, uint64(0)), nil) }, }, { @@ -276,11 +271,9 @@ func TestBatchRequests(t *testing.T) { NumberOfRequests: 5, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx.On("Commit", context.Background()).Return(nil).Times(tc.NumberOfRequests) - m.State.On("BeginStateTransaction", context.Background()).Return(m.DbTx, nil).Times(tc.NumberOfRequests) - m.State.On("GetLastL2BlockNumber", context.Background(), m.DbTx).Return(block.Number().Uint64(), nil).Times(tc.NumberOfRequests) - m.State.On("GetL2BlockByNumber", context.Background(), block.Number().Uint64(), m.DbTx).Return(block, nil).Times(tc.NumberOfRequests) - m.State.On("GetTransactionReceipt", context.Background(), mock.Anything, m.DbTx).Return(ethTypes.NewReceipt([]byte{}, false, uint64(0)), nil) + m.State.On("GetLastL2BlockNumber", context.Background(), nil).Return(block.Number().Uint64(), nil).Times(tc.NumberOfRequests) + m.State.On("GetL2BlockByNumber", context.Background(), block.Number().Uint64(), nil).Return(block, nil).Times(tc.NumberOfRequests) + m.State.On("GetTransactionReceipt", context.Background(), mock.Anything, nil).Return(ethTypes.NewReceipt([]byte{}, false, uint64(0)), nil) }, }, { @@ -290,11 +283,9 @@ func TestBatchRequests(t *testing.T) { NumberOfRequests: 4, ExpectedError: nil, SetupMocks: func(m *mocksWrapper, tc testCase) { - m.DbTx.On("Commit", context.Background()).Return(nil).Times(tc.NumberOfRequests) - m.State.On("BeginStateTransaction", context.Background()).Return(m.DbTx, nil).Times(tc.NumberOfRequests) - m.State.On("GetLastL2BlockNumber", context.Background(), m.DbTx).Return(block.Number().Uint64(), nil).Times(tc.NumberOfRequests) - m.State.On("GetL2BlockByNumber", context.Background(), block.Number().Uint64(), m.DbTx).Return(block, nil).Times(tc.NumberOfRequests) - m.State.On("GetTransactionReceipt", context.Background(), mock.Anything, m.DbTx).Return(ethTypes.NewReceipt([]byte{}, false, uint64(0)), nil) + m.State.On("GetLastL2BlockNumber", context.Background(), nil).Return(block.Number().Uint64(), nil).Times(tc.NumberOfRequests) + m.State.On("GetL2BlockByNumber", context.Background(), block.Number().Uint64(), nil).Return(block, nil).Times(tc.NumberOfRequests) + m.State.On("GetTransactionReceipt", context.Background(), mock.Anything, nil).Return(ethTypes.NewReceipt([]byte{}, false, uint64(0)), nil) }, }, } @@ -589,9 +580,7 @@ func TestMaxRequestPerIPPerSec(t *testing.T) { // this makes us sure the code is calling these methods only for // allowed requests times := int(cfg.MaxRequestsPerIPAndSecond) - m.DbTx.On("Commit", context.Background()).Return(nil).Times(times) - m.State.On("BeginStateTransaction", context.Background()).Return(m.DbTx, nil).Times(times) - m.State.On("GetLastL2BlockNumber", context.Background(), m.DbTx).Return(uint64(1), nil).Times(times) + m.State.On("GetLastL2BlockNumber", context.Background(), nil).Return(uint64(1), nil).Times(times) // prepare the workers to process the requests as long as a job is available requestsLimitedCount := uint64(0) diff --git a/jsonrpc/types/codec_test.go b/jsonrpc/types/codec_test.go index d08dbd8ed6..33da973651 100644 --- a/jsonrpc/types/codec_test.go +++ b/jsonrpc/types/codec_test.go @@ -54,7 +54,7 @@ func TestGetNumericBlockNumber(t *testing.T) { bn *BlockNumber expectedBlockNumber uint64 expectedError Error - setupMocks func(s *mocks.StateMock, d *mocks.DBTxMock, t *testCase) + setupMocks func(s *mocks.StateMock, t *testCase) } testCases := []testCase{ @@ -63,9 +63,9 @@ func TestGetNumericBlockNumber(t *testing.T) { bn: nil, expectedBlockNumber: 40, expectedError: nil, - setupMocks: func(s *mocks.StateMock, d *mocks.DBTxMock, t *testCase) { + setupMocks: func(s *mocks.StateMock, t *testCase) { s. - On("GetLastL2BlockNumber", context.Background(), d). + On("GetLastL2BlockNumber", context.Background(), nil). Return(uint64(40), nil). Once() }, @@ -75,9 +75,9 @@ func TestGetNumericBlockNumber(t *testing.T) { bn: state.Ptr(LatestBlockNumber), expectedBlockNumber: 50, expectedError: nil, - setupMocks: func(s *mocks.StateMock, d *mocks.DBTxMock, t *testCase) { + setupMocks: func(s *mocks.StateMock, t *testCase) { s. - On("GetLastL2BlockNumber", context.Background(), d). + On("GetLastL2BlockNumber", context.Background(), nil). Return(uint64(50), nil). Once() }, @@ -87,9 +87,9 @@ func TestGetNumericBlockNumber(t *testing.T) { bn: state.Ptr(PendingBlockNumber), expectedBlockNumber: 30, expectedError: nil, - setupMocks: func(s *mocks.StateMock, d *mocks.DBTxMock, t *testCase) { + setupMocks: func(s *mocks.StateMock, t *testCase) { s. - On("GetLastL2BlockNumber", context.Background(), d). + On("GetLastL2BlockNumber", context.Background(), nil). Return(uint64(30), nil). Once() }, @@ -99,14 +99,14 @@ func TestGetNumericBlockNumber(t *testing.T) { bn: state.Ptr(EarliestBlockNumber), expectedBlockNumber: 0, expectedError: nil, - setupMocks: func(s *mocks.StateMock, d *mocks.DBTxMock, t *testCase) {}, + setupMocks: func(s *mocks.StateMock, t *testCase) {}, }, { name: "BlockNumber SafeBlockNumber", bn: state.Ptr(SafeBlockNumber), expectedBlockNumber: 40, expectedError: nil, - setupMocks: func(s *mocks.StateMock, d *mocks.DBTxMock, t *testCase) { + setupMocks: func(s *mocks.StateMock, t *testCase) { safeBlockNumber := uint64(30) e. On("GetSafeBlockNumber", context.Background()). @@ -114,7 +114,7 @@ func TestGetNumericBlockNumber(t *testing.T) { Once() s. - On("GetLastVerifiedL2BlockNumberUntilL1Block", context.Background(), safeBlockNumber, d). + On("GetLastVerifiedL2BlockNumberUntilL1Block", context.Background(), safeBlockNumber, nil). Return(uint64(40), nil). Once() }, @@ -124,7 +124,7 @@ func TestGetNumericBlockNumber(t *testing.T) { bn: state.Ptr(FinalizedBlockNumber), expectedBlockNumber: 60, expectedError: nil, - setupMocks: func(s *mocks.StateMock, d *mocks.DBTxMock, t *testCase) { + setupMocks: func(s *mocks.StateMock, t *testCase) { finalizedBlockNumber := uint64(50) e. On("GetFinalizedBlockNumber", context.Background()). @@ -132,7 +132,7 @@ func TestGetNumericBlockNumber(t *testing.T) { Once() s. - On("GetLastVerifiedL2BlockNumberUntilL1Block", context.Background(), finalizedBlockNumber, d). + On("GetLastVerifiedL2BlockNumberUntilL1Block", context.Background(), finalizedBlockNumber, nil). Return(uint64(60), nil). Once() }, @@ -142,23 +142,22 @@ func TestGetNumericBlockNumber(t *testing.T) { bn: state.Ptr(BlockNumber(int64(10))), expectedBlockNumber: 10, expectedError: nil, - setupMocks: func(s *mocks.StateMock, d *mocks.DBTxMock, t *testCase) {}, + setupMocks: func(s *mocks.StateMock, t *testCase) {}, }, { name: "BlockNumber Negative Number <= -6", bn: state.Ptr(BlockNumber(int64(-6))), expectedBlockNumber: 0, expectedError: NewRPCError(InvalidParamsErrorCode, "invalid block number: -6"), - setupMocks: func(s *mocks.StateMock, d *mocks.DBTxMock, t *testCase) {}, + setupMocks: func(s *mocks.StateMock, t *testCase) {}, }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { tc := testCase - dbTx := mocks.NewDBTxMock(t) - testCase.setupMocks(s, dbTx, &tc) - result, rpcErr := testCase.bn.GetNumericBlockNumber(context.Background(), s, e, dbTx) + testCase.setupMocks(s, &tc) + result, rpcErr := testCase.bn.GetNumericBlockNumber(context.Background(), s, e, nil) assert.Equal(t, testCase.expectedBlockNumber, result) if rpcErr != nil || testCase.expectedError != nil { assert.Equal(t, testCase.expectedError.ErrorCode(), rpcErr.ErrorCode()) @@ -177,7 +176,7 @@ func TestGetNumericBatchNumber(t *testing.T) { bn *BatchNumber expectedBatchNumber uint64 expectedError Error - setupMocks func(s *mocks.StateMock, d *mocks.DBTxMock, t *testCase) + setupMocks func(s *mocks.StateMock, t *testCase) } testCases := []testCase{ @@ -186,9 +185,9 @@ func TestGetNumericBatchNumber(t *testing.T) { bn: nil, expectedBatchNumber: 40, expectedError: nil, - setupMocks: func(s *mocks.StateMock, d *mocks.DBTxMock, t *testCase) { + setupMocks: func(s *mocks.StateMock, t *testCase) { s. - On("GetLastClosedBatchNumber", context.Background(), d). + On("GetLastClosedBatchNumber", context.Background(), nil). Return(uint64(40), nil). Once() }, @@ -198,9 +197,9 @@ func TestGetNumericBatchNumber(t *testing.T) { bn: state.Ptr(LatestBatchNumber), expectedBatchNumber: 50, expectedError: nil, - setupMocks: func(s *mocks.StateMock, d *mocks.DBTxMock, t *testCase) { + setupMocks: func(s *mocks.StateMock, t *testCase) { s. - On("GetLastClosedBatchNumber", context.Background(), d). + On("GetLastClosedBatchNumber", context.Background(), nil). Return(uint64(50), nil). Once() }, @@ -210,9 +209,9 @@ func TestGetNumericBatchNumber(t *testing.T) { bn: state.Ptr(PendingBatchNumber), expectedBatchNumber: 90, expectedError: nil, - setupMocks: func(s *mocks.StateMock, d *mocks.DBTxMock, t *testCase) { + setupMocks: func(s *mocks.StateMock, t *testCase) { s. - On("GetLastBatchNumber", context.Background(), d). + On("GetLastBatchNumber", context.Background(), nil). Return(uint64(90), nil). Once() }, @@ -222,14 +221,14 @@ func TestGetNumericBatchNumber(t *testing.T) { bn: state.Ptr(EarliestBatchNumber), expectedBatchNumber: 0, expectedError: nil, - setupMocks: func(s *mocks.StateMock, d *mocks.DBTxMock, t *testCase) {}, + setupMocks: func(s *mocks.StateMock, t *testCase) {}, }, { name: "BatchNumber SafeBatchNumber", bn: state.Ptr(SafeBatchNumber), expectedBatchNumber: 40, expectedError: nil, - setupMocks: func(s *mocks.StateMock, d *mocks.DBTxMock, t *testCase) { + setupMocks: func(s *mocks.StateMock, t *testCase) { safeBlockNumber := uint64(30) e. On("GetSafeBlockNumber", context.Background()). @@ -237,7 +236,7 @@ func TestGetNumericBatchNumber(t *testing.T) { Once() s. - On("GetLastVerifiedBatchNumberUntilL1Block", context.Background(), safeBlockNumber, d). + On("GetLastVerifiedBatchNumberUntilL1Block", context.Background(), safeBlockNumber, nil). Return(uint64(40), nil). Once() }, @@ -247,7 +246,7 @@ func TestGetNumericBatchNumber(t *testing.T) { bn: state.Ptr(FinalizedBatchNumber), expectedBatchNumber: 60, expectedError: nil, - setupMocks: func(s *mocks.StateMock, d *mocks.DBTxMock, t *testCase) { + setupMocks: func(s *mocks.StateMock, t *testCase) { finalizedBlockNumber := uint64(50) e. On("GetFinalizedBlockNumber", context.Background()). @@ -255,7 +254,7 @@ func TestGetNumericBatchNumber(t *testing.T) { Once() s. - On("GetLastVerifiedBatchNumberUntilL1Block", context.Background(), finalizedBlockNumber, d). + On("GetLastVerifiedBatchNumberUntilL1Block", context.Background(), finalizedBlockNumber, nil). Return(uint64(60), nil). Once() }, @@ -265,23 +264,22 @@ func TestGetNumericBatchNumber(t *testing.T) { bn: state.Ptr(BatchNumber(int64(10))), expectedBatchNumber: 10, expectedError: nil, - setupMocks: func(s *mocks.StateMock, d *mocks.DBTxMock, t *testCase) {}, + setupMocks: func(s *mocks.StateMock, t *testCase) {}, }, { name: "BatchNumber Negative Number <= -6", bn: state.Ptr(BatchNumber(int64(-6))), expectedBatchNumber: 0, expectedError: NewRPCError(InvalidParamsErrorCode, "invalid batch number: -6"), - setupMocks: func(s *mocks.StateMock, d *mocks.DBTxMock, t *testCase) {}, + setupMocks: func(s *mocks.StateMock, t *testCase) {}, }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { tc := testCase - dbTx := mocks.NewDBTxMock(t) - testCase.setupMocks(s, dbTx, &tc) - result, rpcErr := testCase.bn.GetNumericBatchNumber(context.Background(), s, e, dbTx) + testCase.setupMocks(s, &tc) + result, rpcErr := testCase.bn.GetNumericBatchNumber(context.Background(), s, e, nil) assert.Equal(t, testCase.expectedBatchNumber, result) if rpcErr != nil || testCase.expectedError != nil { assert.Equal(t, testCase.expectedError.ErrorCode(), rpcErr.ErrorCode()) diff --git a/test/Makefile b/test/Makefile index 831b2cad10..ca7daa3b88 100644 --- a/test/Makefile +++ b/test/Makefile @@ -661,7 +661,6 @@ generate-mocks-jsonrpc: ## Generates mocks for jsonrpc , using mockery tool export "GOROOT=$$(go env GOROOT)" && $$(go env GOPATH)/bin/mockery --name=PoolInterface --dir=../jsonrpc/types --output=../jsonrpc/mocks --outpkg=mocks --structname=PoolMock --filename=mock_pool.go export "GOROOT=$$(go env GOROOT)" && $$(go env GOPATH)/bin/mockery --name=StateInterface --dir=../jsonrpc/types --output=../jsonrpc/mocks --outpkg=mocks --structname=StateMock --filename=mock_state.go export "GOROOT=$$(go env GOROOT)" && $$(go env GOPATH)/bin/mockery --name=EthermanInterface --dir=../jsonrpc/types --output=../jsonrpc/mocks --outpkg=mocks --structname=EthermanMock --filename=mock_etherman.go - export "GOROOT=$$(go env GOROOT)" && $$(go env GOPATH)/bin/mockery --name=Tx --srcpkg=github.com/jackc/pgx/v4 --output=../jsonrpc/mocks --outpkg=mocks --structname=DBTxMock --filename=mock_dbtx.go .PHONY: generate-mocks-sequencer generate-mocks-sequencer: ## Generates mocks for sequencer , using mockery tool