From 07b926633212f4472dda5f177510022231a0b1cd Mon Sep 17 00:00:00 2001 From: Roman Krasiuk Date: Thu, 13 Oct 2022 15:18:55 +0300 Subject: [PATCH 01/13] headers stage scaffolding --- Cargo.lock | 4 + crates/interfaces/Cargo.toml | 1 + crates/interfaces/src/consensus.rs | 6 +- crates/interfaces/src/lib.rs | 3 + crates/interfaces/src/stages.rs | 32 ++++ crates/primitives/src/header.rs | 7 + crates/stages/Cargo.toml | 7 +- crates/stages/src/lib.rs | 1 + crates/stages/src/stages/headers.rs | 270 ++++++++++++++++++++++++++++ crates/stages/src/stages/mod.rs | 1 + 10 files changed, 328 insertions(+), 4 deletions(-) create mode 100644 crates/interfaces/src/stages.rs create mode 100644 crates/stages/src/stages/headers.rs create mode 100644 crates/stages/src/stages/mod.rs diff --git a/Cargo.lock b/Cargo.lock index a9f68173a820..a1af391c5cac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1847,6 +1847,7 @@ version = "0.1.0" dependencies = [ "async-trait", "auto_impl", + "futures", "reth-primitives", "reth-rpc-types", "thiserror", @@ -1914,7 +1915,10 @@ name = "reth-stages" version = "0.1.0" dependencies = [ "async-trait", + "futures", + "rand", "reth-db", + "reth-interfaces", "reth-primitives", "tempfile", "thiserror", diff --git a/crates/interfaces/Cargo.toml b/crates/interfaces/Cargo.toml index e6d88f7da897..42da856618b4 100644 --- a/crates/interfaces/Cargo.toml +++ b/crates/interfaces/Cargo.toml @@ -13,3 +13,4 @@ async-trait = "0.1.57" thiserror = "1.0.37" auto_impl = "1.0" tokio = { version = "1.21.2", features = ["sync"] } +futures = "0.3" diff --git a/crates/interfaces/src/consensus.rs b/crates/interfaces/src/consensus.rs index d67d074af6e0..e6e3e5c47ff9 100644 --- a/crates/interfaces/src/consensus.rs +++ b/crates/interfaces/src/consensus.rs @@ -7,12 +7,12 @@ use tokio::sync::watch::Receiver; /// Consensus is a protocol that chooses canonical chain. /// We are checking validity of block header here. #[async_trait] -pub trait Consensus { +pub trait Consensus: Sync + Send { /// Get a receiver for the fork choice state - fn fork_choice_state(&self) -> Receiver; + fn forkchoice_state(&self) -> Receiver; /// Validate if header is correct and follows consensus specification - fn validate_header(&self, _header: &Header) -> Result<(), Error> { + fn validate_header(&self, _header: &Header, _parent: &Header) -> Result<(), Error> { Ok(()) } } diff --git a/crates/interfaces/src/lib.rs b/crates/interfaces/src/lib.rs index c7624b35de94..508551a40fb1 100644 --- a/crates/interfaces/src/lib.rs +++ b/crates/interfaces/src/lib.rs @@ -12,3 +12,6 @@ pub mod executor; /// Consensus traits. pub mod consensus; + +/// Stage sync related traits +pub mod stages; diff --git a/crates/interfaces/src/stages.rs b/crates/interfaces/src/stages.rs new file mode 100644 index 000000000000..9856dcf8ac00 --- /dev/null +++ b/crates/interfaces/src/stages.rs @@ -0,0 +1,32 @@ +use async_trait::async_trait; +use futures::Stream; +use reth_primitives::{rpc::BlockId, Header, H256, H512}; +use std::{collections::HashSet, pin::Pin}; + +/// The stream of messages +pub type MessageStream = Pin + Send>>; + +/// The header request struct +#[derive(Debug)] +pub struct HeaderRequest { + /// The starting block + pub start: BlockId, + /// The response max size + pub limit: u64, + /// Flag indicating whether the blocks should + /// arrive in reverse + pub reverse: bool, +} + +/// The block headers downloader client +#[async_trait] +pub trait HeadersClient: Send + Sync { + /// Update the current node status + async fn update_status(&mut self, height: u64, hash: H256, td: H256); + + /// Send the header request + async fn send_header_request(&self, id: u64, request: HeaderRequest) -> HashSet; + + /// Stream the header response messages + async fn stream_headers(&self) -> MessageStream<(u64, Vec
)>; +} diff --git a/crates/primitives/src/header.rs b/crates/primitives/src/header.rs index 5649e7b8e997..9f6b3bf604f7 100644 --- a/crates/primitives/src/header.rs +++ b/crates/primitives/src/header.rs @@ -98,6 +98,13 @@ impl Deref for HeaderLocked { } impl HeaderLocked { + /// Construct a new locked header. + /// Applicable when hash is known from + /// the database provided it's not corrupted. + pub fn new(header: Header, hash: H256) -> Self { + Self { header, hash } + } + /// Extract raw header that can be modified. pub fn unlock(self) -> Header { self.header diff --git a/crates/stages/Cargo.toml b/crates/stages/Cargo.toml index 23872f5a6b33..bde35d252c93 100644 --- a/crates/stages/Cargo.toml +++ b/crates/stages/Cargo.toml @@ -9,12 +9,17 @@ description = "Staged syncing primitives used in reth." [dependencies] reth-primitives = { path = "../primitives" } +reth-interfaces = { path = "../interfaces" } reth-db = { path = "../db" } -async-trait = "0.1.57" thiserror = "1.0.37" tracing = "0.1.36" tracing-futures = "0.2.5" tokio = { version = "1.21.2", features = ["sync"] } +rand = "0.8" # TODO: + +# async/futures +async-trait = "0.1.57" +futures = "0.3" [dev-dependencies] tokio = { version = "*", features = ["rt", "sync", "macros"] } diff --git a/crates/stages/src/lib.rs b/crates/stages/src/lib.rs index 7e4ce6525d76..3e57d1292367 100644 --- a/crates/stages/src/lib.rs +++ b/crates/stages/src/lib.rs @@ -12,6 +12,7 @@ mod error; mod id; mod pipeline; mod stage; +mod stages; mod util; pub use error::*; diff --git a/crates/stages/src/stages/headers.rs b/crates/stages/src/stages/headers.rs new file mode 100644 index 000000000000..0260c93aabe4 --- /dev/null +++ b/crates/stages/src/stages/headers.rs @@ -0,0 +1,270 @@ +use async_trait::async_trait; + +use crate::{ExecInput, ExecOutput, Stage, StageError, StageId, UnwindInput, UnwindOutput}; +use futures::StreamExt; +use rand::Rng; +use reth_db::{ + kv::{ + table::{Decode, Encode}, + tables, + tx::Tx, + }, + mdbx::{self, WriteFlags}, +}; +use reth_interfaces::{ + consensus::Consensus, + stages::{HeaderRequest, HeadersClient, MessageStream}, +}; +use reth_primitives::{rpc::BlockId, BlockNumber, Header, HeaderLocked, H256}; +use std::{sync::Arc, time::Duration}; +use thiserror::Error; +use tracing::*; + +const HEADERS: StageId = StageId("HEADERS"); + +// TODO: docs +// TODO: add tracing +pub struct HeaderStage { + pub consensus: Arc, + pub client: Arc, + pub batch_size: u64, + pub request_retries: usize, + pub request_timeout: usize, +} + +#[derive(Error, Debug)] +pub enum DownloadError { + /// Header validation failed + #[error("Failed to validate header {hash}. Details: {details}.")] + HeaderValidation { hash: H256, details: String }, + /// No headers reponse received + #[error("Failed to get headers for request {request_id}.")] + NoHeaderResponse { request_id: u64 }, + /// The stage encountered an internal error. + #[error(transparent)] + Internal(Box), +} + +impl DownloadError { + fn is_retryable(&self) -> bool { + matches!(self, DownloadError::NoHeaderResponse { .. }) + } +} + +#[async_trait] +impl<'db, E> Stage<'db, E> for HeaderStage +where + E: mdbx::EnvironmentKind, +{ + fn id(&self) -> StageId { + HEADERS + } + + /// Execute the stage. + async fn execute<'tx>( + &mut self, + tx: &mut Tx<'tx, mdbx::RW, E>, + input: ExecInput, + ) -> Result { + let last_block_num = + input.previous_stage.as_ref().map(|(_, block)| *block).unwrap_or_default(); + // TODO: check if in case of panic the node head needs to be updated + self.update_head(tx, last_block_num).await?; + + let mut stage_progress = last_block_num; + + // download the headers + // TODO: check if some upper block constraint is necessary + let last_hash: H256 = tx.get::(last_block_num)?.unwrap(); // TODO: + let last_header: Header = tx.get::((last_block_num, last_hash))?.unwrap(); // TODO: + let head = HeaderLocked::new(last_header, last_hash); + + let forkchoice_state = self.next_forkchoice_state(&head.hash()).await; + + let headers = match self.download(&head, forkchoice_state).await { + Ok(res) => res, + Err(e) => match e { + DownloadError::NoHeaderResponse { request_id } => { + warn!("no response for request {request_id}"); + return Ok(ExecOutput { stage_progress, reached_tip: false, done: false }) + } + DownloadError::HeaderValidation { hash, details } => { + warn!("validation error for header {hash}: {details}"); + return Err(StageError::Validation { block: last_block_num }) + } + DownloadError::Internal(e) => return Err(StageError::Internal(e)), + }, + }; + + let mut cursor_header_number = tx.cursor::()?; + let mut cursor_header = tx.cursor::()?; + let mut cursor_canonical = tx.cursor::()?; + let mut cursor_td = tx.cursor::()?; + let mut td = cursor_td.last()?.map(|((_, _), v)| v).unwrap(); // TODO: + + for header in headers { + if header.number == 0 { + continue + } + + let hash = header.hash(); + td += header.difficulty; + + cursor_header_number.put( + hash.to_fixed_bytes().to_vec(), + header.number, + Some(WriteFlags::APPEND), + )?; + cursor_header.put((header.number, hash), header, Some(WriteFlags::APPEND))?; + cursor_canonical.put(header.number, hash, Some(WriteFlags::APPEND))?; + cursor_td.put((header.number, hash), td, Some(WriteFlags::APPEND))?; + + stage_progress = header.number; + } + + Ok(ExecOutput { stage_progress, reached_tip: true, done: true }) + } + + /// Unwind the stage. + async fn unwind<'tx>( + &mut self, + tx: &mut Tx<'tx, mdbx::RW, E>, + input: UnwindInput, + ) -> Result> { + if let Some(bad_block) = input.bad_block { + todo!() + } + + todo!() + } +} + +impl HeaderStage { + async fn update_head<'tx, E: mdbx::EnvironmentKind>( + &self, + tx: &'tx mut Tx<'tx, mdbx::RW, E>, + height: BlockNumber, + ) -> Result<(), StageError> { + let hash = tx.get::(height)?.unwrap().decode(); + let td: Vec = tx.get::((height, hash))?.unwrap(); + self.client.update_status(height, hash, H256::from_slice(&td)); + Ok(()) + } + + async fn next_forkchoice_state(&self, head: &H256) -> (H256, H256) { + let mut state_rcv = self.consensus.forkchoice_state(); + loop { + state_rcv.changed().await; + let forkchoice = state_rcv.borrow(); + if !forkchoice.head_block_hash.is_zero() && forkchoice.head_block_hash != *head { + return (forkchoice.head_block_hash, forkchoice.finalized_block_hash) + } + } + } + + /// Download headers in batches with retries. + /// Returns the header collection in sorted ascending order + async fn download( + &self, + head: &HeaderLocked, + forkchoice_state: (H256, H256), + ) -> Result, DownloadError> { + let mut stream = self.client.stream_headers().await; + // the header order will be preserved during inserts + let mut retries = self.request_retries; + + let mut out = Vec::::new(); + loop { + match self.download_batch(head, &forkchoice_state, &mut stream, &mut out).await { + Ok(done) => { + if done { + return Ok(out) + } + } + Err(e) if e.is_retryable() && retries > 0 => { + retries -= 1; + } + Err(e) => return Err(e), + } + } + } + + /// Request and process the batch of headers + async fn download_batch( + &self, + head: &HeaderLocked, + (state_tip, state_finalized): &(H256, H256), + stream: &mut MessageStream<(u64, Vec
)>, + out: &mut Vec, + ) -> Result { + let request_id = rand::thread_rng().gen(); + let start = BlockId::Hash(out.first().map_or(state_tip.clone(), |h| h.parent_hash)); + let request = HeaderRequest { start, limit: self.batch_size, reverse: true }; + // TODO: timeout + let _ = self.client.send_header_request(request_id, request).await; + + let mut batch = self.receive_headers(stream, request_id).await?; + + out.reserve_exact(batch.len()); + batch.sort_unstable_by_key(|h| h.number); // TODO: revise: par_sort? + + let mut batch_iter = batch.into_iter().rev(); + while let Some(parent) = batch_iter.next() { + let parent = parent.lock(); + + if head.hash() == parent.hash() { + // we are done + return Ok(true) + } + + if let Some(tail_header) = out.first() { + if !(parent.hash() == tail_header.parent_hash && + parent.number + 1 == tail_header.number) + { + // cannot attach to the current buffer + // discard this batch + return Ok(false) + } + + self.consensus.validate_header(&tail_header, &parent).map_err(|e| { + DownloadError::HeaderValidation { hash: parent.hash(), details: e.to_string() } + })?; + } else if parent.hash() != *state_tip { + // the buffer is empty and the first header + // does not match the one we requested + // discard this batch + return Ok(false) + } + + out.insert(0, parent); + } + + Ok(false) + } + + /// Process header message stream and return the request by id. + /// The messages with empty headers are ignored. + async fn receive_headers( + &self, + stream: &mut MessageStream<(u64, Vec
)>, + request_id: u64, + ) -> Result, DownloadError> { + let timeout = tokio::time::sleep(Duration::from_secs(5)); + tokio::pin!(timeout); + let result = loop { + tokio::select! { + msg = stream.next() => { + match msg { + Some((id, headers)) if request_id == id && !headers.is_empty() => break Some(headers), + _ => (), + } + } + _ = &mut timeout => { + break None; + } + } + }; + + result.ok_or(DownloadError::NoHeaderResponse { request_id }) + } +} diff --git a/crates/stages/src/stages/mod.rs b/crates/stages/src/stages/mod.rs new file mode 100644 index 000000000000..0b97f4357b4a --- /dev/null +++ b/crates/stages/src/stages/mod.rs @@ -0,0 +1 @@ +mod headers; From 8284cc327b49a4e5505238332873b7ff8f66b4cc Mon Sep 17 00:00:00 2001 From: Roman Krasiuk Date: Sun, 16 Oct 2022 18:33:48 +0300 Subject: [PATCH 02/13] refactor to stream based approach, add some tests & docs --- Cargo.lock | 13 +- crates/db/src/kv/cursor.rs | 5 + crates/interfaces/src/consensus.rs | 8 +- crates/interfaces/src/stages.rs | 6 +- crates/primitives/Cargo.toml | 1 + crates/primitives/src/lib.rs | 3 + crates/stages/Cargo.toml | 5 +- crates/stages/src/lib.rs | 1 + crates/stages/src/stages/headers.rs | 446 ++++++++++++++++++++++------ crates/stages/src/stages/mod.rs | 3 +- 10 files changed, 387 insertions(+), 104 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a1af391c5cac..dca7f6b9fff5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -53,6 +53,12 @@ version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" +[[package]] +name = "assert_matches" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9" + [[package]] name = "async-lock" version = "2.5.0" @@ -1873,6 +1879,7 @@ name = "reth-primitives" version = "0.1.0" dependencies = [ "bytes", + "ethereum-types", "ethers-core", "fastrlp", "serde", @@ -1914,12 +1921,14 @@ dependencies = [ name = "reth-stages" version = "0.1.0" dependencies = [ + "assert_matches", "async-trait", "futures", "rand", "reth-db", "reth-interfaces", "reth-primitives", + "reth-rpc-types", "tempfile", "thiserror", "tokio", @@ -2518,9 +2527,9 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.10" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6edf2d6bc038a43d31353570e27270603f4648d18f5ed10c0e179abe43255af" +checksum = "d660770404473ccd7bc9f8b28494a811bc18542b915c0855c51e8f419d5223ce" dependencies = [ "futures-core", "pin-project-lite", diff --git a/crates/db/src/kv/cursor.rs b/crates/db/src/kv/cursor.rs index 0aaf66b176bf..b11f45f72be6 100644 --- a/crates/db/src/kv/cursor.rs +++ b/crates/db/src/kv/cursor.rs @@ -115,6 +115,11 @@ impl<'tx, T: Table> Cursor<'tx, RW, T> { .put(k.encode().as_ref(), v.encode().as_ref(), f.unwrap_or_default()) .map_err(KVError::Put) } + + /// Deletes the current `(key, value)` entry on `table` that the cursor is positioned at. + pub fn delete(&mut self) -> Result<(), KVError> { + self.inner.del(WriteFlags::CURRENT).map_err(KVError::Delete) + } } impl<'txn, K, T> Cursor<'txn, K, T> diff --git a/crates/interfaces/src/consensus.rs b/crates/interfaces/src/consensus.rs index e6e3e5c47ff9..0e01b97565da 100644 --- a/crates/interfaces/src/consensus.rs +++ b/crates/interfaces/src/consensus.rs @@ -1,16 +1,20 @@ use async_trait::async_trait; -use reth_primitives::Header; +use reth_primitives::{Header, H256}; use reth_rpc_types::engine::ForkchoiceState; +use std::fmt::Debug; use thiserror::Error; use tokio::sync::watch::Receiver; /// Consensus is a protocol that chooses canonical chain. /// We are checking validity of block header here. #[async_trait] -pub trait Consensus: Sync + Send { +pub trait Consensus: Sync + Send + Debug { /// Get a receiver for the fork choice state fn forkchoice_state(&self) -> Receiver; + /// Return the current chain tip + fn tip(&self) -> H256; + /// Validate if header is correct and follows consensus specification fn validate_header(&self, _header: &Header, _parent: &Header) -> Result<(), Error> { Ok(()) diff --git a/crates/interfaces/src/stages.rs b/crates/interfaces/src/stages.rs index 9856dcf8ac00..60d9d54766c0 100644 --- a/crates/interfaces/src/stages.rs +++ b/crates/interfaces/src/stages.rs @@ -1,7 +1,7 @@ use async_trait::async_trait; use futures::Stream; use reth_primitives::{rpc::BlockId, Header, H256, H512}; -use std::{collections::HashSet, pin::Pin}; +use std::{collections::HashSet, fmt::Debug, pin::Pin}; /// The stream of messages pub type MessageStream = Pin + Send>>; @@ -20,9 +20,9 @@ pub struct HeaderRequest { /// The block headers downloader client #[async_trait] -pub trait HeadersClient: Send + Sync { +pub trait HeadersClient: Send + Sync + Debug { /// Update the current node status - async fn update_status(&mut self, height: u64, hash: H256, td: H256); + async fn update_status(&self, height: u64, hash: H256, td: H256); /// Send the header request async fn send_header_request(&self, id: u64, request: HeaderRequest) -> HashSet; diff --git a/crates/primitives/Cargo.toml b/crates/primitives/Cargo.toml index d1b7da9ff56c..a21dd8ec1666 100644 --- a/crates/primitives/Cargo.toml +++ b/crates/primitives/Cargo.toml @@ -12,6 +12,7 @@ fastrlp = { version = "0.1.3" } ethers-core = { git = "https://github.com/gakonst/ethers-rs", default-features = false } bytes = "1.2" serde = "1.0" +ethereum-types = { version = "0.13.1", default-features = false } [dev-dependencies] serde_json = "1.0" \ No newline at end of file diff --git a/crates/primitives/src/lib.rs b/crates/primitives/src/lib.rs index 37c5f83a21a9..92c996ac2659 100644 --- a/crates/primitives/src/lib.rs +++ b/crates/primitives/src/lib.rs @@ -45,3 +45,6 @@ pub use ethers_core::{ types as rpc, types::{Bloom, Bytes, H160, H256, H512, H64, U256, U64}, }; + +// For uint to hash conversion +pub use ethereum_types::BigEndianHash; diff --git a/crates/stages/Cargo.toml b/crates/stages/Cargo.toml index bde35d252c93..7bf141b4c559 100644 --- a/crates/stages/Cargo.toml +++ b/crates/stages/Cargo.toml @@ -16,6 +16,7 @@ tracing = "0.1.36" tracing-futures = "0.2.5" tokio = { version = "1.21.2", features = ["sync"] } rand = "0.8" # TODO: +tokio-stream = "0.1.11" # async/futures async-trait = "0.1.57" @@ -25,4 +26,6 @@ futures = "0.3" tokio = { version = "*", features = ["rt", "sync", "macros"] } tokio-stream = "0.1.10" tempfile = "3.3.0" -reth-db = { path = "../db", features = ["test-utils"] } \ No newline at end of file +reth-db = { path = "../db", features = ["test-utils"] } +reth-rpc-types = { path = "../net/rpc-types" } +assert_matches = "1.5.0" \ No newline at end of file diff --git a/crates/stages/src/lib.rs b/crates/stages/src/lib.rs index 3e57d1292367..5217cce25e2e 100644 --- a/crates/stages/src/lib.rs +++ b/crates/stages/src/lib.rs @@ -19,3 +19,4 @@ pub use error::*; pub use id::*; pub use pipeline::*; pub use stage::*; +pub use stages::*; diff --git a/crates/stages/src/stages/headers.rs b/crates/stages/src/stages/headers.rs index 0260c93aabe4..00dafe3a711f 100644 --- a/crates/stages/src/stages/headers.rs +++ b/crates/stages/src/stages/headers.rs @@ -1,45 +1,57 @@ -use async_trait::async_trait; - use crate::{ExecInput, ExecOutput, Stage, StageError, StageId, UnwindInput, UnwindOutput}; -use futures::StreamExt; +use async_trait::async_trait; use rand::Rng; use reth_db::{ - kv::{ - table::{Decode, Encode}, - tables, - tx::Tx, - }, + kv::{table::Encode, tables, tx::Tx}, mdbx::{self, WriteFlags}, }; use reth_interfaces::{ consensus::Consensus, stages::{HeaderRequest, HeadersClient, MessageStream}, }; -use reth_primitives::{rpc::BlockId, BlockNumber, Header, HeaderLocked, H256}; +use reth_primitives::{rpc::BlockId, BigEndianHash, BlockNumber, Header, HeaderLocked, H256, U256}; use std::{sync::Arc, time::Duration}; use thiserror::Error; +use tokio_stream::StreamExt; use tracing::*; const HEADERS: StageId = StageId("HEADERS"); // TODO: docs // TODO: add tracing + +/// The headers stage implementation for staged sync +#[derive(Debug)] pub struct HeaderStage { + /// Consensus client implementation pub consensus: Arc, + /// Downloader client implementation pub client: Arc, + /// The batch size per one request pub batch_size: u64, + /// A single request timeout + pub request_timeout: u64, + /// The number of retries for downloadign pub request_retries: usize, - pub request_timeout: usize, } +/// The downloader error type #[derive(Error, Debug)] pub enum DownloadError { /// Header validation failed #[error("Failed to validate header {hash}. Details: {details}.")] - HeaderValidation { hash: H256, details: String }, + HeaderValidation { + /// Hash of header failing validation + hash: H256, + /// The details of validation failure + details: String, + }, /// No headers reponse received #[error("Failed to get headers for request {request_id}.")] - NoHeaderResponse { request_id: u64 }, + NoHeaderResponse { + /// The last request ID + request_id: u64, + }, /// The stage encountered an internal error. #[error(transparent)] Internal(Box), @@ -56,27 +68,33 @@ impl<'db, E> Stage<'db, E> for HeaderStage where E: mdbx::EnvironmentKind, { + /// Return the id of the stage fn id(&self) -> StageId { HEADERS } - /// Execute the stage. + /// Download the headers in reverse order + /// starting from the tip async fn execute<'tx>( &mut self, tx: &mut Tx<'tx, mdbx::RW, E>, input: ExecInput, - ) -> Result { + ) -> Result + where + 'db: 'tx, + { let last_block_num = input.previous_stage.as_ref().map(|(_, block)| *block).unwrap_or_default(); // TODO: check if in case of panic the node head needs to be updated self.update_head(tx, last_block_num).await?; - let mut stage_progress = last_block_num; - // download the headers // TODO: check if some upper block constraint is necessary - let last_hash: H256 = tx.get::(last_block_num)?.unwrap(); // TODO: - let last_header: Header = tx.get::((last_block_num, last_hash))?.unwrap(); // TODO: + let last_hash = + H256::from_uint(&tx.get::(last_block_num)?.unwrap()); + let last_header: Header = temp::decode_header( + tx.get::(temp::num_hash_to_key(last_block_num, last_hash))?.unwrap(), + ); let head = HeaderLocked::new(last_header, last_hash); let forkchoice_state = self.next_forkchoice_state(&head.hash()).await; @@ -85,8 +103,12 @@ where Ok(res) => res, Err(e) => match e { DownloadError::NoHeaderResponse { request_id } => { - warn!("no response for request {request_id}"); - return Ok(ExecOutput { stage_progress, reached_tip: false, done: false }) + warn!("no response for header request {request_id}"); + return Ok(ExecOutput { + stage_progress: last_block_num, + reached_tip: false, + done: false, + }) } DownloadError::HeaderValidation { hash, details } => { warn!("validation error for header {hash}: {details}"); @@ -96,32 +118,7 @@ where }, }; - let mut cursor_header_number = tx.cursor::()?; - let mut cursor_header = tx.cursor::()?; - let mut cursor_canonical = tx.cursor::()?; - let mut cursor_td = tx.cursor::()?; - let mut td = cursor_td.last()?.map(|((_, _), v)| v).unwrap(); // TODO: - - for header in headers { - if header.number == 0 { - continue - } - - let hash = header.hash(); - td += header.difficulty; - - cursor_header_number.put( - hash.to_fixed_bytes().to_vec(), - header.number, - Some(WriteFlags::APPEND), - )?; - cursor_header.put((header.number, hash), header, Some(WriteFlags::APPEND))?; - cursor_canonical.put(header.number, hash, Some(WriteFlags::APPEND))?; - cursor_td.put((header.number, hash), td, Some(WriteFlags::APPEND))?; - - stage_progress = header.number; - } - + let stage_progress = self.write_headers(tx, headers).await?; Ok(ExecOutput { stage_progress, reached_tip: true, done: true }) } @@ -135,29 +132,67 @@ where todo!() } - todo!() + let mut walker = tx.cursor::()?.walk(input.unwind_to + 1)?; + while let Some((_, hash)) = walker.next().transpose()? { + tx.delete::(hash.encode().to_vec(), None)?; + } + + // TODO: cleanup + let mut cur = tx.cursor::()?; + let mut entry = cur.last()?; + while let Some((key, _)) = entry { + let (num, _) = temp::num_hash_from_key(key); + if num <= input.unwind_to { + break + } + cur.delete()?; + entry = cur.prev()?; + } + + let mut cur = tx.cursor::()?; + let mut entry = cur.last()?; + while let Some((block_num, _)) = entry { + if block_num <= input.unwind_to { + break + } + cur.delete()?; + entry = cur.prev()?; + } + + let mut cur = tx.cursor::()?; + let mut entry = cur.last()?; + while let Some((key, _)) = entry { + let (num, _) = temp::num_hash_from_key(key); + if num <= input.unwind_to { + break + } + cur.delete()?; + entry = cur.prev()?; + } + + Ok(UnwindOutput { stage_progress: input.unwind_to }) } } impl HeaderStage { async fn update_head<'tx, E: mdbx::EnvironmentKind>( &self, - tx: &'tx mut Tx<'tx, mdbx::RW, E>, + tx: &'tx mut Tx<'_, mdbx::RW, E>, height: BlockNumber, ) -> Result<(), StageError> { - let hash = tx.get::(height)?.unwrap().decode(); - let td: Vec = tx.get::((height, hash))?.unwrap(); - self.client.update_status(height, hash, H256::from_slice(&td)); + let hash = H256::from_uint(&tx.get::(height)?.unwrap()); + let td: Vec = tx.get::(temp::num_hash_to_key(height, hash))?.unwrap(); + self.client.update_status(height, hash, H256::from_slice(&td)).await; Ok(()) } - async fn next_forkchoice_state(&self, head: &H256) -> (H256, H256) { + async fn next_forkchoice_state(&self, head: &H256) -> H256 { let mut state_rcv = self.consensus.forkchoice_state(); loop { - state_rcv.changed().await; + let _ = state_rcv.changed().await; let forkchoice = state_rcv.borrow(); if !forkchoice.head_block_hash.is_zero() && forkchoice.head_block_hash != *head { - return (forkchoice.head_block_hash, forkchoice.finalized_block_hash) + return forkchoice.head_block_hash } } } @@ -167,15 +202,15 @@ impl HeaderStage { async fn download( &self, head: &HeaderLocked, - forkchoice_state: (H256, H256), + tip: H256, ) -> Result, DownloadError> { let mut stream = self.client.stream_headers().await; - // the header order will be preserved during inserts + // Header order will be preserved during inserts let mut retries = self.request_retries; let mut out = Vec::::new(); loop { - match self.download_batch(head, &forkchoice_state, &mut stream, &mut out).await { + match self.download_batch(head, tip, &mut stream, &mut out).await { Ok(done) => { if done { return Ok(out) @@ -189,31 +224,42 @@ impl HeaderStage { } } - /// Request and process the batch of headers async fn download_batch( &self, head: &HeaderLocked, - (state_tip, state_finalized): &(H256, H256), + chain_tip: H256, stream: &mut MessageStream<(u64, Vec
)>, out: &mut Vec, ) -> Result { - let request_id = rand::thread_rng().gen(); - let start = BlockId::Hash(out.first().map_or(state_tip.clone(), |h| h.parent_hash)); - let request = HeaderRequest { start, limit: self.batch_size, reverse: true }; - // TODO: timeout - let _ = self.client.send_header_request(request_id, request).await; - - let mut batch = self.receive_headers(stream, request_id).await?; - - out.reserve_exact(batch.len()); - batch.sort_unstable_by_key(|h| h.number); // TODO: revise: par_sort? + // Request headers starting from tip or earliest cached + let start = out.first().map_or(chain_tip, |h| h.parent_hash); + let request_id = self.request_headers(start).await; + + // Filter stream by request id and non empty headers content + let stream = stream.filter(|(id, headers)| request_id == *id && !headers.is_empty()); + + // Wrap the stream with a timeout + let stream = stream.timeout(Duration::from_secs(self.request_timeout)); + + // Unwrap the latest stream message which will be either + // the msg with headers or timeout error + let headers = { + let mut h = match Box::pin(stream).try_next().await { + Ok(Some((_, h))) => h, + _ => return Err(DownloadError::NoHeaderResponse { request_id }), + }; + h.sort_unstable_by_key(|h| h.number); + h + }; - let mut batch_iter = batch.into_iter().rev(); - while let Some(parent) = batch_iter.next() { + // Iterate the headers in reverse + out.reserve_exact(headers.len()); + let mut headers_rev = headers.into_iter().rev(); + while let Some(parent) = headers_rev.next() { let parent = parent.lock(); if head.hash() == parent.hash() { - // we are done + // We've reached the target return Ok(true) } @@ -221,7 +267,7 @@ impl HeaderStage { if !(parent.hash() == tail_header.parent_hash && parent.number + 1 == tail_header.number) { - // cannot attach to the current buffer + // Cannot attach to the current buffer, // discard this batch return Ok(false) } @@ -229,10 +275,11 @@ impl HeaderStage { self.consensus.validate_header(&tail_header, &parent).map_err(|e| { DownloadError::HeaderValidation { hash: parent.hash(), details: e.to_string() } })?; - } else if parent.hash() != *state_tip { - // the buffer is empty and the first header + } else if parent.hash() != chain_tip { + // The buffer is empty and the first header // does not match the one we requested // discard this batch + // TODO: penalize the peer? return Ok(false) } @@ -242,29 +289,238 @@ impl HeaderStage { Ok(false) } - /// Process header message stream and return the request by id. - /// The messages with empty headers are ignored. - async fn receive_headers( + /// Perform a header request. Return the request ID + async fn request_headers(&self, start: H256) -> u64 { + let request_id = rand::thread_rng().gen(); + let request = + HeaderRequest { start: BlockId::Hash(start), limit: self.batch_size, reverse: true }; + let _ = self.client.send_header_request(request_id, request).await; + request_id + } + + /// Write downloaded headers to the database + async fn write_headers<'tx, E: mdbx::EnvironmentKind>( &self, - stream: &mut MessageStream<(u64, Vec
)>, - request_id: u64, - ) -> Result, DownloadError> { - let timeout = tokio::time::sleep(Duration::from_secs(5)); - tokio::pin!(timeout); - let result = loop { - tokio::select! { - msg = stream.next() => { - match msg { - Some((id, headers)) if request_id == id && !headers.is_empty() => break Some(headers), - _ => (), - } - } - _ = &mut timeout => { - break None; - } + tx: &'tx mut Tx<'_, mdbx::RW, E>, + headers: Vec, + ) -> Result { + let mut cursor_header_number = tx.cursor::()?; + let mut cursor_header = tx.cursor::()?; + let mut cursor_canonical = tx.cursor::()?; + let mut cursor_td = tx.cursor::()?; + let mut td = U256::from_big_endian(&cursor_td.last()?.map(|(_, v)| v).unwrap()); + + // TODO: comment + let mut latest = 0; + for header in headers { + if header.number == 0 { + continue } + + let hash = header.hash(); + let number = header.number; + let num_hash_key = temp::num_hash_to_key(header.number, hash); + + td += header.difficulty; + + cursor_header_number.put(hash.to_fixed_bytes().to_vec(), header.number, None)?; + cursor_header.put( + num_hash_key.clone(), + temp::encode_header(header.unlock()), + Some(WriteFlags::APPEND), + )?; + cursor_canonical.put(number, hash.into_uint(), Some(WriteFlags::APPEND))?; + cursor_td.put( + num_hash_key, + H256::from_uint(&td).as_bytes().to_vec(), + Some(WriteFlags::APPEND), + )?; + + latest = number; + } + + Ok(latest) + } +} + +// TODO: remove +mod temp { + use super::*; + + pub(crate) fn num_hash_to_key(number: BlockNumber, hash: H256) -> Vec { + let mut key = number.to_be_bytes().to_vec(); + key.extend(hash.0); + key + } + + pub(crate) fn num_hash_from_key(key: Vec) -> (BlockNumber, H256) { + todo!() + } + + pub(crate) fn encode_header(_header: Header) -> Vec { + todo!() + } + + pub(crate) fn decode_header(_bytes: Vec) -> Header { + todo!() + } +} + +#[cfg(test)] +mod tests { + use super::{DownloadError, HeaderStage}; + use assert_matches::assert_matches; + use reth_interfaces::stages::{HeaderRequest, MessageStream}; + use reth_primitives::{HeaderLocked, H256}; + use std::sync::Arc; + use tokio::sync::mpsc::{channel, Sender}; + use tokio_stream::{pending, wrappers::ReceiverStream, StreamExt}; + + fn setup_stage( + tx: Sender<(u64, HeaderRequest)>, + batch_size: u64, + request_timeout: u64, + request_retries: usize, + ) -> HeaderStage { + let client = utils::TestHeaderClient::new(tx); + let consensus = utils::TestConsensus::new(); + HeaderStage { + consensus: Arc::new(consensus), + client: Arc::new(client), + batch_size, + request_retries, + request_timeout, + } + } + + #[tokio::test] + async fn download_batch_timeout() { + let (tx, rx) = channel(1); + let (req_tx, _req_rx) = channel(1); + let (batch, timeout, retries) = (1, 1, 1); + let stage = setup_stage(req_tx, batch, timeout, retries); + + let mut stream = Box::pin(pending()) as MessageStream; + tokio::spawn(async move { + let result = stage + .download_batch(&HeaderLocked::default(), H256::zero(), &mut stream, &mut vec![]) + .await; + tx.send(result).await.unwrap(); + }); + + assert_matches!( + *ReceiverStream::new(rx).collect::>>().await, + [Err(DownloadError::NoHeaderResponse { .. })] + ); + } + + #[tokio::test] + async fn download_batch_timeout_on_invalid_messages() { + let (tx, rx) = channel(1); + let (req_tx, req_rx) = channel(3); + let (res_tx, res_rx) = channel(3); + + let (batch, timeout, retries) = (1, 5, 3); + let stage = setup_stage(req_tx, batch, timeout, retries); + + let mut stream = + Box::pin(ReceiverStream::new(res_rx)) as MessageStream; + tokio::spawn(async move { + let result = stage + .download_batch(&HeaderLocked::default(), H256::zero(), &mut stream, &mut vec![]) + .await; + tx.send(result).await.unwrap(); + }); + + let mut last_req_id = None; + let mut req_stream = ReceiverStream::new(req_rx); + while let Some((id, _req)) = req_stream.next().await { + // Since the receiving channel filters by id and message length - + // randomize the input to the tested filter + res_tx.send((id.saturating_add(id % 2), vec![])).await.unwrap(); + last_req_id = Some(id); + } + + assert_matches!( + *ReceiverStream::new(rx).collect::>>().await, + [Err(DownloadError::NoHeaderResponse { request_id })] if request_id == last_req_id.unwrap() + ); + } + + mod utils { + use async_trait::async_trait; + use reth_interfaces::{ + consensus::{self, Consensus}, + stages::{HeaderRequest, HeadersClient, MessageStream}, }; + use reth_primitives::{Header, H256, H512}; + use reth_rpc_types::engine::ForkchoiceState; + use std::collections::HashSet; + use tokio::sync::{mpsc::Sender, watch}; + + pub(crate) type HeaderResponse = (u64, Vec
); + + #[derive(Debug)] + pub(crate) struct TestHeaderClient { + tx: Sender<(u64, HeaderRequest)>, + } + + impl TestHeaderClient { + /// Construct a new test header downloader. + /// `tx` is the + pub(crate) fn new(tx: Sender<(u64, HeaderRequest)>) -> Self { + Self { tx } + } + } + + #[async_trait] + impl HeadersClient for TestHeaderClient { + async fn update_status(&self, _height: u64, _hash: H256, _td: H256) {} + + async fn send_header_request(&self, id: u64, request: HeaderRequest) -> HashSet { + self.tx.send((id, request)).await.unwrap(); + HashSet::default() + } + + async fn stream_headers(&self) -> MessageStream<(u64, Vec
)> { + todo!() + } + } + + /// Consensus client impl for testing + #[derive(Debug)] + pub(crate) struct TestConsensus { + chain_tip: H256, + } - result.ok_or(DownloadError::NoHeaderResponse { request_id }) + impl TestConsensus { + pub(crate) fn new() -> Self { + Self { chain_tip: H256::zero() } + } + + /// Set the chain tip + pub(crate) fn set_chain_tip(&mut self, tip: H256) { + self.chain_tip = tip; + } + } + + #[async_trait] + impl Consensus for TestConsensus { + fn forkchoice_state(&self) -> watch::Receiver { + todo!() + } + + fn tip(&self) -> H256 { + self.chain_tip + } + + fn validate_header( + &self, + _header: &Header, + _parent: &Header, + ) -> Result<(), consensus::Error> { + Ok(()) + } + } } } diff --git a/crates/stages/src/stages/mod.rs b/crates/stages/src/stages/mod.rs index 0b97f4357b4a..b4f2022b7d28 100644 --- a/crates/stages/src/stages/mod.rs +++ b/crates/stages/src/stages/mod.rs @@ -1 +1,2 @@ -mod headers; +/// The headers stage implementation +pub mod headers; From 23b85e12b2bbcbc824ffc543dd785c0405ed4f11 Mon Sep 17 00:00:00 2001 From: Roman Krasiuk Date: Sun, 16 Oct 2022 21:39:29 +0300 Subject: [PATCH 03/13] add consensus propagation test --- crates/stages/src/stages/headers.rs | 127 +++++++++++++++++++++------- 1 file changed, 95 insertions(+), 32 deletions(-) diff --git a/crates/stages/src/stages/headers.rs b/crates/stages/src/stages/headers.rs index 00dafe3a711f..27c192c51114 100644 --- a/crates/stages/src/stages/headers.rs +++ b/crates/stages/src/stages/headers.rs @@ -370,35 +370,25 @@ mod temp { mod tests { use super::{DownloadError, HeaderStage}; use assert_matches::assert_matches; + use rand::{self, Rng}; use reth_interfaces::stages::{HeaderRequest, MessageStream}; - use reth_primitives::{HeaderLocked, H256}; + use reth_primitives::{rpc::BlockId, Header, HeaderLocked, H256}; use std::sync::Arc; - use tokio::sync::mpsc::{channel, Sender}; + use tokio::sync::mpsc::channel; use tokio_stream::{pending, wrappers::ReceiverStream, StreamExt}; - fn setup_stage( - tx: Sender<(u64, HeaderRequest)>, - batch_size: u64, - request_timeout: u64, - request_retries: usize, - ) -> HeaderStage { - let client = utils::TestHeaderClient::new(tx); - let consensus = utils::TestConsensus::new(); - HeaderStage { - consensus: Arc::new(consensus), - client: Arc::new(client), - batch_size, - request_retries, - request_timeout, - } - } - #[tokio::test] async fn download_batch_timeout() { let (tx, rx) = channel(1); let (req_tx, _req_rx) = channel(1); - let (batch, timeout, retries) = (1, 1, 1); - let stage = setup_stage(req_tx, batch, timeout, retries); + let (batch_size, request_retries, request_timeout) = (1, 1, 1); + let stage = HeaderStage { + consensus: Arc::new(utils::TestConsensus::new()), + client: Arc::new(utils::TestHeaderClient::new(req_tx)), + batch_size, + request_retries, + request_timeout, + }; let mut stream = Box::pin(pending()) as MessageStream; tokio::spawn(async move { @@ -419,9 +409,14 @@ mod tests { let (tx, rx) = channel(1); let (req_tx, req_rx) = channel(3); let (res_tx, res_rx) = channel(3); - - let (batch, timeout, retries) = (1, 5, 3); - let stage = setup_stage(req_tx, batch, timeout, retries); + let (batch_size, request_retries, request_timeout) = (1, 3, 5); + let stage = HeaderStage { + consensus: Arc::new(utils::TestConsensus::new()), + client: Arc::new(utils::TestHeaderClient::new(req_tx)), + batch_size, + request_retries, + request_timeout, + }; let mut stream = Box::pin(ReceiverStream::new(res_rx)) as MessageStream; @@ -447,6 +442,47 @@ mod tests { ); } + #[tokio::test] + async fn download_batch_propagates_consensus_error() { + let (tx, rx) = channel(1); + let (req_tx, req_rx) = channel(3); + let (res_tx, res_rx) = channel(3); + let (batch_size, request_retries, request_timeout) = (1, 3, 5); + + let mut head_block = Header::default(); + head_block.state_root = H256::from_low_u64_be(rand::thread_rng().gen()); + let chain_tip = head_block.hash_slow(); + + let mut consensus = utils::TestConsensus::new(); + consensus.update_tip(chain_tip); + consensus.set_fail_validation(false); + + let stage = HeaderStage { + consensus: Arc::new(consensus), + client: Arc::new(utils::TestHeaderClient::new(req_tx)), + batch_size, + request_retries, + request_timeout, + }; + + let mut stream = + Box::pin(ReceiverStream::new(res_rx)) as MessageStream; + tokio::spawn(async move { + let result = stage + .download_batch(&HeaderLocked::default(), H256::zero(), &mut stream, &mut vec![]) + .await; + tx.send(result).await.unwrap(); + }); + + let mut req_stream = ReceiverStream::new(req_rx); + let request = req_stream.next().await; + assert_matches!( + request, + Some((_, HeaderRequest { start, .. })) + if matches!(start, BlockId::Hash(hash) if hash == chain_tip) + ); + } + mod utils { use async_trait::async_trait; use reth_interfaces::{ @@ -490,36 +526,63 @@ mod tests { /// Consensus client impl for testing #[derive(Debug)] pub(crate) struct TestConsensus { - chain_tip: H256, + /// Watcher over the forkchoice state + channel: (watch::Sender, watch::Receiver), + /// Flag whether the header validation should purposefully fail + fail_validation: bool, } impl TestConsensus { pub(crate) fn new() -> Self { - Self { chain_tip: H256::zero() } + Self { + channel: watch::channel(ForkchoiceState { + head_block_hash: H256::zero(), + finalized_block_hash: H256::zero(), + safe_block_hash: H256::zero(), + }), + fail_validation: false, + } + } + + /// Update the forkchoice state + pub(crate) fn update_tip(&mut self, tip: H256) { + let state = ForkchoiceState { + head_block_hash: tip, + finalized_block_hash: H256::zero(), + safe_block_hash: H256::zero(), + }; + self.channel.0.send(state).expect("updating forkchoice state failed"); } - /// Set the chain tip - pub(crate) fn set_chain_tip(&mut self, tip: H256) { - self.chain_tip = tip; + /// Update the validation flag + pub(crate) fn set_fail_validation(&mut self, val: bool) { + self.fail_validation = val; } } #[async_trait] impl Consensus for TestConsensus { + /// Return the watcher over the forkchoice state fn forkchoice_state(&self) -> watch::Receiver { - todo!() + self.channel.1.clone() } + /// Retrieve the current chain tip fn tip(&self) -> H256 { - self.chain_tip + self.channel.1.borrow().head_block_hash } + /// Validate the header against its parent fn validate_header( &self, _header: &Header, _parent: &Header, ) -> Result<(), consensus::Error> { - Ok(()) + if self.fail_validation { + Err(consensus::Error::ConsensusError) + } else { + Ok(()) + } } } } From cef747c870207ac99b151c8655dd10e7b4cafaf9 Mon Sep 17 00:00:00 2001 From: Roman Krasiuk Date: Sun, 16 Oct 2022 22:26:44 +0300 Subject: [PATCH 04/13] extract downloading strategy --- crates/stages/src/stages/headers/linear.rs | 249 +++++++++++++++++ crates/stages/src/stages/headers/mod.rs | 5 + .../stages/{headers.rs => headers/stage.rs} | 250 ++---------------- 3 files changed, 273 insertions(+), 231 deletions(-) create mode 100644 crates/stages/src/stages/headers/linear.rs create mode 100644 crates/stages/src/stages/headers/mod.rs rename crates/stages/src/stages/{headers.rs => headers/stage.rs} (58%) diff --git a/crates/stages/src/stages/headers/linear.rs b/crates/stages/src/stages/headers/linear.rs new file mode 100644 index 000000000000..d91bac4ae114 --- /dev/null +++ b/crates/stages/src/stages/headers/linear.rs @@ -0,0 +1,249 @@ +use super::stage::{DownloadError, Downloader}; +use async_trait::async_trait; +use rand::Rng; +use reth_interfaces::{ + consensus::Consensus, + stages::{HeaderRequest, HeadersClient, MessageStream}, +}; +use reth_primitives::{rpc::BlockId, Header, HeaderLocked, H256}; +use std::{sync::Arc, time::Duration}; +use tokio_stream::StreamExt; + +/// Download headers in batches +#[derive(Debug)] +pub struct LinearDownloader { + /// Consensus client implementation + pub consensus: Arc, + /// Downloader client implementation + pub client: Arc, + /// The batch size per one request + pub batch_size: u64, + /// A single request timeout + pub request_timeout: u64, + /// The number of retries for downloading + pub request_retries: usize, +} + +#[async_trait] +impl Downloader for LinearDownloader { + /// Download headers in batches with retries. + /// Returns the header collection in sorted ascending order + async fn download( + &self, + head: &HeaderLocked, + tip: H256, + ) -> Result, DownloadError> { + let mut stream = self.client.stream_headers().await; + // Header order will be preserved during inserts + let mut retries = self.request_retries; + + let mut out = Vec::::new(); + loop { + match self.download_batch(&head, tip, &mut stream, &mut out).await { + Ok(done) => { + if done { + return Ok(out) + } + } + Err(e) if e.is_retryable() && retries > 0 => { + retries -= 1; + } + Err(e) => return Err(e), + } + } + } +} + +impl LinearDownloader { + async fn download_batch( + &self, + head: &HeaderLocked, + chain_tip: H256, + stream: &mut MessageStream<(u64, Vec
)>, + out: &mut Vec, + ) -> Result { + // Request headers starting from tip or earliest cached + let start = out.first().map_or(chain_tip, |h| h.parent_hash); + let request_id = self.request_headers(start).await; + + // Filter stream by request id and non empty headers content + let stream = stream.filter(|(id, headers)| request_id == *id && !headers.is_empty()); + + // Wrap the stream with a timeout + let stream = stream.timeout(Duration::from_secs(self.request_timeout)); + + // Unwrap the latest stream message which will be either + // the msg with headers or timeout error + let headers = { + let mut h = match Box::pin(stream).try_next().await { + Ok(Some((_, h))) => h, + _ => return Err(DownloadError::NoHeaderResponse { request_id }), + }; + h.sort_unstable_by_key(|h| h.number); + h + }; + + // Iterate the headers in reverse + out.reserve_exact(headers.len()); + let mut headers_rev = headers.into_iter().rev(); + while let Some(parent) = headers_rev.next() { + let parent = parent.lock(); + + if head.hash() == parent.hash() { + // We've reached the target + return Ok(true) + } + + if let Some(tail_header) = out.first() { + if !(parent.hash() == tail_header.parent_hash && + parent.number + 1 == tail_header.number) + { + // Cannot attach to the current buffer, + // discard this batch + return Ok(false) + } + + self.consensus.validate_header(&tail_header, &parent).map_err(|e| { + DownloadError::HeaderValidation { hash: parent.hash(), details: e.to_string() } + })?; + } else if parent.hash() != chain_tip { + // The buffer is empty and the first header + // does not match the one we requested + // discard this batch + // TODO: penalize the peer? + return Ok(false) + } + + out.insert(0, parent); + } + + Ok(false) + } + + /// Perform a header request. Return the request ID + async fn request_headers(&self, start: H256) -> u64 { + let request_id = rand::thread_rng().gen(); + let request = + HeaderRequest { start: BlockId::Hash(start), limit: self.batch_size, reverse: true }; + let _ = self.client.send_header_request(request_id, request).await; + request_id + } +} + +#[cfg(test)] +mod tests { + use super::{super::stage::tests::utils, DownloadError, LinearDownloader}; + use assert_matches::assert_matches; + use rand::{self, Rng}; + use reth_interfaces::stages::{HeaderRequest, MessageStream}; + use reth_primitives::{rpc::BlockId, Header, HeaderLocked, H256}; + use std::sync::Arc; + use tokio::sync::mpsc::channel; + use tokio_stream::{pending, wrappers::ReceiverStream, StreamExt}; + + #[tokio::test] + async fn download_batch_timeout() { + let (tx, rx) = channel(1); + let (req_tx, _req_rx) = channel(1); + let (batch_size, request_retries, request_timeout) = (1, 1, 1); + let downloader = LinearDownloader { + consensus: Arc::new(utils::TestConsensus::new()), + client: Arc::new(utils::TestHeaderClient::new(req_tx)), + batch_size, + request_retries, + request_timeout, + }; + + let mut stream = Box::pin(pending()) as MessageStream; + tokio::spawn(async move { + let result = downloader + .download_batch(&HeaderLocked::default(), H256::zero(), &mut stream, &mut vec![]) + .await; + tx.send(result).await.unwrap(); + }); + + assert_matches!( + *ReceiverStream::new(rx).collect::>>().await, + [Err(DownloadError::NoHeaderResponse { .. })] + ); + } + + #[tokio::test] + async fn download_batch_timeout_on_invalid_messages() { + let (tx, rx) = channel(1); + let (req_tx, req_rx) = channel(3); + let (res_tx, res_rx) = channel(3); + let (batch_size, request_retries, request_timeout) = (1, 3, 5); + let downloader = LinearDownloader { + consensus: Arc::new(utils::TestConsensus::new()), + client: Arc::new(utils::TestHeaderClient::new(req_tx)), + batch_size, + request_retries, + request_timeout, + }; + + let mut stream = + Box::pin(ReceiverStream::new(res_rx)) as MessageStream; + tokio::spawn(async move { + let result = downloader + .download_batch(&HeaderLocked::default(), H256::zero(), &mut stream, &mut vec![]) + .await; + tx.send(result).await.unwrap(); + }); + + let mut last_req_id = None; + let mut req_stream = ReceiverStream::new(req_rx); + while let Some((id, _req)) = req_stream.next().await { + // Since the receiving channel filters by id and message length - + // randomize the input to the tested filter + res_tx.send((id.saturating_add(id % 2), vec![])).await.unwrap(); + last_req_id = Some(id); + } + + assert_matches!( + *ReceiverStream::new(rx).collect::>>().await, + [Err(DownloadError::NoHeaderResponse { request_id })] if request_id == last_req_id.unwrap() + ); + } + + #[tokio::test] + async fn download_batch_propagates_consensus_error() { + let (tx, rx) = channel(1); + let (req_tx, req_rx) = channel(3); + let (res_tx, res_rx) = channel(3); + let (batch_size, request_retries, request_timeout) = (1, 3, 5); + + let mut head_block = Header::default(); + head_block.state_root = H256::from_low_u64_be(rand::thread_rng().gen()); + let chain_tip = head_block.hash_slow(); + + let mut consensus = utils::TestConsensus::new(); + consensus.update_tip(chain_tip); + consensus.set_fail_validation(false); + + let downloader = LinearDownloader { + consensus: Arc::new(consensus), + client: Arc::new(utils::TestHeaderClient::new(req_tx)), + batch_size, + request_retries, + request_timeout, + }; + + let mut stream = + Box::pin(ReceiverStream::new(res_rx)) as MessageStream; + tokio::spawn(async move { + let result = downloader + .download_batch(&HeaderLocked::default(), H256::zero(), &mut stream, &mut vec![]) + .await; + tx.send(result).await.unwrap(); + }); + + let mut req_stream = ReceiverStream::new(req_rx); + let request = req_stream.next().await; + assert_matches!( + request, + Some((_, HeaderRequest { start, .. })) + if matches!(start, BlockId::Hash(hash) if hash == chain_tip) + ); + } +} diff --git a/crates/stages/src/stages/headers/mod.rs b/crates/stages/src/stages/headers/mod.rs new file mode 100644 index 000000000000..8785c28188af --- /dev/null +++ b/crates/stages/src/stages/headers/mod.rs @@ -0,0 +1,5 @@ +/// The headers stage implementation +pub mod stage; + +/// The downloading strategies +pub mod linear; diff --git a/crates/stages/src/stages/headers.rs b/crates/stages/src/stages/headers/stage.rs similarity index 58% rename from crates/stages/src/stages/headers.rs rename to crates/stages/src/stages/headers/stage.rs index 27c192c51114..91d86971a109 100644 --- a/crates/stages/src/stages/headers.rs +++ b/crates/stages/src/stages/headers/stage.rs @@ -10,29 +10,32 @@ use reth_interfaces::{ stages::{HeaderRequest, HeadersClient, MessageStream}, }; use reth_primitives::{rpc::BlockId, BigEndianHash, BlockNumber, Header, HeaderLocked, H256, U256}; -use std::{sync::Arc, time::Duration}; +use std::{fmt::Debug, sync::Arc}; use thiserror::Error; -use tokio_stream::StreamExt; use tracing::*; const HEADERS: StageId = StageId("HEADERS"); -// TODO: docs -// TODO: add tracing - /// The headers stage implementation for staged sync #[derive(Debug)] pub struct HeaderStage { + /// Strategy for downloading the headers + pub downloader: Arc, /// Consensus client implementation pub consensus: Arc, /// Downloader client implementation pub client: Arc, - /// The batch size per one request - pub batch_size: u64, - /// A single request timeout - pub request_timeout: u64, - /// The number of retries for downloadign - pub request_retries: usize, +} + +/// The header downloading strategy +#[async_trait] +pub trait Downloader: Sync + Send + Debug { + /// Download the headers + async fn download( + &self, + latest: &HeaderLocked, + tip: H256, + ) -> Result, DownloadError>; } /// The downloader error type @@ -58,7 +61,8 @@ pub enum DownloadError { } impl DownloadError { - fn is_retryable(&self) -> bool { + /// Returns bool indicating whether this error is retryable or fatal + pub fn is_retryable(&self) -> bool { matches!(self, DownloadError::NoHeaderResponse { .. }) } } @@ -99,7 +103,7 @@ where let forkchoice_state = self.next_forkchoice_state(&head.hash()).await; - let headers = match self.download(&head, forkchoice_state).await { + let headers = match self.downloader.download(&head, forkchoice_state).await { Ok(res) => res, Err(e) => match e { DownloadError::NoHeaderResponse { request_id } => { @@ -197,107 +201,6 @@ impl HeaderStage { } } - /// Download headers in batches with retries. - /// Returns the header collection in sorted ascending order - async fn download( - &self, - head: &HeaderLocked, - tip: H256, - ) -> Result, DownloadError> { - let mut stream = self.client.stream_headers().await; - // Header order will be preserved during inserts - let mut retries = self.request_retries; - - let mut out = Vec::::new(); - loop { - match self.download_batch(head, tip, &mut stream, &mut out).await { - Ok(done) => { - if done { - return Ok(out) - } - } - Err(e) if e.is_retryable() && retries > 0 => { - retries -= 1; - } - Err(e) => return Err(e), - } - } - } - - async fn download_batch( - &self, - head: &HeaderLocked, - chain_tip: H256, - stream: &mut MessageStream<(u64, Vec
)>, - out: &mut Vec, - ) -> Result { - // Request headers starting from tip or earliest cached - let start = out.first().map_or(chain_tip, |h| h.parent_hash); - let request_id = self.request_headers(start).await; - - // Filter stream by request id and non empty headers content - let stream = stream.filter(|(id, headers)| request_id == *id && !headers.is_empty()); - - // Wrap the stream with a timeout - let stream = stream.timeout(Duration::from_secs(self.request_timeout)); - - // Unwrap the latest stream message which will be either - // the msg with headers or timeout error - let headers = { - let mut h = match Box::pin(stream).try_next().await { - Ok(Some((_, h))) => h, - _ => return Err(DownloadError::NoHeaderResponse { request_id }), - }; - h.sort_unstable_by_key(|h| h.number); - h - }; - - // Iterate the headers in reverse - out.reserve_exact(headers.len()); - let mut headers_rev = headers.into_iter().rev(); - while let Some(parent) = headers_rev.next() { - let parent = parent.lock(); - - if head.hash() == parent.hash() { - // We've reached the target - return Ok(true) - } - - if let Some(tail_header) = out.first() { - if !(parent.hash() == tail_header.parent_hash && - parent.number + 1 == tail_header.number) - { - // Cannot attach to the current buffer, - // discard this batch - return Ok(false) - } - - self.consensus.validate_header(&tail_header, &parent).map_err(|e| { - DownloadError::HeaderValidation { hash: parent.hash(), details: e.to_string() } - })?; - } else if parent.hash() != chain_tip { - // The buffer is empty and the first header - // does not match the one we requested - // discard this batch - // TODO: penalize the peer? - return Ok(false) - } - - out.insert(0, parent); - } - - Ok(false) - } - - /// Perform a header request. Return the request ID - async fn request_headers(&self, start: H256) -> u64 { - let request_id = rand::thread_rng().gen(); - let request = - HeaderRequest { start: BlockId::Hash(start), limit: self.batch_size, reverse: true }; - let _ = self.client.send_header_request(request_id, request).await; - request_id - } - /// Write downloaded headers to the database async fn write_headers<'tx, E: mdbx::EnvironmentKind>( &self, @@ -367,123 +270,8 @@ mod temp { } #[cfg(test)] -mod tests { - use super::{DownloadError, HeaderStage}; - use assert_matches::assert_matches; - use rand::{self, Rng}; - use reth_interfaces::stages::{HeaderRequest, MessageStream}; - use reth_primitives::{rpc::BlockId, Header, HeaderLocked, H256}; - use std::sync::Arc; - use tokio::sync::mpsc::channel; - use tokio_stream::{pending, wrappers::ReceiverStream, StreamExt}; - - #[tokio::test] - async fn download_batch_timeout() { - let (tx, rx) = channel(1); - let (req_tx, _req_rx) = channel(1); - let (batch_size, request_retries, request_timeout) = (1, 1, 1); - let stage = HeaderStage { - consensus: Arc::new(utils::TestConsensus::new()), - client: Arc::new(utils::TestHeaderClient::new(req_tx)), - batch_size, - request_retries, - request_timeout, - }; - - let mut stream = Box::pin(pending()) as MessageStream; - tokio::spawn(async move { - let result = stage - .download_batch(&HeaderLocked::default(), H256::zero(), &mut stream, &mut vec![]) - .await; - tx.send(result).await.unwrap(); - }); - - assert_matches!( - *ReceiverStream::new(rx).collect::>>().await, - [Err(DownloadError::NoHeaderResponse { .. })] - ); - } - - #[tokio::test] - async fn download_batch_timeout_on_invalid_messages() { - let (tx, rx) = channel(1); - let (req_tx, req_rx) = channel(3); - let (res_tx, res_rx) = channel(3); - let (batch_size, request_retries, request_timeout) = (1, 3, 5); - let stage = HeaderStage { - consensus: Arc::new(utils::TestConsensus::new()), - client: Arc::new(utils::TestHeaderClient::new(req_tx)), - batch_size, - request_retries, - request_timeout, - }; - - let mut stream = - Box::pin(ReceiverStream::new(res_rx)) as MessageStream; - tokio::spawn(async move { - let result = stage - .download_batch(&HeaderLocked::default(), H256::zero(), &mut stream, &mut vec![]) - .await; - tx.send(result).await.unwrap(); - }); - - let mut last_req_id = None; - let mut req_stream = ReceiverStream::new(req_rx); - while let Some((id, _req)) = req_stream.next().await { - // Since the receiving channel filters by id and message length - - // randomize the input to the tested filter - res_tx.send((id.saturating_add(id % 2), vec![])).await.unwrap(); - last_req_id = Some(id); - } - - assert_matches!( - *ReceiverStream::new(rx).collect::>>().await, - [Err(DownloadError::NoHeaderResponse { request_id })] if request_id == last_req_id.unwrap() - ); - } - - #[tokio::test] - async fn download_batch_propagates_consensus_error() { - let (tx, rx) = channel(1); - let (req_tx, req_rx) = channel(3); - let (res_tx, res_rx) = channel(3); - let (batch_size, request_retries, request_timeout) = (1, 3, 5); - - let mut head_block = Header::default(); - head_block.state_root = H256::from_low_u64_be(rand::thread_rng().gen()); - let chain_tip = head_block.hash_slow(); - - let mut consensus = utils::TestConsensus::new(); - consensus.update_tip(chain_tip); - consensus.set_fail_validation(false); - - let stage = HeaderStage { - consensus: Arc::new(consensus), - client: Arc::new(utils::TestHeaderClient::new(req_tx)), - batch_size, - request_retries, - request_timeout, - }; - - let mut stream = - Box::pin(ReceiverStream::new(res_rx)) as MessageStream; - tokio::spawn(async move { - let result = stage - .download_batch(&HeaderLocked::default(), H256::zero(), &mut stream, &mut vec![]) - .await; - tx.send(result).await.unwrap(); - }); - - let mut req_stream = ReceiverStream::new(req_rx); - let request = req_stream.next().await; - assert_matches!( - request, - Some((_, HeaderRequest { start, .. })) - if matches!(start, BlockId::Hash(hash) if hash == chain_tip) - ); - } - - mod utils { +pub(crate) mod tests { + pub(crate) mod utils { use async_trait::async_trait; use reth_interfaces::{ consensus::{self, Consensus}, From b59f74f0bc5e5006c463aee2629f82b7909e4c6b Mon Sep 17 00:00:00 2001 From: Roman Krasiuk Date: Mon, 17 Oct 2022 15:35:05 +0300 Subject: [PATCH 05/13] cleanup tests and db encoding --- Cargo.lock | 1 + crates/db/src/kv/models/blocks.rs | 4 +- crates/stages/Cargo.toml | 2 +- crates/stages/src/stages/headers/linear.rs | 97 ++++++++++++---------- crates/stages/src/stages/headers/stage.rs | 97 ++++++++-------------- 5 files changed, 91 insertions(+), 110 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 83c7487c241d..2be42a01ceb0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3204,6 +3204,7 @@ dependencies = [ "futures-core", "pin-project-lite", "tokio", + "tokio-util", ] [[package]] diff --git a/crates/db/src/kv/models/blocks.rs b/crates/db/src/kv/models/blocks.rs index 4e319373e86e..aa54f2fa2e0a 100644 --- a/crates/db/src/kv/models/blocks.rs +++ b/crates/db/src/kv/models/blocks.rs @@ -20,9 +20,9 @@ pub type HeaderHash = H256; /// element as BlockNumber, helps out with querying/sorting. /// /// Since it's used as a key, the `BlockNumber` is not compressed when encoding it. -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] #[allow(non_camel_case_types)] -pub struct BlockNumHash((BlockNumber, BlockHash)); +pub struct BlockNumHash(pub (BlockNumber, BlockHash)); impl BlockNumHash { /// Consumes `Self` and returns [`BlockNumber`], [`BlockHash`] diff --git a/crates/stages/Cargo.toml b/crates/stages/Cargo.toml index 766cf2423926..95549c482ee6 100644 --- a/crates/stages/Cargo.toml +++ b/crates/stages/Cargo.toml @@ -25,7 +25,7 @@ futures = "0.3" [dev-dependencies] tokio = { version = "*", features = ["rt", "sync", "macros"] } -tokio-stream = "0.1.10" +tokio-stream = { version = "0.1.11", features = ["sync"] } tempfile = "3.3.0" reth-db = { path = "../db", features = ["test-utils"] } reth-rpc-types = { path = "../net/rpc-types" } diff --git a/crates/stages/src/stages/headers/linear.rs b/crates/stages/src/stages/headers/linear.rs index 1ce87851778a..d03e660305f8 100644 --- a/crates/stages/src/stages/headers/linear.rs +++ b/crates/stages/src/stages/headers/linear.rs @@ -39,13 +39,13 @@ impl Downloader for LinearDownloader { let mut out = Vec::::new(); loop { - match self.download_batch(&head, tip, &mut stream, &mut out).await { + match self.download_batch(head, tip, &mut stream, &mut out).await { Ok(done) => { if done { return Ok(out) } } - Err(e) if e.is_retryable() && retries > 0 => { + Err(e) if e.is_retryable() && retries > 1 => { retries -= 1; } Err(e) => return Err(e), @@ -132,85 +132,89 @@ impl LinearDownloader { #[cfg(test)] mod tests { - use super::{super::stage::tests::utils, DownloadError, LinearDownloader}; + use super::{super::stage::tests::utils, DownloadError, Downloader, LinearDownloader}; use assert_matches::assert_matches; use rand::{self, Rng}; - use reth_interfaces::stages::{HeaderRequest, MessageStream}; + use reth_interfaces::stages::HeaderRequest; use reth_primitives::{rpc::BlockId, Header, HeaderLocked, H256}; use std::sync::Arc; - use tokio::sync::mpsc::channel; - use tokio_stream::{pending, wrappers::ReceiverStream, StreamExt}; + use tokio::sync::{broadcast, mpsc, oneshot}; + use tokio_stream::{wrappers::ReceiverStream, StreamExt}; #[tokio::test] async fn download_batch_timeout() { - let (tx, rx) = channel(1); - let (req_tx, _req_rx) = channel(1); + let (tx, rx) = oneshot::channel(); + let (req_tx, req_rx) = mpsc::channel(1); + let (_res_tx, res_rx) = broadcast::channel(1); let (batch_size, request_retries, request_timeout) = (1, 1, 1); + let downloader = LinearDownloader { consensus: Arc::new(utils::TestConsensus::new()), - client: Arc::new(utils::TestHeaderClient::new(req_tx)), + client: Arc::new(utils::TestHeaderClient::new(req_tx, res_rx)), batch_size, request_retries, request_timeout, }; - let mut stream = Box::pin(pending()) as MessageStream; tokio::spawn(async move { - let result = downloader - .download_batch(&HeaderLocked::default(), H256::zero(), &mut stream, &mut vec![]) - .await; - tx.send(result).await.unwrap(); + let result = downloader.download(&HeaderLocked::default(), H256::zero()).await; + tx.send(result).expect("failed to forward download response"); }); - assert_matches!( - *ReceiverStream::new(rx).collect::>>().await, - [Err(DownloadError::NoHeaderResponse { .. })] - ); + let requests = ReceiverStream::new(req_rx).collect::>().await; + assert_eq!(requests.len(), request_retries); + assert_matches!(rx.await, Ok(Err(DownloadError::NoHeaderResponse { .. }))); } #[tokio::test] async fn download_batch_timeout_on_invalid_messages() { - let (tx, rx) = channel(1); - let (req_tx, req_rx) = channel(3); - let (res_tx, res_rx) = channel(3); - let (batch_size, request_retries, request_timeout) = (1, 3, 5); + let (tx, rx) = oneshot::channel(); + let (req_tx, req_rx) = mpsc::channel(1); + let (res_tx, res_rx) = broadcast::channel(1); + let (batch_size, request_retries, request_timeout) = (1, 5, 1); + + let client = Arc::new(utils::TestHeaderClient::new(req_tx, res_rx)); let downloader = LinearDownloader { consensus: Arc::new(utils::TestConsensus::new()), - client: Arc::new(utils::TestHeaderClient::new(req_tx)), + client: client.clone(), batch_size, request_retries, request_timeout, }; - let mut stream = - Box::pin(ReceiverStream::new(res_rx)) as MessageStream; tokio::spawn(async move { - let result = downloader - .download_batch(&HeaderLocked::default(), H256::zero(), &mut stream, &mut vec![]) - .await; - tx.send(result).await.unwrap(); + let result = downloader.download(&HeaderLocked::default(), H256::zero()).await; + tx.send(result).expect("failed to forward download response"); }); + let mut num_of_reqs = 0; let mut last_req_id = None; - let mut req_stream = ReceiverStream::new(req_rx); + let mut req_stream = Box::pin(ReceiverStream::new(req_rx)); while let Some((id, _req)) = req_stream.next().await { // Since the receiving channel filters by id and message length - // randomize the input to the tested filter - res_tx.send((id.saturating_add(id % 2), vec![])).await.unwrap(); + res_tx.send((id.saturating_add(id % 2), vec![])).expect("failed to send response"); + num_of_reqs += 1; last_req_id = Some(id); + + if num_of_reqs == request_retries { + drop(res_tx); + break + } } + assert_eq!(num_of_reqs, request_retries); assert_matches!( - *ReceiverStream::new(rx).collect::>>().await, - [Err(DownloadError::NoHeaderResponse { request_id })] if request_id == last_req_id.unwrap() + rx.await, + Ok(Err(DownloadError::NoHeaderResponse { request_id })) if request_id == last_req_id.unwrap() ); } #[tokio::test] async fn download_batch_propagates_consensus_error() { - let (tx, rx) = channel(1); - let (req_tx, req_rx) = channel(3); - let (res_tx, res_rx) = channel(3); + let (tx, rx) = oneshot::channel(); + let (req_tx, req_rx) = mpsc::channel(1); + let (res_tx, res_rx) = broadcast::channel(1); let (batch_size, request_retries, request_timeout) = (1, 3, 5); let mut head_block = Header::default(); @@ -223,29 +227,32 @@ mod tests { let downloader = LinearDownloader { consensus: Arc::new(consensus), - client: Arc::new(utils::TestHeaderClient::new(req_tx)), + client: Arc::new(utils::TestHeaderClient::new(req_tx, res_rx)), batch_size, request_retries, request_timeout, }; - let mut stream = - Box::pin(ReceiverStream::new(res_rx)) as MessageStream; tokio::spawn(async move { - let result = downloader - .download_batch(&HeaderLocked::default(), H256::zero(), &mut stream, &mut vec![]) - .await; - tx.send(result).await.unwrap(); + let result = downloader.download(&HeaderLocked::default(), H256::zero()).await; + tx.send(result).expect("failed to forward download response"); }); - let mut req_stream = ReceiverStream::new(req_rx); - let request = req_stream.next().await; + let request = ReceiverStream::new(req_rx).next().await; assert_matches!( request, Some((_, HeaderRequest { start, .. })) if matches!(start, BlockId::Hash(hash) if hash == chain_tip) ); + let request = request.unwrap(); + res_tx.send((request.0, vec![head_block])).expect("failed to send header"); + + assert_matches!( + rx.await, + Ok(Err(DownloadError::HeaderValidation { hash, .. })) // TODO: + ); + // TODO: match the propagated error } } diff --git a/crates/stages/src/stages/headers/stage.rs b/crates/stages/src/stages/headers/stage.rs index 91d86971a109..feebe261b886 100644 --- a/crates/stages/src/stages/headers/stage.rs +++ b/crates/stages/src/stages/headers/stage.rs @@ -1,15 +1,16 @@ use crate::{ExecInput, ExecOutput, Stage, StageError, StageId, UnwindInput, UnwindOutput}; use async_trait::async_trait; -use rand::Rng; use reth_db::{ - kv::{table::Encode, tables, tx::Tx}, + kv::{ + blocks::BlockNumHash, + table::{Decode, Encode}, + tables, + tx::Tx, + }, mdbx::{self, WriteFlags}, }; -use reth_interfaces::{ - consensus::Consensus, - stages::{HeaderRequest, HeadersClient, MessageStream}, -}; -use reth_primitives::{rpc::BlockId, BigEndianHash, BlockNumber, Header, HeaderLocked, H256, U256}; +use reth_interfaces::{consensus::Consensus, stages::HeadersClient}; +use reth_primitives::{BigEndianHash, BlockNumber, Header, HeaderLocked, H256, U256}; use std::{fmt::Debug, sync::Arc}; use thiserror::Error; use tracing::*; @@ -94,11 +95,9 @@ where // download the headers // TODO: check if some upper block constraint is necessary - let last_hash = - H256::from_uint(&tx.get::(last_block_num)?.unwrap()); - let last_header: Header = temp::decode_header( - tx.get::(temp::num_hash_to_key(last_block_num, last_hash))?.unwrap(), - ); + let last_hash = tx.get::(last_block_num)?.unwrap(); + let last_header: Header = + tx.get::((last_block_num, last_hash).into())?.unwrap(); let head = HeaderLocked::new(last_header, last_hash); let forkchoice_state = self.next_forkchoice_state(&head.hash()).await; @@ -137,15 +136,14 @@ where } let mut walker = tx.cursor::()?.walk(input.unwind_to + 1)?; - while let Some((_, hash)) = walker.next().transpose()? { - tx.delete::(hash.encode().to_vec(), None)?; + while let Some(key) = walker.next().transpose()? { + tx.delete::(key.into(), None)?; } // TODO: cleanup let mut cur = tx.cursor::()?; let mut entry = cur.last()?; - while let Some((key, _)) = entry { - let (num, _) = temp::num_hash_from_key(key); + while let Some((BlockNumHash((num, _)), _)) = entry { if num <= input.unwind_to { break } @@ -165,8 +163,7 @@ where let mut cur = tx.cursor::()?; let mut entry = cur.last()?; - while let Some((key, _)) = entry { - let (num, _) = temp::num_hash_from_key(key); + while let Some((BlockNumHash((num, _)), _)) = entry { if num <= input.unwind_to { break } @@ -184,8 +181,8 @@ impl HeaderStage { tx: &'tx mut Tx<'_, mdbx::RW, E>, height: BlockNumber, ) -> Result<(), StageError> { - let hash = H256::from_uint(&tx.get::(height)?.unwrap()); - let td: Vec = tx.get::(temp::num_hash_to_key(height, hash))?.unwrap(); + let hash = tx.get::(height)?.unwrap(); + let td: Vec = tx.get::((height, hash).into())?.unwrap(); self.client.update_status(height, hash, H256::from_slice(&td)).await; Ok(()) } @@ -220,55 +217,26 @@ impl HeaderStage { continue } - let hash = header.hash(); - let number = header.number; - let num_hash_key = temp::num_hash_to_key(header.number, hash); + let key: BlockNumHash = (header.number, header.hash()).into(); + let header = header.unlock(); + latest = header.number; td += header.difficulty; - cursor_header_number.put(hash.to_fixed_bytes().to_vec(), header.number, None)?; - cursor_header.put( - num_hash_key.clone(), - temp::encode_header(header.unlock()), - Some(WriteFlags::APPEND), - )?; - cursor_canonical.put(number, hash.into_uint(), Some(WriteFlags::APPEND))?; + cursor_header_number.put(key, header.number, None)?; + cursor_header.put(key, header, Some(WriteFlags::APPEND))?; + cursor_canonical.put(key.0 .0, key.0 .1, Some(WriteFlags::APPEND))?; cursor_td.put( - num_hash_key, + key, H256::from_uint(&td).as_bytes().to_vec(), Some(WriteFlags::APPEND), )?; - - latest = number; } Ok(latest) } } -// TODO: remove -mod temp { - use super::*; - - pub(crate) fn num_hash_to_key(number: BlockNumber, hash: H256) -> Vec { - let mut key = number.to_be_bytes().to_vec(); - key.extend(hash.0); - key - } - - pub(crate) fn num_hash_from_key(key: Vec) -> (BlockNumber, H256) { - todo!() - } - - pub(crate) fn encode_header(_header: Header) -> Vec { - todo!() - } - - pub(crate) fn decode_header(_bytes: Vec) -> Header { - todo!() - } -} - #[cfg(test)] pub(crate) mod tests { pub(crate) mod utils { @@ -280,20 +248,25 @@ pub(crate) mod tests { use reth_primitives::{Header, H256, H512}; use reth_rpc_types::engine::ForkchoiceState; use std::collections::HashSet; - use tokio::sync::{mpsc::Sender, watch}; + use tokio::sync::{broadcast, mpsc::Sender, watch}; + use tokio_stream::{wrappers::BroadcastStream, StreamExt}; pub(crate) type HeaderResponse = (u64, Vec
); #[derive(Debug)] pub(crate) struct TestHeaderClient { tx: Sender<(u64, HeaderRequest)>, + rx: broadcast::Receiver, } impl TestHeaderClient { /// Construct a new test header downloader. - /// `tx` is the - pub(crate) fn new(tx: Sender<(u64, HeaderRequest)>) -> Self { - Self { tx } + /// TODO: + pub(crate) fn new( + tx: Sender<(u64, HeaderRequest)>, + rx: broadcast::Receiver, + ) -> Self { + Self { tx, rx } } } @@ -302,12 +275,12 @@ pub(crate) mod tests { async fn update_status(&self, _height: u64, _hash: H256, _td: H256) {} async fn send_header_request(&self, id: u64, request: HeaderRequest) -> HashSet { - self.tx.send((id, request)).await.unwrap(); + self.tx.send((id, request)).await.expect("failed to send request"); HashSet::default() } async fn stream_headers(&self) -> MessageStream<(u64, Vec
)> { - todo!() + Box::pin(BroadcastStream::new(self.rx.resubscribe()).filter_map(|e| e.ok())) } } From bd84305b78d11bd27ae6e16a4be28802b1f03ad1 Mon Sep 17 00:00:00 2001 From: Roman Krasiuk Date: Mon, 17 Oct 2022 15:39:29 +0300 Subject: [PATCH 06/13] comment --- crates/stages/src/stages/headers/stage.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/stages/src/stages/headers/stage.rs b/crates/stages/src/stages/headers/stage.rs index feebe261b886..6167445532aa 100644 --- a/crates/stages/src/stages/headers/stage.rs +++ b/crates/stages/src/stages/headers/stage.rs @@ -261,7 +261,8 @@ pub(crate) mod tests { impl TestHeaderClient { /// Construct a new test header downloader. - /// TODO: + /// `tx` is the `Sender` for header requests + /// `rx` is the `Receiver` of header responses pub(crate) fn new( tx: Sender<(u64, HeaderRequest)>, rx: broadcast::Receiver, From d5fab6041774cdc0814286b5c007aac2d6003277 Mon Sep 17 00:00:00 2001 From: Roman Krasiuk Date: Tue, 18 Oct 2022 07:40:55 +0300 Subject: [PATCH 07/13] fix validation err propagation test --- crates/stages/src/stages/headers/linear.rs | 45 +++++++++++++--------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/crates/stages/src/stages/headers/linear.rs b/crates/stages/src/stages/headers/linear.rs index d03e660305f8..5ec3fd51b386 100644 --- a/crates/stages/src/stages/headers/linear.rs +++ b/crates/stages/src/stages/headers/linear.rs @@ -39,7 +39,8 @@ impl Downloader for LinearDownloader { let mut out = Vec::::new(); loop { - match self.download_batch(head, tip, &mut stream, &mut out).await { + let result = self.download_batch(head, tip, &mut stream, &mut out).await; + match result { Ok(done) => { if done { return Ok(out) @@ -138,11 +139,14 @@ mod tests { use reth_interfaces::stages::HeaderRequest; use reth_primitives::{rpc::BlockId, Header, HeaderLocked, H256}; use std::sync::Arc; - use tokio::sync::{broadcast, mpsc, oneshot}; + use tokio::sync::{ + broadcast, mpsc, + oneshot::{self, error::TryRecvError}, + }; use tokio_stream::{wrappers::ReceiverStream, StreamExt}; #[tokio::test] - async fn download_batch_timeout() { + async fn download_timeout() { let (tx, rx) = oneshot::channel(); let (req_tx, req_rx) = mpsc::channel(1); let (_res_tx, res_rx) = broadcast::channel(1); @@ -167,7 +171,7 @@ mod tests { } #[tokio::test] - async fn download_batch_timeout_on_invalid_messages() { + async fn download_timeout_on_invalid_messages() { let (tx, rx) = oneshot::channel(); let (req_tx, req_rx) = mpsc::channel(1); let (res_tx, res_rx) = broadcast::channel(1); @@ -189,7 +193,7 @@ mod tests { let mut num_of_reqs = 0; let mut last_req_id = None; - let mut req_stream = Box::pin(ReceiverStream::new(req_rx)); + let mut req_stream = ReceiverStream::new(req_rx); while let Some((id, _req)) = req_stream.next().await { // Since the receiving channel filters by id and message length - // randomize the input to the tested filter @@ -211,19 +215,25 @@ mod tests { } #[tokio::test] - async fn download_batch_propagates_consensus_error() { + async fn download_propagates_consensus_validation_error() { let (tx, rx) = oneshot::channel(); let (req_tx, req_rx) = mpsc::channel(1); let (res_tx, res_rx) = broadcast::channel(1); - let (batch_size, request_retries, request_timeout) = (1, 3, 5); + let (batch_size, request_retries, request_timeout) = (100, 2, 5); - let mut head_block = Header::default(); - head_block.state_root = H256::from_low_u64_be(rand::thread_rng().gen()); - let chain_tip = head_block.hash_slow(); + let mut tip_parent = Header::default(); + tip_parent.nonce = rand::thread_rng().gen(); + tip_parent.number = 1; + let parent_hash = tip_parent.hash_slow(); + + let mut tip_header = Header::default(); + tip_header.number = 2; + tip_header.nonce = rand::thread_rng().gen(); + tip_header.parent_hash = parent_hash; + let chain_tip = tip_header.hash_slow(); let mut consensus = utils::TestConsensus::new(); - consensus.update_tip(chain_tip); - consensus.set_fail_validation(false); + consensus.set_fail_validation(true); let downloader = LinearDownloader { consensus: Arc::new(consensus), @@ -234,11 +244,12 @@ mod tests { }; tokio::spawn(async move { - let result = downloader.download(&HeaderLocked::default(), H256::zero()).await; + let result = downloader.download(&HeaderLocked::default(), chain_tip).await; tx.send(result).expect("failed to forward download response"); }); - let request = ReceiverStream::new(req_rx).next().await; + let mut stream = Box::pin(ReceiverStream::new(req_rx)); + let request = stream.next().await; assert_matches!( request, Some((_, HeaderRequest { start, .. })) @@ -246,13 +257,11 @@ mod tests { ); let request = request.unwrap(); - res_tx.send((request.0, vec![head_block])).expect("failed to send header"); + res_tx.send((request.0, vec![tip_header, tip_parent])).expect("failed to send header"); assert_matches!( rx.await, - Ok(Err(DownloadError::HeaderValidation { hash, .. })) // TODO: + Ok(Err(DownloadError::HeaderValidation { hash, .. })) if hash == parent_hash ); - - // TODO: match the propagated error } } From 28a614f25d82ef73e4dde7b4ca05266857c83de6 Mon Sep 17 00:00:00 2001 From: Roman Krasiuk Date: Tue, 18 Oct 2022 12:14:11 +0300 Subject: [PATCH 08/13] add chain tip test & test runner --- crates/stages/src/stages/headers/linear.rs | 160 ++++++++++++++------- crates/stages/src/stages/headers/stage.rs | 7 +- 2 files changed, 111 insertions(+), 56 deletions(-) diff --git a/crates/stages/src/stages/headers/linear.rs b/crates/stages/src/stages/headers/linear.rs index 5ec3fd51b386..6de4b3bb88a7 100644 --- a/crates/stages/src/stages/headers/linear.rs +++ b/crates/stages/src/stages/headers/linear.rs @@ -139,57 +139,41 @@ mod tests { use reth_interfaces::stages::HeaderRequest; use reth_primitives::{rpc::BlockId, Header, HeaderLocked, H256}; use std::sync::Arc; - use tokio::sync::{ - broadcast, mpsc, - oneshot::{self, error::TryRecvError}, - }; + use tokio::sync::{broadcast, mpsc, oneshot::error::TryRecvError}; use tokio_stream::{wrappers::ReceiverStream, StreamExt}; #[tokio::test] async fn download_timeout() { - let (tx, rx) = oneshot::channel(); let (req_tx, req_rx) = mpsc::channel(1); let (_res_tx, res_rx) = broadcast::channel(1); - let (batch_size, request_retries, request_timeout) = (1, 1, 1); - - let downloader = LinearDownloader { - consensus: Arc::new(utils::TestConsensus::new()), - client: Arc::new(utils::TestHeaderClient::new(req_tx, res_rx)), - batch_size, - request_retries, - request_timeout, - }; - tokio::spawn(async move { - let result = downloader.download(&HeaderLocked::default(), H256::zero()).await; - tx.send(result).expect("failed to forward download response"); - }); + let runner = test_runner::LinearTestRunner::new(); + let retries = runner.retries; + let rx = runner.run( + utils::TestConsensus::new(), + utils::TestHeaderClient::new(req_tx, res_rx), + HeaderLocked::default(), + H256::zero(), + ); let requests = ReceiverStream::new(req_rx).collect::>().await; - assert_eq!(requests.len(), request_retries); + assert_eq!(requests.len(), retries); assert_matches!(rx.await, Ok(Err(DownloadError::NoHeaderResponse { .. }))); } #[tokio::test] async fn download_timeout_on_invalid_messages() { - let (tx, rx) = oneshot::channel(); let (req_tx, req_rx) = mpsc::channel(1); let (res_tx, res_rx) = broadcast::channel(1); - let (batch_size, request_retries, request_timeout) = (1, 5, 1); - - let client = Arc::new(utils::TestHeaderClient::new(req_tx, res_rx)); - let downloader = LinearDownloader { - consensus: Arc::new(utils::TestConsensus::new()), - client: client.clone(), - batch_size, - request_retries, - request_timeout, - }; - tokio::spawn(async move { - let result = downloader.download(&HeaderLocked::default(), H256::zero()).await; - tx.send(result).expect("failed to forward download response"); - }); + let runner = test_runner::LinearTestRunner::new(); + let retries = runner.retries; + let rx = runner.run( + utils::TestConsensus::new(), + utils::TestHeaderClient::new(req_tx, res_rx), + HeaderLocked::default(), + H256::zero(), + ); let mut num_of_reqs = 0; let mut last_req_id = None; @@ -201,13 +185,13 @@ mod tests { num_of_reqs += 1; last_req_id = Some(id); - if num_of_reqs == request_retries { + if num_of_reqs == retries { drop(res_tx); break } } - assert_eq!(num_of_reqs, request_retries); + assert_eq!(num_of_reqs, retries); assert_matches!( rx.await, Ok(Err(DownloadError::NoHeaderResponse { request_id })) if request_id == last_req_id.unwrap() @@ -216,10 +200,8 @@ mod tests { #[tokio::test] async fn download_propagates_consensus_validation_error() { - let (tx, rx) = oneshot::channel(); let (req_tx, req_rx) = mpsc::channel(1); let (res_tx, res_rx) = broadcast::channel(1); - let (batch_size, request_retries, request_timeout) = (100, 2, 5); let mut tip_parent = Header::default(); tip_parent.nonce = rand::thread_rng().gen(); @@ -235,18 +217,13 @@ mod tests { let mut consensus = utils::TestConsensus::new(); consensus.set_fail_validation(true); - let downloader = LinearDownloader { - consensus: Arc::new(consensus), - client: Arc::new(utils::TestHeaderClient::new(req_tx, res_rx)), - batch_size, - request_retries, - request_timeout, - }; - - tokio::spawn(async move { - let result = downloader.download(&HeaderLocked::default(), chain_tip).await; - tx.send(result).expect("failed to forward download response"); - }); + let runner = test_runner::LinearTestRunner::new(); + let rx = runner.run( + consensus, + utils::TestHeaderClient::new(req_tx, res_rx), + HeaderLocked::default(), + chain_tip, + ); let mut stream = Box::pin(ReceiverStream::new(req_rx)); let request = stream.next().await; @@ -264,4 +241,87 @@ mod tests { Ok(Err(DownloadError::HeaderValidation { hash, .. })) if hash == parent_hash ); } + + #[tokio::test] + async fn download_starts_with_chain_tip() { + let (req_tx, req_rx) = mpsc::channel(1); + let (res_tx, res_rx) = broadcast::channel(1); + + let mut tip_parent = Header::default(); + tip_parent.nonce = rand::thread_rng().gen(); + tip_parent.number = 1; + let parent_hash = tip_parent.hash_slow(); + + let mut tip = Header::default(); + tip.parent_hash = parent_hash; + tip.number = 2; + tip.nonce = rand::thread_rng().gen(); + + let runner = test_runner::LinearTestRunner::new(); + let mut rx = runner.run( + utils::TestConsensus::new(), + utils::TestHeaderClient::new(req_tx, res_rx), + tip_parent.clone().lock(), + tip.hash_slow(), + ); + + let mut stream = ReceiverStream::new(req_rx); + let request = stream.next().await.unwrap(); + let mut corrupted_tip = tip.clone(); + corrupted_tip.nonce = rand::thread_rng().gen(); + res_tx + .send((request.0, vec![corrupted_tip, tip_parent.clone()])) + .expect("failed to send header"); + assert_matches!(rx.try_recv(), Err(TryRecvError::Empty)); + + let request = stream.next().await.unwrap(); + res_tx + .send((request.0, vec![tip.clone(), tip_parent.clone()])) + .expect("failed to send header"); + + let result = rx.await; + assert_matches!(result, Ok(Ok(ref val)) if val.len() == 1); + assert_eq!(*result.unwrap().unwrap().first().unwrap(), tip.lock()); + } + + mod test_runner { + use super::*; + use reth_interfaces::{consensus::Consensus, stages::HeadersClient}; + use tokio::sync::oneshot; + + type DownloadResult = Result, DownloadError>; + + pub(crate) struct LinearTestRunner { + pub(crate) retries: usize, + test_ch: (oneshot::Sender, oneshot::Receiver), + } + + impl LinearTestRunner { + pub(crate) fn new() -> Self { + Self { test_ch: oneshot::channel(), retries: 5 } + } + + pub(crate) fn run( + self, + consensus: impl Consensus + 'static, + client: impl HeadersClient + 'static, + head: HeaderLocked, + tip: H256, + ) -> oneshot::Receiver { + let (tx, rx) = self.test_ch; + let downloader = LinearDownloader { + consensus: Arc::new(consensus), + client: Arc::new(client), + request_retries: self.retries, + batch_size: 100, + request_timeout: 3, + }; + tokio::spawn(async move { + let result = downloader.download(&head, tip).await; + tx.send(result).expect("failed to forward download response"); + }); + rx + } + } + } } diff --git a/crates/stages/src/stages/headers/stage.rs b/crates/stages/src/stages/headers/stage.rs index 6167445532aa..c45791ccccd5 100644 --- a/crates/stages/src/stages/headers/stage.rs +++ b/crates/stages/src/stages/headers/stage.rs @@ -1,12 +1,7 @@ use crate::{ExecInput, ExecOutput, Stage, StageError, StageId, UnwindInput, UnwindOutput}; use async_trait::async_trait; use reth_db::{ - kv::{ - blocks::BlockNumHash, - table::{Decode, Encode}, - tables, - tx::Tx, - }, + kv::{blocks::BlockNumHash, tables, tx::Tx}, mdbx::{self, WriteFlags}, }; use reth_interfaces::{consensus::Consensus, stages::HeadersClient}; From 7e39155437ab6cf00568e8425b66357d72c77763 Mon Sep 17 00:00:00 2001 From: Roman Krasiuk Date: Wed, 19 Oct 2022 22:59:34 +0300 Subject: [PATCH 09/13] stream attempt --- Cargo.lock | 3 + crates/net/rpc-types/src/eth/engine.rs | 2 +- crates/stages/Cargo.toml | 8 +- crates/stages/src/error.rs | 6 +- crates/stages/src/pipeline.rs | 6 +- crates/stages/src/stage.rs | 2 +- .../stages/src/stages/headers/downloader.rs | 93 ++++++ crates/stages/src/stages/headers/linear.rs | 313 +++++++++++++----- crates/stages/src/stages/headers/mod.rs | 6 +- crates/stages/src/stages/headers/parallel.rs | 132 ++++++++ crates/stages/src/stages/headers/stage.rs | 190 ++++++++--- 11 files changed, 609 insertions(+), 152 deletions(-) create mode 100644 crates/stages/src/stages/headers/downloader.rs create mode 100644 crates/stages/src/stages/headers/parallel.rs diff --git a/Cargo.lock b/Cargo.lock index 81aa33d16217..dac550cfba2a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2495,10 +2495,13 @@ dependencies = [ name = "reth-stages" version = "0.1.0" dependencies = [ + "anyhow", "aquamarine", "assert_matches", "async-trait", "futures", + "pin-project", + "pin-project-lite", "rand", "reth-db", "reth-interfaces", diff --git a/crates/net/rpc-types/src/eth/engine.rs b/crates/net/rpc-types/src/eth/engine.rs index 2559f40677e6..f5f4e93f4a80 100644 --- a/crates/net/rpc-types/src/eth/engine.rs +++ b/crates/net/rpc-types/src/eth/engine.rs @@ -25,7 +25,7 @@ pub struct ExecutionPayload { } /// This structure encapsulates the fork choice state -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ForkchoiceState { pub head_block_hash: H256, diff --git a/crates/stages/Cargo.toml b/crates/stages/Cargo.toml index 95549c482ee6..b62512348f2a 100644 --- a/crates/stages/Cargo.toml +++ b/crates/stages/Cargo.toml @@ -10,19 +10,25 @@ description = "Staged syncing primitives used in reth." [dependencies] reth-primitives = { path = "../primitives" } reth-interfaces = { path = "../interfaces" } +reth-rpc-types = { path = "../net/rpc-types" } reth-db = { path = "../db" } -thiserror = "1.0.37" tracing = "0.1.36" tracing-futures = "0.2.5" tokio = { version = "1.21.2", features = ["sync"] } rand = "0.8" # TODO: tokio-stream = "0.1.11" +pin-project = "1.0.12" +pin-project-lite = "0.2" aquamarine = "0.1.12" # async/futures async-trait = "0.1.57" futures = "0.3" +# error handling +thiserror = "1.0.37" +anyhow = "1.0.65" + [dev-dependencies] tokio = { version = "*", features = ["rt", "sync", "macros"] } tokio-stream = { version = "0.1.11", features = ["sync"] } diff --git a/crates/stages/src/error.rs b/crates/stages/src/error.rs index 2f501a358f19..f7b7e6bc058b 100644 --- a/crates/stages/src/error.rs +++ b/crates/stages/src/error.rs @@ -1,4 +1,4 @@ -use crate::pipeline::PipelineEvent; +use crate::{pipeline::PipelineEvent, Stage}; use reth_db::kv::KVError; use reth_primitives::BlockNumber; use thiserror::Error; @@ -20,7 +20,7 @@ pub enum StageError { Database(#[from] KVError), /// The stage encountered an internal error. #[error(transparent)] - Internal(Box), + Internal(#[from] anyhow::Error), } /// A pipeline execution error. @@ -37,5 +37,5 @@ pub enum PipelineError { Channel(#[from] SendError), /// The stage encountered an internal error. #[error(transparent)] - Internal(Box), + Internal(#[from] anyhow::Error), } diff --git a/crates/stages/src/pipeline.rs b/crates/stages/src/pipeline.rs index e64a7be57994..b049c8a0ef20 100644 --- a/crates/stages/src/pipeline.rs +++ b/crates/stages/src/pipeline.rs @@ -749,7 +749,7 @@ mod tests { pub(crate) struct TestStage { id: StageId, exec_outputs: VecDeque>, - unwind_outputs: VecDeque>>, + unwind_outputs: VecDeque>, } impl TestStage { @@ -764,7 +764,7 @@ mod tests { pub(crate) fn add_unwind( mut self, - output: Result>, + output: Result, ) -> Self { self.unwind_outputs.push_back(output); self @@ -797,7 +797,7 @@ mod tests { &mut self, _: &mut Tx<'tx, mdbx::RW, E>, _: UnwindInput, - ) -> Result> + ) -> Result where 'db: 'tx, { diff --git a/crates/stages/src/stage.rs b/crates/stages/src/stage.rs index 03cd6c965905..4ff23d2b3ae9 100644 --- a/crates/stages/src/stage.rs +++ b/crates/stages/src/stage.rs @@ -74,7 +74,7 @@ where &mut self, tx: &mut Tx<'tx, mdbx::RW, E>, input: UnwindInput, - ) -> Result> + ) -> Result where 'db: 'tx; } diff --git a/crates/stages/src/stages/headers/downloader.rs b/crates/stages/src/stages/headers/downloader.rs new file mode 100644 index 000000000000..d475d3678b2d --- /dev/null +++ b/crates/stages/src/stages/headers/downloader.rs @@ -0,0 +1,93 @@ +use async_trait::async_trait; +use rand::Rng; +use reth_interfaces::{ + consensus::Consensus, + stages::{HeaderRequest, HeadersClient, MessageStream}, +}; +use reth_primitives::{rpc::BlockId, Header, HeaderLocked, H256}; +use reth_rpc_types::engine::ForkchoiceState; +use std::{fmt::Debug, sync::Arc, time::Duration}; +use thiserror::Error; +use tokio_stream::StreamExt; + +/// The header downloading strategy +#[async_trait] +pub trait Downloader: Sync + Send + Debug { + /// The request timeout in seconds + fn timeout(&self) -> u64; + + /// Download the headers + async fn download( + &self, + client: Arc, + consensus: Arc, + head: &HeaderLocked, + forkchoice: &ForkchoiceState, + ) -> Result, DownloadError>; + + /// Perform a header request. Return the request ID + async fn download_headers( + &self, + stream: &mut MessageStream<(u64, Vec
)>, + client: Arc, + start: BlockId, + limit: u64, + ) -> Result, DownloadError> { + let request_id = rand::thread_rng().gen(); + let request = HeaderRequest { start, limit, reverse: true }; + let _ = client.send_header_request(request_id, request).await; + + // Filter stream by request id and non empty headers content + let stream = stream.filter(|(id, headers)| request_id == *id && !headers.is_empty()); + + // Wrap the stream with a timeout + let stream = stream.timeout(Duration::from_secs(self.timeout())); + match Box::pin(stream).try_next().await { + Ok(Some((_, h))) => Ok(h), + _ => return Err(DownloadError::NoHeaderResponse { request_id }), + } + } + + /// Validate whether the header is valid in relation to it's parent + fn validate( + &self, + consensus: Arc, + header: &HeaderLocked, + parent: &HeaderLocked, + ) -> Result { + if !(parent.hash() == header.parent_hash && parent.number + 1 == header.number) { + return Ok(false) + } + + consensus.validate_header(&header, &parent).map_err(|e| { + DownloadError::HeaderValidation { hash: parent.hash(), details: e.to_string() } + })?; + Ok(true) + } +} + +/// The downloader error type +#[derive(Error, Debug, Clone)] +pub enum DownloadError { + /// Header validation failed + #[error("Failed to validate header {hash}. Details: {details}.")] + HeaderValidation { + /// Hash of header failing validation + hash: H256, + /// The details of validation failure + details: String, + }, + /// No headers reponse received + #[error("Failed to get headers for request {request_id}.")] + NoHeaderResponse { + /// The last request ID + request_id: u64, + }, +} + +impl DownloadError { + /// Returns bool indicating whether this error is retryable or fatal + pub fn is_retryable(&self) -> bool { + matches!(self, DownloadError::NoHeaderResponse { .. }) + } +} diff --git a/crates/stages/src/stages/headers/linear.rs b/crates/stages/src/stages/headers/linear.rs index 6de4b3bb88a7..bba7a16b0fc0 100644 --- a/crates/stages/src/stages/headers/linear.rs +++ b/crates/stages/src/stages/headers/linear.rs @@ -1,21 +1,30 @@ -use super::stage::{DownloadError, Downloader}; +use super::downloader::{DownloadError, Downloader}; use async_trait::async_trait; +use futures::{future::BoxFuture, stream::BoxStream, Future, FutureExt}; +use pin_project::pin_project; +use pin_project_lite::pin_project as pin_project_lite; use rand::Rng; use reth_interfaces::{ consensus::Consensus, stages::{HeaderRequest, HeadersClient, MessageStream}, }; -use reth_primitives::{rpc::BlockId, Header, HeaderLocked, H256}; -use std::{sync::Arc, time::Duration}; -use tokio_stream::StreamExt; +use reth_primitives::{rpc::BlockId, Header, HeaderLocked, H256, H512}; +use reth_rpc_types::engine::ForkchoiceState; +use std::{ + collections::HashSet, + ops::DerefMut, + pin::Pin, + process::Output, + sync::{Arc, Mutex}, + task::{Context, Poll}, + time::Duration, +}; +use tokio::time::{Instant, Sleep}; +use tokio_stream::Stream; /// Download headers in batches #[derive(Debug)] pub struct LinearDownloader { - /// Consensus client implementation - pub consensus: Arc, - /// Downloader client implementation - pub client: Arc, /// The batch size per one request pub batch_size: u64, /// A single request timeout @@ -24,28 +33,50 @@ pub struct LinearDownloader { pub request_retries: usize, } +type HeaderIter = Box + Send>; #[async_trait] impl Downloader for LinearDownloader { + /// The request timeout + fn timeout(&self) -> u64 { + self.request_timeout + } + /// Download headers in batches with retries. /// Returns the header collection in sorted ascending order async fn download( &self, + client: Arc, + consensus: Arc, head: &HeaderLocked, - tip: H256, + forkchoice: &ForkchoiceState, ) -> Result, DownloadError> { - let mut stream = self.client.stream_headers().await; - // Header order will be preserved during inserts + let mut stream = client.stream_headers().await; let mut retries = self.request_retries; - let mut out = Vec::::new(); + // Header order will be preserved during inserts + let mut out = (Box::new(std::iter::empty()) as HeaderIter).peekable(); loop { - let result = self.download_batch(head, tip, &mut stream, &mut out).await; + let result = self + .download_batch( + &mut stream, + client.clone(), + consensus.clone(), + forkchoice, + head, + out.peek(), + ) + .await; match result { - Ok(done) => { - if done { - return Ok(out) + Ok(result) => match result { + LinearDownloadResult::Batch(headers) => { + out = (Box::new(headers.into_iter().chain(out)) as HeaderIter).peekable(); } - } + LinearDownloadResult::Finished(headers) => { + out = (Box::new(headers.into_iter().chain(out)) as HeaderIter).peekable(); + return Ok(out.collect()) + } + LinearDownloadResult::Ignore => (), + }, Err(e) if e.is_retryable() && retries > 1 => { retries -= 1; } @@ -55,85 +86,190 @@ impl Downloader for LinearDownloader { } } +/// The intermediate download result +#[derive(Debug)] +pub enum LinearDownloadResult { + /// Downloaded last batch up to tip + Finished(Vec), + /// Downloaded batch + Batch(Vec), + /// Ignore this batch + Ignore, +} + impl LinearDownloader { - async fn download_batch( - &self, - head: &HeaderLocked, - chain_tip: H256, - stream: &mut MessageStream<(u64, Vec
)>, - out: &mut Vec, - ) -> Result { + async fn download_batch<'a>( + &'a self, + stream: &'a mut MessageStream<(u64, Vec
)>, + client: Arc, + consensus: Arc, + forkchoice: &'a ForkchoiceState, + head: &'a HeaderLocked, + earliest: Option<&HeaderLocked>, + ) -> Result { // Request headers starting from tip or earliest cached - let start = out.first().map_or(chain_tip, |h| h.parent_hash); - let request_id = self.request_headers(start).await; - - // Filter stream by request id and non empty headers content - let stream = stream.filter(|(id, headers)| request_id == *id && !headers.is_empty()); - - // Wrap the stream with a timeout - let stream = stream.timeout(Duration::from_secs(self.request_timeout)); - - // Unwrap the latest stream message which will be either - // the msg with headers or timeout error - let headers = { - let mut h = match Box::pin(stream).try_next().await { - Ok(Some((_, h))) => h, - _ => return Err(DownloadError::NoHeaderResponse { request_id }), - }; - h.sort_unstable_by_key(|h| h.number); - h - }; - - // Iterate the headers in reverse - out.reserve_exact(headers.len()); - let mut headers_rev = headers.into_iter().rev(); - while let Some(parent) = headers_rev.next() { + let start = earliest.map_or(forkchoice.head_block_hash, |h| h.parent_hash); + let mut headers = self + .download_headers(stream, client.clone(), BlockId::Hash(start), self.batch_size) + .await?; + headers.sort_unstable_by_key(|h| h.number); + + let mut out = Vec::with_capacity(headers.len()); + // Iterate headers in reverse + for parent in headers.into_iter().rev() { let parent = parent.lock(); if head.hash() == parent.hash() { // We've reached the target - return Ok(true) + return Ok(LinearDownloadResult::Finished(out)) } - if let Some(tail_header) = out.first() { - if !(parent.hash() == tail_header.parent_hash && - parent.number + 1 == tail_header.number) - { - // Cannot attach to the current buffer, - // discard this batch - return Ok(false) + match out.first().or(earliest) { + Some(header) if !self.validate(consensus.clone(), header, &parent)? => { + return Ok(LinearDownloadResult::Ignore) } - - self.consensus.validate_header(&tail_header, &parent).map_err(|e| { - DownloadError::HeaderValidation { hash: parent.hash(), details: e.to_string() } - })?; - } else if parent.hash() != chain_tip { - // The buffer is empty and the first header - // does not match the one we requested - // discard this batch + // The buffer is empty and the first header does not match the tip, discard // TODO: penalize the peer? - return Ok(false) - } + None if parent.hash() != forkchoice.head_block_hash => { + return Ok(LinearDownloadResult::Ignore) + } + _ => (), + }; out.insert(0, parent); } - Ok(false) + Ok(LinearDownloadResult::Batch(out)) + } +} + +mod linear_stream { + use super::*; + + pin_project_lite! { + pub(crate) struct LinearDownloadStream<'a, S: Stream)>> { + #[pin] + stream: &'a mut S, + #[pin] + state: LinearStreamState<'a>, + client: &'a Arc, + consensus: &'a Arc, + tip: H256, + head: H256, + earliest: Option, + retries: usize, + } + } + + impl<'a, S: Stream)> + Unpin> LinearDownloadStream<'a, S> { + pub(crate) fn new( + stream: &'a mut S, + client: &'a Arc, + consensus: &'a Arc, + tip: H256, + head: H256, + retries: usize, + ) -> Self { + Self { + stream, + state: LinearStreamState::Prepare, + tip, + head, + client, + consensus, + earliest: None, + retries, + } + } + } + + enum LinearStreamState<'a> { + Prepare, + PollRequest(u64, BoxFuture<'a, HashSet>), + PollHeaders(u64), + Done, } - /// Perform a header request. Return the request ID - async fn request_headers(&self, start: H256) -> u64 { - let request_id = rand::thread_rng().gen(); - let request = - HeaderRequest { start: BlockId::Hash(start), limit: self.batch_size, reverse: true }; - let _ = self.client.send_header_request(request_id, request).await; - request_id + impl<'a, S: Stream)> + Unpin + Send> Stream + for LinearDownloadStream<'a, S> + { + type Item = Result, DownloadError>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + match *this.state { + LinearStreamState::Prepare => { + let request_id = rand::thread_rng().gen(); + let request = + HeaderRequest { start: BlockId::Hash(*this.tip), limit: 1, reverse: true }; + this.state.set(LinearStreamState::PollRequest( + request_id, + this.client.send_header_request(request_id, request), + )); + Poll::Pending + } + LinearStreamState::PollRequest(req_id, ref mut fut) => match fut.poll_unpin(cx) { + Poll::Ready(_peers) => { + this.state.set(LinearStreamState::PollHeaders(req_id)); + Poll::Pending + } + Poll::Pending => Poll::Pending, + }, + LinearStreamState::PollHeaders(req_id) => match this.stream.poll_next(cx) { + Poll::Ready(Some((id, mut headers))) if id == req_id && !headers.is_empty() => { + headers.sort_unstable_by_key(|h| h.number); + + this.state.set(LinearStreamState::Prepare); + let mut out = Vec::with_capacity(headers.len()); + for parent in headers.into_iter().rev() { + let parent = parent.lock(); + if *this.head == parent.hash() { + this.state.set(LinearStreamState::Done); + break + } + + match out.first().or(this.earliest.as_ref()) { + Some(header) => { + if !(parent.hash() == header.parent_hash && + parent.number + 1 == header.number) + { + return Poll::Pending + } + + if let Err(e) = this.consensus.validate_header(&header, &parent) + { + return Poll::Ready(Some(Err( + DownloadError::HeaderValidation { + hash: parent.hash(), + details: e.to_string(), + }, + ))) + } + } + None if parent.hash() != *this.tip => return Poll::Pending, + _ => (), + }; + + out.insert(0, parent); + } + *this.earliest = Some(out.first().unwrap().clone()); + Poll::Ready(Some(Ok(out))) + } + _ => Poll::Pending, + }, + LinearStreamState::Done => Poll::Ready(None), + } + } + + fn size_hint(&self) -> (usize, Option) { + (1, None) + } } } #[cfg(test)] mod tests { - use super::{super::stage::tests::utils, DownloadError, Downloader, LinearDownloader}; + use super::{super::stage::tests::test_utils, DownloadError, Downloader, LinearDownloader}; use assert_matches::assert_matches; use rand::{self, Rng}; use reth_interfaces::stages::HeaderRequest; @@ -150,8 +286,8 @@ mod tests { let runner = test_runner::LinearTestRunner::new(); let retries = runner.retries; let rx = runner.run( - utils::TestConsensus::new(), - utils::TestHeaderClient::new(req_tx, res_rx), + test_utils::TestConsensus::new(), + test_utils::TestHeaderClient::new(req_tx, res_rx), HeaderLocked::default(), H256::zero(), ); @@ -169,8 +305,8 @@ mod tests { let runner = test_runner::LinearTestRunner::new(); let retries = runner.retries; let rx = runner.run( - utils::TestConsensus::new(), - utils::TestHeaderClient::new(req_tx, res_rx), + test_utils::TestConsensus::new(), + test_utils::TestHeaderClient::new(req_tx, res_rx), HeaderLocked::default(), H256::zero(), ); @@ -214,13 +350,13 @@ mod tests { tip_header.parent_hash = parent_hash; let chain_tip = tip_header.hash_slow(); - let mut consensus = utils::TestConsensus::new(); + let mut consensus = test_utils::TestConsensus::new(); consensus.set_fail_validation(true); let runner = test_runner::LinearTestRunner::new(); let rx = runner.run( consensus, - utils::TestHeaderClient::new(req_tx, res_rx), + test_utils::TestHeaderClient::new(req_tx, res_rx), HeaderLocked::default(), chain_tip, ); @@ -259,8 +395,8 @@ mod tests { let runner = test_runner::LinearTestRunner::new(); let mut rx = runner.run( - utils::TestConsensus::new(), - utils::TestHeaderClient::new(req_tx, res_rx), + test_utils::TestConsensus::new(), + test_utils::TestHeaderClient::new(req_tx, res_rx), tip_parent.clone().lock(), tip.hash_slow(), ); @@ -287,6 +423,7 @@ mod tests { mod test_runner { use super::*; use reth_interfaces::{consensus::Consensus, stages::HeadersClient}; + use reth_rpc_types::engine::ForkchoiceState; use tokio::sync::oneshot; type DownloadResult = Result, DownloadError>; @@ -301,7 +438,7 @@ mod tests { Self { test_ch: oneshot::channel(), retries: 5 } } - pub(crate) fn run( + pub(crate) fn run<'a>( self, consensus: impl Consensus + 'static, client: impl HeadersClient + 'static, @@ -310,14 +447,16 @@ mod tests { ) -> oneshot::Receiver { let (tx, rx) = self.test_ch; let downloader = LinearDownloader { - consensus: Arc::new(consensus), - client: Arc::new(client), request_retries: self.retries, batch_size: 100, request_timeout: 3, }; tokio::spawn(async move { - let result = downloader.download(&head, tip).await; + let mut forkchoice = ForkchoiceState::default(); + forkchoice.head_block_hash = tip; + let result = downloader + .download(Arc::new(client), Arc::new(consensus), &head, &forkchoice) + .await; tx.send(result).expect("failed to forward download response"); }); rx diff --git a/crates/stages/src/stages/headers/mod.rs b/crates/stages/src/stages/headers/mod.rs index 8785c28188af..336ddaf474cc 100644 --- a/crates/stages/src/stages/headers/mod.rs +++ b/crates/stages/src/stages/headers/mod.rs @@ -1,5 +1,9 @@ /// The headers stage implementation pub mod stage; -/// The downloading strategies +/// The downloader trait +pub mod downloader; +/// The linear downloading strategy pub mod linear; +/// The parallel downloading strategy +pub mod parallel; diff --git a/crates/stages/src/stages/headers/parallel.rs b/crates/stages/src/stages/headers/parallel.rs new file mode 100644 index 000000000000..e5801e24497b --- /dev/null +++ b/crates/stages/src/stages/headers/parallel.rs @@ -0,0 +1,132 @@ +use super::downloader::{DownloadError, Downloader}; +use async_trait::async_trait; +use reth_interfaces::{ + consensus::Consensus, + stages::{HeadersClient, MessageStream}, +}; +use reth_primitives::{rpc::BlockId, Header, HeaderLocked, H256}; +use reth_rpc_types::engine::ForkchoiceState; +use std::sync::Arc; + +/// TODO: +#[derive(Debug)] +pub struct ParallelDownloader { + /// The number of parallel requests + pub par_count: usize, + /// The batch size per one request + pub batch_size: u64, + /// A single request timeout + pub request_timeout: u64, + /// The number of retries for downloading + pub request_retries: usize, +} + +#[async_trait] +impl Downloader for ParallelDownloader { + /// The request timeout + fn timeout(&self) -> u64 { + self.request_timeout + } + + /// Download the headers + async fn download( + &self, + client: Arc, + consensus: Arc, + head: &HeaderLocked, + forkchoice: &ForkchoiceState, + ) -> Result, DownloadError> { + let mut stream = client.stream_headers().await; + let mut retries = self.request_retries; + let mut reached_finalized = false; + todo!() + // // Header order will be preserved during inserts + // let mut out = Vec::::new(); + + // // Request blocks by hash until finalized hash + // loop { + // let result = self + // .download_batch( + // client.clone(), + // consensus.clone(), + // head, + // forkchoice, + // &mut stream, + // &mut out, + // ) + // .await; + // match result { + // Ok(result) => match result { + // ParallelResult::Discard | ParallelResult::Continue => (), + // ParallelResult::ReachedFinalized => reached_finalized = true, + // ParallelResult::ReachedHead => return Ok(out), + // }, + // Err(e) if e.is_retryable() && retries > 1 => { + // retries -= 1; + // } + // Err(e) => return Err(e), + // } + // } + } +} + +enum ParallelResult { + Continue, + Discard, + ReachedFinalized, + ReachedHead, +} + +impl ParallelDownloader { + async fn download_batch<'a>( + &'a self, + client: Arc, + consensus: Arc, + head: &'a HeaderLocked, + forkchoice: &'a ForkchoiceState, + stream: &'a mut MessageStream<(u64, Vec
)>, + out: &'a mut Vec, + ) -> Result { + // Request headers starting from tip or earliest cached + let start = out.first().map_or(forkchoice.head_block_hash, |h| h.parent_hash); + let mut headers = self + .download_headers(stream, client.clone(), BlockId::Hash(start), self.batch_size) + .await?; + headers.sort_unstable_by_key(|h| h.number); + + // Iterate the headers in reverse + out.reserve_exact(headers.len()); + let mut headers_rev = headers.into_iter().rev(); + + let mut result = ParallelResult::Continue; + while let Some(parent) = headers_rev.next() { + let parent = parent.lock(); + + if parent.hash() == head.hash() { + // We've reached the target + return Ok(ParallelResult::ReachedHead) + } + + match out.first() { + Some(tail_header) if !self.validate(consensus.clone(), tail_header, &parent)? => { + // Cannot attach to the current buffer, discard + return Ok(ParallelResult::Discard) + } + // The buffer is empty and the first header does not match the tip, discard + // TODO: penalize the peer? + None if parent.hash() != forkchoice.head_block_hash => { + return Ok(ParallelResult::Discard) + } + _ => (), + }; + + if parent.hash() == forkchoice.finalized_block_hash { + result = ParallelResult::ReachedFinalized; + } + + out.insert(0, parent); + } + + Ok(result) + } +} diff --git a/crates/stages/src/stages/headers/stage.rs b/crates/stages/src/stages/headers/stage.rs index c45791ccccd5..bb776b3cc2cc 100644 --- a/crates/stages/src/stages/headers/stage.rs +++ b/crates/stages/src/stages/headers/stage.rs @@ -1,3 +1,4 @@ +use super::downloader::{DownloadError, Downloader}; use crate::{ExecInput, ExecOutput, Stage, StageError, StageId, UnwindInput, UnwindOutput}; use async_trait::async_trait; use reth_db::{ @@ -5,7 +6,8 @@ use reth_db::{ mdbx::{self, WriteFlags}, }; use reth_interfaces::{consensus::Consensus, stages::HeadersClient}; -use reth_primitives::{BigEndianHash, BlockNumber, Header, HeaderLocked, H256, U256}; +use reth_primitives::{BigEndianHash, BlockNumber, HeaderLocked, H256, U256}; +use reth_rpc_types::engine::ForkchoiceState; use std::{fmt::Debug, sync::Arc}; use thiserror::Error; use tracing::*; @@ -14,57 +16,34 @@ const HEADERS: StageId = StageId("HEADERS"); /// The headers stage implementation for staged sync #[derive(Debug)] -pub struct HeaderStage { +pub struct HeaderStage { /// Strategy for downloading the headers - pub downloader: Arc, + pub downloader: D, /// Consensus client implementation pub consensus: Arc, /// Downloader client implementation pub client: Arc, } -/// The header downloading strategy -#[async_trait] -pub trait Downloader: Sync + Send + Debug { - /// Download the headers - async fn download( - &self, - latest: &HeaderLocked, - tip: H256, - ) -> Result, DownloadError>; -} - -/// The downloader error type +/// The header stage error #[derive(Error, Debug)] -pub enum DownloadError { - /// Header validation failed - #[error("Failed to validate header {hash}. Details: {details}.")] - HeaderValidation { - /// Hash of header failing validation - hash: H256, - /// The details of validation failure - details: String, - }, - /// No headers reponse received - #[error("Failed to get headers for request {request_id}.")] - NoHeaderResponse { - /// The last request ID - request_id: u64, - }, - /// The stage encountered an internal error. - #[error(transparent)] - Internal(Box), +pub enum HeaderStageError { + #[error("no cannonical hash for block #{number}")] + NoCannonicalHash { number: BlockNumber }, + #[error("no cannonical hash for block #{number}")] + NoCannonicalHeader { number: BlockNumber }, + #[error("no header for block #{number} ({hash})")] + NoHeader { number: BlockNumber, hash: H256 }, } -impl DownloadError { - /// Returns bool indicating whether this error is retryable or fatal - pub fn is_retryable(&self) -> bool { - matches!(self, DownloadError::NoHeaderResponse { .. }) +impl Into for HeaderStageError { + fn into(self) -> StageError { + StageError::Internal(anyhow::Error::new(self)) } } #[async_trait] -impl<'db, E> Stage<'db, E> for HeaderStage +impl<'db, E, D: Downloader> Stage<'db, E> for HeaderStage where E: mdbx::EnvironmentKind, { @@ -90,14 +69,23 @@ where // download the headers // TODO: check if some upper block constraint is necessary - let last_hash = tx.get::(last_block_num)?.unwrap(); - let last_header: Header = - tx.get::((last_block_num, last_hash).into())?.unwrap(); + let last_hash = + tx.get::(last_block_num)?.ok_or_else(|| -> StageError { + HeaderStageError::NoCannonicalHash { number: last_block_num }.into() + })?; + let last_header = tx + .get::((last_block_num, last_hash).into())? + .ok_or_else(|| -> StageError { + HeaderStageError::NoHeader { number: last_block_num, hash: last_hash }.into() + })?; let head = HeaderLocked::new(last_header, last_hash); - let forkchoice_state = self.next_forkchoice_state(&head.hash()).await; - - let headers = match self.downloader.download(&head, forkchoice_state).await { + let forkchoice = self.next_forkchoice_state(&head.hash()).await; + let headers = match self + .downloader + .download(self.client.clone(), self.consensus.clone(), &head, &forkchoice) + .await + { Ok(res) => res, Err(e) => match e { DownloadError::NoHeaderResponse { request_id } => { @@ -112,7 +100,6 @@ where warn!("validation error for header {hash}: {details}"); return Err(StageError::Validation { block: last_block_num }) } - DownloadError::Internal(e) => return Err(StageError::Internal(e)), }, }; @@ -125,7 +112,7 @@ where &mut self, tx: &mut Tx<'tx, mdbx::RW, E>, input: UnwindInput, - ) -> Result> { + ) -> Result { if let Some(bad_block) = input.bad_block { todo!() } @@ -170,25 +157,27 @@ where } } -impl HeaderStage { +impl HeaderStage { async fn update_head<'tx, E: mdbx::EnvironmentKind>( &self, tx: &'tx mut Tx<'_, mdbx::RW, E>, height: BlockNumber, ) -> Result<(), StageError> { - let hash = tx.get::(height)?.unwrap(); - let td: Vec = tx.get::((height, hash).into())?.unwrap(); + let hash = tx.get::(height)?.ok_or_else(|| -> StageError { + HeaderStageError::NoCannonicalHeader { number: height }.into() + })?; + let td: Vec = tx.get::((height, hash).into())?.unwrap(); // TODO: self.client.update_status(height, hash, H256::from_slice(&td)).await; Ok(()) } - async fn next_forkchoice_state(&self, head: &H256) -> H256 { + async fn next_forkchoice_state(&self, head: &H256) -> ForkchoiceState { let mut state_rcv = self.consensus.forkchoice_state(); loop { let _ = state_rcv.changed().await; let forkchoice = state_rcv.borrow(); if !forkchoice.head_block_hash.is_zero() && forkchoice.head_block_hash != *head { - return forkchoice.head_block_hash + return forkchoice.clone() } } } @@ -223,7 +212,7 @@ impl HeaderStage { cursor_canonical.put(key.0 .0, key.0 .1, Some(WriteFlags::APPEND))?; cursor_td.put( key, - H256::from_uint(&td).as_bytes().to_vec(), + H256::from_uint(&td).as_bytes().to_vec(), // TODO: Some(WriteFlags::APPEND), )?; } @@ -234,15 +223,78 @@ impl HeaderStage { #[cfg(test)] pub(crate) mod tests { - pub(crate) mod utils { + use super::*; + use crate::util::db::TxContainer; + use assert_matches::assert_matches; + use reth_db::{ + kv::{test_utils as test_db_utils, EnvKind}, + mdbx, + }; + use tokio::sync::{broadcast, mpsc}; + + #[tokio::test] + async fn headers_stage_empty_db() { + let (req_tx, _req_rx) = mpsc::channel(1); + let (_res_tx, res_rx) = broadcast::channel(1); + + let mut stage = HeaderStage { + client: Arc::new(test_utils::TestHeaderClient::new(req_tx, res_rx)), + consensus: Arc::new(test_utils::TestConsensus::new()), + downloader: test_utils::TestDownloader::new(Ok(vec![])), + }; + + let mut db = test_db_utils::create_test_db::(EnvKind::RW); + let mut tx = TxContainer::new(&mut db).unwrap(); + + let input = ExecInput { previous_stage: None, stage_progress: None }; + assert_matches!( + stage.execute(tx.get_mut(), input).await, + Err(StageError::Internal(err)) + if matches!( + err.downcast_ref::(), + Some(HeaderStageError::NoCannonicalHeader { .. } + ) + ) + ); + } + + #[tokio::test] + // TODO: + async fn headers_stage_() { + let (req_tx, _req_rx) = mpsc::channel(1); + let (_res_tx, res_rx) = broadcast::channel(1); + + let mut stage = HeaderStage { + client: Arc::new(test_utils::TestHeaderClient::new(req_tx, res_rx)), + consensus: Arc::new(test_utils::TestConsensus::new()), + downloader: test_utils::TestDownloader::new(Ok(vec![])), + }; + + let mut db = test_db_utils::create_test_db::(EnvKind::RW); + let mut tx = TxContainer::new(&mut db).unwrap(); + + let input = ExecInput { previous_stage: None, stage_progress: None }; + assert_matches!( + stage.execute(tx.get_mut(), input).await, + Err(StageError::Internal(err)) + if matches!( + err.downcast_ref::(), + Some(HeaderStageError::NoCannonicalHeader { .. } + ) + ) + ); + } + + pub(crate) mod test_utils { + use super::super::{DownloadError, Downloader}; use async_trait::async_trait; use reth_interfaces::{ consensus::{self, Consensus}, stages::{HeaderRequest, HeadersClient, MessageStream}, }; - use reth_primitives::{Header, H256, H512}; + use reth_primitives::{Header, HeaderLocked, H256, H512}; use reth_rpc_types::engine::ForkchoiceState; - use std::collections::HashSet; + use std::{collections::HashSet, sync::Arc}; use tokio::sync::{broadcast, mpsc::Sender, watch}; use tokio_stream::{wrappers::BroadcastStream, StreamExt}; @@ -342,5 +394,33 @@ pub(crate) mod tests { } } } + + #[derive(Debug)] + pub(crate) struct TestDownloader { + result: Result, DownloadError>, + } + + impl TestDownloader { + pub(crate) fn new(result: Result, DownloadError>) -> Self { + Self { result } + } + } + + #[async_trait] + impl Downloader for TestDownloader { + fn timeout(&self) -> u64 { + 1 + } + + async fn download( + &self, + _: Arc, + _: Arc, + _: &HeaderLocked, + _: &ForkchoiceState, + ) -> Result, DownloadError> { + self.result.clone() + } + } } } From b27f25ce1d4f194db117016a5fae1961bab69227 Mon Sep 17 00:00:00 2001 From: Georgios Konstantopoulos Date: Wed, 19 Oct 2022 18:37:12 -0700 Subject: [PATCH 10/13] chore: replace boxed iterator with vector --- crates/stages/src/stages/headers/linear.rs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/crates/stages/src/stages/headers/linear.rs b/crates/stages/src/stages/headers/linear.rs index bba7a16b0fc0..795b57cbd07b 100644 --- a/crates/stages/src/stages/headers/linear.rs +++ b/crates/stages/src/stages/headers/linear.rs @@ -33,7 +33,6 @@ pub struct LinearDownloader { pub request_retries: usize, } -type HeaderIter = Box + Send>; #[async_trait] impl Downloader for LinearDownloader { /// The request timeout @@ -54,7 +53,7 @@ impl Downloader for LinearDownloader { let mut retries = self.request_retries; // Header order will be preserved during inserts - let mut out = (Box::new(std::iter::empty()) as HeaderIter).peekable(); + let mut out = vec![]; loop { let result = self .download_batch( @@ -63,17 +62,23 @@ impl Downloader for LinearDownloader { consensus.clone(), forkchoice, head, - out.peek(), + out.get(0), ) .await; match result { Ok(result) => match result { LinearDownloadResult::Batch(headers) => { - out = (Box::new(headers.into_iter().chain(out)) as HeaderIter).peekable(); + // TODO: Should this instead be? + // headers.extend_from_slice(&out); + // out = headers; + out.extend_from_slice(&headers); } LinearDownloadResult::Finished(headers) => { - out = (Box::new(headers.into_iter().chain(out)) as HeaderIter).peekable(); - return Ok(out.collect()) + // TODO: Should this instead be? + // headers.extend_from_slice(&out); + // out = headers; + out.extend_from_slice(&headers); + return Ok(out) } LinearDownloadResult::Ignore => (), }, From 30018c734e19d5c5f9e2b7cd0b52dd6558542818 Mon Sep 17 00:00:00 2001 From: Georgios Konstantopoulos Date: Wed, 19 Oct 2022 23:26:33 -0700 Subject: [PATCH 11/13] feat(headers): replace arc dyn trait with generic (#107) * feat(interfaces): auto impl for ref/arc/box * feat(downloader): make consensus part of the downloader and a generic * impl generic for linear dl * impl generic for parallel dl * test(headers): make it work with generics * chore: rm dead code Co-authored-by: Roman Krasiuk --- Cargo.lock | 3 +- crates/interfaces/src/consensus.rs | 1 + crates/interfaces/src/stages.rs | 1 + crates/stages/Cargo.toml | 5 +- .../stages/src/stages/headers/downloader.rs | 11 +- crates/stages/src/stages/headers/linear.rs | 184 +++--------------- crates/stages/src/stages/headers/parallel.rs | 20 +- crates/stages/src/stages/headers/stage.rs | 22 ++- 8 files changed, 69 insertions(+), 178 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dac550cfba2a..1c98f0903c6e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2500,8 +2500,7 @@ dependencies = [ "assert_matches", "async-trait", "futures", - "pin-project", - "pin-project-lite", + "once_cell", "rand", "reth-db", "reth-interfaces", diff --git a/crates/interfaces/src/consensus.rs b/crates/interfaces/src/consensus.rs index 0e01b97565da..9a66daa7d314 100644 --- a/crates/interfaces/src/consensus.rs +++ b/crates/interfaces/src/consensus.rs @@ -8,6 +8,7 @@ use tokio::sync::watch::Receiver; /// Consensus is a protocol that chooses canonical chain. /// We are checking validity of block header here. #[async_trait] +#[auto_impl::auto_impl(&, Arc, Box)] pub trait Consensus: Sync + Send + Debug { /// Get a receiver for the fork choice state fn forkchoice_state(&self) -> Receiver; diff --git a/crates/interfaces/src/stages.rs b/crates/interfaces/src/stages.rs index 60d9d54766c0..ee7722b55029 100644 --- a/crates/interfaces/src/stages.rs +++ b/crates/interfaces/src/stages.rs @@ -20,6 +20,7 @@ pub struct HeaderRequest { /// The block headers downloader client #[async_trait] +#[auto_impl::auto_impl(&, Arc, Box)] pub trait HeadersClient: Send + Sync + Debug { /// Update the current node status async fn update_status(&self, height: u64, hash: H256, td: H256); diff --git a/crates/stages/Cargo.toml b/crates/stages/Cargo.toml index b62512348f2a..e5c646455bb2 100644 --- a/crates/stages/Cargo.toml +++ b/crates/stages/Cargo.toml @@ -17,8 +17,6 @@ tracing-futures = "0.2.5" tokio = { version = "1.21.2", features = ["sync"] } rand = "0.8" # TODO: tokio-stream = "0.1.11" -pin-project = "1.0.12" -pin-project-lite = "0.2" aquamarine = "0.1.12" # async/futures @@ -35,4 +33,5 @@ tokio-stream = { version = "0.1.11", features = ["sync"] } tempfile = "3.3.0" reth-db = { path = "../db", features = ["test-utils"] } reth-rpc-types = { path = "../net/rpc-types" } -assert_matches = "1.5.0" \ No newline at end of file +assert_matches = "1.5.0" +once_cell = "1.15.0" diff --git a/crates/stages/src/stages/headers/downloader.rs b/crates/stages/src/stages/headers/downloader.rs index d475d3678b2d..a317fb72a97a 100644 --- a/crates/stages/src/stages/headers/downloader.rs +++ b/crates/stages/src/stages/headers/downloader.rs @@ -13,14 +13,20 @@ use tokio_stream::StreamExt; /// The header downloading strategy #[async_trait] pub trait Downloader: Sync + Send + Debug { + /// The Consensus used to verify block validity when + /// downloading + type Consensus: Consensus; + /// The request timeout in seconds fn timeout(&self) -> u64; + /// The consensus engine + fn consensus(&self) -> &Self::Consensus; + /// Download the headers async fn download( &self, client: Arc, - consensus: Arc, head: &HeaderLocked, forkchoice: &ForkchoiceState, ) -> Result, DownloadError>; @@ -51,7 +57,6 @@ pub trait Downloader: Sync + Send + Debug { /// Validate whether the header is valid in relation to it's parent fn validate( &self, - consensus: Arc, header: &HeaderLocked, parent: &HeaderLocked, ) -> Result { @@ -59,7 +64,7 @@ pub trait Downloader: Sync + Send + Debug { return Ok(false) } - consensus.validate_header(&header, &parent).map_err(|e| { + self.consensus().validate_header(&header, &parent).map_err(|e| { DownloadError::HeaderValidation { hash: parent.hash(), details: e.to_string() } })?; Ok(true) diff --git a/crates/stages/src/stages/headers/linear.rs b/crates/stages/src/stages/headers/linear.rs index 795b57cbd07b..9b8a76c09e5b 100644 --- a/crates/stages/src/stages/headers/linear.rs +++ b/crates/stages/src/stages/headers/linear.rs @@ -1,8 +1,6 @@ use super::downloader::{DownloadError, Downloader}; use async_trait::async_trait; use futures::{future::BoxFuture, stream::BoxStream, Future, FutureExt}; -use pin_project::pin_project; -use pin_project_lite::pin_project as pin_project_lite; use rand::Rng; use reth_interfaces::{ consensus::Consensus, @@ -24,7 +22,8 @@ use tokio_stream::Stream; /// Download headers in batches #[derive(Debug)] -pub struct LinearDownloader { +pub struct LinearDownloader<'a, C: Consensus> { + consensus: &'a C, /// The batch size per one request pub batch_size: u64, /// A single request timeout @@ -34,7 +33,13 @@ pub struct LinearDownloader { } #[async_trait] -impl Downloader for LinearDownloader { +impl<'a, C: Consensus> Downloader for LinearDownloader<'a, C> { + type Consensus = C; + + fn consensus(&self) -> &Self::Consensus { + self.consensus + } + /// The request timeout fn timeout(&self) -> u64 { self.request_timeout @@ -45,7 +50,6 @@ impl Downloader for LinearDownloader { async fn download( &self, client: Arc, - consensus: Arc, head: &HeaderLocked, forkchoice: &ForkchoiceState, ) -> Result, DownloadError> { @@ -56,14 +60,7 @@ impl Downloader for LinearDownloader { let mut out = vec![]; loop { let result = self - .download_batch( - &mut stream, - client.clone(), - consensus.clone(), - forkchoice, - head, - out.get(0), - ) + .download_batch(&mut stream, client.clone(), forkchoice, head, out.get(0)) .await; match result { Ok(result) => match result { @@ -102,12 +99,11 @@ pub enum LinearDownloadResult { Ignore, } -impl LinearDownloader { - async fn download_batch<'a>( +impl<'a, C: Consensus> LinearDownloader<'a, C> { + async fn download_batch( &'a self, stream: &'a mut MessageStream<(u64, Vec
)>, client: Arc, - consensus: Arc, forkchoice: &'a ForkchoiceState, head: &'a HeaderLocked, earliest: Option<&HeaderLocked>, @@ -130,7 +126,7 @@ impl LinearDownloader { } match out.first().or(earliest) { - Some(header) if !self.validate(consensus.clone(), header, &parent)? => { + Some(header) if !self.validate(header, &parent)? => { return Ok(LinearDownloadResult::Ignore) } // The buffer is empty and the first header does not match the tip, discard @@ -148,130 +144,6 @@ impl LinearDownloader { } } -mod linear_stream { - use super::*; - - pin_project_lite! { - pub(crate) struct LinearDownloadStream<'a, S: Stream)>> { - #[pin] - stream: &'a mut S, - #[pin] - state: LinearStreamState<'a>, - client: &'a Arc, - consensus: &'a Arc, - tip: H256, - head: H256, - earliest: Option, - retries: usize, - } - } - - impl<'a, S: Stream)> + Unpin> LinearDownloadStream<'a, S> { - pub(crate) fn new( - stream: &'a mut S, - client: &'a Arc, - consensus: &'a Arc, - tip: H256, - head: H256, - retries: usize, - ) -> Self { - Self { - stream, - state: LinearStreamState::Prepare, - tip, - head, - client, - consensus, - earliest: None, - retries, - } - } - } - - enum LinearStreamState<'a> { - Prepare, - PollRequest(u64, BoxFuture<'a, HashSet>), - PollHeaders(u64), - Done, - } - - impl<'a, S: Stream)> + Unpin + Send> Stream - for LinearDownloadStream<'a, S> - { - type Item = Result, DownloadError>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - match *this.state { - LinearStreamState::Prepare => { - let request_id = rand::thread_rng().gen(); - let request = - HeaderRequest { start: BlockId::Hash(*this.tip), limit: 1, reverse: true }; - this.state.set(LinearStreamState::PollRequest( - request_id, - this.client.send_header_request(request_id, request), - )); - Poll::Pending - } - LinearStreamState::PollRequest(req_id, ref mut fut) => match fut.poll_unpin(cx) { - Poll::Ready(_peers) => { - this.state.set(LinearStreamState::PollHeaders(req_id)); - Poll::Pending - } - Poll::Pending => Poll::Pending, - }, - LinearStreamState::PollHeaders(req_id) => match this.stream.poll_next(cx) { - Poll::Ready(Some((id, mut headers))) if id == req_id && !headers.is_empty() => { - headers.sort_unstable_by_key(|h| h.number); - - this.state.set(LinearStreamState::Prepare); - let mut out = Vec::with_capacity(headers.len()); - for parent in headers.into_iter().rev() { - let parent = parent.lock(); - if *this.head == parent.hash() { - this.state.set(LinearStreamState::Done); - break - } - - match out.first().or(this.earliest.as_ref()) { - Some(header) => { - if !(parent.hash() == header.parent_hash && - parent.number + 1 == header.number) - { - return Poll::Pending - } - - if let Err(e) = this.consensus.validate_header(&header, &parent) - { - return Poll::Ready(Some(Err( - DownloadError::HeaderValidation { - hash: parent.hash(), - details: e.to_string(), - }, - ))) - } - } - None if parent.hash() != *this.tip => return Poll::Pending, - _ => (), - }; - - out.insert(0, parent); - } - *this.earliest = Some(out.first().unwrap().clone()); - Poll::Ready(Some(Ok(out))) - } - _ => Poll::Pending, - }, - LinearStreamState::Done => Poll::Ready(None), - } - } - - fn size_hint(&self) -> (usize, Option) { - (1, None) - } - } -} - #[cfg(test)] mod tests { use super::{super::stage::tests::test_utils, DownloadError, Downloader, LinearDownloader}; @@ -283,6 +155,16 @@ mod tests { use tokio::sync::{broadcast, mpsc, oneshot::error::TryRecvError}; use tokio_stream::{wrappers::ReceiverStream, StreamExt}; + use once_cell::sync::Lazy; + use test_utils::TestConsensus; + + static CONSENSUS: Lazy = Lazy::new(|| TestConsensus::new()); + static CONSENSUS_FAIL: Lazy = Lazy::new(|| { + let mut consensus = TestConsensus::new(); + consensus.set_fail_validation(true); + consensus + }); + #[tokio::test] async fn download_timeout() { let (req_tx, req_rx) = mpsc::channel(1); @@ -291,7 +173,7 @@ mod tests { let runner = test_runner::LinearTestRunner::new(); let retries = runner.retries; let rx = runner.run( - test_utils::TestConsensus::new(), + &*CONSENSUS, test_utils::TestHeaderClient::new(req_tx, res_rx), HeaderLocked::default(), H256::zero(), @@ -310,7 +192,7 @@ mod tests { let runner = test_runner::LinearTestRunner::new(); let retries = runner.retries; let rx = runner.run( - test_utils::TestConsensus::new(), + &*CONSENSUS, test_utils::TestHeaderClient::new(req_tx, res_rx), HeaderLocked::default(), H256::zero(), @@ -355,12 +237,9 @@ mod tests { tip_header.parent_hash = parent_hash; let chain_tip = tip_header.hash_slow(); - let mut consensus = test_utils::TestConsensus::new(); - consensus.set_fail_validation(true); - let runner = test_runner::LinearTestRunner::new(); let rx = runner.run( - consensus, + &*CONSENSUS_FAIL, test_utils::TestHeaderClient::new(req_tx, res_rx), HeaderLocked::default(), chain_tip, @@ -400,7 +279,7 @@ mod tests { let runner = test_runner::LinearTestRunner::new(); let mut rx = runner.run( - test_utils::TestConsensus::new(), + &*CONSENSUS, test_utils::TestHeaderClient::new(req_tx, res_rx), tip_parent.clone().lock(), tip.hash_slow(), @@ -443,9 +322,9 @@ mod tests { Self { test_ch: oneshot::channel(), retries: 5 } } - pub(crate) fn run<'a>( + pub(crate) fn run<'a, C: Consensus>( self, - consensus: impl Consensus + 'static, + consensus: &'static C, client: impl HeadersClient + 'static, head: HeaderLocked, tip: H256, @@ -454,14 +333,13 @@ mod tests { let downloader = LinearDownloader { request_retries: self.retries, batch_size: 100, + consensus, request_timeout: 3, }; tokio::spawn(async move { let mut forkchoice = ForkchoiceState::default(); forkchoice.head_block_hash = tip; - let result = downloader - .download(Arc::new(client), Arc::new(consensus), &head, &forkchoice) - .await; + let result = downloader.download(Arc::new(client), &head, &forkchoice).await; tx.send(result).expect("failed to forward download response"); }); rx diff --git a/crates/stages/src/stages/headers/parallel.rs b/crates/stages/src/stages/headers/parallel.rs index e5801e24497b..e045b1783610 100644 --- a/crates/stages/src/stages/headers/parallel.rs +++ b/crates/stages/src/stages/headers/parallel.rs @@ -10,7 +10,9 @@ use std::sync::Arc; /// TODO: #[derive(Debug)] -pub struct ParallelDownloader { +pub struct ParallelDownloader<'a, C> { + /// The consensus engine + pub consensus: &'a C, /// The number of parallel requests pub par_count: usize, /// The batch size per one request @@ -22,17 +24,22 @@ pub struct ParallelDownloader { } #[async_trait] -impl Downloader for ParallelDownloader { +impl<'a, C: Consensus> Downloader for ParallelDownloader<'a, C> { + type Consensus = C; + /// The request timeout fn timeout(&self) -> u64 { self.request_timeout } + fn consensus(&self) -> &Self::Consensus { + self.consensus + } + /// Download the headers async fn download( &self, client: Arc, - consensus: Arc, head: &HeaderLocked, forkchoice: &ForkchoiceState, ) -> Result, DownloadError> { @@ -77,11 +84,10 @@ enum ParallelResult { ReachedHead, } -impl ParallelDownloader { - async fn download_batch<'a>( +impl<'a, C: Consensus> ParallelDownloader<'a, C> { + async fn download_batch( &'a self, client: Arc, - consensus: Arc, head: &'a HeaderLocked, forkchoice: &'a ForkchoiceState, stream: &'a mut MessageStream<(u64, Vec
)>, @@ -108,7 +114,7 @@ impl ParallelDownloader { } match out.first() { - Some(tail_header) if !self.validate(consensus.clone(), tail_header, &parent)? => { + Some(tail_header) if !self.validate(tail_header, &parent)? => { // Cannot attach to the current buffer, discard return Ok(ParallelResult::Discard) } diff --git a/crates/stages/src/stages/headers/stage.rs b/crates/stages/src/stages/headers/stage.rs index bb776b3cc2cc..cdf3c0eb9927 100644 --- a/crates/stages/src/stages/headers/stage.rs +++ b/crates/stages/src/stages/headers/stage.rs @@ -16,11 +16,11 @@ const HEADERS: StageId = StageId("HEADERS"); /// The headers stage implementation for staged sync #[derive(Debug)] -pub struct HeaderStage { +pub struct HeaderStage { /// Strategy for downloading the headers pub downloader: D, /// Consensus client implementation - pub consensus: Arc, + pub consensus: C, /// Downloader client implementation pub client: Arc, } @@ -43,7 +43,7 @@ impl Into for HeaderStageError { } #[async_trait] -impl<'db, E, D: Downloader> Stage<'db, E> for HeaderStage +impl<'db, E, D: Downloader, C: Consensus> Stage<'db, E> for HeaderStage where E: mdbx::EnvironmentKind, { @@ -81,10 +81,7 @@ where let head = HeaderLocked::new(last_header, last_hash); let forkchoice = self.next_forkchoice_state(&head.hash()).await; - let headers = match self - .downloader - .download(self.client.clone(), self.consensus.clone(), &head, &forkchoice) - .await + let headers = match self.downloader.download(self.client.clone(), &head, &forkchoice).await { Ok(res) => res, Err(e) => match e { @@ -157,7 +154,7 @@ where } } -impl HeaderStage { +impl HeaderStage { async fn update_head<'tx, E: mdbx::EnvironmentKind>( &self, tx: &'tx mut Tx<'_, mdbx::RW, E>, @@ -266,7 +263,7 @@ pub(crate) mod tests { let mut stage = HeaderStage { client: Arc::new(test_utils::TestHeaderClient::new(req_tx, res_rx)), - consensus: Arc::new(test_utils::TestConsensus::new()), + consensus: &test_utils::TestConsensus::new(), downloader: test_utils::TestDownloader::new(Ok(vec![])), }; @@ -408,14 +405,19 @@ pub(crate) mod tests { #[async_trait] impl Downloader for TestDownloader { + type Consensus = TestConsensus; + fn timeout(&self) -> u64 { 1 } + fn consensus(&self) -> &Self::Consensus { + unimplemented!() + } + async fn download( &self, _: Arc, - _: Arc, _: &HeaderLocked, _: &ForkchoiceState, ) -> Result, DownloadError> { From 930f5bd03176f95f015fc3e9130589a058621c23 Mon Sep 17 00:00:00 2001 From: Roman Krasiuk Date: Thu, 20 Oct 2022 12:31:06 +0300 Subject: [PATCH 12/13] replace dyn client with generic --- Cargo.lock | 147 +++++++++----- .../stages/src/stages/headers/downloader.rs | 9 +- crates/stages/src/stages/headers/linear.rs | 179 +++++++----------- crates/stages/src/stages/headers/mod.rs | 2 - crates/stages/src/stages/headers/parallel.rs | 138 -------------- crates/stages/src/stages/headers/stage.rs | 104 +++++----- 6 files changed, 231 insertions(+), 348 deletions(-) delete mode 100644 crates/stages/src/stages/headers/parallel.rs diff --git a/Cargo.lock b/Cargo.lock index 1c98f0903c6e..cbec6d64e5ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -87,9 +87,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.57" +version = "0.1.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76464446b8bc32758d7e88ee1a804d9914cd9b1cb264c029899680b0be29826f" +checksum = "1e805d94e6b5001b651426cf4cd446b1ab5f319d27bab5c644f61de0a804360c" dependencies = [ "proc-macro2", "quote", @@ -163,9 +163,9 @@ checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd" [[package]] name = "base64ct" -version = "1.5.2" +version = "1.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea2b2456fd614d856680dcd9fcc660a51a820fa09daef2e49772b56a193c8474" +checksum = "b645a089122eccb6111b4f81cbc1a49f5900ac4666bb93ac027feaecf15607bf" [[package]] name = "beef" @@ -290,9 +290,9 @@ checksum = "b4ae4235e6dac0694637c763029ecea1a2ec9e4e06ec2729bd21ba4d9c863eb7" [[package]] name = "bumpalo" -version = "3.11.0" +version = "3.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1ad822118d20d2c234f427000d5acc36eabe1e29a348c89b63dd60b13f28e5d" +checksum = "572f695136211188308f16ad2ca5c851a712c464060ae6974944458eb83880ba" [[package]] name = "byte-slice-cast" @@ -414,9 +414,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.0.15" +version = "4.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bf8832993da70a4c6d13c581f4463c2bdda27b9bf1c5498dc4365543abe6d6f" +checksum = "06badb543e734a2d6568e19a40af66ed5364360b9226184926f89d229b4b4267" dependencies = [ "atty", "bitflags", @@ -767,7 +767,7 @@ dependencies = [ [[package]] name = "enr" version = "0.6.2" -source = "git+https://github.com/sigp/enr#125f8a5f2deede3e47e852ea70fc9b6e1e9c6e50" +source = "git+https://github.com/sigp/enr#fba51d4473f1b6fcc66161cd593352b70995e702" dependencies = [ "base64", "bs58", @@ -848,7 +848,7 @@ dependencies = [ [[package]] name = "ethers-core" version = "0.17.0" -source = "git+https://github.com/gakonst/ethers-rs#a07581489a12b1007c3a261dc8df2bbdc4e27918" +source = "git+https://github.com/gakonst/ethers-rs#a9dd53da810d8eff82aa77e0f9297b4a453028e6" dependencies = [ "arrayvec", "bytes", @@ -963,9 +963,9 @@ checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" [[package]] name = "futures" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f21eda599937fba36daeb58a22e8f5cee2d14c4a17b5b7739c7c8e5e3b8230c" +checksum = "38390104763dc37a5145a53c29c63c1290b5d316d6086ec32c293f6736051bb0" dependencies = [ "futures-channel", "futures-core", @@ -978,9 +978,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30bdd20c28fadd505d0fd6712cdfcb0d4b5648baf45faef7f852afb2399bb050" +checksum = "52ba265a92256105f45b719605a571ffe2d1f0fea3807304b522c1d778f79eed" dependencies = [ "futures-core", "futures-sink", @@ -988,15 +988,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e5aa3de05362c3fb88de6531e6296e85cde7739cccad4b9dfeeb7f6ebce56bf" +checksum = "04909a7a7e4633ae6c4a9ab280aeb86da1236243a77b694a49eacd659a4bd3ac" [[package]] name = "futures-executor" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ff63c23854bee61b6e9cd331d523909f238fc7636290b96826e9cfa5faa00ab" +checksum = "7acc85df6714c176ab5edf386123fafe217be88c0840ec11f199441134a074e2" dependencies = [ "futures-core", "futures-task", @@ -1005,15 +1005,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbf4d2a7a308fd4578637c0b17c7e1c7ba127b8f6ba00b29f717e9655d85eb68" +checksum = "00f5fb52a06bdcadeb54e8d3671f8888a39697dcb0b81b23b55174030427f4eb" [[package]] name = "futures-macro" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42cd15d1c7456c04dbdf7e88bcd69760d74f3a798d6444e16974b505b0e62f17" +checksum = "bdfb8ce053d86b91919aad980c220b1fb8401a9394410e1c289ed7e66b61835d" dependencies = [ "proc-macro2", "quote", @@ -1022,15 +1022,15 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b20ba5a92e727ba30e72834706623d94ac93a725410b6a6b6fbc1b07f7ba56" +checksum = "39c15cf1a4aa79df40f1bb462fb39676d0ad9e366c2a33b590d7c66f4f81fcf9" [[package]] name = "futures-task" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6508c467c73851293f390476d4491cf4d227dbabcd4170f3bb6044959b294f1" +checksum = "2ffb393ac5d9a6eaa9d3fdf37ae2776656b706e200c8e16b1bdb227f5198e6ea" [[package]] name = "futures-timer" @@ -1044,9 +1044,9 @@ dependencies = [ [[package]] name = "futures-util" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44fb6cb1be61cc1d2e43b262516aafcf63b241cffdb1d3fa115f91d9c7b09c90" +checksum = "197676987abd2f9cadff84926f410af1c183608d36641465df73ae8211dc65d6" dependencies = [ "futures-channel", "futures-core", @@ -1749,7 +1749,7 @@ dependencies = [ "libc", "log", "wasi", - "windows-sys", + "windows-sys 0.36.1", ] [[package]] @@ -1963,15 +1963,15 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09a279cbf25cb0757810394fbc1e359949b59e348145c643a939a525692e6929" +checksum = "4dc9e0dc2adc1c69d09143aff38d3d30c5c3f0df0dad82e6d25547af174ebec0" dependencies = [ "cfg-if", "libc", "redox_syscall", "smallvec", - "windows-sys", + "windows-sys 0.42.0", ] [[package]] @@ -2280,7 +2280,7 @@ dependencies = [ name = "reth" version = "0.1.0" dependencies = [ - "clap 4.0.15", + "clap 4.0.17", "eyre", "reth-primitives", "serde", @@ -2688,9 +2688,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.20.6" +version = "0.20.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aab8ee6c7097ed6057f43c187a62418d0c05a4bd5f18b3571db50ee0f9ce033" +checksum = "539a2bfe908f471bfa933876bd1eb6a19cf2176d375f82ef7f99530a40e48c2c" dependencies = [ "log", "ring", @@ -2783,7 +2783,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "88d6731146462ea25d9244b2ed5fd1d716d25c52e4d54aa4fb0f3c4e9854dbe2" dependencies = [ "lazy_static", - "windows-sys", + "windows-sys 0.36.1", ] [[package]] @@ -2907,9 +2907,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.86" +version = "1.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41feea4228a6f1cd09ec7a3593a682276702cd67b5273544757dae23c096f074" +checksum = "6ce777b7b150d76b9cf60d28b55f5847135a003f7d7350c6be7a773508ce7d45" dependencies = [ "itoa", "ryu", @@ -2942,9 +2942,9 @@ dependencies = [ [[package]] name = "sha3" -version = "0.10.5" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2904bea16a1ae962b483322a1c7b81d976029203aea1f461e51cd7705db7ba9" +checksum = "bdf0c33fae925bdc080598b84bc15c55e7b9a4a43b3c704da051f977469691c9" dependencies = [ "digest 0.10.5", "keccak", @@ -3558,43 +3558,100 @@ version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2" dependencies = [ - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_msvc", + "windows_aarch64_msvc 0.36.1", + "windows_i686_gnu 0.36.1", + "windows_i686_msvc 0.36.1", + "windows_x86_64_gnu 0.36.1", + "windows_x86_64_msvc 0.36.1", ] +[[package]] +name = "windows-sys" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc 0.42.0", + "windows_i686_gnu 0.42.0", + "windows_i686_msvc 0.42.0", + "windows_x86_64_gnu 0.42.0", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc 0.42.0", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d2aa71f6f0cbe00ae5167d90ef3cfe66527d6f613ca78ac8024c3ccab9a19e" + [[package]] name = "windows_aarch64_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd0f252f5a35cac83d6311b2e795981f5ee6e67eb1f9a7f64eb4500fbc4dcdb4" + [[package]] name = "windows_i686_gnu" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6" +[[package]] +name = "windows_i686_gnu" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbeae19f6716841636c28d695375df17562ca208b2b7d0dc47635a50ae6c5de7" + [[package]] name = "windows_i686_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024" +[[package]] +name = "windows_i686_msvc" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84c12f65daa39dd2babe6e442988fc329d6243fdce47d7d2d155b8d874862246" + [[package]] name = "windows_x86_64_gnu" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf7b1b21b5362cbc318f686150e5bcea75ecedc74dd157d874d754a2ca44b0ed" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09d525d2ba30eeb3297665bd434a54297e4170c7f1a44cad4ef58095b4cd2028" + [[package]] name = "windows_x86_64_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40009d85759725a34da6d89a94e63d7bdc50a862acf0dbc7c8e488f1edcb6f5" + [[package]] name = "wyz" version = "0.5.0" diff --git a/crates/stages/src/stages/headers/downloader.rs b/crates/stages/src/stages/headers/downloader.rs index a317fb72a97a..70fa12c61f25 100644 --- a/crates/stages/src/stages/headers/downloader.rs +++ b/crates/stages/src/stages/headers/downloader.rs @@ -16,6 +16,8 @@ pub trait Downloader: Sync + Send + Debug { /// The Consensus used to verify block validity when /// downloading type Consensus: Consensus; + /// The Client used to download the headers + type Client: HeadersClient; /// The request timeout in seconds fn timeout(&self) -> u64; @@ -23,10 +25,12 @@ pub trait Downloader: Sync + Send + Debug { /// The consensus engine fn consensus(&self) -> &Self::Consensus; + /// The headers client + fn client(&self) -> &Self::Client; + /// Download the headers async fn download( &self, - client: Arc, head: &HeaderLocked, forkchoice: &ForkchoiceState, ) -> Result, DownloadError>; @@ -35,13 +39,12 @@ pub trait Downloader: Sync + Send + Debug { async fn download_headers( &self, stream: &mut MessageStream<(u64, Vec
)>, - client: Arc, start: BlockId, limit: u64, ) -> Result, DownloadError> { let request_id = rand::thread_rng().gen(); let request = HeaderRequest { start, limit, reverse: true }; - let _ = client.send_header_request(request_id, request).await; + let _ = self.client().send_header_request(request_id, request).await; // Filter stream by request id and non empty headers content let stream = stream.filter(|(id, headers)| request_id == *id && !headers.is_empty()); diff --git a/crates/stages/src/stages/headers/linear.rs b/crates/stages/src/stages/headers/linear.rs index 9b8a76c09e5b..62a0936ad4b1 100644 --- a/crates/stages/src/stages/headers/linear.rs +++ b/crates/stages/src/stages/headers/linear.rs @@ -1,29 +1,19 @@ use super::downloader::{DownloadError, Downloader}; use async_trait::async_trait; -use futures::{future::BoxFuture, stream::BoxStream, Future, FutureExt}; -use rand::Rng; use reth_interfaces::{ consensus::Consensus, - stages::{HeaderRequest, HeadersClient, MessageStream}, + stages::{HeadersClient, MessageStream}, }; -use reth_primitives::{rpc::BlockId, Header, HeaderLocked, H256, H512}; +use reth_primitives::{rpc::BlockId, Header, HeaderLocked}; use reth_rpc_types::engine::ForkchoiceState; -use std::{ - collections::HashSet, - ops::DerefMut, - pin::Pin, - process::Output, - sync::{Arc, Mutex}, - task::{Context, Poll}, - time::Duration, -}; -use tokio::time::{Instant, Sleep}; -use tokio_stream::Stream; /// Download headers in batches #[derive(Debug)] -pub struct LinearDownloader<'a, C: Consensus> { +pub struct LinearDownloader<'a, C: Consensus, H: HeadersClient> { + /// The consensus client consensus: &'a C, + /// The headers client + client: &'a H, /// The batch size per one request pub batch_size: u64, /// A single request timeout @@ -33,13 +23,18 @@ pub struct LinearDownloader<'a, C: Consensus> { } #[async_trait] -impl<'a, C: Consensus> Downloader for LinearDownloader<'a, C> { +impl<'a, C: Consensus, H: HeadersClient> Downloader for LinearDownloader<'a, C, H> { type Consensus = C; + type Client = H; fn consensus(&self) -> &Self::Consensus { self.consensus } + fn client(&self) -> &Self::Client { + self.client + } + /// The request timeout fn timeout(&self) -> u64 { self.request_timeout @@ -49,19 +44,16 @@ impl<'a, C: Consensus> Downloader for LinearDownloader<'a, C> { /// Returns the header collection in sorted ascending order async fn download( &self, - client: Arc, head: &HeaderLocked, forkchoice: &ForkchoiceState, ) -> Result, DownloadError> { - let mut stream = client.stream_headers().await; + let mut stream = self.client().stream_headers().await; let mut retries = self.request_retries; // Header order will be preserved during inserts let mut out = vec![]; loop { - let result = self - .download_batch(&mut stream, client.clone(), forkchoice, head, out.get(0)) - .await; + let result = self.download_batch(&mut stream, forkchoice, head, out.get(0)).await; match result { Ok(result) => match result { LinearDownloadResult::Batch(headers) => { @@ -99,20 +91,18 @@ pub enum LinearDownloadResult { Ignore, } -impl<'a, C: Consensus> LinearDownloader<'a, C> { +impl<'a, C: Consensus, H: HeadersClient> LinearDownloader<'a, C, H> { async fn download_batch( &'a self, stream: &'a mut MessageStream<(u64, Vec
)>, - client: Arc, forkchoice: &'a ForkchoiceState, head: &'a HeaderLocked, earliest: Option<&HeaderLocked>, ) -> Result { // Request headers starting from tip or earliest cached let start = earliest.map_or(forkchoice.head_block_hash, |h| h.parent_hash); - let mut headers = self - .download_headers(stream, client.clone(), BlockId::Hash(start), self.batch_size) - .await?; + let mut headers = + self.download_headers(stream, BlockId::Hash(start), self.batch_size).await?; headers.sort_unstable_by_key(|h| h.number); let mut out = Vec::with_capacity(headers.len()); @@ -146,17 +136,16 @@ impl<'a, C: Consensus> LinearDownloader<'a, C> { #[cfg(test)] mod tests { - use super::{super::stage::tests::test_utils, DownloadError, Downloader, LinearDownloader}; + use super::{ + super::stage::tests::test_utils::{TestConsensus, TestHeaderClient}, + DownloadError, Downloader, LinearDownloader, + }; use assert_matches::assert_matches; - use rand::{self, Rng}; + use once_cell::sync::Lazy; + use rand::Rng; use reth_interfaces::stages::HeaderRequest; use reth_primitives::{rpc::BlockId, Header, HeaderLocked, H256}; - use std::sync::Arc; - use tokio::sync::{broadcast, mpsc, oneshot::error::TryRecvError}; - use tokio_stream::{wrappers::ReceiverStream, StreamExt}; - - use once_cell::sync::Lazy; - use test_utils::TestConsensus; + use tokio::sync::oneshot::error::TryRecvError; static CONSENSUS: Lazy = Lazy::new(|| TestConsensus::new()); static CONSENSUS_FAIL: Lazy = Lazy::new(|| { @@ -165,67 +154,49 @@ mod tests { consensus }); + static CLIENT: Lazy = Lazy::new(|| TestHeaderClient::new()); + #[tokio::test] async fn download_timeout() { - let (req_tx, req_rx) = mpsc::channel(1); - let (_res_tx, res_rx) = broadcast::channel(1); - let runner = test_runner::LinearTestRunner::new(); let retries = runner.retries; - let rx = runner.run( - &*CONSENSUS, - test_utils::TestHeaderClient::new(req_tx, res_rx), - HeaderLocked::default(), - H256::zero(), - ); - - let requests = ReceiverStream::new(req_rx).collect::>().await; + let rx = runner.run(&*CONSENSUS, &*CLIENT, HeaderLocked::default(), H256::zero()); + + let mut requests = vec![]; + CLIENT + .on_header_request(retries, |_id, req| { + requests.push(req); + }) + .await; assert_eq!(requests.len(), retries); assert_matches!(rx.await, Ok(Err(DownloadError::NoHeaderResponse { .. }))); } #[tokio::test] async fn download_timeout_on_invalid_messages() { - let (req_tx, req_rx) = mpsc::channel(1); - let (res_tx, res_rx) = broadcast::channel(1); - let runner = test_runner::LinearTestRunner::new(); let retries = runner.retries; - let rx = runner.run( - &*CONSENSUS, - test_utils::TestHeaderClient::new(req_tx, res_rx), - HeaderLocked::default(), - H256::zero(), - ); + let rx = runner.run(&*CONSENSUS, &*CLIENT, HeaderLocked::default(), H256::zero()); let mut num_of_reqs = 0; - let mut last_req_id = None; - let mut req_stream = ReceiverStream::new(req_rx); - while let Some((id, _req)) = req_stream.next().await { - // Since the receiving channel filters by id and message length - - // randomize the input to the tested filter - res_tx.send((id.saturating_add(id % 2), vec![])).expect("failed to send response"); - num_of_reqs += 1; - last_req_id = Some(id); - - if num_of_reqs == retries { - drop(res_tx); - break - } - } + let mut last_req_id: Option = None; + + CLIENT + .on_header_request(retries, |id, _req| { + num_of_reqs += 1; + last_req_id = Some(id); + CLIENT.send_header_response(id.saturating_add(id % 2), vec![]) + }) + .await; assert_eq!(num_of_reqs, retries); assert_matches!( rx.await, - Ok(Err(DownloadError::NoHeaderResponse { request_id })) if request_id == last_req_id.unwrap() - ); + Ok(Err(DownloadError::NoHeaderResponse { request_id })) if request_id == last_req_id.unwrap()); } #[tokio::test] async fn download_propagates_consensus_validation_error() { - let (req_tx, req_rx) = mpsc::channel(1); - let (res_tx, res_rx) = broadcast::channel(1); - let mut tip_parent = Header::default(); tip_parent.nonce = rand::thread_rng().gen(); tip_parent.number = 1; @@ -238,23 +209,19 @@ mod tests { let chain_tip = tip_header.hash_slow(); let runner = test_runner::LinearTestRunner::new(); - let rx = runner.run( - &*CONSENSUS_FAIL, - test_utils::TestHeaderClient::new(req_tx, res_rx), - HeaderLocked::default(), - chain_tip, - ); + let rx = runner.run(&*CONSENSUS_FAIL, &*CLIENT, HeaderLocked::default(), chain_tip); - let mut stream = Box::pin(ReceiverStream::new(req_rx)); - let request = stream.next().await; + let requests = CLIENT.on_header_request(1, |id, req| (id, req)).await; + + let request = requests.last(); assert_matches!( request, Some((_, HeaderRequest { start, .. })) - if matches!(start, BlockId::Hash(hash) if hash == chain_tip) + if matches!(start, BlockId::Hash(hash) if *hash == chain_tip) ); let request = request.unwrap(); - res_tx.send((request.0, vec![tip_header, tip_parent])).expect("failed to send header"); + CLIENT.send_header_response(request.0, vec![tip_header, tip_parent]); assert_matches!( rx.await, @@ -264,9 +231,6 @@ mod tests { #[tokio::test] async fn download_starts_with_chain_tip() { - let (req_tx, req_rx) = mpsc::channel(1); - let (res_tx, res_rx) = broadcast::channel(1); - let mut tip_parent = Header::default(); tip_parent.nonce = rand::thread_rng().gen(); tip_parent.number = 1; @@ -278,26 +242,22 @@ mod tests { tip.nonce = rand::thread_rng().gen(); let runner = test_runner::LinearTestRunner::new(); - let mut rx = runner.run( - &*CONSENSUS, - test_utils::TestHeaderClient::new(req_tx, res_rx), - tip_parent.clone().lock(), - tip.hash_slow(), - ); - - let mut stream = ReceiverStream::new(req_rx); - let request = stream.next().await.unwrap(); - let mut corrupted_tip = tip.clone(); - corrupted_tip.nonce = rand::thread_rng().gen(); - res_tx - .send((request.0, vec![corrupted_tip, tip_parent.clone()])) - .expect("failed to send header"); + let mut rx = runner.run(&*CONSENSUS, &*CLIENT, tip_parent.clone().lock(), tip.hash_slow()); + + CLIENT + .on_header_request(1, |id, _req| { + let mut corrupted_tip = tip.clone(); + corrupted_tip.nonce = rand::thread_rng().gen(); + CLIENT.send_header_response(id, vec![corrupted_tip, tip_parent.clone()]) + }) + .await; assert_matches!(rx.try_recv(), Err(TryRecvError::Empty)); - let request = stream.next().await.unwrap(); - res_tx - .send((request.0, vec![tip.clone(), tip_parent.clone()])) - .expect("failed to send header"); + CLIENT + .on_header_request(1, |id, _req| { + CLIENT.send_header_response(id, vec![tip.clone(), tip_parent.clone()]) + }) + .await; let result = rx.await; assert_matches!(result, Ok(Ok(ref val)) if val.len() == 1); @@ -322,24 +282,25 @@ mod tests { Self { test_ch: oneshot::channel(), retries: 5 } } - pub(crate) fn run<'a, C: Consensus>( + pub(crate) fn run<'a, C: Consensus, H: HeadersClient>( self, consensus: &'static C, - client: impl HeadersClient + 'static, + client: &'static H, head: HeaderLocked, tip: H256, ) -> oneshot::Receiver { let (tx, rx) = self.test_ch; let downloader = LinearDownloader { + consensus, + client, request_retries: self.retries, batch_size: 100, - consensus, request_timeout: 3, }; tokio::spawn(async move { let mut forkchoice = ForkchoiceState::default(); forkchoice.head_block_hash = tip; - let result = downloader.download(Arc::new(client), &head, &forkchoice).await; + let result = downloader.download(&head, &forkchoice).await; tx.send(result).expect("failed to forward download response"); }); rx diff --git a/crates/stages/src/stages/headers/mod.rs b/crates/stages/src/stages/headers/mod.rs index 336ddaf474cc..7191aa0902e1 100644 --- a/crates/stages/src/stages/headers/mod.rs +++ b/crates/stages/src/stages/headers/mod.rs @@ -5,5 +5,3 @@ pub mod stage; pub mod downloader; /// The linear downloading strategy pub mod linear; -/// The parallel downloading strategy -pub mod parallel; diff --git a/crates/stages/src/stages/headers/parallel.rs b/crates/stages/src/stages/headers/parallel.rs deleted file mode 100644 index e045b1783610..000000000000 --- a/crates/stages/src/stages/headers/parallel.rs +++ /dev/null @@ -1,138 +0,0 @@ -use super::downloader::{DownloadError, Downloader}; -use async_trait::async_trait; -use reth_interfaces::{ - consensus::Consensus, - stages::{HeadersClient, MessageStream}, -}; -use reth_primitives::{rpc::BlockId, Header, HeaderLocked, H256}; -use reth_rpc_types::engine::ForkchoiceState; -use std::sync::Arc; - -/// TODO: -#[derive(Debug)] -pub struct ParallelDownloader<'a, C> { - /// The consensus engine - pub consensus: &'a C, - /// The number of parallel requests - pub par_count: usize, - /// The batch size per one request - pub batch_size: u64, - /// A single request timeout - pub request_timeout: u64, - /// The number of retries for downloading - pub request_retries: usize, -} - -#[async_trait] -impl<'a, C: Consensus> Downloader for ParallelDownloader<'a, C> { - type Consensus = C; - - /// The request timeout - fn timeout(&self) -> u64 { - self.request_timeout - } - - fn consensus(&self) -> &Self::Consensus { - self.consensus - } - - /// Download the headers - async fn download( - &self, - client: Arc, - head: &HeaderLocked, - forkchoice: &ForkchoiceState, - ) -> Result, DownloadError> { - let mut stream = client.stream_headers().await; - let mut retries = self.request_retries; - let mut reached_finalized = false; - todo!() - // // Header order will be preserved during inserts - // let mut out = Vec::::new(); - - // // Request blocks by hash until finalized hash - // loop { - // let result = self - // .download_batch( - // client.clone(), - // consensus.clone(), - // head, - // forkchoice, - // &mut stream, - // &mut out, - // ) - // .await; - // match result { - // Ok(result) => match result { - // ParallelResult::Discard | ParallelResult::Continue => (), - // ParallelResult::ReachedFinalized => reached_finalized = true, - // ParallelResult::ReachedHead => return Ok(out), - // }, - // Err(e) if e.is_retryable() && retries > 1 => { - // retries -= 1; - // } - // Err(e) => return Err(e), - // } - // } - } -} - -enum ParallelResult { - Continue, - Discard, - ReachedFinalized, - ReachedHead, -} - -impl<'a, C: Consensus> ParallelDownloader<'a, C> { - async fn download_batch( - &'a self, - client: Arc, - head: &'a HeaderLocked, - forkchoice: &'a ForkchoiceState, - stream: &'a mut MessageStream<(u64, Vec
)>, - out: &'a mut Vec, - ) -> Result { - // Request headers starting from tip or earliest cached - let start = out.first().map_or(forkchoice.head_block_hash, |h| h.parent_hash); - let mut headers = self - .download_headers(stream, client.clone(), BlockId::Hash(start), self.batch_size) - .await?; - headers.sort_unstable_by_key(|h| h.number); - - // Iterate the headers in reverse - out.reserve_exact(headers.len()); - let mut headers_rev = headers.into_iter().rev(); - - let mut result = ParallelResult::Continue; - while let Some(parent) = headers_rev.next() { - let parent = parent.lock(); - - if parent.hash() == head.hash() { - // We've reached the target - return Ok(ParallelResult::ReachedHead) - } - - match out.first() { - Some(tail_header) if !self.validate(tail_header, &parent)? => { - // Cannot attach to the current buffer, discard - return Ok(ParallelResult::Discard) - } - // The buffer is empty and the first header does not match the tip, discard - // TODO: penalize the peer? - None if parent.hash() != forkchoice.head_block_hash => { - return Ok(ParallelResult::Discard) - } - _ => (), - }; - - if parent.hash() == forkchoice.finalized_block_hash { - result = ParallelResult::ReachedFinalized; - } - - out.insert(0, parent); - } - - Ok(result) - } -} diff --git a/crates/stages/src/stages/headers/stage.rs b/crates/stages/src/stages/headers/stage.rs index cdf3c0eb9927..bbe3c3b54ee0 100644 --- a/crates/stages/src/stages/headers/stage.rs +++ b/crates/stages/src/stages/headers/stage.rs @@ -16,13 +16,13 @@ const HEADERS: StageId = StageId("HEADERS"); /// The headers stage implementation for staged sync #[derive(Debug)] -pub struct HeaderStage { +pub struct HeaderStage { /// Strategy for downloading the headers pub downloader: D, /// Consensus client implementation pub consensus: C, /// Downloader client implementation - pub client: Arc, + pub client: H, } /// The header stage error @@ -43,7 +43,7 @@ impl Into for HeaderStageError { } #[async_trait] -impl<'db, E, D: Downloader, C: Consensus> Stage<'db, E> for HeaderStage +impl<'db, E, D: Downloader, C: Consensus, H: HeadersClient> Stage<'db, E> for HeaderStage where E: mdbx::EnvironmentKind, { @@ -81,8 +81,7 @@ where let head = HeaderLocked::new(last_header, last_hash); let forkchoice = self.next_forkchoice_state(&head.hash()).await; - let headers = match self.downloader.download(self.client.clone(), &head, &forkchoice).await - { + let headers = match self.downloader.download(&head, &forkchoice).await { Ok(res) => res, Err(e) => match e { DownloadError::NoHeaderResponse { request_id } => { @@ -154,7 +153,7 @@ where } } -impl HeaderStage { +impl HeaderStage { async fn update_head<'tx, E: mdbx::EnvironmentKind>( &self, tx: &'tx mut Tx<'_, mdbx::RW, E>, @@ -223,47 +222,24 @@ pub(crate) mod tests { use super::*; use crate::util::db::TxContainer; use assert_matches::assert_matches; + use once_cell::sync::Lazy; use reth_db::{ kv::{test_utils as test_db_utils, EnvKind}, mdbx, }; use tokio::sync::{broadcast, mpsc}; - #[tokio::test] - async fn headers_stage_empty_db() { - let (req_tx, _req_rx) = mpsc::channel(1); - let (_res_tx, res_rx) = broadcast::channel(1); - - let mut stage = HeaderStage { - client: Arc::new(test_utils::TestHeaderClient::new(req_tx, res_rx)), - consensus: Arc::new(test_utils::TestConsensus::new()), - downloader: test_utils::TestDownloader::new(Ok(vec![])), - }; - - let mut db = test_db_utils::create_test_db::(EnvKind::RW); - let mut tx = TxContainer::new(&mut db).unwrap(); + static CONSENSUS: Lazy = + Lazy::new(|| test_utils::TestConsensus::new()); - let input = ExecInput { previous_stage: None, stage_progress: None }; - assert_matches!( - stage.execute(tx.get_mut(), input).await, - Err(StageError::Internal(err)) - if matches!( - err.downcast_ref::(), - Some(HeaderStageError::NoCannonicalHeader { .. } - ) - ) - ); - } + static CLIENT: Lazy = + Lazy::new(|| test_utils::TestHeaderClient::new()); #[tokio::test] - // TODO: - async fn headers_stage_() { - let (req_tx, _req_rx) = mpsc::channel(1); - let (_res_tx, res_rx) = broadcast::channel(1); - + async fn headers_stage_empty_db() { let mut stage = HeaderStage { - client: Arc::new(test_utils::TestHeaderClient::new(req_tx, res_rx)), - consensus: &test_utils::TestConsensus::new(), + client: &*CLIENT, + consensus: &*CONSENSUS, downloader: test_utils::TestDownloader::new(Ok(vec![])), }; @@ -291,27 +267,49 @@ pub(crate) mod tests { }; use reth_primitives::{Header, HeaderLocked, H256, H512}; use reth_rpc_types::engine::ForkchoiceState; - use std::{collections::HashSet, sync::Arc}; - use tokio::sync::{broadcast, mpsc::Sender, watch}; + use std::{ + collections::HashSet, + sync::{Arc, Mutex}, + }; + use tokio::sync::{broadcast, mpsc, watch}; use tokio_stream::{wrappers::BroadcastStream, StreamExt}; pub(crate) type HeaderResponse = (u64, Vec
); #[derive(Debug)] pub(crate) struct TestHeaderClient { - tx: Sender<(u64, HeaderRequest)>, - rx: broadcast::Receiver, + req_tx: mpsc::Sender<(u64, HeaderRequest)>, + req_rx: Arc>>, + res_tx: broadcast::Sender, + res_rx: broadcast::Receiver, } impl TestHeaderClient { /// Construct a new test header downloader. - /// `tx` is the `Sender` for header requests - /// `rx` is the `Receiver` of header responses - pub(crate) fn new( - tx: Sender<(u64, HeaderRequest)>, - rx: broadcast::Receiver, - ) -> Self { - Self { tx, rx } + pub(crate) fn new() -> Self { + let (req_tx, req_rx) = mpsc::channel(1); + let (res_tx, res_rx) = broadcast::channel(1); + Self { req_tx, req_rx: Arc::new(Mutex::new(req_rx)), res_tx, res_rx } + } + + pub(crate) async fn on_header_request(&self, mut count: usize, mut f: F) -> Vec + where + F: FnMut(u64, HeaderRequest) -> T, + { + let mut rx = self.req_rx.lock().unwrap(); + let mut results = vec![]; + while let Some((id, req)) = rx.recv().await { + results.push(f(id, req)); + count -= 1; + if count == 0 { + break + } + } + return results + } + + pub(crate) fn send_header_response(&self, id: u64, headers: Vec
) { + self.res_tx.send((id, headers)).expect("failed to send header response"); } } @@ -320,12 +318,12 @@ pub(crate) mod tests { async fn update_status(&self, _height: u64, _hash: H256, _td: H256) {} async fn send_header_request(&self, id: u64, request: HeaderRequest) -> HashSet { - self.tx.send((id, request)).await.expect("failed to send request"); + self.req_tx.send((id, request)).await.expect("failed to send request"); HashSet::default() } async fn stream_headers(&self) -> MessageStream<(u64, Vec
)> { - Box::pin(BroadcastStream::new(self.rx.resubscribe()).filter_map(|e| e.ok())) + Box::pin(BroadcastStream::new(self.res_rx.resubscribe()).filter_map(|e| e.ok())) } } @@ -406,6 +404,7 @@ pub(crate) mod tests { #[async_trait] impl Downloader for TestDownloader { type Consensus = TestConsensus; + type Client = TestHeaderClient; fn timeout(&self) -> u64 { 1 @@ -415,9 +414,12 @@ pub(crate) mod tests { unimplemented!() } + fn client(&self) -> &Self::Client { + unimplemented!() + } + async fn download( &self, - _: Arc, _: &HeaderLocked, _: &ForkchoiceState, ) -> Result, DownloadError> { From 015d5bfb2ba1cb0789f67011edfd5e96a0619eb7 Mon Sep 17 00:00:00 2001 From: Roman Krasiuk Date: Thu, 20 Oct 2022 16:39:47 +0300 Subject: [PATCH 13/13] more tests & cleanup --- .../stages/src/stages/headers/downloader.rs | 5 +- crates/stages/src/stages/headers/linear.rs | 122 +++++++++++------- crates/stages/src/stages/headers/stage.rs | 22 +++- 3 files changed, 96 insertions(+), 53 deletions(-) diff --git a/crates/stages/src/stages/headers/downloader.rs b/crates/stages/src/stages/headers/downloader.rs index 70fa12c61f25..3f127d8dbe5f 100644 --- a/crates/stages/src/stages/headers/downloader.rs +++ b/crates/stages/src/stages/headers/downloader.rs @@ -1,12 +1,11 @@ use async_trait::async_trait; -use rand::Rng; use reth_interfaces::{ consensus::Consensus, stages::{HeaderRequest, HeadersClient, MessageStream}, }; use reth_primitives::{rpc::BlockId, Header, HeaderLocked, H256}; use reth_rpc_types::engine::ForkchoiceState; -use std::{fmt::Debug, sync::Arc, time::Duration}; +use std::{fmt::Debug, time::Duration}; use thiserror::Error; use tokio_stream::StreamExt; @@ -42,7 +41,7 @@ pub trait Downloader: Sync + Send + Debug { start: BlockId, limit: u64, ) -> Result, DownloadError> { - let request_id = rand::thread_rng().gen(); + let request_id = rand::random(); let request = HeaderRequest { start, limit, reverse: true }; let _ = self.client().send_header_request(request_id, request).await; diff --git a/crates/stages/src/stages/headers/linear.rs b/crates/stages/src/stages/headers/linear.rs index 62a0936ad4b1..c3a1c9ff680e 100644 --- a/crates/stages/src/stages/headers/linear.rs +++ b/crates/stages/src/stages/headers/linear.rs @@ -56,17 +56,15 @@ impl<'a, C: Consensus, H: HeadersClient> Downloader for LinearDownloader<'a, C, let result = self.download_batch(&mut stream, forkchoice, head, out.get(0)).await; match result { Ok(result) => match result { - LinearDownloadResult::Batch(headers) => { - // TODO: Should this instead be? - // headers.extend_from_slice(&out); - // out = headers; - out.extend_from_slice(&headers); + LinearDownloadResult::Batch(mut headers) => { + // TODO: fix + headers.extend_from_slice(&out); + out = headers; } - LinearDownloadResult::Finished(headers) => { - // TODO: Should this instead be? - // headers.extend_from_slice(&out); - // out = headers; - out.extend_from_slice(&headers); + LinearDownloadResult::Finished(mut headers) => { + // TODO: fix + headers.extend_from_slice(&out); + out = headers; return Ok(out) } LinearDownloadResult::Ignore => (), @@ -137,14 +135,16 @@ impl<'a, C: Consensus, H: HeadersClient> LinearDownloader<'a, C, H> { #[cfg(test)] mod tests { use super::{ - super::stage::tests::test_utils::{TestConsensus, TestHeaderClient}, + super::stage::tests::test_utils::{ + gen_block_range, gen_random_header, TestConsensus, TestHeaderClient, + }, DownloadError, Downloader, LinearDownloader, }; use assert_matches::assert_matches; use once_cell::sync::Lazy; - use rand::Rng; use reth_interfaces::stages::HeaderRequest; - use reth_primitives::{rpc::BlockId, Header, HeaderLocked, H256}; + use reth_primitives::{rpc::BlockId, HeaderLocked, H256}; + use test_runner::LinearTestRunner; use tokio::sync::oneshot::error::TryRecvError; static CONSENSUS: Lazy = Lazy::new(|| TestConsensus::new()); @@ -158,7 +158,7 @@ mod tests { #[tokio::test] async fn download_timeout() { - let runner = test_runner::LinearTestRunner::new(); + let runner = LinearTestRunner::new(); let retries = runner.retries; let rx = runner.run(&*CONSENSUS, &*CLIENT, HeaderLocked::default(), H256::zero()); @@ -174,7 +174,7 @@ mod tests { #[tokio::test] async fn download_timeout_on_invalid_messages() { - let runner = test_runner::LinearTestRunner::new(); + let runner = LinearTestRunner::new(); let retries = runner.retries; let rx = runner.run(&*CONSENSUS, &*CLIENT, HeaderLocked::default(), H256::zero()); @@ -185,7 +185,7 @@ mod tests { .on_header_request(retries, |id, _req| { num_of_reqs += 1; last_req_id = Some(id); - CLIENT.send_header_response(id.saturating_add(id % 2), vec![]) + CLIENT.send_header_response(id.saturating_add(id % 2), vec![]); }) .await; @@ -197,71 +197,95 @@ mod tests { #[tokio::test] async fn download_propagates_consensus_validation_error() { - let mut tip_parent = Header::default(); - tip_parent.nonce = rand::thread_rng().gen(); - tip_parent.number = 1; - let parent_hash = tip_parent.hash_slow(); - - let mut tip_header = Header::default(); - tip_header.number = 2; - tip_header.nonce = rand::thread_rng().gen(); - tip_header.parent_hash = parent_hash; - let chain_tip = tip_header.hash_slow(); - - let runner = test_runner::LinearTestRunner::new(); - let rx = runner.run(&*CONSENSUS_FAIL, &*CLIENT, HeaderLocked::default(), chain_tip); + let tip_parent = gen_random_header(1, None); + let tip = gen_random_header(2, Some(tip_parent.hash())); + + let rx = LinearTestRunner::new().run( + &*CONSENSUS_FAIL, + &*CLIENT, + HeaderLocked::default(), + tip.hash(), + ); let requests = CLIENT.on_header_request(1, |id, req| (id, req)).await; - let request = requests.last(); assert_matches!( request, Some((_, HeaderRequest { start, .. })) - if matches!(start, BlockId::Hash(hash) if *hash == chain_tip) + if matches!(start, BlockId::Hash(hash) if *hash == tip.hash()) ); let request = request.unwrap(); - CLIENT.send_header_response(request.0, vec![tip_header, tip_parent]); + CLIENT.send_header_response( + request.0, + vec![tip_parent.clone().unlock(), tip.clone().unlock()], + ); assert_matches!( rx.await, - Ok(Err(DownloadError::HeaderValidation { hash, .. })) if hash == parent_hash + Ok(Err(DownloadError::HeaderValidation { hash, .. })) if hash == tip_parent.hash() ); } #[tokio::test] async fn download_starts_with_chain_tip() { - let mut tip_parent = Header::default(); - tip_parent.nonce = rand::thread_rng().gen(); - tip_parent.number = 1; - let parent_hash = tip_parent.hash_slow(); - - let mut tip = Header::default(); - tip.parent_hash = parent_hash; - tip.number = 2; - tip.nonce = rand::thread_rng().gen(); + let head = gen_random_header(1, None); + let tip = gen_random_header(2, Some(head.hash())); - let runner = test_runner::LinearTestRunner::new(); - let mut rx = runner.run(&*CONSENSUS, &*CLIENT, tip_parent.clone().lock(), tip.hash_slow()); + let mut rx = LinearTestRunner::new().run(&*CONSENSUS, &*CLIENT, head.clone(), tip.hash()); CLIENT .on_header_request(1, |id, _req| { - let mut corrupted_tip = tip.clone(); - corrupted_tip.nonce = rand::thread_rng().gen(); - CLIENT.send_header_response(id, vec![corrupted_tip, tip_parent.clone()]) + let mut corrupted_tip = tip.clone().unlock(); + corrupted_tip.nonce = rand::random(); + CLIENT.send_header_response(id, vec![corrupted_tip, head.clone().unlock()]) }) .await; assert_matches!(rx.try_recv(), Err(TryRecvError::Empty)); CLIENT .on_header_request(1, |id, _req| { - CLIENT.send_header_response(id, vec![tip.clone(), tip_parent.clone()]) + CLIENT.send_header_response(id, vec![tip.clone().unlock(), head.clone().unlock()]) }) .await; let result = rx.await; assert_matches!(result, Ok(Ok(ref val)) if val.len() == 1); - assert_eq!(*result.unwrap().unwrap().first().unwrap(), tip.lock()); + assert_eq!(*result.unwrap().unwrap().first().unwrap(), tip); + } + + #[tokio::test] + async fn download_returns_headers_asc() { + let (start, end) = (100, 200); + let head = gen_random_header(start, None); + let headers = gen_block_range(start + 1..end, head.hash()); + let tip = headers.last().unwrap(); + + let rx = LinearTestRunner::new().run(&*CONSENSUS, &*CLIENT, head.clone(), tip.hash()); + + let mut idx = 0; + let chunk_size = 10; + let chunk_iter = headers.clone().into_iter().rev(); + // `usize::div_ceil` is unstable. ref: https://github.com/rust-lang/rust/issues/88581 + let count = (headers.len() + chunk_size - 1) / chunk_size; + CLIENT + .on_header_request(count + 1, |id, _req| { + let mut chunk = + chunk_iter.clone().skip(chunk_size * idx).take(chunk_size).peekable(); + idx += 1; + if chunk.peek().is_some() { + let headers: Vec<_> = chunk.map(|h| h.unlock()).collect(); + CLIENT.send_header_response(id, headers); + } else { + CLIENT.send_header_response(id, vec![head.clone().unlock()]) + } + }) + .await; + + let result = rx.await; + assert_matches!(result, Ok(Ok(_))); + let result = result.unwrap().unwrap(); + assert_eq!(result, headers); } mod test_runner { diff --git a/crates/stages/src/stages/headers/stage.rs b/crates/stages/src/stages/headers/stage.rs index bbe3c3b54ee0..a8900f148aa3 100644 --- a/crates/stages/src/stages/headers/stage.rs +++ b/crates/stages/src/stages/headers/stage.rs @@ -227,7 +227,7 @@ pub(crate) mod tests { kv::{test_utils as test_db_utils, EnvKind}, mdbx, }; - use tokio::sync::{broadcast, mpsc}; + use tokio::sync::mpsc; static CONSENSUS: Lazy = Lazy::new(|| test_utils::TestConsensus::new()); @@ -269,11 +269,31 @@ pub(crate) mod tests { use reth_rpc_types::engine::ForkchoiceState; use std::{ collections::HashSet, + ops::Range, sync::{Arc, Mutex}, }; use tokio::sync::{broadcast, mpsc, watch}; use tokio_stream::{wrappers::BroadcastStream, StreamExt}; + pub(crate) fn gen_block_range(rng: Range, head: H256) -> Vec { + let mut headers = Vec::with_capacity(rng.end.saturating_sub(rng.start) as usize); + for idx in rng { + headers.push(gen_random_header( + idx, + Some(headers.last().map(|h: &HeaderLocked| h.hash()).unwrap_or(head)), + )); + } + headers + } + + pub(crate) fn gen_random_header(number: u64, parent: Option) -> HeaderLocked { + let mut header = Header::default(); + header.number = number; + header.nonce = rand::random(); + header.parent_hash = parent.unwrap_or_default(); + header.lock() + } + pub(crate) type HeaderResponse = (u64, Vec
); #[derive(Debug)]