Skip to content

Commit

Permalink
Add assumeNoRevert with reverters param
Browse files Browse the repository at this point in the history
  • Loading branch information
grandizzy committed Aug 30, 2024
1 parent 9842b57 commit 8a6a659
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 11 deletions.
22 changes: 21 additions & 1 deletion crates/cheatcodes/assets/cheatcodes.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions crates/cheatcodes/spec/src/vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,11 @@ interface Vm {
#[cheatcode(group = Testing, safety = Safe)]
function assumeNoRevert() external pure;

/// Discard this run's fuzz inputs and generate new ones if call to one of the specified
/// addresses reverts.
#[cheatcode(group = Testing, safety = Safe)]
function assumeNoRevert(address[] calldata reverters) external pure;

/// Writes a breakpoint to jump to in the debugger.
#[cheatcode(group = Testing, safety = Safe)]
function breakpoint(string calldata char) external;
Expand Down
13 changes: 9 additions & 4 deletions crates/cheatcodes/src/inspector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1116,10 +1116,15 @@ impl<DB: DatabaseExt> Inspector<DB> for Cheatcodes {
// Handle assume no revert.
if let Some(assume_revert) = &self.assume_no_revert {
// Discard fuzz input of the reverted call if target address is not the cheatcode
// address and if we didn't exceed the depth where cheatcode was added.
if outcome.result.is_revert() &&
call.target_address != CHEATCODE_ADDRESS &&
ecx.journaled_state.depth() >= assume_revert.depth
// address and marked as a reverter address (if set) and if we didn't exceed the depth
// where cheatcode was added.
if !cheatcode_call &&
outcome.result.is_revert() &&
ecx.journaled_state.depth() >= assume_revert.depth &&
assume_revert
.reverters
.as_ref()
.map_or(true, |reverters| reverters.contains(&call.target_address))
{
outcome.result.output = Error::from(MAGIC_ASSUME).abi_encode().into();
return outcome;
Expand Down
22 changes: 18 additions & 4 deletions crates/cheatcodes/src/test/assume.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
use crate::{Cheatcode, Cheatcodes, CheatsCtxt, Error, Result};
use alloy_primitives::Address;
use foundry_evm_core::{backend::DatabaseExt, constants::MAGIC_ASSUME};
use spec::Vm::{assumeCall, assumeNoRevertCall};
use spec::Vm::{assumeCall, assumeNoRevert_0Call, assumeNoRevert_1Call};
use std::fmt::Debug;

#[derive(Clone, Debug)]
pub struct AssumeNoRevert {
/// The depth at which the cheatcode was added.
/// The call depth at which the cheatcode was added.
pub depth: u64,
/// Discard fuzz run inputs and generate new ones if one of these addresses reverted.
pub reverters: Option<Vec<Address>>,
}

impl Cheatcode for assumeCall {
Expand All @@ -20,10 +23,21 @@ impl Cheatcode for assumeCall {
}
}

impl Cheatcode for assumeNoRevertCall {
impl Cheatcode for assumeNoRevert_0Call {
fn apply_stateful<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
ccx.state.assume_no_revert =
Some(AssumeNoRevert { depth: ccx.ecx.journaled_state.depth() });
Some(AssumeNoRevert { depth: ccx.ecx.journaled_state.depth(), reverters: None });
Ok(Default::default())
}
}

impl Cheatcode for assumeNoRevert_1Call {
fn apply_stateful<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
let Self { reverters } = self;
ccx.state.assume_no_revert = Some(AssumeNoRevert {
depth: ccx.ecx.journaled_state.depth(),
reverters: Some(reverters.to_vec()),
});
Ok(Default::default())
}
}
31 changes: 29 additions & 2 deletions crates/forge/tests/cli/test_cmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1854,10 +1854,14 @@ contract CounterWithRevert {
}
return 99999999;
}
function dummy() public pure {}
}
contract CounterRevertTest is DSTest {
Vm vm = Vm(HEVM_ADDRESS);
address[] reverters;
function test_count_with_revert_pass(uint256 a) public {
vm.assumeNoRevert();
CounterWithRevert counter = new CounterWithRevert();
Expand All @@ -1870,15 +1874,38 @@ contract CounterRevertTest is DSTest {
a = counter.count(a);
assertTrue(a != 99999999, "wrong value");
}
function test_count_with_revert_and_reverters_fail(uint256 a) public {
CounterWithRevert aCounter = new CounterWithRevert();
CounterWithRevert bCounter = new CounterWithRevert();
reverters.push(address(bCounter));
vm.assumeNoRevert(reverters);
// this will revert since aCounter is not in list of reverters
a = aCounter.count(a);
bCounter.dummy();
assertEq(a, 99999999);
}
function test_count_with_revert_and_reverters_pass(uint256 a) public {
CounterWithRevert aCounter = new CounterWithRevert();
CounterWithRevert bCounter = new CounterWithRevert();
reverters.push(address(aCounter));
reverters.push(address(bCounter));
vm.assumeNoRevert(reverters);
a = aCounter.count(a);
bCounter.dummy();
assertEq(a, 99999999);
}
}
"#,
)
.unwrap();

cmd.args(["test"]).with_no_redact().assert_failure().stdout_eq(str![[r#"
...
[FAIL. Reason: assertion failed; counterexample: [..]] test_count_with_revert_fail(uint256) (runs: 0, [..])
[PASS] test_count_with_revert_pass(uint256) (runs: 256, [..])
[FAIL. Reason: RandomError(); counterexample: [..]] test_count_with_revert_and_reverters_fail(uint256) [..]
[PASS] test_count_with_revert_and_reverters_pass(uint256) [..]
[FAIL. Reason: assertion failed; counterexample: [..]] test_count_with_revert_fail(uint256) [..]
[PASS] test_count_with_revert_pass(uint256) [..]
...
"#]]);
});
1 change: 1 addition & 0 deletions testdata/cheats/Vm.sol

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 8a6a659

Please sign in to comment.