Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(engine): integrate blinded provider factory into state root task #13294

Merged
merged 5 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 35 additions & 13 deletions crates/engine/tree/benches/state_root_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ use reth_provider::{
HashingWriter, ProviderFactory,
};
use reth_testing_utils::generators::{self, Rng};
use reth_trie::TrieInput;
use reth_trie::{
hashed_cursor::HashedPostStateCursorFactory, proof::ProofBlindedProviderFactory,
trie_cursor::InMemoryTrieCursorFactory, TrieInput,
};
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use revm_primitives::{
Account as RevmAccount, AccountInfo, AccountStatus, Address, EvmState, EvmStorageSlot, HashMap,
B256, KECCAK_EMPTY, U256,
Expand Down Expand Up @@ -139,20 +143,38 @@ fn bench_state_root(c: &mut Criterion) {
consistent_view: ConsistentDbView::new(factory, None),
input: trie_input,
};
let provider = config.consistent_view.provider_ro().unwrap();
let nodes_sorted = config.input.nodes.clone().into_sorted();
let state_sorted = config.input.state.clone().into_sorted();
let prefix_sets = Arc::new(config.input.prefix_sets.clone());

(config, state_updates)
(config, state_updates, provider, nodes_sorted, state_sorted, prefix_sets)
},
|(config, state_updates)| {
let task = StateRootTask::new(config);
let mut hook = task.state_hook();
let handle = task.spawn();

for update in state_updates {
hook.on_state(&update)
}
drop(hook);

black_box(handle.wait_for_result().expect("task failed"));
|(config, state_updates, provider, nodes_sorted, state_sorted, prefix_sets)| {
let blinded_provider_factory = ProofBlindedProviderFactory::new(
InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider.tx_ref()),
&nodes_sorted,
),
HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider.tx_ref()),
&state_sorted,
),
prefix_sets,
);

