Skip to content

Commit ccb9862

Browse files
committed
Add support for ERC721 and ERC1155
1 parent dc0ea03 commit ccb9862

File tree

12 files changed

+394
-130
lines changed

12 files changed

+394
-130
lines changed

contracts/ERC1155Pods.sol

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// SPDX-License-Identifier: MIT
2+
3+
pragma solidity ^0.8.0;
4+
5+
import "@openzeppelin/contracts/token/ERC1155/ERC1155.sol";
6+
7+
import "./interfaces/IERC1155Pods.sol";
8+
import "./TokenPodsLib.sol";
9+
import "./libs/ReentrancyGuard.sol";
10+
11+
abstract contract ERC1155Pods is ERC1155, IERC1155Pods, ReentrancyGuardExt {
12+
using TokenPodsLib for TokenPodsLib.Data;
13+
using ReentrancyGuardLib for ReentrancyGuardLib.Data;
14+
15+
error PodsLimitReachedForAccount();
16+
17+
uint256 public immutable podsLimit;
18+
19+
ReentrancyGuardLib.Data private _guard;
20+
mapping(uint256 => TokenPodsLib.Data) private _pods;
21+
22+
constructor(uint256 podsLimit_) {
23+
podsLimit = podsLimit_;
24+
_guard.init();
25+
}
26+
27+
function hasPod(address account, address pod, uint256 id) public view virtual returns(bool) {
28+
return _pods[id].hasPod(account, pod);
29+
}
30+
31+
function podsCount(address account, uint256 id) public view virtual returns(uint256) {
32+
return _pods[id].podsCount(account);
33+
}
34+
35+
function podAt(address account, uint256 index, uint256 id) public view virtual returns(address) {
36+
return _pods[id].podAt(account, index);
37+
}
38+
39+
function pods(address account, uint256 id) public view virtual returns(address[] memory) {
40+
return _pods[id].pods(account);
41+
}
42+
43+
function balanceOf(address account, uint256 id) public nonReentrantView(_guard) view override(IERC1155, ERC1155) virtual returns(uint256) {
44+
return super.balanceOf(account, id);
45+
}
46+
47+
function podBalanceOf(address pod, address account, uint256 id) public nonReentrantView(_guard) view returns(uint256) {
48+
return _pods[id].podBalanceOf(account, pod, balanceOf(msg.sender, id));
49+
}
50+
51+
function addPod(address pod, uint256 id) public virtual {
52+
if (_pods[id].addPod(msg.sender, pod, balanceOf(msg.sender, id)) > podsLimit) revert PodsLimitReachedForAccount();
53+
}
54+
55+
function removePod(address pod, uint256 id) public virtual {
56+
_pods[id].removePod(msg.sender, pod, balanceOf(msg.sender, id));
57+
}
58+
59+
function removeAllPods(uint256 id) public virtual {
60+
_pods[id].removeAllPods(msg.sender, balanceOf(msg.sender, id));
61+
}
62+
63+
// ERC1155 Overrides
64+
65+
function _afterTokenTransfer(
66+
address operator,
67+
address from,
68+
address to,
69+
uint256[] memory ids,
70+
uint256[] memory amounts,
71+
bytes memory data
72+
) internal nonReentrant(_guard) override virtual {
73+
super._afterTokenTransfer(operator, from, to, ids, amounts, data);
74+
75+
unchecked {
76+
for (uint256 i = 0; i < ids.length; i++) {
77+
_pods[ids[i]].updateBalancesWithTokenId(from, to, amounts[i], ids[i]);
78+
}
79+
}
80+
}
81+
}

contracts/ERC20Pods.sol

Lines changed: 14 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -6,164 +6,65 @@ import "@openzeppelin/contracts/token/ERC20/ERC20.sol";
66
import "@1inch/solidity-utils/contracts/libraries/AddressSet.sol";
77

88
import "./interfaces/IERC20Pods.sol";
9-
import "./interfaces/IPod.sol";
9+
import "./TokenPodsLib.sol";
1010
import "./libs/ReentrancyGuard.sol";
1111

