diff --git a/contracts/drafts/ERC20Snapshot.sol b/contracts/drafts/ERC20Snapshot.sol index 94c114fd9bc..a813d6004ab 100644 --- a/contracts/drafts/ERC20Snapshot.sol +++ b/contracts/drafts/ERC20Snapshot.sol @@ -18,14 +18,15 @@ contract ERC20Snapshot is ERC20 { // Snapshots store a value at the time a snapshot is taken (and a new snapshot id created), and the corresponding // snapshot id. Each account has individual snapshots taken on demand, as does the token's total supply. - // These two fields (value and id) belong together, but are not part of a struct so that functions that work on - // arrays can be called on them. - - mapping (address => uint256[]) private _accountSnapshotIds; - mapping (address => uint256[]) private _accountSnapshotValues; + // Snapshoted values have arrays of ids and the value corresponding to that id. These could be an array of a + // Snapshot struct, but that would impede usage of functions that work on an array. + struct Snapshots { + uint256[] ids; + uint256[] values; + } - uint256[] private _totalSupplySnapshotIds; - uint256[] private _totalSupplySnapshotValues; + mapping (address => Snapshots) private _accountBalanceSnapshots; + Snapshots private _totalSupplySnaphots; // Snapshot ids increase monotonically, with the first value being 1. An id of 0 is invalid. Counters.Counter private _currentSnapshotId; @@ -44,14 +45,13 @@ contract ERC20Snapshot is ERC20 { } function balanceOfAt(address account, uint256 snapshotId) public view returns (uint256) { - (bool snapshotted, uint256 value) = - _valueAt(snapshotId, _accountSnapshotValues[account], _accountSnapshotIds[account]); + (bool snapshotted, uint256 value) = _valueAt(snapshotId, _accountBalanceSnapshots[account]); return snapshotted ? value : balanceOf(account); } function totalSupplyAt(uint256 snapshotId) public view returns(uint256) { - (bool snapshotted, uint256 value) = _valueAt(snapshotId, _totalSupplySnapshotValues, _totalSupplySnapshotIds); + (bool snapshotted, uint256 value) = _valueAt(snapshotId, _totalSupplySnaphots); return snapshotted ? value : totalSupply(); } @@ -93,34 +93,34 @@ contract ERC20Snapshot is ERC20 { // In summary, we need to find an element in an array, returning the index of the smallest value that is larger if // it is not found, unless said value doesn't exist (e.g. when all values are smaller). Arrays.findUpperBound does // exactly this. - function _valueAt(uint256 snapshotId, uint256[] storage values, uint256[] storage ids) + function _valueAt(uint256 snapshotId, Snapshots storage snapshots) private view returns (bool, uint256) { require(snapshotId > 0); require(snapshotId <= _currentSnapshotId.current()); - uint256 index = ids.findUpperBound(snapshotId); + uint256 index = snapshots.ids.findUpperBound(snapshotId); - if (index == ids.length) { + if (index == snapshots.ids.length) { return (false, 0); } else { - return (true, values[index]); + return (true, snapshots.values[index]); } } function _updateAccountSnapshot(address account) private { - _updateSnapshot(_accountSnapshotValues[account], _accountSnapshotIds[account], balanceOf(account)); + _updateSnapshot(_accountBalanceSnapshots[account], balanceOf(account)); } function _updateTotalSupplySnapshot() private { - _updateSnapshot(_totalSupplySnapshotValues, _totalSupplySnapshotIds, totalSupply()); + _updateSnapshot(_totalSupplySnaphots, totalSupply()); } - function _updateSnapshot(uint256[] storage values, uint256[] storage ids, uint256 currentValue) private { + function _updateSnapshot(Snapshots storage snapshots, uint256 currentValue) private { uint256 currentId = _currentSnapshotId.current(); - if (_lastSnapshotId(ids) < currentId) { - ids.push(currentId); - values.push(currentValue); + if (_lastSnapshotId(snapshots.ids) < currentId) { + snapshots.ids.push(currentId); + snapshots.values.push(currentValue); } }