-
Notifications
You must be signed in to change notification settings - Fork 0
/
LibDiamond.sol
350 lines (327 loc) · 12.1 KB
/
LibDiamond.sol
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.9;
/******************************************************************************\
* Author: Nick Mudge <nick@perfectabstractions.com> (https://twitter.com/mudgen)
* EIP-2535 Diamonds: https://eips.ethereum.org/EIPS/eip-2535
/******************************************************************************/
import {IDiamondCut} from "../interfaces/IDiamondCut.sol";
// Remember to add the loupe functions from DiamondLoupeFacet to the diamond.
// The loupe functions are required by the EIP2535 Diamonds standard
error InitializationFunctionReverted(
address _initializationContractAddress,
bytes _calldata
);
library LibDiamond {
bytes32 constant DIAMOND_STORAGE_POSITION =
keccak256("diamond.standard.diamond.storage");
struct FacetAddressAndPosition {
address facetAddress;
uint96 functionSelectorPosition; // position in facetFunctionSelectors.functionSelectors array
}
struct FacetFunctionSelectors {
bytes4[] functionSelectors;
uint256 facetAddressPosition; // position of facetAddress in facetAddresses array
}
struct DiamondStorage {
// maps function selector to the facet address and
// the position of the selector in the facetFunctionSelectors.selectors array
mapping(bytes4 => FacetAddressAndPosition) selectorToFacetAndPosition;
// maps facet addresses to function selectors
mapping(address => FacetFunctionSelectors) facetFunctionSelectors;
// facet addresses
address[] facetAddresses;
// Used to query if a contract implements an interface.
// Used to implement ERC-165.
mapping(bytes4 => bool) supportedInterfaces;
// owner of the contract
address contractOwner;
}
function diamondStorage()
internal
pure
returns (DiamondStorage storage ds)
{
bytes32 position = DIAMOND_STORAGE_POSITION;
assembly {
ds.slot := position
}
}
event OwnershipTransferred(
address indexed previousOwner,
address indexed newOwner
);
function setContractOwner(address _newOwner) internal {
DiamondStorage storage ds = diamondStorage();
address previousOwner = ds.contractOwner;
ds.contractOwner = _newOwner;
emit OwnershipTransferred(previousOwner, _newOwner);
}
function contractOwner() internal view returns (address contractOwner_) {
contractOwner_ = diamondStorage().contractOwner;
}
function enforceIsContractOwner() internal view {
require(
msg.sender == diamondStorage().contractOwner,
"LibDiamond: Must be contract owner"
);
}
event DiamondCut(
IDiamondCut.FacetCut[] _diamondCut,
address _init,
bytes _calldata
);
// Internal function version of diamondCut
function diamondCut(
IDiamondCut.FacetCut[] memory _diamondCut,
address _init,
bytes memory _calldata
) internal {
for (
uint256 facetIndex;
facetIndex < _diamondCut.length;
facetIndex++
) {
IDiamondCut.FacetCutAction action = _diamondCut[facetIndex].action;
if (action == IDiamondCut.FacetCutAction.Add) {
addFunctions(
_diamondCut[facetIndex].facetAddress,
_diamondCut[facetIndex].functionSelectors
);
} else if (action == IDiamondCut.FacetCutAction.Replace) {
replaceFunctions(
_diamondCut[facetIndex].facetAddress,
_diamondCut[facetIndex].functionSelectors
);
} else if (action == IDiamondCut.FacetCutAction.Remove) {
removeFunctions(
_diamondCut[facetIndex].facetAddress,
_diamondCut[facetIndex].functionSelectors
);
} else {
revert("LibDiamondCut: Incorrect FacetCutAction");
}
}
emit DiamondCut(_diamondCut, _init, _calldata);
initializeDiamondCut(_init, _calldata);
}
function addFunctions(
address _facetAddress,
bytes4[] memory _functionSelectors
) internal {
require(
_functionSelectors.length > 0,
"LibDiamondCut: No selectors in facet to cut"
);
DiamondStorage storage ds = diamondStorage();
require(
_facetAddress != address(0),
"LibDiamondCut: Add facet can't be address(0)"
);
uint96 selectorPosition = uint96(
ds.facetFunctionSelectors[_facetAddress].functionSelectors.length
);
// add new facet address if it does not exist
if (selectorPosition == 0) {
addFacet(ds, _facetAddress);
}
for (
uint256 selectorIndex;
selectorIndex < _functionSelectors.length;
selectorIndex++
) {
bytes4 selector = _functionSelectors[selectorIndex];
address oldFacetAddress = ds
.selectorToFacetAndPosition[selector]
.facetAddress;
require(
oldFacetAddress == address(0),
"LibDiamondCut: Can't add function that already exists"
);
addFunction(ds, selector, selectorPosition, _facetAddress);
selectorPosition++;
}
}
function replaceFunctions(
address _facetAddress,
bytes4[] memory _functionSelectors
) internal {
require(
_functionSelectors.length > 0,
"LibDiamondCut: No selectors in facet to cut"
);
DiamondStorage storage ds = diamondStorage();
require(
_facetAddress != address(0),
"LibDiamondCut: Add facet can't be address(0)"
);
uint96 selectorPosition = uint96(
ds.facetFunctionSelectors[_facetAddress].functionSelectors.length
);
// add new facet address if it does not exist
if (selectorPosition == 0) {
addFacet(ds, _facetAddress);
}
for (
uint256 selectorIndex;
selectorIndex < _functionSelectors.length;
selectorIndex++
) {
bytes4 selector = _functionSelectors[selectorIndex];
address oldFacetAddress = ds
.selectorToFacetAndPosition[selector]
.facetAddress;
require(
oldFacetAddress != _facetAddress,
"LibDiamondCut: Can't replace function with same function"
);
removeFunction(ds, oldFacetAddress, selector);
addFunction(ds, selector, selectorPosition, _facetAddress);
selectorPosition++;
}
}
function removeFunctions(
address _facetAddress,
bytes4[] memory _functionSelectors
) internal {
require(
_functionSelectors.length > 0,
"LibDiamondCut: No selectors in facet to cut"
);
DiamondStorage storage ds = diamondStorage();
// if function does not exist then do nothing and return
require(
_facetAddress == address(0),
"LibDiamondCut: Remove facet address must be address(0)"
);
for (
uint256 selectorIndex;
selectorIndex < _functionSelectors.length;
selectorIndex++
) {
bytes4 selector = _functionSelectors[selectorIndex];
address oldFacetAddress = ds
.selectorToFacetAndPosition[selector]
.facetAddress;
removeFunction(ds, oldFacetAddress, selector);
}
}
function addFacet(DiamondStorage storage ds, address _facetAddress)
internal
{
enforceHasContractCode(
_facetAddress,
"LibDiamondCut: New facet has no code"
);
ds.facetFunctionSelectors[_facetAddress].facetAddressPosition = ds
.facetAddresses
.length;
ds.facetAddresses.push(_facetAddress);
}
function addFunction(
DiamondStorage storage ds,
bytes4 _selector,
uint96 _selectorPosition,
address _facetAddress
) internal {
ds
.selectorToFacetAndPosition[_selector]
.functionSelectorPosition = _selectorPosition;
ds.facetFunctionSelectors[_facetAddress].functionSelectors.push(
_selector
);
ds.selectorToFacetAndPosition[_selector].facetAddress = _facetAddress;
}
function removeFunction(
DiamondStorage storage ds,
address _facetAddress,
bytes4 _selector
) internal {
require(
_facetAddress != address(0),
"LibDiamondCut: Can't remove function that doesn't exist"
);
// an immutable function is a function defined directly in a diamond
require(
_facetAddress != address(this),
"LibDiamondCut: Can't remove immutable function"
);
// replace selector with last selector, then delete last selector
uint256 selectorPosition = ds
.selectorToFacetAndPosition[_selector]
.functionSelectorPosition;
uint256 lastSelectorPosition = ds
.facetFunctionSelectors[_facetAddress]
.functionSelectors
.length - 1;
// if not the same then replace _selector with lastSelector
if (selectorPosition != lastSelectorPosition) {
bytes4 lastSelector = ds
.facetFunctionSelectors[_facetAddress]
.functionSelectors[lastSelectorPosition];
ds.facetFunctionSelectors[_facetAddress].functionSelectors[
selectorPosition
] = lastSelector;
ds
.selectorToFacetAndPosition[lastSelector]
.functionSelectorPosition = uint96(selectorPosition);
}
// delete the last selector
ds.facetFunctionSelectors[_facetAddress].functionSelectors.pop();
delete ds.selectorToFacetAndPosition[_selector];
// if no more selectors for facet address then delete the facet address
if (lastSelectorPosition == 0) {
// replace facet address with last facet address and delete last facet address
uint256 lastFacetAddressPosition = ds.facetAddresses.length - 1;
uint256 facetAddressPosition = ds
.facetFunctionSelectors[_facetAddress]
.facetAddressPosition;
if (facetAddressPosition != lastFacetAddressPosition) {
address lastFacetAddress = ds.facetAddresses[
lastFacetAddressPosition
];
ds.facetAddresses[facetAddressPosition] = lastFacetAddress;
ds
.facetFunctionSelectors[lastFacetAddress]
.facetAddressPosition = facetAddressPosition;
}
ds.facetAddresses.pop();
delete ds
.facetFunctionSelectors[_facetAddress]
.facetAddressPosition;
}
}
function initializeDiamondCut(address _init, bytes memory _calldata)
internal
{
if (_init == address(0)) {
return;
}
enforceHasContractCode(
_init,
"LibDiamondCut: _init address has no code"
);
(bool success, bytes memory error) = _init.delegatecall(_calldata);
if (!success) {
if (error.length > 0) {
// bubble up error
/// @solidity memory-safe-assembly
assembly {
let returndata_size := mload(error)
revert(add(32, error), returndata_size)
}
} else {
revert InitializationFunctionReverted(_init, _calldata);
}
}
}
function enforceHasContractCode(
address _contract,
string memory _errorMessage
) internal view {
uint256 contractSize;
assembly {
contractSize := extcodesize(_contract)
}
require(contractSize > 0, _errorMessage);
}
}