1212
abstract contract ERC20Pods is ERC20, IERC20Pods, ReentrancyGuardExt {
13-
using AddressSet for AddressSet.Data;
14-
using AddressArray for AddressArray.Data;
13+
using TokenPodsLib for TokenPodsLib.Data;
1514
using ReentrancyGuardLib for ReentrancyGuardLib.Data;
1615

17-
error PodAlreadyAdded();
18-
error PodNotFound();
19-
error InvalidPodAddress();
2016
error PodsLimitReachedForAccount();
21-
error InsufficientGas();
22-
23-
uint256 private constant _POD_CALL_GAS_LIMIT = 200_000;
2417

2518
uint256 public immutable podsLimit;
2619

2720
ReentrancyGuardLib.Data private _guard;
28-
mapping(address => AddressSet.Data) private _pods;
21+
TokenPodsLib.Data private _pods;
2922

3023
constructor(uint256 podsLimit_) {
3124
podsLimit = podsLimit_;
3225
_guard.init();
3326
}
3427

3528
function hasPod(address account, address pod) public view virtual returns(bool) {
36-
return _pods[account].contains(pod);
29+
return _pods.hasPod(account, pod);
3730
}
3831

3932
function podsCount(address account) public view virtual returns(uint256) {
40-
return _pods[account].length();
33+
return _pods.podsCount(account);
4134
}
4235

4336
function podAt(address account, uint256 index) public view virtual returns(address) {
44-
return _pods[account].at(index);
37+
return _pods.podAt(account, index);
4538
}
4639

4740
function pods(address account) public view virtual returns(address[] memory) {
48-
return _pods[account].items.get();
41+
return _pods.pods(account);
4942
}
5043

51-
function balanceOf(address account) public nonReentrantView(_guard) view override(IERC20, ERC20) returns(uint256) {
44+
function balanceOf(address account) public nonReentrantView(_guard) view override(IERC20, ERC20) virtual returns(uint256) {
5245
return super.balanceOf(account);
5346
}
5447

55-
function podBalanceOf(address pod, address account) public nonReentrantView(_guard) view returns(uint256) {
56-
if (hasPod(account, pod)) {
57-
return balanceOf(account);
58-
}
59-
return 0;
48+
function podBalanceOf(address pod, address account) public nonReentrantView(_guard) view virtual returns(uint256) {
49+
return _pods.podBalanceOf(account, pod, balanceOf(account));
6050
}
6151

6252
function addPod(address pod) public virtual {
63-
_addPod(msg.sender, pod);
53+
if (_pods.addPod(msg.sender, pod, balanceOf(msg.sender)) > podsLimit) revert PodsLimitReachedForAccount();
6454
}
6555

6656
function removePod(address pod) public virtual {
67-
_removePod(msg.sender, pod);
57+
_pods.removePod(msg.sender, pod, balanceOf(msg.sender));
6858
}
6959

7060
function removeAllPods() public virtual {
71-
_removeAllPods(msg.sender);
72-
}
73-
74-
function _addPod(address account, address pod) internal virtual {
75-
if (pod == address(0)) revert InvalidPodAddress();
76-
if (!_pods[account].add(pod)) revert PodAlreadyAdded();
77-
if (_pods[account].length() > podsLimit) revert PodsLimitReachedForAccount();
78-
79-
uint256 balance = balanceOf(account);
80-
if (balance > 0) {
81-
_updateBalances(pod, address(0), account, balance);
82-
}
83-
}
84-
85-
function _removePod(address account, address pod) internal virtual {
86-
if (!_pods[account].remove(pod)) revert PodNotFound();
87-
88-
uint256 balance = balanceOf(account);
89-
if (balance > 0) {
90-
_updateBalances(pod, account, address(0), balance);
91-
}
92-
}
93-
94-
function _removeAllPods(address account) internal virtual {
95-
address[] memory items = _pods[account].items.get();
96-
uint256 balance = balanceOf(account);
97-
unchecked {
98-
for (uint256 i = items.length; i > 0; i--) {
99-
if (balance > 0) {
100-
_updateBalances(items[i - 1], account, address(0), balance);
101-
}
102-
_pods[account].remove(items[i - 1]);
103-
}
104-
}
105-
}
106-
107-
/// @notice Assembly implementation of the gas limited call to avoid return gas bomb,
108-
// moreover call to a destructed pod would also revert even inside try-catch block in Solidity 0.8.17
109-
/// @dev try IPod(pod).updateBalances{gas: _POD_CALL_GAS_LIMIT}(from, to, amount) {} catch {}
110-
function _updateBalances(address pod, address from, address to, uint256 amount) private {
111-
bytes4 selector = IPod.updateBalances.selector;
112-
bytes4 exception = InsufficientGas.selector;
113-
assembly { // solhint-disable-line no-inline-assembly
114-
let ptr := mload(0x40)
115-
mstore(ptr, selector)
116-
mstore(add(ptr, 0x04), from)
117-
mstore(add(ptr, 0x24), to)
118-
mstore(add(ptr, 0x44), amount)
119-
120-
if lt(div(mul(gas(), 63), 64), _POD_CALL_GAS_LIMIT) {
121-
mstore(0, exception)
122-
revert(0, 4)
123-
}
124-
pop(call(_POD_CALL_GAS_LIMIT, pod, 0, ptr, 0x64, 0, 0))
125-
}
61+
_pods.removeAllPods(msg.sender, balanceOf(msg.sender));
12662
}
12763

12864
// ERC20 Overrides
12965

13066
function _afterTokenTransfer(address from, address to, uint256 amount) internal nonReentrant(_guard) override virtual {
13167
super._afterTokenTransfer(from, to, amount);
132-
133-
unchecked {
134-
if (amount > 0 && from != to) {
135-
address[] memory a = _pods[from].items.get();
136-
address[] memory b = _pods[to].items.get();
137-
uint256 aLength = a.length;
138-
uint256 bLength = b.length;
139-
140-
for (uint256 i = 0; i < aLength; i++) {
141-
address pod = a[i];
142-
143-
uint256 j;
144-
for (j = 0; j < bLength; j++) {
145-
if (pod == b[j]) {
146-
// Both parties are participating of the same Pod
147-
_updateBalances(pod, from, to, amount);
148-
b[j] = address(0);
149-
break;
150-
}
151-
}
152-
153-
if (j == bLength) {
154-
// Sender is participating in a Pod, but receiver is not
155-
_updateBalances(pod, from, address(0), amount);
156-
}
157-
}
158-
159-
for (uint256 j = 0; j < bLength; j++) {
160-
address pod = b[j];
161-
if (pod != address(0)) {
162-
// Receiver is participating in a Pod, but sender is not
163-
_updateBalances(pod, address(0), to, amount);
164-
}
165-
}
166-
}
167-
}
68+
_pods.updateBalances(from, to, amount);
16869
}
16970
}

contracts/ERC721Pods.sol

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
2+
// SPDX-License-Identifier: MIT
3+
4+
pragma solidity ^0.8.0;
5+
6+
import "@openzeppelin/contracts/token/ERC721/ERC721.sol";
7+
import "@1inch/solidity-utils/contracts/libraries/AddressSet.sol";
8+
9+
import "./interfaces/IERC721Pods.sol";
10+
import "./TokenPodsLib.sol";
11+
import "./libs/ReentrancyGuard.sol";
12+
13+
abstract contract ERC721Pods is ERC721, IERC721Pods, ReentrancyGuardExt {
14+
using TokenPodsLib for TokenPodsLib.Data;
15+
using ReentrancyGuardLib for ReentrancyGuardLib.Data;
16+
17+
error PodsLimitReachedForAccount();
18+
19+
uint256 public immutable podsLimit;
20+
21+
ReentrancyGuardLib.Data private _guard;
22+
TokenPodsLib.Data private _pods;
23+
24+
constructor(uint256 podsLimit_) {
25+
podsLimit = podsLimit_;
26+
_guard.init();
27+
}
28+
29+
function hasPod(address account, address pod) public view virtual returns(bool) {
30+
return _pods.hasPod(account, pod);
31+
}
32+
33+
function podsCount(address account) public view virtual returns(uint256) {
34+
return _pods.podsCount(account);
35+
}
36+
37+
function podAt(address account, uint256 index) public view virtual returns(address) {
38+
return _pods.podAt(account, index);
39+
}
40+
41+
function pods(address account) public view virtual returns(address[] memory) {
42+
return _pods.pods(account);
43+
}
44+
45+
function balanceOf(address account) public nonReentrantView(_guard) view override(IERC721, ERC721) virtual returns(uint256) {
46+
return super.balanceOf(account);
47+
}
48+
49+
function podBalanceOf(address pod, address account) public nonReentrantView(_guard) view virtual returns(uint256) {
50+
return _pods.podBalanceOf(account, pod, balanceOf(account));
51+
}
52+
53+
function addPod(address pod) public virtual {
54+
if (_pods.addPod(msg.sender, pod, balanceOf(msg.sender)) > podsLimit) revert PodsLimitReachedForAccount();
55+
}
56+
57+
function removePod(address pod) public virtual {
58+
_pods.removePod(msg.sender, pod, balanceOf(msg.sender));
59+
}
60+
61+
function removeAllPods() public virtual {
62+
_pods.removeAllPods(msg.sender, balanceOf(msg.sender));
63+
}
64+
65+
// ERC721 Overrides
66+
67+
function _afterTokenTransfer(address from, address to, uint256 firstTokenId, uint256 batchSize) internal nonReentrant(_guard) override virtual {
68+
super._afterTokenTransfer(from, to, firstTokenId, batchSize);
69+
_pods.updateBalances(from, to, batchSize);
70+
}
71+
}

contracts/Pod.sol

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,30 @@ abstract contract Pod is IPod {
88
error AccessDenied();
99

1010
address public immutable token;
11+
uint256 public immutable tokenId;
1112

1213
modifier onlyToken {
1314
if (msg.sender != token) revert AccessDenied();
1415
_;
1516
}
1617

17-
constructor(address token_) {
18+
modifier onlyTokenId(uint256 id) {
19+
if (id != tokenId) revert AccessDenied();
20+
_;
21+
}
22+
23+
constructor(address token_, uint256 tokenId_) {
1824
token = token_;
25+
tokenId = tokenId_;
26+
}
27+
28+
function updateBalancesWithTokenId(address from, address to, uint256 amount, uint256 id) external onlyToken onlyTokenId(id) {
29+
_updateBalances(from, to, amount);
30+
}
31+
32+
function updateBalances(address from, address to, uint256 amount) external onlyToken {
33+
_updateBalances(from, to, amount);
1934
}
35+
36+
function _updateBalances(address from, address to, uint256 amount) internal virtual;
2037
}

0 commit comments

Comments
 (0)