diff --git a/crates/contract/src/call.rs b/crates/contract/src/call.rs index c7d164e6acb..1d1f0a38171 100644 --- a/crates/contract/src/call.rs +++ b/crates/contract/src/call.rs @@ -447,7 +447,7 @@ impl, D: CallDecoder, N: Network> CallBu /// If this is not desired, use [`call_raw`](Self::call_raw) to get the raw output data. #[doc(alias = "eth_call")] #[doc(alias = "call_with_overrides")] - pub fn call(&self) -> EthCall<'_, '_, '_, D, T, N> { + pub fn call(&self) -> EthCall<'_, '_, D, T, N> { self.call_raw().with_decoder(&self.decoder) } @@ -457,7 +457,7 @@ impl, D: CallDecoder, N: Network> CallBu /// Does not decode the output of the call, returning the raw output data instead. /// /// See [`call`](Self::call) for more information. - pub fn call_raw(&self) -> EthCall<'_, '_, '_, (), T, N> { + pub fn call_raw(&self) -> EthCall<'_, '_, (), T, N> { let call = self.provider.call(&self.request).block(self.block); let call = match &self.state { Some(state) => call.overrides(state), diff --git a/crates/contract/src/eth_call.rs b/crates/contract/src/eth_call.rs index c4dd8c227bb..2f4e9db089c 100644 --- a/crates/contract/src/eth_call.rs +++ b/crates/contract/src/eth_call.rs @@ -24,18 +24,18 @@ mod private { /// An [`alloy_provider::EthCall`] with an abi decoder. #[must_use = "EthCall must be awaited to execute the call"] #[derive(Clone, Debug)] -pub struct EthCall<'req, 'state, 'coder, D, T, N> +pub struct EthCall<'req, 'coder, D, T, N> where T: Transport + Clone, N: Network, D: CallDecoder, { - inner: alloy_provider::EthCall<'req, 'state, T, N, Bytes>, + inner: alloy_provider::EthCall<'req, T, N, Bytes>, decoder: &'coder D, } -impl<'req, 'state, 'coder, D, T, N> EthCall<'req, 'state, 'coder, D, T, N> +impl<'req, 'coder, D, T, N> EthCall<'req, 'coder, D, T, N> where T: Transport + Clone, N: Network, @@ -43,25 +43,25 @@ where { /// Create a new [`EthCall`]. pub const fn new( - inner: alloy_provider::EthCall<'req, 'state, T, N, Bytes>, + inner: alloy_provider::EthCall<'req, T, N, Bytes>, decoder: &'coder D, ) -> Self { Self { inner, decoder } } } -impl<'req, 'state, T, N> EthCall<'req, 'state, 'static, (), T, N> +impl<'req, T, N> EthCall<'req, 'static, (), T, N> where T: Transport + Clone, N: Network, { /// Create a new [`EthCall`]. - pub const fn new_raw(inner: alloy_provider::EthCall<'req, 'state, T, N, Bytes>) -> Self { + pub const fn new_raw(inner: alloy_provider::EthCall<'req, T, N, Bytes>) -> Self { Self::new(inner, &RAW_CODER) } } -impl<'req, 'state, 'coder, D, T, N> EthCall<'req, 'state, 'coder, D, T, N> +impl<'req, 'coder, D, T, N> EthCall<'req, 'coder, D, T, N> where T: Transport + Clone, N: Network, @@ -71,7 +71,7 @@ where pub fn with_decoder<'new_coder, E>( self, decoder: &'new_coder E, - ) -> EthCall<'req, 'state, 'new_coder, E, T, N> + ) -> EthCall<'req, 'new_coder, E, T, N> where E: CallDecoder, { @@ -79,7 +79,7 @@ where } /// Set the state overrides for this call. - pub fn overrides(mut self, overrides: &'state StateOverride) -> Self { + pub fn overrides(mut self, overrides: &'req StateOverride) -> Self { self.inner = self.inner.overrides(overrides); self } @@ -91,19 +91,18 @@ where } } -impl<'req, 'state, T, N> From> - for EthCall<'req, 'state, 'static, (), T, N> +impl<'req, T, N> From> + for EthCall<'req, 'static, (), T, N> where T: Transport + Clone, N: Network, { - fn from(inner: alloy_provider::EthCall<'req, 'state, T, N, Bytes>) -> Self { + fn from(inner: alloy_provider::EthCall<'req, T, N, Bytes>) -> Self { Self { inner, decoder: &RAW_CODER } } } -impl<'req, 'state, 'coder, D, T, N> std::future::IntoFuture - for EthCall<'req, 'state, 'coder, D, T, N> +impl<'req, 'coder, D, T, N> std::future::IntoFuture for EthCall<'req, 'coder, D, T, N> where D: CallDecoder + Unpin, T: Transport + Clone, @@ -111,7 +110,7 @@ where { type Output = Result; - type IntoFuture = EthCallFut<'req, 'state, 'coder, D, T, N>; + type IntoFuture = EthCallFut<'req, 'coder, D, T, N>; fn into_future(self) -> Self::IntoFuture { EthCallFut { inner: self.inner.into_future(), decoder: self.decoder } @@ -121,20 +120,19 @@ where /// Future for the [`EthCall`] type. This future wraps an RPC call with an abi /// decoder. #[must_use = "futures do nothing unless you `.await` or poll them"] -#[derive(Clone, Debug)] +#[derive(Debug)] #[allow(unnameable_types)] -pub struct EthCallFut<'req, 'state, 'coder, D, T, N> +pub struct EthCallFut<'req, 'coder, D, T, N> where T: Transport + Clone, N: Network, D: CallDecoder, { - inner: as IntoFuture>::IntoFuture, + inner: as IntoFuture>::IntoFuture, decoder: &'coder D, } -impl<'req, 'state, 'coder, D, T, N> std::future::Future - for EthCallFut<'req, 'state, 'coder, D, T, N> +impl<'req, 'coder, D, T, N> std::future::Future for EthCallFut<'req, 'coder, D, T, N> where D: CallDecoder + Unpin, T: Transport + Clone, diff --git a/crates/json-rpc/src/request.rs b/crates/json-rpc/src/request.rs index bc8fb11b4e7..2dac9b67092 100644 --- a/crates/json-rpc/src/request.rs +++ b/crates/json-rpc/src/request.rs @@ -78,6 +78,14 @@ impl Request { pub fn set_subscription_status(&mut self, sub: bool) { self.meta.set_subscription_status(sub); } + + /// Change type of the request parameters. + pub fn map_params( + self, + map: impl FnOnce(Params) -> NewParams, + ) -> Request { + Request { meta: self.meta, params: map(self.params) } + } } /// A [`Request`] that has been partially serialized. @@ -113,11 +121,12 @@ where impl Request<&Params> where - Params: Clone, + Params: ToOwned, + Params::Owned: RpcParam, { /// Clone the request, including the request parameters. - pub fn into_owned_params(self) -> Request { - Request { meta: self.meta, params: self.params.clone() } + pub fn into_owned_params(self) -> Request { + Request { meta: self.meta, params: self.params.to_owned() } } } diff --git a/crates/provider/src/ext/trace.rs b/crates/provider/src/ext/trace.rs index a0b748b5a9c..58944e8963d 100644 --- a/crates/provider/src/ext/trace.rs +++ b/crates/provider/src/ext/trace.rs @@ -112,14 +112,14 @@ where trace_types: &'b [TraceType], ) -> RpcWithBlock::TransactionRequest, &'b [TraceType]), TraceResults> { - RpcWithBlock::new(self.weak_client(), "trace_call", (request, trace_types)) + self.client().request("trace_call", (request, trace_types)).into() } fn trace_call_many<'a>( &self, request: TraceCallList<'a, N>, ) -> RpcWithBlock,), Vec> { - RpcWithBlock::new(self.weak_client(), "trace_callMany", (request,)) + self.client().request("trace_callMany", (request,)).into() } async fn trace_transaction( diff --git a/crates/provider/src/lib.rs b/crates/provider/src/lib.rs index a59bdc8c0fb..9105f58e530 100644 --- a/crates/provider/src/lib.rs +++ b/crates/provider/src/lib.rs @@ -28,12 +28,11 @@ extern crate tracing; mod builder; pub use builder::{Identity, ProviderBuilder, ProviderLayer, Stack}; +mod chain; + pub mod ext; pub mod fillers; -pub mod layers; - -mod chain; mod heart; pub use heart::{ @@ -41,10 +40,12 @@ pub use heart::{ PendingTransactionError, WatchTxError, }; +pub mod layers; + mod provider; pub use provider::{ - builder, EthCall, FilterPollerBuilder, Provider, RootProvider, RpcWithBlock, SendableTx, - WalletProvider, + builder, Caller, EthCall, EthCallParams, FilterPollerBuilder, ParamsWithBlock, Provider, + ProviderCall, RootProvider, RpcWithBlock, SendableTx, WalletProvider, }; pub mod utils; diff --git a/crates/provider/src/provider/caller.rs b/crates/provider/src/provider/caller.rs new file mode 100644 index 00000000000..abd6db8c565 --- /dev/null +++ b/crates/provider/src/provider/caller.rs @@ -0,0 +1,47 @@ +use crate::ProviderCall; +use alloy_json_rpc::{RpcParam, RpcReturn}; +use alloy_rpc_client::WeakClient; +use alloy_transport::{RpcError, Transport, TransportErrorKind, TransportResult}; +use std::borrow::Cow; + +// TODO: Make `EthCall` specific. Ref: https://github.com/alloy-rs/alloy/pull/788#discussion_r1748862509. + +/// Trait that helpes convert `EthCall` into a `ProviderCall`. +pub trait Caller: Send + Sync +where + T: Transport + Clone, + Params: RpcParam, + Resp: RpcReturn, +{ + /// Method that needs to be implemented to convert to a `ProviderCall`. + /// + /// This method handles serialization of the params and sends the request to relevant data + /// source and returns a `ProviderCall`. + fn call( + &self, + method: Cow<'static, str>, + params: Params, + ) -> TransportResult>; +} + +impl Caller for WeakClient +where + T: Transport + Clone, + Params: RpcParam, + Resp: RpcReturn, +{ + fn call( + &self, + method: Cow<'static, str>, + params: Params, + ) -> TransportResult> { + let client = self.upgrade().ok_or_else(TransportErrorKind::backend_gone)?; + + // serialize the params + let ser = serde_json::to_value(params).map_err(RpcError::ser_err)?; + + let rpc_call = client.request(method, ser); + + Ok(ProviderCall::RpcCall(rpc_call)) + } +} diff --git a/crates/provider/src/provider/call.rs b/crates/provider/src/provider/eth_call.rs similarity index 61% rename from crates/provider/src/provider/call.rs rename to crates/provider/src/provider/eth_call.rs index c79d3944c36..83433490568 100644 --- a/crates/provider/src/provider/call.rs +++ b/crates/provider/src/provider/eth_call.rs @@ -1,24 +1,23 @@ use alloy_eips::BlockId; use alloy_json_rpc::RpcReturn; use alloy_network::Network; -use alloy_rpc_client::{RpcCall, WeakClient}; use alloy_rpc_types_eth::state::StateOverride; -use alloy_transport::{Transport, TransportErrorKind, TransportResult}; +use alloy_transport::{Transport, TransportResult}; use futures::FutureExt; use serde::ser::SerializeSeq; -use std::{future::Future, marker::PhantomData, task::Poll}; +use std::{future::Future, marker::PhantomData, sync::Arc, task::Poll}; -type RunningFut<'req, 'state, T, N, Resp, Output, Map> = - RpcCall, Resp, Output, Map>; +use crate::{Caller, ProviderCall}; +/// The parameters for an `"eth_call"` RPC request. #[derive(Clone, Debug)] -struct EthCallParams<'req, 'state, N: Network> { +pub struct EthCallParams<'req, N: Network> { data: &'req N::TransactionRequest, block: Option, - overrides: Option<&'state StateOverride>, + overrides: Option<&'req StateOverride>, } -impl serde::Serialize for EthCallParams<'_, '_, N> { +impl serde::Serialize for EthCallParams<'_, N> { fn serialize(&self, serializer: S) -> Result { let len = if self.overrides.is_some() { 3 } else { 2 }; @@ -37,21 +36,22 @@ impl serde::Serialize for EthCallParams<'_, '_, N> { } /// The [`EthCallFut`] future is the future type for an `eth_call` RPC request. -#[derive(Clone, Debug)] +#[derive(Debug)] #[doc(hidden)] // Not public API. #[allow(unnameable_types)] #[pin_project::pin_project] -pub struct EthCallFut<'req, 'state, T, N, Resp, Output, Map>( - EthCallFutInner<'req, 'state, T, N, Resp, Output, Map>, -) +pub struct EthCallFut<'req, T, N, Resp, Output, Map> where T: Transport + Clone, N: Network, Resp: RpcReturn, - Map: Fn(Resp) -> Output; + Output: 'static, + Map: Fn(Resp) -> Output, +{ + inner: EthCallFutInner<'req, T, N, Resp, Output, Map>, +} -#[derive(Clone, Debug)] -enum EthCallFutInner<'req, 'state, T, N, Resp, Output, Map> +enum EthCallFutInner<'req, T, N, Resp, Output, Map> where T: Transport + Clone, N: Network, @@ -59,18 +59,45 @@ where Map: Fn(Resp) -> Output, { Preparing { - client: WeakClient, + caller: Arc, Resp>>, data: &'req N::TransactionRequest, - overrides: Option<&'state StateOverride>, + overrides: Option<&'req StateOverride>, block: Option, method: &'static str, map: Map, }, - Running(RunningFut<'req, 'state, T, N, Resp, Output, Map>), + Running { + map: Map, + fut: ProviderCall, + }, Polling, } -impl<'req, 'state, T, N, Resp, Output, Map> EthCallFutInner<'req, 'state, T, N, Resp, Output, Map> +impl<'req, T, N, Resp, Output, Map> core::fmt::Debug + for EthCallFutInner<'req, T, N, Resp, Output, Map> +where + T: Transport + Clone, + N: Network, + Resp: RpcReturn, + Output: 'static, + Map: Fn(Resp) -> Output, +{ + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::Preparing { caller: _, data, overrides, block, method, map: _ } => f + .debug_struct("Preparing") + .field("data", data) + .field("overrides", overrides) + .field("block", block) + .field("method", method) + .finish(), + Self::Running { .. } => f.debug_tuple("Running").finish(), + Self::Polling => f.debug_tuple("Polling").finish(), + } + } +} + +impl<'req, T, N, Resp, Output, Map> EthCallFut<'req, T, N, Resp, Output, Map> where T: Transport + Clone, N: Network, @@ -80,43 +107,40 @@ where { /// Returns `true` if the future is in the preparing state. const fn is_preparing(&self) -> bool { - matches!(self, Self::Preparing { .. }) + matches!(self.inner, EthCallFutInner::Preparing { .. }) } /// Returns `true` if the future is in the running state. const fn is_running(&self) -> bool { - matches!(self, Self::Running(..)) + matches!(self.inner, EthCallFutInner::Running { .. }) } fn poll_preparing(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { - let Self::Preparing { client, data, overrides, block, method, map } = - std::mem::replace(self, Self::Polling) + let EthCallFutInner::Preparing { caller, data, overrides, block, method, map } = + std::mem::replace(&mut self.inner, EthCallFutInner::Polling) else { unreachable!("bad state") }; - let client = match client.upgrade().ok_or_else(TransportErrorKind::backend_gone) { - Ok(client) => client, - Err(e) => return Poll::Ready(Err(e)), - }; - let params = EthCallParams { data, block, overrides }; - let fut = client.request(method, params).map_resp(map); + let fut = caller.call(method.into(), params)?; + + self.inner = EthCallFutInner::Running { map, fut }; - *self = Self::Running(fut); self.poll_running(cx) } fn poll_running(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { - let Self::Running(ref mut call) = self else { unreachable!("bad state") }; + let EthCallFutInner::Running { ref map, ref mut fut } = self.inner else { + unreachable!("bad state") + }; - call.poll_unpin(cx) + fut.poll_unpin(cx).map(|res| res.map(map)) } } -impl<'req, 'state, T, N, Resp, Output, Map> Future - for EthCallFut<'req, 'state, T, N, Resp, Output, Map> +impl<'req, T, N, Resp, Output, Map> Future for EthCallFut<'req, T, N, Resp, Output, Map> where T: Transport + Clone, N: Network, @@ -130,7 +154,7 @@ where self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll { - let this = &mut self.get_mut().0; + let this = self.get_mut(); if this.is_preparing() { this.poll_preparing(cx) } else if this.is_running() { @@ -146,34 +170,52 @@ where /// /// [`Provider::call`]: crate::Provider::call #[must_use = "EthCall must be awaited to execute the call"] -#[derive(Debug, Clone)] -pub struct EthCall<'req, 'state, T, N, Resp, Output = Resp, Map = fn(Resp) -> Output> +#[derive(Clone)] +pub struct EthCall<'req, T, N, Resp, Output = Resp, Map = fn(Resp) -> Output> where T: Transport + Clone, N: Network, Resp: RpcReturn, Map: Fn(Resp) -> Output, { - client: WeakClient, - + caller: Arc, Resp>>, data: &'req N::TransactionRequest, - overrides: Option<&'state StateOverride>, + overrides: Option<&'req StateOverride>, block: Option, method: &'static str, map: Map, _pd: PhantomData (Resp, Output)>, } -impl<'req, T, N, Resp> EthCall<'req, 'static, T, N, Resp> +impl<'req, T, N, Resp> core::fmt::Debug for EthCall<'req, T, N, Resp> +where + T: Transport + Clone, + N: Network, + Resp: RpcReturn, +{ + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("EthCall") + .field("method", &self.method) + .field("data", &self.data) + .field("block", &self.block) + .field("overrides", &self.overrides) + .finish() + } +} + +impl<'req, T, N, Resp> EthCall<'req, T, N, Resp> where T: Transport + Clone, N: Network, Resp: RpcReturn, { /// Create a new CallBuilder. - pub const fn new(client: WeakClient, data: &'req N::TransactionRequest) -> Self { + pub fn new( + caller: impl Caller, Resp> + 'static, + data: &'req N::TransactionRequest, + ) -> Self { Self { - client, + caller: Arc::new(caller), data, overrides: None, block: None, @@ -184,9 +226,12 @@ where } /// Create new EthCall for gas estimates. - pub const fn gas_estimate(client: WeakClient, data: &'req N::TransactionRequest) -> Self { + pub fn gas_estimate( + caller: impl Caller, Resp> + 'static, + data: &'req N::TransactionRequest, + ) -> Self { Self { - client, + caller: Arc::new(caller), data, overrides: None, block: None, @@ -197,23 +242,33 @@ where } } -impl<'req, 'state, T, N, Resp, Output, Map> EthCall<'req, 'state, T, N, Resp, Output, Map> +impl<'req, T, N, Resp, Output, Map> EthCall<'req, T, N, Resp, Output, Map> where T: Transport + Clone, N: Network, Resp: RpcReturn, Map: Fn(Resp) -> Output, { - /// Map the response to a different type. + /// Map the response to a different type. This is usable for converting + /// the response to a more usable type, e.g. changing `U64` to `u64`. + /// + /// ## Note + /// + /// Carefully review the rust documentation on [fn pointers] before passing + /// them to this function. Unless the pointer is specifically coerced to a + /// `fn(_) -> _`, the `NewMap` will be inferred as that function's unique + /// type. This can lead to confusing error messages. + /// + /// [fn pointers]: https://doc.rust-lang.org/std/primitive.fn.html#creating-function-pointers pub fn map_resp( self, map: NewMap, - ) -> EthCall<'req, 'state, T, N, Resp, NewOutput, NewMap> + ) -> EthCall<'req, T, N, Resp, NewOutput, NewMap> where NewMap: Fn(Resp) -> NewOutput, { EthCall { - client: self.client, + caller: self.caller, data: self.data, overrides: self.overrides, block: self.block, @@ -224,7 +279,7 @@ where } /// Set the state overrides for this call. - pub const fn overrides(mut self, overrides: &'state StateOverride) -> Self { + pub const fn overrides(mut self, overrides: &'req StateOverride) -> Self { self.overrides = Some(overrides); self } @@ -236,8 +291,8 @@ where } } -impl<'req, 'state, T, N, Resp, Output, Map> std::future::IntoFuture - for EthCall<'req, 'state, T, N, Resp, Output, Map> +impl<'req, T, N, Resp, Output, Map> std::future::IntoFuture + for EthCall<'req, T, N, Resp, Output, Map> where T: Transport + Clone, N: Network, @@ -247,17 +302,19 @@ where { type Output = TransportResult; - type IntoFuture = EthCallFut<'req, 'state, T, N, Resp, Output, Map>; + type IntoFuture = EthCallFut<'req, T, N, Resp, Output, Map>; fn into_future(self) -> Self::IntoFuture { - EthCallFut(EthCallFutInner::Preparing { - client: self.client, - data: self.data, - overrides: self.overrides, - block: self.block, - method: self.method, - map: self.map, - }) + EthCallFut { + inner: EthCallFutInner::Preparing { + caller: self.caller, + data: self.data, + overrides: self.overrides, + block: self.block, + method: self.method, + map: self.map, + }, + } } } @@ -288,7 +345,7 @@ mod test { let overrides = StateOverride::default(); // Expected: [data] - let params: EthCallParams<'_, '_, Ethereum> = + let params: EthCallParams<'_, Ethereum> = EthCallParams { data: &data, block: None, overrides: None }; assert_eq!(params.data, &data); @@ -300,7 +357,7 @@ mod test { ); // Expected: [data, block, overrides] - let params: EthCallParams<'_, '_, Ethereum> = + let params: EthCallParams<'_, Ethereum> = EthCallParams { data: &data, block: Some(block), overrides: Some(&overrides) }; assert_eq!(params.data, &data); @@ -312,7 +369,7 @@ mod test { ); // Expected: [data, (default), overrides] - let params: EthCallParams<'_, '_, Ethereum> = + let params: EthCallParams<'_, Ethereum> = EthCallParams { data: &data, block: None, overrides: Some(&overrides) }; assert_eq!(params.data, &data); @@ -324,7 +381,7 @@ mod test { ); // Expected: [data, block] - let params: EthCallParams<'_, '_, Ethereum> = + let params: EthCallParams<'_, Ethereum> = EthCallParams { data: &data, block: Some(block), overrides: None }; assert_eq!(params.data, &data); diff --git a/crates/provider/src/provider/mod.rs b/crates/provider/src/provider/mod.rs index 0d8e939ed5a..8cd167497e6 100644 --- a/crates/provider/src/provider/mod.rs +++ b/crates/provider/src/provider/mod.rs @@ -1,5 +1,8 @@ -mod call; -pub use call::EthCall; +mod eth_call; +pub use eth_call::{EthCall, EthCallParams}; + +mod prov_call; +pub use prov_call::ProviderCall; mod root; pub use root::{builder, RootProvider}; @@ -14,4 +17,7 @@ mod wallet; pub use wallet::WalletProvider; mod with_block; -pub use with_block::RpcWithBlock; +pub use with_block::{ParamsWithBlock, RpcWithBlock}; + +mod caller; +pub use caller::Caller; diff --git a/crates/provider/src/provider/prov_call.rs b/crates/provider/src/provider/prov_call.rs new file mode 100644 index 00000000000..8e6b9f6119d --- /dev/null +++ b/crates/provider/src/provider/prov_call.rs @@ -0,0 +1,268 @@ +use alloy_json_rpc::{RpcParam, RpcReturn}; +use alloy_rpc_client::{RpcCall, Waiter}; +use alloy_transport::{Transport, TransportResult}; +use futures::FutureExt; +use pin_project::pin_project; +use serde_json::value::RawValue; +use std::{ + future::Future, + pin::Pin, + task::{self, Poll}, +}; +use tokio::sync::oneshot; + +/// The primary future type for the [`Provider`]. +/// +/// This future abstracts over several potential data sources. It allows +/// providers to: +/// - produce data via an [`RpcCall`] +/// - produce data by waiting on a batched RPC [`Waiter`] +/// - proudce data via an arbitrary boxed future +/// - produce data in any synchronous way +/// +/// [`Provider`]: crate::Provider +#[pin_project(project = ProviderCallProj)] +pub enum ProviderCall Output> +where + Conn: Transport + Clone, + Params: RpcParam, + Resp: RpcReturn, + Map: Fn(Resp) -> Output, +{ + /// An underlying call to an RPC server. + RpcCall(RpcCall), + /// A waiter for a batched call to a remote RPC server. + Waiter(Waiter), + /// A boxed future. + BoxedFuture(Pin> + Send>>), + /// The output, produces synchronously. + Ready(Option>), +} + +impl ProviderCall +where + Conn: Transport + Clone, + Params: RpcParam, + Resp: RpcReturn, + Map: Fn(Resp) -> Output, +{ + /// Instantiate a new [`ProviderCall`] from the output. + pub const fn ready(output: TransportResult) -> Self { + Self::Ready(Some(output)) + } + + /// True if this is an RPC call. + pub const fn is_rpc_call(&self) -> bool { + matches!(self, Self::RpcCall(_)) + } + + /// Fallible cast to [`RpcCall`] + pub const fn as_rpc_call(&self) -> Option<&RpcCall> { + match self { + Self::RpcCall(call) => Some(call), + _ => None, + } + } + + /// Fallible cast to mutable [`RpcCall`] + pub fn as_mut_rpc_call(&mut self) -> Option<&mut RpcCall> { + match self { + Self::RpcCall(call) => Some(call), + _ => None, + } + } + + /// True if this is a waiter. + pub const fn is_waiter(&self) -> bool { + matches!(self, Self::Waiter(_)) + } + + /// Fallible cast to [`Waiter`] + pub const fn as_waiter(&self) -> Option<&Waiter> { + match self { + Self::Waiter(waiter) => Some(waiter), + _ => None, + } + } + + /// Fallible cast to mutable [`Waiter`] + pub fn as_mut_waiter(&mut self) -> Option<&mut Waiter> { + match self { + Self::Waiter(waiter) => Some(waiter), + _ => None, + } + } + + /// True if this is a boxed future. + pub const fn is_boxed_future(&self) -> bool { + matches!(self, Self::BoxedFuture(_)) + } + + /// Fallible cast to a boxed future. + pub const fn as_boxed_future( + &self, + ) -> Option<&Pin> + Send>>> { + match self { + Self::BoxedFuture(fut) => Some(fut), + _ => None, + } + } + + /// True if this is a ready value. + pub const fn is_ready(&self) -> bool { + matches!(self, Self::Ready(_)) + } + + /// Fallible cast to a ready value. + /// + /// # Panics + /// + /// Panics if the future is already complete + pub const fn as_ready(&self) -> Option<&TransportResult> { + match self { + Self::Ready(Some(output)) => Some(output), + Self::Ready(None) => panic!("tried to access ready value after taking"), + _ => None, + } + } + + /// Set a function to map the response into a different type. This is + /// useful for transforming the response into a more usable type, e.g. + /// changing `U64` to `u64`. + /// + /// This function fails if the inner future is not an [`RpcCall`] or + /// [`Waiter`]. + /// + /// ## Note + /// + /// Carefully review the rust documentation on [fn pointers] before passing + /// them to this function. Unless the pointer is specifically coerced to a + /// `fn(_) -> _`, the `NewMap` will be inferred as that function's unique + /// type. This can lead to confusing error messages. + /// + /// [fn pointers]: https://doc.rust-lang.org/std/primitive.fn.html#creating-function-pointers + pub fn map_resp( + self, + map: NewMap, + ) -> Result, Self> + where + NewMap: Fn(Resp) -> NewOutput + Clone, + { + match self { + Self::RpcCall(call) => Ok(ProviderCall::RpcCall(call.map_resp(map))), + Self::Waiter(waiter) => Ok(ProviderCall::Waiter(waiter.map_resp(map))), + _ => Err(self), + } + } +} + +impl ProviderCall +where + Conn: Transport + Clone, + Params: RpcParam, + Params: ToOwned, + Params::Owned: RpcParam, + Resp: RpcReturn, + Map: Fn(Resp) -> Output, +{ + /// Convert this call into one with owned params, by cloning the params. + /// + /// # Panics + /// + /// Panics if called after the request has been polled. + pub fn into_owned_params(self) -> ProviderCall { + match self { + Self::RpcCall(call) => ProviderCall::RpcCall(call.into_owned_params()), + _ => panic!(), + } + } +} + +impl std::fmt::Debug for ProviderCall +where + Conn: Transport + Clone, + Params: RpcParam, + Resp: RpcReturn, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::RpcCall(call) => f.debug_tuple("RpcCall").field(call).finish(), + Self::Waiter { .. } => f.debug_struct("Waiter").finish_non_exhaustive(), + Self::BoxedFuture(_) => f.debug_struct("BoxedFuture").finish_non_exhaustive(), + Self::Ready(_) => f.debug_struct("Ready").finish_non_exhaustive(), + } + } +} + +impl From> + for ProviderCall +where + Conn: Transport + Clone, + Params: RpcParam, + Resp: RpcReturn, + Map: Fn(Resp) -> Output, +{ + fn from(call: RpcCall) -> Self { + Self::RpcCall(call) + } +} + +impl From> + for ProviderCall Resp> +where + Conn: Transport + Clone, + Params: RpcParam, + Resp: RpcReturn, +{ + fn from(waiter: Waiter) -> Self { + Self::Waiter(waiter) + } +} + +impl + From> + Send>>> + for ProviderCall +where + Conn: Transport + Clone, + Params: RpcParam, + Resp: RpcReturn, + Map: Fn(Resp) -> Output, +{ + fn from(fut: Pin> + Send>>) -> Self { + Self::BoxedFuture(fut) + } +} + +impl From>>> + for ProviderCall +where + Conn: Transport + Clone, + Params: RpcParam, + Resp: RpcReturn, +{ + fn from(rx: oneshot::Receiver>>) -> Self { + Waiter::from(rx).into() + } +} + +impl Future for ProviderCall +where + Conn: Transport + Clone, + Params: RpcParam, + Resp: RpcReturn, + Output: 'static, + Map: Fn(Resp) -> Output, +{ + type Output = TransportResult; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll { + match self.as_mut().project() { + ProviderCallProj::RpcCall(call) => call.poll_unpin(cx), + ProviderCallProj::Waiter(waiter) => waiter.poll_unpin(cx), + ProviderCallProj::BoxedFuture(fut) => fut.poll_unpin(cx), + ProviderCallProj::Ready(output) => { + Poll::Ready(output.take().expect("output taken twice")) + } + } + } +} diff --git a/crates/provider/src/provider/trait.rs b/crates/provider/src/provider/trait.rs index c522c8a3799..d04fa821786 100644 --- a/crates/provider/src/provider/trait.rs +++ b/crates/provider/src/provider/trait.rs @@ -1,10 +1,9 @@ //! Ethereum JSON-RPC provider. - use crate::{ heart::PendingTransactionError, utils::{self, Eip1559Estimation, EstimatorFunction}, EthCall, Identity, PendingTransaction, PendingTransactionBuilder, PendingTransactionConfig, - ProviderBuilder, RootProvider, RpcWithBlock, SendableTx, + ProviderBuilder, ProviderCall, RootProvider, RpcWithBlock, SendableTx, }; use alloy_eips::eip2718::Encodable2718; use alloy_json_rpc::{RpcError, RpcParam, RpcReturn}; @@ -16,7 +15,7 @@ use alloy_primitives::{ hex, Address, BlockHash, BlockNumber, Bytes, StorageKey, StorageValue, TxHash, B256, U128, U256, U64, }; -use alloy_rpc_client::{ClientRef, NoParams, PollerBuilder, RpcCall, WeakClient}; +use alloy_rpc_client::{ClientRef, NoParams, PollerBuilder, WeakClient}; use alloy_rpc_types_eth::{ AccessListResult, BlockId, BlockNumberOrTag, EIP1186AccountProofResponse, FeeHistory, Filter, FilterChanges, Log, SyncStatus, @@ -100,18 +99,24 @@ pub trait Provider: /// Gets the accounts in the remote node. This is usually empty unless you're using a local /// node. - async fn get_accounts(&self) -> TransportResult> { - self.client().request_noparams("eth_accounts").await + fn get_accounts(&self) -> ProviderCall> { + self.client().request_noparams("eth_accounts").into() } /// Returns the base fee per blob gas (blob gas price) in wei. - async fn get_blob_base_fee(&self) -> TransportResult { - self.client().request_noparams("eth_blobBaseFee").await.map(|fee: U128| fee.to::()) + fn get_blob_base_fee(&self) -> ProviderCall { + self.client() + .request_noparams("eth_blobBaseFee") + .map_resp(utils::convert_u128 as fn(U128) -> u128) + .into() } /// Get the last block number available. - fn get_block_number(&self) -> RpcCall { - self.client().request_noparams("eth_blockNumber").map_resp(crate::utils::convert_u64) + fn get_block_number(&self) -> ProviderCall { + self.client() + .request_noparams("eth_blockNumber") + .map_resp(crate::utils::convert_u64 as fn(U64) -> u64) + .into() } /// Execute a smart contract call with a transaction request and state @@ -138,29 +143,21 @@ pub trait Provider: /// # let tx = alloy_rpc_types_eth::transaction::TransactionRequest::default(); /// // Execute a call on the latest block, with no state overrides /// let output = provider.call(&tx).await?; - /// // Execute a call with a block ID. - /// let output = provider.call(&tx).block(1.into()).await?; - /// // Execute a call with state overrides. - /// let output = provider.call(&tx).overrides(&my_overrides).await?; /// # Ok(()) /// # } /// ``` - /// - /// # Note - /// - /// Not all client implementations support state overrides. #[doc(alias = "eth_call")] #[doc(alias = "call_with_overrides")] - fn call<'req, 'state>( - &self, - tx: &'req N::TransactionRequest, - ) -> EthCall<'req, 'state, T, N, Bytes> { + fn call<'req>(&self, tx: &'req N::TransactionRequest) -> EthCall<'req, T, N, Bytes> { EthCall::new(self.weak_client(), tx) } - /// Gets the chain ID. - fn get_chain_id(&self) -> RpcCall { - self.client().request_noparams("eth_chainId").map_resp(crate::utils::convert_u64) + /// Gets the chain ID. + fn get_chain_id(&self) -> ProviderCall { + self.client() + .request_noparams("eth_chainId") + .map_resp(crate::utils::convert_u64 as fn(U64) -> u64) + .into() } /// Create an [EIP-2930] access list. @@ -170,7 +167,7 @@ pub trait Provider: &self, request: &'a N::TransactionRequest, ) -> RpcWithBlock { - RpcWithBlock::new(self.weak_client(), "eth_createAccessList", request) + self.client().request("eth_createAccessList", request).into() } /// This function returns an [`EthCall`] which can be used to get a gas estimate, @@ -186,7 +183,7 @@ pub trait Provider: fn estimate_gas<'req>( &self, tx: &'req N::TransactionRequest, - ) -> EthCall<'req, 'static, T, N, U128, u128> { + ) -> EthCall<'req, T, N, U128, u128> { EthCall::gas_estimate(self.weak_client(), tx).map_resp(crate::utils::convert_u128) } @@ -242,21 +239,24 @@ pub trait Provider: } /// Gets the current gas price in wei. - fn get_gas_price(&self) -> RpcCall { - self.client().request_noparams("eth_gasPrice").map_resp(crate::utils::convert_u128) + fn get_gas_price(&self) -> ProviderCall { + self.client() + .request_noparams("eth_gasPrice") + .map_resp(crate::utils::convert_u128 as fn(U128) -> u128) + .into() } /// Retrieves account information ([Account](alloy_consensus::Account)) for the given [Address] /// at the particular [BlockId]. fn get_account(&self, address: Address) -> RpcWithBlock { - RpcWithBlock::new(self.weak_client(), "eth_getAccount", address) + self.client().request("eth_getAccount", address).into() } /// Gets the balance of the account. /// /// Defaults to the latest block. See also [`RpcWithBlock::block_id`]. - fn get_balance(&self, address: Address) -> RpcWithBlock { - RpcWithBlock::new(self.weak_client(), "eth_getBalance", address) + fn get_balance(&self, address: Address) -> RpcWithBlock { + self.client().request("eth_getBalance", address).into() } /// Gets a block by either its hash, tag, or number, with full transactions or only hashes. @@ -324,16 +324,16 @@ pub trait Provider: } /// Gets the selected block [BlockId] receipts. - async fn get_block_receipts( + fn get_block_receipts( &self, block: BlockId, - ) -> TransportResult>> { - self.client().request("eth_getBlockReceipts", (block,)).await + ) -> ProviderCall>> { + self.client().request("eth_getBlockReceipts", (block,)).into() } /// Gets the bytecode located at the corresponding [Address]. fn get_code_at(&self, address: Address) -> RpcWithBlock { - RpcWithBlock::new(self.weak_client(), "eth_getCode", address) + self.client().request("eth_getCode", address).into() } /// Watch for new blocks by polling the provider with @@ -501,7 +501,7 @@ pub trait Provider: address: Address, keys: Vec, ) -> RpcWithBlock), EIP1186AccountProofResponse> { - RpcWithBlock::new(self.weak_client(), "eth_getProof", (address, keys)) + self.client().request("eth_getProof", (address, keys)).into() } /// Gets the specified storage value from [Address]. @@ -510,15 +510,15 @@ pub trait Provider: address: Address, key: U256, ) -> RpcWithBlock { - RpcWithBlock::new(self.weak_client(), "eth_getStorageAt", (address, key)) + self.client().request("eth_getStorageAt", (address, key)).into() } /// Gets a transaction by its [TxHash]. - async fn get_transaction_by_hash( + fn get_transaction_by_hash( &self, hash: TxHash, - ) -> TransportResult> { - self.client().request("eth_getTransactionByHash", (hash,)).await + ) -> ProviderCall> { + self.client().request("eth_getTransactionByHash", (hash,)).into() } /// Returns the EIP-2718 encoded transaction if it exists, see also @@ -529,24 +529,32 @@ pub trait Provider: /// [TxEip4844](alloy_consensus::transaction::eip4844::TxEip4844). /// /// This can be decoded into [TxEnvelope](alloy_consensus::transaction::TxEnvelope). - async fn get_raw_transaction_by_hash(&self, hash: TxHash) -> TransportResult> { - self.client().request("eth_getRawTransactionByHash", (hash,)).await + fn get_raw_transaction_by_hash( + &self, + hash: TxHash, + ) -> ProviderCall> { + self.client().request("eth_getRawTransactionByHash", (hash,)).into() } /// Gets the transaction count (AKA "nonce") of the corresponding address. #[doc(alias = "get_nonce")] #[doc(alias = "get_account_nonce")] - fn get_transaction_count(&self, address: Address) -> RpcWithBlock { - RpcWithBlock::new(self.weak_client(), "eth_getTransactionCount", address) - .map_resp(crate::utils::convert_u64) + fn get_transaction_count( + &self, + address: Address, + ) -> RpcWithBlock u64> { + self.client() + .request("eth_getTransactionCount", address) + .map_resp(crate::utils::convert_u64 as fn(U64) -> u64) + .into() } /// Gets a transaction receipt if it exists, by its [TxHash]. - async fn get_transaction_receipt( + fn get_transaction_receipt( &self, hash: TxHash, - ) -> TransportResult> { - self.client().request("eth_getTransactionReceipt", (hash,)).await + ) -> ProviderCall> { + self.client().request("eth_getTransactionReceipt", (hash,)).into() } /// Gets an uncle block through the tag [BlockId] and index [u64]. @@ -581,11 +589,11 @@ pub trait Provider: } /// Returns a suggestion for the current `maxPriorityFeePerGas` in wei. - async fn get_max_priority_fee_per_gas(&self) -> TransportResult { + fn get_max_priority_fee_per_gas(&self) -> ProviderCall { self.client() .request_noparams("eth_maxPriorityFeePerGas") - .await - .map(|fee: U128| fee.to::()) + .map_resp(utils::convert_u128 as fn(U128) -> u128) + .into() } /// Notify the provider that we are interested in new blocks. @@ -869,25 +877,28 @@ pub trait Provider: } /// Gets syncing info. - async fn syncing(&self) -> TransportResult { - self.client().request_noparams("eth_syncing").await + fn syncing(&self) -> ProviderCall { + self.client().request_noparams("eth_syncing").into() } /// Gets the client version. #[doc(alias = "web3_client_version")] - async fn get_client_version(&self) -> TransportResult { - self.client().request_noparams("web3_clientVersion").await + fn get_client_version(&self) -> ProviderCall { + self.client().request_noparams("web3_clientVersion").into() } /// Gets the `Keccak-256` hash of the given data. #[doc(alias = "web3_sha3")] - async fn get_sha3(&self, data: &[u8]) -> TransportResult { - self.client().request("web3_sha3", (hex::encode_prefixed(data),)).await + fn get_sha3(&self, data: &[u8]) -> ProviderCall { + self.client().request("web3_sha3", (hex::encode_prefixed(data),)).into() } /// Gets the network ID. Same as `eth_chainId`. - fn get_net_version(&self) -> RpcCall { - self.client().request_noparams("net_version").map_resp(crate::utils::convert_u64) + fn get_net_version(&self) -> ProviderCall { + self.client() + .request_noparams("net_version") + .map_resp(crate::utils::convert_u64 as fn(U64) -> u64) + .into() } /* ---------------------------------------- raw calls --------------------------------------- */ @@ -1383,14 +1394,35 @@ mod tests { assert_eq!(0, num.to::()) } + #[cfg(feature = "anvil-api")] #[tokio::test] async fn gets_transaction_count() { init_tracing(); let provider = ProviderBuilder::new().on_anvil(); - let count = provider - .get_transaction_count(address!("328375e18E7db8F1CA9d9bA8bF3E9C94ee34136A")) - .await - .unwrap(); + let accounts = provider.get_accounts().await.unwrap(); + let sender = accounts[0]; + + // Initial tx count should be 0 + let count = provider.get_transaction_count(sender).await.unwrap(); + assert_eq!(count, 0); + + // Send Tx + let tx = TransactionRequest { + value: Some(U256::from(100)), + from: Some(sender), + to: Some(address!("d8dA6BF26964aF9D7eEd9e03E53415D37aA96045").into()), + gas_price: Some(20e9 as u128), + gas: Some(21000), + ..Default::default() + }; + let _ = provider.send_transaction(tx).await.unwrap().get_receipt().await; + + // Tx count should be 1 + let count = provider.get_transaction_count(sender).await.unwrap(); + assert_eq!(count, 1); + + // Tx count should be 0 at block 0 + let count = provider.get_transaction_count(sender).block_id(0.into()).await.unwrap(); assert_eq!(count, 0); } @@ -1652,6 +1684,9 @@ mod tests { .with_input(bytes!("06fdde03")); // `name()` let result = provider.call(&req).await.unwrap(); assert_eq!(String::abi_decode(&result, true).unwrap(), "Wrapped Ether"); + + let result = provider.call(&req).block(0.into()).await.unwrap(); + assert_eq!(result.to_string(), "0x"); } #[tokio::test] diff --git a/crates/provider/src/provider/with_block.rs b/crates/provider/src/provider/with_block.rs index 471e310c156..d3fcd300957 100644 --- a/crates/provider/src/provider/with_block.rs +++ b/crates/provider/src/provider/with_block.rs @@ -1,37 +1,60 @@ use alloy_eips::BlockId; -use alloy_json_rpc::{RpcError, RpcParam, RpcReturn}; +use alloy_json_rpc::{RpcParam, RpcReturn}; use alloy_primitives::B256; -use alloy_rpc_client::{RpcCall, WeakClient}; -use alloy_transport::{Transport, TransportErrorKind, TransportResult}; -use futures::FutureExt; -use std::{ - borrow::Cow, - future::{Future, IntoFuture}, - marker::PhantomData, - task::Poll, -}; +use alloy_rpc_client::RpcCall; +use alloy_transport::{Transport, TransportResult}; +use std::future::IntoFuture; -/// States of the [`RpcWithBlock`] future. -#[derive(Clone)] -enum States Output> +use crate::ProviderCall; + +/// Helper struct that houses the params along with the BlockId. +#[derive(Debug, Clone)] +pub struct ParamsWithBlock { + params: Params, + block_id: BlockId, +} + +impl serde::Serialize for ParamsWithBlock { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + // Serialize params to a Value first + let mut ser = serde_json::to_value(&self.params).map_err(serde::ser::Error::custom)?; + + // serialize the block id + let block_id = serde_json::to_value(self.block_id).map_err(serde::ser::Error::custom)?; + + if let serde_json::Value::Array(ref mut arr) = ser { + arr.push(block_id); + } else if ser.is_null() { + ser = serde_json::Value::Array(vec![block_id]); + } else { + ser = serde_json::Value::Array(vec![ser, block_id]); + } + + ser.serialize(serializer) + } +} + +type ProviderCallProducer = + Box ProviderCall, Resp, Output, Map> + Send>; + +/// Container for varous types of calls dependent on a block id. +enum WithBlockInner Output> where T: Transport + Clone, Params: RpcParam, Resp: RpcReturn, Map: Fn(Resp) -> Output, { - Invalid, - Preparing { - client: WeakClient, - method: Cow<'static, str>, - params: Params, - block_id: BlockId, - map: Map, - }, - Running(RpcCall), + /// [RpcCall] which params are getting wrapped into [ParamsWithBlock] once the block id is set. + RpcCall(RpcCall), + /// Closure that produces a [ProviderCall] once the block id is set. + ProviderCall(ProviderCallProducer), } -impl core::fmt::Debug for States +impl core::fmt::Debug for WithBlockInner where T: Transport + Clone, Params: RpcParam, @@ -40,163 +63,77 @@ where { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Invalid => f.debug_tuple("Invalid").finish(), - Self::Preparing { client, method, params, block_id, .. } => f - .debug_struct("Preparing") - .field("client", client) - .field("method", method) - .field("params", params) - .field("block_id", block_id) - .finish(), - Self::Running(arg0) => f.debug_tuple("Running").field(arg0).finish(), + Self::RpcCall(call) => f.debug_tuple("RpcCall").field(call).finish(), + Self::ProviderCall(_) => f.debug_struct("ProviderCall").finish(), } } } -/// A future for [`RpcWithBlock`]. Simple wrapper around [`RpcCall`]. -#[derive(Debug, Clone)] +/// A struct that takes an optional [`BlockId`] parameter. +/// +/// This resolves to a [`ProviderCall`] that will execute the call on the specified block. +/// +/// By default this will use "latest". #[pin_project::pin_project] -#[allow(unnameable_types)] -pub struct RpcWithBlockFut +#[derive(Debug)] +pub struct RpcWithBlock Output> where T: Transport + Clone, Params: RpcParam, Resp: RpcReturn, - Map: Fn(Resp) -> Output, + Map: Fn(Resp) -> Output + Clone, { - state: States, + inner: WithBlockInner, + block_id: BlockId, } -impl RpcWithBlockFut +impl RpcWithBlock where T: Transport + Clone, Params: RpcParam, Resp: RpcReturn, - Output: 'static, - Map: Fn(Resp) -> Output, + Map: Fn(Resp) -> Output + Clone, { - fn poll_preparing( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - let this = self.project(); - let States::Preparing { client, method, params, block_id, map } = - std::mem::replace(this.state, States::Invalid) - else { - unreachable!("bad state") - }; - - let mut fut = { - // make sure the client still exists - let client = match client.upgrade().ok_or_else(TransportErrorKind::backend_gone) { - Ok(client) => client, - Err(e) => return Poll::Ready(Err(e)), - }; - - // serialize the params - let ser = serde_json::to_value(params).map_err(RpcError::ser_err); - let mut ser = match ser { - Ok(ser) => ser, - Err(e) => return Poll::Ready(Err(e)), - }; - - // serialize the block id - let block_id = serde_json::to_value(block_id).map_err(RpcError::ser_err); - let block_id = match block_id { - Ok(block_id) => block_id, - Err(e) => return Poll::Ready(Err(e)), - }; - - // append the block id to the params - if let serde_json::Value::Array(ref mut arr) = ser { - arr.push(block_id); - } else if ser.is_null() { - ser = serde_json::Value::Array(vec![block_id]); - } else { - ser = serde_json::Value::Array(vec![ser, block_id]); - } - - // create the call - client.request(method.clone(), ser).map_resp(map) - }; - // poll the call immediately - match fut.poll_unpin(cx) { - Poll::Ready(value) => Poll::Ready(value), - Poll::Pending => { - *this.state = States::Running(fut); - Poll::Pending - } - } + /// Create a new [`RpcWithBlock`] from a [`RpcCall`]. + pub fn new_rpc(inner: RpcCall) -> Self { + Self { inner: WithBlockInner::RpcCall(inner), block_id: Default::default() } } - fn poll_running( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - let States::Running(call) = self.project().state else { unreachable!("bad state") }; - call.poll_unpin(cx) + /// Create a new [`RpcWithBlock`] from a closure producing a [`ProviderCall`]. + pub fn new_provider(get_call: F) -> Self + where + F: Fn(BlockId) -> ProviderCall, Resp, Output, Map> + + Send + + 'static, + { + let get_call = Box::new(get_call); + Self { inner: WithBlockInner::ProviderCall(get_call), block_id: Default::default() } } } -impl Future for RpcWithBlockFut +impl From> + for RpcWithBlock where T: Transport + Clone, Params: RpcParam, Resp: RpcReturn, - Output: 'static, - Map: Fn(Resp) -> Output, + Map: Fn(Resp) -> Output + Clone, { - type Output = TransportResult; - - fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { - if matches!(self.state, States::Preparing { .. }) { - self.poll_preparing(cx) - } else if matches!(self.state, States::Running { .. }) { - self.poll_running(cx) - } else { - panic!("bad state") - } + fn from(inner: RpcCall) -> Self { + Self::new_rpc(inner) } } -/// An [`RpcCall`] that takes an optional [`BlockId`] parameter. By default -/// this will use "latest". -#[derive(Debug, Clone)] -pub struct RpcWithBlock Output> -where - T: Transport + Clone, - Params: RpcParam, - Resp: RpcReturn, - Map: Fn(Resp) -> Output, -{ - client: WeakClient, - method: Cow<'static, str>, - params: Params, - block_id: BlockId, - map: Map, - _pd: PhantomData (Resp, Output)>, -} - -impl RpcWithBlock +impl From for RpcWithBlock where T: Transport + Clone, Params: RpcParam, Resp: RpcReturn, + Map: Fn(Resp) -> Output + Clone, + F: Fn(BlockId) -> ProviderCall, Resp, Output, Map> + Send + 'static, { - /// Create a new [`RpcWithBlock`] instance. - pub fn new( - client: WeakClient, - method: impl Into>, - params: Params, - ) -> Self { - Self { - client, - method: method.into(), - params, - block_id: Default::default(), - map: std::convert::identity, - _pd: PhantomData, - } + fn from(inner: F) -> Self { + Self::new_provider(inner) } } @@ -205,26 +142,8 @@ where T: Transport + Clone, Params: RpcParam, Resp: RpcReturn, - Map: Fn(Resp) -> Output, + Map: Fn(Resp) -> Output + Clone, { - /// Map the response. - pub fn map_resp( - self, - map: NewMap, - ) -> RpcWithBlock - where - NewMap: Fn(Resp) -> NewOutput, - { - RpcWithBlock { - client: self.client, - method: self.method, - params: self.params, - block_id: self.block_id, - map, - _pd: PhantomData, - } - } - /// Set the block id. pub const fn block_id(mut self, block_id: BlockId) -> Self { self.block_id = block_id; @@ -280,21 +199,20 @@ where Params: RpcParam, Resp: RpcReturn, Output: 'static, - Map: Fn(Resp) -> Output, + Map: Fn(Resp) -> Output + Clone, { type Output = TransportResult; - type IntoFuture = RpcWithBlockFut; + type IntoFuture = ProviderCall, Resp, Output, Map>; fn into_future(self) -> Self::IntoFuture { - RpcWithBlockFut { - state: States::Preparing { - client: self.client, - method: self.method, - params: self.params, - block_id: self.block_id, - map: self.map, - }, + match self.inner { + WithBlockInner::RpcCall(rpc_call) => { + let block_id = self.block_id; + let rpc_call = rpc_call.map_params(|params| ParamsWithBlock { params, block_id }); + ProviderCall::RpcCall(rpc_call) + } + WithBlockInner::ProviderCall(get_call) => get_call(self.block_id), } } } diff --git a/crates/rpc-client/src/batch.rs b/crates/rpc-client/src/batch.rs index 7affdac2994..0993d22825f 100644 --- a/crates/rpc-client/src/batch.rs +++ b/crates/rpc-client/src/batch.rs @@ -4,7 +4,8 @@ use alloy_json_rpc::{ RpcReturn, SerializedRequest, }; use alloy_transport::{Transport, TransportError, TransportErrorKind, TransportResult}; -use futures::channel::oneshot; +use futures::FutureExt; +use pin_project::pin_project; use serde_json::value::RawValue; use std::{ borrow::Cow, @@ -12,8 +13,12 @@ use std::{ future::{Future, IntoFuture}, marker::PhantomData, pin::Pin, - task::{self, ready, Poll}, + task::{ + self, ready, + Poll::{self, Ready}, + }, }; +use tokio::sync::oneshot; pub(crate) type Channel = oneshot::Sender>>; pub(crate) type ChannelMap = HashMap; @@ -35,29 +40,58 @@ pub struct BatchRequest<'a, T> { /// Awaits a single response for a request that has been included in a batch. #[must_use = "A Waiter does nothing unless the corresponding BatchRequest is sent via `send_batch` and `.await`, AND the Waiter is awaited."] +#[pin_project] #[derive(Debug)] -pub struct Waiter { +pub struct Waiter Output> { + #[pin] rx: oneshot::Receiver>>, - _resp: PhantomData Resp>, + map: Option, + _resp: PhantomData (Output, Resp)>, +} + +impl Waiter { + /// Map the response to a different type. This is usable for converting + /// the response to a more usable type, e.g. changing `U64` to `u64`. + /// + /// ## Note + /// + /// Carefully review the rust documentation on [fn pointers] before passing + /// them to this function. Unless the pointer is specifically coerced to a + /// `fn(_) -> _`, the `NewMap` will be inferred as that function's unique + /// type. This can lead to confusing error messages. + /// + /// [fn pointers]: https://doc.rust-lang.org/std/primitive.fn.html#creating-function-pointers + pub fn map_resp(self, map: NewMap) -> Waiter + where + NewMap: FnOnce(Resp) -> NewOutput, + { + Waiter { rx: self.rx, map: Some(map), _resp: PhantomData } + } } impl From>>> for Waiter { fn from(rx: oneshot::Receiver>>) -> Self { - Self { rx, _resp: PhantomData } + Self { rx, map: Some(std::convert::identity), _resp: PhantomData } } } -impl std::future::Future for Waiter +impl std::future::Future for Waiter where Resp: RpcReturn, + Map: FnOnce(Resp) -> Output, { - type Output = TransportResult; + type Output = TransportResult; + + fn poll(self: std::pin::Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + let this = self.get_mut(); - fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { - Pin::new(&mut self.rx).poll(cx).map(|resp| match resp { - Ok(resp) => try_deserialize_ok(resp), - Err(e) => Err(TransportErrorKind::custom(e)), - }) + match ready!(this.rx.poll_unpin(cx)) { + Ok(resp) => { + let resp: Result = try_deserialize_ok(resp); + Ready(resp.map(this.map.take().expect("polled after completion"))) + } + Err(e) => Poll::Ready(Err(TransportErrorKind::custom(e))), + } } } diff --git a/crates/rpc-client/src/call.rs b/crates/rpc-client/src/call.rs index a1efca47f50..8c9efc60fab 100644 --- a/crates/rpc-client/src/call.rs +++ b/crates/rpc-client/src/call.rs @@ -4,13 +4,14 @@ use alloy_json_rpc::{ }; use alloy_transport::{RpcFut, Transport, TransportError, TransportResult}; use core::panic; +use futures::FutureExt; use serde_json::value::RawValue; use std::{ fmt, future::Future, marker::PhantomData, pin::Pin, - task::{self, Poll::Ready}, + task::{self, ready, Poll::Ready}, }; use tower::Service; @@ -139,11 +140,11 @@ pub struct RpcCall Output> where Conn: Transport + Clone, Params: RpcParam, - Map: Fn(Resp) -> Output, + Map: FnOnce(Resp) -> Output, { #[pin] state: CallState, - map: Map, + map: Option, _pd: core::marker::PhantomData (Resp, Output)>, } @@ -151,7 +152,7 @@ impl core::fmt::Debug for RpcCall Output, + Map: FnOnce(Resp) -> Output, { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("RpcCall").field("state", &self.state).finish() @@ -167,7 +168,7 @@ where pub fn new(req: Request, connection: Conn) -> Self { Self { state: CallState::Prepared { request: Some(req), connection }, - map: std::convert::identity, + map: Some(std::convert::identity), _pd: PhantomData, } } @@ -177,17 +178,27 @@ impl RpcCall where Conn: Transport + Clone, Params: RpcParam, - Map: Fn(Resp) -> Output, + Map: FnOnce(Resp) -> Output, { - /// Set a function to map the response into a different type. + /// Map the response to a different type. This is usable for converting + /// the response to a more usable type, e.g. changing `U64` to `u64`. + /// + /// ## Note + /// + /// Carefully review the rust documentation on [fn pointers] before passing + /// them to this function. Unless the pointer is specifically coerced to a + /// `fn(_) -> _`, the `NewMap` will be inferred as that function's unique + /// type. This can lead to confusing error messages. + /// + /// [fn pointers]: https://doc.rust-lang.org/std/primitive.fn.html#creating-function-pointers pub fn map_resp( self, map: NewMap, ) -> RpcCall where - NewMap: Fn(Resp) -> NewOutput, + NewMap: FnOnce(Resp) -> NewOutput, { - RpcCall { state: self.state, map, _pd: PhantomData } + RpcCall { state: self.state, map: Some(map), _pd: PhantomData } } /// Returns `true` if the request is a subscription. @@ -249,20 +260,37 @@ where }; request.as_mut().expect("no request in prepared") } + + /// Map the params of the request into a new type. + pub fn map_params( + self, + map: impl Fn(Params) -> NewParams, + ) -> RpcCall { + let CallState::Prepared { request, connection } = self.state else { + panic!("Cannot get request after request has been sent"); + }; + let request = request.expect("no request in prepared").map_params(map); + RpcCall { + state: CallState::Prepared { request: Some(request), connection }, + map: self.map, + _pd: PhantomData, + } + } } impl RpcCall where Conn: Transport + Clone, - Params: RpcParam + Clone, - Map: Fn(Resp) -> Output, + Params: RpcParam + ToOwned, + Params::Owned: RpcParam, + Map: FnOnce(Resp) -> Output, { /// Convert this call into one with owned params, by cloning the params. /// /// # Panics /// - /// Panics if called after the request has been sent. - pub fn into_owned_params(self) -> RpcCall { + /// Panics if called after the request has been polled. + pub fn into_owned_params(self) -> RpcCall { let CallState::Prepared { request, connection } = self.state else { panic!("Cannot get params after request has been sent"); }; @@ -282,7 +310,7 @@ where Params: RpcParam + 'a, Resp: RpcReturn, Output: 'static, - Map: Fn(Resp) -> Output + Send + 'a, + Map: FnOnce(Resp) -> Output + Send + 'a, { /// Convert this future into a boxed, pinned future, erasing its type. pub fn boxed(self) -> RpcFut<'a, Output> { @@ -296,13 +324,16 @@ where Params: RpcParam, Resp: RpcReturn, Output: 'static, - Map: Fn(Resp) -> Output, + Map: FnOnce(Resp) -> Output, { type Output = TransportResult; fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll { trace!(?self.state, "polling RpcCall"); - let this = self.project(); - this.state.poll(cx).map(try_deserialize_ok).map(|r| r.map(this.map)) + + let this = self.get_mut(); + let resp = try_deserialize_ok(ready!(this.state.poll_unpin(cx))); + + Ready(resp.map(this.map.take().expect("polled after completion"))) } }