Skip to content

Commit

Permalink
fix lifetimes
Browse files Browse the repository at this point in the history
  • Loading branch information
shekhirin committed Dec 11, 2024
1 parent 9cc0964 commit 1ae7f81
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 36 deletions.
48 changes: 35 additions & 13 deletions crates/engine/tree/benches/state_root_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ use reth_provider::{
HashingWriter, ProviderFactory,
};
use reth_testing_utils::generators::{self, Rng};
use reth_trie::TrieInput;
use reth_trie::{
hashed_cursor::HashedPostStateCursorFactory, proof::ProofBlindedProviderFactory,
trie_cursor::InMemoryTrieCursorFactory, TrieInput,
};
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use revm_primitives::{
Account as RevmAccount, AccountInfo, AccountStatus, Address, EvmState, EvmStorageSlot, HashMap,
B256, KECCAK_EMPTY, U256,
Expand Down Expand Up @@ -139,20 +143,38 @@ fn bench_state_root(c: &mut Criterion) {
consistent_view: ConsistentDbView::new(factory, None),
input: trie_input,
};
let provider = config.consistent_view.provider_ro().unwrap();
let nodes_sorted = config.input.nodes.clone().into_sorted();
let state_sorted = config.input.state.clone().into_sorted();
let prefix_sets = Arc::new(config.input.prefix_sets.clone());

(config, state_updates)
(config, state_updates, provider, nodes_sorted, state_sorted, prefix_sets)
},
|(config, state_updates)| {
let task = StateRootTask::new(config);
let mut hook = task.state_hook();
let handle = task.spawn();

for update in state_updates {
hook.on_state(&update)
}
drop(hook);

black_box(handle.wait_for_result().expect("task failed"));
|(config, state_updates, provider, nodes_sorted, state_sorted, prefix_sets)| {
let blinded_provider_factory = ProofBlindedProviderFactory::new(
InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider.tx_ref()),
&nodes_sorted,
),
HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider.tx_ref()),
&state_sorted,
),
prefix_sets,
);

black_box(std::thread::scope(|scope| {
let task = StateRootTask::new(config, blinded_provider_factory);
let mut hook = task.state_hook();
let handle = task.spawn(scope);

for update in state_updates {
hook.on_state(&update)
}
drop(hook);

handle.wait_for_result().expect("task failed")
}));
},
)
},
Expand Down
65 changes: 42 additions & 23 deletions crates/engine/tree/src/tree/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use std::{
mpsc::{self, channel, Receiver, Sender},
Arc,
},
thread::{self},
time::{Duration, Instant},
};
use tracing::{debug, error, trace};
Expand Down Expand Up @@ -242,17 +243,20 @@ pub struct StateRootTask<Factory, BPF: BlindedProviderFactory> {
}

