Skip to content

Commit

Permalink
feat(cheatcodes): specify reverter address in expectReverts
Browse files Browse the repository at this point in the history
  • Loading branch information
grandizzy committed Aug 29, 2024
1 parent 0d83028 commit 8479102
Show file tree
Hide file tree
Showing 6 changed files with 337 additions and 49 deletions.
20 changes: 20 additions & 0 deletions crates/cheatcodes/assets/cheatcodes.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions crates/cheatcodes/spec/src/vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,10 @@ interface Vm {
#[cheatcode(group = Testing, safety = Unsafe)]
function expectRevert(bytes calldata revertData) external;

/// Expects an error with any revert data on next call to reverter address.
#[cheatcode(group = Testing, safety = Unsafe)]
function expectRevert(address reverter) external;

/// Expects an error on next call that starts with the revert data.
#[cheatcode(group = Testing, safety = Unsafe)]
function expectPartialRevert(bytes4 revertData) external;
Expand Down
71 changes: 28 additions & 43 deletions crates/cheatcodes/src/inspector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ impl Cheatcodes {
fn create_end_common<DB>(
&mut self,
ecx: &mut EvmContext<DB>,
mut outcome: CreateOutcome,
outcome: CreateOutcome,
) -> CreateOutcome
where
DB: DatabaseExt,
Expand Down Expand Up @@ -637,28 +637,11 @@ impl Cheatcodes {
if ecx.journaled_state.depth() <= expected_revert.depth &&
matches!(expected_revert.kind, ExpectedRevertKind::Default)
{
let expected_revert = std::mem::take(&mut self.expected_revert).unwrap();
return match expect::handle_expect_revert(
false,
true,
expected_revert.reason.as_deref(),
expected_revert.partial_match,
outcome.result.result,
outcome.result.output.clone(),
return expect::process_create_expect_revert(
std::mem::take(&mut self.expected_revert).unwrap(),
&self.config.available_artifacts,
) {
Ok((address, retdata)) => {
outcome.result.result = InstructionResult::Return;
outcome.result.output = retdata;
outcome.address = address;
outcome
}
Err(err) => {
outcome.result.result = InstructionResult::Revert;
outcome.result.output = err.abi_encode().into();
outcome
}
};
outcome,
)
}
}

Expand Down Expand Up @@ -1107,7 +1090,25 @@ impl<DB: DatabaseExt> Inspector<DB> for Cheatcodes {
}

// Handle expected reverts
if let Some(expected_revert) = &self.expected_revert {
if let Some(expected_revert) = &mut self.expected_revert {
// Record address that reverted the call (to compare later with expected reverter).
if expected_revert.reverted_by.is_none() && outcome.result.result.is_revert() {
expected_revert.reverted_by = Some(call.target_address);
}

// If call target address is the expected reverter address (if set), then process
// expected revert.
if let Some(reverter) = expected_revert.reverter {
if call.target_address == reverter {
return expect::process_call_expect_revert(
std::mem::take(&mut self.expected_revert).unwrap(),
cheatcode_call,
&self.config.available_artifacts,
outcome,
);
}
}

if ecx.journaled_state.depth() <= expected_revert.depth {
let needs_processing = match expected_revert.kind {
ExpectedRevertKind::Default => !cheatcode_call,
Expand All @@ -1119,28 +1120,12 @@ impl<DB: DatabaseExt> Inspector<DB> for Cheatcodes {
};

if needs_processing {
let expected_revert = std::mem::take(&mut self.expected_revert).unwrap();
return match expect::handle_expect_revert(
return expect::process_call_expect_revert(
std::mem::take(&mut self.expected_revert).unwrap(),
cheatcode_call,
false,
expected_revert.reason.as_deref(),
expected_revert.partial_match,
outcome.result.result,
outcome.result.output.clone(),
&self.config.available_artifacts,
) {
Err(error) => {
trace!(expected=?expected_revert, ?error, status=?outcome.result.result, "Expected revert mismatch");
outcome.result.result = InstructionResult::Revert;
outcome.result.output = error.abi_encode().into();
outcome
}
Ok((_, retdata)) => {
outcome.result.result = InstructionResult::Return;
outcome.result.output = retdata;
outcome
}
};
outcome,
);
}

// Flip `pending_processing` flag for cheatcode revert expectations, marking that
Expand Down
147 changes: 141 additions & 6 deletions crates/cheatcodes/src/test/expect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use alloy_sol_types::{SolError, SolValue};
use foundry_common::ContractsByArtifact;
use foundry_evm_core::decode::RevertDecoder;
use revm::interpreter::{
return_ok, InstructionResult, Interpreter, InterpreterAction, InterpreterResult,
return_ok, CallOutcome, CreateOutcome, InstructionResult, Interpreter, InterpreterAction,
InterpreterResult,
};
use spec::Vm;
use std::collections::{hash_map::Entry, HashMap};
Expand Down Expand Up @@ -80,6 +81,10 @@ pub struct ExpectedRevert {
pub kind: ExpectedRevertKind,
/// If true then only the first 4 bytes of expected data returned by the revert are checked.
pub partial_match: bool,
/// Contract expected to revert current call.
pub reverter: Option<Address>,
/// Contract that reverted current call.
pub reverted_by: Option<Address>,
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -288,7 +293,7 @@ impl Cheatcode for expectEmitAnonymous_3Call {
impl Cheatcode for expectRevert_0Call {
fn apply_stateful<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
let Self {} = self;
expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), false, false)
expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), false, false, None)
}
}

Expand All @@ -301,14 +306,36 @@ impl Cheatcode for expectRevert_1Call {
ccx.ecx.journaled_state.depth(),
false,
false,
None,
)
}
}

impl Cheatcode for expectRevert_2Call {
fn apply_stateful<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
let Self { revertData } = self;
expect_revert(ccx.state, Some(revertData), ccx.ecx.journaled_state.depth(), false, false)
expect_revert(
ccx.state,
Some(revertData),
ccx.ecx.journaled_state.depth(),
false,
false,
None,
)
}
}

