diff --git a/README.md b/README.md index 5eb4aa5..30cc767 100644 --- a/README.md +++ b/README.md @@ -233,9 +233,18 @@ export type SimulationRequest = { gasLimit: number; value: string; blockNumber?: number; // if not specified, latest used, + stateOverrides?: Record; formatTrace?: boolean; }; +export type StateOverride = { + balance?: string; + nonce?: number; + code?: string; + state?: Record; + stateDiff?: Record; +}; + export type SimulationResponse = { simulationId: string; gasUsed: number; diff --git a/src/errors.rs b/src/errors.rs index 2265674..68b1ef4 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -10,16 +10,6 @@ pub struct ErrorMessage { pub message: String, } -#[derive(Debug)] -pub struct FromHexError; - -impl Reject for FromHexError {} - -#[derive(Debug)] -pub struct FromDecStrError; - -impl Reject for FromDecStrError {} - #[derive(Debug)] pub struct NoURLForChainIdError; @@ -50,6 +40,11 @@ pub struct StateNotFound(); impl Reject for StateNotFound {} +#[derive(Debug)] +pub struct OverrideError; + +impl Reject for OverrideError {} + #[derive(Debug)] pub struct EvmError(pub Report); @@ -65,12 +60,6 @@ pub async fn handle_rejection(err: Rejection) -> Result } else if let Some(_e) = err.find::() { code = StatusCode::NOT_FOUND; message = "STATE_NOT_FOUND".to_string(); - } else if let Some(FromHexError) = err.find() { - code = StatusCode::BAD_REQUEST; - message = "FROM_HEX_ERROR".to_string(); - } else if let Some(FromDecStrError) = err.find() { - code = StatusCode::BAD_REQUEST; - message = "FROM_DEC_STR_ERROR".to_string(); } else if let Some(NoURLForChainIdError) = err.find() { code = StatusCode::BAD_REQUEST; message = "CHAIN_ID_NOT_SUPPORTED".to_string(); @@ -86,6 +75,9 @@ pub async fn handle_rejection(err: Rejection) -> Result } else if let Some(_e) = err.find::() { code = StatusCode::BAD_REQUEST; message = "INVALID_BLOCK_NUMBERS".to_string(); + } else if let Some(_e) = err.find::() { + code = StatusCode::INTERNAL_SERVER_ERROR; + message = "OVERRIDE_ERROR".to_string(); } else if let Some(_e) = err.find::() { if _e.0.to_string().contains("CallGasCostMoreThanGasLimit") { code = StatusCode::BAD_REQUEST; diff --git a/src/evm.rs b/src/evm.rs index b460870..3720977 100644 --- a/src/evm.rs +++ b/src/evm.rs @@ -1,4 +1,6 @@ -use ethers::abi::{Address, Uint}; +use std::collections::HashMap; + +use ethers::abi::{Address, Hash, Uint}; use ethers::core::types::Log; use ethers::types::Bytes; use foundry_config::Chain; @@ -7,10 +9,13 @@ use foundry_evm::executor::{opts::EvmOpts, Backend, ExecutorBuilder}; use foundry_evm::trace::identifier::{EtherscanIdentifier, SignaturesIdentifier}; use foundry_evm::trace::node::CallTraceNode; use foundry_evm::trace::{CallTraceArena, CallTraceDecoder, CallTraceDecoderBuilder}; +use foundry_evm::utils::{h160_to_b160, u256_to_ru256}; +use revm::db::DatabaseRef; use revm::interpreter::InstructionResult; -use revm::primitives::Env; +use revm::primitives::{Account, Bytecode, Env, StorageSlot}; +use revm::DatabaseCommit; -use crate::errors::EvmError; +use crate::errors::{EvmError, OverrideError}; use crate::simulation::CallTrace; #[derive(Debug, Clone)] @@ -36,6 +41,12 @@ impl From for CallTrace { } } +#[derive(Debug, Clone, PartialEq)] +pub struct StorageOverride { + pub slots: HashMap, + pub diff: bool, +} + pub struct Evm { executor: Executor, decoder: CallTraceDecoder, @@ -155,6 +166,56 @@ impl Evm { }) } + pub fn override_account( + &mut self, + address: Address, + balance: Option, + nonce: Option, + code: Option, + storage: Option, + ) -> Result<(), OverrideError> { + let address = h160_to_b160(address); + let mut account = Account { + info: self + .executor + .backend() + .basic(address) + .map_err(|_| OverrideError)? + .unwrap_or_default(), + ..Account::new_not_existing() + }; + + if let Some(balance) = balance { + account.info.balance = u256_to_ru256(balance); + } + if let Some(nonce) = nonce { + account.info.nonce = nonce; + } + if let Some(code) = code { + account.info.code = Some(Bytecode::new_raw(code.to_vec().into())); + } + if let Some(storage) = storage { + // If we do a "full storage override", make sure to set this flag so + // that existing storage slots are cleared, and unknown ones aren't + // fetched from the forked node. + account.storage_cleared = !storage.diff; + account + .storage + .extend(storage.slots.into_iter().map(|(key, value)| { + ( + u256_to_ru256(Uint::from_big_endian(key.as_bytes())), + StorageSlot::new(u256_to_ru256(value)), + ) + })); + } + + self.executor + .backend_mut() + .commit([(address, account)].into_iter().collect()); + + Ok(()) + } + pub async fn call_raw_committing( &mut self, from: Address, diff --git a/src/simulation.rs b/src/simulation.rs index 4cd7c48..6ab293b 100644 --- a/src/simulation.rs +++ b/src/simulation.rs @@ -1,8 +1,9 @@ +use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; use dashmap::mapref::one::RefMut; -use ethers::abi::{Address, Uint}; +use ethers::abi::{Address, Hash, Uint}; use ethers::core::types::Log; use ethers::types::Bytes; use foundry_evm::CallKind; @@ -10,14 +11,14 @@ use revm::interpreter::InstructionResult; use serde::{Deserialize, Serialize}; use tokio::sync::Mutex; use uuid::Uuid; -use warp::reject::custom; use warp::reply::Json; use warp::Rejection; use crate::errors::{ - FromDecStrError, FromHexError, IncorrectChainIdError, InvalidBlockNumbersError, - MultipleChainIdsError, NoURLForChainIdError, StateNotFound, + IncorrectChainIdError, InvalidBlockNumbersError, MultipleChainIdsError, NoURLForChainIdError, + StateNotFound, }; +use crate::evm::StorageOverride; use crate::SharedSimulationState; use super::config::Config; @@ -31,8 +32,9 @@ pub struct SimulationRequest { pub to: Address, pub data: Option, pub gas_limit: u64, - pub value: Option, + pub value: Option, pub block_number: Option, + pub state_overrides: Option>, pub format_trace: Option, } @@ -69,6 +71,44 @@ pub struct StatefulSimulationEndResponse { pub success: bool, } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct StateOverride { + pub balance: Option, + pub nonce: Option, + pub code: Option, + #[serde(flatten)] + pub state: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(untagged)] +pub enum State { + Full { + state: HashMap, + }, + #[serde(rename_all = "camelCase")] + Diff { + state_diff: HashMap, + }, +} + +impl From for StorageOverride { + fn from(value: State) -> Self { + let (slots, diff) = match value { + State::Full { state } => (state, false), + State::Diff { state_diff } => (state_diff, true), + }; + + StorageOverride { + slots: slots + .into_iter() + .map(|(key, value)| (key, value.into())) + .collect(), + diff, + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(rename_all = "camelCase")] pub struct CallTrace { @@ -78,6 +118,32 @@ pub struct CallTrace { pub value: Uint, } +#[derive(Debug, Default, Clone, Copy, Serialize, PartialEq)] +#[serde(transparent)] +pub struct PermissiveUint(pub Uint); + +impl From for Uint { + fn from(value: PermissiveUint) -> Self { + value.0 + } +} + +impl<'de> Deserialize<'de> for PermissiveUint { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + // Accept value in hex or decimal formats + let value = String::deserialize(deserializer)?; + let parsed = if value.starts_with("0x") { + Uint::from_str(&value).map_err(serde::de::Error::custom)? + } else { + Uint::from_dec_str(&value).map_err(serde::de::Error::custom)? + }; + Ok(Self(parsed)) + } +} + fn chain_id_to_fork_url(chain_id: u64) -> Result { match chain_id { // ethereum @@ -113,22 +179,21 @@ async fn run( transaction: SimulationRequest, commit: bool, ) -> Result { - // Accept value in hex or decimal formats - let value = if let Some(value) = transaction.value { - if value.starts_with("0x") { - Some(Uint::from_str(value.as_str()).map_err(|_err| custom(FromHexError))?) - } else { - Some(Uint::from_dec_str(value.as_str()).map_err(|_err| custom(FromDecStrError))?) - } - } else { - None - }; + for (address, state_override) in transaction.state_overrides.into_iter().flatten() { + evm.override_account( + address, + state_override.balance.map(Uint::from), + state_override.nonce, + state_override.code, + state_override.state.map(StorageOverride::from), + )?; + } let result = if commit { evm.call_raw_committing( transaction.from, transaction.to, - value, + transaction.value.map(Uint::from), transaction.data, transaction.gas_limit, transaction.format_trace.unwrap_or_default(), @@ -138,7 +203,7 @@ async fn run( evm.call_raw( transaction.from, transaction.to, - value, + transaction.value.map(Uint::from), transaction.data, transaction.format_trace.unwrap_or_default(), ) diff --git a/tests/api.rs b/tests/api.rs index 2c1ba6c..32b2bb7 100644 --- a/tests/api.rs +++ b/tests/api.rs @@ -154,6 +154,41 @@ async fn post_simulate_zerox_swap() { assert!(body.success); } +#[tokio::test(flavor = "multi_thread")] +async fn post_simulate_state_overrides() { + let filter = filter(); + + let json = serde_json::json!({ + "chainId": 1, + "from": "0xd8dA6BF26964aF9D7eEd9e03E53415D37aA96045", + "to": "0xDEf1CA1fb7FBcDC777520aa7f396b4E015F497aB", + "data": "0x70a08231000000000000000000000000d8da6bf26964af9d7eed9e03e53415d37aa96045", + "gasLimit": 5000000, + "stateOverrides": { + "0xDEf1CA1fb7FBcDC777520aa7f396b4E015F497aB": { + "stateDiff": { + "0xfca351f4d96129454cfc8ef7930b638ac71fea35eb69ee3b8d959496beb04a33": + "123456789012345678901234567890" + } + } + } + }); + + let res = warp::test::request() + .method("POST") + .path("/simulate") + .json(&json) + .reply(&filter) + .await; + + assert_eq!(res.status(), 200); + + let body: SimulationResponse = serde_json::from_slice(res.body()).unwrap(); + let result = U256::from_big_endian(&body.return_data); + + assert_eq!(result.as_u128(), 123456789012345678901234567890); +} + #[tokio::test(flavor = "multi_thread")] async fn post_simulate_bundle_single_zerox_swap() { let filter = filter();