Skip to content

Commit

Permalink
feat(cheatcodes): expectRevert with specific reverter address
Browse files Browse the repository at this point in the history
  • Loading branch information
grandizzy committed Sep 2, 2024
1 parent d75318c commit c672bda
Show file tree
Hide file tree
Showing 7 changed files with 300 additions and 19 deletions.
82 changes: 81 additions & 1 deletion crates/cheatcodes/assets/cheatcodes.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions crates/cheatcodes/spec/src/vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -799,10 +799,26 @@ interface Vm {
#[cheatcode(group = Testing, safety = Unsafe)]
function expectRevert(bytes calldata revertData) external;

/// Expects an error with any revert data on next call to reverter address.
#[cheatcode(group = Testing, safety = Unsafe)]
function expectRevert(address reverter) external;

/// Expects an error from reverter address on next call, with any revert data.
#[cheatcode(group = Testing, safety = Unsafe)]
function expectRevert(bytes4 revertData, address reverter) external;

/// Expects an error from reverter address on next call, that exactly matches the revert data.
#[cheatcode(group = Testing, safety = Unsafe)]
function expectRevert(bytes calldata revertData, address reverter) external;

/// Expects an error on next call that starts with the revert data.
#[cheatcode(group = Testing, safety = Unsafe)]
function expectPartialRevert(bytes4 revertData) external;

/// Expects an error on next call to reverter address, that starts with the revert data.
#[cheatcode(group = Testing, safety = Unsafe)]
function expectPartialRevert(bytes4 revertData, address reverter) external;

/// Expects an error on next cheatcode call with any revert data.
#[cheatcode(group = Testing, safety = Unsafe, status = Internal)]
function _expectCheatcodeRevert() external;
Expand Down
8 changes: 4 additions & 4 deletions crates/cheatcodes/src/inspector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -641,10 +641,10 @@ impl Cheatcodes {
return match expect::handle_expect_revert(
false,
true,
expected_revert.reason.as_deref(),
expected_revert.partial_match,
&expected_revert,
outcome.result.result,
outcome.result.output.clone(),
None,
&self.config.available_artifacts,
) {
Ok((address, retdata)) => {
Expand Down Expand Up @@ -1123,10 +1123,10 @@ impl<DB: DatabaseExt> Inspector<DB> for Cheatcodes {
return match expect::handle_expect_revert(
cheatcode_call,
false,
expected_revert.reason.as_deref(),
expected_revert.partial_match,
&expected_revert,
outcome.result.result,
outcome.result.output.clone(),
Some(call.target_address),
&self.config.available_artifacts,
) {
Err(error) => {
Expand Down
119 changes: 105 additions & 14 deletions crates/cheatcodes/src/test/expect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ pub struct ExpectedRevert {
pub kind: ExpectedRevertKind,
/// If true then only the first 4 bytes of expected data returned by the revert are checked.
pub partial_match: bool,
/// Contract expected to revert next call.
pub reverter: Option<Address>,
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -288,7 +290,7 @@ impl Cheatcode for expectEmitAnonymous_3Call {
impl Cheatcode for expectRevert_0Call {
fn apply_stateful<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
let Self {} = self;
expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), false, false)
expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), false, false, None)
}
}

Expand All @@ -301,18 +303,68 @@ impl Cheatcode for expectRevert_1Call {
ccx.ecx.journaled_state.depth(),
false,
false,
None,
)
}
}

impl Cheatcode for expectRevert_2Call {
fn apply_stateful<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
let Self { revertData } = self;
expect_revert(ccx.state, Some(revertData), ccx.ecx.journaled_state.depth(), false, false)
expect_revert(
ccx.state,
Some(revertData),
ccx.ecx.journaled_state.depth(),
false,
false,
None,
)
}
}

impl Cheatcode for expectPartialRevertCall {
impl Cheatcode for expectRevert_3Call {
fn apply_stateful<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
let Self { reverter } = self;
expect_revert(
ccx.state,
None,
ccx.ecx.journaled_state.depth(),
false,
false,
Some(*reverter),
)
}
}

impl Cheatcode for expectRevert_4Call {
fn apply_stateful<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
let Self { revertData, reverter } = self;
expect_revert(
ccx.state,
Some(revertData.as_ref()),
ccx.ecx.journaled_state.depth(),
false,
false,
Some(*reverter),
)
}
}

impl Cheatcode for expectRevert_5Call {
fn apply_stateful<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
let Self { revertData, reverter } = self;
expect_revert(
ccx.state,
Some(revertData),
ccx.ecx.journaled_state.depth(),
false,
false,
Some(*reverter),
)
}
}

