diff --git a/modules/apps/29-fee/ibc_module_test.go b/modules/apps/29-fee/ibc_module_test.go index 2a54fc03ec3..4ca50d2c612 100644 --- a/modules/apps/29-fee/ibc_module_test.go +++ b/modules/apps/29-fee/ibc_module_test.go @@ -631,7 +631,7 @@ func (suite *FeeTestSuite) TestOnAcknowledgementPacket() { }.Acknowledgement() expectedRelayerBalance = packetFee.Fee.AckFee - expectedBalance = expectedBalance.Add(packetFee.Fee.RecvFee[0]) + expectedBalance = expectedBalance.Add(packetFee.Fee.RecvFee...) }, true, }, @@ -759,8 +759,9 @@ func (suite *FeeTestSuite) TestOnTimeoutPacket() { relayerAddr = suite.chainA.GetSimApp().AccountKeeper.GetModuleAccount(suite.chainA.GetContext(), transfertypes.ModuleName).GetAddress() expectedBalance = originalBalance. - Add(packetFee.Fee.RecvFee[0]). - Add(packetFee.Fee.AckFee[0]) + Add(packetFee.Fee.RecvFee...). + Add(packetFee.Fee.AckFee...). + Add(packetFee.Fee.TimeoutFee...) }, false, }, diff --git a/modules/apps/29-fee/keeper/escrow.go b/modules/apps/29-fee/keeper/escrow.go index 40aad06e3be..f0c18e379ba 100644 --- a/modules/apps/29-fee/keeper/escrow.go +++ b/modules/apps/29-fee/keeper/escrow.go @@ -1,6 +1,7 @@ package keeper import ( + "bytes" "fmt" sdk "github.com/cosmos/cosmos-sdk/types" @@ -61,17 +62,17 @@ func (k Keeper) DistributePacketFees(ctx sdk.Context, forwardRelayer string, rev // distribute fee to valid forward relayer address otherwise refund the fee if !forwardAddr.Empty() && !k.bankKeeper.BlockedAddr(forwardAddr) { // distribute fee for forward relaying - k.distributeFee(ctx, forwardAddr, packetFee.Fee.RecvFee) + k.distributeFee(ctx, forwardAddr, refundAddr, packetFee.Fee.RecvFee) } else { // refund onRecv fee as forward relayer is not valid address - k.distributeFee(ctx, refundAddr, packetFee.Fee.RecvFee) + k.distributeFee(ctx, refundAddr, refundAddr, packetFee.Fee.RecvFee) } // distribute fee for reverse relaying - k.distributeFee(ctx, reverseRelayer, packetFee.Fee.AckFee) + k.distributeFee(ctx, reverseRelayer, refundAddr, packetFee.Fee.AckFee) // refund timeout fee for unused timeout - k.distributeFee(ctx, refundAddr, packetFee.Fee.TimeoutFee) + k.distributeFee(ctx, refundAddr, refundAddr, packetFee.Fee.TimeoutFee) } } @@ -85,31 +86,42 @@ func (k Keeper) DistributePacketFeesOnTimeout(ctx sdk.Context, timeoutRelayer sd } // refund receive fee for unused forward relaying - k.distributeFee(ctx, refundAddr, feeInEscrow.Fee.RecvFee) + k.distributeFee(ctx, refundAddr, refundAddr, feeInEscrow.Fee.RecvFee) // refund ack fee for unused reverse relaying - k.distributeFee(ctx, refundAddr, feeInEscrow.Fee.AckFee) + k.distributeFee(ctx, refundAddr, refundAddr, feeInEscrow.Fee.AckFee) // distribute fee for timeout relaying - k.distributeFee(ctx, timeoutRelayer, feeInEscrow.Fee.TimeoutFee) + k.distributeFee(ctx, timeoutRelayer, refundAddr, feeInEscrow.Fee.TimeoutFee) } } // distributeFee will attempt to distribute the escrowed fee to the receiver address. // If the distribution fails for any reason (such as the receiving address being blocked), // the state changes will be discarded. -func (k Keeper) distributeFee(ctx sdk.Context, receiver sdk.AccAddress, fee sdk.Coins) { +func (k Keeper) distributeFee(ctx sdk.Context, receiver, refundAccAddress sdk.AccAddress, fee sdk.Coins) { // cache context before trying to distribute fees cacheCtx, writeFn := ctx.CacheContext() err := k.bankKeeper.SendCoinsFromModuleToAccount(cacheCtx, types.ModuleName, receiver, fee) - if err == nil { - // write the cache - writeFn() + if err != nil { + if bytes.Equal(receiver, refundAccAddress) { + return // if sending to the refund address already failed, then return (no-op) + } - // NOTE: The context returned by CacheContext() refers to a new EventManager, so it needs to explicitly set events to the original context. - ctx.EventManager().EmitEvents(cacheCtx.EventManager().Events()) + // if an error is returned from x/bank and the receiver is not the refundAccAddress + // then attempt to refund the fee to the original sender + err := k.bankKeeper.SendCoinsFromModuleToAccount(cacheCtx, types.ModuleName, refundAccAddress, fee) + if err != nil { + return // if sending to the refund address fails, no-op + } } + + // write the cache + writeFn() + + // NOTE: The context returned by CacheContext() refers to a new EventManager, so it needs to explicitly set events to the original context. + ctx.EventManager().EmitEvents(cacheCtx.EventManager().Events()) } func (k Keeper) RefundFeesOnChannel(ctx sdk.Context, portID, channelID string) error { diff --git a/modules/apps/29-fee/keeper/escrow_test.go b/modules/apps/29-fee/keeper/escrow_test.go index b03da9f12ca..052a8cdc9c1 100644 --- a/modules/apps/29-fee/keeper/escrow_test.go +++ b/modules/apps/29-fee/keeper/escrow_test.go @@ -122,64 +122,93 @@ func (suite *KeeperTestSuite) TestEscrowPacketFee() { func (suite *KeeperTestSuite) TestDistributeFee() { var ( - reverseRelayer sdk.AccAddress - forwardRelayer string - refundAcc sdk.AccAddress - refundAccBal sdk.Coin - fee types.Fee - packetID channeltypes.PacketId + forwardRelayer string + forwardRelayerBal sdk.Coin + reverseRelayer sdk.AccAddress + reverseRelayerBal sdk.Coin + refundAcc sdk.AccAddress + refundAccBal sdk.Coin + packetFee types.PacketFee ) - validSeq := uint64(1) - testCases := []struct { name string malleate func() expResult func() }{ { - "success", func() {}, func() { + "success", + func() {}, + func() { // check if the reverse relayer is paid - hasBalance := suite.chainA.GetSimApp().BankKeeper.HasBalance(suite.chainA.GetContext(), reverseRelayer, fee.AckFee[0].Add(fee.AckFee[0])) - suite.Require().True(hasBalance) + expectedReverseAccBal := reverseRelayerBal.Add(defaultAckFee[0]).Add(defaultAckFee[0]) + balance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), reverseRelayer, sdk.DefaultBondDenom) + suite.Require().Equal(expectedReverseAccBal, balance) // check if the forward relayer is paid forward, err := sdk.AccAddressFromBech32(forwardRelayer) suite.Require().NoError(err) - hasBalance = suite.chainA.GetSimApp().BankKeeper.HasBalance(suite.chainA.GetContext(), forward, fee.RecvFee[0].Add(fee.RecvFee[0])) - suite.Require().True(hasBalance) + + expectedForwardAccBal := forwardRelayerBal.Add(defaultReceiveFee[0]).Add(defaultReceiveFee[0]) + balance = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), forward, sdk.DefaultBondDenom) + suite.Require().Equal(expectedForwardAccBal, balance) // check if the refund acc has been refunded the timeoutFee - expectedRefundAccBal := refundAccBal.Add(fee.TimeoutFee[0].Add(fee.TimeoutFee[0])) - hasBalance = suite.chainA.GetSimApp().BankKeeper.HasBalance(suite.chainA.GetContext(), refundAcc, expectedRefundAccBal) - suite.Require().True(hasBalance) + expectedRefundAccBal := refundAccBal.Add(defaultTimeoutFee[0].Add(defaultTimeoutFee[0])) + balance = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAcc, sdk.DefaultBondDenom) + suite.Require().Equal(expectedRefundAccBal, balance) // check the module acc wallet is now empty - hasBalance = suite.chainA.GetSimApp().BankKeeper.HasBalance(suite.chainA.GetContext(), suite.chainA.GetSimApp().IBCFeeKeeper.GetFeeModuleAddress(), sdk.Coin{Denom: sdk.DefaultBondDenom, Amount: sdk.NewInt(0)}) - suite.Require().True(hasBalance) + balance = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.GetSimApp().IBCFeeKeeper.GetFeeModuleAddress(), sdk.DefaultBondDenom) + suite.Require().Equal(sdk.NewCoin(sdk.DefaultBondDenom, sdk.NewInt(0)), balance) }, }, { - "invalid forward address", func() { + "invalid forward address", + func() { forwardRelayer = "invalid address" }, func() { - // check if the refund acc has been refunded the timeoutFee & onRecvFee - expectedRefundAccBal := refundAccBal.Add(fee.TimeoutFee[0]).Add(fee.RecvFee[0]).Add(fee.TimeoutFee[0]).Add(fee.RecvFee[0]) - hasBalance := suite.chainA.GetSimApp().BankKeeper.HasBalance(suite.chainA.GetContext(), refundAcc, expectedRefundAccBal) - suite.Require().True(hasBalance) - + // check if the refund acc has been refunded the timeoutFee & recvFee + expectedRefundAccBal := refundAccBal.Add(defaultTimeoutFee[0]).Add(defaultReceiveFee[0]).Add(defaultTimeoutFee[0]).Add(defaultReceiveFee[0]) + balance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAcc, sdk.DefaultBondDenom) + suite.Require().Equal(expectedRefundAccBal, balance) }, }, { - "invalid forward address: blocked address", func() { + "invalid forward address: blocked address", + func() { forwardRelayer = suite.chainA.GetSimApp().AccountKeeper.GetModuleAccount(suite.chainA.GetContext(), transfertypes.ModuleName).GetAddress().String() }, func() { - // check if the refund acc has been refunded the timeoutFee & onRecvFee - expectedRefundAccBal := refundAccBal.Add(fee.TimeoutFee[0]).Add(fee.RecvFee[0]).Add(fee.TimeoutFee[0]).Add(fee.RecvFee[0]) - hasBalance := suite.chainA.GetSimApp().BankKeeper.HasBalance(suite.chainA.GetContext(), refundAcc, expectedRefundAccBal) - suite.Require().True(hasBalance) + // check if the refund acc has been refunded the timeoutFee & recvFee + expectedRefundAccBal := refundAccBal.Add(defaultTimeoutFee[0]).Add(defaultReceiveFee[0]).Add(defaultTimeoutFee[0]).Add(defaultReceiveFee[0]) + balance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAcc, sdk.DefaultBondDenom) + suite.Require().Equal(expectedRefundAccBal, balance) + }, + }, + { + "invalid receiver address: ack fee returned to sender", + func() { + reverseRelayer = suite.chainA.GetSimApp().AccountKeeper.GetModuleAccount(suite.chainA.GetContext(), transfertypes.ModuleName).GetAddress() + }, + func() { + // check if the refund acc has been refunded the timeoutFee & ackFee + expectedRefundAccBal := refundAccBal.Add(defaultTimeoutFee[0]).Add(defaultAckFee[0]).Add(defaultTimeoutFee[0]).Add(defaultAckFee[0]) + balance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAcc, sdk.DefaultBondDenom) + suite.Require().Equal(expectedRefundAccBal, balance) + }, + }, + { + "invalid refund address: no-op, timeout fee remains in escrow", + func() { + packetFee.RefundAddress = suite.chainA.GetSimApp().AccountKeeper.GetModuleAccount(suite.chainA.GetContext(), transfertypes.ModuleName).GetAddress().String() + }, + func() { + // check if the module acc contains the timeoutFee + expectedModuleAccBal := sdk.NewCoin(sdk.DefaultBondDenom, defaultTimeoutFee.Add(defaultTimeoutFee...).AmountOf(sdk.DefaultBondDenom)) + balance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.GetSimApp().IBCFeeKeeper.GetFeeModuleAddress(), sdk.DefaultBondDenom) + suite.Require().Equal(expectedModuleAccBal, balance) }, }, } @@ -191,30 +220,29 @@ func (suite *KeeperTestSuite) TestDistributeFee() { suite.SetupTest() // reset suite.coordinator.Setup(suite.path) // setup channel - // setup - refundAcc = suite.chainA.SenderAccount.GetAddress() - reverseRelayer = sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()) + // setup accounts forwardRelayer = sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String() + reverseRelayer = sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()) + refundAcc = suite.chainA.SenderAccount.GetAddress() - packetID = channeltypes.NewPacketId(suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, validSeq) - fee = types.Fee{ - RecvFee: defaultReceiveFee, - AckFee: defaultAckFee, - TimeoutFee: defaultTimeoutFee, - } + packetID := channeltypes.NewPacketId(suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, 1) + fee := types.NewFee(defaultReceiveFee, defaultAckFee, defaultTimeoutFee) // escrow the packet fee & store the fee in state - packetFee := types.NewPacketFee(fee, refundAcc.String(), []string{}) - + packetFee = types.NewPacketFee(fee, refundAcc.String(), []string{}) err := suite.chainA.GetSimApp().IBCFeeKeeper.EscrowPacketFee(suite.chainA.GetContext(), packetID, packetFee) suite.Require().NoError(err) + // escrow a second packet fee to test with multiple fees distributed err = suite.chainA.GetSimApp().IBCFeeKeeper.EscrowPacketFee(suite.chainA.GetContext(), packetID, packetFee) suite.Require().NoError(err) tc.malleate() - // refundAcc balance after escrow + // fetch the account balances before fee distribution (forward, reverse, refund) + forwardAccAddress, _ := sdk.AccAddressFromBech32(forwardRelayer) + forwardRelayerBal = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), forwardAccAddress, sdk.DefaultBondDenom) + reverseRelayerBal = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), reverseRelayer, sdk.DefaultBondDenom) refundAccBal = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAcc, sdk.DefaultBondDenom) suite.chainA.GetSimApp().IBCFeeKeeper.DistributePacketFees(suite.chainA.GetContext(), forwardRelayer, reverseRelayer, []types.PacketFee{packetFee, packetFee})