diff --git a/cmd/devp2p/discv4cmd.go b/cmd/devp2p/discv4cmd.go index 9d35880b128b..94e61c36f325 100644 --- a/cmd/devp2p/discv4cmd.go +++ b/cmd/devp2p/discv4cmd.go @@ -19,6 +19,7 @@ package main import ( "fmt" "net" + "strconv" "strings" "time" @@ -50,34 +51,34 @@ var ( Usage: "Sends ping to a node", Action: discv4Ping, ArgsUsage: "", - Flags: v4NodeFlags, + Flags: discoveryNodeFlags, } discv4RequestRecordCommand = &cli.Command{ Name: "requestenr", Usage: "Requests a node record using EIP-868 enrRequest", Action: discv4RequestRecord, ArgsUsage: "", - Flags: v4NodeFlags, + Flags: discoveryNodeFlags, } discv4ResolveCommand = &cli.Command{ Name: "resolve", Usage: "Finds a node in the DHT", Action: discv4Resolve, ArgsUsage: "", - Flags: v4NodeFlags, + Flags: discoveryNodeFlags, } discv4ResolveJSONCommand = &cli.Command{ Name: "resolve-json", Usage: "Re-resolves nodes in a nodes.json file", Action: discv4ResolveJSON, - Flags: v4NodeFlags, + Flags: discoveryNodeFlags, ArgsUsage: "", } discv4CrawlCommand = &cli.Command{ Name: "crawl", Usage: "Updates a nodes.json file with random nodes found in the DHT", Action: discv4Crawl, - Flags: flags.Merge(v4NodeFlags, []cli.Flag{crawlTimeoutFlag}), + Flags: flags.Merge(discoveryNodeFlags, []cli.Flag{crawlTimeoutFlag}), } discv4TestCommand = &cli.Command{ Name: "test", @@ -110,6 +111,10 @@ var ( Name: "addr", Usage: "Listening address", } + extAddrFlag = &cli.StringFlag{ + Name: "extaddr", + Usage: "UDP endpoint announced in ENR. You can provide a bare IP address or IP:port as the value of this flag.", + } crawlTimeoutFlag = &cli.DurationFlag{ Name: "timeout", Usage: "Time limit for the crawl.", @@ -122,11 +127,12 @@ var ( } ) -var v4NodeFlags = []cli.Flag{ +var discoveryNodeFlags = []cli.Flag{ bootnodesFlag, nodekeyFlag, nodedbFlag, listenAddrFlag, + extAddrFlag, } func discv4Ping(ctx *cli.Context) error { @@ -228,7 +234,7 @@ func discv4Test(ctx *cli.Context) error { // startV4 starts an ephemeral discovery V4 node. func startV4(ctx *cli.Context) *discover.UDPv4 { ln, config := makeDiscoveryConfig(ctx) - socket := listen(ln, ctx.String(listenAddrFlag.Name)) + socket := listen(ctx, ln) disc, err := discover.ListenV4(socket, ln, config) if err != nil { exit(err) @@ -266,7 +272,28 @@ func makeDiscoveryConfig(ctx *cli.Context) (*enode.LocalNode, discover.Config) { return ln, cfg } -func listen(ln *enode.LocalNode, addr string) *net.UDPConn { +func parseExtAddr(spec string) (ip net.IP, port int, ok bool) { + ip = net.ParseIP(spec) + if ip != nil { + return ip, 0, true + } + host, portstr, err := net.SplitHostPort(spec) + if err != nil { + return nil, 0, false + } + ip = net.ParseIP(host) + if ip == nil { + return nil, 0, false + } + port, err = strconv.Atoi(portstr) + if err != nil { + return nil, 0, false + } + return ip, port, true +} + +func listen(ctx *cli.Context, ln *enode.LocalNode) *net.UDPConn { + addr := ctx.String(listenAddrFlag.Name) if addr == "" { addr = "0.0.0.0:0" } @@ -274,6 +301,8 @@ func listen(ln *enode.LocalNode, addr string) *net.UDPConn { if err != nil { exit(err) } + + // Configure UDP endpoint in ENR from listener address. usocket := socket.(*net.UDPConn) uaddr := socket.LocalAddr().(*net.UDPAddr) if uaddr.IP.IsUnspecified() { @@ -282,6 +311,22 @@ func listen(ln *enode.LocalNode, addr string) *net.UDPConn { ln.SetFallbackIP(uaddr.IP) } ln.SetFallbackUDP(uaddr.Port) + + // If an ENR endpoint is set explicitly on the command-line, override + // the information from the listening address. Note this is careful not + // to set the UDP port if the external address doesn't have it. + extAddr := ctx.String(extAddrFlag.Name) + if extAddr != "" { + ip, port, ok := parseExtAddr(extAddr) + if !ok { + exit(fmt.Errorf("-%s: invalid external address %q", extAddrFlag.Name, extAddr)) + } + ln.SetStaticIP(ip) + if port != 0 { + ln.SetFallbackUDP(port) + } + } + return usocket } diff --git a/cmd/devp2p/discv5cmd.go b/cmd/devp2p/discv5cmd.go index 298196034b58..343e2a0d5d42 100644 --- a/cmd/devp2p/discv5cmd.go +++ b/cmd/devp2p/discv5cmd.go @@ -22,6 +22,7 @@ import ( "github.com/ethereum/go-ethereum/cmd/devp2p/internal/v5test" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/internal/flags" "github.com/ethereum/go-ethereum/p2p/discover" "github.com/urfave/cli/v2" ) @@ -42,18 +43,21 @@ var ( Name: "ping", Usage: "Sends ping to a node", Action: discv5Ping, + Flags: discoveryNodeFlags, } discv5ResolveCommand = &cli.Command{ Name: "resolve", Usage: "Finds a node in the DHT", Action: discv5Resolve, - Flags: []cli.Flag{bootnodesFlag}, + Flags: discoveryNodeFlags, } discv5CrawlCommand = &cli.Command{ Name: "crawl", Usage: "Updates a nodes.json file with random nodes found in the DHT", Action: discv5Crawl, - Flags: []cli.Flag{bootnodesFlag, crawlTimeoutFlag}, + Flags: flags.Merge(discoveryNodeFlags, []cli.Flag{ + crawlTimeoutFlag, + }), } discv5TestCommand = &cli.Command{ Name: "test", @@ -70,12 +74,7 @@ var ( Name: "listen", Usage: "Runs a node", Action: discv5Listen, - Flags: []cli.Flag{ - bootnodesFlag, - nodekeyFlag, - nodedbFlag, - listenAddrFlag, - }, + Flags: discoveryNodeFlags, } ) @@ -137,7 +136,7 @@ func discv5Listen(ctx *cli.Context) error { // startV5 starts an ephemeral discovery v5 node. func startV5(ctx *cli.Context) *discover.UDPv5 { ln, config := makeDiscoveryConfig(ctx) - socket := listen(ln, ctx.String(listenAddrFlag.Name)) + socket := listen(ctx, ln) disc, err := discover.ListenV5(socket, ln, config) if err != nil { exit(err) diff --git a/cmd/devp2p/internal/v5test/framework.go b/cmd/devp2p/internal/v5test/framework.go index 6ccbbd075bf0..f31677e519e7 100644 --- a/cmd/devp2p/internal/v5test/framework.go +++ b/cmd/devp2p/internal/v5test/framework.go @@ -86,7 +86,7 @@ func newConn(dest *enode.Node, log logger) *conn { localNode: ln, remote: dest, remoteAddr: &net.UDPAddr{IP: dest.IP(), Port: dest.UDP()}, - codec: v5wire.NewCodec(ln, key, mclock.System{}), + codec: v5wire.NewCodec(ln, key, mclock.System{}, nil), log: log, } } diff --git a/cmd/evm/staterunner.go b/cmd/evm/staterunner.go index 36f4e19b0bea..5eba25c725a3 100644 --- a/cmd/evm/staterunner.go +++ b/cmd/evm/staterunner.go @@ -22,12 +22,12 @@ import ( "fmt" "os" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/eth/tracers/logger" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/tests" - "github.com/urfave/cli/v2" ) @@ -41,11 +41,12 @@ var stateTestCommand = &cli.Command{ // StatetestResult contains the execution status after running a state test, any // error that might have occurred and a dump of the final state if requested. type StatetestResult struct { - Name string `json:"name"` - Pass bool `json:"pass"` - Fork string `json:"fork"` - Error string `json:"error,omitempty"` - State *state.Dump `json:"state,omitempty"` + Name string `json:"name"` + Pass bool `json:"pass"` + Root *common.Hash `json:"stateRoot,omitempty"` + Fork string `json:"fork"` + Error string `json:"error,omitempty"` + State *state.Dump `json:"state,omitempty"` } func stateTestCmd(ctx *cli.Context) error { @@ -100,8 +101,12 @@ func stateTestCmd(ctx *cli.Context) error { result := &StatetestResult{Name: key, Fork: st.Fork, Pass: true} _, s, err := test.Run(st, cfg, false) // print state root for evmlab tracing - if ctx.Bool(MachineFlag.Name) && s != nil { - fmt.Fprintf(os.Stderr, "{\"stateRoot\": \"%x\"}\n", s.IntermediateRoot(false)) + if s != nil { + root := s.IntermediateRoot(false) + result.Root = &root + if ctx.Bool(MachineFlag.Name) { + fmt.Fprintf(os.Stderr, "{\"stateRoot\": \"%#x\"}\n", root) + } } if err != nil { // Test failed, mark as so and dump any state to aid debugging diff --git a/cmd/geth/chaincmd.go b/cmd/geth/chaincmd.go index 48b21ddbf7a5..10af6f32f49a 100644 --- a/cmd/geth/chaincmd.go +++ b/cmd/geth/chaincmd.go @@ -39,6 +39,7 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/node" + "github.com/ethereum/go-ethereum/trie" "github.com/urfave/cli/v2" ) @@ -48,7 +49,7 @@ var ( Name: "init", Usage: "Bootstrap and initialize a new genesis block", ArgsUsage: "", - Flags: utils.DatabasePathFlags, + Flags: flags.Merge([]cli.Flag{utils.CachePreimagesFlag}, utils.DatabasePathFlags), Description: ` The init command initializes a new genesis block and definition for the network. This is a destructive action and changes the network in which you will be @@ -188,12 +189,16 @@ func initGenesis(ctx *cli.Context) error { // Open and initialise both full and light databases stack, _ := makeConfigNode(ctx) defer stack.Close() + for _, name := range []string{"chaindata", "lightchaindata"} { chaindb, err := stack.OpenDatabaseWithFreezer(name, 0, 0, ctx.String(utils.AncientFlag.Name), "", false) if err != nil { utils.Fatalf("Failed to open database: %v", err) } - _, hash, err := core.SetupGenesisBlock(chaindb, genesis) + triedb := trie.NewDatabaseWithConfig(chaindb, &trie.Config{ + Preimages: ctx.Bool(utils.CachePreimagesFlag.Name), + }) + _, hash, err := core.SetupGenesisBlock(chaindb, triedb, genesis) if err != nil { utils.Fatalf("Failed to write genesis block: %v", err) } @@ -460,7 +465,10 @@ func dump(ctx *cli.Context) error { if err != nil { return err } - state, err := state.New(root, state.NewDatabase(db), nil) + config := &trie.Config{ + Preimages: true, // always enable preimage lookup + } + state, err := state.New(root, state.NewDatabaseWithConfig(db, config), nil) if err != nil { return err } diff --git a/cmd/geth/config.go b/cmd/geth/config.go index a8cee0d13a59..e15302544cc5 100644 --- a/cmd/geth/config.go +++ b/cmd/geth/config.go @@ -31,7 +31,6 @@ import ( "github.com/ethereum/go-ethereum/accounts/scwallet" "github.com/ethereum/go-ethereum/accounts/usbwallet" "github.com/ethereum/go-ethereum/cmd/utils" - "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/eth/ethconfig" "github.com/ethereum/go-ethereum/internal/ethapi" "github.com/ethereum/go-ethereum/internal/flags" @@ -166,26 +165,7 @@ func makeFullNode(ctx *cli.Context) (*node.Node, ethapi.Backend) { cfg.Eth.OverrideTerminalTotalDifficultyPassed = &override } - backend, eth := utils.RegisterEthService(stack, &cfg.Eth) - - // Warn users to migrate if they have a legacy freezer format. - if eth != nil && !ctx.IsSet(utils.IgnoreLegacyReceiptsFlag.Name) { - firstIdx := uint64(0) - // Hack to speed up check for mainnet because we know - // the first non-empty block. - ghash := rawdb.ReadCanonicalHash(eth.ChainDb(), 0) - if cfg.Eth.NetworkId == 1 && ghash == params.MainnetGenesisHash { - firstIdx = 46147 - } - isLegacy, firstLegacy, err := dbHasLegacyReceipts(eth.ChainDb(), firstIdx) - if err != nil { - log.Error("Failed to check db for legacy receipts", "err", err) - } else if isLegacy { - stack.Close() - log.Error("Database has receipts with a legacy format", "firstLegacy", firstLegacy) - utils.Fatalf("Aborting. Please run `geth db freezer-migrate`.") - } - } + backend, _ := utils.RegisterEthService(stack, &cfg.Eth) // Configure log filter RPC API. filterSystem := utils.RegisterFilterAPI(stack, backend, &cfg.Eth) diff --git a/cmd/geth/dbcmd.go b/cmd/geth/dbcmd.go index 9d834ee14b9d..5231ed116bc9 100644 --- a/cmd/geth/dbcmd.go +++ b/cmd/geth/dbcmd.go @@ -33,7 +33,6 @@ import ( "github.com/ethereum/go-ethereum/console/prompt" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/state/snapshot" - "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/internal/flags" @@ -69,7 +68,6 @@ Remove blockchain and state databases`, dbImportCmd, dbExportCmd, dbMetadataCmd, - dbMigrateFreezerCmd, dbCheckStateContentCmd, }, } @@ -195,17 +193,6 @@ WARNING: This is a low-level operation which may cause database corruption!`, }, utils.NetworkFlags, utils.DatabasePathFlags), Description: "Shows metadata about the chain status.", } - dbMigrateFreezerCmd = &cli.Command{ - Action: freezerMigrate, - Name: "freezer-migrate", - Usage: "Migrate legacy parts of the freezer. (WARNING: may take a long time)", - ArgsUsage: "", - Flags: flags.Merge([]cli.Flag{ - utils.SyncModeFlag, - }, utils.NetworkFlags, utils.DatabasePathFlags), - Description: `The freezer-migrate command checks your database for receipts in a legacy format and updates those. -WARNING: please back-up the receipt files in your ancients before running this command.`, - } ) func removeDB(ctx *cli.Context) error { @@ -756,92 +743,3 @@ func showMetaData(ctx *cli.Context) error { table.Render() return nil } - -func freezerMigrate(ctx *cli.Context) error { - stack, _ := makeConfigNode(ctx) - defer stack.Close() - - db := utils.MakeChainDatabase(ctx, stack, false) - defer db.Close() - - // Check first block for legacy receipt format - numAncients, err := db.Ancients() - if err != nil { - return err - } - if numAncients < 1 { - log.Info("No receipts in freezer to migrate") - return nil - } - - isFirstLegacy, firstIdx, err := dbHasLegacyReceipts(db, 0) - if err != nil { - return err - } - if !isFirstLegacy { - log.Info("No legacy receipts to migrate") - return nil - } - - log.Info("Starting migration", "ancients", numAncients, "firstLegacy", firstIdx) - start := time.Now() - if err := db.MigrateTable("receipts", types.ConvertLegacyStoredReceipts); err != nil { - return err - } - if err := db.Close(); err != nil { - return err - } - log.Info("Migration finished", "duration", time.Since(start)) - - return nil -} - -// dbHasLegacyReceipts checks freezer entries for legacy receipts. It stops at the first -// non-empty receipt and checks its format. The index of this first non-empty element is -// the second return parameter. -func dbHasLegacyReceipts(db ethdb.Database, firstIdx uint64) (bool, uint64, error) { - // Check first block for legacy receipt format - numAncients, err := db.Ancients() - if err != nil { - return false, 0, err - } - if numAncients < 1 { - return false, 0, nil - } - if firstIdx >= numAncients { - return false, firstIdx, nil - } - var ( - legacy bool - blob []byte - emptyRLPList = []byte{192} - ) - // Find first block with non-empty receipt, only if - // the index is not already provided. - if firstIdx == 0 { - for i := uint64(0); i < numAncients; i++ { - blob, err = db.Ancient("receipts", i) - if err != nil { - return false, 0, err - } - if len(blob) == 0 { - continue - } - if !bytes.Equal(blob, emptyRLPList) { - firstIdx = i - break - } - } - } - first, err := db.Ancient("receipts", firstIdx) - if err != nil { - return false, 0, err - } - // We looped over all receipts and they were all empty - if bytes.Equal(first, emptyRLPList) { - return false, 0, nil - } - // Is first non-empty receipt legacy? - legacy, err = types.IsLegacyStoredReceipts(first) - return legacy, firstIdx, err -} diff --git a/cmd/geth/main.go b/cmd/geth/main.go index 4921376c669b..10b6c6df3bbb 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -149,7 +149,6 @@ var ( utils.GpoMaxGasPriceFlag, utils.GpoIgnoreGasPriceFlag, utils.MinerNotifyFullFlag, - utils.IgnoreLegacyReceiptsFlag, configFileFlag, }, utils.NetworkFlags, utils.DatabasePathFlags) @@ -342,7 +341,7 @@ func prepare(ctx *cli.Context) { go metrics.CollectProcessMetrics(3 * time.Second) } -// geth is the main entry point into the system if no special subcommand is ran. +// geth is the main entry point into the system if no special subcommand is run. // It creates a default node based on the command line arguments and runs it in // blocking mode, waiting for it to be shut down. func geth(ctx *cli.Context) error { diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 9c70b411b67b..526c85481ab8 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -674,11 +674,6 @@ var ( } // MISC settings - IgnoreLegacyReceiptsFlag = &cli.BoolFlag{ - Name: "ignore-legacy-receipts", - Usage: "Geth will start up even if there are legacy receipts in freezer", - Category: flags.MiscCategory, - } SyncTargetFlag = &cli.PathFlag{ Name: "synctarget", Usage: `File for containing the hex-encoded block-rlp as sync target(dev feature)`, @@ -932,13 +927,13 @@ var ( // other profiling behavior or information. MetricsHTTPFlag = &cli.StringFlag{ Name: "metrics.addr", - Usage: "Enable stand-alone metrics HTTP server listening interface", - Value: metrics.DefaultConfig.HTTP, + Usage: `Enable stand-alone metrics HTTP server listening interface.`, Category: flags.MetricsCategory, } MetricsPortFlag = &cli.IntFlag{ - Name: "metrics.port", - Usage: "Metrics HTTP server listening port", + Name: "metrics.port", + Usage: `Metrics HTTP server listening port. +Please note that --` + MetricsHTTPFlag.Name + ` must be set to start the server.`, Value: metrics.DefaultConfig.Port, Category: flags.MetricsCategory, } @@ -2169,6 +2164,8 @@ func SetupMetrics(ctx *cli.Context) { address := fmt.Sprintf("%s:%d", ctx.String(MetricsHTTPFlag.Name), ctx.Int(MetricsPortFlag.Name)) log.Info("Enabling stand-alone metrics HTTP endpoint", "address", address) exp.Setup(address) + } else if ctx.IsSet(MetricsPortFlag.Name) { + log.Warn(fmt.Sprintf("--%s specified without --%s, metrics server will not start.", MetricsPortFlag.Name, MetricsHTTPFlag.Name)) } } } diff --git a/console/console.go b/console/console.go index 7b9ed27e15ec..fde673be8be9 100644 --- a/console/console.go +++ b/console/console.go @@ -34,6 +34,7 @@ import ( "github.com/ethereum/go-ethereum/internal/jsre" "github.com/ethereum/go-ethereum/internal/jsre/deps" "github.com/ethereum/go-ethereum/internal/web3ext" + "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rpc" "github.com/mattn/go-colorable" "github.com/peterh/liner" @@ -198,13 +199,22 @@ func (c *Console) initWeb3(bridge *bridge) error { return err } +var defaultAPIs = map[string]string{"eth": "1.0", "net": "1.0", "debug": "1.0"} + // initExtensions loads and registers web3.js extensions. func (c *Console) initExtensions() error { - // Compute aliases from server-provided modules. + const methodNotFound = -32601 apis, err := c.client.SupportedModules() if err != nil { - return fmt.Errorf("api modules: %v", err) + if rpcErr, ok := err.(rpc.Error); ok && rpcErr.ErrorCode() == methodNotFound { + log.Warn("Server does not support method rpc_modules, using default API list.") + apis = defaultAPIs + } else { + return err + } } + + // Compute aliases from server-provided modules. aliases := map[string]struct{}{"eth": {}, "personal": {}} for api := range apis { if api == "web3" { diff --git a/core/blockchain.go b/core/blockchain.go index 910da91a3491..1d0a720c7aa3 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -169,10 +169,12 @@ type BlockChain struct { chainConfig *params.ChainConfig // Chain & network configuration cacheConfig *CacheConfig // Cache configuration for pruning - db ethdb.Database // Low level persistent database to store final content in - snaps *snapshot.Tree // Snapshot tree for fast trie leaf access - triegc *prque.Prque // Priority queue mapping block numbers to tries to gc - gcproc time.Duration // Accumulates canonical block processing for trie dumping + db ethdb.Database // Low level persistent database to store final content in + snaps *snapshot.Tree // Snapshot tree for fast trie leaf access + triegc *prque.Prque // Priority queue mapping block numbers to tries to gc + gcproc time.Duration // Accumulates canonical block processing for trie dumping + triedb *trie.Database // The database handler for maintaining trie nodes. + stateCache state.Database // State database to reuse between imports (contains state cache) // txLookupLimit is the maximum number of blocks from head whose tx indices // are reserved: @@ -200,7 +202,6 @@ type BlockChain struct { currentFinalizedBlock atomic.Value // Current finalized head currentSafeBlock atomic.Value // Current safe head - stateCache state.Database // State database to reuse between imports (contains state cache) bodyCache *lru.Cache[common.Hash, *types.Body] bodyRLPCache *lru.Cache[common.Hash, rlp.RawValue] receiptsCache *lru.Cache[common.Hash, []*types.Receipt] @@ -231,10 +232,16 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, genesis *Genesis cacheConfig = defaultCacheConfig } + // Open trie database with provided config + triedb := trie.NewDatabaseWithConfig(db, &trie.Config{ + Cache: cacheConfig.TrieCleanLimit, + Journal: cacheConfig.TrieCleanJournal, + Preimages: cacheConfig.Preimages, + }) // Setup the genesis block, commit the provided genesis specification // to database if the genesis block is not present yet, or load the // stored one from database. - chainConfig, genesisHash, genesisErr := SetupGenesisBlockWithOverride(db, genesis, overrides) + chainConfig, genesisHash, genesisErr := SetupGenesisBlockWithOverride(db, triedb, genesis, overrides) if _, ok := genesisErr.(*params.ConfigCompatError); genesisErr != nil && !ok { return nil, genesisErr } @@ -247,15 +254,11 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, genesis *Genesis log.Info("") bc := &BlockChain{ - chainConfig: chainConfig, - cacheConfig: cacheConfig, - db: db, - triegc: prque.New(nil), - stateCache: state.NewDatabaseWithConfig(db, &trie.Config{ - Cache: cacheConfig.TrieCleanLimit, - Journal: cacheConfig.TrieCleanJournal, - Preimages: cacheConfig.Preimages, - }), + chainConfig: chainConfig, + cacheConfig: cacheConfig, + db: db, + triedb: triedb, + triegc: prque.New(nil), quit: make(chan struct{}), chainmu: syncx.NewClosableMutex(), bodyCache: lru.NewCache[common.Hash, *types.Body](bodyCacheLimit), @@ -268,6 +271,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, genesis *Genesis vmConfig: vmConfig, } bc.forker = NewForkChoice(bc, shouldPreserve) + bc.stateCache = state.NewDatabaseWithNodeDB(bc.db, bc.triedb) bc.validator = NewBlockValidator(chainConfig, bc, engine) bc.prefetcher = newStatePrefetcher(chainConfig, bc, engine) bc.processor = NewStateProcessor(chainConfig, bc, engine) @@ -300,7 +304,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, genesis *Genesis } // Make sure the state associated with the block is available head := bc.CurrentBlock() - if _, err := state.New(head.Root(), bc.stateCache, bc.snaps); err != nil { + if !bc.HasState(head.Root()) { // Head state is missing, before the state recovery, find out the // disk layer point of snapshot(if it's enabled). Make sure the // rewound point is lower than disk layer. @@ -388,7 +392,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, genesis *Genesis var recover bool head := bc.CurrentBlock() - if layer := rawdb.ReadSnapshotRecoveryNumber(bc.db); layer != nil && *layer > head.NumberU64() { + if layer := rawdb.ReadSnapshotRecoveryNumber(bc.db); layer != nil && *layer >= head.NumberU64() { log.Warn("Enabling snapshot recovery", "chainhead", head.NumberU64(), "diskbase", *layer) recover = true } @@ -398,7 +402,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, genesis *Genesis NoBuild: bc.cacheConfig.SnapshotNoBuild, AsyncBuild: !bc.cacheConfig.SnapshotWait, } - bc.snaps, _ = snapshot.New(snapconfig, bc.db, bc.stateCache.TrieDB(), head.Root()) + bc.snaps, _ = snapshot.New(snapconfig, bc.db, bc.triedb, head.Root()) } // Start future block processor. @@ -411,11 +415,10 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, genesis *Genesis log.Warn("Sanitizing invalid trie cache journal time", "provided", bc.cacheConfig.TrieCleanRejournal, "updated", time.Minute) bc.cacheConfig.TrieCleanRejournal = time.Minute } - triedb := bc.stateCache.TrieDB() bc.wg.Add(1) go func() { defer bc.wg.Done() - triedb.SaveCachePeriodically(bc.cacheConfig.TrieCleanJournal, bc.cacheConfig.TrieCleanRejournal, bc.quit) + bc.triedb.SaveCachePeriodically(bc.cacheConfig.TrieCleanJournal, bc.cacheConfig.TrieCleanRejournal, bc.quit) }() } // Rewind the chain in case of an incompatible config upgrade. @@ -594,7 +597,7 @@ func (bc *BlockChain) setHeadBeyondRoot(head uint64, root common.Hash, repair bo if root != (common.Hash{}) && !beyondRoot && newHeadBlock.Root() == root { beyondRoot, rootNumber = true, newHeadBlock.NumberU64() } - if _, err := state.New(newHeadBlock.Root(), bc.stateCache, bc.snaps); err != nil { + if !bc.HasState(newHeadBlock.Root()) { log.Trace("Block state missing, rewinding further", "number", newHeadBlock.NumberU64(), "hash", newHeadBlock.Hash()) if pivot == nil || newHeadBlock.NumberU64() > *pivot { parent := bc.GetBlock(newHeadBlock.ParentHash(), newHeadBlock.NumberU64()-1) @@ -617,7 +620,7 @@ func (bc *BlockChain) setHeadBeyondRoot(head uint64, root common.Hash, repair bo // if the historical chain pruning is enabled. In that case the logic // needs to be improved here. if !bc.HasState(bc.genesisBlock.Root()) { - if err := CommitGenesisState(bc.db, bc.genesisBlock.Hash()); err != nil { + if err := CommitGenesisState(bc.db, bc.triedb, bc.genesisBlock.Hash()); err != nil { log.Crit("Failed to commit genesis state", "err", err) } log.Debug("Recommitted genesis state to disk") @@ -900,7 +903,7 @@ func (bc *BlockChain) Stop() { // - HEAD-1: So we don't do large reorgs if our HEAD becomes an uncle // - HEAD-127: So we have a hard limit on the number of blocks reexecuted if !bc.cacheConfig.TrieDirtyDisabled { - triedb := bc.stateCache.TrieDB() + triedb := bc.triedb for _, offset := range []uint64{0, 1, TriesInMemory - 1} { if number := bc.CurrentBlock().NumberU64(); number > offset { @@ -932,8 +935,7 @@ func (bc *BlockChain) Stop() { // Ensure all live cached entries be saved into disk, so that we can skip // cache warmup when node restarts. if bc.cacheConfig.TrieCleanJournal != "" { - triedb := bc.stateCache.TrieDB() - triedb.SaveCache(bc.cacheConfig.TrieCleanJournal) + bc.triedb.SaveCache(bc.cacheConfig.TrieCleanJournal) } log.Info("Blockchain stopped") } @@ -1306,24 +1308,22 @@ func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types. if err != nil { return err } - triedb := bc.stateCache.TrieDB() - // If we're running an archive node, always flush if bc.cacheConfig.TrieDirtyDisabled { - return triedb.Commit(root, false, nil) + return bc.triedb.Commit(root, false, nil) } else { // Full but not archive node, do proper garbage collection - triedb.Reference(root, common.Hash{}) // metadata reference to keep trie alive + bc.triedb.Reference(root, common.Hash{}) // metadata reference to keep trie alive bc.triegc.Push(root, -int64(block.NumberU64())) if current := block.NumberU64(); current > TriesInMemory { // If we exceeded our memory allowance, flush matured singleton nodes to disk var ( - nodes, imgs = triedb.Size() + nodes, imgs = bc.triedb.Size() limit = common.StorageSize(bc.cacheConfig.TrieDirtyLimit) * 1024 * 1024 ) if nodes > limit || imgs > 4*1024*1024 { - triedb.Cap(limit - ethdb.IdealBatchSize) + bc.triedb.Cap(limit - ethdb.IdealBatchSize) } // Find the next state trie we need to commit chosen := current - TriesInMemory @@ -1342,7 +1342,7 @@ func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types. log.Info("State in memory for too long, committing", "time", bc.gcproc, "allowance", bc.cacheConfig.TrieTimeLimit, "optimum", float64(chosen-lastWrite)/TriesInMemory) } // Flush an entire trie and restart the counters - triedb.Commit(header.Root, true, nil) + bc.triedb.Commit(header.Root, true, nil) lastWrite = chosen bc.gcproc = 0 } @@ -1354,7 +1354,7 @@ func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types. bc.triegc.Push(root, number) break } - triedb.Dereference(root.(common.Hash)) + bc.triedb.Dereference(root.(common.Hash)) } } } @@ -1760,10 +1760,14 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals, setHead bool) stats.processed++ stats.usedGas += usedGas - dirty, _ := bc.stateCache.TrieDB().Size() + dirty, _ := bc.triedb.Size() stats.report(chain, it.index, dirty, setHead) if !setHead { + // After merge we expect few side chains. Simply count + // all blocks the CL gives us for GC processing time + bc.gcproc += proctime + return it.index, nil // Direct block insertion of a single block } switch status { diff --git a/core/blockchain_reader.go b/core/blockchain_reader.go index da948029a13e..e8a5d952a240 100644 --- a/core/blockchain_reader.go +++ b/core/blockchain_reader.go @@ -29,6 +29,7 @@ import ( "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" ) // CurrentHeader retrieves the current head header of the canonical chain. The @@ -375,6 +376,11 @@ func (bc *BlockChain) TxLookupLimit() uint64 { return bc.txLookupLimit } +// TrieDB retrieves the low level trie database used for data storage. +func (bc *BlockChain) TrieDB() *trie.Database { + return bc.triedb +} + // SubscribeRemovedLogsEvent registers a subscription of RemovedLogsEvent. func (bc *BlockChain) SubscribeRemovedLogsEvent(ch chan<- RemovedLogsEvent) event.Subscription { return bc.scope.Track(bc.rmLogsFeed.Subscribe(ch)) diff --git a/core/chain_makers.go b/core/chain_makers.go index 0af66c28046a..bb75ea86b0b2 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -29,6 +29,7 @@ import ( "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/trie" ) // BlockGen creates blocks for testing. @@ -334,7 +335,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse // then generate chain on top. func GenerateChainWithGenesis(genesis *Genesis, engine consensus.Engine, n int, gen func(int, *BlockGen)) (ethdb.Database, []*types.Block, []types.Receipts) { db := rawdb.NewMemoryDatabase() - _, err := genesis.Commit(db) + _, err := genesis.Commit(db, trie.NewDatabase(db)) if err != nil { panic(err) } diff --git a/core/forkchoice.go b/core/forkchoice.go index b0dbb200ecc7..b293c851bf27 100644 --- a/core/forkchoice.go +++ b/core/forkchoice.go @@ -74,10 +74,10 @@ func NewForkChoice(chainReader ChainReader, preserve func(header *types.Header) // In the td mode, the new head is chosen if the corresponding // total difficulty is higher. In the extern mode, the trusted // header is always selected as the head. -func (f *ForkChoice) ReorgNeeded(current *types.Header, header *types.Header) (bool, error) { +func (f *ForkChoice) ReorgNeeded(current *types.Header, extern *types.Header) (bool, error) { var ( localTD = f.chain.GetTd(current.Hash(), current.Number.Uint64()) - externTd = f.chain.GetTd(header.Hash(), header.Number.Uint64()) + externTd = f.chain.GetTd(extern.Hash(), extern.Number.Uint64()) ) if localTD == nil || externTd == nil { return false, errors.New("missing td") @@ -88,21 +88,26 @@ func (f *ForkChoice) ReorgNeeded(current *types.Header, header *types.Header) (b if ttd := f.chain.Config().TerminalTotalDifficulty; ttd != nil && ttd.Cmp(externTd) <= 0 { return true, nil } + // If the total difficulty is higher than our known, add it to the canonical chain + if diff := externTd.Cmp(localTD); diff > 0 { + return true, nil + } else if diff < 0 { + return false, nil + } + // Local and external difficulty is identical. // Second clause in the if statement reduces the vulnerability to selfish mining. // Please refer to http://www.cs.cornell.edu/~ie53/publications/btcProcFC.pdf - reorg := externTd.Cmp(localTD) > 0 - if !reorg && externTd.Cmp(localTD) == 0 { - number, headNumber := header.Number.Uint64(), current.Number.Uint64() - if number < headNumber { - reorg = true - } else if number == headNumber { - var currentPreserve, externPreserve bool - if f.preserve != nil { - currentPreserve, externPreserve = f.preserve(current), f.preserve(header) - } - reorg = !currentPreserve && (externPreserve || f.rand.Float64() < 0.5) + reorg := false + externNum, localNum := extern.Number.Uint64(), current.Number.Uint64() + if externNum < localNum { + reorg = true + } else if externNum == localNum { + var currentPreserve, externPreserve bool + if f.preserve != nil { + currentPreserve, externPreserve = f.preserve(current), f.preserve(extern) } + reorg = !currentPreserve && (externPreserve || f.rand.Float64() < 0.5) } return reorg, nil } diff --git a/core/genesis.go b/core/genesis.go index 175540b1abf6..7406dc4cef0b 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -139,8 +139,8 @@ func (ga *GenesisAlloc) deriveHash() (common.Hash, error) { // flush is very similar with deriveHash, but the main difference is // all the generated states will be persisted into the given database. // Also, the genesis state specification will be flushed as well. -func (ga *GenesisAlloc) flush(db ethdb.Database) error { - statedb, err := state.New(common.Hash{}, state.NewDatabaseWithConfig(db, &trie.Config{Preimages: true}), nil) +func (ga *GenesisAlloc) flush(db ethdb.Database, triedb *trie.Database) error { + statedb, err := state.New(common.Hash{}, state.NewDatabaseWithNodeDB(db, triedb), nil) if err != nil { return err } @@ -156,9 +156,11 @@ func (ga *GenesisAlloc) flush(db ethdb.Database) error { if err != nil { return err } - err = statedb.Database().TrieDB().Commit(root, true, nil) - if err != nil { - return err + // Commit newly generated states into disk if it's not empty. + if root != types.EmptyRootHash { + if err := triedb.Commit(root, true, nil); err != nil { + return err + } } // Marshal the genesis state specification and persist. blob, err := json.Marshal(ga) @@ -170,8 +172,8 @@ func (ga *GenesisAlloc) flush(db ethdb.Database) error { } // CommitGenesisState loads the stored genesis state with the given block -// hash and commits them into the given database handler. -func CommitGenesisState(db ethdb.Database, hash common.Hash) error { +// hash and commits it into the provided trie database. +func CommitGenesisState(db ethdb.Database, triedb *trie.Database, hash common.Hash) error { var alloc GenesisAlloc blob := rawdb.ReadGenesisStateSpec(db, hash) if len(blob) != 0 { @@ -203,7 +205,7 @@ func CommitGenesisState(db ethdb.Database, hash common.Hash) error { return errors.New("not found") } } - return alloc.flush(db) + return alloc.flush(db, triedb) } // GenesisAccount is an account in the state of the genesis block. @@ -286,15 +288,14 @@ type ChainOverrides struct { // error is a *params.ConfigCompatError and the new, unwritten config is returned. // // The returned chain configuration is never nil. -func SetupGenesisBlock(db ethdb.Database, genesis *Genesis) (*params.ChainConfig, common.Hash, error) { - return SetupGenesisBlockWithOverride(db, genesis, nil) +func SetupGenesisBlock(db ethdb.Database, triedb *trie.Database, genesis *Genesis) (*params.ChainConfig, common.Hash, error) { + return SetupGenesisBlockWithOverride(db, triedb, genesis, nil) } -func SetupGenesisBlockWithOverride(db ethdb.Database, genesis *Genesis, overrides *ChainOverrides) (*params.ChainConfig, common.Hash, error) { +func SetupGenesisBlockWithOverride(db ethdb.Database, triedb *trie.Database, genesis *Genesis, overrides *ChainOverrides) (*params.ChainConfig, common.Hash, error) { if genesis != nil && genesis.Config == nil { return params.AllEthashProtocolChanges, common.Hash{}, errGenesisNoConfig } - applyOverrides := func(config *params.ChainConfig) { if config != nil { if overrides != nil && overrides.OverrideTerminalTotalDifficulty != nil { @@ -315,7 +316,7 @@ func SetupGenesisBlockWithOverride(db ethdb.Database, genesis *Genesis, override } else { log.Info("Writing custom genesis block") } - block, err := genesis.Commit(db) + block, err := genesis.Commit(db, triedb) if err != nil { return genesis.Config, common.Hash{}, err } @@ -325,7 +326,7 @@ func SetupGenesisBlockWithOverride(db ethdb.Database, genesis *Genesis, override // We have the genesis block in database(perhaps in ancient database) // but the corresponding state is missing. header := rawdb.ReadHeader(db, stored, 0) - if _, err := state.New(header.Root, state.NewDatabaseWithConfig(db, nil), nil); err != nil { + if _, err := state.New(header.Root, state.NewDatabaseWithNodeDB(db, triedb), nil); err != nil { if genesis == nil { genesis = DefaultGenesisBlock() } @@ -334,7 +335,7 @@ func SetupGenesisBlockWithOverride(db ethdb.Database, genesis *Genesis, override if hash != stored { return genesis.Config, hash, &GenesisMismatchError{stored, hash} } - block, err := genesis.Commit(db) + block, err := genesis.Commit(db, triedb) if err != nil { return genesis.Config, hash, err } @@ -487,7 +488,7 @@ func (g *Genesis) ToBlock() *types.Block { // Commit writes the block and state of a genesis specification to the database. // The block is committed as the canonical head block. -func (g *Genesis) Commit(db ethdb.Database) (*types.Block, error) { +func (g *Genesis) Commit(db ethdb.Database, triedb *trie.Database) (*types.Block, error) { block := g.ToBlock() if block.Number().Sign() != 0 { return nil, errors.New("can't commit genesis block with number > 0") @@ -505,7 +506,7 @@ func (g *Genesis) Commit(db ethdb.Database) (*types.Block, error) { // All the checks has passed, flush the states derived from the genesis // specification as well as the specification itself into the provided // database. - if err := g.Alloc.flush(db); err != nil { + if err := g.Alloc.flush(db, triedb); err != nil { return nil, err } rawdb.WriteTd(db, block.Hash(), block.NumberU64(), block.Difficulty()) @@ -521,8 +522,10 @@ func (g *Genesis) Commit(db ethdb.Database) (*types.Block, error) { // MustCommit writes the genesis block and state to db, panicking on error. // The block is committed as the canonical head block. +// Note the state changes will be committed in hash-based scheme, use Commit +// if path-scheme is preferred. func (g *Genesis) MustCommit(db ethdb.Database) *types.Block { - block, err := g.Commit(db) + block, err := g.Commit(db, trie.NewDatabase(db)) if err != nil { panic(err) } diff --git a/core/genesis_test.go b/core/genesis_test.go index a7d04f53fe23..135ecb934c03 100644 --- a/core/genesis_test.go +++ b/core/genesis_test.go @@ -17,6 +17,7 @@ package core import ( + "encoding/json" "math/big" "reflect" "testing" @@ -28,12 +29,14 @@ import ( "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/trie" ) func TestInvalidCliqueConfig(t *testing.T) { block := DefaultGoerliGenesisBlock() block.ExtraData = []byte{} - if _, err := block.Commit(nil); err == nil { + db := rawdb.NewMemoryDatabase() + if _, err := block.Commit(db, trie.NewDatabase(db)); err == nil { t.Fatal("Expected error on invalid clique config") } } @@ -60,7 +63,7 @@ func TestSetupGenesis(t *testing.T) { { name: "genesis without ChainConfig", fn: func(db ethdb.Database) (*params.ChainConfig, common.Hash, error) { - return SetupGenesisBlock(db, new(Genesis)) + return SetupGenesisBlock(db, trie.NewDatabase(db), new(Genesis)) }, wantErr: errGenesisNoConfig, wantConfig: params.AllEthashProtocolChanges, @@ -68,7 +71,7 @@ func TestSetupGenesis(t *testing.T) { { name: "no block in DB, genesis == nil", fn: func(db ethdb.Database) (*params.ChainConfig, common.Hash, error) { - return SetupGenesisBlock(db, nil) + return SetupGenesisBlock(db, trie.NewDatabase(db), nil) }, wantHash: params.MainnetGenesisHash, wantConfig: params.MainnetChainConfig, @@ -77,7 +80,7 @@ func TestSetupGenesis(t *testing.T) { name: "mainnet block in DB, genesis == nil", fn: func(db ethdb.Database) (*params.ChainConfig, common.Hash, error) { DefaultGenesisBlock().MustCommit(db) - return SetupGenesisBlock(db, nil) + return SetupGenesisBlock(db, trie.NewDatabase(db), nil) }, wantHash: params.MainnetGenesisHash, wantConfig: params.MainnetChainConfig, @@ -86,7 +89,7 @@ func TestSetupGenesis(t *testing.T) { name: "custom block in DB, genesis == nil", fn: func(db ethdb.Database) (*params.ChainConfig, common.Hash, error) { customg.MustCommit(db) - return SetupGenesisBlock(db, nil) + return SetupGenesisBlock(db, trie.NewDatabase(db), nil) }, wantHash: customghash, wantConfig: customg.Config, @@ -95,7 +98,7 @@ func TestSetupGenesis(t *testing.T) { name: "custom block in DB, genesis == ropsten", fn: func(db ethdb.Database) (*params.ChainConfig, common.Hash, error) { customg.MustCommit(db) - return SetupGenesisBlock(db, DefaultRopstenGenesisBlock()) + return SetupGenesisBlock(db, trie.NewDatabase(db), DefaultRopstenGenesisBlock()) }, wantErr: &GenesisMismatchError{Stored: customghash, New: params.RopstenGenesisHash}, wantHash: params.RopstenGenesisHash, @@ -105,7 +108,7 @@ func TestSetupGenesis(t *testing.T) { name: "compatible config in DB", fn: func(db ethdb.Database) (*params.ChainConfig, common.Hash, error) { oldcustomg.MustCommit(db) - return SetupGenesisBlock(db, &customg) + return SetupGenesisBlock(db, trie.NewDatabase(db), &customg) }, wantHash: customghash, wantConfig: customg.Config, @@ -122,9 +125,9 @@ func TestSetupGenesis(t *testing.T) { blocks, _ := GenerateChain(oldcustomg.Config, genesis, ethash.NewFaker(), db, 4, nil) bc.InsertChain(blocks) - bc.CurrentBlock() + // This should return a compatibility error. - return SetupGenesisBlock(db, &customg) + return SetupGenesisBlock(db, trie.NewDatabase(db), &customg) }, wantHash: customghash, wantConfig: customg.Config, @@ -193,6 +196,7 @@ func TestGenesis_Commit(t *testing.T) { db := rawdb.NewMemoryDatabase() genesisBlock := genesis.MustCommit(db) + if genesis.Difficulty != nil { t.Fatalf("assumption wrong") } @@ -219,7 +223,8 @@ func TestReadWriteGenesisAlloc(t *testing.T) { } hash, _ = alloc.deriveHash() ) - alloc.flush(db) + blob, _ := json.Marshal(alloc) + rawdb.WriteGenesisStateSpec(db, hash, blob) var reload GenesisAlloc err := reload.UnmarshalJSON(rawdb.ReadGenesisStateSpec(db, hash)) diff --git a/core/headerchain_test.go b/core/headerchain_test.go index fe083b003145..08d19f695072 100644 --- a/core/headerchain_test.go +++ b/core/headerchain_test.go @@ -28,6 +28,7 @@ import ( "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/trie" ) func verifyUnbrokenCanonchain(hc *HeaderChain) error { @@ -72,7 +73,7 @@ func TestHeaderInsertion(t *testing.T) { db = rawdb.NewMemoryDatabase() gspec = &Genesis{BaseFee: big.NewInt(params.InitialBaseFee), Config: params.AllEthashProtocolChanges} ) - gspec.Commit(db) + gspec.Commit(db, trie.NewDatabase(db)) hc, err := NewHeaderChain(db, gspec.Config, ethash.NewFaker(), func() bool { return false }) if err != nil { t.Fatal(err) diff --git a/core/rawdb/accessors_chain.go b/core/rawdb/accessors_chain.go index 0ca9292e4f4d..ae6a97877d3b 100644 --- a/core/rawdb/accessors_chain.go +++ b/core/rawdb/accessors_chain.go @@ -674,10 +674,11 @@ func DeleteReceipts(db ethdb.KeyValueWriter, hash common.Hash, number uint64) { // storedReceiptRLP is the storage encoding of a receipt. // Re-definition in core/types/receipt.go. +// TODO: Re-use the existing definition. type storedReceiptRLP struct { PostStateOrStatus []byte CumulativeGasUsed uint64 - Logs []*types.LogForStorage + Logs []*types.Log } // ReceiptLogs is a barebone version of ReceiptForStorage which only keeps @@ -693,10 +694,7 @@ func (r *receiptLogs) DecodeRLP(s *rlp.Stream) error { if err := s.Decode(&stored); err != nil { return err } - r.Logs = make([]*types.Log, len(stored.Logs)) - for i, log := range stored.Logs { - r.Logs[i] = (*types.Log)(log) - } + r.Logs = stored.Logs return nil } @@ -732,11 +730,6 @@ func ReadLogs(db ethdb.Reader, hash common.Hash, number uint64, config *params.C } receipts := []*receiptLogs{} if err := rlp.DecodeBytes(data, &receipts); err != nil { - // Receipts might be in the legacy format, try decoding that. - // TODO: to be removed after users migrated - if logs := readLegacyLogs(db, hash, number, config); logs != nil { - return logs - } log.Error("Invalid receipt array RLP", "hash", hash, "err", err) return nil } @@ -757,21 +750,6 @@ func ReadLogs(db ethdb.Reader, hash common.Hash, number uint64, config *params.C return logs } -// readLegacyLogs is a temporary workaround for when trying to read logs -// from a block which has its receipt stored in the legacy format. It'll -// be removed after users have migrated their freezer databases. -func readLegacyLogs(db ethdb.Reader, hash common.Hash, number uint64, config *params.ChainConfig) [][]*types.Log { - receipts := ReadReceipts(db, hash, number, config) - if receipts == nil { - return nil - } - logs := make([][]*types.Log, len(receipts)) - for i, receipt := range receipts { - logs[i] = receipt.Logs - } - return logs -} - // ReadBlock retrieves an entire block corresponding to the hash, assembling it // back from the stored header and body. If either the header or body could not // be retrieved nil is returned. diff --git a/core/rawdb/freezer.go b/core/rawdb/freezer.go index 53bd989a482d..7bae0a2ea0d1 100644 --- a/core/rawdb/freezer.go +++ b/core/rawdb/freezer.go @@ -318,30 +318,35 @@ func (f *Freezer) Sync() error { return nil } -// validate checks that every table has the same length. +// validate checks that every table has the same boundary. // Used instead of `repair` in readonly mode. func (f *Freezer) validate() error { if len(f.tables) == 0 { return nil } var ( - length uint64 - name string + head uint64 + tail uint64 + name string ) - // Hack to get length of any table + // Hack to get boundary of any table for kind, table := range f.tables { - length = atomic.LoadUint64(&table.items) + head = atomic.LoadUint64(&table.items) + tail = atomic.LoadUint64(&table.itemHidden) name = kind break } - // Now check every table against that length + // Now check every table against those boundaries. for kind, table := range f.tables { - items := atomic.LoadUint64(&table.items) - if length != items { - return fmt.Errorf("freezer tables %s and %s have differing lengths: %d != %d", kind, name, items, length) + if head != atomic.LoadUint64(&table.items) { + return fmt.Errorf("freezer tables %s and %s have differing head: %d != %d", kind, name, atomic.LoadUint64(&table.items), head) + } + if tail != atomic.LoadUint64(&table.itemHidden) { + return fmt.Errorf("freezer tables %s and %s have differing tail: %d != %d", kind, name, atomic.LoadUint64(&table.itemHidden), tail) } } - atomic.StoreUint64(&f.frozen, length) + atomic.StoreUint64(&f.frozen, head) + atomic.StoreUint64(&f.tail, tail) return nil } diff --git a/core/rawdb/freezer_table.go b/core/rawdb/freezer_table.go index 746f825e4038..7af937fd81ad 100644 --- a/core/rawdb/freezer_table.go +++ b/core/rawdb/freezer_table.go @@ -867,13 +867,20 @@ func (t *freezerTable) advanceHead() error { // Sync pushes any pending data from memory out to disk. This is an expensive // operation, so use it with care. func (t *freezerTable) Sync() error { - if err := t.index.Sync(); err != nil { - return err - } - if err := t.meta.Sync(); err != nil { - return err + t.lock.Lock() + defer t.lock.Unlock() + + var err error + trackError := func(e error) { + if e != nil && err == nil { + err = e + } } - return t.head.Sync() + + trackError(t.index.Sync()) + trackError(t.meta.Sync()) + trackError(t.head.Sync()) + return err } func (t *freezerTable) dumpIndexStdout(start, stop int64) { diff --git a/core/state/database.go b/core/state/database.go index 2de0650df892..fbd6d2883cc0 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -130,23 +130,33 @@ func NewDatabase(db ethdb.Database) Database { // large memory cache. func NewDatabaseWithConfig(db ethdb.Database, config *trie.Config) Database { return &cachingDB{ - db: trie.NewDatabaseWithConfig(db, config), disk: db, codeSizeCache: lru.NewCache[common.Hash, int](codeSizeCacheSize), codeCache: lru.NewSizeConstrainedCache[common.Hash, []byte](codeCacheSize), + triedb: trie.NewDatabaseWithConfig(db, config), + } +} + +// NewDatabaseWithNodeDB creates a state database with an already initialized node database. +func NewDatabaseWithNodeDB(db ethdb.Database, triedb *trie.Database) Database { + return &cachingDB{ + disk: db, + codeSizeCache: lru.NewCache[common.Hash, int](codeSizeCacheSize), + codeCache: lru.NewSizeConstrainedCache[common.Hash, []byte](codeCacheSize), + triedb: triedb, } } type cachingDB struct { - db *trie.Database disk ethdb.KeyValueStore codeSizeCache *lru.Cache[common.Hash, int] codeCache *lru.SizeConstrainedCache[common.Hash, []byte] + triedb *trie.Database } // OpenTrie opens the main account trie at a specific root hash. func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) { - tr, err := trie.NewStateTrie(trie.StateTrieID(root), db.db) + tr, err := trie.NewStateTrie(trie.StateTrieID(root), db.triedb) if err != nil { return nil, err } @@ -155,7 +165,7 @@ func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) { // OpenStorageTrie opens the storage trie of an account. func (db *cachingDB) OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error) { - tr, err := trie.NewStateTrie(trie.StorageTrieID(stateRoot, addrHash, root), db.db) + tr, err := trie.NewStateTrie(trie.StorageTrieID(stateRoot, addrHash, root), db.triedb) if err != nil { return nil, err } @@ -220,5 +230,5 @@ func (db *cachingDB) DiskDB() ethdb.KeyValueStore { // TrieDB retrieves any intermediate trie-node caching layer. func (db *cachingDB) TrieDB() *trie.Database { - return db.db + return db.triedb } diff --git a/core/state/iterator_test.go b/core/state/iterator_test.go index f9337512647a..7669ac97a215 100644 --- a/core/state/iterator_test.go +++ b/core/state/iterator_test.go @@ -26,10 +26,10 @@ import ( // Tests that the node iterator indeed walks over the entire database contents. func TestNodeIteratorCoverage(t *testing.T) { // Create some arbitrary test state to iterate - db, root, _ := makeTestState() - db.TrieDB().Commit(root, false, nil) + db, sdb, root, _ := makeTestState() + sdb.TrieDB().Commit(root, false, nil) - state, err := New(root, db, nil) + state, err := New(root, sdb, nil) if err != nil { t.Fatalf("failed to create state trie at %x: %v", root, err) } @@ -42,19 +42,19 @@ func TestNodeIteratorCoverage(t *testing.T) { } // Cross check the iterated hashes and the database/nodepool content for hash := range hashes { - if _, err = db.TrieDB().Node(hash); err != nil { - _, err = db.ContractCode(common.Hash{}, hash) + if _, err = sdb.TrieDB().Node(hash); err != nil { + _, err = sdb.ContractCode(common.Hash{}, hash) } if err != nil { t.Errorf("failed to retrieve reported node %x", hash) } } - for _, hash := range db.TrieDB().Nodes() { + for _, hash := range sdb.TrieDB().Nodes() { if _, ok := hashes[hash]; !ok { t.Errorf("state entry not reported %x", hash) } } - it := db.DiskDB().NewIterator(nil, nil) + it := db.NewIterator(nil, nil) for it.Next() { key := it.Key() if bytes.HasPrefix(key, []byte("secure-key-")) { diff --git a/core/state/snapshot/conversion.go b/core/state/snapshot/conversion.go index c15b17aa87e4..43fee456d8e9 100644 --- a/core/state/snapshot/conversion.go +++ b/core/state/snapshot/conversion.go @@ -43,7 +43,7 @@ type trieKV struct { type ( // trieGeneratorFn is the interface of trie generation which can // be implemented by different trie algorithm. - trieGeneratorFn func(db ethdb.KeyValueWriter, owner common.Hash, in chan (trieKV), out chan (common.Hash)) + trieGeneratorFn func(db ethdb.KeyValueWriter, scheme trie.NodeScheme, owner common.Hash, in chan (trieKV), out chan (common.Hash)) // leafCallbackFn is the callback invoked at the leaves of the trie, // returns the subtrie root with the specified subtrie identifier. @@ -52,12 +52,12 @@ type ( // GenerateAccountTrieRoot takes an account iterator and reproduces the root hash. func GenerateAccountTrieRoot(it AccountIterator) (common.Hash, error) { - return generateTrieRoot(nil, it, common.Hash{}, stackTrieGenerate, nil, newGenerateStats(), true) + return generateTrieRoot(nil, nil, it, common.Hash{}, stackTrieGenerate, nil, newGenerateStats(), true) } // GenerateStorageTrieRoot takes a storage iterator and reproduces the root hash. func GenerateStorageTrieRoot(account common.Hash, it StorageIterator) (common.Hash, error) { - return generateTrieRoot(nil, it, account, stackTrieGenerate, nil, newGenerateStats(), true) + return generateTrieRoot(nil, nil, it, account, stackTrieGenerate, nil, newGenerateStats(), true) } // GenerateTrie takes the whole snapshot tree as the input, traverses all the @@ -71,7 +71,8 @@ func GenerateTrie(snaptree *Tree, root common.Hash, src ethdb.Database, dst ethd } defer acctIt.Release() - got, err := generateTrieRoot(dst, acctIt, common.Hash{}, stackTrieGenerate, func(dst ethdb.KeyValueWriter, accountHash, codeHash common.Hash, stat *generateStats) (common.Hash, error) { + scheme := snaptree.triedb.Scheme() + got, err := generateTrieRoot(dst, scheme, acctIt, common.Hash{}, stackTrieGenerate, func(dst ethdb.KeyValueWriter, accountHash, codeHash common.Hash, stat *generateStats) (common.Hash, error) { // Migrate the code first, commit the contract code into the tmp db. if codeHash != emptyCode { code := rawdb.ReadCode(src, codeHash) @@ -87,7 +88,7 @@ func GenerateTrie(snaptree *Tree, root common.Hash, src ethdb.Database, dst ethd } defer storageIt.Release() - hash, err := generateTrieRoot(dst, storageIt, accountHash, stackTrieGenerate, nil, stat, false) + hash, err := generateTrieRoot(dst, scheme, storageIt, accountHash, stackTrieGenerate, nil, stat, false) if err != nil { return common.Hash{}, err } @@ -242,7 +243,7 @@ func runReport(stats *generateStats, stop chan bool) { // generateTrieRoot generates the trie hash based on the snapshot iterator. // It can be used for generating account trie, storage trie or even the // whole state which connects the accounts and the corresponding storages. -func generateTrieRoot(db ethdb.KeyValueWriter, it Iterator, account common.Hash, generatorFn trieGeneratorFn, leafCallback leafCallbackFn, stats *generateStats, report bool) (common.Hash, error) { +func generateTrieRoot(db ethdb.KeyValueWriter, scheme trie.NodeScheme, it Iterator, account common.Hash, generatorFn trieGeneratorFn, leafCallback leafCallbackFn, stats *generateStats, report bool) (common.Hash, error) { var ( in = make(chan trieKV) // chan to pass leaves out = make(chan common.Hash, 1) // chan to collect result @@ -253,7 +254,7 @@ func generateTrieRoot(db ethdb.KeyValueWriter, it Iterator, account common.Hash, wg.Add(1) go func() { defer wg.Done() - generatorFn(db, account, in, out) + generatorFn(db, scheme, account, in, out) }() // Spin up a go-routine for progress logging if report && stats != nil { @@ -360,8 +361,14 @@ func generateTrieRoot(db ethdb.KeyValueWriter, it Iterator, account common.Hash, return stop(nil) } -func stackTrieGenerate(db ethdb.KeyValueWriter, owner common.Hash, in chan trieKV, out chan common.Hash) { - t := trie.NewStackTrieWithOwner(db, owner) +func stackTrieGenerate(db ethdb.KeyValueWriter, scheme trie.NodeScheme, owner common.Hash, in chan trieKV, out chan common.Hash) { + var nodeWriter trie.NodeWriteFunc + if db != nil { + nodeWriter = func(owner common.Hash, path []byte, hash common.Hash, blob []byte) { + scheme.WriteTrieNode(db, owner, path, hash, blob) + } + } + t := trie.NewStackTrieWithOwner(nodeWriter, owner) for leaf := range in { t.TryUpdate(leaf.key[:], leaf.value) } diff --git a/core/state/snapshot/generate.go b/core/state/snapshot/generate.go index 8589aa784f67..3ed303cdfc75 100644 --- a/core/state/snapshot/generate.go +++ b/core/state/snapshot/generate.go @@ -29,7 +29,6 @@ import ( "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" - "github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" @@ -360,9 +359,9 @@ func (dl *diskLayer) generateRange(ctx *generatorContext, trieId *trie.ID, prefi } // We use the snap data to build up a cache which can be used by the // main account trie as a primary lookup when resolving hashes - var snapNodeCache ethdb.KeyValueStore + var snapNodeCache ethdb.Database if len(result.keys) > 0 { - snapNodeCache = memorydb.New() + snapNodeCache = rawdb.NewMemoryDatabase() snapTrieDb := trie.NewDatabase(snapNodeCache) snapTrie := trie.NewEmpty(snapTrieDb) for i, key := range result.keys { diff --git a/core/state/snapshot/generate_test.go b/core/state/snapshot/generate_test.go index 784d76859e44..3b44d4d481fd 100644 --- a/core/state/snapshot/generate_test.go +++ b/core/state/snapshot/generate_test.go @@ -117,12 +117,12 @@ func checkSnapRoot(t *testing.T, snap *diskLayer, trieRoot common.Hash) { accIt := snap.AccountIterator(common.Hash{}) defer accIt.Release() - snapRoot, err := generateTrieRoot(nil, accIt, common.Hash{}, stackTrieGenerate, + snapRoot, err := generateTrieRoot(nil, nil, accIt, common.Hash{}, stackTrieGenerate, func(db ethdb.KeyValueWriter, accountHash, codeHash common.Hash, stat *generateStats) (common.Hash, error) { storageIt, _ := snap.StorageIterator(accountHash, common.Hash{}) defer storageIt.Release() - hash, err := generateTrieRoot(nil, storageIt, accountHash, stackTrieGenerate, nil, stat, false) + hash, err := generateTrieRoot(nil, nil, storageIt, accountHash, stackTrieGenerate, nil, stat, false) if err != nil { return common.Hash{}, err } diff --git a/core/state/snapshot/snapshot.go b/core/state/snapshot/snapshot.go index f07f8d8e31ef..f8f52056dd7e 100644 --- a/core/state/snapshot/snapshot.go +++ b/core/state/snapshot/snapshot.go @@ -776,14 +776,14 @@ func (t *Tree) Verify(root common.Hash) error { } defer acctIt.Release() - got, err := generateTrieRoot(nil, acctIt, common.Hash{}, stackTrieGenerate, func(db ethdb.KeyValueWriter, accountHash, codeHash common.Hash, stat *generateStats) (common.Hash, error) { + got, err := generateTrieRoot(nil, nil, acctIt, common.Hash{}, stackTrieGenerate, func(db ethdb.KeyValueWriter, accountHash, codeHash common.Hash, stat *generateStats) (common.Hash, error) { storageIt, err := t.StorageIterator(root, accountHash, common.Hash{}) if err != nil { return common.Hash{}, err } defer storageIt.Release() - hash, err := generateTrieRoot(nil, storageIt, accountHash, stackTrieGenerate, nil, stat, false) + hash, err := generateTrieRoot(nil, nil, storageIt, accountHash, stackTrieGenerate, nil, stat, false) if err != nil { return common.Hash{}, err } diff --git a/core/state/sync.go b/core/state/sync.go index 00a4c67aa3cb..b40e75f487f6 100644 --- a/core/state/sync.go +++ b/core/state/sync.go @@ -27,7 +27,7 @@ import ( ) // NewStateSync create a new state trie download scheduler. -func NewStateSync(root common.Hash, database ethdb.KeyValueReader, onLeaf func(keys [][]byte, leaf []byte) error) *trie.Sync { +func NewStateSync(root common.Hash, database ethdb.KeyValueReader, onLeaf func(keys [][]byte, leaf []byte) error, scheme trie.NodeScheme) *trie.Sync { // Register the storage slot callback if the external callback is specified. var onSlot func(keys [][]byte, path []byte, leaf []byte, parent common.Hash, parentPath []byte) error if onLeaf != nil { @@ -52,6 +52,6 @@ func NewStateSync(root common.Hash, database ethdb.KeyValueReader, onLeaf func(k syncer.AddCodeEntry(common.BytesToHash(obj.CodeHash), path, parent, parentPath) return nil } - syncer = trie.NewSync(root, database, onAccount) + syncer = trie.NewSync(root, database, onAccount, scheme) return syncer } diff --git a/core/state/sync_test.go b/core/state/sync_test.go index dbcbb7c96344..62eba60fa01c 100644 --- a/core/state/sync_test.go +++ b/core/state/sync_test.go @@ -39,10 +39,11 @@ type testAccount struct { } // makeTestState create a sample test state to test node-wise reconstruction. -func makeTestState() (Database, common.Hash, []*testAccount) { +func makeTestState() (ethdb.Database, Database, common.Hash, []*testAccount) { // Create an empty state - db := NewDatabase(rawdb.NewMemoryDatabase()) - state, _ := New(common.Hash{}, db, nil) + db := rawdb.NewMemoryDatabase() + sdb := NewDatabase(db) + state, _ := New(common.Hash{}, sdb, nil) // Fill it with some arbitrary data var accounts []*testAccount @@ -63,7 +64,7 @@ func makeTestState() (Database, common.Hash, []*testAccount) { if i%5 == 0 { for j := byte(0); j < 5; j++ { hash := crypto.Keccak256Hash([]byte{i, i, i, i, i, j, j}) - obj.SetState(db, hash, hash) + obj.SetState(sdb, hash, hash) } } state.updateStateObject(obj) @@ -72,7 +73,7 @@ func makeTestState() (Database, common.Hash, []*testAccount) { root, _ := state.Commit(false) // Return the generated state - return db, root, accounts + return db, sdb, root, accounts } // checkStateAccounts cross references a reconstructed state with an expected @@ -100,7 +101,7 @@ func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accou } // checkTrieConsistency checks that all nodes in a (sub-)trie are indeed present. -func checkTrieConsistency(db ethdb.KeyValueStore, root common.Hash) error { +func checkTrieConsistency(db ethdb.Database, root common.Hash) error { if v, _ := db.Get(root[:]); v == nil { return nil // Consider a non existent state consistent. } @@ -132,8 +133,9 @@ func checkStateConsistency(db ethdb.Database, root common.Hash) error { // Tests that an empty state is not scheduled for syncing. func TestEmptyStateSync(t *testing.T) { + db := trie.NewDatabase(rawdb.NewMemoryDatabase()) empty := common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") - sync := NewStateSync(empty, rawdb.NewMemoryDatabase(), nil) + sync := NewStateSync(empty, rawdb.NewMemoryDatabase(), nil, db.Scheme()) if paths, nodes, codes := sync.Missing(1); len(paths) != 0 || len(nodes) != 0 || len(codes) != 0 { t.Errorf("content requested for empty state: %v, %v, %v", nodes, paths, codes) } @@ -170,7 +172,7 @@ type stateElement struct { func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) { // Create a random state to copy - srcDb, srcRoot, srcAccounts := makeTestState() + _, srcDb, srcRoot, srcAccounts := makeTestState() if commit { srcDb.TrieDB().Commit(srcRoot, false, nil) } @@ -178,7 +180,7 @@ func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) { // Create a destination state and sync with the scheduler dstDb := rawdb.NewMemoryDatabase() - sched := NewStateSync(srcRoot, dstDb, nil) + sched := NewStateSync(srcRoot, dstDb, nil, srcDb.TrieDB().Scheme()) var ( nodeElements []stateElement @@ -281,11 +283,11 @@ func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) { // partial results are returned, and the others sent only later. func TestIterativeDelayedStateSync(t *testing.T) { // Create a random state to copy - srcDb, srcRoot, srcAccounts := makeTestState() + _, srcDb, srcRoot, srcAccounts := makeTestState() // Create a destination state and sync with the scheduler dstDb := rawdb.NewMemoryDatabase() - sched := NewStateSync(srcRoot, dstDb, nil) + sched := NewStateSync(srcRoot, dstDb, nil, srcDb.TrieDB().Scheme()) var ( nodeElements []stateElement @@ -374,11 +376,11 @@ func TestIterativeRandomStateSyncBatched(t *testing.T) { testIterativeRandomS func testIterativeRandomStateSync(t *testing.T, count int) { // Create a random state to copy - srcDb, srcRoot, srcAccounts := makeTestState() + _, srcDb, srcRoot, srcAccounts := makeTestState() // Create a destination state and sync with the scheduler dstDb := rawdb.NewMemoryDatabase() - sched := NewStateSync(srcRoot, dstDb, nil) + sched := NewStateSync(srcRoot, dstDb, nil, srcDb.TrieDB().Scheme()) nodeQueue := make(map[string]stateElement) codeQueue := make(map[common.Hash]struct{}) @@ -454,11 +456,11 @@ func testIterativeRandomStateSync(t *testing.T, count int) { // partial results are returned (Even those randomly), others sent only later. func TestIterativeRandomDelayedStateSync(t *testing.T) { // Create a random state to copy - srcDb, srcRoot, srcAccounts := makeTestState() + _, srcDb, srcRoot, srcAccounts := makeTestState() // Create a destination state and sync with the scheduler dstDb := rawdb.NewMemoryDatabase() - sched := NewStateSync(srcRoot, dstDb, nil) + sched := NewStateSync(srcRoot, dstDb, nil, srcDb.TrieDB().Scheme()) nodeQueue := make(map[string]stateElement) codeQueue := make(map[common.Hash]struct{}) @@ -544,7 +546,7 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) { // the database. func TestIncompleteStateSync(t *testing.T) { // Create a random state to copy - srcDb, srcRoot, srcAccounts := makeTestState() + db, srcDb, srcRoot, srcAccounts := makeTestState() // isCodeLookup to save some hashing var isCode = make(map[common.Hash]struct{}) @@ -554,15 +556,16 @@ func TestIncompleteStateSync(t *testing.T) { } } isCode[common.BytesToHash(emptyCodeHash)] = struct{}{} - checkTrieConsistency(srcDb.DiskDB(), srcRoot) + checkTrieConsistency(db, srcRoot) // Create a destination state and sync with the scheduler dstDb := rawdb.NewMemoryDatabase() - sched := NewStateSync(srcRoot, dstDb, nil) + sched := NewStateSync(srcRoot, dstDb, nil, srcDb.TrieDB().Scheme()) var ( - addedCodes []common.Hash - addedNodes []common.Hash + addedCodes []common.Hash + addedPaths []string + addedHashes []common.Hash ) nodeQueue := make(map[string]stateElement) codeQueue := make(map[common.Hash]struct{}) @@ -599,15 +602,16 @@ func TestIncompleteStateSync(t *testing.T) { var nodehashes []common.Hash if len(nodeQueue) > 0 { results := make([]trie.NodeSyncResult, 0, len(nodeQueue)) - for key, element := range nodeQueue { + for path, element := range nodeQueue { data, err := srcDb.TrieDB().Node(element.hash) if err != nil { t.Fatalf("failed to retrieve node data for %x", element.hash) } - results = append(results, trie.NodeSyncResult{Path: key, Data: data}) + results = append(results, trie.NodeSyncResult{Path: path, Data: data}) if element.hash != srcRoot { - addedNodes = append(addedNodes, element.hash) + addedPaths = append(addedPaths, element.path) + addedHashes = append(addedHashes, element.hash) } nodehashes = append(nodehashes, element.hash) } @@ -655,12 +659,18 @@ func TestIncompleteStateSync(t *testing.T) { } rawdb.WriteCode(dstDb, node, val) } - for _, node := range addedNodes { - val := rawdb.ReadTrieNode(dstDb, node) - rawdb.DeleteTrieNode(dstDb, node) + scheme := srcDb.TrieDB().Scheme() + for i, path := range addedPaths { + owner, inner := trie.ResolvePath([]byte(path)) + hash := addedHashes[i] + val := scheme.ReadTrieNode(dstDb, owner, inner, hash) + if val == nil { + t.Error("missing trie node") + } + scheme.DeleteTrieNode(dstDb, owner, inner, hash) if err := checkStateConsistency(dstDb, srcRoot); err == nil { - t.Errorf("trie inconsistency not caught, missing: %v", node.Hex()) + t.Errorf("trie inconsistency not caught, missing: %v", path) } - rawdb.WriteTrieNode(dstDb, node, val) + scheme.WriteTrieNode(dstDb, owner, inner, hash, val) } } diff --git a/core/txpool/list.go b/core/txpool/list.go index eb0c753f21e9..062cbbf63e6a 100644 --- a/core/txpool/list.go +++ b/core/txpool/list.go @@ -45,6 +45,7 @@ func (h *nonceHeap) Pop() interface{} { old := *h n := len(old) x := old[n-1] + old[n-1] = 0 *h = old[0 : n-1] return x } diff --git a/core/types/legacy.go b/core/types/legacy.go deleted file mode 100644 index 14ed30d883d4..000000000000 --- a/core/types/legacy.go +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2022 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package types - -import ( - "errors" - - "github.com/ethereum/go-ethereum/rlp" -) - -// IsLegacyStoredReceipts tries to parse the RLP-encoded blob -// first as an array of v3 stored receipt, then v4 stored receipt and -// returns true if successful. -func IsLegacyStoredReceipts(raw []byte) (bool, error) { - var v3 []v3StoredReceiptRLP - if err := rlp.DecodeBytes(raw, &v3); err == nil { - return true, nil - } - var v4 []v4StoredReceiptRLP - if err := rlp.DecodeBytes(raw, &v4); err == nil { - return true, nil - } - var v5 []storedReceiptRLP - // Check to see valid fresh stored receipt - if err := rlp.DecodeBytes(raw, &v5); err == nil { - return false, nil - } - return false, errors.New("value is not a valid receipt encoding") -} - -// ConvertLegacyStoredReceipts takes the RLP encoding of an array of legacy -// stored receipts and returns a fresh RLP-encoded stored receipt. -func ConvertLegacyStoredReceipts(raw []byte) ([]byte, error) { - var receipts []ReceiptForStorage - if err := rlp.DecodeBytes(raw, &receipts); err != nil { - return nil, err - } - return rlp.EncodeToBytes(&receipts) -} diff --git a/core/types/log.go b/core/types/log.go index eb30957b1278..e48919136889 100644 --- a/core/types/log.go +++ b/core/types/log.go @@ -64,24 +64,13 @@ type logMarshaling struct { //go:generate go run ../../rlp/rlpgen -type rlpLog -out gen_log_rlp.go +// rlpLog is used to RLP-encode both the consensus and storage formats. type rlpLog struct { Address common.Address Topics []common.Hash Data []byte } -// legacyRlpStorageLog is the previous storage encoding of a log including some redundant fields. -type legacyRlpStorageLog struct { - Address common.Address - Topics []common.Hash - Data []byte - BlockNumber uint64 - TxHash common.Hash - TxIndex uint - BlockHash common.Hash - Index uint -} - // EncodeRLP implements rlp.Encoder. func (l *Log) EncodeRLP(w io.Writer) error { rl := rlpLog{Address: l.Address, Topics: l.Topics, Data: l.Data} @@ -97,44 +86,3 @@ func (l *Log) DecodeRLP(s *rlp.Stream) error { } return err } - -// LogForStorage is a wrapper around a Log that handles -// backward compatibility with prior storage formats. -type LogForStorage Log - -// EncodeRLP implements rlp.Encoder. -func (l *LogForStorage) EncodeRLP(w io.Writer) error { - rl := rlpLog{Address: l.Address, Topics: l.Topics, Data: l.Data} - return rlp.Encode(w, &rl) -} - -// DecodeRLP implements rlp.Decoder. -// -// Note some redundant fields(e.g. block number, tx hash etc) will be assembled later. -func (l *LogForStorage) DecodeRLP(s *rlp.Stream) error { - blob, err := s.Raw() - if err != nil { - return err - } - var dec rlpLog - err = rlp.DecodeBytes(blob, &dec) - if err == nil { - *l = LogForStorage{ - Address: dec.Address, - Topics: dec.Topics, - Data: dec.Data, - } - } else { - // Try to decode log with previous definition. - var dec legacyRlpStorageLog - err = rlp.DecodeBytes(blob, &dec) - if err == nil { - *l = LogForStorage{ - Address: dec.Address, - Topics: dec.Topics, - Data: dec.Data, - } - } - } - return err -} diff --git a/core/types/receipt.go b/core/types/receipt.go index f5f76958d18d..41376ed6fa3f 100644 --- a/core/types/receipt.go +++ b/core/types/receipt.go @@ -93,28 +93,7 @@ type receiptRLP struct { type storedReceiptRLP struct { PostStateOrStatus []byte CumulativeGasUsed uint64 - Logs []*LogForStorage -} - -// v4StoredReceiptRLP is the storage encoding of a receipt used in database version 4. -type v4StoredReceiptRLP struct { - PostStateOrStatus []byte - CumulativeGasUsed uint64 - TxHash common.Hash - ContractAddress common.Address - Logs []*LogForStorage - GasUsed uint64 -} - -// v3StoredReceiptRLP is the original storage encoding of a receipt including some unnecessary fields. -type v3StoredReceiptRLP struct { - PostStateOrStatus []byte - CumulativeGasUsed uint64 - Bloom Bloom - TxHash common.Hash - ContractAddress common.Address - Logs []*LogForStorage - GasUsed uint64 + Logs []*Log } // NewReceipt creates a barebone transaction receipt, copying the init fields. @@ -294,82 +273,20 @@ func (r *ReceiptForStorage) EncodeRLP(_w io.Writer) error { // DecodeRLP implements rlp.Decoder, and loads both consensus and implementation // fields of a receipt from an RLP stream. func (r *ReceiptForStorage) DecodeRLP(s *rlp.Stream) error { - // Retrieve the entire receipt blob as we need to try multiple decoders - blob, err := s.Raw() - if err != nil { - return err - } - // Try decoding from the newest format for future proofness, then the older one - // for old nodes that just upgraded. V4 was an intermediate unreleased format so - // we do need to decode it, but it's not common (try last). - if err := decodeStoredReceiptRLP(r, blob); err == nil { - return nil - } - if err := decodeV3StoredReceiptRLP(r, blob); err == nil { - return nil - } - return decodeV4StoredReceiptRLP(r, blob) -} - -func decodeStoredReceiptRLP(r *ReceiptForStorage, blob []byte) error { var stored storedReceiptRLP - if err := rlp.DecodeBytes(blob, &stored); err != nil { + if err := s.Decode(&stored); err != nil { return err } if err := (*Receipt)(r).setStatus(stored.PostStateOrStatus); err != nil { return err } r.CumulativeGasUsed = stored.CumulativeGasUsed - r.Logs = make([]*Log, len(stored.Logs)) - for i, log := range stored.Logs { - r.Logs[i] = (*Log)(log) - } + r.Logs = stored.Logs r.Bloom = CreateBloom(Receipts{(*Receipt)(r)}) return nil } -func decodeV4StoredReceiptRLP(r *ReceiptForStorage, blob []byte) error { - var stored v4StoredReceiptRLP - if err := rlp.DecodeBytes(blob, &stored); err != nil { - return err - } - if err := (*Receipt)(r).setStatus(stored.PostStateOrStatus); err != nil { - return err - } - r.CumulativeGasUsed = stored.CumulativeGasUsed - r.TxHash = stored.TxHash - r.ContractAddress = stored.ContractAddress - r.GasUsed = stored.GasUsed - r.Logs = make([]*Log, len(stored.Logs)) - for i, log := range stored.Logs { - r.Logs[i] = (*Log)(log) - } - r.Bloom = CreateBloom(Receipts{(*Receipt)(r)}) - - return nil -} - -func decodeV3StoredReceiptRLP(r *ReceiptForStorage, blob []byte) error { - var stored v3StoredReceiptRLP - if err := rlp.DecodeBytes(blob, &stored); err != nil { - return err - } - if err := (*Receipt)(r).setStatus(stored.PostStateOrStatus); err != nil { - return err - } - r.CumulativeGasUsed = stored.CumulativeGasUsed - r.Bloom = stored.Bloom - r.TxHash = stored.TxHash - r.ContractAddress = stored.ContractAddress - r.GasUsed = stored.GasUsed - r.Logs = make([]*Log, len(stored.Logs)) - for i, log := range stored.Logs { - r.Logs[i] = (*Log)(log) - } - return nil -} - // Receipts implements DerivableList for receipts. type Receipts []*Receipt diff --git a/core/types/receipt_test.go b/core/types/receipt_test.go index f3b5fba4725a..2cb553088035 100644 --- a/core/types/receipt_test.go +++ b/core/types/receipt_test.go @@ -91,136 +91,6 @@ func TestDecodeEmptyTypedReceipt(t *testing.T) { } } -func TestLegacyReceiptDecoding(t *testing.T) { - tests := []struct { - name string - encode func(*Receipt) ([]byte, error) - }{ - { - "ReceiptForStorage", - encodeAsReceiptForStorage, - }, - { - "StoredReceiptRLP", - encodeAsStoredReceiptRLP, - }, - { - "V4StoredReceiptRLP", - encodeAsV4StoredReceiptRLP, - }, - { - "V3StoredReceiptRLP", - encodeAsV3StoredReceiptRLP, - }, - } - - tx := NewTransaction(1, common.HexToAddress("0x1"), big.NewInt(1), 1, big.NewInt(1), nil) - receipt := &Receipt{ - Status: ReceiptStatusFailed, - CumulativeGasUsed: 1, - Logs: []*Log{ - { - Address: common.BytesToAddress([]byte{0x11}), - Topics: []common.Hash{common.HexToHash("dead"), common.HexToHash("beef")}, - Data: []byte{0x01, 0x00, 0xff}, - }, - { - Address: common.BytesToAddress([]byte{0x01, 0x11}), - Topics: []common.Hash{common.HexToHash("dead"), common.HexToHash("beef")}, - Data: []byte{0x01, 0x00, 0xff}, - }, - }, - TxHash: tx.Hash(), - ContractAddress: common.BytesToAddress([]byte{0x01, 0x11, 0x11}), - GasUsed: 111111, - } - receipt.Bloom = CreateBloom(Receipts{receipt}) - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - enc, err := tc.encode(receipt) - if err != nil { - t.Fatalf("Error encoding receipt: %v", err) - } - var dec ReceiptForStorage - if err := rlp.DecodeBytes(enc, &dec); err != nil { - t.Fatalf("Error decoding RLP receipt: %v", err) - } - // Check whether all consensus fields are correct. - if dec.Status != receipt.Status { - t.Fatalf("Receipt status mismatch, want %v, have %v", receipt.Status, dec.Status) - } - if dec.CumulativeGasUsed != receipt.CumulativeGasUsed { - t.Fatalf("Receipt CumulativeGasUsed mismatch, want %v, have %v", receipt.CumulativeGasUsed, dec.CumulativeGasUsed) - } - if dec.Bloom != receipt.Bloom { - t.Fatalf("Bloom data mismatch, want %v, have %v", receipt.Bloom, dec.Bloom) - } - if len(dec.Logs) != len(receipt.Logs) { - t.Fatalf("Receipt log number mismatch, want %v, have %v", len(receipt.Logs), len(dec.Logs)) - } - for i := 0; i < len(dec.Logs); i++ { - if dec.Logs[i].Address != receipt.Logs[i].Address { - t.Fatalf("Receipt log %d address mismatch, want %v, have %v", i, receipt.Logs[i].Address, dec.Logs[i].Address) - } - if !reflect.DeepEqual(dec.Logs[i].Topics, receipt.Logs[i].Topics) { - t.Fatalf("Receipt log %d topics mismatch, want %v, have %v", i, receipt.Logs[i].Topics, dec.Logs[i].Topics) - } - if !bytes.Equal(dec.Logs[i].Data, receipt.Logs[i].Data) { - t.Fatalf("Receipt log %d data mismatch, want %v, have %v", i, receipt.Logs[i].Data, dec.Logs[i].Data) - } - } - }) - } -} - -func encodeAsReceiptForStorage(want *Receipt) ([]byte, error) { - return rlp.EncodeToBytes((*ReceiptForStorage)(want)) -} - -func encodeAsStoredReceiptRLP(want *Receipt) ([]byte, error) { - stored := &storedReceiptRLP{ - PostStateOrStatus: want.statusEncoding(), - CumulativeGasUsed: want.CumulativeGasUsed, - Logs: make([]*LogForStorage, len(want.Logs)), - } - for i, log := range want.Logs { - stored.Logs[i] = (*LogForStorage)(log) - } - return rlp.EncodeToBytes(stored) -} - -func encodeAsV4StoredReceiptRLP(want *Receipt) ([]byte, error) { - stored := &v4StoredReceiptRLP{ - PostStateOrStatus: want.statusEncoding(), - CumulativeGasUsed: want.CumulativeGasUsed, - TxHash: want.TxHash, - ContractAddress: want.ContractAddress, - Logs: make([]*LogForStorage, len(want.Logs)), - GasUsed: want.GasUsed, - } - for i, log := range want.Logs { - stored.Logs[i] = (*LogForStorage)(log) - } - return rlp.EncodeToBytes(stored) -} - -func encodeAsV3StoredReceiptRLP(want *Receipt) ([]byte, error) { - stored := &v3StoredReceiptRLP{ - PostStateOrStatus: want.statusEncoding(), - CumulativeGasUsed: want.CumulativeGasUsed, - Bloom: want.Bloom, - TxHash: want.TxHash, - ContractAddress: want.ContractAddress, - Logs: make([]*LogForStorage, len(want.Logs)), - GasUsed: want.GasUsed, - } - for i, log := range want.Logs { - stored.Logs[i] = (*LogForStorage)(log) - } - return rlp.EncodeToBytes(stored) -} - // Tests that receipt data can be correctly derived from the contextual infos func TestDeriveFields(t *testing.T) { // Create a few transactions to have receipts for diff --git a/core/types/transaction.go b/core/types/transaction.go index ca565e9410ca..9ef9990524e4 100644 --- a/core/types/transaction.go +++ b/core/types/transaction.go @@ -687,6 +687,7 @@ func (s *TxByPriceAndTime) Pop() interface{} { old := *s n := len(old) x := old[n-1] + old[n-1] = nil *s = old[0 : n-1] return x } diff --git a/core/vm/evm.go b/core/vm/evm.go index 67eafe62d97c..cd3b1c483006 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -19,7 +19,6 @@ package vm import ( "math/big" "sync/atomic" - "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" @@ -187,7 +186,7 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas if evm.Config.Debug { if evm.depth == 0 { evm.Config.Tracer.CaptureStart(evm, caller.Address(), addr, false, input, gas, value) - evm.Config.Tracer.CaptureEnd(ret, 0, 0, nil) + evm.Config.Tracer.CaptureEnd(ret, 0, nil) } else { evm.Config.Tracer.CaptureEnter(CALL, caller.Address(), addr, input, gas, value) evm.Config.Tracer.CaptureExit(ret, 0, nil) @@ -203,9 +202,9 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas if evm.Config.Debug { if evm.depth == 0 { evm.Config.Tracer.CaptureStart(evm, caller.Address(), addr, false, input, gas, value) - defer func(startGas uint64, startTime time.Time) { // Lazy evaluation of the parameters - evm.Config.Tracer.CaptureEnd(ret, startGas-gas, time.Since(startTime), err) - }(gas, time.Now()) + defer func(startGas uint64) { // Lazy evaluation of the parameters + evm.Config.Tracer.CaptureEnd(ret, startGas-gas, err) + }(gas) } else { // Handle tracer events for entering and exiting a call frame evm.Config.Tracer.CaptureEnter(CALL, caller.Address(), addr, input, gas, value) @@ -452,8 +451,6 @@ func (evm *EVM) create(caller ContractRef, codeAndHash *codeAndHash, gas uint64, } } - start := time.Now() - ret, err := evm.interpreter.Run(contract, nil, false) // Check whether the max code size has been exceeded, assign err if the case. @@ -491,7 +488,7 @@ func (evm *EVM) create(caller ContractRef, codeAndHash *codeAndHash, gas uint64, if evm.Config.Debug { if evm.depth == 0 { - evm.Config.Tracer.CaptureEnd(ret, gas-contract.Gas, time.Since(start), err) + evm.Config.Tracer.CaptureEnd(ret, gas-contract.Gas, err) } else { evm.Config.Tracer.CaptureExit(ret, gas-contract.Gas, err) } diff --git a/core/vm/logger.go b/core/vm/logger.go index 50fccafcf53e..2667908a84d1 100644 --- a/core/vm/logger.go +++ b/core/vm/logger.go @@ -18,7 +18,6 @@ package vm import ( "math/big" - "time" "github.com/ethereum/go-ethereum/common" ) @@ -34,7 +33,7 @@ type EVMLogger interface { CaptureTxEnd(restGas uint64) // Top call frame CaptureStart(env *EVM, from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) - CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) + CaptureEnd(output []byte, gasUsed uint64, err error) // Rest of call frames CaptureEnter(typ OpCode, from common.Address, to common.Address, input []byte, gas uint64, value *big.Int) CaptureExit(output []byte, gasUsed uint64, err error) diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index af28d9e82097..41c5d66edb38 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -35,6 +35,7 @@ import ( "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/trie" ) var ( @@ -206,6 +207,10 @@ type BlockChain interface { // Snapshots returns the blockchain snapshot tree to paused it during sync. Snapshots() *snapshot.Tree + + // TrieDB retrieves the low level trie database used for interacting + // with trie nodes. + TrieDB() *trie.Database } // New creates a new downloader to fetch hashes and blocks from remote peers. @@ -224,7 +229,7 @@ func New(checkpoint uint64, stateDb ethdb.Database, mux *event.TypeMux, chain Bl dropPeer: dropPeer, headerProcCh: make(chan *headerTask, 1), quitCh: make(chan struct{}), - SnapSyncer: snap.NewSyncer(stateDb), + SnapSyncer: snap.NewSyncer(stateDb, chain.TrieDB().Scheme()), stateSyncStart: make(chan *stateSync), } dl.skeleton = newSkeleton(stateDb, dl.peers, dropPeer, newBeaconBackfiller(dl, success)) diff --git a/eth/protocols/snap/sync.go b/eth/protocols/snap/sync.go index 6e8c450f51c3..a9e35f971482 100644 --- a/eth/protocols/snap/sync.go +++ b/eth/protocols/snap/sync.go @@ -417,7 +417,8 @@ type SyncPeer interface { // - The peer delivers a stale response after a previous timeout // - The peer delivers a refusal to serve the requested state type Syncer struct { - db ethdb.KeyValueStore // Database to store the trie nodes into (and dedup) + db ethdb.KeyValueStore // Database to store the trie nodes into (and dedup) + scheme trie.NodeScheme // Node scheme used in node database root common.Hash // Current state trie root being synced tasks []*accountTask // Current account task set being synced @@ -485,9 +486,10 @@ type Syncer struct { // NewSyncer creates a new snapshot syncer to download the Ethereum state over the // snap protocol. -func NewSyncer(db ethdb.KeyValueStore) *Syncer { +func NewSyncer(db ethdb.KeyValueStore, scheme trie.NodeScheme) *Syncer { return &Syncer{ - db: db, + db: db, + scheme: scheme, peers: make(map[string]SyncPeer), peerJoin: new(event.Feed), @@ -581,7 +583,7 @@ func (s *Syncer) Sync(root common.Hash, cancel chan struct{}) error { s.lock.Lock() s.root = root s.healer = &healTask{ - scheduler: state.NewStateSync(root, s.db, s.onHealState), + scheduler: state.NewStateSync(root, s.db, s.onHealState, s.scheme), trieTasks: make(map[string]common.Hash), codeTasks: make(map[common.Hash]struct{}), } @@ -743,8 +745,9 @@ func (s *Syncer) loadSyncStatus() { s.accountBytes += common.StorageSize(len(key) + len(value)) }, } - task.genTrie = trie.NewStackTrie(task.genBatch) - + task.genTrie = trie.NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, val []byte) { + s.scheme.WriteTrieNode(task.genBatch, owner, path, hash, val) + }) for accountHash, subtasks := range task.SubTasks { for _, subtask := range subtasks { subtask.genBatch = ethdb.HookedBatch{ @@ -753,7 +756,9 @@ func (s *Syncer) loadSyncStatus() { s.storageBytes += common.StorageSize(len(key) + len(value)) }, } - subtask.genTrie = trie.NewStackTrieWithOwner(subtask.genBatch, accountHash) + subtask.genTrie = trie.NewStackTrieWithOwner(func(owner common.Hash, path []byte, hash common.Hash, val []byte) { + s.scheme.WriteTrieNode(subtask.genBatch, owner, path, hash, val) + }, accountHash) } } } @@ -810,7 +815,9 @@ func (s *Syncer) loadSyncStatus() { Last: last, SubTasks: make(map[common.Hash][]*storageTask), genBatch: batch, - genTrie: trie.NewStackTrie(batch), + genTrie: trie.NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, val []byte) { + s.scheme.WriteTrieNode(batch, owner, path, hash, val) + }), }) log.Debug("Created account sync task", "from", next, "last", last) next = common.BigToHash(new(big.Int).Add(last.Big(), common.Big1)) @@ -1835,7 +1842,7 @@ func (s *Syncer) processAccountResponse(res *accountResponse) { } // Check if the account is a contract with an unknown storage trie if account.Root != emptyRoot { - if ok, err := s.db.Has(account.Root[:]); err != nil || !ok { + if !s.scheme.HasTrieNode(s.db, res.hashes[i], nil, account.Root) { // If there was a previous large state retrieval in progress, // don't restart it from scratch. This happens if a sync cycle // is interrupted and resumed later. However, *do* update the @@ -2007,7 +2014,9 @@ func (s *Syncer) processStorageResponse(res *storageResponse) { Last: r.End(), root: acc.Root, genBatch: batch, - genTrie: trie.NewStackTrieWithOwner(batch, account), + genTrie: trie.NewStackTrieWithOwner(func(owner common.Hash, path []byte, hash common.Hash, val []byte) { + s.scheme.WriteTrieNode(batch, owner, path, hash, val) + }, account), }) for r.Next() { batch := ethdb.HookedBatch{ @@ -2021,7 +2030,9 @@ func (s *Syncer) processStorageResponse(res *storageResponse) { Last: r.End(), root: acc.Root, genBatch: batch, - genTrie: trie.NewStackTrieWithOwner(batch, account), + genTrie: trie.NewStackTrieWithOwner(func(owner common.Hash, path []byte, hash common.Hash, val []byte) { + s.scheme.WriteTrieNode(batch, owner, path, hash, val) + }, account), }) } for _, task := range tasks { @@ -2066,7 +2077,9 @@ func (s *Syncer) processStorageResponse(res *storageResponse) { slots += len(res.hashes[i]) if i < len(res.hashes)-1 || res.subTask == nil { - tr := trie.NewStackTrieWithOwner(batch, account) + tr := trie.NewStackTrieWithOwner(func(owner common.Hash, path []byte, hash common.Hash, val []byte) { + s.scheme.WriteTrieNode(batch, owner, path, hash, val) + }, account) for j := 0; j < len(res.hashes[i]); j++ { tr.Update(res.hashes[i][j][:], res.slots[i][j]) } diff --git a/eth/protocols/snap/sync_test.go b/eth/protocols/snap/sync_test.go index 1d1ce932e073..9b99d7e7a2d0 100644 --- a/eth/protocols/snap/sync_test.go +++ b/eth/protocols/snap/sync_test.go @@ -159,6 +159,13 @@ func newTestPeer(id string, t *testing.T, term func()) *testPeer { return peer } +func (t *testPeer) setStorageTries(tries map[common.Hash]*trie.Trie) { + t.storageTries = make(map[common.Hash]*trie.Trie) + for root, trie := range tries { + t.storageTries[root] = trie.Copy() + } +} + func (t *testPeer) ID() string { return t.id } func (t *testPeer) Log() log.Logger { return t.logger } @@ -562,9 +569,9 @@ func TestSyncBloatedProof(t *testing.T) { }) } ) - sourceAccountTrie, elems := makeAccountTrieNoStorage(100) + nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(100) source := newTestPeer("source", t, term) - source.accountTrie = sourceAccountTrie + source.accountTrie = sourceAccountTrie.Copy() source.accountValues = elems source.accountRequestHandler = func(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error { @@ -610,15 +617,15 @@ func TestSyncBloatedProof(t *testing.T) { } return nil } - syncer := setupSyncer(source) + syncer := setupSyncer(nodeScheme, source) if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err == nil { t.Fatal("No error returned from incomplete/cancelled sync") } } -func setupSyncer(peers ...*testPeer) *Syncer { +func setupSyncer(scheme trie.NodeScheme, peers ...*testPeer) *Syncer { stateDb := rawdb.NewMemoryDatabase() - syncer := NewSyncer(stateDb) + syncer := NewSyncer(stateDb, scheme) for _, peer := range peers { syncer.Register(peer) peer.remote = syncer @@ -639,15 +646,15 @@ func TestSync(t *testing.T) { }) } ) - sourceAccountTrie, elems := makeAccountTrieNoStorage(100) + nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(100) mkSource := func(name string) *testPeer { source := newTestPeer(name, t, term) - source.accountTrie = sourceAccountTrie + source.accountTrie = sourceAccountTrie.Copy() source.accountValues = elems return source } - syncer := setupSyncer(mkSource("source")) + syncer := setupSyncer(nodeScheme, mkSource("source")) if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil { t.Fatalf("sync failed: %v", err) } @@ -668,15 +675,15 @@ func TestSyncTinyTriePanic(t *testing.T) { }) } ) - sourceAccountTrie, elems := makeAccountTrieNoStorage(1) + nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(1) mkSource := func(name string) *testPeer { source := newTestPeer(name, t, term) - source.accountTrie = sourceAccountTrie + source.accountTrie = sourceAccountTrie.Copy() source.accountValues = elems return source } - syncer := setupSyncer(mkSource("source")) + syncer := setupSyncer(nodeScheme, mkSource("source")) done := checkStall(t, term) if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil { t.Fatalf("sync failed: %v", err) @@ -698,15 +705,15 @@ func TestMultiSync(t *testing.T) { }) } ) - sourceAccountTrie, elems := makeAccountTrieNoStorage(100) + nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(100) mkSource := func(name string) *testPeer { source := newTestPeer(name, t, term) - source.accountTrie = sourceAccountTrie + source.accountTrie = sourceAccountTrie.Copy() source.accountValues = elems return source } - syncer := setupSyncer(mkSource("sourceA"), mkSource("sourceB")) + syncer := setupSyncer(nodeScheme, mkSource("sourceA"), mkSource("sourceB")) done := checkStall(t, term) if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil { t.Fatalf("sync failed: %v", err) @@ -728,17 +735,17 @@ func TestSyncWithStorage(t *testing.T) { }) } ) - sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(3, 3000, true, false) + nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(3, 3000, true, false) mkSource := func(name string) *testPeer { source := newTestPeer(name, t, term) - source.accountTrie = sourceAccountTrie + source.accountTrie = sourceAccountTrie.Copy() source.accountValues = elems - source.storageTries = storageTries + source.setStorageTries(storageTries) source.storageValues = storageElems return source } - syncer := setupSyncer(mkSource("sourceA")) + syncer := setupSyncer(nodeScheme, mkSource("sourceA")) done := checkStall(t, term) if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil { t.Fatalf("sync failed: %v", err) @@ -760,13 +767,13 @@ func TestMultiSyncManyUseless(t *testing.T) { }) } ) - sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(100, 3000, true, false) + nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(100, 3000, true, false) mkSource := func(name string, noAccount, noStorage, noTrieNode bool) *testPeer { source := newTestPeer(name, t, term) - source.accountTrie = sourceAccountTrie + source.accountTrie = sourceAccountTrie.Copy() source.accountValues = elems - source.storageTries = storageTries + source.setStorageTries(storageTries) source.storageValues = storageElems if !noAccount { @@ -782,6 +789,7 @@ func TestMultiSyncManyUseless(t *testing.T) { } syncer := setupSyncer( + nodeScheme, mkSource("full", true, true, true), mkSource("noAccounts", false, true, true), mkSource("noStorage", true, false, true), @@ -806,13 +814,13 @@ func TestMultiSyncManyUselessWithLowTimeout(t *testing.T) { }) } ) - sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(100, 3000, true, false) + nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(100, 3000, true, false) mkSource := func(name string, noAccount, noStorage, noTrieNode bool) *testPeer { source := newTestPeer(name, t, term) - source.accountTrie = sourceAccountTrie + source.accountTrie = sourceAccountTrie.Copy() source.accountValues = elems - source.storageTries = storageTries + source.setStorageTries(storageTries) source.storageValues = storageElems if !noAccount { @@ -828,6 +836,7 @@ func TestMultiSyncManyUselessWithLowTimeout(t *testing.T) { } syncer := setupSyncer( + nodeScheme, mkSource("full", true, true, true), mkSource("noAccounts", false, true, true), mkSource("noStorage", true, false, true), @@ -857,13 +866,13 @@ func TestMultiSyncManyUnresponsive(t *testing.T) { }) } ) - sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(100, 3000, true, false) + nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(100, 3000, true, false) mkSource := func(name string, noAccount, noStorage, noTrieNode bool) *testPeer { source := newTestPeer(name, t, term) - source.accountTrie = sourceAccountTrie + source.accountTrie = sourceAccountTrie.Copy() source.accountValues = elems - source.storageTries = storageTries + source.setStorageTries(storageTries) source.storageValues = storageElems if !noAccount { @@ -879,6 +888,7 @@ func TestMultiSyncManyUnresponsive(t *testing.T) { } syncer := setupSyncer( + nodeScheme, mkSource("full", true, true, true), mkSource("noAccounts", false, true, true), mkSource("noStorage", true, false, true), @@ -923,15 +933,16 @@ func TestSyncBoundaryAccountTrie(t *testing.T) { }) } ) - sourceAccountTrie, elems := makeBoundaryAccountTrie(3000) + nodeScheme, sourceAccountTrie, elems := makeBoundaryAccountTrie(3000) mkSource := func(name string) *testPeer { source := newTestPeer(name, t, term) - source.accountTrie = sourceAccountTrie + source.accountTrie = sourceAccountTrie.Copy() source.accountValues = elems return source } syncer := setupSyncer( + nodeScheme, mkSource("peer-a"), mkSource("peer-b"), ) @@ -957,11 +968,11 @@ func TestSyncNoStorageAndOneCappedPeer(t *testing.T) { }) } ) - sourceAccountTrie, elems := makeAccountTrieNoStorage(3000) + nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(3000) mkSource := func(name string, slow bool) *testPeer { source := newTestPeer(name, t, term) - source.accountTrie = sourceAccountTrie + source.accountTrie = sourceAccountTrie.Copy() source.accountValues = elems if slow { @@ -971,6 +982,7 @@ func TestSyncNoStorageAndOneCappedPeer(t *testing.T) { } syncer := setupSyncer( + nodeScheme, mkSource("nice-a", false), mkSource("nice-b", false), mkSource("nice-c", false), @@ -998,11 +1010,11 @@ func TestSyncNoStorageAndOneCodeCorruptPeer(t *testing.T) { }) } ) - sourceAccountTrie, elems := makeAccountTrieNoStorage(3000) + nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(3000) mkSource := func(name string, codeFn codeHandlerFunc) *testPeer { source := newTestPeer(name, t, term) - source.accountTrie = sourceAccountTrie + source.accountTrie = sourceAccountTrie.Copy() source.accountValues = elems source.codeRequestHandler = codeFn return source @@ -1012,6 +1024,7 @@ func TestSyncNoStorageAndOneCodeCorruptPeer(t *testing.T) { // non-corrupt peer, which delivers everything in one go, and makes the // test moot syncer := setupSyncer( + nodeScheme, mkSource("capped", cappedCodeRequestHandler), mkSource("corrupt", corruptCodeRequestHandler), ) @@ -1035,11 +1048,11 @@ func TestSyncNoStorageAndOneAccountCorruptPeer(t *testing.T) { }) } ) - sourceAccountTrie, elems := makeAccountTrieNoStorage(3000) + nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(3000) mkSource := func(name string, accFn accountHandlerFunc) *testPeer { source := newTestPeer(name, t, term) - source.accountTrie = sourceAccountTrie + source.accountTrie = sourceAccountTrie.Copy() source.accountValues = elems source.accountRequestHandler = accFn return source @@ -1049,6 +1062,7 @@ func TestSyncNoStorageAndOneAccountCorruptPeer(t *testing.T) { // non-corrupt peer, which delivers everything in one go, and makes the // test moot syncer := setupSyncer( + nodeScheme, mkSource("capped", defaultAccountRequestHandler), mkSource("corrupt", corruptAccountRequestHandler), ) @@ -1074,11 +1088,11 @@ func TestSyncNoStorageAndOneCodeCappedPeer(t *testing.T) { }) } ) - sourceAccountTrie, elems := makeAccountTrieNoStorage(3000) + nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(3000) mkSource := func(name string, codeFn codeHandlerFunc) *testPeer { source := newTestPeer(name, t, term) - source.accountTrie = sourceAccountTrie + source.accountTrie = sourceAccountTrie.Copy() source.accountValues = elems source.codeRequestHandler = codeFn return source @@ -1087,6 +1101,7 @@ func TestSyncNoStorageAndOneCodeCappedPeer(t *testing.T) { // so it shouldn't be more than that var counter int syncer := setupSyncer( + nodeScheme, mkSource("capped", func(t *testPeer, id uint64, hashes []common.Hash, max uint64) error { counter++ return cappedCodeRequestHandler(t, id, hashes, max) @@ -1124,17 +1139,18 @@ func TestSyncBoundaryStorageTrie(t *testing.T) { }) } ) - sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(10, 1000, false, true) + nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(10, 1000, false, true) mkSource := func(name string) *testPeer { source := newTestPeer(name, t, term) - source.accountTrie = sourceAccountTrie + source.accountTrie = sourceAccountTrie.Copy() source.accountValues = elems - source.storageTries = storageTries + source.setStorageTries(storageTries) source.storageValues = storageElems return source } syncer := setupSyncer( + nodeScheme, mkSource("peer-a"), mkSource("peer-b"), ) @@ -1160,13 +1176,13 @@ func TestSyncWithStorageAndOneCappedPeer(t *testing.T) { }) } ) - sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(300, 1000, false, false) + nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(300, 1000, false, false) mkSource := func(name string, slow bool) *testPeer { source := newTestPeer(name, t, term) - source.accountTrie = sourceAccountTrie + source.accountTrie = sourceAccountTrie.Copy() source.accountValues = elems - source.storageTries = storageTries + source.setStorageTries(storageTries) source.storageValues = storageElems if slow { @@ -1176,6 +1192,7 @@ func TestSyncWithStorageAndOneCappedPeer(t *testing.T) { } syncer := setupSyncer( + nodeScheme, mkSource("nice-a", false), mkSource("slow", true), ) @@ -1201,19 +1218,20 @@ func TestSyncWithStorageAndCorruptPeer(t *testing.T) { }) } ) - sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(100, 3000, true, false) + nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(100, 3000, true, false) mkSource := func(name string, handler storageHandlerFunc) *testPeer { source := newTestPeer(name, t, term) - source.accountTrie = sourceAccountTrie + source.accountTrie = sourceAccountTrie.Copy() source.accountValues = elems - source.storageTries = storageTries + source.setStorageTries(storageTries) source.storageValues = storageElems source.storageRequestHandler = handler return source } syncer := setupSyncer( + nodeScheme, mkSource("nice-a", defaultStorageRequestHandler), mkSource("nice-b", defaultStorageRequestHandler), mkSource("nice-c", defaultStorageRequestHandler), @@ -1239,18 +1257,19 @@ func TestSyncWithStorageAndNonProvingPeer(t *testing.T) { }) } ) - sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(100, 3000, true, false) + nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(100, 3000, true, false) mkSource := func(name string, handler storageHandlerFunc) *testPeer { source := newTestPeer(name, t, term) - source.accountTrie = sourceAccountTrie + source.accountTrie = sourceAccountTrie.Copy() source.accountValues = elems - source.storageTries = storageTries + source.setStorageTries(storageTries) source.storageValues = storageElems source.storageRequestHandler = handler return source } syncer := setupSyncer( + nodeScheme, mkSource("nice-a", defaultStorageRequestHandler), mkSource("nice-b", defaultStorageRequestHandler), mkSource("nice-c", defaultStorageRequestHandler), @@ -1279,18 +1298,18 @@ func TestSyncWithStorageMisbehavingProve(t *testing.T) { }) } ) - sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorageWithUniqueStorage(10, 30, false) + nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorageWithUniqueStorage(10, 30, false) mkSource := func(name string) *testPeer { source := newTestPeer(name, t, term) - source.accountTrie = sourceAccountTrie + source.accountTrie = sourceAccountTrie.Copy() source.accountValues = elems - source.storageTries = storageTries + source.setStorageTries(storageTries) source.storageValues = storageElems source.storageRequestHandler = proofHappyStorageRequestHandler return source } - syncer := setupSyncer(mkSource("sourceA")) + syncer := setupSyncer(nodeScheme, mkSource("sourceA")) if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil { t.Fatalf("sync failed: %v", err) } @@ -1347,7 +1366,7 @@ func getCodeByHash(hash common.Hash) []byte { } // makeAccountTrieNoStorage spits out a trie, along with the leafs -func makeAccountTrieNoStorage(n int) (*trie.Trie, entrySlice) { +func makeAccountTrieNoStorage(n int) (trie.NodeScheme, *trie.Trie, entrySlice) { var ( db = trie.NewDatabase(rawdb.NewMemoryDatabase()) accTrie = trie.NewEmpty(db) @@ -1373,13 +1392,13 @@ func makeAccountTrieNoStorage(n int) (*trie.Trie, entrySlice) { db.Update(trie.NewWithNodeSet(nodes)) accTrie, _ = trie.New(trie.StateTrieID(root), db) - return accTrie, entries + return db.Scheme(), accTrie, entries } // makeBoundaryAccountTrie constructs an account trie. Instead of filling // accounts normally, this function will fill a few accounts which have // boundary hash. -func makeBoundaryAccountTrie(n int) (*trie.Trie, entrySlice) { +func makeBoundaryAccountTrie(n int) (trie.NodeScheme, *trie.Trie, entrySlice) { var ( entries entrySlice boundaries []common.Hash @@ -1435,12 +1454,12 @@ func makeBoundaryAccountTrie(n int) (*trie.Trie, entrySlice) { db.Update(trie.NewWithNodeSet(nodes)) accTrie, _ = trie.New(trie.StateTrieID(root), db) - return accTrie, entries + return db.Scheme(), accTrie, entries } // makeAccountTrieWithStorageWithUniqueStorage creates an account trie where each accounts // has a unique storage set. -func makeAccountTrieWithStorageWithUniqueStorage(accounts, slots int, code bool) (*trie.Trie, entrySlice, map[common.Hash]*trie.Trie, map[common.Hash]entrySlice) { +func makeAccountTrieWithStorageWithUniqueStorage(accounts, slots int, code bool) (trie.NodeScheme, *trie.Trie, entrySlice, map[common.Hash]*trie.Trie, map[common.Hash]entrySlice) { var ( db = trie.NewDatabase(rawdb.NewMemoryDatabase()) accTrie = trie.NewEmpty(db) @@ -1491,11 +1510,11 @@ func makeAccountTrieWithStorageWithUniqueStorage(accounts, slots int, code bool) trie, _ := trie.New(id, db) storageTries[common.BytesToHash(key)] = trie } - return accTrie, entries, storageTries, storageEntries + return db.Scheme(), accTrie, entries, storageTries, storageEntries } // makeAccountTrieWithStorage spits out a trie, along with the leafs -func makeAccountTrieWithStorage(accounts, slots int, code, boundary bool) (*trie.Trie, entrySlice, map[common.Hash]*trie.Trie, map[common.Hash]entrySlice) { +func makeAccountTrieWithStorage(accounts, slots int, code, boundary bool) (trie.NodeScheme, *trie.Trie, entrySlice, map[common.Hash]*trie.Trie, map[common.Hash]entrySlice) { var ( db = trie.NewDatabase(rawdb.NewMemoryDatabase()) accTrie = trie.NewEmpty(db) @@ -1562,7 +1581,7 @@ func makeAccountTrieWithStorage(accounts, slots int, code, boundary bool) (*trie } storageTries[common.BytesToHash(key)] = trie } - return accTrie, entries, storageTries, storageEntries + return db.Scheme(), accTrie, entries, storageTries, storageEntries } // makeStorageTrieWithSeed fills a storage trie with n items, returning the @@ -1641,7 +1660,7 @@ func makeBoundaryStorageTrie(owner common.Hash, n int, db *trie.Database) (commo func verifyTrie(db ethdb.KeyValueStore, root common.Hash, t *testing.T) { t.Helper() - triedb := trie.NewDatabase(db) + triedb := trie.NewDatabase(rawdb.NewDatabase(db)) accTrie, err := trie.New(trie.StateTrieID(root), triedb) if err != nil { t.Fatal(err) @@ -1697,16 +1716,16 @@ func TestSyncAccountPerformance(t *testing.T) { }) } ) - sourceAccountTrie, elems := makeAccountTrieNoStorage(100) + nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(100) mkSource := func(name string) *testPeer { source := newTestPeer(name, t, term) - source.accountTrie = sourceAccountTrie + source.accountTrie = sourceAccountTrie.Copy() source.accountValues = elems return source } src := mkSource("source") - syncer := setupSyncer(src) + syncer := setupSyncer(nodeScheme, src) if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil { t.Fatalf("sync failed: %v", err) } diff --git a/eth/tracers/js/goja.go b/eth/tracers/js/goja.go index dd01ff09c133..57839320b677 100644 --- a/eth/tracers/js/goja.go +++ b/eth/tracers/js/goja.go @@ -21,7 +21,6 @@ import ( "errors" "fmt" "math/big" - "time" "github.com/dop251/goja" @@ -285,9 +284,8 @@ func (t *jsTracer) CaptureFault(pc uint64, op vm.OpCode, gas, cost uint64, scope } // CaptureEnd is called after the call finishes to finalize the tracing. -func (t *jsTracer) CaptureEnd(output []byte, gasUsed uint64, duration time.Duration, err error) { +func (t *jsTracer) CaptureEnd(output []byte, gasUsed uint64, err error) { t.ctx["output"] = t.vm.ToValue(output) - t.ctx["time"] = t.vm.ToValue(duration.String()) if err != nil { t.ctx["error"] = t.vm.ToValue(err.Error()) } diff --git a/eth/tracers/js/internal/tracers/call_tracer_legacy.js b/eth/tracers/js/internal/tracers/call_tracer_legacy.js index b9e555df8746..451a644b917a 100644 --- a/eth/tracers/js/internal/tracers/call_tracer_legacy.js +++ b/eth/tracers/js/internal/tracers/call_tracer_legacy.js @@ -233,7 +233,6 @@ input: call.input, output: call.output, error: call.error, - time: call.time, calls: call.calls, } for (var key in sorted) { diff --git a/eth/tracers/js/tracer_test.go b/eth/tracers/js/tracer_test.go index 3b5a6b93f054..fe036dd7ef50 100644 --- a/eth/tracers/js/tracer_test.go +++ b/eth/tracers/js/tracer_test.go @@ -76,7 +76,7 @@ func runTrace(tracer tracers.Tracer, vmctx *vmContext, chaincfg *params.ChainCon tracer.CaptureTxStart(gasLimit) tracer.CaptureStart(env, contract.Caller(), contract.Address(), false, []byte{}, startGas, value) ret, err := env.Interpreter().Run(contract, []byte{}, false) - tracer.CaptureEnd(ret, startGas-contract.Gas, 1, err) + tracer.CaptureEnd(ret, startGas-contract.Gas, err) // Rest gas assumes no refund tracer.CaptureTxEnd(contract.Gas) if err != nil { @@ -206,7 +206,7 @@ func TestNoStepExec(t *testing.T) { } env := vm.NewEVM(vm.BlockContext{BlockNumber: big.NewInt(1), Time: big.NewInt(1)}, vm.TxContext{GasPrice: big.NewInt(100)}, &dummyStatedb{}, params.TestChainConfig, vm.Config{Debug: true, Tracer: tracer}) tracer.CaptureStart(env, common.Address{}, common.Address{}, false, []byte{}, 1000, big.NewInt(0)) - tracer.CaptureEnd(nil, 0, 1, nil) + tracer.CaptureEnd(nil, 0, nil) ret, err := tracer.GetResult() if err != nil { t.Fatal(err) diff --git a/eth/tracers/logger/access_list_tracer.go b/eth/tracers/logger/access_list_tracer.go index a8908094eb50..766ee4e4b95c 100644 --- a/eth/tracers/logger/access_list_tracer.go +++ b/eth/tracers/logger/access_list_tracer.go @@ -18,7 +18,6 @@ package logger import ( "math/big" - "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" @@ -162,7 +161,7 @@ func (a *AccessListTracer) CaptureState(pc uint64, op vm.OpCode, gas, cost uint6 func (*AccessListTracer) CaptureFault(pc uint64, op vm.OpCode, gas, cost uint64, scope *vm.ScopeContext, depth int, err error) { } -func (*AccessListTracer) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) {} +func (*AccessListTracer) CaptureEnd(output []byte, gasUsed uint64, err error) {} func (*AccessListTracer) CaptureEnter(typ vm.OpCode, from common.Address, to common.Address, input []byte, gas uint64, value *big.Int) { } diff --git a/eth/tracers/logger/logger.go b/eth/tracers/logger/logger.go index ce774270e127..5e75318b9a92 100644 --- a/eth/tracers/logger/logger.go +++ b/eth/tracers/logger/logger.go @@ -24,7 +24,6 @@ import ( "math/big" "strings" "sync/atomic" - "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" @@ -219,7 +218,7 @@ func (l *StructLogger) CaptureFault(pc uint64, op vm.OpCode, gas, cost uint64, s } // CaptureEnd is called after the call finishes to finalize the tracing. -func (l *StructLogger) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) { +func (l *StructLogger) CaptureEnd(output []byte, gasUsed uint64, err error) { l.output = output l.err = err if l.cfg.Debug { @@ -385,7 +384,7 @@ func (t *mdLogger) CaptureFault(pc uint64, op vm.OpCode, gas, cost uint64, scope fmt.Fprintf(t.out, "\nError: at pc=%d, op=%v: %v\n", pc, op, err) } -func (t *mdLogger) CaptureEnd(output []byte, gasUsed uint64, tm time.Duration, err error) { +func (t *mdLogger) CaptureEnd(output []byte, gasUsed uint64, err error) { fmt.Fprintf(t.out, "\nOutput: `%#x`\nConsumed gas: `%d`\nError: `%v`\n", output, gasUsed, err) } diff --git a/eth/tracers/logger/logger_json.go b/eth/tracers/logger/logger_json.go index 838d5017b863..a2cb4cd9fc59 100644 --- a/eth/tracers/logger/logger_json.go +++ b/eth/tracers/logger/logger_json.go @@ -20,7 +20,6 @@ import ( "encoding/json" "io" "math/big" - "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" @@ -80,18 +79,17 @@ func (l *JSONLogger) CaptureState(pc uint64, op vm.OpCode, gas, cost uint64, sco } // CaptureEnd is triggered at end of execution. -func (l *JSONLogger) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) { +func (l *JSONLogger) CaptureEnd(output []byte, gasUsed uint64, err error) { type endLog struct { Output string `json:"output"` GasUsed math.HexOrDecimal64 `json:"gasUsed"` - Time time.Duration `json:"time"` Err string `json:"error,omitempty"` } var errMsg string if err != nil { errMsg = err.Error() } - l.encoder.Encode(endLog{common.Bytes2Hex(output), math.HexOrDecimal64(gasUsed), t, errMsg}) + l.encoder.Encode(endLog{common.Bytes2Hex(output), math.HexOrDecimal64(gasUsed), errMsg}) } func (l *JSONLogger) CaptureEnter(typ vm.OpCode, from common.Address, to common.Address, input []byte, gas uint64, value *big.Int) { diff --git a/eth/tracers/native/call.go b/eth/tracers/native/call.go index 4be242c8b43d..24fd406398bb 100644 --- a/eth/tracers/native/call.go +++ b/eth/tracers/native/call.go @@ -21,7 +21,6 @@ import ( "errors" "math/big" "sync/atomic" - "time" "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/common" @@ -142,7 +141,7 @@ func (t *callTracer) CaptureStart(env *vm.EVM, from common.Address, to common.Ad } // CaptureEnd is called after the call finishes to finalize the tracing. -func (t *callTracer) CaptureEnd(output []byte, gasUsed uint64, _ time.Duration, err error) { +func (t *callTracer) CaptureEnd(output []byte, gasUsed uint64, err error) { t.callstack[0].processOutput(output, err) } diff --git a/eth/tracers/native/mux.go b/eth/tracers/native/mux.go index 05b5e3d808b6..878e2dc9d6d7 100644 --- a/eth/tracers/native/mux.go +++ b/eth/tracers/native/mux.go @@ -19,7 +19,6 @@ package native import ( "encoding/json" "math/big" - "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/vm" @@ -67,9 +66,9 @@ func (t *muxTracer) CaptureStart(env *vm.EVM, from common.Address, to common.Add } // CaptureEnd is called after the call finishes to finalize the tracing. -func (t *muxTracer) CaptureEnd(output []byte, gasUsed uint64, elapsed time.Duration, err error) { +func (t *muxTracer) CaptureEnd(output []byte, gasUsed uint64, err error) { for _, t := range t.tracers { - t.CaptureEnd(output, gasUsed, elapsed, err) + t.CaptureEnd(output, gasUsed, err) } } diff --git a/eth/tracers/native/noop.go b/eth/tracers/native/noop.go index c252b2408fc9..c1035bd1b7c6 100644 --- a/eth/tracers/native/noop.go +++ b/eth/tracers/native/noop.go @@ -19,7 +19,6 @@ package native import ( "encoding/json" "math/big" - "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/vm" @@ -44,7 +43,7 @@ func (t *noopTracer) CaptureStart(env *vm.EVM, from common.Address, to common.Ad } // CaptureEnd is called after the call finishes to finalize the tracing. -func (t *noopTracer) CaptureEnd(output []byte, gasUsed uint64, _ time.Duration, err error) { +func (t *noopTracer) CaptureEnd(output []byte, gasUsed uint64, err error) { } // CaptureState implements the EVMLogger interface to trace a single step of VM execution. diff --git a/eth/tracers/native/prestate.go b/eth/tracers/native/prestate.go index b965c50df730..9313d0769071 100644 --- a/eth/tracers/native/prestate.go +++ b/eth/tracers/native/prestate.go @@ -21,7 +21,6 @@ import ( "encoding/json" "math/big" "sync/atomic" - "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" @@ -118,7 +117,7 @@ func (t *prestateTracer) CaptureStart(env *vm.EVM, from common.Address, to commo } // CaptureEnd is called after the call finishes to finalize the tracing. -func (t *prestateTracer) CaptureEnd(output []byte, gasUsed uint64, _ time.Duration, err error) { +func (t *prestateTracer) CaptureEnd(output []byte, gasUsed uint64, err error) { if t.config.DiffMode { return } diff --git a/graphql/graphql_test.go b/graphql/graphql_test.go index 491c73152113..46acd1529342 100644 --- a/graphql/graphql_test.go +++ b/graphql/graphql_test.go @@ -321,10 +321,11 @@ func TestGraphQLTransactionLogs(t *testing.T) { func createNode(t *testing.T) *node.Node { stack, err := node.New(&node.Config{ - HTTPHost: "127.0.0.1", - HTTPPort: 0, - WSHost: "127.0.0.1", - WSPort: 0, + HTTPHost: "127.0.0.1", + HTTPPort: 0, + WSHost: "127.0.0.1", + WSPort: 0, + HTTPTimeouts: node.DefaultConfig.HTTPTimeouts, }) if err != nil { t.Fatalf("could not create node: %v", err) diff --git a/graphql/service.go b/graphql/service.go index 684fdc71268d..4392dd83e688 100644 --- a/graphql/service.go +++ b/graphql/service.go @@ -20,12 +20,16 @@ import ( "context" "encoding/json" "net/http" + "strconv" + "sync" "time" "github.com/ethereum/go-ethereum/eth/filters" "github.com/ethereum/go-ethereum/internal/ethapi" "github.com/ethereum/go-ethereum/node" + "github.com/ethereum/go-ethereum/rpc" "github.com/graph-gophers/graphql-go" + gqlErrors "github.com/graph-gophers/graphql-go/errors" ) type handler struct { @@ -43,21 +47,60 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second) + var ( + ctx = r.Context() + responded sync.Once + timer *time.Timer + cancel context.CancelFunc + ) + ctx, cancel = context.WithCancel(ctx) defer cancel() - response := h.Schema.Exec(ctx, params.Query, params.OperationName, params.Variables) - responseJSON, err := json.Marshal(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if len(response.Errors) > 0 { - w.WriteHeader(http.StatusBadRequest) + if timeout, ok := rpc.ContextRequestTimeout(ctx); ok { + timer = time.AfterFunc(timeout, func() { + responded.Do(func() { + // Cancel request handling. + cancel() + + // Create the timeout response. + response := &graphql.Response{ + Errors: []*gqlErrors.QueryError{{Message: "request timed out"}}, + } + responseJSON, err := json.Marshal(response) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Setting this disables gzip compression in package node. + w.Header().Set("transfer-encoding", "identity") + + // Flush the response. Since we are writing close to the response timeout, + // chunked transfer encoding must be disabled by setting content-length. + w.Header().Set("content-type", "application/json") + w.Header().Set("content-length", strconv.Itoa(len(responseJSON))) + w.Write(responseJSON) + if flush, ok := w.(http.Flusher); ok { + flush.Flush() + } + }) + }) } - w.Header().Set("Content-Type", "application/json") - w.Write(responseJSON) + response := h.Schema.Exec(ctx, params.Query, params.OperationName, params.Variables) + timer.Stop() + responded.Do(func() { + responseJSON, err := json.Marshal(response) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if len(response.Errors) > 0 { + w.WriteHeader(http.StatusBadRequest) + } + w.Header().Set("Content-Type", "application/json") + w.Write(responseJSON) + }) } // New constructs a new GraphQL service instance. diff --git a/internal/web3ext/web3ext.go b/internal/web3ext/web3ext.go index 134562bde6fc..801afedaa02c 100644 --- a/internal/web3ext/web3ext.go +++ b/internal/web3ext/web3ext.go @@ -600,6 +600,12 @@ web3._extend({ call: 'eth_getLogs', params: 1, }), + new web3._extend.Method({ + name: 'call', + call: 'eth_call', + params: 3, + inputFormatter: [web3._extend.formatters.inputCallFormatter, web3._extend.formatters.inputDefaultBlockNumberFormatter, null], + }), ], properties: [ new web3._extend.Property({ diff --git a/les/client.go b/les/client.go index c304bf86f8a8..7aa4f9b8cc81 100644 --- a/les/client.go +++ b/les/client.go @@ -48,6 +48,7 @@ import ( "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rpc" + "github.com/ethereum/go-ethereum/trie" ) type LightEthereum struct { @@ -99,7 +100,7 @@ func New(stack *node.Node, config *ethconfig.Config) (*LightEthereum, error) { if config.OverrideTerminalTotalDifficultyPassed != nil { overrides.OverrideTerminalTotalDifficultyPassed = config.OverrideTerminalTotalDifficultyPassed } - chainConfig, genesisHash, genesisErr := core.SetupGenesisBlockWithOverride(chainDb, config.Genesis, &overrides) + chainConfig, genesisHash, genesisErr := core.SetupGenesisBlockWithOverride(chainDb, trie.NewDatabase(chainDb), config.Genesis, &overrides) if _, isCompat := genesisErr.(*params.ConfigCompatError); genesisErr != nil && !isCompat { return nil, genesisErr } diff --git a/les/downloader/downloader.go b/les/downloader/downloader.go index 9eb7be715cdb..b005aa6a492f 100644 --- a/les/downloader/downloader.go +++ b/les/downloader/downloader.go @@ -226,7 +226,7 @@ func New(checkpoint uint64, stateDb ethdb.Database, mux *event.TypeMux, chain Bl headerProcCh: make(chan []*types.Header, 1), quitCh: make(chan struct{}), stateCh: make(chan dataPack), - SnapSyncer: snap.NewSyncer(stateDb), + SnapSyncer: snap.NewSyncer(stateDb, nil), stateSyncStart: make(chan *stateSync), //syncStatsState: stateSyncStats{ // processed: rawdb.ReadFastTrieProgress(stateDb), diff --git a/les/downloader/statesync.go b/les/downloader/statesync.go index 22f952155f11..8816d936f722 100644 --- a/les/downloader/statesync.go +++ b/les/downloader/statesync.go @@ -22,6 +22,7 @@ import ( "time" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" @@ -295,10 +296,13 @@ type codeTask struct { // newStateSync creates a new state trie download scheduler. This method does not // yet start the sync. The user needs to call run to initiate. func newStateSync(d *Downloader, root common.Hash) *stateSync { + // Hack the node scheme here. It's a dead code is not used + // by light client at all. Just aim for passing tests. + scheme := trie.NewDatabase(rawdb.NewMemoryDatabase()).Scheme() return &stateSync{ d: d, root: root, - sched: state.NewStateSync(root, d.stateDB, nil), + sched: state.NewStateSync(root, d.stateDB, nil, scheme), keccak: sha3.NewLegacyKeccak256().(crypto.KeccakState), trieTasks: make(map[string]*trieTask), codeTasks: make(map[common.Hash]*codeTask), diff --git a/miner/miner_test.go b/miner/miner_test.go index 7c07b21dd82f..7bf091f375e5 100644 --- a/miner/miner_test.go +++ b/miner/miner_test.go @@ -31,7 +31,6 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/eth/downloader" - "github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/trie" ) @@ -247,10 +246,9 @@ func createMiner(t *testing.T) (*Miner, *event.TypeMux, func(skipMiner bool)) { Etherbase: common.HexToAddress("123456789"), } // Create chainConfig - memdb := memorydb.New() - chainDB := rawdb.NewDatabase(memdb) + chainDB := rawdb.NewMemoryDatabase() genesis := core.DeveloperGenesisBlock(15, 11_500_000, common.HexToAddress("12345")) - chainConfig, _, err := core.SetupGenesisBlock(chainDB, genesis) + chainConfig, _, err := core.SetupGenesisBlock(chainDB, trie.NewDatabase(chainDB), genesis) if err != nil { t.Fatalf("can't create new chain config: %v", err) } diff --git a/miner/worker_test.go b/miner/worker_test.go index 0730fc60ed5e..4b8b7c0518f4 100644 --- a/miner/worker_test.go +++ b/miner/worker_test.go @@ -369,7 +369,7 @@ func TestStreamUncleBlock(t *testing.T) { w, b := newTestWorker(t, ethashChainConfig, ethash, rawdb.NewMemoryDatabase(), 1) defer w.close() - var taskCh = make(chan struct{}) + var taskCh = make(chan struct{}, 3) taskIndex := 0 w.newTaskHook = func(task *task) { diff --git a/node/api_test.go b/node/api_test.go index d76cb943e4ee..8761c4883ef8 100644 --- a/node/api_test.go +++ b/node/api_test.go @@ -252,6 +252,9 @@ func TestStartRPC(t *testing.T) { config := test.cfg // config.Logger = testlog.Logger(t, log.LvlDebug) config.P2P.NoDiscovery = true + if config.HTTPTimeouts == (rpc.HTTPTimeouts{}) { + config.HTTPTimeouts = rpc.DefaultHTTPTimeouts + } // Create Node. stack, err := New(&config) diff --git a/node/node_test.go b/node/node_test.go index 7c76e21f6baf..560d487fa823 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -559,13 +559,13 @@ func (test rpcPrefixTest) check(t *testing.T, node *Node) { } for _, path := range test.wantHTTP { - resp := rpcRequest(t, httpBase+path) + resp := rpcRequest(t, httpBase+path, testMethod) if resp.StatusCode != 200 { t.Errorf("Error: %s: bad status code %d, want 200", path, resp.StatusCode) } } for _, path := range test.wantNoHTTP { - resp := rpcRequest(t, httpBase+path) + resp := rpcRequest(t, httpBase+path, testMethod) if resp.StatusCode != 404 { t.Errorf("Error: %s: bad status code %d, want 404", path, resp.StatusCode) } @@ -586,10 +586,11 @@ func (test rpcPrefixTest) check(t *testing.T, node *Node) { func createNode(t *testing.T, httpPort, wsPort int) *Node { conf := &Config{ - HTTPHost: "127.0.0.1", - HTTPPort: httpPort, - WSHost: "127.0.0.1", - WSPort: wsPort, + HTTPHost: "127.0.0.1", + HTTPPort: httpPort, + WSHost: "127.0.0.1", + WSPort: wsPort, + HTTPTimeouts: rpc.DefaultHTTPTimeouts, } node, err := New(conf) if err != nil { diff --git a/node/rpcstack.go b/node/rpcstack.go index 8244c892ff50..97d591642c09 100644 --- a/node/rpcstack.go +++ b/node/rpcstack.go @@ -24,6 +24,7 @@ import ( "net" "net/http" "sort" + "strconv" "strings" "sync" "sync/atomic" @@ -196,6 +197,7 @@ func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } return } + // if http-rpc is enabled, try to serve request rpc := h.httpHandler.Load().(*rpcHandler) if rpc != nil { @@ -462,17 +464,94 @@ var gzPool = sync.Pool{ } type gzipResponseWriter struct { - io.Writer - http.ResponseWriter + resp http.ResponseWriter + + gz *gzip.Writer + contentLength uint64 // total length of the uncompressed response + written uint64 // amount of written bytes from the uncompressed response + hasLength bool // true if uncompressed response had Content-Length + inited bool // true after init was called for the first time +} + +// init runs just before response headers are written. Among other things, this function +// also decides whether compression will be applied at all. +func (w *gzipResponseWriter) init() { + if w.inited { + return + } + w.inited = true + + hdr := w.resp.Header() + length := hdr.Get("content-length") + if len(length) > 0 { + if n, err := strconv.ParseUint(length, 10, 64); err != nil { + w.hasLength = true + w.contentLength = n + } + } + + // Setting Transfer-Encoding to "identity" explicitly disables compression. net/http + // also recognizes this header value and uses it to disable "chunked" transfer + // encoding, trimming the header from the response. This means downstream handlers can + // set this without harm, even if they aren't wrapped by newGzipHandler. + // + // In go-ethereum, we use this signal to disable compression for certain error + // responses which are flushed out close to the write deadline of the response. For + // these cases, we want to avoid chunked transfer encoding and compression because + // they require additional output that may not get written in time. + passthrough := hdr.Get("transfer-encoding") == "identity" + if !passthrough { + w.gz = gzPool.Get().(*gzip.Writer) + w.gz.Reset(w.resp) + hdr.Del("content-length") + hdr.Set("content-encoding", "gzip") + } +} + +func (w *gzipResponseWriter) Header() http.Header { + return w.resp.Header() } func (w *gzipResponseWriter) WriteHeader(status int) { - w.Header().Del("Content-Length") - w.ResponseWriter.WriteHeader(status) + w.init() + w.resp.WriteHeader(status) } func (w *gzipResponseWriter) Write(b []byte) (int, error) { - return w.Writer.Write(b) + w.init() + + if w.gz == nil { + // Compression is disabled. + return w.resp.Write(b) + } + + n, err := w.gz.Write(b) + w.written += uint64(n) + if w.hasLength && w.written >= w.contentLength { + // The HTTP handler has finished writing the entire uncompressed response. Close + // the gzip stream to ensure the footer will be seen by the client in case the + // response is flushed after this call to write. + err = w.gz.Close() + } + return n, err +} + +func (w *gzipResponseWriter) Flush() { + if w.gz != nil { + w.gz.Flush() + } + if f, ok := w.resp.(http.Flusher); ok { + f.Flush() + } +} + +func (w *gzipResponseWriter) close() { + if w.gz == nil { + return + } + w.gz.Close() + gzPool.Put(w.gz) + w.gz = nil } func newGzipHandler(next http.Handler) http.Handler { @@ -482,15 +561,10 @@ func newGzipHandler(next http.Handler) http.Handler { return } - w.Header().Set("Content-Encoding", "gzip") - - gz := gzPool.Get().(*gzip.Writer) - defer gzPool.Put(gz) - - gz.Reset(w) - defer gz.Close() + wrapper := &gzipResponseWriter{resp: w} + defer wrapper.close() - next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, Writer: gz}, r) + next.ServeHTTP(wrapper, r) }) } diff --git a/node/rpcstack_test.go b/node/rpcstack_test.go index ebc253800623..795bc93c8386 100644 --- a/node/rpcstack_test.go +++ b/node/rpcstack_test.go @@ -19,7 +19,9 @@ package node import ( "bytes" "fmt" + "io" "net/http" + "net/http/httptest" "net/url" "strconv" "strings" @@ -34,29 +36,31 @@ import ( "github.com/stretchr/testify/assert" ) +const testMethod = "rpc_modules" + // TestCorsHandler makes sure CORS are properly handled on the http server. func TestCorsHandler(t *testing.T) { - srv := createAndStartServer(t, &httpConfig{CorsAllowedOrigins: []string{"test", "test.com"}}, false, &wsConfig{}) + srv := createAndStartServer(t, &httpConfig{CorsAllowedOrigins: []string{"test", "test.com"}}, false, &wsConfig{}, nil) defer srv.stop() url := "http://" + srv.listenAddr() - resp := rpcRequest(t, url, "origin", "test.com") + resp := rpcRequest(t, url, testMethod, "origin", "test.com") assert.Equal(t, "test.com", resp.Header.Get("Access-Control-Allow-Origin")) - resp2 := rpcRequest(t, url, "origin", "bad") + resp2 := rpcRequest(t, url, testMethod, "origin", "bad") assert.Equal(t, "", resp2.Header.Get("Access-Control-Allow-Origin")) } // TestVhosts makes sure vhosts are properly handled on the http server. func TestVhosts(t *testing.T) { - srv := createAndStartServer(t, &httpConfig{Vhosts: []string{"test"}}, false, &wsConfig{}) + srv := createAndStartServer(t, &httpConfig{Vhosts: []string{"test"}}, false, &wsConfig{}, nil) defer srv.stop() url := "http://" + srv.listenAddr() - resp := rpcRequest(t, url, "host", "test") + resp := rpcRequest(t, url, testMethod, "host", "test") assert.Equal(t, resp.StatusCode, http.StatusOK) - resp2 := rpcRequest(t, url, "host", "bad") + resp2 := rpcRequest(t, url, testMethod, "host", "bad") assert.Equal(t, resp2.StatusCode, http.StatusForbidden) } @@ -145,7 +149,7 @@ func TestWebsocketOrigins(t *testing.T) { }, } for _, tc := range tests { - srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: splitAndTrim(tc.spec)}) + srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: splitAndTrim(tc.spec)}, nil) url := fmt.Sprintf("ws://%v", srv.listenAddr()) for _, origin := range tc.expOk { if err := wsRequest(t, url, "Origin", origin); err != nil { @@ -231,11 +235,14 @@ func Test_checkPath(t *testing.T) { } } -func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsConfig) *httpServer { +func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsConfig, timeouts *rpc.HTTPTimeouts) *httpServer { t.Helper() - srv := newHTTPServer(testlog.Logger(t, log.LvlDebug), rpc.DefaultHTTPTimeouts) - assert.NoError(t, srv.enableRPC(nil, *conf)) + if timeouts == nil { + timeouts = &rpc.DefaultHTTPTimeouts + } + srv := newHTTPServer(testlog.Logger(t, log.LvlDebug), *timeouts) + assert.NoError(t, srv.enableRPC(apis(), *conf)) if ws { assert.NoError(t, srv.enableWS(nil, *wsConf)) } @@ -266,16 +273,33 @@ func wsRequest(t *testing.T, url string, extraHeaders ...string) error { } // rpcRequest performs a JSON-RPC request to the given URL. -func rpcRequest(t *testing.T, url string, extraHeaders ...string) *http.Response { +func rpcRequest(t *testing.T, url, method string, extraHeaders ...string) *http.Response { + t.Helper() + + body := fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"method":"%s","params":[]}`, method) + return baseRpcRequest(t, url, body, extraHeaders...) +} + +func batchRpcRequest(t *testing.T, url string, methods []string, extraHeaders ...string) *http.Response { + reqs := make([]string, len(methods)) + for i, m := range methods { + reqs[i] = fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"method":"%s","params":[]}`, m) + } + body := fmt.Sprintf(`[%s]`, strings.Join(reqs, ",")) + return baseRpcRequest(t, url, body, extraHeaders...) +} + +func baseRpcRequest(t *testing.T, url, bodyStr string, extraHeaders ...string) *http.Response { t.Helper() // Create the request. - body := bytes.NewReader([]byte(`{"jsonrpc":"2.0","id":1,"method":"rpc_modules","params":[]}`)) + body := bytes.NewReader([]byte(bodyStr)) req, err := http.NewRequest("POST", url, body) if err != nil { t.Fatal("could not create http request:", err) } req.Header.Set("content-type", "application/json") + req.Header.Set("accept-encoding", "identity") // Apply extra headers. if len(extraHeaders)%2 != 0 { @@ -315,7 +339,7 @@ func TestJWT(t *testing.T) { return ss } srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")}, - true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}) + true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}, nil) wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr()) htUrl := fmt.Sprintf("http://%v", srv.listenAddr()) @@ -348,7 +372,7 @@ func TestJWT(t *testing.T) { t.Errorf("test %d-ws, token '%v': expected ok, got %v", i, token, err) } token = tokenFn() - if resp := rpcRequest(t, htUrl, "Authorization", token); resp.StatusCode != 200 { + if resp := rpcRequest(t, htUrl, testMethod, "Authorization", token); resp.StatusCode != 200 { t.Errorf("test %d-http, token '%v': expected ok, got %v", i, token, resp.StatusCode) } } @@ -414,10 +438,176 @@ func TestJWT(t *testing.T) { } token = tokenFn() - resp := rpcRequest(t, htUrl, "Authorization", token) + resp := rpcRequest(t, htUrl, testMethod, "Authorization", token) if resp.StatusCode != http.StatusUnauthorized { t.Errorf("tc %d-http, token '%v': expected not to allow, got %v", i, token, resp.StatusCode) } } srv.stop() } + +func TestGzipHandler(t *testing.T) { + type gzipTest struct { + name string + handler http.HandlerFunc + status int + isGzip bool + header map[string]string + } + tests := []gzipTest{ + { + name: "Write", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("response")) + }, + isGzip: true, + status: 200, + }, + { + name: "WriteHeader", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("x-foo", "bar") + w.WriteHeader(205) + w.Write([]byte("response")) + }, + isGzip: true, + status: 205, + header: map[string]string{"x-foo": "bar"}, + }, + { + name: "WriteContentLength", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("content-length", "8") + w.Write([]byte("response")) + }, + isGzip: true, + status: 200, + }, + { + name: "Flush", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("res")) + w.(http.Flusher).Flush() + w.Write([]byte("ponse")) + }, + isGzip: true, + status: 200, + }, + { + name: "disable", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("transfer-encoding", "identity") + w.Header().Set("x-foo", "bar") + w.Write([]byte("response")) + }, + isGzip: false, + status: 200, + header: map[string]string{"x-foo": "bar"}, + }, + { + name: "disable-WriteHeader", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("transfer-encoding", "identity") + w.Header().Set("x-foo", "bar") + w.WriteHeader(205) + w.Write([]byte("response")) + }, + isGzip: false, + status: 205, + header: map[string]string{"x-foo": "bar"}, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + srv := httptest.NewServer(newGzipHandler(test.handler)) + defer srv.Close() + + resp, err := http.Get(srv.URL) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + content, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + wasGzip := resp.Uncompressed + + if string(content) != "response" { + t.Fatalf("wrong response content %q", content) + } + if wasGzip != test.isGzip { + t.Fatalf("response gzipped == %t, want %t", wasGzip, test.isGzip) + } + if resp.StatusCode != test.status { + t.Fatalf("response status == %d, want %d", resp.StatusCode, test.status) + } + for name, expectedValue := range test.header { + if v := resp.Header.Get(name); v != expectedValue { + t.Fatalf("response header %s == %s, want %s", name, v, expectedValue) + } + } + }) + } +} + +func TestHTTPWriteTimeout(t *testing.T) { + const ( + timeoutRes = `{"jsonrpc":"2.0","id":1,"error":{"code":-32002,"message":"request timed out"}}` + greetRes = `{"jsonrpc":"2.0","id":1,"result":"Hello"}` + ) + // Set-up server + timeouts := rpc.DefaultHTTPTimeouts + timeouts.WriteTimeout = time.Second + srv := createAndStartServer(t, &httpConfig{Modules: []string{"test"}}, false, &wsConfig{}, &timeouts) + url := fmt.Sprintf("http://%v", srv.listenAddr()) + + // Send normal request + t.Run("message", func(t *testing.T) { + resp := rpcRequest(t, url, "test_sleep") + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if string(body) != timeoutRes { + t.Errorf("wrong response. have %s, want %s", string(body), timeoutRes) + } + }) + + // Batch request + t.Run("batch", func(t *testing.T) { + want := fmt.Sprintf("[%s,%s,%s]", greetRes, timeoutRes, timeoutRes) + resp := batchRpcRequest(t, url, []string{"test_greet", "test_sleep", "test_greet"}) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if string(body) != want { + t.Errorf("wrong response. have %s, want %s", string(body), want) + } + }) +} + +func apis() []rpc.API { + return []rpc.API{ + { + Namespace: "test", + Service: &testService{}, + }, + } +} + +type testService struct{} + +func (s *testService) Greet() string { + return "Hello" +} + +func (s *testService) Sleep() { + time.Sleep(1500 * time.Millisecond) +} diff --git a/p2p/discover/common.go b/p2p/discover/common.go index e389821fda8b..c36e8dcc3a71 100644 --- a/p2p/discover/common.go +++ b/p2p/discover/common.go @@ -35,16 +35,24 @@ type UDPConn interface { LocalAddr() net.Addr } +type V5Config struct { + ProtocolID *[6]byte +} + // Config holds settings for the discovery listener. type Config struct { // These settings are required and configure the UDP listener: PrivateKey *ecdsa.PrivateKey // These settings are optional: - NetRestrict *netutil.Netlist // list of allowed IP networks - Bootnodes []*enode.Node // list of bootstrap nodes - Unhandled chan<- ReadPacket // unhandled packets are sent on this channel - Log log.Logger // if set, log messages go here + NetRestrict *netutil.Netlist // list of allowed IP networks + Bootnodes []*enode.Node // list of bootstrap nodes + Unhandled chan<- ReadPacket // unhandled packets are sent on this channel + Log log.Logger // if set, log messages go here + + // V5ProtocolID configures the discv5 protocol identifier. + V5ProtocolID *[6]byte + ValidSchemes enr.IdentityScheme // allowed identity schemes Clock mclock.Clock } diff --git a/p2p/discover/table.go b/p2p/discover/table.go index d08f8a6c69cb..41d5ac6e34e7 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -672,15 +672,14 @@ func (h *nodesByDistance) push(n *node, maxElems int) { ix := sort.Search(len(h.entries), func(i int) bool { return enode.DistCmp(h.target, h.entries[i].ID(), n.ID()) > 0 }) + + end := len(h.entries) if len(h.entries) < maxElems { h.entries = append(h.entries, n) } - if ix == len(h.entries) { - // farther away than all nodes we already have. - // if there was room for it, the node is now the last element. - } else { - // slide existing entries down to make room - // this will overwrite the entry we just appended. + if ix < end { + // Slide existing entries down to make room. + // This will overwrite the entry we just appended. copy(h.entries[ix+1:], h.entries[ix:]) h.entries[ix] = n } diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index 5f40c967fd5b..1ef63fe01019 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -396,6 +396,59 @@ func TestTable_revalidateSyncRecord(t *testing.T) { } } +func TestNodesPush(t *testing.T) { + var target enode.ID + n1 := nodeAtDistance(target, 255, intIP(1)) + n2 := nodeAtDistance(target, 254, intIP(2)) + n3 := nodeAtDistance(target, 253, intIP(3)) + perm := [][]*node{ + {n3, n2, n1}, + {n3, n1, n2}, + {n2, n3, n1}, + {n2, n1, n3}, + {n1, n3, n2}, + {n1, n2, n3}, + } + + // Insert all permutations into lists with size limit 3. + for _, nodes := range perm { + list := nodesByDistance{target: target} + for _, n := range nodes { + list.push(n, 3) + } + if !slicesEqual(list.entries, perm[0], nodeIDEqual) { + t.Fatal("not equal") + } + } + + // Insert all permutations into lists with size limit 2. + for _, nodes := range perm { + list := nodesByDistance{target: target} + for _, n := range nodes { + list.push(n, 2) + } + if !slicesEqual(list.entries, perm[0][:2], nodeIDEqual) { + t.Fatal("not equal") + } + } +} + +func nodeIDEqual(n1, n2 *node) bool { + return n1.ID() == n2.ID() +} + +func slicesEqual[T any](s1, s2 []T, check func(e1, e2 T) bool) bool { + if len(s1) != len(s2) { + return false + } + for i := range s1 { + if !check(s1[i], s2[i]) { + return false + } + } + return true +} + // gen wraps quick.Value so it's easier to use. // it generates a random value of the given value's type. func gen(typ interface{}, rand *rand.Rand) interface{} { diff --git a/p2p/discover/v5_udp.go b/p2p/discover/v5_udp.go index 321c5bd2a818..57d624498ea1 100644 --- a/p2p/discover/v5_udp.go +++ b/p2p/discover/v5_udp.go @@ -154,7 +154,7 @@ func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) { callDoneCh: make(chan *callV5), respTimeoutCh: make(chan *callTimeout), // state of dispatch - codec: v5wire.NewCodec(ln, cfg.PrivateKey, cfg.Clock), + codec: v5wire.NewCodec(ln, cfg.PrivateKey, cfg.Clock, cfg.V5ProtocolID), activeCallByNode: make(map[enode.ID]*callV5), activeCallByAuth: make(map[v5wire.Nonce]*callV5), callQueue: make(map[enode.ID][]*callV5), diff --git a/p2p/discover/v5wire/encoding.go b/p2p/discover/v5wire/encoding.go index e41d7f4c451e..d979ab0f9cd8 100644 --- a/p2p/discover/v5wire/encoding.go +++ b/p2p/discover/v5wire/encoding.go @@ -98,7 +98,7 @@ const ( randomPacketMsgSize = 20 ) -var protocolID = [6]byte{'d', 'i', 's', 'c', 'v', '5'} +var DefaultProtocolID = [6]byte{'d', 'i', 's', 'c', 'v', '5'} // Errors. var ( @@ -134,10 +134,11 @@ var ( // Codec encodes and decodes Discovery v5 packets. // This type is not safe for concurrent use. type Codec struct { - sha256 hash.Hash - localnode *enode.LocalNode - privkey *ecdsa.PrivateKey - sc *SessionCache + sha256 hash.Hash + localnode *enode.LocalNode + privkey *ecdsa.PrivateKey + sc *SessionCache + protocolID [6]byte // encoder buffers buf bytes.Buffer // whole packet @@ -150,12 +151,16 @@ type Codec struct { } // NewCodec creates a wire codec. -func NewCodec(ln *enode.LocalNode, key *ecdsa.PrivateKey, clock mclock.Clock) *Codec { +func NewCodec(ln *enode.LocalNode, key *ecdsa.PrivateKey, clock mclock.Clock, protocolID *[6]byte) *Codec { c := &Codec{ - sha256: sha256.New(), - localnode: ln, - privkey: key, - sc: NewSessionCache(1024, clock), + sha256: sha256.New(), + localnode: ln, + privkey: key, + sc: NewSessionCache(1024, clock), + protocolID: DefaultProtocolID, + } + if protocolID != nil { + c.protocolID = *protocolID } return c } @@ -255,7 +260,7 @@ func (c *Codec) makeHeader(toID enode.ID, flag byte, authsizeExtra int) Header { } return Header{ StaticHeader: StaticHeader{ - ProtocolID: protocolID, + ProtocolID: c.protocolID, Version: version, Flag: flag, AuthSize: uint16(authsize), @@ -434,7 +439,7 @@ func (c *Codec) Decode(input []byte, addr string) (src enode.ID, n *enode.Node, c.reader.Reset(staticHeader) binary.Read(&c.reader, binary.BigEndian, &head.StaticHeader) remainingInput := len(input) - sizeofStaticPacketData - if err := head.checkValid(remainingInput); err != nil { + if err := head.checkValid(remainingInput, c.protocolID); err != nil { return enode.ID{}, nil, nil, err } @@ -621,7 +626,7 @@ func (c *Codec) decryptMessage(input, nonce, headerData, readKey []byte) (Packet // checkValid performs some basic validity checks on the header. // The packetLen here is the length remaining after the static header. -func (h *StaticHeader) checkValid(packetLen int) error { +func (h *StaticHeader) checkValid(packetLen int, protocolID [6]byte) error { if h.ProtocolID != protocolID { return errInvalidHeader } diff --git a/p2p/discover/v5wire/encoding_test.go b/p2p/discover/v5wire/encoding_test.go index a08cffa2a576..25df732835dd 100644 --- a/p2p/discover/v5wire/encoding_test.go +++ b/p2p/discover/v5wire/encoding_test.go @@ -504,8 +504,8 @@ type handshakeTestNode struct { func newHandshakeTest() *handshakeTest { t := new(handshakeTest) - t.nodeA.init(testKeyA, net.IP{127, 0, 0, 1}, &t.clock) - t.nodeB.init(testKeyB, net.IP{127, 0, 0, 1}, &t.clock) + t.nodeA.init(testKeyA, net.IP{127, 0, 0, 1}, &t.clock, DefaultProtocolID) + t.nodeB.init(testKeyB, net.IP{127, 0, 0, 1}, &t.clock, DefaultProtocolID) return t } @@ -514,11 +514,11 @@ func (t *handshakeTest) close() { t.nodeB.ln.Database().Close() } -func (n *handshakeTestNode) init(key *ecdsa.PrivateKey, ip net.IP, clock mclock.Clock) { +func (n *handshakeTestNode) init(key *ecdsa.PrivateKey, ip net.IP, clock mclock.Clock, protocolID [6]byte) { db, _ := enode.OpenDB("") n.ln = enode.NewLocalNode(db, key) n.ln.SetStaticIP(ip) - n.c = NewCodec(n.ln, key, clock) + n.c = NewCodec(n.ln, key, clock, nil) } func (n *handshakeTestNode) encode(t testing.TB, to handshakeTestNode, p Packet) ([]byte, Nonce) { diff --git a/p2p/nat/natpmp.go b/p2p/nat/natpmp.go index 7f85543f8e29..40f2aff44e7a 100644 --- a/p2p/nat/natpmp.go +++ b/p2p/nat/natpmp.go @@ -50,8 +50,22 @@ func (n *pmp) AddMapping(protocol string, extport, intport int, name string, lif } // Note order of port arguments is switched between our // AddMapping and the client's AddPortMapping. - _, err := n.c.AddPortMapping(strings.ToLower(protocol), intport, extport, int(lifetime/time.Second)) - return err + res, err := n.c.AddPortMapping(strings.ToLower(protocol), intport, extport, int(lifetime/time.Second)) + if err != nil { + return err + } + + // NAT-PMP maps an alternative available port number if the requested + // port is already mapped to another address and returns success. In this + // case, we return an error because there is no way to return the new port + // to the caller. + if uint16(extport) != res.MappedExternalPort { + // Destroy the mapping in NAT device. + n.c.AddPortMapping(strings.ToLower(protocol), intport, 0, 0) + return fmt.Errorf("port %d already mapped to another address (%s)", extport, protocol) + } + + return nil } func (n *pmp) DeleteMapping(protocol string, extport, intport int) (err error) { @@ -95,13 +109,6 @@ func discoverPMP() Interface { return nil } -var ( - // LAN IP ranges - _, lan10, _ = net.ParseCIDR("10.0.0.0/8") - _, lan176, _ = net.ParseCIDR("172.16.0.0/12") - _, lan192, _ = net.ParseCIDR("192.168.0.0/16") -) - // TODO: improve this. We currently assume that (on most networks) // the router is X.X.X.1 in a local LAN range. func potentialGateways() (gws []net.IP) { @@ -116,7 +123,7 @@ func potentialGateways() (gws []net.IP) { } for _, addr := range ifaddrs { if x, ok := addr.(*net.IPNet); ok { - if lan10.Contains(x.IP) || lan176.Contains(x.IP) || lan192.Contains(x.IP) { + if x.IP.IsPrivate() { ip := x.IP.Mask(x.Mask).To4() if ip != nil { ip[3] = ip[3] | 0x01 diff --git a/p2p/simulations/adapters/inproc.go b/p2p/simulations/adapters/inproc.go index 1cb26a8ea05a..36b5286517ae 100644 --- a/p2p/simulations/adapters/inproc.go +++ b/p2p/simulations/adapters/inproc.go @@ -206,7 +206,7 @@ func (sn *SimNode) ServeRPC(conn *websocket.Conn) error { if err != nil { return err } - codec := rpc.NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON) + codec := rpc.NewFuncCodec(conn, func(v any, _ bool) error { return conn.WriteJSON(v) }, conn.ReadJSON) handler.ServeCodec(codec, 0) return nil } diff --git a/p2p/util.go b/p2p/util.go index 3c5f6b8508d5..2c8f322a66ac 100644 --- a/p2p/util.go +++ b/p2p/util.go @@ -70,6 +70,7 @@ func (h *expHeap) Pop() interface{} { old := *h n := len(old) x := old[n-1] + old[n-1] = expItem{} *h = old[0 : n-1] return x } diff --git a/rpc/client.go b/rpc/client.go index d89aa69277c7..a509cb2e0fa0 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -527,7 +527,7 @@ func (c *Client) write(ctx context.Context, msg interface{}, retry bool) error { return err } } - err := c.writeConn.writeJSON(ctx, msg) + err := c.writeConn.writeJSON(ctx, msg, false) if err != nil { c.writeConn = nil if !retry { @@ -660,7 +660,8 @@ func (c *Client) read(codec ServerCodec) { for { msgs, batch, err := codec.readBatch() if _, ok := err.(*json.SyntaxError); ok { - codec.writeJSON(context.Background(), errorMessage(&parseError{err.Error()})) + msg := errorMessage(&parseError{err.Error()}) + codec.writeJSON(context.Background(), msg, true) } if err != nil { c.readErr <- err diff --git a/rpc/errors.go b/rpc/errors.go index 9a19e9fe67f5..7188332d551e 100644 --- a/rpc/errors.go +++ b/rpc/errors.go @@ -60,10 +60,15 @@ var ( const ( errcodeDefault = -32000 errcodeNotificationsUnsupported = -32001 + errcodeTimeout = -32002 errcodePanic = -32603 errcodeMarshalError = -32603 ) +const ( + errMsgTimeout = "request timed out" +) + type methodNotFoundError struct{ method string } func (e *methodNotFoundError) ErrorCode() int { return -32601 } diff --git a/rpc/handler.go b/rpc/handler.go index f3052e7eb822..c2e7d7dc08c6 100644 --- a/rpc/handler.go +++ b/rpc/handler.go @@ -91,12 +91,83 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg * return h } +// batchCallBuffer manages in progress call messages and their responses during a batch +// call. Calls need to be synchronized between the processing and timeout-triggering +// goroutines. +type batchCallBuffer struct { + mutex sync.Mutex + calls []*jsonrpcMessage + resp []*jsonrpcMessage + wrote bool +} + +// nextCall returns the next unprocessed message. +func (b *batchCallBuffer) nextCall() *jsonrpcMessage { + b.mutex.Lock() + defer b.mutex.Unlock() + + if len(b.calls) == 0 { + return nil + } + // The popping happens in `pushAnswer`. The in progress call is kept + // so we can return an error for it in case of timeout. + msg := b.calls[0] + return msg +} + +// pushResponse adds the response to last call returned by nextCall. +func (b *batchCallBuffer) pushResponse(answer *jsonrpcMessage) { + b.mutex.Lock() + defer b.mutex.Unlock() + + if answer != nil { + b.resp = append(b.resp, answer) + } + b.calls = b.calls[1:] +} + +// write sends the responses. +func (b *batchCallBuffer) write(ctx context.Context, conn jsonWriter) { + b.mutex.Lock() + defer b.mutex.Unlock() + + b.doWrite(ctx, conn, false) +} + +// timeout sends the responses added so far. For the remaining unanswered call +// messages, it sends a timeout error response. +func (b *batchCallBuffer) timeout(ctx context.Context, conn jsonWriter) { + b.mutex.Lock() + defer b.mutex.Unlock() + + for _, msg := range b.calls { + if !msg.isNotification() { + resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout}) + b.resp = append(b.resp, resp) + } + } + b.doWrite(ctx, conn, true) +} + +// doWrite actually writes the response. +// This assumes b.mutex is held. +func (b *batchCallBuffer) doWrite(ctx context.Context, conn jsonWriter, isErrorResponse bool) { + if b.wrote { + return + } + b.wrote = true // can only write once + if len(b.resp) > 0 { + conn.writeJSON(ctx, b.resp, isErrorResponse) + } +} + // handleBatch executes all messages in a batch and returns the responses. func (h *handler) handleBatch(msgs []*jsonrpcMessage) { // Emit error response for empty batches: if len(msgs) == 0 { h.startCallProc(func(cp *callProc) { - h.conn.writeJSON(cp.ctx, errorMessage(&invalidRequestError{"empty batch"})) + resp := errorMessage(&invalidRequestError{"empty batch"}) + h.conn.writeJSON(cp.ctx, resp, true) }) return } @@ -113,16 +184,42 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) { } // Process calls on a goroutine because they may block indefinitely: h.startCallProc(func(cp *callProc) { - answers := make([]*jsonrpcMessage, 0, len(msgs)) - for _, msg := range calls { - if answer := h.handleCallMsg(cp, msg); answer != nil { - answers = append(answers, answer) + var ( + timer *time.Timer + cancel context.CancelFunc + callBuffer = &batchCallBuffer{calls: calls, resp: make([]*jsonrpcMessage, 0, len(calls))} + ) + + cp.ctx, cancel = context.WithCancel(cp.ctx) + defer cancel() + + // Cancel the request context after timeout and send an error response. Since the + // currently-running method might not return immediately on timeout, we must wait + // for the timeout concurrently with processing the request. + if timeout, ok := ContextRequestTimeout(cp.ctx); ok { + timer = time.AfterFunc(timeout, func() { + cancel() + callBuffer.timeout(cp.ctx, h.conn) + }) + } + + for { + // No need to handle rest of calls if timed out. + if cp.ctx.Err() != nil { + break } + msg := callBuffer.nextCall() + if msg == nil { + break + } + resp := h.handleCallMsg(cp, msg) + callBuffer.pushResponse(resp) } - h.addSubscriptions(cp.notifiers) - if len(answers) > 0 { - h.conn.writeJSON(cp.ctx, answers) + if timer != nil { + timer.Stop() } + callBuffer.write(cp.ctx, h.conn) + h.addSubscriptions(cp.notifiers) for _, n := range cp.notifiers { n.activate() } @@ -135,10 +232,36 @@ func (h *handler) handleMsg(msg *jsonrpcMessage) { return } h.startCallProc(func(cp *callProc) { + var ( + responded sync.Once + timer *time.Timer + cancel context.CancelFunc + ) + cp.ctx, cancel = context.WithCancel(cp.ctx) + defer cancel() + + // Cancel the request context after timeout and send an error response. Since the + // running method might not return immediately on timeout, we must wait for the + // timeout concurrently with processing the request. + if timeout, ok := ContextRequestTimeout(cp.ctx); ok { + timer = time.AfterFunc(timeout, func() { + cancel() + responded.Do(func() { + resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout}) + h.conn.writeJSON(cp.ctx, resp, true) + }) + }) + } + answer := h.handleCallMsg(cp, msg) + if timer != nil { + timer.Stop() + } h.addSubscriptions(cp.notifiers) if answer != nil { - h.conn.writeJSON(cp.ctx, answer) + responded.Do(func() { + h.conn.writeJSON(cp.ctx, answer, false) + }) } for _, n := range cp.notifiers { n.activate() @@ -334,7 +457,6 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage } start := time.Now() answer := h.runMethod(cp.ctx, msg, callb, args) - // Collect the statistics for RPC calls if metrics is enabled. // We only care about pure rpc call. Filter out subscription. if callb != h.unsubscribeCb { diff --git a/rpc/http.go b/rpc/http.go index 0ba6588f9906..bbabe15bada3 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -23,9 +23,11 @@ import ( "errors" "fmt" "io" + "math" "mime" "net/http" "net/url" + "strconv" "sync" "time" ) @@ -52,7 +54,7 @@ type httpConn struct { // and some methods don't work. The panic() stubs here exist to ensure // this special treatment is correct. -func (hc *httpConn) writeJSON(context.Context, interface{}) error { +func (hc *httpConn) writeJSON(context.Context, interface{}, bool) error { panic("writeJSON called on httpConn") } @@ -256,7 +258,42 @@ type httpServerConn struct { func newHTTPServerConn(r *http.Request, w http.ResponseWriter) ServerCodec { body := io.LimitReader(r.Body, maxRequestContentLength) conn := &httpServerConn{Reader: body, Writer: w, r: r} - return NewCodec(conn) + + encoder := func(v any, isErrorResponse bool) error { + if !isErrorResponse { + return json.NewEncoder(conn).Encode(v) + } + + // It's an error response and requires special treatment. + // + // In case of a timeout error, the response must be written before the HTTP + // server's write timeout occurs. So we need to flush the response. The + // Content-Length header also needs to be set to ensure the client knows + // when it has the full response. + encdata, err := json.Marshal(v) + if err != nil { + return err + } + w.Header().Set("content-length", strconv.Itoa(len(encdata))) + + // If this request is wrapped in a handler that might remove Content-Length (such + // as the automatic gzip we do in package node), we need to ensure the HTTP server + // doesn't perform chunked encoding. In case WriteTimeout is reached, the chunked + // encoding might not be finished correctly, and some clients do not like it when + // the final chunk is missing. + w.Header().Set("transfer-encoding", "identity") + + _, err = w.Write(encdata) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + return err + } + + dec := json.NewDecoder(conn) + dec.UseNumber() + + return NewFuncCodec(conn, encoder, dec.Decode) } // Close does nothing and always returns nil. @@ -326,3 +363,35 @@ func validateRequest(r *http.Request) (int, error) { err := fmt.Errorf("invalid content type, only %s is supported", contentType) return http.StatusUnsupportedMediaType, err } + +// ContextRequestTimeout returns the request timeout derived from the given context. +func ContextRequestTimeout(ctx context.Context) (time.Duration, bool) { + timeout := time.Duration(math.MaxInt64) + hasTimeout := false + setTimeout := func(d time.Duration) { + if d < timeout { + timeout = d + hasTimeout = true + } + } + + if deadline, ok := ctx.Deadline(); ok { + setTimeout(time.Until(deadline)) + } + + // If the context is an HTTP request context, use the server's WriteTimeout. + httpSrv, ok := ctx.Value(http.ServerContextKey).(*http.Server) + if ok && httpSrv.WriteTimeout > 0 { + wt := httpSrv.WriteTimeout + // When a write timeout is configured, we need to send the response message before + // the HTTP server cuts connection. So our internal timeout must be earlier than + // the server's true timeout. + // + // Note: Timeouts are sanitized to be a minimum of 1 second. + // Also see issue: https://github.com/golang/go/issues/47229 + wt -= 100 * time.Millisecond + setTimeout(wt) + } + + return timeout, hasTimeout +} diff --git a/rpc/json.go b/rpc/json.go index 1064939ff8b6..8a3b162cabbc 100644 --- a/rpc/json.go +++ b/rpc/json.go @@ -168,18 +168,22 @@ type ConnRemoteAddr interface { // support for parsing arguments and serializing (result) objects. type jsonCodec struct { remote string - closer sync.Once // close closed channel once - closeCh chan interface{} // closed on Close - decode func(v interface{}) error // decoder to allow multiple transports - encMu sync.Mutex // guards the encoder - encode func(v interface{}) error // encoder to allow multiple transports + closer sync.Once // close closed channel once + closeCh chan interface{} // closed on Close + decode decodeFunc // decoder to allow multiple transports + encMu sync.Mutex // guards the encoder + encode encodeFunc // encoder to allow multiple transports conn deadlineCloser } +type encodeFunc = func(v interface{}, isErrorResponse bool) error + +type decodeFunc = func(v interface{}) error + // NewFuncCodec creates a codec which uses the given functions to read and write. If conn // implements ConnRemoteAddr, log messages will use it to include the remote address of // the connection. -func NewFuncCodec(conn deadlineCloser, encode, decode func(v interface{}) error) ServerCodec { +func NewFuncCodec(conn deadlineCloser, encode encodeFunc, decode decodeFunc) ServerCodec { codec := &jsonCodec{ closeCh: make(chan interface{}), encode: encode, @@ -198,7 +202,11 @@ func NewCodec(conn Conn) ServerCodec { enc := json.NewEncoder(conn) dec := json.NewDecoder(conn) dec.UseNumber() - return NewFuncCodec(conn, enc.Encode, dec.Decode) + + encode := func(v interface{}, isErrorResponse bool) error { + return enc.Encode(v) + } + return NewFuncCodec(conn, encode, dec.Decode) } func (c *jsonCodec) peerInfo() PeerInfo { @@ -228,7 +236,7 @@ func (c *jsonCodec) readBatch() (messages []*jsonrpcMessage, batch bool, err err return messages, batch, nil } -func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}) error { +func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}, isErrorResponse bool) error { c.encMu.Lock() defer c.encMu.Unlock() @@ -237,7 +245,7 @@ func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}) error { deadline = time.Now().Add(defaultWriteTimeout) } c.conn.SetWriteDeadline(deadline) - return c.encode(v) + return c.encode(v, isErrorResponse) } func (c *jsonCodec) close() { diff --git a/rpc/server.go b/rpc/server.go index fe162d5a428e..9c72c26d7b94 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -125,7 +125,8 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) { reqs, batch, err := codec.readBatch() if err != nil { if err != io.EOF { - codec.writeJSON(ctx, errorMessage(&invalidMessageError{"parse error"})) + resp := errorMessage(&invalidMessageError{"parse error"}) + codec.writeJSON(ctx, resp, true) } return } diff --git a/rpc/subscription.go b/rpc/subscription.go index d7ba784fc532..334ead3ace4d 100644 --- a/rpc/subscription.go +++ b/rpc/subscription.go @@ -175,11 +175,13 @@ func (n *Notifier) activate() error { func (n *Notifier) send(sub *Subscription, data json.RawMessage) error { params, _ := json.Marshal(&subscriptionResult{ID: string(sub.ID), Result: data}) ctx := context.Background() - return n.h.conn.writeJSON(ctx, &jsonrpcMessage{ + + msg := &jsonrpcMessage{ Version: vsn, Method: n.namespace + notificationMethodSuffix, Params: params, - }) + } + return n.h.conn.writeJSON(ctx, msg, false) } // A Subscription is created by a notifier and tied to that notifier. The client can use diff --git a/rpc/types.go b/rpc/types.go index e7158796ead0..9dda067e7f2f 100644 --- a/rpc/types.go +++ b/rpc/types.go @@ -51,7 +51,9 @@ type ServerCodec interface { // jsonWriter can write JSON messages to its underlying connection. // Implementations must be safe for concurrent use. type jsonWriter interface { - writeJSON(context.Context, interface{}) error + // writeJSON writes a message to the connection. + writeJSON(ctx context.Context, msg interface{}, isError bool) error + // Closed returns a channel which is closed when the connection is closed. closed() <-chan interface{} // RemoteAddr returns the peer address of the connection. diff --git a/rpc/websocket.go b/rpc/websocket.go index f6d09288590c..0ac2a2792d5a 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -35,7 +35,7 @@ import ( const ( wsReadBuffer = 1024 wsWriteBuffer = 1024 - wsPingInterval = 60 * time.Second + wsPingInterval = 30 * time.Second wsPingWriteTimeout = 5 * time.Second wsPongTimeout = 30 * time.Second wsMessageSizeLimit = 15 * 1024 * 1024 @@ -287,8 +287,12 @@ func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) Serve conn.SetReadDeadline(time.Time{}) return nil }) + + encode := func(v interface{}, isErrorResponse bool) error { + return conn.WriteJSON(v) + } wc := &websocketCodec{ - jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec), + jsonCodec: NewFuncCodec(conn, encode, conn.ReadJSON).(*jsonCodec), conn: conn, pingReset: make(chan struct{}, 1), info: PeerInfo{ @@ -315,8 +319,8 @@ func (wc *websocketCodec) peerInfo() PeerInfo { return wc.info } -func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error { - err := wc.jsonCodec.writeJSON(ctx, v) +func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}, isError bool) error { + err := wc.jsonCodec.writeJSON(ctx, v, isError) if err == nil { // Notify pingLoop to delay the next idle ping. select { diff --git a/signer/core/apitypes/signed_data_internal_test.go b/signer/core/apitypes/signed_data_internal_test.go index 8379c0a7f075..af7fc93ed88f 100644 --- a/signer/core/apitypes/signed_data_internal_test.go +++ b/signer/core/apitypes/signed_data_internal_test.go @@ -23,6 +23,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/common/math" ) func TestBytesPadding(t *testing.T) { @@ -197,3 +198,38 @@ func TestParseInteger(t *testing.T) { } } } + +func TestConvertStringDataToSlice(t *testing.T) { + slice := []string{"a", "b", "c"} + var it interface{} = slice + _, err := convertDataToSlice(it) + if err != nil { + t.Fatal(err) + } +} + +func TestConvertUint256DataToSlice(t *testing.T) { + slice := []*math.HexOrDecimal256{ + math.NewHexOrDecimal256(1), + math.NewHexOrDecimal256(2), + math.NewHexOrDecimal256(3), + } + var it interface{} = slice + _, err := convertDataToSlice(it) + if err != nil { + t.Fatal(err) + } +} + +func TestConvertAddressDataToSlice(t *testing.T) { + slice := []common.Address{ + common.HexToAddress("0x0000000000000000000000000000000000000001"), + common.HexToAddress("0x0000000000000000000000000000000000000002"), + common.HexToAddress("0x0000000000000000000000000000000000000003"), + } + var it interface{} = slice + _, err := convertDataToSlice(it) + if err != nil { + t.Fatal(err) + } +} diff --git a/signer/core/apitypes/types.go b/signer/core/apitypes/types.go index 12b5a3f33d7d..16f7e6d77bc2 100644 --- a/signer/core/apitypes/types.go +++ b/signer/core/apitypes/types.go @@ -398,8 +398,8 @@ func (typedData *TypedData) EncodeData(primaryType string, data map[string]inter encType := field.Type encValue := data[field.Name] if encType[len(encType)-1:] == "]" { - arrayValue, ok := encValue.([]interface{}) - if !ok { + arrayValue, err := convertDataToSlice(encValue) + if err != nil { return nil, dataMismatchError(encType, encValue) } @@ -604,6 +604,19 @@ func dataMismatchError(encType string, encValue interface{}) error { return fmt.Errorf("provided data '%v' doesn't match type '%s'", encValue, encType) } +func convertDataToSlice(encValue interface{}) ([]interface{}, error) { + var outEncValue []interface{} + rv := reflect.ValueOf(encValue) + if rv.Kind() == reflect.Slice { + for i := 0; i < rv.Len(); i++ { + outEncValue = append(outEncValue, rv.Index(i).Interface()) + } + } else { + return outEncValue, fmt.Errorf("provided data '%v' is not slice", encValue) + } + return outEncValue, nil +} + // validate makes sure the types are sound func (typedData *TypedData) validate() error { if err := typedData.Types.validate(); err != nil { @@ -663,7 +676,7 @@ func (typedData *TypedData) formatData(primaryType string, data map[string]inter Typ: field.Type, } if field.isArray() { - arrayValue, _ := encValue.([]interface{}) + arrayValue, _ := convertDataToSlice(encValue) parsedType := field.typeName() for _, v := range arrayValue { if typedData.Types[parsedType] != nil { diff --git a/signer/core/signed_data.go b/signer/core/signed_data.go index c0da22e62662..8ee572f53ec8 100644 --- a/signer/core/signed_data.go +++ b/signer/core/signed_data.go @@ -18,6 +18,7 @@ package core import ( "context" + "encoding/json" "errors" "fmt" "mime" @@ -135,11 +136,7 @@ func (api *SignerAPI) determineSignatureFormat(ctx context.Context, contentType req = &SignDataRequest{ContentType: mediaType, Rawdata: []byte(msg), Messages: messages, Hash: sighash} case apitypes.ApplicationClique.Mime: // Clique is the Ethereum PoA standard - stringData, ok := data.(string) - if !ok { - return nil, useEthereumV, fmt.Errorf("input for %v must be an hex-encoded string", apitypes.ApplicationClique.Mime) - } - cliqueData, err := hexutil.Decode(stringData) + cliqueData, err := fromHex(data) if err != nil { return nil, useEthereumV, err } @@ -167,27 +164,30 @@ func (api *SignerAPI) determineSignatureFormat(ctx context.Context, contentType // Clique uses V on the form 0 or 1 useEthereumV = false req = &SignDataRequest{ContentType: mediaType, Rawdata: cliqueRlp, Messages: messages, Hash: sighash} + case apitypes.DataTyped.Mime: + // EIP-712 conformant typed data + var err error + req, err = typedDataRequest(data) + if err != nil { + return nil, useEthereumV, err + } default: // also case TextPlain.Mime: // Calculates an Ethereum ECDSA signature for: // hash = keccak256("\x19Ethereum Signed Message:\n${message length}${message}") - // We expect it to be a string - if stringData, ok := data.(string); !ok { - return nil, useEthereumV, fmt.Errorf("input for text/plain must be an hex-encoded string") - } else { - if textData, err := hexutil.Decode(stringData); err != nil { - return nil, useEthereumV, err - } else { - sighash, msg := accounts.TextAndHash(textData) - messages := []*apitypes.NameValueType{ - { - Name: "message", - Typ: accounts.MimetypeTextPlain, - Value: msg, - }, - } - req = &SignDataRequest{ContentType: mediaType, Rawdata: []byte(msg), Messages: messages, Hash: sighash} - } + // We expect input to be a hex-encoded string + textData, err := fromHex(data) + if err != nil { + return nil, useEthereumV, err } + sighash, msg := accounts.TextAndHash(textData) + messages := []*apitypes.NameValueType{ + { + Name: "message", + Typ: accounts.MimetypeTextPlain, + Value: msg, + }, + } + req = &SignDataRequest{ContentType: mediaType, Rawdata: []byte(msg), Messages: messages, Hash: sighash} } req.Address = addr req.Meta = MetadataFromContext(ctx) @@ -233,20 +233,12 @@ func (api *SignerAPI) SignTypedData(ctx context.Context, addr common.MixedcaseAd // - the signature preimage (hash) func (api *SignerAPI) signTypedData(ctx context.Context, addr common.MixedcaseAddress, typedData apitypes.TypedData, validationMessages *apitypes.ValidationMessages) (hexutil.Bytes, hexutil.Bytes, error) { - sighash, rawData, err := apitypes.TypedDataAndHash(typedData) + req, err := typedDataRequest(typedData) if err != nil { return nil, nil, err } - messages, err := typedData.Format() - if err != nil { - return nil, nil, err - } - req := &SignDataRequest{ - ContentType: apitypes.DataTyped.Mime, - Rawdata: []byte(rawData), - Messages: messages, - Hash: sighash, - Address: addr} + req.Address = addr + req.Meta = MetadataFromContext(ctx) if validationMessages != nil { req.Callinfo = validationMessages.Messages } @@ -255,7 +247,46 @@ func (api *SignerAPI) signTypedData(ctx context.Context, addr common.MixedcaseAd api.UI.ShowError(err.Error()) return nil, nil, err } - return signature, sighash, nil + return signature, req.Hash, nil +} + +// fromHex tries to interpret the data as type string, and convert from +// hexadecimal to []byte +func fromHex(data any) ([]byte, error) { + if stringData, ok := data.(string); ok { + binary, err := hexutil.Decode(stringData) + return binary, err + } + return nil, fmt.Errorf("wrong type %T", data) +} + +// typeDataRequest tries to convert the data into a SignDataRequest. +func typedDataRequest(data any) (*SignDataRequest, error) { + var typedData apitypes.TypedData + if td, ok := data.(apitypes.TypedData); ok { + typedData = td + } else { // Hex-encoded data + jsonData, err := fromHex(data) + if err != nil { + return nil, err + } + if err = json.Unmarshal(jsonData, &typedData); err != nil { + return nil, err + } + } + messages, err := typedData.Format() + if err != nil { + return nil, err + } + sighash, rawData, err := apitypes.TypedDataAndHash(typedData) + if err != nil { + return nil, err + } + return &SignDataRequest{ + ContentType: apitypes.DataTyped.Mime, + Rawdata: []byte(rawData), + Messages: messages, + Hash: sighash}, nil } // EcRecover recovers the address associated with the given sig. @@ -293,30 +324,20 @@ func UnmarshalValidatorData(data interface{}) (apitypes.ValidatorData, error) { if !ok { return apitypes.ValidatorData{}, errors.New("validator input is not a map[string]interface{}") } - addr, ok := raw["address"].(string) - if !ok { - return apitypes.ValidatorData{}, errors.New("validator address is not sent as a string") - } - addrBytes, err := hexutil.Decode(addr) + addrBytes, err := fromHex(raw["address"]) if err != nil { - return apitypes.ValidatorData{}, err + return apitypes.ValidatorData{}, fmt.Errorf("validator address error: %w", err) } - if !ok || len(addrBytes) == 0 { + if len(addrBytes) == 0 { return apitypes.ValidatorData{}, errors.New("validator address is undefined") } - - message, ok := raw["message"].(string) - if !ok { - return apitypes.ValidatorData{}, errors.New("message is not sent as a string") - } - messageBytes, err := hexutil.Decode(message) + messageBytes, err := fromHex(raw["message"]) if err != nil { - return apitypes.ValidatorData{}, err + return apitypes.ValidatorData{}, fmt.Errorf("message error: %w", err) } - if !ok || len(messageBytes) == 0 { + if len(messageBytes) == 0 { return apitypes.ValidatorData{}, errors.New("message is undefined") } - return apitypes.ValidatorData{ Address: common.BytesToAddress(addrBytes), Message: messageBytes, diff --git a/signer/core/signed_data_test.go b/signer/core/signed_data_test.go index 7d5661e7e6a8..8deff919cba1 100644 --- a/signer/core/signed_data_test.go +++ b/signer/core/signed_data_test.go @@ -220,15 +220,29 @@ func TestSignData(t *testing.T) { if signature == nil || len(signature) != 65 { t.Errorf("Expected 65 byte signature (got %d bytes)", len(signature)) } - // data/typed + // data/typed via SignTypeData control.approveCh <- "Y" control.inputCh <- "a_long_password" - signature, err = api.SignTypedData(context.Background(), a, typedData) - if err != nil { + var want []byte + if signature, err = api.SignTypedData(context.Background(), a, typedData); err != nil { t.Fatal(err) + } else if signature == nil || len(signature) != 65 { + t.Errorf("Expected 65 byte signature (got %d bytes)", len(signature)) + } else { + want = signature } - if signature == nil || len(signature) != 65 { + + // data/typed via SignData / mimetype typed data + control.approveCh <- "Y" + control.inputCh <- "a_long_password" + if typedDataJson, err := json.Marshal(typedData); err != nil { + t.Fatal(err) + } else if signature, err = api.SignData(context.Background(), apitypes.DataTyped.Mime, a, hexutil.Encode(typedDataJson)); err != nil { + t.Fatal(err) + } else if signature == nil || len(signature) != 65 { t.Errorf("Expected 65 byte signature (got %d bytes)", len(signature)) + } else if have := signature; !bytes.Equal(have, want) { + t.Fatalf("want %x, have %x", want, have) } } diff --git a/tests/block_test_util.go b/tests/block_test_util.go index 313a703fae8f..5b200a60727c 100644 --- a/tests/block_test_util.go +++ b/tests/block_test_util.go @@ -107,10 +107,7 @@ func (t *BlockTest) Run(snapshotter bool) error { // import pre accounts & construct test genesis block & state root db := rawdb.NewMemoryDatabase() gspec := t.genesis(config) - gblock, err := gspec.Commit(db) - if err != nil { - return err - } + gblock := gspec.MustCommit(db) if gblock.Hash() != t.json.Genesis.Hash { return fmt.Errorf("genesis block hash doesn't match test: computed=%x, test=%x", gblock.Hash().Bytes()[:6], t.json.Genesis.Hash[:6]) } diff --git a/tests/evm-benchmarks b/tests/evm-benchmarks index 849b3e239a28..d8b88f4046a8 160000 --- a/tests/evm-benchmarks +++ b/tests/evm-benchmarks @@ -1 +1 @@ -Subproject commit 849b3e239a28f236dc99574b2e10e0c720895105 +Subproject commit d8b88f4046a87d6b902378cef752591f95427b43 diff --git a/tests/fuzzers/stacktrie/trie_fuzzer.go b/tests/fuzzers/stacktrie/trie_fuzzer.go index 6a95a1804c81..3af16bf81df7 100644 --- a/tests/fuzzers/stacktrie/trie_fuzzer.go +++ b/tests/fuzzers/stacktrie/trie_fuzzer.go @@ -25,6 +25,9 @@ import ( "io" "sort" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/trie" "golang.org/x/crypto/sha3" @@ -143,11 +146,14 @@ func Debug(data []byte) int { func (f *fuzzer) fuzz() int { // This spongeDb is used to check the sequence of disk-db-writes var ( - spongeA = &spongeDb{sponge: sha3.NewLegacyKeccak256()} - dbA = trie.NewDatabase(spongeA) - trieA = trie.NewEmpty(dbA) - spongeB = &spongeDb{sponge: sha3.NewLegacyKeccak256()} - trieB = trie.NewStackTrie(spongeB) + spongeA = &spongeDb{sponge: sha3.NewLegacyKeccak256()} + dbA = trie.NewDatabase(rawdb.NewDatabase(spongeA)) + trieA = trie.NewEmpty(dbA) + spongeB = &spongeDb{sponge: sha3.NewLegacyKeccak256()} + dbB = trie.NewDatabase(rawdb.NewDatabase(spongeB)) + trieB = trie.NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, blob []byte) { + dbB.Scheme().WriteTrieNode(spongeB, owner, path, hash, blob) + }) vals kvs useful bool maxElements = 10000 @@ -206,5 +212,48 @@ func (f *fuzzer) fuzz() int { if !bytes.Equal(sumA, sumB) { panic(fmt.Sprintf("sequence differ: (trie) %x != %x (stacktrie)", sumA, sumB)) } + + // Ensure all the nodes are persisted correctly + var ( + nodeset = make(map[string][]byte) // path -> blob + trieC = trie.NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, blob []byte) { + if crypto.Keccak256Hash(blob) != hash { + panic("invalid node blob") + } + if owner != (common.Hash{}) { + panic("invalid node owner") + } + nodeset[string(path)] = common.CopyBytes(blob) + }) + checked int + ) + for _, kv := range vals { + trieC.Update(kv.k, kv.v) + } + rootC, _ := trieC.Commit() + if rootA != rootC { + panic(fmt.Sprintf("roots differ: (trie) %x != %x (stacktrie)", rootA, rootC)) + } + trieA, _ = trie.New(trie.TrieID(rootA), dbA) + iterA := trieA.NodeIterator(nil) + for iterA.Next(true) { + if iterA.Hash() == (common.Hash{}) { + if _, present := nodeset[string(iterA.Path())]; present { + panic("unexpected tiny node") + } + continue + } + nodeBlob, present := nodeset[string(iterA.Path())] + if !present { + panic("missing node") + } + if !bytes.Equal(nodeBlob, iterA.NodeBlob()) { + panic("node blob is not matched") + } + checked += 1 + } + if checked != len(nodeset) { + panic("node number is not matched") + } return 1 } diff --git a/tests/fuzzers/trie/trie-fuzzer.go b/tests/fuzzers/trie/trie-fuzzer.go index 3cb07dff98e9..85a73c675589 100644 --- a/tests/fuzzers/trie/trie-fuzzer.go +++ b/tests/fuzzers/trie/trie-fuzzer.go @@ -21,7 +21,7 @@ import ( "encoding/binary" "fmt" - "github.com/ethereum/go-ethereum/ethdb/memorydb" + "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/trie" ) @@ -139,7 +139,7 @@ func Fuzz(input []byte) int { } func runRandTest(rt randTest) error { - triedb := trie.NewDatabase(memorydb.New()) + triedb := trie.NewDatabase(rawdb.NewMemoryDatabase()) tr := trie.NewEmpty(triedb) values := make(map[string]string) // tracks content of the trie diff --git a/trie/database.go b/trie/database.go index 76ca188add9c..469c33fc84dd 100644 --- a/trie/database.go +++ b/trie/database.go @@ -68,7 +68,7 @@ var ( // behind this split design is to provide read access to RPC handlers and sync // servers even while the trie is executing expensive garbage collection. type Database struct { - diskdb ethdb.KeyValueStore // Persistent storage for matured trie nodes + diskdb ethdb.Database // Persistent storage for matured trie nodes cleans *fastcache.Cache // GC friendly memory cache of clean node RLPs dirties map[common.Hash]*cachedNode // Data and references relationships of dirty trie nodes @@ -273,14 +273,14 @@ type Config struct { // NewDatabase creates a new trie database to store ephemeral trie content before // its written out to disk or garbage collected. No read cache is created, so all // data retrievals will hit the underlying disk database. -func NewDatabase(diskdb ethdb.KeyValueStore) *Database { +func NewDatabase(diskdb ethdb.Database) *Database { return NewDatabaseWithConfig(diskdb, nil) } // NewDatabaseWithConfig creates a new trie database to store ephemeral trie content // before its written out to disk or garbage collected. It also acts as a read cache // for nodes loaded from disk. -func NewDatabaseWithConfig(diskdb ethdb.KeyValueStore, config *Config) *Database { +func NewDatabaseWithConfig(diskdb ethdb.Database, config *Config) *Database { var cleans *fastcache.Cache if config != nil && config.Cache > 0 { if config.Journal == "" { @@ -917,3 +917,8 @@ func (db *Database) CommitPreimages() error { } return db.preimages.commit(true) } + +// Scheme returns the node scheme used in the database. +func (db *Database) Scheme() NodeScheme { + return &hashScheme{} +} diff --git a/trie/database_test.go b/trie/database_test.go index 81c469500f98..54d752947672 100644 --- a/trie/database_test.go +++ b/trie/database_test.go @@ -20,13 +20,13 @@ import ( "testing" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/ethdb/memorydb" + "github.com/ethereum/go-ethereum/core/rawdb" ) // Tests that the trie database returns a missing trie node error if attempting // to retrieve the meta root. func TestDatabaseMetarootFetch(t *testing.T) { - db := NewDatabase(memorydb.New()) + db := NewDatabase(rawdb.NewMemoryDatabase()) if _, err := db.Node(common.Hash{}); err == nil { t.Fatalf("metaroot retrieval succeeded") } diff --git a/trie/iterator_test.go b/trie/iterator_test.go index 74b87a25c233..2664dab2d265 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -327,7 +327,7 @@ func TestIteratorContinueAfterErrorDisk(t *testing.T) { testIteratorContinueA func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) } func testIteratorContinueAfterError(t *testing.T, memonly bool) { - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) tr := NewEmpty(triedb) @@ -419,7 +419,7 @@ func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) { func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) { // Commit test trie to db, then remove the node containing "bars". - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) ctr := NewEmpty(triedb) @@ -532,7 +532,7 @@ func (l *loggingDb) Close() error { func makeLargeTestTrie() (*Database, *StateTrie, *loggingDb) { // Create an empty trie logDb := &loggingDb{0, memorydb.New()} - triedb := NewDatabase(logDb) + triedb := NewDatabase(rawdb.NewDatabase(logDb)) trie, _ := NewStateTrie(TrieID(common.Hash{}), triedb) // Fill it with some arbitrary data @@ -567,7 +567,7 @@ func TestNodeIteratorLargeTrie(t *testing.T) { func TestIteratorNodeBlob(t *testing.T) { var ( - db = memorydb.New() + db = rawdb.NewMemoryDatabase() triedb = NewDatabase(db) trie = NewEmpty(triedb) ) diff --git a/trie/schema.go b/trie/schema.go new file mode 100644 index 000000000000..ed049faa5ce0 --- /dev/null +++ b/trie/schema.go @@ -0,0 +1,96 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/ethdb" +) + +const ( + HashScheme = "hashScheme" // Identifier of hash based node scheme + + // Path-based scheme will be introduced in the following PRs. + // PathScheme = "pathScheme" // Identifier of path based node scheme +) + +// NodeScheme describes the scheme for interacting nodes in disk. +type NodeScheme interface { + // Name returns the identifier of node scheme. + Name() string + + // HasTrieNode checks the trie node presence with the provided node info and + // the associated node hash. + HasTrieNode(db ethdb.KeyValueReader, owner common.Hash, path []byte, hash common.Hash) bool + + // ReadTrieNode retrieves the trie node from database with the provided node + // info and the associated node hash. + ReadTrieNode(db ethdb.KeyValueReader, owner common.Hash, path []byte, hash common.Hash) []byte + + // WriteTrieNode writes the trie node into database with the provided node + // info and associated node hash. + WriteTrieNode(db ethdb.KeyValueWriter, owner common.Hash, path []byte, hash common.Hash, node []byte) + + // DeleteTrieNode deletes the trie node from database with the provided node + // info and associated node hash. + DeleteTrieNode(db ethdb.KeyValueWriter, owner common.Hash, path []byte, hash common.Hash) + + // IsTrieNode returns an indicator if the given database key is the key of + // trie node according to the scheme. + IsTrieNode(key []byte) (bool, []byte) +} + +type hashScheme struct{} + +// Name returns the identifier of hash based scheme. +func (scheme *hashScheme) Name() string { + return HashScheme +} + +// HasTrieNode checks the trie node presence with the provided node info and +// the associated node hash. +func (scheme *hashScheme) HasTrieNode(db ethdb.KeyValueReader, owner common.Hash, path []byte, hash common.Hash) bool { + return rawdb.HasTrieNode(db, hash) +} + +// ReadTrieNode retrieves the trie node from database with the provided node info +// and associated node hash. +func (scheme *hashScheme) ReadTrieNode(db ethdb.KeyValueReader, owner common.Hash, path []byte, hash common.Hash) []byte { + return rawdb.ReadTrieNode(db, hash) +} + +// WriteTrieNode writes the trie node into database with the provided node info +// and associated node hash. +func (scheme *hashScheme) WriteTrieNode(db ethdb.KeyValueWriter, owner common.Hash, path []byte, hash common.Hash, node []byte) { + rawdb.WriteTrieNode(db, hash, node) +} + +// DeleteTrieNode deletes the trie node from database with the provided node info +// and associated node hash. +func (scheme *hashScheme) DeleteTrieNode(db ethdb.KeyValueWriter, owner common.Hash, path []byte, hash common.Hash) { + rawdb.DeleteTrieNode(db, hash) +} + +// IsTrieNode returns an indicator if the given database key is the key of trie +// node according to the scheme. +func (scheme *hashScheme) IsTrieNode(key []byte) (bool, []byte) { + if len(key) == common.HashLength { + return true, key + } + return false, nil +} diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go index ab8462607d99..24b8c5f095e0 100644 --- a/trie/secure_trie_test.go +++ b/trie/secure_trie_test.go @@ -24,19 +24,19 @@ import ( "testing" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/ethdb/memorydb" ) func newEmptySecure() *StateTrie { - trie, _ := NewStateTrie(TrieID(common.Hash{}), NewDatabase(memorydb.New())) + trie, _ := NewStateTrie(TrieID(common.Hash{}), NewDatabase(rawdb.NewMemoryDatabase())) return trie } // makeTestStateTrie creates a large enough secure trie for testing. func makeTestStateTrie() (*Database, *StateTrie, map[string][]byte) { // Create an empty trie - triedb := NewDatabase(memorydb.New()) + triedb := NewDatabase(rawdb.NewMemoryDatabase()) trie, _ := NewStateTrie(TrieID(common.Hash{}), triedb) // Fill it with some arbitrary data diff --git a/trie/stacktrie.go b/trie/stacktrie.go index 2df2cd6ed016..fb8cc0d763e6 100644 --- a/trie/stacktrie.go +++ b/trie/stacktrie.go @@ -25,7 +25,6 @@ import ( "sync" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" ) @@ -37,10 +36,14 @@ var stPool = sync.Pool{ }, } -func stackTrieFromPool(db ethdb.KeyValueWriter, owner common.Hash) *StackTrie { +// NodeWriteFunc is used to provide all information of a dirty node for committing +// so that callers can flush nodes into database with desired scheme. +type NodeWriteFunc = func(owner common.Hash, path []byte, hash common.Hash, blob []byte) + +func stackTrieFromPool(writeFn NodeWriteFunc, owner common.Hash) *StackTrie { st := stPool.Get().(*StackTrie) - st.db = db st.owner = owner + st.writeFn = writeFn return st } @@ -53,41 +56,41 @@ func returnToPool(st *StackTrie) { // in order. Once it determines that a subtree will no longer be inserted // into, it will hash it and free up the memory it uses. type StackTrie struct { - owner common.Hash // the owner of the trie - nodeType uint8 // node type (as in branch, ext, leaf) - val []byte // value contained by this node if it's a leaf - key []byte // key chunk covered by this (leaf|ext) node - children [16]*StackTrie // list of children (for branch and exts) - db ethdb.KeyValueWriter // Pointer to the commit db, can be nil + owner common.Hash // the owner of the trie + nodeType uint8 // node type (as in branch, ext, leaf) + val []byte // value contained by this node if it's a leaf + key []byte // key chunk covered by this (leaf|ext) node + children [16]*StackTrie // list of children (for branch and exts) + writeFn NodeWriteFunc // function for committing nodes, can be nil } // NewStackTrie allocates and initializes an empty trie. -func NewStackTrie(db ethdb.KeyValueWriter) *StackTrie { +func NewStackTrie(writeFn NodeWriteFunc) *StackTrie { return &StackTrie{ nodeType: emptyNode, - db: db, + writeFn: writeFn, } } // NewStackTrieWithOwner allocates and initializes an empty trie, but with // the additional owner field. -func NewStackTrieWithOwner(db ethdb.KeyValueWriter, owner common.Hash) *StackTrie { +func NewStackTrieWithOwner(writeFn NodeWriteFunc, owner common.Hash) *StackTrie { return &StackTrie{ owner: owner, nodeType: emptyNode, - db: db, + writeFn: writeFn, } } // NewFromBinary initialises a serialized stacktrie with the given db. -func NewFromBinary(data []byte, db ethdb.KeyValueWriter) (*StackTrie, error) { +func NewFromBinary(data []byte, writeFn NodeWriteFunc) (*StackTrie, error) { var st StackTrie if err := st.UnmarshalBinary(data); err != nil { return nil, err } // If a database is used, we need to recursively add it to every child - if db != nil { - st.setDb(db) + if writeFn != nil { + st.setWriter(writeFn) } return &st, nil } @@ -160,25 +163,25 @@ func (st *StackTrie) unmarshalBinary(r io.Reader) error { return nil } -func (st *StackTrie) setDb(db ethdb.KeyValueWriter) { - st.db = db +func (st *StackTrie) setWriter(writeFn NodeWriteFunc) { + st.writeFn = writeFn for _, child := range st.children { if child != nil { - child.setDb(db) + child.setWriter(writeFn) } } } -func newLeaf(owner common.Hash, key, val []byte, db ethdb.KeyValueWriter) *StackTrie { - st := stackTrieFromPool(db, owner) +func newLeaf(owner common.Hash, key, val []byte, writeFn NodeWriteFunc) *StackTrie { + st := stackTrieFromPool(writeFn, owner) st.nodeType = leafNode st.key = append(st.key, key...) st.val = val return st } -func newExt(owner common.Hash, key []byte, child *StackTrie, db ethdb.KeyValueWriter) *StackTrie { - st := stackTrieFromPool(db, owner) +func newExt(owner common.Hash, key []byte, child *StackTrie, writeFn NodeWriteFunc) *StackTrie { + st := stackTrieFromPool(writeFn, owner) st.nodeType = extNode st.key = append(st.key, key...) st.children[0] = child @@ -200,7 +203,7 @@ func (st *StackTrie) TryUpdate(key, value []byte) error { if len(value) == 0 { panic("deletion not supported") } - st.insert(k[:len(k)-1], value) + st.insert(k[:len(k)-1], value, nil) return nil } @@ -212,7 +215,7 @@ func (st *StackTrie) Update(key, value []byte) { func (st *StackTrie) Reset() { st.owner = common.Hash{} - st.db = nil + st.writeFn = nil st.key = st.key[:0] st.val = nil for i := range st.children { @@ -235,7 +238,7 @@ func (st *StackTrie) getDiffIndex(key []byte) int { // Helper function to that inserts a (key, value) pair into // the trie. -func (st *StackTrie) insert(key, value []byte) { +func (st *StackTrie) insert(key, value []byte, prefix []byte) { switch st.nodeType { case branchNode: /* Branch */ idx := int(key[0]) @@ -244,7 +247,7 @@ func (st *StackTrie) insert(key, value []byte) { for i := idx - 1; i >= 0; i-- { if st.children[i] != nil { if st.children[i].nodeType != hashedNode { - st.children[i].hash() + st.children[i].hash(append(prefix, byte(i))) } break } @@ -252,9 +255,9 @@ func (st *StackTrie) insert(key, value []byte) { // Add new child if st.children[idx] == nil { - st.children[idx] = newLeaf(st.owner, key[1:], value, st.db) + st.children[idx] = newLeaf(st.owner, key[1:], value, st.writeFn) } else { - st.children[idx].insert(key[1:], value) + st.children[idx].insert(key[1:], value, append(prefix, key[0])) } case extNode: /* Ext */ @@ -269,7 +272,7 @@ func (st *StackTrie) insert(key, value []byte) { if diffidx == len(st.key) { // Ext key and key segment are identical, recurse into // the child node. - st.children[0].insert(key[diffidx:], value) + st.children[0].insert(key[diffidx:], value, append(prefix, key[:diffidx]...)) return } // Save the original part. Depending if the break is @@ -278,14 +281,19 @@ func (st *StackTrie) insert(key, value []byte) { // node directly. var n *StackTrie if diffidx < len(st.key)-1 { - n = newExt(st.owner, st.key[diffidx+1:], st.children[0], st.db) + // Break on the non-last byte, insert an intermediate + // extension. The path prefix of the newly-inserted + // extension should also contain the different byte. + n = newExt(st.owner, st.key[diffidx+1:], st.children[0], st.writeFn) + n.hash(append(prefix, st.key[:diffidx+1]...)) } else { // Break on the last byte, no need to insert - // an extension node: reuse the current node + // an extension node: reuse the current node. + // The path prefix of the original part should + // still be same. n = st.children[0] + n.hash(append(prefix, st.key...)) } - // Convert to hash - n.hash() var p *StackTrie if diffidx == 0 { // the break is on the first byte, so @@ -298,12 +306,12 @@ func (st *StackTrie) insert(key, value []byte) { // the common prefix is at least one byte // long, insert a new intermediate branch // node. - st.children[0] = stackTrieFromPool(st.db, st.owner) + st.children[0] = stackTrieFromPool(st.writeFn, st.owner) st.children[0].nodeType = branchNode p = st.children[0] } // Create a leaf for the inserted part - o := newLeaf(st.owner, key[diffidx+1:], value, st.db) + o := newLeaf(st.owner, key[diffidx+1:], value, st.writeFn) // Insert both child leaves where they belong: origIdx := st.key[diffidx] @@ -339,7 +347,7 @@ func (st *StackTrie) insert(key, value []byte) { // Convert current node into an ext, // and insert a child branch node. st.nodeType = extNode - st.children[0] = NewStackTrieWithOwner(st.db, st.owner) + st.children[0] = NewStackTrieWithOwner(st.writeFn, st.owner) st.children[0].nodeType = branchNode p = st.children[0] } @@ -348,11 +356,11 @@ func (st *StackTrie) insert(key, value []byte) { // value and another containing the new value. The child leaf // is hashed directly in order to free up some memory. origIdx := st.key[diffidx] - p.children[origIdx] = newLeaf(st.owner, st.key[diffidx+1:], st.val, st.db) - p.children[origIdx].hash() + p.children[origIdx] = newLeaf(st.owner, st.key[diffidx+1:], st.val, st.writeFn) + p.children[origIdx].hash(append(prefix, st.key[:diffidx+1]...)) newIdx := key[diffidx] - p.children[newIdx] = newLeaf(st.owner, key[diffidx+1:], value, st.db) + p.children[newIdx] = newLeaf(st.owner, key[diffidx+1:], value, st.writeFn) // Finally, cut off the key part that has been passed // over to the children. @@ -383,14 +391,14 @@ func (st *StackTrie) insert(key, value []byte) { // - And the 'st.type' will be 'hashedNode' AGAIN // // This method also sets 'st.type' to hashedNode, and clears 'st.key'. -func (st *StackTrie) hash() { +func (st *StackTrie) hash(path []byte) { h := newHasher(false) defer returnHasherToPool(h) - st.hashRec(h) + st.hashRec(h, path) } -func (st *StackTrie) hashRec(hasher *hasher) { +func (st *StackTrie) hashRec(hasher *hasher, path []byte) { // The switch below sets this to the RLP-encoding of this node. var encodedNode []byte @@ -411,8 +419,7 @@ func (st *StackTrie) hashRec(hasher *hasher) { nodes[i] = nilValueNode continue } - - child.hashRec(hasher) + child.hashRec(hasher, append(path, byte(i))) if len(child.val) < 32 { nodes[i] = rawNode(child.val) } else { @@ -428,10 +435,9 @@ func (st *StackTrie) hashRec(hasher *hasher) { encodedNode = hasher.encodedBytes() case extNode: - st.children[0].hashRec(hasher) + st.children[0].hashRec(hasher, append(path, st.key...)) - sz := hexToCompactInPlace(st.key) - n := rawShortNode{Key: st.key[:sz]} + n := rawShortNode{Key: hexToCompact(st.key)} if len(st.children[0].val) < 32 { n.Val = rawNode(st.children[0].val) } else { @@ -447,8 +453,7 @@ func (st *StackTrie) hashRec(hasher *hasher) { case leafNode: st.key = append(st.key, byte(16)) - sz := hexToCompactInPlace(st.key) - n := rawShortNode{Key: st.key[:sz], Val: valueNode(st.val)} + n := rawShortNode{Key: hexToCompact(st.key), Val: valueNode(st.val)} n.encode(hasher.encbuf) encodedNode = hasher.encodedBytes() @@ -467,10 +472,8 @@ func (st *StackTrie) hashRec(hasher *hasher) { // Write the hash to the 'val'. We allocate a new val here to not mutate // input values st.val = hasher.hashData(encodedNode) - if st.db != nil { - // TODO! Is it safe to Put the slice here? - // Do all db implementations copy the value provided? - st.db.Put(st.val, encodedNode) + if st.writeFn != nil { + st.writeFn(st.owner, path, common.BytesToHash(st.val), encodedNode) } } @@ -479,12 +482,11 @@ func (st *StackTrie) Hash() (h common.Hash) { hasher := newHasher(false) defer returnHasherToPool(hasher) - st.hashRec(hasher) + st.hashRec(hasher, nil) if len(st.val) == 32 { copy(h[:], st.val) return h } - // If the node's RLP isn't 32 bytes long, the node will not // be hashed, and instead contain the rlp-encoding of the // node. For the top level node, we need to force the hashing. @@ -502,25 +504,24 @@ func (st *StackTrie) Hash() (h common.Hash) { // The associated database is expected, otherwise the whole commit // functionality should be disabled. func (st *StackTrie) Commit() (h common.Hash, err error) { - if st.db == nil { + if st.writeFn == nil { return common.Hash{}, ErrCommitDisabled } - hasher := newHasher(false) defer returnHasherToPool(hasher) - st.hashRec(hasher) + st.hashRec(hasher, nil) if len(st.val) == 32 { copy(h[:], st.val) return h, nil } - // If the node's RLP isn't 32 bytes long, the node will not - // be hashed (and committed), and instead contain the rlp-encoding of the + // be hashed (and committed), and instead contain the rlp-encoding of the // node. For the top level node, we need to force the hashing+commit. hasher.sha.Reset() hasher.sha.Write(st.val) hasher.sha.Read(h[:]) - st.db.Put(h[:], st.val) + + st.writeFn(st.owner, nil, h, st.val) return h, nil } diff --git a/trie/stacktrie_test.go b/trie/stacktrie_test.go index 069e4981d71a..215c97cfcdf7 100644 --- a/trie/stacktrie_test.go +++ b/trie/stacktrie_test.go @@ -22,8 +22,8 @@ import ( "testing" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/ethdb/memorydb" ) func TestStackTrieInsertAndHash(t *testing.T) { @@ -188,7 +188,7 @@ func TestStackTrieInsertAndHash(t *testing.T) { func TestSizeBug(t *testing.T) { st := NewStackTrie(nil) - nt := NewEmpty(NewDatabase(memorydb.New())) + nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563") value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3") @@ -203,7 +203,7 @@ func TestSizeBug(t *testing.T) { func TestEmptyBug(t *testing.T) { st := NewStackTrie(nil) - nt := NewEmpty(NewDatabase(memorydb.New())) + nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) //leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563") //value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3") @@ -229,7 +229,7 @@ func TestEmptyBug(t *testing.T) { func TestValLength56(t *testing.T) { st := NewStackTrie(nil) - nt := NewEmpty(NewDatabase(memorydb.New())) + nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) //leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563") //value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3") @@ -254,7 +254,7 @@ func TestValLength56(t *testing.T) { // which causes a lot of node-within-node. This case was found via fuzzing. func TestUpdateSmallNodes(t *testing.T) { st := NewStackTrie(nil) - nt := NewEmpty(NewDatabase(memorydb.New())) + nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) kvs := []struct { K string @@ -283,7 +283,7 @@ func TestUpdateSmallNodes(t *testing.T) { func TestUpdateVariableKeys(t *testing.T) { t.SkipNow() st := NewStackTrie(nil) - nt := NewEmpty(NewDatabase(memorydb.New())) + nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) kvs := []struct { K string @@ -353,7 +353,7 @@ func TestStacktrieNotModifyValues(t *testing.T) { func TestStacktrieSerialization(t *testing.T) { var ( st = NewStackTrie(nil) - nt = NewEmpty(NewDatabase(memorydb.New())) + nt = NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) keyB = big.NewInt(1) keyDelta = big.NewInt(1) vals [][]byte diff --git a/trie/sync.go b/trie/sync.go index 31d3cbe91b9e..199766983577 100644 --- a/trie/sync.go +++ b/trie/sync.go @@ -64,7 +64,7 @@ type SyncPath [][]byte // version that can be sent over the network. func NewSyncPath(path []byte) SyncPath { // If the hash is from the account trie, append a single item, if it - // is from the a storage trie, append a tuple. Note, the length 64 is + // is from a storage trie, append a tuple. Note, the length 64 is // clashing between account leaf and storage root. It's fine though // because having a trie node at 64 depth means a hash collision was // found and we're long dead. @@ -74,6 +74,22 @@ func NewSyncPath(path []byte) SyncPath { return SyncPath{hexToKeybytes(path[:64]), hexToCompact(path[64:])} } +// LeafCallback is a callback type invoked when a trie operation reaches a leaf +// node. +// +// The keys is a path tuple identifying a particular trie node either in a single +// trie (account) or a layered trie (account -> storage). Each key in the tuple +// is in the raw format(32 bytes). +// +// The path is a composite hexary path identifying the trie node. All the key +// bytes are converted to the hexary nibbles and composited with the parent path +// if the trie node is in a layered trie. +// +// It's used by state sync and commit to allow handling external references +// between account and storage tries. And also it's used in the state healing +// for extracting the raw states(leaf nodes) with corresponding paths. +type LeafCallback func(keys [][]byte, path []byte, leaf []byte, parent common.Hash, parentPath []byte) error + // nodeRequest represents a scheduled or already in-flight trie node retrieval request. type nodeRequest struct { hash common.Hash // Hash of the trie node to retrieve @@ -139,6 +155,7 @@ func (batch *syncMemBatch) hasCode(hash common.Hash) bool { // unknown trie hashes to retrieve, accepts node data associated with said hashes // and reconstructs the trie step by step until all is done. type Sync struct { + scheme NodeScheme // Node scheme descriptor used in database. database ethdb.KeyValueReader // Persistent database to check for existing entries membatch *syncMemBatch // Memory buffer to avoid frequent database writes nodeReqs map[string]*nodeRequest // Pending requests pertaining to a trie node path @@ -148,8 +165,9 @@ type Sync struct { } // NewSync creates a new trie data download scheduler. -func NewSync(root common.Hash, database ethdb.KeyValueReader, callback LeafCallback) *Sync { +func NewSync(root common.Hash, database ethdb.KeyValueReader, callback LeafCallback, scheme NodeScheme) *Sync { ts := &Sync{ + scheme: scheme, database: database, membatch: newSyncMemBatch(), nodeReqs: make(map[string]*nodeRequest), @@ -172,7 +190,8 @@ func (s *Sync) AddSubTrie(root common.Hash, path []byte, parent common.Hash, par if s.membatch.hasNode(path) { return } - if rawdb.HasTrieNode(s.database, root) { + owner, inner := ResolvePath(path) + if s.scheme.HasTrieNode(s.database, owner, inner, root) { return } // Assemble the new sub-trie sync request @@ -205,7 +224,7 @@ func (s *Sync) AddCodeEntry(hash common.Hash, path []byte, parent common.Hash, p return } // If database says duplicate, the blob is present for sure. - // Note we only check the existence with new code scheme, fast + // Note we only check the existence with new code scheme, snap // sync is expected to run with a fresh new node. Even there // exists the code with legacy format, fetch and store with // new scheme anyway. @@ -329,7 +348,8 @@ func (s *Sync) ProcessNode(result NodeSyncResult) error { func (s *Sync) Commit(dbw ethdb.Batch) error { // Dump the membatch into a database dbw for path, value := range s.membatch.nodes { - rawdb.WriteTrieNode(dbw, s.membatch.hashes[path], value) + owner, inner := ResolvePath([]byte(path)) + s.scheme.WriteTrieNode(dbw, owner, inner, s.membatch.hashes[path], value) } for hash, value := range s.membatch.codes { rawdb.WriteCode(dbw, hash, value) @@ -450,8 +470,11 @@ func (s *Sync) children(req *nodeRequest, object node) ([]*nodeRequest, error) { // If database says duplicate, then at least the trie node is present // and we hold the assumption that it's NOT legacy contract code. - chash := common.BytesToHash(node) - if rawdb.HasTrieNode(s.database, chash) { + var ( + chash = common.BytesToHash(node) + owner, inner = ResolvePath(child.path) + ) + if s.scheme.HasTrieNode(s.database, owner, inner, chash) { return } // Locally unknown node, schedule for retrieval @@ -525,3 +548,14 @@ func (s *Sync) commitCodeRequest(req *codeRequest) error { } return nil } + +// ResolvePath resolves the provided composite node path by separating the +// path in account trie if it's existent. +func ResolvePath(path []byte) (common.Hash, []byte) { + var owner common.Hash + if len(path) >= 2*common.HashLength { + owner = common.BytesToHash(hexToKeybytes(path[:2*common.HashLength])) + path = path[2*common.HashLength:] + } + return owner, path +} diff --git a/trie/sync_test.go b/trie/sync_test.go index a02527855300..821f7cdf4dc4 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb/memorydb" ) @@ -29,7 +30,7 @@ import ( // makeTestTrie create a sample test trie to test node-wise reconstruction. func makeTestTrie() (*Database, *StateTrie, map[string][]byte) { // Create an empty trie - triedb := NewDatabase(memorydb.New()) + triedb := NewDatabase(rawdb.NewMemoryDatabase()) trie, _ := NewStateTrie(TrieID(common.Hash{}), triedb) // Fill it with some arbitrary data @@ -103,13 +104,13 @@ type trieElement struct { // Tests that an empty trie is not scheduled for syncing. func TestEmptySync(t *testing.T) { - dbA := NewDatabase(memorydb.New()) - dbB := NewDatabase(memorydb.New()) + dbA := NewDatabase(rawdb.NewMemoryDatabase()) + dbB := NewDatabase(rawdb.NewMemoryDatabase()) emptyA, _ := New(TrieID(common.Hash{}), dbA) emptyB, _ := New(TrieID(emptyRoot), dbB) for i, trie := range []*Trie{emptyA, emptyB} { - sync := NewSync(trie.Hash(), memorydb.New(), nil) + sync := NewSync(trie.Hash(), memorydb.New(), nil, []*Database{dbA, dbB}[i].Scheme()) if paths, nodes, codes := sync.Missing(1); len(paths) != 0 || len(nodes) != 0 || len(codes) != 0 { t.Errorf("test %d: content requested for empty trie: %v, %v, %v", i, paths, nodes, codes) } @@ -128,9 +129,9 @@ func testIterativeSync(t *testing.T, count int, bypath bool) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) - sched := NewSync(srcTrie.Hash(), diskdb, nil) + sched := NewSync(srcTrie.Hash(), diskdb, nil, srcDb.Scheme()) // The code requests are ignored here since there is no code // at the testing trie. @@ -194,9 +195,9 @@ func TestIterativeDelayedSync(t *testing.T) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) - sched := NewSync(srcTrie.Hash(), diskdb, nil) + sched := NewSync(srcTrie.Hash(), diskdb, nil, srcDb.Scheme()) // The code requests are ignored here since there is no code // at the testing trie. @@ -255,9 +256,9 @@ func testIterativeRandomSync(t *testing.T, count int) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) - sched := NewSync(srcTrie.Hash(), diskdb, nil) + sched := NewSync(srcTrie.Hash(), diskdb, nil, srcDb.Scheme()) // The code requests are ignored here since there is no code // at the testing trie. @@ -313,9 +314,9 @@ func TestIterativeRandomDelayedSync(t *testing.T) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) - sched := NewSync(srcTrie.Hash(), diskdb, nil) + sched := NewSync(srcTrie.Hash(), diskdb, nil, srcDb.Scheme()) // The code requests are ignored here since there is no code // at the testing trie. @@ -376,9 +377,9 @@ func TestDuplicateAvoidanceSync(t *testing.T) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) - sched := NewSync(srcTrie.Hash(), diskdb, nil) + sched := NewSync(srcTrie.Hash(), diskdb, nil, srcDb.Scheme()) // The code requests are ignored here since there is no code // at the testing trie. @@ -439,9 +440,9 @@ func TestIncompleteSync(t *testing.T) { srcDb, srcTrie, _ := makeTestTrie() // Create a destination trie and sync with the scheduler - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) - sched := NewSync(srcTrie.Hash(), diskdb, nil) + sched := NewSync(srcTrie.Hash(), diskdb, nil, srcDb.Scheme()) // The code requests are ignored here since there is no code // at the testing trie. @@ -519,9 +520,9 @@ func TestSyncOrdering(t *testing.T) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler, tracking the requests - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) - sched := NewSync(srcTrie.Hash(), diskdb, nil) + sched := NewSync(srcTrie.Hash(), diskdb, nil, srcDb.Scheme()) // The code requests are ignored here since there is no code // at the testing trie. diff --git a/trie/trie.go b/trie/trie.go index bec6a1cc7891..abc63f46749a 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -35,22 +35,6 @@ var ( emptyState = crypto.Keccak256Hash(nil) ) -// LeafCallback is a callback type invoked when a trie operation reaches a leaf -// node. -// -// The keys is a path tuple identifying a particular trie node either in a single -// trie (account) or a layered trie (account -> storage). Each key in the tuple -// is in the raw format(32 bytes). -// -// The path is a composite hexary path identifying the trie node. All the key -// bytes are converted to the hexary nibbles and composited with the parent path -// if the trie node is in a layered trie. -// -// It's used by state sync and commit to allow handling external references -// between account and storage tries. And also it's used in the state healing -// for extracting the raw states(leaf nodes) with corresponding paths. -type LeafCallback func(keys [][]byte, path []byte, leaf []byte, parent common.Hash, parentPath []byte) error - // Trie is a Merkle Patricia Trie. Use New to create a trie that sits on // top of a database. Whenever trie performs a commit operation, the generated // nodes will be gathered and returned in a set. Once the trie is committed, diff --git a/trie/trie_test.go b/trie/trie_test.go index 832546b1e344..76307ba78686 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -34,7 +34,6 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" - "github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/rlp" "golang.org/x/crypto/sha3" ) @@ -65,7 +64,7 @@ func TestNull(t *testing.T) { func TestMissingRoot(t *testing.T) { root := common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33") - trie, err := New(TrieID(root), NewDatabase(memorydb.New())) + trie, err := New(TrieID(root), NewDatabase(rawdb.NewMemoryDatabase())) if trie != nil { t.Error("New returned non-nil trie for invalid root") } @@ -78,7 +77,7 @@ func TestMissingNodeDisk(t *testing.T) { testMissingNode(t, false) } func TestMissingNodeMemonly(t *testing.T) { testMissingNode(t, true) } func testMissingNode(t *testing.T, memonly bool) { - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) trie := NewEmpty(triedb) @@ -414,7 +413,7 @@ func (randTest) Generate(r *rand.Rand, size int) reflect.Value { func runRandTest(rt randTest) bool { var ( - triedb = NewDatabase(memorydb.New()) + triedb = NewDatabase(rawdb.NewMemoryDatabase()) tr = NewEmpty(triedb) values = make(map[string]string) // tracks content of the trie origTrie = NewEmpty(triedb) @@ -811,7 +810,7 @@ func TestCommitSequence(t *testing.T) { addresses, accounts := makeAccounts(tc.count) // This spongeDb is used to check the sequence of disk-db-writes s := &spongeDb{sponge: sha3.NewLegacyKeccak256()} - db := NewDatabase(s) + db := NewDatabase(rawdb.NewDatabase(s)) trie := NewEmpty(db) // Another sponge is used to check the callback-sequence callbackSponge := sha3.NewLegacyKeccak256() @@ -854,7 +853,7 @@ func TestCommitSequenceRandomBlobs(t *testing.T) { prng := rand.New(rand.NewSource(int64(i))) // This spongeDb is used to check the sequence of disk-db-writes s := &spongeDb{sponge: sha3.NewLegacyKeccak256()} - db := NewDatabase(s) + db := NewDatabase(rawdb.NewDatabase(s)) trie := NewEmpty(db) // Another sponge is used to check the callback-sequence callbackSponge := sha3.NewLegacyKeccak256() @@ -894,11 +893,13 @@ func TestCommitSequenceStackTrie(t *testing.T) { prng := rand.New(rand.NewSource(int64(count))) // This spongeDb is used to check the sequence of disk-db-writes s := &spongeDb{sponge: sha3.NewLegacyKeccak256(), id: "a"} - db := NewDatabase(s) + db := NewDatabase(rawdb.NewDatabase(s)) trie := NewEmpty(db) // Another sponge is used for the stacktrie commits stackTrieSponge := &spongeDb{sponge: sha3.NewLegacyKeccak256(), id: "b"} - stTrie := NewStackTrie(stackTrieSponge) + stTrie := NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, blob []byte) { + db.Scheme().WriteTrieNode(stackTrieSponge, owner, path, hash, blob) + }) // Fill the trie with elements for i := 0; i < count; i++ { // For the stack trie, we need to do inserts in proper order @@ -951,11 +952,13 @@ func TestCommitSequenceStackTrie(t *testing.T) { // not fit into 32 bytes, rlp-encoded. However, it's still the correct thing to do. func TestCommitSequenceSmallRoot(t *testing.T) { s := &spongeDb{sponge: sha3.NewLegacyKeccak256(), id: "a"} - db := NewDatabase(s) + db := NewDatabase(rawdb.NewDatabase(s)) trie := NewEmpty(db) // Another sponge is used for the stacktrie commits stackTrieSponge := &spongeDb{sponge: sha3.NewLegacyKeccak256(), id: "b"} - stTrie := NewStackTrie(stackTrieSponge) + stTrie := NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, blob []byte) { + db.Scheme().WriteTrieNode(stackTrieSponge, owner, path, hash, blob) + }) // Add a single small-element to the trie(s) key := make([]byte, 5) key[0] = 1