Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transient version of ReentrancyGuard #4988

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/witty-chicken-smile.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'openzeppelin-solidity': minor
---

`ReentrancyGuardTransient`: Added a variant of `ReentrancyGuard` that uses transient storage.
50 changes: 50 additions & 0 deletions contracts/mocks/ReentrancyTransientMock.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// SPDX-License-Identifier: MIT

pragma solidity ^0.8.24;

import {ReentrancyGuardTransient} from "../utils/ReentrancyGuardTransient.sol";
import {ReentrancyAttack} from "./ReentrancyAttack.sol";

contract ReentrancyTransientMock is ReentrancyGuardTransient {
uint256 public counter;

constructor() {
counter = 0;
}

function callback() external nonReentrant {
_count();
}

function countLocalRecursive(uint256 n) public nonReentrant {
if (n > 0) {
_count();
countLocalRecursive(n - 1);
}
}

function countThisRecursive(uint256 n) public nonReentrant {
if (n > 0) {
_count();
(bool success, ) = address(this).call(abi.encodeCall(this.countThisRecursive, (n - 1)));
require(success, "ReentrancyTransientMock: failed call");
}
}

function countAndCall(ReentrancyAttack attacker) public nonReentrant {
_count();
attacker.callSender(abi.encodeCall(this.callback, ()));
}

function _count() private {
counter += 1;
}

function guardedCheckEntered() public nonReentrant {
require(_reentrancyGuardEntered());
}

function unguardedCheckNotEntered() public view {
require(!_reentrancyGuardEntered());
}
}
3 changes: 3 additions & 0 deletions contracts/utils/README.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Miscellaneous contracts and libraries containing utility functions you can use t
* {MerkleProof}: Functions for verifying https://en.wikipedia.org/wiki/Merkle_tree[Merkle Tree] proofs.
* {EIP712}: Contract with functions to allow processing signed typed structure data according to https://eips.ethereum.org/EIPS/eip-712[EIP-712].
* {ReentrancyGuard}: A modifier that can prevent reentrancy during certain functions.
* {ReentrancyGuardTransient}: Variant of {ReentrancyGuard} that uses transient storage (EIP-1153).
ernestognw marked this conversation as resolved.
Show resolved Hide resolved
* {Pausable}: A common emergency response mechanism that can pause functionality while a remediation is pending.
* {Nonces}: Utility for tracking and verifying address nonces that only increment.
* {ERC165, ERC165Checker}: Utilities for inspecting interfaces supported by contracts.
Expand Down Expand Up @@ -65,6 +66,8 @@ Because Solidity does not support generic types, {EnumerableMap} and {Enumerable

{{ReentrancyGuard}}

{{ReentrancyGuardTransient}}

{{Pausable}}

{{Nonces}}
Expand Down
3 changes: 3 additions & 0 deletions contracts/utils/ReentrancyGuard.sol
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ pragma solidity ^0.8.20;
* those functions `private`, and then adding `external` `nonReentrant` entry
* points to them.
*
* NOTE: If EIP-1153 (transient storage) is available on the targeted network, you
* should consider using {TransientReentrancyGuard} instead.
ernestognw marked this conversation as resolved.
Show resolved Hide resolved
*
* TIP: If you would like to learn more about reentrancy and alternative ways
* to protect against it, check out our blog post
* https://blog.openzeppelin.com/reentrancy-after-istanbul/[Reentrancy After Istanbul].
Expand Down
58 changes: 58 additions & 0 deletions contracts/utils/ReentrancyGuardTransient.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// SPDX-License-Identifier: MIT

pragma solidity ^0.8.24;

import {StorageSlot} from "./StorageSlot.sol";

/**
* @dev Variant of {ReentrancyGuard} that uses transient storage.
*
* NOTE: This variant only works on networks where EIP-1153 is available.
*/
abstract contract ReentrancyGuardTransient {
using StorageSlot for *;

// keccak256(abi.encode(uint256(keccak256("openzeppelin.storage.ReentrancyGuard")) - 1)) & ~bytes32(uint256(0xff))
bytes32 private constant REENTRANCY_GUARD_STORAGE =
0x9b779b17422d0df92223018b32b4d1fa46e071723d6817e2486d003becc55f00;

/**
* @dev Unauthorized reentrant call.
*/
error ReentrancyGuardReentrantCall();

/**
* @dev Prevents a contract from calling itself, directly or indirectly.
* Calling a `nonReentrant` function from another `nonReentrant`
* function is not supported. It is possible to prevent this from happening
* by making the `nonReentrant` function external, and making it call a
* `private` function that does the actual work.
*/
modifier nonReentrant() {
_nonReentrantBefore();
_;
_nonReentrantAfter();
}

function _nonReentrantBefore() private {
// On the first call to nonReentrant, _status will be NOT_ENTERED
if (_reentrancyGuardEntered()) {
revert ReentrancyGuardReentrantCall();
}

// Any calls to nonReentrant after this point will fail
REENTRANCY_GUARD_STORAGE.asBoolean().tstore(true);
}

function _nonReentrantAfter() private {
REENTRANCY_GUARD_STORAGE.asBoolean().tstore(false);
}

/**
* @dev Returns true if the reentrancy guard is currently set to "entered", which indicates there is a
* `nonReentrant` function in the call stack.
*/
function _reentrancyGuardEntered() internal view returns (bool) {
return REENTRANCY_GUARD_STORAGE.asBoolean().tload();
}
}
87 changes: 45 additions & 42 deletions test/utils/ReentrancyGuard.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,49 @@ const { ethers } = require('hardhat');
const { expect } = require('chai');
const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');

async function fixture() {
const mock = await ethers.deployContract('ReentrancyMock');
return { mock };
}

describe('ReentrancyGuard', function () {
beforeEach(async function () {
Object.assign(this, await loadFixture(fixture));
});

it('nonReentrant function can be called', async function () {
expect(await this.mock.counter()).to.equal(0n);
await this.mock.callback();
expect(await this.mock.counter()).to.equal(1n);
});

it('does not allow remote callback', async function () {
const attacker = await ethers.deployContract('ReentrancyAttack');
await expect(this.mock.countAndCall(attacker)).to.be.revertedWith('ReentrancyAttack: failed call');
});

it('_reentrancyGuardEntered should be true when guarded', async function () {
await this.mock.guardedCheckEntered();
for (const variant of ['', 'Transient']) {
describe(`Reentrancy${variant}Guard`, function () {
async function fixture() {
const name = `Reentrancy${variant}Mock`;
const mock = await ethers.deployContract(name);
return { name, mock };
}

beforeEach(async function () {
Object.assign(this, await loadFixture(fixture));
});

it('nonReentrant function can be called', async function () {
expect(await this.mock.counter()).to.equal(0n);
await this.mock.callback();
expect(await this.mock.counter()).to.equal(1n);
});

it('does not allow remote callback', async function () {
const attacker = await ethers.deployContract('ReentrancyAttack');
await expect(this.mock.countAndCall(attacker)).to.be.revertedWith('ReentrancyAttack: failed call');
});

it('_reentrancyGuardEntered should be true when guarded', async function () {
await this.mock.guardedCheckEntered();
});

it('_reentrancyGuardEntered should be false when unguarded', async function () {
await this.mock.unguardedCheckNotEntered();
});

// The following are more side-effects than intended behavior:
// I put them here as documentation, and to monitor any changes
// in the side-effects.
it('does not allow local recursion', async function () {
await expect(this.mock.countLocalRecursive(10n)).to.be.revertedWithCustomError(
this.mock,
'ReentrancyGuardReentrantCall',
);
});

it('does not allow indirect local recursion', async function () {
await expect(this.mock.countThisRecursive(10n)).to.be.revertedWith(`${this.name}: failed call`);
});
});

it('_reentrancyGuardEntered should be false when unguarded', async function () {
await this.mock.unguardedCheckNotEntered();
});

// The following are more side-effects than intended behavior:
// I put them here as documentation, and to monitor any changes
// in the side-effects.
it('does not allow local recursion', async function () {
await expect(this.mock.countLocalRecursive(10n)).to.be.revertedWithCustomError(
this.mock,
'ReentrancyGuardReentrantCall',
);
});

it('does not allow indirect local recursion', async function () {
await expect(this.mock.countThisRecursive(10n)).to.be.revertedWith('ReentrancyMock: failed call');
});
});
}
Loading