diff --git a/CHANGELOG.md b/CHANGELOG.md index 38a94bdc263..2ad1832b81a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ * `Context`: making `_msgData` return `bytes calldata` instead of `bytes memory` ([#2492](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2492)) * `ERC20`: Removed the `_setDecimals` function and the storage slot associated to decimals. ([#2502](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2502)) * `Strings`: addition of a `toHexString` function. ([#2504](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2504)) + * `EnumerableMap`: change implementation to optimize for `key → value` lookups instead of enumeration. ([#2518](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2518)) ## 3.4.0 (2021-02-02) diff --git a/contracts/utils/EnumerableMap.sol b/contracts/utils/EnumerableMap.sol index 77fd125624f..afcf47ab9aa 100644 --- a/contracts/utils/EnumerableMap.sol +++ b/contracts/utils/EnumerableMap.sol @@ -2,6 +2,8 @@ pragma solidity ^0.8.0; +import "./EnumerableSet.sol"; + /** * @dev Library for managing an enumerable variant of Solidity's * https://solidity.readthedocs.io/en/latest/types.html#mapping-types[`mapping`] @@ -27,6 +29,8 @@ pragma solidity ^0.8.0; * supported. */ library EnumerableMap { + using EnumerableSet for EnumerableSet.Bytes32Set; + // To implement this library for multiple types with as little code // repetition as possible, we write it in terms of a generic Map type with // bytes32 keys and values. @@ -36,18 +40,11 @@ library EnumerableMap { // This means that we can only create new EnumerableMaps for types that fit // in bytes32. - struct MapEntry { - bytes32 _key; - bytes32 _value; - } - struct Map { - // Storage of map keys and values - MapEntry[] _entries; + // Storage of keys + EnumerableSet.Bytes32Set _keys; - // Position of the entry defined by a key in the `entries` array, plus 1 - // because index 0 means a key is not in the map. - mapping (bytes32 => uint256) _indexes; + mapping (bytes32 => bytes32) _values; } /** @@ -58,19 +55,8 @@ library EnumerableMap { * already present. */ function _set(Map storage map, bytes32 key, bytes32 value) private returns (bool) { - // We read and store the key's index to prevent multiple reads from the same storage slot - uint256 keyIndex = map._indexes[key]; - - if (keyIndex == 0) { // Equivalent to !contains(map, key) - map._entries.push(MapEntry({ _key: key, _value: value })); - // The entry is stored at length-1, but we add 1 to all indexes - // and use 0 as a sentinel value - map._indexes[key] = map._entries.length; - return true; - } else { - map._entries[keyIndex - 1]._value = value; - return false; - } + map._values[key] = value; + return map._keys.add(key); } /** @@ -79,51 +65,22 @@ library EnumerableMap { * Returns true if the key was removed from the map, that is if it was present. */ function _remove(Map storage map, bytes32 key) private returns (bool) { - // We read and store the key's index to prevent multiple reads from the same storage slot - uint256 keyIndex = map._indexes[key]; - - if (keyIndex != 0) { // Equivalent to contains(map, key) - // To delete a key-value pair from the _entries array in O(1), we swap the entry to delete with the last one - // in the array, and then remove the last entry (sometimes called as 'swap and pop'). - // This modifies the order of the array, as noted in {at}. - - uint256 toDeleteIndex = keyIndex - 1; - uint256 lastIndex = map._entries.length - 1; - - // When the entry to delete is the last one, the swap operation is unnecessary. However, since this occurs - // so rarely, we still do the swap anyway to avoid the gas cost of adding an 'if' statement. - - MapEntry storage lastEntry = map._entries[lastIndex]; - - // Move the last entry to the index where the entry to delete is - map._entries[toDeleteIndex] = lastEntry; - // Update the index for the moved entry - map._indexes[lastEntry._key] = toDeleteIndex + 1; // All indexes are 1-based - - // Delete the slot where the moved entry was stored - map._entries.pop(); - - // Delete the index for the deleted slot - delete map._indexes[key]; - - return true; - } else { - return false; - } + delete map._values[key]; + return map._keys.remove(key); } /** * @dev Returns true if the key is in the map. O(1). */ function _contains(Map storage map, bytes32 key) private view returns (bool) { - return map._indexes[key] != 0; + return map._keys.contains(key); } /** * @dev Returns the number of key-value pairs in the map. O(1). */ function _length(Map storage map) private view returns (uint256) { - return map._entries.length; + return map._keys.length(); } /** @@ -137,10 +94,8 @@ library EnumerableMap { * - `index` must be strictly less than {length}. */ function _at(Map storage map, uint256 index) private view returns (bytes32, bytes32) { - require(map._entries.length > index, "EnumerableMap: index out of bounds"); - - MapEntry storage entry = map._entries[index]; - return (entry._key, entry._value); + bytes32 key = map._keys.at(index); + return (key, map._values[key]); } /** @@ -148,9 +103,12 @@ library EnumerableMap { * Does not revert if `key` is not in the map. */ function _tryGet(Map storage map, bytes32 key) private view returns (bool, bytes32) { - uint256 keyIndex = map._indexes[key]; - if (keyIndex == 0) return (false, 0); // Equivalent to contains(map, key) - return (true, map._entries[keyIndex - 1]._value); // All indexes are 1-based + bytes32 value = map._values[key]; + if (value == bytes32(0)) { + return (_contains(map, key), bytes32(0)); + } else { + return (true, value); + } } /** @@ -161,9 +119,9 @@ library EnumerableMap { * - `key` must be in the map. */ function _get(Map storage map, bytes32 key) private view returns (bytes32) { - uint256 keyIndex = map._indexes[key]; - require(keyIndex != 0, "EnumerableMap: nonexistent key"); // Equivalent to contains(map, key) - return map._entries[keyIndex - 1]._value; // All indexes are 1-based + bytes32 value = map._values[key]; + require(value != 0 || _contains(map, key), "EnumerableMap: nonexistent key"); + return value; } /** @@ -173,9 +131,9 @@ library EnumerableMap { * message unnecessarily. For custom revert reasons use {_tryGet}. */ function _get(Map storage map, bytes32 key, string memory errorMessage) private view returns (bytes32) { - uint256 keyIndex = map._indexes[key]; - require(keyIndex != 0, errorMessage); // Equivalent to contains(map, key) - return map._entries[keyIndex - 1]._value; // All indexes are 1-based + bytes32 value = map._values[key]; + require(value != 0 || _contains(map, key), errorMessage); + return value; } // UintToAddressMap diff --git a/test/token/ERC721/ERC721.test.js b/test/token/ERC721/ERC721.test.js index 23f4088056c..d3d585429bd 100644 --- a/test/token/ERC721/ERC721.test.js +++ b/test/token/ERC721/ERC721.test.js @@ -801,7 +801,7 @@ contract('ERC721', function (accounts) { it('reverts if index is greater than supply', async function () { await expectRevert( - this.token.tokenByIndex(2), 'EnumerableMap: index out of bounds', + this.token.tokenByIndex(2), 'EnumerableSet: index out of bounds', ); }); @@ -908,7 +908,7 @@ contract('ERC721', function (accounts) { await this.token.burn(secondTokenId, { from: owner }); expect(await this.token.totalSupply()).to.be.bignumber.equal('0'); await expectRevert( - this.token.tokenByIndex(0), 'EnumerableMap: index out of bounds', + this.token.tokenByIndex(0), 'EnumerableSet: index out of bounds', ); });