Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf wExp (1) #11

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions src/irm/Irm.sol
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,8 @@ contract Irm is IIrm {

/// @notice Constructor.
/// @param morpho The address of Morpho.
/// @param lnJumpFactor The log of the jump factor (scaled by WAD). Warning: lnJumpFactor <= 3 must hold. Above
/// that, the approximations in wExp are considered too large.
/// @param speedFactor The speed factor (scaled by WAD). Warning: |speedFactor * error * elapsed| <= 3 must hold.
/// Above that, the approximations in wExp are considered too large.
/// @param lnJumpFactor The log of the jump factor (scaled by WAD).
/// @param speedFactor The speed factor (scaled by WAD).
/// @param targetUtilization The target utilization (scaled by WAD). Should be between 0 and 1.
/// @param initialRate The initial rate (scaled by WAD).
constructor(
Expand Down Expand Up @@ -111,13 +109,13 @@ contract Irm is IIrm {
int256 errDelta = err - marketIrm[id].prevErr;

// Safe "unchecked" cast because LN_JUMP_FACTOR <= type(int256).max.
uint256 jumpMultiplier = MathLib.wExp12(errDelta.wMulDown(int256(LN_JUMP_FACTOR)));
uint256 jumpMultiplier = MathLib.wExp(errDelta.wMulDown(int256(LN_JUMP_FACTOR)));
// Safe "unchecked" cast because SPEED_FACTOR <= type(int256).max.
int256 speed = int256(SPEED_FACTOR).wMulDown(err);
uint256 elapsed = block.timestamp - market.lastUpdate;
// Safe "unchecked" cast because elapsed <= block.timestamp.
int256 linearVariation = speed * int256(elapsed);
uint256 variationMultiplier = MathLib.wExp12(linearVariation);
uint256 variationMultiplier = MathLib.wExp(linearVariation);

// newBorrowRate = prevBorrowRate * jumpMultiplier * variationMultiplier.
uint256 borrowRateAfterJump = marketIrm[id].prevBorrowRate.wMulDown(jumpMultiplier);
Expand Down
69 changes: 53 additions & 16 deletions src/irm/libraries/MathLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,60 @@ library MathLib {
using {wDivDown} for int256;
using {wMulDown} for int256;

/// @dev 12th-order Taylor polynomial of e^x, for x around 0.
/// @dev The input is limited to a range between -3 and 3.
/// @dev The approximation error is less than 1% between -3 and 3.
function wExp12(int256 x) internal pure returns (uint256) {
x = x >= -3 * WAD_INT ? x : -3 * WAD_INT;
x = x <= 3 * WAD_INT ? x : 3 * WAD_INT;

// `N` should be even otherwise the result can be negative.
int256 N = 12;
int256 res = WAD_INT;
int256 monomial = WAD_INT;
for (int256 k = 1; k <= N; k++) {
monomial = monomial.wMulDown(x) / k;
res += monomial;
/// @dev Returns an approximation of exp.
/// @dev Greatly inspired by https://xn--2-umb.com/22/exp-ln/#fixed-point-exponential-function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems this is excatly solmate's version why not importing it instead? Do we want to do the same as in blue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm no not really Solmate's one is in full assembly

Copy link
Contributor

@MerlinEgalite MerlinEgalite Sep 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not full assmebly there's just the sdiv

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes mb

function wExp(int256 x) internal pure returns (uint256) {
unchecked {
int256 r;

// When the result is < 0.5 we return zero. This happens when x <= floor(log(0.5e18) * 1e18) ~ -42e18
if (x <= -42139678854452767551) return 0;

// When the result is > (2**255 - 1) / 1e18 we can not represent it as an int. This happens when x >=
// floor(log((2**255 - 1) / 1e18) * 1e18) ~ 135.
if (x >= 135305999368893231589) revert("EXP_OVERFLOW");

// x is now in the range (-42, 136) * 1e18. Convert to (-42, 136) * 2**96 for more intermediate precision
// and a binary basis. This base conversion is a multiplication by 1e18 / 2**96 = 5**18 / 2**78.
x = (x << 78) / 5 ** 18;

// Reduce range of x to (-½ ln 2, ½ ln 2) * 2**96 by factoring out powers of two such that exp(x) =
// exp(x') * 2**k, where k is an integer. Solving this gives k = round(x / log(2)) and x' = x - k * log(2).
int256 k = ((x << 96) / 54916777467707473351141471128 + 2 ** 95) >> 96;
x = x - k * 54916777467707473351141471128;

// k is in the range [-61, 195].

// Evaluate using a (6, 7)-term rational approximation. p is made monic, we'll multiply by a scale factor
// later.
int256 y = x + 1346386616545796478920950773328;
y = ((y * x) >> 96) + 57155421227552351082224309758442;
int256 p = y + x - 94201549194550492254356042504812;
p = ((p * y) >> 96) + 28719021644029726153956944680412240;
p = p * x + (4385272521454847904659076985693276 << 96);

// We leave p in 2**192 basis so we don't need to scale it back up for the division.
int256 q = x - 2855989394907223263936484059900;
q = ((q * x) >> 96) + 50020603652535783019961831881945;
q = ((q * x) >> 96) - 533845033583426703283633433725380;
q = ((q * x) >> 96) + 3604857256930695427073651918091429;
q = ((q * x) >> 96) - 14423608567350463180887372962807573;
q = ((q * x) >> 96) + 26449188498355588339934803723976023;

r = p / q;

// r should be in the range (0.09, 0.25) * 2**96.

// We now need to multiply r by:
// * the scale factor s = ~6.031367120.
// * the 2**k factor from the range reduction.
// * the 1e18 / 2**96 factor for base conversion.
// We do this all at once, with an intermediate result in 2**213 basis, so the final right shift is always
// by a positive amount.
r = int256((uint256(r) * 3822833074963236453042738258902158003155416615667) >> uint256(195 - k));

return uint256(r);
}
// Safe "unchecked" cast because `N` is even.
return uint256(res);
}

function wMulDown(int256 a, int256 b) internal pure returns (int256) {
Expand Down
4 changes: 2 additions & 2 deletions test/irm/IrmTest.sol
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ contract IrmTest is Test {
int256 errDelta = err - prevErr;
uint256 elapsed = block.timestamp - market1.lastUpdate;

uint256 jumpMultiplier = MathLib.wExp12(errDelta.wMulDown(int256(LN2)));
uint256 jumpMultiplier = MathLib.wExp(errDelta.wMulDown(int256(LN2)));
int256 speed = int256(SPEED_FACTOR).wMulDown(err);
uint256 variationMultiplier = MathLib.wExp12(speed * int256(elapsed));
uint256 variationMultiplier = MathLib.wExp(speed * int256(elapsed));
uint256 expectedBorrowRateAfterJump = INITIAL_RATE.wMulDown(jumpMultiplier);
uint256 expectedNewBorrowRate = INITIAL_RATE.wMulDown(jumpMultiplier).wMulDown(variationMultiplier);

Expand Down
27 changes: 15 additions & 12 deletions test/irm/MathLibTest.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,25 @@ contract MathLibTest is Test {
using MathLib for uint256;

function testWExp() public {
assertEq(MathLib.wExp12(-5 ether), MathLib.wExp12(-3 ether));
assertApproxEqRel(MathLib.wExp12(-3 ether), 0.04978706836 ether, 0.005 ether);
assertApproxEqRel(MathLib.wExp12(-2 ether), 0.13533528323 ether, 0.00001 ether);
assertApproxEqRel(MathLib.wExp12(-1 ether), 0.36787944117 ether, 0.00000001 ether);
assertEq(MathLib.wExp12(0 ether), 1.0 ether);
assertApproxEqRel(MathLib.wExp12(1 ether), 2.71828182846 ether, 0.00000001 ether);
assertApproxEqRel(MathLib.wExp12(2 ether), 7.38905609893 ether, 0.00001 ether);
assertApproxEqRel(MathLib.wExp12(3 ether), 20.0855369232 ether, 0.001 ether);
assertEq(MathLib.wExp12(5 ether), MathLib.wExp12(3 ether));
assertApproxEqRel(MathLib.wExp(-3 ether), 0.04978706836 ether, 0.005 ether);
assertApproxEqRel(MathLib.wExp(-2 ether), 0.13533528323 ether, 0.00001 ether);
assertApproxEqRel(MathLib.wExp(-1 ether), 0.36787944117 ether, 0.00000001 ether);
assertEq(MathLib.wExp(0 ether), 1.0 ether);
assertApproxEqRel(MathLib.wExp(1 ether), 2.71828182846 ether, 0.00000001 ether);
assertApproxEqRel(MathLib.wExp(2 ether), 7.38905609893 ether, 0.00001 ether);
assertApproxEqRel(MathLib.wExp(3 ether), 20.0855369232 ether, 0.001 ether);
}

function testWExp(int256 x) public {
// Bounds between log(5e-18) and log((2**255 - 1) / 1e18).
x = bound(x, -42 ether, 58.76 ether);
if (x >= 0) assertGe(MathLib.wExp(x), WAD + uint256(x));
if (x < 0) assertLe(MathLib.wExp(x), WAD);
}

function testWExpRef(int256 x) public {
x = bound(x, -3 ether, 3 ether);
assertGe(int256(MathLib.wExp12(x)), int256(WAD) + x);
if (x < 0) assertLe(MathLib.wExp12(x), WAD);
assertApproxEqRel(MathLib.wExp12(x), wExpRef(x), 0.01 ether);
assertApproxEqRel(MathLib.wExp(x), wExpRef(x), 0.01 ether);
}
}

Expand Down