black_box(std::thread::scope(|scope| {
let task = StateRootTask::new(config, blinded_provider_factory);
let mut hook = task.state_hook();
let handle = task.spawn(scope);

for update in state_updates {
hook.on_state(&update)
}
drop(hook);

handle.wait_for_result().expect("task failed")
}));
},
)
},
Expand Down
121 changes: 82 additions & 39 deletions crates/engine/tree/src/tree/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ use reth_trie::{
use reth_trie_db::DatabaseProof;
use reth_trie_parallel::root::ParallelStateRootError;
use reth_trie_sparse::{
errors::{SparseStateTrieResult, SparseTrieErrorKind},
blinded::{BlindedProvider, BlindedProviderFactory},
errors::{SparseStateTrieResult, SparseTrieError, SparseTrieErrorKind},
SparseStateTrie,
};
use revm_primitives::{keccak256, EvmState, B256};
Expand All @@ -25,6 +26,7 @@ use std::{
mpsc::{self, channel, Receiver, Sender},
Arc,
},
thread::{self},
time::{Duration, Instant},
};
use tracing::{debug, error, trace};
Expand Down Expand Up @@ -68,7 +70,7 @@ pub struct StateRootConfig<Factory> {
/// Messages used internally by the state root task
#[derive(Debug)]
#[allow(dead_code)]
pub enum StateRootMessage {
pub enum StateRootMessage<BPF: BlindedProviderFactory> {
/// New state update from transaction execution
StateUpdate(EvmState),
/// Proof calculation completed for a specific state update
Expand All @@ -83,7 +85,7 @@ pub enum StateRootMessage {
/// State root calculation completed
RootCalculated {
/// The updated sparse trie
trie: Box<SparseStateTrie>,
trie: Box<SparseStateTrie<BPF>>,
/// Time taken to calculate the root
elapsed: Duration,
},
Expand Down Expand Up @@ -159,24 +161,24 @@ impl ProofSequencer {

/// A wrapper for the sender that signals completion when dropped
#[allow(dead_code)]
pub(crate) struct StateHookSender(Sender<StateRootMessage>);
pub(crate) struct StateHookSender<BPF: BlindedProviderFactory>(Sender<StateRootMessage<BPF>>);

#[allow(dead_code)]
impl StateHookSender {
pub(crate) const fn new(inner: Sender<StateRootMessage>) -> Self {
impl<BPF: BlindedProviderFactory> StateHookSender<BPF> {
pub(crate) const fn new(inner: Sender<StateRootMessage<BPF>>) -> Self {
Self(inner)
}
}

impl Deref for StateHookSender {
type Target = Sender<StateRootMessage>;
impl<BPF: BlindedProviderFactory> Deref for StateHookSender<BPF> {
type Target = Sender<StateRootMessage<BPF>>;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl Drop for StateHookSender {
impl<BPF: BlindedProviderFactory> Drop for StateHookSender<BPF> {
fn drop(&mut self) {
// Send completion signal when the sender is dropped
let _ = self.0.send(StateRootMessage::FinishedStateUpdates);
Expand Down Expand Up @@ -224,34 +226,40 @@ fn evm_state_to_hashed_post_state(update: EvmState) -> HashedPostState {
/// to the tree.
/// Then it updates relevant leaves according to the result of the transaction.
#[derive(Debug)]
pub struct StateRootTask<Factory> {
pub struct StateRootTask<Factory, BPF: BlindedProviderFactory> {
/// Task configuration.
config: StateRootConfig<Factory>,
/// Receiver for state root related messages.
rx: Receiver<StateRootMessage>,
rx: Receiver<StateRootMessage<BPF>>,
/// Sender for state root related messages.
tx: Sender<StateRootMessage>,
tx: Sender<StateRootMessage<BPF>>,
/// Proof targets that have been already fetched.
fetched_proof_targets: MultiProofTargets,
/// Proof sequencing handler.
proof_sequencer: ProofSequencer,
/// The sparse trie used for the state root calculation. If [`None`], then update is in
/// progress.
sparse_trie: Option<Box<SparseStateTrie>>,
sparse_trie: Option<Box<SparseStateTrie<BPF>>>,
}

#[allow(dead_code)]
impl<Factory> StateRootTask<Factory>
impl<'env, Factory, ABP, SBP, BPF> StateRootTask<Factory, BPF>
where
Factory: DatabaseProviderFactory<Provider: BlockReader>
+ StateCommitmentProvider
+ Clone
+ Send
+ Sync
+ 'static,
ABP: BlindedProvider<Error = SparseTrieError> + Send + Sync + 'env,
SBP: BlindedProvider<Error = SparseTrieError> + Send + Sync + 'env,
BPF: BlindedProviderFactory<AccountNodeProvider = ABP, StorageNodeProvider = SBP>
+ Send
+ Sync
+ 'env,
{
/// Creates a new state root task with the unified message channel
pub fn new(config: StateRootConfig<Factory>) -> Self {
pub fn new(config: StateRootConfig<Factory>, blinded_provider: BPF) -> Self {
let (tx, rx) = channel();

Self {
Expand All @@ -260,18 +268,19 @@ where
tx,
fetched_proof_targets: Default::default(),
proof_sequencer: ProofSequencer::new(),
sparse_trie: Some(Box::new(SparseStateTrie::default().with_updates(true))),
sparse_trie: Some(Box::new(SparseStateTrie::new(blinded_provider).with_updates(true))),
}
}

/// Spawns the state root task and returns a handle to await its result.
pub fn spawn(self) -> StateRootHandle {
pub fn spawn<'scope>(self, scope: &'scope thread::Scope<'scope, 'env>) -> StateRootHandle {
let (tx, rx) = mpsc::sync_channel(1);
std::thread::Builder::new()
.name("State Root Task".to_string())
.spawn(move || {
.spawn_scoped(scope, move || {
debug!(target: "engine::tree", "Starting state root task");
let result = self.run();

let result = rayon::scope(|scope| self.run(scope));
let _ = tx.send(result);
})
.expect("failed to spawn state root thread");
Expand All @@ -294,12 +303,13 @@ where
///
/// Returns proof targets derived from the state update.
fn on_state_update(
scope: &rayon::Scope<'env>,
view: ConsistentDbView<Factory>,
input: Arc<TrieInput>,
update: EvmState,
fetched_proof_targets: &mut MultiProofTargets,
proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage>,
state_root_message_sender: Sender<StateRootMessage<BPF>>,
) {
let hashed_state_update = evm_state_to_hashed_post_state(update);

Expand All @@ -309,7 +319,7 @@ where
}

// Dispatch proof gathering for this state update
rayon::spawn(move || {
scope.spawn(move |_| {
let provider = match view.provider_ro() {
Ok(provider) => provider,
Err(error) => {
Expand Down Expand Up @@ -362,7 +372,12 @@ where
}

/// Spawns root calculation with the current state and proofs.
fn spawn_root_calculation(&mut self, state: HashedPostState, multiproof: MultiProof) {
fn spawn_root_calculation(
&mut self,
scope: &rayon::Scope<'env>,
state: HashedPostState,
multiproof: MultiProof,
) {
let Some(trie) = self.sparse_trie.take() else { return };

trace!(
Expand All @@ -376,7 +391,7 @@ where
let targets = get_proof_targets(&state, &HashMap::default());

let tx = self.tx.clone();
rayon::spawn(move || {
scope.spawn(move |_| {
let result = update_sparse_trie(trie, multiproof, targets, state);
match result {
Ok((trie, elapsed)) => {
Expand All @@ -394,7 +409,7 @@ where
});
}

fn run(mut self) -> StateRootResult {
fn run(mut self, scope: &rayon::Scope<'env>) -> StateRootResult {
let mut current_state_update = HashedPostState::default();
let mut current_multiproof = MultiProof::default();
let mut updates_received = 0;
Expand All @@ -414,6 +429,7 @@ where
"Received new state update"
);
Self::on_state_update(
scope,
self.config.consistent_view.clone(),
self.config.input.clone(),
update,
Expand Down Expand Up @@ -443,7 +459,11 @@ where
current_multiproof.extend(combined_proof);
current_state_update.extend(combined_state_update);
} else {
self.spawn_root_calculation(combined_state_update, combined_proof);
self.spawn_root_calculation(
scope,
combined_state_update,
combined_proof,
);
}
}
}
Expand Down Expand Up @@ -481,6 +501,7 @@ where
"Spawning subsequent root calculation"
);
self.spawn_root_calculation(
scope,
std::mem::take(&mut current_state_update),
std::mem::take(&mut current_multiproof),
);
Expand Down Expand Up @@ -555,12 +576,16 @@ fn get_proof_targets(

/// Updates the sparse trie with the given proofs and state, and returns the updated trie and the
/// time it took.
fn update_sparse_trie(
mut trie: Box<SparseStateTrie>,
fn update_sparse_trie<
ABP: BlindedProvider<Error = SparseTrieError> + Send + Sync,
SBP: BlindedProvider<Error = SparseTrieError> + Send + Sync,
BPF: BlindedProviderFactory<AccountNodeProvider = ABP, StorageNodeProvider = SBP> + Send + Sync,
>(
mut trie: Box<SparseStateTrie<BPF>>,
multiproof: MultiProof,
targets: MultiProofTargets,
state: HashedPostState,
) -> SparseStateTrieResult<(Box<SparseStateTrie>, Duration)> {
) -> SparseStateTrieResult<(Box<SparseStateTrie<BPF>>, Duration)> {
trace!(target: "engine::root::sparse", "Updating sparse trie");
let started_at = Instant::now();

Expand All @@ -582,12 +607,10 @@ fn update_sparse_trie(
trace!(target: "engine::root::sparse", ?address, "Wiping storage");
storage_trie.wipe()?;
}

for (slot, value) in storage.storage {
let slot_nibbles = Nibbles::unpack(slot);
if value.is_zero() {
trace!(target: "engine::root::sparse", ?address, ?slot, "Removing storage slot");

storage_trie.remove_leaf(&slot_nibbles)?;
} else {
trace!(target: "engine::root::sparse", ?address, ?slot, "Updating storage slot");
Expand Down Expand Up @@ -629,7 +652,11 @@ mod tests {
providers::ConsistentDbView, test_utils::create_test_provider_factory, HashingWriter,
};
use reth_testing_utils::generators::{self, Rng};
use reth_trie::{test_utils::state_root, TrieInput};
use reth_trie::{
hashed_cursor::HashedPostStateCursorFactory, proof::ProofBlindedProviderFactory,
test_utils::state_root, trie_cursor::InMemoryTrieCursorFactory, TrieInput,
};
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use revm_primitives::{
Account as RevmAccount, AccountInfo, AccountStatus, Address, EvmState, EvmStorageSlot,
HashMap, B256, KECCAK_EMPTY, U256,
Expand Down Expand Up @@ -746,16 +773,32 @@ mod tests {
consistent_view: ConsistentDbView::new(factory, None),
input: Arc::new(TrieInput::from_state(hashed_state)),
};
let task = StateRootTask::new(config);
let mut state_hook = task.state_hook();
let handle = task.spawn();
let provider = config.consistent_view.provider_ro().unwrap();
let nodes_sorted = config.input.nodes.clone().into_sorted();
let state_sorted = config.input.state.clone().into_sorted();
let blinded_provider_factory = ProofBlindedProviderFactory::new(
InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider.tx_ref()),
&nodes_sorted,
),
HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider.tx_ref()),
&state_sorted,
),
Arc::new(config.input.prefix_sets.clone()),
);
let (root_from_task, _) = std::thread::scope(|std_scope| {
let task = StateRootTask::new(config, blinded_provider_factory);
let mut state_hook = task.state_hook();
let handle = task.spawn(std_scope);

for update in state_updates {
state_hook.on_state(&update);
}
drop(state_hook);
for update in state_updates {
state_hook.on_state(&update);
}
drop(state_hook);

let (root_from_task, _) = handle.wait_for_result().expect("task failed");
handle.wait_for_result().expect("task failed")
});
let root_from_base = state_root(accumulated_state);

assert_eq!(
Expand Down
Loading
Loading