Skip to content

Commit

Permalink
Branchless ternary, min and max methods (#4976)
Browse files Browse the repository at this point in the history
Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
Co-authored-by: Ernesto García <ernestognw@gmail.com>
  • Loading branch information
3 people authored Apr 23, 2024
1 parent 60afc99 commit 4032b42
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 7 deletions.
5 changes: 5 additions & 0 deletions .changeset/spotty-falcons-explain.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'openzeppelin-solidity': minor
---

`Math`, `SignedMath`: Add a branchless `ternary` function that computes`cond ? a : b` in constant gas cost.
26 changes: 21 additions & 5 deletions contracts/utils/math/Math.sol
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,34 @@ library Math {
}
}

/**
* @dev Branchless ternary evaluation for `a ? b : c`. Gas costs are constant.
*
* IMPORTANT: This function may reduce bytecode size and consume less gas when used standalone.
* However, the compiler may optimize Solidity ternary operations (i.e. `a ? b : c`) to only compute
* one branch when needed, making this function more expensive.
*/
function ternary(bool condition, uint256 a, uint256 b) internal pure returns (uint256) {
unchecked {
// branchless ternary works because:
// b ^ (a ^ b) == a
// b ^ 0 == b
return b ^ ((a ^ b) * SafeCast.toUint(condition));
}
}

/**
* @dev Returns the largest of two numbers.
*/
function max(uint256 a, uint256 b) internal pure returns (uint256) {
return a > b ? a : b;
return ternary(a > b, a, b);
}

/**
* @dev Returns the smallest of two numbers.
*/
function min(uint256 a, uint256 b) internal pure returns (uint256) {
return a < b ? a : b;
return ternary(a < b, a, b);
}

/**
Expand Down Expand Up @@ -114,7 +130,7 @@ library Math {
// but the largest value we can obtain is type(uint256).max - 1, which happens
// when a = type(uint256).max and b = 1.
unchecked {
return a == 0 ? 0 : (a - 1) / b + 1;
return SafeCast.toUint(a > 0) * ((a - 1) / b + 1);
}
}

Expand Down Expand Up @@ -147,7 +163,7 @@ library Math {

// Make sure the result is less than 2²⁵⁶. Also prevents denominator == 0.
if (denominator <= prod1) {
Panic.panic(denominator == 0 ? Panic.DIVISION_BY_ZERO : Panic.UNDER_OVERFLOW);
Panic.panic(ternary(denominator == 0, Panic.DIVISION_BY_ZERO, Panic.UNDER_OVERFLOW));
}

///////////////////////////////////////////////
Expand Down Expand Up @@ -268,7 +284,7 @@ library Math {
}

if (gcd != 1) return 0; // No inverse exists.
return x < 0 ? (n - uint256(-x)) : uint256(x); // Wrap the result if it's negative.
return ternary(x < 0, n - uint256(-x), uint256(x)); // Wrap the result if it's negative.
}
}

Expand Down
22 changes: 20 additions & 2 deletions contracts/utils/math/SignedMath.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,40 @@

pragma solidity ^0.8.20;

import {SafeCast} from "./SafeCast.sol";

/**
* @dev Standard signed math utilities missing in the Solidity language.
*/
library SignedMath {
/**
* @dev Branchless ternary evaluation for `a ? b : c`. Gas costs are constant.
*
* IMPORTANT: This function may reduce bytecode size and consume less gas when used standalone.
* However, the compiler may optimize Solidity ternary operations (i.e. `a ? b : c`) to only compute
* one branch when needed, making this function more expensive.
*/
function ternary(bool condition, int256 a, int256 b) internal pure returns (int256) {
unchecked {
// branchless terinary works because:
// b ^ (a ^ b) == a
// b ^ 0 == b
return b ^ ((a ^ b) * int256(SafeCast.toUint(condition)));
}
}

/**
* @dev Returns the largest of two signed numbers.
*/
function max(int256 a, int256 b) internal pure returns (int256) {
return a > b ? a : b;
return ternary(a > b, a, b);
}

/**
* @dev Returns the smallest of two signed numbers.
*/
function min(int256 a, int256 b) internal pure returns (int256) {
return a < b ? a : b;
return ternary(a < b, a, b);
}

/**
Expand Down
10 changes: 10 additions & 0 deletions test/utils/math/Math.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ import {Test, stdError} from "forge-std/Test.sol";
import {Math} from "@openzeppelin/contracts/utils/math/Math.sol";

contract MathTest is Test {
function testSelect(bool f, uint256 a, uint256 b) public {
assertEq(Math.ternary(f, a, b), f ? a : b);
}

// MIN & MAX
function testMinMax(uint256 a, uint256 b) public {
assertEq(Math.min(a, b), a < b ? a : b);
assertEq(Math.max(a, b), a > b ? a : b);
}

// CEILDIV
function testCeilDiv(uint256 a, uint256 b) public {
vm.assume(b > 0);
Expand Down
10 changes: 10 additions & 0 deletions test/utils/math/SignedMath.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@ import {Math} from "../../../contracts/utils/math/Math.sol";
import {SignedMath} from "../../../contracts/utils/math/SignedMath.sol";

contract SignedMathTest is Test {
function testSelect(bool f, int256 a, int256 b) public {
assertEq(SignedMath.ternary(f, a, b), f ? a : b);
}

// MIN & MAX
function testMinMax(int256 a, int256 b) public {
assertEq(SignedMath.min(a, b), a < b ? a : b);
assertEq(SignedMath.max(a, b), a > b ? a : b);
}

// MIN
function testMin(int256 a, int256 b) public {
int256 result = SignedMath.min(a, b);
Expand Down

0 comments on commit 4032b42

Please sign in to comment.