Skip to content

Commit 61117c4

Browse files
AmxxRenanSouza2ernestognw
authored
Bound lookup in arrays with duplicate (#4842)
Co-authored-by: RenanSouza2 <renan.rodrigues.souza1@gmail.com> Co-authored-by: ernestognw <ernestognw@gmail.com>
1 parent 7439664 commit 61117c4

File tree

6 files changed

+204
-35
lines changed

6 files changed

+204
-35
lines changed

.changeset/flat-turtles-repeat.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'openzeppelin-solidity': minor
3+
---
4+
5+
`Arrays`: deprecate `findUpperBound` in favor of the new `lowerBound`.

.changeset/thick-pumpkins-report.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'openzeppelin-solidity': minor
3+
---
4+
5+
`Arrays`: add new functions `lowerBound`, `upperBound`, `lowerBoundMemory` and `upperBoundMemory` for lookups in sorted arrays with potential duplicates.

contracts/mocks/ArraysMock.sol

+18-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,24 @@ contract Uint256ArraysMock {
1313
_array = array;
1414
}
1515

16-
function findUpperBound(uint256 element) external view returns (uint256) {
17-
return _array.findUpperBound(element);
16+
function findUpperBound(uint256 value) external view returns (uint256) {
17+
return _array.findUpperBound(value);
18+
}
19+
20+
function lowerBound(uint256 value) external view returns (uint256) {
21+
return _array.lowerBound(value);
22+
}
23+
24+
function upperBound(uint256 value) external view returns (uint256) {
25+
return _array.upperBound(value);
26+
}
27+
28+
function lowerBoundMemory(uint256[] memory array, uint256 value) external pure returns (uint256) {
29+
return array.lowerBoundMemory(value);
30+
}
31+
32+
function upperBoundMemory(uint256[] memory array, uint256 value) external pure returns (uint256) {
33+
return array.upperBoundMemory(value);
1834
}
1935

2036
function unsafeAccess(uint256 pos) external view returns (uint256) {

contracts/utils/Arrays.sol

+132-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@ library Arrays {
1818
* values in the array are strictly less than `element`), the array length is
1919
* returned. Time complexity O(log n).
2020
*
21-
* `array` is expected to be sorted in ascending order, and to contain no
22-
* repeated elements.
21+
* NOTE: The `array` is expected to be sorted in ascending order, and to
22+
* contain no repeated elements.
23+
*
24+
* IMPORTANT: Deprecated. This implementation behaves as {lowerBound} but lacks
25+
* support for repeated elements in the array. The {lowerBound} function should
26+
* be used instead.
2327
*/
2428
function findUpperBound(uint256[] storage array, uint256 element) internal view returns (uint256) {
2529
uint256 low = 0;
@@ -49,6 +53,132 @@ library Arrays {
4953
}
5054
}
5155

56+
/**
57+
* @dev Searches an `array` sorted in ascending order and returns the first
58+
* index that contains a value greater or equal than `element`. If no such index
59+
* exists (i.e. all values in the array are strictly less than `element`), the array
60+
* length is returned. Time complexity O(log n).
61+
*
62+
* See C++'s https://en.cppreference.com/w/cpp/algorithm/lower_bound[lower_bound].
63+
*/
64+
function lowerBound(uint256[] storage array, uint256 element) internal view returns (uint256) {
65+
uint256 low = 0;
66+
uint256 high = array.length;
67+
68+
if (high == 0) {
69+
return 0;
70+
}
71+
72+
while (low < high) {
73+
uint256 mid = Math.average(low, high);
74+
75+
// Note that mid will always be strictly less than high (i.e. it will be a valid array index)
76+
// because Math.average rounds towards zero (it does integer division with truncation).
77+
if (unsafeAccess(array, mid).value < element) {
78+
// this cannot overflow because mid < high
79+
unchecked {
80+
low = mid + 1;
81+
}
82+
} else {
83+
high = mid;
84+
}
85+
}
86+
87+
return low;
88+
}
89+
90+
/**
91+
* @dev Searches an `array` sorted in ascending order and returns the first
92+
* index that contains a value strictly greater than `element`. If no such index
93+
* exists (i.e. all values in the array are strictly less than `element`), the array
94+
* length is returned. Time complexity O(log n).
95+
*
96+
* See C++'s https://en.cppreference.com/w/cpp/algorithm/upper_bound[upper_bound].
97+
*/
98+
function upperBound(uint256[] storage array, uint256 element) internal view returns (uint256) {
99+
uint256 low = 0;
100+
uint256 high = array.length;
101+
102+
if (high == 0) {
103+
return 0;
104+
}
105+
106+
while (low < high) {
107+
uint256 mid = Math.average(low, high);
108+
109+
// Note that mid will always be strictly less than high (i.e. it will be a valid array index)
110+
// because Math.average rounds towards zero (it does integer division with truncation).
111+
if (unsafeAccess(array, mid).value > element) {
112+
high = mid;
113+
} else {
114+
// this cannot overflow because mid < high
115+
unchecked {
116+
low = mid + 1;
117+
}
118+
}
119+
}
120+
121+
return low;
122+
}
123+
124+
/**
125+
* @dev Same as {lowerBound}, but with an array in memory.
126+
*/
127+
function lowerBoundMemory(uint256[] memory array, uint256 element) internal pure returns (uint256) {
128+
uint256 low = 0;
129+
uint256 high = array.length;
130+
131+
if (high == 0) {
132+
return 0;
133+
}
134+
135+
while (low < high) {
136+
uint256 mid = Math.average(low, high);
137+
138+
// Note that mid will always be strictly less than high (i.e. it will be a valid array index)
139+
// because Math.average rounds towards zero (it does integer division with truncation).
140+
if (unsafeMemoryAccess(array, mid) < element) {
141+
// this cannot overflow because mid < high
142+
unchecked {
143+
low = mid + 1;
144+
}
145+
} else {
146+
high = mid;
147+
}
148+
}
149+
150+
return low;
151+
}
152+
153+
/**
154+
* @dev Same as {upperBound}, but with an array in memory.
155+
*/
156+
function upperBoundMemory(uint256[] memory array, uint256 element) internal pure returns (uint256) {
157+
uint256 low = 0;
158+
uint256 high = array.length;
159+
160+
if (high == 0) {
161+
return 0;
162+
}
163+
164+
while (low < high) {
165+
uint256 mid = Math.average(low, high);
166+
167+
// Note that mid will always be strictly less than high (i.e. it will be a valid array index)
168+
// because Math.average rounds towards zero (it does integer division with truncation).
169+
if (unsafeMemoryAccess(array, mid) > element) {
170+
high = mid;
171+
} else {
172+
// this cannot overflow because mid < high
173+
unchecked {
174+
low = mid + 1;
175+
}
176+
}
177+
}
178+
179+
return low;
180+
}
181+
52182
/**
53183
* @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check.
54184
*

test/utils/Arrays.test.js

+41-26
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,22 @@ const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
44

55
const { randomArray, generators } = require('../helpers/random');
66

7-
// See https://en.cppreference.com/w/cpp/algorithm/ranges/lower_bound
7+
// See https://en.cppreference.com/w/cpp/algorithm/lower_bound
88
const lowerBound = (array, value) => {
99
const i = array.findIndex(element => value <= element);
1010
return i == -1 ? array.length : i;
1111
};
1212

1313
// See https://en.cppreference.com/w/cpp/algorithm/upper_bound
14-
// const upperBound = (array, value) => {
15-
// const i = array.findIndex(element => value < element);
16-
// return i == -1 ? array.length : i;
17-
// };
14+
const upperBound = (array, value) => {
15+
const i = array.findIndex(element => value < element);
16+
return i == -1 ? array.length : i;
17+
};
1818

1919
const hasDuplicates = array => array.some((v, i) => array.indexOf(v) != i);
2020

2121
describe('Arrays', function () {
22-
describe('findUpperBound', function () {
22+
describe('search', function () {
2323
for (const [title, { array, tests }] of Object.entries({
2424
'Even number of elements': {
2525
array: [11n, 12n, 13n, 14n, 15n, 16n, 17n, 18n, 19n, 20n],
@@ -82,40 +82,55 @@ describe('Arrays', function () {
8282
});
8383

8484
for (const [name, input] of Object.entries(tests)) {
85-
it(name, async function () {
86-
// findUpperBound does not support duplicated
87-
if (hasDuplicates(array)) this.skip();
88-
expect(await this.mock.findUpperBound(input)).to.equal(lowerBound(array, input));
85+
describe(name, function () {
86+
it('[deprecated] findUpperBound', async function () {
87+
// findUpperBound does not support duplicated
88+
if (hasDuplicates(array)) {
89+
expect(await this.mock.findUpperBound(input)).to.be.equal(upperBound(array, input) - 1);
90+
} else {
91+
expect(await this.mock.findUpperBound(input)).to.be.equal(lowerBound(array, input));
92+
}
93+
});
94+
95+
it('lowerBound', async function () {
96+
expect(await this.mock.lowerBound(input)).to.be.equal(lowerBound(array, input));
97+
expect(await this.mock.lowerBoundMemory(array, input)).to.be.equal(lowerBound(array, input));
98+
});
99+
100+
it('upperBound', async function () {
101+
expect(await this.mock.upperBound(input)).to.be.equal(upperBound(array, input));
102+
expect(await this.mock.upperBoundMemory(array, input)).to.be.equal(upperBound(array, input));
103+
});
89104
});
90105
}
91106
});
92107
}
93108
});
94109

95110
describe('unsafeAccess', function () {
96-
const contractCases = {
111+
for (const [title, { artifact, elements }] of Object.entries({
97112
address: { artifact: 'AddressArraysMock', elements: randomArray(generators.address, 10) },
98113
bytes32: { artifact: 'Bytes32ArraysMock', elements: randomArray(generators.bytes32, 10) },
99114
uint256: { artifact: 'Uint256ArraysMock', elements: randomArray(generators.uint256, 10) },
100-
};
101-
102-
const fixture = async () => {
103-
const contracts = {};
104-
for (const [name, { artifact, elements }] of Object.entries(contractCases)) {
105-
contracts[name] = await ethers.deployContract(artifact, [elements]);
106-
}
107-
return { contracts };
108-
};
115+
})) {
116+
describe(title, function () {
117+
const fixture = async () => {
118+
return { mock: await ethers.deployContract(artifact, [elements]) };
119+
};
109120

110-
beforeEach(async function () {
111-
Object.assign(this, await loadFixture(fixture));
112-
});
121+
beforeEach(async function () {
122+
Object.assign(this, await loadFixture(fixture));
123+
});
113124

114-
for (const [name, { elements }] of Object.entries(contractCases)) {
115-
it(name, async function () {
116125
for (const i in elements) {
117-
expect(await this.contracts[name].unsafeAccess(i)).to.equal(elements[i]);
126+
it(`unsafeAccess within bounds #${i}`, async function () {
127+
expect(await this.mock.unsafeAccess(i)).to.equal(elements[i]);
128+
});
118129
}
130+
131+
it('unsafeAccess outside bounds', async function () {
132+
await expect(this.mock.unsafeAccess(elements.length)).to.not.be.rejected;
133+
});
119134
});
120135
}
121136
});

test/utils/structs/Checkpoints.test.js

+3-5
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
44

55
const { VALUE_SIZES } = require('../../../scripts/generate/templates/Checkpoints.opts');
66

7-
const last = array => (array.length ? array[array.length - 1] : undefined);
8-
97
describe('Checkpoints', function () {
108
for (const length of VALUE_SIZES) {
119
describe(`Trace${length}`, function () {
@@ -81,7 +79,7 @@ describe('Checkpoints', function () {
8179
it('returns latest value', async function () {
8280
const latest = this.checkpoints.at(-1);
8381
expect(await this.methods.latest()).to.equal(latest.value);
84-
expect(await this.methods.latestCheckpoint()).to.have.ordered.members([true, latest.key, latest.value]);
82+
expect(await this.methods.latestCheckpoint()).to.deep.equal([true, latest.key, latest.value]);
8583
});
8684

8785
it('cannot push values in the past', async function () {
@@ -115,7 +113,7 @@ describe('Checkpoints', function () {
115113

116114
it('upper lookup & upperLookupRecent', async function () {
117115
for (let i = 0; i < 14; ++i) {
118-
const value = last(this.checkpoints.filter(x => i >= x.key))?.value || 0n;
116+
const value = this.checkpoints.findLast(x => i >= x.key)?.value || 0n;
119117

120118
expect(await this.methods.upperLookup(i)).to.equal(value);
121119
expect(await this.methods.upperLookupRecent(i)).to.equal(value);
@@ -137,7 +135,7 @@ describe('Checkpoints', function () {
137135
}
138136

139137
for (let i = 0; i < 25; ++i) {
140-
const value = last(allCheckpoints.filter(x => i >= x.key))?.value || 0n;
138+
const value = allCheckpoints.findLast(x => i >= x.key)?.value || 0n;
141139
expect(await this.methods.upperLookup(i)).to.equal(value);
142140
expect(await this.methods.upperLookupRecent(i)).to.equal(value);
143141
}

0 commit comments

Comments
 (0)