Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
shekhirin committed Jul 10, 2024
1 parent 47669ba commit 232c6ae
Showing 1 changed file with 47 additions and 38 deletions.
85 changes: 47 additions & 38 deletions crates/stages/stages/src/stages/prune.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,26 +68,24 @@ impl<DB: Database> Stage<DB> for PruneStage {
mod tests {
use super::*;
use crate::test_utils::{
stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, TestRunnerError,
TestStageDB, UnwindStageTestRunner,
stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, StorageKind,
TestRunnerError, TestStageDB, UnwindStageTestRunner,
};
use reth_primitives::SealedHeader;
use reth_provider::providers::StaticFileWriter;
use reth_testing_utils::{
generators,
generators::{random_header, random_header_range},
use reth_primitives::{SealedBlock, B256};
use reth_provider::{
providers::StaticFileWriter, TransactionsProvider, TransactionsProviderExt,
};
use reth_prune::PruneMode;
use reth_testing_utils::generators::{self, random_block_range};

stage_test_suite_ext!(FinishTestRunner, finish);
stage_test_suite_ext!(PruneTestRunner, prune);

#[derive(Default)]
struct FinishTestRunner {
struct PruneTestRunner {
db: TestStageDB,
prune_modes: PruneModes,
commit_threshold: usize,
}

impl StageTestRunner for FinishTestRunner {
impl StageTestRunner for PruneTestRunner {
type S = PruneStage;

fn db(&self) -> &TestStageDB {
Expand All @@ -96,32 +94,33 @@ mod tests {

fn stage(&self) -> Self::S {
PruneStage {
prune_modes: self.prune_modes.clone(),
commit_threshold: self.commit_threshold,
prune_modes: PruneModes {
sender_recovery: Some(PruneMode::Full),
..Default::default()
},
commit_threshold: usize::MAX,
}
}
}

impl ExecuteStageTestRunner for FinishTestRunner {
type Seed = Vec<SealedHeader>;
impl ExecuteStageTestRunner for PruneTestRunner {
type Seed = Vec<SealedBlock>;

fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
let start = input.checkpoint().block_number;
let mut rng = generators::rng();
let head = random_header(&mut rng, start, None);
self.db.insert_headers_with_td(std::iter::once(&head))?;

// use previous progress as seed size
let end = input.target.unwrap_or_default() + 1;

if start + 1 >= end {
return Ok(Vec::default())
}

let mut headers = random_header_range(&mut rng, start + 1..end, head.hash());
self.db.insert_headers_with_td(headers.iter())?;
headers.insert(0, head);
Ok(headers)
let blocks = random_block_range(
&mut rng,
input.checkpoint().block_number..=input.target(),
B256::ZERO,
1..3,
);
self.db.insert_blocks(blocks.iter(), StorageKind::Static)?;
self.db.insert_transaction_senders(
blocks.iter().flat_map(|block| block.body.iter()).enumerate().map(|(i, tx)| {
(i as u64, tx.recover_signer().expect("failed to recover signer"))
}),
)?;
Ok(blocks)
}

fn validate_execution(
Expand All @@ -130,18 +129,28 @@ mod tests {
output: Option<ExecOutput>,
) -> Result<(), TestRunnerError> {
if let Some(output) = output {
assert!(output.done, "stage should always be done");
assert_eq!(
output.checkpoint.block_number,
input.target(),
"stage progress should always match progress of previous stage"
);
let start_block = input.next_block();
let end_block = output.checkpoint.block_number;

if start_block > end_block {
return Ok(())
}

assert!(output.done);
assert_eq!(output.checkpoint.block_number, input.target());

// Verify that the senders are pruned
let provider = self.db.factory.provider()?;
let tx_range =
provider.transaction_range_by_block_range(start_block..=end_block)?;
let senders = self.db.factory.provider()?.senders_by_tx_range(tx_range)?;
assert!(senders.is_empty());
}
Ok(())
}
}

impl UnwindStageTestRunner for FinishTestRunner {
impl UnwindStageTestRunner for PruneTestRunner {
fn validate_unwind(&self, _input: UnwindInput) -> Result<(), TestRunnerError> {
Ok(())
}
Expand Down

0 comments on commit 232c6ae

Please sign in to comment.