diff --git a/Cargo.lock b/Cargo.lock index 054ad19dad5a..e83e20a03997 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7224,9 +7224,10 @@ dependencies = [ "alloy-rlp", "alloy-rpc-types-engine", "assert_matches", + "criterion", + "crossbeam-channel", "futures", "metrics", - "pin-project", "reth-beacon-consensus", "reth-blockchain-tree", "reth-blockchain-tree-api", @@ -7261,7 +7262,6 @@ dependencies = [ "revm-primitives", "thiserror 1.0.69", "tokio", - "tokio-stream", "tracing", ] diff --git a/crates/engine/tree/Cargo.toml b/crates/engine/tree/Cargo.toml index 278457145e70..d6e1c80a7261 100644 --- a/crates/engine/tree/Cargo.toml +++ b/crates/engine/tree/Cargo.toml @@ -45,9 +45,7 @@ revm-primitives.workspace = true # common futures.workspace = true -pin-project.workspace = true tokio = { workspace = true, features = ["macros", "sync"] } -tokio-stream.workspace = true thiserror.workspace = true # metrics @@ -82,6 +80,12 @@ reth-chainspec.workspace = true alloy-rlp.workspace = true assert_matches.workspace = true +criterion.workspace = true +crossbeam-channel = "0.5.13" + +[[bench]] +name = "channel_perf" +harness = false [features] test-utils = [ diff --git a/crates/engine/tree/benches/channel_perf.rs b/crates/engine/tree/benches/channel_perf.rs new file mode 100644 index 000000000000..c1c65e0a68e1 --- /dev/null +++ b/crates/engine/tree/benches/channel_perf.rs @@ -0,0 +1,132 @@ +//! Benchmark comparing `std::sync::mpsc` and `crossbeam` channels for `StateRootTask`. + +#![allow(missing_docs)] + +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use revm_primitives::{ + Account, AccountInfo, AccountStatus, Address, EvmState, EvmStorage, EvmStorageSlot, HashMap, + B256, U256, +}; +use std::thread; + +/// Creates a mock state with the specified number of accounts for benchmarking +fn create_bench_state(num_accounts: usize) -> EvmState { + let mut state_changes = HashMap::default(); + + for i in 0..num_accounts { + let storage = + EvmStorage::from_iter([(U256::from(i), EvmStorageSlot::new(U256::from(i + 1)))]); + + let account = Account { + info: AccountInfo { + balance: U256::from(100), + nonce: 10, + code_hash: B256::random(), + code: Default::default(), + }, + storage, + status: AccountStatus::Loaded, + }; + + let address = Address::random(); + state_changes.insert(address, account); + } + + state_changes +} + +/// Simulated `StateRootTask` with `std::sync::mpsc` +struct StdStateRootTask { + rx: std::sync::mpsc::Receiver, +} + +impl StdStateRootTask { + const fn new(rx: std::sync::mpsc::Receiver) -> Self { + Self { rx } + } + + fn run(self) { + while let Ok(state) = self.rx.recv() { + criterion::black_box(state); + } + } +} + +/// Simulated `StateRootTask` with `crossbeam-channel` +struct CrossbeamStateRootTask { + rx: crossbeam_channel::Receiver, +} + +impl CrossbeamStateRootTask { + const fn new(rx: crossbeam_channel::Receiver) -> Self { + Self { rx } + } + + fn run(self) { + while let Ok(state) = self.rx.recv() { + criterion::black_box(state); + } + } +} + +/// Benchmarks the performance of different channel implementations for state streaming +fn bench_state_stream(c: &mut Criterion) { + let mut group = c.benchmark_group("state_stream_channels"); + group.sample_size(10); + + for size in &[1, 10, 100] { + let bench_setup = || { + let states: Vec<_> = (0..100).map(|_| create_bench_state(*size)).collect(); + states + }; + + group.bench_with_input(BenchmarkId::new("std_channel", size), size, |b, _| { + b.iter_batched( + bench_setup, + |states| { + let (tx, rx) = std::sync::mpsc::channel(); + let task = StdStateRootTask::new(rx); + + let processor = thread::spawn(move || { + task.run(); + }); + + for state in states { + tx.send(state).unwrap(); + } + drop(tx); + + processor.join().unwrap(); + }, + BatchSize::LargeInput, + ); + }); + + group.bench_with_input(BenchmarkId::new("crossbeam_channel", size), size, |b, _| { + b.iter_batched( + bench_setup, + |states| { + let (tx, rx) = crossbeam_channel::unbounded(); + let task = CrossbeamStateRootTask::new(rx); + + let processor = thread::spawn(move || { + task.run(); + }); + + for state in states { + tx.send(state).unwrap(); + } + drop(tx); + + processor.join().unwrap(); + }, + BatchSize::LargeInput, + ); + }); + } + + group.finish(); +} + +criterion_group!(benches, bench_state_stream); +criterion_main!(benches); diff --git a/crates/engine/tree/src/tree/root.rs b/crates/engine/tree/src/tree/root.rs index fbf6c3481384..45cf5a780310 100644 --- a/crates/engine/tree/src/tree/root.rs +++ b/crates/engine/tree/src/tree/root.rs @@ -1,18 +1,13 @@ //! State root task related functionality. -use futures::Stream; -use pin_project::pin_project; use reth_provider::providers::ConsistentDbView; use reth_trie::{updates::TrieUpdates, TrieInput}; use reth_trie_parallel::root::ParallelStateRootError; use revm_primitives::{EvmState, B256}; -use std::{ - future::Future, - pin::Pin, - sync::{mpsc, Arc}, - task::{Context, Poll}, +use std::sync::{ + mpsc::{self, Receiver, RecvError}, + Arc, }; -use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::debug; /// Result of the state root calculation @@ -28,12 +23,43 @@ pub(crate) struct StateRootHandle { #[allow(dead_code)] impl StateRootHandle { + /// Creates a new handle from a receiver. + pub(crate) const fn new(rx: mpsc::Receiver) -> Self { + Self { rx } + } + /// Waits for the state root calculation to complete. pub(crate) fn wait_for_result(self) -> StateRootResult { self.rx.recv().expect("state root task was dropped without sending result") } } +/// Common configuration for state root tasks +#[derive(Debug)] +pub(crate) struct StateRootConfig { + /// View over the state in the database. + pub consistent_view: ConsistentDbView, + /// Latest trie input. + pub input: Arc, +} + +/// Wrapper for std channel receiver to maintain compatibility with `UnboundedReceiverStream` +#[allow(dead_code)] +pub(crate) struct StdReceiverStream { + rx: Receiver, +} + +#[allow(dead_code)] +impl StdReceiverStream { + pub(crate) const fn new(rx: Receiver) -> Self { + Self { rx } + } + + pub(crate) fn recv(&self) -> Result { + self.rx.recv() + } +} + /// Standalone task that receives a transaction state stream and updates relevant /// data structures to calculate state root. /// @@ -42,15 +68,12 @@ impl StateRootHandle { /// fetches the proofs for relevant accounts from the database and reveal them /// to the tree. /// Then it updates relevant leaves according to the result of the transaction. -#[pin_project] +#[allow(dead_code)] pub(crate) struct StateRootTask { - /// View over the state in the database. - consistent_view: ConsistentDbView, /// Incoming state updates. - #[pin] - state_stream: UnboundedReceiverStream, - /// Latest trie input. - input: Arc, + state_stream: StdReceiverStream, + /// Task configuration. + config: StateRootConfig, } #[allow(dead_code)] @@ -60,65 +83,109 @@ where { /// Creates a new `StateRootTask`. pub(crate) const fn new( - consistent_view: ConsistentDbView, - input: Arc, - state_stream: UnboundedReceiverStream, + config: StateRootConfig, + state_stream: StdReceiverStream, ) -> Self { - Self { consistent_view, state_stream, input } + Self { config, state_stream } } /// Spawns the state root task and returns a handle to await its result. pub(crate) fn spawn(self) -> StateRootHandle { - let (tx, rx) = mpsc::channel(); - - // Spawn the task that will process state updates and calculate the root - tokio::spawn(async move { - debug!(target: "engine::tree", "Starting state root task"); - let result = self.await; - let _ = tx.send(result); - }); + let (tx, rx) = mpsc::sync_channel(1); + std::thread::Builder::new() + .name("State Root Task".to_string()) + .spawn(move || { + debug!(target: "engine::tree", "Starting state root task"); + let result = self.run(); + let _ = tx.send(result); + }) + .expect("failed to spawn state root thread"); - StateRootHandle { rx } + StateRootHandle::new(rx) } /// Handles state updates. fn on_state_update( - _view: &ConsistentDbView, - _input: &Arc, + _view: &reth_provider::providers::ConsistentDbView, + _input: &std::sync::Arc, _state: EvmState, ) { + // Default implementation of state update handling // TODO: calculate hashed state update and dispatch proof gathering for it. } } -impl Future for StateRootTask +#[allow(dead_code)] +impl StateRootTask where Factory: Send + 'static, { - type Output = StateRootResult; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - - // Process all items until the stream is closed - loop { - match this.state_stream.as_mut().poll_next(cx) { - Poll::Ready(Some(state)) => { - Self::on_state_update(this.consistent_view, this.input, state); - } - Poll::Ready(None) => { - // stream closed, return final result - return Poll::Ready(Ok((B256::default(), TrieUpdates::default()))); - } - Poll::Pending => { - return Poll::Pending; - } - } + fn run(self) -> StateRootResult { + while let Ok(state) = self.state_stream.recv() { + Self::on_state_update(&self.config.consistent_view, &self.config.input, state); } // TODO: // * keep track of proof calculation // * keep track of intermediate root computation // * return final state root result + Ok((B256::default(), TrieUpdates::default())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use reth_provider::{providers::ConsistentDbView, test_utils::MockEthProvider}; + use reth_trie::TrieInput; + use revm_primitives::{ + Account, AccountInfo, AccountStatus, Address, EvmState, EvmStorage, EvmStorageSlot, + HashMap, B256, U256, + }; + use std::sync::Arc; + + fn create_mock_config() -> StateRootConfig { + let factory = MockEthProvider::default(); + let view = ConsistentDbView::new(factory, None); + let input = Arc::new(TrieInput::default()); + StateRootConfig { consistent_view: view, input } + } + + fn create_mock_state() -> revm_primitives::EvmState { + let mut state_changes: EvmState = HashMap::default(); + let storage = EvmStorage::from_iter([(U256::from(1), EvmStorageSlot::new(U256::from(2)))]); + let account = Account { + info: AccountInfo { + balance: U256::from(100), + nonce: 10, + code_hash: B256::random(), + code: Default::default(), + }, + storage, + status: AccountStatus::Loaded, + }; + + let address = Address::random(); + state_changes.insert(address, account); + + state_changes + } + + #[test] + fn test_state_root_task() { + let config = create_mock_config(); + let (tx, rx) = std::sync::mpsc::channel(); + let stream = StdReceiverStream::new(rx); + + let task = StateRootTask::new(config, stream); + let handle = task.spawn(); + + for _ in 0..10 { + tx.send(create_mock_state()).expect("failed to send state"); + } + drop(tx); + + let result = handle.wait_for_result(); + assert!(result.is_ok(), "sync block execution failed"); } }