diff --git a/packages/horizon/contracts/interfaces/ITAPCollector.sol b/packages/horizon/contracts/interfaces/ITAPCollector.sol index 4d353f39c..9ff103696 100644 --- a/packages/horizon/contracts/interfaces/ITAPCollector.sol +++ b/packages/horizon/contracts/interfaces/ITAPCollector.sol @@ -2,6 +2,7 @@ pragma solidity 0.8.27; import { IPaymentsCollector } from "./IPaymentsCollector.sol"; +import { IGraphPayments } from "./IGraphPayments.sol"; /** * @title Interface for the {TAPCollector} contract @@ -166,6 +167,13 @@ interface ITAPCollector is IPaymentsCollector { */ error TAPCollectorInconsistentRAVTokens(uint256 tokens, uint256 tokensCollected); + /** + * Thrown when the attempting to collect more tokens than what it's owed + * @param tokensToCollect The amount of tokens to collect + * @param maxTokensToCollect The maximum amount of tokens to collect + */ + error TAPCollectorInvalidTokensToCollectAmount(uint256 tokensToCollect, uint256 maxTokensToCollect); + /** * @notice Authorize a signer to sign on behalf of the payer. * A signer can not be authorized for multiple payers even after revoking previous authorizations. @@ -228,4 +236,21 @@ interface ITAPCollector is IPaymentsCollector { * @return The hash of the RAV. */ function encodeRAV(ReceiptAggregateVoucher calldata rav) external view returns (bytes32); + + /** + * @notice See {IPaymentsCollector.collect} + * This variant adds the ability to partially collect a RAV by specifying the amount of tokens to collect. + * + * Requirements: + * - The amount of tokens to collect must be less than or equal to the total amount of tokens in the RAV minus + * the tokens already collected. + * @param paymentType The payment type to collect + * @param data Additional data required for the payment collection + * @param tokensToCollect The amount of tokens to collect + */ + function collect( + IGraphPayments.PaymentTypes paymentType, + bytes calldata data, + uint256 tokensToCollect + ) external returns (uint256); } diff --git a/packages/horizon/contracts/payments/collectors/TAPCollector.sol b/packages/horizon/contracts/payments/collectors/TAPCollector.sol index e5c491732..8c9ae5e2a 100644 --- a/packages/horizon/contracts/payments/collectors/TAPCollector.sol +++ b/packages/horizon/contracts/payments/collectors/TAPCollector.sol @@ -123,28 +123,15 @@ contract TAPCollector is EIP712, GraphDirectory, ITAPCollector { * @notice REVERT: This function may revert if ECDSA.recover fails, check ECDSA library for details. */ function collect(IGraphPayments.PaymentTypes paymentType, bytes memory data) external override returns (uint256) { - (SignedRAV memory signedRAV, uint256 dataServiceCut) = abi.decode(data, (SignedRAV, uint256)); - require( - signedRAV.rav.dataService == msg.sender, - TAPCollectorCallerNotDataService(msg.sender, signedRAV.rav.dataService) - ); - - address signer = _recoverRAVSigner(signedRAV); - require( - authorizedSigners[signer].payer != address(0) && !authorizedSigners[signer].revoked, - TAPCollectorInvalidRAVSigner() - ); - - // Check the service provider has an active provision with the data service - // This prevents an attack where the payer can deny the service provider from collecting payments - // by using a signer as data service to syphon off the tokens in the escrow to an account they control - uint256 tokensAvailable = _graphStaking().getProviderTokensAvailable( - signedRAV.rav.serviceProvider, - signedRAV.rav.dataService - ); - require(tokensAvailable > 0, TAPCollectorUnauthorizedDataService(signedRAV.rav.dataService)); + return _collect(paymentType, data, 0); + } - return _collect(paymentType, authorizedSigners[signer].payer, signedRAV, dataServiceCut); + function collect( + IGraphPayments.PaymentTypes paymentType, + bytes memory data, + uint256 tokensToCollect + ) external override returns (uint256) { + return _collect(paymentType, data, tokensToCollect); } /** @@ -166,28 +153,62 @@ contract TAPCollector is EIP712, GraphDirectory, ITAPCollector { */ function _collect( IGraphPayments.PaymentTypes _paymentType, - address _payer, - SignedRAV memory _signedRAV, - uint256 _dataServiceCut + bytes memory _data, + uint256 _tokensToCollect ) private returns (uint256) { - address dataService = _signedRAV.rav.dataService; - address receiver = _signedRAV.rav.serviceProvider; + (SignedRAV memory signedRAV, uint256 dataServiceCut) = abi.decode(_data, (SignedRAV, uint256)); + require( + signedRAV.rav.dataService == msg.sender, + TAPCollectorCallerNotDataService(msg.sender, signedRAV.rav.dataService) + ); - uint256 tokensRAV = _signedRAV.rav.valueAggregate; - uint256 tokensAlreadyCollected = tokensCollected[dataService][receiver][_payer]; + address signer = _recoverRAVSigner(signedRAV); require( - tokensRAV > tokensAlreadyCollected, - TAPCollectorInconsistentRAVTokens(tokensRAV, tokensAlreadyCollected) + authorizedSigners[signer].payer != address(0) && !authorizedSigners[signer].revoked, + TAPCollectorInvalidRAVSigner() ); + address payer = authorizedSigners[signer].payer; + address dataService = signedRAV.rav.dataService; + address receiver = signedRAV.rav.serviceProvider; + + // Check the service provider has an active provision with the data service + // This prevents an attack where the payer can deny the service provider from collecting payments + // by using a signer as data service to syphon off the tokens in the escrow to an account they control + { + uint256 tokensAvailable = _graphStaking().getProviderTokensAvailable( + signedRAV.rav.serviceProvider, + signedRAV.rav.dataService + ); + require(tokensAvailable > 0, TAPCollectorUnauthorizedDataService(signedRAV.rav.dataService)); + } + + uint256 tokensToCollect = 0; + { + uint256 tokensRAV = signedRAV.rav.valueAggregate; + uint256 tokensAlreadyCollected = tokensCollected[dataService][receiver][payer]; + require( + tokensRAV > tokensAlreadyCollected, + TAPCollectorInconsistentRAVTokens(tokensRAV, tokensAlreadyCollected) + ); + + if (_tokensToCollect == 0) { + tokensToCollect = tokensRAV - tokensAlreadyCollected; + } else { + require( + _tokensToCollect <= tokensRAV - tokensAlreadyCollected, + TAPCollectorInvalidTokensToCollectAmount(_tokensToCollect, tokensRAV - tokensAlreadyCollected) + ); + tokensToCollect = _tokensToCollect; + } + } - uint256 tokensToCollect = tokensRAV - tokensAlreadyCollected; - uint256 tokensDataService = tokensToCollect.mulPPM(_dataServiceCut); + uint256 tokensDataService = tokensToCollect.mulPPM(dataServiceCut); if (tokensToCollect > 0) { - tokensCollected[dataService][receiver][_payer] = tokensRAV; + tokensCollected[dataService][receiver][payer] += tokensToCollect; _graphPaymentsEscrow().collect( _paymentType, - _payer, + payer, receiver, tokensToCollect, dataService, @@ -195,15 +216,15 @@ contract TAPCollector is EIP712, GraphDirectory, ITAPCollector { ); } - emit PaymentCollected(_paymentType, _payer, receiver, tokensToCollect, dataService, tokensDataService); + emit PaymentCollected(_paymentType, payer, receiver, tokensToCollect, dataService, tokensDataService); emit RAVCollected( - _payer, + payer, dataService, receiver, - _signedRAV.rav.timestampNs, - _signedRAV.rav.valueAggregate, - _signedRAV.rav.metadata, - _signedRAV.signature + signedRAV.rav.timestampNs, + signedRAV.rav.valueAggregate, + signedRAV.rav.metadata, + signedRAV.signature ); return tokensToCollect; } diff --git a/packages/horizon/test/payments/tap-collector/TAPCollector.t.sol b/packages/horizon/test/payments/tap-collector/TAPCollector.t.sol index 1120c5b92..ac67d6552 100644 --- a/packages/horizon/test/payments/tap-collector/TAPCollector.t.sol +++ b/packages/horizon/test/payments/tap-collector/TAPCollector.t.sol @@ -119,12 +119,20 @@ contract TAPCollectorTest is HorizonStakingSharedTest, PaymentsEscrowSharedTest } function _collect(IGraphPayments.PaymentTypes _paymentType, bytes memory _data) internal { + __collect(_paymentType, _data, 0); + } + + function _collect(IGraphPayments.PaymentTypes _paymentType, bytes memory _data, uint256 _tokensToCollect) internal { + __collect(_paymentType, _data, _tokensToCollect); + } + + function __collect(IGraphPayments.PaymentTypes _paymentType, bytes memory _data, uint256 _tokensToCollect) internal { (ITAPCollector.SignedRAV memory signedRAV, uint256 dataServiceCut) = abi.decode(_data, (ITAPCollector.SignedRAV, uint256)); bytes32 messageHash = tapCollector.encodeRAV(signedRAV.rav); address _signer = ECDSA.recover(messageHash, signedRAV.signature); (address _payer, , ) = tapCollector.authorizedSigners(_signer); uint256 tokensAlreadyCollected = tapCollector.tokensCollected(signedRAV.rav.dataService, signedRAV.rav.serviceProvider, _payer); - uint256 tokensToCollect = signedRAV.rav.valueAggregate - tokensAlreadyCollected; + uint256 tokensToCollect = _tokensToCollect == 0 ? signedRAV.rav.valueAggregate - tokensAlreadyCollected : _tokensToCollect; uint256 tokensDataService = tokensToCollect.mulPPM(dataServiceCut); vm.expectEmit(address(tapCollector)); @@ -136,6 +144,7 @@ contract TAPCollectorTest is HorizonStakingSharedTest, PaymentsEscrowSharedTest signedRAV.rav.dataService, tokensDataService ); + vm.expectEmit(address(tapCollector)); emit ITAPCollector.RAVCollected( _payer, signedRAV.rav.dataService, @@ -145,11 +154,10 @@ contract TAPCollectorTest is HorizonStakingSharedTest, PaymentsEscrowSharedTest signedRAV.rav.metadata, signedRAV.signature ); - - uint256 tokensCollected = tapCollector.collect(_paymentType, _data); - assertEq(tokensCollected, tokensToCollect); + uint256 tokensCollected = _tokensToCollect == 0 ? tapCollector.collect(_paymentType, _data) : tapCollector.collect(_paymentType, _data, _tokensToCollect); uint256 tokensCollectedAfter = tapCollector.tokensCollected(signedRAV.rav.dataService, signedRAV.rav.serviceProvider, _payer); - assertEq(tokensCollectedAfter, signedRAV.rav.valueAggregate); + assertEq(tokensCollected, tokensToCollect); + assertEq(tokensCollectedAfter, _tokensToCollect == 0 ? signedRAV.rav.valueAggregate : tokensAlreadyCollected + _tokensToCollect); } } diff --git a/packages/horizon/test/payments/tap-collector/collect/collect.t.sol b/packages/horizon/test/payments/tap-collector/collect/collect.t.sol index ecaa1ccec..533358aaa 100644 --- a/packages/horizon/test/payments/tap-collector/collect/collect.t.sol +++ b/packages/horizon/test/payments/tap-collector/collect/collect.t.sol @@ -253,4 +253,47 @@ contract TAPCollectorCollectTest is TAPCollectorTest { resetPrank(users.verifier); _collect(IGraphPayments.PaymentTypes.QueryFee, data); } + + function testTAPCollector_CollectPartial( + uint256 tokens, + uint256 tokensToCollect + ) public useIndexer useProvisionDataService(users.verifier, 100, 0, 0) useGateway useSigner { + tokens = bound(tokens, 1, type(uint128).max); + tokensToCollect = bound(tokensToCollect, 1, tokens); + + _depositTokens(address(tapCollector), users.indexer, tokens); + + bytes memory data = _getQueryFeeEncodedData(signerPrivateKey, users.indexer, users.verifier, uint128(tokens)); + + resetPrank(users.verifier); + _collect(IGraphPayments.PaymentTypes.QueryFee, data, tokensToCollect); + } + + function testTAPCollector_CollectPartial_RevertWhen_AmountTooHigh( + uint256 tokens, + uint256 tokensToCollect + ) public useIndexer useProvisionDataService(users.verifier, 100, 0, 0) useGateway useSigner { + tokens = bound(tokens, 1, type(uint128).max - 1); + + _depositTokens(address(tapCollector), users.indexer, tokens); + + bytes memory data = _getQueryFeeEncodedData(signerPrivateKey, users.indexer, users.verifier, uint128(tokens)); + + resetPrank(users.verifier); + uint256 tokensAlreadyCollected = tapCollector.tokensCollected( + users.verifier, + users.indexer, + users.gateway + ); + tokensToCollect = bound(tokensToCollect, tokens - tokensAlreadyCollected + 1, type(uint128).max); + + vm.expectRevert( + abi.encodeWithSelector( + ITAPCollector.TAPCollectorInvalidTokensToCollectAmount.selector, + tokensToCollect, + tokens - tokensAlreadyCollected + ) + ); + tapCollector.collect(IGraphPayments.PaymentTypes.QueryFee, data, tokensToCollect); + } }