impl Cheatcode for expectRevert_3Call {
fn apply_stateful<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
let Self { reverter } = self;
expect_revert(
ccx.state,
None,
ccx.ecx.journaled_state.depth(),
false,
false,
Some(*reverter),
)
}
}

Expand All @@ -321,13 +348,14 @@ impl Cheatcode for expectPartialRevertCall {
ccx.ecx.journaled_state.depth(),
false,
true,
None,
)
}
}

impl Cheatcode for _expectCheatcodeRevert_0Call {
fn apply_stateful<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), true, false)
expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), true, false, None)
}
}

Expand All @@ -340,14 +368,22 @@ impl Cheatcode for _expectCheatcodeRevert_1Call {
ccx.ecx.journaled_state.depth(),
true,
false,
None,
)
}
}

impl Cheatcode for _expectCheatcodeRevert_2Call {
fn apply_stateful<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
let Self { revertData } = self;
expect_revert(ccx.state, Some(revertData), ccx.ecx.journaled_state.depth(), true, false)
expect_revert(
ccx.state,
Some(revertData),
ccx.ecx.journaled_state.depth(),
true,
false,
None,
)
}
}

Expand Down Expand Up @@ -568,6 +604,7 @@ fn expect_revert(
depth: u64,
cheatcode: bool,
partial_match: bool,
reverter: Option<Address>,
) -> Result {
ensure!(
state.expected_revert.is_none(),
Expand All @@ -582,11 +619,109 @@ fn expect_revert(
ExpectedRevertKind::Default
},
partial_match,
reverter,
reverted_by: None,
});
Ok(Default::default())
}

pub(crate) fn handle_expect_revert(
/// Process `expectRevert` cheatcode for create outcome.
/// Similar with processing regular call but sets the outcome address if call succeeds.
pub(crate) fn process_create_expect_revert(
expected_revert: ExpectedRevert,
known_contracts: &Option<ContractsByArtifact>,
mut outcome: CreateOutcome,
) -> CreateOutcome {
let (result, output, address) = process_expect_revert(
expected_revert,
false,
true,
outcome.result.result,
outcome.result.output,
known_contracts,
);
outcome.result.result = result;
outcome.result.output = output;
if result == InstructionResult::Return {
outcome.address = address;
}
outcome
}

/// Process `expectRevert` cheatcode for call outcome.
pub(crate) fn process_call_expect_revert(
expected_revert: ExpectedRevert,
cheatcode_call: bool,
known_contracts: &Option<ContractsByArtifact>,
mut outcome: CallOutcome,
) -> CallOutcome {
let (result, output, _) = process_expect_revert(
expected_revert,
cheatcode_call,
false,
outcome.result.result,
outcome.result.output,
known_contracts,
);
outcome.result.result = result;
outcome.result.output = output;
outcome
}

/// Process `expectRevert`s by checking if call reverted with expected data and expected reverter
/// (if specified). Returns new result instruction, result output and dummy create address.
/// If expected revert mismatch then returns `InstructionResult::Revert` and failure reason
/// (`expectRevert` cheatcode failed).
/// - If expected revert matches actual revert then it checks for expected reverter (if specified)
/// to be the same as address that reverted the call.
/// If expected reverter is not the actual reverter then returns `InstructionResult::Revert` and
/// failure reason (`expectRevert` cheatcode failed).
/// Otherwise returns `InstructionResult::Return` (`expectRevert` cheatcode succeeded).
fn process_expect_revert(
expected_revert: ExpectedRevert,
cheatcode_call: bool,
is_create: bool,
result: InstructionResult,
output: Bytes,
known_contracts: &Option<ContractsByArtifact>,
) -> (InstructionResult, Bytes, Option<Address>) {
return match handle_expect_revert(
cheatcode_call,
is_create,
expected_revert.reason.as_deref(),
expected_revert.partial_match,
result,
output,
known_contracts,
) {
Err(error) => {
trace!(expected=?expected_revert, ?error, status=?result, "Expected revert mismatch");
(InstructionResult::Revert, error.abi_encode().into(), None)
}
Ok((address, retdata)) => {
match (expected_revert.reverter, expected_revert.reverted_by) {
(Some(reverter), Some(reverted_by)) => {
// If both expected reverter and call reverter are set, then make sure they're
// the same.
if reverter != reverted_by {
(
InstructionResult::Revert,
format!("Reverter != expected reverter: {reverted_by} != {reverter}")
.abi_encode()
.into(),
None,
)
} else {
(InstructionResult::Return, retdata, address)
}
}
_ => (InstructionResult::Return, retdata, address),
}
}
};
}

fn handle_expect_revert(
is_cheatcode: bool,
is_create: bool,
expected_revert: Option<&[u8]>,
Expand Down
Loading

0 comments on commit 8479102

Please sign in to comment.