impl Cheatcode for expectPartialRevert_0Call {
fn apply_stateful<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
let Self { revertData } = self;
expect_revert(
Expand All @@ -321,13 +373,28 @@ impl Cheatcode for expectPartialRevertCall {
ccx.ecx.journaled_state.depth(),
false,
true,
None,
)
}
}

impl Cheatcode for expectPartialRevert_1Call {
fn apply_stateful<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
let Self { revertData, reverter } = self;
expect_revert(
ccx.state,
Some(revertData.as_ref()),
ccx.ecx.journaled_state.depth(),
false,
true,
Some(*reverter),
)
}
}

impl Cheatcode for _expectCheatcodeRevert_0Call {
fn apply_stateful<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), true, false)
expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), true, false, None)
}
}

Expand All @@ -340,14 +407,22 @@ impl Cheatcode for _expectCheatcodeRevert_1Call {
ccx.ecx.journaled_state.depth(),
true,
false,
None,
)
}
}

impl Cheatcode for _expectCheatcodeRevert_2Call {
fn apply_stateful<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
let Self { revertData } = self;
expect_revert(ccx.state, Some(revertData), ccx.ecx.journaled_state.depth(), true, false)
expect_revert(
ccx.state,
Some(revertData),
ccx.ecx.journaled_state.depth(),
true,
false,
None,
)
}
}

Expand Down Expand Up @@ -568,6 +643,7 @@ fn expect_revert(
depth: u64,
cheatcode: bool,
partial_match: bool,
reverter: Option<Address>,
) -> Result {
ensure!(
state.expected_revert.is_none(),
Expand All @@ -582,17 +658,18 @@ fn expect_revert(
ExpectedRevertKind::Default
},
partial_match,
reverter,
});
Ok(Default::default())
}

pub(crate) fn handle_expect_revert(
is_cheatcode: bool,
is_create: bool,
expected_revert: Option<&[u8]>,
partial_match: bool,
expected_revert: &ExpectedRevert,
status: InstructionResult,
retdata: Bytes,
target_address: Option<Address>,
known_contracts: &Option<ContractsByArtifact>,
) -> Result<(Option<Address>, Bytes)> {
let success_return = || {
Expand All @@ -605,19 +682,33 @@ pub(crate) fn handle_expect_revert(

ensure!(!matches!(status, return_ok!()), "next call did not revert as expected");

// If expected reverter address is set then check it matches the actual reverter.
if let (Some(expected_reverter), Some(actual_reverter)) =
(expected_revert.reverter, target_address)
{
if expected_reverter != actual_reverter {
return Err(fmt_err!(
"Reverter != expected reverter: {} != {}",
actual_reverter,
expected_reverter
));
}
}

let expected_reason = expected_revert.reason.as_deref();
// If None, accept any revert.
let Some(expected_revert) = expected_revert else {
let Some(expected_reason) = expected_reason else {
return Ok(success_return());
};

if !expected_revert.is_empty() && retdata.is_empty() {
if !expected_reason.is_empty() && retdata.is_empty() {
bail!("call reverted as expected, but without data");
}

let mut actual_revert: Vec<u8> = retdata.into();

// Compare only the first 4 bytes if partial match.
if partial_match && actual_revert.get(..4) == expected_revert.get(..4) {
if expected_revert.partial_match && actual_revert.get(..4) == expected_reason.get(..4) {
return Ok(success_return())
}

Expand All @@ -631,16 +722,16 @@ pub(crate) fn handle_expect_revert(
}
}

if actual_revert == expected_revert ||
(is_cheatcode && memchr::memmem::find(&actual_revert, expected_revert).is_some())
if actual_revert == expected_reason ||
(is_cheatcode && memchr::memmem::find(&actual_revert, expected_reason).is_some())
{
Ok(success_return())
} else {
let (actual, expected) = if let Some(contracts) = known_contracts {
let decoder = RevertDecoder::new().with_abis(contracts.iter().map(|(_, c)| &c.abi));
(
&decoder.decode(actual_revert.as_slice(), Some(status)),
&decoder.decode(expected_revert, Some(status)),
&decoder.decode(expected_reason, Some(status)),
)
} else {
let stringify = |data: &[u8]| {
Expand All @@ -652,7 +743,7 @@ pub(crate) fn handle_expect_revert(
}
hex::encode_prefixed(data)
};
(&stringify(&actual_revert), &stringify(expected_revert))
(&stringify(&actual_revert), &stringify(expected_reason))
};
Err(fmt_err!("Error != expected error: {} != {}", actual, expected,))
}
Expand Down
Loading

0 comments on commit c672bda

Please sign in to comment.