diff --git a/crates/cheatcodes/spec/src/vm.rs b/crates/cheatcodes/spec/src/vm.rs index a9421d087daa..50bdd26f3aa4 100644 --- a/crates/cheatcodes/spec/src/vm.rs +++ b/crates/cheatcodes/spec/src/vm.rs @@ -53,6 +53,14 @@ interface Vm { SelfDestruct, /// Synthetic access indicating the current context has resumed after a previous sub-context (AccountAccess). Resume, + /// The account's balance was read. + Balance, + /// The account's codesize was read. + Extcodesize, + /// The account's code was copied. + Extcodecopy, + /// The account's codehash was read. + Extcodehash, } /// An Ethereum log. Returned by `getRecordedLogs`. diff --git a/crates/cheatcodes/src/inspector.rs b/crates/cheatcodes/src/inspector.rs index b154955330af..390c16fb92f5 100644 --- a/crates/cheatcodes/src/inspector.rs +++ b/crates/cheatcodes/src/inspector.rs @@ -505,6 +505,60 @@ impl Inspector for Cheatcodes { data.journaled_state.depth(), ); } + // Record account accesses via the EXT family of opcodes + opcode::EXTCODECOPY | + opcode::EXTCODESIZE | + opcode::EXTCODEHASH | + opcode::BALANCE => { + let kind = match interpreter.current_opcode() { + opcode::EXTCODECOPY => crate::Vm::AccountAccessKind::Extcodecopy, + opcode::EXTCODESIZE => crate::Vm::AccountAccessKind::Extcodesize, + opcode::EXTCODEHASH => crate::Vm::AccountAccessKind::Extcodehash, + opcode::BALANCE => crate::Vm::AccountAccessKind::Balance, + _ => unreachable!(), + }; + let address = Address::from_word(B256::from(try_or_continue!(interpreter + .stack() + .peek(0)))); + let balance; + let initialized; + if let Ok((acc, _)) = data.journaled_state.load_account(address, data.db) { + initialized = acc.info.exists(); + balance = acc.info.balance; + } else { + initialized = false; + balance = U256::ZERO; + } + let account_access = crate::Vm::AccountAccess { + chainInfo: crate::Vm::ChainInfo { + forkId: data.db.active_fork_id().unwrap_or_default(), + chainId: U256::from(data.env.cfg.chain_id), + }, + accessor: interpreter.contract().address, + account: address, + kind, + initialized, + oldBalance: balance, + newBalance: balance, + value: U256::ZERO, + data: vec![], + reverted: false, + deployedCode: vec![], + storageAccesses: vec![], + }; + let access = AccountAccess { + access: account_access, + // use current depth; EXT* opcodes are not creating new contexts + depth: data.journaled_state.depth(), + }; + // Record the EXT* call as an account access at the current depth + // (future storage accesses will be recorded in a new "Resume" context) + if let Some(last) = recorded_account_diffs_stack.last_mut() { + last.push(access); + } else { + recorded_account_diffs_stack.push(vec![access]); + } + } _ => (), } } diff --git a/testdata/cheats/RecordAccountAccesses.t.sol b/testdata/cheats/RecordAccountAccesses.t.sol index 4ab707edbfac..15c1780b1a35 100644 --- a/testdata/cheats/RecordAccountAccesses.t.sol +++ b/testdata/cheats/RecordAccountAccesses.t.sol @@ -121,6 +121,19 @@ contract NestedRunner { } } +/// Helper contract that uses all three EXT* opcodes on a given address +contract ExtChecker { + function checkExts(address a) external { + assembly { + let x := extcodesize(a) + let y := extcodehash(a) + extcodecopy(a, x, y, 0) + // sstore to check that storage accesses are correctly stored in a new access with a "resume" context + sstore(0, balance(a)) + } + } +} + /// @notice Helper contract that writes to storage in a nested call contract NestedStorer { mapping(bytes32 key => uint256 value) slots; @@ -196,6 +209,7 @@ contract RecordAccountAccessesTest is DSTest { Create2or create2or; StorageAccessor test1; StorageAccessor test2; + ExtChecker extChecker; function setUp() public { runner = new NestedRunner(); @@ -203,6 +217,7 @@ contract RecordAccountAccessesTest is DSTest { create2or = new Create2or(); test1 = new StorageAccessor(); test2 = new StorageAccessor(); + extChecker = new ExtChecker(); } function testStorageAccessDelegateCall() public { @@ -211,7 +226,7 @@ contract RecordAccountAccessesTest is DSTest { cheats.startStateDiffRecording(); address(proxy).call(abi.encodeCall(StorageAccessor.read, bytes32(uint256(1234)))); - Vm.AccountAccess[] memory called = cheats.stopAndReturnStateDiff(); + Vm.AccountAccess[] memory called = filterExtcodesizeForLegacyTests(cheats.stopAndReturnStateDiff()); assertEq(called.length, 2, "incorrect length"); @@ -246,7 +261,7 @@ contract RecordAccountAccessesTest is DSTest { two.write(bytes32(uint256(5678)), bytes32(uint256(123469))); two.write(bytes32(uint256(5678)), bytes32(uint256(1234))); - Vm.AccountAccess[] memory called = cheats.stopAndReturnStateDiff(); + Vm.AccountAccess[] memory called = filterExtcodesizeForLegacyTests(cheats.stopAndReturnStateDiff()); assertEq(called.length, 4, "incorrect length"); assertEq(called[0].storageAccesses.length, 1, "incorrect storage length"); @@ -317,7 +332,7 @@ contract RecordAccountAccessesTest is DSTest { // contract calls to self in constructor SelfCaller caller = new SelfCaller{value: 2 ether}("hello2 world2"); - Vm.AccountAccess[] memory called = cheats.stopAndReturnStateDiff(); + Vm.AccountAccess[] memory called = filterExtcodesizeForLegacyTests(cheats.stopAndReturnStateDiff()); assertEq(called.length, 6); assertEq( called[0], @@ -430,7 +445,7 @@ contract RecordAccountAccessesTest is DSTest { uint256 initBalance = address(this).balance; cheats.startStateDiffRecording(); try this.revertingCall{value: 1 ether}(address(1234), "") {} catch {} - Vm.AccountAccess[] memory called = cheats.stopAndReturnStateDiff(); + Vm.AccountAccess[] memory called = filterExtcodesizeForLegacyTests(cheats.stopAndReturnStateDiff()); assertEq(called.length, 2); assertEq( called[0], @@ -485,7 +500,7 @@ contract RecordAccountAccessesTest is DSTest { /// @param shouldRevert Whether the first call should revert function runNested(bool shouldRevert, bool expectFirstCall) public { try runner.run{value: 1 ether}(shouldRevert) {} catch {} - Vm.AccountAccess[] memory called = cheats.stopAndReturnStateDiff(); + Vm.AccountAccess[] memory called = filterExtcodesizeForLegacyTests(cheats.stopAndReturnStateDiff()); assertEq(called.length, 7 + toUint(expectFirstCall), "incorrect length"); uint256 startingIndex = toUint(expectFirstCall); @@ -737,7 +752,7 @@ contract RecordAccountAccessesTest is DSTest { function testNestedStorage() public { cheats.startStateDiffRecording(); nestedStorer.run(); - Vm.AccountAccess[] memory called = cheats.stopAndReturnStateDiff(); + Vm.AccountAccess[] memory called = filterExtcodesizeForLegacyTests(cheats.stopAndReturnStateDiff()); assertEq(called.length, 3, "incorrect account access length"); assertEq(called[0].storageAccesses.length, 2, "incorrect run storage length"); @@ -858,7 +873,7 @@ contract RecordAccountAccessesTest is DSTest { bytes memory creationCode = abi.encodePacked(type(ConstructorStorer).creationCode, abi.encode(true)); address hypotheticalStorer = deriveCreate2Address(address(create2or), bytes32(0), keccak256(creationCode)); - Vm.AccountAccess[] memory called = cheats.stopAndReturnStateDiff(); + Vm.AccountAccess[] memory called = filterExtcodesizeForLegacyTests(cheats.stopAndReturnStateDiff()); assertEq(called.length, 3, "incorrect account access length"); assertEq(toUint(called[0].kind), toUint(Vm.AccountAccessKind.Create), "incorrect kind"); assertEq(toUint(called[1].kind), toUint(Vm.AccountAccessKind.Call), "incorrect kind"); @@ -967,7 +982,7 @@ contract RecordAccountAccessesTest is DSTest { try create2or.create2(bytes32(0), creationCode) {} catch {} address hypotheticalAddress = deriveCreate2Address(address(create2or), bytes32(0), keccak256(creationCode)); - Vm.AccountAccess[] memory called = cheats.stopAndReturnStateDiff(); + Vm.AccountAccess[] memory called = filterExtcodesizeForLegacyTests(cheats.stopAndReturnStateDiff()); assertEq(called.length, 3, "incorrect length"); assertEq( called[1], @@ -1013,7 +1028,7 @@ contract RecordAccountAccessesTest is DSTest { this.startRecordingFromLowerDepth(); address a = address(new SelfDestructor{value: 1 ether}(address(this))); address b = address(new SelfDestructor{value: 1 ether}(address(bytes20("doesn't exist yet")))); - Vm.AccountAccess[] memory called = cheats.stopAndReturnStateDiff(); + Vm.AccountAccess[] memory called = filterExtcodesizeForLegacyTests(cheats.stopAndReturnStateDiff()); assertEq(called.length, 5, "incorrect length"); assertEq( called[1], @@ -1093,12 +1108,58 @@ contract RecordAccountAccessesTest is DSTest { StorageAccessor a = new StorageAccessor(); cheats.stopBroadcast(); - Vm.AccountAccess[] memory called = cheats.stopAndReturnStateDiff(); + Vm.AccountAccess[] memory called = filterExtcodesizeForLegacyTests(cheats.stopAndReturnStateDiff()); assertEq(called.length, 1, "incorrect length"); assertEq(toUint(called[0].kind), toUint(Vm.AccountAccessKind.Create)); assertEq(called[0].account, address(a)); } + /// @notice Test that EXT* opcodes are recorded as account accesses + function testExtOpcodes() public { + cheats.startStateDiffRecording(); + extChecker.checkExts(address(1234)); + Vm.AccountAccess[] memory called = cheats.stopAndReturnStateDiff(); + assertEq(called.length, 7, "incorrect length"); + // initial solidity extcodesize check for calling extChecker + assertEq(toUint(called[0].kind), toUint(Vm.AccountAccessKind.Extcodesize)); + // call to extChecker + assertEq(toUint(called[1].kind), toUint(Vm.AccountAccessKind.Call)); + // extChecker checks + assertEq(toUint(called[2].kind), toUint(Vm.AccountAccessKind.Extcodesize)); + assertEq(toUint(called[3].kind), toUint(Vm.AccountAccessKind.Extcodehash)); + assertEq(toUint(called[4].kind), toUint(Vm.AccountAccessKind.Extcodecopy)); + assertEq(toUint(called[5].kind), toUint(Vm.AccountAccessKind.Balance)); + // resume of extChecker to hold SSTORE access + assertEq(toUint(called[6].kind), toUint(Vm.AccountAccessKind.Resume)); + assertEq(called[6].storageAccesses.length, 1, "incorrect length"); + } + + /** + * @notice Filter out extcodesize account accesses for legacy tests written before + * EXT* opcodes were supported. + */ + function filterExtcodesizeForLegacyTests(Vm.AccountAccess[] memory inArr) + internal + pure + returns (Vm.AccountAccess[] memory out) + { + // allocate max length for out array + out = new Vm.AccountAccess[](inArr.length); + // track end size + uint256 size; + for (uint256 i = 0; i < inArr.length; ++i) { + // only append if not extcodesize + if (inArr[i].kind != Vm.AccountAccessKind.Extcodesize) { + out[size] = inArr[i]; + ++size; + } + } + // manually truncate out array + assembly { + mstore(out, size) + } + } + function startRecordingFromLowerDepth() external { cheats.startStateDiffRecording(); assembly { diff --git a/testdata/cheats/Vm.sol b/testdata/cheats/Vm.sol index 4fc5c3ce51b2..ab301787465b 100644 --- a/testdata/cheats/Vm.sol +++ b/testdata/cheats/Vm.sol @@ -7,7 +7,7 @@ pragma solidity ^0.8.4; interface Vm { error CheatcodeError(string message); enum CallerMode { None, Broadcast, RecurrentBroadcast, Prank, RecurrentPrank } - enum AccountAccessKind { Call, DelegateCall, CallCode, StaticCall, Create, SelfDestruct, Resume } + enum AccountAccessKind { Call, DelegateCall, CallCode, StaticCall, Create, SelfDestruct, Resume, Balance, Extcodesize, Extcodecopy, Extcodehash } struct Log { bytes32[] topics; bytes data; address emitter; } struct Rpc { string key; string url; } struct EthGetLogs { address emitter; bytes32[] topics; bytes data; bytes32 blockHash; uint64 blockNumber; bytes32 transactionHash; uint64 transactionIndex; uint256 logIndex; bool removed; }