diff --git a/.forge-snapshots/BinPoolManagerBytecodeSize.snap b/.forge-snapshots/BinPoolManagerBytecodeSize.snap index 7e31e3d0..eeb037e3 100644 --- a/.forge-snapshots/BinPoolManagerBytecodeSize.snap +++ b/.forge-snapshots/BinPoolManagerBytecodeSize.snap @@ -1 +1 @@ -23634 \ No newline at end of file +23822 \ No newline at end of file diff --git a/.forge-snapshots/CLPoolManagerBytecodeSize.snap b/.forge-snapshots/CLPoolManagerBytecodeSize.snap index 86561671..a90ad789 100644 --- a/.forge-snapshots/CLPoolManagerBytecodeSize.snap +++ b/.forge-snapshots/CLPoolManagerBytecodeSize.snap @@ -1 +1 @@ -20704 \ No newline at end of file +20886 \ No newline at end of file diff --git a/src/ProtocolFees.sol b/src/ProtocolFees.sol index 8a3d1eff..c6f0f216 100644 --- a/src/ProtocolFees.sol +++ b/src/ProtocolFees.sol @@ -10,11 +10,10 @@ import {ProtocolFeeLibrary} from "./libraries/ProtocolFeeLibrary.sol"; import {PoolKey} from "./types/PoolKey.sol"; import {PoolId} from "./types/PoolId.sol"; import {IVault} from "./interfaces/IVault.sol"; -import {BipsLibrary} from "./libraries/BipsLibrary.sol"; +import {CustomRevert} from "./libraries/CustomRevert.sol"; abstract contract ProtocolFees is IProtocolFees, Owner { using ProtocolFeeLibrary for uint24; - using BipsLibrary for uint256; /// @inheritdoc IProtocolFees mapping(Currency currency => uint256 amount) public protocolFeesAccrued; @@ -25,10 +24,6 @@ abstract contract ProtocolFees is IProtocolFees, Owner { /// @inheritdoc IProtocolFees IVault public immutable vault; - // a percentage of the block.gaslimit denoted in basis points, used as the gas limit for fee controller calls - // 100 bps is 1%, at 30M gas, the limit is 300K - uint256 private constant BLOCK_LIMIT_BPS = 100; - constructor(IVault _vault) { vault = _vault; } @@ -45,18 +40,11 @@ abstract contract ProtocolFees is IProtocolFees, Owner { } /// @notice Fetch the protocol fee for a given pool - /// @dev the success of this function is false if the call fails or the returned fees are invalid - /// @dev to prevent an invalid protocol fee controller from blocking pools from being initialized - /// the success of this function is NOT checked on initialize and if the call fails, the protocol fees are set to 0. + /// @dev Revert if call to protocolFeeController fails or if return value is not 32 bytes + /// However if the call to protocolFeeController succeed, it can still revert if the return value is too large + /// @return protocolFee The protocol fee for the pool function _fetchProtocolFee(PoolKey memory key) internal returns (uint24 protocolFee) { if (address(protocolFeeController) != address(0)) { - uint256 controllerGasLimit = block.gaslimit.calculatePortion(BLOCK_LIMIT_BPS); - - // note that EIP-150 mandates that calls requesting more than 63/64ths of remaining gas - // will be allotted no more than this amount, so controllerGasLimit must be set with this - // in mind. - if (gasleft() < controllerGasLimit) revert ProtocolFeeCannotBeFetched(); - address targetProtocolFeeController = address(protocolFeeController); bytes memory data = abi.encodeCall(IProtocolFeeController.protocolFeeForPool, (key)); @@ -64,17 +52,28 @@ abstract contract ProtocolFees is IProtocolFees, Owner { uint256 returnData; assembly ("memory-safe") { // only load the first 32 bytes of the return data to prevent gas griefing - success := call(controllerGasLimit, targetProtocolFeeController, 0, add(data, 0x20), mload(data), 0, 32) - // if success is false this wont actually be returned, instead 0 will be returned + success := call(gas(), targetProtocolFeeController, 0, add(data, 0x20), mload(data), 0, 32) + + // load the return data returnData := mload(0) - // success if return data size is 32 bytes + // success if return data size is also 32 bytes success := and(success, eq(returndatasize(), 32)) } - // Ensure return data does not overflow a uint24 and that the underlying fees are within bounds. - protocolFee = - success && (returnData == uint24(returnData)) && uint24(returnData).validate() ? uint24(returnData) : 0; + // revert if call fails or return size is not 32 bytes + if (!success) { + CustomRevert.bubbleUpAndRevertWith( + targetProtocolFeeController, bytes4(data), ProtocolFeeCannotBeFetched.selector + ); + } + + if (returnData == uint24(returnData) && uint24(returnData).validate()) { + protocolFee = uint24(returnData); + } else { + // revert if return value overflow a uint24 or greater than max protocol fee + revert ProtocolFeeTooLarge(uint24(returnData)); + } } } diff --git a/src/interfaces/IProtocolFees.sol b/src/interfaces/IProtocolFees.sol index 332d3d87..b10b204f 100644 --- a/src/interfaces/IProtocolFees.sol +++ b/src/interfaces/IProtocolFees.sol @@ -10,8 +10,9 @@ import {IVault} from "./IVault.sol"; interface IProtocolFees { /// @notice Thrown when the protocol fee exceeds the upper limit. error ProtocolFeeTooLarge(uint24 fee); - /// @notice Thrown when not enough gas is provided to look up the protocol fee + /// @notice Thrown when calls to protocolFeeController fails or return size is not 32 bytes error ProtocolFeeCannotBeFetched(); + /// @notice Thrown when user not authorized to set or collect protocol fee error InvalidCaller(); diff --git a/src/test/MockFeePoolManager.sol b/src/test/MockFeePoolManager.sol index 8a4ef820..642b22da 100644 --- a/src/test/MockFeePoolManager.sol +++ b/src/test/MockFeePoolManager.sol @@ -23,7 +23,7 @@ contract MockFeePoolManager is ProtocolFees { uint24 protocolFee; } - constructor(IVault vault, uint256 controllerGasLimit) ProtocolFees(vault) {} + constructor(IVault vault) ProtocolFees(vault) {} function initialize(PoolKey memory key) external { PoolId id = key.toId(); diff --git a/src/test/fee/MockProtocolFeeController.sol b/src/test/fee/MockProtocolFeeController.sol index 15044adb..a9112a06 100644 --- a/src/test/fee/MockProtocolFeeController.sol +++ b/src/test/fee/MockProtocolFeeController.sol @@ -25,8 +25,10 @@ contract MockProtocolFeeController is IProtocolFeeController { /// @notice Reverts on call contract RevertingMockProtocolFeeController is IProtocolFeeController { + error DevsBlock(); + function protocolFeeForPool(PoolKey memory /* key */ ) external pure returns (uint24) { - revert(); + revert DevsBlock(); } } diff --git a/test/ProtocolFees.t.sol b/test/ProtocolFees.t.sol index a72ed801..23fc5347 100644 --- a/test/ProtocolFees.t.sol +++ b/test/ProtocolFees.t.sol @@ -38,7 +38,7 @@ contract ProtocolFeesTest is Test { function setUp() public { vault = new MockVault(); - poolManager = new MockFeePoolManager(IVault(address(vault)), 500_000); + poolManager = new MockFeePoolManager(IVault(address(vault))); feeController = new MockProtocolFeeController(); revertingFeeController = new RevertingMockProtocolFeeController(); outOfBoundsFeeController = new OutOfBoundsMockProtocolFeeController(); @@ -75,72 +75,106 @@ contract ProtocolFeesTest is Test { assertEq(protocolFee1, 0); } - function testInit_WhenFeeController_ProtocolFeeCannotBeFetched() public { - MockFeePoolManager poolManagerWithLowControllerGasLimit = - new MockFeePoolManager(IVault(address(vault)), 5000_000); - PoolKey memory _key = PoolKey({ - currency0: Currency.wrap(address(token0)), - currency1: Currency.wrap(address(token1)), - hooks: IHooks(address(0)), - poolManager: IPoolManager(address(poolManagerWithLowControllerGasLimit)), - fee: uint24(0), // fee not used in the setup - parameters: 0x00 - }); - poolManagerWithLowControllerGasLimit.setProtocolFeeController(feeController); + function test_Init_ProtocolFeeTooLarge() public { + uint24 protocolFee = + _buildProtocolFee(ProtocolFeeLibrary.MAX_PROTOCOL_FEE + 1, ProtocolFeeLibrary.MAX_PROTOCOL_FEE + 1); + feeController.setProtocolFeeForPool(key, protocolFee); + poolManager.setProtocolFeeController(IProtocolFeeController(address(feeController))); - vm.expectRevert(IProtocolFees.ProtocolFeeCannotBeFetched.selector); - poolManagerWithLowControllerGasLimit.initialize{gas: 2000_000}(_key); + vm.expectRevert(abi.encodeWithSelector(IProtocolFees.ProtocolFeeTooLarge.selector, protocolFee)); + poolManager.initialize(key); + } + + function testFuzz_Init_WhenOutOfGasForProtocolFeeController(uint256 gasLimit) public { + gasLimit = bound(gasLimit, 10_000, 100_000); // 10_000 gas will have out of gas revert + + uint24 protocolFee = _buildProtocolFee(ProtocolFeeLibrary.MAX_PROTOCOL_FEE, ProtocolFeeLibrary.MAX_PROTOCOL_FEE); + feeController.setProtocolFeeForPool(key, protocolFee); + poolManager.setProtocolFeeController(IProtocolFeeController(address(feeController))); + + try poolManager.initialize{gas: gasLimit}(key) { + // txn success, verify if protocol fee is set + uint24 fetchedProtocolFee = poolManager.pools(key.toId()); + assertEq(fetchedProtocolFee, protocolFee); + } catch { + // txn reverted, can ignore checking + } } function testInit_WhenFeeControllerRevert() public { poolManager.setProtocolFeeController(revertingFeeController); - poolManager.initialize(key); - assertEq(poolManager.getProtocolFee(key), 0); + vm.expectRevert( + abi.encodeWithSelector( + CustomRevert.WrappedError.selector, + address(revertingFeeController), + IProtocolFeeController.protocolFeeForPool.selector, + abi.encodeWithSelector(RevertingMockProtocolFeeController.DevsBlock.selector), + abi.encodeWithSelector(IProtocolFees.ProtocolFeeCannotBeFetched.selector) + ) + ); + poolManager.initialize(key); } function testInit_WhenFeeControllerOutOfBound() public { poolManager.setProtocolFeeController(outOfBoundsFeeController); assertEq(address(poolManager.protocolFeeController()), address(outOfBoundsFeeController)); - poolManager.initialize(key); - assertEq(poolManager.getProtocolFee(key), 0); + vm.expectRevert( + abi.encodeWithSelector(IProtocolFees.ProtocolFeeTooLarge.selector, ProtocolFeeLibrary.MAX_PROTOCOL_FEE + 1) + ); + poolManager.initialize(key); } function testInit_WhenFeeControllerOverflow() public { poolManager.setProtocolFeeController(overflowFeeController); assertEq(address(poolManager.protocolFeeController()), address(overflowFeeController)); - poolManager.initialize(key); - assertEq(poolManager.getProtocolFee(key), 0); + // 0xFFFFFFFFAAA001 from OverflowMockProtocolFeeController + vm.expectRevert( + abi.encodeWithSelector(IProtocolFees.ProtocolFeeTooLarge.selector, uint24(uint256(0xFFFFFFFFAAA001))) + ); + poolManager.initialize(key); } function testInit_WhenFeeControllerInvalidReturnSize() public { poolManager.setProtocolFeeController(invalidReturnSizeFeeController); assertEq(address(poolManager.protocolFeeController()), address(invalidReturnSizeFeeController)); + + vm.expectRevert( + abi.encodeWithSelector( + CustomRevert.WrappedError.selector, + address(invalidReturnSizeFeeController), + IProtocolFeeController.protocolFeeForPool.selector, + abi.encode(address(invalidReturnSizeFeeController), address(invalidReturnSizeFeeController)), + abi.encodeWithSelector(IProtocolFees.ProtocolFeeCannotBeFetched.selector) + ) + ); poolManager.initialize(key); assertEq(poolManager.getProtocolFee(key), 0); } - function testInitFuzz(uint24 fee) public { + function testInitFuzz(uint24 protocolFee) public { poolManager.setProtocolFeeController(feeController); vm.mockCall( - address(feeController), abi.encodeCall(IProtocolFeeController.protocolFeeForPool, key), abi.encode(fee) + address(feeController), + abi.encodeCall(IProtocolFeeController.protocolFeeForPool, key), + abi.encode(protocolFee) ); - poolManager.initialize(key); - - if (fee != 0) { - uint24 fee0 = fee % 4096; - uint24 fee1 = fee >> 12; + if (protocolFee != 0) { + uint24 fee0 = protocolFee % 4096; + uint24 fee1 = protocolFee >> 12; if (fee0 > ProtocolFeeLibrary.MAX_PROTOCOL_FEE || fee1 > ProtocolFeeLibrary.MAX_PROTOCOL_FEE) { // invalid fee, fallback to 0 - assertEq(poolManager.getProtocolFee(key), 0); + vm.expectRevert(abi.encodeWithSelector(IProtocolFees.ProtocolFeeTooLarge.selector, protocolFee)); + poolManager.initialize(key); } else { - assertEq(poolManager.getProtocolFee(key), fee); + poolManager.initialize(key); + assertEq(poolManager.getProtocolFee(key), protocolFee); } } } diff --git a/test/pool-cl/CLProtocolFees.t.sol b/test/pool-cl/CLProtocolFees.t.sol index 029becc0..559c4d01 100644 --- a/test/pool-cl/CLProtocolFees.t.sol +++ b/test/pool-cl/CLProtocolFees.t.sol @@ -28,6 +28,7 @@ import {IVault} from "../../src/interfaces/IVault.sol"; import {ProtocolFeeLibrary} from "../../src/libraries/ProtocolFeeLibrary.sol"; import {CLPoolGetter} from "./helpers/CLPoolGetter.sol"; import {CLSlot0} from "../../src/pool-cl/types/CLSlot0.sol"; +import {CustomRevert} from "../../src/libraries/CustomRevert.sol"; contract CLProtocolFeesTest is Test, Deployers, TokenFixture, GasSnapshot { using Hooks for IHooks; @@ -164,19 +165,27 @@ contract CLProtocolFeesTest is Test, Deployers, TokenFixture, GasSnapshot { assertEq(currency1.balanceOf(address(protocolFeeController)), expectedProtocolFees); } + /// @dev this should not happen as ProtocolFeeController is owned by PCS and theres no incentive for PCS to block pool creation function testMaliciousProtocolFeeControllerReturnHugeData() public { IProtocolFeeController controller = IProtocolFeeController(address(new MaliciousProtocolFeeController())); manager.setProtocolFeeController(controller); // the original pool has already been initialized, hence we need to pick a new pool key.fee = 6000; - uint256 gasBefore = gasleft(); + // payload from MaliciousProtocolFeeController + bytes memory payload = new bytes(230_000); + payload[payload.length - 1] = 0x01; + + vm.expectRevert( + abi.encodeWithSelector( + CustomRevert.WrappedError.selector, + address(controller), + IProtocolFeeController.protocolFeeForPool.selector, + abi.encode(payload), + abi.encodeWithSelector(IProtocolFees.ProtocolFeeCannotBeFetched.selector) + ) + ); manager.initialize(key, SQRT_RATIO_1_1); - uint256 gasConsumed = gasBefore - gasleft(); - /// @dev Return data size 230k would consume almost all the gas speicified in the controllerGasLimit i.e. 500k - /// And the gas consumed by the tx would be more than 800K if the payload is copied to the caller context. - /// The following assertion makes sure this doesn't happen. - assertLe(gasConsumed, 800_000, "gas griefing vector"); } }