diff --git a/crates/cheatcodes/assets/cheatcodes.json b/crates/cheatcodes/assets/cheatcodes.json index 574d310ca19f..fd2441efc43e 100644 --- a/crates/cheatcodes/assets/cheatcodes.json +++ b/crates/cheatcodes/assets/cheatcodes.json @@ -4713,7 +4713,7 @@ }, { "func": { - "id": "expectPartialRevert", + "id": "expectPartialRevert_0", "description": "Expects an error on next call that starts with the revert data.", "declaration": "function expectPartialRevert(bytes4 revertData) external;", "visibility": "external", @@ -4731,6 +4731,26 @@ "status": "stable", "safety": "unsafe" }, + { + "func": { + "id": "expectPartialRevert_1", + "description": "Expects an error on next call to reverter address, that starts with the revert data.", + "declaration": "function expectPartialRevert(bytes4 revertData, address reverter) external;", + "visibility": "external", + "mutability": "", + "signature": "expectPartialRevert(bytes4,address)", + "selector": "0x51aa008a", + "selectorBytes": [ + 81, + 170, + 0, + 138 + ] + }, + "group": "testing", + "status": "stable", + "safety": "unsafe" + }, { "func": { "id": "expectRevert_0", @@ -4791,6 +4811,66 @@ "status": "stable", "safety": "unsafe" }, + { + "func": { + "id": "expectRevert_3", + "description": "Expects an error with any revert data on next call to reverter address.", + "declaration": "function expectRevert(address reverter) external;", + "visibility": "external", + "mutability": "", + "signature": "expectRevert(address)", + "selector": "0xd814f38a", + "selectorBytes": [ + 216, + 20, + 243, + 138 + ] + }, + "group": "testing", + "status": "stable", + "safety": "unsafe" + }, + { + "func": { + "id": "expectRevert_4", + "description": "Expects an error from reverter address on next call, with any revert data.", + "declaration": "function expectRevert(bytes4 revertData, address reverter) external;", + "visibility": "external", + "mutability": "", + "signature": "expectRevert(bytes4,address)", + "selector": "0x260bc5de", + "selectorBytes": [ + 38, + 11, + 197, + 222 + ] + }, + "group": "testing", + "status": "stable", + "safety": "unsafe" + }, + { + "func": { + "id": "expectRevert_5", + "description": "Expects an error from reverter address on next call, that exactly matches the revert data.", + "declaration": "function expectRevert(bytes calldata revertData, address reverter) external;", + "visibility": "external", + "mutability": "", + "signature": "expectRevert(bytes,address)", + "selector": "0x61ebcf12", + "selectorBytes": [ + 97, + 235, + 207, + 18 + ] + }, + "group": "testing", + "status": "stable", + "safety": "unsafe" + }, { "func": { "id": "expectSafeMemory", diff --git a/crates/cheatcodes/spec/src/vm.rs b/crates/cheatcodes/spec/src/vm.rs index 575ccfa84a5a..0fc51b8d047a 100644 --- a/crates/cheatcodes/spec/src/vm.rs +++ b/crates/cheatcodes/spec/src/vm.rs @@ -799,10 +799,26 @@ interface Vm { #[cheatcode(group = Testing, safety = Unsafe)] function expectRevert(bytes calldata revertData) external; + /// Expects an error with any revert data on next call to reverter address. + #[cheatcode(group = Testing, safety = Unsafe)] + function expectRevert(address reverter) external; + + /// Expects an error from reverter address on next call, with any revert data. + #[cheatcode(group = Testing, safety = Unsafe)] + function expectRevert(bytes4 revertData, address reverter) external; + + /// Expects an error from reverter address on next call, that exactly matches the revert data. + #[cheatcode(group = Testing, safety = Unsafe)] + function expectRevert(bytes calldata revertData, address reverter) external; + /// Expects an error on next call that starts with the revert data. #[cheatcode(group = Testing, safety = Unsafe)] function expectPartialRevert(bytes4 revertData) external; + /// Expects an error on next call to reverter address, that starts with the revert data. + #[cheatcode(group = Testing, safety = Unsafe)] + function expectPartialRevert(bytes4 revertData, address reverter) external; + /// Expects an error on next cheatcode call with any revert data. #[cheatcode(group = Testing, safety = Unsafe, status = Internal)] function _expectCheatcodeRevert() external; diff --git a/crates/cheatcodes/src/inspector.rs b/crates/cheatcodes/src/inspector.rs index 1ff2a6e999dc..18b42061834c 100644 --- a/crates/cheatcodes/src/inspector.rs +++ b/crates/cheatcodes/src/inspector.rs @@ -641,10 +641,10 @@ impl Cheatcodes { return match expect::handle_expect_revert( false, true, - expected_revert.reason.as_deref(), - expected_revert.partial_match, + &expected_revert, outcome.result.result, outcome.result.output.clone(), + None, &self.config.available_artifacts, ) { Ok((address, retdata)) => { @@ -1123,10 +1123,10 @@ impl Inspector for Cheatcodes { return match expect::handle_expect_revert( cheatcode_call, false, - expected_revert.reason.as_deref(), - expected_revert.partial_match, + &expected_revert, outcome.result.result, outcome.result.output.clone(), + Some(call.target_address), &self.config.available_artifacts, ) { Err(error) => { diff --git a/crates/cheatcodes/src/test/expect.rs b/crates/cheatcodes/src/test/expect.rs index 83899041515d..e467dd314b28 100644 --- a/crates/cheatcodes/src/test/expect.rs +++ b/crates/cheatcodes/src/test/expect.rs @@ -80,6 +80,8 @@ pub struct ExpectedRevert { pub kind: ExpectedRevertKind, /// If true then only the first 4 bytes of expected data returned by the revert are checked. pub partial_match: bool, + /// Contract expected to revert next call. + pub reverter: Option
, } #[derive(Clone, Debug)] @@ -288,7 +290,7 @@ impl Cheatcode for expectEmitAnonymous_3Call { impl Cheatcode for expectRevert_0Call { fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { let Self {} = self; - expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), false, false) + expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), false, false, None) } } @@ -301,6 +303,7 @@ impl Cheatcode for expectRevert_1Call { ccx.ecx.journaled_state.depth(), false, false, + None, ) } } @@ -308,11 +311,60 @@ impl Cheatcode for expectRevert_1Call { impl Cheatcode for expectRevert_2Call { fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { let Self { revertData } = self; - expect_revert(ccx.state, Some(revertData), ccx.ecx.journaled_state.depth(), false, false) + expect_revert( + ccx.state, + Some(revertData), + ccx.ecx.journaled_state.depth(), + false, + false, + None, + ) } } -impl Cheatcode for expectPartialRevertCall { +impl Cheatcode for expectRevert_3Call { + fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { + let Self { reverter } = self; + expect_revert( + ccx.state, + None, + ccx.ecx.journaled_state.depth(), + false, + false, + Some(*reverter), + ) + } +} + +impl Cheatcode for expectRevert_4Call { + fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { + let Self { revertData, reverter } = self; + expect_revert( + ccx.state, + Some(revertData.as_ref()), + ccx.ecx.journaled_state.depth(), + false, + false, + Some(*reverter), + ) + } +} + +impl Cheatcode for expectRevert_5Call { + fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { + let Self { revertData, reverter } = self; + expect_revert( + ccx.state, + Some(revertData), + ccx.ecx.journaled_state.depth(), + false, + false, + Some(*reverter), + ) + } +} + +impl Cheatcode for expectPartialRevert_0Call { fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { let Self { revertData } = self; expect_revert( @@ -321,13 +373,28 @@ impl Cheatcode for expectPartialRevertCall { ccx.ecx.journaled_state.depth(), false, true, + None, + ) + } +} + +impl Cheatcode for expectPartialRevert_1Call { + fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { + let Self { revertData, reverter } = self; + expect_revert( + ccx.state, + Some(revertData.as_ref()), + ccx.ecx.journaled_state.depth(), + false, + true, + Some(*reverter), ) } } impl Cheatcode for _expectCheatcodeRevert_0Call { fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { - expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), true, false) + expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), true, false, None) } } @@ -340,6 +407,7 @@ impl Cheatcode for _expectCheatcodeRevert_1Call { ccx.ecx.journaled_state.depth(), true, false, + None, ) } } @@ -347,7 +415,14 @@ impl Cheatcode for _expectCheatcodeRevert_1Call { impl Cheatcode for _expectCheatcodeRevert_2Call { fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { let Self { revertData } = self; - expect_revert(ccx.state, Some(revertData), ccx.ecx.journaled_state.depth(), true, false) + expect_revert( + ccx.state, + Some(revertData), + ccx.ecx.journaled_state.depth(), + true, + false, + None, + ) } } @@ -568,6 +643,7 @@ fn expect_revert( depth: u64, cheatcode: bool, partial_match: bool, + reverter: Option
, ) -> Result { ensure!( state.expected_revert.is_none(), @@ -582,6 +658,7 @@ fn expect_revert( ExpectedRevertKind::Default }, partial_match, + reverter, }); Ok(Default::default()) } @@ -589,10 +666,10 @@ fn expect_revert( pub(crate) fn handle_expect_revert( is_cheatcode: bool, is_create: bool, - expected_revert: Option<&[u8]>, - partial_match: bool, + expected_revert: &ExpectedRevert, status: InstructionResult, retdata: Bytes, + target_address: Option
, known_contracts: &Option, ) -> Result<(Option
, Bytes)> { let success_return = || { @@ -605,19 +682,33 @@ pub(crate) fn handle_expect_revert( ensure!(!matches!(status, return_ok!()), "next call did not revert as expected"); + // If expected reverter address is set then check it matches the actual reverter. + if let (Some(expected_reverter), Some(actual_reverter)) = + (expected_revert.reverter, target_address) + { + if expected_reverter != actual_reverter { + return Err(fmt_err!( + "Reverter != expected reverter: {} != {}", + actual_reverter, + expected_reverter + )); + } + } + + let expected_reason = expected_revert.reason.as_deref(); // If None, accept any revert. - let Some(expected_revert) = expected_revert else { + let Some(expected_reason) = expected_reason else { return Ok(success_return()); }; - if !expected_revert.is_empty() && retdata.is_empty() { + if !expected_reason.is_empty() && retdata.is_empty() { bail!("call reverted as expected, but without data"); } let mut actual_revert: Vec = retdata.into(); // Compare only the first 4 bytes if partial match. - if partial_match && actual_revert.get(..4) == expected_revert.get(..4) { + if expected_revert.partial_match && actual_revert.get(..4) == expected_reason.get(..4) { return Ok(success_return()) } @@ -631,8 +722,8 @@ pub(crate) fn handle_expect_revert( } } - if actual_revert == expected_revert || - (is_cheatcode && memchr::memmem::find(&actual_revert, expected_revert).is_some()) + if actual_revert == expected_reason || + (is_cheatcode && memchr::memmem::find(&actual_revert, expected_reason).is_some()) { Ok(success_return()) } else { @@ -640,7 +731,7 @@ pub(crate) fn handle_expect_revert( let decoder = RevertDecoder::new().with_abis(contracts.iter().map(|(_, c)| &c.abi)); ( &decoder.decode(actual_revert.as_slice(), Some(status)), - &decoder.decode(expected_revert, Some(status)), + &decoder.decode(expected_reason, Some(status)), ) } else { let stringify = |data: &[u8]| { @@ -652,7 +743,7 @@ pub(crate) fn handle_expect_revert( } hex::encode_prefixed(data) }; - (&stringify(&actual_revert), &stringify(expected_revert)) + (&stringify(&actual_revert), &stringify(expected_reason)) }; Err(fmt_err!("Error != expected error: {} != {}", actual, expected,)) } diff --git a/crates/evm/core/src/decode.rs b/crates/evm/core/src/decode.rs index 29f448bcebd5..eacf9b8ace9c 100644 --- a/crates/evm/core/src/decode.rs +++ b/crates/evm/core/src/decode.rs @@ -144,11 +144,31 @@ impl RevertDecoder { let e = Vm::expectRevert_2Call::abi_decode_raw(data, false).ok()?; return self.maybe_decode(&e.revertData[..], status); } + // `expectRevert(bytes,address)` + Vm::expectRevert_5Call::SELECTOR => { + let e = Vm::expectRevert_5Call::abi_decode_raw(data, false).ok()?; + return self.maybe_decode(&e.revertData[..], status); + } // `expectRevert(bytes4)` Vm::expectRevert_1Call::SELECTOR => { let e = Vm::expectRevert_1Call::abi_decode_raw(data, false).ok()?; return self.maybe_decode(&e.revertData[..], status); } + // `expectRevert(bytes4,address)` + Vm::expectRevert_4Call::SELECTOR => { + let e = Vm::expectRevert_4Call::abi_decode_raw(data, false).ok()?; + return self.maybe_decode(&e.revertData[..], status); + } + // `expectPartialRevert(bytes4)` + Vm::expectPartialRevert_0Call::SELECTOR => { + let e = Vm::expectPartialRevert_0Call::abi_decode_raw(data, false).ok()?; + return self.maybe_decode(&e.revertData[..], status); + } + // `expectPartialRevert(bytes4,address)` + Vm::expectPartialRevert_1Call::SELECTOR => { + let e = Vm::expectPartialRevert_1Call::abi_decode_raw(data, false).ok()?; + return self.maybe_decode(&e.revertData[..], status); + } _ => {} } diff --git a/crates/forge/tests/cli/test_cmd.rs b/crates/forge/tests/cli/test_cmd.rs index 23ffa890d603..3017dd9b9f69 100644 --- a/crates/forge/tests/cli/test_cmd.rs +++ b/crates/forge/tests/cli/test_cmd.rs @@ -1834,3 +1834,73 @@ contract CounterTest is DSTest { ... "#]]); }); + +// Tests `expectRevert` with specific reverter address. +forgetest_init!(test_expect_revert_with_reverter, |prj, cmd| { + prj.wipe_contracts(); + prj.insert_ds_test(); + prj.insert_vm(); + prj.clear(); + + prj.add_source( + "Reverter.t.sol", + r#"pragma solidity 0.8.24; +import {Vm} from "./Vm.sol"; +import {DSTest} from "./test.sol"; +contract Reverter { + error CustomError(); + function withRevert() public pure { + revert CustomError(); + } + function withNoRevert() public pure {} +} +contract ReverterTest is DSTest { + Vm vm = Vm(HEVM_ADDRESS); + error CustomError(); + function test_match_reverters() public { + Reverter reverter = new Reverter(); + vm.expectRevert(address(reverter)); + reverter.withRevert(); + vm.expectRevert(CustomError.selector, address(reverter)); + reverter.withRevert(); + vm.expectPartialRevert(CustomError.selector, address(reverter)); + reverter.withRevert(); + } + function test_next_call_fail() public { + Reverter reverter = new Reverter(); + vm.expectRevert(address(reverter)); + reverter.withNoRevert(); + } + function test_wrong_reverter_fail_1() public { + Reverter reverter = new Reverter(); + Reverter bReverter = new Reverter(); + vm.expectRevert(address(bReverter)); + reverter.withRevert(); + } + function test_wrong_reverter_fail_2() public { + Reverter reverter = new Reverter(); + Reverter bReverter = new Reverter(); + vm.expectRevert(CustomError.selector, address(bReverter)); + reverter.withRevert(); + } + function test_wrong_reverter_fail_3() public { + Reverter reverter = new Reverter(); + Reverter bReverter = new Reverter(); + vm.expectPartialRevert(CustomError.selector, address(bReverter)); + reverter.withRevert(); + } +} + "#, + ) + .unwrap(); + + cmd.args(["test"]).assert_failure().stdout_eq(str![[r#" +... +[PASS] test_match_reverters() ([GAS]) +[FAIL. Reason: next call did not revert as expected] test_next_call_fail() ([GAS]) +[FAIL. Reason: Reverter != expected reverter: 0x5615dEB798BB3E4dFa0139dFa1b3D433Cc23b72f != 0x2e234DAe75C793f67A35089C9d99245E1C58470b] test_wrong_reverter_fail_1() ([GAS]) +[FAIL. Reason: Reverter != expected reverter: 0x5615dEB798BB3E4dFa0139dFa1b3D433Cc23b72f != 0x2e234DAe75C793f67A35089C9d99245E1C58470b] test_wrong_reverter_fail_2() ([GAS]) +[FAIL. Reason: Reverter != expected reverter: 0x5615dEB798BB3E4dFa0139dFa1b3D433Cc23b72f != 0x2e234DAe75C793f67A35089C9d99245E1C58470b] test_wrong_reverter_fail_3() ([GAS]) +... +"#]]); +}); diff --git a/testdata/cheats/Vm.sol b/testdata/cheats/Vm.sol index b929053dac01..c7dfe913023e 100644 --- a/testdata/cheats/Vm.sol +++ b/testdata/cheats/Vm.sol @@ -232,9 +232,13 @@ interface Vm { function expectEmit() external; function expectEmit(address emitter) external; function expectPartialRevert(bytes4 revertData) external; + function expectPartialRevert(bytes4 revertData, address reverter) external; function expectRevert() external; function expectRevert(bytes4 revertData) external; function expectRevert(bytes calldata revertData) external; + function expectRevert(address reverter) external; + function expectRevert(bytes4 revertData, address reverter) external; + function expectRevert(bytes calldata revertData, address reverter) external; function expectSafeMemory(uint64 min, uint64 max) external; function expectSafeMemoryCall(uint64 min, uint64 max) external; function fee(uint256 newBasefee) external;