Skip to content

Commit

Permalink
Implement State Overrides for Simulations (#18)
Browse files Browse the repository at this point in the history
* Implement State Overrides for Simulations

* README documentation and E2E test

* Easier to see override in output
  • Loading branch information
nlordell authored Sep 25, 2023
1 parent c189669 commit b4eea73
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 36 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,18 @@ export type SimulationRequest = {
gasLimit: number;
value: string;
blockNumber?: number; // if not specified, latest used,
stateOverrides?: Record<string, StateOverride>;
formatTrace?: boolean;
};

export type StateOverride = {
balance?: string;
nonce?: number;
code?: string;
state?: Record<string, string>;
stateDiff?: Record<string, string>;
};

export type SimulationResponse = {
simulationId: string;
gasUsed: number;
Expand Down
24 changes: 8 additions & 16 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,6 @@ pub struct ErrorMessage {
pub message: String,
}

#[derive(Debug)]
pub struct FromHexError;

impl Reject for FromHexError {}

#[derive(Debug)]
pub struct FromDecStrError;

impl Reject for FromDecStrError {}

#[derive(Debug)]
pub struct NoURLForChainIdError;

Expand Down Expand Up @@ -50,6 +40,11 @@ pub struct StateNotFound();

impl Reject for StateNotFound {}

#[derive(Debug)]
pub struct OverrideError;

impl Reject for OverrideError {}

#[derive(Debug)]
pub struct EvmError(pub Report);

Expand All @@ -65,12 +60,6 @@ pub async fn handle_rejection(err: Rejection) -> Result<impl Reply, Infallible>
} else if let Some(_e) = err.find::<StateNotFound>() {
code = StatusCode::NOT_FOUND;
message = "STATE_NOT_FOUND".to_string();
} else if let Some(FromHexError) = err.find() {
code = StatusCode::BAD_REQUEST;
message = "FROM_HEX_ERROR".to_string();
} else if let Some(FromDecStrError) = err.find() {
code = StatusCode::BAD_REQUEST;
message = "FROM_DEC_STR_ERROR".to_string();
} else if let Some(NoURLForChainIdError) = err.find() {
code = StatusCode::BAD_REQUEST;
message = "CHAIN_ID_NOT_SUPPORTED".to_string();
Expand All @@ -86,6 +75,9 @@ pub async fn handle_rejection(err: Rejection) -> Result<impl Reply, Infallible>
} else if let Some(_e) = err.find::<InvalidBlockNumbersError>() {
code = StatusCode::BAD_REQUEST;
message = "INVALID_BLOCK_NUMBERS".to_string();
} else if let Some(_e) = err.find::<OverrideError>() {
code = StatusCode::INTERNAL_SERVER_ERROR;
message = "OVERRIDE_ERROR".to_string();
} else if let Some(_e) = err.find::<EvmError>() {
if _e.0.to_string().contains("CallGasCostMoreThanGasLimit") {
code = StatusCode::BAD_REQUEST;
Expand Down
67 changes: 64 additions & 3 deletions src/evm.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use ethers::abi::{Address, Uint};
use std::collections::HashMap;

use ethers::abi::{Address, Hash, Uint};
use ethers::core::types::Log;
use ethers::types::Bytes;
use foundry_config::Chain;
Expand All @@ -7,10 +9,13 @@ use foundry_evm::executor::{opts::EvmOpts, Backend, ExecutorBuilder};
use foundry_evm::trace::identifier::{EtherscanIdentifier, SignaturesIdentifier};
use foundry_evm::trace::node::CallTraceNode;
use foundry_evm::trace::{CallTraceArena, CallTraceDecoder, CallTraceDecoderBuilder};
use foundry_evm::utils::{h160_to_b160, u256_to_ru256};
use revm::db::DatabaseRef;
use revm::interpreter::InstructionResult;
use revm::primitives::Env;
use revm::primitives::{Account, Bytecode, Env, StorageSlot};
use revm::DatabaseCommit;

use crate::errors::EvmError;
use crate::errors::{EvmError, OverrideError};
use crate::simulation::CallTrace;

#[derive(Debug, Clone)]
Expand All @@ -36,6 +41,12 @@ impl From<CallTraceNode> for CallTrace {
}
}

#[derive(Debug, Clone, PartialEq)]
pub struct StorageOverride {
pub slots: HashMap<Hash, Uint>,
pub diff: bool,
}

pub struct Evm {
executor: Executor,
decoder: CallTraceDecoder,
Expand Down Expand Up @@ -155,6 +166,56 @@ impl Evm {
})
}

pub fn override_account(
&mut self,
address: Address,
balance: Option<Uint>,
nonce: Option<u64>,
code: Option<Bytes>,
storage: Option<StorageOverride>,
) -> Result<(), OverrideError> {
let address = h160_to_b160(address);
let mut account = Account {
info: self
.executor
.backend()
.basic(address)
.map_err(|_| OverrideError)?
.unwrap_or_default(),
..Account::new_not_existing()
};

if let Some(balance) = balance {
account.info.balance = u256_to_ru256(balance);
}
if let Some(nonce) = nonce {
account.info.nonce = nonce;
}
if let Some(code) = code {
account.info.code = Some(Bytecode::new_raw(code.to_vec().into()));
}
if let Some(storage) = storage {
// If we do a "full storage override", make sure to set this flag so
// that existing storage slots are cleared, and unknown ones aren't
// fetched from the forked node.
account.storage_cleared = !storage.diff;
account
.storage
.extend(storage.slots.into_iter().map(|(key, value)| {
(
u256_to_ru256(Uint::from_big_endian(key.as_bytes())),
StorageSlot::new(u256_to_ru256(value)),
)
}));
}

self.executor
.backend_mut()
.commit([(address, account)].into_iter().collect());

Ok(())
}

pub async fn call_raw_committing(
&mut self,
from: Address,
Expand Down
99 changes: 82 additions & 17 deletions src/simulation.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;

use dashmap::mapref::one::RefMut;
use ethers::abi::{Address, Uint};
use ethers::abi::{Address, Hash, Uint};
use ethers::core::types::Log;
use ethers::types::Bytes;
use foundry_evm::CallKind;
use revm::interpreter::InstructionResult;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use uuid::Uuid;
use warp::reject::custom;
use warp::reply::Json;
use warp::Rejection;

use crate::errors::{
FromDecStrError, FromHexError, IncorrectChainIdError, InvalidBlockNumbersError,
MultipleChainIdsError, NoURLForChainIdError, StateNotFound,
IncorrectChainIdError, InvalidBlockNumbersError, MultipleChainIdsError, NoURLForChainIdError,
StateNotFound,
};
use crate::evm::StorageOverride;
use crate::SharedSimulationState;

use super::config::Config;
Expand All @@ -31,8 +32,9 @@ pub struct SimulationRequest {
pub to: Address,
pub data: Option<Bytes>,
pub gas_limit: u64,
pub value: Option<String>,
pub value: Option<PermissiveUint>,
pub block_number: Option<u64>,
pub state_overrides: Option<HashMap<Address, StateOverride>>,
pub format_trace: Option<bool>,
}

Expand Down Expand Up @@ -69,6 +71,44 @@ pub struct StatefulSimulationEndResponse {
pub success: bool,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct StateOverride {
pub balance: Option<PermissiveUint>,
pub nonce: Option<u64>,
pub code: Option<Bytes>,
#[serde(flatten)]
pub state: Option<State>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum State {
Full {
state: HashMap<Hash, PermissiveUint>,
},
#[serde(rename_all = "camelCase")]
Diff {
state_diff: HashMap<Hash, PermissiveUint>,
},
}

impl From<State> for StorageOverride {
fn from(value: State) -> Self {
let (slots, diff) = match value {
State::Full { state } => (state, false),
State::Diff { state_diff } => (state_diff, true),
};

StorageOverride {
slots: slots
.into_iter()
.map(|(key, value)| (key, value.into()))
.collect(),
diff,
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct CallTrace {
Expand All @@ -78,6 +118,32 @@ pub struct CallTrace {
pub value: Uint,
}

#[derive(Debug, Default, Clone, Copy, Serialize, PartialEq)]
#[serde(transparent)]
pub struct PermissiveUint(pub Uint);

impl From<PermissiveUint> for Uint {
fn from(value: PermissiveUint) -> Self {
value.0
}
}

impl<'de> Deserialize<'de> for PermissiveUint {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
// Accept value in hex or decimal formats
let value = String::deserialize(deserializer)?;
let parsed = if value.starts_with("0x") {
Uint::from_str(&value).map_err(serde::de::Error::custom)?
} else {
Uint::from_dec_str(&value).map_err(serde::de::Error::custom)?
};
Ok(Self(parsed))
}
}

fn chain_id_to_fork_url(chain_id: u64) -> Result<String, Rejection> {
match chain_id {
// ethereum
Expand Down Expand Up @@ -113,22 +179,21 @@ async fn run(
transaction: SimulationRequest,
commit: bool,
) -> Result<SimulationResponse, Rejection> {
// Accept value in hex or decimal formats
let value = if let Some(value) = transaction.value {
if value.starts_with("0x") {
Some(Uint::from_str(value.as_str()).map_err(|_err| custom(FromHexError))?)
} else {
Some(Uint::from_dec_str(value.as_str()).map_err(|_err| custom(FromDecStrError))?)
}
} else {
None
};
for (address, state_override) in transaction.state_overrides.into_iter().flatten() {
evm.override_account(
address,
state_override.balance.map(Uint::from),
state_override.nonce,
state_override.code,
state_override.state.map(StorageOverride::from),
)?;
}

let result = if commit {
evm.call_raw_committing(
transaction.from,
transaction.to,
value,
transaction.value.map(Uint::from),
transaction.data,
transaction.gas_limit,
transaction.format_trace.unwrap_or_default(),
Expand All @@ -138,7 +203,7 @@ async fn run(
evm.call_raw(
transaction.from,
transaction.to,
value,
transaction.value.map(Uint::from),
transaction.data,
transaction.format_trace.unwrap_or_default(),
)
Expand Down
35 changes: 35 additions & 0 deletions tests/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,41 @@ async fn post_simulate_zerox_swap() {
assert!(body.success);
}

#[tokio::test(flavor = "multi_thread")]
async fn post_simulate_state_overrides() {
let filter = filter();

let json = serde_json::json!({
"chainId": 1,
"from": "0xd8dA6BF26964aF9D7eEd9e03E53415D37aA96045",
"to": "0xDEf1CA1fb7FBcDC777520aa7f396b4E015F497aB",
"data": "0x70a08231000000000000000000000000d8da6bf26964af9d7eed9e03e53415d37aa96045",
"gasLimit": 5000000,
"stateOverrides": {
"0xDEf1CA1fb7FBcDC777520aa7f396b4E015F497aB": {
"stateDiff": {
"0xfca351f4d96129454cfc8ef7930b638ac71fea35eb69ee3b8d959496beb04a33":
"123456789012345678901234567890"
}
}
}
});

let res = warp::test::request()
.method("POST")
.path("/simulate")
.json(&json)
.reply(&filter)
.await;

assert_eq!(res.status(), 200);

let body: SimulationResponse = serde_json::from_slice(res.body()).unwrap();
let result = U256::from_big_endian(&body.return_data);

assert_eq!(result.as_u128(), 123456789012345678901234567890);
}

#[tokio::test(flavor = "multi_thread")]
async fn post_simulate_bundle_single_zerox_swap() {
let filter = filter();
Expand Down

0 comments on commit b4eea73

Please sign in to comment.