Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Add dev-rpc middleware #640

Merged
merged 5 commits into from
Dec 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
- Fix `http Provider` data race when generating new request `id`s.
- Add support for `net_version` RPC method.
[595](https://github.com/gakonst/ethers-rs/pull/595)
- Add support for `evm_snapshot` and `evm_revert` dev RPC methods.
[640](https://github.com/gakonst/ethers-rs/pull/640)

### Unreleased

Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ ws = ["ethers-providers/ws"]
ipc = ["ethers-providers/ipc"]
rustls = ["ethers-providers/rustls"]
openssl = ["ethers-providers/openssl"]
dev-rpc = ["ethers-providers/dev-rpc"]
## signers
ledger = ["ethers-signers/ledger"]
yubi = ["ethers-signers/yubi"]
Expand Down
1 change: 1 addition & 0 deletions ethers-providers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ ipc = ["tokio", "tokio/io-util", "tokio-util", "bytes"]

openssl = ["tokio-tungstenite/native-tls", "reqwest/native-tls"]
rustls = ["tokio-tungstenite/rustls-tls", "reqwest/rustls-tls"]
dev-rpc = []
4 changes: 4 additions & 0 deletions ethers-providers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ use std::{error::Error, fmt::Debug, future::Future, pin::Pin, str::FromStr};

pub use provider::{FilterKind, Provider, ProviderError};

// feature-enabled support for dev-rpc methods
#[cfg(feature = "dev-rpc")]
pub use provider::dev_rpc::DevRpcMiddleware;

/// A simple gas escalation policy
pub type EscalationPolicy = Box<dyn Fn(U256, usize) -> U256 + Send + Sync>;

Expand Down
177 changes: 177 additions & 0 deletions ethers-providers/src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,183 @@ impl TryFrom<String> for Provider<HttpProvider> {
}
}

/// A middleware supporting development-specific JSON RPC methods
///
/// # Example
///
///```
/// use ethers_providers::{Provider, Http, Middleware, DevRpcMiddleware};
/// use ethers_core::types::TransactionRequest;
/// use ethers_core::utils::Ganache;
/// use std::convert::TryFrom;
///
/// # #[tokio::main]
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// let ganache = Ganache::new().spawn();
/// let provider = Provider::<Http>::try_from(ganache.endpoint()).unwrap();
/// let client = DevRpcMiddleware::new(provider);
///
/// // snapshot the initial state
/// let block0 = client.get_block_number().await.unwrap();
/// let snap_id = client.snapshot().await.unwrap();
///
/// // send a transaction
/// let accounts = client.get_accounts().await?;
/// let from = accounts[0];
/// let to = accounts[1];
/// let balance_before = client.get_balance(to, None).await?;
/// let tx = TransactionRequest::new().to(to).value(1000).from(from);
/// client.send_transaction(tx, None).await?.await?;
/// let balance_after = client.get_balance(to, None).await?;
/// assert_eq!(balance_after, balance_before + 1000);
///
/// // revert to snapshot
/// client.revert_to_snapshot(snap_id).await.unwrap();
/// let balance_after_revert = client.get_balance(to, None).await?;
/// assert_eq!(balance_after_revert, balance_before);
/// # Ok(())
/// # }
/// ```
#[cfg(feature = "dev-rpc")]
pub mod dev_rpc {
use crate::{FromErr, Middleware, ProviderError};
use async_trait::async_trait;
use ethers_core::types::U256;
use thiserror::Error;

use std::fmt::Debug;

#[derive(Clone, Debug)]
pub struct DevRpcMiddleware<M>(M);

#[derive(Error, Debug)]
pub enum DevRpcMiddlewareError<M: Middleware> {
#[error("{0}")]
MiddlewareError(M::Error),

#[error("{0}")]
ProviderError(ProviderError),

#[error("Could not revert to snapshot")]
NoSnapshot,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I misunderstood this lint, but it seems to not jive with the rest of the Error naming convention for ethers.

}

#[async_trait]
impl<M: Middleware> Middleware for DevRpcMiddleware<M> {
type Error = DevRpcMiddlewareError<M>;
type Provider = M::Provider;
type Inner = M;

fn inner(&self) -> &M {
&self.0
}
}

impl<M: Middleware> FromErr<M::Error> for DevRpcMiddlewareError<M> {
fn from(src: M::Error) -> DevRpcMiddlewareError<M> {
DevRpcMiddlewareError::MiddlewareError(src)
}
}

impl<M> From<ProviderError> for DevRpcMiddlewareError<M>
where
M: Middleware,
{
fn from(src: ProviderError) -> Self {
Self::ProviderError(src)
}
}

impl<M: Middleware> DevRpcMiddleware<M> {
pub fn new(inner: M) -> Self {
Self(inner)
}

// both ganache and hardhat increment snapshot id even if no state has changed
pub async fn snapshot(&self) -> Result<U256, DevRpcMiddlewareError<M>> {
self.provider().request::<(), U256>("evm_snapshot", ()).await.map_err(From::from)
}

pub async fn revert_to_snapshot(&self, id: U256) -> Result<(), DevRpcMiddlewareError<M>> {
let ok = self
.provider()
.request::<[U256; 1], bool>("evm_revert", [id])
.await
.map_err(DevRpcMiddlewareError::ProviderError)?;
if ok {
Ok(())
} else {
Err(DevRpcMiddlewareError::NoSnapshot)
}
}
}
#[cfg(test)]
// Celo blocks can not get parsed when used with Ganache
#[cfg(not(feature = "celo"))]
mod tests {
use super::*;
use crate::{Http, Provider};
use ethers_core::utils::Ganache;
use std::convert::TryFrom;

#[tokio::test]
async fn test_snapshot() {
// launch ganache
let ganache = Ganache::new().spawn();
let provider = Provider::<Http>::try_from(ganache.endpoint()).unwrap();
let client = DevRpcMiddleware::new(provider);

// snapshot initial state
let block0 = client.get_block_number().await.unwrap();
let time0 = client.get_block(block0).await.unwrap().unwrap().timestamp;
let snap_id0 = client.snapshot().await.unwrap();

// mine a new block
client.provider().mine(1).await.unwrap();

// snapshot state
let block1 = client.get_block_number().await.unwrap();
let time1 = client.get_block(block1).await.unwrap().unwrap().timestamp;
let snap_id1 = client.snapshot().await.unwrap();

// mine some blocks
client.provider().mine(5).await.unwrap();

// snapshot state
let block2 = client.get_block_number().await.unwrap();
let time2 = client.get_block(block2).await.unwrap().unwrap().timestamp;
let snap_id2 = client.snapshot().await.unwrap();

// mine some blocks
client.provider().mine(5).await.unwrap();

// revert_to_snapshot should reset state to snap id
client.revert_to_snapshot(snap_id2).await.unwrap();
let block = client.get_block_number().await.unwrap();
let time = client.get_block(block).await.unwrap().unwrap().timestamp;
assert_eq!(block, block2);
assert_eq!(time, time2);

client.revert_to_snapshot(snap_id1).await.unwrap();
let block = client.get_block_number().await.unwrap();
let time = client.get_block(block).await.unwrap().unwrap().timestamp;
assert_eq!(block, block1);
assert_eq!(time, time1);

// revert_to_snapshot should throw given non-existent or
// previously used snapshot
let result = client.revert_to_snapshot(snap_id1).await;
assert!(result.is_err());

client.revert_to_snapshot(snap_id0).await.unwrap();
let block = client.get_block_number().await.unwrap();
let time = client.get_block(block).await.unwrap().unwrap().timestamp;
assert_eq!(block, block0);
assert_eq!(time, time0);
}
}
}

#[cfg(test)]
#[cfg(not(target_arch = "wasm32"))]
mod tests {
Expand Down