Skip to content

Commit dc0ea03

Browse files
committed
Fix reentrancy by protecting podBalanceOf() and balanceOf() from access during updateBalances() loop
1 parent 5fc07f7 commit dc0ea03

File tree

2 files changed

+58
-3
lines changed

2 files changed

+58
-3
lines changed

contracts/ERC20Pods.sol

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@ import "@1inch/solidity-utils/contracts/libraries/AddressSet.sol";
77

88
import "./interfaces/IERC20Pods.sol";
99
import "./interfaces/IPod.sol";
10+
import "./libs/ReentrancyGuard.sol";
1011

11-
abstract contract ERC20Pods is ERC20, IERC20Pods {
12+
abstract contract ERC20Pods is ERC20, IERC20Pods, ReentrancyGuardExt {
1213
using AddressSet for AddressSet.Data;
1314
using AddressArray for AddressArray.Data;
15+
using ReentrancyGuardLib for ReentrancyGuardLib.Data;
1416

1517
error PodAlreadyAdded();
1618
error PodNotFound();
@@ -22,10 +24,12 @@ abstract contract ERC20Pods is ERC20, IERC20Pods {
2224

2325
uint256 public immutable podsLimit;
2426

27+
ReentrancyGuardLib.Data private _guard;
2528
mapping(address => AddressSet.Data) private _pods;
2629

2730
constructor(uint256 podsLimit_) {
2831
podsLimit = podsLimit_;
32+
_guard.init();
2933
}
3034

3135
function hasPod(address account, address pod) public view virtual returns(bool) {
@@ -44,7 +48,11 @@ abstract contract ERC20Pods is ERC20, IERC20Pods {
4448
return _pods[account].items.get();
4549
}
4650

47-
function podBalanceOf(address pod, address account) public view returns(uint256) {
51+
function balanceOf(address account) public nonReentrantView(_guard) view override(IERC20, ERC20) returns(uint256) {
52+
return super.balanceOf(account);
53+
}
54+
55+
function podBalanceOf(address pod, address account) public nonReentrantView(_guard) view returns(uint256) {
4856
if (hasPod(account, pod)) {
4957
return balanceOf(account);
5058
}
@@ -119,7 +127,7 @@ abstract contract ERC20Pods is ERC20, IERC20Pods {
119127

120128
// ERC20 Overrides
121129

122-
function _afterTokenTransfer(address from, address to, uint256 amount) internal override virtual {
130+
function _afterTokenTransfer(address from, address to, uint256 amount) internal nonReentrant(_guard) override virtual {
123131
super._afterTokenTransfer(from, to, amount);
124132

125133
unchecked {

contracts/libs/ReentrancyGuard.sol

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// SPDX-License-Identifier: MIT
2+
3+
pragma solidity ^0.8.0;
4+
5+
library ReentrancyGuardLib {
6+
error ReentrantCall();
7+
8+
uint256 private constant _NOT_ENTERED = 1;
9+
uint256 private constant _ENTERED = 2;
10+
11+
struct Data {
12+
uint256 _status;
13+
}
14+
15+
function init(Data storage self) internal {
16+
self._status = _NOT_ENTERED;
17+
}
18+
19+
function enter(Data storage self) internal {
20+
if (self._status == _ENTERED) revert ReentrantCall();
21+
self._status = _ENTERED;
22+
}
23+
24+
function exit(Data storage self) internal {
25+
self._status = _NOT_ENTERED;
26+
}
27+
28+
function check(Data storage self) internal view returns (bool) {
29+
return self._status == _ENTERED;
30+
}
31+
}
32+
33+
contract ReentrancyGuardExt {
34+
using ReentrancyGuardLib for ReentrancyGuardLib.Data;
35+
error AccessDenied();
36+
37+
modifier nonReentrant(ReentrancyGuardLib.Data storage self) {
38+
self.enter();
39+
_;
40+
self.exit();
41+
}
42+
43+
modifier nonReentrantView(ReentrancyGuardLib.Data storage self) {
44+
if (self.check()) revert ReentrancyGuardLib.ReentrantCall();
45+
_;
46+
}
47+
}

0 commit comments

Comments
 (0)