Skip to content

Commit

Permalink
feat(provider): subscribe to new blocks if possible in heartbeat (all…
Browse files Browse the repository at this point in the history
…oy-rs#1321)

* feat(provider): subscribe to new blocks if possible in heartbeat

* msg

* wasm

* fix

* move into stream

* feat: lazily subscribe to newHeads

* chore: remove async from get_heart

* testname
  • Loading branch information
DaniPopes authored and lwedge99 committed Oct 8, 2024
1 parent e85ab25 commit ea78674
Show file tree
Hide file tree
Showing 11 changed files with 256 additions and 163 deletions.
4 changes: 2 additions & 2 deletions crates/network-primitives/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@ pub trait HeaderResponse {
/// Block JSON-RPC response.
pub trait BlockResponse {
/// Header type
type Header;
type Header: HeaderResponse;
/// Transaction type
type Transaction;
type Transaction: TransactionResponse;

/// Block header
fn header(&self) -> &Self::Header;
Expand Down
169 changes: 124 additions & 45 deletions crates/provider/src/chain.rs → crates/provider/src/blocks.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use alloy_network::{Ethereum, Network};
use alloy_primitives::{BlockNumber, U64};
use alloy_rpc_client::{NoParams, PollerBuilder, WeakClient};
use alloy_rpc_types_eth::Block;
use alloy_transport::{RpcError, Transport};
use async_stream::stream;
use futures::{Stream, StreamExt};
use lru::LruCache;
use std::{marker::PhantomData, num::NonZeroUsize};

#[cfg(feature = "pubsub")]
use futures::{future::Either, FutureExt};

/// The size of the block cache.
const BLOCK_CACHE_SIZE: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(10) };

Expand All @@ -17,38 +19,89 @@ const MAX_RETRIES: usize = 3;
/// Default block number for when we don't have a block yet.
const NO_BLOCK_NUMBER: BlockNumber = BlockNumber::MAX;

pub(crate) struct ChainStreamPoller<T, N = Ethereum> {
/// Streams new blocks from the client.
pub(crate) struct NewBlocks<T, N: Network = Ethereum> {
client: WeakClient<T>,
poll_task: PollerBuilder<T, NoParams, U64>,
/// The next block to yield.
/// [`NO_BLOCK_NUMBER`] indicates that it will be updated on the first poll.
/// Only used by the polling task.
next_yield: BlockNumber,
known_blocks: LruCache<BlockNumber, Block>,
/// LRU cache of known blocks. Only used by the polling task.
known_blocks: LruCache<BlockNumber, N::BlockResponse>,
_phantom: PhantomData<N>,
}

impl<T: Transport + Clone, N: Network> ChainStreamPoller<T, N> {
pub(crate) fn from_weak_client(w: WeakClient<T>) -> Self {
Self::new(w)
}

impl<T: Transport + Clone, N: Network> NewBlocks<T, N> {
pub(crate) fn new(client: WeakClient<T>) -> Self {
Self::with_next_yield(client, NO_BLOCK_NUMBER)
}

/// Can be used to force the poller to start at a specific block number.
/// Mostly useful for tests.
fn with_next_yield(client: WeakClient<T>, next_yield: BlockNumber) -> Self {
Self {
client: client.clone(),
poll_task: PollerBuilder::new(client, "eth_blockNumber", []),
next_yield,
client,
next_yield: NO_BLOCK_NUMBER,
known_blocks: LruCache::new(BLOCK_CACHE_SIZE),
_phantom: PhantomData,
}
}

pub(crate) fn into_stream(mut self) -> impl Stream<Item = Block> + 'static {
#[cfg(test)]
const fn with_next_yield(mut self, next_yield: u64) -> Self {
self.next_yield = next_yield;
self
}

pub(crate) fn into_stream(self) -> impl Stream<Item = N::BlockResponse> + 'static {
// Return a stream that lazily subscribes to `newHeads` on the first poll.
#[cfg(feature = "pubsub")]
if let Some(client) = self.client.upgrade() {
if client.pubsub_frontend().is_some() {
let subscriber = self.into_subscription_stream().map(futures::stream::iter);
let subscriber = futures::stream::once(subscriber);
return Either::Left(subscriber.flatten().flatten());
}
}

// Returns a stream that lazily initializes an `eth_blockNumber` polling task on the first
// poll, mapped with `eth_getBlockByNumber`.
#[cfg(feature = "pubsub")]
let right = Either::Right;
#[cfg(not(feature = "pubsub"))]
let right = std::convert::identity;
right(self.into_poll_stream())
}

#[cfg(feature = "pubsub")]
async fn into_subscription_stream(
self,
) -> Option<impl Stream<Item = N::BlockResponse> + 'static> {
let Some(client) = self.client.upgrade() else {
debug!("client dropped");
return None;
};
let Some(pubsub) = client.pubsub_frontend() else {
error!("pubsub_frontend returned None after being Some");
return None;
};
let id = match client.request("eth_subscribe", ("newHeads",)).await {
Ok(id) => id,
Err(err) => {
error!(%err, "failed to subscribe to newHeads");
return None;
}
};
let sub = match pubsub.get_subscription(id).await {
Ok(sub) => sub,
Err(err) => {
error!(%err, "failed to get subscription");
return None;
}
};
Some(sub.into_typed::<N::BlockResponse>().into_stream())
}

fn into_poll_stream(mut self) -> impl Stream<Item = N::BlockResponse> + 'static {
stream! {
let mut poll_task = self.poll_task.spawn().into_stream_raw();
// Spawned lazily on the first `poll`.
let poll_task_builder: PollerBuilder<T, NoParams, U64> =
PollerBuilder::new(self.client.clone(), "eth_blockNumber", []);
let mut poll_task = poll_task_builder.spawn().into_stream_raw();
'task: loop {
// Clear any buffered blocks.
while let Some(known_block) = self.known_blocks.pop(&self.next_yield) {
Expand All @@ -71,6 +124,7 @@ impl<T: Transport + Clone, N: Network> ChainStreamPoller<T, N> {
}
};
let block_number = block_number.to::<u64>();
trace!(%block_number, "got block number");
if self.next_yield == NO_BLOCK_NUMBER {
assert!(block_number < NO_BLOCK_NUMBER, "too many blocks");
self.next_yield = block_number;
Expand Down Expand Up @@ -125,64 +179,89 @@ impl<T: Transport + Clone, N: Network> ChainStreamPoller<T, N> {

#[cfg(all(test, feature = "anvil-api"))] // Tests rely heavily on ability to mine blocks on demand.
mod tests {
use std::{future::Future, time::Duration};

use crate::{ext::AnvilApi, ProviderBuilder};
use super::*;
use crate::{ext::AnvilApi, Provider, ProviderBuilder};
use alloy_node_bindings::Anvil;
use alloy_primitives::U256;
use alloy_rpc_client::ReqwestClient;

use super::*;
use std::{future::Future, time::Duration};

fn init_tracing() {
let _ = tracing_subscriber::fmt::try_init();
}

async fn with_timeout<T: Future>(fut: T) -> T::Output {
tokio::select! {
_ = tokio::time::sleep(Duration::from_secs(1)) => panic!("Operation timed out"),
out = fut => out,
}
async fn timeout<T: Future>(future: T) -> T::Output {
try_timeout(future).await.expect("Timeout")
}

async fn try_timeout<T: Future>(future: T) -> Option<T::Output> {
tokio::time::timeout(Duration::from_secs(2), future).await.ok()
}

#[tokio::test]
async fn yield_block_http() {
yield_block(false).await;
}
#[tokio::test]
async fn yield_block() {
#[cfg(feature = "ws")]
async fn yield_block_ws() {
yield_block(true).await;
}
async fn yield_block(ws: bool) {
init_tracing();

let anvil = Anvil::new().spawn();

let client = ReqwestClient::new_http(anvil.endpoint_url());
let poller: ChainStreamPoller<_, Ethereum> =
ChainStreamPoller::with_next_yield(client.get_weak(), 1);
let mut stream = Box::pin(poller.into_stream());
let url = if ws { anvil.ws_endpoint() } else { anvil.endpoint() };
let provider = ProviderBuilder::new().on_builtin(&url).await.unwrap();

let new_blocks = NewBlocks::<_, Ethereum>::new(provider.weak_client()).with_next_yield(1);
let mut stream = Box::pin(new_blocks.into_stream());
if ws {
let _ = try_timeout(stream.next()).await; // Subscribe to newHeads.
}

// We will also use provider to manipulate anvil instance via RPC.
let provider = ProviderBuilder::new().on_http(anvil.endpoint_url());
provider.anvil_mine(Some(U256::from(1)), None).await.unwrap();

let block = with_timeout(stream.next()).await.expect("Block wasn't fetched");
let block = timeout(stream.next()).await.expect("Block wasn't fetched");
assert_eq!(block.header.number, 1);
}

#[tokio::test]
async fn yield_many_blocks() {
async fn yield_many_blocks_http() {
yield_many_blocks(false).await;
}
#[tokio::test]
#[cfg(feature = "ws")]
async fn yield_many_blocks_ws() {
yield_many_blocks(true).await;
}
async fn yield_many_blocks(ws: bool) {
// Make sure that we can process more blocks than fits in the cache.
const BLOCKS_TO_MINE: usize = BLOCK_CACHE_SIZE.get() + 1;

init_tracing();

let anvil = Anvil::new().spawn();

let client = ReqwestClient::new_http(anvil.endpoint_url());
let poller: ChainStreamPoller<_, Ethereum> =
ChainStreamPoller::with_next_yield(client.get_weak(), 1);
let stream = Box::pin(poller.into_stream());
let url = if ws { anvil.ws_endpoint() } else { anvil.endpoint() };
let provider = ProviderBuilder::new().on_builtin(&url).await.unwrap();

let new_blocks = NewBlocks::<_, Ethereum>::new(provider.weak_client()).with_next_yield(1);
let mut stream = Box::pin(new_blocks.into_stream());
if ws {
let _ = try_timeout(stream.next()).await; // Subscribe to newHeads.
}

// We will also use provider to manipulate anvil instance via RPC.
let provider = ProviderBuilder::new().on_http(anvil.endpoint_url());
provider.anvil_mine(Some(U256::from(BLOCKS_TO_MINE)), None).await.unwrap();

let blocks = with_timeout(stream.take(BLOCKS_TO_MINE).collect::<Vec<_>>()).await;
let blocks = timeout(stream.take(BLOCKS_TO_MINE).collect::<Vec<_>>()).await;
assert_eq!(blocks.len(), BLOCKS_TO_MINE);
let first = blocks[0].header.number;
assert_eq!(first, 1);
for (i, block) in blocks.iter().enumerate() {
assert_eq!(block.header.number, first + i as u64);
}
}
}
2 changes: 1 addition & 1 deletion crates/provider/src/ext/anvil.rs
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ mod tests {
}

#[tokio::test]
async fn test_anvil_set_block_timestamp_interval_anvil_remove_block_timestamp_interval() {
async fn test_anvil_block_timestamp_interval() {
let provider = ProviderBuilder::new().on_anvil();

provider.anvil_set_block_timestamp_interval(1).await.unwrap();
Expand Down
Loading

0 comments on commit ea78674

Please sign in to comment.