diff --git a/CHANGELOG.md b/CHANGELOG.md index 1fa60ab87b4..393bb199bf6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ ### Improvements: * Upgraded the minimum compiler version to v0.5.2: this removes many Solidity warnings that were false positives. + * `Counter`'s API has been improved, and is now used by `ERC721` (though it is still in `drafts`). + * `ERC721`'s transfers are now more gas efficient due to removal of unnecessary `SafeMath` calls. * Fixed variable shadowing issues. ### Bugfixes: diff --git a/contracts/drafts/Counters.sol b/contracts/drafts/Counters.sol index 9097aadaf44..7a68039cc7c 100644 --- a/contracts/drafts/Counters.sol +++ b/contracts/drafts/Counters.sol @@ -1,24 +1,37 @@ pragma solidity ^0.5.2; +import "../math/SafeMath.sol"; + /** * @title Counters * @author Matt Condon (@shrugs) - * @dev Provides an incrementing uint256 id acquired by the `Counter#next` getter. - * Use this for issuing ERC721 ids or keeping track of request ids, anything you want, really. + * @dev Provides counters that can only be incremented or decremented by one. This can be used e.g. to track the number + * of elements in a mapping, issuing ERC721 ids, or counting request ids * - * Include with `using Counters` for Counters.Counter;` - * @notice Does not allow an Id of 0, which is popularly used to signify a null state in solidity. - * Does not protect from overflows, but if you have 2^256 ids, you have other problems. - * (But actually, it's generally impossible to increment a counter this many times, energy wise - * so it's not something you have to worry about.) + * Include with `using Counter for Counter.Counter;` + * Since it is not possible to overflow a 256 bit integer with increments of one, `increment` can skip the SafeMath + * overflow check, thereby saving gas. This does assume however correct usage, in that the underlying `_value` is never + * directly accessed. */ library Counters { + using SafeMath for uint256; + struct Counter { - uint256 current; // default: 0 + // This variable should never be directly accessed by users of the library: interactions must be restricted to + // the library's function. As of Solidity v0.5.2, this cannot be enforced, though there is a proposal to add + // this feature: see https://github.com/ethereum/solidity/issues/4637 + uint256 _value; // default: 0 + } + + function current(Counter storage counter) internal view returns (uint256) { + return counter._value; + } + + function increment(Counter storage counter) internal { + counter._value += 1; } - function next(Counter storage index) internal returns (uint256) { - index.current += 1; - return index.current; + function decrement(Counter storage counter) internal { + counter._value = counter._value.sub(1); } } diff --git a/contracts/mocks/CountersImpl.sol b/contracts/mocks/CountersImpl.sol index 13026491f69..21d4f04c6b0 100644 --- a/contracts/mocks/CountersImpl.sol +++ b/contracts/mocks/CountersImpl.sol @@ -5,13 +5,17 @@ import "../drafts/Counters.sol"; contract CountersImpl { using Counters for Counters.Counter; - uint256 public theId; + Counters.Counter private _counter; - // use whatever key you want to track your counters - mapping(string => Counters.Counter) private _counters; + function current() public view returns (uint256) { + return _counter.current(); + } + + function increment() public { + _counter.increment(); + } - function doThing(string memory key) public returns (uint256) { - theId = _counters[key].next(); - return theId; + function decrement() public { + _counter.decrement(); } } diff --git a/contracts/token/ERC721/ERC721.sol b/contracts/token/ERC721/ERC721.sol index 878801be685..d26ce72e56d 100644 --- a/contracts/token/ERC721/ERC721.sol +++ b/contracts/token/ERC721/ERC721.sol @@ -4,6 +4,7 @@ import "./IERC721.sol"; import "./IERC721Receiver.sol"; import "../../math/SafeMath.sol"; import "../../utils/Address.sol"; +import "../../drafts/Counters.sol"; import "../../introspection/ERC165.sol"; /** @@ -13,6 +14,7 @@ import "../../introspection/ERC165.sol"; contract ERC721 is ERC165, IERC721 { using SafeMath for uint256; using Address for address; + using Counters for Counters.Counter; // Equals to `bytes4(keccak256("onERC721Received(address,address,uint256,bytes)"))` // which can be also obtained as `IERC721Receiver(0).onERC721Received.selector` @@ -25,7 +27,7 @@ contract ERC721 is ERC165, IERC721 { mapping (uint256 => address) private _tokenApprovals; // Mapping from owner to number of owned token - mapping (address => uint256) private _ownedTokensCount; + mapping (address => Counters.Counter) private _ownedTokensCount; // Mapping from owner to operator approvals mapping (address => mapping (address => bool)) private _operatorApprovals; @@ -56,7 +58,7 @@ contract ERC721 is ERC165, IERC721 { */ function balanceOf(address owner) public view returns (uint256) { require(owner != address(0)); - return _ownedTokensCount[owner]; + return _ownedTokensCount[owner].current(); } /** @@ -200,7 +202,7 @@ contract ERC721 is ERC165, IERC721 { require(!_exists(tokenId)); _tokenOwner[tokenId] = to; - _ownedTokensCount[to] = _ownedTokensCount[to].add(1); + _ownedTokensCount[to].increment(); emit Transfer(address(0), to, tokenId); } @@ -217,7 +219,7 @@ contract ERC721 is ERC165, IERC721 { _clearApproval(tokenId); - _ownedTokensCount[owner] = _ownedTokensCount[owner].sub(1); + _ownedTokensCount[owner].decrement(); _tokenOwner[tokenId] = address(0); emit Transfer(owner, address(0), tokenId); @@ -245,8 +247,8 @@ contract ERC721 is ERC165, IERC721 { _clearApproval(tokenId); - _ownedTokensCount[from] = _ownedTokensCount[from].sub(1); - _ownedTokensCount[to] = _ownedTokensCount[to].add(1); + _ownedTokensCount[from].decrement(); + _ownedTokensCount[to].increment(); _tokenOwner[tokenId] = to; diff --git a/test/drafts/Counters.test.js b/test/drafts/Counters.test.js index dc51ca55c71..ad10d51c633 100644 --- a/test/drafts/Counters.test.js +++ b/test/drafts/Counters.test.js @@ -1,37 +1,58 @@ -const { BN } = require('openzeppelin-test-helpers'); +const { shouldFail } = require('openzeppelin-test-helpers'); const CountersImpl = artifacts.require('CountersImpl'); -const EXPECTED = [new BN(1), new BN(2), new BN(3), new BN(4)]; -const KEY1 = web3.utils.sha3('key1'); -const KEY2 = web3.utils.sha3('key2'); - -contract('Counters', function ([_, owner]) { +contract('Counters', function () { beforeEach(async function () { - this.mock = await CountersImpl.new({ from: owner }); + this.counter = await CountersImpl.new(); + }); + + it('starts at zero', async function () { + (await this.counter.current()).should.be.bignumber.equal('0'); }); - context('custom key', async function () { - it('should return expected values', async function () { - for (const expectedId of EXPECTED) { - await this.mock.doThing(KEY1, { from: owner }); - const actualId = await this.mock.theId(); - actualId.should.be.bignumber.equal(expectedId); - } + describe('increment', function () { + it('increments the current value by one', async function () { + await this.counter.increment(); + (await this.counter.current()).should.be.bignumber.equal('1'); + }); + + it('can be called multiple times', async function () { + await this.counter.increment(); + await this.counter.increment(); + await this.counter.increment(); + + (await this.counter.current()).should.be.bignumber.equal('3'); }); }); - context('parallel keys', async function () { - it('should return expected values for each counter', async function () { - for (const expectedId of EXPECTED) { - await this.mock.doThing(KEY1, { from: owner }); - let actualId = await this.mock.theId(); - actualId.should.be.bignumber.equal(expectedId); - - await this.mock.doThing(KEY2, { from: owner }); - actualId = await this.mock.theId(); - actualId.should.be.bignumber.equal(expectedId); - } + describe('decrement', function () { + beforeEach(async function () { + await this.counter.increment(); + (await this.counter.current()).should.be.bignumber.equal('1'); + }); + + it('decrements the current value by one', async function () { + await this.counter.decrement(); + (await this.counter.current()).should.be.bignumber.equal('0'); + }); + + it('reverts if the current value is 0', async function () { + await this.counter.decrement(); + await shouldFail.reverting(this.counter.decrement()); + }); + + it('can be called multiple times', async function () { + await this.counter.increment(); + await this.counter.increment(); + + (await this.counter.current()).should.be.bignumber.equal('3'); + + await this.counter.decrement(); + await this.counter.decrement(); + await this.counter.decrement(); + + (await this.counter.current()).should.be.bignumber.equal('0'); }); }); });