diff --git a/crates/engine/tree/benches/state_root_task.rs b/crates/engine/tree/benches/state_root_task.rs index 391fd333d12f..9055190fb6f4 100644 --- a/crates/engine/tree/benches/state_root_task.rs +++ b/crates/engine/tree/benches/state_root_task.rs @@ -13,7 +13,11 @@ use reth_provider::{ HashingWriter, ProviderFactory, }; use reth_testing_utils::generators::{self, Rng}; -use reth_trie::TrieInput; +use reth_trie::{ + hashed_cursor::HashedPostStateCursorFactory, proof::ProofBlindedProviderFactory, + trie_cursor::InMemoryTrieCursorFactory, TrieInput, +}; +use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory}; use revm_primitives::{ Account as RevmAccount, AccountInfo, AccountStatus, Address, EvmState, EvmStorageSlot, HashMap, B256, KECCAK_EMPTY, U256, @@ -139,20 +143,38 @@ fn bench_state_root(c: &mut Criterion) { consistent_view: ConsistentDbView::new(factory, None), input: trie_input, }; + let provider = config.consistent_view.provider_ro().unwrap(); + let nodes_sorted = config.input.nodes.clone().into_sorted(); + let state_sorted = config.input.state.clone().into_sorted(); + let prefix_sets = Arc::new(config.input.prefix_sets.clone()); - (config, state_updates) + (config, state_updates, provider, nodes_sorted, state_sorted, prefix_sets) }, - |(config, state_updates)| { - let task = StateRootTask::new(config); - let mut hook = task.state_hook(); - let handle = task.spawn(); - - for update in state_updates { - hook.on_state(&update) - } - drop(hook); - - black_box(handle.wait_for_result().expect("task failed")); + |(config, state_updates, provider, nodes_sorted, state_sorted, prefix_sets)| { + let blinded_provider_factory = ProofBlindedProviderFactory::new( + InMemoryTrieCursorFactory::new( + DatabaseTrieCursorFactory::new(provider.tx_ref()), + &nodes_sorted, + ), + HashedPostStateCursorFactory::new( + DatabaseHashedCursorFactory::new(provider.tx_ref()), + &state_sorted, + ), + prefix_sets, + ); + + black_box(std::thread::scope(|scope| { + let task = StateRootTask::new(config, blinded_provider_factory); + let mut hook = task.state_hook(); + let handle = task.spawn(scope); + + for update in state_updates { + hook.on_state(&update) + } + drop(hook); + + handle.wait_for_result().expect("task failed") + })); }, ) }, diff --git a/crates/engine/tree/src/tree/root.rs b/crates/engine/tree/src/tree/root.rs index 72b18d49f52c..7254cc882a7e 100644 --- a/crates/engine/tree/src/tree/root.rs +++ b/crates/engine/tree/src/tree/root.rs @@ -14,7 +14,8 @@ use reth_trie::{ use reth_trie_db::DatabaseProof; use reth_trie_parallel::root::ParallelStateRootError; use reth_trie_sparse::{ - errors::{SparseStateTrieResult, SparseTrieErrorKind}, + blinded::{BlindedProvider, BlindedProviderFactory}, + errors::{SparseStateTrieResult, SparseTrieError, SparseTrieErrorKind}, SparseStateTrie, }; use revm_primitives::{keccak256, EvmState, B256}; @@ -25,6 +26,7 @@ use std::{ mpsc::{self, channel, Receiver, Sender}, Arc, }, + thread::{self}, time::{Duration, Instant}, }; use tracing::{debug, error, trace}; @@ -68,7 +70,7 @@ pub struct StateRootConfig { /// Messages used internally by the state root task #[derive(Debug)] #[allow(dead_code)] -pub enum StateRootMessage { +pub enum StateRootMessage { /// New state update from transaction execution StateUpdate(EvmState), /// Proof calculation completed for a specific state update @@ -83,7 +85,7 @@ pub enum StateRootMessage { /// State root calculation completed RootCalculated { /// The updated sparse trie - trie: Box, + trie: Box>, /// Time taken to calculate the root elapsed: Duration, }, @@ -159,24 +161,24 @@ impl ProofSequencer { /// A wrapper for the sender that signals completion when dropped #[allow(dead_code)] -pub(crate) struct StateHookSender(Sender); +pub(crate) struct StateHookSender(Sender>); #[allow(dead_code)] -impl StateHookSender { - pub(crate) const fn new(inner: Sender) -> Self { +impl StateHookSender { + pub(crate) const fn new(inner: Sender>) -> Self { Self(inner) } } -impl Deref for StateHookSender { - type Target = Sender; +impl Deref for StateHookSender { + type Target = Sender>; fn deref(&self) -> &Self::Target { &self.0 } } -impl Drop for StateHookSender { +impl Drop for StateHookSender { fn drop(&mut self) { // Send completion signal when the sender is dropped let _ = self.0.send(StateRootMessage::FinishedStateUpdates); @@ -224,24 +226,24 @@ fn evm_state_to_hashed_post_state(update: EvmState) -> HashedPostState { /// to the tree. /// Then it updates relevant leaves according to the result of the transaction. #[derive(Debug)] -pub struct StateRootTask { +pub struct StateRootTask { /// Task configuration. config: StateRootConfig, /// Receiver for state root related messages. - rx: Receiver, + rx: Receiver>, /// Sender for state root related messages. - tx: Sender, + tx: Sender>, /// Proof targets that have been already fetched. fetched_proof_targets: MultiProofTargets, /// Proof sequencing handler. proof_sequencer: ProofSequencer, /// The sparse trie used for the state root calculation. If [`None`], then update is in /// progress. - sparse_trie: Option>, + sparse_trie: Option>>, } #[allow(dead_code)] -impl StateRootTask +impl<'env, Factory, ABP, SBP, BPF> StateRootTask where Factory: DatabaseProviderFactory + StateCommitmentProvider @@ -249,9 +251,15 @@ where + Send + Sync + 'static, + ABP: BlindedProvider + Send + Sync + 'env, + SBP: BlindedProvider + Send + Sync + 'env, + BPF: BlindedProviderFactory + + Send + + Sync + + 'env, { /// Creates a new state root task with the unified message channel - pub fn new(config: StateRootConfig) -> Self { + pub fn new(config: StateRootConfig, blinded_provider: BPF) -> Self { let (tx, rx) = channel(); Self { @@ -260,18 +268,19 @@ where tx, fetched_proof_targets: Default::default(), proof_sequencer: ProofSequencer::new(), - sparse_trie: Some(Box::new(SparseStateTrie::default().with_updates(true))), + sparse_trie: Some(Box::new(SparseStateTrie::new(blinded_provider).with_updates(true))), } } /// Spawns the state root task and returns a handle to await its result. - pub fn spawn(self) -> StateRootHandle { + pub fn spawn<'scope>(self, scope: &'scope thread::Scope<'scope, 'env>) -> StateRootHandle { let (tx, rx) = mpsc::sync_channel(1); std::thread::Builder::new() .name("State Root Task".to_string()) - .spawn(move || { + .spawn_scoped(scope, move || { debug!(target: "engine::tree", "Starting state root task"); - let result = self.run(); + + let result = rayon::scope(|scope| self.run(scope)); let _ = tx.send(result); }) .expect("failed to spawn state root thread"); @@ -294,12 +303,13 @@ where /// /// Returns proof targets derived from the state update. fn on_state_update( + scope: &rayon::Scope<'env>, view: ConsistentDbView, input: Arc, update: EvmState, fetched_proof_targets: &mut MultiProofTargets, proof_sequence_number: u64, - state_root_message_sender: Sender, + state_root_message_sender: Sender>, ) { let hashed_state_update = evm_state_to_hashed_post_state(update); @@ -309,7 +319,7 @@ where } // Dispatch proof gathering for this state update - rayon::spawn(move || { + scope.spawn(move |_| { let provider = match view.provider_ro() { Ok(provider) => provider, Err(error) => { @@ -362,7 +372,12 @@ where } /// Spawns root calculation with the current state and proofs. - fn spawn_root_calculation(&mut self, state: HashedPostState, multiproof: MultiProof) { + fn spawn_root_calculation( + &mut self, + scope: &rayon::Scope<'env>, + state: HashedPostState, + multiproof: MultiProof, + ) { let Some(trie) = self.sparse_trie.take() else { return }; trace!( @@ -376,7 +391,7 @@ where let targets = get_proof_targets(&state, &HashMap::default()); let tx = self.tx.clone(); - rayon::spawn(move || { + scope.spawn(move |_| { let result = update_sparse_trie(trie, multiproof, targets, state); match result { Ok((trie, elapsed)) => { @@ -394,7 +409,7 @@ where }); } - fn run(mut self) -> StateRootResult { + fn run(mut self, scope: &rayon::Scope<'env>) -> StateRootResult { let mut current_state_update = HashedPostState::default(); let mut current_multiproof = MultiProof::default(); let mut updates_received = 0; @@ -414,6 +429,7 @@ where "Received new state update" ); Self::on_state_update( + scope, self.config.consistent_view.clone(), self.config.input.clone(), update, @@ -443,7 +459,11 @@ where current_multiproof.extend(combined_proof); current_state_update.extend(combined_state_update); } else { - self.spawn_root_calculation(combined_state_update, combined_proof); + self.spawn_root_calculation( + scope, + combined_state_update, + combined_proof, + ); } } } @@ -481,6 +501,7 @@ where "Spawning subsequent root calculation" ); self.spawn_root_calculation( + scope, std::mem::take(&mut current_state_update), std::mem::take(&mut current_multiproof), ); @@ -555,12 +576,16 @@ fn get_proof_targets( /// Updates the sparse trie with the given proofs and state, and returns the updated trie and the /// time it took. -fn update_sparse_trie( - mut trie: Box, +fn update_sparse_trie< + ABP: BlindedProvider + Send + Sync, + SBP: BlindedProvider + Send + Sync, + BPF: BlindedProviderFactory + Send + Sync, +>( + mut trie: Box>, multiproof: MultiProof, targets: MultiProofTargets, state: HashedPostState, -) -> SparseStateTrieResult<(Box, Duration)> { +) -> SparseStateTrieResult<(Box>, Duration)> { trace!(target: "engine::root::sparse", "Updating sparse trie"); let started_at = Instant::now(); @@ -582,12 +607,10 @@ fn update_sparse_trie( trace!(target: "engine::root::sparse", ?address, "Wiping storage"); storage_trie.wipe()?; } - for (slot, value) in storage.storage { let slot_nibbles = Nibbles::unpack(slot); if value.is_zero() { trace!(target: "engine::root::sparse", ?address, ?slot, "Removing storage slot"); - storage_trie.remove_leaf(&slot_nibbles)?; } else { trace!(target: "engine::root::sparse", ?address, ?slot, "Updating storage slot"); @@ -629,7 +652,11 @@ mod tests { providers::ConsistentDbView, test_utils::create_test_provider_factory, HashingWriter, }; use reth_testing_utils::generators::{self, Rng}; - use reth_trie::{test_utils::state_root, TrieInput}; + use reth_trie::{ + hashed_cursor::HashedPostStateCursorFactory, proof::ProofBlindedProviderFactory, + test_utils::state_root, trie_cursor::InMemoryTrieCursorFactory, TrieInput, + }; + use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory}; use revm_primitives::{ Account as RevmAccount, AccountInfo, AccountStatus, Address, EvmState, EvmStorageSlot, HashMap, B256, KECCAK_EMPTY, U256, @@ -746,16 +773,32 @@ mod tests { consistent_view: ConsistentDbView::new(factory, None), input: Arc::new(TrieInput::from_state(hashed_state)), }; - let task = StateRootTask::new(config); - let mut state_hook = task.state_hook(); - let handle = task.spawn(); + let provider = config.consistent_view.provider_ro().unwrap(); + let nodes_sorted = config.input.nodes.clone().into_sorted(); + let state_sorted = config.input.state.clone().into_sorted(); + let blinded_provider_factory = ProofBlindedProviderFactory::new( + InMemoryTrieCursorFactory::new( + DatabaseTrieCursorFactory::new(provider.tx_ref()), + &nodes_sorted, + ), + HashedPostStateCursorFactory::new( + DatabaseHashedCursorFactory::new(provider.tx_ref()), + &state_sorted, + ), + Arc::new(config.input.prefix_sets.clone()), + ); + let (root_from_task, _) = std::thread::scope(|std_scope| { + let task = StateRootTask::new(config, blinded_provider_factory); + let mut state_hook = task.state_hook(); + let handle = task.spawn(std_scope); - for update in state_updates { - state_hook.on_state(&update); - } - drop(state_hook); + for update in state_updates { + state_hook.on_state(&update); + } + drop(state_hook); - let (root_from_task, _) = handle.wait_for_result().expect("task failed"); + handle.wait_for_result().expect("task failed") + }); let root_from_base = state_root(accumulated_state); assert_eq!( diff --git a/crates/trie/trie/src/proof/blinded.rs b/crates/trie/trie/src/proof/blinded.rs index 1383453f344d..55f8bdfbc48c 100644 --- a/crates/trie/trie/src/proof/blinded.rs +++ b/crates/trie/trie/src/proof/blinded.rs @@ -33,8 +33,8 @@ impl ProofBlindedProviderFactory { impl BlindedProviderFactory for ProofBlindedProviderFactory where - T: TrieCursorFactory + Clone, - H: HashedCursorFactory + Clone, + T: TrieCursorFactory + Clone + Send + Sync, + H: HashedCursorFactory + Clone + Send + Sync, { type AccountNodeProvider = ProofBlindedAccountProvider; type StorageNodeProvider = ProofBlindedStorageProvider; @@ -81,8 +81,8 @@ impl ProofBlindedAccountProvider { impl BlindedProvider for ProofBlindedAccountProvider where - T: TrieCursorFactory + Clone, - H: HashedCursorFactory + Clone, + T: TrieCursorFactory + Clone + Send + Sync, + H: HashedCursorFactory + Clone + Send + Sync, { type Error = SparseTrieError; @@ -125,8 +125,8 @@ impl ProofBlindedStorageProvider { impl BlindedProvider for ProofBlindedStorageProvider where - T: TrieCursorFactory + Clone, - H: HashedCursorFactory + Clone, + T: TrieCursorFactory + Clone + Send + Sync, + H: HashedCursorFactory + Clone + Send + Sync, { type Error = SparseTrieError; diff --git a/crates/trie/trie/src/witness.rs b/crates/trie/trie/src/witness.rs index 5e56cbf21c71..f5e6d46d6ce3 100644 --- a/crates/trie/trie/src/witness.rs +++ b/crates/trie/trie/src/witness.rs @@ -75,8 +75,8 @@ impl TrieWitness { impl TrieWitness where - T: TrieCursorFactory + Clone, - H: HashedCursorFactory + Clone, + T: TrieCursorFactory + Clone + Send + Sync, + H: HashedCursorFactory + Clone + Send + Sync, { /// Compute the state transition witness for the trie. Gather all required nodes /// to apply `state` on top of the current trie state.