diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 61a28bcb777b..60341fcb805b 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -32,6 +32,29 @@ jobs: env: RUSTFLAGS: -D warnings + msrv: + name: MSRV / ${{ matrix.network }} + runs-on: ubuntu-latest + timeout-minutes: 30 + strategy: + matrix: + include: + - binary: reth + network: ethereum + - binary: op-reth + network: optimism + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: "1.70" # MSRV + - uses: Swatinem/rust-cache@v2 + with: + cache-on-failure: true + - run: cargo build --bin "${{ matrix.binary }}" --workspace --features "${{ matrix.network }}" + env: + RUSTFLAGS: -D warnings + docs: name: docs runs-on: ubuntu-latest diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7dd8e1fbdbda..945feefd7d01 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -185,7 +185,7 @@ jobs: ) assets=() for asset in ./reth-*.tar.gz*; do - assets+=("-a" "$asset/$asset") + assets+=("$asset/$asset") done tag_name="${{ env.VERSION }}" - echo "$body" | gh release create --draft "${assets[@]}" -F "-" "$tag_name" + echo "$body" | gh release create --draft -F "-" "$tag_name" "${assets[@]}" diff --git a/Cargo.lock b/Cargo.lock index e1f81c2c2c59..b14f1c39973a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -499,9 +499,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f658e2baef915ba0f26f1f7c42bfb8e12f532a01f449a090ded75ae7a07e9ba2" +checksum = "bc2d0cfb2a7388d34f590e76686704c494ed7aaceed62ee1ba35cbf363abc2a5" dependencies = [ "brotli", "flate2", @@ -1104,25 +1104,11 @@ dependencies = [ [[package]] name = "cargo-platform" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12024c4645c97566567129c204f65d5815a8c9aecf30fcbe682b2fe034996d36" -dependencies = [ - "serde", -] - -[[package]] -name = "cargo_metadata" -version = "0.17.0" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7daec1a2a2129eeba1644b220b4647ec537b0b5d4bfd6876fcc5a540056b592" +checksum = "e34637b3140142bdf929fb439e8aa4ebad7651ebf7b1080b3930aa16ac1459ff" dependencies = [ - "camino", - "cargo-platform", - "semver 1.0.20", "serde", - "serde_json", - "thiserror", ] [[package]] @@ -1305,7 +1291,7 @@ checksum = "67ba02a97a2bd10f4b59b25c7973101c79642302776489e030cd13cdab09ed15" [[package]] name = "codecs-derive" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "convert_case 0.6.0", "parity-scale-codec", @@ -1652,9 +1638,9 @@ checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" [[package]] name = "crypto-bigint" -version = "0.5.3" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "740fe28e594155f10cfc383984cbefd529d7396050557148f79cb0f621204124" +checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" dependencies = [ "generic-array", "rand_core 0.6.4", @@ -2109,9 +2095,9 @@ checksum = "6c8adcce29eef18ae1369bbd268fd56bf98144e80281315e9d4a82e34df001c7" [[package]] name = "ecdsa" -version = "0.16.8" +version = "0.16.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4b1e0c257a9e9f25f90ff76d7a68360ed497ee519c8e428d1825ef0000799d4" +checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" dependencies = [ "der", "digest 0.10.7", @@ -2133,15 +2119,16 @@ dependencies = [ [[package]] name = "ed25519-dalek" -version = "2.0.0" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7277392b266383ef8396db7fdeb1e77b6c52fed775f5df15bb24f35b72156980" +checksum = "1f628eaec48bfd21b865dc2950cfa014450c01d2fa2b69a86c2fd5844ec523c0" dependencies = [ "curve25519-dalek", "ed25519", "rand_core 0.6.4", "serde", "sha2", + "subtle", "zeroize", ] @@ -2159,7 +2146,7 @@ dependencies = [ [[package]] name = "ef-tests" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "alloy-rlp", "reth-db", @@ -2183,9 +2170,9 @@ checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" [[package]] name = "elliptic-curve" -version = "0.13.6" +version = "0.13.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d97ca172ae9dc9f9b779a6e3a65d308f2af74e5b8c921299075bdb4a0370e914" +checksum = "e9775b22bc152ad86a0cf23f0f348b884b26add12bf741e7ffc4d4ab2ab4d205" dependencies = [ "base16ct", "crypto-bigint", @@ -2303,9 +2290,9 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c18ee0ed65a5f1f81cac6b1d213b69c35fa47d4252ad41f1486dbd8226fe36e" +checksum = "f258a7194e7f7c2a7837a8913aeab7fd8c383457034fa20ce4dd3dcb813e8eb8" dependencies = [ "libc", "windows-sys 0.48.0", @@ -2383,9 +2370,9 @@ dependencies = [ [[package]] name = "ethers-contract" -version = "2.0.10" +version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d79269278125006bb0552349c03593ffa9702112ca88bc7046cc669f148fb47c" +checksum = "0111ead599d17a7bff6985fd5756f39ca7033edc79a31b23026a8d5d64fa95cd" dependencies = [ "const-hex", "ethers-contract-abigen", @@ -2402,9 +2389,9 @@ dependencies = [ [[package]] name = "ethers-contract-abigen" -version = "2.0.10" +version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce95a43c939b2e4e2f3191c5ad4a1f279780b8a39139c9905b43a7433531e2ab" +checksum = "51258120c6b47ea9d9bec0d90f9e8af71c977fbefbef8213c91bfed385fe45eb" dependencies = [ "Inflector", "const-hex", @@ -2418,15 +2405,15 @@ dependencies = [ "serde", "serde_json", "syn 2.0.39", - "toml 0.7.8", + "toml 0.8.8", "walkdir", ] [[package]] name = "ethers-contract-derive" -version = "2.0.10" +version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e9ce44906fc871b3ee8c69a695ca7ec7f70e50cb379c9b9cb5e532269e492f6" +checksum = "936e7a0f1197cee2b62dc89f63eff3201dbf87c283ff7e18d86d38f83b845483" dependencies = [ "Inflector", "const-hex", @@ -2440,13 +2427,13 @@ dependencies = [ [[package]] name = "ethers-core" -version = "2.0.10" +version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0a17f0708692024db9956b31d7a20163607d2745953f5ae8125ab368ba280ad" +checksum = "2f03e0bdc216eeb9e355b90cf610ef6c5bb8aca631f97b5ae9980ce34ea7878d" dependencies = [ "arrayvec", "bytes", - "cargo_metadata 0.17.0", + "cargo_metadata", "chrono", "const-hex", "elliptic-curve", @@ -2468,32 +2455,16 @@ dependencies = [ "unicode-xid", ] -[[package]] -name = "ethers-etherscan" -version = "2.0.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e53451ea4a8128fbce33966da71132cf9e1040dcfd2a2084fd7733ada7b2045" -dependencies = [ - "ethers-core", - "reqwest", - "semver 1.0.20", - "serde", - "serde_json", - "thiserror", - "tracing", -] - [[package]] name = "ethers-middleware" -version = "2.0.10" +version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "473f1ccd0c793871bbc248729fa8df7e6d2981d6226e4343e3bbaa9281074d5d" +checksum = "681ece6eb1d10f7cf4f873059a77c04ff1de4f35c63dd7bccde8f438374fcb93" dependencies = [ "async-trait", "auto_impl", "ethers-contract", "ethers-core", - "ethers-etherscan", "ethers-providers", "ethers-signers", "futures-channel", @@ -2512,9 +2483,9 @@ dependencies = [ [[package]] name = "ethers-providers" -version = "2.0.10" +version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6838fa110e57d572336178b7c79e94ff88ef976306852d8cb87d9e5b1fc7c0b5" +checksum = "25d6c0c9455d93d4990c06e049abf9b30daf148cf461ee939c11d88907c60816" dependencies = [ "async-trait", "auto_impl", @@ -2550,9 +2521,9 @@ dependencies = [ [[package]] name = "ethers-signers" -version = "2.0.10" +version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ea44bec930f12292866166f9ddbea6aa76304850e4d8dcd66dc492b43d00ff1" +checksum = "0cb1b714e227bbd2d8c53528adb580b203009728b17d0d0e4119353aa9bc5532" dependencies = [ "async-trait", "coins-bip32", @@ -2598,9 +2569,9 @@ dependencies = [ [[package]] name = "eyre" -version = "0.6.8" +version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c2b6b5a29c02cdc822728b7d7b8ae1bab3e3b05d44522770ddd49722eeac7eb" +checksum = "80f656be11ddf91bd709454d15d5bd896fbaf4cc3314e69349e4d1569f5b46cd" dependencies = [ "indenter", "once_cell", @@ -2659,9 +2630,9 @@ dependencies = [ [[package]] name = "fiat-crypto" -version = "0.2.3" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f69037fe1b785e84986b4f2cbcf647381876a00671d25ceef715d7812dd7e1dd" +checksum = "27573eac26f4dd11e2b1916c3fe1baa56407c83c71a773a8ba17ec0bca03b6b7" [[package]] name = "findshlibs" @@ -2995,9 +2966,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.3.21" +version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91fc23aa11be92976ef4729127f1a74adf36d8436f7816b185d18df956790833" +checksum = "4d6250322ef6e60f93f9a2162799302cd6f68f79f6e5d85c8c16f14d1d958178" dependencies = [ "bytes", "fnv", @@ -3005,7 +2976,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap 1.9.3", + "indexmap 2.1.0", "slab", "tokio", "tokio-util", @@ -3079,9 +3050,9 @@ dependencies = [ [[package]] name = "hdrhistogram" -version = "7.5.3" +version = "7.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5b38e5c02b7c7be48c8dc5217c4f1634af2ea221caae2e024bffc7a7651c691" +checksum = "765c9198f173dd59ce26ff9f95ef0aafd0a0fe01fb9d72841bc5066a4c06511d" dependencies = [ "byteorder", "num-traits", @@ -3168,9 +3139,9 @@ dependencies = [ [[package]] name = "http" -version = "0.2.10" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f95b9abcae896730d42b78e09c155ed4ddf82c07b4de772c64aee5b2d8b7c150" +checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb" dependencies = [ "bytes", "fnv", @@ -3927,9 +3898,9 @@ dependencies = [ [[package]] name = "k256" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cadb76004ed8e97623117f3df85b17aaa6626ab0b0831e6573f104df16cd1bcc" +checksum = "3f01b677d82ef7a676aa37e099defd83a28e15687112cafdd112d60236b6115b" dependencies = [ "cfg-if", "ecdsa", @@ -4007,15 +3978,6 @@ dependencies = [ "redox_syscall 0.4.1", ] -[[package]] -name = "lifetimed-bytes" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c970c8ea4c7b023a41cfa4af4c785a16694604c2f2a3b0d1f20a9bcb73fa550" -dependencies = [ - "bytes", -] - [[package]] name = "linked-hash-map" version = "0.5.6" @@ -4039,9 +4001,9 @@ checksum = "969488b55f8ac402214f3f5fd243ebb7206cf82de60d3172994707a4bcc2b829" [[package]] name = "litemap" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a1a2647d5b7134127971a6de0d533c49de2159167e7f259c427195f87168a1" +checksum = "f9d642685b028806386b2b6e75685faadd3eb65a85fff7df711ce18446a422da" [[package]] name = "lock_api" @@ -5575,7 +5537,7 @@ dependencies = [ [[package]] name = "reth" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "alloy-rlp", "aquamarine", @@ -5652,7 +5614,7 @@ dependencies = [ [[package]] name = "reth-auto-seal-consensus" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "clap", "eyre", @@ -5675,7 +5637,7 @@ dependencies = [ [[package]] name = "reth-basic-payload-builder" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "alloy-rlp", "futures-core", @@ -5696,7 +5658,7 @@ dependencies = [ [[package]] name = "reth-beacon-consensus" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "assert_matches", "cfg-if", @@ -5729,7 +5691,7 @@ dependencies = [ [[package]] name = "reth-blockchain-tree" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "aquamarine", "assert_matches", @@ -5749,7 +5711,7 @@ dependencies = [ [[package]] name = "reth-codecs" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "arbitrary", "bytes", @@ -5764,7 +5726,7 @@ dependencies = [ [[package]] name = "reth-config" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "confy", "reth-discv4", @@ -5781,7 +5743,7 @@ dependencies = [ [[package]] name = "reth-consensus-common" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "assert_matches", "cfg-if", @@ -5793,7 +5755,7 @@ dependencies = [ [[package]] name = "reth-db" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "arbitrary", "assert_matches", @@ -5838,7 +5800,7 @@ dependencies = [ [[package]] name = "reth-discv4" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "alloy-rlp", "discv5", @@ -5861,7 +5823,7 @@ dependencies = [ [[package]] name = "reth-dns-discovery" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "alloy-rlp", "async-trait", @@ -5885,7 +5847,7 @@ dependencies = [ [[package]] name = "reth-downloaders" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "alloy-rlp", "assert_matches", @@ -5899,6 +5861,7 @@ dependencies = [ "reth-interfaces", "reth-metrics", "reth-primitives", + "reth-provider", "reth-tasks", "reth-tracing", "tempfile", @@ -5911,7 +5874,7 @@ dependencies = [ [[package]] name = "reth-ecies" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "aes 0.8.3", "alloy-rlp", @@ -5941,7 +5904,7 @@ dependencies = [ [[package]] name = "reth-eth-wire" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "alloy-rlp", "arbitrary", @@ -5973,7 +5936,7 @@ dependencies = [ [[package]] name = "reth-interfaces" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "arbitrary", "async-trait", @@ -6000,7 +5963,7 @@ dependencies = [ [[package]] name = "reth-ipc" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "async-trait", "bytes", @@ -6019,7 +5982,7 @@ dependencies = [ [[package]] name = "reth-libmdbx" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "bitflags 2.4.1", "byteorder", @@ -6027,7 +5990,6 @@ dependencies = [ "derive_more", "indexmap 2.1.0", "libc", - "lifetimed-bytes", "parking_lot 0.12.1", "pprof", "rand 0.8.5", @@ -6039,7 +6001,7 @@ dependencies = [ [[package]] name = "reth-mdbx-sys" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "bindgen 0.68.1", "cc", @@ -6048,7 +6010,7 @@ dependencies = [ [[package]] name = "reth-metrics" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "futures", "metrics", @@ -6059,7 +6021,7 @@ dependencies = [ [[package]] name = "reth-metrics-derive" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "metrics", "once_cell", @@ -6073,7 +6035,7 @@ dependencies = [ [[package]] name = "reth-net-common" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "pin-project", "reth-primitives", @@ -6082,7 +6044,7 @@ dependencies = [ [[package]] name = "reth-net-nat" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "igd", "pin-project-lite", @@ -6096,7 +6058,7 @@ dependencies = [ [[package]] name = "reth-network" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "alloy-rlp", "aquamarine", @@ -6146,7 +6108,7 @@ dependencies = [ [[package]] name = "reth-network-api" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "async-trait", "reth-discv4", @@ -6160,7 +6122,7 @@ dependencies = [ [[package]] name = "reth-nippy-jar" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "anyhow", "bincode", @@ -6180,7 +6142,7 @@ dependencies = [ [[package]] name = "reth-payload-builder" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "alloy-rlp", "futures-util", @@ -6202,7 +6164,7 @@ dependencies = [ [[package]] name = "reth-primitives" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "alloy-primitives", "alloy-rlp", @@ -6250,7 +6212,7 @@ dependencies = [ [[package]] name = "reth-provider" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "alloy-rlp", "assert_matches", @@ -6277,7 +6239,7 @@ dependencies = [ [[package]] name = "reth-prune" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "assert_matches", "itertools 0.11.0", @@ -6299,25 +6261,26 @@ dependencies = [ [[package]] name = "reth-revm" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "reth-consensus-common", "reth-interfaces", "reth-primitives", "reth-provider", "reth-revm-inspectors", + "reth-trie", "revm", "tracing", ] [[package]] name = "reth-revm-inspectors" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ + "alloy-primitives", "alloy-sol-types", "boa_engine", "boa_gc", - "reth-primitives", "reth-rpc-types", "revm", "serde", @@ -6328,7 +6291,7 @@ dependencies = [ [[package]] name = "reth-rpc" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "alloy-dyn-abi", "alloy-primitives", @@ -6381,7 +6344,7 @@ dependencies = [ [[package]] name = "reth-rpc-api" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "jsonrpsee", "reth-primitives", @@ -6391,7 +6354,7 @@ dependencies = [ [[package]] name = "reth-rpc-api-testing-util" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "async-trait", "futures", @@ -6405,7 +6368,7 @@ dependencies = [ [[package]] name = "reth-rpc-builder" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "hyper", "jsonrpsee", @@ -6438,7 +6401,7 @@ dependencies = [ [[package]] name = "reth-rpc-engine-api" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "alloy-rlp", "assert_matches", @@ -6464,7 +6427,7 @@ dependencies = [ [[package]] name = "reth-rpc-types" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "alloy-primitives", "alloy-rlp", @@ -6486,7 +6449,7 @@ dependencies = [ [[package]] name = "reth-rpc-types-compat" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "alloy-rlp", "reth-primitives", @@ -6495,7 +6458,7 @@ dependencies = [ [[package]] name = "reth-snapshot" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "assert_matches", "clap", @@ -6513,12 +6476,13 @@ dependencies = [ [[package]] name = "reth-stages" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "alloy-rlp", "aquamarine", "assert_matches", "async-trait", + "auto_impl", "criterion", "futures-util", "itertools 0.11.0", @@ -6552,7 +6516,7 @@ dependencies = [ [[package]] name = "reth-tasks" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "dyn-clone", "futures-util", @@ -6566,7 +6530,7 @@ dependencies = [ [[package]] name = "reth-tokio-util" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "tokio", "tokio-stream", @@ -6574,7 +6538,7 @@ dependencies = [ [[package]] name = "reth-tracing" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "rolling-file", "tracing", @@ -6585,7 +6549,7 @@ dependencies = [ [[package]] name = "reth-transaction-pool" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "alloy-rlp", "aquamarine", @@ -6619,7 +6583,7 @@ dependencies = [ [[package]] name = "reth-trie" -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" dependencies = [ "alloy-rlp", "auto_impl", @@ -6643,7 +6607,7 @@ dependencies = [ [[package]] name = "revm" version = "3.5.0" -source = "git+https://github.com/bluealloy/revm?rev=1609e07c68048909ad1682c98cf2b9baa76310b5#1609e07c68048909ad1682c98cf2b9baa76310b5" +source = "git+https://github.com/bluealloy/revm?branch=reth_freeze#74643d37fc6231d558868ccc8b97400506e10906" dependencies = [ "auto_impl", "revm-interpreter", @@ -6653,7 +6617,7 @@ dependencies = [ [[package]] name = "revm-interpreter" version = "1.3.0" -source = "git+https://github.com/bluealloy/revm?rev=1609e07c68048909ad1682c98cf2b9baa76310b5#1609e07c68048909ad1682c98cf2b9baa76310b5" +source = "git+https://github.com/bluealloy/revm?branch=reth_freeze#74643d37fc6231d558868ccc8b97400506e10906" dependencies = [ "revm-primitives", ] @@ -6661,7 +6625,7 @@ dependencies = [ [[package]] name = "revm-precompile" version = "2.2.0" -source = "git+https://github.com/bluealloy/revm?rev=1609e07c68048909ad1682c98cf2b9baa76310b5#1609e07c68048909ad1682c98cf2b9baa76310b5" +source = "git+https://github.com/bluealloy/revm?branch=reth_freeze#74643d37fc6231d558868ccc8b97400506e10906" dependencies = [ "aurora-engine-modexp", "c-kzg", @@ -6677,7 +6641,7 @@ dependencies = [ [[package]] name = "revm-primitives" version = "1.3.0" -source = "git+https://github.com/bluealloy/revm?rev=1609e07c68048909ad1682c98cf2b9baa76310b5#1609e07c68048909ad1682c98cf2b9baa76310b5" +source = "git+https://github.com/bluealloy/revm?branch=reth_freeze#74643d37fc6231d558868ccc8b97400506e10906" dependencies = [ "alloy-primitives", "alloy-rlp", @@ -6876,9 +6840,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.21" +version = "0.38.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b426b0506e5d50a7d8dafcf2e81471400deb602392c7dd110815afb4eaf02a3" +checksum = "9ad981d6c340a49cdc40a1028d9c6084ec7e9fa33fcb839cab656a267071e234" dependencies = [ "bitflags 2.4.1", "errno", @@ -6889,9 +6853,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.21.8" +version = "0.21.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "446e14c5cda4f3f30fe71863c34ec70f5ac79d6087097ad0bb433e1be5edf04c" +checksum = "629648aced5775d558af50b2b4c7b02983a04b312126d45eeead26e7caa498b9" dependencies = [ "log", "ring 0.17.5", @@ -7402,9 +7366,9 @@ dependencies = [ [[package]] name = "signature" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e1788eed21689f9cf370582dfc467ef36ed9c707f073528ddafa8d83e3b8500" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" dependencies = [ "digest 0.10.7", "rand_core 0.6.4", @@ -7653,9 +7617,9 @@ dependencies = [ [[package]] name = "symbolic-common" -version = "12.5.0" +version = "12.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d3aa424281de488c1ddbaffb55a421ad87d04b0fdd5106e7e71d748c0c71ea6" +checksum = "39eac77836da383d35edbd9ff4585b4fc1109929ff641232f2e9a1aefdfc9e91" dependencies = [ "debugid", "memmap2 0.8.0", @@ -7665,9 +7629,9 @@ dependencies = [ [[package]] name = "symbolic-demangle" -version = "12.5.0" +version = "12.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9bdcf77effe2908a21c1011b4d49a7122e0f44487a6ad89db67c55a1687e2572" +checksum = "4ee1608a1d13061fb0e307a316de29f6c6e737b05459fe6bbf5dd8d7837c4fb7" dependencies = [ "cpp_demangle", "rustc-demangle", @@ -7762,9 +7726,9 @@ dependencies = [ [[package]] name = "termcolor" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6093bad37da69aab9d123a8091e4be0aa4a03e4d601ec641c327398315f62b64" +checksum = "ff1bc3d3f05aff0403e8ac0d92ced918ec05b666a43f83297ccef5bea8a3d449" dependencies = [ "winapi-util", ] @@ -7777,9 +7741,9 @@ checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" [[package]] name = "test-fuzz" -version = "4.0.3" +version = "4.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59bdd14ea6ac9fd993d966b0133da233f534bac0c1a44a2200cec1eb244c733c" +checksum = "de8cb3597f1463b9c98b21c08d11033166a57942e60e8044e7e3bb4a8ca5416b" dependencies = [ "serde", "test-fuzz-internal", @@ -7789,20 +7753,20 @@ dependencies = [ [[package]] name = "test-fuzz-internal" -version = "4.0.3" +version = "4.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7eb212edbf2406eed119bd5e1b89bf3201f3f9d9961b5ae39324873f2a0805ed" +checksum = "3dd8da182ee4e8b195da3aa38f72b84d267bda3874cd6ef8dd29c03a71f866f2" dependencies = [ "bincode", - "cargo_metadata 0.18.1", + "cargo_metadata", "serde", ] [[package]] name = "test-fuzz-macro" -version = "4.0.3" +version = "4.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b2f42720e86f42661bd88d7aaa9d041056530f79c1f0bc6ac90dfb681905e86" +checksum = "86cb030b9e51def5bd7bf98b3ee6e81aae7f021ebf2e05e70029b768508c376f" dependencies = [ "darling 0.20.3", "if_chain", @@ -7817,9 +7781,9 @@ dependencies = [ [[package]] name = "test-fuzz-runtime" -version = "4.0.3" +version = "4.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e0aae6ea22e9e0730b79eac5cb7426dc257503d07ecedf7bd799598070908d1" +checksum = "dd6e7a964e6c5b20df8b03572f7fa43aa28d80fa4871b3083e597ed32664f614" dependencies = [ "hex", "num-traits", @@ -8048,18 +8012,6 @@ dependencies = [ "serde", ] -[[package]] -name = "toml" -version = "0.7.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd79e69d3b627db300ff956027cc6c3798cef26d22526befdfcd12feeb6d2257" -dependencies = [ - "serde", - "serde_spanned", - "toml_datetime", - "toml_edit 0.19.15", -] - [[package]] name = "toml" version = "0.8.8" @@ -8088,8 +8040,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ "indexmap 2.1.0", - "serde", - "serde_spanned", "toml_datetime", "winnow", ] @@ -8220,11 +8170,12 @@ dependencies = [ [[package]] name = "tracing-appender" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09d48f71a791638519505cefafe162606f706c25592e4bde4d97600c0195312e" +checksum = "3566e8ce28cc0a3fe42519fc80e6b4c943cc4c8cef275620eb8dac2d3d4e06cf" dependencies = [ "crossbeam-channel", + "thiserror", "time", "tracing-subscriber", ] @@ -8275,9 +8226,9 @@ dependencies = [ [[package]] name = "tracing-log" -version = "0.1.4" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f751112709b4e791d8ce53e32c4ed2d353565a795ce84da2285393f41557bdf2" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" dependencies = [ "log", "once_cell", @@ -8286,9 +8237,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.17" +version = "0.3.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" dependencies = [ "matchers", "nu-ansi-term", @@ -9014,9 +8965,9 @@ checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" [[package]] name = "writeable" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0af0c3d13faebf8dda0b5256fa7096a2d5ccb662f7b9f54a40fe201077ab1c2" +checksum = "dad7bb64b8ef9c0aa27b6da38b452b0ee9fd82beaf276a87dd796fb55cbae14e" [[package]] name = "ws_stream_wasm" @@ -9078,9 +9029,9 @@ checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" [[package]] name = "yoke" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e38c508604d6bbbd292dadb3c02559aa7fff6b654a078a36217cad871636e4" +checksum = "65e71b2e4f287f467794c671e2b8f8a5f3716b3c829079a1c44740148eff07e4" dependencies = [ "serde", "stable_deref_trait", @@ -9090,9 +9041,9 @@ dependencies = [ [[package]] name = "yoke-derive" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5e19fb6ed40002bab5403ffa37e53e0e56f914a4450c8765f533018db1db35f" +checksum = "9e6936f0cce458098a201c245a11bef556c6a0181129c7034d10d76d1ec3a2b8" dependencies = [ "proc-macro2", "quote", @@ -9102,18 +9053,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.25" +version = "0.7.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cd369a67c0edfef15010f980c3cbe45d7f651deac2cd67ce097cd801de16557" +checksum = "e97e415490559a91254a2979b4829267a57d2fcd741a98eee8b722fb57289aa0" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.25" +version = "0.7.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2f140bda219a26ccc0cdb03dba58af72590c53b22642577d88a927bc5c87d6b" +checksum = "dd7e48ccf166952882ca8bd778a43502c64f33bf94c12ebe2a7f08e5a0f6689f" dependencies = [ "proc-macro2", "quote", @@ -9143,9 +9094,9 @@ dependencies = [ [[package]] name = "zeroize" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" dependencies = [ "zeroize_derive", ] diff --git a/Cargo.toml b/Cargo.toml index d0bbf106e879..009997b8478a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,9 +62,9 @@ default-members = ["bin/reth"] resolver = "2" [workspace.package] -version = "0.1.0-alpha.10" +version = "0.1.0-alpha.11" edition = "2021" -rust-version = "1.70" # Remember to update clippy.toml and README.md +rust-version = "1.70" license = "MIT OR Apache-2.0" homepage = "https://paradigmxyz.github.io/reth" repository = "https://github.com/paradigmxyz/reth" @@ -135,8 +135,8 @@ reth-transaction-pool = { path = "crates/transaction-pool" } reth-trie = { path = "crates/trie" } # revm -revm = { git = "https://github.com/bluealloy/revm", rev = "1609e07c68048909ad1682c98cf2b9baa76310b5" } -revm-primitives = { git = "https://github.com/bluealloy/revm", rev = "1609e07c68048909ad1682c98cf2b9baa76310b5" } +revm = { git = "https://github.com/bluealloy/revm", branch = "reth_freeze", features = ["std", "secp256k1"], default-features = false } +revm-primitives = { git = "https://github.com/bluealloy/revm", branch = "reth_freeze", features = ["std"], default-features = false } # eth alloy-primitives = "0.4" diff --git a/README.md b/README.md index 8dd1c987d7a4..910a7cf401d2 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,13 @@ If you want to contribute, or follow along with contributor discussion, you can ### Building and testing + + The Minimum Supported Rust Version (MSRV) of this project is [1.70.0](https://blog.rust-lang.org/2023/06/01/Rust-1.70.0.html). See the book for detailed instructions on how to [build from source](https://paradigmxyz.github.io/reth/installation/source.html). diff --git a/bin/reth/src/args/database_args.rs b/bin/reth/src/args/database_args.rs index b94c36d7eeb7..954390cdc211 100644 --- a/bin/reth/src/args/database_args.rs +++ b/bin/reth/src/args/database_args.rs @@ -5,7 +5,7 @@ use reth_interfaces::db::LogLevel; /// Parameters for database configuration #[derive(Debug, Args, PartialEq, Default, Clone, Copy)] -#[command(next_help_heading = "Database")] +#[clap(next_help_heading = "Database")] pub struct DatabaseArgs { /// Database logging level. Levels higher than "notice" require a debug build. #[arg(long = "db.log-level", value_enum)] diff --git a/bin/reth/src/args/debug_args.rs b/bin/reth/src/args/debug_args.rs index ffc5e36a27a5..3772aa52cb2f 100644 --- a/bin/reth/src/args/debug_args.rs +++ b/bin/reth/src/args/debug_args.rs @@ -5,7 +5,7 @@ use reth_primitives::{TxHash, B256}; /// Parameters for debugging purposes #[derive(Debug, Args, PartialEq, Default)] -#[command(next_help_heading = "Debug")] +#[clap(next_help_heading = "Debug")] pub struct DebugArgs { /// Prompt the downloader to download blocks one at a time. /// diff --git a/bin/reth/src/args/dev_args.rs b/bin/reth/src/args/dev_args.rs index ec951e795912..da046225a89e 100644 --- a/bin/reth/src/args/dev_args.rs +++ b/bin/reth/src/args/dev_args.rs @@ -6,7 +6,7 @@ use humantime::parse_duration; /// Parameters for Dev testnet configuration #[derive(Debug, Args, PartialEq, Default, Clone, Copy)] -#[command(next_help_heading = "Dev testnet")] +#[clap(next_help_heading = "Dev testnet")] pub struct DevArgs { /// Start the node in dev mode /// diff --git a/bin/reth/src/args/gas_price_oracle_args.rs b/bin/reth/src/args/gas_price_oracle_args.rs index 9d417f348159..42237f91276c 100644 --- a/bin/reth/src/args/gas_price_oracle_args.rs +++ b/bin/reth/src/args/gas_price_oracle_args.rs @@ -2,7 +2,7 @@ use clap::Args; /// Parameters to configure Gas Price Oracle #[derive(Debug, Clone, Args, PartialEq, Eq, Default)] -#[command(next_help_heading = "Gas Price Oracle")] +#[clap(next_help_heading = "Gas Price Oracle")] pub struct GasPriceOracleArgs { /// Number of recent blocks to check for gas price #[arg(long = "gpo.blocks", default_value = "20")] diff --git a/bin/reth/src/args/network_args.rs b/bin/reth/src/args/network_args.rs index 70053d4ad405..44fd62282599 100644 --- a/bin/reth/src/args/network_args.rs +++ b/bin/reth/src/args/network_args.rs @@ -12,7 +12,7 @@ use std::{net::Ipv4Addr, path::PathBuf, sync::Arc}; /// Parameters for configuring the network more granularity via CLI #[derive(Debug, Args)] -#[command(next_help_heading = "Networking")] +#[clap(next_help_heading = "Networking")] pub struct NetworkArgs { /// Disable the discovery service. #[command(flatten)] diff --git a/bin/reth/src/args/payload_builder_args.rs b/bin/reth/src/args/payload_builder_args.rs index 93f9ad2571d4..7de104987c79 100644 --- a/bin/reth/src/args/payload_builder_args.rs +++ b/bin/reth/src/args/payload_builder_args.rs @@ -11,30 +11,26 @@ use std::{borrow::Cow, ffi::OsStr, time::Duration}; /// Parameters for configuring the Payload Builder #[derive(Debug, Args, PartialEq, Default)] +#[clap(next_help_heading = "Builder")] pub struct PayloadBuilderArgs { /// Block extra data set by the payload builder. - #[arg(long = "builder.extradata", help_heading = "Builder", value_parser=ExtradataValueParser::default(), default_value_t = default_extradata())] + #[arg(long = "builder.extradata", value_parser=ExtradataValueParser::default(), default_value_t = default_extradata())] pub extradata: String, /// Target gas ceiling for built blocks. - #[arg( - long = "builder.gaslimit", - help_heading = "Builder", - default_value = "30000000", - value_name = "GAS_LIMIT" - )] + #[arg(long = "builder.gaslimit", default_value = "30000000", value_name = "GAS_LIMIT")] pub max_gas_limit: u64, /// The interval at which the job should build a new payload after the last (in seconds). - #[arg(long = "builder.interval", help_heading = "Builder", value_parser = parse_duration_from_secs, default_value = "1", value_name = "SECONDS")] + #[arg(long = "builder.interval", value_parser = parse_duration_from_secs, default_value = "1", value_name = "SECONDS")] pub interval: Duration, /// The deadline for when the payload builder job should resolve. - #[arg(long = "builder.deadline", help_heading = "Builder", value_parser = parse_duration_from_secs, default_value = "12", value_name = "SECONDS")] + #[arg(long = "builder.deadline", value_parser = parse_duration_from_secs, default_value = "12", value_name = "SECONDS")] pub deadline: Duration, /// Maximum number of tasks to spawn for building a payload. - #[arg(long = "builder.max-tasks", help_heading = "Builder", default_value = "3", value_parser = RangedU64ValueParser::::new().range(1..))] + #[arg(long = "builder.max-tasks", default_value = "3", value_parser = RangedU64ValueParser::::new().range(1..))] pub max_payload_tasks: usize, /// By default the pending block equals the latest block diff --git a/bin/reth/src/args/pruning_args.rs b/bin/reth/src/args/pruning_args.rs index f233c810efa4..251e1a918669 100644 --- a/bin/reth/src/args/pruning_args.rs +++ b/bin/reth/src/args/pruning_args.rs @@ -9,7 +9,7 @@ use std::sync::Arc; /// Parameters for pruning and full node #[derive(Debug, Args, PartialEq, Default)] -#[command(next_help_heading = "Pruning")] +#[clap(next_help_heading = "Pruning")] pub struct PruningArgs { /// Run full node. Only the most recent [`MINIMUM_PRUNING_DISTANCE`] block states are stored. /// This flag takes priority over pruning configuration in reth.toml. diff --git a/bin/reth/src/args/rollup_args.rs b/bin/reth/src/args/rollup_args.rs index abd38f6db7a3..c97fe19147e5 100644 --- a/bin/reth/src/args/rollup_args.rs +++ b/bin/reth/src/args/rollup_args.rs @@ -2,7 +2,7 @@ /// Parameters for rollup configuration #[derive(Debug, clap::Args)] -#[command(next_help_heading = "Rollup")] +#[clap(next_help_heading = "Rollup")] pub struct RollupArgs { /// HTTP endpoint for the sequencer mempool #[arg(long = "rollup.sequencer-http", value_name = "HTTP_URL")] diff --git a/bin/reth/src/args/rpc_server_args.rs b/bin/reth/src/args/rpc_server_args.rs index 32005b1638f2..fb66122456c8 100644 --- a/bin/reth/src/args/rpc_server_args.rs +++ b/bin/reth/src/args/rpc_server_args.rs @@ -62,7 +62,7 @@ pub(crate) const RPC_DEFAULT_MAX_CONNECTIONS: u32 = 500; /// Parameters for configuring the rpc more granularity via CLI #[derive(Debug, Clone, Args)] -#[command(next_help_heading = "RPC")] +#[clap(next_help_heading = "RPC")] pub struct RpcServerArgs { /// Enable the HTTP-RPC server #[arg(long, default_value_if("dev", "true", "true"))] @@ -532,6 +532,23 @@ mod tests { assert_eq!(apis, expected); } + #[test] + fn test_rpc_server_eth_call_bundle_args() { + let args = CommandParser::::parse_from([ + "reth", + "--http.api", + "eth,admin,debug,eth-call-bundle", + ]) + .args; + + let apis = args.http_api.unwrap(); + let expected = + RpcModuleSelection::try_from_selection(["eth", "admin", "debug", "eth-call-bundle"]) + .unwrap(); + + assert_eq!(apis, expected); + } + #[test] fn test_rpc_server_args_parser_none() { let args = CommandParser::::parse_from(["reth", "--http.api", "none"]).args; diff --git a/bin/reth/src/args/txpool_args.rs b/bin/reth/src/args/txpool_args.rs index 0f01f5c42284..f6c3a37bbf75 100644 --- a/bin/reth/src/args/txpool_args.rs +++ b/bin/reth/src/args/txpool_args.rs @@ -9,38 +9,39 @@ use reth_transaction_pool::{ /// Parameters for debugging purposes #[derive(Debug, Args, PartialEq, Default)] +#[clap(next_help_heading = "TxPool")] pub struct TxPoolArgs { /// Max number of transaction in the pending sub-pool. - #[arg(long = "txpool.pending_max_count", help_heading = "TxPool", default_value_t = TXPOOL_SUBPOOL_MAX_TXS_DEFAULT)] + #[arg(long = "txpool.pending_max_count", default_value_t = TXPOOL_SUBPOOL_MAX_TXS_DEFAULT)] pub pending_max_count: usize, /// Max size of the pending sub-pool in megabytes. - #[arg(long = "txpool.pending_max_size", help_heading = "TxPool", default_value_t = TXPOOL_SUBPOOL_MAX_SIZE_MB_DEFAULT)] + #[arg(long = "txpool.pending_max_size", default_value_t = TXPOOL_SUBPOOL_MAX_SIZE_MB_DEFAULT)] pub pending_max_size: usize, /// Max number of transaction in the basefee sub-pool - #[arg(long = "txpool.basefee_max_count", help_heading = "TxPool", default_value_t = TXPOOL_SUBPOOL_MAX_TXS_DEFAULT)] + #[arg(long = "txpool.basefee_max_count", default_value_t = TXPOOL_SUBPOOL_MAX_TXS_DEFAULT)] pub basefee_max_count: usize, /// Max size of the basefee sub-pool in megabytes. - #[arg(long = "txpool.basefee_max_size", help_heading = "TxPool", default_value_t = TXPOOL_SUBPOOL_MAX_SIZE_MB_DEFAULT)] + #[arg(long = "txpool.basefee_max_size", default_value_t = TXPOOL_SUBPOOL_MAX_SIZE_MB_DEFAULT)] pub basefee_max_size: usize, /// Max number of transaction in the queued sub-pool - #[arg(long = "txpool.queued_max_count", help_heading = "TxPool", default_value_t = TXPOOL_SUBPOOL_MAX_TXS_DEFAULT)] + #[arg(long = "txpool.queued_max_count", default_value_t = TXPOOL_SUBPOOL_MAX_TXS_DEFAULT)] pub queued_max_count: usize, /// Max size of the queued sub-pool in megabytes. - #[arg(long = "txpool.queued_max_size", help_heading = "TxPool", default_value_t = TXPOOL_SUBPOOL_MAX_SIZE_MB_DEFAULT)] + #[arg(long = "txpool.queued_max_size", default_value_t = TXPOOL_SUBPOOL_MAX_SIZE_MB_DEFAULT)] pub queued_max_size: usize, /// Max number of executable transaction slots guaranteed per account - #[arg(long = "txpool.max_account_slots", help_heading = "TxPool", default_value_t = TXPOOL_MAX_ACCOUNT_SLOTS_PER_SENDER)] + #[arg(long = "txpool.max_account_slots", default_value_t = TXPOOL_MAX_ACCOUNT_SLOTS_PER_SENDER)] pub max_account_slots: usize, /// Price bump (in %) for the transaction pool underpriced check. - #[arg(long = "txpool.pricebump", help_heading = "TxPool", default_value_t = DEFAULT_PRICE_BUMP)] + #[arg(long = "txpool.pricebump", default_value_t = DEFAULT_PRICE_BUMP)] pub price_bump: u128, /// Price bump percentage to replace an already existing blob transaction - #[arg(long = "blobpool.pricebump", help_heading = "TxPool", default_value_t = REPLACE_BLOB_PRICE_BUMP)] + #[arg(long = "blobpool.pricebump", default_value_t = REPLACE_BLOB_PRICE_BUMP)] pub blob_transaction_price_bump: u128, } diff --git a/bin/reth/src/chain/import.rs b/bin/reth/src/chain/import.rs index 984a34f8cf44..fffd5abed7e9 100644 --- a/bin/reth/src/chain/import.rs +++ b/bin/reth/src/chain/import.rs @@ -1,4 +1,8 @@ use crate::{ + args::{ + utils::{chain_help, genesis_value_parser, SUPPORTED_CHAINS}, + DatabaseArgs, + }, dirs::{DataDirPath, MaybePlatformPath}, init::init_genesis, node::events::{handle_events, NodeEvent}, @@ -8,12 +12,6 @@ use clap::Parser; use eyre::Context; use futures::{Stream, StreamExt}; use reth_beacon_consensus::BeaconConsensus; -use reth_provider::{ProviderFactory, StageCheckpointReader}; - -use crate::args::{ - utils::{chain_help, genesis_value_parser, SUPPORTED_CHAINS}, - DatabaseArgs, -}; use reth_config::Config; use reth_db::{database::Database, init_db}; use reth_downloaders::{ @@ -22,12 +20,10 @@ use reth_downloaders::{ }; use reth_interfaces::consensus::Consensus; use reth_primitives::{stage::StageId, ChainSpec, B256}; +use reth_provider::{HeaderSyncMode, ProviderFactory, StageCheckpointReader}; use reth_stages::{ prelude::*, - stages::{ - ExecutionStage, ExecutionStageThresholds, HeaderSyncMode, SenderRecoveryStage, - TotalDifficultyStage, - }, + stages::{ExecutionStage, ExecutionStageThresholds, SenderRecoveryStage, TotalDifficultyStage}, }; use std::{path::PathBuf, sync::Arc}; use tokio::sync::watch; @@ -90,6 +86,7 @@ impl ImportCommand { info!(target: "reth::cli", path = ?db_path, "Opening database"); let db = Arc::new(init_db(db_path, self.db.log_level)?); info!(target: "reth::cli", "Database opened"); + let provider_factory = ProviderFactory::new(db.clone(), self.chain.clone()); debug!(target: "reth::cli", chain=%self.chain.chain, genesis=?self.chain.genesis_hash(), "Initializing genesis"); @@ -106,19 +103,19 @@ impl ImportCommand { let tip = file_client.tip().expect("file client has no tip"); info!(target: "reth::cli", "Chain file imported"); - let (mut pipeline, events) = - self.build_import_pipeline(config, Arc::clone(&db), &consensus, file_client).await?; + let (mut pipeline, events) = self + .build_import_pipeline(config, provider_factory.clone(), &consensus, file_client) + .await?; // override the tip pipeline.set_tip(tip); debug!(target: "reth::cli", ?tip, "Tip manually set"); - let factory = ProviderFactory::new(&db, self.chain.clone()); - let provider = factory.provider()?; + let provider = provider_factory.provider()?; let latest_block_number = provider.get_stage_checkpoint(StageId::Finish)?.map(|ch| ch.block_number); - tokio::spawn(handle_events(None, latest_block_number, events)); + tokio::spawn(handle_events(None, latest_block_number, events, db.clone())); // Run pipeline info!(target: "reth::cli", "Starting sync pipeline"); @@ -134,7 +131,7 @@ impl ImportCommand { async fn build_import_pipeline( &self, config: Config, - db: DB, + provider_factory: ProviderFactory, consensus: &Arc, file_client: Arc, ) -> eyre::Result<(Pipeline, impl Stream)> @@ -151,11 +148,11 @@ impl ImportCommand { .into_task(); let body_downloader = BodiesDownloaderBuilder::from(config.stages.bodies) - .build(file_client.clone(), consensus.clone(), db.clone()) + .build(file_client.clone(), consensus.clone(), provider_factory.clone()) .into_task(); let (tip_tx, tip_rx) = watch::channel(B256::ZERO); - let factory = reth_revm::Factory::new(self.chain.clone()); + let factory = reth_revm::EvmProcessorFactory::new(self.chain.clone()); let max_block = file_client.max_block().unwrap_or(0); let mut pipeline = Pipeline::builder() @@ -164,6 +161,7 @@ impl ImportCommand { .with_max_block(max_block) .add_stages( DefaultStages::new( + provider_factory.clone(), HeaderSyncMode::Tip(tip_rx), consensus.clone(), header_downloader, @@ -193,7 +191,7 @@ impl ImportCommand { config.prune.map(|prune| prune.segments).unwrap_or_default(), )), ) - .build(db, self.chain.clone()); + .build(provider_factory); let events = pipeline.events().map(Into::into); diff --git a/bin/reth/src/cli/components.rs b/bin/reth/src/cli/components.rs index 8f45774688f7..18ef804f4a11 100644 --- a/bin/reth/src/cli/components.rs +++ b/bin/reth/src/cli/components.rs @@ -1,6 +1,6 @@ //! Components that are used by the node command. -use reth_network::NetworkEvents; +use reth_network::{NetworkEvents, NetworkProtocols}; use reth_network_api::{NetworkInfo, Peers}; use reth_primitives::ChainSpec; use reth_provider::{ @@ -48,7 +48,7 @@ pub trait RethNodeComponents: Clone + Send + Sync + 'static { /// The transaction pool type type Pool: TransactionPool + Clone + Unpin + 'static; /// The network type used to communicate with p2p. - type Network: NetworkInfo + Peers + NetworkEvents + Clone + 'static; + type Network: NetworkInfo + Peers + NetworkProtocols + NetworkEvents + Clone + 'static; /// The events type used to create subscriptions. type Events: CanonStateSubscriptions + Clone + 'static; /// The type that is used to spawn tasks. @@ -117,7 +117,7 @@ where Provider: FullProvider + Clone + 'static, Tasks: TaskSpawner + Clone + Unpin + 'static, Pool: TransactionPool + Clone + Unpin + 'static, - Network: NetworkInfo + Peers + NetworkEvents + Clone + 'static, + Network: NetworkInfo + Peers + NetworkProtocols + NetworkEvents + Clone + 'static, Events: CanonStateSubscriptions + Clone + 'static, { type Provider = Provider; diff --git a/bin/reth/src/cli/config.rs b/bin/reth/src/cli/config.rs index 8700edf04b70..48c1e2bd5fed 100644 --- a/bin/reth/src/cli/config.rs +++ b/bin/reth/src/cli/config.rs @@ -1,6 +1,7 @@ //! Config traits for various node components. use alloy_rlp::Encodable; +use reth_network::protocol::IntoRlpxSubProtocol; use reth_primitives::{Bytes, BytesMut}; use reth_rpc::{eth::gas_oracle::GasPriceOracleConfig, JwtError, JwtSecret}; use reth_rpc_builder::{ @@ -102,3 +103,21 @@ pub trait PayloadBuilderConfig { #[cfg(feature = "optimism")] fn compute_pending_block(&self) -> bool; } + +/// A trait that can be used to apply additional configuration to the network. +pub trait RethNetworkConfig { + /// Adds a new additional protocol to the RLPx sub-protocol list. + /// + /// These additional protocols are negotiated during the RLPx handshake. + /// If both peers share the same protocol, the corresponding handler will be included alongside + /// the `eth` protocol. + /// + /// See also [ProtocolHandler](reth_network::protocol::ProtocolHandler) + fn add_rlpx_sub_protocol(&mut self, protocol: impl IntoRlpxSubProtocol); +} + +impl RethNetworkConfig for reth_network::NetworkManager { + fn add_rlpx_sub_protocol(&mut self, protocol: impl IntoRlpxSubProtocol) { + reth_network::NetworkManager::add_rlpx_sub_protocol(self, protocol); + } +} diff --git a/bin/reth/src/cli/ext.rs b/bin/reth/src/cli/ext.rs index c48778f2c3af..352997527396 100644 --- a/bin/reth/src/cli/ext.rs +++ b/bin/reth/src/cli/ext.rs @@ -10,7 +10,7 @@ use reth_payload_builder::{PayloadBuilderHandle, PayloadBuilderService}; use reth_tasks::TaskSpawner; use std::{fmt, marker::PhantomData}; -use crate::cli::components::RethRpcServerHandles; +use crate::cli::{components::RethRpcServerHandles, config::RethNetworkConfig}; /// A trait that allows for extending parts of the CLI with additional functionality. /// @@ -35,12 +35,30 @@ impl RethCliExt for () { /// /// The functions are invoked during the initialization of the node command in the following order: /// -/// 1. [on_components_initialized](RethNodeCommandConfig::on_components_initialized) -/// 2. [spawn_payload_builder_service](RethNodeCommandConfig::spawn_payload_builder_service) -/// 3. [extend_rpc_modules](RethNodeCommandConfig::extend_rpc_modules) -/// 4. [on_rpc_server_started](RethNodeCommandConfig::on_rpc_server_started) -/// 5. [on_node_started](RethNodeCommandConfig::on_node_started) +/// 1. [configure_network](RethNodeCommandConfig::configure_network) +/// 2. [on_components_initialized](RethNodeCommandConfig::on_components_initialized) +/// 3. [spawn_payload_builder_service](RethNodeCommandConfig::spawn_payload_builder_service) +/// 4. [extend_rpc_modules](RethNodeCommandConfig::extend_rpc_modules) +/// 5. [on_rpc_server_started](RethNodeCommandConfig::on_rpc_server_started) +/// 6. [on_node_started](RethNodeCommandConfig::on_node_started) pub trait RethNodeCommandConfig: fmt::Debug { + /// Invoked with the network configuration before the network is configured. + /// + /// This allows additional configuration of the network before it is launched. + fn configure_network( + &mut self, + config: &mut Conf, + components: &Reth, + ) -> eyre::Result<()> + where + Conf: RethNetworkConfig, + Reth: RethNodeComponents, + { + let _ = config; + let _ = components; + Ok(()) + } + /// Event hook called once all components have been initialized. /// /// This is called as soon as the node components have been initialized. @@ -224,6 +242,22 @@ impl NoArgs { } impl RethNodeCommandConfig for NoArgs { + fn configure_network( + &mut self, + config: &mut Conf, + components: &Reth, + ) -> eyre::Result<()> + where + Conf: RethNetworkConfig, + Reth: RethNodeComponents, + { + if let Some(conf) = self.inner_mut() { + conf.configure_network(config, components) + } else { + Ok(()) + } + } + fn on_components_initialized( &mut self, components: &Reth, diff --git a/bin/reth/src/db/snapshots/bench.rs b/bin/reth/src/db/snapshots/bench.rs index 2505b23d4015..928898205f07 100644 --- a/bin/reth/src/db/snapshots/bench.rs +++ b/bin/reth/src/db/snapshots/bench.rs @@ -25,7 +25,7 @@ pub(crate) fn bench( ) -> eyre::Result<()> where F1: FnMut() -> eyre::Result, - F2: Fn(DatabaseProviderRO<'_, DatabaseEnv>) -> eyre::Result, + F2: Fn(DatabaseProviderRO) -> eyre::Result, R: Debug + PartialEq, { let (db, chain) = db; diff --git a/bin/reth/src/db/snapshots/headers.rs b/bin/reth/src/db/snapshots/headers.rs index e4537cd6c3da..d05ff80c8c9d 100644 --- a/bin/reth/src/db/snapshots/headers.rs +++ b/bin/reth/src/db/snapshots/headers.rs @@ -22,7 +22,7 @@ use std::{ impl Command { pub(crate) fn generate_headers_snapshot( &self, - provider: &DatabaseProviderRO<'_, DB>, + provider: &DatabaseProviderRO, compression: Compression, inclusion_filter: InclusionFilter, phf: PerfectHashingFunction, diff --git a/bin/reth/src/db/snapshots/receipts.rs b/bin/reth/src/db/snapshots/receipts.rs index dc8708ac0403..b24eccda51d8 100644 --- a/bin/reth/src/db/snapshots/receipts.rs +++ b/bin/reth/src/db/snapshots/receipts.rs @@ -22,7 +22,7 @@ use std::{ impl Command { pub(crate) fn generate_receipts_snapshot( &self, - provider: &DatabaseProviderRO<'_, DB>, + provider: &DatabaseProviderRO, compression: Compression, inclusion_filter: InclusionFilter, phf: PerfectHashingFunction, diff --git a/bin/reth/src/db/snapshots/transactions.rs b/bin/reth/src/db/snapshots/transactions.rs index 00c06102e8d7..94a61d262a8e 100644 --- a/bin/reth/src/db/snapshots/transactions.rs +++ b/bin/reth/src/db/snapshots/transactions.rs @@ -22,7 +22,7 @@ use std::{ impl Command { pub(crate) fn generate_transactions_snapshot( &self, - provider: &DatabaseProviderRO<'_, DB>, + provider: &DatabaseProviderRO, compression: Compression, inclusion_filter: InclusionFilter, phf: PerfectHashingFunction, diff --git a/bin/reth/src/debug_cmd/build_block.rs b/bin/reth/src/debug_cmd/build_block.rs index 082c9a706727..e1d57e4b7d93 100644 --- a/bin/reth/src/debug_cmd/build_block.rs +++ b/bin/reth/src/debug_cmd/build_block.rs @@ -32,7 +32,7 @@ use reth_provider::{ providers::BlockchainProvider, BlockHashReader, BlockReader, BlockWriter, ExecutorFactory, ProviderFactory, StageCheckpointReader, StateProviderFactory, }; -use reth_revm::Factory; +use reth_revm::EvmProcessorFactory; use reth_rpc_types::engine::{BlobsBundleV1, PayloadAttributes}; use reth_transaction_pool::{ blobstore::InMemoryBlobStore, BlobStore, EthPooledTransaction, PoolConfig, TransactionOrigin, @@ -142,15 +142,15 @@ impl Command { // initialize the database let db = Arc::new(init_db(db_path, self.db.log_level)?); + let provider_factory = ProviderFactory::new(Arc::clone(&db), Arc::clone(&self.chain)); let consensus: Arc = Arc::new(BeaconConsensus::new(Arc::clone(&self.chain))); // configure blockchain tree let tree_externals = TreeExternals::new( - Arc::clone(&db), + provider_factory.clone(), Arc::clone(&consensus), - Factory::new(self.chain.clone()), - Arc::clone(&self.chain), + EvmProcessorFactory::new(self.chain.clone()), ); let tree = BlockchainTree::new(tree_externals, BlockchainTreeConfig::default(), None)?; let blockchain_tree = ShareableBlockchainTree::new(tree); @@ -159,8 +159,8 @@ impl Command { let best_block = self.lookup_best_block(Arc::clone(&db)).wrap_err("the head block is missing")?; - let factory = ProviderFactory::new(Arc::clone(&db), Arc::clone(&self.chain)); - let blockchain_db = BlockchainProvider::new(factory.clone(), blockchain_tree.clone())?; + let blockchain_db = + BlockchainProvider::new(provider_factory.clone(), blockchain_tree.clone())?; let blob_store = InMemoryBlobStore::default(); let validator = TransactionValidationTaskExecutor::eth_builder(Arc::clone(&self.chain)) @@ -267,7 +267,7 @@ impl Command { let block_with_senders = SealedBlockWithSenders::new(block.clone(), senders).unwrap(); - let executor_factory = Factory::new(self.chain.clone()); + let executor_factory = EvmProcessorFactory::new(self.chain.clone()); let mut executor = executor_factory.with_state(blockchain_db.latest()?); executor.execute_and_verify_receipt( &block_with_senders.block.clone().unseal(), @@ -278,7 +278,7 @@ impl Command { debug!(target: "reth::cli", ?state, "Executed block"); // Attempt to insert new block without committing - let provider_rw = factory.provider_rw()?; + let provider_rw = provider_factory.provider_rw()?; provider_rw.append_blocks_with_bundle_state( Vec::from([block_with_senders]), state, diff --git a/bin/reth/src/debug_cmd/execution.rs b/bin/reth/src/debug_cmd/execution.rs index fee6390d2f9e..aa1745191172 100644 --- a/bin/reth/src/debug_cmd/execution.rs +++ b/bin/reth/src/debug_cmd/execution.rs @@ -27,13 +27,10 @@ use reth_interfaces::{ use reth_network::{NetworkEvents, NetworkHandle}; use reth_network_api::NetworkInfo; use reth_primitives::{fs, stage::StageId, BlockHashOrNumber, BlockNumber, ChainSpec, B256}; -use reth_provider::{BlockExecutionWriter, ProviderFactory, StageCheckpointReader}; +use reth_provider::{BlockExecutionWriter, HeaderSyncMode, ProviderFactory, StageCheckpointReader}; use reth_stages::{ sets::DefaultStages, - stages::{ - ExecutionStage, ExecutionStageThresholds, HeaderSyncMode, SenderRecoveryStage, - TotalDifficultyStage, - }, + stages::{ExecutionStage, ExecutionStageThresholds, SenderRecoveryStage, TotalDifficultyStage}, Pipeline, StageSet, }; use reth_tasks::TaskExecutor; @@ -92,7 +89,7 @@ impl Command { config: &Config, client: Client, consensus: Arc, - db: DB, + provider_factory: ProviderFactory, task_executor: &TaskExecutor, ) -> eyre::Result> where @@ -105,19 +102,20 @@ impl Command { .into_task_with(task_executor); let body_downloader = BodiesDownloaderBuilder::from(config.stages.bodies) - .build(client, Arc::clone(&consensus), db.clone()) + .build(client, Arc::clone(&consensus), provider_factory.clone()) .into_task_with(task_executor); let stage_conf = &config.stages; let (tip_tx, tip_rx) = watch::channel(B256::ZERO); - let factory = reth_revm::Factory::new(self.chain.clone()); + let factory = reth_revm::EvmProcessorFactory::new(self.chain.clone()); let header_mode = HeaderSyncMode::Tip(tip_rx); let pipeline = Pipeline::builder() .with_tip_sender(tip_tx) .add_stages( DefaultStages::new( + provider_factory.clone(), header_mode, Arc::clone(&consensus), header_downloader, @@ -146,7 +144,7 @@ impl Command { config.prune.as_ref().map(|prune| prune.segments.clone()).unwrap_or_default(), )), ) - .build(db, self.chain.clone()); + .build(provider_factory); Ok(pipeline) } @@ -204,6 +202,7 @@ impl Command { let db_path = data_dir.db_path(); fs::create_dir_all(&db_path)?; let db = Arc::new(init_db(db_path, self.db.log_level)?); + let provider_factory = ProviderFactory::new(db.clone(), self.chain.clone()); debug!(target: "reth::cli", chain=%self.chain.chain, genesis=?self.chain.genesis_hash(), "Initializing genesis"); init_genesis(db.clone(), self.chain.clone())?; @@ -229,12 +228,11 @@ impl Command { &config, fetch_client.clone(), Arc::clone(&consensus), - db.clone(), + provider_factory.clone(), &ctx.task_executor, )?; - let factory = ProviderFactory::new(&db, self.chain.clone()); - let provider = factory.provider()?; + let provider = provider_factory.provider()?; let latest_block_number = provider.get_stage_checkpoint(StageId::Finish)?.map(|ch| ch.block_number); @@ -250,7 +248,7 @@ impl Command { ); ctx.task_executor.spawn_critical( "events task", - events::handle_events(Some(network.clone()), latest_block_number, events), + events::handle_events(Some(network.clone()), latest_block_number, events, db.clone()), ); let mut current_max_block = latest_block_number.unwrap_or_default(); @@ -268,7 +266,7 @@ impl Command { // Unwind the pipeline without committing. { - factory + provider_factory .provider_rw()? .take_block_and_execution_range(&self.chain, next_block..=target_block)?; } diff --git a/bin/reth/src/debug_cmd/in_memory_merkle.rs b/bin/reth/src/debug_cmd/in_memory_merkle.rs index 81db51c5ce63..223a104efabf 100644 --- a/bin/reth/src/debug_cmd/in_memory_merkle.rs +++ b/bin/reth/src/debug_cmd/in_memory_merkle.rs @@ -159,7 +159,7 @@ impl Command { ) .await?; - let executor_factory = reth_revm::Factory::new(self.chain.clone()); + let executor_factory = reth_revm::EvmProcessorFactory::new(self.chain.clone()); let mut executor = executor_factory.with_state(LatestStateProviderRef::new(provider.tx_ref())); diff --git a/bin/reth/src/debug_cmd/merkle.rs b/bin/reth/src/debug_cmd/merkle.rs index dc5f98e59eb3..fd6a9c0c5109 100644 --- a/bin/reth/src/debug_cmd/merkle.rs +++ b/bin/reth/src/debug_cmd/merkle.rs @@ -197,7 +197,7 @@ impl Command { checkpoint.stage_checkpoint.is_some() }); - let factory = reth_revm::Factory::new(self.chain.clone()); + let factory = reth_revm::EvmProcessorFactory::new(self.chain.clone()); let mut execution_stage = ExecutionStage::new( factory, ExecutionStageThresholds { @@ -222,53 +222,42 @@ impl Command { None }; - execution_stage - .execute( + execution_stage.execute( + &provider_rw, + ExecInput { + target: Some(block), + checkpoint: block.checked_sub(1).map(StageCheckpoint::new), + }, + )?; + + let mut account_hashing_done = false; + while !account_hashing_done { + let output = account_hashing_stage.execute( &provider_rw, ExecInput { target: Some(block), - checkpoint: block.checked_sub(1).map(StageCheckpoint::new), + checkpoint: progress.map(StageCheckpoint::new), }, - ) - .await?; - - let mut account_hashing_done = false; - while !account_hashing_done { - let output = account_hashing_stage - .execute( - &provider_rw, - ExecInput { - target: Some(block), - checkpoint: progress.map(StageCheckpoint::new), - }, - ) - .await?; + )?; account_hashing_done = output.done; } let mut storage_hashing_done = false; while !storage_hashing_done { - let output = storage_hashing_stage - .execute( - &provider_rw, - ExecInput { - target: Some(block), - checkpoint: progress.map(StageCheckpoint::new), - }, - ) - .await?; - storage_hashing_done = output.done; - } - - let incremental_result = merkle_stage - .execute( + let output = storage_hashing_stage.execute( &provider_rw, ExecInput { target: Some(block), checkpoint: progress.map(StageCheckpoint::new), }, - ) - .await; + )?; + storage_hashing_done = output.done; + } + + let incremental_result = merkle_stage.execute( + &provider_rw, + ExecInput { target: Some(block), checkpoint: progress.map(StageCheckpoint::new) }, + ); if incremental_result.is_err() { tracing::warn!(target: "reth::cli", block, "Incremental calculation failed, retrying from scratch"); @@ -285,7 +274,7 @@ impl Command { let clean_input = ExecInput { target: Some(block), checkpoint: None }; loop { - let clean_result = merkle_stage.execute(&provider_rw, clean_input).await; + let clean_result = merkle_stage.execute(&provider_rw, clean_input); assert!(clean_result.is_ok(), "Clean state root calculation failed"); if clean_result.unwrap().done { break diff --git a/bin/reth/src/init.rs b/bin/reth/src/init.rs index 6b3d638040ed..04b29036226f 100644 --- a/bin/reth/src/init.rs +++ b/bin/reth/src/init.rs @@ -1,7 +1,7 @@ //! Reth genesis initialization utility functions. use reth_db::{ cursor::DbCursorRO, - database::{Database, DatabaseGAT}, + database::Database, tables, transaction::{DbTx, DbTxMut}, }; @@ -94,7 +94,7 @@ pub fn init_genesis( /// Inserts the genesis state into the database. pub fn insert_genesis_state( - tx: &>::TXMut, + tx: &::TXMut, genesis: &reth_primitives::Genesis, ) -> ProviderResult<()> { let mut state_init: BundleStateInit = HashMap::new(); @@ -160,7 +160,7 @@ pub fn insert_genesis_state( /// Inserts hashes for the genesis state. pub fn insert_genesis_hashes( - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW<&DB>, genesis: &reth_primitives::Genesis, ) -> ProviderResult<()> { // insert and hash accounts to hashing table @@ -184,7 +184,7 @@ pub fn insert_genesis_hashes( /// Inserts history indices for genesis accounts and storage. pub fn insert_genesis_history( - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW<&DB>, genesis: &reth_primitives::Genesis, ) -> ProviderResult<()> { let account_transitions = @@ -204,7 +204,7 @@ pub fn insert_genesis_history( /// Inserts header for the genesis state. pub fn insert_genesis_header( - tx: &>::TXMut, + tx: &::TXMut, chain: Arc, ) -> ProviderResult<()> { let header = chain.sealed_genesis_header(); @@ -236,7 +236,7 @@ mod tests { #[allow(clippy::type_complexity)] fn collect_table_entries( - tx: &>::TX, + tx: &::TX, ) -> Result>, InitDatabaseError> where DB: Database, diff --git a/bin/reth/src/node/events.rs b/bin/reth/src/node/events.rs index 8b5d7c76ad6a..0fb17f43b33c 100644 --- a/bin/reth/src/node/events.rs +++ b/bin/reth/src/node/events.rs @@ -3,6 +3,7 @@ use crate::node::cl_events::ConsensusLayerHealthEvent; use futures::Stream; use reth_beacon_consensus::BeaconConsensusEngineEvent; +use reth_db::DatabaseEnv; use reth_interfaces::consensus::ForkchoiceState; use reth_network::{NetworkEvent, NetworkHandle}; use reth_network_api::PeersInfo; @@ -13,8 +14,10 @@ use reth_primitives::{ use reth_prune::PrunerEvent; use reth_stages::{ExecOutput, PipelineEvent}; use std::{ + fmt::{Display, Formatter}, future::Future, pin::Pin, + sync::Arc, task::{Context, Poll}, time::{Duration, Instant}, }; @@ -26,27 +29,25 @@ const INFO_MESSAGE_INTERVAL: Duration = Duration::from_secs(25); /// The current high-level state of the node. struct NodeState { + /// Database environment. + /// Used for freelist calculation reported in the "Status" log message. + /// See [EventHandler::poll]. + db: Arc, /// Connection to the network. network: Option, /// The stage currently being executed. - current_stage: Option, - /// The ETA for the current stage. - eta: Eta, - /// The current checkpoint of the executing stage. - current_checkpoint: StageCheckpoint, + current_stage: Option, /// The latest block reached by either pipeline or consensus engine. latest_block: Option, } impl NodeState { - fn new(network: Option, latest_block: Option) -> Self { - Self { - network, - current_stage: None, - eta: Eta::default(), - current_checkpoint: StageCheckpoint::new(0), - latest_block, - } + fn new( + db: Arc, + network: Option, + latest_block: Option, + ) -> Self { + Self { db, network, current_stage: None, latest_block } } fn num_connected_peers(&self) -> usize { @@ -56,70 +57,80 @@ impl NodeState { /// Processes an event emitted by the pipeline fn handle_pipeline_event(&mut self, event: PipelineEvent) { match event { - PipelineEvent::Running { pipeline_stages_progress, stage_id, checkpoint } => { - let notable = self.current_stage.is_none(); - self.current_stage = Some(stage_id); - self.current_checkpoint = checkpoint.unwrap_or_default(); + PipelineEvent::Run { pipeline_stages_progress, stage_id, checkpoint, target } => { + let checkpoint = checkpoint.unwrap_or_default(); + let current_stage = CurrentStage { + stage_id, + eta: match &self.current_stage { + Some(current_stage) if current_stage.stage_id == stage_id => { + current_stage.eta + } + _ => Eta::default(), + }, + checkpoint, + target, + }; + + let progress = OptionalField( + checkpoint.entities().and_then(|entities| entities.fmt_percentage()), + ); + let eta = current_stage.eta.fmt_for_stage(stage_id); - if notable { - if let Some(progress) = self.current_checkpoint.entities() { - info!( - pipeline_stages = %pipeline_stages_progress, - stage = %stage_id, - from = self.current_checkpoint.block_number, - checkpoint = %self.current_checkpoint.block_number, - %progress, - eta = %self.eta.fmt_for_stage(stage_id), - "Executing stage", - ); - } else { - info!( - pipeline_stages = %pipeline_stages_progress, - stage = %stage_id, - from = self.current_checkpoint.block_number, - checkpoint = %self.current_checkpoint.block_number, - eta = %self.eta.fmt_for_stage(stage_id), - "Executing stage", - ); - } - } + info!( + pipeline_stages = %pipeline_stages_progress, + stage = %stage_id, + checkpoint = %checkpoint.block_number, + target = %OptionalField(target), + %progress, + %eta, + "Executing stage", + ); + + self.current_stage = Some(current_stage); } PipelineEvent::Ran { pipeline_stages_progress, stage_id, result: ExecOutput { checkpoint, done }, } => { - self.current_checkpoint = checkpoint; if stage_id.is_finish() { self.latest_block = Some(checkpoint.block_number); } - self.eta.update(self.current_checkpoint); - - let message = - if done { "Stage finished executing" } else { "Stage committed progress" }; - - if let Some(progress) = checkpoint.entities() { - info!( - pipeline_stages = %pipeline_stages_progress, - stage = %stage_id, - checkpoint = %checkpoint.block_number, - %progress, - eta = %self.eta.fmt_for_stage(stage_id), - "{message}", - ); - } else { - info!( - pipeline_stages = %pipeline_stages_progress, - stage = %stage_id, - checkpoint = %checkpoint.block_number, - eta = %self.eta.fmt_for_stage(stage_id), - "{message}", + + if let Some(current_stage) = self.current_stage.as_mut() { + current_stage.checkpoint = checkpoint; + current_stage.eta.update(checkpoint); + + let target = OptionalField(current_stage.target); + let progress = OptionalField( + checkpoint.entities().and_then(|entities| entities.fmt_percentage()), ); + + if done { + info!( + pipeline_stages = %pipeline_stages_progress, + stage = %stage_id, + checkpoint = %checkpoint.block_number, + %target, + %progress, + "Stage finished executing", + ) + } else { + let eta = current_stage.eta.fmt_for_stage(stage_id); + info!( + pipeline_stages = %pipeline_stages_progress, + stage = %stage_id, + checkpoint = %checkpoint.block_number, + %target, + %progress, + %eta, + "Stage committed progress", + ) + } } if done { self.current_stage = None; - self.eta = Eta::default(); } } _ => (), @@ -189,6 +200,29 @@ impl NodeState { } } +/// Helper type for formatting of optional fields: +/// - If [Some(x)], then `x` is written +/// - If [None], then `None` is written +struct OptionalField(Option); + +impl Display for OptionalField { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + if let Some(field) = &self.0 { + write!(f, "{field}") + } else { + write!(f, "None") + } + } +} + +/// The stage currently being executed. +struct CurrentStage { + stage_id: StageId, + eta: Eta, + checkpoint: StageCheckpoint, + target: Option, +} + /// A node event. #[derive(Debug)] pub enum NodeEvent { @@ -240,10 +274,11 @@ pub async fn handle_events( network: Option, latest_block_number: Option, events: E, + db: Arc, ) where E: Stream + Unpin, { - let state = NodeState::new(network, latest_block_number); + let state = NodeState::new(db, network, latest_block_number); let start = tokio::time::Instant::now() + Duration::from_secs(3); let mut info_interval = tokio::time::interval_at(start, INFO_MESSAGE_INTERVAL); @@ -273,32 +308,40 @@ where let mut this = self.project(); while this.info_interval.poll_tick(cx).is_ready() { - if let Some(stage) = this.state.current_stage { - if let Some(progress) = this.state.current_checkpoint.entities() { - info!( - target: "reth::cli", - connected_peers = this.state.num_connected_peers(), - %stage, - checkpoint = %this.state.current_checkpoint.block_number, - %progress, - eta = %this.state.eta.fmt_for_stage(stage), - "Status" - ); - } else { - info!( - target: "reth::cli", - connected_peers = this.state.num_connected_peers(), - %stage, - checkpoint = %this.state.current_checkpoint.block_number, - eta = %this.state.eta.fmt_for_stage(stage), - "Status" - ); - } + let freelist = OptionalField(this.state.db.freelist().ok()); + + if let Some(CurrentStage { stage_id, eta, checkpoint, target }) = + &this.state.current_stage + { + let progress = OptionalField( + checkpoint.entities().and_then(|entities| entities.fmt_percentage()), + ); + let eta = eta.fmt_for_stage(*stage_id); + + info!( + target: "reth::cli", + connected_peers = this.state.num_connected_peers(), + %freelist, + stage = %stage_id, + checkpoint = checkpoint.block_number, + target = %OptionalField(*target), + %progress, + %eta, + "Status" + ); + } else if let Some(latest_block) = this.state.latest_block { + info!( + target: "reth::cli", + connected_peers = this.state.num_connected_peers(), + %freelist, + %latest_block, + "Status" + ); } else { info!( target: "reth::cli", connected_peers = this.state.num_connected_peers(), - latest_block = this.state.latest_block.unwrap_or(this.state.current_checkpoint.block_number), + %freelist, "Status" ); } @@ -332,7 +375,7 @@ where /// checkpoints reported by the pipeline. /// /// One `Eta` is only valid for a single stage. -#[derive(Default)] +#[derive(Default, Copy, Clone)] struct Eta { /// The last stage checkpoint last_checkpoint: EntitiesCheckpoint, @@ -375,8 +418,8 @@ impl Eta { } } -impl std::fmt::Display for Eta { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl Display for Eta { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { if let Some((eta, last_checkpoint_time)) = self.eta.zip(self.last_checkpoint_time) { let remaining = eta.checked_sub(last_checkpoint_time.elapsed()); diff --git a/bin/reth/src/node/mod.rs b/bin/reth/src/node/mod.rs index 3db510564a2a..42d3e0136f59 100644 --- a/bin/reth/src/node/mod.rs +++ b/bin/reth/src/node/mod.rs @@ -49,9 +49,7 @@ use reth_interfaces::{ }, RethResult, }; -use reth_network::{ - error::NetworkError, NetworkConfig, NetworkEvents, NetworkHandle, NetworkManager, -}; +use reth_network::{NetworkBuilder, NetworkConfig, NetworkEvents, NetworkHandle, NetworkManager}; use reth_network_api::{NetworkInfo, PeersInfo}; use reth_primitives::{ constants::eip4844::{LoadKzgSettingsError, MAINNET_KZG_TRUSTED_SETUP}, @@ -61,19 +59,19 @@ use reth_primitives::{ }; use reth_provider::{ providers::BlockchainProvider, BlockHashReader, BlockReader, CanonStateSubscriptions, - HeaderProvider, ProviderFactory, StageCheckpointReader, + HeaderProvider, HeaderSyncMode, ProviderFactory, StageCheckpointReader, }; use reth_prune::{segments::SegmentSet, Pruner}; -use reth_revm::Factory; +use reth_revm::EvmProcessorFactory; use reth_revm_inspectors::stack::Hook; use reth_rpc_engine_api::EngineApi; use reth_snapshot::HighestSnapshotsTracker; use reth_stages::{ prelude::*, stages::{ - AccountHashingStage, ExecutionStage, ExecutionStageThresholds, HeaderSyncMode, - IndexAccountHistoryStage, IndexStorageHistoryStage, MerkleStage, SenderRecoveryStage, - StorageHashingStage, TotalDifficultyStage, TransactionLookupStage, + AccountHashingStage, ExecutionStage, ExecutionStageThresholds, IndexAccountHistoryStage, + IndexStorageHistoryStage, MerkleStage, SenderRecoveryStage, StorageHashingStage, + TotalDifficultyStage, TransactionLookupStage, }, }; use reth_tasks::TaskExecutor; @@ -188,6 +186,7 @@ pub struct NodeCommand { /// Additional cli arguments #[clap(flatten)] + #[clap(next_help_heading = "Extension")] pub ext: Ext::Node, } @@ -258,6 +257,18 @@ impl NodeCommand { let db = Arc::new(init_db(&db_path, self.db.log_level)?.with_metrics()); info!(target: "reth::cli", "Database opened"); + let mut provider_factory = ProviderFactory::new(Arc::clone(&db), Arc::clone(&self.chain)); + + // configure snapshotter + let snapshotter = reth_snapshot::Snapshotter::new( + provider_factory.clone(), + data_dir.snapshots_path(), + self.chain.snapshot_block_interval, + )?; + + provider_factory = provider_factory + .with_snapshots(data_dir.snapshots_path(), snapshotter.highest_snapshot_receiver()); + self.start_metrics_endpoint(prometheus_handle, Arc::clone(&db)).await?; debug!(target: "reth::cli", chain=%self.chain.chain, genesis=?self.chain.genesis_hash(), "Initializing genesis"); @@ -280,10 +291,9 @@ impl NodeCommand { // configure blockchain tree let tree_externals = TreeExternals::new( - Arc::clone(&db), + provider_factory.clone(), Arc::clone(&consensus), - Factory::new(self.chain.clone()), - Arc::clone(&self.chain), + EvmProcessorFactory::new(self.chain.clone()), ); let tree = BlockchainTree::new( tree_externals, @@ -298,18 +308,9 @@ impl NodeCommand { // fetch the head block from the database let head = self.lookup_head(Arc::clone(&db)).wrap_err("the head block is missing")?; - // configure snapshotter - let snapshotter = reth_snapshot::Snapshotter::new( - db.clone(), - data_dir.snapshots_path(), - self.chain.clone(), - self.chain.snapshot_block_interval, - )?; - // setup the blockchain provider - let factory = ProviderFactory::new(Arc::clone(&db), Arc::clone(&self.chain)) - .with_snapshots(data_dir.snapshots_path(), snapshotter.highest_snapshot_receiver()); - let blockchain_db = BlockchainProvider::new(factory, blockchain_tree.clone())?; + let blockchain_db = + BlockchainProvider::new(provider_factory.clone(), blockchain_tree.clone())?; let blob_store = InMemoryBlobStore::default(); let validator = TransactionValidationTaskExecutor::eth_builder(Arc::clone(&self.chain)) .with_head_timestamp(head.timestamp) @@ -353,25 +354,34 @@ impl NodeCommand { secret_key, default_peers_path.clone(), ); - let network = self - .start_network( - network_config, - &ctx.task_executor, - transaction_pool.clone(), - default_peers_path, - ) - .await?; - info!(target: "reth::cli", peer_id = %network.peer_id(), local_addr = %network.local_addr(), enode = %network.local_node_record(), "Connected to P2P network"); - debug!(target: "reth::cli", peer_id = ?network.peer_id(), "Full peer ID"); - let network_client = network.fetch_client().await?; + + let network_client = network_config.client.clone(); + let mut network_builder = NetworkManager::builder(network_config).await?; let components = RethNodeComponentsImpl { provider: blockchain_db.clone(), pool: transaction_pool.clone(), - network: network.clone(), + network: network_builder.handle(), task_executor: ctx.task_executor.clone(), events: blockchain_db.clone(), }; + + // allow network modifications + self.ext.configure_network(network_builder.network_mut(), &components)?; + + // launch network + let network = self.start_network( + network_builder, + &ctx.task_executor, + transaction_pool.clone(), + network_client, + default_peers_path, + ); + + info!(target: "reth::cli", peer_id = %network.peer_id(), local_addr = %network.local_addr(), enode = %network.local_node_record(), "Connected to P2P network"); + debug!(target: "reth::cli", peer_id = ?network.peer_id(), "Full peer ID"); + let network_client = network.fetch_client().await?; + self.ext.on_components_initialized(&components)?; debug!(target: "reth::cli", "Spawning payload builder service"); @@ -417,7 +427,7 @@ impl NodeCommand { &config, client.clone(), Arc::clone(&consensus), - db.clone(), + provider_factory, &ctx.task_executor, sync_metrics_tx, prune_config.clone(), @@ -437,7 +447,7 @@ impl NodeCommand { &config, network_client.clone(), Arc::clone(&consensus), - db.clone(), + provider_factory, &ctx.task_executor, sync_metrics_tx, prune_config.clone(), @@ -515,7 +525,7 @@ impl NodeCommand { ); ctx.task_executor.spawn_critical( "events task", - events::handle_events(Some(network.clone()), Some(head.number), events), + events::handle_events(Some(network.clone()), Some(head.number), events, db.clone()), ); let engine_api = EngineApi::new( @@ -601,7 +611,7 @@ impl NodeCommand { config: &Config, client: Client, consensus: Arc, - db: DB, + provider_factory: ProviderFactory, task_executor: &TaskExecutor, metrics_tx: reth_stages::MetricEventsSender, prune_config: Option, @@ -617,12 +627,12 @@ impl NodeCommand { .into_task_with(task_executor); let body_downloader = BodiesDownloaderBuilder::from(config.stages.bodies) - .build(client, Arc::clone(&consensus), db.clone()) + .build(client, Arc::clone(&consensus), provider_factory.clone()) .into_task_with(task_executor); let pipeline = self .build_pipeline( - db, + provider_factory, config, header_downloader, body_downloader, @@ -691,33 +701,30 @@ impl NodeCommand { /// Spawns the configured network and associated tasks and returns the [NetworkHandle] connected /// to that network. - async fn start_network( + fn start_network( &self, - config: NetworkConfig, + builder: NetworkBuilder, task_executor: &TaskExecutor, pool: Pool, + client: C, default_peers_path: PathBuf, - ) -> Result + ) -> NetworkHandle where C: BlockReader + HeaderProvider + Clone + Unpin + 'static, Pool: TransactionPool + Unpin + 'static, { - let client = config.client.clone(); - let (handle, network, txpool, eth) = NetworkManager::builder(config) - .await? - .transactions(pool) - .request_handler(client) - .split_with_handle(); + let (handle, network, txpool, eth) = + builder.transactions(pool).request_handler(client).split_with_handle(); task_executor.spawn_critical("p2p txpool", txpool); task_executor.spawn_critical("p2p eth request handler", eth); let known_peers_file = self.network.persistent_peers_file(default_peers_path); - task_executor.spawn_critical_with_signal("p2p network task", |shutdown| { + task_executor.spawn_critical_with_shutdown_signal("p2p network task", |shutdown| { run_network_until_shutdown(shutdown, network, known_peers_file) }); - Ok(handle) + handle } /// Fetches the head block from the database. @@ -844,7 +851,7 @@ impl NodeCommand { #[allow(clippy::too_many_arguments)] async fn build_pipeline( &self, - db: DB, + provider_factory: ProviderFactory, config: &Config, header_downloader: H, body_downloader: B, @@ -870,7 +877,7 @@ impl NodeCommand { let (tip_tx, tip_rx) = watch::channel(B256::ZERO); use reth_revm_inspectors::stack::InspectorStackConfig; - let factory = reth_revm::Factory::new(self.chain.clone()); + let factory = reth_revm::EvmProcessorFactory::new(self.chain.clone()); let stack_config = InspectorStackConfig { use_printer_tracer: self.debug.print_inspector, @@ -896,6 +903,7 @@ impl NodeCommand { .with_metrics_tx(metrics_tx.clone()) .add_stages( DefaultStages::new( + provider_factory.clone(), header_mode, Arc::clone(&consensus), header_downloader, @@ -948,7 +956,7 @@ impl NodeCommand { prune_modes.storage_history, )), ) - .build(db, self.chain.clone()); + .build(provider_factory); Ok(pipeline) } diff --git a/bin/reth/src/stage/dump/execution.rs b/bin/reth/src/stage/dump/execution.rs index 67eda8033cc7..d0ce96ce5dfc 100644 --- a/bin/reth/src/stage/dump/execution.rs +++ b/bin/reth/src/stage/dump/execution.rs @@ -7,7 +7,7 @@ use reth_db::{ }; use reth_primitives::{stage::StageCheckpoint, ChainSpec}; use reth_provider::ProviderFactory; -use reth_revm::Factory; +use reth_revm::EvmProcessorFactory; use reth_stages::{stages::ExecutionStage, Stage, UnwindInput}; use std::{path::PathBuf, sync::Arc}; use tracing::info; @@ -98,18 +98,17 @@ async fn unwind_and_copy( let factory = ProviderFactory::new(db_tool.db, db_tool.chain.clone()); let provider = factory.provider_rw()?; - let mut exec_stage = ExecutionStage::new_with_factory(Factory::new(db_tool.chain.clone())); + let mut exec_stage = + ExecutionStage::new_with_factory(EvmProcessorFactory::new(db_tool.chain.clone())); - exec_stage - .unwind( - &provider, - UnwindInput { - unwind_to: from, - checkpoint: StageCheckpoint::new(tip_block_number), - bad_block: None, - }, - ) - .await?; + exec_stage.unwind( + &provider, + UnwindInput { + unwind_to: from, + checkpoint: StageCheckpoint::new(tip_block_number), + bad_block: None, + }, + )?; let unwind_inner_tx = provider.into_tx(); @@ -131,20 +130,13 @@ async fn dry_run( info!(target: "reth::cli", "Executing stage. [dry-run]"); let factory = ProviderFactory::new(&output_db, chain.clone()); - let provider = factory.provider_rw()?; - let mut exec_stage = ExecutionStage::new_with_factory(Factory::new(chain.clone())); - - exec_stage - .execute( - &provider, - reth_stages::ExecInput { - target: Some(to), - checkpoint: Some(StageCheckpoint::new(from)), - }, - ) - .await?; + let mut exec_stage = ExecutionStage::new_with_factory(EvmProcessorFactory::new(chain.clone())); + + let input = + reth_stages::ExecInput { target: Some(to), checkpoint: Some(StageCheckpoint::new(from)) }; + exec_stage.execute(&factory.provider_rw()?, input)?; - info!(target: "reth::cli", "Success."); + info!(target: "reth::cli", "Success"); Ok(()) } diff --git a/bin/reth/src/stage/dump/hashing_account.rs b/bin/reth/src/stage/dump/hashing_account.rs index 2a947d013e63..7fe723257f69 100644 --- a/bin/reth/src/stage/dump/hashing_account.rs +++ b/bin/reth/src/stage/dump/hashing_account.rs @@ -22,7 +22,7 @@ pub(crate) async fn dump_hashing_account_stage( tx.import_table_with_range::(&db_tool.db.tx()?, Some(from), to) })??; - unwind_and_copy(db_tool, from, tip_block_number, &output_db).await?; + unwind_and_copy(db_tool, from, tip_block_number, &output_db)?; if should_run { dry_run(db_tool.chain.clone(), output_db, to, from).await?; @@ -32,7 +32,7 @@ pub(crate) async fn dump_hashing_account_stage( } /// Dry-run an unwind to FROM block and copy the necessary table data to the new database. -async fn unwind_and_copy( +fn unwind_and_copy( db_tool: &DbTool<'_, DB>, from: u64, tip_block_number: u64, @@ -42,16 +42,14 @@ async fn unwind_and_copy( let provider = factory.provider_rw()?; let mut exec_stage = AccountHashingStage::default(); - exec_stage - .unwind( - &provider, - UnwindInput { - unwind_to: from, - checkpoint: StageCheckpoint::new(tip_block_number), - bad_block: None, - }, - ) - .await?; + exec_stage.unwind( + &provider, + UnwindInput { + unwind_to: from, + checkpoint: StageCheckpoint::new(tip_block_number), + bad_block: None, + }, + )?; let unwind_inner_tx = provider.into_tx(); output_db.update(|tx| tx.import_table::(&unwind_inner_tx))??; @@ -70,23 +68,19 @@ async fn dry_run( let factory = ProviderFactory::new(&output_db, chain); let provider = factory.provider_rw()?; - let mut exec_stage = AccountHashingStage { + let mut stage = AccountHashingStage { clean_threshold: 1, // Forces hashing from scratch ..Default::default() }; - let mut exec_output = false; - while !exec_output { - exec_output = exec_stage - .execute( - &provider, - reth_stages::ExecInput { - target: Some(to), - checkpoint: Some(StageCheckpoint::new(from)), - }, - ) - .await? - .done; + loop { + let input = reth_stages::ExecInput { + target: Some(to), + checkpoint: Some(StageCheckpoint::new(from)), + }; + if stage.execute(&provider, input)?.done { + break + } } info!(target: "reth::cli", "Success."); diff --git a/bin/reth/src/stage/dump/hashing_storage.rs b/bin/reth/src/stage/dump/hashing_storage.rs index 0a8df0a6e44a..373818072529 100644 --- a/bin/reth/src/stage/dump/hashing_storage.rs +++ b/bin/reth/src/stage/dump/hashing_storage.rs @@ -17,7 +17,7 @@ pub(crate) async fn dump_hashing_storage_stage( ) -> Result<()> { let (output_db, tip_block_number) = setup(from, to, output_db, db_tool)?; - unwind_and_copy(db_tool, from, tip_block_number, &output_db).await?; + unwind_and_copy(db_tool, from, tip_block_number, &output_db)?; if should_run { dry_run(db_tool.chain.clone(), output_db, to, from).await?; @@ -27,7 +27,7 @@ pub(crate) async fn dump_hashing_storage_stage( } /// Dry-run an unwind to FROM block and copy the necessary table data to the new database. -async fn unwind_and_copy( +fn unwind_and_copy( db_tool: &DbTool<'_, DB>, from: u64, tip_block_number: u64, @@ -38,16 +38,14 @@ async fn unwind_and_copy( let mut exec_stage = StorageHashingStage::default(); - exec_stage - .unwind( - &provider, - UnwindInput { - unwind_to: from, - checkpoint: StageCheckpoint::new(tip_block_number), - bad_block: None, - }, - ) - .await?; + exec_stage.unwind( + &provider, + UnwindInput { + unwind_to: from, + checkpoint: StageCheckpoint::new(tip_block_number), + bad_block: None, + }, + )?; let unwind_inner_tx = provider.into_tx(); // TODO optimize we can actually just get the entries we need for both these tables @@ -69,23 +67,19 @@ async fn dry_run( let factory = ProviderFactory::new(&output_db, chain); let provider = factory.provider_rw()?; - let mut exec_stage = StorageHashingStage { + let mut stage = StorageHashingStage { clean_threshold: 1, // Forces hashing from scratch ..Default::default() }; - let mut exec_output = false; - while !exec_output { - exec_output = exec_stage - .execute( - &provider, - reth_stages::ExecInput { - target: Some(to), - checkpoint: Some(StageCheckpoint::new(from)), - }, - ) - .await? - .done; + loop { + let input = reth_stages::ExecInput { + target: Some(to), + checkpoint: Some(StageCheckpoint::new(from)), + }; + if stage.execute(&provider, input)?.done { + break + } } info!(target: "reth::cli", "Success."); diff --git a/bin/reth/src/stage/dump/merkle.rs b/bin/reth/src/stage/dump/merkle.rs index 55eef819f1c5..c57a85c597b9 100644 --- a/bin/reth/src/stage/dump/merkle.rs +++ b/bin/reth/src/stage/dump/merkle.rs @@ -61,14 +61,14 @@ async fn unwind_and_copy( // Unwind hashes all the way to FROM - StorageHashingStage::default().unwind(&provider, unwind).await.unwrap(); - AccountHashingStage::default().unwind(&provider, unwind).await.unwrap(); + StorageHashingStage::default().unwind(&provider, unwind).unwrap(); + AccountHashingStage::default().unwind(&provider, unwind).unwrap(); - MerkleStage::default_unwind().unwind(&provider, unwind).await?; + MerkleStage::default_unwind().unwind(&provider, unwind)?; // Bring Plainstate to TO (hashing stage execution requires it) let mut exec_stage = ExecutionStage::new( - reth_revm::Factory::new(db_tool.chain.clone()), + reth_revm::EvmProcessorFactory::new(db_tool.chain.clone()), ExecutionStageThresholds { max_blocks: Some(u64::MAX), max_changes: None, @@ -78,26 +78,21 @@ async fn unwind_and_copy( PruneModes::all(), ); - exec_stage - .unwind( - &provider, - UnwindInput { - unwind_to: to, - checkpoint: StageCheckpoint::new(tip_block_number), - bad_block: None, - }, - ) - .await?; + exec_stage.unwind( + &provider, + UnwindInput { + unwind_to: to, + checkpoint: StageCheckpoint::new(tip_block_number), + bad_block: None, + }, + )?; // Bring hashes to TO - AccountHashingStage { clean_threshold: u64::MAX, commit_threshold: u64::MAX } .execute(&provider, execute_input) - .await .unwrap(); StorageHashingStage { clean_threshold: u64::MAX, commit_threshold: u64::MAX } .execute(&provider, execute_input) - .await .unwrap(); let unwind_inner_tx = provider.into_tx(); @@ -123,25 +118,23 @@ async fn dry_run( info!(target: "reth::cli", "Executing stage."); let factory = ProviderFactory::new(&output_db, chain); let provider = factory.provider_rw()?; - let mut exec_output = false; - while !exec_output { - exec_output = MerkleStage::Execution { - clean_threshold: u64::MAX, /* Forces updating the root instead of calculating - * from - * scratch */ + + let mut stage = MerkleStage::Execution { + // Forces updating the root instead of calculating from scratch + clean_threshold: u64::MAX, + }; + + loop { + let input = reth_stages::ExecInput { + target: Some(to), + checkpoint: Some(StageCheckpoint::new(from)), + }; + if stage.execute(&provider, input)?.done { + break } - .execute( - &provider, - reth_stages::ExecInput { - target: Some(to), - checkpoint: Some(StageCheckpoint::new(from)), - }, - ) - .await? - .done; } - info!(target: "reth::cli", "Success."); + info!(target: "reth::cli", "Success"); Ok(()) } diff --git a/bin/reth/src/stage/run.rs b/bin/reth/src/stage/run.rs index c66792668371..d1b0cbf670cd 100644 --- a/bin/reth/src/stage/run.rs +++ b/bin/reth/src/stage/run.rs @@ -24,7 +24,7 @@ use reth_stages::{ IndexAccountHistoryStage, IndexStorageHistoryStage, MerkleStage, SenderRecoveryStage, StorageHashingStage, TransactionLookupStage, }, - ExecInput, ExecOutput, Stage, UnwindInput, + ExecInput, Stage, StageExt, UnwindInput, }; use std::{any::Any, net::SocketAddr, path::PathBuf, sync::Arc}; use tracing::*; @@ -124,7 +124,7 @@ impl Command { let db = Arc::new(init_db(db_path, self.db.log_level)?); info!(target: "reth::cli", "Database opened"); - let factory = ProviderFactory::new(&db, self.chain.clone()); + let factory = ProviderFactory::new(Arc::clone(&db), self.chain.clone()); let mut provider_rw = factory.provider_rw()?; if let Some(listen_addr) = self.metrics { @@ -162,6 +162,9 @@ impl Command { let default_peers_path = data_dir.known_peers_path(); + let provider_factory = + Arc::new(ProviderFactory::new(db.clone(), self.chain.clone())); + let network = self .network .network_config( @@ -170,13 +173,13 @@ impl Command { p2p_secret_key, default_peers_path, ) - .build(Arc::new(ProviderFactory::new(db.clone(), self.chain.clone()))) + .build(provider_factory.clone()) .start_network() .await?; let fetch_client = Arc::new(network.fetch_client().await?); - let stage = BodyStage { - downloader: BodiesDownloaderBuilder::default() + let stage = BodyStage::new( + BodiesDownloaderBuilder::default() .with_stream_batch_size(batch_size as usize) .with_request_limit(config.stages.bodies.downloader_request_limit) .with_max_buffered_blocks_size_bytes( @@ -186,15 +189,13 @@ impl Command { config.stages.bodies.downloader_min_concurrent_requests..= config.stages.bodies.downloader_max_concurrent_requests, ) - .build(fetch_client, consensus.clone(), db.clone()), - consensus: consensus.clone(), - }; - + .build(fetch_client, consensus.clone(), provider_factory), + ); (Box::new(stage), None) } StageEnum::Senders => (Box::new(SenderRecoveryStage::new(batch_size)), None), StageEnum::Execution => { - let factory = reth_revm::Factory::new(self.chain.clone()); + let factory = reth_revm::EvmProcessorFactory::new(self.chain.clone()); ( Box::new(ExecutionStage::new( factory, @@ -242,7 +243,7 @@ impl Command { if !self.skip_unwind { while unwind.checkpoint.block_number > self.from { - let unwind_output = unwind_stage.unwind(&provider_rw, unwind).await?; + let unwind_output = unwind_stage.unwind(&provider_rw, unwind)?; unwind.checkpoint = unwind_output.checkpoint; if self.commit { @@ -257,19 +258,20 @@ impl Command { checkpoint: Some(checkpoint.with_block_number(self.from)), }; - while let ExecOutput { checkpoint: stage_progress, done: false } = - exec_stage.execute(&provider_rw, input).await? - { - input.checkpoint = Some(stage_progress); + loop { + exec_stage.execute_ready(input).await?; + let output = exec_stage.execute(&provider_rw, input)?; + + input.checkpoint = Some(output.checkpoint); if self.commit { provider_rw.commit()?; provider_rw = factory.provider_rw()?; } - } - if self.commit { - provider_rw.commit()?; + if output.done { + break + } } Ok(()) diff --git a/crates/blockchain-tree/src/blockchain_tree.rs b/crates/blockchain-tree/src/blockchain_tree.rs index 565d6801aebd..8d399dda2599 100644 --- a/crates/blockchain-tree/src/blockchain_tree.rs +++ b/crates/blockchain-tree/src/blockchain_tree.rs @@ -24,7 +24,7 @@ use reth_provider::{ chain::{ChainSplit, ChainSplitTarget}, BlockExecutionWriter, BlockNumReader, BlockWriter, BundleStateWithReceipts, CanonStateNotification, CanonStateNotificationSender, CanonStateNotifications, Chain, - DatabaseProvider, DisplayBlocksChain, ExecutorFactory, HeaderProvider, + ChainSpecProvider, DisplayBlocksChain, ExecutorFactory, HeaderProvider, }; use reth_stages::{MetricEvent, MetricEventsSender}; use std::{collections::BTreeMap, sync::Arc}; @@ -104,8 +104,8 @@ impl BlockchainTree { externals.fetch_latest_canonical_hashes(config.num_of_canonical_hashes() as usize)?; // TODO(rakita) save last finalized block inside database but for now just take - // tip-max_reorg_depth - // task: https://github.com/paradigmxyz/reth/issues/1712 + // `tip - max_reorg_depth` + // https://github.com/paradigmxyz/reth/issues/1712 let last_finalized_block_number = if last_canonical_hashes.len() > max_reorg_depth { // we pick `Highest - max_reorg_depth` block as last finalized block. last_canonical_hashes.keys().nth_back(max_reorg_depth) @@ -161,7 +161,7 @@ impl BlockchainTree { } // check if block is inside database - if self.externals.database().provider()?.block_number(block.hash)?.is_some() { + if self.externals.provider_factory.provider()?.block_number(block.hash)?.is_some() { return Ok(Some(BlockStatus::Valid)) } @@ -209,10 +209,22 @@ impl BlockchainTree { /// Returns the block with matching hash from any side-chain. /// /// Caution: This will not return blocks from the canonical chain. + #[inline] pub fn block_by_hash(&self, block_hash: BlockHash) -> Option<&SealedBlock> { self.state.block_by_hash(block_hash) } + /// Returns the block with matching hash from any side-chain. + /// + /// Caution: This will not return blocks from the canonical chain. + #[inline] + pub fn block_with_senders_by_hash( + &self, + block_hash: BlockHash, + ) -> Option<&SealedBlockWithSenders> { + self.state.block_with_senders_by_hash(block_hash) + } + /// Returns the block's receipts with matching hash from any side-chain. /// /// Caution: This will not return blocks from the canonical chain. @@ -368,8 +380,9 @@ impl BlockchainTree { // https://github.com/paradigmxyz/reth/issues/1713 let (block_status, chain) = { - let factory = self.externals.database(); - let provider = factory + let provider = self + .externals + .provider_factory .provider() .map_err(|err| InsertBlockError::new(block.block.clone(), err.into()))?; @@ -385,7 +398,12 @@ impl BlockchainTree { })?; // Pass the parent total difficulty to short-circuit unnecessary calculations. - if !self.externals.chain_spec.fork(Hardfork::Paris).active_at_ttd(parent_td, U256::ZERO) + if !self + .externals + .provider_factory + .chain_spec() + .fork(Hardfork::Paris) + .active_at_ttd(parent_td, U256::ZERO) { return Err(InsertBlockError::execution_error( BlockValidationError::BlockPreMerge { hash: block.hash }.into(), @@ -546,8 +564,9 @@ impl BlockchainTree { let Some(chain) = self.state.chains.get(&chain_id) else { return hashes }; hashes.extend(chain.blocks().values().map(|b| (b.number, b.hash()))); - let fork_block = chain.fork_block_hash(); - if let Some(next_chain_id) = self.block_indices().get_blocks_chain_id(&fork_block) { + let fork_block = chain.fork_block(); + if let Some(next_chain_id) = self.block_indices().get_blocks_chain_id(&fork_block.hash) + { chain_id = next_chain_id; } else { // if there is no fork block that point to other chains, break the loop. @@ -794,7 +813,7 @@ impl BlockchainTree { // check unconnected block buffer for childs of the chains let mut all_chain_blocks = Vec::new(); for (_, chain) in self.state.chains.iter() { - for (&number, blocks) in chain.blocks.iter() { + for (&number, blocks) in chain.blocks().iter() { all_chain_blocks.push(BlockNumHash { number, hash: blocks.hash }) } } @@ -868,8 +887,7 @@ impl BlockchainTree { // canonical, but in the db. If it is in a sidechain, it is not canonical. If it is not in // the db, then it is not canonical. - let factory = self.externals.database(); - let provider = factory.provider()?; + let provider = self.externals.provider_factory.provider()?; let mut header = None; if let Some(num) = self.block_indices().get_canonical_block_number(hash) { @@ -917,12 +935,18 @@ impl BlockchainTree { if let Some(header) = canonical_header { info!(target: "blockchain_tree", ?block_hash, "Block is already canonical, ignoring."); // TODO: this could be fetched from the chainspec first - let td = self.externals.database().provider()?.header_td(block_hash)?.ok_or( + let td = self.externals.provider_factory.provider()?.header_td(block_hash)?.ok_or( CanonicalError::from(BlockValidationError::MissingTotalDifficulty { hash: *block_hash, }), )?; - if !self.externals.chain_spec.fork(Hardfork::Paris).active_at_ttd(td, U256::ZERO) { + if !self + .externals + .provider_factory + .chain_spec() + .fork(Hardfork::Paris) + .active_at_ttd(td, U256::ZERO) + { return Err(CanonicalError::from(BlockValidationError::BlockPreMerge { hash: *block_hash, }) @@ -946,18 +970,16 @@ impl BlockchainTree { let canonical = self.split_chain(chain_id, chain, ChainSplitTarget::Hash(*block_hash)); durations_recorder.record_relative(MakeCanonicalAction::SplitChain); - let mut block_fork = canonical.fork_block(); - let mut block_fork_number = canonical.fork_block_number(); + let mut fork_block = canonical.fork_block(); let mut chains_to_promote = vec![canonical]; // loop while fork blocks are found in Tree. - while let Some(chain_id) = self.block_indices().get_blocks_chain_id(&block_fork.hash) { - let chain = self.state.chains.remove(&chain_id).expect("To fork to be present"); - block_fork = chain.fork_block(); + while let Some(chain_id) = self.block_indices().get_blocks_chain_id(&fork_block.hash) { + let chain = self.state.chains.remove(&chain_id).expect("fork is present"); // canonical chain is lower part of the chain. let canonical = - self.split_chain(chain_id, chain, ChainSplitTarget::Number(block_fork_number)); - block_fork_number = canonical.fork_block_number(); + self.split_chain(chain_id, chain, ChainSplitTarget::Number(fork_block.number)); + fork_block = canonical.fork_block(); chains_to_promote.push(canonical); } durations_recorder.record_relative(MakeCanonicalAction::SplitChainForks); @@ -989,7 +1011,7 @@ impl BlockchainTree { ); // if joins to the tip; - if new_canon_chain.fork_block_hash() == old_tip.hash { + if new_canon_chain.fork_block().hash == old_tip.hash { chain_notification = CanonStateNotification::Commit { new: Arc::new(new_canon_chain.clone()) }; // append to database @@ -1082,14 +1104,11 @@ impl BlockchainTree { /// Write the given chain to the database as canonical. fn commit_canonical_to_database(&self, chain: Chain) -> RethResult<()> { - let provider = DatabaseProvider::new_rw( - self.externals.db.tx_mut()?, - self.externals.chain_spec.clone(), - ); + let provider_rw = self.externals.provider_factory.provider_rw()?; let (blocks, state) = chain.into_inner(); - provider + provider_rw .append_blocks_with_bundle_state( blocks.into_blocks().collect(), state, @@ -1097,7 +1116,7 @@ impl BlockchainTree { ) .map_err(|e| BlockExecutionError::CanonicalCommit { inner: e.to_string() })?; - provider.commit()?; + provider_rw.commit()?; Ok(()) } @@ -1130,21 +1149,20 @@ impl BlockchainTree { revert_until: BlockNumber, ) -> RethResult> { // read data that is needed for new sidechain + let provider_rw = self.externals.provider_factory.provider_rw()?; - let provider = DatabaseProvider::new_rw( - self.externals.db.tx_mut()?, - self.externals.chain_spec.clone(), - ); - - let tip = provider.last_block_number()?; + let tip = provider_rw.last_block_number()?; let revert_range = (revert_until + 1)..=tip; info!(target: "blockchain_tree", "Unwinding canonical chain blocks: {:?}", revert_range); // read block and execution result from database. and remove traces of block from tables. - let blocks_and_execution = provider - .take_block_and_execution_range(self.externals.chain_spec.as_ref(), revert_range) + let blocks_and_execution = provider_rw + .take_block_and_execution_range( + self.externals.provider_factory.chain_spec().as_ref(), + revert_range, + ) .map_err(|e| BlockExecutionError::CanonicalRevert { inner: e.to_string() })?; - provider.commit()?; + provider_rw.commit()?; if blocks_and_execution.is_empty() { Ok(None) @@ -1186,18 +1204,16 @@ mod tests { use crate::block_buffer::BufferedBlocks; use assert_matches::assert_matches; use linked_hash_set::LinkedHashSet; - use reth_db::{ - tables, - test_utils::{create_test_rw_db, TempDatabase}, - transaction::DbTxMut, - DatabaseEnv, - }; + use reth_db::{tables, test_utils::TempDatabase, transaction::DbTxMut, DatabaseEnv}; use reth_interfaces::test_utils::TestConsensus; use reth_primitives::{ constants::EMPTY_ROOT_HASH, stage::StageCheckpoint, ChainSpecBuilder, B256, MAINNET, }; use reth_provider::{ - test_utils::{blocks::BlockChainTestData, TestExecutorFactory}, + test_utils::{ + blocks::BlockChainTestData, create_test_provider_factory_with_chain_spec, + TestExecutorFactory, + }, BlockWriter, BundleStateWithReceipts, ProviderFactory, }; use std::{ @@ -1208,8 +1224,6 @@ mod tests { fn setup_externals( exec_res: Vec, ) -> TreeExternals>, TestExecutorFactory> { - let db = create_test_rw_db(); - let consensus = Arc::new(TestConsensus::default()); let chain_spec = Arc::new( ChainSpecBuilder::default() .chain(MAINNET.chain) @@ -1217,18 +1231,19 @@ mod tests { .shanghai_activated() .build(), ); + let provider_factory = create_test_provider_factory_with_chain_spec(chain_spec.clone()); + let consensus = Arc::new(TestConsensus::default()); let executor_factory = TestExecutorFactory::new(chain_spec.clone()); executor_factory.extend(exec_res); - TreeExternals::new(db, consensus, executor_factory, chain_spec) + TreeExternals::new(provider_factory, consensus, executor_factory) } - fn setup_genesis(db: DB, mut genesis: SealedBlock) { + fn setup_genesis(factory: &ProviderFactory, mut genesis: SealedBlock) { // insert genesis to db. genesis.header.header.number = 10; genesis.header.header.state_root = EMPTY_ROOT_HASH; - let factory = ProviderFactory::new(&db, MAINNET.clone()); let provider = factory.provider_rw().unwrap(); provider.insert_block(genesis, None, None).unwrap(); @@ -1328,7 +1343,7 @@ mod tests { let externals = setup_externals(vec![exec2.clone(), exec1.clone(), exec2, exec1]); // last finalized block would be number 9. - setup_genesis(externals.db.clone(), genesis); + setup_genesis(&externals.provider_factory, genesis); // make tree let config = BlockchainTreeConfig::new(1, 2, 3, 2); diff --git a/crates/blockchain-tree/src/chain.rs b/crates/blockchain-tree/src/chain.rs index 36227f003b72..1b01b2787931 100644 --- a/crates/blockchain-tree/src/chain.rs +++ b/crates/blockchain-tree/src/chain.rs @@ -152,28 +152,29 @@ impl AppendableChain { ) })?; - let mut state = self.state.clone(); + let mut state = self.state().clone(); // Revert state to the state after execution of the parent block state.revert_to(parent.number); // Revert changesets to get the state of the parent that we need to apply the change. - let post_state_data = BundleStateDataRef { + let bundle_state_data = BundleStateDataRef { state: &state, sidechain_block_hashes: &side_chain_block_hashes, canonical_block_hashes, canonical_fork, }; - let block_state = - Self::validate_and_execute_sidechain(block.clone(), parent, post_state_data, externals) - .map_err(|err| InsertBlockError::new(block.block.clone(), err.into()))?; + let block_state = Self::validate_and_execute_sidechain( + block.clone(), + parent, + bundle_state_data, + externals, + ) + .map_err(|err| InsertBlockError::new(block.block.clone(), err.into()))?; state.extend(block_state); - let chain = - Self { chain: Chain { state, blocks: BTreeMap::from([(block.number, block)]) } }; - // If all is okay, return new chain back. Present chain is not modified. - Ok(chain) + Ok(Self { chain: Chain::from_block(block, state) }) } /// Validate and execute the given block that _extends the canonical chain_, validating its @@ -188,7 +189,7 @@ impl AppendableChain { fn validate_and_execute( block: SealedBlockWithSenders, parent_block: &SealedHeader, - post_state_data_provider: BSDP, + bundle_state_data_provider: BSDP, externals: &TreeExternals, block_kind: BlockKind, block_validation_kind: BlockValidationKind, @@ -205,11 +206,11 @@ impl AppendableChain { let block = block.unseal(); // get the state provider. - let db = externals.database(); - let canonical_fork = post_state_data_provider.canonical_fork(); - let state_provider = db.history_by_block_number(canonical_fork.number)?; + let canonical_fork = bundle_state_data_provider.canonical_fork(); + let state_provider = + externals.provider_factory.history_by_block_number(canonical_fork.number)?; - let provider = BundleStateProvider::new(state_provider, post_state_data_provider); + let provider = BundleStateProvider::new(state_provider, bundle_state_data_provider); let mut executor = externals.executor_factory.with_state(&provider); executor.execute_and_verify_receipt(&block, U256::MAX, Some(senders))?; @@ -235,7 +236,7 @@ impl AppendableChain { fn validate_and_execute_sidechain( block: SealedBlockWithSenders, parent_block: &SealedHeader, - post_state_data_provider: BSDP, + bundle_state_data_provider: BSDP, externals: &TreeExternals, ) -> RethResult where @@ -246,7 +247,7 @@ impl AppendableChain { Self::validate_and_execute( block, parent_block, - post_state_data_provider, + bundle_state_data_provider, externals, BlockKind::ForksHistoricalBlock, BlockValidationKind::SkipStateRootValidation, @@ -280,10 +281,10 @@ impl AppendableChain { DB: Database, EF: ExecutorFactory, { - let (_, parent_block) = self.blocks.last_key_value().expect("Chain has at least one block"); + let parent_block = self.chain.tip(); - let post_state_data = BundleStateDataRef { - state: &self.state, + let bundle_state_data = BundleStateDataRef { + state: self.state(), sidechain_block_hashes: &side_chain_block_hashes, canonical_block_hashes, canonical_fork, @@ -292,15 +293,14 @@ impl AppendableChain { let block_state = Self::validate_and_execute( block.clone(), parent_block, - post_state_data, + bundle_state_data, externals, block_kind, block_validation_kind, ) .map_err(|err| InsertBlockError::new(block.block.clone(), err.into()))?; // extend the state. - self.state.extend(block_state); - self.blocks.insert(block.number, block); + self.chain.append_block(block, block_state); Ok(()) } } diff --git a/crates/blockchain-tree/src/externals.rs b/crates/blockchain-tree/src/externals.rs index 06cd694d3b0d..9bd12195aa44 100644 --- a/crates/blockchain-tree/src/externals.rs +++ b/crates/blockchain-tree/src/externals.rs @@ -2,7 +2,7 @@ use reth_db::{cursor::DbCursorRO, database::Database, tables, transaction::DbTx}; use reth_interfaces::{consensus::Consensus, RethResult}; -use reth_primitives::{BlockHash, BlockNumber, ChainSpec}; +use reth_primitives::{BlockHash, BlockNumber}; use reth_provider::ProviderFactory; use std::{collections::BTreeMap, sync::Arc}; @@ -17,34 +17,26 @@ use std::{collections::BTreeMap, sync::Arc}; /// - The chain spec #[derive(Debug)] pub struct TreeExternals { - /// The database, used to commit the canonical chain, or unwind it. - pub(crate) db: DB, + /// The provider factory, used to commit the canonical chain, or unwind it. + pub(crate) provider_factory: ProviderFactory, /// The consensus engine. pub(crate) consensus: Arc, /// The executor factory to execute blocks with. pub(crate) executor_factory: EF, - /// The chain spec. - pub(crate) chain_spec: Arc, } impl TreeExternals { /// Create new tree externals. pub fn new( - db: DB, + provider_factory: ProviderFactory, consensus: Arc, executor_factory: EF, - chain_spec: Arc, ) -> Self { - Self { db, consensus, executor_factory, chain_spec } + Self { provider_factory, consensus, executor_factory } } } impl TreeExternals { - /// Return shareable database helper structure. - pub fn database(&self) -> ProviderFactory<&DB> { - ProviderFactory::new(&self.db, self.chain_spec.clone()) - } - /// Fetches the latest canonical block hashes by walking backwards from the head. /// /// Returns the hashes sorted by increasing block numbers @@ -53,8 +45,9 @@ impl TreeExternals { num_hashes: usize, ) -> RethResult> { Ok(self - .db - .tx()? + .provider_factory + .provider()? + .tx_ref() .cursor_read::()? .walk_back(None)? .take(num_hashes) diff --git a/crates/blockchain-tree/src/noop.rs b/crates/blockchain-tree/src/noop.rs index 95709dc7de81..732a7d1a09c9 100644 --- a/crates/blockchain-tree/src/noop.rs +++ b/crates/blockchain-tree/src/noop.rs @@ -74,6 +74,10 @@ impl BlockchainTreeViewer for NoopBlockchainTree { None } + fn block_with_senders_by_hash(&self, _hash: BlockHash) -> Option { + None + } + fn buffered_block_by_hash(&self, _block_hash: BlockHash) -> Option { None } diff --git a/crates/blockchain-tree/src/shareable.rs b/crates/blockchain-tree/src/shareable.rs index ebb57ca1c783..d4776a67ce50 100644 --- a/crates/blockchain-tree/src/shareable.rs +++ b/crates/blockchain-tree/src/shareable.rs @@ -117,6 +117,11 @@ impl BlockchainTreeViewer for ShareableBlockc self.tree.read().block_by_hash(block_hash).cloned() } + fn block_with_senders_by_hash(&self, block_hash: BlockHash) -> Option { + trace!(target: "blockchain_tree", ?block_hash, "Returning block by hash"); + self.tree.read().block_with_senders_by_hash(block_hash).cloned() + } + fn buffered_block_by_hash(&self, block_hash: BlockHash) -> Option { self.tree.read().get_buffered_block(&block_hash).map(|b| b.block.clone()) } diff --git a/crates/blockchain-tree/src/state.rs b/crates/blockchain-tree/src/state.rs index bca7ddf40957..8c4c58229414 100644 --- a/crates/blockchain-tree/src/state.rs +++ b/crates/blockchain-tree/src/state.rs @@ -56,10 +56,21 @@ impl TreeState { /// Returns the block with matching hash from any side-chain. /// /// Caution: This will not return blocks from the canonical chain. + #[inline] pub(crate) fn block_by_hash(&self, block_hash: BlockHash) -> Option<&SealedBlock> { + self.block_with_senders_by_hash(block_hash).map(|block| &block.block) + } + /// Returns the block with matching hash from any side-chain. + /// + /// Caution: This will not return blocks from the canonical chain. + #[inline] + pub(crate) fn block_with_senders_by_hash( + &self, + block_hash: BlockHash, + ) -> Option<&SealedBlockWithSenders> { let id = self.block_indices.get_blocks_chain_id(&block_hash)?; let chain = self.chains.get(&id)?; - chain.block(block_hash) + chain.block_with_senders(block_hash) } /// Returns the block's receipts with matching hash from any side-chain. diff --git a/crates/consensus/beacon/src/engine/sync.rs b/crates/consensus/beacon/src/engine/sync.rs index 07780d4330e1..10c18e742089 100644 --- a/crates/consensus/beacon/src/engine/sync.rs +++ b/crates/consensus/beacon/src/engine/sync.rs @@ -394,16 +394,16 @@ mod tests { use super::*; use assert_matches::assert_matches; use futures::poll; - use reth_db::{ - mdbx::DatabaseEnv, - test_utils::{create_test_rw_db, TempDatabase}, - }; + use reth_db::{mdbx::DatabaseEnv, test_utils::TempDatabase}; use reth_interfaces::{p2p::either::EitherDownloader, test_utils::TestFullBlockClient}; use reth_primitives::{ constants::ETHEREUM_BLOCK_GAS_LIMIT, stage::StageCheckpoint, BlockBody, ChainSpec, ChainSpecBuilder, Header, SealedHeader, MAINNET, }; - use reth_provider::{test_utils::TestExecutorFactory, BundleStateWithReceipts}; + use reth_provider::{ + test_utils::{create_test_provider_factory_with_chain_spec, TestExecutorFactory}, + BundleStateWithReceipts, + }; use reth_stages::{test_utils::TestStages, ExecOutput, StageError}; use reth_tasks::TokioTaskExecutor; use std::{collections::VecDeque, future::poll_fn, sync::Arc}; @@ -451,7 +451,6 @@ mod tests { /// Builds the pipeline. fn build(self, chain_spec: Arc) -> Pipeline>> { reth_tracing::init_test_tracing(); - let db = create_test_rw_db(); let executor_factory = TestExecutorFactory::new(chain_spec.clone()); executor_factory.extend(self.executor_results); @@ -466,7 +465,7 @@ mod tests { pipeline = pipeline.with_max_block(max_block); } - pipeline.build(db, chain_spec) + pipeline.build(create_test_provider_factory_with_chain_spec(chain_spec)) } } diff --git a/crates/consensus/beacon/src/engine/test_utils.rs b/crates/consensus/beacon/src/engine/test_utils.rs index 092ce9f5e218..781161bb3a02 100644 --- a/crates/consensus/beacon/src/engine/test_utils.rs +++ b/crates/consensus/beacon/src/engine/test_utils.rs @@ -26,17 +26,15 @@ use reth_payload_builder::test_utils::spawn_test_payload_service; use reth_primitives::{BlockNumber, ChainSpec, PruneModes, Receipt, B256, U256}; use reth_provider::{ providers::BlockchainProvider, test_utils::TestExecutorFactory, BlockExecutor, - BundleStateWithReceipts, ExecutorFactory, ProviderFactory, PrunableBlockExecutor, + BundleStateWithReceipts, ExecutorFactory, HeaderSyncMode, ProviderFactory, + PrunableBlockExecutor, }; use reth_prune::Pruner; -use reth_revm::Factory; +use reth_revm::EvmProcessorFactory; use reth_rpc_types::engine::{ CancunPayloadFields, ExecutionPayload, ForkchoiceState, ForkchoiceUpdated, PayloadStatus, }; -use reth_stages::{ - sets::DefaultStages, stages::HeaderSyncMode, test_utils::TestStages, ExecOutput, Pipeline, - StageError, -}; +use reth_stages::{sets::DefaultStages, test_utils::TestStages, ExecOutput, Pipeline, StageError}; use reth_tasks::TokioTaskExecutor; use std::{collections::VecDeque, sync::Arc}; use tokio::sync::{oneshot, watch}; @@ -47,7 +45,7 @@ type TestBeaconConsensusEngine = BeaconConsensusEngine< Arc, ShareableBlockchainTree< Arc, - EitherExecutorFactory, + EitherExecutorFactory, >, >, Arc>, @@ -457,6 +455,8 @@ where pub fn build(self) -> (TestBeaconConsensusEngine, TestEnv>) { reth_tracing::init_test_tracing(); let db = create_test_rw_db(); + let provider_factory = + ProviderFactory::new(db.clone(), self.base_config.chain_spec.clone()); let consensus: Arc = match self.base_config.consensus { TestConsensusConfig::Real => { @@ -481,9 +481,9 @@ where executor_factory.extend(results); EitherExecutorFactory::Left(executor_factory) } - TestExecutorConfig::Real => { - EitherExecutorFactory::Right(Factory::new(self.base_config.chain_spec.clone())) - } + TestExecutorConfig::Real => EitherExecutorFactory::Right(EvmProcessorFactory::new( + self.base_config.chain_spec.clone(), + )), }; // Setup pipeline @@ -498,10 +498,11 @@ where .into_task(); let body_downloader = BodiesDownloaderBuilder::default() - .build(client.clone(), consensus.clone(), db.clone()) + .build(client.clone(), consensus.clone(), provider_factory.clone()) .into_task(); Pipeline::builder().add_stages(DefaultStages::new( + ProviderFactory::new(db.clone(), self.base_config.chain_spec.clone()), HeaderSyncMode::Tip(tip_rx.clone()), Arc::clone(&consensus), header_downloader, @@ -515,22 +516,16 @@ where pipeline = pipeline.with_max_block(max_block); } - let pipeline = pipeline.build(db.clone(), self.base_config.chain_spec.clone()); + let pipeline = pipeline.build(provider_factory.clone()); // Setup blockchain tree - let externals = TreeExternals::new( - db.clone(), - consensus, - executor_factory, - self.base_config.chain_spec.clone(), - ); + let externals = TreeExternals::new(provider_factory.clone(), consensus, executor_factory); let config = BlockchainTreeConfig::new(1, 2, 3, 2); let tree = ShareableBlockchainTree::new( BlockchainTree::new(externals, config, None).expect("failed to create tree"), ); - let shareable_db = ProviderFactory::new(db.clone(), self.base_config.chain_spec.clone()); let latest = self.base_config.chain_spec.genesis_header().seal_slow(); - let blockchain_provider = BlockchainProvider::with_latest(shareable_db, tree, latest); + let blockchain_provider = BlockchainProvider::with_latest(provider_factory, tree, latest); let pruner = Pruner::new( db.clone(), diff --git a/crates/interfaces/src/blockchain_tree/mod.rs b/crates/interfaces/src/blockchain_tree/mod.rs index 8a365361bd88..3d77fef03cf9 100644 --- a/crates/interfaces/src/blockchain_tree/mod.rs +++ b/crates/interfaces/src/blockchain_tree/mod.rs @@ -237,6 +237,12 @@ pub trait BlockchainTreeViewer: Send + Sync { /// disconnected from the canonical chain. fn block_by_hash(&self, hash: BlockHash) -> Option; + /// Returns the block with matching hash from the tree, if it exists. + /// + /// Caution: This will not return blocks from the canonical chain or buffered blocks that are + /// disconnected from the canonical chain. + fn block_with_senders_by_hash(&self, hash: BlockHash) -> Option; + /// Returns the _buffered_ (disconnected) block with matching hash from the internal buffer if /// it exists. /// @@ -295,6 +301,11 @@ pub trait BlockchainTreeViewer: Send + Sync { self.block_by_hash(self.pending_block_num_hash()?.hash) } + /// Returns the pending block if there is one. + fn pending_block_with_senders(&self) -> Option { + self.block_with_senders_by_hash(self.pending_block_num_hash()?.hash) + } + /// Returns the pending block and its receipts in one call. /// /// This exists to prevent a potential data race if the pending block changes in between diff --git a/crates/interfaces/src/error.rs b/crates/interfaces/src/error.rs index ef14d1211eb4..58bef122b424 100644 --- a/crates/interfaces/src/error.rs +++ b/crates/interfaces/src/error.rs @@ -1,3 +1,13 @@ +use crate::{ + blockchain_tree::error::{BlockchainTreeError, CanonicalError}, + consensus::ConsensusError, + db::DatabaseError, + executor::BlockExecutionError, + provider::ProviderError, +}; +use reth_network_api::NetworkError; +use reth_primitives::fs::FsPathError; + /// Result alias for [`RethError`]. pub type RethResult = Result; @@ -6,47 +16,55 @@ pub type RethResult = Result; #[allow(missing_docs)] pub enum RethError { #[error(transparent)] - Execution(#[from] crate::executor::BlockExecutionError), + Execution(#[from] BlockExecutionError), #[error(transparent)] - Consensus(#[from] crate::consensus::ConsensusError), + Consensus(#[from] ConsensusError), #[error(transparent)] - Database(#[from] crate::db::DatabaseError), + Database(#[from] DatabaseError), #[error(transparent)] - Provider(#[from] crate::provider::ProviderError), + Provider(#[from] ProviderError), #[error(transparent)] - Network(#[from] reth_network_api::NetworkError), + Network(#[from] NetworkError), #[error(transparent)] - Canonical(#[from] crate::blockchain_tree::error::CanonicalError), + Canonical(#[from] CanonicalError), #[error("{0}")] Custom(String), } -impl From for RethError { - fn from(error: crate::blockchain_tree::error::BlockchainTreeError) -> Self { - RethError::Canonical(error.into()) +impl From for RethError { + fn from(error: BlockchainTreeError) -> Self { + RethError::Canonical(CanonicalError::BlockchainTree(error)) } } -impl From for RethError { - fn from(err: reth_primitives::fs::FsPathError) -> Self { +impl From for RethError { + fn from(err: FsPathError) -> Self { RethError::Custom(err.to_string()) } } -// We don't want these types to be too large because they're used in a lot of places. -const _SIZE_ASSERTIONS: () = { - // Main error. - let _: [(); 64] = [(); std::mem::size_of::()]; +// Some types are used a lot. Make sure they don't unintentionally get bigger. +#[cfg(all(target_arch = "x86_64", target_pointer_width = "64"))] +mod size_asserts { + use super::*; - // Biggest variant. - let _: [(); 64] = [(); std::mem::size_of::()]; + macro_rules! static_assert_size { + ($t:ty, $sz:expr) => { + const _: [(); $sz] = [(); std::mem::size_of::<$t>()]; + }; + } - // Other common types. - let _: [(); 16] = [(); std::mem::size_of::()]; -}; + static_assert_size!(RethError, 56); + static_assert_size!(BlockExecutionError, 48); + static_assert_size!(ConsensusError, 48); + static_assert_size!(DatabaseError, 16); + static_assert_size!(ProviderError, 48); + static_assert_size!(NetworkError, 0); + static_assert_size!(CanonicalError, 48); +} diff --git a/crates/interfaces/src/p2p/error.rs b/crates/interfaces/src/p2p/error.rs index 53cf68e4b4de..29c238e5f529 100644 --- a/crates/interfaces/src/p2p/error.rs +++ b/crates/interfaces/src/p2p/error.rs @@ -1,5 +1,5 @@ use super::headers::client::HeadersRequest; -use crate::{consensus::ConsensusError, db}; +use crate::{consensus::ConsensusError, db::DatabaseError, provider::ProviderError}; use reth_network_api::ReputationChangeKind; use reth_primitives::{ BlockHashOrNumber, BlockNumber, GotExpected, GotExpectedBoxed, Header, WithPeerId, B256, @@ -177,9 +177,15 @@ pub enum DownloadError { /// Error while executing the request. #[error(transparent)] RequestError(#[from] RequestError), - /// Error while reading data from database. + /// Provider error. #[error(transparent)] - DatabaseError(#[from] db::DatabaseError), + Provider(#[from] ProviderError), +} + +impl From for DownloadError { + fn from(error: DatabaseError) -> Self { + Self::Provider(ProviderError::Database(error)) + } } #[cfg(test)] diff --git a/crates/interfaces/src/provider.rs b/crates/interfaces/src/provider.rs index f5f0a7fccf6b..9fad40efd8b3 100644 --- a/crates/interfaces/src/provider.rs +++ b/crates/interfaces/src/provider.rs @@ -20,25 +20,30 @@ pub enum ProviderError { /// Error when recovering the sender for a transaction #[error("failed to recover sender for transaction")] SenderRecoveryError, + /// Inconsistent header gap. + #[error("inconsistent header gap in the database")] + InconsistentHeaderGap, /// The header number was not found for the given block hash. #[error("block hash {0} does not exist in Headers table")] BlockHashNotFound(BlockHash), /// A block body is missing. #[error("block meta not found for block #{0}")] BlockBodyIndicesNotFound(BlockNumber), - /// The transition id was found for the given address and storage key, but the changeset was + /// The transition ID was found for the given address and storage key, but the changeset was /// not found. - #[error("storage ChangeSet address: ({address} key: {storage_key:?}) for block #{block_number} does not exist")] + #[error("storage change set for address {address} and key {storage_key} at block #{block_number} does not exist")] StorageChangesetNotFound { /// The block number found for the address and storage key. block_number: BlockNumber, /// The account address. address: Address, /// The storage key. - storage_key: B256, + // NOTE: This is a Box only because otherwise this variant is 16 bytes larger than the + // second largest (which uses `BlockHashOrNumber`). + storage_key: Box, }, /// The block number was found for the given address, but the changeset was not found. - #[error("account {address} ChangeSet for block #{block_number} does not exist")] + #[error("account change set for address {address} at block #{block_number} does not exist")] AccountChangesetNotFound { /// Block number found for the address. block_number: BlockNumber, diff --git a/crates/net/downloaders/Cargo.toml b/crates/net/downloaders/Cargo.toml index cdb3317dbe84..3a50908b13d2 100644 --- a/crates/net/downloaders/Cargo.toml +++ b/crates/net/downloaders/Cargo.toml @@ -12,8 +12,8 @@ description = "Implementations of various block downloaders" # reth reth-interfaces.workspace = true reth-primitives.workspace = true -reth-db.workspace = true reth-tasks.workspace = true +reth-provider.workspace = true # async futures.workspace = true @@ -33,6 +33,7 @@ rayon.workspace = true thiserror.workspace = true # optional deps for the test-utils feature +reth-db = { workspace = true, optional = true } alloy-rlp = { workspace = true, optional = true } tempfile = { workspace = true, optional = true } itertools = { workspace = true, optional = true } @@ -50,4 +51,4 @@ itertools.workspace = true tempfile.workspace = true [features] -test-utils = ["dep:alloy-rlp", "dep:tempfile", "dep:itertools", "reth-interfaces/test-utils"] +test-utils = ["dep:alloy-rlp", "dep:tempfile", "dep:itertools", "reth-db/test-utils", "reth-interfaces/test-utils"] diff --git a/crates/net/downloaders/src/bodies/bodies.rs b/crates/net/downloaders/src/bodies/bodies.rs index a1451bb5b159..b601865d3a6c 100644 --- a/crates/net/downloaders/src/bodies/bodies.rs +++ b/crates/net/downloaders/src/bodies/bodies.rs @@ -2,7 +2,6 @@ use super::queue::BodiesRequestQueue; use crate::{bodies::task::TaskDownloader, metrics::BodyDownloaderMetrics}; use futures::Stream; use futures_util::StreamExt; -use reth_db::{cursor::DbCursorRO, database::Database, tables, transaction::DbTx}; use reth_interfaces::{ consensus::Consensus, p2p::{ @@ -15,6 +14,7 @@ use reth_interfaces::{ }, }; use reth_primitives::{BlockNumber, SealedHeader}; +use reth_provider::HeaderProvider; use reth_tasks::{TaskSpawner, TokioTaskExecutor}; use std::{ cmp::Ordering, @@ -27,22 +27,18 @@ use std::{ }; use tracing::info; -/// The scope for headers downloader metrics. -pub const BODIES_DOWNLOADER_SCOPE: &str = "downloaders.bodies"; - /// Downloads bodies in batches. /// /// All blocks in a batch are fetched at the same time. #[must_use = "Stream does nothing unless polled"] #[derive(Debug)] -pub struct BodiesDownloader { +pub struct BodiesDownloader { /// The bodies client client: Arc, /// The consensus client consensus: Arc, - // TODO: make this a [HeaderProvider] /// The database handle - db: DB, + provider: Provider, /// The maximum number of non-empty blocks per one request request_limit: u64, /// The maximum number of block bodies returned at once from the stream @@ -67,10 +63,10 @@ pub struct BodiesDownloader { metrics: BodyDownloaderMetrics, } -impl BodiesDownloader +impl BodiesDownloader where B: BodiesClient + 'static, - DB: Database + Unpin + 'static, + Provider: HeaderProvider + Unpin + 'static, { /// Returns the next contiguous request. fn next_headers_request(&mut self) -> DownloadResult>> { @@ -103,47 +99,29 @@ where return Ok(None) } - // Collection of results - let mut headers = Vec::new(); - - // Non empty headers count - let mut non_empty_headers = 0; - let mut current_block_num = *range.start(); - - // Acquire cursors over canonical and header tables - let tx = self.db.tx()?; - let mut canonical_cursor = tx.cursor_read::()?; - let mut header_cursor = tx.cursor_read::()?; - // Collect headers while // 1. Current block number is in range // 2. The number of non empty headers is less than maximum // 3. The total number of headers is less than the stream batch size (this is only - // relevant if the range consists entirely of empty headers) - while range.contains(¤t_block_num) && - non_empty_headers < max_non_empty && - headers.len() < self.stream_batch_size - { - // Find the block hash. - let (number, hash) = canonical_cursor - .seek_exact(current_block_num)? - .ok_or(DownloadError::MissingHeader { block_number: current_block_num })?; - // Find the block header. - let (_, header) = header_cursor - .seek_exact(number)? - .ok_or(DownloadError::MissingHeader { block_number: number })?; - - // If the header is not empty, increment the counter - if !header.is_empty() { - non_empty_headers += 1; + // relevant if the range consists entirely of empty headers) + let mut collected = 0; + let mut non_empty_headers = 0; + let headers = self.provider.sealed_headers_while(range.clone(), |header| { + let should_take = range.contains(&header.number) && + non_empty_headers < max_non_empty && + collected < self.stream_batch_size; + + if should_take { + collected += 1; + if !header.is_empty() { + non_empty_headers += 1; + } + true + } else { + false } + })?; - // Add header to the result collection - headers.push(header.seal(hash)); - - // Increment current block number - current_block_num += 1; - } Ok(Some(headers).filter(|h| !h.is_empty())) } @@ -286,10 +264,10 @@ where } } -impl BodiesDownloader +impl BodiesDownloader where B: BodiesClient + 'static, - DB: Database + Unpin + 'static, + Provider: HeaderProvider + Unpin + 'static, Self: BodyDownloader + 'static, { /// Spawns the downloader task via [tokio::task::spawn] @@ -306,10 +284,10 @@ where } } -impl BodyDownloader for BodiesDownloader +impl BodyDownloader for BodiesDownloader where B: BodiesClient + 'static, - DB: Database + Unpin + 'static, + Provider: HeaderProvider + Unpin + 'static, { /// Set a new download range (exclusive). /// @@ -354,10 +332,10 @@ where } } -impl Stream for BodiesDownloader +impl Stream for BodiesDownloader where B: BodiesClient + 'static, - DB: Database + Unpin + 'static, + Provider: HeaderProvider + Unpin + 'static, { type Item = BodyDownloaderResult; @@ -557,15 +535,15 @@ impl BodiesDownloaderBuilder { } /// Consume self and return the concurrent downloader. - pub fn build( + pub fn build( self, client: B, consensus: Arc, - db: DB, - ) -> BodiesDownloader + provider: Provider, + ) -> BodiesDownloader where B: BodiesClient + 'static, - DB: Database, + Provider: HeaderProvider, { let Self { request_limit, @@ -578,7 +556,7 @@ impl BodiesDownloaderBuilder { BodiesDownloader { client: Arc::new(client), consensus, - db, + provider, request_limit, stream_batch_size, max_buffered_blocks_size_bytes, @@ -605,7 +583,8 @@ mod tests { use futures_util::stream::StreamExt; use reth_db::test_utils::create_test_rw_db; use reth_interfaces::test_utils::{generators, generators::random_block_range, TestConsensus}; - use reth_primitives::{BlockBody, B256}; + use reth_primitives::{BlockBody, B256, MAINNET}; + use reth_provider::ProviderFactory; use std::{collections::HashMap, sync::Arc}; // Check that the blocks are emitted in order of block number, not in order of @@ -624,7 +603,7 @@ mod tests { let mut downloader = BodiesDownloaderBuilder::default().build( client.clone(), Arc::new(TestConsensus::default()), - db, + ProviderFactory::new(db, MAINNET.clone()), ); downloader.set_download_range(0..=19).expect("failed to set download range"); @@ -659,9 +638,12 @@ mod tests { let request_limit = 10; let client = Arc::new(TestBodiesClient::default().with_bodies(bodies.clone())); - let mut downloader = BodiesDownloaderBuilder::default() - .with_request_limit(request_limit) - .build(client.clone(), Arc::new(TestConsensus::default()), db); + let mut downloader = + BodiesDownloaderBuilder::default().with_request_limit(request_limit).build( + client.clone(), + Arc::new(TestConsensus::default()), + ProviderFactory::new(db, MAINNET.clone()), + ); downloader.set_download_range(0..=199).expect("failed to set download range"); let _ = downloader.collect::>().await; @@ -686,7 +668,11 @@ mod tests { let mut downloader = BodiesDownloaderBuilder::default() .with_stream_batch_size(stream_batch_size) .with_request_limit(request_limit) - .build(client.clone(), Arc::new(TestConsensus::default()), db); + .build( + client.clone(), + Arc::new(TestConsensus::default()), + ProviderFactory::new(db, MAINNET.clone()), + ); let mut range_start = 0; while range_start < 100 { @@ -715,7 +701,7 @@ mod tests { let mut downloader = BodiesDownloaderBuilder::default().with_stream_batch_size(100).build( client.clone(), Arc::new(TestConsensus::default()), - db, + ProviderFactory::new(db, MAINNET.clone()), ); // Set and download the first range @@ -752,7 +738,11 @@ mod tests { .with_stream_batch_size(10) .with_request_limit(1) .with_max_buffered_blocks_size_bytes(1) - .build(client.clone(), Arc::new(TestConsensus::default()), db); + .build( + client.clone(), + Arc::new(TestConsensus::default()), + ProviderFactory::new(db, MAINNET.clone()), + ); // Set and download the entire range downloader.set_download_range(0..=199).expect("failed to set download range"); @@ -779,7 +769,11 @@ mod tests { let mut downloader = BodiesDownloaderBuilder::default() .with_request_limit(3) .with_stream_batch_size(100) - .build(client.clone(), Arc::new(TestConsensus::default()), db); + .build( + client.clone(), + Arc::new(TestConsensus::default()), + ProviderFactory::new(db, MAINNET.clone()), + ); // Download the requested range downloader.set_download_range(0..=99).expect("failed to set download range"); diff --git a/crates/net/downloaders/src/bodies/task.rs b/crates/net/downloaders/src/bodies/task.rs index 97748f54f7e1..9a713a8539bc 100644 --- a/crates/net/downloaders/src/bodies/task.rs +++ b/crates/net/downloaders/src/bodies/task.rs @@ -42,16 +42,17 @@ impl TaskDownloader { /// # Example /// /// ``` - /// use reth_db::database::Database; /// use reth_downloaders::bodies::{bodies::BodiesDownloaderBuilder, task::TaskDownloader}; /// use reth_interfaces::{consensus::Consensus, p2p::bodies::client::BodiesClient}; + /// use reth_provider::HeaderProvider; /// use std::sync::Arc; - /// fn t( + /// + /// fn t( /// client: Arc, /// consensus: Arc, - /// db: Arc, + /// provider: Provider, /// ) { - /// let downloader = BodiesDownloaderBuilder::default().build(client, consensus, db); + /// let downloader = BodiesDownloaderBuilder::default().build(client, consensus, provider); /// let downloader = TaskDownloader::spawn(downloader); /// } /// ``` @@ -170,6 +171,8 @@ mod tests { use assert_matches::assert_matches; use reth_db::test_utils::create_test_rw_db; use reth_interfaces::{p2p::error::DownloadError, test_utils::TestConsensus}; + use reth_primitives::MAINNET; + use reth_provider::ProviderFactory; use std::sync::Arc; #[tokio::test(flavor = "multi_thread")] @@ -187,7 +190,7 @@ mod tests { let downloader = BodiesDownloaderBuilder::default().build( client.clone(), Arc::new(TestConsensus::default()), - db, + ProviderFactory::new(db, MAINNET.clone()), ); let mut downloader = TaskDownloader::spawn(downloader); @@ -209,7 +212,7 @@ mod tests { let downloader = BodiesDownloaderBuilder::default().build( Arc::new(TestBodiesClient::default()), Arc::new(TestConsensus::default()), - db, + ProviderFactory::new(db, MAINNET.clone()), ); let mut downloader = TaskDownloader::spawn(downloader); diff --git a/crates/net/downloaders/src/test_utils/file_client.rs b/crates/net/downloaders/src/test_utils/file_client.rs index 45df474e1e90..69320fe4b171 100644 --- a/crates/net/downloaders/src/test_utils/file_client.rs +++ b/crates/net/downloaders/src/test_utils/file_client.rs @@ -267,7 +267,8 @@ mod tests { }, test_utils::TestConsensus, }; - use reth_primitives::SealedHeader; + use reth_primitives::{SealedHeader, MAINNET}; + use reth_provider::ProviderFactory; use std::{ io::{Read, Seek, SeekFrom, Write}, sync::Arc, @@ -291,7 +292,7 @@ mod tests { let mut downloader = BodiesDownloaderBuilder::default().build( client.clone(), Arc::new(TestConsensus::default()), - db, + ProviderFactory::new(db, MAINNET.clone()), ); downloader.set_download_range(0..=19).expect("failed to set download range"); @@ -373,7 +374,7 @@ mod tests { let mut downloader = BodiesDownloaderBuilder::default().build( client.clone(), Arc::new(TestConsensus::default()), - db, + ProviderFactory::new(db, MAINNET.clone()), ); downloader.set_download_range(0..=19).expect("failed to set download range"); diff --git a/crates/net/eth-wire/src/capability.rs b/crates/net/eth-wire/src/capability.rs index f33993a52a42..3a72504e93b5 100644 --- a/crates/net/eth-wire/src/capability.rs +++ b/crates/net/eth-wire/src/capability.rs @@ -236,10 +236,8 @@ pub enum SharedCapability { }, /// Any other unknown capability. UnknownCapability { - /// Name of the capability. - name: Cow<'static, str>, - /// (Highest) negotiated version of the eth capability. - version: u8, + /// Shared capability. + cap: Capability, /// The message ID offset for this capability. /// /// This represents the message ID offset for the first message of the eth capability in @@ -259,7 +257,10 @@ impl SharedCapability { match name { "eth" => Ok(Self::eth(EthVersion::try_from(version)?, offset)), - _ => Ok(Self::UnknownCapability { name: name.to_string().into(), version, offset }), + _ => Ok(Self::UnknownCapability { + cap: Capability::new(name.to_string(), version as usize), + offset, + }), } } @@ -268,12 +269,20 @@ impl SharedCapability { Self::Eth { version, offset } } + /// Returns the capability. + pub fn capability(&self) -> Cow<'_, Capability> { + match self { + SharedCapability::Eth { version, .. } => Cow::Owned(Capability::eth(*version)), + SharedCapability::UnknownCapability { cap, .. } => Cow::Borrowed(cap), + } + } + /// Returns the name of the capability. #[inline] pub fn name(&self) -> &str { match self { SharedCapability::Eth { .. } => "eth", - SharedCapability::UnknownCapability { name, .. } => name, + SharedCapability::UnknownCapability { cap, .. } => cap.name.as_ref(), } } @@ -287,7 +296,7 @@ impl SharedCapability { pub fn version(&self) -> u8 { match self { SharedCapability::Eth { version, .. } => *version as u8, - SharedCapability::UnknownCapability { version, .. } => *version, + SharedCapability::UnknownCapability { cap, .. } => cap.version as u8, } } @@ -348,9 +357,63 @@ impl SharedCapabilities { } /// Returns the negotiated eth version if it is shared. + #[inline] pub fn eth_version(&self) -> Result { self.eth().map(|cap| cap.version()) } + + /// Returns true if the shared capabilities contain the given capability. + #[inline] + pub fn contains(&self, cap: &Capability) -> bool { + self.find(cap).is_some() + } + + /// Returns the shared capability for the given capability. + #[inline] + pub fn find(&self, cap: &Capability) -> Option<&SharedCapability> { + self.0.iter().find(|c| c.version() == cap.version as u8 && c.name() == cap.name) + } + + /// Returns the matching shared capability for the given capability offset. + /// + /// `offset` is the multiplexed message id offset of the capability relative to + /// [`MAX_RESERVED_MESSAGE_ID`]. + #[inline] + pub fn find_by_relative_offset(&self, offset: u8) -> Option<&SharedCapability> { + self.find_by_offset(offset.saturating_add(MAX_RESERVED_MESSAGE_ID)) + } + + /// Returns the matching shared capability for the given capability offset. + /// + /// `offset` is the multiplexed message id offset of the capability that includes the reserved + /// message id space. + #[inline] + pub fn find_by_offset(&self, offset: u8) -> Option<&SharedCapability> { + let mut iter = self.0.iter(); + let mut cap = iter.next()?; + if offset < cap.message_id_offset() { + // reserved message id space + return None + } + + for next in iter { + if offset < next.message_id_offset() { + return Some(cap) + } + cap = next + } + + Some(cap) + } + + /// Returns the shared capability for the given capability or an error if it's not compatible. + #[inline] + pub fn ensure_matching_capability( + &self, + cap: &Capability, + ) -> Result<&SharedCapability, UnsupportedCapabilityError> { + self.find(cap).ok_or_else(|| UnsupportedCapabilityError { capability: cap.clone() }) + } } /// Determines the offsets for each shared capability between the input list of peer @@ -452,6 +515,13 @@ pub enum SharedCapabilityError { ReservedMessageIdOffset(u8), } +/// An error thrown when capabilities mismatch. +#[derive(Debug, thiserror::Error)] +#[error("unsupported capability {capability}")] +pub struct UnsupportedCapabilityError { + capability: Capability, +} + #[cfg(test)] mod tests { use super::*; @@ -559,4 +629,46 @@ mod tests { Err(P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities)) )) } + + #[test] + fn test_find_by_offset() { + let local_capabilities = vec![EthVersion::Eth66.into()]; + let peer_capabilities = vec![EthVersion::Eth66.into()]; + + let shared = SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap(); + + assert!(shared.find_by_relative_offset(0).is_none()); + let shared_eth = shared.find_by_relative_offset(1).unwrap(); + assert_eq!(shared_eth.name(), "eth"); + + let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 1).unwrap(); + assert_eq!(shared_eth.name(), "eth"); + } + + #[test] + fn test_find_by_offset_many() { + let cap = Capability::new_static("aaa", 1); + let proto = Protocol::new(cap.clone(), 5); + let local_capabilities = vec![proto.clone(), EthVersion::Eth66.into()]; + let peer_capabilities = vec![cap, EthVersion::Eth66.into()]; + + let shared = SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap(); + + assert!(shared.find_by_relative_offset(0).is_none()); + let shared_eth = shared.find_by_relative_offset(1).unwrap(); + assert_eq!(shared_eth.name(), proto.cap.name); + + let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 1).unwrap(); + assert_eq!(shared_eth.name(), proto.cap.name); + + // the 5th shared message is the last message of the aaa capability + let shared_eth = shared.find_by_relative_offset(5).unwrap(); + assert_eq!(shared_eth.name(), proto.cap.name); + let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 5).unwrap(); + assert_eq!(shared_eth.name(), proto.cap.name); + + // the 6th shared message is the first message of the eth capability + let shared_eth = shared.find_by_relative_offset(1 + proto.messages).unwrap(); + assert_eq!(shared_eth.name(), "eth"); + } } diff --git a/crates/net/eth-wire/src/lib.rs b/crates/net/eth-wire/src/lib.rs index f379090aa4e5..a1c7c70bdf0f 100644 --- a/crates/net/eth-wire/src/lib.rs +++ b/crates/net/eth-wire/src/lib.rs @@ -20,6 +20,7 @@ mod disconnect; pub mod errors; mod ethstream; mod hello; +pub mod multiplex; mod p2pstream; mod pinger; pub mod protocol; @@ -27,6 +28,9 @@ pub use builder::*; pub mod types; pub use types::*; +#[cfg(test)] +pub mod test_utils; + #[cfg(test)] pub use tokio_util::codec::{ LengthDelimitedCodec as PassthroughCodec, LengthDelimitedCodecError as PassthroughCodecError, diff --git a/crates/net/eth-wire/src/multiplex.rs b/crates/net/eth-wire/src/multiplex.rs new file mode 100644 index 000000000000..d0dcf467e59c --- /dev/null +++ b/crates/net/eth-wire/src/multiplex.rs @@ -0,0 +1,459 @@ +//! Rlpx protocol multiplexer and satellite stream +//! +//! A Satellite is a Stream that primarily drives a single RLPx subprotocol but can also handle +//! additional subprotocols. +//! +//! Most of other subprotocols are "dependent satellite" protocols of "eth" and not a fully standalone protocol, for example "snap", See also [snap protocol](https://github.com/ethereum/devp2p/blob/298d7a77c3bf833641579ecbbb5b13f0311eeeea/caps/snap.md?plain=1#L71) +//! Hence it is expected that the primary protocol is "eth" and the additional protocols are +//! "dependent satellite" protocols. + +use std::{ + collections::VecDeque, + fmt, + future::Future, + io, + pin::Pin, + task::{ready, Context, Poll}, +}; + +use bytes::{Bytes, BytesMut}; +use futures::{pin_mut, Sink, SinkExt, Stream, StreamExt, TryStream, TryStreamExt}; +use tokio::sync::{mpsc, mpsc::UnboundedSender}; +use tokio_stream::wrappers::UnboundedReceiverStream; + +use crate::{ + capability::{Capability, SharedCapabilities, SharedCapability, UnsupportedCapabilityError}, + errors::P2PStreamError, + CanDisconnect, DisconnectReason, P2PStream, +}; + +/// A Stream and Sink type that wraps a raw rlpx stream [P2PStream] and handles message ID +/// multiplexing. +#[derive(Debug)] +pub struct RlpxProtocolMultiplexer { + /// The raw p2p stream + conn: P2PStream, + /// All the subprotocols that are multiplexed on top of the raw p2p stream + protocols: Vec, +} + +impl RlpxProtocolMultiplexer { + /// Wraps the raw p2p stream + pub fn new(conn: P2PStream) -> Self { + Self { conn, protocols: Default::default() } + } + + /// Installs a new protocol on top of the raw p2p stream + pub fn install_protocol( + &mut self, + _cap: Capability, + _st: S, + ) -> Result<(), UnsupportedCapabilityError> { + todo!() + } + + /// Returns the [SharedCapabilities] of the underlying raw p2p stream + pub fn shared_capabilities(&self) -> &SharedCapabilities { + self.conn.shared_capabilities() + } + + /// Converts this multiplexer into a [RlpxSatelliteStream] with the given primary protocol. + /// + /// Returns an error if the primary protocol is not supported by the remote or the handshake + /// failed. + pub async fn into_satellite_stream_with_handshake( + mut self, + cap: &Capability, + handshake: F, + ) -> Result, Self> + where + F: FnOnce(ProtocolProxy) -> Fut, + Fut: Future>, + St: Stream> + Sink + Unpin, + { + let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned() + else { + return Err(self) + }; + + let (to_primary, from_wire) = mpsc::unbounded_channel(); + let (to_wire, mut from_primary) = mpsc::unbounded_channel(); + let proxy = ProtocolProxy { + cap: shared_cap.clone(), + from_wire: UnboundedReceiverStream::new(from_wire), + to_wire, + }; + + let f = handshake(proxy); + pin_mut!(f); + + // handle messages until the handshake is complete + loop { + // TODO error handling + tokio::select! { + Some(Ok(msg)) = self.conn.next() => { + // TODO handle multiplex + let _ = to_primary.send(msg); + } + Some(msg) = from_primary.recv() => { + // TODO error handling + self.conn.send(msg).await.unwrap(); + } + res = &mut f => { + let Ok(primary) = res else { return Err(self) }; + return Ok(RlpxSatelliteStream { + conn: self.conn, + to_primary, + from_primary: UnboundedReceiverStream::new(from_primary), + primary, + primary_capability: shared_cap, + satellites: self.protocols, + out_buffer: Default::default(), + }) + } + } + } + } +} + +/// A Stream and Sink type that acts as a wrapper around a primary RLPx subprotocol (e.g. "eth") +#[derive(Debug)] +pub struct ProtocolProxy { + cap: SharedCapability, + from_wire: UnboundedReceiverStream, + to_wire: UnboundedSender, +} + +impl ProtocolProxy { + fn mask_msg_id(&self, msg: Bytes) -> Bytes { + // TODO handle empty messages + let mut masked_bytes = BytesMut::zeroed(msg.len()); + masked_bytes[0] = msg[0] + self.cap.relative_message_id_offset(); + masked_bytes[1..].copy_from_slice(&msg[1..]); + masked_bytes.freeze() + } + + fn unmask_id(&self, mut msg: BytesMut) -> BytesMut { + // TODO handle empty messages + msg[0] -= self.cap.relative_message_id_offset(); + msg + } +} + +impl Stream for ProtocolProxy { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let msg = ready!(self.from_wire.poll_next_unpin(cx)); + Poll::Ready(msg.map(|msg| Ok(self.get_mut().unmask_id(msg)))) + } +} + +impl Sink for ProtocolProxy { + type Error = io::Error; + + fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + let msg = self.mask_msg_id(item); + self.to_wire.send(msg).map_err(|_| io::ErrorKind::BrokenPipe.into()) + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +#[async_trait::async_trait] +impl CanDisconnect for ProtocolProxy { + async fn disconnect( + &mut self, + _reason: DisconnectReason, + ) -> Result<(), >::Error> { + // TODO handle disconnects + Ok(()) + } +} + +/// A connection channel to receive messages for the negotiated protocol. +/// +/// This is a [Stream] that returns raw bytes of the received messages for this protocol. +#[derive(Debug)] +pub struct ProtocolConnection { + from_wire: UnboundedReceiverStream, +} + +impl Stream for ProtocolConnection { + type Item = BytesMut; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.from_wire.poll_next_unpin(cx) + } +} + +/// A Stream and Sink type that acts as a wrapper around a primary RLPx subprotocol (e.g. "eth") +/// [EthStream](crate::EthStream) and can also handle additional subprotocols. +#[derive(Debug)] +pub struct RlpxSatelliteStream { + /// The raw p2p stream + conn: P2PStream, + to_primary: UnboundedSender, + from_primary: UnboundedReceiverStream, + primary: Primary, + primary_capability: SharedCapability, + satellites: Vec, + out_buffer: VecDeque, +} + +impl RlpxSatelliteStream {} + +impl Stream for RlpxSatelliteStream +where + St: Stream> + Sink + Unpin, + Primary: TryStream + Unpin, + P2PStreamError: Into, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + loop { + // first drain the primary stream + if let Poll::Ready(Some(msg)) = this.primary.try_poll_next_unpin(cx) { + return Poll::Ready(Some(msg)) + } + + let mut out_ready = true; + loop { + match this.conn.poll_ready_unpin(cx) { + Poll::Ready(_) => { + if let Some(msg) = this.out_buffer.pop_front() { + if let Err(err) = this.conn.start_send_unpin(msg) { + return Poll::Ready(Some(Err(err.into()))) + } + } else { + break; + } + } + Poll::Pending => { + out_ready = false; + break + } + } + } + + // advance primary out + loop { + match this.from_primary.poll_next_unpin(cx) { + Poll::Ready(Some(msg)) => { + this.out_buffer.push_back(msg); + } + Poll::Ready(None) => { + // primary closed + return Poll::Ready(None) + } + Poll::Pending => break, + } + } + + // advance all satellites + for idx in (0..this.satellites.len()).rev() { + let mut proto = this.satellites.swap_remove(idx); + loop { + match proto.poll_next_unpin(cx) { + Poll::Ready(Some(msg)) => { + this.out_buffer.push_back(msg); + } + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => { + this.satellites.push(proto); + break + } + } + } + } + + let mut delegated = false; + loop { + // pull messages from connection + match this.conn.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(msg))) => { + delegated = true; + let offset = msg[0]; + // find the protocol that matches the offset + // TODO optimize this by keeping a better index + let mut lowest_satellite = None; + // find the protocol with the lowest offset that is greater than the message + // offset + for (i, proto) in this.satellites.iter().enumerate() { + let proto_offset = proto.cap.relative_message_id_offset(); + if proto_offset >= offset { + if let Some((_, lowest_offset)) = lowest_satellite { + if proto_offset < lowest_offset { + lowest_satellite = Some((i, proto_offset)); + } + } else { + lowest_satellite = Some((i, proto_offset)); + } + } + } + + if let Some((idx, lowest_offset)) = lowest_satellite { + if lowest_offset < this.primary_capability.relative_message_id_offset() + { + // delegate to satellite + this.satellites[idx].send_raw(msg); + continue + } + } + // delegate to primary + let _ = this.to_primary.send(msg); + } + Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))), + Poll::Ready(None) => { + // connection closed + return Poll::Ready(None) + } + Poll::Pending => break, + } + } + + if !delegated || !out_ready || this.out_buffer.is_empty() { + return Poll::Pending + } + } + } +} + +impl Sink for RlpxSatelliteStream +where + St: Stream> + Sink + Unpin, + Primary: Sink + Unpin, + P2PStreamError: Into<>::Error>, +{ + type Error = >::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + if let Err(err) = ready!(this.conn.poll_ready_unpin(cx)) { + return Poll::Ready(Err(err.into())) + } + if let Err(err) = ready!(this.primary.poll_ready_unpin(cx)) { + return Poll::Ready(Err(err)) + } + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + self.get_mut().primary.start_send_unpin(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().conn.poll_flush_unpin(cx).map_err(Into::into) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().conn.poll_close_unpin(cx).map_err(Into::into) + } +} + +/// Wraps a RLPx subprotocol and handles message ID multiplexing. +struct ProtocolStream { + cap: SharedCapability, + /// the channel shared with the satellite stream + to_satellite: UnboundedSender, + satellite_st: Pin>>, +} + +impl ProtocolStream { + fn mask_msg_id(&self, mut msg: BytesMut) -> Bytes { + // TODO handle empty messages + msg[0] += self.cap.relative_message_id_offset(); + msg.freeze() + } + + fn unmask_id(&self, mut msg: BytesMut) -> BytesMut { + // TODO handle empty messages + msg[0] -= self.cap.relative_message_id_offset(); + msg + } + + fn send_raw(&self, msg: BytesMut) { + let _ = self.to_satellite.send(self.unmask_id(msg)); + } +} + +impl Stream for ProtocolStream { + type Item = Bytes; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let msg = ready!(this.satellite_st.as_mut().poll_next(cx)); + Poll::Ready(msg.map(|msg| this.mask_msg_id(msg))) + } +} + +impl fmt::Debug for ProtocolStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ProtocolStream").field("cap", &self.cap).finish_non_exhaustive() + } +} + +#[cfg(test)] +mod tests { + use tokio::net::TcpListener; + use tokio_util::codec::Decoder; + + use crate::{ + test_utils::{connect_passthrough, eth_handshake, eth_hello}, + UnauthedEthStream, UnauthedP2PStream, + }; + + use super::*; + + #[tokio::test] + async fn eth_satellite() { + reth_tracing::init_test_tracing(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let local_addr = listener.local_addr().unwrap(); + let (status, fork_filter) = eth_handshake(); + let other_status = status; + let other_fork_filter = fork_filter.clone(); + let _handle = tokio::spawn(async move { + let (incoming, _) = listener.accept().await.unwrap(); + let stream = crate::PassthroughCodec::default().framed(incoming); + let (server_hello, _) = eth_hello(); + let (p2p_stream, _) = + UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap(); + + let (_eth_stream, _) = UnauthedEthStream::new(p2p_stream) + .handshake(other_status, other_fork_filter) + .await + .unwrap(); + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + }); + + let conn = connect_passthrough(local_addr, eth_hello().0).await; + let eth = conn.shared_capabilities().eth().unwrap().clone(); + + let multiplexer = RlpxProtocolMultiplexer::new(conn); + + let _satellite = multiplexer + .into_satellite_stream_with_handshake( + eth.capability().as_ref(), + move |proxy| async move { + UnauthedEthStream::new(proxy).handshake(status, fork_filter).await + }, + ) + .await + .unwrap(); + } +} diff --git a/crates/net/eth-wire/src/p2pstream.rs b/crates/net/eth-wire/src/p2pstream.rs index a0f5c9f48d51..ed6001fb934a 100644 --- a/crates/net/eth-wire/src/p2pstream.rs +++ b/crates/net/eth-wire/src/p2pstream.rs @@ -1,19 +1,5 @@ #![allow(dead_code, unreachable_pub, missing_docs, unused_variables)] -use crate::{ - disconnect::CanDisconnect, - errors::{P2PHandshakeError, P2PStreamError}, - pinger::{Pinger, PingerEvent}, - DisconnectReason, HelloMessage, HelloMessageWithProtocols, -}; -use alloy_rlp::{Decodable, Encodable, Error as RlpError, EMPTY_LIST_CODE}; -use futures::{Sink, SinkExt, StreamExt}; -use pin_project::pin_project; -use reth_codecs::derive_arbitrary; -use reth_metrics::metrics::counter; -use reth_primitives::{ - bytes::{Buf, BufMut, Bytes, BytesMut}, - hex, GotExpected, -}; + use std::{ collections::VecDeque, fmt, io, @@ -21,13 +7,30 @@ use std::{ task::{ready, Context, Poll}, time::Duration, }; -use tokio_stream::Stream; -use crate::capability::SharedCapabilities; +use alloy_rlp::{Decodable, Encodable, Error as RlpError, EMPTY_LIST_CODE}; +use futures::{Sink, SinkExt, StreamExt}; +use pin_project::pin_project; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use tokio_stream::Stream; use tracing::{debug, trace}; +use reth_codecs::derive_arbitrary; +use reth_metrics::metrics::counter; +use reth_primitives::{ + bytes::{Buf, BufMut, Bytes, BytesMut}, + hex, GotExpected, +}; + +use crate::{ + capability::SharedCapabilities, + disconnect::CanDisconnect, + errors::{P2PHandshakeError, P2PStreamError}, + pinger::{Pinger, PingerEvent}, + DisconnectReason, HelloMessage, HelloMessageWithProtocols, +}; + /// [`MAX_PAYLOAD_SIZE`] is the maximum size of an uncompressed message payload. /// This is defined in [EIP-706](https://eips.ethereum.org/EIPS/eip-706). const MAX_PAYLOAD_SIZE: usize = 16 * 1024 * 1024; @@ -785,27 +788,12 @@ impl Decodable for ProtocolVersion { #[cfg(test)] mod tests { use super::*; - use crate::{capability::SharedCapability, DisconnectReason, EthVersion}; - use reth_discv4::DEFAULT_DISCOVERY_PORT; - use reth_ecies::util::pk2id; - use secp256k1::{SecretKey, SECP256K1}; + use crate::{ + capability::SharedCapability, test_utils::eth_hello, DisconnectReason, EthVersion, + }; use tokio::net::{TcpListener, TcpStream}; use tokio_util::codec::Decoder; - /// Returns a testing `HelloMessage` and new secretkey - fn eth_hello() -> (HelloMessageWithProtocols, SecretKey) { - let server_key = SecretKey::new(&mut rand::thread_rng()); - let protocols = vec![EthVersion::Eth67.into()]; - let hello = HelloMessageWithProtocols { - protocol_version: ProtocolVersion::V5, - client_version: "bitcoind/1.0.0".to_string(), - protocols, - port: DEFAULT_DISCOVERY_PORT, - id: pk2id(&server_key.public_key(SECP256K1)), - }; - (hello, server_key) - } - #[tokio::test] async fn test_can_disconnect() { reth_tracing::init_test_tracing(); diff --git a/crates/net/eth-wire/src/test_utils.rs b/crates/net/eth-wire/src/test_utils.rs new file mode 100644 index 000000000000..01bd9a048dc3 --- /dev/null +++ b/crates/net/eth-wire/src/test_utils.rs @@ -0,0 +1,57 @@ +//! Utilities for testing p2p protocol. + +use crate::{ + EthVersion, HelloMessageWithProtocols, P2PStream, ProtocolVersion, Status, UnauthedP2PStream, +}; +use reth_discv4::DEFAULT_DISCOVERY_PORT; +use reth_ecies::util::pk2id; +use reth_primitives::{Chain, ForkFilter, Head, B256, U256}; +use secp256k1::{SecretKey, SECP256K1}; +use std::net::SocketAddr; +use tokio::net::TcpStream; +use tokio_util::codec::{Decoder, Framed, LengthDelimitedCodec}; + +pub type P2pPassthroughTcpStream = P2PStream>; + +/// Returns a new testing `HelloMessage` and new secretkey +pub fn eth_hello() -> (HelloMessageWithProtocols, SecretKey) { + let server_key = SecretKey::new(&mut rand::thread_rng()); + let protocols = vec![EthVersion::Eth67.into()]; + let hello = HelloMessageWithProtocols { + protocol_version: ProtocolVersion::V5, + client_version: "eth/1.0.0".to_string(), + protocols, + port: DEFAULT_DISCOVERY_PORT, + id: pk2id(&server_key.public_key(SECP256K1)), + }; + (hello, server_key) +} + +/// Returns testing eth handshake status and fork filter. +pub fn eth_handshake() -> (Status, ForkFilter) { + let genesis = B256::random(); + let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new()); + + let status = Status { + version: EthVersion::Eth67 as u8, + chain: Chain::mainnet(), + total_difficulty: U256::ZERO, + blockhash: B256::random(), + genesis, + // Pass the current fork id. + forkid: fork_filter.current(), + }; + (status, fork_filter) +} + +/// Connects to a remote node and returns an authenticated `P2PStream` with the remote node. +pub async fn connect_passthrough( + addr: SocketAddr, + client_hello: HelloMessageWithProtocols, +) -> P2pPassthroughTcpStream { + let outgoing = TcpStream::connect(addr).await.unwrap(); + let sink = crate::PassthroughCodec::default().framed(outgoing); + let (p2p_stream, _) = UnauthedP2PStream::new(sink).handshake(client_hello).await.unwrap(); + + p2p_stream +} diff --git a/crates/net/network/src/builder.rs b/crates/net/network/src/builder.rs index efe17ec5a4c8..05c84b7da9b1 100644 --- a/crates/net/network/src/builder.rs +++ b/crates/net/network/src/builder.rs @@ -28,6 +28,21 @@ impl NetworkBuilder { (network, transactions, request_handler) } + /// Returns the network manager. + pub fn network(&self) -> &NetworkManager { + &self.network + } + + /// Returns the mutable network manager. + pub fn network_mut(&mut self) -> &mut NetworkManager { + &mut self.network + } + + /// Returns the handle to the network. + pub fn handle(&self) -> NetworkHandle { + self.network.handle().clone() + } + /// Consumes the type and returns all fields and also return a [`NetworkHandle`]. pub fn split_with_handle(self) -> (NetworkHandle, NetworkManager, Tx, Eth) { let NetworkBuilder { network, transactions, request_handler } = self; diff --git a/crates/net/network/src/lib.rs b/crates/net/network/src/lib.rs index 3975918e999d..eec19ee1c127 100644 --- a/crates/net/network/src/lib.rs +++ b/crates/net/network/src/lib.rs @@ -141,7 +141,7 @@ pub use discovery::{Discovery, DiscoveryEvent}; pub use fetch::FetchClient; pub use manager::{NetworkEvent, NetworkManager}; pub use message::PeerRequest; -pub use network::{NetworkEvents, NetworkHandle}; +pub use network::{NetworkEvents, NetworkHandle, NetworkProtocols}; pub use peers::PeersConfig; pub use session::{ ActiveSessionHandle, ActiveSessionMessage, Direction, PeerInfo, PendingSessionEvent, diff --git a/crates/net/network/src/manager.rs b/crates/net/network/src/manager.rs index 7df8addb2101..bd96ba28072c 100644 --- a/crates/net/network/src/manager.rs +++ b/crates/net/network/src/manager.rs @@ -26,6 +26,7 @@ use crate::{ metrics::{DisconnectMetrics, NetworkMetrics, NETWORK_POOL_TRANSACTIONS_SCOPE}, network::{NetworkHandle, NetworkHandleMessage}, peers::{PeersHandle, PeersManager}, + protocol::IntoRlpxSubProtocol, session::SessionManager, state::NetworkState, swarm::{NetworkConnectionState, Swarm, SwarmEvent}, @@ -142,6 +143,11 @@ impl NetworkManager { self.to_eth_request_handler = Some(tx); } + /// Adds an additional protocol handler to the RLPx sub-protocol list. + pub fn add_rlpx_sub_protocol(&mut self, protocol: impl IntoRlpxSubProtocol) { + self.swarm.add_rlpx_sub_protocol(protocol) + } + /// Returns the [`NetworkHandle`] that can be cloned and shared. /// /// The [`NetworkHandle`] can be used to interact with this [`NetworkManager`] @@ -598,6 +604,7 @@ where let peers = self.swarm.state().peers().peers_by_kind(kind); let _ = tx.send(self.swarm.sessions().get_peer_infos_by_ids(peers)); } + NetworkHandleMessage::AddRlpxSubProtocol(proto) => self.add_rlpx_sub_protocol(proto), } } } diff --git a/crates/net/network/src/network.rs b/crates/net/network/src/network.rs index 59b05312c9ca..3448a788cbcd 100644 --- a/crates/net/network/src/network.rs +++ b/crates/net/network/src/network.rs @@ -1,6 +1,6 @@ use crate::{ config::NetworkMode, discovery::DiscoveryEvent, manager::NetworkEvent, message::PeerRequest, - peers::PeersHandle, FetchClient, + peers::PeersHandle, protocol::RlpxSubProtocol, FetchClient, }; use async_trait::async_trait; use parking_lot::Mutex; @@ -155,6 +155,8 @@ impl NetworkHandle { } } +// === API Implementations === + impl NetworkEvents for NetworkHandle { fn event_listener(&self) -> UnboundedReceiverStream { let (tx, rx) = mpsc::unbounded_channel(); @@ -169,7 +171,11 @@ impl NetworkEvents for NetworkHandle { } } -// === API Implementations === +impl NetworkProtocols for NetworkHandle { + fn add_rlpx_sub_protocol(&self, protocol: RlpxSubProtocol) { + self.send_message(NetworkHandleMessage::AddRlpxSubProtocol(protocol)) + } +} impl PeersInfo for NetworkHandle { fn num_connected_peers(&self) -> usize { @@ -353,6 +359,12 @@ pub trait NetworkEvents: Send + Sync { fn discovery_listener(&self) -> UnboundedReceiverStream; } +/// Provides access to modify the network's additional protocol handlers. +pub trait NetworkProtocols: Send + Sync { + /// Adds an additional protocol handler to the RLPx sub-protocol list. + fn add_rlpx_sub_protocol(&self, protocol: RlpxSubProtocol); +} + /// Internal messages that can be passed to the [`NetworkManager`](crate::NetworkManager). #[allow(missing_docs)] #[derive(Debug)] @@ -400,4 +412,6 @@ pub(crate) enum NetworkHandleMessage { Shutdown(oneshot::Sender<()>), /// Add a new listener for `DiscoveryEvent`. DiscoveryListener(UnboundedSender), + /// Add an additional [RlpxSubProtocol]. + AddRlpxSubProtocol(RlpxSubProtocol), } diff --git a/crates/net/network/src/protocol.rs b/crates/net/network/src/protocol.rs index 24dd68690422..adcfb75f22db 100644 --- a/crates/net/network/src/protocol.rs +++ b/crates/net/network/src/protocol.rs @@ -2,19 +2,14 @@ //! //! See also -use futures::{Stream, StreamExt}; -use reth_eth_wire::{capability::SharedCapabilities, protocol::Protocol}; +use futures::Stream; +use reth_eth_wire::{ + capability::SharedCapabilities, multiplex::ProtocolConnection, protocol::Protocol, +}; use reth_network_api::Direction; use reth_primitives::BytesMut; use reth_rpc_types::PeerId; -use std::{ - fmt, - net::SocketAddr, - pin::Pin, - task::{Context, Poll}, -}; - -use tokio_stream::wrappers::UnboundedReceiverStream; +use std::{fmt, net::SocketAddr, pin::Pin}; /// A trait that allows to offer additional RLPx-based application-level protocols when establishing /// a peer-to-peer connection. @@ -81,22 +76,6 @@ pub enum OnNotSupported { Disconnect, } -/// A connection channel to receive messages for the negotiated protocol. -/// -/// This is a [Stream] that returns raw bytes of the received messages for this protocol. -#[derive(Debug)] -pub struct ProtocolConnection { - from_wire: UnboundedReceiverStream, -} - -impl Stream for ProtocolConnection { - type Item = BytesMut; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.from_wire.poll_next_unpin(cx) - } -} - /// A wrapper type for a RLPx sub-protocol. #[derive(Debug)] pub struct RlpxSubProtocol(Box); @@ -116,6 +95,12 @@ where } } +impl IntoRlpxSubProtocol for RlpxSubProtocol { + fn into_rlpx_sub_protocol(self) -> RlpxSubProtocol { + self + } +} + /// Additional RLPx-based sub-protocols. #[derive(Debug, Default)] pub struct RlpxSubProtocols { diff --git a/crates/net/network/src/session/mod.rs b/crates/net/network/src/session/mod.rs index 0e6fdcd7fbe4..863964ac97ed 100644 --- a/crates/net/network/src/session/mod.rs +++ b/crates/net/network/src/session/mod.rs @@ -48,7 +48,7 @@ pub use handle::{ SessionCommand, }; -use crate::protocol::RlpxSubProtocols; +use crate::protocol::{IntoRlpxSubProtocol, RlpxSubProtocols}; pub use reth_network_api::{Direction, PeerInfo}; /// Internal identifier for active sessions. @@ -103,7 +103,6 @@ pub struct SessionManager { /// Receiver half that listens for [`ActiveSessionMessage`] produced by pending sessions. active_session_rx: ReceiverStream, /// Additional RLPx sub-protocols to be used by the session manager. - #[allow(unused)] extra_protocols: RlpxSubProtocols, /// Used to measure inbound & outbound bandwidth across all managed streams bandwidth_meter: BandwidthMeter, @@ -176,6 +175,11 @@ impl SessionManager { self.hello_message.clone() } + /// Adds an additional protocol handler to the RLPx sub-protocol list. + pub(crate) fn add_rlpx_sub_protocol(&mut self, protocol: impl IntoRlpxSubProtocol) { + self.extra_protocols.push(protocol) + } + /// Spawns the given future onto a new task that is tracked in the `spawned_tasks` /// [`JoinSet`](tokio::task::JoinSet). fn spawn(&self, f: F) diff --git a/crates/net/network/src/swarm.rs b/crates/net/network/src/swarm.rs index 9f32efd16852..ce647fe181e8 100644 --- a/crates/net/network/src/swarm.rs +++ b/crates/net/network/src/swarm.rs @@ -2,6 +2,7 @@ use crate::{ listener::{ConnectionListener, ListenerEvent}, message::{PeerMessage, PeerRequestSender}, peers::InboundConnectionError, + protocol::IntoRlpxSubProtocol, session::{Direction, PendingSessionHandshakeError, SessionEvent, SessionId, SessionManager}, state::{NetworkState, StateAction}, }; @@ -76,10 +77,7 @@ pub(crate) struct Swarm { // === impl Swarm === -impl Swarm -where - C: BlockNumReader, -{ +impl Swarm { /// Configures a new swarm instance. pub(crate) fn new( incoming: ConnectionListener, @@ -90,6 +88,11 @@ where Self { incoming, sessions, state, net_connection_state } } + /// Adds an additional protocol handler to the RLPx sub-protocol list. + pub(crate) fn add_rlpx_sub_protocol(&mut self, protocol: impl IntoRlpxSubProtocol) { + self.sessions_mut().add_rlpx_sub_protocol(protocol); + } + /// Access to the state. pub(crate) fn state(&self) -> &NetworkState { &self.state @@ -114,7 +117,12 @@ where pub(crate) fn sessions_mut(&mut self) -> &mut SessionManager { &mut self.sessions } +} +impl Swarm +where + C: BlockNumReader, +{ /// Triggers a new outgoing connection to the given node pub(crate) fn dial_outbound(&mut self, remote_addr: SocketAddr, remote_id: PeerId) { self.sessions.dial_outbound(remote_addr, remote_id) diff --git a/crates/net/network/src/transactions.rs b/crates/net/network/src/transactions.rs index 24f6f8ffb24a..f9ffc138deae 100644 --- a/crates/net/network/src/transactions.rs +++ b/crates/net/network/src/transactions.rs @@ -760,7 +760,7 @@ where } fn report_peer(&self, peer_id: PeerId, kind: ReputationChangeKind) { - trace!(target: "net::tx", ?peer_id, ?kind); + trace!(target: "net::tx", ?peer_id, ?kind, "reporting reputation change"); self.network.reputation_change(peer_id, kind); self.metrics.reported_bad_transactions.increment(1); } @@ -831,11 +831,10 @@ where while let Poll::Ready(fetch_event) = this.transaction_fetcher.poll(cx) { match fetch_event { FetchEvent::TransactionsFetched { peer_id, transactions } => { - if let Some(txns) = transactions { - this.import_transactions(peer_id, txns, TransactionSource::Response); - } + this.import_transactions(peer_id, transactions, TransactionSource::Response); } FetchEvent::FetchError { peer_id, error } => { + trace!(target: "net::tx", ?peer_id, ?error, "requesting transactions from peer failed"); this.on_request_error(peer_id, error); } } @@ -857,7 +856,7 @@ where // known that this transaction is bad. (e.g. consensus // rules) if err.is_bad_transaction() && !this.network.is_syncing() { - trace!(target: "net::tx", ?err, "Bad transaction import"); + trace!(target: "net::tx", ?err, "bad pool transaction import"); this.on_bad_import(err.hash); continue } @@ -1008,6 +1007,8 @@ impl TransactionSource { /// An inflight request for `PooledTransactions` from a peer struct GetPooledTxRequest { peer_id: PeerId, + /// Transaction hashes that were requested, for cleanup purposes + requested_hashes: Vec, response: oneshot::Receiver>, } @@ -1026,11 +1027,13 @@ struct GetPooledTxRequestFut { } impl GetPooledTxRequestFut { + #[inline] fn new( peer_id: PeerId, + requested_hashes: Vec, response: oneshot::Receiver>, ) -> Self { - Self { inner: Some(GetPooledTxRequest { peer_id, response }) } + Self { inner: Some(GetPooledTxRequest { peer_id, requested_hashes, response }) } } } @@ -1040,20 +1043,11 @@ impl Future for GetPooledTxRequestFut { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut req = self.as_mut().project().inner.take().expect("polled after completion"); match req.response.poll_unpin(cx) { - Poll::Ready(result) => { - let request_hashes: Vec = match &result { - Ok(Ok(pooled_txs)) => { - pooled_txs.0.iter().map(|tx_elem| *tx_elem.hash()).collect() - } - _ => Vec::new(), - }; - - Poll::Ready(GetPooledTxResponse { - peer_id: req.peer_id, - requested_hashes: request_hashes, - result, - }) - } + Poll::Ready(result) => Poll::Ready(GetPooledTxResponse { + peer_id: req.peer_id, + requested_hashes: req.requested_hashes, + result, + }), Poll::Pending => { self.project().inner.set(Some(req)); Poll::Pending @@ -1108,16 +1102,16 @@ impl TransactionFetcher { self.inflight_requests.poll_next_unpin(cx) { return match result { - Ok(Ok(txs)) => { + Ok(Ok(transactions)) => { // clear received hashes - self.remove_inflight_hashes(txs.hashes()); + self.remove_inflight_hashes(transactions.hashes()); // TODO: re-request missing hashes, for now clear all of them self.remove_inflight_hashes(requested_hashes.iter()); Poll::Ready(FetchEvent::TransactionsFetched { peer_id, - transactions: Some(txs.0), + transactions: transactions.0, }) } Ok(Err(req_err)) => { @@ -1189,7 +1183,7 @@ impl TransactionFetcher { let (response, rx) = oneshot::channel(); let req: PeerRequest = PeerRequest::GetPooledTransactions { - request: GetPooledTransactions(announced_hashes), + request: GetPooledTransactions(announced_hashes.clone()), response, }; @@ -1210,7 +1204,7 @@ impl TransactionFetcher { return false } else { //create a new request for it, from that peer - self.inflight_requests.push(GetPooledTxRequestFut::new(peer_id, rx)) + self.inflight_requests.push(GetPooledTxRequestFut::new(peer_id, announced_hashes, rx)) } true @@ -1225,7 +1219,7 @@ enum FetchEvent { /// The ID of the peer from which transactions were fetched. peer_id: PeerId, /// The transactions that were fetched, if available. - transactions: Option>, + transactions: Vec, }, /// Triggered when there is an error in fetching transactions. FetchError { diff --git a/crates/primitives/Cargo.toml b/crates/primitives/Cargo.toml index 4a9cfcedc4ab..8204636cec5b 100644 --- a/crates/primitives/Cargo.toml +++ b/crates/primitives/Cargo.toml @@ -12,7 +12,8 @@ description = "Commonly used types in reth." # reth reth-codecs.workspace = true reth-rpc-types.workspace = true -revm-primitives = { workspace = true, features = ["serde"] } +revm-primitives.workspace = true +revm.workspace = true # ethereum alloy-primitives = { workspace = true, features = ["rand", "rlp"] } @@ -61,8 +62,6 @@ proptest = { workspace = true, optional = true } proptest-derive = { workspace = true, optional = true } strum = { workspace = true, features = ["derive"] } -revm.workspace = true - [dev-dependencies] serde_json.workspace = true test-fuzz = "4" @@ -87,10 +86,10 @@ pprof = { workspace = true, features = ["flamegraph", "frame-pointer", "criterio [features] default = ["c-kzg"] arbitrary = ["revm-primitives/arbitrary", "reth-rpc-types/arbitrary", "dep:arbitrary", "dep:proptest", "dep:proptest-derive"] -c-kzg = ["revm-primitives/c-kzg", "dep:c-kzg"] -test-utils = ["dep:plain_hasher", "dep:hash-db", "dep:ethers-core"] +c-kzg = ["dep:c-kzg", "revm/c-kzg", "revm-primitives/c-kzg"] clap = ["dep:clap"] optimism = ["reth-codecs/optimism", "revm-primitives/optimism", "revm/optimism"] +test-utils = ["dep:plain_hasher", "dep:hash-db", "dep:ethers-core"] [[bench]] name = "recover_ecdsa_crit" diff --git a/crates/primitives/src/account.rs b/crates/primitives/src/account.rs index ab6761e41534..94d245828697 100644 --- a/crates/primitives/src/account.rs +++ b/crates/primitives/src/account.rs @@ -98,10 +98,7 @@ impl Compact for Bytecode { len + self.0.bytecode.len() + 4 } - fn from_compact(mut buf: &[u8], _: usize) -> (Self, &[u8]) - where - Self: Sized, - { + fn from_compact(mut buf: &[u8], _: usize) -> (Self, &[u8]) { let len = buf.read_u32::().expect("could not read bytecode length"); let bytes = Bytes::from(buf.copy_to_bytes(len as usize)); let variant = buf.read_u8().expect("could not read bytecode variant"); diff --git a/crates/primitives/src/block.rs b/crates/primitives/src/block.rs index a432cdbc2089..54e0085f531b 100644 --- a/crates/primitives/src/block.rs +++ b/crates/primitives/src/block.rs @@ -98,7 +98,17 @@ impl BlockWithSenders { (!block.body.len() != senders.len()).then_some(Self { block, senders }) } + /// Seal the block with a known hash. + /// + /// WARNING: This method does not perform validation whether the hash is correct. + #[inline] + pub fn seal(self, hash: B256) -> SealedBlockWithSenders { + let Self { block, senders } = self; + SealedBlockWithSenders { block: block.seal(hash), senders } + } + /// Split Structure to its components + #[inline] pub fn into_components(self) -> (Block, Vec
) { (self.block, self.senders) } @@ -288,6 +298,13 @@ impl SealedBlockWithSenders { (self.block, self.senders) } + /// Returns the unsealed [BlockWithSenders] + #[inline] + pub fn unseal(self) -> BlockWithSenders { + let Self { block, senders } = self; + BlockWithSenders { block: block.unseal(), senders } + } + /// Returns an iterator over all transactions in the block. #[inline] pub fn transactions(&self) -> impl Iterator + '_ { diff --git a/crates/primitives/src/genesis.rs b/crates/primitives/src/genesis.rs index dd1f4a0ffd64..02b4700fe8d9 100644 --- a/crates/primitives/src/genesis.rs +++ b/crates/primitives/src/genesis.rs @@ -2,8 +2,9 @@ use crate::{ constants::EMPTY_ROOT_HASH, keccak256, serde_helper::{ - deserialize_json_ttd_opt, deserialize_json_u256, deserialize_storage_map, + json_u256::{deserialize_json_ttd_opt, deserialize_json_u256}, num::{u64_hex_or_decimal, u64_hex_or_decimal_opt}, + storage::deserialize_storage_map, }, trie::{HashBuilder, Nibbles}, Account, Address, Bytes, B256, KECCAK_EMPTY, U256, diff --git a/crates/primitives/src/revm/env.rs b/crates/primitives/src/revm/env.rs index 2f3d8005c572..26b87a159b40 100644 --- a/crates/primitives/src/revm/env.rs +++ b/crates/primitives/src/revm/env.rs @@ -309,6 +309,7 @@ pub fn fill_tx_env( } #[cfg(feature = "optimism")] Transaction::Deposit(tx) => { + tx_env.access_list.clear(); tx_env.gas_limit = tx.gas_limit; tx_env.gas_price = U256::ZERO; tx_env.gas_priority_fee = None; diff --git a/crates/primitives/src/serde_helper/mod.rs b/crates/primitives/src/serde_helper/mod.rs index 6e78c5a7984f..2e897ebcdab1 100644 --- a/crates/primitives/src/serde_helper/mod.rs +++ b/crates/primitives/src/serde_helper/mod.rs @@ -1,73 +1,6 @@ //! [serde] utilities. -use crate::{B256, U64}; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; - -mod storage; -pub use storage::*; - -pub use reth_rpc_types::json_u256::*; - -pub mod num; +pub use reth_rpc_types::serde_helpers::*; mod prune; pub use prune::deserialize_opt_prune_mode_with_min_blocks; - -/// serde functions for handling primitive `u64` as [`U64`]. -pub mod u64_hex { - use super::*; - - /// Deserializes an `u64` from [U64] accepting a hex quantity string with optional 0x prefix - pub fn deserialize<'de, D>(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - U64::deserialize(deserializer).map(|val| val.to()) - } - - /// Serializes u64 as hex string - pub fn serialize(value: &u64, s: S) -> Result { - U64::from(*value).serialize(s) - } -} - -/// Serialize a byte vec as a hex string _without_ the "0x" prefix. -/// -/// This behaves the same as [`hex::encode`](crate::hex::encode). -pub fn serialize_hex_string_no_prefix(x: T, s: S) -> Result -where - S: Serializer, - T: AsRef<[u8]>, -{ - s.serialize_str(&crate::hex::encode(x.as_ref())) -} - -/// Serialize a [B256] as a hex string _without_ the "0x" prefix. -pub fn serialize_b256_hex_string_no_prefix(x: &B256, s: S) -> Result -where - S: Serializer, -{ - s.serialize_str(&format!("{x:x}")) -} - -#[cfg(test)] -mod tests { - use super::*; - use serde::{Deserialize, Serialize}; - - #[test] - fn test_hex_u64() { - #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] - struct Value { - #[serde(with = "u64_hex")] - inner: u64, - } - - let val = Value { inner: 1000 }; - let s = serde_json::to_string(&val).unwrap(); - assert_eq!(s, "{\"inner\":\"0x3e8\"}"); - - let deserialized: Value = serde_json::from_str(&s).unwrap(); - assert_eq!(val, deserialized); - } -} diff --git a/crates/primitives/src/serde_helper/num.rs b/crates/primitives/src/serde_helper/num.rs deleted file mode 100644 index e2262ccca78b..000000000000 --- a/crates/primitives/src/serde_helper/num.rs +++ /dev/null @@ -1,216 +0,0 @@ -//! Numeric helpers - -use crate::{U256, U64}; -use serde::{de, Deserialize, Deserializer, Serialize}; -use std::str::FromStr; - -/// A `u64` wrapper type that deserializes from hex or a u64 and serializes as hex. -/// -/// -/// ```rust -/// use reth_primitives::serde_helper::num::U64HexOrNumber; -/// let number_json = "100"; -/// let hex_json = "\"0x64\""; -/// -/// let number: U64HexOrNumber = serde_json::from_str(number_json).unwrap(); -/// let hex: U64HexOrNumber = serde_json::from_str(hex_json).unwrap(); -/// assert_eq!(number, hex); -/// assert_eq!(hex.to(), 100); -/// ``` -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)] -pub struct U64HexOrNumber(U64); - -impl U64HexOrNumber { - /// Returns the wrapped u64 - pub fn to(self) -> u64 { - self.0.to() - } -} - -impl From for U64HexOrNumber { - fn from(value: u64) -> Self { - Self(U64::from(value)) - } -} - -impl From for U64HexOrNumber { - fn from(value: U64) -> Self { - Self(value) - } -} - -impl From for u64 { - fn from(value: U64HexOrNumber) -> Self { - value.to() - } -} - -impl From for U64 { - fn from(value: U64HexOrNumber) -> Self { - value.0 - } -} - -impl<'de> Deserialize<'de> for U64HexOrNumber { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - #[derive(Deserialize)] - #[serde(untagged)] - enum NumberOrHexU64 { - Hex(U64), - Int(u64), - } - match NumberOrHexU64::deserialize(deserializer)? { - NumberOrHexU64::Int(val) => Ok(val.into()), - NumberOrHexU64::Hex(val) => Ok(val.into()), - } - } -} - -/// serde functions for handling primitive `u64` as [U64] -pub mod u64_hex_or_decimal { - use crate::serde_helper::num::U64HexOrNumber; - use serde::{Deserialize, Deserializer, Serialize, Serializer}; - - /// Deserializes an `u64` accepting a hex quantity string with optional 0x prefix or - /// a number - pub fn deserialize<'de, D>(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - U64HexOrNumber::deserialize(deserializer).map(Into::into) - } - - /// Serializes u64 as hex string - pub fn serialize(value: &u64, s: S) -> Result { - U64HexOrNumber::from(*value).serialize(s) - } -} - -/// serde functions for handling primitive optional `u64` as [U64] -pub mod u64_hex_or_decimal_opt { - use crate::serde_helper::num::U64HexOrNumber; - use serde::{Deserialize, Deserializer, Serialize, Serializer}; - - /// Deserializes an `u64` accepting a hex quantity string with optional 0x prefix or - /// a number - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - match Option::::deserialize(deserializer)? { - Some(val) => Ok(Some(val.into())), - None => Ok(None), - } - } - - /// Serializes u64 as hex string - pub fn serialize(value: &Option, s: S) -> Result { - match value { - Some(val) => U64HexOrNumber::from(*val).serialize(s), - None => s.serialize_none(), - } - } -} - -/// Deserializes the input into an `Option`, using [`from_int_or_hex`] to deserialize the -/// inner value. -pub fn from_int_or_hex_opt<'de, D>(deserializer: D) -> Result, D::Error> -where - D: Deserializer<'de>, -{ - match Option::::deserialize(deserializer)? { - Some(val) => val.try_into_u256().map(Some), - None => Ok(None), - } -} - -/// Deserializes the input into a U256, accepting both 0x-prefixed hex and decimal strings with -/// arbitrary precision, defined by serde_json's [`Number`](serde_json::Number). -pub fn from_int_or_hex<'de, D>(deserializer: D) -> Result -where - D: Deserializer<'de>, -{ - NumberOrHexU256::deserialize(deserializer)?.try_into_u256() -} - -#[derive(Deserialize)] -#[serde(untagged)] -enum NumberOrHexU256 { - Int(serde_json::Number), - Hex(U256), -} - -impl NumberOrHexU256 { - fn try_into_u256(self) -> Result { - match self { - NumberOrHexU256::Int(num) => { - U256::from_str(num.to_string().as_str()).map_err(E::custom) - } - NumberOrHexU256::Hex(val) => Ok(val), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_u256_int_or_hex() { - #[derive(Debug, Deserialize, PartialEq, Eq)] - struct V(#[serde(deserialize_with = "from_int_or_hex")] U256); - - proptest::proptest!(|(value: u64)| { - let u256_val = U256::from(value); - - let num_obj = serde_json::to_string(&value).unwrap(); - let hex_obj = serde_json::to_string(&u256_val).unwrap(); - - let int_val:V = serde_json::from_str(&num_obj).unwrap(); - let hex_val = serde_json::from_str(&hex_obj).unwrap(); - assert_eq!(int_val, hex_val); - }); - } - - #[test] - fn test_u256_int_or_hex_opt() { - #[derive(Debug, Deserialize, PartialEq, Eq)] - struct V(#[serde(deserialize_with = "from_int_or_hex_opt")] Option); - - let null = serde_json::to_string(&None::).unwrap(); - let val: V = serde_json::from_str(&null).unwrap(); - assert!(val.0.is_none()); - - proptest::proptest!(|(value: u64)| { - let u256_val = U256::from(value); - - let num_obj = serde_json::to_string(&value).unwrap(); - let hex_obj = serde_json::to_string(&u256_val).unwrap(); - - let int_val:V = serde_json::from_str(&num_obj).unwrap(); - let hex_val = serde_json::from_str(&hex_obj).unwrap(); - assert_eq!(int_val, hex_val); - assert_eq!(int_val.0, Some(u256_val)); - }); - } - - #[test] - fn serde_hex_or_number_u64() { - #[derive(Debug, Deserialize, PartialEq, Eq)] - struct V(U64HexOrNumber); - - proptest::proptest!(|(value: u64)| { - let val = U64::from(value); - - let num_obj = serde_json::to_string(&value).unwrap(); - let hex_obj = serde_json::to_string(&val).unwrap(); - - let int_val:V = serde_json::from_str(&num_obj).unwrap(); - let hex_val = serde_json::from_str(&hex_obj).unwrap(); - assert_eq!(int_val, hex_val); - }); - } -} diff --git a/crates/primitives/src/serde_helper/storage.rs b/crates/primitives/src/serde_helper/storage.rs deleted file mode 100644 index 7d0b5045f453..000000000000 --- a/crates/primitives/src/serde_helper/storage.rs +++ /dev/null @@ -1,102 +0,0 @@ -use crate::{Bytes, B256, U256}; -use serde::{Deserialize, Deserializer, Serialize}; -use std::{collections::HashMap, fmt::Write}; - -/// A storage key type that can be serialized to and from a hex string up to 32 bytes. Used for -/// `eth_getStorageAt` and `eth_getProof` RPCs. -/// -/// This is a wrapper type meant to mirror geth's serialization and deserialization behavior for -/// storage keys. -/// -/// In `eth_getStorageAt`, this is used for deserialization of the `index` field. Internally, the -/// index is a [B256], but in `eth_getStorageAt` requests, its serialization can be _up to_ 32 -/// bytes. To support this, the storage key is deserialized first as a U256, and converted to a -/// B256 for use internally. -/// -/// `eth_getProof` also takes storage keys up to 32 bytes as input, so the `keys` field is -/// similarly deserialized. However, geth populates the storage proof `key` fields in the response -/// by mirroring the `key` field used in the input. -/// * See how `storageKey`s (the input) are populated in the `StorageResult` (the output): -/// -/// -/// The contained [B256] and From implementation for String are used to preserve the input and -/// implement this behavior from geth. -#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize)] -#[serde(from = "U256", into = "String")] -pub struct JsonStorageKey(pub B256); - -impl From for JsonStorageKey { - fn from(value: U256) -> Self { - // SAFETY: Address (B256) and U256 have the same number of bytes - JsonStorageKey(B256::from(value.to_be_bytes())) - } -} - -impl From for String { - fn from(value: JsonStorageKey) -> Self { - // SAFETY: Address (B256) and U256 have the same number of bytes - let uint = U256::from_be_bytes(value.0 .0); - - // serialize byte by byte - // - // this is mainly so we can return an output that hive testing expects, because the - // `eth_getProof` implementation in geth simply mirrors the input - // - // see the use of `hexKey` in the `eth_getProof` response: - // - let bytes = uint.to_be_bytes_trimmed_vec(); - let mut hex = String::with_capacity(2 + bytes.len() * 2); - hex.push_str("0x"); - for byte in bytes { - write!(hex, "{:02x}", byte).unwrap(); - } - hex - } -} - -/// Converts a Bytes value into a B256, accepting inputs that are less than 32 bytes long. These -/// inputs will be left padded with zeros. -pub fn from_bytes_to_b256<'de, D>(bytes: Bytes) -> Result -where - D: Deserializer<'de>, -{ - if bytes.0.len() > 32 { - return Err(serde::de::Error::custom("input too long to be a B256")) - } - - // left pad with zeros to 32 bytes - let mut padded = [0u8; 32]; - padded[32 - bytes.0.len()..].copy_from_slice(&bytes.0); - - // then convert to B256 without a panic - Ok(B256::from_slice(&padded)) -} - -/// Deserializes the input into an Option>, using [from_bytes_to_b256] which -/// allows cropped values: -/// -/// ```json -/// { -/// "0x0000000000000000000000000000000000000000000000000000000000000001": "0x22" -/// } -/// ``` -pub fn deserialize_storage_map<'de, D>( - deserializer: D, -) -> Result>, D::Error> -where - D: Deserializer<'de>, -{ - let map = Option::>::deserialize(deserializer)?; - match map { - Some(mut map) => { - let mut res_map = HashMap::with_capacity(map.len()); - for (k, v) in map.drain() { - let k_deserialized = from_bytes_to_b256::<'de, D>(k)?; - let v_deserialized = from_bytes_to_b256::<'de, D>(v)?; - res_map.insert(k_deserialized, v_deserialized); - } - Ok(Some(res_map)) - } - None => Ok(None), - } -} diff --git a/crates/primitives/src/stage/checkpoints.rs b/crates/primitives/src/stage/checkpoints.rs index 6df66123f5c6..c0cace519cbc 100644 --- a/crates/primitives/src/stage/checkpoints.rs +++ b/crates/primitives/src/stage/checkpoints.rs @@ -4,10 +4,7 @@ use crate::{ }; use bytes::Buf; use reth_codecs::{main_codec, Compact}; -use std::{ - fmt::{Display, Formatter}, - ops::RangeInclusive, -}; +use std::ops::RangeInclusive; /// Saves the progress of Merkle stage. #[derive(Default, Debug, Clone, PartialEq)] @@ -57,10 +54,7 @@ impl Compact for MerkleCheckpoint { len } - fn from_compact(mut buf: &[u8], _len: usize) -> (Self, &[u8]) - where - Self: Sized, - { + fn from_compact(mut buf: &[u8], _len: usize) -> (Self, &[u8]) { let target_block = buf.get_u64(); let last_account_key = B256::from_slice(&buf[..32]); @@ -145,9 +139,16 @@ pub struct EntitiesCheckpoint { pub total: u64, } -impl Display for EntitiesCheckpoint { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{:.2}%", 100.0 * self.processed as f64 / self.total as f64) +impl EntitiesCheckpoint { + /// Formats entities checkpoint as percentage, i.e. `processed / total`. + /// + /// Return [None] if `total == 0`. + pub fn fmt_percentage(&self) -> Option { + if self.total == 0 { + return None + } + + Some(format!("{:.2}%", 100.0 * self.processed as f64 / self.total as f64)) } } diff --git a/crates/primitives/src/storage.rs b/crates/primitives/src/storage.rs index 91bdce470477..1c9157fbd888 100644 --- a/crates/primitives/src/storage.rs +++ b/crates/primitives/src/storage.rs @@ -40,10 +40,7 @@ impl Compact for StorageEntry { self.value.to_compact(buf) + 32 } - fn from_compact(buf: &[u8], len: usize) -> (Self, &[u8]) - where - Self: Sized, - { + fn from_compact(buf: &[u8], len: usize) -> (Self, &[u8]) { let key = B256::from_slice(&buf[..32]); let (value, out) = U256::from_compact(&buf[32..], len - 32); (Self { key, value }, out) diff --git a/crates/primitives/src/transaction/access_list.rs b/crates/primitives/src/transaction/access_list.rs index 64b0551af89a..7032c1634b42 100644 --- a/crates/primitives/src/transaction/access_list.rs +++ b/crates/primitives/src/transaction/access_list.rs @@ -82,3 +82,31 @@ impl AccessList { self.0.capacity() * mem::size_of::() } } + +impl From for AccessList { + #[inline] + fn from(value: reth_rpc_types::AccessList) -> Self { + let converted_list = value + .0 + .into_iter() + .map(|item| AccessListItem { address: item.address, storage_keys: item.storage_keys }) + .collect(); + + AccessList(converted_list) + } +} + +impl From for reth_rpc_types::AccessList { + #[inline] + fn from(value: AccessList) -> Self { + let list = value + .0 + .into_iter() + .map(|item| reth_rpc_types::AccessListItem { + address: item.address, + storage_keys: item.storage_keys, + }) + .collect(); + reth_rpc_types::AccessList(list) + } +} diff --git a/crates/primitives/src/transaction/mod.rs b/crates/primitives/src/transaction/mod.rs index 47a262705d4b..9702ae19eb4d 100644 --- a/crates/primitives/src/transaction/mod.rs +++ b/crates/primitives/src/transaction/mod.rs @@ -819,7 +819,8 @@ impl TransactionSignedNoHash { /// Calculates the transaction hash. If used more than once, it's better to convert it to /// [`TransactionSigned`] first. pub fn hash(&self) -> B256 { - let mut buf = Vec::new(); + // pre-allocate buffer for the transaction + let mut buf = Vec::with_capacity(128 + self.transaction.input().len()); self.transaction.encode_with_signature(&self.signature, &mut buf, false); keccak256(&buf) } diff --git a/crates/primitives/src/transaction/variant.rs b/crates/primitives/src/transaction/variant.rs index 2ab2667222c7..b89aa6aa10e3 100644 --- a/crates/primitives/src/transaction/variant.rs +++ b/crates/primitives/src/transaction/variant.rs @@ -1,8 +1,10 @@ //! Helper enum functions for `Transaction`, `TransactionSigned` and //! `TransactionSignedEcRecovered` use crate::{ - Transaction, TransactionSigned, TransactionSignedEcRecovered, TransactionSignedNoHash, + Address, Transaction, TransactionSigned, TransactionSignedEcRecovered, TransactionSignedNoHash, + B256, }; +use std::ops::Deref; /// Represents various different transaction formats used in reth. /// @@ -29,6 +31,26 @@ impl TransactionSignedVariant { } } + /// Returns the hash of the transaction + pub fn hash(&self) -> B256 { + match self { + TransactionSignedVariant::SignedNoHash(tx) => tx.hash(), + TransactionSignedVariant::Signed(tx) => tx.hash, + TransactionSignedVariant::SignedEcRecovered(tx) => tx.hash, + } + } + + /// Returns the signer of the transaction. + /// + /// If the transaction is of not of [TransactionSignedEcRecovered] it will be recovered. + pub fn signer(&self) -> Option
{ + match self { + TransactionSignedVariant::SignedNoHash(tx) => tx.recover_signer(), + TransactionSignedVariant::Signed(tx) => tx.recover_signer(), + TransactionSignedVariant::SignedEcRecovered(tx) => Some(tx.signer), + } + } + /// Returns [TransactionSigned] type /// else None pub fn as_signed(&self) -> Option<&TransactionSigned> { @@ -130,3 +152,11 @@ impl AsRef for TransactionSignedVariant { self.as_raw() } } + +impl Deref for TransactionSignedVariant { + type Target = Transaction; + + fn deref(&self) -> &Self::Target { + self.as_raw() + } +} diff --git a/crates/primitives/src/trie/hash_builder/state.rs b/crates/primitives/src/trie/hash_builder/state.rs index fef54726bbae..c714dedd0bb4 100644 --- a/crates/primitives/src/trie/hash_builder/state.rs +++ b/crates/primitives/src/trie/hash_builder/state.rs @@ -68,10 +68,7 @@ impl Compact for HashBuilderState { len } - fn from_compact(buf: &[u8], _len: usize) -> (Self, &[u8]) - where - Self: Sized, - { + fn from_compact(buf: &[u8], _len: usize) -> (Self, &[u8]) { let (key, mut buf) = Vec::from_compact(buf, 0); let stack_len = buf.get_u16() as usize; diff --git a/crates/primitives/src/trie/hash_builder/value.rs b/crates/primitives/src/trie/hash_builder/value.rs index 45d4c0ce1c9f..fed85e680cf9 100644 --- a/crates/primitives/src/trie/hash_builder/value.rs +++ b/crates/primitives/src/trie/hash_builder/value.rs @@ -29,10 +29,7 @@ impl Compact for HashBuilderValue { } } - fn from_compact(buf: &[u8], _len: usize) -> (Self, &[u8]) - where - Self: Sized, - { + fn from_compact(buf: &[u8], _len: usize) -> (Self, &[u8]) { match buf[0] { 0 => { let (hash, buf) = B256::from_compact(&buf[1..], 32); diff --git a/crates/primitives/src/trie/mask.rs b/crates/primitives/src/trie/mask.rs index d54f239ad0a6..152be03c936d 100644 --- a/crates/primitives/src/trie/mask.rs +++ b/crates/primitives/src/trie/mask.rs @@ -72,10 +72,7 @@ impl Compact for TrieMask { 2 } - fn from_compact(mut buf: &[u8], _len: usize) -> (Self, &[u8]) - where - Self: Sized, - { + fn from_compact(mut buf: &[u8], _len: usize) -> (Self, &[u8]) { let mask = buf.get_u16(); (Self(mask), buf) } diff --git a/crates/primitives/src/trie/nibbles.rs b/crates/primitives/src/trie/nibbles.rs index a6ad01b241ab..876b9ff9a652 100644 --- a/crates/primitives/src/trie/nibbles.rs +++ b/crates/primitives/src/trie/nibbles.rs @@ -41,10 +41,7 @@ impl Compact for StoredNibblesSubKey { 64 + 1 } - fn from_compact(buf: &[u8], _len: usize) -> (Self, &[u8]) - where - Self: Sized, - { + fn from_compact(buf: &[u8], _len: usize) -> (Self, &[u8]) { let len = buf[64] as usize; let inner = Vec::from(&buf[..len]).into(); (Self(StoredNibbles { inner }), &buf[65..]) diff --git a/crates/primitives/src/trie/nodes/branch.rs b/crates/primitives/src/trie/nodes/branch.rs index 073c2e125f69..2771adfa40c8 100644 --- a/crates/primitives/src/trie/nodes/branch.rs +++ b/crates/primitives/src/trie/nodes/branch.rs @@ -163,10 +163,7 @@ impl Compact for BranchNodeCompact { buf_size } - fn from_compact(buf: &[u8], _len: usize) -> (Self, &[u8]) - where - Self: Sized, - { + fn from_compact(buf: &[u8], _len: usize) -> (Self, &[u8]) { let hash_len = B256::len_bytes(); // Assert the buffer is long enough to contain the masks and the hashes. diff --git a/crates/primitives/src/trie/storage.rs b/crates/primitives/src/trie/storage.rs index bbb5c5bc4206..33f68fc05d1b 100644 --- a/crates/primitives/src/trie/storage.rs +++ b/crates/primitives/src/trie/storage.rs @@ -24,10 +24,7 @@ impl Compact for StorageTrieEntry { nibbles_len + node_len } - fn from_compact(buf: &[u8], len: usize) -> (Self, &[u8]) - where - Self: Sized, - { + fn from_compact(buf: &[u8], len: usize) -> (Self, &[u8]) { let (nibbles, buf) = StoredNibblesSubKey::from_compact(buf, 33); let (node, buf) = BranchNodeCompact::from_compact(buf, len - 33); let this = Self { nibbles, node }; diff --git a/crates/primitives/src/trie/subnode.rs b/crates/primitives/src/trie/subnode.rs index 232a67279220..e6976cf13a2b 100644 --- a/crates/primitives/src/trie/subnode.rs +++ b/crates/primitives/src/trie/subnode.rs @@ -46,10 +46,7 @@ impl Compact for StoredSubNode { len } - fn from_compact(mut buf: &[u8], _len: usize) -> (Self, &[u8]) - where - Self: Sized, - { + fn from_compact(mut buf: &[u8], _len: usize) -> (Self, &[u8]) { let key_len = buf.get_u16() as usize; let key = Vec::from(&buf[..key_len]); buf.advance(key_len); diff --git a/crates/prune/src/segments/account_history.rs b/crates/prune/src/segments/account_history.rs index c1b5ae682cfc..bfebad1a95c0 100644 --- a/crates/prune/src/segments/account_history.rs +++ b/crates/prune/src/segments/account_history.rs @@ -32,7 +32,7 @@ impl Segment for AccountHistory { #[instrument(level = "trace", target = "pruner", skip(self, provider), ret)] fn prune( &self, - provider: &DatabaseProviderRW<'_, DB>, + provider: &DatabaseProviderRW, input: PruneInput, ) -> Result { let range = match input.get_next_block_range() { @@ -90,16 +90,16 @@ mod tests { }; use reth_primitives::{BlockNumber, PruneCheckpoint, PruneMode, PruneSegment, B256}; use reth_provider::PruneCheckpointReader; - use reth_stages::test_utils::TestTransaction; + use reth_stages::test_utils::TestStageDB; use std::{collections::BTreeMap, ops::AddAssign}; #[test] fn prune() { - let tx = TestTransaction::default(); + let db = TestStageDB::default(); let mut rng = generators::rng(); let blocks = random_block_range(&mut rng, 1..=5000, B256::ZERO, 0..1); - tx.insert_blocks(blocks.iter(), None).expect("insert blocks"); + db.insert_blocks(blocks.iter(), None).expect("insert blocks"); let accounts = random_eoa_account_range(&mut rng, 0..2).into_iter().collect::>(); @@ -111,10 +111,10 @@ mod tests { 0..0, 0..0, ); - tx.insert_changesets(changesets.clone(), None).expect("insert changesets"); - tx.insert_history(changesets.clone(), None).expect("insert history"); + db.insert_changesets(changesets.clone(), None).expect("insert changesets"); + db.insert_history(changesets.clone(), None).expect("insert history"); - let account_occurrences = tx.table::().unwrap().into_iter().fold( + let account_occurrences = db.table::().unwrap().into_iter().fold( BTreeMap::<_, usize>::new(), |mut map, (key, _)| { map.entry(key.key).or_default().add_assign(1); @@ -124,17 +124,19 @@ mod tests { assert!(account_occurrences.into_iter().any(|(_, occurrences)| occurrences > 1)); assert_eq!( - tx.table::().unwrap().len(), + db.table::().unwrap().len(), changesets.iter().flatten().count() ); - let original_shards = tx.table::().unwrap(); + let original_shards = db.table::().unwrap(); let test_prune = |to_block: BlockNumber, run: usize, expected_result: (bool, usize)| { let prune_mode = PruneMode::Before(to_block); let input = PruneInput { - previous_checkpoint: tx - .inner() + previous_checkpoint: db + .factory + .provider() + .unwrap() .get_prune_checkpoint(PruneSegment::AccountHistory) .unwrap(), to_block, @@ -142,7 +144,7 @@ mod tests { }; let segment = AccountHistory::new(prune_mode); - let provider = tx.inner_rw(); + let provider = db.factory.provider_rw().unwrap(); let result = segment.prune(&provider, input).unwrap(); assert_matches!( result, @@ -200,11 +202,11 @@ mod tests { ); assert_eq!( - tx.table::().unwrap().len(), + db.table::().unwrap().len(), pruned_changesets.values().flatten().count() ); - let actual_shards = tx.table::().unwrap(); + let actual_shards = db.table::().unwrap(); let expected_shards = original_shards .iter() @@ -221,7 +223,11 @@ mod tests { assert_eq!(actual_shards, expected_shards); assert_eq!( - tx.inner().get_prune_checkpoint(PruneSegment::AccountHistory).unwrap(), + db.factory + .provider() + .unwrap() + .get_prune_checkpoint(PruneSegment::AccountHistory) + .unwrap(), Some(PruneCheckpoint { block_number: Some(last_pruned_block_number), tx_number: None, diff --git a/crates/prune/src/segments/headers.rs b/crates/prune/src/segments/headers.rs index 12ba19416af3..f6913fe15d2a 100644 --- a/crates/prune/src/segments/headers.rs +++ b/crates/prune/src/segments/headers.rs @@ -33,7 +33,7 @@ impl Segment for Headers { #[instrument(level = "trace", target = "pruner", skip(self, provider), ret)] fn prune( &self, - provider: &DatabaseProviderRW<'_, DB>, + provider: &DatabaseProviderRW, input: PruneInput, ) -> Result { let block_range = match input.get_next_block_range() { @@ -91,7 +91,7 @@ impl Headers { /// Returns `done`, number of pruned rows and last pruned block number. fn prune_table>( &self, - provider: &DatabaseProviderRW<'_, DB>, + provider: &DatabaseProviderRW, range: RangeInclusive, delete_limit: usize, ) -> RethResult<(bool, usize, BlockNumber)> { @@ -116,25 +116,27 @@ mod tests { use reth_interfaces::test_utils::{generators, generators::random_header_range}; use reth_primitives::{BlockNumber, PruneCheckpoint, PruneMode, PruneSegment, B256}; use reth_provider::PruneCheckpointReader; - use reth_stages::test_utils::TestTransaction; + use reth_stages::test_utils::TestStageDB; #[test] fn prune() { - let tx = TestTransaction::default(); + let db = TestStageDB::default(); let mut rng = generators::rng(); let headers = random_header_range(&mut rng, 0..100, B256::ZERO); - tx.insert_headers_with_td(headers.iter()).expect("insert headers"); + db.insert_headers_with_td(headers.iter()).expect("insert headers"); - assert_eq!(tx.table::().unwrap().len(), headers.len()); - assert_eq!(tx.table::().unwrap().len(), headers.len()); - assert_eq!(tx.table::().unwrap().len(), headers.len()); + assert_eq!(db.table::().unwrap().len(), headers.len()); + assert_eq!(db.table::().unwrap().len(), headers.len()); + assert_eq!(db.table::().unwrap().len(), headers.len()); let test_prune = |to_block: BlockNumber, expected_result: (bool, usize)| { let prune_mode = PruneMode::Before(to_block); let input = PruneInput { - previous_checkpoint: tx - .inner() + previous_checkpoint: db + .factory + .provider() + .unwrap() .get_prune_checkpoint(PruneSegment::Headers) .unwrap(), to_block, @@ -142,15 +144,17 @@ mod tests { }; let segment = Headers::new(prune_mode); - let next_block_number_to_prune = tx - .inner() + let next_block_number_to_prune = db + .factory + .provider() + .unwrap() .get_prune_checkpoint(PruneSegment::Headers) .unwrap() .and_then(|checkpoint| checkpoint.block_number) .map(|block_number| block_number + 1) .unwrap_or_default(); - let provider = tx.inner_rw(); + let provider = db.factory.provider_rw().unwrap(); let result = segment.prune(&provider, input).unwrap(); assert_matches!( result, @@ -169,19 +173,19 @@ mod tests { .min(next_block_number_to_prune + input.delete_limit as BlockNumber / 3 - 1); assert_eq!( - tx.table::().unwrap().len(), + db.table::().unwrap().len(), headers.len() - (last_pruned_block_number + 1) as usize ); assert_eq!( - tx.table::().unwrap().len(), + db.table::().unwrap().len(), headers.len() - (last_pruned_block_number + 1) as usize ); assert_eq!( - tx.table::().unwrap().len(), + db.table::().unwrap().len(), headers.len() - (last_pruned_block_number + 1) as usize ); assert_eq!( - tx.inner().get_prune_checkpoint(PruneSegment::Headers).unwrap(), + db.factory.provider().unwrap().get_prune_checkpoint(PruneSegment::Headers).unwrap(), Some(PruneCheckpoint { block_number: Some(last_pruned_block_number), tx_number: None, @@ -196,7 +200,7 @@ mod tests { #[test] fn prune_cannot_be_done() { - let tx = TestTransaction::default(); + let db = TestStageDB::default(); let input = PruneInput { previous_checkpoint: None, @@ -206,7 +210,7 @@ mod tests { }; let segment = Headers::new(PruneMode::Full); - let provider = tx.inner_rw(); + let provider = db.factory.provider_rw().unwrap(); let result = segment.prune(&provider, input).unwrap(); assert_eq!(result, PruneOutput::not_done()); } diff --git a/crates/prune/src/segments/history.rs b/crates/prune/src/segments/history.rs index bb3352a396a0..4836eeb84154 100644 --- a/crates/prune/src/segments/history.rs +++ b/crates/prune/src/segments/history.rs @@ -14,7 +14,7 @@ use reth_provider::DatabaseProviderRW; /// /// Returns total number of processed (walked) and deleted entities. pub(crate) fn prune_history_indices( - provider: &DatabaseProviderRW<'_, DB>, + provider: &DatabaseProviderRW, to_block: BlockNumber, key_matches: impl Fn(&T::Key, &T::Key) -> bool, last_key: impl Fn(&T::Key) -> T::Key, diff --git a/crates/prune/src/segments/mod.rs b/crates/prune/src/segments/mod.rs index 62fda6195864..339c4e013745 100644 --- a/crates/prune/src/segments/mod.rs +++ b/crates/prune/src/segments/mod.rs @@ -45,14 +45,14 @@ pub trait Segment: Debug + Send + Sync { /// Prune data for [Self::segment] using the provided input. fn prune( &self, - provider: &DatabaseProviderRW<'_, DB>, + provider: &DatabaseProviderRW, input: PruneInput, ) -> Result; /// Save checkpoint for [Self::segment] to the database. fn save_checkpoint( &self, - provider: &DatabaseProviderRW<'_, DB>, + provider: &DatabaseProviderRW, checkpoint: PruneCheckpoint, ) -> ProviderResult<()> { provider.save_prune_checkpoint(self.segment(), checkpoint) @@ -80,7 +80,7 @@ impl PruneInput { /// To get the range end: get last tx number for `to_block`. pub(crate) fn get_next_tx_num_range( &self, - provider: &DatabaseProviderRW<'_, DB>, + provider: &DatabaseProviderRW, ) -> RethResult>> { let from_tx_number = self.previous_checkpoint // Checkpoint exists, prune from the next transaction after the highest pruned one diff --git a/crates/prune/src/segments/receipts.rs b/crates/prune/src/segments/receipts.rs index fb97897e0cd4..fdd4d0402e40 100644 --- a/crates/prune/src/segments/receipts.rs +++ b/crates/prune/src/segments/receipts.rs @@ -31,7 +31,7 @@ impl Segment for Receipts { #[instrument(level = "trace", target = "pruner", skip(self, provider), ret)] fn prune( &self, - provider: &DatabaseProviderRW<'_, DB>, + provider: &DatabaseProviderRW, input: PruneInput, ) -> Result { let tx_range = match input.get_next_tx_num_range(provider)? { @@ -71,7 +71,7 @@ impl Segment for Receipts { fn save_checkpoint( &self, - provider: &DatabaseProviderRW<'_, DB>, + provider: &DatabaseProviderRW, checkpoint: PruneCheckpoint, ) -> ProviderResult<()> { provider.save_prune_checkpoint(PruneSegment::Receipts, checkpoint)?; @@ -99,16 +99,16 @@ mod tests { }; use reth_primitives::{BlockNumber, PruneCheckpoint, PruneMode, PruneSegment, TxNumber, B256}; use reth_provider::PruneCheckpointReader; - use reth_stages::test_utils::TestTransaction; + use reth_stages::test_utils::TestStageDB; use std::ops::Sub; #[test] fn prune() { - let tx = TestTransaction::default(); + let db = TestStageDB::default(); let mut rng = generators::rng(); let blocks = random_block_range(&mut rng, 1..=10, B256::ZERO, 2..3); - tx.insert_blocks(blocks.iter(), None).expect("insert blocks"); + db.insert_blocks(blocks.iter(), None).expect("insert blocks"); let mut receipts = Vec::new(); for block in &blocks { @@ -117,22 +117,24 @@ mod tests { .push((receipts.len() as u64, random_receipt(&mut rng, transaction, Some(0)))); } } - tx.insert_receipts(receipts.clone()).expect("insert receipts"); + db.insert_receipts(receipts.clone()).expect("insert receipts"); assert_eq!( - tx.table::().unwrap().len(), + db.table::().unwrap().len(), blocks.iter().map(|block| block.body.len()).sum::() ); assert_eq!( - tx.table::().unwrap().len(), - tx.table::().unwrap().len() + db.table::().unwrap().len(), + db.table::().unwrap().len() ); let test_prune = |to_block: BlockNumber, expected_result: (bool, usize)| { let prune_mode = PruneMode::Before(to_block); let input = PruneInput { - previous_checkpoint: tx - .inner() + previous_checkpoint: db + .factory + .provider() + .unwrap() .get_prune_checkpoint(PruneSegment::Receipts) .unwrap(), to_block, @@ -140,8 +142,10 @@ mod tests { }; let segment = Receipts::new(prune_mode); - let next_tx_number_to_prune = tx - .inner() + let next_tx_number_to_prune = db + .factory + .provider() + .unwrap() .get_prune_checkpoint(PruneSegment::Receipts) .unwrap() .and_then(|checkpoint| checkpoint.tx_number) @@ -156,7 +160,7 @@ mod tests { .min(next_tx_number_to_prune as usize + input.delete_limit) .sub(1); - let provider = tx.inner_rw(); + let provider = db.factory.provider_rw().unwrap(); let result = segment.prune(&provider, input).unwrap(); assert_matches!( result, @@ -187,11 +191,15 @@ mod tests { .checked_sub(if result.done { 0 } else { 1 }); assert_eq!( - tx.table::().unwrap().len(), + db.table::().unwrap().len(), receipts.len() - (last_pruned_tx_number + 1) ); assert_eq!( - tx.inner().get_prune_checkpoint(PruneSegment::Receipts).unwrap(), + db.factory + .provider() + .unwrap() + .get_prune_checkpoint(PruneSegment::Receipts) + .unwrap(), Some(PruneCheckpoint { block_number: last_pruned_block_number, tx_number: Some(last_pruned_tx_number as TxNumber), diff --git a/crates/prune/src/segments/receipts_by_logs.rs b/crates/prune/src/segments/receipts_by_logs.rs index 8f9faa4fd7b0..e05c87533812 100644 --- a/crates/prune/src/segments/receipts_by_logs.rs +++ b/crates/prune/src/segments/receipts_by_logs.rs @@ -32,7 +32,7 @@ impl Segment for ReceiptsByLogs { #[instrument(level = "trace", target = "pruner", skip(self, provider), ret)] fn prune( &self, - provider: &DatabaseProviderRW<'_, DB>, + provider: &DatabaseProviderRW, input: PruneInput, ) -> Result { // Contract log filtering removes every receipt possible except the ones in the list. So, @@ -216,12 +216,12 @@ mod tests { }; use reth_primitives::{PruneMode, PruneSegment, ReceiptsLogPruneConfig, B256}; use reth_provider::{PruneCheckpointReader, TransactionsProvider}; - use reth_stages::test_utils::TestTransaction; + use reth_stages::test_utils::TestStageDB; use std::collections::BTreeMap; #[test] fn prune_receipts_by_logs() { - let tx = TestTransaction::default(); + let db = TestStageDB::default(); let mut rng = generators::rng(); let tip = 20000; @@ -231,7 +231,7 @@ mod tests { random_block_range(&mut rng, (tip - 100 + 1)..=tip, B256::ZERO, 1..5), ] .concat(); - tx.insert_blocks(blocks.iter(), None).expect("insert blocks"); + db.insert_blocks(blocks.iter(), None).expect("insert blocks"); let mut receipts = Vec::new(); @@ -247,19 +247,19 @@ mod tests { receipts.push((receipts.len() as u64, receipt)); } } - tx.insert_receipts(receipts).expect("insert receipts"); + db.insert_receipts(receipts).expect("insert receipts"); assert_eq!( - tx.table::().unwrap().len(), + db.table::().unwrap().len(), blocks.iter().map(|block| block.body.len()).sum::() ); assert_eq!( - tx.table::().unwrap().len(), - tx.table::().unwrap().len() + db.table::().unwrap().len(), + db.table::().unwrap().len() ); let run_prune = || { - let provider = tx.inner_rw(); + let provider = db.factory.provider_rw().unwrap(); let prune_before_block: usize = 20; let prune_mode = PruneMode::Before(prune_before_block as u64); @@ -269,8 +269,10 @@ mod tests { let result = ReceiptsByLogs::new(receipts_log_filter).prune( &provider, PruneInput { - previous_checkpoint: tx - .inner() + previous_checkpoint: db + .factory + .provider() + .unwrap() .get_prune_checkpoint(PruneSegment::ContractLogs) .unwrap(), to_block: tip, @@ -282,8 +284,10 @@ mod tests { assert_matches!(result, Ok(_)); let output = result.unwrap(); - let (pruned_block, pruned_tx) = tx - .inner() + let (pruned_block, pruned_tx) = db + .factory + .provider() + .unwrap() .get_prune_checkpoint(PruneSegment::ContractLogs) .unwrap() .map(|checkpoint| (checkpoint.block_number.unwrap(), checkpoint.tx_number.unwrap())) @@ -293,7 +297,7 @@ mod tests { let unprunable = pruned_block.saturating_sub(prune_before_block as u64 - 1); assert_eq!( - tx.table::().unwrap().len(), + db.table::().unwrap().len(), blocks.iter().map(|block| block.body.len()).sum::() - ((pruned_tx + 1) - unprunable) as usize ); @@ -303,7 +307,7 @@ mod tests { while !run_prune() {} - let provider = tx.inner(); + let provider = db.factory.provider().unwrap(); let mut cursor = provider.tx_ref().cursor_read::().unwrap(); let walker = cursor.walk(None).unwrap(); for receipt in walker { diff --git a/crates/prune/src/segments/sender_recovery.rs b/crates/prune/src/segments/sender_recovery.rs index aa8d48a2ea4b..ec2d189f55cc 100644 --- a/crates/prune/src/segments/sender_recovery.rs +++ b/crates/prune/src/segments/sender_recovery.rs @@ -30,7 +30,7 @@ impl Segment for SenderRecovery { #[instrument(level = "trace", target = "pruner", skip(self, provider), ret)] fn prune( &self, - provider: &DatabaseProviderRW<'_, DB>, + provider: &DatabaseProviderRW, input: PruneInput, ) -> Result { let tx_range = match input.get_next_tx_num_range(provider)? { @@ -81,16 +81,16 @@ mod tests { use reth_interfaces::test_utils::{generators, generators::random_block_range}; use reth_primitives::{BlockNumber, PruneCheckpoint, PruneMode, PruneSegment, TxNumber, B256}; use reth_provider::PruneCheckpointReader; - use reth_stages::test_utils::TestTransaction; + use reth_stages::test_utils::TestStageDB; use std::ops::Sub; #[test] fn prune() { - let tx = TestTransaction::default(); + let db = TestStageDB::default(); let mut rng = generators::rng(); let blocks = random_block_range(&mut rng, 1..=10, B256::ZERO, 2..3); - tx.insert_blocks(blocks.iter(), None).expect("insert blocks"); + db.insert_blocks(blocks.iter(), None).expect("insert blocks"); let mut transaction_senders = Vec::new(); for block in &blocks { @@ -101,23 +101,25 @@ mod tests { )); } } - tx.insert_transaction_senders(transaction_senders.clone()) + db.insert_transaction_senders(transaction_senders.clone()) .expect("insert transaction senders"); assert_eq!( - tx.table::().unwrap().len(), + db.table::().unwrap().len(), blocks.iter().map(|block| block.body.len()).sum::() ); assert_eq!( - tx.table::().unwrap().len(), - tx.table::().unwrap().len() + db.table::().unwrap().len(), + db.table::().unwrap().len() ); let test_prune = |to_block: BlockNumber, expected_result: (bool, usize)| { let prune_mode = PruneMode::Before(to_block); let input = PruneInput { - previous_checkpoint: tx - .inner() + previous_checkpoint: db + .factory + .provider() + .unwrap() .get_prune_checkpoint(PruneSegment::SenderRecovery) .unwrap(), to_block, @@ -125,8 +127,10 @@ mod tests { }; let segment = SenderRecovery::new(prune_mode); - let next_tx_number_to_prune = tx - .inner() + let next_tx_number_to_prune = db + .factory + .provider() + .unwrap() .get_prune_checkpoint(PruneSegment::SenderRecovery) .unwrap() .and_then(|checkpoint| checkpoint.tx_number) @@ -155,7 +159,7 @@ mod tests { .into_inner() .0; - let provider = tx.inner_rw(); + let provider = db.factory.provider_rw().unwrap(); let result = segment.prune(&provider, input).unwrap(); assert_matches!( result, @@ -174,11 +178,15 @@ mod tests { last_pruned_block_number.checked_sub(if result.done { 0 } else { 1 }); assert_eq!( - tx.table::().unwrap().len(), + db.table::().unwrap().len(), transaction_senders.len() - (last_pruned_tx_number + 1) ); assert_eq!( - tx.inner().get_prune_checkpoint(PruneSegment::SenderRecovery).unwrap(), + db.factory + .provider() + .unwrap() + .get_prune_checkpoint(PruneSegment::SenderRecovery) + .unwrap(), Some(PruneCheckpoint { block_number: last_pruned_block_number, tx_number: Some(last_pruned_tx_number as TxNumber), diff --git a/crates/prune/src/segments/storage_history.rs b/crates/prune/src/segments/storage_history.rs index aa68eb714794..45713760c7da 100644 --- a/crates/prune/src/segments/storage_history.rs +++ b/crates/prune/src/segments/storage_history.rs @@ -36,7 +36,7 @@ impl Segment for StorageHistory { #[instrument(level = "trace", target = "pruner", skip(self, provider), ret)] fn prune( &self, - provider: &DatabaseProviderRW<'_, DB>, + provider: &DatabaseProviderRW, input: PruneInput, ) -> Result { let range = match input.get_next_block_range() { @@ -94,16 +94,16 @@ mod tests { }; use reth_primitives::{BlockNumber, PruneCheckpoint, PruneMode, PruneSegment, B256}; use reth_provider::PruneCheckpointReader; - use reth_stages::test_utils::TestTransaction; + use reth_stages::test_utils::TestStageDB; use std::{collections::BTreeMap, ops::AddAssign}; #[test] fn prune() { - let tx = TestTransaction::default(); + let db = TestStageDB::default(); let mut rng = generators::rng(); let blocks = random_block_range(&mut rng, 0..=5000, B256::ZERO, 0..1); - tx.insert_blocks(blocks.iter(), None).expect("insert blocks"); + db.insert_blocks(blocks.iter(), None).expect("insert blocks"); let accounts = random_eoa_account_range(&mut rng, 0..2).into_iter().collect::>(); @@ -115,10 +115,10 @@ mod tests { 2..3, 1..2, ); - tx.insert_changesets(changesets.clone(), None).expect("insert changesets"); - tx.insert_history(changesets.clone(), None).expect("insert history"); + db.insert_changesets(changesets.clone(), None).expect("insert changesets"); + db.insert_history(changesets.clone(), None).expect("insert history"); - let storage_occurrences = tx.table::().unwrap().into_iter().fold( + let storage_occurrences = db.table::().unwrap().into_iter().fold( BTreeMap::<_, usize>::new(), |mut map, (key, _)| { map.entry((key.address, key.sharded_key.key)).or_default().add_assign(1); @@ -128,17 +128,19 @@ mod tests { assert!(storage_occurrences.into_iter().any(|(_, occurrences)| occurrences > 1)); assert_eq!( - tx.table::().unwrap().len(), + db.table::().unwrap().len(), changesets.iter().flatten().flat_map(|(_, _, entries)| entries).count() ); - let original_shards = tx.table::().unwrap(); + let original_shards = db.table::().unwrap(); let test_prune = |to_block: BlockNumber, run: usize, expected_result: (bool, usize)| { let prune_mode = PruneMode::Before(to_block); let input = PruneInput { - previous_checkpoint: tx - .inner() + previous_checkpoint: db + .factory + .provider() + .unwrap() .get_prune_checkpoint(PruneSegment::StorageHistory) .unwrap(), to_block, @@ -146,7 +148,7 @@ mod tests { }; let segment = StorageHistory::new(prune_mode); - let provider = tx.inner_rw(); + let provider = db.factory.provider_rw().unwrap(); let result = segment.prune(&provider, input).unwrap(); assert_matches!( result, @@ -206,11 +208,11 @@ mod tests { ); assert_eq!( - tx.table::().unwrap().len(), + db.table::().unwrap().len(), pruned_changesets.values().flatten().count() ); - let actual_shards = tx.table::().unwrap(); + let actual_shards = db.table::().unwrap(); let expected_shards = original_shards .iter() @@ -227,7 +229,11 @@ mod tests { assert_eq!(actual_shards, expected_shards); assert_eq!( - tx.inner().get_prune_checkpoint(PruneSegment::StorageHistory).unwrap(), + db.factory + .provider() + .unwrap() + .get_prune_checkpoint(PruneSegment::StorageHistory) + .unwrap(), Some(PruneCheckpoint { block_number: Some(last_pruned_block_number), tx_number: None, diff --git a/crates/prune/src/segments/transaction_lookup.rs b/crates/prune/src/segments/transaction_lookup.rs index 6785a22fc2f4..342a764a68a6 100644 --- a/crates/prune/src/segments/transaction_lookup.rs +++ b/crates/prune/src/segments/transaction_lookup.rs @@ -31,7 +31,7 @@ impl Segment for TransactionLookup { #[instrument(level = "trace", target = "pruner", skip(self, provider), ret)] fn prune( &self, - provider: &DatabaseProviderRW<'_, DB>, + provider: &DatabaseProviderRW, input: PruneInput, ) -> Result { let (start, end) = match input.get_next_tx_num_range(provider)? { @@ -104,16 +104,16 @@ mod tests { use reth_interfaces::test_utils::{generators, generators::random_block_range}; use reth_primitives::{BlockNumber, PruneCheckpoint, PruneMode, PruneSegment, TxNumber, B256}; use reth_provider::PruneCheckpointReader; - use reth_stages::test_utils::TestTransaction; + use reth_stages::test_utils::TestStageDB; use std::ops::Sub; #[test] fn prune() { - let tx = TestTransaction::default(); + let db = TestStageDB::default(); let mut rng = generators::rng(); let blocks = random_block_range(&mut rng, 1..=10, B256::ZERO, 2..3); - tx.insert_blocks(blocks.iter(), None).expect("insert blocks"); + db.insert_blocks(blocks.iter(), None).expect("insert blocks"); let mut tx_hash_numbers = Vec::new(); for block in &blocks { @@ -121,22 +121,24 @@ mod tests { tx_hash_numbers.push((transaction.hash, tx_hash_numbers.len() as u64)); } } - tx.insert_tx_hash_numbers(tx_hash_numbers.clone()).expect("insert tx hash numbers"); + db.insert_tx_hash_numbers(tx_hash_numbers.clone()).expect("insert tx hash numbers"); assert_eq!( - tx.table::().unwrap().len(), + db.table::().unwrap().len(), blocks.iter().map(|block| block.body.len()).sum::() ); assert_eq!( - tx.table::().unwrap().len(), - tx.table::().unwrap().len() + db.table::().unwrap().len(), + db.table::().unwrap().len() ); let test_prune = |to_block: BlockNumber, expected_result: (bool, usize)| { let prune_mode = PruneMode::Before(to_block); let input = PruneInput { - previous_checkpoint: tx - .inner() + previous_checkpoint: db + .factory + .provider() + .unwrap() .get_prune_checkpoint(PruneSegment::TransactionLookup) .unwrap(), to_block, @@ -144,8 +146,10 @@ mod tests { }; let segment = TransactionLookup::new(prune_mode); - let next_tx_number_to_prune = tx - .inner() + let next_tx_number_to_prune = db + .factory + .provider() + .unwrap() .get_prune_checkpoint(PruneSegment::TransactionLookup) .unwrap() .and_then(|checkpoint| checkpoint.tx_number) @@ -174,7 +178,7 @@ mod tests { .into_inner() .0; - let provider = tx.inner_rw(); + let provider = db.factory.provider_rw().unwrap(); let result = segment.prune(&provider, input).unwrap(); assert_matches!( result, @@ -193,11 +197,15 @@ mod tests { last_pruned_block_number.checked_sub(if result.done { 0 } else { 1 }); assert_eq!( - tx.table::().unwrap().len(), + db.table::().unwrap().len(), tx_hash_numbers.len() - (last_pruned_tx_number + 1) ); assert_eq!( - tx.inner().get_prune_checkpoint(PruneSegment::TransactionLookup).unwrap(), + db.factory + .provider() + .unwrap() + .get_prune_checkpoint(PruneSegment::TransactionLookup) + .unwrap(), Some(PruneCheckpoint { block_number: last_pruned_block_number, tx_number: Some(last_pruned_tx_number as TxNumber), diff --git a/crates/prune/src/segments/transactions.rs b/crates/prune/src/segments/transactions.rs index d06e97b65955..7155cd8888ad 100644 --- a/crates/prune/src/segments/transactions.rs +++ b/crates/prune/src/segments/transactions.rs @@ -30,7 +30,7 @@ impl Segment for Transactions { #[instrument(level = "trace", target = "pruner", skip(self, provider), ret)] fn prune( &self, - provider: &DatabaseProviderRW<'_, DB>, + provider: &DatabaseProviderRW, input: PruneInput, ) -> Result { let tx_range = match input.get_next_tx_num_range(provider)? { @@ -80,26 +80,28 @@ mod tests { use reth_interfaces::test_utils::{generators, generators::random_block_range}; use reth_primitives::{BlockNumber, PruneCheckpoint, PruneMode, PruneSegment, TxNumber, B256}; use reth_provider::PruneCheckpointReader; - use reth_stages::test_utils::TestTransaction; + use reth_stages::test_utils::TestStageDB; use std::ops::Sub; #[test] fn prune() { - let tx = TestTransaction::default(); + let db = TestStageDB::default(); let mut rng = generators::rng(); let blocks = random_block_range(&mut rng, 1..=100, B256::ZERO, 2..3); - tx.insert_blocks(blocks.iter(), None).expect("insert blocks"); + db.insert_blocks(blocks.iter(), None).expect("insert blocks"); let transactions = blocks.iter().flat_map(|block| &block.body).collect::>(); - assert_eq!(tx.table::().unwrap().len(), transactions.len()); + assert_eq!(db.table::().unwrap().len(), transactions.len()); let test_prune = |to_block: BlockNumber, expected_result: (bool, usize)| { let prune_mode = PruneMode::Before(to_block); let input = PruneInput { - previous_checkpoint: tx - .inner() + previous_checkpoint: db + .factory + .provider() + .unwrap() .get_prune_checkpoint(PruneSegment::Transactions) .unwrap(), to_block, @@ -107,15 +109,17 @@ mod tests { }; let segment = Transactions::new(prune_mode); - let next_tx_number_to_prune = tx - .inner() + let next_tx_number_to_prune = db + .factory + .provider() + .unwrap() .get_prune_checkpoint(PruneSegment::Transactions) .unwrap() .and_then(|checkpoint| checkpoint.tx_number) .map(|tx_number| tx_number + 1) .unwrap_or_default(); - let provider = tx.inner_rw(); + let provider = db.factory.provider_rw().unwrap(); let result = segment.prune(&provider, input).unwrap(); assert_matches!( result, @@ -154,11 +158,15 @@ mod tests { .checked_sub(if result.done { 0 } else { 1 }); assert_eq!( - tx.table::().unwrap().len(), + db.table::().unwrap().len(), transactions.len() - (last_pruned_tx_number + 1) ); assert_eq!( - tx.inner().get_prune_checkpoint(PruneSegment::Transactions).unwrap(), + db.factory + .provider() + .unwrap() + .get_prune_checkpoint(PruneSegment::Transactions) + .unwrap(), Some(PruneCheckpoint { block_number: last_pruned_block_number, tx_number: Some(last_pruned_tx_number as TxNumber), diff --git a/crates/revm/Cargo.toml b/crates/revm/Cargo.toml index 30c6c351c570..18f74ecf222f 100644 --- a/crates/revm/Cargo.toml +++ b/crates/revm/Cargo.toml @@ -22,6 +22,9 @@ revm.workspace = true # common tracing.workspace = true +[dev-dependencies] +reth-trie.workspace = true + [features] optimism = [ "revm/optimism", diff --git a/crates/revm/revm-inspectors/Cargo.toml b/crates/revm/revm-inspectors/Cargo.toml index 1edcd0d34e82..bdfbe51f77ab 100644 --- a/crates/revm/revm-inspectors/Cargo.toml +++ b/crates/revm/revm-inspectors/Cargo.toml @@ -10,11 +10,11 @@ description = "revm inspector implementations used by reth" [dependencies] # reth -reth-primitives.workspace = true reth-rpc-types.workspace = true # eth alloy-sol-types.workspace = true +alloy-primitives.workspace = true revm.workspace = true diff --git a/crates/revm/revm-inspectors/src/access_list.rs b/crates/revm/revm-inspectors/src/access_list.rs index 6eb0c0b3fa3e..52eb08c0d104 100644 --- a/crates/revm/revm-inspectors/src/access_list.rs +++ b/crates/revm/revm-inspectors/src/access_list.rs @@ -1,4 +1,5 @@ -use reth_primitives::{AccessList, AccessListItem, Address, B256}; +use alloy_primitives::{Address, B256}; +use reth_rpc_types::{AccessList, AccessListItem}; use revm::{ interpreter::{opcode, Interpreter}, Database, EVMData, Inspector, diff --git a/crates/revm/revm-inspectors/src/stack/maybe_owned.rs b/crates/revm/revm-inspectors/src/stack/maybe_owned.rs index f29b44090a0d..c02ce6e4b9da 100644 --- a/crates/revm/revm-inspectors/src/stack/maybe_owned.rs +++ b/crates/revm/revm-inspectors/src/stack/maybe_owned.rs @@ -1,4 +1,4 @@ -use reth_primitives::U256; +use alloy_primitives::U256; use revm::{ interpreter::{CallInputs, CreateInputs, Gas, InstructionResult, Interpreter}, primitives::{db::Database, Address, Bytes, B256}, diff --git a/crates/revm/revm-inspectors/src/stack/mod.rs b/crates/revm/revm-inspectors/src/stack/mod.rs index 603fe1d692c0..f8ea91794ed0 100644 --- a/crates/revm/revm-inspectors/src/stack/mod.rs +++ b/crates/revm/revm-inspectors/src/stack/mod.rs @@ -1,4 +1,4 @@ -use reth_primitives::{Address, Bytes, TxHash, B256, U256}; +use alloy_primitives::{Address, Bytes, B256, U256}; use revm::{ inspectors::CustomPrintTracer, interpreter::{CallInputs, CreateInputs, Gas, InstructionResult, Interpreter}, @@ -23,7 +23,7 @@ pub enum Hook { /// Hook on a specific block. Block(u64), /// Hook on a specific transaction hash. - Transaction(TxHash), + Transaction(B256), /// Hooks on every transaction in a block. All, } @@ -62,7 +62,7 @@ impl InspectorStack { } /// Check if the inspector should be used. - pub fn should_inspect(&self, env: &Env, tx_hash: TxHash) -> bool { + pub fn should_inspect(&self, env: &Env, tx_hash: B256) -> bool { match self.hook { Hook::None => false, Hook::Block(block) => env.block.number.to::() == block, diff --git a/crates/revm/revm-inspectors/src/tracing/arena.rs b/crates/revm/revm-inspectors/src/tracing/arena.rs index d157b705f2ff..cb7c6b5187d7 100644 --- a/crates/revm/revm-inspectors/src/tracing/arena.rs +++ b/crates/revm/revm-inspectors/src/tracing/arena.rs @@ -56,6 +56,16 @@ impl CallTraceArena { } } } + + /// Returns the nodes in the arena + pub fn nodes(&self) -> &[CallTraceNode] { + &self.arena + } + + /// Consumes the arena and returns the nodes + pub fn into_nodes(self) -> Vec { + self.arena + } } /// How to push a trace into the arena diff --git a/crates/revm/revm-inspectors/src/tracing/builder/geth.rs b/crates/revm/revm-inspectors/src/tracing/builder/geth.rs index 0f583e6c8471..4e456c42b7ff 100644 --- a/crates/revm/revm-inspectors/src/tracing/builder/geth.rs +++ b/crates/revm/revm-inspectors/src/tracing/builder/geth.rs @@ -5,7 +5,7 @@ use crate::tracing::{ utils::load_account_code, TracingInspectorConfig, }; -use reth_primitives::{Address, Bytes, B256, U256}; +use alloy_primitives::{Address, Bytes, B256, U256}; use reth_rpc_types::trace::geth::{ AccountChangeKind, AccountState, CallConfig, CallFrame, DefaultFrame, DiffMode, GethDefaultTracingOptions, PreStateConfig, PreStateFrame, PreStateMode, StructLog, diff --git a/crates/revm/revm-inspectors/src/tracing/builder/parity.rs b/crates/revm/revm-inspectors/src/tracing/builder/parity.rs index d17ecc53118a..943a85ee2b1c 100644 --- a/crates/revm/revm-inspectors/src/tracing/builder/parity.rs +++ b/crates/revm/revm-inspectors/src/tracing/builder/parity.rs @@ -4,7 +4,7 @@ use crate::tracing::{ utils::load_account_code, TracingInspectorConfig, }; -use reth_primitives::{Address, U64}; +use alloy_primitives::{Address, U64}; use reth_rpc_types::{trace::parity::*, TransactionInfo}; use revm::{ db::DatabaseRef, @@ -453,9 +453,11 @@ impl ParityTraceBuilder { } }; let mut push_stack = step.push_stack.clone().unwrap_or_default(); - for idx in (0..show_stack).rev() { - if step.stack.len() > idx { - push_stack.push(step.stack.peek(idx).unwrap_or_default()) + if let Some(stack) = step.stack.as_ref() { + for idx in (0..show_stack).rev() { + if stack.len() > idx { + push_stack.push(stack[stack.len() - idx - 1]) + } } } push_stack @@ -487,10 +489,6 @@ impl ParityTraceBuilder { } /// An iterator for [TransactionTrace]s -/// -/// This iterator handles additional selfdestruct actions based on the last emitted -/// [TransactionTrace], since selfdestructs are not recorded as individual call traces but are -/// derived from recorded call struct TransactionTraceIter { iter: Iter, next_selfdestruct: Option, diff --git a/crates/revm/revm-inspectors/src/tracing/config.rs b/crates/revm/revm-inspectors/src/tracing/config.rs index 382cecb990a9..689fa16de949 100644 --- a/crates/revm/revm-inspectors/src/tracing/config.rs +++ b/crates/revm/revm-inspectors/src/tracing/config.rs @@ -1,4 +1,5 @@ -use reth_rpc_types::trace::geth::GethDefaultTracingOptions; +use reth_rpc_types::trace::{geth::GethDefaultTracingOptions, parity::TraceType}; +use std::collections::HashSet; /// What kind of tracing style this is. /// @@ -85,7 +86,21 @@ impl TracingInspectorConfig { } } + /// Returns the [TracingInspectorConfig] depending on the enabled [TraceType]s + /// + /// Note: the parity statediffs can be populated entirely via the execution result, so we don't + /// need statediff recording + #[inline] + pub fn from_parity_config(trace_types: &HashSet) -> Self { + let needs_vm_trace = trace_types.contains(&TraceType::VmTrace); + TracingInspectorConfig::default_parity() + .set_steps(needs_vm_trace) + .set_stack_snapshots(needs_vm_trace) + .set_memory_snapshots(needs_vm_trace) + } + /// Returns a config for geth style traces based on the given [GethDefaultTracingOptions]. + #[inline] pub fn from_geth_config(config: &GethDefaultTracingOptions) -> Self { Self { record_memory_snapshots: config.enable_memory.unwrap_or_default(), @@ -148,3 +163,32 @@ impl TracingInspectorConfig { self } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parity_config() { + let mut s = HashSet::new(); + s.insert(TraceType::StateDiff); + let config = TracingInspectorConfig::from_parity_config(&s); + // not required + assert!(!config.record_steps); + assert!(!config.record_state_diff); + + let mut s = HashSet::new(); + s.insert(TraceType::VmTrace); + let config = TracingInspectorConfig::from_parity_config(&s); + assert!(config.record_steps); + assert!(!config.record_state_diff); + + let mut s = HashSet::new(); + s.insert(TraceType::VmTrace); + s.insert(TraceType::StateDiff); + let config = TracingInspectorConfig::from_parity_config(&s); + assert!(config.record_steps); + // not required for StateDiff + assert!(!config.record_state_diff); + } +} diff --git a/crates/revm/revm-inspectors/src/tracing/fourbyte.rs b/crates/revm/revm-inspectors/src/tracing/fourbyte.rs index 8d116b8d4fb9..7ec8b312944b 100644 --- a/crates/revm/revm-inspectors/src/tracing/fourbyte.rs +++ b/crates/revm/revm-inspectors/src/tracing/fourbyte.rs @@ -21,7 +21,7 @@ //! //! See also -use reth_primitives::{hex, Bytes, Selector}; +use alloy_primitives::{hex, Bytes, Selector}; use reth_rpc_types::trace::geth::FourByteFrame; use revm::{ interpreter::{CallInputs, Gas, InstructionResult}, diff --git a/crates/revm/revm-inspectors/src/tracing/js/bindings.rs b/crates/revm/revm-inspectors/src/tracing/js/bindings.rs index abe56e397ee8..438571a5be60 100644 --- a/crates/revm/revm-inspectors/src/tracing/js/bindings.rs +++ b/crates/revm/revm-inspectors/src/tracing/js/bindings.rs @@ -10,19 +10,19 @@ use crate::tracing::{ }, types::CallKind, }; +use alloy_primitives::{Address, Bytes, B256, U256}; use boa_engine::{ native_function::NativeFunction, object::{builtins::JsArrayBuffer, FunctionObjectBuilder}, Context, JsArgs, JsError, JsNativeError, JsObject, JsResult, JsValue, }; use boa_gc::{empty_trace, Finalize, Trace}; -use reth_primitives::{Account, Address, Bytes, B256, KECCAK_EMPTY, U256}; use revm::{ interpreter::{ opcode::{PUSH0, PUSH32}, OpCode, SharedMemory, Stack, }, - primitives::State, + primitives::{AccountInfo, State, KECCAK_EMPTY}, }; use std::{cell::RefCell, rc::Rc, sync::mpsc::channel}; use tokio::sync::mpsc; @@ -297,14 +297,8 @@ impl StateRef { (StateRef(inner), guard) } - fn get_account(&self, address: &Address) -> Option { - self.0.with_inner(|state| { - state.get(address).map(|acc| Account { - nonce: acc.info.nonce, - balance: acc.info.balance, - bytecode_hash: Some(acc.info.code_hash), - }) - })? + fn get_account(&self, address: &Address) -> Option { + self.0.with_inner(|state| state.get(address).map(|acc| acc.info.clone()))? } } @@ -707,7 +701,7 @@ impl EvmDbRef { (this, guard) } - fn read_basic(&self, address: JsValue, ctx: &mut Context<'_>) -> JsResult> { + fn read_basic(&self, address: JsValue, ctx: &mut Context<'_>) -> JsResult> { let buf = from_buf(address, ctx)?; let address = bytes_to_address(buf); if let acc @ Some(_) = self.state.get_account(&address) { @@ -732,7 +726,7 @@ impl EvmDbRef { fn read_code(&self, address: JsValue, ctx: &mut Context<'_>) -> JsResult { let acc = self.read_basic(address, ctx)?; - let code_hash = acc.and_then(|acc| acc.bytecode_hash).unwrap_or(KECCAK_EMPTY); + let code_hash = acc.map(|acc| acc.code_hash).unwrap_or(KECCAK_EMPTY); if code_hash == KECCAK_EMPTY { return JsArrayBuffer::new(0, ctx) } diff --git a/crates/revm/revm-inspectors/src/tracing/js/builtins.rs b/crates/revm/revm-inspectors/src/tracing/js/builtins.rs index 5ae1ff7af1b0..91a7a2672b5e 100644 --- a/crates/revm/revm-inspectors/src/tracing/js/builtins.rs +++ b/crates/revm/revm-inspectors/src/tracing/js/builtins.rs @@ -1,12 +1,12 @@ //! Builtin functions +use alloy_primitives::{hex, Address, B256, U256}; use boa_engine::{ object::builtins::{JsArray, JsArrayBuffer}, property::Attribute, Context, JsArgs, JsError, JsNativeError, JsResult, JsString, JsValue, NativeFunction, Source, }; use boa_gc::{empty_trace, Finalize, Trace}; -use reth_primitives::{hex, Address, B256, U256}; use std::collections::HashSet; /// bigIntegerJS is the minified version of . diff --git a/crates/revm/revm-inspectors/src/tracing/js/mod.rs b/crates/revm/revm-inspectors/src/tracing/js/mod.rs index 94adf1d44162..792fd363b8c6 100644 --- a/crates/revm/revm-inspectors/src/tracing/js/mod.rs +++ b/crates/revm/revm-inspectors/src/tracing/js/mod.rs @@ -10,14 +10,14 @@ use crate::tracing::{ types::CallKind, utils::get_create_address, }; +use alloy_primitives::{Address, Bytes, B256, U256}; use boa_engine::{Context, JsError, JsObject, JsResult, JsValue, Source}; -use reth_primitives::{Account, Address, Bytes, B256, U256}; use revm::{ interpreter::{ return_revert, CallInputs, CallScheme, CreateInputs, Gas, InstructionResult, Interpreter, }, precompile::Precompiles, - primitives::{Env, ExecutionResult, Output, ResultAndState, TransactTo}, + primitives::{AccountInfo, Env, ExecutionResult, Output, ResultAndState, TransactTo}, Database, EVMData, Inspector, }; use tokio::sync::mpsc; @@ -483,7 +483,7 @@ pub enum JsDbRequest { /// The address of the account to be loaded address: Address, /// The response channel - resp: std::sync::mpsc::Sender, String>>, + resp: std::sync::mpsc::Sender, String>>, }, /// Bindings for [Database::code_by_hash] Code { diff --git a/crates/revm/revm-inspectors/src/tracing/mod.rs b/crates/revm/revm-inspectors/src/tracing/mod.rs index a947c42538c4..31d7c29c2e8f 100644 --- a/crates/revm/revm-inspectors/src/tracing/mod.rs +++ b/crates/revm/revm-inspectors/src/tracing/mod.rs @@ -1,9 +1,9 @@ use crate::tracing::{ - types::{CallKind, LogCallOrder, RawLog}, + types::{CallKind, LogCallOrder}, utils::get_create_address, }; +use alloy_primitives::{Address, Bytes, Log, B256, U256}; pub use arena::CallTraceArena; -use reth_primitives::{Address, Bytes, B256, U256}; use revm::{ inspectors::GasInspector, interpreter::{ @@ -20,7 +20,7 @@ mod builder; mod config; mod fourbyte; mod opcount; -mod types; +pub mod types; mod utils; use crate::tracing::{ arena::PushTraceKind, @@ -282,8 +282,7 @@ impl TracingInspector { .record_memory_snapshots .then(|| RecordedMemory::new(interp.shared_memory.context_memory().to_vec())) .unwrap_or_default(); - let stack = - self.config.record_stack_snapshots.then(|| interp.stack.clone()).unwrap_or_default(); + let stack = self.config.record_stack_snapshots.then(|| interp.stack.data().clone()); let op = OpCode::new(interp.current_opcode()) .or_else(|| { @@ -326,9 +325,12 @@ impl TracingInspector { self.step_stack.pop().expect("can't fill step without starting a step first"); let step = &mut self.traces.arena[trace_idx].trace.steps[step_idx]; - if interp.stack.len() > step.stack.len() { - // if the stack grew, we need to record the new values - step.push_stack = Some(interp.stack.data()[step.stack.len()..].to_vec()); + if let Some(stack) = step.stack.as_ref() { + // only check stack changes if record stack snapshots is enabled: if stack is Some + if interp.stack.len() > stack.len() { + // if the stack grew, we need to record the new values + step.push_stack = Some(interp.stack.data()[stack.len()..].to_vec()); + } } if self.config.record_memory_snapshots { @@ -406,7 +408,7 @@ where if self.config.record_logs { trace.ordering.push(LogCallOrder::Log(trace.logs.len())); - trace.logs.push(RawLog { topics: topics.to_vec(), data: data.clone() }); + trace.logs.push(Log::new_unchecked(topics.to_vec(), data.clone())); } } diff --git a/crates/revm/revm-inspectors/src/tracing/types.rs b/crates/revm/revm-inspectors/src/tracing/types.rs index ec0e7269a080..b5e6415bf435 100644 --- a/crates/revm/revm-inspectors/src/tracing/types.rs +++ b/crates/revm/revm-inspectors/src/tracing/types.rs @@ -1,8 +1,9 @@ //! Types for representing call trace items. use crate::tracing::{config::TraceStyle, utils::convert_memory}; +pub use alloy_primitives::Log; +use alloy_primitives::{Address, Bytes, U256, U64}; use alloy_sol_types::decode_revert_reason; -use reth_primitives::{Address, Bytes, B256, U256, U64}; use reth_rpc_types::trace::{ geth::{CallFrame, CallLogFrame, GethDefaultTracingOptions, StructLog}, parity::{ @@ -10,168 +11,62 @@ use reth_rpc_types::trace::{ SelfdestructAction, TraceOutput, TransactionTrace, }, }; -use revm::interpreter::{ - opcode, CallContext, CallScheme, CreateScheme, InstructionResult, OpCode, Stack, -}; +use revm::interpreter::{opcode, CallContext, CallScheme, CreateScheme, InstructionResult, OpCode}; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, VecDeque}; -/// A unified representation of a call -#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "UPPERCASE")] -#[allow(missing_docs)] -pub enum CallKind { - #[default] - Call, - StaticCall, - CallCode, - DelegateCall, - Create, - Create2, -} - -impl CallKind { - /// Returns true if the call is a create - #[inline] - pub fn is_any_create(&self) -> bool { - matches!(self, CallKind::Create | CallKind::Create2) - } - - /// Returns true if the call is a delegate of some sorts - #[inline] - pub fn is_delegate(&self) -> bool { - matches!(self, CallKind::DelegateCall | CallKind::CallCode) - } - - /// Returns true if the call is [CallKind::StaticCall]. - #[inline] - pub fn is_static_call(&self) -> bool { - matches!(self, CallKind::StaticCall) - } -} - -impl std::fmt::Display for CallKind { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - CallKind::Call => { - write!(f, "CALL") - } - CallKind::StaticCall => { - write!(f, "STATICCALL") - } - CallKind::CallCode => { - write!(f, "CALLCODE") - } - CallKind::DelegateCall => { - write!(f, "DELEGATECALL") - } - CallKind::Create => { - write!(f, "CREATE") - } - CallKind::Create2 => { - write!(f, "CREATE2") - } - } - } -} - -impl From for CallKind { - fn from(scheme: CallScheme) -> Self { - match scheme { - CallScheme::Call => CallKind::Call, - CallScheme::StaticCall => CallKind::StaticCall, - CallScheme::CallCode => CallKind::CallCode, - CallScheme::DelegateCall => CallKind::DelegateCall, - } - } -} - -impl From for CallKind { - fn from(create: CreateScheme) -> Self { - match create { - CreateScheme::Create => CallKind::Create, - CreateScheme::Create2 { .. } => CallKind::Create2, - } - } -} - -impl From for ActionType { - fn from(kind: CallKind) -> Self { - match kind { - CallKind::Call | CallKind::StaticCall | CallKind::DelegateCall | CallKind::CallCode => { - ActionType::Call - } - CallKind::Create => ActionType::Create, - CallKind::Create2 => ActionType::Create, - } - } -} - -impl From for CallType { - fn from(ty: CallKind) -> Self { - match ty { - CallKind::Call => CallType::Call, - CallKind::StaticCall => CallType::StaticCall, - CallKind::CallCode => CallType::CallCode, - CallKind::DelegateCall => CallType::DelegateCall, - CallKind::Create => CallType::None, - CallKind::Create2 => CallType::None, - } - } -} - /// A trace of a call. #[derive(Clone, Debug, PartialEq, Eq)] -pub(crate) struct CallTrace { +pub struct CallTrace { /// The depth of the call - pub(crate) depth: usize, + pub depth: usize, /// Whether the call was successful - pub(crate) success: bool, + pub success: bool, /// caller of this call - pub(crate) caller: Address, + pub caller: Address, /// The destination address of the call or the address from the created contract. /// /// In other words, this is the callee if the [CallKind::Call] or the address of the created /// contract if [CallKind::Create]. - pub(crate) address: Address, + pub address: Address, /// Whether this is a call to a precompile /// /// Note: This is an Option because not all tracers make use of this - pub(crate) maybe_precompile: Option, + pub maybe_precompile: Option, /// Holds the target for the selfdestruct refund target if `status` is /// [InstructionResult::SelfDestruct] - pub(crate) selfdestruct_refund_target: Option
, + pub selfdestruct_refund_target: Option
, /// The kind of call this is - pub(crate) kind: CallKind, + pub kind: CallKind, /// The value transferred in the call - pub(crate) value: U256, + pub value: U256, /// The calldata for the call, or the init code for contract creations - pub(crate) data: Bytes, + pub data: Bytes, /// The return data of the call if this was not a contract creation, otherwise it is the /// runtime bytecode of the created contract - pub(crate) output: Bytes, + pub output: Bytes, /// The gas cost of the call - pub(crate) gas_used: u64, + pub gas_used: u64, /// The gas limit of the call - pub(crate) gas_limit: u64, + pub gas_limit: u64, /// The status of the trace's call - pub(crate) status: InstructionResult, + pub status: InstructionResult, /// call context of the runtime - pub(crate) call_context: Option>, + pub call_context: Option>, /// Opcode-level execution steps - pub(crate) steps: Vec, + pub steps: Vec, } impl CallTrace { - // Returns true if the status code is an error or revert, See [InstructionResult::Revert] + /// Returns true if the status code is an error or revert, See [InstructionResult::Revert] #[inline] - pub(crate) fn is_error(&self) -> bool { + pub fn is_error(&self) -> bool { self.status.is_error() } - // Returns true if the status code is a revert + /// Returns true if the status code is a revert #[inline] - pub(crate) fn is_revert(&self) -> bool { + pub fn is_revert(&self) -> bool { self.status == InstructionResult::Revert } @@ -225,26 +120,26 @@ impl Default for CallTrace { /// A node in the arena #[derive(Default, Debug, Clone, PartialEq, Eq)] -pub(crate) struct CallTraceNode { +pub struct CallTraceNode { /// Parent node index in the arena - pub(crate) parent: Option, + pub parent: Option, /// Children node indexes in the arena - pub(crate) children: Vec, + pub children: Vec, /// This node's index in the arena - pub(crate) idx: usize, + pub idx: usize, /// The call trace - pub(crate) trace: CallTrace, - /// Logs - pub(crate) logs: Vec, + pub trace: CallTrace, + /// Recorded logs, if enabled + pub logs: Vec, /// Ordering of child calls and logs - pub(crate) ordering: Vec, + pub ordering: Vec, } impl CallTraceNode { /// Returns the call context's execution address /// /// See `Inspector::call` impl of [TracingInspector](crate::tracing::TracingInspector) - pub(crate) fn execution_address(&self) -> Address { + pub fn execution_address(&self) -> Address { if self.trace.kind.is_delegate() { self.trace.caller } else { @@ -258,7 +153,7 @@ impl CallTraceNode { /// /// If the slot is accessed more than once, the result only includes the first time it was /// accessed, in other words in only returns the original value of the slot. - pub(crate) fn touched_slots(&self) -> BTreeMap { + pub fn touched_slots(&self) -> BTreeMap { let mut touched_slots = BTreeMap::new(); for change in self.trace.steps.iter().filter_map(|s| s.storage_change.as_ref()) { match touched_slots.entry(change.key) { @@ -429,8 +324,6 @@ impl CallTraceNode { } /// Converts this call trace into an _empty_ geth [CallFrame] - /// - /// Caution: this does not include any of the child calls pub(crate) fn geth_empty_call_frame(&self, include_logs: bool) -> CallFrame { let mut call_frame = CallFrame { typ: self.trace.kind.to_string(), @@ -465,7 +358,7 @@ impl CallTraceNode { .iter() .map(|log| CallLogFrame { address: Some(self.execution_address()), - topics: Some(log.topics.clone()), + topics: Some(log.topics().to_vec()), data: Some(log.data.clone()), }) .collect(); @@ -475,6 +368,110 @@ impl CallTraceNode { } } +/// A unified representation of a call +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "UPPERCASE")] +#[allow(missing_docs)] +pub enum CallKind { + #[default] + Call, + StaticCall, + CallCode, + DelegateCall, + Create, + Create2, +} + +impl CallKind { + /// Returns true if the call is a create + #[inline] + pub fn is_any_create(&self) -> bool { + matches!(self, CallKind::Create | CallKind::Create2) + } + + /// Returns true if the call is a delegate of some sorts + #[inline] + pub fn is_delegate(&self) -> bool { + matches!(self, CallKind::DelegateCall | CallKind::CallCode) + } + + /// Returns true if the call is [CallKind::StaticCall]. + #[inline] + pub fn is_static_call(&self) -> bool { + matches!(self, CallKind::StaticCall) + } +} + +impl std::fmt::Display for CallKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CallKind::Call => { + write!(f, "CALL") + } + CallKind::StaticCall => { + write!(f, "STATICCALL") + } + CallKind::CallCode => { + write!(f, "CALLCODE") + } + CallKind::DelegateCall => { + write!(f, "DELEGATECALL") + } + CallKind::Create => { + write!(f, "CREATE") + } + CallKind::Create2 => { + write!(f, "CREATE2") + } + } + } +} + +impl From for CallKind { + fn from(scheme: CallScheme) -> Self { + match scheme { + CallScheme::Call => CallKind::Call, + CallScheme::StaticCall => CallKind::StaticCall, + CallScheme::CallCode => CallKind::CallCode, + CallScheme::DelegateCall => CallKind::DelegateCall, + } + } +} + +impl From for CallKind { + fn from(create: CreateScheme) -> Self { + match create { + CreateScheme::Create => CallKind::Create, + CreateScheme::Create2 { .. } => CallKind::Create2, + } + } +} + +impl From for ActionType { + fn from(kind: CallKind) -> Self { + match kind { + CallKind::Call | CallKind::StaticCall | CallKind::DelegateCall | CallKind::CallCode => { + ActionType::Call + } + CallKind::Create => ActionType::Create, + CallKind::Create2 => ActionType::Create, + } + } +} + +impl From for CallType { + fn from(ty: CallKind) -> Self { + match ty { + CallKind::Call => CallType::Call, + CallKind::StaticCall => CallType::StaticCall, + CallKind::CallCode => CallType::CallCode, + CallKind::DelegateCall => CallType::DelegateCall, + CallKind::Create => CallType::None, + CallKind::Create2 => CallType::None, + } + } +} + pub(crate) struct CallTraceStepStackItem<'a> { /// The trace node that contains this step pub(crate) trace_node: &'a CallTraceNode, @@ -485,59 +482,49 @@ pub(crate) struct CallTraceStepStackItem<'a> { } /// Ordering enum for calls and logs -/// -/// i.e. if Call 0 occurs before Log 0, it will be pushed into the `CallTraceNode`'s ordering before -/// the log. #[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum LogCallOrder { +pub enum LogCallOrder { + /// Contains the index of the corresponding log Log(usize), + /// Contains the index of the corresponding trace node Call(usize), } -/// Ethereum log. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct RawLog { - /// Indexed event params are represented as log topics. - pub(crate) topics: Vec, - /// Others are just plain data. - pub(crate) data: Bytes, -} - /// Represents a tracked call step during execution #[derive(Clone, Debug, PartialEq, Eq)] -pub(crate) struct CallTraceStep { +pub struct CallTraceStep { // Fields filled in `step` /// Call depth - pub(crate) depth: u64, + pub depth: u64, /// Program counter before step execution - pub(crate) pc: usize, + pub pc: usize, /// Opcode to be executed - pub(crate) op: OpCode, + pub op: OpCode, /// Current contract address - pub(crate) contract: Address, + pub contract: Address, /// Stack before step execution - pub(crate) stack: Stack, + pub stack: Option>, /// The new stack items placed by this step if any - pub(crate) push_stack: Option>, + pub push_stack: Option>, /// All allocated memory in a step /// /// This will be empty if memory capture is disabled - pub(crate) memory: RecordedMemory, + pub memory: RecordedMemory, /// Size of memory at the beginning of the step - pub(crate) memory_size: usize, + pub memory_size: usize, /// Remaining gas before step execution - pub(crate) gas_remaining: u64, + pub gas_remaining: u64, /// Gas refund counter before step execution - pub(crate) gas_refund_counter: u64, + pub gas_refund_counter: u64, // Fields filled in `step_end` /// Gas cost of step execution - pub(crate) gas_cost: u64, + pub gas_cost: u64, /// Change of the contract state after step execution (effect of the SLOAD/SSTORE instructions) - pub(crate) storage_change: Option, + pub storage_change: Option, /// Final status of the step /// /// This is set after the step was executed. - pub(crate) status: InstructionResult, + pub status: InstructionResult, } // === impl CallTraceStep === @@ -568,7 +555,7 @@ impl CallTraceStep { }; if opts.is_stack_enabled() { - log.stack = Some(self.stack.data().clone()); + log.stack = self.stack.clone(); } if opts.is_memory_enabled() { @@ -616,25 +603,37 @@ impl CallTraceStep { /// from an SSTORE or SLOAD instruction. #[allow(clippy::upper_case_acronyms)] #[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub(crate) enum StorageChangeReason { +pub enum StorageChangeReason { + /// SLOAD opcode SLOAD, + /// SSTORE opcode SSTORE, } -/// Represents a storage change during execution +/// Represents a storage change during execution. +/// +/// This maps to evm internals: +/// [JournalEntry::StorageChange](revm::JournalEntry::StorageChange) +/// +/// It is used to track both storage change and warm load of a storage slot. For warm load in regard +/// to EIP-2929 AccessList had_value will be None. #[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub(crate) struct StorageChange { - pub(crate) key: U256, - pub(crate) value: U256, - pub(crate) had_value: Option, - pub(crate) reason: StorageChangeReason, +pub struct StorageChange { + /// key of the storage slot + pub key: U256, + /// Current value of the storage slot + pub value: U256, + /// The previous value of the storage slot, if any + pub had_value: Option, + /// How this storage was accessed + pub reason: StorageChangeReason, } /// Represents the memory captured during execution /// /// This is a wrapper around the [SharedMemory](revm::interpreter::SharedMemory) context memory. #[derive(Debug, Clone, PartialEq, Eq, Default)] -pub(crate) struct RecordedMemory(pub(crate) Vec); +pub struct RecordedMemory(pub(crate) Vec); impl RecordedMemory { #[inline] @@ -642,8 +641,9 @@ impl RecordedMemory { Self(mem) } + /// Returns the memory as a byte slice #[inline] - pub(crate) fn as_bytes(&self) -> &[u8] { + pub fn as_bytes(&self) -> &[u8] { &self.0 } @@ -652,19 +652,27 @@ impl RecordedMemory { self.0.resize(size, 0); } + /// Returns the size of the memory #[inline] - pub(crate) fn len(&self) -> usize { + pub fn len(&self) -> usize { self.0.len() } + /// Returns whether the memory is empty #[inline] - pub(crate) fn is_empty(&self) -> bool { + pub fn is_empty(&self) -> bool { self.0.is_empty() } /// Converts the memory into 32byte hex chunks #[inline] - pub(crate) fn memory_chunks(&self) -> Vec { + pub fn memory_chunks(&self) -> Vec { convert_memory(self.as_bytes()) } } + +impl AsRef<[u8]> for RecordedMemory { + fn as_ref(&self) -> &[u8] { + self.as_bytes() + } +} diff --git a/crates/revm/revm-inspectors/src/tracing/utils.rs b/crates/revm/revm-inspectors/src/tracing/utils.rs index a50edd89519e..c8cd2dc48b5b 100644 --- a/crates/revm/revm-inspectors/src/tracing/utils.rs +++ b/crates/revm/revm-inspectors/src/tracing/utils.rs @@ -1,9 +1,10 @@ //! Util functions for revm related ops -use reth_primitives::{hex, revm_primitives::db::DatabaseRef, Address, Bytes, B256, KECCAK_EMPTY}; +use alloy_primitives::{hex, Address, Bytes, B256}; use revm::{ interpreter::CreateInputs, - primitives::{CreateScheme, SpecId}, + primitives::{CreateScheme, SpecId, KECCAK_EMPTY}, + DatabaseRef, }; /// creates the memory data in 32byte chunks diff --git a/crates/revm/src/factory.rs b/crates/revm/src/factory.rs index 6e326b5cd362..c35b1ad11af6 100644 --- a/crates/revm/src/factory.rs +++ b/crates/revm/src/factory.rs @@ -7,14 +7,14 @@ use reth_primitives::ChainSpec; use reth_provider::{ExecutorFactory, PrunableBlockExecutor, StateProvider}; use std::sync::Arc; -/// Factory that spawn Executor. +/// Factory for creating [EVMProcessor]. #[derive(Clone, Debug)] -pub struct Factory { +pub struct EvmProcessorFactory { chain_spec: Arc, stack: Option, } -impl Factory { +impl EvmProcessorFactory { /// Create new factory pub fn new(chain_spec: Arc) -> Self { Self { chain_spec, stack: None } @@ -33,7 +33,7 @@ impl Factory { } } -impl ExecutorFactory for Factory { +impl ExecutorFactory for EvmProcessorFactory { fn with_state<'a, SP: StateProvider + 'a>( &'a self, sp: SP, diff --git a/crates/revm/src/lib.rs b/crates/revm/src/lib.rs index b1714e00578e..116269da668b 100644 --- a/crates/revm/src/lib.rs +++ b/crates/revm/src/lib.rs @@ -22,7 +22,7 @@ pub mod processor; pub mod state_change; /// revm executor factory. -pub use factory::Factory; +pub use factory::EvmProcessorFactory; /// reexport for convenience pub use reth_revm_inspectors::*; diff --git a/crates/revm/src/processor.rs b/crates/revm/src/processor.rs index f78c0411703f..1f30ef4dd985 100644 --- a/crates/revm/src/processor.rs +++ b/crates/revm/src/processor.rs @@ -568,6 +568,7 @@ mod tests { use reth_provider::{ AccountReader, BlockHashReader, BundleStateWithReceipts, StateRootProvider, }; + use reth_trie::updates::TrieUpdates; use revm::{Database, TransitionState}; use std::collections::HashMap; @@ -627,6 +628,13 @@ mod tests { fn state_root(&self, _bundle_state: &BundleStateWithReceipts) -> ProviderResult { unimplemented!("state root computation is not supported") } + + fn state_root_with_updates( + &self, + _bundle_state: &BundleStateWithReceipts, + ) -> ProviderResult<(B256, TrieUpdates)> { + unimplemented!("state root computation is not supported") + } } impl StateProvider for StateProviderTest { diff --git a/crates/rpc/rpc-builder/src/lib.rs b/crates/rpc/rpc-builder/src/lib.rs index b23029b75fc2..bd2a8db3b711 100644 --- a/crates/rpc/rpc-builder/src/lib.rs +++ b/crates/rpc/rpc-builder/src/lib.rs @@ -2081,6 +2081,20 @@ mod tests { assert_eq!(selection, RethRpcModule::EthCallBundle); } + #[test] + fn parse_eth_call_bundle_selection() { + let selection = "eth,admin,debug,eth-call-bundle".parse::().unwrap(); + assert_eq!( + selection, + RpcModuleSelection::Selection(vec![ + RethRpcModule::Eth, + RethRpcModule::Admin, + RethRpcModule::Debug, + RethRpcModule::EthCallBundle, + ]) + ); + } + #[test] fn parse_rpc_module_selection() { let selection = "all".parse::().unwrap(); diff --git a/crates/rpc/rpc-testing-util/tests/it/trace.rs b/crates/rpc/rpc-testing-util/tests/it/trace.rs index f3d4ee7e7b3c..c2192b0f1e63 100644 --- a/crates/rpc/rpc-testing-util/tests/it/trace.rs +++ b/crates/rpc/rpc-testing-util/tests/it/trace.rs @@ -1,7 +1,9 @@ use futures::StreamExt; use jsonrpsee::http_client::HttpClientBuilder; use reth_rpc_api_testing_util::{trace::TraceApiExt, utils::parse_env_url}; -use reth_rpc_types::trace::{filter::TraceFilter, parity::TraceType}; +use reth_rpc_types::trace::{ + filter::TraceFilter, parity::TraceType, tracerequest::TraceCallRequest, +}; use std::{collections::HashSet, time::Instant}; /// This is intended to be run locally against a running node. /// @@ -67,3 +69,26 @@ async fn trace_filters() { println!("Duration since test start: {:?}", start_time.elapsed()); } } + +#[tokio::test(flavor = "multi_thread")] +#[ignore] +async fn trace_call() { + let url = parse_env_url("RETH_RPC_TEST_NODE_URL").unwrap(); + let client = HttpClientBuilder::default().build(url).unwrap(); + let trace_call_request = TraceCallRequest::default(); + let mut stream = client.trace_call_stream(trace_call_request); + let start_time = Instant::now(); + + while let Some(result) = stream.next().await { + match result { + Ok(trace_result) => { + println!("Trace Result: {:?}", trace_result); + } + Err((error, request)) => { + eprintln!("Error for request {:?}: {:?}", request, error); + } + } + } + + println!("Completed in {:?}", start_time.elapsed()); +} diff --git a/crates/rpc/rpc-types-compat/src/block.rs b/crates/rpc/rpc-types-compat/src/block.rs index 570697dffb76..578a47b36929 100644 --- a/crates/rpc/rpc-types-compat/src/block.rs +++ b/crates/rpc/rpc-types-compat/src/block.rs @@ -2,7 +2,9 @@ use crate::transaction::from_recovered_with_block_context; use alloy_rlp::Encodable; -use reth_primitives::{Block as PrimitiveBlock, Header as PrimitiveHeader, B256, U256, U64}; +use reth_primitives::{ + Block as PrimitiveBlock, BlockWithSenders, Header as PrimitiveHeader, B256, U256, U64, +}; use reth_rpc_types::{Block, BlockError, BlockTransactions, BlockTransactionsKind, Header}; /// Converts the given primitive block into a [Block] response with the given @@ -10,7 +12,7 @@ use reth_rpc_types::{Block, BlockError, BlockTransactions, BlockTransactionsKind /// /// If a `block_hash` is provided, then this is used, otherwise the block hash is computed. pub fn from_block( - block: PrimitiveBlock, + block: BlockWithSenders, total_difficulty: U256, kind: BlockTransactionsKind, block_hash: Option, @@ -29,7 +31,7 @@ pub fn from_block( /// This will populate the `transactions` field with only the hashes of the transactions in the /// block: [BlockTransactions::Hashes] pub fn from_block_with_tx_hashes( - block: PrimitiveBlock, + block: BlockWithSenders, total_difficulty: U256, block_hash: Option, ) -> Block { @@ -39,7 +41,7 @@ pub fn from_block_with_tx_hashes( from_block_with_transactions( block.length(), block_hash, - block, + block.block, total_difficulty, BlockTransactions::Hashes(transactions), ) @@ -51,35 +53,38 @@ pub fn from_block_with_tx_hashes( /// This will populate the `transactions` field with the _full_ /// [Transaction](reth_rpc_types::Transaction) objects: [BlockTransactions::Full] pub fn from_block_full( - mut block: PrimitiveBlock, + mut block: BlockWithSenders, total_difficulty: U256, block_hash: Option, ) -> Result { - let block_hash = block_hash.unwrap_or_else(|| block.header.hash_slow()); - let block_number = block.number; - let base_fee_per_gas = block.base_fee_per_gas; + let block_hash = block_hash.unwrap_or_else(|| block.block.header.hash_slow()); + let block_number = block.block.number; + let base_fee_per_gas = block.block.base_fee_per_gas; // NOTE: we can safely remove the body here because not needed to finalize the `Block` in // `from_block_with_transactions`, however we need to compute the length before - let block_length = block.length(); - let body = std::mem::take(&mut block.body); + let block_length = block.block.length(); + let body = std::mem::take(&mut block.block.body); + let transactions_with_senders = body.into_iter().zip(block.senders); + let transactions = transactions_with_senders + .enumerate() + .map(|(idx, (tx, sender))| { + let signed_tx_ec_recovered = tx.with_signer(sender); - let mut transactions = Vec::with_capacity(block.body.len()); - for (idx, tx) in body.into_iter().enumerate() { - let signed_tx = tx.into_ecrecovered().ok_or(BlockError::InvalidSignature)?; - transactions.push(from_recovered_with_block_context( - signed_tx, - block_hash, - block_number, - base_fee_per_gas, - U256::from(idx), - )) - } + from_recovered_with_block_context( + signed_tx_ec_recovered, + block_hash, + block_number, + base_fee_per_gas, + U256::from(idx), + ) + }) + .collect::>(); Ok(from_block_with_transactions( block_length, block_hash, - block, + block.block, total_difficulty, BlockTransactions::Full(transactions), )) diff --git a/crates/rpc/rpc-types-compat/src/log.rs b/crates/rpc/rpc-types-compat/src/log.rs index 2d2ebb2e98c9..16d59692cfbb 100644 --- a/crates/rpc/rpc-types-compat/src/log.rs +++ b/crates/rpc/rpc-types-compat/src/log.rs @@ -21,35 +21,3 @@ pub fn from_primitive_log(log: reth_primitives::Log) -> reth_rpc_types::Log { pub fn to_primitive_log(log: reth_rpc_types::Log) -> reth_primitives::Log { reth_primitives::Log { address: log.address, topics: log.topics, data: log.data } } - -/// Converts a primitive `AccessList` structure from the `reth_primitives` module into the -/// corresponding RPC type. -#[inline] -pub fn from_primitive_access_list(list: reth_primitives::AccessList) -> reth_rpc_types::AccessList { - let converted_list: Vec = list - .0 - .into_iter() - .map(|item| reth_rpc_types::AccessListItem { - address: item.address, - storage_keys: item.storage_keys, - }) - .collect(); - - reth_rpc_types::AccessList(converted_list) -} - -/// Converts a primitive `AccessList` structure from the `reth_primitives` module into the -/// corresponding RPC type. -#[inline] -pub fn to_primitive_access_list(list: reth_rpc_types::AccessList) -> reth_primitives::AccessList { - let converted_list: Vec = list - .0 - .into_iter() - .map(|item| reth_primitives::AccessListItem { - address: item.address, - storage_keys: item.storage_keys, - }) - .collect(); - - reth_primitives::AccessList(converted_list) -} diff --git a/crates/rpc/rpc-types-compat/src/transaction/mod.rs b/crates/rpc/rpc-types-compat/src/transaction/mod.rs index 8c37a8a8f02d..414dc8063aa2 100644 --- a/crates/rpc/rpc-types-compat/src/transaction/mod.rs +++ b/crates/rpc/rpc-types-compat/src/transaction/mod.rs @@ -162,22 +162,6 @@ pub fn from_primitive_access_list( ) } -/// Convert [reth_rpc_types::AccessList] to [reth_primitives::AccessList] -pub fn to_primitive_access_list( - access_list: reth_rpc_types::AccessList, -) -> reth_primitives::AccessList { - reth_primitives::AccessList( - access_list - .0 - .into_iter() - .map(|item| reth_primitives::AccessListItem { - address: item.address.0.into(), - storage_keys: item.storage_keys.into_iter().map(|key| key.0.into()).collect(), - }) - .collect(), - ) -} - /// Convert [TransactionSignedEcRecovered] to [CallRequest] pub fn transaction_to_call_request(tx: TransactionSignedEcRecovered) -> CallRequest { let from = tx.signer(); diff --git a/crates/rpc/rpc-types-compat/src/transaction/typed.rs b/crates/rpc/rpc-types-compat/src/transaction/typed.rs index 19adf6652c2d..55d316040d4e 100644 --- a/crates/rpc/rpc-types-compat/src/transaction/typed.rs +++ b/crates/rpc/rpc-types-compat/src/transaction/typed.rs @@ -1,5 +1,3 @@ -use crate::log::to_primitive_access_list; - /// Converts a typed transaction request into a primitive transaction. /// /// Returns `None` if any of the following are true: @@ -30,7 +28,7 @@ pub fn to_primitive_transaction( to: to_primitive_transaction_kind(tx.kind), value: tx.value.into(), input: tx.input, - access_list: to_primitive_access_list(tx.access_list), + access_list: tx.access_list.into(), }), TypedTransactionRequest::EIP1559(tx) => Transaction::Eip1559(TxEip1559 { chain_id: tx.chain_id, @@ -40,7 +38,7 @@ pub fn to_primitive_transaction( to: to_primitive_transaction_kind(tx.kind), value: tx.value.into(), input: tx.input, - access_list: to_primitive_access_list(tx.access_list), + access_list: tx.access_list.into(), max_priority_fee_per_gas: tx.max_priority_fee_per_gas.to(), }), TypedTransactionRequest::EIP4844(tx) => Transaction::Eip4844(TxEip4844 { @@ -51,7 +49,7 @@ pub fn to_primitive_transaction( max_priority_fee_per_gas: tx.max_priority_fee_per_gas.to(), to: to_primitive_transaction_kind(tx.kind), value: tx.value.into(), - access_list: to_primitive_access_list(tx.access_list), + access_list: tx.access_list.into(), blob_versioned_hashes: tx.blob_versioned_hashes, max_fee_per_blob_gas: tx.max_fee_per_blob_gas.to(), input: tx.input, diff --git a/crates/rpc/rpc-types/src/eth/block.rs b/crates/rpc/rpc-types/src/eth/block.rs index cdad6cf12749..17ecf8ce755d 100644 --- a/crates/rpc/rpc-types/src/eth/block.rs +++ b/crates/rpc/rpc-types/src/eth/block.rs @@ -938,4 +938,11 @@ mod tests { let block2 = serde_json::from_str::(&serialized).unwrap(); assert_eq!(block, block2); } + + #[test] + fn compact_block_number_serde() { + let num: BlockNumberOrTag = 1u64.into(); + let serialized = serde_json::to_string(&num).unwrap(); + assert_eq!(serialized, "\"0x1\""); + } } diff --git a/crates/rpc/rpc-types/src/eth/engine/payload.rs b/crates/rpc/rpc-types/src/eth/engine/payload.rs index 739a6a10ad41..9ae18a351b10 100644 --- a/crates/rpc/rpc-types/src/eth/engine/payload.rs +++ b/crates/rpc/rpc-types/src/eth/engine/payload.rs @@ -447,7 +447,7 @@ pub struct OptimismPayloadAttributes { #[serde( default, skip_serializing_if = "Option::is_none", - deserialize_with = "crate::serde_helpers::u64_hex::u64_hex_opt::deserialize" + deserialize_with = "crate::serde_helpers::u64_hex_opt::deserialize" )] pub gas_limit: Option, } diff --git a/crates/rpc/rpc-types/src/eth/fee.rs b/crates/rpc/rpc-types/src/eth/fee.rs index c7717994ea18..5fb145ab68a3 100644 --- a/crates/rpc/rpc-types/src/eth/fee.rs +++ b/crates/rpc/rpc-types/src/eth/fee.rs @@ -36,9 +36,8 @@ pub struct FeeHistory { /// /// # Note /// - /// The `Option` is only for compatability with Erigon and Geth. - #[serde(skip_serializing_if = "Vec::is_empty")] - #[serde(default)] + /// Empty list is skipped only for compatibility with Erigon and Geth. + #[serde(default, skip_serializing_if = "Vec::is_empty")] pub base_fee_per_gas: Vec, /// An array of block gas used ratios. These are calculated as the ratio /// of `gasUsed` and `gasLimit`. @@ -46,13 +45,11 @@ pub struct FeeHistory { /// # Note /// /// The `Option` is only for compatability with Erigon and Geth. - #[serde(skip_serializing_if = "Vec::is_empty")] - #[serde(default)] pub gas_used_ratio: Vec, /// Lowest number block of the returned range. pub oldest_block: U256, /// An (optional) array of effective priority fee per gas data points from a single /// block. All zeroes are returned if the block is empty. - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub reward: Option>>, } diff --git a/crates/rpc/rpc-types/src/lib.rs b/crates/rpc/rpc-types/src/lib.rs index 8c9e22485bbd..f71f7660810e 100644 --- a/crates/rpc/rpc-types/src/lib.rs +++ b/crates/rpc/rpc-types/src/lib.rs @@ -20,7 +20,7 @@ mod otterscan; mod peer; pub mod relay; mod rpc; -mod serde_helpers; +pub mod serde_helpers; pub use admin::*; pub use eth::*; diff --git a/crates/rpc/rpc-types/src/serde_helpers/mod.rs b/crates/rpc/rpc-types/src/serde_helpers/mod.rs index 1c45b0d56d42..adeb4a24583f 100644 --- a/crates/rpc/rpc-types/src/serde_helpers/mod.rs +++ b/crates/rpc/rpc-types/src/serde_helpers/mod.rs @@ -1,22 +1,18 @@ //! Serde helpers for primitive types. -use alloy_primitives::U256; -use serde::{Deserialize, Deserializer, Serializer}; +use alloy_primitives::B256; +use serde::Serializer; pub mod json_u256; +pub use json_u256::JsonU256; + +/// Helpers for dealing with numbers. pub mod num; +pub use num::*; + /// Storage related helpers. pub mod storage; -pub mod u64_hex; - -/// Deserializes the input into a U256, accepting both 0x-prefixed hex and decimal strings with -/// arbitrary precision, defined by serde_json's [`Number`](serde_json::Number). -pub fn from_int_or_hex<'de, D>(deserializer: D) -> Result -where - D: Deserializer<'de>, -{ - num::NumberOrHexU256::deserialize(deserializer)?.try_into_u256() -} +pub use storage::JsonStorageKey; /// Serialize a byte vec as a hex string _without_ the "0x" prefix. /// @@ -28,3 +24,11 @@ where { s.serialize_str(&alloy_primitives::hex::encode(x.as_ref())) } + +/// Serialize a [B256] as a hex string _without_ the "0x" prefix. +pub fn serialize_b256_hex_string_no_prefix(x: &B256, s: S) -> Result +where + S: Serializer, +{ + s.serialize_str(&format!("{x:x}")) +} diff --git a/crates/rpc/rpc-types/src/serde_helpers/num.rs b/crates/rpc/rpc-types/src/serde_helpers/num.rs index d1e6959065fe..4c34471cd7d0 100644 --- a/crates/rpc/rpc-types/src/serde_helpers/num.rs +++ b/crates/rpc/rpc-types/src/serde_helpers/num.rs @@ -69,6 +69,68 @@ impl<'de> Deserialize<'de> for U64HexOrNumber { } } +/// serde functions for handling `u64` as [U64] +pub mod u64_hex { + use alloy_primitives::U64; + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + /// Deserializes an `u64` from [U64] accepting a hex quantity string with optional 0x prefix + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + U64::deserialize(deserializer).map(|val| val.to()) + } + + /// Serializes u64 as hex string + pub fn serialize(value: &u64, s: S) -> Result { + U64::from(*value).serialize(s) + } +} + +/// serde functions for handling `Option` as [U64] +pub mod u64_hex_opt { + use alloy_primitives::U64; + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + /// Serializes u64 as hex string + pub fn serialize(value: &Option, s: S) -> Result { + match value { + Some(val) => U64::from(*val).serialize(s), + None => s.serialize_none(), + } + } + + /// Deserializes an `Option` from [U64] accepting a hex quantity string with optional 0x prefix + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + Ok(U64::deserialize(deserializer) + .map_or(None, |v| Some(u64::from_be_bytes(v.to_be_bytes())))) + } +} + +/// serde functions for handling primitive `u64` as [U64] +pub mod u64_hex_or_decimal { + use crate::serde_helpers::num::U64HexOrNumber; + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + /// Deserializes an `u64` accepting a hex quantity string with optional 0x prefix or + /// a number + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + U64HexOrNumber::deserialize(deserializer).map(Into::into) + } + + /// Serializes u64 as hex string + pub fn serialize(value: &u64, s: S) -> Result { + U64HexOrNumber::from(*value).serialize(s) + } +} + /// serde functions for handling primitive optional `u64` as [U64] pub mod u64_hex_or_decimal_opt { use crate::serde_helpers::num::U64HexOrNumber; @@ -137,3 +199,25 @@ where { NumberOrHexU256::deserialize(deserializer)?.try_into_u256() } + +#[cfg(test)] +mod tests { + use super::*; + use serde::{Deserialize, Serialize}; + + #[test] + fn test_hex_u64() { + #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] + struct Value { + #[serde(with = "u64_hex")] + inner: u64, + } + + let val = Value { inner: 1000 }; + let s = serde_json::to_string(&val).unwrap(); + assert_eq!(s, "{\"inner\":\"0x3e8\"}"); + + let deserialized: Value = serde_json::from_str(&s).unwrap(); + assert_eq!(val, deserialized); + } +} diff --git a/crates/rpc/rpc-types/src/serde_helpers/u64_hex.rs b/crates/rpc/rpc-types/src/serde_helpers/u64_hex.rs deleted file mode 100644 index e73061cdc936..000000000000 --- a/crates/rpc/rpc-types/src/serde_helpers/u64_hex.rs +++ /dev/null @@ -1,33 +0,0 @@ -//! Helper to deserialize an `u64` from [U64] accepting a hex quantity string with optional 0x -//! prefix - -use alloy_primitives::U64; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; - -/// Deserializes an `u64` from [U64] accepting a hex quantity string with optional 0x prefix -pub fn deserialize<'de, D>(deserializer: D) -> Result -where - D: Deserializer<'de>, -{ - U64::deserialize(deserializer).map(|val| val.to()) -} - -/// Serializes u64 as hex string -pub fn serialize(value: &u64, s: S) -> Result { - U64::from(*value).serialize(s) -} - -/// serde functions for handling `Option` as [U64] -pub mod u64_hex_opt { - use alloy_primitives::U64; - use serde::{Deserialize, Deserializer}; - - /// Deserializes an `Option` from [U64] accepting a hex quantity string with optional 0x prefix - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - Ok(U64::deserialize(deserializer) - .map_or(None, |v| Some(u64::from_be_bytes(v.to_be_bytes())))) - } -} diff --git a/crates/rpc/rpc/src/debug.rs b/crates/rpc/rpc/src/debug.rs index b8b313d804bf..392c2172161e 100644 --- a/crates/rpc/rpc/src/debug.rs +++ b/crates/rpc/rpc/src/debug.rs @@ -19,7 +19,7 @@ use reth_primitives::{ db::{DatabaseCommit, DatabaseRef}, BlockEnv, CfgEnv, }, - Account, Address, Block, BlockId, BlockNumberOrTag, Bytes, TransactionSigned, B256, + Address, Block, BlockId, BlockNumberOrTag, Bytes, TransactionSigned, B256, }; use reth_provider::{BlockReaderIdExt, HeaderProvider, StateProviderBox}; use reth_revm::{ @@ -458,7 +458,7 @@ where opts: GethDebugTracingOptions, env: Env, at: BlockId, - db: &mut SubState>, + db: &mut SubState, ) -> EthResult<(GethTrace, revm_primitives::State)> { let GethDebugTracingOptions { config, tracer, tracer_config, .. } = opts; @@ -610,16 +610,7 @@ where while let Some(req) = stream.next().await { match req { JsDbRequest::Basic { address, resp } => { - let acc = db - .basic_ref(address) - .map(|maybe_acc| { - maybe_acc.map(|acc| Account { - nonce: acc.nonce, - balance: acc.balance, - bytecode_hash: Some(acc.code_hash), - }) - }) - .map_err(|err| err.to_string()); + let acc = db.basic_ref(address).map_err(|err| err.to_string()); let _ = resp.send(acc); } JsDbRequest::Code { code_hash, resp } => { diff --git a/crates/rpc/rpc/src/eth/api/block.rs b/crates/rpc/rpc/src/eth/api/block.rs index 3cd33be201f9..c1c835107cdb 100644 --- a/crates/rpc/rpc/src/eth/api/block.rs +++ b/crates/rpc/rpc/src/eth/api/block.rs @@ -146,13 +146,23 @@ where &self, block_id: impl Into, ) -> EthResult> { + self.block_with_senders(block_id) + .await + .map(|maybe_block| maybe_block.map(|block| block.block)) + } + + /// Returns the block object for the given block id. + pub(crate) async fn block_with_senders( + &self, + block_id: impl Into, + ) -> EthResult> { let block_id = block_id.into(); if block_id.is_pending() { // Pending block can be fetched directly without need for caching - let maybe_pending = self.provider().pending_block()?; + let maybe_pending = self.provider().pending_block_with_senders()?; return if maybe_pending.is_some() { - return Ok(maybe_pending) + Ok(maybe_pending) } else { self.local_pending_block().await } @@ -163,7 +173,7 @@ where None => return Ok(None), }; - Ok(self.cache().get_sealed_block(block_hash).await?) + Ok(self.cache().get_sealed_block_with_senders(block_hash).await?) } /// Returns the populated rpc block object for the given block id. @@ -175,7 +185,7 @@ where block_id: impl Into, full: bool, ) -> EthResult> { - let block = match self.block(block_id).await? { + let block = match self.block_with_senders(block_id).await? { Some(block) => block, None => return Ok(None), }; @@ -184,7 +194,7 @@ where .provider() .header_td_by_number(block.number)? .ok_or(EthApiError::UnknownBlockNumber)?; - let block = from_block(block.into(), total_difficulty, full.into(), Some(block_hash))?; + let block = from_block(block.unseal(), total_difficulty, full.into(), Some(block_hash))?; Ok(Some(block.into())) } } diff --git a/crates/rpc/rpc/src/eth/api/call.rs b/crates/rpc/rpc/src/eth/api/call.rs index ce53f247d2bb..a90d64ecc5cf 100644 --- a/crates/rpc/rpc/src/eth/api/call.rs +++ b/crates/rpc/rpc/src/eth/api/call.rs @@ -21,7 +21,6 @@ use reth_rpc_types::{ state::StateOverride, AccessListWithGasUsed, BlockError, Bundle, CallRequest, EthCallResponse, StateContext, }; -use reth_rpc_types_compat::log::{from_primitive_access_list, to_primitive_access_list}; use reth_transaction_pool::TransactionPool; use revm::{ db::{CacheDB, DatabaseRef}, @@ -392,8 +391,7 @@ where let initial = request.access_list.take().unwrap_or_default(); let precompiles = get_precompiles(env.cfg.spec_id); - let mut inspector = - AccessListInspector::new(to_primitive_access_list(initial), from, to, precompiles); + let mut inspector = AccessListInspector::new(initial, from, to, precompiles); let (result, env) = inspect(&mut db, env, &mut inspector)?; match result.result { @@ -410,10 +408,10 @@ where let access_list = inspector.into_access_list(); // calculate the gas used using the access list - request.access_list = Some(from_primitive_access_list(access_list.clone())); + request.access_list = Some(access_list.clone()); let gas_used = self.estimate_gas_with(env.cfg, env.block, request, db.db.state())?; - Ok(AccessListWithGasUsed { access_list: from_primitive_access_list(access_list), gas_used }) + Ok(AccessListWithGasUsed { access_list, gas_used }) } } diff --git a/crates/rpc/rpc/src/eth/api/mod.rs b/crates/rpc/rpc/src/eth/api/mod.rs index b30b4db562db..4ff22571dec9 100644 --- a/crates/rpc/rpc/src/eth/api/mod.rs +++ b/crates/rpc/rpc/src/eth/api/mod.rs @@ -15,7 +15,7 @@ use reth_interfaces::RethResult; use reth_network_api::NetworkInfo; use reth_primitives::{ revm_primitives::{BlockEnv, CfgEnv}, - Address, BlockId, BlockNumberOrTag, ChainInfo, SealedBlock, B256, U256, U64, + Address, BlockId, BlockNumberOrTag, ChainInfo, SealedBlockWithSenders, B256, U256, U64, }; use reth_provider::{ BlockReaderIdExt, ChainSpecProvider, EvmEnvProvider, StateProviderBox, StateProviderFactory, @@ -206,7 +206,7 @@ where /// Returns the state at the given [BlockId] enum. /// /// Note: if not [BlockNumberOrTag::Pending] then this will only return canonical state. See also - pub fn state_at_block_id(&self, at: BlockId) -> EthResult> { + pub fn state_at_block_id(&self, at: BlockId) -> EthResult { Ok(self.provider().state_by_block_id(at)?) } @@ -216,7 +216,7 @@ where pub fn state_at_block_id_or_latest( &self, block_id: Option, - ) -> EthResult> { + ) -> EthResult { if let Some(block_id) = block_id { self.state_at_block_id(block_id) } else { @@ -225,12 +225,12 @@ where } /// Returns the state at the given block number - pub fn state_at_hash(&self, block_hash: B256) -> RethResult> { + pub fn state_at_hash(&self, block_hash: B256) -> RethResult { Ok(self.provider().history_by_block_hash(block_hash)?) } /// Returns the _latest_ state - pub fn latest_state(&self) -> RethResult> { + pub fn latest_state(&self) -> RethResult { Ok(self.provider().latest()?) } } @@ -246,7 +246,7 @@ where /// /// If no pending block is available, this will derive it from the `latest` block pub(crate) fn pending_block_env_and_cfg(&self) -> EthResult { - let origin = if let Some(pending) = self.provider().pending_block()? { + let origin = if let Some(pending) = self.provider().pending_block_with_senders()? { PendingBlockEnvOrigin::ActualPending(pending) } else { // no pending block from the CL yet, so we use the latest block and modify the env @@ -281,7 +281,7 @@ where } /// Returns the locally built pending block - pub(crate) async fn local_pending_block(&self) -> EthResult> { + pub(crate) async fn local_pending_block(&self) -> EthResult> { let pending = self.pending_block_env_and_cfg()?; if pending.origin.is_actual_pending() { return Ok(pending.origin.into_actual_pending()) diff --git a/crates/rpc/rpc/src/eth/api/pending_block.rs b/crates/rpc/rpc/src/eth/api/pending_block.rs index 827dfec1a17b..3c6f6b58793c 100644 --- a/crates/rpc/rpc/src/eth/api/pending_block.rs +++ b/crates/rpc/rpc/src/eth/api/pending_block.rs @@ -9,7 +9,7 @@ use reth_primitives::{ BlockEnv, CfgEnv, EVMError, Env, InvalidTransaction, ResultAndState, SpecId, }, Block, BlockId, BlockNumberOrTag, ChainSpec, Header, IntoRecoveredTransaction, Receipt, - Receipts, SealedBlock, SealedHeader, B256, EMPTY_OMMER_ROOT_HASH, U256, + Receipts, SealedBlockWithSenders, SealedHeader, B256, EMPTY_OMMER_ROOT_HASH, U256, }; use reth_provider::{BundleStateWithReceipts, ChainSpecProvider, StateProviderFactory}; use reth_revm::{ @@ -42,7 +42,7 @@ impl PendingBlockEnv { self, client: &Client, pool: &Pool, - ) -> EthResult + ) -> EthResult where Client: StateProviderFactory + ChainSpecProvider, Pool: TransactionPool, @@ -61,6 +61,7 @@ impl PendingBlockEnv { let block_number = block_env.number.to::(); let mut executed_txs = Vec::new(); + let mut senders = Vec::new(); let mut best_txs = pool.best_transactions_with_base_fee(base_fee); let (withdrawals, withdrawals_root) = match origin { @@ -176,7 +177,9 @@ impl PendingBlockEnv { })); // append transaction to the list of executed transactions - executed_txs.push(tx.into_signed()); + let (tx, sender) = tx.to_components(); + executed_txs.push(tx); + senders.push(sender); } // executes the withdrawals and commits them to the Database and BundleState. @@ -236,9 +239,7 @@ impl PendingBlockEnv { // seal the block let block = Block { header, body: executed_txs, ommers: vec![], withdrawals }; - let sealed_block = block.seal_slow(); - - Ok(sealed_block) + Ok(SealedBlockWithSenders { block: block.seal_slow(), senders }) } } @@ -286,7 +287,7 @@ where #[derive(Clone, Debug)] pub(crate) enum PendingBlockEnvOrigin { /// The pending block as received from the CL. - ActualPending(SealedBlock), + ActualPending(SealedBlockWithSenders), /// The header of the latest block DerivedFromLatest(SealedHeader), } @@ -298,7 +299,7 @@ impl PendingBlockEnvOrigin { } /// Consumes the type and returns the actual pending block. - pub(crate) fn into_actual_pending(self) -> Option { + pub(crate) fn into_actual_pending(self) -> Option { match self { PendingBlockEnvOrigin::ActualPending(block) => Some(block), _ => None, @@ -337,7 +338,7 @@ impl PendingBlockEnvOrigin { #[derive(Debug)] pub(crate) struct PendingBlock { /// The cached pending block - pub(crate) block: SealedBlock, + pub(crate) block: SealedBlockWithSenders, /// Timestamp when the pending block is considered outdated pub(crate) expires_at: Instant, } diff --git a/crates/rpc/rpc/src/eth/api/transactions.rs b/crates/rpc/rpc/src/eth/api/transactions.rs index 78307ee33372..58af7c7697a9 100644 --- a/crates/rpc/rpc/src/eth/api/transactions.rs +++ b/crates/rpc/rpc/src/eth/api/transactions.rs @@ -51,7 +51,7 @@ use revm::L1BlockInfo; use std::ops::Div; /// Helper alias type for the state's [CacheDB] -pub(crate) type StateCacheDB<'r> = CacheDB>>; +pub(crate) type StateCacheDB = CacheDB>; /// Commonly used transaction related functions for the [EthApi] type in the `eth_` namespace. /// @@ -63,17 +63,17 @@ pub trait EthTransactions: Send + Sync { fn call_gas_limit(&self) -> u64; /// Returns the state at the given [BlockId] - fn state_at(&self, at: BlockId) -> EthResult>; + fn state_at(&self, at: BlockId) -> EthResult; /// Executes the closure with the state that corresponds to the given [BlockId]. fn with_state_at_block(&self, at: BlockId, f: F) -> EthResult where - F: FnOnce(StateProviderBox<'_>) -> EthResult; + F: FnOnce(StateProviderBox) -> EthResult; /// Executes the closure with the state that corresponds to the given [BlockId] on a new task async fn spawn_with_state_at_block(&self, at: BlockId, f: F) -> EthResult where - F: FnOnce(StateProviderBox<'_>) -> EthResult + Send + 'static, + F: FnOnce(StateProviderBox) -> EthResult + Send + 'static, T: Send + 'static; /// Returns the revm evm env for the requested [BlockId] @@ -154,7 +154,7 @@ pub trait EthTransactions: Send + Sync { f: F, ) -> EthResult where - F: for<'r> FnOnce(StateCacheDB<'r>, Env) -> EthResult + Send + 'static, + F: FnOnce(StateCacheDB, Env) -> EthResult + Send + 'static, R: Send + 'static; /// Executes the call request at the given [BlockId]. @@ -175,7 +175,7 @@ pub trait EthTransactions: Send + Sync { inspector: I, ) -> EthResult<(ResultAndState, Env)> where - I: for<'r> Inspector> + Send + 'static; + I: Inspector + Send + 'static; /// Executes the transaction on top of the given [BlockId] with a tracer configured by the /// config. @@ -209,9 +209,7 @@ pub trait EthTransactions: Send + Sync { f: F, ) -> EthResult where - F: for<'a> FnOnce(TracingInspector, ResultAndState, StateCacheDB<'a>) -> EthResult - + Send - + 'static, + F: FnOnce(TracingInspector, ResultAndState, StateCacheDB) -> EthResult + Send + 'static, R: Send + 'static; /// Fetches the transaction and the transaction's block @@ -236,12 +234,7 @@ pub trait EthTransactions: Send + Sync { f: F, ) -> EthResult> where - F: for<'a> FnOnce( - TransactionInfo, - TracingInspector, - ResultAndState, - StateCacheDB<'a>, - ) -> EthResult + F: FnOnce(TransactionInfo, TracingInspector, ResultAndState, StateCacheDB) -> EthResult + Send + 'static, R: Send + 'static; @@ -269,7 +262,7 @@ pub trait EthTransactions: Send + Sync { TracingInspector, ExecutionResult, &'a State, - &'a CacheDB>>, + &'a CacheDB>, ) -> EthResult + Send + 'static, @@ -293,7 +286,7 @@ pub trait EthTransactions: Send + Sync { TracingInspector, ExecutionResult, &'a State, - &'a CacheDB>>, + &'a CacheDB>, ) -> EthResult + Send + 'static, @@ -312,13 +305,13 @@ where self.inner.gas_cap } - fn state_at(&self, at: BlockId) -> EthResult> { + fn state_at(&self, at: BlockId) -> EthResult { self.state_at_block_id(at) } fn with_state_at_block(&self, at: BlockId, f: F) -> EthResult where - F: FnOnce(StateProviderBox<'_>) -> EthResult, + F: FnOnce(StateProviderBox) -> EthResult, { let state = self.state_at(at)?; f(state) @@ -326,7 +319,7 @@ where async fn spawn_with_state_at_block(&self, at: BlockId, f: F) -> EthResult where - F: FnOnce(StateProviderBox<'_>) -> EthResult + Send + 'static, + F: FnOnce(StateProviderBox) -> EthResult + Send + 'static, T: Send + 'static, { self.spawn_tracing_task_with(move |this| { @@ -595,7 +588,7 @@ where f: F, ) -> EthResult where - F: for<'r> FnOnce(StateCacheDB<'r>, Env) -> EthResult + Send + 'static, + F: FnOnce(StateCacheDB, Env) -> EthResult + Send + 'static, R: Send + 'static, { let (cfg, block_env, at) = self.evm_env_at(at).await?; @@ -638,7 +631,7 @@ where inspector: I, ) -> EthResult<(ResultAndState, Env)> where - I: for<'r> Inspector> + Send + 'static, + I: Inspector + Send + 'static, { self.spawn_with_call_at(request, at, overrides, move |db, env| inspect(db, env, inspector)) .await @@ -672,9 +665,7 @@ where f: F, ) -> EthResult where - F: for<'a> FnOnce(TracingInspector, ResultAndState, StateCacheDB<'a>) -> EthResult - + Send - + 'static, + F: FnOnce(TracingInspector, ResultAndState, StateCacheDB) -> EthResult + Send + 'static, R: Send + 'static, { self.spawn_with_state_at_block(at, move |state| { @@ -712,12 +703,7 @@ where f: F, ) -> EthResult> where - F: for<'a> FnOnce( - TransactionInfo, - TracingInspector, - ResultAndState, - StateCacheDB<'a>, - ) -> EthResult + F: FnOnce(TransactionInfo, TracingInspector, ResultAndState, StateCacheDB) -> EthResult + Send + 'static, R: Send + 'static, @@ -764,7 +750,7 @@ where TracingInspector, ExecutionResult, &'a State, - &'a CacheDB>>, + &'a CacheDB>, ) -> EthResult + Send + 'static, @@ -786,7 +772,7 @@ where TracingInspector, ExecutionResult, &'a State, - &'a CacheDB>>, + &'a CacheDB>, ) -> EthResult + Send + 'static, diff --git a/crates/rpc/rpc/src/eth/cache/mod.rs b/crates/rpc/rpc/src/eth/cache/mod.rs index fb1f65d1db59..2ca9406cb10a 100644 --- a/crates/rpc/rpc/src/eth/cache/mod.rs +++ b/crates/rpc/rpc/src/eth/cache/mod.rs @@ -2,9 +2,12 @@ use futures::{future::Either, Stream, StreamExt}; use reth_interfaces::provider::{ProviderError, ProviderResult}; -use reth_primitives::{Block, Receipt, SealedBlock, TransactionSigned, B256}; +use reth_primitives::{ + Block, BlockHashOrNumber, BlockWithSenders, Receipt, SealedBlock, SealedBlockWithSenders, + TransactionSigned, TransactionSignedEcRecovered, B256, +}; use reth_provider::{ - BlockReader, BlockSource, CanonStateNotification, EvmEnvProvider, StateProviderFactory, + BlockReader, CanonStateNotification, EvmEnvProvider, StateProviderFactory, TransactionVariant, }; use reth_tasks::{TaskSpawner, TokioTaskExecutor}; use revm::primitives::{BlockEnv, CfgEnv}; @@ -29,13 +32,13 @@ mod metrics; mod multi_consumer; pub use multi_consumer::MultiConsumerLruCache; -/// The type that can send the response to a requested [Block] -type BlockResponseSender = oneshot::Sender>>; - /// The type that can send the response to a requested [Block] type BlockTransactionsResponseSender = oneshot::Sender>>>; +/// The type that can send the response to a requested [BlockWithSenders] +type BlockWithSendersResponseSender = oneshot::Sender>>; + /// The type that can send the response to the requested receipts of a block. type ReceiptsResponseSender = oneshot::Sender>>>; @@ -44,9 +47,9 @@ type EnvResponseSender = oneshot::Sender>; type BlockLruCache = MultiConsumerLruCache< B256, - Block, + BlockWithSenders, L, - Either, + Either, >; type ReceiptsLruCache = MultiConsumerLruCache, L, ReceiptsResponseSender>; @@ -128,26 +131,30 @@ impl EthStateCache { /// Requests the [Block] for the block hash /// /// Returns `None` if the block does not exist. - pub(crate) async fn get_block(&self, block_hash: B256) -> ProviderResult> { + pub async fn get_block(&self, block_hash: B256) -> ProviderResult> { let (response_tx, rx) = oneshot::channel(); - let _ = self.to_service.send(CacheAction::GetBlock { block_hash, response_tx }); - rx.await.map_err(|_| ProviderError::CacheServiceUnavailable)? + let _ = self.to_service.send(CacheAction::GetBlockWithSenders { block_hash, response_tx }); + let block_with_senders_res = + rx.await.map_err(|_| ProviderError::CacheServiceUnavailable)?; + + if let Ok(Some(block_with_senders)) = block_with_senders_res { + Ok(Some(block_with_senders.block)) + } else { + Ok(None) + } } /// Requests the [Block] for the block hash, sealed with the given block hash. /// /// Returns `None` if the block does not exist. - pub(crate) async fn get_sealed_block( - &self, - block_hash: B256, - ) -> ProviderResult> { + pub async fn get_sealed_block(&self, block_hash: B256) -> ProviderResult> { Ok(self.get_block(block_hash).await?.map(|block| block.seal(block_hash))) } /// Requests the transactions of the [Block] /// /// Returns `None` if the block does not exist. - pub(crate) async fn get_block_transactions( + pub async fn get_block_transactions( &self, block_hash: B256, ) -> ProviderResult>> { @@ -156,8 +163,21 @@ impl EthStateCache { rx.await.map_err(|_| ProviderError::CacheServiceUnavailable)? } + /// Requests the ecrecovered transactions of the [Block] + /// + /// Returns `None` if the block does not exist. + pub async fn get_block_transactions_ecrecovered( + &self, + block_hash: B256, + ) -> ProviderResult>> { + Ok(self + .get_block_with_senders(block_hash) + .await? + .map(|block| block.into_transactions_ecrecovered().collect())) + } + /// Fetches both transactions and receipts for the given block hash. - pub(crate) async fn get_transactions_and_receipts( + pub async fn get_transactions_and_receipts( &self, block_hash: B256, ) -> ProviderResult, Vec)>> { @@ -169,20 +189,39 @@ impl EthStateCache { Ok(transactions.zip(receipts)) } - /// Requests the [Receipt] for the block hash + /// Requests the [BlockWithSenders] for the block hash /// - /// Returns `None` if the block was not found. - pub(crate) async fn get_receipts( + /// Returns `None` if the block does not exist. + pub async fn get_block_with_senders( &self, block_hash: B256, - ) -> ProviderResult>> { + ) -> ProviderResult> { + let (response_tx, rx) = oneshot::channel(); + let _ = self.to_service.send(CacheAction::GetBlockWithSenders { block_hash, response_tx }); + rx.await.map_err(|_| ProviderError::CacheServiceUnavailable)? + } + + /// Requests the [SealedBlockWithSenders] for the block hash + /// + /// Returns `None` if the block does not exist. + pub async fn get_sealed_block_with_senders( + &self, + block_hash: B256, + ) -> ProviderResult> { + Ok(self.get_block_with_senders(block_hash).await?.map(|block| block.seal(block_hash))) + } + + /// Requests the [Receipt] for the block hash + /// + /// Returns `None` if the block was not found. + pub async fn get_receipts(&self, block_hash: B256) -> ProviderResult>> { let (response_tx, rx) = oneshot::channel(); let _ = self.to_service.send(CacheAction::GetReceipts { block_hash, response_tx }); rx.await.map_err(|_| ProviderError::CacheServiceUnavailable)? } /// Fetches both receipts and block for the given block hash. - pub(crate) async fn get_block_and_receipts( + pub async fn get_block_and_receipts( &self, block_hash: B256, ) -> ProviderResult)>> { @@ -198,7 +237,7 @@ impl EthStateCache { /// /// Returns an error if the corresponding header (required for populating the envs) was not /// found. - pub(crate) async fn get_evm_env(&self, block_hash: B256) -> ProviderResult<(CfgEnv, BlockEnv)> { + pub async fn get_evm_env(&self, block_hash: B256) -> ProviderResult<(CfgEnv, BlockEnv)> { let (response_tx, rx) = oneshot::channel(); let _ = self.to_service.send(CacheAction::GetEnv { block_hash, response_tx }); rx.await.map_err(|_| ProviderError::CacheServiceUnavailable)? @@ -228,7 +267,7 @@ pub(crate) struct EthStateCacheService< LimitReceipts = ByLength, LimitEnvs = ByLength, > where - LimitBlocks: Limiter, + LimitBlocks: Limiter, LimitReceipts: Limiter>, LimitEnvs: Limiter, { @@ -255,17 +294,18 @@ where Provider: StateProviderFactory + BlockReader + EvmEnvProvider + Clone + Unpin + 'static, Tasks: TaskSpawner + Clone + 'static, { - fn on_new_block(&mut self, block_hash: B256, res: ProviderResult>) { + fn on_new_block(&mut self, block_hash: B256, res: ProviderResult>) { if let Some(queued) = self.full_block_cache.remove(&block_hash) { // send the response to queued senders for tx in queued { match tx { - Either::Left(block_tx) => { - let _ = block_tx.send(res.clone()); + Either::Left(block_with_senders) => { + let _ = block_with_senders.send(res.clone()); } Either::Right(transaction_tx) => { let _ = transaction_tx.send( - res.clone().map(|maybe_block| maybe_block.map(|block| block.body)), + res.clone() + .map(|maybe_block| maybe_block.map(|block| block.block.body)), ); } } @@ -316,8 +356,7 @@ where } Some(action) => { match action { - CacheAction::GetBlock { block_hash, response_tx } => { - // check if block is cached + CacheAction::GetBlockWithSenders { block_hash, response_tx } => { if let Some(block) = this.full_block_cache.get(&block_hash).cloned() { let _ = response_tx.send(Ok(Some(block))); continue @@ -333,10 +372,14 @@ where let _permit = rate_limiter.acquire().await; // Only look in the database to prevent situations where we // looking up the tree is blocking - let res = provider - .find_block_by_hash(block_hash, BlockSource::Database); - let _ = action_tx - .send(CacheAction::BlockResult { block_hash, res }); + let block_sender = provider.block_with_senders( + BlockHashOrNumber::Hash(block_hash), + TransactionVariant::WithHash, + ); + let _ = action_tx.send(CacheAction::BlockWithSendersResult { + block_hash, + res: block_sender, + }); })); } } @@ -357,10 +400,14 @@ where let _permit = rate_limiter.acquire().await; // Only look in the database to prevent situations where we // looking up the tree is blocking - let res = provider - .find_block_by_hash(block_hash, BlockSource::Database); - let _ = action_tx - .send(CacheAction::BlockResult { block_hash, res }); + let res = provider.block_with_senders( + BlockHashOrNumber::Hash(block_hash), + TransactionVariant::WithHash, + ); + let _ = action_tx.send(CacheAction::BlockWithSendersResult { + block_hash, + res, + }); })); } } @@ -413,12 +460,20 @@ where })); } } - CacheAction::BlockResult { block_hash, res } => { - this.on_new_block(block_hash, res); - } CacheAction::ReceiptsResult { block_hash, res } => { this.on_new_receipts(block_hash, res); } + CacheAction::BlockWithSendersResult { block_hash, res } => match res { + Ok(Some(block_with_senders)) => { + this.on_new_block(block_hash, Ok(Some(block_with_senders))); + } + Ok(None) => { + this.on_new_block(block_hash, Ok(None)); + } + Err(e) => { + this.on_new_block(block_hash, Err(e)); + } + }, CacheAction::EnvResult { block_hash, res } => { let res = *res; if let Some(queued) = this.evm_env_cache.remove(&block_hash) { @@ -457,14 +512,14 @@ where /// All message variants sent through the channel enum CacheAction { - GetBlock { block_hash: B256, response_tx: BlockResponseSender }, + GetBlockWithSenders { block_hash: B256, response_tx: BlockWithSendersResponseSender }, GetBlockTransactions { block_hash: B256, response_tx: BlockTransactionsResponseSender }, GetEnv { block_hash: B256, response_tx: EnvResponseSender }, GetReceipts { block_hash: B256, response_tx: ReceiptsResponseSender }, - BlockResult { block_hash: B256, res: ProviderResult> }, + BlockWithSendersResult { block_hash: B256, res: ProviderResult> }, ReceiptsResult { block_hash: B256, res: ProviderResult>> }, EnvResult { block_hash: B256, res: Box> }, - CacheNewCanonicalChain { blocks: Vec, receipts: Vec }, + CacheNewCanonicalChain { blocks: Vec, receipts: Vec }, } struct BlockReceipts { @@ -483,13 +538,13 @@ where // we're only interested in new committed blocks let (blocks, state) = committed.inner(); - let blocks = blocks.iter().map(|(_, block)| block.block.clone()).collect::>(); + let blocks = blocks.iter().map(|(_, block)| block.clone()).collect::>(); // also cache all receipts of the blocks let mut receipts = Vec::with_capacity(blocks.len()); for block in &blocks { let block_receipts = BlockReceipts { - block_hash: block.hash, + block_hash: block.block.hash, receipts: state.receipts_by_block(block.number).to_vec(), }; receipts.push(block_receipts); diff --git a/crates/rpc/rpc/src/result.rs b/crates/rpc/rpc/src/result.rs index 43ceb5d94945..c37ced80179e 100644 --- a/crates/rpc/rpc/src/result.rs +++ b/crates/rpc/rpc/src/result.rs @@ -8,12 +8,11 @@ use reth_rpc_types::engine::PayloadError; use std::fmt::Display; /// Helper trait to easily convert various `Result` types into [`RpcResult`] -pub trait ToRpcResult { +pub trait ToRpcResult: Sized { /// Converts the error of the [Result] to an [RpcResult] via the `Err` [Display] impl. fn to_rpc_result(self) -> RpcResult where Err: Display, - Self: Sized, { self.map_internal_err(|err| err.to_string()) } diff --git a/crates/rpc/rpc/src/trace.rs b/crates/rpc/rpc/src/trace.rs index 5399ec49071a..7f0e8c5e543a 100644 --- a/crates/rpc/rpc/src/trace.rs +++ b/crates/rpc/rpc/src/trace.rs @@ -68,7 +68,7 @@ where /// Executes the given call and returns a number of possible traces for it. pub async fn trace_call(&self, trace_request: TraceCallRequest) -> EthResult { let at = trace_request.block_id.unwrap_or(BlockId::Number(BlockNumberOrTag::Latest)); - let config = tracing_config(&trace_request.trace_types); + let config = TracingInspectorConfig::from_parity_config(&trace_request.trace_types); let overrides = EvmOverrides::new(trace_request.state_overrides, trace_request.block_overrides); let mut inspector = TracingInspector::new(config); @@ -103,7 +103,7 @@ where let tx = tx_env_with_recovered(&tx.into_ecrecovered_transaction()); let env = Env { cfg, block, tx }; - let config = tracing_config(&trace_types); + let config = TracingInspectorConfig::from_parity_config(&trace_types); self.inner .eth_api @@ -148,7 +148,7 @@ where &mut db, Default::default(), )?; - let config = tracing_config(&trace_types); + let config = TracingInspectorConfig::from_parity_config(&trace_types); let mut inspector = TracingInspector::new(config); let (res, _) = inspect(&mut db, env, &mut inspector)?; @@ -180,7 +180,7 @@ where hash: B256, trace_types: HashSet, ) -> EthResult { - let config = tracing_config(&trace_types); + let config = TracingInspectorConfig::from_parity_config(&trace_types); self.inner .eth_api .spawn_trace_transaction_in_block(hash, config, move |_, inspector, res, db| { @@ -403,7 +403,7 @@ where .eth_api .trace_block_with( block_id, - tracing_config(&trace_types), + TracingInspectorConfig::from_parity_config(&trace_types), move |tx_info, inspector, res, state, db| { let mut full_trace = inspector.into_parity_builder().into_trace_results(&res, &trace_types); @@ -549,18 +549,6 @@ struct TraceApiInner { blocking_task_guard: BlockingTaskGuard, } -/// Returns the [TracingInspectorConfig] depending on the enabled [TraceType]s -/// -/// Note: the parity statediffs can be populated entirely via the execution result, so we don't need -/// statediff recording -#[inline] -fn tracing_config(trace_types: &HashSet) -> TracingInspectorConfig { - let needs_vm_trace = trace_types.contains(&TraceType::VmTrace); - TracingInspectorConfig::default_parity() - .set_steps(needs_vm_trace) - .set_memory_snapshots(needs_vm_trace) -} - /// Helper to construct a [`LocalizedTransactionTrace`] that describes a reward to the block /// beneficiary. fn reward_trace(header: &SealedHeader, reward: RewardAction) -> LocalizedTransactionTrace { @@ -578,32 +566,3 @@ fn reward_trace(header: &SealedHeader, reward: RewardAction) -> LocalizedTransac }, } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_parity_config() { - let mut s = HashSet::new(); - s.insert(TraceType::StateDiff); - let config = tracing_config(&s); - // not required - assert!(!config.record_steps); - assert!(!config.record_state_diff); - - let mut s = HashSet::new(); - s.insert(TraceType::VmTrace); - let config = tracing_config(&s); - assert!(config.record_steps); - assert!(!config.record_state_diff); - - let mut s = HashSet::new(); - s.insert(TraceType::VmTrace); - s.insert(TraceType::StateDiff); - let config = tracing_config(&s); - assert!(config.record_steps); - // not required for StateDiff - assert!(!config.record_state_diff); - } -} diff --git a/crates/snapshot/src/segments/headers.rs b/crates/snapshot/src/segments/headers.rs index d6852c73dec3..2bf73b2f7cf6 100644 --- a/crates/snapshot/src/segments/headers.rs +++ b/crates/snapshot/src/segments/headers.rs @@ -37,7 +37,7 @@ impl Segment for Headers { fn snapshot( &self, - provider: &DatabaseProviderRO<'_, DB>, + provider: &DatabaseProviderRO, directory: impl AsRef, range: RangeInclusive, ) -> ProviderResult<()> { diff --git a/crates/snapshot/src/segments/mod.rs b/crates/snapshot/src/segments/mod.rs index ec9061ebcb19..88cdc52ef685 100644 --- a/crates/snapshot/src/segments/mod.rs +++ b/crates/snapshot/src/segments/mod.rs @@ -31,7 +31,7 @@ pub trait Segment: Default { /// file's save location. fn snapshot( &self, - provider: &DatabaseProviderRO<'_, DB>, + provider: &DatabaseProviderRO, directory: impl AsRef, range: RangeInclusive, ) -> ProviderResult<()>; @@ -42,7 +42,7 @@ pub trait Segment: Default { /// Generates the dataset to train a zstd dictionary with the most recent rows (at most 1000). fn dataset_for_compression>( &self, - provider: &DatabaseProviderRO<'_, DB>, + provider: &DatabaseProviderRO, range: &RangeInclusive, range_len: usize, ) -> ProviderResult>> { @@ -58,7 +58,7 @@ pub trait Segment: Default { /// Returns a [`NippyJar`] according to the desired configuration. The `directory` parameter /// determines the snapshot file's save location. pub(crate) fn prepare_jar( - provider: &DatabaseProviderRO<'_, DB>, + provider: &DatabaseProviderRO, directory: impl AsRef, segment: SnapshotSegment, segment_config: SegmentConfig, diff --git a/crates/snapshot/src/segments/receipts.rs b/crates/snapshot/src/segments/receipts.rs index c40949a0dd65..4b82a7133a4e 100644 --- a/crates/snapshot/src/segments/receipts.rs +++ b/crates/snapshot/src/segments/receipts.rs @@ -34,7 +34,7 @@ impl Segment for Receipts { fn snapshot( &self, - provider: &DatabaseProviderRO<'_, DB>, + provider: &DatabaseProviderRO, directory: impl AsRef, block_range: RangeInclusive, ) -> ProviderResult<()> { diff --git a/crates/snapshot/src/segments/transactions.rs b/crates/snapshot/src/segments/transactions.rs index 4367f1ce0a7f..585bc9625e42 100644 --- a/crates/snapshot/src/segments/transactions.rs +++ b/crates/snapshot/src/segments/transactions.rs @@ -34,7 +34,7 @@ impl Segment for Transactions { fn snapshot( &self, - provider: &DatabaseProviderRO<'_, DB>, + provider: &DatabaseProviderRO, directory: impl AsRef, block_range: RangeInclusive, ) -> ProviderResult<()> { diff --git a/crates/snapshot/src/snapshotter.rs b/crates/snapshot/src/snapshotter.rs index d9c1f6aeb003..729b0c1b974f 100644 --- a/crates/snapshot/src/snapshotter.rs +++ b/crates/snapshot/src/snapshotter.rs @@ -5,14 +5,13 @@ use reth_db::database::Database; use reth_interfaces::{RethError, RethResult}; use reth_primitives::{ snapshot::{iter_snapshots, HighestSnapshots}, - BlockNumber, ChainSpec, TxNumber, + BlockNumber, TxNumber, }; use reth_provider::{BlockReader, DatabaseProviderRO, ProviderFactory, TransactionsProviderExt}; use std::{ collections::HashMap, ops::RangeInclusive, path::{Path, PathBuf}, - sync::Arc, }; use tokio::sync::watch; use tracing::warn; @@ -94,15 +93,14 @@ impl SnapshotTargets { impl Snapshotter { /// Creates a new [Snapshotter]. pub fn new( - db: DB, + provider_factory: ProviderFactory, snapshots_path: impl AsRef, - chain_spec: Arc, block_interval: u64, ) -> RethResult { let (highest_snapshots_notifier, highest_snapshots_tracker) = watch::channel(None); let mut snapshotter = Self { - provider_factory: ProviderFactory::new(db, chain_spec), + provider_factory, snapshots_path: snapshots_path.as_ref().into(), highest_snapshots: HighestSnapshots::default(), highest_snapshots_notifier, @@ -291,7 +289,7 @@ impl Snapshotter { fn get_snapshot_target_tx_range( &self, - provider: &DatabaseProviderRO<'_, DB>, + provider: &DatabaseProviderRO, block_to_tx_number_cache: &mut HashMap, highest_snapshot: Option, block_range: &RangeInclusive, @@ -329,16 +327,14 @@ mod tests { test_utils::{generators, generators::random_block_range}, RethError, }; - use reth_primitives::{snapshot::HighestSnapshots, B256, MAINNET}; - use reth_stages::test_utils::TestTransaction; + use reth_primitives::{snapshot::HighestSnapshots, B256}; + use reth_stages::test_utils::TestStageDB; #[test] fn new() { - let tx = TestTransaction::default(); + let db = TestStageDB::default(); let snapshots_dir = tempfile::TempDir::new().unwrap(); - let snapshotter = - Snapshotter::new(tx.inner_raw(), snapshots_dir.into_path(), MAINNET.clone(), 2) - .unwrap(); + let snapshotter = Snapshotter::new(db.factory, snapshots_dir.into_path(), 2).unwrap(); assert_eq!( *snapshotter.highest_snapshot_receiver().borrow(), @@ -348,16 +344,14 @@ mod tests { #[test] fn get_snapshot_targets() { - let tx = TestTransaction::default(); + let db = TestStageDB::default(); let snapshots_dir = tempfile::TempDir::new().unwrap(); let mut rng = generators::rng(); let blocks = random_block_range(&mut rng, 0..=3, B256::ZERO, 2..3); - tx.insert_blocks(blocks.iter(), None).expect("insert blocks"); + db.insert_blocks(blocks.iter(), None).expect("insert blocks"); - let mut snapshotter = - Snapshotter::new(tx.inner_raw(), snapshots_dir.into_path(), MAINNET.clone(), 2) - .unwrap(); + let mut snapshotter = Snapshotter::new(db.factory, snapshots_dir.into_path(), 2).unwrap(); // Snapshot targets has data per part up to the passed finalized block number, // respecting the block interval diff --git a/crates/stages/Cargo.toml b/crates/stages/Cargo.toml index 6da02ad00a70..890bc135ed02 100644 --- a/crates/stages/Cargo.toml +++ b/crates/stages/Cargo.toml @@ -50,6 +50,7 @@ aquamarine.workspace = true itertools.workspace = true rayon.workspace = true num-traits = "0.2.15" +auto_impl = "1" [dev-dependencies] # reth diff --git a/crates/stages/benches/criterion.rs b/crates/stages/benches/criterion.rs index 9e55781b7e74..2f73ec71f9f1 100644 --- a/crates/stages/benches/criterion.rs +++ b/crates/stages/benches/criterion.rs @@ -9,8 +9,8 @@ use reth_primitives::{stage::StageCheckpoint, MAINNET}; use reth_provider::ProviderFactory; use reth_stages::{ stages::{MerkleStage, SenderRecoveryStage, TotalDifficultyStage, TransactionLookupStage}, - test_utils::TestTransaction, - ExecInput, Stage, UnwindInput, + test_utils::TestStageDB, + ExecInput, Stage, StageExt, UnwindInput, }; use std::{path::PathBuf, sync::Arc}; @@ -123,9 +123,9 @@ fn measure_stage_with_path( label: String, ) where S: Clone + Stage, - F: Fn(S, &TestTransaction, StageRange), + F: Fn(S, &TestStageDB, StageRange), { - let tx = TestTransaction::new(&path); + let tx = TestStageDB::new(&path); let (input, _) = stage_range; group.bench_function(label, move |b| { @@ -136,9 +136,13 @@ fn measure_stage_with_path( }, |_| async { let mut stage = stage.clone(); - let factory = ProviderFactory::new(tx.tx.db(), MAINNET.clone()); + let factory = ProviderFactory::new(tx.factory.db(), MAINNET.clone()); let provider = factory.provider_rw().unwrap(); - stage.execute(&provider, input).await.unwrap(); + stage + .execute_ready(input) + .await + .and_then(|_| stage.execute(&provider, input)) + .unwrap(); provider.commit().unwrap(); }, ) @@ -153,7 +157,7 @@ fn measure_stage( label: String, ) where S: Clone + Stage, - F: Fn(S, &TestTransaction, StageRange), + F: Fn(S, &TestStageDB, StageRange), { let path = setup::txs_testdata(block_interval.end); diff --git a/crates/stages/benches/setup/account_hashing.rs b/crates/stages/benches/setup/account_hashing.rs index 341dbd42b61d..a94a8250aede 100644 --- a/crates/stages/benches/setup/account_hashing.rs +++ b/crates/stages/benches/setup/account_hashing.rs @@ -5,7 +5,7 @@ use reth_db::{ use reth_primitives::stage::StageCheckpoint; use reth_stages::{ stages::{AccountHashingStage, SeedOpts}, - test_utils::TestTransaction, + test_utils::TestStageDB, ExecInput, UnwindInput, }; use std::path::{Path, PathBuf}; @@ -31,8 +31,8 @@ pub fn prepare_account_hashing(num_blocks: u64) -> (PathBuf, AccountHashingStage fn find_stage_range(db: &Path) -> StageRange { let mut stage_range = None; - TestTransaction::new(db) - .tx + TestStageDB::new(db) + .factory .view(|tx| { let mut cursor = tx.cursor_read::()?; let from = cursor.first()?.unwrap().0; @@ -62,8 +62,8 @@ fn generate_testdata_db(num_blocks: u64) -> (PathBuf, StageRange) { // create the dirs std::fs::create_dir_all(&path).unwrap(); println!("Account Hashing testdata not found, generating to {:?}", path.display()); - let tx = TestTransaction::new(&path); - let provider = tx.inner_rw(); + let tx = TestStageDB::new(&path); + let provider = tx.provider_rw(); let _accounts = AccountHashingStage::seed(&provider, opts); provider.commit().expect("failed to commit"); } diff --git a/crates/stages/benches/setup/mod.rs b/crates/stages/benches/setup/mod.rs index f5c45be9b96e..3850ca44fa55 100644 --- a/crates/stages/benches/setup/mod.rs +++ b/crates/stages/benches/setup/mod.rs @@ -16,7 +16,7 @@ use reth_primitives::{Account, Address, SealedBlock, B256, MAINNET}; use reth_provider::ProviderFactory; use reth_stages::{ stages::{AccountHashingStage, StorageHashingStage}, - test_utils::TestTransaction, + test_utils::TestStageDB, ExecInput, Stage, UnwindInput, }; use reth_trie::StateRoot; @@ -34,24 +34,23 @@ pub(crate) type StageRange = (ExecInput, UnwindInput); pub(crate) fn stage_unwind>( stage: S, - tx: &TestTransaction, + db: &TestStageDB, range: StageRange, ) { let (_, unwind) = range; tokio::runtime::Runtime::new().unwrap().block_on(async { let mut stage = stage.clone(); - let factory = ProviderFactory::new(tx.tx.db(), MAINNET.clone()); + let factory = ProviderFactory::new(db.factory.db(), MAINNET.clone()); let provider = factory.provider_rw().unwrap(); // Clear previous run stage .unwind(&provider, unwind) - .await .map_err(|e| { format!( "{e}\nMake sure your test database at `{}` isn't too old and incompatible with newer stage changes.", - tx.path.as_ref().unwrap().display() + db.path.as_ref().unwrap().display() ) }) .unwrap(); @@ -62,27 +61,25 @@ pub(crate) fn stage_unwind>( pub(crate) fn unwind_hashes>( stage: S, - tx: &TestTransaction, + db: &TestStageDB, range: StageRange, ) { let (input, unwind) = range; - tokio::runtime::Runtime::new().unwrap().block_on(async { - let mut stage = stage.clone(); - let factory = ProviderFactory::new(tx.tx.db(), MAINNET.clone()); - let provider = factory.provider_rw().unwrap(); + let mut stage = stage.clone(); + let factory = ProviderFactory::new(db.factory.db(), MAINNET.clone()); + let provider = factory.provider_rw().unwrap(); - StorageHashingStage::default().unwind(&provider, unwind).await.unwrap(); - AccountHashingStage::default().unwind(&provider, unwind).await.unwrap(); + StorageHashingStage::default().unwind(&provider, unwind).unwrap(); + AccountHashingStage::default().unwind(&provider, unwind).unwrap(); - // Clear previous run - stage.unwind(&provider, unwind).await.unwrap(); + // Clear previous run + stage.unwind(&provider, unwind).unwrap(); - AccountHashingStage::default().execute(&provider, input).await.unwrap(); - StorageHashingStage::default().execute(&provider, input).await.unwrap(); + AccountHashingStage::default().execute(&provider, input).unwrap(); + StorageHashingStage::default().execute(&provider, input).unwrap(); - provider.commit().unwrap(); - }); + provider.commit().unwrap(); } // Helper for generating testdata for the benchmarks. @@ -108,7 +105,7 @@ pub(crate) fn txs_testdata(num_blocks: u64) -> PathBuf { // create the dirs std::fs::create_dir_all(&path).unwrap(); println!("Transactions testdata not found, generating to {:?}", path.display()); - let tx = TestTransaction::new(&path); + let tx = TestStageDB::new(&path); let accounts: BTreeMap = concat([ random_eoa_account_range(&mut rng, 0..n_eoa), @@ -130,7 +127,8 @@ pub(crate) fn txs_testdata(num_blocks: u64) -> PathBuf { tx.insert_accounts_and_storages(start_state.clone()).unwrap(); // make first block after genesis have valid state root - let (root, updates) = StateRoot::new(tx.inner_rw().tx_ref()).root_with_updates().unwrap(); + let (root, updates) = + StateRoot::new(tx.provider_rw().tx_ref()).root_with_updates().unwrap(); let second_block = blocks.get_mut(1).unwrap(); let cloned_second = second_block.clone(); let mut updated_header = cloned_second.header.unseal(); @@ -156,7 +154,7 @@ pub(crate) fn txs_testdata(num_blocks: u64) -> PathBuf { // make last block have valid state root let root = { - let tx_mut = tx.inner_rw(); + let tx_mut = tx.provider_rw(); let root = StateRoot::new(tx_mut.tx_ref()).root().unwrap(); tx_mut.commit().unwrap(); root diff --git a/crates/stages/src/error.rs b/crates/stages/src/error.rs index 180a8ca5ae24..4a4df0d26987 100644 --- a/crates/stages/src/error.rs +++ b/crates/stages/src/error.rs @@ -50,6 +50,9 @@ pub enum StageError { #[source] error: Box, }, + /// The headers stage is missing sync gap. + #[error("missing sync gap")] + MissingSyncGap, /// The stage encountered a database error. #[error("internal database error occurred: {0}")] Database(#[from] DbError), @@ -59,6 +62,10 @@ pub enum StageError { /// Invalid checkpoint passed to the stage #[error("invalid stage checkpoint: {0}")] StageCheckpoint(u64), + /// Missing download buffer on stage execution. + /// Returned if stage execution was called without polling for readiness. + #[error("missing download buffer")] + MissingDownloadBuffer, /// Download channel closed #[error("download channel closed")] ChannelClosed, @@ -94,6 +101,8 @@ impl StageError { StageError::Download(_) | StageError::DatabaseIntegrity(_) | StageError::StageCheckpoint(_) | + StageError::MissingDownloadBuffer | + StageError::MissingSyncGap | StageError::ChannelClosed | StageError::Fatal(_) ) diff --git a/crates/stages/src/lib.rs b/crates/stages/src/lib.rs index f30471182c8a..e59597ebabc1 100644 --- a/crates/stages/src/lib.rs +++ b/crates/stages/src/lib.rs @@ -13,42 +13,46 @@ //! //! ``` //! # use std::sync::Arc; -//! # use reth_db::test_utils::create_test_rw_db; //! # use reth_downloaders::bodies::bodies::BodiesDownloaderBuilder; //! # use reth_downloaders::headers::reverse_headers::ReverseHeadersDownloaderBuilder; //! # use reth_interfaces::consensus::Consensus; //! # use reth_interfaces::test_utils::{TestBodiesClient, TestConsensus, TestHeadersClient}; -//! # use reth_revm::Factory; +//! # use reth_revm::EvmProcessorFactory; //! # use reth_primitives::{PeerId, MAINNET, B256}; //! # use reth_stages::Pipeline; //! # use reth_stages::sets::DefaultStages; -//! # use reth_stages::stages::HeaderSyncMode; //! # use tokio::sync::watch; +//! # use reth_provider::ProviderFactory; +//! # use reth_provider::HeaderSyncMode; +//! # use reth_provider::test_utils::create_test_provider_factory; +//! # +//! # let chain_spec = MAINNET.clone(); //! # let consensus: Arc = Arc::new(TestConsensus::default()); //! # let headers_downloader = ReverseHeadersDownloaderBuilder::default().build( //! # Arc::new(TestHeadersClient::default()), //! # consensus.clone() //! # ); -//! # let db = create_test_rw_db(); +//! # let provider_factory = create_test_provider_factory(); //! # let bodies_downloader = BodiesDownloaderBuilder::default().build( //! # Arc::new(TestBodiesClient { responder: |_| Ok((PeerId::ZERO, vec![]).into()) }), //! # consensus.clone(), -//! # db.clone() +//! # provider_factory.clone() //! # ); //! # let (tip_tx, tip_rx) = watch::channel(B256::default()); -//! # let factory = Factory::new(MAINNET.clone()); +//! # let executor_factory = EvmProcessorFactory::new(chain_spec.clone()); //! // Create a pipeline that can fully sync //! # let pipeline = //! Pipeline::builder() //! .with_tip_sender(tip_tx) //! .add_stages(DefaultStages::new( +//! provider_factory.clone(), //! HeaderSyncMode::Tip(tip_rx), //! consensus, //! headers_downloader, //! bodies_downloader, -//! factory, +//! executor_factory, //! )) -//! .build(db, MAINNET.clone()); +//! .build(provider_factory); //! ``` //! //! ## Feature Flags diff --git a/crates/stages/src/pipeline/builder.rs b/crates/stages/src/pipeline/builder.rs index b5a0a2d409a1..3e160577fddc 100644 --- a/crates/stages/src/pipeline/builder.rs +++ b/crates/stages/src/pipeline/builder.rs @@ -1,8 +1,7 @@ -use std::sync::Arc; - use crate::{pipeline::BoxedStage, MetricEventsSender, Pipeline, Stage, StageSet}; use reth_db::database::Database; -use reth_primitives::{stage::StageId, BlockNumber, ChainSpec, B256}; +use reth_primitives::{stage::StageId, BlockNumber, B256}; +use reth_provider::ProviderFactory; use tokio::sync::watch; /// Builds a [`Pipeline`]. @@ -68,13 +67,10 @@ where } /// Builds the final [`Pipeline`] using the given database. - /// - /// Note: it's expected that this is either an [Arc] or an Arc wrapper type. - pub fn build(self, db: DB, chain_spec: Arc) -> Pipeline { + pub fn build(self, provider_factory: ProviderFactory) -> Pipeline { let Self { stages, max_block, tip_tx, metrics_tx } = self; Pipeline { - db, - chain_spec, + provider_factory, stages, max_block, tip_tx, diff --git a/crates/stages/src/pipeline/event.rs b/crates/stages/src/pipeline/event.rs index 05d7945d3319..d5b02610a541 100644 --- a/crates/stages/src/pipeline/event.rs +++ b/crates/stages/src/pipeline/event.rs @@ -1,5 +1,8 @@ use crate::stage::{ExecOutput, UnwindInput, UnwindOutput}; -use reth_primitives::stage::{StageCheckpoint, StageId}; +use reth_primitives::{ + stage::{StageCheckpoint, StageId}, + BlockNumber, +}; use std::fmt::{Display, Formatter}; /// An event emitted by a [Pipeline][crate::Pipeline]. @@ -12,13 +15,15 @@ use std::fmt::{Display, Formatter}; #[derive(Debug, PartialEq, Eq, Clone)] pub enum PipelineEvent { /// Emitted when a stage is about to be run. - Running { + Run { /// Pipeline stages progress. pipeline_stages_progress: PipelineStagesProgress, /// The stage that is about to be run. stage_id: StageId, /// The previous checkpoint of the stage. checkpoint: Option, + /// The block number up to which the stage is running, if known. + target: Option, }, /// Emitted when a stage has run a single time. Ran { @@ -30,7 +35,7 @@ pub enum PipelineEvent { result: ExecOutput, }, /// Emitted when a stage is about to be unwound. - Unwinding { + Unwind { /// The stage that is about to be unwound. stage_id: StageId, /// The unwind parameters. diff --git a/crates/stages/src/pipeline/mod.rs b/crates/stages/src/pipeline/mod.rs index f5955a5dffbc..344510b23332 100644 --- a/crates/stages/src/pipeline/mod.rs +++ b/crates/stages/src/pipeline/mod.rs @@ -1,15 +1,17 @@ use crate::{ error::*, BlockErrorKind, ExecInput, ExecOutput, MetricEvent, MetricEventsSender, Stage, - StageError, UnwindInput, + StageError, StageExt, UnwindInput, }; use futures_util::Future; use reth_db::database::Database; use reth_primitives::{ - constants::BEACON_CONSENSUS_REORG_UNWIND_DEPTH, stage::StageId, BlockNumber, ChainSpec, B256, + constants::BEACON_CONSENSUS_REORG_UNWIND_DEPTH, + stage::{StageCheckpoint, StageId}, + BlockNumber, B256, }; use reth_provider::{ProviderFactory, StageCheckpointReader, StageCheckpointWriter}; use reth_tokio_util::EventListeners; -use std::{pin::Pin, sync::Arc}; +use std::pin::Pin; use tokio::sync::watch; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::*; @@ -91,10 +93,8 @@ pub type PipelineWithResult = (Pipeline, Result { - /// The Database - db: DB, - /// Chain spec - chain_spec: Arc, + /// Provider factory. + provider_factory: ProviderFactory, /// All configured stages in the order they will be executed. stages: Vec>, /// The maximum block number to sync to. @@ -139,8 +139,7 @@ where /// Registers progress metrics for each registered stage pub fn register_metrics(&mut self) -> Result<(), PipelineError> { let Some(metrics_tx) = &mut self.metrics_tx else { return Ok(()) }; - let factory = ProviderFactory::new(&self.db, self.chain_spec.clone()); - let provider = factory.provider()?; + let provider = self.provider_factory.provider()?; for stage in &self.stages { let stage_id = stage.id(); @@ -217,10 +216,7 @@ where let stage_id = stage.id(); trace!(target: "sync::pipeline", stage = %stage_id, "Executing stage"); - let next = self - .execute_stage_to_completion(previous_stage, stage_index) - .instrument(info_span!("execute", stage = %stage_id)) - .await?; + let next = self.execute_stage_to_completion(previous_stage, stage_index).await?; trace!(target: "sync::pipeline", stage = %stage_id, ?next, "Completed stage"); @@ -232,15 +228,13 @@ where } ControlFlow::Continue { block_number } => self.progress.update(block_number), ControlFlow::Unwind { target, bad_block } => { - self.unwind(target, Some(bad_block.number)).await?; + self.unwind(target, Some(bad_block.number))?; return Ok(ControlFlow::Unwind { target, bad_block }) } } - let factory = ProviderFactory::new(&self.db, self.chain_spec.clone()); - previous_stage = Some( - factory + self.provider_factory .provider()? .get_stage_checkpoint(stage_id)? .unwrap_or_default() @@ -254,7 +248,7 @@ where /// Unwind the stages to the target block. /// /// If the unwind is due to a bad block the number of that block should be specified. - pub async fn unwind( + pub fn unwind( &mut self, to: BlockNumber, bad_block: Option, @@ -262,8 +256,7 @@ where // Unwind stages in reverse order of execution let unwind_pipeline = self.stages.iter_mut().rev(); - let factory = ProviderFactory::new(&self.db, self.chain_spec.clone()); - let mut provider_rw = factory.provider_rw()?; + let mut provider_rw = self.provider_factory.provider_rw()?; for stage in unwind_pipeline { let stage_id = stage.id(); @@ -291,9 +284,9 @@ where ); while checkpoint.block_number > to { let input = UnwindInput { checkpoint, unwind_to: to, bad_block }; - self.listeners.notify(PipelineEvent::Unwinding { stage_id, input }); + self.listeners.notify(PipelineEvent::Unwind { stage_id, input }); - let output = stage.unwind(&provider_rw, input).await; + let output = stage.unwind(&provider_rw, input); match output { Ok(unwind_output) => { checkpoint = unwind_output.checkpoint; @@ -320,7 +313,7 @@ where .notify(PipelineEvent::Unwound { stage_id, result: unwind_output }); provider_rw.commit()?; - provider_rw = factory.provider_rw()?; + provider_rw = self.provider_factory.provider_rw()?; } Err(err) => { self.listeners.notify(PipelineEvent::Error { stage_id }); @@ -345,11 +338,8 @@ where let mut made_progress = false; let target = self.max_block.or(previous_stage); - let factory = ProviderFactory::new(&self.db, self.chain_spec.clone()); - let mut provider_rw = factory.provider_rw()?; - loop { - let prev_checkpoint = provider_rw.get_stage_checkpoint(stage_id)?; + let prev_checkpoint = self.provider_factory.get_stage_checkpoint(stage_id)?; let stage_reached_max_block = prev_checkpoint .zip(self.max_block) @@ -370,43 +360,32 @@ where }) } - self.listeners.notify(PipelineEvent::Running { + let exec_input = ExecInput { target, checkpoint: prev_checkpoint }; + + if let Err(err) = stage.execute_ready(exec_input).await { + self.listeners.notify(PipelineEvent::Error { stage_id }); + match on_stage_error(&self.provider_factory, stage_id, prev_checkpoint, err)? { + Some(ctrl) => return Ok(ctrl), + None => continue, + }; + } + + self.listeners.notify(PipelineEvent::Run { pipeline_stages_progress: event::PipelineStagesProgress { current: stage_index + 1, total: total_stages, }, stage_id, checkpoint: prev_checkpoint, + target, }); - match stage - .execute(&provider_rw, ExecInput { target, checkpoint: prev_checkpoint }) - .await - { + let provider_rw = self.provider_factory.provider_rw()?; + match stage.execute(&provider_rw, exec_input) { Ok(out @ ExecOutput { checkpoint, done }) => { made_progress |= checkpoint.block_number != prev_checkpoint.unwrap_or_default().block_number; - if let Some(progress) = checkpoint.entities() { - debug!( - target: "sync::pipeline", - stage = %stage_id, - checkpoint = checkpoint.block_number, - ?target, - %progress, - %done, - "Stage committed progress" - ); - } else { - debug!( - target: "sync::pipeline", - stage = %stage_id, - checkpoint = checkpoint.block_number, - ?target, - %done, - "Stage committed progress" - ); - } if let Some(metrics_tx) = &mut self.metrics_tx { let _ = metrics_tx.send(MetricEvent::StageCheckpoint { stage_id, @@ -425,9 +404,7 @@ where result: out.clone(), }); - // TODO: Make the commit interval configurable provider_rw.commit()?; - provider_rw = factory.provider_rw()?; if done { let block_number = checkpoint.block_number; @@ -439,94 +416,95 @@ where } } Err(err) => { + drop(provider_rw); self.listeners.notify(PipelineEvent::Error { stage_id }); - - let out = if let StageError::DetachedHead { local_head, header, error } = err { - warn!(target: "sync::pipeline", stage = %stage_id, ?local_head, ?header, ?error, "Stage encountered detached head"); - - // We unwind because of a detached head. - let unwind_to = local_head - .number - .saturating_sub(BEACON_CONSENSUS_REORG_UNWIND_DEPTH) - .max(1); - Ok(ControlFlow::Unwind { target: unwind_to, bad_block: local_head }) - } else if let StageError::Block { block, error } = err { - match error { - BlockErrorKind::Validation(validation_error) => { - error!( - target: "sync::pipeline", - stage = %stage_id, - bad_block = %block.number, - "Stage encountered a validation error: {validation_error}" - ); - - // FIXME: When handling errors, we do not commit the database - // transaction. This leads to the Merkle - // stage not clearing its checkpoint, and - // restarting from an invalid place. - drop(provider_rw); - provider_rw = factory.provider_rw()?; - provider_rw.save_stage_checkpoint_progress( - StageId::MerkleExecute, - vec![], - )?; - provider_rw.save_stage_checkpoint( - StageId::MerkleExecute, - prev_checkpoint.unwrap_or_default(), - )?; - provider_rw.commit()?; - - // We unwind because of a validation error. If the unwind itself - // fails, we bail entirely, - // otherwise we restart the execution loop from the - // beginning. - Ok(ControlFlow::Unwind { - target: prev_checkpoint.unwrap_or_default().block_number, - bad_block: block, - }) - } - BlockErrorKind::Execution(execution_error) => { - error!( - target: "sync::pipeline", - stage = %stage_id, - bad_block = %block.number, - "Stage encountered an execution error: {execution_error}" - ); - - // We unwind because of an execution error. If the unwind itself - // fails, we bail entirely, - // otherwise we restart - // the execution loop from the beginning. - Ok(ControlFlow::Unwind { - target: prev_checkpoint.unwrap_or_default().block_number, - bad_block: block, - }) - } - } - } else if err.is_fatal() { - error!( - target: "sync::pipeline", - stage = %stage_id, - "Stage encountered a fatal error: {err}." - ); - Err(err.into()) - } else { - // On other errors we assume they are recoverable if we discard the - // transaction and run the stage again. - warn!( - target: "sync::pipeline", - stage = %stage_id, - "Stage encountered a non-fatal error: {err}. Retrying..." - ); - continue - }; - return out + if let Some(ctrl) = + on_stage_error(&self.provider_factory, stage_id, prev_checkpoint, err)? + { + return Ok(ctrl) + } } } } } } +fn on_stage_error( + factory: &ProviderFactory, + stage_id: StageId, + prev_checkpoint: Option, + err: StageError, +) -> Result, PipelineError> { + if let StageError::DetachedHead { local_head, header, error } = err { + warn!(target: "sync::pipeline", stage = %stage_id, ?local_head, ?header, ?error, "Stage encountered detached head"); + + // We unwind because of a detached head. + let unwind_to = + local_head.number.saturating_sub(BEACON_CONSENSUS_REORG_UNWIND_DEPTH).max(1); + Ok(Some(ControlFlow::Unwind { target: unwind_to, bad_block: local_head })) + } else if let StageError::Block { block, error } = err { + match error { + BlockErrorKind::Validation(validation_error) => { + error!( + target: "sync::pipeline", + stage = %stage_id, + bad_block = %block.number, + "Stage encountered a validation error: {validation_error}" + ); + + // FIXME: When handling errors, we do not commit the database transaction. This + // leads to the Merkle stage not clearing its checkpoint, and restarting from an + // invalid place. + let provider_rw = factory.provider_rw()?; + provider_rw.save_stage_checkpoint_progress(StageId::MerkleExecute, vec![])?; + provider_rw.save_stage_checkpoint( + StageId::MerkleExecute, + prev_checkpoint.unwrap_or_default(), + )?; + provider_rw.commit()?; + + // We unwind because of a validation error. If the unwind itself + // fails, we bail entirely, + // otherwise we restart the execution loop from the + // beginning. + Ok(Some(ControlFlow::Unwind { + target: prev_checkpoint.unwrap_or_default().block_number, + bad_block: block, + })) + } + BlockErrorKind::Execution(execution_error) => { + error!( + target: "sync::pipeline", + stage = %stage_id, + bad_block = %block.number, + "Stage encountered an execution error: {execution_error}" + ); + + // We unwind because of an execution error. If the unwind itself + // fails, we bail entirely, + // otherwise we restart + // the execution loop from the beginning. + Ok(Some(ControlFlow::Unwind { + target: prev_checkpoint.unwrap_or_default().block_number, + bad_block: block, + })) + } + } + } else if err.is_fatal() { + error!(target: "sync::pipeline", stage = %stage_id, "Stage encountered a fatal error: {err}"); + Err(err.into()) + } else { + // On other errors we assume they are recoverable if we discard the + // transaction and run the stage again. + warn!( + target: "sync::pipeline", + stage = %stage_id, + "Stage encountered a non-fatal error: {err}. Retrying..." + ); + Ok(None) + } +} + impl std::fmt::Debug for Pipeline { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Pipeline") @@ -542,13 +520,13 @@ mod tests { use super::*; use crate::{test_utils::TestStage, UnwindOutput}; use assert_matches::assert_matches; - use reth_db::test_utils::create_test_rw_db; use reth_interfaces::{ consensus, provider::ProviderError, test_utils::{generators, generators::random_header}, }; - use reth_primitives::{stage::StageCheckpoint, MAINNET}; + use reth_primitives::stage::StageCheckpoint; + use reth_provider::test_utils::create_test_provider_factory; use tokio_stream::StreamExt; #[test] @@ -581,7 +559,7 @@ mod tests { /// Runs a simple pipeline. #[tokio::test] async fn run_pipeline() { - let db = create_test_rw_db(); + let provider_factory = create_test_provider_factory(); let mut pipeline = Pipeline::builder() .add_stage( @@ -593,7 +571,7 @@ mod tests { .add_exec(Ok(ExecOutput { checkpoint: StageCheckpoint::new(10), done: true })), ) .with_max_block(10) - .build(db, MAINNET.clone()); + .build(provider_factory); let events = pipeline.events(); // Run pipeline @@ -605,20 +583,22 @@ mod tests { assert_eq!( events.collect::>().await, vec![ - PipelineEvent::Running { + PipelineEvent::Run { pipeline_stages_progress: PipelineStagesProgress { current: 1, total: 2 }, stage_id: StageId::Other("A"), - checkpoint: None + checkpoint: None, + target: Some(10), }, PipelineEvent::Ran { pipeline_stages_progress: PipelineStagesProgress { current: 1, total: 2 }, stage_id: StageId::Other("A"), result: ExecOutput { checkpoint: StageCheckpoint::new(20), done: true }, }, - PipelineEvent::Running { + PipelineEvent::Run { pipeline_stages_progress: PipelineStagesProgress { current: 2, total: 2 }, stage_id: StageId::Other("B"), - checkpoint: None + checkpoint: None, + target: Some(10), }, PipelineEvent::Ran { pipeline_stages_progress: PipelineStagesProgress { current: 2, total: 2 }, @@ -632,7 +612,7 @@ mod tests { /// Unwinds a simple pipeline. #[tokio::test] async fn unwind_pipeline() { - let db = create_test_rw_db(); + let provider_factory = create_test_provider_factory(); let mut pipeline = Pipeline::builder() .add_stage( @@ -651,7 +631,7 @@ mod tests { .add_unwind(Ok(UnwindOutput { checkpoint: StageCheckpoint::new(1) })), ) .with_max_block(10) - .build(db, MAINNET.clone()); + .build(provider_factory); let events = pipeline.events(); // Run pipeline @@ -660,7 +640,7 @@ mod tests { pipeline.run().await.expect("Could not run pipeline"); // Unwind - pipeline.unwind(1, None).await.expect("Could not unwind pipeline"); + pipeline.unwind(1, None).expect("Could not unwind pipeline"); }); // Check that the stages were unwound in reverse order @@ -668,30 +648,33 @@ mod tests { events.collect::>().await, vec![ // Executing - PipelineEvent::Running { + PipelineEvent::Run { pipeline_stages_progress: PipelineStagesProgress { current: 1, total: 3 }, stage_id: StageId::Other("A"), - checkpoint: None + checkpoint: None, + target: Some(10), }, PipelineEvent::Ran { pipeline_stages_progress: PipelineStagesProgress { current: 1, total: 3 }, stage_id: StageId::Other("A"), result: ExecOutput { checkpoint: StageCheckpoint::new(100), done: true }, }, - PipelineEvent::Running { + PipelineEvent::Run { pipeline_stages_progress: PipelineStagesProgress { current: 2, total: 3 }, stage_id: StageId::Other("B"), - checkpoint: None + checkpoint: None, + target: Some(10), }, PipelineEvent::Ran { pipeline_stages_progress: PipelineStagesProgress { current: 2, total: 3 }, stage_id: StageId::Other("B"), result: ExecOutput { checkpoint: StageCheckpoint::new(10), done: true }, }, - PipelineEvent::Running { + PipelineEvent::Run { pipeline_stages_progress: PipelineStagesProgress { current: 3, total: 3 }, stage_id: StageId::Other("C"), - checkpoint: None + checkpoint: None, + target: Some(10), }, PipelineEvent::Ran { pipeline_stages_progress: PipelineStagesProgress { current: 3, total: 3 }, @@ -699,7 +682,7 @@ mod tests { result: ExecOutput { checkpoint: StageCheckpoint::new(20), done: true }, }, // Unwinding - PipelineEvent::Unwinding { + PipelineEvent::Unwind { stage_id: StageId::Other("C"), input: UnwindInput { checkpoint: StageCheckpoint::new(20), @@ -711,7 +694,7 @@ mod tests { stage_id: StageId::Other("C"), result: UnwindOutput { checkpoint: StageCheckpoint::new(1) }, }, - PipelineEvent::Unwinding { + PipelineEvent::Unwind { stage_id: StageId::Other("B"), input: UnwindInput { checkpoint: StageCheckpoint::new(10), @@ -723,7 +706,7 @@ mod tests { stage_id: StageId::Other("B"), result: UnwindOutput { checkpoint: StageCheckpoint::new(1) }, }, - PipelineEvent::Unwinding { + PipelineEvent::Unwind { stage_id: StageId::Other("A"), input: UnwindInput { checkpoint: StageCheckpoint::new(100), @@ -742,7 +725,7 @@ mod tests { /// Unwinds a pipeline with intermediate progress. #[tokio::test] async fn unwind_pipeline_with_intermediate_progress() { - let db = create_test_rw_db(); + let provider_factory = create_test_provider_factory(); let mut pipeline = Pipeline::builder() .add_stage( @@ -755,7 +738,7 @@ mod tests { .add_exec(Ok(ExecOutput { checkpoint: StageCheckpoint::new(10), done: true })), ) .with_max_block(10) - .build(db, MAINNET.clone()); + .build(provider_factory); let events = pipeline.events(); // Run pipeline @@ -764,7 +747,7 @@ mod tests { pipeline.run().await.expect("Could not run pipeline"); // Unwind - pipeline.unwind(50, None).await.expect("Could not unwind pipeline"); + pipeline.unwind(50, None).expect("Could not unwind pipeline"); }); // Check that the stages were unwound in reverse order @@ -772,20 +755,22 @@ mod tests { events.collect::>().await, vec![ // Executing - PipelineEvent::Running { + PipelineEvent::Run { pipeline_stages_progress: PipelineStagesProgress { current: 1, total: 2 }, stage_id: StageId::Other("A"), - checkpoint: None + checkpoint: None, + target: Some(10), }, PipelineEvent::Ran { pipeline_stages_progress: PipelineStagesProgress { current: 1, total: 2 }, stage_id: StageId::Other("A"), result: ExecOutput { checkpoint: StageCheckpoint::new(100), done: true }, }, - PipelineEvent::Running { + PipelineEvent::Run { pipeline_stages_progress: PipelineStagesProgress { current: 2, total: 2 }, stage_id: StageId::Other("B"), - checkpoint: None + checkpoint: None, + target: Some(10), }, PipelineEvent::Ran { pipeline_stages_progress: PipelineStagesProgress { current: 2, total: 2 }, @@ -795,7 +780,7 @@ mod tests { // Unwinding // Nothing to unwind in stage "B" PipelineEvent::Skipped { stage_id: StageId::Other("B") }, - PipelineEvent::Unwinding { + PipelineEvent::Unwind { stage_id: StageId::Other("A"), input: UnwindInput { checkpoint: StageCheckpoint::new(100), @@ -825,7 +810,7 @@ mod tests { /// - The pipeline finishes #[tokio::test] async fn run_pipeline_with_unwind() { - let db = create_test_rw_db(); + let provider_factory = create_test_provider_factory(); let mut pipeline = Pipeline::builder() .add_stage( @@ -850,7 +835,7 @@ mod tests { .add_exec(Ok(ExecOutput { checkpoint: StageCheckpoint::new(10), done: true })), ) .with_max_block(10) - .build(db, MAINNET.clone()); + .build(provider_factory); let events = pipeline.events(); // Run pipeline @@ -862,23 +847,25 @@ mod tests { assert_eq!( events.collect::>().await, vec![ - PipelineEvent::Running { + PipelineEvent::Run { pipeline_stages_progress: PipelineStagesProgress { current: 1, total: 2 }, stage_id: StageId::Other("A"), - checkpoint: None + checkpoint: None, + target: Some(10), }, PipelineEvent::Ran { pipeline_stages_progress: PipelineStagesProgress { current: 1, total: 2 }, stage_id: StageId::Other("A"), result: ExecOutput { checkpoint: StageCheckpoint::new(10), done: true }, }, - PipelineEvent::Running { + PipelineEvent::Run { pipeline_stages_progress: PipelineStagesProgress { current: 2, total: 2 }, stage_id: StageId::Other("B"), - checkpoint: None + checkpoint: None, + target: Some(10), }, PipelineEvent::Error { stage_id: StageId::Other("B") }, - PipelineEvent::Unwinding { + PipelineEvent::Unwind { stage_id: StageId::Other("A"), input: UnwindInput { checkpoint: StageCheckpoint::new(10), @@ -890,20 +877,22 @@ mod tests { stage_id: StageId::Other("A"), result: UnwindOutput { checkpoint: StageCheckpoint::new(0) }, }, - PipelineEvent::Running { + PipelineEvent::Run { pipeline_stages_progress: PipelineStagesProgress { current: 1, total: 2 }, stage_id: StageId::Other("A"), - checkpoint: Some(StageCheckpoint::new(0)) + checkpoint: Some(StageCheckpoint::new(0)), + target: Some(10), }, PipelineEvent::Ran { pipeline_stages_progress: PipelineStagesProgress { current: 1, total: 2 }, stage_id: StageId::Other("A"), result: ExecOutput { checkpoint: StageCheckpoint::new(10), done: true }, }, - PipelineEvent::Running { + PipelineEvent::Run { pipeline_stages_progress: PipelineStagesProgress { current: 2, total: 2 }, stage_id: StageId::Other("B"), - checkpoint: None + checkpoint: None, + target: Some(10), }, PipelineEvent::Ran { pipeline_stages_progress: PipelineStagesProgress { current: 2, total: 2 }, @@ -918,7 +907,7 @@ mod tests { #[tokio::test] async fn pipeline_error_handling() { // Non-fatal - let db = create_test_rw_db(); + let provider_factory = create_test_provider_factory(); let mut pipeline = Pipeline::builder() .add_stage( TestStage::new(StageId::Other("NonFatal")) @@ -926,17 +915,17 @@ mod tests { .add_exec(Ok(ExecOutput { checkpoint: StageCheckpoint::new(10), done: true })), ) .with_max_block(10) - .build(db, MAINNET.clone()); + .build(provider_factory); let result = pipeline.run().await; assert_matches!(result, Ok(())); // Fatal - let db = create_test_rw_db(); + let provider_factory = create_test_provider_factory(); let mut pipeline = Pipeline::builder() .add_stage(TestStage::new(StageId::Other("Fatal")).add_exec(Err( StageError::DatabaseIntegrity(ProviderError::BlockBodyIndicesNotFound(5)), ))) - .build(db, MAINNET.clone()); + .build(provider_factory); let result = pipeline.run().await; assert_matches!( result, diff --git a/crates/stages/src/sets.rs b/crates/stages/src/sets.rs index f49714e0133e..ac97fd5c742c 100644 --- a/crates/stages/src/sets.rs +++ b/crates/stages/src/sets.rs @@ -12,33 +12,32 @@ //! ```no_run //! # use reth_stages::Pipeline; //! # use reth_stages::sets::{OfflineStages}; -//! # use reth_revm::Factory; +//! # use reth_revm::EvmProcessorFactory; //! # use reth_primitives::MAINNET; -//! use reth_db::test_utils::create_test_rw_db; +//! # use reth_provider::test_utils::create_test_provider_factory; //! -//! # let factory = Factory::new(MAINNET.clone()); -//! # let db = create_test_rw_db(); +//! # let executor_factory = EvmProcessorFactory::new(MAINNET.clone()); +//! # let provider_factory = create_test_provider_factory(); //! // Build a pipeline with all offline stages. -//! # let pipeline = -//! Pipeline::builder().add_stages(OfflineStages::new(factory)).build(db, MAINNET.clone()); +//! # let pipeline = Pipeline::builder().add_stages(OfflineStages::new(executor_factory)).build(provider_factory); //! ``` //! //! ```ignore //! # use reth_stages::Pipeline; //! # use reth_stages::{StageSet, sets::OfflineStages}; -//! # use reth_revm::Factory; +//! # use reth_revm::EvmProcessorFactory; //! # use reth_primitives::MAINNET; //! // Build a pipeline with all offline stages and a custom stage at the end. -//! # let factory = Factory::new(MAINNET.clone()); +//! # let executor_factory = EvmProcessorFactory::new(MAINNET.clone()); //! Pipeline::builder() //! .add_stages( -//! OfflineStages::new(factory).builder().add_stage(MyCustomStage) +//! OfflineStages::new(executor_factory).builder().add_stage(MyCustomStage) //! ) //! .build(); //! ``` use crate::{ stages::{ - AccountHashingStage, BodyStage, ExecutionStage, FinishStage, HeaderStage, HeaderSyncMode, + AccountHashingStage, BodyStage, ExecutionStage, FinishStage, HeaderStage, IndexAccountHistoryStage, IndexStorageHistoryStage, MerkleStage, SenderRecoveryStage, StorageHashingStage, TotalDifficultyStage, TransactionLookupStage, }, @@ -49,7 +48,7 @@ use reth_interfaces::{ consensus::Consensus, p2p::{bodies::downloader::BodyDownloader, headers::downloader::HeaderDownloader}, }; -use reth_provider::ExecutorFactory; +use reth_provider::{ExecutorFactory, HeaderSyncGapProvider, HeaderSyncMode}; use std::sync::Arc; /// A set containing all stages to run a fully syncing instance of reth. @@ -75,16 +74,17 @@ use std::sync::Arc; /// - [`IndexAccountHistoryStage`] /// - [`FinishStage`] #[derive(Debug)] -pub struct DefaultStages { +pub struct DefaultStages { /// Configuration for the online stages - online: OnlineStages, + online: OnlineStages, /// Executor factory needs for execution stage executor_factory: EF, } -impl DefaultStages { +impl DefaultStages { /// Create a new set of default stages with default values. pub fn new( + provider: Provider, header_mode: HeaderSyncMode, consensus: Arc, header_downloader: H, @@ -95,13 +95,19 @@ impl DefaultStages { EF: ExecutorFactory, { Self { - online: OnlineStages::new(header_mode, consensus, header_downloader, body_downloader), + online: OnlineStages::new( + provider, + header_mode, + consensus, + header_downloader, + body_downloader, + ), executor_factory, } } } -impl DefaultStages +impl DefaultStages where EF: ExecutorFactory, { @@ -114,9 +120,10 @@ where } } -impl StageSet for DefaultStages +impl StageSet for DefaultStages where DB: Database, + Provider: HeaderSyncGapProvider + 'static, H: HeaderDownloader + 'static, B: BodyDownloader + 'static, EF: ExecutorFactory, @@ -131,7 +138,9 @@ where /// These stages *can* be run without network access if the specified downloaders are /// themselves offline. #[derive(Debug)] -pub struct OnlineStages { +pub struct OnlineStages { + /// Sync gap provider for the headers stage. + provider: Provider, /// The sync mode for the headers stage. header_mode: HeaderSyncMode, /// The consensus engine used to validate incoming data. @@ -142,60 +151,64 @@ pub struct OnlineStages { body_downloader: B, } -impl OnlineStages { +impl OnlineStages { /// Create a new set of online stages with default values. pub fn new( + provider: Provider, header_mode: HeaderSyncMode, consensus: Arc, header_downloader: H, body_downloader: B, ) -> Self { - Self { header_mode, consensus, header_downloader, body_downloader } + Self { provider, header_mode, consensus, header_downloader, body_downloader } } } -impl OnlineStages +impl OnlineStages where + Provider: HeaderSyncGapProvider + 'static, H: HeaderDownloader + 'static, B: BodyDownloader + 'static, { /// Create a new builder using the given headers stage. pub fn builder_with_headers( - headers: HeaderStage, + headers: HeaderStage, body_downloader: B, consensus: Arc, ) -> StageSetBuilder { StageSetBuilder::default() .add_stage(headers) .add_stage(TotalDifficultyStage::new(consensus.clone())) - .add_stage(BodyStage { downloader: body_downloader, consensus }) + .add_stage(BodyStage::new(body_downloader)) } /// Create a new builder using the given bodies stage. pub fn builder_with_bodies( bodies: BodyStage, + provider: Provider, mode: HeaderSyncMode, header_downloader: H, consensus: Arc, ) -> StageSetBuilder { StageSetBuilder::default() - .add_stage(HeaderStage::new(header_downloader, mode)) + .add_stage(HeaderStage::new(provider, header_downloader, mode)) .add_stage(TotalDifficultyStage::new(consensus.clone())) .add_stage(bodies) } } -impl StageSet for OnlineStages +impl StageSet for OnlineStages where DB: Database, + Provider: HeaderSyncGapProvider + 'static, H: HeaderDownloader + 'static, B: BodyDownloader + 'static, { fn builder(self) -> StageSetBuilder { StageSetBuilder::default() - .add_stage(HeaderStage::new(self.header_downloader, self.header_mode)) + .add_stage(HeaderStage::new(self.provider, self.header_downloader, self.header_mode)) .add_stage(TotalDifficultyStage::new(self.consensus.clone())) - .add_stage(BodyStage { downloader: self.body_downloader, consensus: self.consensus }) + .add_stage(BodyStage::new(self.body_downloader)) } } diff --git a/crates/stages/src/stage.rs b/crates/stages/src/stage.rs index 95e397cbe8a1..aa8360b7d9d8 100644 --- a/crates/stages/src/stage.rs +++ b/crates/stages/src/stage.rs @@ -1,5 +1,4 @@ use crate::error::StageError; -use async_trait::async_trait; use reth_db::database::Database; use reth_primitives::{ stage::{StageCheckpoint, StageId}, @@ -8,7 +7,9 @@ use reth_primitives::{ use reth_provider::{BlockReader, DatabaseProviderRW, ProviderError, TransactionsProvider}; use std::{ cmp::{max, min}, + future::poll_fn, ops::{Range, RangeInclusive}, + task::{Context, Poll}, }; /// Stage execution input, see [Stage::execute]. @@ -75,7 +76,7 @@ impl ExecInput { /// the number of transactions exceeds the threshold. pub fn next_block_range_with_transaction_threshold( &self, - provider: &DatabaseProviderRW<'_, DB>, + provider: &DatabaseProviderRW, tx_threshold: u64, ) -> Result<(Range, RangeInclusive, bool), StageError> { let start_block = self.next_block(); @@ -189,24 +190,70 @@ pub struct UnwindOutput { /// Stages are executed as part of a pipeline where they are executed serially. /// /// Stages receive [`DatabaseProviderRW`]. -#[async_trait] +#[auto_impl::auto_impl(Box)] pub trait Stage: Send + Sync { /// Get the ID of the stage. /// /// Stage IDs must be unique. fn id(&self) -> StageId; + /// Returns `Poll::Ready(Ok(()))` when the stage is ready to execute the given range. + /// + /// This method is heavily inspired by [tower](https://crates.io/crates/tower)'s `Service` trait. + /// Any asynchronous tasks or communication should be handled in `poll_ready`, e.g. moving + /// downloaded items from downloaders to an internal buffer in the stage. + /// + /// If the stage has any pending external state, then `Poll::Pending` is returned. + /// + /// If `Poll::Ready(Err(_))` is returned, the stage may not be able to execute anymore + /// depending on the specific error. In that case, an unwind must be issued instead. + /// + /// Once `Poll::Ready(Ok(()))` is returned, the stage may be executed once using `execute`. + /// Until the stage has been executed, repeated calls to `poll_ready` must return either + /// `Poll::Ready(Ok(()))` or `Poll::Ready(Err(_))`. + /// + /// Note that `poll_ready` may reserve shared resources that are consumed in a subsequent call + /// of `execute`, e.g. internal buffers. It is crucial for implementations to not assume that + /// `execute` will always be invoked and to ensure that those resources are appropriately + /// released if the stage is dropped before `execute` is called. + /// + /// For the same reason, it is also important that any shared resources do not exhibit + /// unbounded growth on repeated calls to `poll_ready`. + /// + /// Unwinds may happen without consulting `poll_ready` first. + fn poll_execute_ready( + &mut self, + _cx: &mut Context<'_>, + _input: ExecInput, + ) -> Poll> { + Poll::Ready(Ok(())) + } + /// Execute the stage. - async fn execute( + /// It is expected that the stage will write all necessary data to the database + /// upon invoking this method. + fn execute( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, input: ExecInput, ) -> Result; /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, input: UnwindInput, ) -> Result; } + +/// [Stage] trait extension. +#[async_trait::async_trait] +pub trait StageExt: Stage { + /// Utility extension for the `Stage` trait that invokes `Stage::poll_execute_ready` + /// with [poll_fn] context. For more information see [Stage::poll_execute_ready]. + async fn execute_ready(&mut self, input: ExecInput) -> Result<(), StageError> { + poll_fn(|cx| self.poll_execute_ready(cx, input)).await + } +} + +impl> StageExt for S {} diff --git a/crates/stages/src/stages/bodies.rs b/crates/stages/src/stages/bodies.rs index 8da7e6511ed3..56001595cf76 100644 --- a/crates/stages/src/stages/bodies.rs +++ b/crates/stages/src/stages/bodies.rs @@ -8,13 +8,10 @@ use reth_db::{ transaction::{DbTx, DbTxMut}, DatabaseError, }; -use reth_interfaces::{ - consensus::Consensus, - p2p::bodies::{downloader::BodyDownloader, response::BlockResponse}, -}; +use reth_interfaces::p2p::bodies::{downloader::BodyDownloader, response::BlockResponse}; use reth_primitives::stage::{EntitiesCheckpoint, StageCheckpoint, StageId}; use reth_provider::DatabaseProviderRW; -use std::sync::Arc; +use std::task::{ready, Context, Poll}; use tracing::*; // TODO(onbjerg): Metrics and events (gradual status for e.g. CLI) @@ -51,33 +48,63 @@ use tracing::*; #[derive(Debug)] pub struct BodyStage { /// The body downloader. - pub downloader: D, - /// The consensus engine. - pub consensus: Arc, + downloader: D, + /// Block response buffer. + buffer: Option>, +} + +impl BodyStage { + /// Create new bodies stage from downloader. + pub fn new(downloader: D) -> Self { + Self { downloader, buffer: None } + } } -#[async_trait::async_trait] impl Stage for BodyStage { /// Return the id of the stage fn id(&self) -> StageId { StageId::Bodies } + fn poll_execute_ready( + &mut self, + cx: &mut Context<'_>, + input: ExecInput, + ) -> Poll> { + if input.target_reached() || self.buffer.is_some() { + return Poll::Ready(Ok(())) + } + + // Update the header range on the downloader + self.downloader.set_download_range(input.next_block_range())?; + + // Poll next downloader item. + let maybe_next_result = ready!(self.downloader.try_poll_next_unpin(cx)); + + // Task downloader can return `None` only if the response relaying channel was closed. This + // is a fatal error to prevent the pipeline from running forever. + let response = match maybe_next_result { + Some(Ok(downloaded)) => { + self.buffer = Some(downloaded); + Ok(()) + } + Some(Err(err)) => Err(err.into()), + None => Err(StageError::ChannelClosed), + }; + Poll::Ready(response) + } + /// Download block bodies from the last checkpoint for this stage up until the latest synced /// header, limited by the stage's batch size. - async fn execute( + fn execute( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, input: ExecInput, ) -> Result { if input.target_reached() { return Ok(ExecOutput::done(input.checkpoint())) } - - let range = input.next_block_range(); - // Update the header range on the downloader - self.downloader.set_download_range(range.clone())?; - let (from_block, to_block) = range.into_inner(); + let (from_block, to_block) = input.next_block_range().into_inner(); // Cursors used to write bodies, ommers and transactions let tx = provider.tx_ref(); @@ -92,15 +119,10 @@ impl Stage for BodyStage { debug!(target: "sync::stages::bodies", stage_progress = from_block, target = to_block, start_tx_id = next_tx_num, "Commencing sync"); - // Task downloader can return `None` only if the response relaying channel was closed. This - // is a fatal error to prevent the pipeline from running forever. - let downloaded_bodies = - self.downloader.try_next().await?.ok_or(StageError::ChannelClosed)?; - - trace!(target: "sync::stages::bodies", bodies_len = downloaded_bodies.len(), "Writing blocks"); - + let buffer = self.buffer.take().ok_or(StageError::MissingDownloadBuffer)?; + trace!(target: "sync::stages::bodies", bodies_len = buffer.len(), "Writing blocks"); let mut highest_block = from_block; - for response in downloaded_bodies { + for response in buffer { // Write block let block_number = response.block_number(); @@ -161,11 +183,13 @@ impl Stage for BodyStage { } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, input: UnwindInput, ) -> Result { + self.buffer.take(); + let tx = provider.tx_ref(); // Cursors to unwind bodies, ommers let mut body_cursor = tx.cursor_write::()?; @@ -221,7 +245,7 @@ impl Stage for BodyStage { // beforehand how many bytes we need to download. So the good solution would be to measure the // progress in gas as a proxy to size. Execution stage uses a similar approach. fn stage_checkpoint( - provider: &DatabaseProviderRW<'_, DB>, + provider: &DatabaseProviderRW, ) -> Result { Ok(EntitiesCheckpoint { processed: provider.tx_ref().entries::()? as u64, @@ -416,7 +440,7 @@ mod tests { // Delete a transaction runner - .tx() + .db() .commit(|tx| { let mut tx_cursor = tx.cursor_write::()?; tx_cursor.last()?.expect("Could not read last transaction"); @@ -447,7 +471,7 @@ mod tests { use crate::{ stages::bodies::BodyStage, test_utils::{ - ExecuteStageTestRunner, StageTestRunner, TestRunnerError, TestTransaction, + ExecuteStageTestRunner, StageTestRunner, TestRunnerError, TestStageDB, UnwindStageTestRunner, }, ExecInput, ExecOutput, UnwindInput, @@ -455,7 +479,6 @@ mod tests { use futures_util::Stream; use reth_db::{ cursor::DbCursorRO, - database::Database, models::{StoredBlockBodyIndices, StoredBlockOmmers}, tables, test_utils::TempDatabase, @@ -476,10 +499,10 @@ mod tests { test_utils::{ generators, generators::{random_block_range, random_signed_tx}, - TestConsensus, }, }; use reth_primitives::{BlockBody, BlockNumber, SealedBlock, SealedHeader, TxNumber, B256}; + use reth_provider::ProviderFactory; use std::{ collections::{HashMap, VecDeque}, ops::RangeInclusive, @@ -505,20 +528,14 @@ mod tests { /// A helper struct for running the [BodyStage]. pub(crate) struct BodyTestRunner { - pub(crate) consensus: Arc, responses: HashMap, - tx: TestTransaction, + db: TestStageDB, batch_size: u64, } impl Default for BodyTestRunner { fn default() -> Self { - Self { - consensus: Arc::new(TestConsensus::default()), - responses: HashMap::default(), - tx: TestTransaction::default(), - batch_size: 1000, - } + Self { responses: HashMap::default(), db: TestStageDB::default(), batch_size: 1000 } } } @@ -535,19 +552,16 @@ mod tests { impl StageTestRunner for BodyTestRunner { type S = BodyStage; - fn tx(&self) -> &TestTransaction { - &self.tx + fn db(&self) -> &TestStageDB { + &self.db } fn stage(&self) -> Self::S { - BodyStage { - downloader: TestBodyDownloader::new( - self.tx.inner_raw(), - self.responses.clone(), - self.batch_size, - ), - consensus: self.consensus.clone(), - } + BodyStage::new(TestBodyDownloader::new( + self.db.factory.clone(), + self.responses.clone(), + self.batch_size, + )) } } @@ -560,10 +574,10 @@ mod tests { let end = input.target(); let mut rng = generators::rng(); let blocks = random_block_range(&mut rng, start..=end, GENESIS_HASH, 0..2); - self.tx.insert_headers_with_td(blocks.iter().map(|block| &block.header))?; + self.db.insert_headers_with_td(blocks.iter().map(|block| &block.header))?; if let Some(progress) = blocks.first() { // Insert last progress data - self.tx.commit(|tx| { + self.db.commit(|tx| { let body = StoredBlockBodyIndices { first_tx_num: 0, tx_count: progress.body.len() as u64, @@ -611,16 +625,16 @@ mod tests { impl UnwindStageTestRunner for BodyTestRunner { fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> { - self.tx.ensure_no_entry_above::( + self.db.ensure_no_entry_above::( input.unwind_to, |key| key, )?; - self.tx + self.db .ensure_no_entry_above::(input.unwind_to, |key| key)?; if let Some(last_tx_id) = self.get_last_tx_id()? { - self.tx + self.db .ensure_no_entry_above::(last_tx_id, |key| key)?; - self.tx.ensure_no_entry_above::( + self.db.ensure_no_entry_above::( last_tx_id, |key| key, )?; @@ -632,7 +646,7 @@ mod tests { impl BodyTestRunner { /// Get the last available tx id if any pub(crate) fn get_last_tx_id(&self) -> Result, TestRunnerError> { - let last_body = self.tx.query(|tx| { + let last_body = self.db.query(|tx| { let v = tx.cursor_read::()?.last()?; Ok(v) })?; @@ -650,7 +664,7 @@ mod tests { prev_progress: BlockNumber, highest_block: BlockNumber, ) -> Result<(), TestRunnerError> { - self.tx.query(|tx| { + self.db.query(|tx| { // Acquire cursors on body related tables let mut headers_cursor = tx.cursor_read::()?; let mut bodies_cursor = tx.cursor_read::()?; @@ -741,7 +755,7 @@ mod tests { /// A [BodyDownloader] that is backed by an internal [HashMap] for testing. #[derive(Debug)] pub(crate) struct TestBodyDownloader { - db: Arc>, + provider_factory: ProviderFactory>>, responses: HashMap, headers: VecDeque, batch_size: u64, @@ -749,11 +763,11 @@ mod tests { impl TestBodyDownloader { pub(crate) fn new( - db: Arc>, + provider_factory: ProviderFactory>>, responses: HashMap, batch_size: u64, ) -> Self { - Self { db, responses, headers: VecDeque::default(), batch_size } + Self { provider_factory, responses, headers: VecDeque::default(), batch_size } } } @@ -762,22 +776,19 @@ mod tests { &mut self, range: RangeInclusive, ) -> DownloadResult<()> { - self.headers = - VecDeque::from(self.db.view(|tx| -> DownloadResult> { - let mut header_cursor = tx.cursor_read::()?; - - let mut canonical_cursor = tx.cursor_read::()?; - let walker = canonical_cursor.walk_range(range)?; - - let mut headers = Vec::default(); - for entry in walker { - let (num, hash) = entry?; - let (_, header) = - header_cursor.seek_exact(num)?.expect("missing header"); - headers.push(header.seal(hash)); - } - Ok(headers) - })??); + let provider = self.provider_factory.provider()?; + let mut header_cursor = provider.tx_ref().cursor_read::()?; + + let mut canonical_cursor = + provider.tx_ref().cursor_read::()?; + let walker = canonical_cursor.walk_range(range)?; + + for entry in walker { + let (num, hash) = entry?; + let (_, header) = header_cursor.seek_exact(num)?.expect("missing header"); + self.headers.push_back(header.seal(hash)); + } + Ok(()) } } diff --git a/crates/stages/src/stages/execution.rs b/crates/stages/src/stages/execution.rs index a53bef070211..41a26165c9bb 100644 --- a/crates/stages/src/stages/execution.rs +++ b/crates/stages/src/stages/execution.rs @@ -110,7 +110,7 @@ impl ExecutionStage { /// Execute the stage. pub fn execute_inner( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, input: ExecInput, ) -> Result { if input.target_reached() { @@ -228,7 +228,7 @@ impl ExecutionStage { /// been previously executed. fn adjust_prune_modes( &self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, start_block: u64, max_block: u64, ) -> Result { @@ -247,7 +247,7 @@ impl ExecutionStage { } fn execution_checkpoint( - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, start_block: BlockNumber, max_block: BlockNumber, checkpoint: StageCheckpoint, @@ -314,7 +314,7 @@ fn execution_checkpoint( } fn calculate_gas_used_from_headers( - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, range: RangeInclusive, ) -> Result { let mut gas_total = 0; @@ -331,7 +331,6 @@ fn calculate_gas_used_from_headers( Ok(gas_total) } -#[async_trait::async_trait] impl Stage for ExecutionStage { /// Return the id of the stage fn id(&self) -> StageId { @@ -339,18 +338,18 @@ impl Stage for ExecutionStage { } /// Execute the stage - async fn execute( + fn execute( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, input: ExecInput, ) -> Result { self.execute_inner(provider, input) } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, input: UnwindInput, ) -> Result { let tx = provider.tx_ref(); @@ -492,7 +491,7 @@ impl ExecutionStageThresholds { #[cfg(test)] mod tests { use super::*; - use crate::test_utils::TestTransaction; + use crate::test_utils::TestStageDB; use alloy_rlp::Decodable; use assert_matches::assert_matches; use reth_db::{models::AccountBeforeTx, test_utils::create_test_rw_db}; @@ -501,14 +500,15 @@ mod tests { ChainSpecBuilder, PruneModes, SealedBlock, StorageEntry, B256, MAINNET, U256, }; use reth_provider::{AccountReader, BlockWriter, ProviderFactory, ReceiptProvider}; - use reth_revm::Factory; + use reth_revm::EvmProcessorFactory; use std::sync::Arc; - fn stage() -> ExecutionStage { - let factory = - Factory::new(Arc::new(ChainSpecBuilder::mainnet().berlin_activated().build())); + fn stage() -> ExecutionStage { + let executor_factory = EvmProcessorFactory::new(Arc::new( + ChainSpecBuilder::mainnet().berlin_activated().build(), + )); ExecutionStage::new( - factory, + executor_factory, ExecutionStageThresholds { max_blocks: Some(100), max_changes: None, @@ -685,8 +685,8 @@ mod tests { provider.commit().unwrap(); let provider = factory.provider_rw().unwrap(); - let mut execution_stage = stage(); - let output = execution_stage.execute(&provider, input).await.unwrap(); + let mut execution_stage: ExecutionStage = stage(); + let output = execution_stage.execute(&provider, input).unwrap(); provider.commit().unwrap(); assert_matches!(output, ExecOutput { checkpoint: StageCheckpoint { @@ -787,7 +787,7 @@ mod tests { // execute let provider = factory.provider_rw().unwrap(); let mut execution_stage = stage(); - let result = execution_stage.execute(&provider, input).await.unwrap(); + let result = execution_stage.execute(&provider, input).unwrap(); provider.commit().unwrap(); let provider = factory.provider_rw().unwrap(); @@ -797,7 +797,6 @@ mod tests { &provider, UnwindInput { checkpoint: result.checkpoint, unwind_to: 0, bad_block: None }, ) - .await .unwrap(); assert_matches!(result, UnwindOutput { @@ -828,9 +827,8 @@ mod tests { #[tokio::test] async fn test_selfdestruct() { - let test_tx = TestTransaction::default(); - let factory = ProviderFactory::new(test_tx.tx.as_ref(), MAINNET.clone()); - let provider = factory.provider_rw().unwrap(); + let test_db = TestStageDB::default(); + let provider = test_db.factory.provider_rw().unwrap(); let input = ExecInput { target: Some(1), checkpoint: None }; let mut genesis_rlp = hex!("f901f8f901f3a00000000000000000000000000000000000000000000000000000000000000000a01dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347942adc25665018aa1fe0e6bc666dac8fc2697ff9baa0c9ceb8372c88cb461724d8d3d87e8b933f6fc5f679d4841800e662f4428ffd0da056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421b90100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008302000080830f4240808000a00000000000000000000000000000000000000000000000000000000000000000880000000000000000c0c0").as_slice(); let genesis = SealedBlock::decode(&mut genesis_rlp).unwrap(); @@ -855,7 +853,7 @@ mod tests { Account { nonce: 0, balance: U256::ZERO, bytecode_hash: Some(code_hash) }; // set account - let provider = factory.provider_rw().unwrap(); + let provider = test_db.factory.provider_rw().unwrap(); provider.tx_ref().put::(caller_address, caller_info).unwrap(); provider .tx_ref() @@ -884,13 +882,13 @@ mod tests { provider.commit().unwrap(); // execute - let provider = factory.provider_rw().unwrap(); + let provider = test_db.factory.provider_rw().unwrap(); let mut execution_stage = stage(); - let _ = execution_stage.execute(&provider, input).await.unwrap(); + let _ = execution_stage.execute(&provider, input).unwrap(); provider.commit().unwrap(); // assert unwind stage - let provider = factory.provider_rw().unwrap(); + let provider = test_db.factory.provider_rw().unwrap(); assert_eq!(provider.basic_account(destroyed_address), Ok(None), "Account was destroyed"); assert_eq!( @@ -900,8 +898,8 @@ mod tests { ); // drops tx so that it returns write privilege to test_tx drop(provider); - let plain_accounts = test_tx.table::().unwrap(); - let plain_storage = test_tx.table::().unwrap(); + let plain_accounts = test_db.table::().unwrap(); + let plain_storage = test_db.table::().unwrap(); assert_eq!( plain_accounts, @@ -926,8 +924,8 @@ mod tests { ); assert!(plain_storage.is_empty()); - let account_changesets = test_tx.table::().unwrap(); - let storage_changesets = test_tx.table::().unwrap(); + let account_changesets = test_db.table::().unwrap(); + let storage_changesets = test_db.table::().unwrap(); assert_eq!( account_changesets, diff --git a/crates/stages/src/stages/finish.rs b/crates/stages/src/stages/finish.rs index 751c4e37bfe1..341be77dd1e1 100644 --- a/crates/stages/src/stages/finish.rs +++ b/crates/stages/src/stages/finish.rs @@ -11,23 +11,22 @@ use reth_provider::DatabaseProviderRW; #[non_exhaustive] pub struct FinishStage; -#[async_trait::async_trait] impl Stage for FinishStage { fn id(&self) -> StageId { StageId::Finish } - async fn execute( + fn execute( &mut self, - _provider: &DatabaseProviderRW<'_, &DB>, + _provider: &DatabaseProviderRW, input: ExecInput, ) -> Result { Ok(ExecOutput { checkpoint: StageCheckpoint::new(input.target()), done: true }) } - async fn unwind( + fn unwind( &mut self, - _provider: &DatabaseProviderRW<'_, &DB>, + _provider: &DatabaseProviderRW, input: UnwindInput, ) -> Result { Ok(UnwindOutput { checkpoint: StageCheckpoint::new(input.unwind_to) }) @@ -39,7 +38,7 @@ mod tests { use super::*; use crate::test_utils::{ stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, TestRunnerError, - TestTransaction, UnwindStageTestRunner, + TestStageDB, UnwindStageTestRunner, }; use reth_interfaces::test_utils::{ generators, @@ -51,14 +50,14 @@ mod tests { #[derive(Default)] struct FinishTestRunner { - tx: TestTransaction, + db: TestStageDB, } impl StageTestRunner for FinishTestRunner { type S = FinishStage; - fn tx(&self) -> &TestTransaction { - &self.tx + fn db(&self) -> &TestStageDB { + &self.db } fn stage(&self) -> Self::S { @@ -73,7 +72,7 @@ mod tests { let start = input.checkpoint().block_number; let mut rng = generators::rng(); let head = random_header(&mut rng, start, None); - self.tx.insert_headers_with_td(std::iter::once(&head))?; + self.db.insert_headers_with_td(std::iter::once(&head))?; // use previous progress as seed size let end = input.target.unwrap_or_default() + 1; @@ -83,7 +82,7 @@ mod tests { } let mut headers = random_header_range(&mut rng, start + 1..end, head.hash()); - self.tx.insert_headers_with_td(headers.iter())?; + self.db.insert_headers_with_td(headers.iter())?; headers.insert(0, head); Ok(headers) } diff --git a/crates/stages/src/stages/hashing_account.rs b/crates/stages/src/stages/hashing_account.rs index 896bfc9762b1..308cfa71ea27 100644 --- a/crates/stages/src/stages/hashing_account.rs +++ b/crates/stages/src/stages/hashing_account.rs @@ -21,8 +21,8 @@ use std::{ cmp::max, fmt::Debug, ops::{Range, RangeInclusive}, + sync::mpsc, }; -use tokio::sync::mpsc; use tracing::*; /// Account hashing stage hashes plain account. @@ -79,7 +79,7 @@ impl AccountHashingStage { /// Proceeds to go to the `BlockTransitionIndex` end, go back `transitions` and change the /// account state in the `AccountChangeSet` table. pub fn seed( - provider: &DatabaseProviderRW<'_, DB>, + provider: &DatabaseProviderRW, opts: SeedOpts, ) -> Result, StageError> { use reth_db::models::AccountBeforeTx; @@ -125,7 +125,6 @@ impl AccountHashingStage { } } -#[async_trait::async_trait] impl Stage for AccountHashingStage { /// Return the id of the stage fn id(&self) -> StageId { @@ -133,9 +132,9 @@ impl Stage for AccountHashingStage { } /// Execute the stage. - async fn execute( + fn execute( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, input: ExecInput, ) -> Result { if input.target_reached() { @@ -190,7 +189,7 @@ impl Stage for AccountHashingStage { ) { // An _unordered_ channel to receive results from a rayon job - let (tx, rx) = mpsc::unbounded_channel(); + let (tx, rx) = mpsc::channel(); channels.push(rx); let chunk = chunk.collect::, _>>()?; @@ -205,8 +204,8 @@ impl Stage for AccountHashingStage { let mut hashed_batch = Vec::with_capacity(self.commit_threshold as usize); // Iterate over channels and append the hashed accounts. - for mut channel in channels { - while let Some(hashed) = channel.recv().await { + for channel in channels { + while let Ok(hashed) = channel.recv() { hashed_batch.push(hashed); } } @@ -265,9 +264,9 @@ impl Stage for AccountHashingStage { } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, input: UnwindInput, ) -> Result { let (range, unwind_progress, _) = @@ -289,7 +288,7 @@ impl Stage for AccountHashingStage { } fn stage_checkpoint_progress( - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, ) -> Result { Ok(EntitiesCheckpoint { processed: provider.tx_ref().entries::()? as u64, @@ -342,7 +341,7 @@ mod tests { done: true, }) if block_number == previous_stage && processed == total && - total == runner.tx.table::().unwrap().len() as u64 + total == runner.db.table::().unwrap().len() as u64 ); // Validate the stage execution @@ -369,7 +368,7 @@ mod tests { let result = rx.await.unwrap(); let fifth_address = runner - .tx + .db .query(|tx| { let (address, _) = tx .cursor_read::()? @@ -399,9 +398,9 @@ mod tests { }, done: false }) if address == fifth_address && - total == runner.tx.table::().unwrap().len() as u64 + total == runner.db.table::().unwrap().len() as u64 ); - assert_eq!(runner.tx.table::().unwrap().len(), 5); + assert_eq!(runner.db.table::().unwrap().len(), 5); // second run, hash next five accounts. input.checkpoint = Some(result.unwrap().checkpoint); @@ -426,9 +425,9 @@ mod tests { }, done: true }) if processed == total && - total == runner.tx.table::().unwrap().len() as u64 + total == runner.db.table::().unwrap().len() as u64 ); - assert_eq!(runner.tx.table::().unwrap().len(), 10); + assert_eq!(runner.db.table::().unwrap().len(), 10); // Validate the stage execution assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation"); @@ -438,14 +437,14 @@ mod tests { use super::*; use crate::{ stages::hashing_account::AccountHashingStage, - test_utils::{StageTestRunner, TestTransaction}, + test_utils::{StageTestRunner, TestStageDB}, ExecInput, ExecOutput, UnwindInput, }; use reth_db::{cursor::DbCursorRO, tables, transaction::DbTx}; use reth_primitives::Address; pub(crate) struct AccountHashingTestRunner { - pub(crate) tx: TestTransaction, + pub(crate) db: TestStageDB, commit_threshold: u64, clean_threshold: u64, } @@ -463,7 +462,7 @@ mod tests { /// Iterates over PlainAccount table and checks that the accounts match the ones /// in the HashedAccount table pub(crate) fn check_hashed_accounts(&self) -> Result<(), TestRunnerError> { - self.tx.query(|tx| { + self.db.query(|tx| { let mut acc_cursor = tx.cursor_read::()?; let mut hashed_acc_cursor = tx.cursor_read::()?; @@ -482,7 +481,7 @@ mod tests { /// Same as check_hashed_accounts, only that checks with the old account state, /// namely, the same account with nonce - 1 and balance - 1. pub(crate) fn check_old_hashed_accounts(&self) -> Result<(), TestRunnerError> { - self.tx.query(|tx| { + self.db.query(|tx| { let mut acc_cursor = tx.cursor_read::()?; let mut hashed_acc_cursor = tx.cursor_read::()?; @@ -507,19 +506,15 @@ mod tests { impl Default for AccountHashingTestRunner { fn default() -> Self { - Self { - tx: TestTransaction::default(), - commit_threshold: 1000, - clean_threshold: 1000, - } + Self { db: TestStageDB::default(), commit_threshold: 1000, clean_threshold: 1000 } } } impl StageTestRunner for AccountHashingTestRunner { type S = AccountHashingStage; - fn tx(&self) -> &TestTransaction { - &self.tx + fn db(&self) -> &TestStageDB { + &self.db } fn stage(&self) -> Self::S { @@ -535,7 +530,7 @@ mod tests { type Seed = Vec<(Address, Account)>; fn seed_execution(&mut self, input: ExecInput) -> Result { - let provider = self.tx.inner_rw(); + let provider = self.db.factory.provider_rw()?; let res = Ok(AccountHashingStage::seed( &provider, SeedOpts { blocks: 1..=input.target(), accounts: 0..10, txs: 0..3 }, diff --git a/crates/stages/src/stages/hashing_storage.rs b/crates/stages/src/stages/hashing_storage.rs index 2580b58c9783..d508846a43c4 100644 --- a/crates/stages/src/stages/hashing_storage.rs +++ b/crates/stages/src/stages/hashing_storage.rs @@ -44,7 +44,6 @@ impl Default for StorageHashingStage { } } -#[async_trait::async_trait] impl Stage for StorageHashingStage { /// Return the id of the stage fn id(&self) -> StageId { @@ -52,9 +51,9 @@ impl Stage for StorageHashingStage { } /// Execute the stage. - async fn execute( + fn execute( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, input: ExecInput, ) -> Result { let tx = provider.tx_ref(); @@ -191,9 +190,9 @@ impl Stage for StorageHashingStage { } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, input: UnwindInput, ) -> Result { let (range, unwind_progress, _) = @@ -214,7 +213,7 @@ impl Stage for StorageHashingStage { } fn stage_checkpoint_progress( - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, ) -> Result { Ok(EntitiesCheckpoint { processed: provider.tx_ref().entries::()? as u64, @@ -227,7 +226,7 @@ mod tests { use super::*; use crate::test_utils::{ stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, TestRunnerError, - TestTransaction, UnwindStageTestRunner, + TestStageDB, UnwindStageTestRunner, }; use assert_matches::assert_matches; use rand::Rng; @@ -283,7 +282,7 @@ mod tests { }, .. }) if processed == previous_checkpoint.progress.processed + 1 && - total == runner.tx.table::().unwrap().len() as u64); + total == runner.db.table::().unwrap().len() as u64); // Continue from checkpoint input.checkpoint = Some(checkpoint); @@ -297,7 +296,7 @@ mod tests { }, .. }) if processed == total && - total == runner.tx.table::().unwrap().len() as u64); + total == runner.db.table::().unwrap().len() as u64); // Validate the stage execution assert!( @@ -332,7 +331,7 @@ mod tests { let result = rx.await.unwrap(); let (progress_address, progress_key) = runner - .tx + .db .query(|tx| { let (address, entry) = tx .cursor_read::()? @@ -364,9 +363,9 @@ mod tests { }, done: false }) if address == progress_address && storage == progress_key && - total == runner.tx.table::().unwrap().len() as u64 + total == runner.db.table::().unwrap().len() as u64 ); - assert_eq!(runner.tx.table::().unwrap().len(), 500); + assert_eq!(runner.db.table::().unwrap().len(), 500); // second run with commit threshold of 2 to check if subkey is set. runner.set_commit_threshold(2); @@ -376,7 +375,7 @@ mod tests { let result = rx.await.unwrap(); let (progress_address, progress_key) = runner - .tx + .db .query(|tx| { let (address, entry) = tx .cursor_read::()? @@ -410,9 +409,9 @@ mod tests { }, done: false }) if address == progress_address && storage == progress_key && - total == runner.tx.table::().unwrap().len() as u64 + total == runner.db.table::().unwrap().len() as u64 ); - assert_eq!(runner.tx.table::().unwrap().len(), 502); + assert_eq!(runner.db.table::().unwrap().len(), 502); // third last run, hash rest of storages. runner.set_commit_threshold(1000); @@ -442,11 +441,11 @@ mod tests { }, done: true }) if processed == total && - total == runner.tx.table::().unwrap().len() as u64 + total == runner.db.table::().unwrap().len() as u64 ); assert_eq!( - runner.tx.table::().unwrap().len(), - runner.tx.table::().unwrap().len() + runner.db.table::().unwrap().len(), + runner.db.table::().unwrap().len() ); // Validate the stage execution @@ -454,22 +453,22 @@ mod tests { } struct StorageHashingTestRunner { - tx: TestTransaction, + db: TestStageDB, commit_threshold: u64, clean_threshold: u64, } impl Default for StorageHashingTestRunner { fn default() -> Self { - Self { tx: TestTransaction::default(), commit_threshold: 1000, clean_threshold: 1000 } + Self { db: TestStageDB::default(), commit_threshold: 1000, clean_threshold: 1000 } } } impl StageTestRunner for StorageHashingTestRunner { type S = StorageHashingStage; - fn tx(&self) -> &TestTransaction { - &self.tx + fn db(&self) -> &TestStageDB { + &self.db } fn stage(&self) -> Self::S { @@ -494,7 +493,7 @@ mod tests { let blocks = random_block_range(&mut rng, stage_progress..=end, B256::ZERO, 0..3); - self.tx.insert_headers(blocks.iter().map(|block| &block.header))?; + self.db.insert_headers(blocks.iter().map(|block| &block.header))?; let iter = blocks.iter(); let mut next_tx_num = 0; @@ -502,7 +501,7 @@ mod tests { for progress in iter { // Insert last progress data let block_number = progress.number; - self.tx.commit(|tx| { + self.db.commit(|tx| { progress.body.iter().try_for_each( |transaction| -> Result<(), reth_db::DatabaseError> { tx.put::(transaction.hash(), next_tx_num)?; @@ -553,7 +552,8 @@ mod tests { first_tx_num = next_tx_num; - tx.put::(progress.number, body) + tx.put::(progress.number, body)?; + Ok(()) })?; } @@ -593,7 +593,7 @@ mod tests { } fn check_hashed_storage(&self) -> Result<(), TestRunnerError> { - self.tx + self.db .query(|tx| { let mut storage_cursor = tx.cursor_dup_read::()?; let mut hashed_storage_cursor = @@ -662,7 +662,7 @@ mod tests { fn unwind_storage(&self, input: UnwindInput) -> Result<(), TestRunnerError> { tracing::debug!("unwinding storage..."); let target_block = input.unwind_to; - self.tx.commit(|tx| { + self.db.commit(|tx| { let mut storage_cursor = tx.cursor_dup_write::()?; let mut changeset_cursor = tx.cursor_dup_read::()?; diff --git a/crates/stages/src/stages/headers.rs b/crates/stages/src/stages/headers.rs index e57b736d61e6..b57fcd279df9 100644 --- a/crates/stages/src/stages/headers.rs +++ b/crates/stages/src/stages/headers.rs @@ -7,33 +7,19 @@ use reth_db::{ transaction::{DbTx, DbTxMut}, }; use reth_interfaces::{ - p2p::headers::{ - downloader::{HeaderDownloader, SyncTarget}, - error::HeadersDownloaderError, - }, + p2p::headers::{downloader::HeaderDownloader, error::HeadersDownloaderError}, provider::ProviderError, }; use reth_primitives::{ stage::{ CheckpointBlockRange, EntitiesCheckpoint, HeadersCheckpoint, StageCheckpoint, StageId, }, - BlockHashOrNumber, BlockNumber, SealedHeader, B256, + BlockHashOrNumber, BlockNumber, SealedHeader, }; -use reth_provider::DatabaseProviderRW; -use tokio::sync::watch; +use reth_provider::{DatabaseProviderRW, HeaderSyncGap, HeaderSyncGapProvider, HeaderSyncMode}; +use std::task::{ready, Context, Poll}; use tracing::*; -/// The header sync mode. -#[derive(Debug)] -pub enum HeaderSyncMode { - /// A sync mode in which the stage continuously requests the downloader for - /// next blocks. - Continuous, - /// A sync mode in which the stage polls the receiver for the next tip - /// to download from. - Tip(watch::Receiver), -} - /// The headers stage. /// /// The headers stage downloads all block headers from the highest block in the local database to @@ -48,27 +34,33 @@ pub enum HeaderSyncMode { /// NOTE: This stage downloads headers in reverse. Upon returning the control flow to the pipeline, /// the stage checkpoint is not updated until this stage is done. #[derive(Debug)] -pub struct HeaderStage { +pub struct HeaderStage { + /// Database handle. + provider: Provider, /// Strategy for downloading the headers - downloader: D, + downloader: Downloader, /// The sync mode for the stage. mode: HeaderSyncMode, + /// Current sync gap. + sync_gap: Option, + /// Header buffer. + buffer: Option>, } // === impl HeaderStage === -impl HeaderStage +impl HeaderStage where - D: HeaderDownloader, + Downloader: HeaderDownloader, { /// Create a new header stage - pub fn new(downloader: D, mode: HeaderSyncMode) -> Self { - Self { downloader, mode } + pub fn new(database: Provider, downloader: Downloader, mode: HeaderSyncMode) -> Self { + Self { provider: database, downloader, mode, sync_gap: None, buffer: None } } fn is_stage_done( &self, - tx: &>::TXMut, + tx: &::TXMut, checkpoint: u64, ) -> Result { let mut header_cursor = tx.cursor_read::()?; @@ -79,75 +71,12 @@ where Ok(header_cursor.next()?.map(|(next_num, _)| head_num + 1 == next_num).unwrap_or_default()) } - /// Get the head and tip of the range we need to sync - /// - /// See also [SyncTarget] - async fn get_sync_gap( - &mut self, - provider: &DatabaseProviderRW<'_, &DB>, - checkpoint: u64, - ) -> Result { - // Create a cursor over canonical header hashes - let mut cursor = provider.tx_ref().cursor_read::()?; - let mut header_cursor = provider.tx_ref().cursor_read::()?; - - // Get head hash and reposition the cursor - let (head_num, head_hash) = cursor - .seek_exact(checkpoint)? - .ok_or_else(|| ProviderError::HeaderNotFound(checkpoint.into()))?; - - // Construct head - let (_, head) = header_cursor - .seek_exact(head_num)? - .ok_or_else(|| ProviderError::HeaderNotFound(head_num.into()))?; - let local_head = head.seal(head_hash); - - // Look up the next header - let next_header = cursor - .next()? - .map(|(next_num, next_hash)| -> Result { - let (_, next) = header_cursor - .seek_exact(next_num)? - .ok_or_else(|| ProviderError::HeaderNotFound(next_num.into()))?; - Ok(next.seal(next_hash)) - }) - .transpose()?; - - // Decide the tip or error out on invalid input. - // If the next element found in the cursor is not the "expected" next block per our current - // checkpoint, then there is a gap in the database and we should start downloading in - // reverse from there. Else, it should use whatever the forkchoice state reports. - let target = match next_header { - Some(header) if checkpoint + 1 != header.number => SyncTarget::Gap(header), - None => self - .next_sync_target(head_num) - .await - .ok_or(StageError::StageCheckpoint(checkpoint))?, - _ => return Err(StageError::StageCheckpoint(checkpoint)), - }; - - Ok(SyncGap { local_head, target }) - } - - async fn next_sync_target(&mut self, head: BlockNumber) -> Option { - match self.mode { - HeaderSyncMode::Tip(ref mut rx) => { - let tip = rx.wait_for(|tip| !tip.is_zero()).await.ok()?; - Some(SyncTarget::Tip(*tip)) - } - HeaderSyncMode::Continuous => { - trace!(target: "sync::stages::headers", head, "No next header found, using continuous sync strategy"); - Some(SyncTarget::TipNum(head + 1)) - } - } - } - /// Write downloaded headers to the given transaction /// /// Note: this writes the headers with rising block numbers. fn write_headers( &self, - tx: &>::TXMut, + tx: &::TXMut, headers: Vec, ) -> Result, StageError> { trace!(target: "sync::stages::headers", len = headers.len(), "writing headers"); @@ -178,10 +107,10 @@ where } } -#[async_trait::async_trait] -impl Stage for HeaderStage +impl Stage for HeaderStage where DB: Database, + Provider: HeaderSyncGapProvider, D: HeaderDownloader, { /// Return the id of the stage @@ -189,20 +118,28 @@ where StageId::Headers } - /// Download the headers in reverse order (falling block numbers) - /// starting from the tip of the chain - async fn execute( + fn poll_execute_ready( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + cx: &mut Context<'_>, input: ExecInput, - ) -> Result { - let tx = provider.tx_ref(); + ) -> Poll> { let current_checkpoint = input.checkpoint(); + // Return if buffer already has some items. + if self.buffer.is_some() { + // TODO: review + trace!( + target: "sync::stages::headers", + checkpoint = %current_checkpoint.block_number, + "Buffer is not empty" + ); + return Poll::Ready(Ok(())) + } + // Lookup the head and tip of the sync range - let gap = self.get_sync_gap(provider, current_checkpoint.block_number).await?; - let local_head = gap.local_head.number; + let gap = self.provider.sync_gap(self.mode.clone(), current_checkpoint.block_number)?; let tip = gap.target.tip(); + self.sync_gap = Some(gap.clone()); // Nothing to sync if gap.is_closed() { @@ -212,7 +149,7 @@ where target = ?tip, "Target block already reached" ); - return Ok(ExecOutput::done(current_checkpoint)) + return Poll::Ready(Ok(())) } debug!(target: "sync::stages::headers", ?tip, head = ?gap.local_head.hash(), "Commencing sync"); @@ -220,31 +157,45 @@ where // let the downloader know what to sync self.downloader.update_sync_gap(gap.local_head, gap.target); - // The downloader returns the headers in descending order starting from the tip - // down to the local head (latest block in db). - // Task downloader can return `None` only if the response relaying channel was closed. This - // is a fatal error to prevent the pipeline from running forever. - let downloaded_headers = match self.downloader.next().await { - Some(Ok(headers)) => headers, + let result = match ready!(self.downloader.poll_next_unpin(cx)) { + Some(Ok(headers)) => { + info!(target: "sync::stages::headers", len = headers.len(), "Received headers"); + self.buffer = Some(headers); + Ok(()) + } Some(Err(HeadersDownloaderError::DetachedHead { local_head, header, error })) => { error!(target: "sync::stages::headers", ?error, "Cannot attach header to head"); - return Err(StageError::DetachedHead { local_head, header, error }) + Err(StageError::DetachedHead { local_head, header, error }) } - None => return Err(StageError::ChannelClosed), + None => Err(StageError::ChannelClosed), }; + Poll::Ready(result) + } - info!(target: "sync::stages::headers", len = downloaded_headers.len(), "Received headers"); + /// Download the headers in reverse order (falling block numbers) + /// starting from the tip of the chain + fn execute( + &mut self, + provider: &DatabaseProviderRW, + input: ExecInput, + ) -> Result { + let current_checkpoint = input.checkpoint(); + let gap = self.sync_gap.clone().ok_or(StageError::MissingSyncGap)?; + if gap.is_closed() { + return Ok(ExecOutput::done(current_checkpoint)) + } + + let local_head = gap.local_head.number; + let tip = gap.target.tip(); + + let downloaded_headers = self.buffer.take().ok_or(StageError::MissingDownloadBuffer)?; let tip_block_number = match tip { // If tip is hash and it equals to the first downloaded header's hash, we can use // the block number of this header as tip. - BlockHashOrNumber::Hash(hash) => downloaded_headers.first().and_then(|header| { - if header.hash == hash { - Some(header.number) - } else { - None - } - }), + BlockHashOrNumber::Hash(hash) => downloaded_headers + .first() + .and_then(|header| (header.hash == hash).then_some(header.number)), // If tip is number, we can just grab it and not resolve using downloaded headers. BlockHashOrNumber::Number(number) => Some(number), }; @@ -254,13 +205,14 @@ where // syncing towards, we need to take into account already synced headers from the database. // It is `None`, if tip didn't change and we're still downloading headers for previously // calculated gap. + let tx = provider.tx_ref(); let target_block_number = if let Some(tip_block_number) = tip_block_number { let local_max_block_number = tx .cursor_read::()? .last()? .map(|(canonical_block, _)| canonical_block); - Some(tip_block_number.max(local_max_block_number.unwrap_or(tip_block_number))) + Some(tip_block_number.max(local_max_block_number.unwrap_or_default())) } else { None }; @@ -278,18 +230,17 @@ where // `target_block_number` is guaranteed to be `Some`, because on the first iteration // we download the header for missing tip and use its block number. _ => { + let target = target_block_number.expect("No downloaded header for tip found"); HeadersCheckpoint { block_range: CheckpointBlockRange { from: input.checkpoint().block_number, - to: target_block_number.expect("No downloaded header for tip found"), + to: target, }, progress: EntitiesCheckpoint { // Set processed to the local head block number + number // of block already filled in the gap. - processed: local_head + - (target_block_number.unwrap_or_default() - - tip_block_number.unwrap_or_default()), - total: target_block_number.expect("No downloaded header for tip found"), + processed: local_head + (target - tip_block_number.unwrap_or_default()), + total: target, }, } } @@ -326,12 +277,14 @@ where } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, input: UnwindInput, ) -> Result { - // TODO: handle bad block + self.buffer.take(); + self.sync_gap.take(); + provider.unwind_table_by_walker::( input.unwind_to + 1, )?; @@ -359,46 +312,22 @@ where } } -/// Represents a gap to sync: from `local_head` to `target` -#[derive(Debug)] -pub struct SyncGap { - /// The local head block. Represents lower bound of sync range. - pub local_head: SealedHeader, - - /// The sync target. Represents upper bound of sync range. - pub target: SyncTarget, -} - -// === impl SyncGap === - -impl SyncGap { - /// Returns `true` if the gap from the head to the target was closed - #[inline] - pub fn is_closed(&self) -> bool { - match self.target.tip() { - BlockHashOrNumber::Hash(hash) => self.local_head.hash() == hash, - BlockHashOrNumber::Number(num) => self.local_head.number == num, - } - } -} - #[cfg(test)] mod tests { - use super::*; use crate::test_utils::{ stage_test_suite, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner, }; use assert_matches::assert_matches; - use rand::Rng; - use reth_interfaces::test_utils::{generators, generators::random_header}; - use reth_primitives::{stage::StageUnitCheckpoint, B256, MAINNET}; + use reth_interfaces::test_utils::generators::random_header; + use reth_primitives::{stage::StageUnitCheckpoint, B256}; use reth_provider::ProviderFactory; use test_runner::HeadersTestRunner; mod test_runner { use super::*; - use crate::test_utils::{TestRunnerError, TestTransaction}; + use crate::test_utils::{TestRunnerError, TestStageDB}; + use reth_db::{test_utils::TempDatabase, DatabaseEnv}; use reth_downloaders::headers::reverse_headers::{ ReverseHeadersDownloader, ReverseHeadersDownloaderBuilder, }; @@ -409,12 +338,13 @@ mod tests { use reth_primitives::U256; use reth_provider::{BlockHashReader, BlockNumReader, HeaderProvider}; use std::sync::Arc; + use tokio::sync::watch; pub(crate) struct HeadersTestRunner { pub(crate) client: TestHeadersClient, channel: (watch::Sender, watch::Receiver), downloader_factory: Box D + Send + Sync + 'static>, - tx: TestTransaction, + db: TestStageDB, } impl Default for HeadersTestRunner { @@ -431,23 +361,24 @@ mod tests { 1000, ) }), - tx: TestTransaction::default(), + db: TestStageDB::default(), } } } impl StageTestRunner for HeadersTestRunner { - type S = HeaderStage; + type S = HeaderStage>>, D>; - fn tx(&self) -> &TestTransaction { - &self.tx + fn db(&self) -> &TestStageDB { + &self.db } fn stage(&self) -> Self::S { - HeaderStage { - mode: HeaderSyncMode::Tip(self.channel.1.clone()), - downloader: (*self.downloader_factory)(), - } + HeaderStage::new( + self.db.factory.clone(), + (*self.downloader_factory)(), + HeaderSyncMode::Tip(self.channel.1.clone()), + ) } } @@ -459,9 +390,10 @@ mod tests { let mut rng = generators::rng(); let start = input.checkpoint().block_number; let head = random_header(&mut rng, start, None); - self.tx.insert_headers(std::iter::once(&head))?; + self.db.insert_headers(std::iter::once(&head))?; // patch td table for `update_head` call - self.tx.commit(|tx| tx.put::(head.number, U256::ZERO.into()))?; + self.db + .commit(|tx| Ok(tx.put::(head.number, U256::ZERO.into())?))?; // use previous checkpoint as seed size let end = input.target.unwrap_or_default() + 1; @@ -484,7 +416,7 @@ mod tests { let initial_checkpoint = input.checkpoint().block_number; match output { Some(output) if output.checkpoint.block_number > initial_checkpoint => { - let provider = self.tx.factory.provider()?; + let provider = self.db.factory.provider()?; for block_num in (initial_checkpoint..output.checkpoint.block_number).rev() { // look up the header hash @@ -511,7 +443,7 @@ mod tests { headers.last().unwrap().hash() } else { let tip = random_header(&mut generators::rng(), 0, None); - self.tx.insert_headers(std::iter::once(&tip))?; + self.db.insert_headers(std::iter::once(&tip))?; tip.hash() }; self.send_tip(tip); @@ -536,7 +468,7 @@ mod tests { .stream_batch_size(500) .build(client.clone(), Arc::new(TestConsensus::default())) }), - tx: TestTransaction::default(), + db: TestStageDB::default(), } } } @@ -546,10 +478,10 @@ mod tests { &self, block: BlockNumber, ) -> Result<(), TestRunnerError> { - self.tx + self.db .ensure_no_entry_above_by_value::(block, |val| val)?; - self.tx.ensure_no_entry_above::(block, |key| key)?; - self.tx.ensure_no_entry_above::(block, |key| key)?; + self.db.ensure_no_entry_above::(block, |key| key)?; + self.db.ensure_no_entry_above::(block, |key| key)?; Ok(()) } @@ -599,65 +531,6 @@ mod tests { assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed"); } - /// Test the head and tip range lookup - #[tokio::test] - async fn head_and_tip_lookup() { - let runner = HeadersTestRunner::default(); - let factory = ProviderFactory::new(runner.tx().tx.as_ref(), MAINNET.clone()); - let provider = factory.provider_rw().unwrap(); - let tx = provider.tx_ref(); - let mut stage = runner.stage(); - - let mut rng = generators::rng(); - - let consensus_tip = rng.gen(); - runner.send_tip(consensus_tip); - - // Genesis - let checkpoint = 0; - let head = random_header(&mut rng, 0, None); - let gap_fill = random_header(&mut rng, 1, Some(head.hash())); - let gap_tip = random_header(&mut rng, 2, Some(gap_fill.hash())); - - // Empty database - assert_matches!( - stage.get_sync_gap(&provider, checkpoint).await, - Err(StageError::DatabaseIntegrity(ProviderError::HeaderNotFound(block_number))) - if block_number.as_number().unwrap() == checkpoint - ); - - // Checkpoint and no gap - tx.put::(head.number, head.hash()) - .expect("failed to write canonical"); - tx.put::(head.number, head.clone().unseal()) - .expect("failed to write header"); - - let gap = stage.get_sync_gap(&provider, checkpoint).await.unwrap(); - assert_eq!(gap.local_head, head); - assert_eq!(gap.target.tip(), consensus_tip.into()); - - // Checkpoint and gap - tx.put::(gap_tip.number, gap_tip.hash()) - .expect("failed to write canonical"); - tx.put::(gap_tip.number, gap_tip.clone().unseal()) - .expect("failed to write header"); - - let gap = stage.get_sync_gap(&provider, checkpoint).await.unwrap(); - assert_eq!(gap.local_head, head); - assert_eq!(gap.target.tip(), gap_tip.parent_hash.into()); - - // Checkpoint and gap closed - tx.put::(gap_fill.number, gap_fill.hash()) - .expect("failed to write canonical"); - tx.put::(gap_fill.number, gap_fill.clone().unseal()) - .expect("failed to write header"); - - assert_matches!( - stage.get_sync_gap(&provider, checkpoint).await, - Err(StageError::StageCheckpoint(_checkpoint)) if _checkpoint == checkpoint - ); - } - /// Execute the stage in two steps #[tokio::test] async fn execute_from_previous_checkpoint() { diff --git a/crates/stages/src/stages/index_account_history.rs b/crates/stages/src/stages/index_account_history.rs index 0945538c3dcb..355a63a7d5c2 100644 --- a/crates/stages/src/stages/index_account_history.rs +++ b/crates/stages/src/stages/index_account_history.rs @@ -35,7 +35,6 @@ impl Default for IndexAccountHistoryStage { } } -#[async_trait::async_trait] impl Stage for IndexAccountHistoryStage { /// Return the id of the stage fn id(&self) -> StageId { @@ -43,9 +42,9 @@ impl Stage for IndexAccountHistoryStage { } /// Execute the stage. - async fn execute( + fn execute( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, mut input: ExecInput, ) -> Result { if let Some((target_prunable_block, prune_mode)) = self @@ -86,9 +85,9 @@ impl Stage for IndexAccountHistoryStage { } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, input: UnwindInput, ) -> Result { let (range, unwind_progress, _) = @@ -106,7 +105,7 @@ mod tests { use super::*; use crate::test_utils::{ stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, TestRunnerError, - TestTransaction, UnwindStageTestRunner, + TestStageDB, UnwindStageTestRunner, }; use itertools::Itertools; use reth_db::{ @@ -123,8 +122,7 @@ mod tests { generators, generators::{random_block_range, random_changeset_range, random_contract_account_range}, }; - use reth_primitives::{address, Address, BlockNumber, PruneMode, B256, MAINNET}; - use reth_provider::ProviderFactory; + use reth_primitives::{address, Address, BlockNumber, PruneMode, B256}; use std::collections::BTreeMap; const ADDRESS: Address = address!("0000000000000000000000000000000000000001"); @@ -154,9 +152,9 @@ mod tests { .collect() } - fn partial_setup(tx: &TestTransaction) { + fn partial_setup(db: &TestStageDB) { // setup - tx.commit(|tx| { + db.commit(|tx| { // we just need first and last tx.put::( 0, @@ -178,26 +176,24 @@ mod tests { .unwrap() } - async fn run(tx: &TestTransaction, run_to: u64) { + fn run(db: &TestStageDB, run_to: u64) { let input = ExecInput { target: Some(run_to), ..Default::default() }; let mut stage = IndexAccountHistoryStage::default(); - let factory = ProviderFactory::new(tx.tx.as_ref(), MAINNET.clone()); - let provider = factory.provider_rw().unwrap(); - let out = stage.execute(&provider, input).await.unwrap(); + let provider = db.factory.provider_rw().unwrap(); + let out = stage.execute(&provider, input).unwrap(); assert_eq!(out, ExecOutput { checkpoint: StageCheckpoint::new(5), done: true }); provider.commit().unwrap(); } - async fn unwind(tx: &TestTransaction, unwind_from: u64, unwind_to: u64) { + fn unwind(db: &TestStageDB, unwind_from: u64, unwind_to: u64) { let input = UnwindInput { checkpoint: StageCheckpoint::new(unwind_from), unwind_to, ..Default::default() }; let mut stage = IndexAccountHistoryStage::default(); - let factory = ProviderFactory::new(tx.tx.as_ref(), MAINNET.clone()); - let provider = factory.provider_rw().unwrap(); - let out = stage.unwind(&provider, input).await.unwrap(); + let provider = db.factory.provider_rw().unwrap(); + let out = stage.unwind(&provider, input).unwrap(); assert_eq!(out, UnwindOutput { checkpoint: StageCheckpoint::new(unwind_to) }); provider.commit().unwrap(); } @@ -205,116 +201,116 @@ mod tests { #[tokio::test] async fn insert_index_to_empty() { // init - let tx = TestTransaction::default(); + let db = TestStageDB::default(); // setup - partial_setup(&tx); + partial_setup(&db); // run - run(&tx, 5).await; + run(&db, 5); // verify - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![4, 5])])); // unwind - unwind(&tx, 5, 0).await; + unwind(&db, 5, 0); // verify initial state - let table = tx.table::().unwrap(); + let table = db.table::().unwrap(); assert!(table.is_empty()); } #[tokio::test] async fn insert_index_to_not_empty_shard() { // init - let tx = TestTransaction::default(); + let db = TestStageDB::default(); // setup - partial_setup(&tx); - tx.commit(|tx| { + partial_setup(&db); + db.commit(|tx| { tx.put::(shard(u64::MAX), list(&[1, 2, 3])).unwrap(); Ok(()) }) .unwrap(); // run - run(&tx, 5).await; + run(&db, 5); // verify - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![1, 2, 3, 4, 5]),])); // unwind - unwind(&tx, 5, 0).await; + unwind(&db, 5, 0); // verify initial state - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![1, 2, 3]),])); } #[tokio::test] async fn insert_index_to_full_shard() { // init - let tx = TestTransaction::default(); + let db = TestStageDB::default(); let full_list = vec![3; NUM_OF_INDICES_IN_SHARD]; // setup - partial_setup(&tx); - tx.commit(|tx| { + partial_setup(&db); + db.commit(|tx| { tx.put::(shard(u64::MAX), list(&full_list)).unwrap(); Ok(()) }) .unwrap(); // run - run(&tx, 5).await; + run(&db, 5); // verify - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!( table, BTreeMap::from([(shard(3), full_list.clone()), (shard(u64::MAX), vec![4, 5])]) ); // unwind - unwind(&tx, 5, 0).await; + unwind(&db, 5, 0); // verify initial state - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!(table, BTreeMap::from([(shard(u64::MAX), full_list)])); } #[tokio::test] async fn insert_index_to_fill_shard() { // init - let tx = TestTransaction::default(); + let db = TestStageDB::default(); let mut close_full_list = vec![1; NUM_OF_INDICES_IN_SHARD - 2]; // setup - partial_setup(&tx); - tx.commit(|tx| { + partial_setup(&db); + db.commit(|tx| { tx.put::(shard(u64::MAX), list(&close_full_list)).unwrap(); Ok(()) }) .unwrap(); // run - run(&tx, 5).await; + run(&db, 5); // verify close_full_list.push(4); close_full_list.push(5); - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!(table, BTreeMap::from([(shard(u64::MAX), close_full_list.clone()),])); // unwind - unwind(&tx, 5, 0).await; + unwind(&db, 5, 0); // verify initial state close_full_list.pop(); close_full_list.pop(); - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!(table, BTreeMap::from([(shard(u64::MAX), close_full_list),])); // verify initial state @@ -323,46 +319,46 @@ mod tests { #[tokio::test] async fn insert_index_second_half_shard() { // init - let tx = TestTransaction::default(); + let db = TestStageDB::default(); let mut close_full_list = vec![1; NUM_OF_INDICES_IN_SHARD - 1]; // setup - partial_setup(&tx); - tx.commit(|tx| { + partial_setup(&db); + db.commit(|tx| { tx.put::(shard(u64::MAX), list(&close_full_list)).unwrap(); Ok(()) }) .unwrap(); // run - run(&tx, 5).await; + run(&db, 5); // verify close_full_list.push(4); - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!( table, BTreeMap::from([(shard(4), close_full_list.clone()), (shard(u64::MAX), vec![5])]) ); // unwind - unwind(&tx, 5, 0).await; + unwind(&db, 5, 0); // verify initial state close_full_list.pop(); - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!(table, BTreeMap::from([(shard(u64::MAX), close_full_list),])); } #[tokio::test] async fn insert_index_to_third_shard() { // init - let tx = TestTransaction::default(); + let db = TestStageDB::default(); let full_list = vec![1; NUM_OF_INDICES_IN_SHARD]; // setup - partial_setup(&tx); - tx.commit(|tx| { + partial_setup(&db); + db.commit(|tx| { tx.put::(shard(1), list(&full_list)).unwrap(); tx.put::(shard(2), list(&full_list)).unwrap(); tx.put::(shard(u64::MAX), list(&[2, 3])).unwrap(); @@ -370,10 +366,10 @@ mod tests { }) .unwrap(); - run(&tx, 5).await; + run(&db, 5); // verify - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!( table, BTreeMap::from([ @@ -384,10 +380,10 @@ mod tests { ); // unwind - unwind(&tx, 5, 0).await; + unwind(&db, 5, 0); // verify initial state - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!( table, BTreeMap::from([ @@ -401,10 +397,10 @@ mod tests { #[tokio::test] async fn insert_index_with_prune_mode() { // init - let tx = TestTransaction::default(); + let db = TestStageDB::default(); // setup - tx.commit(|tx| { + db.commit(|tx| { // we just need first and last tx.put::( 0, @@ -432,43 +428,42 @@ mod tests { prune_mode: Some(PruneMode::Before(36)), ..Default::default() }; - let factory = ProviderFactory::new(tx.tx.as_ref(), MAINNET.clone()); - let provider = factory.provider_rw().unwrap(); - let out = stage.execute(&provider, input).await.unwrap(); + let provider = db.factory.provider_rw().unwrap(); + let out = stage.execute(&provider, input).unwrap(); assert_eq!(out, ExecOutput { checkpoint: StageCheckpoint::new(20000), done: true }); provider.commit().unwrap(); // verify - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![36, 100])])); // unwind - unwind(&tx, 20000, 0).await; + unwind(&db, 20000, 0); // verify initial state - let table = tx.table::().unwrap(); + let table = db.table::().unwrap(); assert!(table.is_empty()); } stage_test_suite_ext!(IndexAccountHistoryTestRunner, index_account_history); struct IndexAccountHistoryTestRunner { - pub(crate) tx: TestTransaction, + pub(crate) db: TestStageDB, commit_threshold: u64, prune_mode: Option, } impl Default for IndexAccountHistoryTestRunner { fn default() -> Self { - Self { tx: TestTransaction::default(), commit_threshold: 1000, prune_mode: None } + Self { db: TestStageDB::default(), commit_threshold: 1000, prune_mode: None } } } impl StageTestRunner for IndexAccountHistoryTestRunner { type S = IndexAccountHistoryStage; - fn tx(&self) -> &TestTransaction { - &self.tx + fn db(&self) -> &TestStageDB { + &self.db } fn stage(&self) -> Self::S { @@ -501,7 +496,7 @@ mod tests { ); // add block changeset from block 1. - self.tx.insert_changesets(transitions, Some(start))?; + self.db.insert_changesets(transitions, Some(start))?; Ok(()) } @@ -523,7 +518,7 @@ mod tests { ExecOutput { checkpoint: StageCheckpoint::new(input.target()), done: true } ); - let provider = self.tx.inner(); + let provider = self.db.factory.provider()?; let mut changeset_cursor = provider.tx_ref().cursor_read::()?; @@ -569,7 +564,7 @@ mod tests { }; } - let table = cast(self.tx.table::().unwrap()); + let table = cast(self.db.table::().unwrap()); assert_eq!(table, result); } Ok(()) @@ -578,7 +573,7 @@ mod tests { impl UnwindStageTestRunner for IndexAccountHistoryTestRunner { fn validate_unwind(&self, _input: UnwindInput) -> Result<(), TestRunnerError> { - let table = self.tx.table::().unwrap(); + let table = self.db.table::().unwrap(); assert!(table.is_empty()); Ok(()) } diff --git a/crates/stages/src/stages/index_storage_history.rs b/crates/stages/src/stages/index_storage_history.rs index b1e27aed1809..c189a90c320b 100644 --- a/crates/stages/src/stages/index_storage_history.rs +++ b/crates/stages/src/stages/index_storage_history.rs @@ -34,7 +34,6 @@ impl Default for IndexStorageHistoryStage { } } -#[async_trait::async_trait] impl Stage for IndexStorageHistoryStage { /// Return the id of the stage fn id(&self) -> StageId { @@ -42,9 +41,9 @@ impl Stage for IndexStorageHistoryStage { } /// Execute the stage. - async fn execute( + fn execute( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, mut input: ExecInput, ) -> Result { if let Some((target_prunable_block, prune_mode)) = self @@ -84,9 +83,9 @@ impl Stage for IndexStorageHistoryStage { } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, input: UnwindInput, ) -> Result { let (range, unwind_progress, _) = @@ -103,7 +102,7 @@ mod tests { use super::*; use crate::test_utils::{ stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, TestRunnerError, - TestTransaction, UnwindStageTestRunner, + TestStageDB, UnwindStageTestRunner, }; use itertools::Itertools; use reth_db::{ @@ -122,9 +121,8 @@ mod tests { generators::{random_block_range, random_changeset_range, random_contract_account_range}, }; use reth_primitives::{ - address, b256, Address, BlockNumber, PruneMode, StorageEntry, B256, MAINNET, U256, + address, b256, Address, BlockNumber, PruneMode, StorageEntry, B256, U256, }; - use reth_provider::ProviderFactory; use std::collections::BTreeMap; const ADDRESS: Address = address!("0000000000000000000000000000000000000001"); @@ -164,9 +162,9 @@ mod tests { .collect() } - fn partial_setup(tx: &TestTransaction) { + fn partial_setup(db: &TestStageDB) { // setup - tx.commit(|tx| { + db.commit(|tx| { // we just need first and last tx.put::( 0, @@ -188,26 +186,24 @@ mod tests { .unwrap() } - async fn run(tx: &TestTransaction, run_to: u64) { + fn run(db: &TestStageDB, run_to: u64) { let input = ExecInput { target: Some(run_to), ..Default::default() }; let mut stage = IndexStorageHistoryStage::default(); - let factory = ProviderFactory::new(tx.tx.as_ref(), MAINNET.clone()); - let provider = factory.provider_rw().unwrap(); - let out = stage.execute(&provider, input).await.unwrap(); + let provider = db.factory.provider_rw().unwrap(); + let out = stage.execute(&provider, input).unwrap(); assert_eq!(out, ExecOutput { checkpoint: StageCheckpoint::new(5), done: true }); provider.commit().unwrap(); } - async fn unwind(tx: &TestTransaction, unwind_from: u64, unwind_to: u64) { + fn unwind(db: &TestStageDB, unwind_from: u64, unwind_to: u64) { let input = UnwindInput { checkpoint: StageCheckpoint::new(unwind_from), unwind_to, ..Default::default() }; let mut stage = IndexStorageHistoryStage::default(); - let factory = ProviderFactory::new(tx.tx.as_ref(), MAINNET.clone()); - let provider = factory.provider_rw().unwrap(); - let out = stage.unwind(&provider, input).await.unwrap(); + let provider = db.factory.provider_rw().unwrap(); + let out = stage.unwind(&provider, input).unwrap(); assert_eq!(out, UnwindOutput { checkpoint: StageCheckpoint::new(unwind_to) }); provider.commit().unwrap(); } @@ -215,119 +211,119 @@ mod tests { #[tokio::test] async fn insert_index_to_empty() { // init - let tx = TestTransaction::default(); + let db = TestStageDB::default(); // setup - partial_setup(&tx); + partial_setup(&db); // run - run(&tx, 5).await; + run(&db, 5); // verify - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![4, 5]),])); // unwind - unwind(&tx, 5, 0).await; + unwind(&db, 5, 0); // verify initial state - let table = tx.table::().unwrap(); + let table = db.table::().unwrap(); assert!(table.is_empty()); } #[tokio::test] async fn insert_index_to_not_empty_shard() { // init - let tx = TestTransaction::default(); + let db = TestStageDB::default(); // setup - partial_setup(&tx); - tx.commit(|tx| { + partial_setup(&db); + db.commit(|tx| { tx.put::(shard(u64::MAX), list(&[1, 2, 3])).unwrap(); Ok(()) }) .unwrap(); // run - run(&tx, 5).await; + run(&db, 5); // verify - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![1, 2, 3, 4, 5]),])); // unwind - unwind(&tx, 5, 0).await; + unwind(&db, 5, 0); // verify initial state - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![1, 2, 3]),])); } #[tokio::test] async fn insert_index_to_full_shard() { // init - let tx = TestTransaction::default(); + let db = TestStageDB::default(); let _input = ExecInput { target: Some(5), ..Default::default() }; // change does not matter only that account is present in changeset. let full_list = vec![3; NUM_OF_INDICES_IN_SHARD]; // setup - partial_setup(&tx); - tx.commit(|tx| { + partial_setup(&db); + db.commit(|tx| { tx.put::(shard(u64::MAX), list(&full_list)).unwrap(); Ok(()) }) .unwrap(); // run - run(&tx, 5).await; + run(&db, 5); // verify - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!( table, BTreeMap::from([(shard(3), full_list.clone()), (shard(u64::MAX), vec![4, 5])]) ); // unwind - unwind(&tx, 5, 0).await; + unwind(&db, 5, 0); // verify initial state - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!(table, BTreeMap::from([(shard(u64::MAX), full_list)])); } #[tokio::test] async fn insert_index_to_fill_shard() { // init - let tx = TestTransaction::default(); + let db = TestStageDB::default(); let mut close_full_list = vec![1; NUM_OF_INDICES_IN_SHARD - 2]; // setup - partial_setup(&tx); - tx.commit(|tx| { + partial_setup(&db); + db.commit(|tx| { tx.put::(shard(u64::MAX), list(&close_full_list)).unwrap(); Ok(()) }) .unwrap(); // run - run(&tx, 5).await; + run(&db, 5); // verify close_full_list.push(4); close_full_list.push(5); - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!(table, BTreeMap::from([(shard(u64::MAX), close_full_list.clone()),])); // unwind - unwind(&tx, 5, 0).await; + unwind(&db, 5, 0); // verify initial state close_full_list.pop(); close_full_list.pop(); - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!(table, BTreeMap::from([(shard(u64::MAX), close_full_list),])); // verify initial state @@ -336,46 +332,46 @@ mod tests { #[tokio::test] async fn insert_index_second_half_shard() { // init - let tx = TestTransaction::default(); + let db = TestStageDB::default(); let mut close_full_list = vec![1; NUM_OF_INDICES_IN_SHARD - 1]; // setup - partial_setup(&tx); - tx.commit(|tx| { + partial_setup(&db); + db.commit(|tx| { tx.put::(shard(u64::MAX), list(&close_full_list)).unwrap(); Ok(()) }) .unwrap(); // run - run(&tx, 5).await; + run(&db, 5); // verify close_full_list.push(4); - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!( table, BTreeMap::from([(shard(4), close_full_list.clone()), (shard(u64::MAX), vec![5])]) ); // unwind - unwind(&tx, 5, 0).await; + unwind(&db, 5, 0); // verify initial state close_full_list.pop(); - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!(table, BTreeMap::from([(shard(u64::MAX), close_full_list),])); } #[tokio::test] async fn insert_index_to_third_shard() { // init - let tx = TestTransaction::default(); + let db = TestStageDB::default(); let full_list = vec![1; NUM_OF_INDICES_IN_SHARD]; // setup - partial_setup(&tx); - tx.commit(|tx| { + partial_setup(&db); + db.commit(|tx| { tx.put::(shard(1), list(&full_list)).unwrap(); tx.put::(shard(2), list(&full_list)).unwrap(); tx.put::(shard(u64::MAX), list(&[2, 3])).unwrap(); @@ -383,10 +379,10 @@ mod tests { }) .unwrap(); - run(&tx, 5).await; + run(&db, 5); // verify - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!( table, BTreeMap::from([ @@ -397,10 +393,10 @@ mod tests { ); // unwind - unwind(&tx, 5, 0).await; + unwind(&db, 5, 0); // verify initial state - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!( table, BTreeMap::from([ @@ -414,10 +410,10 @@ mod tests { #[tokio::test] async fn insert_index_with_prune_mode() { // init - let tx = TestTransaction::default(); + let db = TestStageDB::default(); // setup - tx.commit(|tx| { + db.commit(|tx| { // we just need first and last tx.put::( 0, @@ -445,43 +441,42 @@ mod tests { prune_mode: Some(PruneMode::Before(36)), ..Default::default() }; - let factory = ProviderFactory::new(tx.tx.as_ref(), MAINNET.clone()); - let provider = factory.provider_rw().unwrap(); - let out = stage.execute(&provider, input).await.unwrap(); + let provider = db.factory.provider_rw().unwrap(); + let out = stage.execute(&provider, input).unwrap(); assert_eq!(out, ExecOutput { checkpoint: StageCheckpoint::new(20000), done: true }); provider.commit().unwrap(); // verify - let table = cast(tx.table::().unwrap()); + let table = cast(db.table::().unwrap()); assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![36, 100]),])); // unwind - unwind(&tx, 20000, 0).await; + unwind(&db, 20000, 0); // verify initial state - let table = tx.table::().unwrap(); + let table = db.table::().unwrap(); assert!(table.is_empty()); } stage_test_suite_ext!(IndexStorageHistoryTestRunner, index_storage_history); struct IndexStorageHistoryTestRunner { - pub(crate) tx: TestTransaction, + pub(crate) db: TestStageDB, commit_threshold: u64, prune_mode: Option, } impl Default for IndexStorageHistoryTestRunner { fn default() -> Self { - Self { tx: TestTransaction::default(), commit_threshold: 1000, prune_mode: None } + Self { db: TestStageDB::default(), commit_threshold: 1000, prune_mode: None } } } impl StageTestRunner for IndexStorageHistoryTestRunner { type S = IndexStorageHistoryStage; - fn tx(&self) -> &TestTransaction { - &self.tx + fn db(&self) -> &TestStageDB { + &self.db } fn stage(&self) -> Self::S { @@ -514,7 +509,7 @@ mod tests { ); // add block changeset from block 1. - self.tx.insert_changesets(transitions, Some(start))?; + self.db.insert_changesets(transitions, Some(start))?; Ok(()) } @@ -536,7 +531,7 @@ mod tests { ExecOutput { checkpoint: StageCheckpoint::new(input.target()), done: true } ); - let provider = self.tx.inner(); + let provider = self.db.factory.provider()?; let mut changeset_cursor = provider.tx_ref().cursor_read::()?; @@ -587,7 +582,7 @@ mod tests { }; } - let table = cast(self.tx.table::().unwrap()); + let table = cast(self.db.table::().unwrap()); assert_eq!(table, result); } Ok(()) @@ -596,7 +591,7 @@ mod tests { impl UnwindStageTestRunner for IndexStorageHistoryTestRunner { fn validate_unwind(&self, _input: UnwindInput) -> Result<(), TestRunnerError> { - let table = self.tx.table::().unwrap(); + let table = self.db.table::().unwrap(); assert!(table.is_empty()); Ok(()) } diff --git a/crates/stages/src/stages/merkle.rs b/crates/stages/src/stages/merkle.rs index b65ab48d1b76..e4311f163476 100644 --- a/crates/stages/src/stages/merkle.rs +++ b/crates/stages/src/stages/merkle.rs @@ -79,7 +79,7 @@ impl MerkleStage { /// Gets the hashing progress pub fn get_execution_checkpoint( &self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, ) -> Result, StageError> { let buf = provider.get_stage_checkpoint_progress(StageId::MerkleExecute)?.unwrap_or_default(); @@ -95,7 +95,7 @@ impl MerkleStage { /// Saves the hashing progress pub fn save_execution_checkpoint( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, checkpoint: Option, ) -> Result<(), StageError> { let mut buf = vec![]; @@ -111,7 +111,6 @@ impl MerkleStage { } } -#[async_trait::async_trait] impl Stage for MerkleStage { /// Return the id of the stage fn id(&self) -> StageId { @@ -124,9 +123,9 @@ impl Stage for MerkleStage { } /// Execute the stage. - async fn execute( + fn execute( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, input: ExecInput, ) -> Result { let threshold = match self { @@ -257,9 +256,9 @@ impl Stage for MerkleStage { } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, input: UnwindInput, ) -> Result { let tx = provider.tx_ref(); @@ -336,7 +335,7 @@ mod tests { use super::*; use crate::test_utils::{ stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, TestRunnerError, - TestTransaction, UnwindStageTestRunner, + TestStageDB, UnwindStageTestRunner, }; use assert_matches::assert_matches; use reth_db::{ @@ -390,8 +389,8 @@ mod tests { done: true }) if block_number == previous_stage && processed == total && total == ( - runner.tx.table::().unwrap().len() + - runner.tx.table::().unwrap().len() + runner.db.table::().unwrap().len() + + runner.db.table::().unwrap().len() ) as u64 ); @@ -430,8 +429,8 @@ mod tests { done: true }) if block_number == previous_stage && processed == total && total == ( - runner.tx.table::().unwrap().len() + - runner.tx.table::().unwrap().len() + runner.db.table::().unwrap().len() + + runner.db.table::().unwrap().len() ) as u64 ); @@ -440,21 +439,21 @@ mod tests { } struct MerkleTestRunner { - tx: TestTransaction, + db: TestStageDB, clean_threshold: u64, } impl Default for MerkleTestRunner { fn default() -> Self { - Self { tx: TestTransaction::default(), clean_threshold: 10000 } + Self { db: TestStageDB::default(), clean_threshold: 10000 } } } impl StageTestRunner for MerkleTestRunner { type S = MerkleStage; - fn tx(&self) -> &TestTransaction { - &self.tx + fn db(&self) -> &TestStageDB { + &self.db } fn stage(&self) -> Self::S { @@ -477,7 +476,7 @@ mod tests { .into_iter() .collect::>(); - self.tx.insert_accounts_and_storages( + self.db.insert_accounts_and_storages( accounts.iter().map(|(addr, acc)| (*addr, (*acc, std::iter::empty()))), )?; @@ -496,7 +495,7 @@ mod tests { let head_hash = sealed_head.hash(); let mut blocks = vec![sealed_head]; blocks.extend(random_block_range(&mut rng, start..=end, head_hash, 0..3)); - self.tx.insert_blocks(blocks.iter(), None)?; + self.db.insert_blocks(blocks.iter(), None)?; let (transitions, final_state) = random_changeset_range( &mut rng, @@ -506,11 +505,11 @@ mod tests { 0..256, ); // add block changeset from block 1. - self.tx.insert_changesets(transitions, Some(start))?; - self.tx.insert_accounts_and_storages(final_state)?; + self.db.insert_changesets(transitions, Some(start))?; + self.db.insert_accounts_and_storages(final_state)?; // Calculate state root - let root = self.tx.query(|tx| { + let root = self.db.query(|tx| { let mut accounts = BTreeMap::default(); let mut accounts_cursor = tx.cursor_read::()?; let mut storage_cursor = tx.cursor_dup_read::()?; @@ -534,10 +533,11 @@ mod tests { })?; let last_block_number = end; - self.tx.commit(|tx| { + self.db.commit(|tx| { let mut last_header = tx.get::(last_block_number)?.unwrap(); last_header.state_root = root; - tx.put::(last_block_number, last_header) + tx.put::(last_block_number, last_header)?; + Ok(()) })?; Ok(blocks) @@ -562,7 +562,7 @@ mod tests { fn before_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> { let target_block = input.unwind_to + 1; - self.tx + self.db .commit(|tx| { let mut storage_changesets_cursor = tx.cursor_dup_read::().unwrap(); diff --git a/crates/stages/src/stages/mod.rs b/crates/stages/src/stages/mod.rs index c0173747a8ec..ffe8ae1da1f6 100644 --- a/crates/stages/src/stages/mod.rs +++ b/crates/stages/src/stages/mod.rs @@ -42,7 +42,7 @@ mod tests { use crate::{ stage::Stage, stages::{ExecutionStage, IndexAccountHistoryStage, IndexStorageHistoryStage}, - test_utils::TestTransaction, + test_utils::TestStageDB, ExecInput, }; use alloy_rlp::Decodable; @@ -50,36 +50,35 @@ mod tests { cursor::DbCursorRO, mdbx::{cursor::Cursor, RW}, tables, + test_utils::TempDatabase, transaction::{DbTx, DbTxMut}, AccountHistory, DatabaseEnv, }; use reth_interfaces::test_utils::generators::{self, random_block}; use reth_primitives::{ address, hex_literal::hex, keccak256, Account, Bytecode, ChainSpecBuilder, PruneMode, - PruneModes, SealedBlock, MAINNET, U256, + PruneModes, SealedBlock, U256, }; use reth_provider::{ - AccountExtReader, BlockWriter, DatabaseProviderRW, ProviderFactory, ReceiptProvider, - StorageReader, + AccountExtReader, BlockWriter, ProviderFactory, ReceiptProvider, StorageReader, }; - use reth_revm::Factory; + use reth_revm::EvmProcessorFactory; use std::sync::Arc; #[tokio::test] #[ignore] async fn test_prune() { - let test_tx = TestTransaction::default(); - let factory = Arc::new(ProviderFactory::new(test_tx.tx.db(), MAINNET.clone())); + let test_db = TestStageDB::default(); - let provider = factory.provider_rw().unwrap(); + let provider_rw = test_db.factory.provider_rw().unwrap(); let tip = 66; let input = ExecInput { target: Some(tip), checkpoint: None }; let mut genesis_rlp = hex!("f901faf901f5a00000000000000000000000000000000000000000000000000000000000000000a01dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347942adc25665018aa1fe0e6bc666dac8fc2697ff9baa045571b40ae66ca7480791bbb2887286e4e4c4b1b298b191c889d6959023a32eda056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421b901000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000083020000808502540be400808000a00000000000000000000000000000000000000000000000000000000000000000880000000000000000c0c0").as_slice(); let genesis = SealedBlock::decode(&mut genesis_rlp).unwrap(); let mut block_rlp = hex!("f90262f901f9a075c371ba45999d87f4542326910a11af515897aebce5265d3f6acd1f1161f82fa01dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347942adc25665018aa1fe0e6bc666dac8fc2697ff9baa098f2dcd87c8ae4083e7017a05456c14eea4b1db2032126e27b3b1563d57d7cc0a08151d548273f6683169524b66ca9fe338b9ce42bc3540046c828fd939ae23bcba03f4e5c2ec5b2170b711d97ee755c160457bb58d8daa338e835ec02ae6860bbabb901000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000083020000018502540be40082a8798203e800a00000000000000000000000000000000000000000000000000000000000000000880000000000000000f863f861800a8405f5e10094100000000000000000000000000000000000000080801ba07e09e26678ed4fac08a249ebe8ed680bf9051a5e14ad223e4b2b9d26e0208f37a05f6e3f188e3e6eab7d7d3b6568f5eac7d687b08d307d3154ccd8c87b4630509bc0").as_slice(); let block = SealedBlock::decode(&mut block_rlp).unwrap(); - provider.insert_block(genesis, None, None).unwrap(); - provider.insert_block(block.clone(), None, None).unwrap(); + provider_rw.insert_block(genesis, None, None).unwrap(); + provider_rw.insert_block(block.clone(), None, None).unwrap(); // Fill with bogus blocks to respect PruneMode distance. let mut head = block.hash; @@ -87,22 +86,22 @@ mod tests { for block_number in 2..=tip { let nblock = random_block(&mut rng, block_number, Some(head), Some(0), Some(0)); head = nblock.hash; - provider.insert_block(nblock, None, None).unwrap(); + provider_rw.insert_block(nblock, None, None).unwrap(); } - provider.commit().unwrap(); + provider_rw.commit().unwrap(); // insert pre state - let provider = factory.provider_rw().unwrap(); + let provider_rw = test_db.factory.provider_rw().unwrap(); let code = hex!("5a465a905090036002900360015500"); let code_hash = keccak256(hex!("5a465a905090036002900360015500")); - provider + provider_rw .tx_ref() .put::( address!("1000000000000000000000000000000000000000"), Account { nonce: 0, balance: U256::ZERO, bytecode_hash: Some(code_hash) }, ) .unwrap(); - provider + provider_rw .tx_ref() .put::( address!("a94f5374fce5edbc8e2a8697c15331677e6ebf0b"), @@ -113,23 +112,25 @@ mod tests { }, ) .unwrap(); - provider + provider_rw .tx_ref() .put::(code_hash, Bytecode::new_raw(code.to_vec().into())) .unwrap(); - provider.commit().unwrap(); + provider_rw.commit().unwrap(); - let check_pruning = |factory: Arc>, + let check_pruning = |factory: ProviderFactory>>, prune_modes: PruneModes, expect_num_receipts: usize, expect_num_acc_changesets: usize, expect_num_storage_changesets: usize| async move { - let provider: DatabaseProviderRW<'_, &DatabaseEnv> = factory.provider_rw().unwrap(); + let provider = factory.provider_rw().unwrap(); // Check execution and create receipts and changesets according to the pruning // configuration let mut execution_stage = ExecutionStage::new( - Factory::new(Arc::new(ChainSpecBuilder::mainnet().berlin_activated().build())), + EvmProcessorFactory::new(Arc::new( + ChainSpecBuilder::mainnet().berlin_activated().build(), + )), ExecutionStageThresholds { max_blocks: Some(100), max_changes: None, @@ -139,7 +140,7 @@ mod tests { prune_modes.clone(), ); - execution_stage.execute(&provider, input).await.unwrap(); + execution_stage.execute(&provider, input).unwrap(); assert_eq!( provider.receipts_by_block(1.into()).unwrap().unwrap().len(), expect_num_receipts @@ -163,10 +164,10 @@ mod tests { if let Some(PruneMode::Full) = prune_modes.account_history { // Full is not supported - assert!(acc_indexing_stage.execute(&provider, input).await.is_err()); + assert!(acc_indexing_stage.execute(&provider, input).is_err()); } else { - acc_indexing_stage.execute(&provider, input).await.unwrap(); - let mut account_history: Cursor<'_, RW, AccountHistory> = + acc_indexing_stage.execute(&provider, input).unwrap(); + let mut account_history: Cursor = provider.tx_ref().cursor_read::().unwrap(); assert_eq!(account_history.walk(None).unwrap().count(), expect_num_acc_changesets); } @@ -179,9 +180,9 @@ mod tests { if let Some(PruneMode::Full) = prune_modes.storage_history { // Full is not supported - assert!(acc_indexing_stage.execute(&provider, input).await.is_err()); + assert!(acc_indexing_stage.execute(&provider, input).is_err()); } else { - storage_indexing_stage.execute(&provider, input).await.unwrap(); + storage_indexing_stage.execute(&provider, input).unwrap(); let mut storage_history = provider.tx_ref().cursor_read::().unwrap(); @@ -195,34 +196,34 @@ mod tests { // In an unpruned configuration there is 1 receipt, 3 changed accounts and 1 changed // storage. let mut prune = PruneModes::none(); - check_pruning(factory.clone(), prune.clone(), 1, 3, 1).await; + check_pruning(test_db.factory.clone(), prune.clone(), 1, 3, 1).await; prune.receipts = Some(PruneMode::Full); prune.account_history = Some(PruneMode::Full); prune.storage_history = Some(PruneMode::Full); // This will result in error for account_history and storage_history, which is caught. - check_pruning(factory.clone(), prune.clone(), 0, 0, 0).await; + check_pruning(test_db.factory.clone(), prune.clone(), 0, 0, 0).await; prune.receipts = Some(PruneMode::Before(1)); prune.account_history = Some(PruneMode::Before(1)); prune.storage_history = Some(PruneMode::Before(1)); - check_pruning(factory.clone(), prune.clone(), 1, 3, 1).await; + check_pruning(test_db.factory.clone(), prune.clone(), 1, 3, 1).await; prune.receipts = Some(PruneMode::Before(2)); prune.account_history = Some(PruneMode::Before(2)); prune.storage_history = Some(PruneMode::Before(2)); // The one account is the miner - check_pruning(factory.clone(), prune.clone(), 0, 1, 0).await; + check_pruning(test_db.factory.clone(), prune.clone(), 0, 1, 0).await; prune.receipts = Some(PruneMode::Distance(66)); prune.account_history = Some(PruneMode::Distance(66)); prune.storage_history = Some(PruneMode::Distance(66)); - check_pruning(factory.clone(), prune.clone(), 1, 3, 1).await; + check_pruning(test_db.factory.clone(), prune.clone(), 1, 3, 1).await; prune.receipts = Some(PruneMode::Distance(64)); prune.account_history = Some(PruneMode::Distance(64)); prune.storage_history = Some(PruneMode::Distance(64)); // The one account is the miner - check_pruning(factory.clone(), prune.clone(), 0, 1, 0).await; + check_pruning(test_db.factory.clone(), prune.clone(), 0, 1, 0).await; } } diff --git a/crates/stages/src/stages/sender_recovery.rs b/crates/stages/src/stages/sender_recovery.rs index 80ffb040a057..a7b19ca57a2d 100644 --- a/crates/stages/src/stages/sender_recovery.rs +++ b/crates/stages/src/stages/sender_recovery.rs @@ -16,9 +16,8 @@ use reth_primitives::{ use reth_provider::{ BlockReader, DatabaseProviderRW, HeaderProvider, ProviderError, PruneCheckpointReader, }; -use std::fmt::Debug; +use std::{fmt::Debug, sync::mpsc}; use thiserror::Error; -use tokio::sync::mpsc; use tracing::*; /// The sender recovery stage iterates over existing transactions, @@ -44,7 +43,6 @@ impl Default for SenderRecoveryStage { } } -#[async_trait::async_trait] impl Stage for SenderRecoveryStage { /// Return the id of the stage fn id(&self) -> StageId { @@ -56,9 +54,9 @@ impl Stage for SenderRecoveryStage { /// collect transactions within that range, /// recover signer for each transaction and store entries in /// the [`TxSenders`][reth_db::tables::TxSenders] table. - async fn execute( + fn execute( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, input: ExecInput, ) -> Result { if input.target_reached() { @@ -110,7 +108,7 @@ impl Stage for SenderRecoveryStage { for chunk in &tx_walker.chunks(chunk_size) { // An _unordered_ channel to receive results from a rayon job - let (recovered_senders_tx, recovered_senders_rx) = mpsc::unbounded_channel(); + let (recovered_senders_tx, recovered_senders_rx) = mpsc::channel(); channels.push(recovered_senders_rx); // Note: Unfortunate side-effect of how chunk is designed in itertools (it is not Send) let chunk: Vec<_> = chunk.collect(); @@ -128,8 +126,8 @@ impl Stage for SenderRecoveryStage { } // Iterate over channels and append the sender in the order that they are received. - for mut channel in channels { - while let Some(recovered) = channel.recv().await { + for channel in channels { + while let Ok(recovered) = channel.recv() { let (tx_id, sender) = match recovered { Ok(result) => result, Err(error) => { @@ -168,9 +166,9 @@ impl Stage for SenderRecoveryStage { } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, input: UnwindInput, ) -> Result { let (_, unwind_to, _) = input.unwind_block_range_with_threshold(self.commit_threshold); @@ -209,7 +207,7 @@ fn recover_sender( } fn stage_checkpoint( - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, ) -> Result { let pruned_entries = provider .get_prune_checkpoint(PruneSegment::SenderRecovery)? @@ -252,14 +250,14 @@ mod tests { }; use reth_primitives::{ stage::StageUnitCheckpoint, BlockNumber, PruneCheckpoint, PruneMode, SealedBlock, - TransactionSigned, B256, MAINNET, + TransactionSigned, B256, }; - use reth_provider::{ProviderFactory, PruneCheckpointWriter, TransactionsProvider}; + use reth_provider::{PruneCheckpointWriter, TransactionsProvider}; use super::*; use crate::test_utils::{ stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, TestRunnerError, - TestTransaction, UnwindStageTestRunner, + TestStageDB, UnwindStageTestRunner, }; stage_test_suite_ext!(SenderRecoveryTestRunner, sender_recovery); @@ -290,7 +288,7 @@ mod tests { ) }) .collect::>(); - runner.tx.insert_blocks(blocks.iter(), None).expect("failed to insert blocks"); + runner.db.insert_blocks(blocks.iter(), None).expect("failed to insert blocks"); let rx = runner.execute(input); @@ -324,9 +322,9 @@ mod tests { // Manually seed once with full input range let seed = random_block_range(&mut rng, stage_progress + 1..=previous_stage, B256::ZERO, 0..4); // set tx count range high enough to hit the threshold - runner.tx.insert_blocks(seed.iter(), None).expect("failed to seed execution"); + runner.db.insert_blocks(seed.iter(), None).expect("failed to seed execution"); - let total_transactions = runner.tx.table::().unwrap().len() as u64; + let total_transactions = runner.db.table::().unwrap().len() as u64; let first_input = ExecInput { target: Some(previous_stage), @@ -350,7 +348,7 @@ mod tests { ExecOutput { checkpoint: StageCheckpoint::new(expected_progress).with_entities_stage_checkpoint( EntitiesCheckpoint { - processed: runner.tx.table::().unwrap().len() as u64, + processed: runner.db.table::().unwrap().len() as u64, total: total_transactions } ), @@ -381,11 +379,11 @@ mod tests { #[test] fn stage_checkpoint_pruned() { - let tx = TestTransaction::default(); + let db = TestStageDB::default(); let mut rng = generators::rng(); let blocks = random_block_range(&mut rng, 0..=100, B256::ZERO, 0..10); - tx.insert_blocks(blocks.iter(), None).expect("insert blocks"); + db.insert_blocks(blocks.iter(), None).expect("insert blocks"); let max_pruned_block = 30; let max_processed_block = 70; @@ -401,9 +399,9 @@ mod tests { tx_number += 1; } } - tx.insert_transaction_senders(tx_senders).expect("insert tx hash numbers"); + db.insert_transaction_senders(tx_senders).expect("insert tx hash numbers"); - let provider = tx.inner_rw(); + let provider = db.factory.provider_rw().unwrap(); provider .save_prune_checkpoint( PruneSegment::SenderRecovery, @@ -421,10 +419,7 @@ mod tests { .expect("save stage checkpoint"); provider.commit().expect("commit"); - let db = tx.inner_raw(); - let factory = ProviderFactory::new(db.as_ref(), MAINNET.clone()); - let provider = factory.provider_rw().expect("provider rw"); - + let provider = db.factory.provider_rw().unwrap(); assert_eq!( stage_checkpoint(&provider).expect("stage checkpoint"), EntitiesCheckpoint { @@ -438,13 +433,13 @@ mod tests { } struct SenderRecoveryTestRunner { - tx: TestTransaction, + db: TestStageDB, threshold: u64, } impl Default for SenderRecoveryTestRunner { fn default() -> Self { - Self { threshold: 1000, tx: TestTransaction::default() } + Self { threshold: 1000, db: TestStageDB::default() } } } @@ -461,16 +456,17 @@ mod tests { /// not empty. fn ensure_no_senders_by_block(&self, block: BlockNumber) -> Result<(), TestRunnerError> { let body_result = self - .tx - .inner_rw() + .db + .factory + .provider_rw()? .block_body_indices(block)? .ok_or(ProviderError::BlockBodyIndicesNotFound(block)); match body_result { Ok(body) => self - .tx + .db .ensure_no_entry_above::(body.last_tx_num(), |key| key)?, Err(_) => { - assert!(self.tx.table_is_empty::()?); + assert!(self.db.table_is_empty::()?); } }; @@ -481,8 +477,8 @@ mod tests { impl StageTestRunner for SenderRecoveryTestRunner { type S = SenderRecoveryStage; - fn tx(&self) -> &TestTransaction { - &self.tx + fn db(&self) -> &TestStageDB { + &self.db } fn stage(&self) -> Self::S { @@ -499,7 +495,7 @@ mod tests { let end = input.target(); let blocks = random_block_range(&mut rng, stage_progress..=end, B256::ZERO, 0..2); - self.tx.insert_blocks(blocks.iter(), None)?; + self.db.insert_blocks(blocks.iter(), None)?; Ok(blocks) } @@ -510,7 +506,7 @@ mod tests { ) -> Result<(), TestRunnerError> { match output { Some(output) => { - let provider = self.tx.inner(); + let provider = self.db.factory.provider()?; let start_block = input.next_block(); let end_block = output.checkpoint.block_number; diff --git a/crates/stages/src/stages/total_difficulty.rs b/crates/stages/src/stages/total_difficulty.rs index ea1e20630d4a..d523cf4ce850 100644 --- a/crates/stages/src/stages/total_difficulty.rs +++ b/crates/stages/src/stages/total_difficulty.rs @@ -41,7 +41,6 @@ impl TotalDifficultyStage { } } -#[async_trait::async_trait] impl Stage for TotalDifficultyStage { /// Return the id of the stage fn id(&self) -> StageId { @@ -49,9 +48,9 @@ impl Stage for TotalDifficultyStage { } /// Write total difficulty entries - async fn execute( + fn execute( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, input: ExecInput, ) -> Result { let tx = provider.tx_ref(); @@ -99,9 +98,9 @@ impl Stage for TotalDifficultyStage { } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, input: UnwindInput, ) -> Result { let (_, unwind_to, _) = input.unwind_block_range_with_threshold(self.commit_threshold); @@ -116,7 +115,7 @@ impl Stage for TotalDifficultyStage { } fn stage_checkpoint( - provider: &DatabaseProviderRW<'_, DB>, + provider: &DatabaseProviderRW, ) -> Result { Ok(EntitiesCheckpoint { processed: provider.tx_ref().entries::()? as u64, @@ -139,7 +138,7 @@ mod tests { use super::*; use crate::test_utils::{ stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, TestRunnerError, - TestTransaction, UnwindStageTestRunner, + TestStageDB, UnwindStageTestRunner, }; stage_test_suite_ext!(TotalDifficultyTestRunner, total_difficulty); @@ -172,7 +171,7 @@ mod tests { total })) }, done: false }) if block_number == expected_progress && processed == 1 + threshold && - total == runner.tx.table::().unwrap().len() as u64 + total == runner.db.table::().unwrap().len() as u64 ); // Execute second time @@ -190,14 +189,14 @@ mod tests { total })) }, done: true }) if block_number == previous_stage && processed == total && - total == runner.tx.table::().unwrap().len() as u64 + total == runner.db.table::().unwrap().len() as u64 ); assert!(runner.validate_execution(first_input, result.ok()).is_ok(), "validation failed"); } struct TotalDifficultyTestRunner { - tx: TestTransaction, + db: TestStageDB, consensus: Arc, commit_threshold: u64, } @@ -205,7 +204,7 @@ mod tests { impl Default for TotalDifficultyTestRunner { fn default() -> Self { Self { - tx: Default::default(), + db: Default::default(), consensus: Arc::new(TestConsensus::default()), commit_threshold: 500, } @@ -215,8 +214,8 @@ mod tests { impl StageTestRunner for TotalDifficultyTestRunner { type S = TotalDifficultyStage; - fn tx(&self) -> &TestTransaction { - &self.tx + fn db(&self) -> &TestStageDB { + &self.db } fn stage(&self) -> Self::S { @@ -235,15 +234,16 @@ mod tests { let mut rng = generators::rng(); let start = input.checkpoint().block_number; let head = random_header(&mut rng, start, None); - self.tx.insert_headers(std::iter::once(&head))?; - self.tx.commit(|tx| { + self.db.insert_headers(std::iter::once(&head))?; + self.db.commit(|tx| { let td: U256 = tx .cursor_read::()? .last()? .map(|(_, v)| v) .unwrap_or_default() .into(); - tx.put::(head.number, (td + head.difficulty).into()) + tx.put::(head.number, (td + head.difficulty).into())?; + Ok(()) })?; // use previous progress as seed size @@ -254,7 +254,7 @@ mod tests { } let mut headers = random_header_range(&mut rng, start + 1..end, head.hash()); - self.tx.insert_headers(headers.iter())?; + self.db.insert_headers(headers.iter())?; headers.insert(0, head); Ok(headers) } @@ -268,7 +268,7 @@ mod tests { let initial_stage_progress = input.checkpoint().block_number; match output { Some(output) if output.checkpoint.block_number > initial_stage_progress => { - let provider = self.tx.inner(); + let provider = self.db.factory.provider()?; let mut header_cursor = provider.tx_ref().cursor_read::()?; let (_, mut current_header) = header_cursor @@ -302,7 +302,7 @@ mod tests { impl TotalDifficultyTestRunner { fn check_no_td_above(&self, block: BlockNumber) -> Result<(), TestRunnerError> { - self.tx.ensure_no_entry_above::(block, |num| num)?; + self.db.ensure_no_entry_above::(block, |num| num)?; Ok(()) } diff --git a/crates/stages/src/stages/tx_lookup.rs b/crates/stages/src/stages/tx_lookup.rs index 758fa403320b..a741bed28582 100644 --- a/crates/stages/src/stages/tx_lookup.rs +++ b/crates/stages/src/stages/tx_lookup.rs @@ -42,7 +42,6 @@ impl TransactionLookupStage { } } -#[async_trait::async_trait] impl Stage for TransactionLookupStage { /// Return the id of the stage fn id(&self) -> StageId { @@ -50,9 +49,9 @@ impl Stage for TransactionLookupStage { } /// Write transaction hash -> id entries - async fn execute( + fn execute( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, mut input: ExecInput, ) -> Result { if let Some((target_prunable_block, prune_mode)) = self @@ -128,9 +127,9 @@ impl Stage for TransactionLookupStage { } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, input: UnwindInput, ) -> Result { let tx = provider.tx_ref(); @@ -165,7 +164,7 @@ impl Stage for TransactionLookupStage { } fn stage_checkpoint( - provider: &DatabaseProviderRW<'_, &DB>, + provider: &DatabaseProviderRW, ) -> Result { let pruned_entries = provider .get_prune_checkpoint(PruneSegment::TransactionLookup)? @@ -187,7 +186,7 @@ mod tests { use super::*; use crate::test_utils::{ stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, TestRunnerError, - TestTransaction, UnwindStageTestRunner, + TestStageDB, UnwindStageTestRunner, }; use assert_matches::assert_matches; use reth_interfaces::test_utils::{ @@ -196,11 +195,8 @@ mod tests { }; use reth_primitives::{ stage::StageUnitCheckpoint, BlockNumber, PruneCheckpoint, PruneMode, SealedBlock, B256, - MAINNET, - }; - use reth_provider::{ - BlockReader, ProviderError, ProviderFactory, PruneCheckpointWriter, TransactionsProvider, }; + use reth_provider::{BlockReader, ProviderError, PruneCheckpointWriter, TransactionsProvider}; use std::ops::Sub; // Implement stage test suite. @@ -231,7 +227,7 @@ mod tests { ) }) .collect::>(); - runner.tx.insert_blocks(blocks.iter(), None).expect("failed to insert blocks"); + runner.db.insert_blocks(blocks.iter(), None).expect("failed to insert blocks"); let rx = runner.execute(input); @@ -247,7 +243,7 @@ mod tests { total })) }, done: true }) if block_number == previous_stage && processed == total && - total == runner.tx.table::().unwrap().len() as u64 + total == runner.db.table::().unwrap().len() as u64 ); // Validate the stage execution @@ -270,9 +266,9 @@ mod tests { // Seed only once with full input range let seed = random_block_range(&mut rng, stage_progress + 1..=previous_stage, B256::ZERO, 0..4); // set tx count range high enough to hit the threshold - runner.tx.insert_blocks(seed.iter(), None).expect("failed to seed execution"); + runner.db.insert_blocks(seed.iter(), None).expect("failed to seed execution"); - let total_txs = runner.tx.table::().unwrap().len() as u64; + let total_txs = runner.db.table::().unwrap().len() as u64; // Execute first time let result = runner.execute(first_input).await.unwrap(); @@ -291,7 +287,7 @@ mod tests { ExecOutput { checkpoint: StageCheckpoint::new(expected_progress).with_entities_stage_checkpoint( EntitiesCheckpoint { - processed: runner.tx.table::().unwrap().len() as u64, + processed: runner.db.table::().unwrap().len() as u64, total: total_txs } ), @@ -335,7 +331,7 @@ mod tests { // Seed only once with full input range let seed = random_block_range(&mut rng, stage_progress + 1..=previous_stage, B256::ZERO, 0..2); - runner.tx.insert_blocks(seed.iter(), None).expect("failed to seed execution"); + runner.db.insert_blocks(seed.iter(), None).expect("failed to seed execution"); runner.set_prune_mode(PruneMode::Before(prune_target)); @@ -353,7 +349,7 @@ mod tests { total })) }, done: true }) if block_number == previous_stage && processed == total && - total == runner.tx.table::().unwrap().len() as u64 + total == runner.db.table::().unwrap().len() as u64 ); // Validate the stage execution @@ -362,11 +358,11 @@ mod tests { #[test] fn stage_checkpoint_pruned() { - let tx = TestTransaction::default(); + let db = TestStageDB::default(); let mut rng = generators::rng(); let blocks = random_block_range(&mut rng, 0..=100, B256::ZERO, 0..10); - tx.insert_blocks(blocks.iter(), None).expect("insert blocks"); + db.insert_blocks(blocks.iter(), None).expect("insert blocks"); let max_pruned_block = 30; let max_processed_block = 70; @@ -381,9 +377,9 @@ mod tests { tx_hash_number += 1; } } - tx.insert_tx_hash_numbers(tx_hash_numbers).expect("insert tx hash numbers"); + db.insert_tx_hash_numbers(tx_hash_numbers).expect("insert tx hash numbers"); - let provider = tx.inner_rw(); + let provider = db.factory.provider_rw().unwrap(); provider .save_prune_checkpoint( PruneSegment::TransactionLookup, @@ -402,10 +398,7 @@ mod tests { .expect("save stage checkpoint"); provider.commit().expect("commit"); - let db = tx.inner_raw(); - let factory = ProviderFactory::new(db.as_ref(), MAINNET.clone()); - let provider = factory.provider_rw().expect("provider rw"); - + let provider = db.factory.provider_rw().unwrap(); assert_eq!( stage_checkpoint(&provider).expect("stage checkpoint"), EntitiesCheckpoint { @@ -419,14 +412,14 @@ mod tests { } struct TransactionLookupTestRunner { - tx: TestTransaction, + db: TestStageDB, commit_threshold: u64, prune_mode: Option, } impl Default for TransactionLookupTestRunner { fn default() -> Self { - Self { tx: TestTransaction::default(), commit_threshold: 1000, prune_mode: None } + Self { db: TestStageDB::default(), commit_threshold: 1000, prune_mode: None } } } @@ -448,17 +441,18 @@ mod tests { /// not empty. fn ensure_no_hash_by_block(&self, number: BlockNumber) -> Result<(), TestRunnerError> { let body_result = self - .tx - .inner_rw() + .db + .factory + .provider_rw()? .block_body_indices(number)? .ok_or(ProviderError::BlockBodyIndicesNotFound(number)); match body_result { - Ok(body) => self.tx.ensure_no_entry_above_by_value::( + Ok(body) => self.db.ensure_no_entry_above_by_value::( body.last_tx_num(), |key| key, )?, Err(_) => { - assert!(self.tx.table_is_empty::()?); + assert!(self.db.table_is_empty::()?); } }; @@ -469,8 +463,8 @@ mod tests { impl StageTestRunner for TransactionLookupTestRunner { type S = TransactionLookupStage; - fn tx(&self) -> &TestTransaction { - &self.tx + fn db(&self) -> &TestStageDB { + &self.db } fn stage(&self) -> Self::S { @@ -490,7 +484,7 @@ mod tests { let mut rng = generators::rng(); let blocks = random_block_range(&mut rng, stage_progress + 1..=end, B256::ZERO, 0..2); - self.tx.insert_blocks(blocks.iter(), None)?; + self.db.insert_blocks(blocks.iter(), None)?; Ok(blocks) } @@ -501,7 +495,7 @@ mod tests { ) -> Result<(), TestRunnerError> { match output { Some(output) => { - let provider = self.tx.inner(); + let provider = self.db.factory.provider()?; if let Some((target_prunable_block, _)) = self .prune_mode diff --git a/crates/stages/src/test_utils/mod.rs b/crates/stages/src/test_utils/mod.rs index b9fe397a8873..b74b3e9455c1 100644 --- a/crates/stages/src/test_utils/mod.rs +++ b/crates/stages/src/test_utils/mod.rs @@ -10,7 +10,7 @@ pub(crate) use runner::{ }; mod test_db; -pub use test_db::TestTransaction; +pub use test_db::TestStageDB; mod stage; pub use stage::TestStage; diff --git a/crates/stages/src/test_utils/runner.rs b/crates/stages/src/test_utils/runner.rs index 9bc08638d34f..17289b9cf20b 100644 --- a/crates/stages/src/test_utils/runner.rs +++ b/crates/stages/src/test_utils/runner.rs @@ -1,6 +1,6 @@ -use super::TestTransaction; -use crate::{ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput}; -use reth_db::DatabaseEnv; +use super::TestStageDB; +use crate::{ExecInput, ExecOutput, Stage, StageError, StageExt, UnwindInput, UnwindOutput}; +use reth_db::{test_utils::TempDatabase, DatabaseEnv}; use reth_interfaces::db::DatabaseError; use reth_primitives::MAINNET; use reth_provider::{ProviderError, ProviderFactory}; @@ -19,10 +19,10 @@ pub(crate) enum TestRunnerError { /// A generic test runner for stages. pub(crate) trait StageTestRunner { - type S: Stage + 'static; + type S: Stage>> + 'static; /// Return a reference to the database. - fn tx(&self) -> &TestTransaction; + fn db(&self) -> &TestStageDB; /// Return an instance of a Stage. fn stage(&self) -> Self::S; @@ -45,13 +45,14 @@ pub(crate) trait ExecuteStageTestRunner: StageTestRunner { /// Run [Stage::execute] and return a receiver for the result. fn execute(&self, input: ExecInput) -> oneshot::Receiver> { let (tx, rx) = oneshot::channel(); - let (db, mut stage) = (self.tx().inner_raw(), self.stage()); + let (db, mut stage) = (self.db().factory.clone(), self.stage()); tokio::spawn(async move { - let factory = ProviderFactory::new(db.db(), MAINNET.clone()); - let provider = factory.provider_rw().unwrap(); - - let result = stage.execute(&provider, input).await; - provider.commit().expect("failed to commit"); + let result = stage.execute_ready(input).await.and_then(|_| { + let provider_rw = db.provider_rw().unwrap(); + let result = stage.execute(&provider_rw, input); + provider_rw.commit().expect("failed to commit"); + result + }); tx.send(result).expect("failed to send message") }); rx @@ -71,12 +72,10 @@ pub(crate) trait UnwindStageTestRunner: StageTestRunner { /// Run [Stage::unwind] and return a receiver for the result. async fn unwind(&self, input: UnwindInput) -> Result { let (tx, rx) = oneshot::channel(); - let (db, mut stage) = (self.tx().inner_raw(), self.stage()); + let (db, mut stage) = (self.db().factory.clone(), self.stage()); tokio::spawn(async move { - let factory = ProviderFactory::new(db.db(), MAINNET.clone()); - let provider = factory.provider_rw().unwrap(); - - let result = stage.unwind(&provider, input).await; + let provider = db.provider_rw().unwrap(); + let result = stage.unwind(&provider, input); provider.commit().expect("failed to commit"); tx.send(result).expect("failed to send result"); }); diff --git a/crates/stages/src/test_utils/stage.rs b/crates/stages/src/test_utils/stage.rs index 65ea51362dfb..a76e46e67cd6 100644 --- a/crates/stages/src/test_utils/stage.rs +++ b/crates/stages/src/test_utils/stage.rs @@ -40,15 +40,14 @@ impl TestStage { } } -#[async_trait::async_trait] impl Stage for TestStage { fn id(&self) -> StageId { self.id } - async fn execute( + fn execute( &mut self, - _: &DatabaseProviderRW<'_, &DB>, + _: &DatabaseProviderRW, _input: ExecInput, ) -> Result { self.exec_outputs @@ -56,9 +55,9 @@ impl Stage for TestStage { .unwrap_or_else(|| panic!("Test stage {} executed too many times.", self.id)) } - async fn unwind( + fn unwind( &mut self, - _: &DatabaseProviderRW<'_, &DB>, + _: &DatabaseProviderRW, _input: UnwindInput, ) -> Result { self.unwind_outputs diff --git a/crates/stages/src/test_utils/test_db.rs b/crates/stages/src/test_utils/test_db.rs index 56361f21295a..4582bb86acf4 100644 --- a/crates/stages/src/test_utils/test_db.rs +++ b/crates/stages/src/test_utils/test_db.rs @@ -1,15 +1,15 @@ use reth_db::{ common::KeyValue, cursor::{DbCursorRO, DbCursorRW, DbDupCursorRO}, - database::DatabaseGAT, + database::Database, models::{AccountBeforeTx, StoredBlockBodyIndices}, table::{Table, TableRow}, tables, test_utils::{create_test_rw_db, create_test_rw_db_with_path, TempDatabase}, - transaction::{DbTx, DbTxGAT, DbTxMut, DbTxMutGAT}, + transaction::{DbTx, DbTxMut}, DatabaseEnv, DatabaseError as DbError, }; -use reth_interfaces::{test_utils::generators::ChangeSet, RethResult}; +use reth_interfaces::{provider::ProviderResult, test_utils::generators::ChangeSet, RethResult}; use reth_primitives::{ keccak256, Account, Address, BlockNumber, Receipt, SealedBlock, SealedHeader, StorageEntry, TxHash, TxNumber, B256, MAINNET, U256, @@ -18,80 +18,50 @@ use reth_provider::{DatabaseProviderRO, DatabaseProviderRW, HistoryWriter, Provi use std::{ borrow::Borrow, collections::BTreeMap, - ops::RangeInclusive, + ops::{Deref, RangeInclusive}, path::{Path, PathBuf}, sync::Arc, }; -/// The [TestTransaction] is used as an internal -/// database for testing stage implementation. -/// -/// ```rust,ignore -/// let tx = TestTransaction::default(); -/// stage.execute(&mut tx.container(), input); -/// ``` +/// Test database that is used for testing stage implementations. #[derive(Debug)] -pub struct TestTransaction { - /// DB - pub tx: Arc>, - pub path: Option, +pub struct TestStageDB { pub factory: ProviderFactory>>, } -impl Default for TestTransaction { - /// Create a new instance of [TestTransaction] +impl Default for TestStageDB { + /// Create a new instance of [TestStageDB] fn default() -> Self { - let tx = create_test_rw_db(); - Self { tx: tx.clone(), path: None, factory: ProviderFactory::new(tx, MAINNET.clone()) } + Self { factory: ProviderFactory::new(create_test_rw_db(), MAINNET.clone()) } } } -impl TestTransaction { +impl TestStageDB { pub fn new(path: &Path) -> Self { - let tx = create_test_rw_db_with_path(path); - Self { - tx: tx.clone(), - path: Some(path.to_path_buf()), - factory: ProviderFactory::new(tx, MAINNET.clone()), - } - } - - /// Return a database wrapped in [DatabaseProviderRW]. - pub fn inner_rw(&self) -> DatabaseProviderRW<'_, Arc>> { - self.factory.provider_rw().expect("failed to create db container") - } - - /// Return a database wrapped in [DatabaseProviderRO]. - pub fn inner(&self) -> DatabaseProviderRO<'_, Arc>> { - self.factory.provider().expect("failed to create db container") - } - - /// Get a pointer to an internal database. - pub fn inner_raw(&self) -> Arc> { - self.tx.clone() + Self { factory: ProviderFactory::new(create_test_rw_db_with_path(path), MAINNET.clone()) } } /// Invoke a callback with transaction committing it afterwards - pub fn commit(&self, f: F) -> Result<(), DbError> + pub fn commit(&self, f: F) -> ProviderResult<()> where - F: FnOnce(&>::TXMut) -> Result<(), DbError>, + F: FnOnce(&::TXMut) -> ProviderResult<()>, { - let mut tx = self.inner_rw(); + let mut tx = self.factory.provider_rw()?; f(tx.tx_ref())?; tx.commit().expect("failed to commit"); Ok(()) } /// Invoke a callback with a read transaction - pub fn query(&self, f: F) -> Result + pub fn query(&self, f: F) -> ProviderResult where - F: FnOnce(&>::TX) -> Result, + F: FnOnce(&::TX) -> ProviderResult, { - f(self.inner().tx_ref()) + f(self.factory.provider()?.tx_ref()) } /// Check if the table is empty - pub fn table_is_empty(&self) -> Result { + pub fn table_is_empty(&self) -> ProviderResult { self.query(|tx| { let last = tx.cursor_read::()?.last()?; Ok(last.is_none()) @@ -99,70 +69,21 @@ impl TestTransaction { } /// Return full table as Vec - pub fn table(&self) -> Result>, DbError> + pub fn table(&self) -> ProviderResult>> where T::Key: Default + Ord, { self.query(|tx| { - tx.cursor_read::()? + Ok(tx + .cursor_read::()? .walk(Some(T::Key::default()))? - .collect::, DbError>>() - }) - } - - /// Map a collection of values and store them in the database. - /// This function commits the transaction before exiting. - /// - /// ```rust,ignore - /// let tx = TestTransaction::default(); - /// tx.map_put::(&items, |item| item)?; - /// ``` - #[allow(dead_code)] - pub fn map_put(&self, values: &[S], mut map: F) -> Result<(), DbError> - where - T: Table, - S: Clone, - F: FnMut(&S) -> TableRow, - { - self.commit(|tx| { - values.iter().try_for_each(|src| { - let (k, v) = map(src); - tx.put::(k, v) - }) - }) - } - - /// Transform a collection of values using a callback and store - /// them in the database. The callback additionally accepts the - /// optional last element that was stored. - /// This function commits the transaction before exiting. - /// - /// ```rust,ignore - /// let tx = TestTransaction::default(); - /// tx.transform_append::(&items, |prev, item| prev.unwrap_or_default() + item)?; - /// ``` - #[allow(dead_code)] - pub fn transform_append(&self, values: &[S], mut transform: F) -> Result<(), DbError> - where - T: Table, - ::Value: Clone, - S: Clone, - F: FnMut(&Option<::Value>, &S) -> TableRow, - { - self.commit(|tx| { - let mut cursor = tx.cursor_write::()?; - let mut last = cursor.last()?.map(|(_, v)| v); - values.iter().try_for_each(|src| { - let (k, v) = transform(&last, src); - last = Some(v.clone()); - cursor.append(k, v) - }) + .collect::, DbError>>()?) }) } /// Check that there is no table entry above a given /// number by [Table::Key] - pub fn ensure_no_entry_above(&self, num: u64, mut selector: F) -> Result<(), DbError> + pub fn ensure_no_entry_above(&self, num: u64, mut selector: F) -> ProviderResult<()> where T: Table, F: FnMut(T::Key) -> BlockNumber, @@ -182,7 +103,7 @@ impl TestTransaction { &self, num: u64, mut selector: F, - ) -> Result<(), DbError> + ) -> ProviderResult<()> where T: Table, F: FnMut(T::Value) -> BlockNumber, @@ -206,17 +127,19 @@ impl TestTransaction { /// Insert ordered collection of [SealedHeader] into the corresponding tables /// that are supposed to be populated by the headers stage. - pub fn insert_headers<'a, I>(&self, headers: I) -> Result<(), DbError> + pub fn insert_headers<'a, I>(&self, headers: I) -> ProviderResult<()> where I: Iterator, { - self.commit(|tx| headers.into_iter().try_for_each(|header| Self::insert_header(tx, header))) + self.commit(|tx| { + Ok(headers.into_iter().try_for_each(|header| Self::insert_header(tx, header))?) + }) } /// Inserts total difficulty of headers into the corresponding tables. /// - /// Superset functionality of [TestTransaction::insert_headers]. - pub fn insert_headers_with_td<'a, I>(&self, headers: I) -> Result<(), DbError> + /// Superset functionality of [TestStageDB::insert_headers]. + pub fn insert_headers_with_td<'a, I>(&self, headers: I) -> ProviderResult<()> where I: Iterator, { @@ -225,16 +148,16 @@ impl TestTransaction { headers.into_iter().try_for_each(|header| { Self::insert_header(tx, header)?; td += header.difficulty; - tx.put::(header.number, td.into()) + Ok(tx.put::(header.number, td.into())?) }) }) } /// Insert ordered collection of [SealedBlock] into corresponding tables. - /// Superset functionality of [TestTransaction::insert_headers]. + /// Superset functionality of [TestStageDB::insert_headers]. /// /// Assumes that there's a single transition for each transaction (i.e. no block rewards). - pub fn insert_blocks<'a, I>(&self, blocks: I, tx_offset: Option) -> Result<(), DbError> + pub fn insert_blocks<'a, I>(&self, blocks: I, tx_offset: Option) -> ProviderResult<()> where I: Iterator, { @@ -266,45 +189,45 @@ impl TestTransaction { }) } - pub fn insert_tx_hash_numbers(&self, tx_hash_numbers: I) -> Result<(), DbError> + pub fn insert_tx_hash_numbers(&self, tx_hash_numbers: I) -> ProviderResult<()> where I: IntoIterator, { self.commit(|tx| { tx_hash_numbers.into_iter().try_for_each(|(tx_hash, tx_num)| { // Insert into tx hash numbers table. - tx.put::(tx_hash, tx_num) + Ok(tx.put::(tx_hash, tx_num)?) }) }) } /// Insert collection of ([TxNumber], [Receipt]) into the corresponding table. - pub fn insert_receipts(&self, receipts: I) -> Result<(), DbError> + pub fn insert_receipts(&self, receipts: I) -> ProviderResult<()> where I: IntoIterator, { self.commit(|tx| { receipts.into_iter().try_for_each(|(tx_num, receipt)| { // Insert into receipts table. - tx.put::(tx_num, receipt) + Ok(tx.put::(tx_num, receipt)?) }) }) } - pub fn insert_transaction_senders(&self, transaction_senders: I) -> Result<(), DbError> + pub fn insert_transaction_senders(&self, transaction_senders: I) -> ProviderResult<()> where I: IntoIterator, { self.commit(|tx| { transaction_senders.into_iter().try_for_each(|(tx_num, sender)| { // Insert into receipts table. - tx.put::(tx_num, sender) + Ok(tx.put::(tx_num, sender)?) }) }) } /// Insert collection of ([Address], [Account]) into corresponding tables. - pub fn insert_accounts_and_storages(&self, accounts: I) -> Result<(), DbError> + pub fn insert_accounts_and_storages(&self, accounts: I) -> ProviderResult<()> where I: IntoIterator, S: IntoIterator, @@ -350,7 +273,7 @@ impl TestTransaction { &self, changesets: I, block_offset: Option, - ) -> Result<(), DbError> + ) -> ProviderResult<()> where I: IntoIterator, { @@ -369,14 +292,14 @@ impl TestTransaction { // Insert into storage changeset. old_storage.into_iter().try_for_each(|entry| { - tx.put::(block_address, entry) + Ok(tx.put::(block_address, entry)?) }) }) }) }) } - pub fn insert_history(&self, changesets: I, block_offset: Option) -> RethResult<()> + pub fn insert_history(&self, changesets: I, block_offset: Option) -> ProviderResult<()> where I: IntoIterator, { @@ -392,10 +315,10 @@ impl TestTransaction { } } - let provider = self.factory.provider_rw()?; - provider.insert_account_history_index(accounts)?; - provider.insert_storage_history_index(storages)?; - provider.commit()?; + let provider_rw = self.factory.provider_rw()?; + provider_rw.insert_account_history_index(accounts)?; + provider_rw.insert_storage_history_index(storages)?; + provider_rw.commit()?; Ok(()) } diff --git a/crates/storage/codecs/src/lib.rs b/crates/storage/codecs/src/lib.rs index fee674a23abf..5b7a5ee1610e 100644 --- a/crates/storage/codecs/src/lib.rs +++ b/crates/storage/codecs/src/lib.rs @@ -31,7 +31,7 @@ use revm_primitives::{ /// Regarding the `specialized_to/from_compact` methods: Mainly used as a workaround for not being /// able to specialize an impl over certain types like `Vec`/`Option` where `T` is a fixed /// size array like `Vec`. -pub trait Compact { +pub trait Compact: Sized { /// Takes a buffer which can be written to. *Ideally*, it returns the length written to. fn to_compact(self, buf: &mut B) -> usize where @@ -43,24 +43,18 @@ pub trait Compact { /// `len` can either be the `buf` remaining length, or the length of the compacted type. /// /// It will panic, if `len` is smaller than `buf.len()`. - fn from_compact(buf: &[u8], len: usize) -> (Self, &[u8]) - where - Self: Sized; + fn from_compact(buf: &[u8], len: usize) -> (Self, &[u8]); /// "Optional": If there's no good reason to use it, don't. fn specialized_to_compact(self, buf: &mut B) -> usize where B: bytes::BufMut + AsMut<[u8]>, - Self: Sized, { self.to_compact(buf) } /// "Optional": If there's no good reason to use it, don't. - fn specialized_from_compact(buf: &[u8], len: usize) -> (Self, &[u8]) - where - Self: Sized, - { + fn specialized_from_compact(buf: &[u8], len: usize) -> (Self, &[u8]) { Self::from_compact(buf, len) } } diff --git a/crates/storage/db/src/abstraction/common.rs b/crates/storage/db/src/abstraction/common.rs index 29a1e34294d9..9bce16e397d2 100644 --- a/crates/storage/db/src/abstraction/common.rs +++ b/crates/storage/db/src/abstraction/common.rs @@ -1,3 +1,5 @@ +use crate::{abstraction::table::*, DatabaseError}; + /// A key-value pair for table `T`. pub type KeyValue = (::Key, ::Value); @@ -16,13 +18,20 @@ pub type IterPairResult = Option, DatabaseError>>; /// A value only result for table `T`. pub type ValueOnlyResult = Result::Value>, DatabaseError>; -use crate::{abstraction::table::*, DatabaseError}; - -// Sealed trait helper to prevent misuse of the API. +// Sealed trait helper to prevent misuse of the Database API. mod sealed { + use crate::{database::Database, mock::DatabaseMock, DatabaseEnv}; + use std::sync::Arc; + + /// Sealed trait to limit the implementors of the Database trait. pub trait Sealed: Sized {} - #[allow(missing_debug_implementations)] - pub struct Bounds(T); - impl Sealed for Bounds {} + + impl Sealed for &DB {} + impl Sealed for Arc {} + impl Sealed for DatabaseEnv {} + impl Sealed for DatabaseMock {} + + #[cfg(any(test, feature = "test-utils"))] + impl Sealed for crate::test_utils::TempDatabase {} } -pub(crate) use sealed::{Bounds, Sealed}; +pub(crate) use sealed::Sealed; diff --git a/crates/storage/db/src/abstraction/database.rs b/crates/storage/db/src/abstraction/database.rs index eacf845bb7bd..e185b4438263 100644 --- a/crates/storage/db/src/abstraction/database.rs +++ b/crates/storage/db/src/abstraction/database.rs @@ -1,35 +1,31 @@ use crate::{ - common::{Bounds, Sealed}, + abstraction::common::Sealed, table::TableImporter, transaction::{DbTx, DbTxMut}, DatabaseError, }; use std::{fmt::Debug, sync::Arc}; -/// Implements the GAT method from: -/// . +/// Main Database trait that can open read-only and read-write transactions. /// -/// Sealed trait which cannot be implemented by 3rd parties, exposed only for implementers -pub trait DatabaseGAT<'a, __ImplicitBounds: Sealed = Bounds<&'a Self>>: Send + Sync { - /// RO database transaction - type TX: DbTx + Send + Sync + Debug; - /// RW database transaction - type TXMut: DbTxMut + DbTx + TableImporter + Send + Sync + Debug; -} +/// Sealed trait which cannot be implemented by 3rd parties, exposed only for consumption. +pub trait Database: Send + Sync + Sealed { + /// Read-Only database transaction + type TX: DbTx + Send + Sync + Debug + 'static; + /// Read-Write database transaction + type TXMut: DbTxMut + DbTx + TableImporter + Send + Sync + Debug + 'static; -/// Main Database trait that spawns transactions to be executed. -pub trait Database: for<'a> DatabaseGAT<'a> { /// Create read only transaction. - fn tx(&self) -> Result<>::TX, DatabaseError>; + fn tx(&self) -> Result; /// Create read write transaction only possible if database is open with write access. - fn tx_mut(&self) -> Result<>::TXMut, DatabaseError>; + fn tx_mut(&self) -> Result; /// Takes a function and passes a read-only transaction into it, making sure it's closed in the /// end of the execution. fn view(&self, f: F) -> Result where - F: FnOnce(&>::TX) -> T, + F: FnOnce(&Self::TX) -> T, { let tx = self.tx()?; @@ -43,7 +39,7 @@ pub trait Database: for<'a> DatabaseGAT<'a> { /// the end of the execution. fn update(&self, f: F) -> Result where - F: FnOnce(&>::TXMut) -> T, + F: FnOnce(&Self::TXMut) -> T, { let tx = self.tx_mut()?; @@ -54,34 +50,27 @@ pub trait Database: for<'a> DatabaseGAT<'a> { } } -// Generic over Arc -impl<'a, DB: Database> DatabaseGAT<'a> for Arc { - type TX = >::TX; - type TXMut = >::TXMut; -} - impl Database for Arc { - fn tx(&self) -> Result<>::TX, DatabaseError> { + type TX = ::TX; + type TXMut = ::TXMut; + + fn tx(&self) -> Result { ::tx(self) } - fn tx_mut(&self) -> Result<>::TXMut, DatabaseError> { + fn tx_mut(&self) -> Result { ::tx_mut(self) } } -// Generic over reference -impl<'a, DB: Database> DatabaseGAT<'a> for &DB { - type TX = >::TX; - type TXMut = >::TXMut; -} - impl Database for &DB { - fn tx(&self) -> Result<>::TX, DatabaseError> { + type TX = ::TX; + type TXMut = ::TXMut; + fn tx(&self) -> Result { ::tx(self) } - fn tx_mut(&self) -> Result<>::TXMut, DatabaseError> { + fn tx_mut(&self) -> Result { ::tx_mut(self) } } diff --git a/crates/storage/db/src/abstraction/mock.rs b/crates/storage/db/src/abstraction/mock.rs index 737797008085..d5427b49fbe7 100644 --- a/crates/storage/db/src/abstraction/mock.rs +++ b/crates/storage/db/src/abstraction/mock.rs @@ -1,17 +1,16 @@ //! Mock database -use std::{collections::BTreeMap, ops::RangeBounds}; - use crate::{ common::{PairResult, ValueOnlyResult}, cursor::{ DbCursorRO, DbCursorRW, DbDupCursorRO, DbDupCursorRW, DupWalker, RangeWalker, ReverseWalker, Walker, }, - database::{Database, DatabaseGAT}, + database::Database, table::{DupSort, Table, TableImporter}, - transaction::{DbTx, DbTxGAT, DbTxMut, DbTxMutGAT}, + transaction::{DbTx, DbTxMut}, DatabaseError, }; +use std::{collections::BTreeMap, ops::RangeBounds}; /// Mock database used for testing with inner BTreeMap structure /// TODO @@ -22,21 +21,17 @@ pub struct DatabaseMock { } impl Database for DatabaseMock { - fn tx(&self) -> Result<>::TX, DatabaseError> { + type TX = TxMock; + type TXMut = TxMock; + fn tx(&self) -> Result { Ok(TxMock::default()) } - fn tx_mut(&self) -> Result<>::TXMut, DatabaseError> { + fn tx_mut(&self) -> Result { Ok(TxMock::default()) } } -impl<'a> DatabaseGAT<'a> for DatabaseMock { - type TX = TxMock; - - type TXMut = TxMock; -} - /// Mock read only tx #[derive(Debug, Clone, Default)] pub struct TxMock { @@ -44,17 +39,10 @@ pub struct TxMock { _table: BTreeMap, Vec>, } -impl<'a> DbTxGAT<'a> for TxMock { +impl DbTx for TxMock { type Cursor = CursorMock; type DupCursor = CursorMock; -} - -impl<'a> DbTxMutGAT<'a> for TxMock { - type CursorMut = CursorMock; - type DupCursorMut = CursorMock; -} -impl DbTx for TxMock { fn get(&self, _key: T::Key) -> Result, DatabaseError> { todo!() } @@ -65,13 +53,11 @@ impl DbTx for TxMock { fn abort(self) {} - fn cursor_read(&self) -> Result<>::Cursor, DatabaseError> { + fn cursor_read(&self) -> Result, DatabaseError> { Ok(CursorMock { _cursor: 0 }) } - fn cursor_dup_read( - &self, - ) -> Result<>::DupCursor, DatabaseError> { + fn cursor_dup_read(&self) -> Result, DatabaseError> { Ok(CursorMock { _cursor: 0 }) } @@ -81,6 +67,9 @@ impl DbTx for TxMock { } impl DbTxMut for TxMock { + type CursorMut = CursorMock; + type DupCursorMut = CursorMock; + fn put(&self, _key: T::Key, _value: T::Value) -> Result<(), DatabaseError> { todo!() } @@ -97,15 +86,11 @@ impl DbTxMut for TxMock { todo!() } - fn cursor_write( - &self, - ) -> Result<>::CursorMut, DatabaseError> { + fn cursor_write(&self) -> Result, DatabaseError> { todo!() } - fn cursor_dup_write( - &self, - ) -> Result<>::DupCursorMut, DatabaseError> { + fn cursor_dup_write(&self) -> Result, DatabaseError> { todo!() } } @@ -120,57 +105,48 @@ pub struct CursorMock { impl DbCursorRO for CursorMock { fn first(&mut self) -> PairResult { - todo!() + Ok(None) } fn seek_exact(&mut self, _key: T::Key) -> PairResult { - todo!() + Ok(None) } fn seek(&mut self, _key: T::Key) -> PairResult { - todo!() + Ok(None) } fn next(&mut self) -> PairResult { - todo!() + Ok(None) } fn prev(&mut self) -> PairResult { - todo!() + Ok(None) } fn last(&mut self) -> PairResult { - todo!() + Ok(None) } fn current(&mut self) -> PairResult { - todo!() + Ok(None) } - fn walk(&mut self, _start_key: Option) -> Result, DatabaseError> - where - Self: Sized, - { + fn walk(&mut self, _start_key: Option) -> Result, DatabaseError> { todo!() } fn walk_range( &mut self, _range: impl RangeBounds, - ) -> Result, DatabaseError> - where - Self: Sized, - { + ) -> Result, DatabaseError> { todo!() } fn walk_back( &mut self, _start_key: Option, - ) -> Result, DatabaseError> - where - Self: Sized, - { + ) -> Result, DatabaseError> { todo!() } } @@ -200,10 +176,7 @@ impl DbDupCursorRO for CursorMock { &mut self, _key: Option<::Key>, _subkey: Option<::SubKey>, - ) -> Result, DatabaseError> - where - Self: Sized, - { + ) -> Result, DatabaseError> { todo!() } } diff --git a/crates/storage/db/src/abstraction/transaction.rs b/crates/storage/db/src/abstraction/transaction.rs index bbbd775d7a16..472563f3520b 100644 --- a/crates/storage/db/src/abstraction/transaction.rs +++ b/crates/storage/db/src/abstraction/transaction.rs @@ -1,39 +1,16 @@ use crate::{ - common::{Bounds, Sealed}, cursor::{DbCursorRO, DbCursorRW, DbDupCursorRO, DbDupCursorRW}, table::{DupSort, Table}, DatabaseError, }; -/// Implements the GAT method from: -/// . -/// -/// Sealed trait which cannot be implemented by 3rd parties, exposed only for implementers -pub trait DbTxGAT<'a, __ImplicitBounds: Sealed = Bounds<&'a Self>>: Send + Sync { - /// Cursor GAT +/// Read only transaction +pub trait DbTx: Send + Sync { + /// Cursor type for this read-only transaction type Cursor: DbCursorRO + Send + Sync; - /// DupCursor GAT + /// DupCursor type for this read-only transaction type DupCursor: DbDupCursorRO + DbCursorRO + Send + Sync; -} - -/// Implements the GAT method from: -/// . -/// -/// Sealed trait which cannot be implemented by 3rd parties, exposed only for implementers -pub trait DbTxMutGAT<'a, __ImplicitBounds: Sealed = Bounds<&'a Self>>: Send + Sync { - /// Cursor GAT - type CursorMut: DbCursorRW + DbCursorRO + Send + Sync; - /// DupCursor GAT - type DupCursorMut: DbDupCursorRW - + DbCursorRW - + DbDupCursorRO - + DbCursorRO - + Send - + Sync; -} -/// Read only transaction -pub trait DbTx: for<'a> DbTxGAT<'a> { /// Get value fn get(&self, key: T::Key) -> Result, DatabaseError>; /// Commit for read only transaction will consume and free transaction and allows @@ -42,17 +19,25 @@ pub trait DbTx: for<'a> DbTxGAT<'a> { /// Aborts transaction fn abort(self); /// Iterate over read only values in table. - fn cursor_read(&self) -> Result<>::Cursor, DatabaseError>; + fn cursor_read(&self) -> Result, DatabaseError>; /// Iterate over read only values in dup sorted table. - fn cursor_dup_read( - &self, - ) -> Result<>::DupCursor, DatabaseError>; + fn cursor_dup_read(&self) -> Result, DatabaseError>; /// Returns number of entries in the table. fn entries(&self) -> Result; } /// Read write transaction that allows writing to database -pub trait DbTxMut: for<'a> DbTxMutGAT<'a> { +pub trait DbTxMut: Send + Sync { + /// Read-Write Cursor type + type CursorMut: DbCursorRW + DbCursorRO + Send + Sync; + /// Read-Write DupCursor type + type DupCursorMut: DbDupCursorRW + + DbCursorRW + + DbDupCursorRO + + DbCursorRO + + Send + + Sync; + /// Put value to database fn put(&self, key: T::Key, value: T::Value) -> Result<(), DatabaseError>; /// Delete value from database @@ -61,11 +46,7 @@ pub trait DbTxMut: for<'a> DbTxMutGAT<'a> { /// Clears database. fn clear(&self) -> Result<(), DatabaseError>; /// Cursor mut - fn cursor_write( - &self, - ) -> Result<>::CursorMut, DatabaseError>; + fn cursor_write(&self) -> Result, DatabaseError>; /// DupCursor mut. - fn cursor_dup_write( - &self, - ) -> Result<>::DupCursorMut, DatabaseError>; + fn cursor_dup_write(&self) -> Result, DatabaseError>; } diff --git a/crates/storage/db/src/implementation/mdbx/cursor.rs b/crates/storage/db/src/implementation/mdbx/cursor.rs index 181e13e05421..63017be2d524 100644 --- a/crates/storage/db/src/implementation/mdbx/cursor.rs +++ b/crates/storage/db/src/implementation/mdbx/cursor.rs @@ -17,15 +17,15 @@ use crate::{ use reth_libmdbx::{self, Error as MDBXError, TransactionKind, WriteFlags, RO, RW}; /// Read only Cursor. -pub type CursorRO<'tx, T> = Cursor<'tx, RO, T>; +pub type CursorRO = Cursor; /// Read write cursor. -pub type CursorRW<'tx, T> = Cursor<'tx, RW, T>; +pub type CursorRW = Cursor; /// Cursor wrapper to access KV items. #[derive(Debug)] -pub struct Cursor<'tx, K: TransactionKind, T: Table> { +pub struct Cursor { /// Inner `libmdbx` cursor. - pub(crate) inner: reth_libmdbx::Cursor<'tx, K>, + pub(crate) inner: reth_libmdbx::Cursor, /// Cache buffer that receives compressed values. buf: Vec, /// Whether to record metrics or not. @@ -34,11 +34,8 @@ pub struct Cursor<'tx, K: TransactionKind, T: Table> { _dbi: PhantomData, } -impl<'tx, K: TransactionKind, T: Table> Cursor<'tx, K, T> { - pub(crate) fn new_with_metrics( - inner: reth_libmdbx::Cursor<'tx, K>, - with_metrics: bool, - ) -> Self { +impl Cursor { + pub(crate) fn new_with_metrics(inner: reth_libmdbx::Cursor, with_metrics: bool) -> Self { Self { inner, buf: Vec::new(), with_metrics, _dbi: PhantomData } } @@ -81,7 +78,7 @@ macro_rules! compress_to_buf_or_ref { }; } -impl DbCursorRO for Cursor<'_, K, T> { +impl DbCursorRO for Cursor { fn first(&mut self) -> PairResult { decode!(self.inner.first()) } @@ -110,10 +107,7 @@ impl DbCursorRO for Cursor<'_, K, T> { decode!(self.inner.get_current()) } - fn walk(&mut self, start_key: Option) -> Result, DatabaseError> - where - Self: Sized, - { + fn walk(&mut self, start_key: Option) -> Result, DatabaseError> { let start = if let Some(start_key) = start_key { self.inner .set_range(start_key.encode().as_ref()) @@ -129,10 +123,7 @@ impl DbCursorRO for Cursor<'_, K, T> { fn walk_range( &mut self, range: impl RangeBounds, - ) -> Result, DatabaseError> - where - Self: Sized, - { + ) -> Result, DatabaseError> { let start = match range.start_bound().cloned() { Bound::Included(key) => self.inner.set_range(key.encode().as_ref()), Bound::Excluded(_key) => { @@ -149,10 +140,7 @@ impl DbCursorRO for Cursor<'_, K, T> { fn walk_back( &mut self, start_key: Option, - ) -> Result, DatabaseError> - where - Self: Sized, - { + ) -> Result, DatabaseError> { let start = if let Some(start_key) = start_key { decode!(self.inner.set_range(start_key.encode().as_ref())) } else { @@ -164,7 +152,7 @@ impl DbCursorRO for Cursor<'_, K, T> { } } -impl DbDupCursorRO for Cursor<'_, K, T> { +impl DbDupCursorRO for Cursor { /// Returns the next `(key, value)` pair of a DUPSORT table. fn next_dup(&mut self) -> PairResult { decode!(self.inner.next_dup()) @@ -210,16 +198,14 @@ impl DbDupCursorRO for Cursor<'_, K, T> { let start = match (key, subkey) { (Some(key), Some(subkey)) => { // encode key and decode it after. - let key = key.encode().as_ref().to_vec(); - + let key: Vec = key.encode().into(); self.inner .get_both_range(key.as_ref(), subkey.encode().as_ref()) .map_err(|e| DatabaseError::Read(e.into()))? .map(|val| decoder::((Cow::Owned(key), val))) } (Some(key), None) => { - let key = key.encode().as_ref().to_vec(); - + let key: Vec = key.encode().into(); self.inner .set(key.as_ref()) .map_err(|e| DatabaseError::Read(e.into()))? @@ -227,8 +213,7 @@ impl DbDupCursorRO for Cursor<'_, K, T> { } (None, Some(subkey)) => { if let Some((key, _)) = self.first()? { - let key = key.encode().as_ref().to_vec(); - + let key: Vec = key.encode().into(); self.inner .get_both_range(key.as_ref(), subkey.encode().as_ref()) .map_err(|e| DatabaseError::Read(e.into()))? @@ -245,7 +230,7 @@ impl DbDupCursorRO for Cursor<'_, K, T> { } } -impl DbCursorRW for Cursor<'_, RW, T> { +impl DbCursorRW for Cursor { /// Database operation that will update an existing row if a specified value already /// exists in a table, and insert a new row if the specified value doesn't already exist /// @@ -328,7 +313,7 @@ impl DbCursorRW for Cursor<'_, RW, T> { } } -impl DbDupCursorRW for Cursor<'_, RW, T> { +impl DbDupCursorRW for Cursor { fn delete_current_duplicates(&mut self) -> Result<(), DatabaseError> { self.execute_with_operation_metric(Operation::CursorDeleteCurrentDuplicates, None, |this| { this.inner.del(WriteFlags::NO_DUP_DATA).map_err(|e| DatabaseError::Delete(e.into())) diff --git a/crates/storage/db/src/implementation/mdbx/mod.rs b/crates/storage/db/src/implementation/mdbx/mod.rs index ccb9cee30903..e532ae928e40 100644 --- a/crates/storage/db/src/implementation/mdbx/mod.rs +++ b/crates/storage/db/src/implementation/mdbx/mod.rs @@ -1,7 +1,7 @@ //! Module that interacts with MDBX. use crate::{ - database::{Database, DatabaseGAT}, + database::Database, tables::{TableType, Tables}, utils::default_page_size, DatabaseError, @@ -40,20 +40,18 @@ pub struct DatabaseEnv { with_metrics: bool, } -impl<'a> DatabaseGAT<'a> for DatabaseEnv { +impl Database for DatabaseEnv { type TX = tx::Tx; type TXMut = tx::Tx; -} -impl Database for DatabaseEnv { - fn tx(&self) -> Result<>::TX, DatabaseError> { + fn tx(&self) -> Result { Ok(Tx::new_with_metrics( self.inner.begin_ro_txn().map_err(|e| DatabaseError::InitTx(e.into()))?, self.with_metrics, )) } - fn tx_mut(&self) -> Result<>::TXMut, DatabaseError> { + fn tx_mut(&self) -> Result { Ok(Tx::new_with_metrics( self.inner.begin_rw_txn().map_err(|e| DatabaseError::InitTx(e.into()))?, self.with_metrics, diff --git a/crates/storage/db/src/implementation/mdbx/tx.rs b/crates/storage/db/src/implementation/mdbx/tx.rs index 3798587d8280..a9b4488accb3 100644 --- a/crates/storage/db/src/implementation/mdbx/tx.rs +++ b/crates/storage/db/src/implementation/mdbx/tx.rs @@ -7,13 +7,26 @@ use crate::{ }, table::{Compress, DupSort, Encode, Table, TableImporter}, tables::{utils::decode_one, Tables, NUM_TABLES}, - transaction::{DbTx, DbTxGAT, DbTxMut, DbTxMutGAT}, + transaction::{DbTx, DbTxMut}, DatabaseError, }; use parking_lot::RwLock; use reth_interfaces::db::{DatabaseWriteError, DatabaseWriteOperation}; use reth_libmdbx::{ffi::DBI, Transaction, TransactionKind, WriteFlags, RW}; -use std::{marker::PhantomData, str::FromStr, sync::Arc, time::Instant}; +use reth_tracing::tracing::debug; +use std::{ + backtrace::Backtrace, + marker::PhantomData, + str::FromStr, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::{Duration, Instant}, +}; + +/// Duration after which we emit the log about long-lived database transactions. +const LONG_TRANSACTION_DURATION: Duration = Duration::from_secs(60); /// Wrapper for the libmdbx transaction. #[derive(Debug)] @@ -38,12 +51,7 @@ impl Tx { /// Creates new `Tx` object with a `RO` or `RW` transaction and optionally enables metrics. pub fn new_with_metrics(inner: Transaction, with_metrics: bool) -> Self { let metrics_handler = with_metrics.then(|| { - let handler = MetricsHandler:: { - txn_id: inner.id(), - start: Instant::now(), - close_recorded: false, - _marker: PhantomData, - }; + let handler = MetricsHandler::::new(inner.id()); TransactionMetrics::record_open(handler.transaction_mode()); handler }); @@ -75,7 +83,7 @@ impl Tx { } /// Create db Cursor - pub fn new_cursor(&self) -> Result, DatabaseError> { + pub fn new_cursor(&self) -> Result, DatabaseError> { let inner = self .inner .cursor_with_dbi(self.get_dbi::()?) @@ -95,11 +103,12 @@ impl Tx { ) -> R { if let Some(mut metrics_handler) = self.metrics_handler.take() { metrics_handler.close_recorded = true; + metrics_handler.log_backtrace_on_long_transaction(); let start = Instant::now(); let result = f(self); - let close_duration = start.elapsed(); let open_duration = metrics_handler.start.elapsed(); + let close_duration = start.elapsed(); TransactionMetrics::record_close( metrics_handler.transaction_mode(), @@ -124,7 +133,8 @@ impl Tx { value_size: Option, f: impl FnOnce(&Transaction) -> R, ) -> R { - if self.metrics_handler.is_some() { + if let Some(metrics_handler) = &self.metrics_handler { + metrics_handler.log_backtrace_on_long_transaction(); OperationMetrics::record(T::NAME, operation, value_size, || f(&self.inner)) } else { f(&self.inner) @@ -138,13 +148,26 @@ struct MetricsHandler { txn_id: u64, /// The time when transaction has started. start: Instant, - /// If true, the metric about transaction closing has already been recorded and we don't need + /// If `true`, the metric about transaction closing has already been recorded and we don't need /// to do anything on [Drop::drop]. close_recorded: bool, + /// If `true`, the backtrace of transaction has already been recorded and logged. + /// See [MetricsHandler::log_backtrace_on_long_transaction]. + backtrace_recorded: AtomicBool, _marker: PhantomData, } impl MetricsHandler { + fn new(txn_id: u64) -> Self { + Self { + txn_id, + start: Instant::now(), + close_recorded: false, + backtrace_recorded: AtomicBool::new(false), + _marker: PhantomData, + } + } + const fn transaction_mode(&self) -> TransactionMode { if K::IS_READ_ONLY { TransactionMode::ReadOnly @@ -152,11 +175,38 @@ impl MetricsHandler { TransactionMode::ReadWrite } } + + /// Logs the backtrace of current call if the duration that the transaction has been open is + /// more than [LONG_TRANSACTION_DURATION]. + /// The backtrace is recorded and logged just once, guaranteed by `backtrace_recorded` atomic. + /// + /// NOTE: Backtrace is recorded using [Backtrace::force_capture], so `RUST_BACKTRACE` env var is + /// not needed. + fn log_backtrace_on_long_transaction(&self) { + if self.backtrace_recorded.load(Ordering::Relaxed) { + return + } + + let open_duration = self.start.elapsed(); + if open_duration > LONG_TRANSACTION_DURATION { + self.backtrace_recorded.store(true, Ordering::Relaxed); + + let backtrace = Backtrace::force_capture(); + debug!( + target: "storage::db::mdbx", + ?open_duration, + ?backtrace, + "The database transaction has been open for too long" + ); + } + } } impl Drop for MetricsHandler { fn drop(&mut self) { if !self.close_recorded { + self.log_backtrace_on_long_transaction(); + TransactionMetrics::record_close( self.transaction_mode(), TransactionOutcome::Drop, @@ -167,19 +217,12 @@ impl Drop for MetricsHandler { } } -impl<'a, K: TransactionKind> DbTxGAT<'a> for Tx { - type Cursor = Cursor<'a, K, T>; - type DupCursor = Cursor<'a, K, T>; -} - -impl<'a, K: TransactionKind> DbTxMutGAT<'a> for Tx { - type CursorMut = Cursor<'a, RW, T>; - type DupCursorMut = Cursor<'a, RW, T>; -} - impl TableImporter for Tx {} impl DbTx for Tx { + type Cursor = Cursor; + type DupCursor = Cursor; + fn get(&self, key: T::Key) -> Result::Value>, DatabaseError> { self.execute_with_operation_metric::(Operation::Get, None, |tx| { tx.get(self.get_dbi::()?, key.encode().as_ref()) @@ -202,14 +245,12 @@ impl DbTx for Tx { } // Iterate over read only values in database. - fn cursor_read(&self) -> Result<>::Cursor, DatabaseError> { + fn cursor_read(&self) -> Result, DatabaseError> { self.new_cursor() } /// Iterate over read only values in database. - fn cursor_dup_read( - &self, - ) -> Result<>::DupCursor, DatabaseError> { + fn cursor_dup_read(&self) -> Result, DatabaseError> { self.new_cursor() } @@ -224,6 +265,9 @@ impl DbTx for Tx { } impl DbTxMut for Tx { + type CursorMut = Cursor; + type DupCursorMut = Cursor; + fn put(&self, key: T::Key, value: T::Value) -> Result<(), DatabaseError> { let key = key.encode(); let value = value.compress(); @@ -268,15 +312,11 @@ impl DbTxMut for Tx { Ok(()) } - fn cursor_write( - &self, - ) -> Result<>::CursorMut, DatabaseError> { + fn cursor_write(&self) -> Result, DatabaseError> { self.new_cursor() } - fn cursor_dup_write( - &self, - ) -> Result<>::DupCursorMut, DatabaseError> { + fn cursor_dup_write(&self) -> Result, DatabaseError> { self.new_cursor() } } diff --git a/crates/storage/db/src/lib.rs b/crates/storage/db/src/lib.rs index 250177dfb182..e813bf0d1169 100644 --- a/crates/storage/db/src/lib.rs +++ b/crates/storage/db/src/lib.rs @@ -153,7 +153,7 @@ pub fn open_db(path: &Path, log_level: Option) -> eyre::Result DatabaseGAT<'a> for TempDatabase { - type TX = >::TX; - type TXMut = >::TXMut; - } - impl Database for TempDatabase { - fn tx(&self) -> Result<>::TX, DatabaseError> { + type TX = ::TX; + type TXMut = ::TXMut; + fn tx(&self) -> Result { self.db().tx() } - fn tx_mut(&self) -> Result<>::TXMut, DatabaseError> { + fn tx_mut(&self) -> Result { self.db().tx_mut() } } diff --git a/crates/storage/db/src/tables/models/accounts.rs b/crates/storage/db/src/tables/models/accounts.rs index 57533f57783e..3fe6122e2399 100644 --- a/crates/storage/db/src/tables/models/accounts.rs +++ b/crates/storage/db/src/tables/models/accounts.rs @@ -41,10 +41,7 @@ impl Compact for AccountBeforeTx { acc_len + 20 } - fn from_compact(mut buf: &[u8], len: usize) -> (Self, &[u8]) - where - Self: Sized, - { + fn from_compact(mut buf: &[u8], len: usize) -> (Self, &[u8]) { let address = Address::from_slice(&buf[..20]); buf.advance(20); diff --git a/crates/storage/libmdbx-rs/Cargo.toml b/crates/storage/libmdbx-rs/Cargo.toml index dc65f34faae2..7acda0894fc4 100644 --- a/crates/storage/libmdbx-rs/Cargo.toml +++ b/crates/storage/libmdbx-rs/Cargo.toml @@ -22,8 +22,6 @@ thiserror.workspace = true ffi = { package = "reth-mdbx-sys", path = "./mdbx-sys" } -lifetimed-bytes = { version = "0.1", optional = true } - [features] default = [] return-borrowed = [] diff --git a/crates/storage/libmdbx-rs/benches/cursor.rs b/crates/storage/libmdbx-rs/benches/cursor.rs index 78044e45b9fe..89c87c6f417a 100644 --- a/crates/storage/libmdbx-rs/benches/cursor.rs +++ b/crates/storage/libmdbx-rs/benches/cursor.rs @@ -33,7 +33,7 @@ fn bench_get_seq_iter(c: &mut Criterion) { count += 1; } - fn iterate(cursor: &mut Cursor<'_, K>) -> Result<()> { + fn iterate(cursor: &mut Cursor) -> Result<()> { let mut i = 0; for result in cursor.iter::() { let (key_len, data_len) = result?; diff --git a/crates/storage/libmdbx-rs/src/codec.rs b/crates/storage/libmdbx-rs/src/codec.rs index f313492d7ac4..024c869cfb63 100644 --- a/crates/storage/libmdbx-rs/src/codec.rs +++ b/crates/storage/libmdbx-rs/src/codec.rs @@ -3,11 +3,9 @@ use derive_more::*; use std::{borrow::Cow, slice}; /// Implement this to be able to decode data values -pub trait TableObject<'tx> { +pub trait TableObject: Sized { /// Decodes the object from the given bytes. - fn decode(data_val: &[u8]) -> Result - where - Self: Sized; + fn decode(data_val: &[u8]) -> Result; /// Decodes the value directly from the given MDBX_val pointer. /// @@ -17,18 +15,14 @@ pub trait TableObject<'tx> { #[doc(hidden)] unsafe fn decode_val( _: *const ffi::MDBX_txn, - data_val: &ffi::MDBX_val, - ) -> Result - where - Self: Sized, - { + data_val: ffi::MDBX_val, + ) -> Result { let s = slice::from_raw_parts(data_val.iov_base as *const u8, data_val.iov_len); - - TableObject::decode(s) + Self::decode(s) } } -impl<'tx> TableObject<'tx> for Cow<'tx, [u8]> { +impl<'tx> TableObject for Cow<'tx, [u8]> { fn decode(_: &[u8]) -> Result { unreachable!() } @@ -36,7 +30,7 @@ impl<'tx> TableObject<'tx> for Cow<'tx, [u8]> { #[doc(hidden)] unsafe fn decode_val( _txn: *const ffi::MDBX_txn, - data_val: &ffi::MDBX_val, + data_val: ffi::MDBX_val, ) -> Result { let s = slice::from_raw_parts(data_val.iov_base as *const u8, data_val.iov_len); @@ -55,38 +49,20 @@ impl<'tx> TableObject<'tx> for Cow<'tx, [u8]> { } } -#[cfg(feature = "lifetimed-bytes")] -impl<'tx> TableObject<'tx> for lifetimed_bytes::Bytes<'tx> { - fn decode(_: &[u8]) -> Result { - unreachable!() - } - - #[doc(hidden)] - unsafe fn decode_val( - txn: *const ffi::MDBX_txn, - data_val: &ffi::MDBX_val, - ) -> Result { - Cow::<'tx, [u8]>::decode_val::(txn, data_val).map(From::from) - } -} - -impl<'tx> TableObject<'tx> for Vec { - fn decode(data_val: &[u8]) -> Result - where - Self: Sized, - { +impl TableObject for Vec { + fn decode(data_val: &[u8]) -> Result { Ok(data_val.to_vec()) } } -impl<'tx> TableObject<'tx> for () { +impl TableObject for () { fn decode(_: &[u8]) -> Result { Ok(()) } unsafe fn decode_val( _: *const ffi::MDBX_txn, - _: &ffi::MDBX_val, + _: ffi::MDBX_val, ) -> Result { Ok(()) } @@ -96,20 +72,14 @@ impl<'tx> TableObject<'tx> for () { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deref, DerefMut)] pub struct ObjectLength(pub usize); -impl<'tx> TableObject<'tx> for ObjectLength { - fn decode(data_val: &[u8]) -> Result - where - Self: Sized, - { +impl TableObject for ObjectLength { + fn decode(data_val: &[u8]) -> Result { Ok(Self(data_val.len())) } } -impl<'tx, const LEN: usize> TableObject<'tx> for [u8; LEN] { - fn decode(data_val: &[u8]) -> Result - where - Self: Sized, - { +impl TableObject for [u8; LEN] { + fn decode(data_val: &[u8]) -> Result { if data_val.len() != LEN { return Err(Error::DecodeErrorLenDiff) } diff --git a/crates/storage/libmdbx-rs/src/cursor.rs b/crates/storage/libmdbx-rs/src/cursor.rs index dd242586e073..a5cb2a3830a8 100644 --- a/crates/storage/libmdbx-rs/src/cursor.rs +++ b/crates/storage/libmdbx-rs/src/cursor.rs @@ -2,7 +2,7 @@ use crate::{ error::{mdbx_result, Error, Result}, flags::*, mdbx_try_optional, - transaction::{TransactionKind, TransactionPtr, RW}, + transaction::{TransactionKind, RW}, TableObject, Transaction, }; use ffi::{ @@ -15,26 +15,24 @@ use libc::c_void; use std::{borrow::Cow, fmt, marker::PhantomData, mem, ptr}; /// A cursor for navigating the items within a database. -pub struct Cursor<'txn, K> +pub struct Cursor where K: TransactionKind, { - txn: TransactionPtr, + txn: Transaction, cursor: *mut ffi::MDBX_cursor, - _marker: PhantomData, } -impl<'txn, K> Cursor<'txn, K> +impl Cursor where K: TransactionKind, { - pub(crate) fn new(txn: &'txn Transaction, dbi: ffi::MDBX_dbi) -> Result { + pub(crate) fn new(txn: Transaction, dbi: ffi::MDBX_dbi) -> Result { let mut cursor: *mut ffi::MDBX_cursor = ptr::null_mut(); - let txn = txn.txn_ptr(); unsafe { mdbx_result(txn.txn_execute(|txn| ffi::mdbx_cursor_open(txn, dbi, &mut cursor)))?; } - Ok(Self { txn, cursor, _marker: PhantomData }) + Ok(Self { txn, cursor }) } fn new_at_position(other: &Self) -> Result { @@ -43,7 +41,7 @@ where let res = ffi::mdbx_cursor_copy(other.cursor(), cursor); - let s = Self { txn: other.txn.clone(), cursor, _marker: PhantomData }; + let s = Self { txn: other.txn.clone(), cursor }; mdbx_result(res)?; @@ -59,6 +57,22 @@ where self.cursor } + /// Returns an iterator over the raw key value slices. + #[allow(clippy::needless_lifetimes)] + pub fn iter_slices<'a>(&'a self) -> IntoIter<'a, K, Cow<'a, [u8]>, Cow<'a, [u8]>> { + self.into_iter() + } + + /// Returns an iterator over database items. + #[allow(clippy::should_implement_trait)] + pub fn into_iter(&self) -> IntoIter<'_, K, Key, Value> + where + Key: TableObject, + Value: TableObject, + { + IntoIter::new(self.clone(), MDBX_NEXT, MDBX_NEXT) + } + /// Retrieves a key/data pair from the cursor. Depending on the cursor op, /// the current key may be returned. fn get( @@ -68,8 +82,8 @@ where op: MDBX_cursor_op, ) -> Result<(Option, Value, bool)> where - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { unsafe { let mut key_val = slice_to_val(key); @@ -87,12 +101,12 @@ where let key_out = { // MDBX wrote in new key if key_ptr != key_val.iov_base { - Some(Key::decode_val::(txn, &key_val)?) + Some(Key::decode_val::(txn, key_val)?) } else { None } }; - let data_out = Value::decode_val::(txn, &data_val)?; + let data_out = Value::decode_val::(txn, data_val)?; Ok((key_out, data_out, v)) }) } @@ -105,7 +119,7 @@ where op: MDBX_cursor_op, ) -> Result> where - Value: TableObject<'txn>, + Value: TableObject, { let (_, v, _) = mdbx_try_optional!(self.get::<(), Value>(key, data, op)); @@ -119,8 +133,8 @@ where op: MDBX_cursor_op, ) -> Result> where - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { let (k, v, _) = mdbx_try_optional!(self.get(key, data, op)); @@ -130,8 +144,8 @@ where /// Position at first key/data item. pub fn first(&mut self) -> Result> where - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { self.get_full(None, None, MDBX_FIRST) } @@ -139,7 +153,7 @@ where /// [DatabaseFlags::DUP_SORT]-only: Position at first data item of current key. pub fn first_dup(&mut self) -> Result> where - Value: TableObject<'txn>, + Value: TableObject, { self.get_value(None, None, MDBX_FIRST_DUP) } @@ -147,7 +161,7 @@ where /// [DatabaseFlags::DUP_SORT]-only: Position at key/data pair. pub fn get_both(&mut self, k: &[u8], v: &[u8]) -> Result> where - Value: TableObject<'txn>, + Value: TableObject, { self.get_value(Some(k), Some(v), MDBX_GET_BOTH) } @@ -156,7 +170,7 @@ where /// equal to specified data. pub fn get_both_range(&mut self, k: &[u8], v: &[u8]) -> Result> where - Value: TableObject<'txn>, + Value: TableObject, { self.get_value(Some(k), Some(v), MDBX_GET_BOTH_RANGE) } @@ -164,8 +178,8 @@ where /// Return key/data at current cursor position. pub fn get_current(&mut self) -> Result> where - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { self.get_full(None, None, MDBX_GET_CURRENT) } @@ -174,7 +188,7 @@ where /// Move cursor to prepare for [Self::next_multiple()]. pub fn get_multiple(&mut self) -> Result> where - Value: TableObject<'txn>, + Value: TableObject, { self.get_value(None, None, MDBX_GET_MULTIPLE) } @@ -182,8 +196,8 @@ where /// Position at last key/data item. pub fn last(&mut self) -> Result> where - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { self.get_full(None, None, MDBX_LAST) } @@ -191,7 +205,7 @@ where /// DupSort-only: Position at last data item of current key. pub fn last_dup(&mut self) -> Result> where - Value: TableObject<'txn>, + Value: TableObject, { self.get_value(None, None, MDBX_LAST_DUP) } @@ -200,8 +214,8 @@ where #[allow(clippy::should_implement_trait)] pub fn next(&mut self) -> Result> where - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { self.get_full(None, None, MDBX_NEXT) } @@ -209,8 +223,8 @@ where /// [DatabaseFlags::DUP_SORT]-only: Position at next data item of current key. pub fn next_dup(&mut self) -> Result> where - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { self.get_full(None, None, MDBX_NEXT_DUP) } @@ -219,8 +233,8 @@ where /// cursor position. Move cursor to prepare for MDBX_NEXT_MULTIPLE. pub fn next_multiple(&mut self) -> Result> where - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { self.get_full(None, None, MDBX_NEXT_MULTIPLE) } @@ -228,8 +242,8 @@ where /// Position at first data item of next key. pub fn next_nodup(&mut self) -> Result> where - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { self.get_full(None, None, MDBX_NEXT_NODUP) } @@ -237,8 +251,8 @@ where /// Position at previous data item. pub fn prev(&mut self) -> Result> where - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { self.get_full(None, None, MDBX_PREV) } @@ -246,8 +260,8 @@ where /// [DatabaseFlags::DUP_SORT]-only: Position at previous data item of current key. pub fn prev_dup(&mut self) -> Result> where - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { self.get_full(None, None, MDBX_PREV_DUP) } @@ -255,8 +269,8 @@ where /// Position at last data item of previous key. pub fn prev_nodup(&mut self) -> Result> where - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { self.get_full(None, None, MDBX_PREV_NODUP) } @@ -264,7 +278,7 @@ where /// Position at specified key. pub fn set(&mut self, key: &[u8]) -> Result> where - Value: TableObject<'txn>, + Value: TableObject, { self.get_value(Some(key), None, MDBX_SET) } @@ -272,8 +286,8 @@ where /// Position at specified key, return both key and data. pub fn set_key(&mut self, key: &[u8]) -> Result> where - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { self.get_full(Some(key), None, MDBX_SET_KEY) } @@ -281,8 +295,8 @@ where /// Position at first key greater than or equal to specified key. pub fn set_range(&mut self, key: &[u8]) -> Result> where - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { self.get_full(Some(key), None, MDBX_SET_RANGE) } @@ -291,8 +305,8 @@ where /// duplicate data items. pub fn prev_multiple(&mut self) -> Result> where - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { self.get_full(None, None, MDBX_PREV_MULTIPLE) } @@ -308,26 +322,26 @@ where /// exactly and [true] if the next pair was returned. pub fn set_lowerbound(&mut self, key: &[u8]) -> Result> where - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { let (k, v, found) = mdbx_try_optional!(self.get(Some(key), None, MDBX_SET_LOWERBOUND)); Ok(Some((found, k.unwrap(), v))) } - /// Iterate over database items. The iterator will begin with item next - /// after the cursor, and continue until the end of the database. For new - /// cursors, the iterator will begin with the first item in the database. + /// Returns an iterator over database items. + /// + /// The iterator will begin with item next after the cursor, and continue until the end of the + /// database. For new cursors, the iterator will begin with the first item in the database. /// /// For databases with duplicate data items ([DatabaseFlags::DUP_SORT]), the /// duplicate data items of each key will be returned before moving on to /// the next key. - pub fn iter(&mut self) -> Iter<'txn, '_, K, Key, Value> + pub fn iter(&mut self) -> Iter<'_, K, Key, Value> where - Self: Sized, - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { Iter::new(self, ffi::MDBX_NEXT, ffi::MDBX_NEXT) } @@ -337,11 +351,10 @@ where /// For databases with duplicate data items ([DatabaseFlags::DUP_SORT]), the /// duplicate data items of each key will be returned before moving on to /// the next key. - pub fn iter_start(&mut self) -> Iter<'txn, '_, K, Key, Value> + pub fn iter_start(&mut self) -> Iter<'_, K, Key, Value> where - Self: Sized, - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { Iter::new(self, ffi::MDBX_FIRST, ffi::MDBX_NEXT) } @@ -351,10 +364,10 @@ where /// For databases with duplicate data items ([DatabaseFlags::DUP_SORT]), the /// duplicate data items of each key will be returned before moving on to /// the next key. - pub fn iter_from(&mut self, key: &[u8]) -> Iter<'txn, '_, K, Key, Value> + pub fn iter_from(&mut self, key: &[u8]) -> Iter<'_, K, Key, Value> where - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { let res: Result> = self.set_range(key); if let Err(error) = res { @@ -366,30 +379,30 @@ where /// Iterate over duplicate database items. The iterator will begin with the /// item next after the cursor, and continue until the end of the database. /// Each item will be returned as an iterator of its duplicates. - pub fn iter_dup(&mut self) -> IterDup<'txn, '_, K, Key, Value> + pub fn iter_dup(&mut self) -> IterDup<'_, K, Key, Value> where - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { IterDup::new(self, ffi::MDBX_NEXT) } /// Iterate over duplicate database items starting from the beginning of the /// database. Each item will be returned as an iterator of its duplicates. - pub fn iter_dup_start(&mut self) -> IterDup<'txn, '_, K, Key, Value> + pub fn iter_dup_start(&mut self) -> IterDup<'_, K, Key, Value> where - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { IterDup::new(self, ffi::MDBX_FIRST) } /// Iterate over duplicate items in the database starting from the given /// key. Each item will be returned as an iterator of its duplicates. - pub fn iter_dup_from(&mut self, key: &[u8]) -> IterDup<'txn, '_, K, Key, Value> + pub fn iter_dup_from(&mut self, key: &[u8]) -> IterDup<'_, K, Key, Value> where - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { let res: Result> = self.set_range(key); if let Err(error) = res { @@ -399,10 +412,10 @@ where } /// Iterate over the duplicates of the item in the database with the given key. - pub fn iter_dup_of(&mut self, key: &[u8]) -> Iter<'txn, '_, K, Key, Value> + pub fn iter_dup_of(&mut self, key: &[u8]) -> Iter<'_, K, Key, Value> where - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { let res: Result> = self.set(key); match res { @@ -417,7 +430,7 @@ where } } -impl<'txn> Cursor<'txn, RW> { +impl Cursor { /// Puts a key/data pair into the database. The cursor will be positioned at /// the new data item, or on failure usually near it. pub fn put(&mut self, key: &[u8], data: &[u8], flags: WriteFlags) -> Result<()> { @@ -449,7 +462,7 @@ impl<'txn> Cursor<'txn, RW> { } } -impl<'txn, K> Clone for Cursor<'txn, K> +impl Clone for Cursor where K: TransactionKind, { @@ -458,7 +471,7 @@ where } } -impl<'txn, K> fmt::Debug for Cursor<'txn, K> +impl fmt::Debug for Cursor where K: TransactionKind, { @@ -467,7 +480,7 @@ where } } -impl<'txn, K> Drop for Cursor<'txn, K> +impl Drop for Cursor where K: TransactionKind, { @@ -485,28 +498,16 @@ unsafe fn slice_to_val(slice: Option<&[u8]>) -> ffi::MDBX_val { } } -unsafe impl<'txn, K> Send for Cursor<'txn, K> where K: TransactionKind {} -unsafe impl<'txn, K> Sync for Cursor<'txn, K> where K: TransactionKind {} - -impl<'txn, K> IntoIterator for Cursor<'txn, K> -where - K: TransactionKind, -{ - type Item = Result<(Cow<'txn, [u8]>, Cow<'txn, [u8]>)>; - type IntoIter = IntoIter<'txn, K, Cow<'txn, [u8]>, Cow<'txn, [u8]>>; - - fn into_iter(self) -> Self::IntoIter { - IntoIter::new(self, MDBX_NEXT, MDBX_NEXT) - } -} +unsafe impl Send for Cursor where K: TransactionKind {} +unsafe impl Sync for Cursor where K: TransactionKind {} /// An iterator over the key/value pairs in an MDBX database. #[derive(Debug)] -pub enum IntoIter<'txn, K, Key, Value> +pub enum IntoIter<'cur, K, Key, Value> where K: TransactionKind, - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { /// An iterator that returns an error on every call to [Iter::next()]. /// Cursor.iter*() creates an Iter of this type when MDBX returns an error @@ -521,7 +522,7 @@ where /// fails for some reason. Ok { /// The MDBX cursor with which to iterate. - cursor: Cursor<'txn, K>, + cursor: Cursor, /// The first operation to perform when the consumer calls [Iter::next()]. op: ffi::MDBX_cursor_op, @@ -529,33 +530,33 @@ where /// The next and subsequent operations to perform. next_op: ffi::MDBX_cursor_op, - _marker: PhantomData, + _marker: PhantomData<(&'cur (), Key, Value)>, }, } -impl<'txn, K, Key, Value> IntoIter<'txn, K, Key, Value> +impl<'cur, K, Key, Value> IntoIter<'cur, K, Key, Value> where K: TransactionKind, - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { /// Creates a new iterator backed by the given cursor. - fn new(cursor: Cursor<'txn, K>, op: ffi::MDBX_cursor_op, next_op: ffi::MDBX_cursor_op) -> Self { - IntoIter::Ok { cursor, op, next_op, _marker: PhantomData } + fn new(cursor: Cursor, op: ffi::MDBX_cursor_op, next_op: ffi::MDBX_cursor_op) -> Self { + IntoIter::Ok { cursor, op, next_op, _marker: Default::default() } } } -impl<'txn, K, Key, Value> Iterator for IntoIter<'txn, K, Key, Value> +impl<'cur, K, Key, Value> Iterator for IntoIter<'cur, K, Key, Value> where K: TransactionKind, - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { type Item = Result<(Key, Value)>; fn next(&mut self) -> Option { match self { - Self::Ok { cursor, op, next_op, _marker } => { + Self::Ok { cursor, op, next_op, .. } => { let mut key = ffi::MDBX_val { iov_len: 0, iov_base: ptr::null_mut() }; let mut data = ffi::MDBX_val { iov_len: 0, iov_base: ptr::null_mut() }; let op = mem::replace(op, *next_op); @@ -563,11 +564,11 @@ where cursor.txn.txn_execute(|txn| { match ffi::mdbx_cursor_get(cursor.cursor(), &mut key, &mut data, op) { ffi::MDBX_SUCCESS => { - let key = match Key::decode_val::(txn, &key) { + let key = match Key::decode_val::(txn, key) { Ok(v) => v, Err(e) => return Some(Err(e)), }; - let data = match Value::decode_val::(txn, &data) { + let data = match Value::decode_val::(txn, data) { Ok(v) => v, Err(e) => return Some(Err(e)), }; @@ -589,11 +590,11 @@ where /// An iterator over the key/value pairs in an MDBX database. #[derive(Debug)] -pub enum Iter<'txn, 'cur, K, Key, Value> +pub enum Iter<'cur, K, Key, Value> where K: TransactionKind, - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { /// An iterator that returns an error on every call to [Iter::next()]. /// Cursor.iter*() creates an Iter of this type when MDBX returns an error @@ -608,7 +609,7 @@ where /// fails for some reason. Ok { /// The MDBX cursor with which to iterate. - cursor: &'cur mut Cursor<'txn, K>, + cursor: &'cur mut Cursor, /// The first operation to perform when the consumer calls [Iter::next()]. op: ffi::MDBX_cursor_op, @@ -616,31 +617,31 @@ where /// The next and subsequent operations to perform. next_op: ffi::MDBX_cursor_op, - _marker: PhantomData, + _marker: PhantomData, }, } -impl<'txn, 'cur, K, Key, Value> Iter<'txn, 'cur, K, Key, Value> +impl<'cur, K, Key, Value> Iter<'cur, K, Key, Value> where K: TransactionKind, - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { /// Creates a new iterator backed by the given cursor. fn new( - cursor: &'cur mut Cursor<'txn, K>, + cursor: &'cur mut Cursor, op: ffi::MDBX_cursor_op, next_op: ffi::MDBX_cursor_op, ) -> Self { - Iter::Ok { cursor, op, next_op, _marker: PhantomData } + Iter::Ok { cursor, op, next_op, _marker: Default::default() } } } -impl<'txn, 'cur, K, Key, Value> Iterator for Iter<'txn, 'cur, K, Key, Value> +impl<'cur, K, Key, Value> Iterator for Iter<'cur, K, Key, Value> where K: TransactionKind, - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { type Item = Result<(Key, Value)>; @@ -654,11 +655,11 @@ where cursor.txn.txn_execute(|txn| { match ffi::mdbx_cursor_get(cursor.cursor(), &mut key, &mut data, op) { ffi::MDBX_SUCCESS => { - let key = match Key::decode_val::(txn, &key) { + let key = match Key::decode_val::(txn, key) { Ok(v) => v, Err(e) => return Some(Err(e)), }; - let data = match Value::decode_val::(txn, &data) { + let data = match Value::decode_val::(txn, data) { Ok(v) => v, Err(e) => return Some(Err(e)), }; @@ -682,11 +683,11 @@ where /// /// The yielded items of the iterator are themselves iterators over the duplicate values for a /// specific key. -pub enum IterDup<'txn, 'cur, K, Key, Value> +pub enum IterDup<'cur, K, Key, Value> where K: TransactionKind, - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { /// An iterator that returns an error on every call to Iter.next(). /// Cursor.iter*() creates an Iter of this type when MDBX returns an error @@ -701,45 +702,45 @@ where /// fails for some reason. Ok { /// The MDBX cursor with which to iterate. - cursor: &'cur mut Cursor<'txn, K>, + cursor: &'cur mut Cursor, /// The first operation to perform when the consumer calls Iter.next(). op: MDBX_cursor_op, - _marker: PhantomData, + _marker: PhantomData, }, } -impl<'txn, 'cur, K, Key, Value> IterDup<'txn, 'cur, K, Key, Value> +impl<'cur, K, Key, Value> IterDup<'cur, K, Key, Value> where K: TransactionKind, - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { /// Creates a new iterator backed by the given cursor. - fn new(cursor: &'cur mut Cursor<'txn, K>, op: MDBX_cursor_op) -> Self { - IterDup::Ok { cursor, op, _marker: PhantomData } + fn new(cursor: &'cur mut Cursor, op: MDBX_cursor_op) -> Self { + IterDup::Ok { cursor, op, _marker: Default::default() } } } -impl<'txn, 'cur, K, Key, Value> fmt::Debug for IterDup<'txn, 'cur, K, Key, Value> +impl<'cur, K, Key, Value> fmt::Debug for IterDup<'cur, K, Key, Value> where K: TransactionKind, - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("IterDup").finish() } } -impl<'txn, 'cur, K, Key, Value> Iterator for IterDup<'txn, 'cur, K, Key, Value> +impl<'cur, K, Key, Value> Iterator for IterDup<'cur, K, Key, Value> where K: TransactionKind, - Key: TableObject<'txn>, - Value: TableObject<'txn>, + Key: TableObject, + Value: TableObject, { - type Item = IntoIter<'txn, K, Key, Value>; + type Item = IntoIter<'cur, K, Key, Value>; fn next(&mut self) -> Option { match self { diff --git a/crates/storage/libmdbx-rs/src/environment.rs b/crates/storage/libmdbx-rs/src/environment.rs index 80d63cc1b241..0b83a243c0f4 100644 --- a/crates/storage/libmdbx-rs/src/environment.rs +++ b/crates/storage/libmdbx-rs/src/environment.rs @@ -208,7 +208,7 @@ impl Environment { let db = Database::freelist_db(); let cursor = txn.cursor(&db)?; - for result in cursor { + for result in cursor.iter_slices() { let (_key, value) = result?; if value.len() < size_of::() { return Err(Error::Corrupted) diff --git a/crates/storage/libmdbx-rs/src/transaction.rs b/crates/storage/libmdbx-rs/src/transaction.rs index c75168e0452b..2330c3a3f891 100644 --- a/crates/storage/libmdbx-rs/src/transaction.rs +++ b/crates/storage/libmdbx-rs/src/transaction.rs @@ -107,11 +107,6 @@ where self.inner.txn_execute(f) } - /// Returns a copy of the pointer to the underlying MDBX transaction. - pub(crate) fn txn_ptr(&self) -> TransactionPtr { - self.inner.txn.clone() - } - /// Returns a copy of the raw pointer to the underlying MDBX transaction. #[doc(hidden)] pub fn txn(&self) -> *mut ffi::MDBX_txn { @@ -151,9 +146,9 @@ where /// returned. Retrieval of other items requires the use of /// [Cursor]. If the item is not in the database, then /// [None] will be returned. - pub fn get<'txn, Key>(&'txn self, dbi: ffi::MDBX_dbi, key: &[u8]) -> Result> + pub fn get(&self, dbi: ffi::MDBX_dbi, key: &[u8]) -> Result> where - Key: TableObject<'txn>, + Key: TableObject, { let key_val: ffi::MDBX_val = ffi::MDBX_val { iov_len: key.len(), iov_base: key.as_ptr() as *mut c_void }; @@ -161,7 +156,7 @@ where self.txn_execute(|txn| unsafe { match ffi::mdbx_get(txn, dbi, &key_val, &mut data_val) { - ffi::MDBX_SUCCESS => Key::decode_val::(txn, &data_val).map(Some), + ffi::MDBX_SUCCESS => Key::decode_val::(txn, data_val).map(Some), ffi::MDBX_NOTFOUND => Ok(None), err_code => Err(Error::from_err_code(err_code)), } @@ -257,13 +252,31 @@ where } /// Open a new cursor on the given database. - pub fn cursor(&self, db: &Database) -> Result> { - Cursor::new(self, db.dbi()) + pub fn cursor(&self, db: &Database) -> Result> { + Cursor::new(self.clone(), db.dbi()) } /// Open a new cursor on the given dbi. - pub fn cursor_with_dbi(&self, dbi: ffi::MDBX_dbi) -> Result> { - Cursor::new(self, dbi) + pub fn cursor_with_dbi(&self, dbi: ffi::MDBX_dbi) -> Result> { + Cursor::new(self.clone(), dbi) + } +} + +impl Clone for Transaction +where + K: TransactionKind, +{ + fn clone(&self) -> Self { + Self { inner: Arc::clone(&self.inner) } + } +} + +impl fmt::Debug for Transaction +where + K: TransactionKind, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RoTransaction").finish_non_exhaustive() } } @@ -499,15 +512,6 @@ impl Transaction { } } -impl fmt::Debug for Transaction -where - K: TransactionKind, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("RoTransaction").finish_non_exhaustive() - } -} - /// A shareable pointer to an MDBX transaction. #[derive(Clone)] pub(crate) struct TransactionPtr { diff --git a/crates/storage/provider/Cargo.toml b/crates/storage/provider/Cargo.toml index d022ca972017..1cecabfce4a2 100644 --- a/crates/storage/provider/Cargo.toml +++ b/crates/storage/provider/Cargo.toml @@ -55,7 +55,7 @@ assert_matches.workspace = true rand.workspace = true [features] -test-utils = ["alloy-rlp"] +test-utils = ["alloy-rlp", "reth-db/test-utils"] optimism = [ "reth-primitives/optimism", "reth-interfaces/optimism" diff --git a/crates/storage/provider/src/bundle_state/bundle_state_with_receipts.rs b/crates/storage/provider/src/bundle_state/bundle_state_with_receipts.rs index c1ee9b65f477..2a65909fdb65 100644 --- a/crates/storage/provider/src/bundle_state/bundle_state_with_receipts.rs +++ b/crates/storage/provider/src/bundle_state/bundle_state_with_receipts.rs @@ -13,6 +13,7 @@ use reth_primitives::{ }; use reth_trie::{ hashed_cursor::{HashedPostState, HashedPostStateCursorFactory, HashedStorage}, + updates::TrieUpdates, StateRoot, StateRootError, }; use revm::{db::states::BundleState, primitives::AccountInfo}; @@ -154,6 +155,20 @@ impl BundleStateWithReceipts { hashed_state.sorted() } + /// Returns [StateRoot] calculator. + fn state_root_calculator<'a, 'b, TX: DbTx>( + &self, + tx: &'a TX, + hashed_post_state: &'b HashedPostState, + ) -> StateRoot<'a, TX, HashedPostStateCursorFactory<'a, 'b, TX>> { + let (account_prefix_set, storage_prefix_set) = hashed_post_state.construct_prefix_sets(); + let hashed_cursor_factory = HashedPostStateCursorFactory::new(tx, hashed_post_state); + StateRoot::new(tx) + .with_hashed_cursor_factory(hashed_cursor_factory) + .with_changed_account_prefixes(account_prefix_set) + .with_changed_storage_prefixes(storage_prefix_set) + } + /// Calculate the state root for this [BundleState]. /// Internally, function calls [Self::hash_state_slow] to obtain the [HashedPostState]. /// Afterwards, it retrieves the prefixsets from the [HashedPostState] and uses them to @@ -196,13 +211,17 @@ impl BundleStateWithReceipts { /// The state root for this [BundleState]. pub fn state_root_slow(&self, tx: &TX) -> Result { let hashed_post_state = self.hash_state_slow(); - let (account_prefix_set, storage_prefix_set) = hashed_post_state.construct_prefix_sets(); - let hashed_cursor_factory = HashedPostStateCursorFactory::new(tx, &hashed_post_state); - StateRoot::new(tx) - .with_hashed_cursor_factory(hashed_cursor_factory) - .with_changed_account_prefixes(account_prefix_set) - .with_changed_storage_prefixes(storage_prefix_set) - .root() + self.state_root_calculator(tx, &hashed_post_state).root() + } + + /// Calculates the state root for this [BundleState] and returns it alongside trie updates. + /// See [Self::state_root_slow] for more info. + pub fn state_root_slow_with_updates( + &self, + tx: &TX, + ) -> Result<(B256, TrieUpdates), StateRootError> { + let hashed_post_state = self.hash_state_slow(); + self.state_root_calculator(tx, &hashed_post_state).root_with_updates() } /// Transform block number to the index of block. @@ -362,7 +381,7 @@ impl BundleStateWithReceipts { #[cfg(test)] mod tests { use super::*; - use crate::{AccountReader, BundleStateWithReceipts, ProviderFactory}; + use crate::{test_utils::create_test_provider_factory, AccountReader, BundleStateWithReceipts}; use reth_db::{ cursor::{DbCursorRO, DbDupCursorRO}, database::Database, @@ -372,7 +391,7 @@ mod tests { transaction::DbTx, }; use reth_primitives::{ - revm::compat::into_reth_acc, Address, Receipt, Receipts, StorageEntry, B256, MAINNET, U256, + revm::compat::into_reth_acc, Address, Receipt, Receipts, StorageEntry, B256, U256, }; use reth_trie::test_utils::state_root; use revm::{ @@ -394,8 +413,7 @@ mod tests { #[test] fn write_to_db_account_info() { - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db, MAINNET.clone()); + let factory = create_test_provider_factory(); let provider = factory.provider_rw().unwrap(); let address_a = Address::ZERO; @@ -533,8 +551,7 @@ mod tests { #[test] fn write_to_db_storage() { - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db, MAINNET.clone()); + let factory = create_test_provider_factory(); let provider = factory.provider_rw().unwrap(); let address_a = Address::ZERO; @@ -722,8 +739,7 @@ mod tests { #[test] fn write_to_db_multiple_selfdestructs() { - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db, MAINNET.clone()); + let factory = create_test_provider_factory(); let provider = factory.provider_rw().unwrap(); let address1 = Address::random(); @@ -1031,8 +1047,7 @@ mod tests { #[test] fn storage_change_after_selfdestruct_within_block() { - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db, MAINNET.clone()); + let factory = create_test_provider_factory(); let provider = factory.provider_rw().unwrap(); let address1 = Address::random(); diff --git a/crates/storage/provider/src/bundle_state/state_changes.rs b/crates/storage/provider/src/bundle_state/state_changes.rs index 765fc0ee20c8..a62606dedebc 100644 --- a/crates/storage/provider/src/bundle_state/state_changes.rs +++ b/crates/storage/provider/src/bundle_state/state_changes.rs @@ -19,7 +19,7 @@ impl From for StateChanges { } impl StateChanges { - /// Write the post state to the database. + /// Write the bundle state to the database. pub fn write_to_db(mut self, tx: &TX) -> Result<(), DatabaseError> { // sort all entries so they can be written to database in more performant way. // and take smaller memory footprint. @@ -28,28 +28,28 @@ impl StateChanges { self.0.contracts.par_sort_by_key(|a| a.0); // Write new account state - tracing::trace!(target: "provider::post_state", len = self.0.accounts.len(), "Writing new account state"); + tracing::trace!(target: "provider::bundle_state", len = self.0.accounts.len(), "Writing new account state"); let mut accounts_cursor = tx.cursor_write::()?; // write account to database. for (address, account) in self.0.accounts.into_iter() { if let Some(account) = account { - tracing::trace!(target: "provider::post_state", ?address, "Updating plain state account"); + tracing::trace!(target: "provider::bundle_state", ?address, "Updating plain state account"); accounts_cursor.upsert(address, into_reth_acc(account))?; } else if accounts_cursor.seek_exact(address)?.is_some() { - tracing::trace!(target: "provider::post_state", ?address, "Deleting plain state account"); + tracing::trace!(target: "provider::bundle_state", ?address, "Deleting plain state account"); accounts_cursor.delete_current()?; } } // Write bytecode - tracing::trace!(target: "provider::post_state", len = self.0.contracts.len(), "Writing bytecodes"); + tracing::trace!(target: "provider::bundle_state", len = self.0.contracts.len(), "Writing bytecodes"); let mut bytecodes_cursor = tx.cursor_write::()?; for (hash, bytecode) in self.0.contracts.into_iter() { bytecodes_cursor.upsert(hash, Bytecode(bytecode))?; } // Write new storage state and wipe storage if needed. - tracing::trace!(target: "provider::post_state", len = self.0.storage.len(), "Writing new storage state"); + tracing::trace!(target: "provider::bundle_state", len = self.0.storage.len(), "Writing new storage state"); let mut storages_cursor = tx.cursor_dup_write::()?; for PlainStorageChangeset { address, wipe_storage, storage } in self.0.storage.into_iter() { // Wiping of storage. @@ -65,7 +65,7 @@ impl StateChanges { storage.par_sort_unstable_by_key(|a| a.key); for entry in storage.into_iter() { - tracing::trace!(target: "provider::post_state", ?address, ?entry.key, "Updating plain state storage"); + tracing::trace!(target: "provider::bundle_state", ?address, ?entry.key, "Updating plain state storage"); if let Some(db_entry) = storages_cursor.seek_by_key_subkey(address, entry.key)? { if db_entry.key == entry.key { storages_cursor.delete_current()?; diff --git a/crates/storage/provider/src/chain.rs b/crates/storage/provider/src/chain.rs index 53bc31a941ce..3a6ae3fffc9f 100644 --- a/crates/storage/provider/src/chain.rs +++ b/crates/storage/provider/src/chain.rs @@ -16,16 +16,29 @@ use std::{borrow::Cow, collections::BTreeMap, fmt}; /// Used inside the BlockchainTree. #[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct Chain { + /// All blocks in this chain. + blocks: BTreeMap, /// The state of all accounts after execution of the _all_ blocks in this chain's range from /// [Chain::first] to [Chain::tip], inclusive. /// /// This state also contains the individual changes that lead to the current state. - pub state: BundleStateWithReceipts, - /// All blocks in this chain. - pub blocks: BTreeMap, + state: BundleStateWithReceipts, } impl Chain { + /// Create new Chain from blocks and state. + pub fn new( + blocks: impl IntoIterator, + state: BundleStateWithReceipts, + ) -> Self { + Self { blocks: BTreeMap::from_iter(blocks.into_iter().map(|b| (b.number, b))), state } + } + + /// Create new Chain from a single block and its state. + pub fn from_block(block: SealedBlockWithSenders, state: BundleStateWithReceipts) -> Self { + Self::new([block], state) + } + /// Get the blocks in this chain. pub fn blocks(&self) -> &BTreeMap { &self.blocks @@ -58,9 +71,12 @@ impl Chain { /// Returns the block with matching hash. pub fn block(&self, block_hash: BlockHash) -> Option<&SealedBlock> { - self.blocks - .iter() - .find_map(|(_num, block)| (block.hash() == block_hash).then_some(&block.block)) + self.block_with_senders(block_hash).map(|block| &block.block) + } + + /// Returns the block with matching hash. + pub fn block_with_senders(&self, block_hash: BlockHash) -> Option<&SealedBlockWithSenders> { + self.blocks.iter().find_map(|(_num, block)| (block.hash() == block_hash).then_some(block)) } /// Return post state of the block at the `block_number` or None if block is not known @@ -96,18 +112,6 @@ impl Chain { ForkBlock { number: first.number.saturating_sub(1), hash: first.parent_hash } } - /// Get the block number at which this chain forked. - #[track_caller] - pub fn fork_block_number(&self) -> BlockNumber { - self.first().number.saturating_sub(1) - } - - /// Get the block hash at which this chain forked. - #[track_caller] - pub fn fork_block_hash(&self) -> BlockHash { - self.first().parent_hash - } - /// Get the first block in this chain. #[track_caller] pub fn first(&self) -> &SealedBlockWithSenders { @@ -124,11 +128,6 @@ impl Chain { self.blocks.last_key_value().expect("Chain should have at least one block").1 } - /// Create new chain with given blocks and post state. - pub fn new(blocks: Vec, state: BundleStateWithReceipts) -> Self { - Self { state, blocks: blocks.into_iter().map(|b| (b.number, b)).collect() } - } - /// Returns length of the chain. pub fn len(&self) -> usize { self.blocks.len() @@ -160,22 +159,30 @@ impl Chain { receipt_attch } + /// Append a single block with state to the chain. + /// This method assumes that blocks attachment to the chain has already been validated. + pub fn append_block(&mut self, block: SealedBlockWithSenders, state: BundleStateWithReceipts) { + self.blocks.insert(block.number, block); + self.state.extend(state); + } + /// Merge two chains by appending the given chain into the current one. /// /// The state of accounts for this chain is set to the state of the newest chain. - pub fn append_chain(&mut self, chain: Chain) -> RethResult<()> { + pub fn append_chain(&mut self, other: Chain) -> RethResult<()> { let chain_tip = self.tip(); - if chain_tip.hash != chain.fork_block_hash() { + let other_fork_block = other.fork_block(); + if chain_tip.hash != other_fork_block.hash { return Err(BlockExecutionError::AppendChainDoesntConnect { chain_tip: Box::new(chain_tip.num_hash()), - other_chain_fork: Box::new(chain.fork_block()), + other_chain_fork: Box::new(other_fork_block), } .into()) } // Insert blocks from other chain - self.blocks.extend(chain.blocks); - self.state.extend(chain.state); + self.blocks.extend(other.blocks); + self.state.extend(other.state); Ok(()) } diff --git a/crates/storage/provider/src/lib.rs b/crates/storage/provider/src/lib.rs index 87118a6351c1..194c60d500c5 100644 --- a/crates/storage/provider/src/lib.rs +++ b/crates/storage/provider/src/lib.rs @@ -21,11 +21,11 @@ pub use traits::{ BlockWriter, BlockchainTreePendingStateProvider, BundleStateDataProvider, CanonChainTracker, CanonStateNotification, CanonStateNotificationSender, CanonStateNotifications, CanonStateSubscriptions, ChainSpecProvider, ChangeSetReader, EvmEnvProvider, ExecutorFactory, - HashingWriter, HeaderProvider, HistoryWriter, PrunableBlockExecutor, PruneCheckpointReader, - PruneCheckpointWriter, ReceiptProvider, ReceiptProviderIdExt, StageCheckpointReader, - StageCheckpointWriter, StateProvider, StateProviderBox, StateProviderFactory, - StateRootProvider, StorageReader, TransactionVariant, TransactionsProvider, - TransactionsProviderExt, WithdrawalsProvider, + HashingWriter, HeaderProvider, HeaderSyncGap, HeaderSyncGapProvider, HeaderSyncMode, + HistoryWriter, PrunableBlockExecutor, PruneCheckpointReader, PruneCheckpointWriter, + ReceiptProvider, ReceiptProviderIdExt, StageCheckpointReader, StageCheckpointWriter, + StateProvider, StateProviderBox, StateProviderFactory, StateRootProvider, StorageReader, + TransactionVariant, TransactionsProvider, TransactionsProviderExt, WithdrawalsProvider, }; /// Provider trait implementations. diff --git a/crates/storage/provider/src/providers/bundle_state_provider.rs b/crates/storage/provider/src/providers/bundle_state_provider.rs index 46d9ae702ffd..f2b0e5fdf468 100644 --- a/crates/storage/provider/src/providers/bundle_state_provider.rs +++ b/crates/storage/provider/src/providers/bundle_state_provider.rs @@ -4,6 +4,7 @@ use crate::{ }; use reth_interfaces::provider::{ProviderError, ProviderResult}; use reth_primitives::{trie::AccountProof, Account, Address, BlockNumber, Bytecode, B256}; +use reth_trie::updates::TrieUpdates; /// A state provider that either resolves to data in a wrapped [`crate::BundleStateWithReceipts`], /// or an underlying state provider. @@ -11,14 +12,14 @@ use reth_primitives::{trie::AccountProof, Account, Address, BlockNumber, Bytecod pub struct BundleStateProvider { /// The inner state provider. pub(crate) state_provider: SP, - /// Post state data, - pub(crate) post_state_data_provider: BSDP, + /// Bundle state data, + pub(crate) bundle_state_data_provider: BSDP, } impl BundleStateProvider { - /// Create new post-state provider - pub fn new(state_provider: SP, post_state_data_provider: BSDP) -> Self { - Self { state_provider, post_state_data_provider } + /// Create new bundle state provider + pub fn new(state_provider: SP, bundle_state_data_provider: BSDP) -> Self { + Self { state_provider, bundle_state_data_provider } } } @@ -28,7 +29,7 @@ impl BlockHashReader for BundleStateProvider { fn block_hash(&self, block_number: BlockNumber) -> ProviderResult> { - let block_hash = self.post_state_data_provider.block_hash(block_number); + let block_hash = self.bundle_state_data_provider.block_hash(block_number); if block_hash.is_some() { return Ok(block_hash) } @@ -48,7 +49,7 @@ impl AccountReader for BundleStateProvider { fn basic_account(&self, address: Address) -> ProviderResult> { - if let Some(account) = self.post_state_data_provider.state().account(&address) { + if let Some(account) = self.bundle_state_data_provider.state().account(&address) { Ok(account) } else { self.state_provider.basic_account(address) @@ -59,11 +60,20 @@ impl AccountReader impl StateRootProvider for BundleStateProvider { - fn state_root(&self, post_state: &BundleStateWithReceipts) -> ProviderResult { - let mut state = self.post_state_data_provider.state().clone(); - state.extend(post_state.clone()); + fn state_root(&self, bundle_state: &BundleStateWithReceipts) -> ProviderResult { + let mut state = self.bundle_state_data_provider.state().clone(); + state.extend(bundle_state.clone()); self.state_provider.state_root(&state) } + + fn state_root_with_updates( + &self, + bundle_state: &BundleStateWithReceipts, + ) -> ProviderResult<(B256, TrieUpdates)> { + let mut state = self.bundle_state_data_provider.state().clone(); + state.extend(bundle_state.clone()); + self.state_provider.state_root_with_updates(&state) + } } impl StateProvider @@ -76,7 +86,7 @@ impl StateProvider ) -> ProviderResult> { let u256_storage_key = storage_key.into(); if let Some(value) = - self.post_state_data_provider.state().storage(&account, u256_storage_key) + self.bundle_state_data_provider.state().storage(&account, u256_storage_key) { return Ok(Some(value)) } @@ -85,7 +95,7 @@ impl StateProvider } fn bytecode_by_hash(&self, code_hash: B256) -> ProviderResult> { - if let Some(bytecode) = self.post_state_data_provider.state().bytecode(&code_hash) { + if let Some(bytecode) = self.bundle_state_data_provider.state().bytecode(&code_hash) { return Ok(Some(bytecode)) } diff --git a/crates/storage/provider/src/providers/database/mod.rs b/crates/storage/provider/src/providers/database/mod.rs index 38b4be901d27..0d1ca70ab465 100644 --- a/crates/storage/provider/src/providers/database/mod.rs +++ b/crates/storage/provider/src/providers/database/mod.rs @@ -5,8 +5,9 @@ use crate::{ }, traits::{BlockSource, ReceiptProvider}, BlockHashReader, BlockNumReader, BlockReader, ChainSpecProvider, EvmEnvProvider, - HeaderProvider, ProviderError, PruneCheckpointReader, StageCheckpointReader, StateProviderBox, - TransactionVariant, TransactionsProvider, WithdrawalsProvider, + HeaderProvider, HeaderSyncGap, HeaderSyncGapProvider, HeaderSyncMode, ProviderError, + PruneCheckpointReader, StageCheckpointReader, StateProviderBox, TransactionVariant, + TransactionsProvider, WithdrawalsProvider, }; use reth_db::{database::Database, init_db, models::StoredBlockBodyIndices, DatabaseEnv}; use reth_interfaces::{db::LogLevel, provider::ProviderResult, RethError, RethResult}; @@ -14,9 +15,9 @@ use reth_primitives::{ snapshot::HighestSnapshots, stage::{StageCheckpoint, StageId}, Address, Block, BlockHash, BlockHashOrNumber, BlockNumber, BlockWithSenders, ChainInfo, - ChainSpec, Header, PruneCheckpoint, PruneSegment, Receipt, SealedBlock, SealedHeader, - TransactionMeta, TransactionSigned, TransactionSignedNoHash, TxHash, TxNumber, Withdrawal, - B256, U256, + ChainSpec, Header, PruneCheckpoint, PruneSegment, Receipt, SealedBlock, SealedBlockWithSenders, + SealedHeader, TransactionMeta, TransactionSigned, TransactionSignedNoHash, TxHash, TxNumber, + Withdrawal, B256, U256, }; use revm::primitives::{BlockEnv, CfgEnv}; use std::{ @@ -49,7 +50,7 @@ impl ProviderFactory { /// Returns a provider with a created `DbTx` inside, which allows fetching data from the /// database using different types of providers. Example: [`HeaderProvider`] /// [`BlockHashReader`]. This may fail if the inner read database transaction fails to open. - pub fn provider(&self) -> ProviderResult> { + pub fn provider(&self) -> ProviderResult> { let mut provider = DatabaseProvider::new(self.db.tx()?, self.chain_spec.clone()); if let Some(snapshot_provider) = &self.snapshot_provider { @@ -63,7 +64,7 @@ impl ProviderFactory { /// data from the database using different types of providers. Example: [`HeaderProvider`] /// [`BlockHashReader`]. This may fail if the inner read/write database transaction fails to /// open. - pub fn provider_rw(&self) -> ProviderResult> { + pub fn provider_rw(&self) -> ProviderResult> { let mut provider = DatabaseProvider::new_rw(self.db.tx_mut()?, self.chain_spec.clone()); if let Some(snapshot_provider) = &self.snapshot_provider { @@ -122,7 +123,7 @@ impl Clone for ProviderFactory { impl ProviderFactory { /// Storage provider for latest block - pub fn latest(&self) -> ProviderResult> { + pub fn latest(&self) -> ProviderResult { trace!(target: "providers::db", "Returning latest state provider"); Ok(Box::new(LatestStateProvider::new(self.db.tx()?))) } @@ -131,7 +132,7 @@ impl ProviderFactory { fn state_provider_by_block_number( &self, mut block_number: BlockNumber, - ) -> ProviderResult> { + ) -> ProviderResult { let provider = self.provider()?; if block_number == provider.best_block_number().unwrap_or_default() && @@ -174,17 +175,14 @@ impl ProviderFactory { pub fn history_by_block_number( &self, block_number: BlockNumber, - ) -> ProviderResult> { + ) -> ProviderResult { let state_provider = self.state_provider_by_block_number(block_number)?; trace!(target: "providers::db", ?block_number, "Returning historical state provider for block number"); Ok(state_provider) } /// Storage provider for state at that given block hash - pub fn history_by_block_hash( - &self, - block_hash: BlockHash, - ) -> ProviderResult> { + pub fn history_by_block_hash(&self, block_hash: BlockHash) -> ProviderResult { let block_number = self .provider()? .block_number(block_hash)? @@ -196,6 +194,16 @@ impl ProviderFactory { } } +impl HeaderSyncGapProvider for ProviderFactory { + fn sync_gap( + &self, + mode: HeaderSyncMode, + highest_uninterrupted_block: BlockNumber, + ) -> RethResult { + self.provider()?.sync_gap(mode, highest_uninterrupted_block) + } +} + impl HeaderProvider for ProviderFactory { fn header(&self, block_hash: &BlockHash) -> ProviderResult> { self.provider()?.header(block_hash) @@ -282,6 +290,10 @@ impl BlockReader for ProviderFactory { self.provider()?.pending_block() } + fn pending_block_with_senders(&self) -> ProviderResult> { + self.provider()?.pending_block_with_senders() + } + fn pending_block_and_receipts(&self) -> ProviderResult)>> { self.provider()?.pending_block_and_receipts() } @@ -477,33 +489,37 @@ impl PruneCheckpointReader for ProviderFactory { #[cfg(test)] mod tests { use super::ProviderFactory; - use crate::{BlockHashReader, BlockNumReader, BlockWriter, TransactionsProvider}; + use crate::{ + test_utils::create_test_provider_factory, BlockHashReader, BlockNumReader, BlockWriter, + HeaderSyncGapProvider, HeaderSyncMode, TransactionsProvider, + }; use alloy_rlp::Decodable; use assert_matches::assert_matches; - use reth_db::{ - tables, - test_utils::{create_test_rw_db, ERROR_TEMPDIR}, - DatabaseEnv, + use rand::Rng; + use reth_db::{tables, test_utils::ERROR_TEMPDIR, transaction::DbTxMut, DatabaseEnv}; + use reth_interfaces::{ + provider::ProviderError, + test_utils::{ + generators, + generators::{random_block, random_header}, + }, + RethError, }; - use reth_interfaces::test_utils::{generators, generators::random_block}; use reth_primitives::{ hex_literal::hex, ChainSpecBuilder, PruneMode, PruneModes, SealedBlock, TxNumber, B256, }; use std::{ops::RangeInclusive, sync::Arc}; + use tokio::sync::watch; #[test] fn common_history_provider() { - let chain_spec = ChainSpecBuilder::mainnet().build(); - let db = create_test_rw_db(); - let provider = ProviderFactory::new(db, Arc::new(chain_spec)); - let _ = provider.latest(); + let factory = create_test_provider_factory(); + let _ = factory.latest(); } #[test] fn default_chain_info() { - let chain_spec = ChainSpecBuilder::mainnet().build(); - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db, Arc::new(chain_spec)); + let factory = create_test_provider_factory(); let provider = factory.provider().unwrap(); let chain_info = provider.chain_info().expect("should be ok"); @@ -513,9 +529,7 @@ mod tests { #[test] fn provider_flow() { - let chain_spec = ChainSpecBuilder::mainnet().build(); - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db, Arc::new(chain_spec)); + let factory = create_test_provider_factory(); let provider = factory.provider().unwrap(); provider.block_hash(0).unwrap(); let provider_rw = factory.provider_rw().unwrap(); @@ -542,9 +556,7 @@ mod tests { #[test] fn insert_block_with_prune_modes() { - let chain_spec = ChainSpecBuilder::mainnet().build(); - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db, Arc::new(chain_spec)); + let factory = create_test_provider_factory(); let mut block_rlp = hex!("f9025ff901f7a0c86e8cc0310ae7c531c758678ddbfd16fc51c8cef8cec650b032de9869e8b94fa01dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347942adc25665018aa1fe0e6bc666dac8fc2697ff9baa050554882fbbda2c2fd93fdc466db9946ea262a67f7a76cc169e714f105ab583da00967f09ef1dfed20c0eacfaa94d5cd4002eda3242ac47eae68972d07b106d192a0e3c8b47fbfc94667ef4cceb17e5cc21e3b1eebd442cebb27f07562b33836290db90100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008302000001830f42408238108203e800a00000000000000000000000000000000000000000000000000000000000000000880000000000000000f862f860800a83061a8094095e7baea6a6c7c4c2dfeb977efac326af552d8780801ba072ed817487b84ba367d15d2f039b5fc5f087d0a8882fbdf73e8cb49357e1ce30a0403d800545b8fc544f92ce8124e2255f8c3c6af93f28243a120585d4c4c6a2a3c0").as_slice(); let block = SealedBlock::decode(&mut block_rlp).unwrap(); @@ -580,9 +592,7 @@ mod tests { #[test] fn get_take_block_transaction_range_recover_senders() { - let chain_spec = ChainSpecBuilder::mainnet().build(); - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db, Arc::new(chain_spec)); + let factory = create_test_provider_factory(); let mut rng = generators::rng(); let block = random_block(&mut rng, 0, None, Some(3), None); @@ -618,4 +628,71 @@ mod tests { ) } } + + #[test] + fn header_sync_gap_lookup() { + let factory = create_test_provider_factory(); + let provider = factory.provider_rw().unwrap(); + + let mut rng = generators::rng(); + let consensus_tip = rng.gen(); + let (_tip_tx, tip_rx) = watch::channel(consensus_tip); + let mode = HeaderSyncMode::Tip(tip_rx); + + // Genesis + let checkpoint = 0; + let head = random_header(&mut rng, 0, None); + let gap_fill = random_header(&mut rng, 1, Some(head.hash())); + let gap_tip = random_header(&mut rng, 2, Some(gap_fill.hash())); + + // Empty database + assert_matches!( + provider.sync_gap(mode.clone(), checkpoint), + Err(RethError::Provider(ProviderError::HeaderNotFound(block_number))) + if block_number.as_number().unwrap() == checkpoint + ); + + // Checkpoint and no gap + provider + .tx_ref() + .put::(head.number, head.hash()) + .expect("failed to write canonical"); + provider + .tx_ref() + .put::(head.number, head.clone().unseal()) + .expect("failed to write header"); + + let gap = provider.sync_gap(mode.clone(), checkpoint).unwrap(); + assert_eq!(gap.local_head, head); + assert_eq!(gap.target.tip(), consensus_tip.into()); + + // Checkpoint and gap + provider + .tx_ref() + .put::(gap_tip.number, gap_tip.hash()) + .expect("failed to write canonical"); + provider + .tx_ref() + .put::(gap_tip.number, gap_tip.clone().unseal()) + .expect("failed to write header"); + + let gap = provider.sync_gap(mode.clone(), checkpoint).unwrap(); + assert_eq!(gap.local_head, head); + assert_eq!(gap.target.tip(), gap_tip.parent_hash.into()); + + // Checkpoint and gap closed + provider + .tx_ref() + .put::(gap_fill.number, gap_fill.hash()) + .expect("failed to write canonical"); + provider + .tx_ref() + .put::(gap_fill.number, gap_fill.clone().unseal()) + .expect("failed to write header"); + + assert_matches!( + provider.sync_gap(mode, checkpoint), + Err(RethError::Provider(ProviderError::InconsistentHeaderGap)) + ); + } } diff --git a/crates/storage/provider/src/providers/database/provider.rs b/crates/storage/provider/src/providers/database/provider.rs index 198aeb5533ac..244011c7b619 100644 --- a/crates/storage/provider/src/providers/database/provider.rs +++ b/crates/storage/provider/src/providers/database/provider.rs @@ -5,16 +5,16 @@ use crate::{ AccountExtReader, BlockSource, ChangeSetReader, ReceiptProvider, StageCheckpointWriter, }, AccountReader, BlockExecutionWriter, BlockHashReader, BlockNumReader, BlockReader, BlockWriter, - Chain, EvmEnvProvider, HashingWriter, HeaderProvider, HistoryWriter, OriginalValuesKnown, - ProviderError, PruneCheckpointReader, PruneCheckpointWriter, StageCheckpointReader, - StorageReader, TransactionVariant, TransactionsProvider, TransactionsProviderExt, - WithdrawalsProvider, + Chain, EvmEnvProvider, HashingWriter, HeaderProvider, HeaderSyncGap, HeaderSyncGapProvider, + HeaderSyncMode, HistoryWriter, OriginalValuesKnown, ProviderError, PruneCheckpointReader, + PruneCheckpointWriter, StageCheckpointReader, StorageReader, TransactionVariant, + TransactionsProvider, TransactionsProviderExt, WithdrawalsProvider, }; use itertools::{izip, Itertools}; use reth_db::{ common::KeyValue, cursor::{DbCursorRO, DbCursorRW, DbDupCursorRO}, - database::{Database, DatabaseGAT}, + database::Database, models::{ sharded_key, storage_sharded_key::StorageShardedKey, AccountBeforeTx, BlockNumberAddress, ShardedKey, StoredBlockBodyIndices, StoredBlockOmmers, StoredBlockWithdrawals, @@ -24,7 +24,11 @@ use reth_db::{ transaction::{DbTx, DbTxMut}, BlockNumberList, DatabaseError, }; -use reth_interfaces::provider::{ProviderResult, RootMismatch}; +use reth_interfaces::{ + p2p::headers::downloader::SyncTarget, + provider::{ProviderResult, RootMismatch}, + RethError, RethResult, +}; use reth_primitives::{ keccak256, revm::{ @@ -51,39 +55,37 @@ use std::{ use tracing::{debug, warn}; /// A [`DatabaseProvider`] that holds a read-only database transaction. -pub type DatabaseProviderRO<'this, DB> = DatabaseProvider<>::TX>; +pub type DatabaseProviderRO = DatabaseProvider<::TX>; /// A [`DatabaseProvider`] that holds a read-write database transaction. /// /// Ideally this would be an alias type. However, there's some weird compiler error (), that forces us to wrap this in a struct instead. /// Once that issue is solved, we can probably revert back to being an alias type. #[derive(Debug)] -pub struct DatabaseProviderRW<'this, DB: Database>( - pub DatabaseProvider<>::TXMut>, -); +pub struct DatabaseProviderRW(pub DatabaseProvider<::TXMut>); -impl<'this, DB: Database> Deref for DatabaseProviderRW<'this, DB> { - type Target = DatabaseProvider<>::TXMut>; +impl Deref for DatabaseProviderRW { + type Target = DatabaseProvider<::TXMut>; fn deref(&self) -> &Self::Target { &self.0 } } -impl DerefMut for DatabaseProviderRW<'_, DB> { +impl DerefMut for DatabaseProviderRW { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } -impl<'this, DB: Database> DatabaseProviderRW<'this, DB> { +impl DatabaseProviderRW { /// Commit database transaction pub fn commit(self) -> ProviderResult { self.0.commit() } /// Consume `DbTx` or `DbTxMut`. - pub fn into_tx(self) -> >::TXMut { + pub fn into_tx(self) -> ::TXMut { self.0.into_tx() } } @@ -868,6 +870,57 @@ impl ChangeSetReader for DatabaseProvider { } } +impl HeaderSyncGapProvider for DatabaseProvider { + fn sync_gap( + &self, + mode: HeaderSyncMode, + highest_uninterrupted_block: BlockNumber, + ) -> RethResult { + // Create a cursor over canonical header hashes + let mut cursor = self.tx.cursor_read::()?; + let mut header_cursor = self.tx.cursor_read::()?; + + // Get head hash and reposition the cursor + let (head_num, head_hash) = cursor + .seek_exact(highest_uninterrupted_block)? + .ok_or_else(|| ProviderError::HeaderNotFound(highest_uninterrupted_block.into()))?; + + // Construct head + let (_, head) = header_cursor + .seek_exact(head_num)? + .ok_or_else(|| ProviderError::HeaderNotFound(head_num.into()))?; + let local_head = head.seal(head_hash); + + // Look up the next header + let next_header = cursor + .next()? + .map(|(next_num, next_hash)| -> Result { + let (_, next) = header_cursor + .seek_exact(next_num)? + .ok_or_else(|| ProviderError::HeaderNotFound(next_num.into()))?; + Ok(next.seal(next_hash)) + }) + .transpose()?; + + // Decide the tip or error out on invalid input. + // If the next element found in the cursor is not the "expected" next block per our current + // checkpoint, then there is a gap in the database and we should start downloading in + // reverse from there. Else, it should use whatever the forkchoice state reports. + let target = match next_header { + Some(header) if highest_uninterrupted_block + 1 != header.number => { + SyncTarget::Gap(header) + } + None => match mode { + HeaderSyncMode::Tip(rx) => SyncTarget::Tip(*rx.borrow()), + HeaderSyncMode::Continuous => SyncTarget::TipNum(head_num + 1), + }, + _ => return Err(ProviderError::InconsistentHeaderGap.into()), + }; + + Ok(HeaderSyncGap { local_head, target }) + } +} + impl HeaderProvider for DatabaseProvider { fn header(&self, block_hash: &BlockHash) -> ProviderResult> { if let Some(num) = self.block_number(*block_hash)? { @@ -1020,6 +1073,10 @@ impl BlockReader for DatabaseProvider { Ok(None) } + fn pending_block_with_senders(&self) -> ProviderResult> { + Ok(None) + } + fn pending_block_and_receipts(&self) -> ProviderResult)>> { Ok(None) } diff --git a/crates/storage/provider/src/providers/mod.rs b/crates/storage/provider/src/providers/mod.rs index 898b5a39c065..11dba5e6a6c4 100644 --- a/crates/storage/provider/src/providers/mod.rs +++ b/crates/storage/provider/src/providers/mod.rs @@ -252,6 +252,10 @@ where Ok(self.tree.pending_block()) } + fn pending_block_with_senders(&self) -> ProviderResult> { + Ok(self.tree.pending_block_with_senders()) + } + fn pending_block_and_receipts(&self) -> ProviderResult)>> { Ok(self.tree.pending_block_and_receipts()) } @@ -508,7 +512,7 @@ where Tree: BlockchainTreePendingStateProvider + BlockchainTreeViewer, { /// Storage provider for latest block - fn latest(&self) -> ProviderResult> { + fn latest(&self) -> ProviderResult { trace!(target: "providers::blockchain", "Getting latest block state provider"); self.database.latest() } @@ -516,18 +520,18 @@ where fn history_by_block_number( &self, block_number: BlockNumber, - ) -> ProviderResult> { + ) -> ProviderResult { trace!(target: "providers::blockchain", ?block_number, "Getting history by block number"); self.ensure_canonical_block(block_number)?; self.database.history_by_block_number(block_number) } - fn history_by_block_hash(&self, block_hash: BlockHash) -> ProviderResult> { + fn history_by_block_hash(&self, block_hash: BlockHash) -> ProviderResult { trace!(target: "providers::blockchain", ?block_hash, "Getting history by block hash"); self.database.history_by_block_hash(block_hash) } - fn state_by_block_hash(&self, block: BlockHash) -> ProviderResult> { + fn state_by_block_hash(&self, block: BlockHash) -> ProviderResult { trace!(target: "providers::blockchain", ?block, "Getting state by block hash"); let mut state = self.history_by_block_hash(block); @@ -546,7 +550,7 @@ where /// /// If there's no pending block available then the latest state provider is returned: /// [Self::latest] - fn pending(&self) -> ProviderResult> { + fn pending(&self) -> ProviderResult { trace!(target: "providers::blockchain", "Getting provider for pending state"); if let Some(block) = self.tree.pending_block_num_hash() { @@ -559,10 +563,7 @@ where self.latest() } - fn pending_state_by_hash( - &self, - block_hash: B256, - ) -> ProviderResult>> { + fn pending_state_by_hash(&self, block_hash: B256) -> ProviderResult> { if let Some(state) = self.tree.find_pending_state_provider(block_hash) { return Ok(Some(self.pending_with_provider(state)?)) } @@ -571,14 +572,14 @@ where fn pending_with_provider( &self, - post_state_data: Box, - ) -> ProviderResult> { - let canonical_fork = post_state_data.canonical_fork(); + bundle_state_data: Box, + ) -> ProviderResult { + let canonical_fork = bundle_state_data.canonical_fork(); trace!(target: "providers::blockchain", ?canonical_fork, "Returning post state provider"); let state_provider = self.history_by_block_hash(canonical_fork.hash)?; - let post_state_provider = BundleStateProvider::new(state_provider, post_state_data); - Ok(Box::new(post_state_provider)) + let bundle_state_provider = BundleStateProvider::new(state_provider, bundle_state_data); + Ok(Box::new(bundle_state_provider)) } } @@ -640,6 +641,10 @@ where self.tree.block_by_hash(block_hash) } + fn block_with_senders_by_hash(&self, block_hash: BlockHash) -> Option { + self.tree.block_with_senders_by_hash(block_hash) + } + fn buffered_block_by_hash(&self, block_hash: BlockHash) -> Option { self.tree.buffered_block_by_hash(block_hash) } diff --git a/crates/storage/provider/src/providers/snapshot/mod.rs b/crates/storage/provider/src/providers/snapshot/mod.rs index a5244c78e891..26f180e853de 100644 --- a/crates/storage/provider/src/providers/snapshot/mod.rs +++ b/crates/storage/provider/src/providers/snapshot/mod.rs @@ -41,19 +41,17 @@ impl Deref for LoadedJar { #[cfg(test)] mod test { use super::*; - use crate::{HeaderProvider, ProviderFactory}; + use crate::{test_utils::create_test_provider_factory, HeaderProvider}; use rand::{self, seq::SliceRandom}; use reth_db::{ cursor::DbCursorRO, - database::Database, snapshot::create_snapshot_T1_T2_T3, - test_utils::create_test_rw_db, transaction::{DbTx, DbTxMut}, - CanonicalHeaders, DatabaseError, HeaderNumbers, HeaderTD, Headers, RawTable, + CanonicalHeaders, HeaderNumbers, HeaderTD, Headers, RawTable, }; use reth_interfaces::test_utils::generators::{self, random_header_range}; use reth_nippy_jar::NippyJar; - use reth_primitives::{BlockNumber, B256, MAINNET, U256}; + use reth_primitives::{BlockNumber, B256, U256}; #[test] fn test_snap() { @@ -64,8 +62,7 @@ mod test { SegmentHeader::new(range.clone(), range.clone(), SnapshotSegment::Headers); // Data sources - let db = create_test_rw_db(); - let factory = ProviderFactory::new(&db, MAINNET.clone()); + let factory = create_test_provider_factory(); let snap_path = tempfile::tempdir().unwrap(); let snap_file = snap_path.path().join(SnapshotSegment::Headers.filename(&range, &range)); @@ -76,21 +73,19 @@ mod test { B256::random(), ); - db.update(|tx| -> Result<(), DatabaseError> { - let mut td = U256::ZERO; - for header in headers.clone() { - td += header.header.difficulty; - let hash = header.hash(); - - tx.put::(header.number, hash)?; - tx.put::(header.number, header.clone().unseal())?; - tx.put::(header.number, td.into())?; - tx.put::(hash, header.number)?; - } - Ok(()) - }) - .unwrap() - .unwrap(); + let mut provider_rw = factory.provider_rw().unwrap(); + let tx = provider_rw.tx_mut(); + let mut td = U256::ZERO; + for header in headers.clone() { + td += header.header.difficulty; + let hash = header.hash(); + + tx.put::(header.number, hash).unwrap(); + tx.put::(header.number, header.clone().unseal()).unwrap(); + tx.put::(header.number, td.into()).unwrap(); + tx.put::(hash, header.number).unwrap(); + } + provider_rw.commit().unwrap(); // Create Snapshot { @@ -107,7 +102,8 @@ mod test { nippy_jar = nippy_jar.with_cuckoo_filter(row_count as usize + 10).with_fmph(); } - let tx = db.tx().unwrap(); + let provider = factory.provider().unwrap(); + let tx = provider.tx_ref(); // Hacky type inference. TODO fix let mut none_vec = Some(vec![vec![vec![0u8]].into_iter()]); @@ -127,7 +123,7 @@ mod test { BlockNumber, SegmentHeader, >( - &tx, range, None, none_vec, Some(hashes), row_count as usize, &mut nippy_jar + tx, range, None, none_vec, Some(hashes), row_count as usize, &mut nippy_jar ) .unwrap(); } diff --git a/crates/storage/provider/src/providers/state/historical.rs b/crates/storage/provider/src/providers/state/historical.rs index c76ea75d53e9..7d7fabe2018e 100644 --- a/crates/storage/provider/src/providers/state/historical.rs +++ b/crates/storage/provider/src/providers/state/historical.rs @@ -14,6 +14,7 @@ use reth_interfaces::provider::ProviderResult; use reth_primitives::{ trie::AccountProof, Account, Address, BlockNumber, Bytecode, StorageKey, StorageValue, B256, }; +use reth_trie::updates::TrieUpdates; /// State provider for a given block number which takes a tx reference. /// @@ -198,7 +199,14 @@ impl<'b, TX: DbTx> BlockHashReader for HistoricalStateProviderRef<'b, TX> { } impl<'b, TX: DbTx> StateRootProvider for HistoricalStateProviderRef<'b, TX> { - fn state_root(&self, _post_state: &BundleStateWithReceipts) -> ProviderResult { + fn state_root(&self, _bundle_state: &BundleStateWithReceipts) -> ProviderResult { + Err(ProviderError::StateRootNotAvailableForHistoricalBlock) + } + + fn state_root_with_updates( + &self, + _bundle_state: &BundleStateWithReceipts, + ) -> ProviderResult<(B256, TrieUpdates)> { Err(ProviderError::StateRootNotAvailableForHistoricalBlock) } } @@ -217,10 +225,10 @@ impl<'b, TX: DbTx> StateProvider for HistoricalStateProviderRef<'b, TX> { .cursor_dup_read::()? .seek_by_key_subkey((changeset_block_number, address).into(), storage_key)? .filter(|entry| entry.key == storage_key) - .ok_or(ProviderError::StorageChangesetNotFound { + .ok_or_else(|| ProviderError::StorageChangesetNotFound { block_number: changeset_block_number, address, - storage_key, + storage_key: Box::new(storage_key), })? .value, )), diff --git a/crates/storage/provider/src/providers/state/latest.rs b/crates/storage/provider/src/providers/state/latest.rs index 1b45555fc218..df515f78e202 100644 --- a/crates/storage/provider/src/providers/state/latest.rs +++ b/crates/storage/provider/src/providers/state/latest.rs @@ -12,6 +12,7 @@ use reth_primitives::{ keccak256, trie::AccountProof, Account, Address, BlockNumber, Bytecode, StorageKey, StorageValue, B256, }; +use reth_trie::updates::TrieUpdates; /// State provider over latest state that takes tx reference. #[derive(Debug)] @@ -62,6 +63,15 @@ impl<'b, TX: DbTx> StateRootProvider for LatestStateProviderRef<'b, TX> { fn state_root(&self, bundle_state: &BundleStateWithReceipts) -> ProviderResult { bundle_state.state_root_slow(self.db).map_err(|err| ProviderError::Database(err.into())) } + + fn state_root_with_updates( + &self, + bundle_state: &BundleStateWithReceipts, + ) -> ProviderResult<(B256, TrieUpdates)> { + bundle_state + .state_root_slow_with_updates(self.db) + .map_err(|err| ProviderError::Database(err.into())) + } } impl<'b, TX: DbTx> StateProvider for LatestStateProviderRef<'b, TX> { diff --git a/crates/storage/provider/src/providers/state/macros.rs b/crates/storage/provider/src/providers/state/macros.rs index 67b3c33f50af..300b2c2ec4f5 100644 --- a/crates/storage/provider/src/providers/state/macros.rs +++ b/crates/storage/provider/src/providers/state/macros.rs @@ -32,6 +32,7 @@ macro_rules! delegate_provider_impls { for $target => StateRootProvider $(where [$($generics)*])? { fn state_root(&self, state: &crate::BundleStateWithReceipts) -> reth_interfaces::provider::ProviderResult; + fn state_root_with_updates(&self, state: &crate::BundleStateWithReceipts) -> reth_interfaces::provider::ProviderResult<(reth_primitives::B256, reth_trie::updates::TrieUpdates)>; } AccountReader $(where [$($generics)*])? { fn basic_account(&self, address: reth_primitives::Address) -> reth_interfaces::provider::ProviderResult>; diff --git a/crates/storage/provider/src/test_utils/blocks.rs b/crates/storage/provider/src/test_utils/blocks.rs index 30ee18f6c95e..3162266bc781 100644 --- a/crates/storage/provider/src/test_utils/blocks.rs +++ b/crates/storage/provider/src/test_utils/blocks.rs @@ -10,7 +10,7 @@ use reth_primitives::{ use std::collections::HashMap; /// Assert genesis block -pub fn assert_genesis_block(provider: &DatabaseProviderRW<'_, DB>, g: SealedBlock) { +pub fn assert_genesis_block(provider: &DatabaseProviderRW, g: SealedBlock) { let n = g.number; let h = B256::ZERO; let tx = provider; diff --git a/crates/storage/provider/src/test_utils/mock.rs b/crates/storage/provider/src/test_utils/mock.rs index c4689ac57234..8a1cb6ca379e 100644 --- a/crates/storage/provider/src/test_utils/mock.rs +++ b/crates/storage/provider/src/test_utils/mock.rs @@ -12,9 +12,10 @@ use reth_interfaces::provider::{ProviderError, ProviderResult}; use reth_primitives::{ keccak256, trie::AccountProof, Account, Address, Block, BlockHash, BlockHashOrNumber, BlockId, BlockNumber, BlockWithSenders, Bytecode, Bytes, ChainInfo, ChainSpec, Header, Receipt, - SealedBlock, SealedHeader, StorageKey, StorageValue, TransactionMeta, TransactionSigned, - TransactionSignedNoHash, TxHash, TxNumber, B256, U256, + SealedBlock, SealedBlockWithSenders, SealedHeader, StorageKey, StorageValue, TransactionMeta, + TransactionSigned, TransactionSignedNoHash, TxHash, TxNumber, B256, U256, }; +use reth_trie::updates::TrieUpdates; use revm::primitives::{BlockEnv, CfgEnv}; use std::{ collections::{BTreeMap, HashMap}, @@ -437,6 +438,10 @@ impl BlockReader for MockEthProvider { Ok(None) } + fn pending_block_with_senders(&self) -> ProviderResult> { + Ok(None) + } + fn pending_block_and_receipts(&self) -> ProviderResult)>> { Ok(None) } @@ -496,7 +501,14 @@ impl AccountReader for MockEthProvider { } impl StateRootProvider for MockEthProvider { - fn state_root(&self, _state: &BundleStateWithReceipts) -> ProviderResult { + fn state_root(&self, _bundle_state: &BundleStateWithReceipts) -> ProviderResult { + todo!() + } + + fn state_root_with_updates( + &self, + _bundle_state: &BundleStateWithReceipts, + ) -> ProviderResult<(B256, TrieUpdates)> { todo!() } } @@ -573,73 +585,67 @@ impl EvmEnvProvider for MockEthProvider { } impl StateProviderFactory for MockEthProvider { - fn latest(&self) -> ProviderResult> { + fn latest(&self) -> ProviderResult { Ok(Box::new(self.clone())) } - fn history_by_block_number(&self, _block: BlockNumber) -> ProviderResult> { + fn history_by_block_number(&self, _block: BlockNumber) -> ProviderResult { Ok(Box::new(self.clone())) } - fn history_by_block_hash(&self, _block: BlockHash) -> ProviderResult> { + fn history_by_block_hash(&self, _block: BlockHash) -> ProviderResult { Ok(Box::new(self.clone())) } - fn state_by_block_hash(&self, _block: BlockHash) -> ProviderResult> { + fn state_by_block_hash(&self, _block: BlockHash) -> ProviderResult { Ok(Box::new(self.clone())) } - fn pending(&self) -> ProviderResult> { + fn pending(&self) -> ProviderResult { Ok(Box::new(self.clone())) } - fn pending_state_by_hash( - &self, - _block_hash: B256, - ) -> ProviderResult>> { + fn pending_state_by_hash(&self, _block_hash: B256) -> ProviderResult> { Ok(Some(Box::new(self.clone()))) } fn pending_with_provider<'a>( &'a self, - _post_state_data: Box, - ) -> ProviderResult> { + _bundle_state_data: Box, + ) -> ProviderResult { Ok(Box::new(self.clone())) } } impl StateProviderFactory for Arc { - fn latest(&self) -> ProviderResult> { + fn latest(&self) -> ProviderResult { Ok(Box::new(self.clone())) } - fn history_by_block_number(&self, _block: BlockNumber) -> ProviderResult> { + fn history_by_block_number(&self, _block: BlockNumber) -> ProviderResult { Ok(Box::new(self.clone())) } - fn history_by_block_hash(&self, _block: BlockHash) -> ProviderResult> { + fn history_by_block_hash(&self, _block: BlockHash) -> ProviderResult { Ok(Box::new(self.clone())) } - fn state_by_block_hash(&self, _block: BlockHash) -> ProviderResult> { + fn state_by_block_hash(&self, _block: BlockHash) -> ProviderResult { Ok(Box::new(self.clone())) } - fn pending(&self) -> ProviderResult> { + fn pending(&self) -> ProviderResult { Ok(Box::new(self.clone())) } - fn pending_state_by_hash( - &self, - _block_hash: B256, - ) -> ProviderResult>> { + fn pending_state_by_hash(&self, _block_hash: B256) -> ProviderResult> { Ok(Some(Box::new(self.clone()))) } fn pending_with_provider<'a>( &'a self, - _post_state_data: Box, - ) -> ProviderResult> { + _bundle_state_data: Box, + ) -> ProviderResult { Ok(Box::new(self.clone())) } } diff --git a/crates/storage/provider/src/test_utils/mod.rs b/crates/storage/provider/src/test_utils/mod.rs index bbbe973908af..0da47c47940b 100644 --- a/crates/storage/provider/src/test_utils/mod.rs +++ b/crates/storage/provider/src/test_utils/mod.rs @@ -1,3 +1,11 @@ +use crate::ProviderFactory; +use reth_db::{ + test_utils::{create_test_rw_db, TempDatabase}, + DatabaseEnv, +}; +use reth_primitives::{ChainSpec, MAINNET}; +use std::sync::Arc; + pub mod blocks; mod events; mod executor; @@ -8,3 +16,16 @@ pub use events::TestCanonStateSubscriptions; pub use executor::{TestExecutor, TestExecutorFactory}; pub use mock::{ExtendedAccount, MockEthProvider}; pub use noop::NoopProvider; + +/// Creates test provider factory with mainnet chain spec. +pub fn create_test_provider_factory() -> ProviderFactory>> { + create_test_provider_factory_with_chain_spec(MAINNET.clone()) +} + +/// Creates test provider factory with provided chain spec. +pub fn create_test_provider_factory_with_chain_spec( + chain_spec: Arc, +) -> ProviderFactory>> { + let db = create_test_rw_db(); + ProviderFactory::new(db, chain_spec) +} diff --git a/crates/storage/provider/src/test_utils/noop.rs b/crates/storage/provider/src/test_utils/noop.rs index 45258bc69263..dc36ac948fc3 100644 --- a/crates/storage/provider/src/test_utils/noop.rs +++ b/crates/storage/provider/src/test_utils/noop.rs @@ -14,9 +14,10 @@ use reth_primitives::{ trie::AccountProof, Account, Address, Block, BlockHash, BlockHashOrNumber, BlockId, BlockNumber, Bytecode, ChainInfo, ChainSpec, Header, PruneCheckpoint, PruneSegment, Receipt, SealedBlock, - SealedHeader, StorageKey, StorageValue, TransactionMeta, TransactionSigned, - TransactionSignedNoHash, TxHash, TxNumber, B256, MAINNET, U256, + SealedBlockWithSenders, SealedHeader, StorageKey, StorageValue, TransactionMeta, + TransactionSigned, TransactionSignedNoHash, TxHash, TxNumber, B256, MAINNET, U256, }; +use reth_trie::updates::TrieUpdates; use revm::primitives::{BlockEnv, CfgEnv}; use std::{ ops::{RangeBounds, RangeInclusive}, @@ -84,6 +85,10 @@ impl BlockReader for NoopProvider { Ok(None) } + fn pending_block_with_senders(&self) -> ProviderResult> { + Ok(None) + } + fn pending_block_and_receipts(&self) -> ProviderResult)>> { Ok(None) } @@ -274,6 +279,13 @@ impl StateRootProvider for NoopProvider { fn state_root(&self, _state: &BundleStateWithReceipts) -> ProviderResult { Ok(B256::default()) } + + fn state_root_with_updates( + &self, + _bundle_state: &BundleStateWithReceipts, + ) -> ProviderResult<(B256, TrieUpdates)> { + Ok((B256::default(), TrieUpdates::default())) + } } impl StateProvider for NoopProvider { @@ -339,37 +351,34 @@ impl EvmEnvProvider for NoopProvider { } impl StateProviderFactory for NoopProvider { - fn latest(&self) -> ProviderResult> { + fn latest(&self) -> ProviderResult { Ok(Box::new(*self)) } - fn history_by_block_number(&self, _block: BlockNumber) -> ProviderResult> { + fn history_by_block_number(&self, _block: BlockNumber) -> ProviderResult { Ok(Box::new(*self)) } - fn history_by_block_hash(&self, _block: BlockHash) -> ProviderResult> { + fn history_by_block_hash(&self, _block: BlockHash) -> ProviderResult { Ok(Box::new(*self)) } - fn state_by_block_hash(&self, _block: BlockHash) -> ProviderResult> { + fn state_by_block_hash(&self, _block: BlockHash) -> ProviderResult { Ok(Box::new(*self)) } - fn pending(&self) -> ProviderResult> { + fn pending(&self) -> ProviderResult { Ok(Box::new(*self)) } - fn pending_state_by_hash( - &self, - _block_hash: B256, - ) -> ProviderResult>> { + fn pending_state_by_hash(&self, _block_hash: B256) -> ProviderResult> { Ok(Some(Box::new(*self))) } fn pending_with_provider<'a>( &'a self, - _post_state_data: Box, - ) -> ProviderResult> { + _bundle_state_data: Box, + ) -> ProviderResult { Ok(Box::new(*self)) } } diff --git a/crates/storage/provider/src/traits/block.rs b/crates/storage/provider/src/traits/block.rs index 76690580620b..44951a3fcac8 100644 --- a/crates/storage/provider/src/traits/block.rs +++ b/crates/storage/provider/src/traits/block.rs @@ -85,6 +85,12 @@ pub trait BlockReader: /// and the caller does not know the hash. fn pending_block(&self) -> ProviderResult>; + /// Returns the pending block if available + /// + /// Note: This returns a [SealedBlockWithSenders] because it's expected that this is sealed by + /// the provider and the caller does not know the hash. + fn pending_block_with_senders(&self) -> ProviderResult>; + /// Returns the pending block and receipts if available. fn pending_block_and_receipts(&self) -> ProviderResult)>>; diff --git a/crates/storage/provider/src/traits/header_sync_gap.rs b/crates/storage/provider/src/traits/header_sync_gap.rs new file mode 100644 index 000000000000..576a26a9e8c7 --- /dev/null +++ b/crates/storage/provider/src/traits/header_sync_gap.rs @@ -0,0 +1,50 @@ +use auto_impl::auto_impl; +use reth_interfaces::{p2p::headers::downloader::SyncTarget, RethResult}; +use reth_primitives::{BlockHashOrNumber, BlockNumber, SealedHeader, B256}; +use tokio::sync::watch; + +/// The header sync mode. +#[derive(Clone, Debug)] +pub enum HeaderSyncMode { + /// A sync mode in which the stage continuously requests the downloader for + /// next blocks. + Continuous, + /// A sync mode in which the stage polls the receiver for the next tip + /// to download from. + Tip(watch::Receiver), +} + +/// Represents a gap to sync: from `local_head` to `target` +#[derive(Clone, Debug)] +pub struct HeaderSyncGap { + /// The local head block. Represents lower bound of sync range. + pub local_head: SealedHeader, + + /// The sync target. Represents upper bound of sync range. + pub target: SyncTarget, +} + +impl HeaderSyncGap { + /// Returns `true` if the gap from the head to the target was closed + #[inline] + pub fn is_closed(&self) -> bool { + match self.target.tip() { + BlockHashOrNumber::Hash(hash) => self.local_head.hash() == hash, + BlockHashOrNumber::Number(num) => self.local_head.number == num, + } + } +} + +/// Client trait for determining the current headers sync gap. +#[auto_impl(&, Arc)] +pub trait HeaderSyncGapProvider: Send + Sync { + /// Find a current sync gap for the headers depending on the [HeaderSyncMode] and the last + /// uninterrupted block number. Last uninterrupted block represents the block number before + /// which there are no gaps. It's up to the caller to ensure that last uninterrupted block is + /// determined correctly. + fn sync_gap( + &self, + mode: HeaderSyncMode, + highest_uninterrupted_block: BlockNumber, + ) -> RethResult; +} diff --git a/crates/storage/provider/src/traits/mod.rs b/crates/storage/provider/src/traits/mod.rs index 8134a19613af..64f806f5f2b2 100644 --- a/crates/storage/provider/src/traits/mod.rs +++ b/crates/storage/provider/src/traits/mod.rs @@ -27,6 +27,9 @@ pub use chain_info::CanonChainTracker; mod header; pub use header::HeaderProvider; +mod header_sync_gap; +pub use header_sync_gap::{HeaderSyncGap, HeaderSyncGapProvider, HeaderSyncMode}; + mod receipts; pub use receipts::{ReceiptProvider, ReceiptProviderIdExt}; diff --git a/crates/storage/provider/src/traits/state.rs b/crates/storage/provider/src/traits/state.rs index 617550441177..8cb985f359d8 100644 --- a/crates/storage/provider/src/traits/state.rs +++ b/crates/storage/provider/src/traits/state.rs @@ -6,9 +6,10 @@ use reth_primitives::{ trie::AccountProof, Address, BlockHash, BlockId, BlockNumHash, BlockNumber, BlockNumberOrTag, Bytecode, StorageKey, StorageValue, B256, KECCAK_EMPTY, U256, }; +use reth_trie::updates::TrieUpdates; /// Type alias of boxed [StateProvider]. -pub type StateProviderBox<'a> = Box; +pub type StateProviderBox = Box; /// An abstraction for a type that provides state data. #[auto_impl(&, Arc, Box)] @@ -99,13 +100,13 @@ pub trait StateProvider: BlockHashReader + AccountReader + StateRootProvider + S /// to be used, since block `n` was executed on its parent block's state. pub trait StateProviderFactory: BlockIdReader + Send + Sync { /// Storage provider for latest block. - fn latest(&self) -> ProviderResult>; + fn latest(&self) -> ProviderResult; /// Returns a [StateProvider] indexed by the given [BlockId]. /// /// Note: if a number or hash is provided this will __only__ look at historical(canonical) /// state. - fn state_by_block_id(&self, block_id: BlockId) -> ProviderResult> { + fn state_by_block_id(&self, block_id: BlockId) -> ProviderResult { match block_id { BlockId::Number(block_number) => self.state_by_block_number_or_tag(block_number), BlockId::Hash(block_hash) => self.history_by_block_hash(block_hash.into()), @@ -118,7 +119,7 @@ pub trait StateProviderFactory: BlockIdReader + Send + Sync { fn state_by_block_number_or_tag( &self, number_or_tag: BlockNumberOrTag, - ) -> ProviderResult> { + ) -> ProviderResult { match number_or_tag { BlockNumberOrTag::Latest => self.latest(), BlockNumberOrTag::Finalized => { @@ -152,40 +153,37 @@ pub trait StateProviderFactory: BlockIdReader + Send + Sync { /// /// /// Note: this only looks at historical blocks, not pending blocks. - fn history_by_block_number(&self, block: BlockNumber) -> ProviderResult>; + fn history_by_block_number(&self, block: BlockNumber) -> ProviderResult; /// Returns a historical [StateProvider] indexed by the given block hash. /// /// Note: this only looks at historical blocks, not pending blocks. - fn history_by_block_hash(&self, block: BlockHash) -> ProviderResult>; + fn history_by_block_hash(&self, block: BlockHash) -> ProviderResult; /// Returns _any_[StateProvider] with matching block hash. /// /// This will return a [StateProvider] for either a historical or pending block. - fn state_by_block_hash(&self, block: BlockHash) -> ProviderResult>; + fn state_by_block_hash(&self, block: BlockHash) -> ProviderResult; /// Storage provider for pending state. /// /// Represents the state at the block that extends the canonical chain by one. /// If there's no `pending` block, then this is equal to [StateProviderFactory::latest] - fn pending(&self) -> ProviderResult>; + fn pending(&self) -> ProviderResult; /// Storage provider for pending state for the given block hash. /// /// Represents the state at the block that extends the canonical chain. /// /// If the block couldn't be found, returns `None`. - fn pending_state_by_hash( - &self, - block_hash: B256, - ) -> ProviderResult>>; + fn pending_state_by_hash(&self, block_hash: B256) -> ProviderResult>; - /// Return a [StateProvider] that contains post state data provider. + /// Return a [StateProvider] that contains bundle state data provider. /// Used to inspect or execute transaction on the pending state. fn pending_with_provider( &self, - post_state_data: Box, - ) -> ProviderResult>; + bundle_state_data: Box, + ) -> ProviderResult; } /// Blockchain trait provider that gives access to the blockchain state that is not yet committed @@ -232,6 +230,17 @@ pub trait BundleStateDataProvider: Send + Sync { /// A type that can compute the state root of a given post state. #[auto_impl[Box,&, Arc]] pub trait StateRootProvider: Send + Sync { - /// Returns the state root of the BundleState on top of the current state. - fn state_root(&self, post_state: &BundleStateWithReceipts) -> ProviderResult; + /// Returns the state root of the `BundleState` on top of the current state. + /// + /// NOTE: It is recommended to provide a different implementation from + /// `state_root_with_updates` since it affects the memory usage during state root + /// computation. + fn state_root(&self, bundle_state: &BundleStateWithReceipts) -> ProviderResult; + + /// Returns the state root of the BundleState on top of the current state with trie + /// updates to be committed to the database. + fn state_root_with_updates( + &self, + bundle_state: &BundleStateWithReceipts, + ) -> ProviderResult<(B256, TrieUpdates)>; } diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index c3f1195f1f29..02fe964de1d2 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -11,7 +11,7 @@ use crate::{ metrics::{IncCounterOnDrop, TaskExecutorMetrics}, - shutdown::{signal, Shutdown, Signal}, + shutdown::{signal, GracefulShutdown, GracefulShutdownGuard, Shutdown, Signal}, }; use dyn_clone::DynClone; use futures_util::{ @@ -22,6 +22,10 @@ use std::{ any::Any, fmt::{Display, Formatter}, pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, task::{ready, Context, Poll}, }; use tokio::{ @@ -29,7 +33,7 @@ use tokio::{ sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, task::JoinHandle, }; -use tracing::error; +use tracing::{debug, error}; use tracing_futures::Instrument; pub mod metrics; @@ -147,10 +151,12 @@ pub struct TaskManager { panicked_tasks_rx: UnboundedReceiver, /// The [Signal] to fire when all tasks should be shutdown. /// - /// This is fired on drop. - _signal: Signal, + /// This is fired when dropped. + signal: Option, /// Receiver of the shutdown signal. on_shutdown: Shutdown, + /// How many [GracefulShutdown] tasks are currently active + graceful_tasks: Arc, } // === impl TaskManager === @@ -159,8 +165,15 @@ impl TaskManager { /// Create a new instance connected to the given handle's tokio runtime. pub fn new(handle: Handle) -> Self { let (panicked_tasks_tx, panicked_tasks_rx) = unbounded_channel(); - let (_signal, on_shutdown) = signal(); - Self { handle, panicked_tasks_tx, panicked_tasks_rx, _signal, on_shutdown } + let (signal, on_shutdown) = signal(); + Self { + handle, + panicked_tasks_tx, + panicked_tasks_rx, + signal: Some(signal), + on_shutdown, + graceful_tasks: Arc::new(AtomicUsize::new(0)), + } } /// Returns a new [`TaskExecutor`] that can spawn new tasks onto the tokio runtime this type is @@ -171,8 +184,36 @@ impl TaskManager { on_shutdown: self.on_shutdown.clone(), panicked_tasks_tx: self.panicked_tasks_tx.clone(), metrics: Default::default(), + graceful_tasks: Arc::clone(&self.graceful_tasks), } } + + /// Fires the shutdown signal and awaits until all tasks are shutdown. + pub fn graceful_shutdown(self) { + let _ = self.do_graceful_shutdown(None); + } + + /// Fires the shutdown signal and awaits until all tasks are shutdown. + /// + /// Returns true if all tasks were shutdown before the timeout elapsed. + pub fn graceful_shutdown_with_timeout(self, timeout: std::time::Duration) -> bool { + self.do_graceful_shutdown(Some(timeout)) + } + + fn do_graceful_shutdown(self, timeout: Option) -> bool { + drop(self.signal); + let when = timeout.map(|t| std::time::Instant::now() + t); + while self.graceful_tasks.load(Ordering::Relaxed) > 0 { + if when.map(|when| std::time::Instant::now() > when).unwrap_or(false) { + debug!("graceful shutdown timed out"); + return false + } + std::hint::spin_loop(); + } + + debug!("gracefully shut down"); + true + } } /// An endless future that resolves if a critical task panicked. @@ -232,6 +273,8 @@ pub struct TaskExecutor { panicked_tasks_tx: UnboundedSender, // Task Executor Metrics metrics: TaskExecutorMetrics, + /// How many [GracefulShutdown] tasks are currently active + graceful_tasks: Arc, } // === impl TaskExecutor === @@ -382,7 +425,7 @@ impl TaskExecutor { /// This spawns a critical task onto the runtime. /// /// If this task panics, the [`TaskManager`] is notified. - pub fn spawn_critical_with_signal( + pub fn spawn_critical_with_shutdown_signal( &self, name: &'static str, f: impl FnOnce(Shutdown) -> F, @@ -407,6 +450,55 @@ impl TaskExecutor { self.handle.spawn(task) } + + /// This spawns a critical task onto the runtime. + /// + /// If this task panics, the [TaskManager] is notified. + /// The [TaskManager] will wait until the given future has completed before shutting down. + /// + /// # Example + /// + /// ```no_run + /// # async fn t(executor: reth_tasks::TaskExecutor) { + /// + /// executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move { + /// // await the shutdown signal + /// let guard = shutdown.await; + /// // do work before exiting the program + /// tokio::time::sleep(std::time::Duration::from_secs(1)).await; + /// // allow graceful shutdown + /// drop(guard); + /// }); + /// # } + /// ``` + pub fn spawn_critical_with_graceful_shutdown_signal( + &self, + name: &'static str, + f: impl FnOnce(GracefulShutdown) -> F, + ) -> JoinHandle<()> + where + F: Future + Send + 'static, + { + let panicked_tasks_tx = self.panicked_tasks_tx.clone(); + let on_shutdown = GracefulShutdown::new( + self.on_shutdown.clone(), + GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)), + ); + let fut = f(on_shutdown); + + // wrap the task in catch unwind + let task = std::panic::AssertUnwindSafe(fut) + .catch_unwind() + .map_err(move |error| { + let task_error = PanickedTaskError::new(name, error); + error!("{task_error}"); + let _ = panicked_tasks_tx.send(task_error); + }) + .map(|_| ()) + .in_current_span(); + + self.handle.spawn(task) + } } impl TaskSpawner for TaskExecutor { @@ -444,7 +536,7 @@ enum TaskKind { #[cfg(test)] mod tests { use super::*; - use std::time::Duration; + use std::{sync::atomic::AtomicBool, time::Duration}; #[test] fn test_cloneable() { @@ -521,4 +613,70 @@ mod tests { handle.block_on(shutdown); } + + #[test] + fn test_manager_graceful_shutdown() { + let runtime = tokio::runtime::Runtime::new().unwrap(); + let handle = runtime.handle().clone(); + let manager = TaskManager::new(handle.clone()); + let executor = manager.executor(); + + let val = Arc::new(AtomicBool::new(false)); + let c = val.clone(); + executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move { + let _guard = shutdown.await; + tokio::time::sleep(Duration::from_millis(200)).await; + c.store(true, Ordering::Relaxed); + }); + + manager.graceful_shutdown(); + assert!(val.load(Ordering::Relaxed)); + } + + #[test] + fn test_manager_graceful_shutdown_many() { + let runtime = tokio::runtime::Runtime::new().unwrap(); + let handle = runtime.handle().clone(); + let manager = TaskManager::new(handle.clone()); + let executor = manager.executor(); + let _e = executor.clone(); + + let counter = Arc::new(AtomicUsize::new(0)); + let num = 10; + for _ in 0..num { + let c = counter.clone(); + executor.spawn_critical_with_graceful_shutdown_signal( + "grace", + move |shutdown| async move { + let _guard = shutdown.await; + tokio::time::sleep(Duration::from_millis(200)).await; + c.fetch_add(1, Ordering::SeqCst); + }, + ); + } + + manager.graceful_shutdown(); + assert_eq!(counter.load(Ordering::Relaxed), num); + } + + #[test] + fn test_manager_graceful_shutdown_timeout() { + let runtime = tokio::runtime::Runtime::new().unwrap(); + let handle = runtime.handle().clone(); + let manager = TaskManager::new(handle.clone()); + let executor = manager.executor(); + + let timeout = Duration::from_millis(500); + let val = Arc::new(AtomicBool::new(false)); + let val2 = val.clone(); + executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move { + let _guard = shutdown.await; + tokio::time::sleep(timeout * 3).await; + val2.store(true, Ordering::Relaxed); + unreachable!("should not be reached"); + }); + + manager.graceful_shutdown_with_timeout(timeout); + assert!(!val.load(Ordering::Relaxed)); + } } diff --git a/crates/tasks/src/shutdown.rs b/crates/tasks/src/shutdown.rs index 6264841ae242..5cc012d8ec74 100644 --- a/crates/tasks/src/shutdown.rs +++ b/crates/tasks/src/shutdown.rs @@ -7,10 +7,63 @@ use futures_util::{ use std::{ future::Future, pin::Pin, - task::{Context, Poll}, + sync::{atomic::AtomicUsize, Arc}, + task::{ready, Context, Poll}, }; use tokio::sync::oneshot; +/// A Future that resolves when the shutdown event has been fired. +/// +/// The [TaskManager](crate) +#[derive(Debug)] +pub struct GracefulShutdown { + shutdown: Shutdown, + guard: Option, +} + +impl GracefulShutdown { + pub(crate) fn new(shutdown: Shutdown, guard: GracefulShutdownGuard) -> Self { + Self { shutdown, guard: Some(guard) } + } +} + +impl Future for GracefulShutdown { + type Output = GracefulShutdownGuard; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + ready!(self.shutdown.poll_unpin(cx)); + Poll::Ready(self.get_mut().guard.take().expect("Future polled after completion")) + } +} + +impl Clone for GracefulShutdown { + fn clone(&self) -> Self { + Self { + shutdown: self.shutdown.clone(), + guard: self.guard.as_ref().map(|g| GracefulShutdownGuard::new(Arc::clone(&g.0))), + } + } +} + +/// A guard that fires once dropped to signal the [TaskManager](crate::TaskManager) that the +/// [GracefulShutdown] has completed. +#[derive(Debug)] +#[must_use = "if unused the task will not be gracefully shutdown"] +pub struct GracefulShutdownGuard(Arc); + +impl GracefulShutdownGuard { + pub(crate) fn new(counter: Arc) -> Self { + counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + Self(counter) + } +} + +impl Drop for GracefulShutdownGuard { + fn drop(&mut self) { + self.0.fetch_sub(1, std::sync::atomic::Ordering::SeqCst); + } +} + /// A Future that resolves when the shutdown event has been fired. #[derive(Debug, Clone)] pub struct Shutdown(Shared>); diff --git a/crates/transaction-pool/src/lib.rs b/crates/transaction-pool/src/lib.rs index de60ab7c5fff..84c9e15ecf9d 100644 --- a/crates/transaction-pool/src/lib.rs +++ b/crates/transaction-pool/src/lib.rs @@ -473,6 +473,13 @@ where self.pool.get_transactions_by_sender(sender) } + fn get_transactions_by_origin( + &self, + origin: TransactionOrigin, + ) -> Vec>> { + self.pool.get_transactions_by_origin(origin) + } + fn unique_senders(&self) -> HashSet
{ self.pool.unique_senders() } diff --git a/crates/transaction-pool/src/noop.rs b/crates/transaction-pool/src/noop.rs index 6e6be8f2b8f4..a1cadb5f6597 100644 --- a/crates/transaction-pool/src/noop.rs +++ b/crates/transaction-pool/src/noop.rs @@ -216,6 +216,13 @@ impl TransactionPool for NoopTransactionPool { } Err(BlobStoreError::MissingSidecar(tx_hashes[0])) } + + fn get_transactions_by_origin( + &self, + _origin: TransactionOrigin, + ) -> Vec>> { + vec![] + } } /// A [`TransactionValidator`] that does nothing. diff --git a/crates/transaction-pool/src/pool/mod.rs b/crates/transaction-pool/src/pool/mod.rs index 7242099c1588..cc60955470b4 100644 --- a/crates/transaction-pool/src/pool/mod.rs +++ b/crates/transaction-pool/src/pool/mod.rs @@ -695,6 +695,14 @@ where self.pool.read().get_transactions_by_sender(sender_id) } + /// Returns all transactions that where submitted with the given [TransactionOrigin] + pub(crate) fn get_transactions_by_origin( + &self, + origin: TransactionOrigin, + ) -> Vec>> { + self.pool.read().all().transactions_iter().filter(|tx| tx.origin == origin).collect() + } + /// Returns all the transactions belonging to the hashes. /// /// If no transaction exists, it is skipped. diff --git a/crates/transaction-pool/src/traits.rs b/crates/transaction-pool/src/traits.rs index 91455df3d866..3888b1078666 100644 --- a/crates/transaction-pool/src/traits.rs +++ b/crates/transaction-pool/src/traits.rs @@ -306,6 +306,27 @@ pub trait TransactionPool: Send + Sync + Clone { sender: Address, ) -> Vec>>; + /// Returns all transactions that where submitted with the given [TransactionOrigin] + fn get_transactions_by_origin( + &self, + origin: TransactionOrigin, + ) -> Vec>>; + + /// Returns all transactions that where submitted as [TransactionOrigin::Local] + fn get_local_transactions(&self) -> Vec>> { + self.get_transactions_by_origin(TransactionOrigin::Local) + } + + /// Returns all transactions that where submitted as [TransactionOrigin::Private] + fn get_private_transactions(&self) -> Vec>> { + self.get_transactions_by_origin(TransactionOrigin::Private) + } + + /// Returns all transactions that where submitted as [TransactionOrigin::External] + fn get_external_transactions(&self) -> Vec>> { + self.get_transactions_by_origin(TransactionOrigin::External) + } + /// Returns a set of all senders of transactions in the pool fn unique_senders(&self) -> HashSet
; diff --git a/crates/trie/Cargo.toml b/crates/trie/Cargo.toml index 1aaf2be53be9..43b87026c83f 100644 --- a/crates/trie/Cargo.toml +++ b/crates/trie/Cargo.toml @@ -34,7 +34,7 @@ triehash = { version = "0.8", optional = true } # reth reth-primitives = { workspace = true, features = ["test-utils", "arbitrary"] } reth-db = { workspace = true, features = ["test-utils"] } -reth-provider.workspace = true +reth-provider = { workspace = true, features = ["test-utils"] } # trie triehash = "0.8" diff --git a/crates/trie/src/hashed_cursor/default.rs b/crates/trie/src/hashed_cursor/default.rs index 5641c289280a..d49feedd1849 100644 --- a/crates/trie/src/hashed_cursor/default.rs +++ b/crates/trie/src/hashed_cursor/default.rs @@ -2,13 +2,13 @@ use super::{HashedAccountCursor, HashedCursorFactory, HashedStorageCursor}; use reth_db::{ cursor::{DbCursorRO, DbDupCursorRO}, tables, - transaction::{DbTx, DbTxGAT}, + transaction::DbTx, }; use reth_primitives::{Account, StorageEntry, B256}; impl<'a, TX: DbTx> HashedCursorFactory for &'a TX { - type AccountCursor = >::Cursor; - type StorageCursor = >::DupCursor; + type AccountCursor = ::Cursor; + type StorageCursor = ::DupCursor; fn hashed_account_cursor(&self) -> Result { self.cursor_read::() diff --git a/crates/trie/src/hashed_cursor/post_state.rs b/crates/trie/src/hashed_cursor/post_state.rs index b6dff3027b11..62e028657fc0 100644 --- a/crates/trie/src/hashed_cursor/post_state.rs +++ b/crates/trie/src/hashed_cursor/post_state.rs @@ -3,7 +3,7 @@ use crate::prefix_set::{PrefixSet, PrefixSetMut}; use reth_db::{ cursor::{DbCursorRO, DbDupCursorRO}, tables, - transaction::{DbTx, DbTxGAT}, + transaction::DbTx, }; use reth_primitives::{trie::Nibbles, Account, StorageEntry, B256, U256}; use std::collections::{HashMap, HashSet}; @@ -171,9 +171,9 @@ impl<'a, 'b, TX> HashedPostStateCursorFactory<'a, 'b, TX> { impl<'a, 'b, TX: DbTx> HashedCursorFactory for HashedPostStateCursorFactory<'a, 'b, TX> { type AccountCursor = - HashedPostStateAccountCursor<'b, >::Cursor>; + HashedPostStateAccountCursor<'b, ::Cursor>; type StorageCursor = - HashedPostStateStorageCursor<'b, >::DupCursor>; + HashedPostStateStorageCursor<'b, ::DupCursor>; fn hashed_account_cursor(&self) -> Result { let cursor = self.tx.cursor_read::()?; diff --git a/crates/trie/src/proof.rs b/crates/trie/src/proof.rs index f37e55e1f43b..eb7c438c4b55 100644 --- a/crates/trie/src/proof.rs +++ b/crates/trie/src/proof.rs @@ -166,10 +166,10 @@ mod tests { use super::*; use crate::StateRoot; use once_cell::sync::Lazy; - use reth_db::{database::Database, test_utils::create_test_rw_db}; + use reth_db::database::Database; use reth_interfaces::RethResult; use reth_primitives::{Account, Bytes, Chain, ChainSpec, StorageEntry, HOLESKY, MAINNET, U256}; - use reth_provider::{HashingWriter, ProviderFactory}; + use reth_provider::{test_utils::create_test_provider_factory, HashingWriter, ProviderFactory}; use std::{str::FromStr, sync::Arc}; /* @@ -197,8 +197,10 @@ mod tests { path.into_iter().map(Bytes::from_str).collect::, _>>().unwrap() } - fn insert_genesis(db: DB, chain_spec: Arc) -> RethResult<()> { - let provider_factory = ProviderFactory::new(db, chain_spec.clone()); + fn insert_genesis( + provider_factory: &ProviderFactory, + chain_spec: Arc, + ) -> RethResult<()> { let mut provider = provider_factory.provider_rw()?; // Hash accounts and insert them into hashing table. @@ -233,10 +235,8 @@ mod tests { #[test] fn testspec_proofs() { // Create test database and insert genesis accounts. - let db = create_test_rw_db(); - insert_genesis(db.clone(), TEST_SPEC.clone()).unwrap(); - - let tx = db.tx().unwrap(); + let factory = create_test_provider_factory(); + insert_genesis(&factory, TEST_SPEC.clone()).unwrap(); let data = Vec::from([ ( @@ -277,9 +277,10 @@ mod tests { ), ]); + let provider = factory.provider().unwrap(); for (target, expected_proof) in data { let target = Address::from_str(target).unwrap(); - let account_proof = Proof::new(&tx).account_proof(target, &[]).unwrap(); + let account_proof = Proof::new(provider.tx_ref()).account_proof(target, &[]).unwrap(); pretty_assertions::assert_eq!( account_proof.proof, expected_proof, @@ -291,14 +292,14 @@ mod tests { #[test] fn testspec_empty_storage_proof() { // Create test database and insert genesis accounts. - let db = create_test_rw_db(); - insert_genesis(db.clone(), TEST_SPEC.clone()).unwrap(); - - let tx = db.tx().unwrap(); + let factory = create_test_provider_factory(); + insert_genesis(&factory, TEST_SPEC.clone()).unwrap(); let target = Address::from_str("0x1ed9b1dd266b607ee278726d324b855a093394a6").unwrap(); let slots = Vec::from([B256::with_last_byte(1), B256::with_last_byte(3)]); - let account_proof = Proof::new(&tx).account_proof(target, &slots).unwrap(); + + let provider = factory.provider().unwrap(); + let account_proof = Proof::new(provider.tx_ref()).account_proof(target, &slots).unwrap(); assert_eq!(account_proof.storage_root, EMPTY_ROOT_HASH, "expected empty storage root"); assert_eq!(slots.len(), account_proof.storage_proofs.len()); @@ -310,8 +311,8 @@ mod tests { #[test] fn mainnet_genesis_account_proof() { // Create test database and insert genesis accounts. - let db = create_test_rw_db(); - insert_genesis(db.clone(), MAINNET.clone()).unwrap(); + let factory = create_test_provider_factory(); + insert_genesis(&factory, MAINNET.clone()).unwrap(); // Address from mainnet genesis allocation. // keccak256 - `0xcf67b71c90b0d523dd5004cf206f325748da347685071b34812e21801f5270c4` @@ -326,16 +327,16 @@ mod tests { "0xf8719f20b71c90b0d523dd5004cf206f325748da347685071b34812e21801f5270c4b84ff84d80890ad78ebc5ac6200000a056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a0c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470" ]); - let tx = db.tx().unwrap(); - let account_proof = Proof::new(&tx).account_proof(target, &[]).unwrap(); + let provider = factory.provider().unwrap(); + let account_proof = Proof::new(provider.tx_ref()).account_proof(target, &[]).unwrap(); pretty_assertions::assert_eq!(account_proof.proof, expected_account_proof); } #[test] fn mainnet_genesis_account_proof_nonexistent() { // Create test database and insert genesis accounts. - let db = create_test_rw_db(); - insert_genesis(db.clone(), MAINNET.clone()).unwrap(); + let factory = create_test_provider_factory(); + insert_genesis(&factory, MAINNET.clone()).unwrap(); // Address that does not exist in mainnet genesis allocation. // keccak256 - `0x18f415ffd7f66bb1924d90f0e82fb79ca8c6d8a3473cd9a95446a443b9db1761` @@ -348,18 +349,16 @@ mod tests { "0xf901d1a0b7c55b381eb205712a2f5d1b7d6309ac725da79ab159cb77dc2783af36e6596da0b3b48aa390e0f3718b486ccc32b01682f92819e652315c1629058cd4d9bb1545a0e3c0cc68af371009f14416c27e17f05f4f696566d2ba45362ce5711d4a01d0e4a0bad1e085e431b510508e2a9e3712633a414b3fe6fd358635ab206021254c1e10a0f8407fe8d5f557b9e012d52e688139bd932fec40d48630d7ff4204d27f8cc68da08c6ca46eff14ad4950e65469c394ca9d6b8690513b1c1a6f91523af00082474c80a0630c034178cb1290d4d906edf28688804d79d5e37a3122c909adab19ac7dc8c5a059f6d047c5d1cc75228c4517a537763cb410c38554f273e5448a53bc3c7166e7a0d842f53ce70c3aad1e616fa6485d3880d15c936fcc306ec14ae35236e5a60549a0218ee2ee673c69b4e1b953194b2568157a69085b86e4f01644fa06ab472c6cf9a016a35a660ea496df7c0da646378bfaa9562f401e42a5c2fe770b7bbe22433585a0dd0fbbe227a4d50868cdbb3107573910fd97131ea8d835bef81d91a2fc30b175a06aafa3d78cf179bf055bd5ec629be0ff8352ce0aec9125a4d75be3ee7eb71f10a01d6817ef9f64fcbb776ff6df0c83138dcd2001bd752727af3e60f4afc123d8d58080" ]); - let tx = db.tx().unwrap(); - let account_proof = Proof::new(&tx).account_proof(target, &[]).unwrap(); + let provider = factory.provider().unwrap(); + let account_proof = Proof::new(provider.tx_ref()).account_proof(target, &[]).unwrap(); pretty_assertions::assert_eq!(account_proof.proof, expected_account_proof); } #[test] fn holesky_deposit_contract_proof() { // Create test database and insert genesis accounts. - let db = create_test_rw_db(); - insert_genesis(db.clone(), HOLESKY.clone()).unwrap(); - - let tx = db.tx().unwrap(); + let factory = create_test_provider_factory(); + insert_genesis(&factory, HOLESKY.clone()).unwrap(); let target = Address::from_str("0x4242424242424242424242424242424242424242").unwrap(); // existent @@ -435,7 +434,8 @@ mod tests { ]) }; - let account_proof = Proof::new(&tx).account_proof(target, &slots).unwrap(); + let provider = factory.provider().unwrap(); + let account_proof = Proof::new(provider.tx_ref()).account_proof(target, &slots).unwrap(); pretty_assertions::assert_eq!(account_proof, expected); } } diff --git a/crates/trie/src/trie.rs b/crates/trie/src/trie.rs index 7a7d3a5a0d87..4fdc41a7dd66 100644 --- a/crates/trie/src/trie.rs +++ b/crates/trie/src/trie.rs @@ -496,7 +496,7 @@ mod tests { use reth_db::{ cursor::{DbCursorRO, DbCursorRW, DbDupCursorRO}, tables, - test_utils::create_test_rw_db, + test_utils::TempDatabase, transaction::DbTxMut, DatabaseEnv, }; @@ -505,10 +505,10 @@ mod tests { keccak256, proofs::triehash::KeccakHasher, trie::{BranchNodeCompact, TrieMask}, - Account, Address, StorageEntry, B256, MAINNET, U256, + Account, Address, StorageEntry, B256, U256, }; - use reth_provider::{DatabaseProviderRW, ProviderFactory}; - use std::{collections::BTreeMap, ops::Mul, str::FromStr}; + use reth_provider::{test_utils::create_test_provider_factory, DatabaseProviderRW}; + use std::{collections::BTreeMap, ops::Mul, str::FromStr, sync::Arc}; fn insert_account( tx: &impl DbTxMut, @@ -532,8 +532,7 @@ mod tests { } fn incremental_vs_full_root(inputs: &[&str], modified: &str) { - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db.as_ref(), MAINNET.clone()); + let factory = create_test_provider_factory(); let tx = factory.provider_rw().unwrap(); let hashed_address = B256::with_last_byte(1); @@ -598,8 +597,7 @@ mod tests { let (address, storage) = item; let hashed_address = keccak256(address); - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db.as_ref(), MAINNET.clone()); + let factory = create_test_provider_factory(); let tx = factory.provider_rw().unwrap(); for (key, value) in &storage { tx.tx_ref().put::( @@ -656,8 +654,7 @@ mod tests { #[test] // This ensures we return an empty root when there are no storage entries fn test_empty_storage_root() { - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db.as_ref(), MAINNET.clone()); + let factory = create_test_provider_factory(); let tx = factory.provider_rw().unwrap(); let address = Address::random(); @@ -678,8 +675,7 @@ mod tests { #[test] // This ensures that the walker goes over all the storage slots fn test_storage_root() { - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db.as_ref(), MAINNET.clone()); + let factory = create_test_provider_factory(); let tx = factory.provider_rw().unwrap(); let address = Address::random(); @@ -720,8 +716,7 @@ mod tests { let hashed_entries_total = state.len() + state.values().map(|(_, slots)| slots.len()).sum::(); - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db.as_ref(), MAINNET.clone()); + let factory = create_test_provider_factory(); let tx = factory.provider_rw().unwrap(); for (address, (account, storage)) in &state { @@ -759,8 +754,7 @@ mod tests { } fn test_state_root_with_state(state: State) { - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db.as_ref(), MAINNET.clone()); + let factory = create_test_provider_factory(); let tx = factory.provider_rw().unwrap(); for (address, (account, storage)) in &state { @@ -786,8 +780,7 @@ mod tests { #[test] fn storage_root_regression() { - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db.as_ref(), MAINNET.clone()); + let factory = create_test_provider_factory(); let tx = factory.provider_rw().unwrap(); // Some address whose hash starts with 0xB041 let address3 = Address::from_str("16b07afd1c635f77172e842a000ead9a2a222459").unwrap(); @@ -831,8 +824,7 @@ mod tests { .map(|(slot, val)| (B256::from_str(slot).unwrap(), U256::from(val))), ); - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db.as_ref(), MAINNET.clone()); + let factory = create_test_provider_factory(); let tx = factory.provider_rw().unwrap(); let mut hashed_account_cursor = @@ -1138,8 +1130,7 @@ mod tests { #[test] fn account_trie_around_extension_node() { - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db.db(), MAINNET.clone()); + let factory = create_test_provider_factory(); let tx = factory.provider_rw().unwrap(); let expected = extension_node_trie(&tx); @@ -1164,8 +1155,7 @@ mod tests { #[test] fn account_trie_around_extension_node_with_dbtrie() { - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db.db(), MAINNET.clone()); + let factory = create_test_provider_factory(); let tx = factory.provider_rw().unwrap(); let expected = extension_node_trie(&tx); @@ -1193,8 +1183,7 @@ mod tests { #[test] fn fuzz_state_root_incremental(account_changes: [BTreeMap; 5]) { tokio::runtime::Runtime::new().unwrap().block_on(async { - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db.as_ref(), MAINNET.clone()); + let factory = create_test_provider_factory(); let tx = factory.provider_rw().unwrap(); let mut hashed_account_cursor = tx.tx_ref().cursor_write::().unwrap(); @@ -1227,8 +1216,7 @@ mod tests { #[test] fn storage_trie_around_extension_node() { - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db.db(), MAINNET.clone()); + let factory = create_test_provider_factory(); let tx = factory.provider_rw().unwrap(); let hashed_address = B256::random(); @@ -1254,7 +1242,7 @@ mod tests { } fn extension_node_storage_trie( - tx: &DatabaseProviderRW<'_, &DatabaseEnv>, + tx: &DatabaseProviderRW>>, hashed_address: B256, ) -> (B256, HashMap) { let value = U256::from(1); @@ -1282,7 +1270,7 @@ mod tests { (root, updates) } - fn extension_node_trie(tx: &DatabaseProviderRW<'_, &DatabaseEnv>) -> B256 { + fn extension_node_trie(tx: &DatabaseProviderRW>>) -> B256 { let a = Account { nonce: 0, balance: U256::from(1u64), bytecode_hash: Some(B256::random()) }; let val = encode_account(a, None); diff --git a/crates/trie/src/trie_cursor/account_cursor.rs b/crates/trie/src/trie_cursor/account_cursor.rs index 815396ab0ac1..0fe241760a06 100644 --- a/crates/trie/src/trie_cursor/account_cursor.rs +++ b/crates/trie/src/trie_cursor/account_cursor.rs @@ -46,16 +46,14 @@ mod tests { use reth_db::{ cursor::{DbCursorRO, DbCursorRW}, tables, - test_utils::create_test_rw_db, transaction::DbTxMut, }; - use reth_primitives::{hex_literal::hex, MAINNET}; - use reth_provider::ProviderFactory; + use reth_primitives::hex_literal::hex; + use reth_provider::test_utils::create_test_provider_factory; #[test] fn test_account_trie_order() { - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db.as_ref(), MAINNET.clone()); + let factory = create_test_provider_factory(); let provider = factory.provider_rw().unwrap(); let mut cursor = provider.tx_ref().cursor_write::().unwrap(); diff --git a/crates/trie/src/trie_cursor/storage_cursor.rs b/crates/trie/src/trie_cursor/storage_cursor.rs index 032ff97f64fe..19fe1b281914 100644 --- a/crates/trie/src/trie_cursor/storage_cursor.rs +++ b/crates/trie/src/trie_cursor/storage_cursor.rs @@ -60,20 +60,14 @@ where mod tests { use super::*; - use reth_db::{ - cursor::DbCursorRW, tables, test_utils::create_test_rw_db, transaction::DbTxMut, - }; - use reth_primitives::{ - trie::{BranchNodeCompact, StorageTrieEntry}, - MAINNET, - }; - use reth_provider::ProviderFactory; + use reth_db::{cursor::DbCursorRW, tables, transaction::DbTxMut}; + use reth_primitives::trie::{BranchNodeCompact, StorageTrieEntry}; + use reth_provider::test_utils::create_test_provider_factory; // tests that upsert and seek match on the storagetrie cursor #[test] fn test_storage_cursor_abstraction() { - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db.as_ref(), MAINNET.clone()); + let factory = create_test_provider_factory(); let provider = factory.provider_rw().unwrap(); let mut cursor = provider.tx_ref().cursor_dup_write::().unwrap(); diff --git a/crates/trie/src/walker.rs b/crates/trie/src/walker.rs index 402977bfb1f0..4ad38fe190cc 100644 --- a/crates/trie/src/walker.rs +++ b/crates/trie/src/walker.rs @@ -252,11 +252,9 @@ mod tests { prefix_set::PrefixSetMut, trie_cursor::{AccountTrieCursor, StorageTrieCursor}, }; - use reth_db::{ - cursor::DbCursorRW, tables, test_utils::create_test_rw_db, transaction::DbTxMut, - }; - use reth_primitives::{trie::StorageTrieEntry, MAINNET}; - use reth_provider::ProviderFactory; + use reth_db::{cursor::DbCursorRW, tables, transaction::DbTxMut}; + use reth_primitives::trie::StorageTrieEntry; + use reth_provider::test_utils::create_test_provider_factory; #[test] fn walk_nodes_with_common_prefix() { @@ -281,9 +279,7 @@ mod tests { vec![0x5, 0x8, 0x2], ]; - let db = create_test_rw_db(); - - let factory = ProviderFactory::new(db.as_ref(), MAINNET.clone()); + let factory = create_test_provider_factory(); let tx = factory.provider_rw().unwrap(); let mut account_cursor = tx.tx_ref().cursor_write::().unwrap(); @@ -327,8 +323,7 @@ mod tests { #[test] fn cursor_rootnode_with_changesets() { - let db = create_test_rw_db(); - let factory = ProviderFactory::new(db.as_ref(), MAINNET.clone()); + let factory = create_test_provider_factory(); let tx = factory.provider_rw().unwrap(); let mut cursor = tx.tx_ref().cursor_dup_write::().unwrap(); diff --git a/docs/crates/db.md b/docs/crates/db.md index 679ddcb016a9..cf0161d2b5c3 100644 --- a/docs/crates/db.md +++ b/docs/crates/db.md @@ -65,24 +65,23 @@ There are many tables within the node, all used to store different types of data ## Database -Reth's database design revolves around it's main [Database trait](https://github.com/paradigmxyz/reth/blob/eaca2a4a7fbbdc2f5cd15eab9a8a18ede1891bda/crates/storage/db/src/abstraction/database.rs#L21), which takes advantage of [generic associated types](https://blog.rust-lang.org/2022/10/28/gats-stabilization.html) and [a few design tricks](https://sabrinajewson.org/blog/the-better-alternative-to-lifetime-gats#the-better-gats) to implement the database's functionality across many types. Let's take a quick look at the `Database` trait and how it works. +Reth's database design revolves around it's main [Database trait](https://github.com/paradigmxyz/reth/blob/eaca2a4a7fbbdc2f5cd15eab9a8a18ede1891bda/crates/storage/db/src/abstraction/database.rs#L21), which implements the database's functionality across many types. Let's take a quick look at the `Database` trait and how it works. [File: crates/storage/db/src/abstraction/database.rs](https://github.com/paradigmxyz/reth/blob/eaca2a4a7fbbdc2f5cd15eab9a8a18ede1891bda/crates/storage/db/src/abstraction/database.rs#L21) ```rust ignore /// Main Database trait that spawns transactions to be executed. -pub trait Database: for<'a> DatabaseGAT<'a> { - /// Create read only transaction. - fn tx(&self) -> Result<>::TX, Error>; - - /// Create read write transaction only possible if database is open with write access. - fn tx_mut(&self) -> Result<>::TXMut, Error>; +pub trait Database { + /// RO database transaction + type TX: DbTx + Send + Sync + Debug; + /// RW database transaction + type TXMut: DbTxMut + DbTx + TableImporter + Send + Sync + Debug; /// Takes a function and passes a read-only transaction into it, making sure it's closed in the /// end of the execution. fn view(&self, f: F) -> Result where - F: Fn(&>::TX) -> T, + F: Fn(&::TX) -> T, { let tx = self.tx()?; @@ -96,7 +95,7 @@ pub trait Database: for<'a> DatabaseGAT<'a> { /// the end of the execution. fn update(&self, f: F) -> Result where - F: Fn(&>::TXMut) -> T, + F: Fn(&::TXMut) -> T, { let tx = self.tx_mut()?; @@ -116,7 +115,7 @@ Any type that implements the `Database` trait can create a database transaction, pub struct Transaction<'this, DB: Database> { /// A handle to the DB. pub(crate) db: &'this DB, - tx: Option<>::TXMut>, + tx: Option<::TXMut>, } //--snip-- @@ -134,26 +133,10 @@ where } ``` -The `Database` trait also implements the `DatabaseGAT` trait which defines two associated types `TX` and `TXMut`. +The `Database` defines two associated types `TX` and `TXMut`. [File: crates/storage/db/src/abstraction/database.rs](https://github.com/paradigmxyz/reth/blob/main/crates/storage/db/src/abstraction/database.rs#L11) -```rust ignore -/// Implements the GAT method from: -/// https://sabrinajewson.org/blog/the-better-alternative-to-lifetime-gats#the-better-gats. -/// -/// Sealed trait which cannot be implemented by 3rd parties, exposed only for implementers -pub trait DatabaseGAT<'a, __ImplicitBounds: Sealed = Bounds<&'a Self>>: Send + Sync { - /// RO database transaction - type TX: DbTx + Send + Sync; - /// RW database transaction - type TXMut: DbTxMut + DbTx + Send + Sync; -} -``` - -In Rust, associated types are like generics in that they can be any type fitting the generic's definition, with the difference being that associated types are associated with a trait and can only be used in the context of that trait. - -In the code snippet above, the `DatabaseGAT` trait has two associated types, `TX` and `TXMut`. The `TX` type can be any type that implements the `DbTx` trait, which provides a set of functions to interact with read only transactions. @@ -161,26 +144,40 @@ The `TX` type can be any type that implements the `DbTx` trait, which provides a ```rust ignore /// Read only transaction -pub trait DbTx: for<'a> DbTxGAT<'a> { +pub trait DbTx: Send + Sync { + /// Cursor type for this read-only transaction + type Cursor: DbCursorRO + Send + Sync; + /// DupCursor type for this read-only transaction + type DupCursor: DbDupCursorRO + DbCursorRO + Send + Sync; + /// Get value fn get(&self, key: T::Key) -> Result, Error>; /// Commit for read only transaction will consume and free transaction and allows /// freeing of memory pages fn commit(self) -> Result; /// Iterate over read only values in table. - fn cursor(&self) -> Result<>::Cursor, Error>; + fn cursor(&self) -> Result, Error>; /// Iterate over read only values in dup sorted table. - fn cursor_dup(&self) -> Result<>::DupCursor, Error>; + fn cursor_dup(&self) -> Result, Error>; } ``` -The `TXMut` type can be any type that implements the `DbTxMut` trait, which provides a set of functions to interact with read/write transactions. +The `TXMut` type can be any type that implements the `DbTxMut` trait, which provides a set of functions to interact with read/write transactions and the associated cursor types. [File: crates/storage/db/src/abstraction/transaction.rs](https://github.com/paradigmxyz/reth/blob/main/crates/storage/db/src/abstraction/transaction.rs#L49) ```rust ignore /// Read write transaction that allows writing to database -pub trait DbTxMut: for<'a> DbTxMutGAT<'a> { +pub trait DbTxMut: Send + Sync { + /// Read-Write Cursor type + type CursorMut: DbCursorRW + DbCursorRO + Send + Sync; + /// Read-Write DupCursor type + type DupCursorMut: DbDupCursorRW + + DbCursorRW + + DbDupCursorRO + + DbCursorRO + + Send + + Sync; /// Put value to database fn put(&self, key: T::Key, value: T::Value) -> Result<(), Error>; /// Delete value from database @@ -188,11 +185,11 @@ pub trait DbTxMut: for<'a> DbTxMutGAT<'a> { /// Clears database. fn clear(&self) -> Result<(), Error>; /// Cursor for writing - fn cursor_write(&self) -> Result<>::CursorMut, Error>; + fn cursor_write(&self) -> Result, Error>; /// DupCursor for writing fn cursor_dup_write( &self, - ) -> Result<>::DupCursorMut, Error>; + ) -> Result, Error>; } ``` @@ -220,14 +217,14 @@ where //--snip-- impl<'a, DB: Database> Deref for Transaction<'a, DB> { - type Target = >::TXMut; + type Target = ::TXMut; fn deref(&self) -> &Self::Target { self.tx.as_ref().expect("Tried getting a reference to a non-existent transaction") } } ``` -The `Transaction` struct implements the `Deref` trait, which returns a reference to its `tx` field, which is a `TxMut`. Recall that `TxMut` is a generic type on the `DatabaseGAT` trait, which is defined as `type TXMut: DbTxMut + DbTx + Send + Sync;`, giving it access to all of the functions available to `DbTx`, including the `DbTx::get()` function. +The `Transaction` struct implements the `Deref` trait, which returns a reference to its `tx` field, which is a `TxMut`. Recall that `TxMut` is a generic type on the `Database` trait, which is defined as `type TXMut: DbTxMut + DbTx + Send + Sync;`, giving it access to all of the functions available to `DbTx`, including the `DbTx::get()` function. Notice that the function uses a [turbofish](https://techblog.tonsser.com/posts/what-is-rusts-turbofish) to define which table to use when passing in the `key` to the `DbTx::get()` function. Taking a quick look at the function definition, a generic `T` is defined that implements the `Table` trait mentioned at the beginning of this chapter. @@ -268,19 +265,6 @@ This next example uses the `DbTx::cursor()` method to get a `Cursor`. The `Curso ``` -We are almost at the last stop in the tour of the `db` crate. In addition to the methods provided by the `DbTx` and `DbTxMut` traits, `DbTx` also inherits the `DbTxGAT` trait, while `DbTxMut` inherits `DbTxMutGAT`. These next two traits provide various associated types related to cursors as well as methods to utilize the cursor types. - -[File: crates/storage/db/src/abstraction/transaction.rs](https://github.com/paradigmxyz/reth/blob/main/crates/storage/db/src/abstraction/transaction.rs#L12-L17) - -```rust ignore -pub trait DbTxGAT<'a, __ImplicitBounds: Sealed = Bounds<&'a Self>>: Send + Sync { - /// Cursor GAT - type Cursor: DbCursorRO<'a, T> + Send + Sync; - /// DupCursor GAT - type DupCursor: DbDupCursorRO<'a, T> + DbCursorRO<'a, T> + Send + Sync; -} -``` - Lets look at an examples of how cursors are used. The code snippet below contains the `unwind` method from the `BodyStage` defined in the `stages` crate. This function is responsible for unwinding any changes to the database if there is an error when executing the body stage within the Reth pipeline. [File: crates/stages/src/stages/bodies.rs](https://github.com/paradigmxyz/reth/blob/main/crates/stages/src/stages/bodies.rs#L205-L238) diff --git a/testing/ef-tests/src/cases/blockchain_test.rs b/testing/ef-tests/src/cases/blockchain_test.rs index 5d9a4bf868d4..3706f13ffd35 100644 --- a/testing/ef-tests/src/cases/blockchain_test.rs +++ b/testing/ef-tests/src/cases/blockchain_test.rs @@ -101,9 +101,9 @@ impl Case for BlockchainTestCase { // Call execution stage { - let mut stage = ExecutionStage::new_with_factory(reth_revm::Factory::new( - Arc::new(case.network.clone().into()), - )); + let mut stage = ExecutionStage::new_with_factory( + reth_revm::EvmProcessorFactory::new(Arc::new(case.network.clone().into())), + ); let target = last_block.as_ref().map(|b| b.number); tokio::runtime::Builder::new_current_thread() @@ -111,8 +111,7 @@ impl Case for BlockchainTestCase { .expect("Could not build tokio RT") .block_on(async { // ignore error - let _ = - stage.execute(&provider, ExecInput { target, checkpoint: None }).await; + let _ = stage.execute(&provider, ExecInput { target, checkpoint: None }); }); }