#[allow(dead_code)]
impl<Factory, ABP, SBP, BPF> StateRootTask<Factory, BPF>
impl<'env, Factory, ABP, SBP, BPF> StateRootTask<Factory, BPF>
where
Factory: DatabaseProviderFactory<Provider: BlockReader>
+ StateCommitmentProvider
+ Clone
+ Send
+ Sync
+ 'static,
ABP: BlindedProvider<Error = SparseTrieError> + Send + Sync,
SBP: BlindedProvider<Error = SparseTrieError> + Send + Sync,
BPF: BlindedProviderFactory<AccountNodeProvider = ABP, StorageNodeProvider = SBP> + Send + Sync,
ABP: BlindedProvider<Error = SparseTrieError> + Send + Sync + 'env,
SBP: BlindedProvider<Error = SparseTrieError> + Send + Sync + 'env,
BPF: BlindedProviderFactory<AccountNodeProvider = ABP, StorageNodeProvider = SBP>
+ Send
+ Sync
+ 'env,
{
/// Creates a new state root task with the unified message channel
pub fn new(config: StateRootConfig<Factory>, blinded_provider: BPF) -> Self {
Expand All @@ -269,13 +273,14 @@ where
}

/// Spawns the state root task and returns a handle to await its result.
pub fn spawn(self) -> StateRootHandle {
pub fn spawn<'scope>(self, scope: &'scope thread::Scope<'scope, 'env>) -> StateRootHandle {
let (tx, rx) = mpsc::sync_channel(1);
std::thread::Builder::new()
.name("State Root Task".to_string())
.spawn(move || {
.spawn_scoped(scope, move || {
debug!(target: "engine::tree", "Starting state root task");
let result = self.run();

let result = rayon::scope(|scope| self.run(scope));
let _ = tx.send(result);
})
.expect("failed to spawn state root thread");
Expand All @@ -298,6 +303,7 @@ where
///
/// Returns proof targets derived from the state update.
fn on_state_update(
scope: &rayon::Scope<'env>,
view: ConsistentDbView<Factory>,
input: Arc<TrieInput>,
update: EvmState,
Expand All @@ -313,7 +319,7 @@ where
}

// Dispatch proof gathering for this state update
rayon::spawn(move || {
scope.spawn(move |_| {
let provider = match view.provider_ro() {
Ok(provider) => provider,
Err(error) => {
Expand Down Expand Up @@ -366,7 +372,12 @@ where
}

/// Spawns root calculation with the current state and proofs.
fn spawn_root_calculation(&mut self, state: HashedPostState, multiproof: MultiProof) {
fn spawn_root_calculation(
&mut self,
scope: &rayon::Scope<'env>,
state: HashedPostState,
multiproof: MultiProof,
) {
let Some(trie) = self.sparse_trie.take() else { return };

trace!(
Expand All @@ -380,7 +391,7 @@ where
let targets = get_proof_targets(&state, &HashMap::default());

let tx = self.tx.clone();
rayon::spawn(move || {
scope.spawn(move |_| {
let result = update_sparse_trie(trie, multiproof, targets, state);
match result {
Ok((trie, elapsed)) => {
Expand All @@ -398,7 +409,7 @@ where
});
}

fn run(mut self) -> StateRootResult {
fn run(mut self, scope: &rayon::Scope<'env>) -> StateRootResult {
let mut current_state_update = HashedPostState::default();
let mut current_multiproof = MultiProof::default();
let mut updates_received = 0;
Expand All @@ -418,6 +429,7 @@ where
"Received new state update"
);
Self::on_state_update(
scope,
self.config.consistent_view.clone(),
self.config.input.clone(),
update,
Expand Down Expand Up @@ -447,7 +459,11 @@ where
current_multiproof.extend(combined_proof);
current_state_update.extend(combined_state_update);
} else {
self.spawn_root_calculation(combined_state_update, combined_proof);
self.spawn_root_calculation(
scope,
combined_state_update,
combined_proof,
);
}
}
}
Expand Down Expand Up @@ -485,6 +501,7 @@ where
"Spawning subsequent root calculation"
);
self.spawn_root_calculation(
scope,
std::mem::take(&mut current_state_update),
std::mem::take(&mut current_multiproof),
);
Expand Down Expand Up @@ -560,9 +577,9 @@ fn get_proof_targets(
/// Updates the sparse trie with the given proofs and state, and returns the updated trie and the
/// time it took.
fn update_sparse_trie<
ABP: BlindedProvider<Error = SparseTrieError>,
SBP: BlindedProvider<Error = SparseTrieError>,
BPF: BlindedProviderFactory<AccountNodeProvider = ABP, StorageNodeProvider = SBP>,
ABP: BlindedProvider<Error = SparseTrieError> + Send + Sync,
SBP: BlindedProvider<Error = SparseTrieError> + Send + Sync,
BPF: BlindedProviderFactory<AccountNodeProvider = ABP, StorageNodeProvider = SBP> + Send + Sync,
>(
mut trie: Box<SparseStateTrie<BPF>>,
multiproof: MultiProof,
Expand Down Expand Up @@ -770,16 +787,18 @@ mod tests {
),
Arc::new(config.input.prefix_sets.clone()),
);
let task = StateRootTask::new(config, blinded_provider_factory);
let mut state_hook = task.state_hook();
let handle = task.spawn();
let (root_from_task, _) = std::thread::scope(|std_scope| {
let task = StateRootTask::new(config, blinded_provider_factory);
let mut state_hook = task.state_hook();
let handle = task.spawn(std_scope);

for update in state_updates {
state_hook.on_state(&update);
}
drop(state_hook);
for update in state_updates {
state_hook.on_state(&update);
}
drop(state_hook);

let (root_from_task, _) = handle.wait_for_result().expect("task failed");
handle.wait_for_result().expect("task failed")
});
let root_from_base = state_root(accumulated_state);

assert_eq!(
Expand Down

0 comments on commit 1ae7f81

Please sign in to comment.