Skip to content

Commit

Permalink
test: add more assets in invariant tests
Browse files Browse the repository at this point in the history
test: add an array of assets in CreateHandler
test: fuzz the asset in CreateHandler
test: remove unneeded Invariant_Test contract
test: remove checkUsers modifier
test: move test contracts in BaseHandler
test: run invariant_ContractBalanceGeStreamBalances test for all assets
test: add a helper function to prevent stack too deep error
test: use assume instead of "if() return"
  • Loading branch information
andreivladbrg committed Jun 6, 2024
1 parent 239d8c6 commit 628e595
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 129 deletions.
1 change: 1 addition & 0 deletions test/Base.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ abstract contract Base_Test is Assertions, Constants, Events, Modifiers, Test, U
}

function labelContracts() internal {
vm.label(address(assetWithoutDecimals), "AWD");
vm.label(address(dai), "DAI");
vm.label(address(flow), "Flow");
vm.label(address(usdc), "USDC");
Expand Down
38 changes: 29 additions & 9 deletions test/invariant/Flow.t.sol
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity >=0.8.22 <0.9.0;

import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol";

import { Helpers } from "src/libraries/Helpers.sol";

import { Invariant_Test } from "./Invariant.t.sol";
import { Base_Test } from "../Base.t.sol";
import { FlowCreateHandler } from "./handlers/FlowCreateHandler.sol";
import { FlowHandler } from "./handlers/FlowHandler.sol";
import { FlowStore } from "./stores/FlowStore.sol";

/// @notice Common invariant test logic needed across contracts that inherit from {SablierFlow}.
contract Flow_Invariant_Test is Invariant_Test {
contract Flow_Invariant_Test is Base_Test {
/*//////////////////////////////////////////////////////////////////////////
TEST CONTRACTS
//////////////////////////////////////////////////////////////////////////*/

IERC20[] internal assets;
FlowCreateHandler internal flowCreateHandler;
FlowHandler internal flowHandler;
FlowStore internal flowStore;
Expand All @@ -23,14 +26,20 @@ contract Flow_Invariant_Test is Invariant_Test {
//////////////////////////////////////////////////////////////////////////*/

function setUp() public virtual override {
Invariant_Test.setUp();
Base_Test.setUp();

// Declare the default assets.
assets.push(assetWithoutDecimals);
assets.push(dai);
assets.push(usdc);
assets.push(IERC20(address(usdt)));

// Deploy and the FlowStore contract.
flowStore = new FlowStore();

// Deploy the handlers.
flowHandler = new FlowHandler({ asset_: dai, flowStore_: flowStore, flow_: flow });
flowCreateHandler = new FlowCreateHandler({ asset_: dai, flowStore_: flowStore, flow_: flow });
flowHandler = new FlowHandler({ flowStore_: flowStore, flow_: flow });
flowCreateHandler = new FlowCreateHandler({ flowStore_: flowStore, flow_: flow, assets_: assets });

// Label the contracts.
vm.label({ account: address(flowStore), newLabel: "flowStore" });
Expand All @@ -42,6 +51,7 @@ contract Flow_Invariant_Test is Invariant_Test {
targetContract(address(flowCreateHandler));

// Prevent these contracts from being fuzzed as `msg.sender`.
excludeSender(address(flow));
excludeSender(address(flowStore));
excludeSender(address(flowHandler));
excludeSender(address(flowCreateHandler));
Expand All @@ -67,14 +77,24 @@ contract Flow_Invariant_Test is Invariant_Test {
/// @dev For a given asset, the sum of all stream balances normalized to the asset's decimal should never exceed
/// the asset balance of the flow contract.
function invariant_ContractBalanceGeStreamBalances() external view {
uint256 contractBalance = dai.balanceOf(address(flow));
// Check the invariant for each asset.
for (uint256 i = 0; i < assets.length; ++i) {
contractBalanceGeStreamBalances(assets[i]);
}
}

uint256 lastStreamId = flowStore.lastStreamId();
function contractBalanceGeStreamBalances(IERC20 asset) internal view {
uint256 contractBalance = asset.balanceOf(address(flow));
uint128 streamBalancesSumNormalized;

uint256 lastStreamId = flowStore.lastStreamId();
for (uint256 i = 0; i < lastStreamId; ++i) {
uint256 streamId = flowStore.streamIds(i);
streamBalancesSumNormalized +=
Helpers.calculateTransferAmount(flow.getBalance(streamId), flow.getAssetDecimals(streamId));

if (flow.getAsset(streamId) == asset) {
streamBalancesSumNormalized +=
Helpers.calculateTransferAmount(flow.getBalance(streamId), flow.getAssetDecimals(streamId));
}
}

assertGe(
Expand Down
18 changes: 0 additions & 18 deletions test/invariant/Invariant.t.sol

This file was deleted.

51 changes: 22 additions & 29 deletions test/invariant/handlers/BaseHandler.sol
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity >=0.8.22 <0.9.0;

import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import { StdCheats } from "forge-std/src/StdCheats.sol";

import { ISablierFlow } from "src/interfaces/ISablierFlow.sol";

import { FlowStore } from "../stores/FlowStore.sol";
import { Constants } from "../../utils/Constants.sol";
import { Utils } from "../../utils/Utils.sol";

/// @notice Base contract with common logic needed by all handler contracts.
abstract contract BaseHandler is Constants, StdCheats, Utils {
/*//////////////////////////////////////////////////////////////////////////
CONSTANTS
VARIABLES
//////////////////////////////////////////////////////////////////////////*/

/// @dev Maximum number of streams that can be created during an invariant campaign.
uint256 internal constant MAX_STREAM_COUNT = 100;

/*//////////////////////////////////////////////////////////////////////////
VARIABLES
//////////////////////////////////////////////////////////////////////////*/

/// @dev Maps function names to the number of times they have been called.
mapping(string func => uint256 calls) public calls;

Expand All @@ -30,15 +28,16 @@ abstract contract BaseHandler is Constants, StdCheats, Utils {
TEST CONTRACTS
//////////////////////////////////////////////////////////////////////////*/

/// @dev Default ERC20 asset used for testing.
IERC20 public asset;
ISablierFlow public flow;
FlowStore public flowStore;

/*//////////////////////////////////////////////////////////////////////////
CONSTRUCTOR
//////////////////////////////////////////////////////////////////////////*/

constructor(IERC20 asset_) {
asset = asset_;
constructor(FlowStore flowStore_, ISablierFlow flow_) {
flowStore = flowStore_;
flow = flow_;
}

/*//////////////////////////////////////////////////////////////////////////
Expand All @@ -53,31 +52,25 @@ abstract contract BaseHandler is Constants, StdCheats, Utils {
_;
}

/// @dev Checks user assumptions.
modifier checkUsers(address sender, address recipient) {
// The protocol doesn't allow the sender or recipient to be the zero address.
if (sender == address(0) || recipient == address(0)) {
return;
}

// Prevent the contract itself from playing the role of any user.
if (sender == address(this) || recipient == address(this)) {
return;
}

_;
}

/// @dev Records a function call for instrumentation purposes.
modifier instrument(string memory functionName) {
calls[functionName]++;
totalCalls++;
_;
}

/// @dev Makes the provided sender the caller.
modifier useNewSender(address sender) {
resetPrank(sender);
_;
/*//////////////////////////////////////////////////////////////////////////
HELPERS
//////////////////////////////////////////////////////////////////////////*/

/// @dev Helper function to calculate the upper bound, based on the asset decimals, for the transfer amount.
function getTransferAmountUpperBound(uint8 assetDecimals) internal pure returns (uint128 upperBound) {
if (assetDecimals == 0) {
upperBound = 1_000_000;
} else if (assetDecimals == 6) {
upperBound = 1_000_000e6;
} else {
upperBound = 1_000_000e18;
}
}
}
106 changes: 69 additions & 37 deletions test/invariant/handlers/FlowCreateHandler.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
pragma solidity >=0.8.22 <0.9.0;

import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import { IERC20Metadata } from "@openzeppelin/contracts/token/ERC20/extensions/IERC20Metadata.sol";

import { ISablierFlow } from "src/interfaces/ISablierFlow.sol";
import { Helpers } from "src/libraries/Helpers.sol";

import { FlowStore } from "../stores/FlowStore.sol";
import { BaseHandler } from "./BaseHandler.sol";
Expand All @@ -13,19 +15,31 @@ import { BaseHandler } from "./BaseHandler.sol";
/// the contracts.
contract FlowCreateHandler is BaseHandler {
/*//////////////////////////////////////////////////////////////////////////
TEST CONTRACTS
VARIABLES
//////////////////////////////////////////////////////////////////////////*/

ISablierFlow public flow;
FlowStore public flowStore;
/// @dev Default ERC20 assets used for testing.
IERC20[] public assets;
IERC20 currentAsset;

/*//////////////////////////////////////////////////////////////////////////
MODIFIERS
//////////////////////////////////////////////////////////////////////////*/

modifier useFuzzedAsset(uint256 assetIndexSeed) {
assetIndexSeed = _bound(assetIndexSeed, 0, assets.length - 1);
currentAsset = assets[assetIndexSeed];
_;
}

/*//////////////////////////////////////////////////////////////////////////
CONSTRUCTOR
//////////////////////////////////////////////////////////////////////////*/

constructor(IERC20 asset_, FlowStore flowStore_, ISablierFlow flow_) BaseHandler(asset_) {
flowStore = flowStore_;
flow = flow_;
constructor(FlowStore flowStore_, ISablierFlow flow_, IERC20[] memory assets_) BaseHandler(flowStore_, flow_) {
for (uint256 i = 0; i < assets_.length; ++i) {
assets.push(assets_[i]);
}
}

/*//////////////////////////////////////////////////////////////////////////
Expand All @@ -35,69 +49,87 @@ contract FlowCreateHandler is BaseHandler {
/// @dev Struct to prevent stack too deep error.
struct CreateParams {
uint256 timeJumpSeed;
uint256 assetIndexSeed;
address sender;
address recipient;
uint128 ratePerSecond;
bool isTransferable;
}

function create(CreateParams memory params)
public
instrument("create")
adjustTimestamp(params.timeJumpSeed)
checkUsers(params.sender, params.recipient)
useNewSender(params.sender)
{
// We don't want to create more than a certain number of streams.
if (flowStore.lastStreamId() >= MAX_STREAM_COUNT) {
return;
}
function create(CreateParams memory params) public {
_commonModifiersInit(params, "create");

// Bound the stream parameters.
params.ratePerSecond = uint128(_bound(params.ratePerSecond, 0.0001e18, 1e18));

// Create the stream.
uint256 streamId =
flow.create(params.sender, params.recipient, params.ratePerSecond, asset, params.isTransferable);
flow.create(params.sender, params.recipient, params.ratePerSecond, currentAsset, params.isTransferable);

// Store the stream id.
flowStore.pushStreamId(streamId, params.sender, params.recipient);
}

function createAndDeposit(
CreateParams memory params,
uint128 transferAmount
)
public
instrument("createAndDeposit")
adjustTimestamp(params.timeJumpSeed)
checkUsers(params.sender, params.recipient)
useNewSender(params.sender)
{
// We don't want to create more than a certain number of streams.
if (flowStore.lastStreamId() >= MAX_STREAM_COUNT) {
return;
}
function createAndDeposit(CreateParams memory params, uint128 transferAmount) public {
_commonModifiersInit(params, "createAndDeposit");

uint8 decimals = IERC20Metadata(address(currentAsset)).decimals();

// Calculate the upper bound, based on the asset decimals, for the transfer amount.
uint128 upperBound = getTransferAmountUpperBound(decimals);

// Bound the stream parameters.
params.ratePerSecond = uint128(_bound(params.ratePerSecond, 0.0001e18, 1e18));
transferAmount = uint128(_bound(transferAmount, 100e18, 1_000_000_000e18));
transferAmount = uint128(_bound(transferAmount, 100, upperBound));

// Mint enough assets to the Sender.
deal({ token: address(asset), to: params.sender, give: asset.balanceOf(params.sender) + transferAmount });
deal({
token: address(currentAsset),
to: params.sender,
give: currentAsset.balanceOf(params.sender) + transferAmount
});

// Approve {SablierFlow} to spend the assets.
asset.approve({ spender: address(flow), value: transferAmount });
currentAsset.approve({ spender: address(flow), value: transferAmount });

// Create the stream.
uint256 streamId = flow.createAndDeposit(
params.sender, params.recipient, params.ratePerSecond, asset, params.isTransferable, transferAmount
params.sender, params.recipient, params.ratePerSecond, currentAsset, params.isTransferable, transferAmount
);

// Store the stream id.
flowStore.pushStreamId(streamId, params.sender, params.recipient);

uint128 normalizedAmount = Helpers.calculateNormalizedAmount(transferAmount, decimals);

// Store the deposited amount.
flowStore.updateStreamDepositedAmountsSum(streamId, transferAmount);
flowStore.updateStreamDepositedAmountsSum(streamId, normalizedAmount);
}

/*//////////////////////////////////////////////////////////////////////////
HELPERS
//////////////////////////////////////////////////////////////////////////*/

/// @dev Helper function to avoid stack too deep error.
function _commonModifiersInit(
CreateParams memory params,
string memory functionName
)
internal
instrument(functionName)
adjustTimestamp(params.timeJumpSeed)
useFuzzedAsset(params.assetIndexSeed)
{
// We don't want to create more than a certain number of streams.
vm.assume(flowStore.lastStreamId() < MAX_STREAM_COUNT);

// The protocol doesn't allow the sender or recipient to be the zero address.
vm.assume(params.sender != address(0) && params.recipient != address(0));

// Prevent the contract itself from playing the role of any user.
vm.assume(params.sender != address(this) && params.recipient != address(this));

// Reset the caller.
resetPrank(params.sender);
}
}
Loading

0 comments on commit 628e595

Please sign in to comment.