diff --git a/.changeset/flat-turtles-repeat.md b/.changeset/flat-turtles-repeat.md new file mode 100644 index 00000000000..6b627201ac9 --- /dev/null +++ b/.changeset/flat-turtles-repeat.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`Arrays`: deprecate `findUpperBound` in favor of the new `lowerBound`. diff --git a/.changeset/thick-pumpkins-report.md b/.changeset/thick-pumpkins-report.md new file mode 100644 index 00000000000..f17a208950c --- /dev/null +++ b/.changeset/thick-pumpkins-report.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`Arrays`: add new functions `lowerBound`, `upperBound`, `lowerBoundMemory` and `upperBoundMemory` for lookups in sorted arrays with potential duplicates. diff --git a/contracts/mocks/ArraysMock.sol b/contracts/mocks/ArraysMock.sol index a00def29cb6..a2fbb6dea63 100644 --- a/contracts/mocks/ArraysMock.sol +++ b/contracts/mocks/ArraysMock.sol @@ -13,8 +13,24 @@ contract Uint256ArraysMock { _array = array; } - function findUpperBound(uint256 element) external view returns (uint256) { - return _array.findUpperBound(element); + function findUpperBound(uint256 value) external view returns (uint256) { + return _array.findUpperBound(value); + } + + function lowerBound(uint256 value) external view returns (uint256) { + return _array.lowerBound(value); + } + + function upperBound(uint256 value) external view returns (uint256) { + return _array.upperBound(value); + } + + function lowerBoundMemory(uint256[] memory array, uint256 value) external pure returns (uint256) { + return array.lowerBoundMemory(value); + } + + function upperBoundMemory(uint256[] memory array, uint256 value) external pure returns (uint256) { + return array.upperBoundMemory(value); } function unsafeAccess(uint256 pos) external view returns (uint256) { diff --git a/contracts/utils/Arrays.sol b/contracts/utils/Arrays.sol index aaab3ce592b..67c63a7e7b9 100644 --- a/contracts/utils/Arrays.sol +++ b/contracts/utils/Arrays.sol @@ -18,8 +18,12 @@ library Arrays { * values in the array are strictly less than `element`), the array length is * returned. Time complexity O(log n). * - * `array` is expected to be sorted in ascending order, and to contain no - * repeated elements. + * NOTE: The `array` is expected to be sorted in ascending order, and to + * contain no repeated elements. + * + * IMPORTANT: Deprecated. This implementation behaves as {lowerBound} but lacks + * support for repeated elements in the array. The {lowerBound} function should + * be used instead. */ function findUpperBound(uint256[] storage array, uint256 element) internal view returns (uint256) { uint256 low = 0; @@ -49,6 +53,132 @@ library Arrays { } } + /** + * @dev Searches an `array` sorted in ascending order and returns the first + * index that contains a value greater or equal than `element`. If no such index + * exists (i.e. all values in the array are strictly less than `element`), the array + * length is returned. Time complexity O(log n). + * + * See C++'s https://en.cppreference.com/w/cpp/algorithm/lower_bound[lower_bound]. + */ + function lowerBound(uint256[] storage array, uint256 element) internal view returns (uint256) { + uint256 low = 0; + uint256 high = array.length; + + if (high == 0) { + return 0; + } + + while (low < high) { + uint256 mid = Math.average(low, high); + + // Note that mid will always be strictly less than high (i.e. it will be a valid array index) + // because Math.average rounds towards zero (it does integer division with truncation). + if (unsafeAccess(array, mid).value < element) { + // this cannot overflow because mid < high + unchecked { + low = mid + 1; + } + } else { + high = mid; + } + } + + return low; + } + + /** + * @dev Searches an `array` sorted in ascending order and returns the first + * index that contains a value strictly greater than `element`. If no such index + * exists (i.e. all values in the array are strictly less than `element`), the array + * length is returned. Time complexity O(log n). + * + * See C++'s https://en.cppreference.com/w/cpp/algorithm/upper_bound[upper_bound]. + */ + function upperBound(uint256[] storage array, uint256 element) internal view returns (uint256) { + uint256 low = 0; + uint256 high = array.length; + + if (high == 0) { + return 0; + } + + while (low < high) { + uint256 mid = Math.average(low, high); + + // Note that mid will always be strictly less than high (i.e. it will be a valid array index) + // because Math.average rounds towards zero (it does integer division with truncation). + if (unsafeAccess(array, mid).value > element) { + high = mid; + } else { + // this cannot overflow because mid < high + unchecked { + low = mid + 1; + } + } + } + + return low; + } + + /** + * @dev Same as {lowerBound}, but with an array in memory. + */ + function lowerBoundMemory(uint256[] memory array, uint256 element) internal pure returns (uint256) { + uint256 low = 0; + uint256 high = array.length; + + if (high == 0) { + return 0; + } + + while (low < high) { + uint256 mid = Math.average(low, high); + + // Note that mid will always be strictly less than high (i.e. it will be a valid array index) + // because Math.average rounds towards zero (it does integer division with truncation). + if (unsafeMemoryAccess(array, mid) < element) { + // this cannot overflow because mid < high + unchecked { + low = mid + 1; + } + } else { + high = mid; + } + } + + return low; + } + + /** + * @dev Same as {upperBound}, but with an array in memory. + */ + function upperBoundMemory(uint256[] memory array, uint256 element) internal pure returns (uint256) { + uint256 low = 0; + uint256 high = array.length; + + if (high == 0) { + return 0; + } + + while (low < high) { + uint256 mid = Math.average(low, high); + + // Note that mid will always be strictly less than high (i.e. it will be a valid array index) + // because Math.average rounds towards zero (it does integer division with truncation). + if (unsafeMemoryAccess(array, mid) > element) { + high = mid; + } else { + // this cannot overflow because mid < high + unchecked { + low = mid + 1; + } + } + } + + return low; + } + /** * @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check. * diff --git a/test/utils/Arrays.test.js b/test/utils/Arrays.test.js index c585fee58e8..dc35e49447f 100644 --- a/test/utils/Arrays.test.js +++ b/test/utils/Arrays.test.js @@ -4,22 +4,22 @@ const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers'); const { randomArray, generators } = require('../helpers/random'); -// See https://en.cppreference.com/w/cpp/algorithm/ranges/lower_bound +// See https://en.cppreference.com/w/cpp/algorithm/lower_bound const lowerBound = (array, value) => { const i = array.findIndex(element => value <= element); return i == -1 ? array.length : i; }; // See https://en.cppreference.com/w/cpp/algorithm/upper_bound -// const upperBound = (array, value) => { -// const i = array.findIndex(element => value < element); -// return i == -1 ? array.length : i; -// }; +const upperBound = (array, value) => { + const i = array.findIndex(element => value < element); + return i == -1 ? array.length : i; +}; const hasDuplicates = array => array.some((v, i) => array.indexOf(v) != i); describe('Arrays', function () { - describe('findUpperBound', function () { + describe('search', function () { for (const [title, { array, tests }] of Object.entries({ 'Even number of elements': { array: [11n, 12n, 13n, 14n, 15n, 16n, 17n, 18n, 19n, 20n], @@ -82,10 +82,25 @@ describe('Arrays', function () { }); for (const [name, input] of Object.entries(tests)) { - it(name, async function () { - // findUpperBound does not support duplicated - if (hasDuplicates(array)) this.skip(); - expect(await this.mock.findUpperBound(input)).to.equal(lowerBound(array, input)); + describe(name, function () { + it('[deprecated] findUpperBound', async function () { + // findUpperBound does not support duplicated + if (hasDuplicates(array)) { + expect(await this.mock.findUpperBound(input)).to.be.equal(upperBound(array, input) - 1); + } else { + expect(await this.mock.findUpperBound(input)).to.be.equal(lowerBound(array, input)); + } + }); + + it('lowerBound', async function () { + expect(await this.mock.lowerBound(input)).to.be.equal(lowerBound(array, input)); + expect(await this.mock.lowerBoundMemory(array, input)).to.be.equal(lowerBound(array, input)); + }); + + it('upperBound', async function () { + expect(await this.mock.upperBound(input)).to.be.equal(upperBound(array, input)); + expect(await this.mock.upperBoundMemory(array, input)).to.be.equal(upperBound(array, input)); + }); }); } }); @@ -93,29 +108,29 @@ describe('Arrays', function () { }); describe('unsafeAccess', function () { - const contractCases = { + for (const [title, { artifact, elements }] of Object.entries({ address: { artifact: 'AddressArraysMock', elements: randomArray(generators.address, 10) }, bytes32: { artifact: 'Bytes32ArraysMock', elements: randomArray(generators.bytes32, 10) }, uint256: { artifact: 'Uint256ArraysMock', elements: randomArray(generators.uint256, 10) }, - }; - - const fixture = async () => { - const contracts = {}; - for (const [name, { artifact, elements }] of Object.entries(contractCases)) { - contracts[name] = await ethers.deployContract(artifact, [elements]); - } - return { contracts }; - }; + })) { + describe(title, function () { + const fixture = async () => { + return { mock: await ethers.deployContract(artifact, [elements]) }; + }; - beforeEach(async function () { - Object.assign(this, await loadFixture(fixture)); - }); + beforeEach(async function () { + Object.assign(this, await loadFixture(fixture)); + }); - for (const [name, { elements }] of Object.entries(contractCases)) { - it(name, async function () { for (const i in elements) { - expect(await this.contracts[name].unsafeAccess(i)).to.equal(elements[i]); + it(`unsafeAccess within bounds #${i}`, async function () { + expect(await this.mock.unsafeAccess(i)).to.equal(elements[i]); + }); } + + it('unsafeAccess outside bounds', async function () { + await expect(this.mock.unsafeAccess(elements.length)).to.not.be.rejected; + }); }); } }); diff --git a/test/utils/structs/Checkpoints.test.js b/test/utils/structs/Checkpoints.test.js index 9458c486ab8..fd22544b955 100644 --- a/test/utils/structs/Checkpoints.test.js +++ b/test/utils/structs/Checkpoints.test.js @@ -4,8 +4,6 @@ const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers'); const { VALUE_SIZES } = require('../../../scripts/generate/templates/Checkpoints.opts'); -const last = array => (array.length ? array[array.length - 1] : undefined); - describe('Checkpoints', function () { for (const length of VALUE_SIZES) { describe(`Trace${length}`, function () { @@ -81,7 +79,7 @@ describe('Checkpoints', function () { it('returns latest value', async function () { const latest = this.checkpoints.at(-1); expect(await this.methods.latest()).to.equal(latest.value); - expect(await this.methods.latestCheckpoint()).to.have.ordered.members([true, latest.key, latest.value]); + expect(await this.methods.latestCheckpoint()).to.deep.equal([true, latest.key, latest.value]); }); it('cannot push values in the past', async function () { @@ -115,7 +113,7 @@ describe('Checkpoints', function () { it('upper lookup & upperLookupRecent', async function () { for (let i = 0; i < 14; ++i) { - const value = last(this.checkpoints.filter(x => i >= x.key))?.value || 0n; + const value = this.checkpoints.findLast(x => i >= x.key)?.value || 0n; expect(await this.methods.upperLookup(i)).to.equal(value); expect(await this.methods.upperLookupRecent(i)).to.equal(value); @@ -137,7 +135,7 @@ describe('Checkpoints', function () { } for (let i = 0; i < 25; ++i) { - const value = last(allCheckpoints.filter(x => i >= x.key))?.value || 0n; + const value = allCheckpoints.findLast(x => i >= x.key)?.value || 0n; expect(await this.methods.upperLookup(i)).to.equal(value); expect(await this.methods.upperLookupRecent(i)).to.equal(value); }