From f9bee71ba7e1d0666ca572fd4cec1233d90b957f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Migone?= Date: Tue, 29 Apr 2025 15:00:35 -0300 Subject: [PATCH] fix: allow partial RAV collection in subgraph service (OZ L-13) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Tomás Migone --- .../contracts/SubgraphService.sol | 17 +++- .../test/mocks/MockRewardsManager.sol | 15 ++-- .../subgraphService/SubgraphService.t.sol | 12 ++- .../subgraphService/collect/query/query.t.sol | 80 +++++++++++++++++-- 4 files changed, 105 insertions(+), 19 deletions(-) diff --git a/packages/subgraph-service/contracts/SubgraphService.sol b/packages/subgraph-service/contracts/SubgraphService.sol index f0d0ec25d..80c3f4b2b 100644 --- a/packages/subgraph-service/contracts/SubgraphService.sol +++ b/packages/subgraph-service/contracts/SubgraphService.sol @@ -268,12 +268,15 @@ contract SubgraphService is uint256 paymentCollected = 0; if (paymentType == IGraphPayments.PaymentTypes.QueryFee) { - IGraphTallyCollector.SignedRAV memory signedRav = abi.decode(data, (IGraphTallyCollector.SignedRAV)); + (IGraphTallyCollector.SignedRAV memory signedRav, uint256 tokensToCollect) = abi.decode( + data, + (IGraphTallyCollector.SignedRAV, uint256) + ); require( signedRav.rav.serviceProvider == indexer, SubgraphServiceIndexerMismatch(signedRav.rav.serviceProvider, indexer) ); - paymentCollected = _collectQueryFees(signedRav); + paymentCollected = _collectQueryFees(signedRav, tokensToCollect); } else if (paymentType == IGraphPayments.PaymentTypes.IndexingRewards) { (address allocationId, bytes32 poi) = abi.decode(data, (address, bytes32)); require( @@ -475,9 +478,14 @@ contract SubgraphService is * Emits a {QueryFeesCollected} event. * * @param _signedRav Signed RAV + * @param tokensToCollect The amount of tokens to collect. Allows partially collecting a RAV. If 0, the entire RAV will + * be collected. * @return The amount of fees collected */ - function _collectQueryFees(IGraphTallyCollector.SignedRAV memory _signedRav) private returns (uint256) { + function _collectQueryFees( + IGraphTallyCollector.SignedRAV memory _signedRav, + uint256 tokensToCollect + ) private returns (uint256) { address indexer = _signedRav.rav.serviceProvider; // Check that collectionId (256 bits) is a valid address (160 bits) @@ -502,7 +510,8 @@ contract SubgraphService is uint256 curationCut = _curation().isCurated(subgraphDeploymentId) ? curationFeesCut : 0; uint256 tokensCollected = _graphTallyCollector().collect( IGraphPayments.PaymentTypes.QueryFee, - abi.encode(_signedRav, curationCut) + abi.encode(_signedRav, curationCut), + tokensToCollect ); uint256 balanceAfter = _graphToken().balanceOf(address(this)); diff --git a/packages/subgraph-service/test/mocks/MockRewardsManager.sol b/packages/subgraph-service/test/mocks/MockRewardsManager.sol index a87c9c06a..3279e114c 100644 --- a/packages/subgraph-service/test/mocks/MockRewardsManager.sol +++ b/packages/subgraph-service/test/mocks/MockRewardsManager.sol @@ -14,7 +14,13 @@ interface IRewardsIssuer { ) external view - returns (bool isActive, address indexer, bytes32 subgraphDeploymentId, uint256 tokens, uint256 accRewardsPerAllocatedToken); + returns ( + bool isActive, + address indexer, + bytes32 subgraphDeploymentId, + uint256 tokens, + uint256 accRewardsPerAllocatedToken + ); } contract MockRewardsManager is IRewardsManager { @@ -71,13 +77,12 @@ contract MockRewardsManager is IRewardsManager { function takeRewards(address _allocationID) external returns (uint256) { address rewardsIssuer = msg.sender; - (bool isActive, , , uint256 tokens, uint256 accRewardsPerAllocatedToken) = IRewardsIssuer(rewardsIssuer).getAllocationData( - _allocationID - ); + (bool isActive, , , uint256 tokens, uint256 accRewardsPerAllocatedToken) = IRewardsIssuer(rewardsIssuer) + .getAllocationData(_allocationID); if (!isActive) { return 0; - } + } uint256 accRewardsPerTokens = tokens.mulPPM(rewardsPerSignal); uint256 rewards = accRewardsPerTokens - accRewardsPerAllocatedToken; diff --git a/packages/subgraph-service/test/subgraphService/SubgraphService.t.sol b/packages/subgraph-service/test/subgraphService/SubgraphService.t.sol index 5e310095b..5418ab435 100644 --- a/packages/subgraph-service/test/subgraphService/SubgraphService.t.sol +++ b/packages/subgraph-service/test/subgraphService/SubgraphService.t.sol @@ -278,7 +278,10 @@ contract SubgraphServiceTest is SubgraphServiceSharedTest { address _indexer, bytes memory _data ) private returns (uint256 paymentCollected) { - IGraphTallyCollector.SignedRAV memory signedRav = abi.decode(_data, (IGraphTallyCollector.SignedRAV)); + (IGraphTallyCollector.SignedRAV memory signedRav, uint256 tokensToCollect) = abi.decode( + _data, + (IGraphTallyCollector.SignedRAV, uint256) + ); address allocationId = address(uint160(uint256(signedRav.rav.collectionId))); Allocation.State memory allocation = subgraphService.getAllocation(allocationId); bytes32 subgraphDeploymentId = allocation.subgraphDeploymentId; @@ -293,7 +296,7 @@ contract SubgraphServiceTest is SubgraphServiceSharedTest { _indexer, payer ); - paymentCollected = signedRav.rav.valueAggregate - tokensCollected; + paymentCollected = tokensToCollect == 0 ? signedRav.rav.valueAggregate - tokensCollected : tokensToCollect; QueryFeeData memory queryFeeData = _queryFeeData(allocation.subgraphDeploymentId); uint256 tokensProtocol = paymentCollected.mulPPMRoundUp(queryFeeData.protocolPaymentCut); @@ -370,7 +373,10 @@ contract SubgraphServiceTest is SubgraphServiceSharedTest { CollectPaymentData memory collectPaymentDataBefore, CollectPaymentData memory collectPaymentDataAfter ) private view { - IGraphTallyCollector.SignedRAV memory signedRav = abi.decode(_data, (IGraphTallyCollector.SignedRAV)); + (IGraphTallyCollector.SignedRAV memory signedRav, uint256 tokensToCollect) = abi.decode( + _data, + (IGraphTallyCollector.SignedRAV, uint256) + ); Allocation.State memory allocation = subgraphService.getAllocation( address(uint160(uint256(signedRav.rav.collectionId))) ); diff --git a/packages/subgraph-service/test/subgraphService/collect/query/query.t.sol b/packages/subgraph-service/test/subgraphService/collect/query/query.t.sol index 28cc677cf..10068f745 100644 --- a/packages/subgraph-service/test/subgraphService/collect/query/query.t.sol +++ b/packages/subgraph-service/test/subgraphService/collect/query/query.t.sol @@ -39,7 +39,11 @@ contract SubgraphServiceRegisterTest is SubgraphServiceTest { return abi.encodePacked(r, s, v); } - function _getQueryFeeEncodedData(address indexer, uint128 tokens) private view returns (bytes memory) { + function _getQueryFeeEncodedData( + address indexer, + uint128 tokens, + uint256 tokensToCollect + ) private view returns (bytes memory) { IGraphTallyCollector.ReceiptAggregateVoucher memory rav = _getRAV( indexer, bytes32(uint256(uint160(allocationID))), @@ -49,7 +53,7 @@ contract SubgraphServiceRegisterTest is SubgraphServiceTest { (uint8 v, bytes32 r, bytes32 s) = vm.sign(signerPrivateKey, messageHash); bytes memory signature = abi.encodePacked(r, s, v); IGraphTallyCollector.SignedRAV memory signedRAV = IGraphTallyCollector.SignedRAV(rav, signature); - return abi.encode(signedRAV); + return abi.encode(signedRAV, tokensToCollect); } function _getRAV( @@ -109,7 +113,7 @@ contract SubgraphServiceRegisterTest is SubgraphServiceTest { _authorizeSigner(); resetPrank(users.indexer); - bytes memory data = _getQueryFeeEncodedData(users.indexer, uint128(tokensPayment)); + bytes memory data = _getQueryFeeEncodedData(users.indexer, uint128(tokensPayment), 0); _collect(users.indexer, IGraphPayments.PaymentTypes.QueryFee, data); } @@ -129,14 +133,14 @@ contract SubgraphServiceRegisterTest is SubgraphServiceTest { uint256 accTokensPayment = 0; for (uint i = 0; i < numPayments; i++) { accTokensPayment = accTokensPayment + tokensPayment; - bytes memory data = _getQueryFeeEncodedData(users.indexer, uint128(accTokensPayment)); + bytes memory data = _getQueryFeeEncodedData(users.indexer, uint128(accTokensPayment), 0); _collect(users.indexer, IGraphPayments.PaymentTypes.QueryFee, data); } } function testCollect_RevertWhen_NotAuthorized(uint256 tokens) public useIndexer useAllocation(tokens) { IGraphPayments.PaymentTypes paymentType = IGraphPayments.PaymentTypes.QueryFee; - bytes memory data = _getQueryFeeEncodedData(users.indexer, uint128(tokens)); + bytes memory data = _getQueryFeeEncodedData(users.indexer, uint128(tokens), 0); resetPrank(users.operator); vm.expectRevert( abi.encodeWithSelector( @@ -157,7 +161,7 @@ contract SubgraphServiceRegisterTest is SubgraphServiceTest { _createAndStartAllocation(newIndexer, tokens); // This data is for user.indexer allocationId - bytes memory data = _getQueryFeeEncodedData(newIndexer, uint128(tokens)); + bytes memory data = _getQueryFeeEncodedData(newIndexer, uint128(tokens), 0); resetPrank(newIndexer); vm.expectRevert( @@ -173,7 +177,7 @@ contract SubgraphServiceRegisterTest is SubgraphServiceTest { // Setup new indexer address newIndexer = makeAddr("newIndexer"); _createAndStartAllocation(newIndexer, tokens); - bytes memory data = _getQueryFeeEncodedData(users.indexer, uint128(tokens)); + bytes memory data = _getQueryFeeEncodedData(users.indexer, uint128(tokens), 0); vm.expectRevert( abi.encodeWithSelector(ISubgraphService.SubgraphServiceIndexerMismatch.selector, users.indexer, newIndexer) ); @@ -193,4 +197,66 @@ contract SubgraphServiceRegisterTest is SubgraphServiceTest { ); subgraphService.collect(users.indexer, IGraphPayments.PaymentTypes.QueryFee, data); } + + function testCollect_QueryFees_PartialCollect( + uint256 tokensAllocated, + uint256 tokensPayment + ) public useIndexer useAllocation(tokensAllocated) { + vm.assume(tokensAllocated > minimumProvisionTokens * stakeToFeesRatio); + uint256 maxTokensPayment = tokensAllocated / stakeToFeesRatio > type(uint128).max + ? type(uint128).max + : tokensAllocated / stakeToFeesRatio; + tokensPayment = bound(tokensPayment, minimumProvisionTokens, maxTokensPayment); + + resetPrank(users.gateway); + _deposit(tokensPayment); + _authorizeSigner(); + + uint256 beforeGatewayBalance = escrow.getBalance(users.gateway, address(graphTallyCollector), users.indexer); + uint256 beforeTokensCollected = graphTallyCollector.tokensCollected( + address(subgraphService), + bytes32(uint256(uint160(allocationID))), + users.indexer, + users.gateway + ); + + // Collect the RAV in two steps + resetPrank(users.indexer); + uint256 tokensToCollect = tokensPayment / 2; + bool oddTokensPayment = tokensPayment % 2 == 1; + bytes memory data = _getQueryFeeEncodedData(users.indexer, uint128(tokensPayment), tokensToCollect); + _collect(users.indexer, IGraphPayments.PaymentTypes.QueryFee, data); + + uint256 intermediateGatewayBalance = escrow.getBalance( + users.gateway, + address(graphTallyCollector), + users.indexer + ); + assertEq(intermediateGatewayBalance, beforeGatewayBalance - tokensToCollect); + uint256 intermediateTokensCollected = graphTallyCollector.tokensCollected( + address(subgraphService), + bytes32(uint256(uint160(allocationID))), + users.indexer, + users.gateway + ); + assertEq(intermediateTokensCollected, beforeTokensCollected + tokensToCollect); + + bytes memory data2 = _getQueryFeeEncodedData( + users.indexer, + uint128(tokensPayment), + oddTokensPayment ? tokensToCollect + 1 : tokensToCollect + ); + _collect(users.indexer, IGraphPayments.PaymentTypes.QueryFee, data2); + + // Check the indexer received the correct amount of tokens + uint256 afterGatewayBalance = escrow.getBalance(users.gateway, address(graphTallyCollector), users.indexer); + assertEq(afterGatewayBalance, beforeGatewayBalance - tokensPayment); + uint256 afterTokensCollected = graphTallyCollector.tokensCollected( + address(subgraphService), + bytes32(uint256(uint160(allocationID))), + users.indexer, + users.gateway + ); + assertEq(afterTokensCollected, intermediateTokensCollected + tokensToCollect + (oddTokensPayment ? 1 : 0)); + } }