diff --git a/modules/apps/29-fee/ibc_module.go b/modules/apps/29-fee/ibc_module.go index bdb3bcadabd..2edce45c6eb 100644 --- a/modules/apps/29-fee/ibc_module.go +++ b/modules/apps/29-fee/ibc_module.go @@ -147,21 +147,12 @@ func (im IBCModule) OnChanCloseInit( return err } - // delete fee enabled on channel - // and refund any remaining fees escrowed on channel - im.keeper.DeleteFeeEnabled(ctx, portID, channelID) - err := im.keeper.RefundFeesOnChannel(ctx, portID, channelID) - // error should only be non-nil if there is a bug in the code - // that causes module account to have insufficient funds to refund - // all escrowed fees on the channel. - // Disable all channels to allow for coordinated fix to the issue - // and mitigate/reverse damage. - // NOTE: Underlying application's packets will still go through, but - // fee module will be disabled for all channels - if err != nil { - im.keeper.DisableAllChannels(ctx) + if err := im.keeper.RefundFeesOnChannelClosure(ctx, portID, channelID); err != nil { + return err } + im.keeper.DeleteFeeEnabled(ctx, portID, channelID) + return nil } @@ -171,21 +162,17 @@ func (im IBCModule) OnChanCloseConfirm( portID, channelID string, ) error { - // delete fee enabled on channel - // and refund any remaining fees escrowed on channel - im.keeper.DeleteFeeEnabled(ctx, portID, channelID) - err := im.keeper.RefundFeesOnChannel(ctx, portID, channelID) - // error should only be non-nil if there is a bug in the code - // that causes module account to have insufficient funds to refund - // all escrowed fees on the channel. - // Disable all channels to allow for coordinated fix to the issue - // and mitigate/reverse damage. - // NOTE: Underlying application's packets will still go through, but - // fee module will be disabled for all channels - if err != nil { - im.keeper.DisableAllChannels(ctx) + if err := im.app.OnChanCloseConfirm(ctx, portID, channelID); err != nil { + return nil } - return im.app.OnChanCloseConfirm(ctx, portID, channelID) + + if err := im.keeper.RefundFeesOnChannelClosure(ctx, portID, channelID); err != nil { + return err + } + + im.keeper.DeleteFeeEnabled(ctx, portID, channelID) + + return nil } // OnRecvPacket implements the IBCModule interface. diff --git a/modules/apps/29-fee/keeper/escrow.go b/modules/apps/29-fee/keeper/escrow.go index d3e96bad312..671ee2447a0 100644 --- a/modules/apps/29-fee/keeper/escrow.go +++ b/modules/apps/29-fee/keeper/escrow.go @@ -109,37 +109,57 @@ func (k Keeper) distributeFee(ctx sdk.Context, receiver sdk.AccAddress, fee sdk. } } -func (k Keeper) RefundFeesOnChannel(ctx sdk.Context, portID, channelID string) error { +// RefundFeesOnChannelClosure will refund all fees associated with the given port and channel identifiers. +// If the escrow account runs out of balance then fee logic will be disabled for all channels as this +// implies a severe bug. +func (k Keeper) RefundFeesOnChannelClosure(ctx sdk.Context, portID, channelID string) error { + identifiedPacketFees := k.GetIdentifiedPacketFeesForChannel(ctx, portID, channelID) - var refundErr error + // cache context before trying to distribute fees + // if the escrow account has insufficient balance then we want to avoid partially distributing fees + cacheCtx, writeFn := ctx.CacheContext() + + for _, identifiedPacketFee := range identifiedPacketFees { + for _, packetFee := range identifiedPacketFee.PacketFees { + + if !k.EscrowAccountHasBalance(cacheCtx, packetFee.Fee) { + // if the escrow account does not have sufficient funds then there must exist a severe bug + // the fee module should be locked until manual intervention fixes the issue + // a locked fee module will simply skip fee logic, all channels will temporarily function as + // fee disabled channels + // NOTE: we use the uncached context to lock the fee module so that the state changes from + // locking the fee module are persisted + lockFeeModule(ctx) + + // return a nil error so state changes are committed but distribution stops + return nil + } - k.IteratePacketFeesInEscrow(ctx, portID, channelID, func(packetFees types.PacketFees) (stop bool) { - for _, identifiedFee := range packetFees.PacketFees { - refundAccAddr, err := sdk.AccAddressFromBech32(identifiedFee.RefundAddress) + refundAccAddr, err := sdk.AccAddressFromBech32(packetFee.RefundAddress) if err != nil { - refundErr = err - return true + return err } // refund all fees to refund address // Use SendCoins rather than the module account send functions since refund address may be a user account or module address. // if any `SendCoins` call returns an error, we return error and stop iteration - if err = k.bankKeeper.SendCoinsFromModuleToAccount(ctx, types.ModuleName, refundAccAddr, identifiedFee.Fee.RecvFee); err != nil { - refundErr = err - return true + if err = k.bankKeeper.SendCoinsFromModuleToAccount(cacheCtx, types.ModuleName, refundAccAddr, packetFee.Fee.RecvFee); err != nil { + return err } - if err = k.bankKeeper.SendCoinsFromModuleToAccount(ctx, types.ModuleName, refundAccAddr, identifiedFee.Fee.AckFee); err != nil { - refundErr = err - return true + if err = k.bankKeeper.SendCoinsFromModuleToAccount(cacheCtx, types.ModuleName, refundAccAddr, packetFee.Fee.AckFee); err != nil { + return err } - if err = k.bankKeeper.SendCoinsFromModuleToAccount(ctx, types.ModuleName, refundAccAddr, identifiedFee.Fee.TimeoutFee); err != nil { - refundErr = err - return true + if err = k.bankKeeper.SendCoinsFromModuleToAccount(cacheCtx, types.ModuleName, refundAccAddr, packetFee.Fee.TimeoutFee); err != nil { + return err } + } - return false - }) + k.DeleteFeesInEscrow(cacheCtx, identifiedPacketFee.PacketId) + } - return refundErr + // write the cache + writeFn() + + return nil } diff --git a/modules/apps/29-fee/keeper/escrow_test.go b/modules/apps/29-fee/keeper/escrow_test.go index 81103197aa5..8e613253352 100644 --- a/modules/apps/29-fee/keeper/escrow_test.go +++ b/modules/apps/29-fee/keeper/escrow_test.go @@ -259,7 +259,7 @@ func (suite *KeeperTestSuite) TestDistributeTimeoutFee() { suite.Require().True(hasBalance) } -func (suite *KeeperTestSuite) TestRefundFeesOnChannel() { +func (suite *KeeperTestSuite) TestRefundFeesOnChannelClosure() { suite.coordinator.Setup(suite.path) // setup @@ -296,7 +296,7 @@ func (suite *KeeperTestSuite) TestRefundFeesOnChannel() { suite.Require().NoError(err) // check that refunding all fees on channel-0 refunds all fees except for fee on channel-1 - err = suite.chainA.GetSimApp().IBCFeeKeeper.RefundFeesOnChannel(suite.chainA.GetContext(), suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID) + err = suite.chainA.GetSimApp().IBCFeeKeeper.RefundFeesOnChannelClosure(suite.chainA.GetContext(), suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID) suite.Require().NoError(err, "refund fees returned unexpected error") // add fee sent to channel-1 to after balance to recover original balance @@ -312,6 +312,6 @@ func (suite *KeeperTestSuite) TestRefundFeesOnChannel() { suite.chainA.GetSimApp().BankKeeper.SendCoinsFromModuleToAccount(suite.chainA.GetContext(), types.ModuleName, refundAcc, fee.TimeoutFee) - err = suite.chainA.GetSimApp().IBCFeeKeeper.RefundFeesOnChannel(suite.chainA.GetContext(), suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID) + err = suite.chainA.GetSimApp().IBCFeeKeeper.RefundFeesOnChannelClosure(suite.chainA.GetContext(), suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID) suite.Require().Error(err, "refund fees returned no error with insufficient balance on module account") } diff --git a/modules/apps/29-fee/keeper/keeper.go b/modules/apps/29-fee/keeper/keeper.go index 9b7583ea055..ec6884a96e7 100644 --- a/modules/apps/29-fee/keeper/keeper.go +++ b/modules/apps/29-fee/keeper/keeper.go @@ -72,11 +72,35 @@ func (k Keeper) GetNextSequenceSend(ctx sdk.Context, portID, channelID string) ( return k.channelKeeper.GetNextSequenceSend(ctx, portID, channelID) } -// GetFeeAccount returns the ICS29 Fee ModuleAccount address +// GetFeeModuleAddress returns the ICS29 Fee ModuleAccount address func (k Keeper) GetFeeModuleAddress() sdk.AccAddress { return k.authKeeper.GetModuleAddress(types.ModuleName) } +// EscrowAccountHasBalance +func (k Keeper) EscrowAccountHasBalance(ctx sdk.Context, fee types.Fee) bool { + for _, coin := range fee.Total() { + if !k.bankKeeper.HasBalance(ctx, k.GetFeeModuleAddress(), coin) { + return false + } + } + + return true +} + +// lockFeeModule sets a flag to determine if fee handling logic should run for the given channel +// identified by channel and port identifiers. +func (k Keeper) lockFeeModule(ctx sdk.Context) { + store := ctx.KVStore(k.storeKey) + store.Set(types.KeyLocked(), []byte{1}) +} + +// IsLocked indicates if the fee module is locked +func (k Keeper) IsLocked(ctx sdk.Context) bool { + store := ctx.KVStore(k.storeKey) + return store.Has(types.KeyLocked()) +} + // SetFeeEnabled sets a flag to determine if fee handling logic should run for the given channel // identified by channel and port identifiers. func (k Keeper) SetFeeEnabled(ctx sdk.Context, portID, channelID string) { @@ -120,21 +144,6 @@ func (k Keeper) GetAllFeeEnabledChannels(ctx sdk.Context) []types.FeeEnabledChan return enabledChArr } -// DisableAllChannels will disable the fee module for all channels. -// Only called if the module enters into an invalid state -// e.g. ModuleAccount has insufficient balance to refund users. -// In this case, chain developers should investigate the issue, fix it, -// and then re-enable the fee module in a coordinated upgrade. -func (k Keeper) DisableAllChannels(ctx sdk.Context) { - store := ctx.KVStore(k.storeKey) - iterator := sdk.KVStorePrefixIterator(store, []byte(types.FeeEnabledKeyPrefix)) - - defer iterator.Close() - for ; iterator.Valid(); iterator.Next() { - store.Delete(iterator.Key()) - } -} - // SetCounterpartyAddress maps the destination chain relayer address to the source relayer address // The receiving chain must store the mapping from: address -> counterpartyAddress for the given channel func (k Keeper) SetCounterpartyAddress(ctx sdk.Context, address, counterpartyAddress, channelID string) { @@ -284,34 +293,27 @@ func (k Keeper) DeleteFeesInEscrow(ctx sdk.Context, packetID channeltypes.Packet store.Delete(key) } -// IteratePacketFeesInEscrow iterates over all the fees on the given channel currently escrowed and calls the provided callback -// if the callback returns true, then iteration is stopped. -func (k Keeper) IteratePacketFeesInEscrow(ctx sdk.Context, portID, channelID string, cb func(packetFees types.PacketFees) (stop bool)) { +// GetIdentifiedPacketFeesForChannel returns all the currently escrowed fees on a given channel. +func (k Keeper) GetIdentifiedPacketFeesForChannel(ctx sdk.Context, portID, channelID string) []types.IdentifiedPacketFees { + var identifiedPacketFees []types.IdentifiedPacketFees + store := ctx.KVStore(k.storeKey) iterator := sdk.KVStorePrefixIterator(store, types.KeyFeesInEscrowChannelPrefix(portID, channelID)) defer iterator.Close() for ; iterator.Valid(); iterator.Next() { - packetFees := k.MustUnmarshalFees(iterator.Value()) - if cb(packetFees) { - break + packetId, err := types.ParseKeyFeesInEscrow(string(iterator.Key())) + if err != nil { + panic(err) } - } -} -// IterateChannelFeesInEscrow iterates over all the fees on the given channel currently escrowed and calls the provided callback -// if the callback returns true, then iteration is stopped. -func (k Keeper) IterateChannelFeesInEscrow(ctx sdk.Context, portID, channelID string, cb func(identifiedFee types.IdentifiedPacketFee) (stop bool)) { - store := ctx.KVStore(k.storeKey) - iterator := sdk.KVStorePrefixIterator(store, types.KeyFeeInEscrowChannelPrefix(portID, channelID)) + packetFees := k.MustUnmarshalFees(iterator.Value()) - defer iterator.Close() - for ; iterator.Valid(); iterator.Next() { - identifiedFee := k.MustUnmarshalFee(iterator.Value()) - if cb(identifiedFee) { - break - } + identifiedFee := types.NewIdentifiedPacketFees(packetId, packetFees.PacketFees) + identifiedPacketFees = append(identifiedPacketFees, identifiedFee) } + + return identifiedPacketFees } // Deletes the fee associated with the given packetId diff --git a/modules/apps/29-fee/keeper/keeper_test.go b/modules/apps/29-fee/keeper/keeper_test.go index c9ad4a5f10b..973716a7cb6 100644 --- a/modules/apps/29-fee/keeper/keeper_test.go +++ b/modules/apps/29-fee/keeper/keeper_test.go @@ -83,28 +83,13 @@ func (suite *KeeperTestSuite) TestFeeInEscrow() { suite.chainA.GetSimApp().IBCFeeKeeper.DeleteFeeInEscrow(suite.chainA.GetContext(), packetId) // iterate over remaining fees - arr := []int64{} - expectedArr := []int64{1, 2, 4, 5} - suite.chainA.GetSimApp().IBCFeeKeeper.IterateChannelFeesInEscrow(suite.chainA.GetContext(), suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, func(identifiedFee types.IdentifiedPacketFee) (stop bool) { - arr = append(arr, int64(identifiedFee.PacketId.Sequence)) - return false - }) - suite.Require().Equal(expectedArr, arr, "did not retrieve expected fees during iteration") -} - -func (suite *KeeperTestSuite) TestDisableAllChannels() { - suite.chainA.GetSimApp().IBCFeeKeeper.SetFeeEnabled(suite.chainA.GetContext(), "port1", "channel1") - suite.chainA.GetSimApp().IBCFeeKeeper.SetFeeEnabled(suite.chainA.GetContext(), "port2", "channel2") - suite.chainA.GetSimApp().IBCFeeKeeper.SetFeeEnabled(suite.chainA.GetContext(), "port3", "channel3") - - suite.chainA.GetSimApp().IBCFeeKeeper.DisableAllChannels(suite.chainA.GetContext()) - - suite.Require().False(suite.chainA.GetSimApp().IBCFeeKeeper.IsFeeEnabled(suite.chainA.GetContext(), "port1", "channel1"), - "fee is still enabled on channel-1 after DisableAllChannels call") - suite.Require().False(suite.chainA.GetSimApp().IBCFeeKeeper.IsFeeEnabled(suite.chainA.GetContext(), "port2", "channel2"), - "fee is still enabled on channel-2 after DisableAllChannels call") - suite.Require().False(suite.chainA.GetSimApp().IBCFeeKeeper.IsFeeEnabled(suite.chainA.GetContext(), "port3", "channel3"), - "fee is still enabled on channel-3 after DisableAllChannels call") + //arr := []int64{} + //expectedArr := []int64{1, 2, 4, 5} + //suite.chainA.GetSimApp().IBCFeeKeeper.IterateChannelFeesInEscrow(suite.chainA.GetContext(), suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, func(identifiedFee types.IdentifiedPacketFee) (stop bool) { + // arr = append(arr, int64(identifiedFee.PacketId.Sequence)) + // return false + //}) + //suite.Require().Equal(expectedArr, arr, "did not retrieve expected fees during iteration") } func (suite *KeeperTestSuite) TestGetAllIdentifiedPacketFees() { diff --git a/modules/apps/29-fee/keeper/msg_server.go b/modules/apps/29-fee/keeper/msg_server.go index fdac1d27874..c8ffb853799 100644 --- a/modules/apps/29-fee/keeper/msg_server.go +++ b/modules/apps/29-fee/keeper/msg_server.go @@ -31,6 +31,10 @@ func (k Keeper) RegisterCounterpartyAddress(goCtx context.Context, msg *types.Ms func (k Keeper) PayPacketFee(goCtx context.Context, msg *types.MsgPayPacketFee) (*types.MsgPayPacketFeeResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) + if k.IsLocked(ctx) { + return nil, types.ErrFeeModuleLocked + } + // get the next sequence sequence, found := k.GetNextSequenceSend(ctx, msg.SourcePortId, msg.SourceChannelId) if !found { @@ -57,6 +61,10 @@ func (k Keeper) PayPacketFee(goCtx context.Context, msg *types.MsgPayPacketFee) func (k Keeper) PayPacketFeeAsync(goCtx context.Context, msg *types.MsgPayPacketFeeAsync) (*types.MsgPayPacketFeeAsyncResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) + if k.IsLocked(ctx) { + return nil, types.ErrFeeModuleLocked + } + if err := k.EscrowPacketFee(ctx, msg.PacketId, msg.PacketFee); err != nil { return nil, err } diff --git a/modules/apps/29-fee/types/errors.go b/modules/apps/29-fee/types/errors.go index 75fdd436c91..1e35c92e2b1 100644 --- a/modules/apps/29-fee/types/errors.go +++ b/modules/apps/29-fee/types/errors.go @@ -15,4 +15,5 @@ var ( ErrForwardRelayerAddressNotFound = sdkerrors.Register(ModuleName, 8, "forward relayer address not found") ErrFeeNotEnabled = sdkerrors.Register(ModuleName, 9, "fee module is not enabled for this channel. If this error occurs after channel setup, fee module may not be enabled") ErrRelayerNotFoundForAsyncAck = sdkerrors.Register(ModuleName, 10, "relayer address must be stored for async WriteAcknowledgement") + ErrFeeModuleLocked = sdkerrors.Register(ModuleName, 11, "the fee module is currently locked") ) diff --git a/modules/apps/29-fee/types/keys.go b/modules/apps/29-fee/types/keys.go index f39cd9b5d44..a93fec7c79f 100644 --- a/modules/apps/29-fee/types/keys.go +++ b/modules/apps/29-fee/types/keys.go @@ -45,6 +45,12 @@ const ( AttributeKeyTimeoutFee = "timeout_fee" ) +// KeyLocked returns the key used to lock and unlock the fee module. This key is used +// in the presence of a severe bug. +func KeyLocked() []byte { + return []byte("locked") +} + // KeyFeeEnabled returns the key that stores a flag to determine if fee logic should // be enabled for the given port and channel identifiers. func KeyFeeEnabled(portID, channelID string) []byte {