diff --git a/.changeset/dirty-cobras-smile.md b/.changeset/dirty-cobras-smile.md new file mode 100644 index 00000000000..f03e936b5d2 --- /dev/null +++ b/.changeset/dirty-cobras-smile.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`Arrays`: add a `sort` function. diff --git a/contracts/utils/Arrays.sol b/contracts/utils/Arrays.sol index 67c63a7e7b9..f1a77f37133 100644 --- a/contracts/utils/Arrays.sol +++ b/contracts/utils/Arrays.sol @@ -12,6 +12,69 @@ import {Math} from "./math/Math.sol"; library Arrays { using StorageSlot for bytes32; + /** + * @dev Sort an array (in memory) in increasing order. + * + * This function does the sorting "in place", meaning that it overrides the input. The object is returned for + * convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array. + * + * NOTE: this function's cost is `O(n · log(n))` in average and `O(n²)` in the worst case, with n the length of the + * array. Using it in view functions that are executed through `eth_call` is safe, but one should be very careful + * when executing this as part of a transaction. If the array being sorted is too large, the sort operation may + * consume more gas than is available in a block, leading to potential DoS. + */ + function sort(uint256[] memory array) internal pure returns (uint256[] memory) { + _quickSort(array, 0, array.length); + return array; + } + + /** + * @dev Performs a quick sort on an array in memory. The array is sorted in increasing order. + * + * Invariant: `i <= j <= array.length`. This is the case when initially called by {sort} and is preserved in + * subcalls. + */ + function _quickSort(uint256[] memory array, uint256 i, uint256 j) private pure { + unchecked { + // Can't overflow given `i <= j` + if (j - i < 2) return; + + // Use first element as pivot + uint256 pivot = unsafeMemoryAccess(array, i); + // Position where the pivot should be at the end of the loop + uint256 index = i; + + for (uint256 k = i + 1; k < j; ++k) { + // Unsafe access is safe given `k < j <= array.length`. + if (unsafeMemoryAccess(array, k) < pivot) { + // If array[k] is smaller than the pivot, we increment the index and move array[k] there. + _swap(array, ++index, k); + } + } + + // Swap pivot into place + _swap(array, i, index); + + _quickSort(array, i, index); // Sort the left side of the pivot + _quickSort(array, index + 1, j); // Sort the right side of the pivot + } + } + + /** + * @dev Swaps the elements at positions `i` and `j` in the `arr` array. + */ + function _swap(uint256[] memory arr, uint256 i, uint256 j) private pure { + assembly { + let start := add(arr, 0x20) // Pointer to the first element of the array + let pos_i := add(start, mul(i, 0x20)) + let pos_j := add(start, mul(j, 0x20)) + let val_i := mload(pos_i) + let val_j := mload(pos_j) + mstore(pos_i, val_j) + mstore(pos_j, val_i) + } + } + /** * @dev Searches a sorted `array` and returns the first index that contains * a value greater or equal to `element`. If no such index exists (i.e. all @@ -238,7 +301,7 @@ library Arrays { * * WARNING: Only use if you are certain `pos` is lower than the array length. */ - function unsafeMemoryAccess(uint256[] memory arr, uint256 pos) internal pure returns (uint256 res) { + function unsafeMemoryAccess(address[] memory arr, uint256 pos) internal pure returns (address res) { assembly { res := mload(add(add(arr, 0x20), mul(pos, 0x20))) } @@ -249,7 +312,18 @@ library Arrays { * * WARNING: Only use if you are certain `pos` is lower than the array length. */ - function unsafeMemoryAccess(address[] memory arr, uint256 pos) internal pure returns (address res) { + function unsafeMemoryAccess(bytes32[] memory arr, uint256 pos) internal pure returns (bytes32 res) { + assembly { + res := mload(add(add(arr, 0x20), mul(pos, 0x20))) + } + } + + /** + * @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check. + * + * WARNING: Only use if you are certain `pos` is lower than the array length. + */ + function unsafeMemoryAccess(uint256[] memory arr, uint256 pos) internal pure returns (uint256 res) { assembly { res := mload(add(add(arr, 0x20), mul(pos, 0x20))) } diff --git a/scripts/generate/templates/Checkpoints.t.js b/scripts/generate/templates/Checkpoints.t.js index 7e6a738dbeb..baea5c315b2 100644 --- a/scripts/generate/templates/Checkpoints.t.js +++ b/scripts/generate/templates/Checkpoints.t.js @@ -7,8 +7,8 @@ const header = `\ pragma solidity ^0.8.20; import {Test} from "forge-std/Test.sol"; -import {SafeCast} from "../../../contracts/utils/math/SafeCast.sol"; -import {Checkpoints} from "../../../contracts/utils/structs/Checkpoints.sol"; +import {SafeCast} from "@openzeppelin/contracts/utils/math/SafeCast.sol"; +import {Checkpoints} from "@openzeppelin/contracts/utils/structs/Checkpoints.sol"; `; /* eslint-disable max-len */ diff --git a/test/utils/Arrays.t.sol b/test/utils/Arrays.t.sol new file mode 100644 index 00000000000..c3d147562ce --- /dev/null +++ b/test/utils/Arrays.t.sol @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.20; + +import {Test} from "forge-std/Test.sol"; +import {Arrays} from "@openzeppelin/contracts/utils/Arrays.sol"; + +contract ArraysTest is Test { + function testSort(uint256[] memory values) public { + Arrays.sort(values); + for (uint256 i = 1; i < values.length; ++i) { + assertLe(values[i - 1], values[i]); + } + } +} diff --git a/test/utils/Arrays.test.js b/test/utils/Arrays.test.js index dc35e49447f..ffe5d5a22fe 100644 --- a/test/utils/Arrays.test.js +++ b/test/utils/Arrays.test.js @@ -16,9 +16,56 @@ const upperBound = (array, value) => { return i == -1 ? array.length : i; }; +// By default, js "sort" cast to string and then sort in alphabetical order. Use this to sort numbers. +const compareNumbers = (a, b) => (a > b ? 1 : a < b ? -1 : 0); + const hasDuplicates = array => array.some((v, i) => array.indexOf(v) != i); describe('Arrays', function () { + const fixture = async () => { + return { mock: await ethers.deployContract('$Arrays') }; + }; + + beforeEach(async function () { + Object.assign(this, await loadFixture(fixture)); + }); + + describe('sort', function () { + for (const length of [0, 1, 2, 8, 32, 128]) { + it(`sort array of length ${length}`, async function () { + this.elements = randomArray(generators.uint256, length); + this.expected = Array.from(this.elements).sort(compareNumbers); + }); + + if (length > 1) { + it(`sort array of length ${length} (identical elements)`, async function () { + this.elements = Array(length).fill(generators.uint256()); + this.expected = this.elements; + }); + + it(`sort array of length ${length} (already sorted)`, async function () { + this.elements = randomArray(generators.uint256, length).sort(compareNumbers); + this.expected = this.elements; + }); + + it(`sort array of length ${length} (sorted in reverse order)`, async function () { + this.elements = randomArray(generators.uint256, length).sort(compareNumbers).reverse(); + this.expected = Array.from(this.elements).reverse(); + }); + + it(`sort array of length ${length} (almost sorted)`, async function () { + this.elements = randomArray(generators.uint256, length).sort(compareNumbers); + this.expected = Array.from(this.elements); + // rotate (move the last element to the front) for an almost sorted effect + this.elements.unshift(this.elements.pop()); + }); + } + } + afterEach(async function () { + expect(await this.mock.$sort(this.elements)).to.deep.equal(this.expected); + }); + }); + describe('search', function () { for (const [title, { array, tests }] of Object.entries({ 'Even number of elements': { @@ -74,7 +121,7 @@ describe('Arrays', function () { })) { describe(title, function () { const fixture = async () => { - return { mock: await ethers.deployContract('Uint256ArraysMock', [array]) }; + return { instance: await ethers.deployContract('Uint256ArraysMock', [array]) }; }; beforeEach(async function () { @@ -86,20 +133,20 @@ describe('Arrays', function () { it('[deprecated] findUpperBound', async function () { // findUpperBound does not support duplicated if (hasDuplicates(array)) { - expect(await this.mock.findUpperBound(input)).to.be.equal(upperBound(array, input) - 1); + expect(await this.instance.findUpperBound(input)).to.equal(upperBound(array, input) - 1); } else { - expect(await this.mock.findUpperBound(input)).to.be.equal(lowerBound(array, input)); + expect(await this.instance.findUpperBound(input)).to.equal(lowerBound(array, input)); } }); it('lowerBound', async function () { - expect(await this.mock.lowerBound(input)).to.be.equal(lowerBound(array, input)); - expect(await this.mock.lowerBoundMemory(array, input)).to.be.equal(lowerBound(array, input)); + expect(await this.instance.lowerBound(input)).to.equal(lowerBound(array, input)); + expect(await this.instance.lowerBoundMemory(array, input)).to.equal(lowerBound(array, input)); }); it('upperBound', async function () { - expect(await this.mock.upperBound(input)).to.be.equal(upperBound(array, input)); - expect(await this.mock.upperBoundMemory(array, input)).to.be.equal(upperBound(array, input)); + expect(await this.instance.upperBound(input)).to.equal(upperBound(array, input)); + expect(await this.instance.upperBoundMemory(array, input)).to.equal(upperBound(array, input)); }); }); } @@ -108,28 +155,44 @@ describe('Arrays', function () { }); describe('unsafeAccess', function () { - for (const [title, { artifact, elements }] of Object.entries({ + for (const [type, { artifact, elements }] of Object.entries({ address: { artifact: 'AddressArraysMock', elements: randomArray(generators.address, 10) }, bytes32: { artifact: 'Bytes32ArraysMock', elements: randomArray(generators.bytes32, 10) }, uint256: { artifact: 'Uint256ArraysMock', elements: randomArray(generators.uint256, 10) }, })) { - describe(title, function () { - const fixture = async () => { - return { mock: await ethers.deployContract(artifact, [elements]) }; - }; + describe(type, function () { + describe('storage', function () { + const fixture = async () => { + return { instance: await ethers.deployContract(artifact, [elements]) }; + }; - beforeEach(async function () { - Object.assign(this, await loadFixture(fixture)); - }); + beforeEach(async function () { + Object.assign(this, await loadFixture(fixture)); + }); - for (const i in elements) { - it(`unsafeAccess within bounds #${i}`, async function () { - expect(await this.mock.unsafeAccess(i)).to.equal(elements[i]); + for (const i in elements) { + it(`unsafeAccess within bounds #${i}`, async function () { + expect(await this.instance.unsafeAccess(i)).to.equal(elements[i]); + }); + } + + it('unsafeAccess outside bounds', async function () { + await expect(this.instance.unsafeAccess(elements.length)).to.not.be.rejected; }); - } + }); + + describe('memory', function () { + const fragment = `$unsafeMemoryAccess(${type}[] arr, uint256 pos)`; - it('unsafeAccess outside bounds', async function () { - await expect(this.mock.unsafeAccess(elements.length)).to.not.be.rejected; + for (const i in elements) { + it(`unsafeMemoryAccess within bounds #${i}`, async function () { + expect(await this.mock[fragment](elements, i)).to.equal(elements[i]); + }); + } + + it('unsafeMemoryAccess outside bounds', async function () { + await expect(this.mock[fragment](elements, elements.length)).to.not.be.rejected; + }); }); }); } diff --git a/test/utils/Base64.t.sol b/test/utils/Base64.t.sol index 2e610ed75d4..021ae03af09 100644 --- a/test/utils/Base64.t.sol +++ b/test/utils/Base64.t.sol @@ -3,7 +3,6 @@ pragma solidity ^0.8.20; import {Test} from "forge-std/Test.sol"; - import {Base64} from "@openzeppelin/contracts/utils/Base64.sol"; contract Base64Test is Test { diff --git a/test/utils/structs/Checkpoints.t.sol b/test/utils/structs/Checkpoints.t.sol index 72d209f4cba..1f4b344c57f 100644 --- a/test/utils/structs/Checkpoints.t.sol +++ b/test/utils/structs/Checkpoints.t.sol @@ -4,8 +4,8 @@ pragma solidity ^0.8.20; import {Test} from "forge-std/Test.sol"; -import {SafeCast} from "../../../contracts/utils/math/SafeCast.sol"; -import {Checkpoints} from "../../../contracts/utils/structs/Checkpoints.sol"; +import {SafeCast} from "@openzeppelin/contracts/utils/math/SafeCast.sol"; +import {Checkpoints} from "@openzeppelin/contracts/utils/structs/Checkpoints.sol"; contract CheckpointsTrace224Test is Test { using Checkpoints for Checkpoints.Trace224;