From d404a2bf2f7e0ebdf9bd6d54a10b8fbe53506d43 Mon Sep 17 00:00:00 2001 From: Michael Heuer <20623991+Michael-A-Heuer@users.noreply.github.com> Date: Wed, 11 Oct 2023 10:53:06 +0200 Subject: [PATCH] feat: added semver comparision lib (#479) * feat: added semver comparision lib * docs: fix wrong NatSpec Co-authored-by: Mathias Scherer * fix: remove redundant await and async * style: move helpers to the file end --------- Co-authored-by: Mathias Scherer --- packages/contracts/CHANGELOG.md | 1 + packages/contracts/src/core/dao/DAO.sol | 6 +- .../test/utils/VersionComparisonLibTest.sol | 33 +++ .../utils/protocol/VersionComparisonLib.sol | 97 +++++++++ .../test/utils/version-comparison-lib.ts | 196 ++++++++++++++++++ 5 files changed, 331 insertions(+), 2 deletions(-) create mode 100644 packages/contracts/src/test/utils/VersionComparisonLibTest.sol create mode 100644 packages/contracts/src/utils/protocol/VersionComparisonLib.sol create mode 100644 packages/contracts/test/utils/version-comparison-lib.ts diff --git a/packages/contracts/CHANGELOG.md b/packages/contracts/CHANGELOG.md index e66d69bed..c926d5b9e 100644 --- a/packages/contracts/CHANGELOG.md +++ b/packages/contracts/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `VersionComparisonLib` to compare semantic versioning numbers. - Inherit `ProtocolVersion` in `Plugin`, `PluginCloneable`, `PluginUUPSUpgradeable`, `PluginSetup`, `PermissionCondition`, `PermissionConditionUpgradeable` `PluginSetupProcessor`, `PluginRepoRegistry`, `DAORegistry`, and `ENSSubdomainRegistrar`. - Added the `FunctionDeprecated` error to `DAO`. diff --git a/packages/contracts/src/core/dao/DAO.sol b/packages/contracts/src/core/dao/DAO.sol index 48ba76aa1..75ea8141d 100644 --- a/packages/contracts/src/core/dao/DAO.sol +++ b/packages/contracts/src/core/dao/DAO.sol @@ -15,6 +15,7 @@ import "@openzeppelin/contracts/interfaces/IERC1271.sol"; import {IProtocolVersion} from "../../utils/protocol/IProtocolVersion.sol"; import {ProtocolVersion} from "../../utils/protocol/ProtocolVersion.sol"; +import {VersionComparisonLib} from "../../utils/protocol/VersionComparisonLib.sol"; import {PermissionManager} from "../permission/PermissionManager.sol"; import {CallbackHandler} from "../utils/CallbackHandler.sol"; import {hasBit, flipBit} from "../utils/BitMap.sol"; @@ -39,6 +40,7 @@ contract DAO is { using SafeERC20Upgradeable for IERC20Upgradeable; using AddressUpgradeable for address; + using VersionComparisonLib for uint8[3]; /// @notice The ID of the permission required to call the `execute` function. bytes32 public constant EXECUTE_PERMISSION_ID = keccak256("EXECUTE_PERMISSION"); @@ -183,13 +185,13 @@ contract DAO is // Initialize `_reentrancyStatus` that was added in v1.3.0. // Register Interface `ProtocolVersion` that was added in v1.3.0. - if (_previousProtocolVersion[1] <= 2) { + if (_previousProtocolVersion.lt([1, 3, 0])) { _reentrancyStatus = _NOT_ENTERED; _registerInterface(type(IProtocolVersion).interfaceId); } // Revoke the `SET_SIGNATURE_VALIDATOR_PERMISSION` that was deprecated in v1.4.0. - if (_previousProtocolVersion[1] <= 3) { + if (_previousProtocolVersion.lt([1, 4, 0])) { _revoke({ _where: address(this), _who: address(this), diff --git a/packages/contracts/src/test/utils/VersionComparisonLibTest.sol b/packages/contracts/src/test/utils/VersionComparisonLibTest.sol new file mode 100644 index 000000000..da73ba22f --- /dev/null +++ b/packages/contracts/src/test/utils/VersionComparisonLibTest.sol @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +pragma solidity 0.8.17; + +import {VersionComparisonLib} from "../../utils/protocol/VersionComparisonLib.sol"; + +contract VersionComparisonLibTest { + using VersionComparisonLib for uint8[3]; + + function eq(uint8[3] memory lhs, uint8[3] memory rhs) public pure returns (bool) { + return lhs.eq(rhs); + } + + function neq(uint8[3] memory lhs, uint8[3] memory rhs) public pure returns (bool) { + return lhs.neq(rhs); + } + + function lt(uint8[3] memory lhs, uint8[3] memory rhs) public pure returns (bool) { + return lhs.lt(rhs); + } + + function lte(uint8[3] memory lhs, uint8[3] memory rhs) public pure returns (bool) { + return lhs.lte(rhs); + } + + function gt(uint8[3] memory lhs, uint8[3] memory rhs) public pure returns (bool) { + return lhs.gt(rhs); + } + + function gte(uint8[3] memory lhs, uint8[3] memory rhs) public pure returns (bool) { + return lhs.gte(rhs); + } +} diff --git a/packages/contracts/src/utils/protocol/VersionComparisonLib.sol b/packages/contracts/src/utils/protocol/VersionComparisonLib.sol new file mode 100644 index 000000000..240449b24 --- /dev/null +++ b/packages/contracts/src/utils/protocol/VersionComparisonLib.sol @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +pragma solidity ^0.8.8; + +library VersionComparisonLib { + /// @notice Equality comparator for two semantic version numbers. + /// @param lhs The left-hand side semantic version number. + /// @param rhs The right-hand side semantic version number. + /// @return Whether the two numbers are equal or not. + function eq(uint8[3] memory lhs, uint8[3] memory rhs) internal pure returns (bool) { + if (lhs[0] != rhs[0]) return false; + + if (lhs[1] != rhs[1]) return false; + + if (lhs[2] != rhs[2]) return false; + + return true; + } + + /// @notice Inequality comparator for two semantic version numbers. + /// @param lhs The left-hand side semantic version number. + /// @param rhs The right-hand side semantic version number. + /// @return Whether the two numbers are inequal or not. + function neq(uint8[3] memory lhs, uint8[3] memory rhs) internal pure returns (bool) { + if (lhs[0] != rhs[0]) return true; + + if (lhs[1] != rhs[1]) return true; + + if (lhs[2] != rhs[2]) return true; + + return false; + } + + /// @notice Less than comparator for two semantic version numbers. + /// @param lhs The left-hand side semantic version number. + /// @param rhs The right-hand side semantic version number. + /// @return Whether the first number is less than the second number or not. + function lt(uint8[3] memory lhs, uint8[3] memory rhs) internal pure returns (bool) { + if (lhs[0] < rhs[0]) return true; + if (lhs[0] > rhs[0]) return false; + + if (lhs[1] < rhs[1]) return true; + if (lhs[1] > rhs[1]) return false; + + if (lhs[2] < rhs[2]) return true; + + return false; + } + + /// @notice Less than or equal to comparator for two semantic version numbers. + /// @param lhs The left-hand side semantic version number. + /// @param rhs The right-hand side semantic version number. + /// @return Whether the first number is less than or equal to the second number or not. + function lte(uint8[3] memory lhs, uint8[3] memory rhs) internal pure returns (bool) { + if (lhs[0] < rhs[0]) return true; + if (lhs[0] > rhs[0]) return false; + + if (lhs[1] < rhs[1]) return true; + if (lhs[1] > rhs[1]) return false; + + if (lhs[2] <= rhs[2]) return true; + + return false; + } + + /// @notice Greater than comparator for two semantic version numbers. + /// @param lhs The left-hand side semantic version number. + /// @param rhs The right-hand side semantic version number. + /// @return Whether the first number is greater than the second number or not. + function gt(uint8[3] memory lhs, uint8[3] memory rhs) internal pure returns (bool) { + if (lhs[0] > rhs[0]) return true; + if (lhs[0] < rhs[0]) return false; + + if (lhs[1] > rhs[1]) return true; + if (lhs[1] < rhs[1]) return false; + + if (lhs[2] > rhs[2]) return true; + + return false; + } + + /// @notice Greater than or equal to comparator for two semantic version numbers. + /// @param lhs The left-hand side semantic version number. + /// @param rhs The right-hand side semantic version number. + /// @return Whether the first number is greater than or equal to the second number or not. + function gte(uint8[3] memory lhs, uint8[3] memory rhs) internal pure returns (bool) { + if (lhs[0] > rhs[0]) return true; + if (lhs[0] < rhs[0]) return false; + + if (lhs[1] > rhs[1]) return true; + if (lhs[1] < rhs[1]) return false; + + if (lhs[2] >= rhs[2]) return true; + + return false; + } +} diff --git a/packages/contracts/test/utils/version-comparison-lib.ts b/packages/contracts/test/utils/version-comparison-lib.ts new file mode 100644 index 000000000..4416a9210 --- /dev/null +++ b/packages/contracts/test/utils/version-comparison-lib.ts @@ -0,0 +1,196 @@ +import {expect} from 'chai'; +import {ethers} from 'hardhat'; + +import { + VersionComparisonLibTest, + VersionComparisonLibTest__factory, +} from '../../typechain'; + +type SemVer = [number, number, number]; + +describe('VersionComparisonLib', function () { + let cmp: VersionComparisonLibTest; + + before(async () => { + const signers = await ethers.getSigners(); + cmp = await new VersionComparisonLibTest__factory(signers[0]).deploy(); + }); + + describe('eq', async () => { + function eq(lhs: SemVer, rhs: SemVer): Promise { + return cmp.eq(lhs, rhs); + } + + it('returns true if lhs equals rhs', async () => { + await eqChecks(eq, true); + }); + + it('returns false if lhs does not equal rhs', async () => { + await ltChecks(eq, false); + await gtChecks(eq, false); + }); + }); + + describe('neq', async () => { + function neq(lhs: SemVer, rhs: SemVer): Promise { + return cmp.neq(lhs, rhs); + } + + it('returns true if lhs does not equal rhs', async () => { + await ltChecks(neq, true); + await gtChecks(neq, true); + }); + + it('returns false if lhs equals rhs', async () => { + await eqChecks(neq, false); + }); + }); + + describe('lt', async () => { + function lt(lhs: SemVer, rhs: SemVer): Promise { + return cmp.lt(lhs, rhs); + } + + it('returns true if lhs is less than rhs', async () => { + await ltChecks(lt, true); + }); + + it('returns false if lhs is not less than rhs', async () => { + await gtChecks(lt, false); + await eqChecks(lt, false); + }); + }); + + describe('lte', async () => { + function lte(lhs: SemVer, rhs: SemVer): Promise { + return cmp.lte(lhs, rhs); + } + + it('returns true if lhs is less than or equal to rhs', async () => { + await ltChecks(lte, true); + await eqChecks(lte, true); + }); + + it('returns false if lhs is not less than or equal to rhs', async () => { + await gtChecks(lte, false); + }); + }); + + describe('gt', async () => { + function gt(lhs: SemVer, rhs: SemVer): Promise { + return cmp.gt(lhs, rhs); + } + + it('returns true if lhs is greater than rhs', async () => { + await gtChecks(gt, true); + }); + + it('returns false if lhs is not greater than rhs', async () => { + await ltChecks(gt, false); + await eqChecks(gt, false); + }); + }); + + describe('gte', async () => { + function gte(lhs: SemVer, rhs: SemVer): Promise { + return cmp.gte(lhs, rhs); + } + + it('returns true if lhs is greater than or equal to rhs', async () => { + await gtChecks(gte, true); + await eqChecks(gte, true); + }); + + it('returns false if lhs is not greater than or equal to rhs', async () => { + await ltChecks(gte, false); + }); + }); +}); + +async function eqChecks( + func: (lhs: SemVer, rhs: SemVer) => Promise, + expected: boolean +) { + const results: boolean[] = await Promise.all([ + func([1, 1, 1], [1, 1, 1]), + // + func([0, 1, 1], [0, 1, 1]), + func([1, 0, 1], [1, 0, 1]), + func([1, 1, 0], [1, 1, 0]), + // + func([1, 0, 0], [1, 0, 0]), + func([0, 1, 0], [0, 1, 0]), + func([0, 0, 1], [0, 0, 1]), + // + func([0, 0, 0], [0, 0, 0]), + ]); + + // Check that all results match the expected value + expect(results.every(v => v === expected)).to.be.true; +} + +async function ltChecks( + func: (lhs: SemVer, rhs: SemVer) => Promise, + expected: boolean +) { + const results: boolean[] = await Promise.all([ + func([1, 1, 1], [2, 1, 1]), + func([1, 1, 1], [1, 2, 1]), + func([1, 1, 1], [1, 1, 2]), + // + func([1, 1, 1], [1, 2, 2]), + func([1, 1, 1], [2, 1, 2]), + func([1, 1, 1], [2, 2, 1]), + // + func([1, 1, 1], [2, 2, 2]), + // + func([1, 1, 0], [1, 2, 0]), + func([1, 1, 0], [2, 1, 0]), + // + func([1, 1, 0], [2, 2, 0]), + // + func([0, 1, 1], [0, 1, 2]), + func([0, 1, 1], [0, 2, 1]), + // + func([0, 1, 1], [0, 2, 2]), + // + func([1, 0, 0], [2, 0, 0]), + func([0, 1, 0], [0, 2, 0]), + func([0, 0, 1], [0, 0, 2]), + ]); + + // Check that all results match the expected value + expect(results.every(v => v === expected)).to.be.true; +} + +async function gtChecks( + func: (lhs: SemVer, rhs: SemVer) => Promise, + expected: boolean +) { + const results: boolean[] = await Promise.all([ + func([2, 1, 1], [1, 1, 1]), + func([1, 2, 1], [1, 1, 1]), + func([1, 1, 2], [1, 1, 1]), + // + func([1, 2, 2], [1, 1, 1]), + func([2, 1, 2], [1, 1, 1]), + func([2, 2, 1], [1, 1, 1]), + // + func([2, 2, 2], [1, 1, 1]), + // + func([1, 2, 0], [1, 1, 0]), + func([2, 1, 0], [1, 1, 0]), + func([2, 2, 0], [1, 1, 0]), + // + func([0, 1, 2], [0, 1, 1]), + func([0, 2, 1], [0, 1, 1]), + func([0, 2, 2], [0, 1, 1]), + // + func([2, 0, 0], [1, 0, 0]), + func([0, 2, 0], [0, 1, 0]), + func([0, 0, 2], [0, 0, 1]), + ]); + + // Check that all results match the expected value + expect(results.every(v => v === expected)).to.be.true; +}