Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Fix RefundFeesOnChannel #1029

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 14 additions & 27 deletions modules/apps/29-fee/ibc_module.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,21 +147,12 @@ func (im IBCModule) OnChanCloseInit(
return err
}

// delete fee enabled on channel
// and refund any remaining fees escrowed on channel
im.keeper.DeleteFeeEnabled(ctx, portID, channelID)
err := im.keeper.RefundFeesOnChannel(ctx, portID, channelID)
// error should only be non-nil if there is a bug in the code
// that causes module account to have insufficient funds to refund
// all escrowed fees on the channel.
// Disable all channels to allow for coordinated fix to the issue
// and mitigate/reverse damage.
// NOTE: Underlying application's packets will still go through, but
// fee module will be disabled for all channels
if err != nil {
im.keeper.DisableAllChannels(ctx)
if err := im.keeper.RefundFeesOnChannelClosure(ctx, portID, channelID); err != nil {
return err
}

im.keeper.DeleteFeeEnabled(ctx, portID, channelID)

return nil
}

Expand All @@ -171,21 +162,17 @@ func (im IBCModule) OnChanCloseConfirm(
portID,
channelID string,
) error {
// delete fee enabled on channel
// and refund any remaining fees escrowed on channel
im.keeper.DeleteFeeEnabled(ctx, portID, channelID)
err := im.keeper.RefundFeesOnChannel(ctx, portID, channelID)
// error should only be non-nil if there is a bug in the code
// that causes module account to have insufficient funds to refund
// all escrowed fees on the channel.
// Disable all channels to allow for coordinated fix to the issue
// and mitigate/reverse damage.
// NOTE: Underlying application's packets will still go through, but
// fee module will be disabled for all channels
if err != nil {
im.keeper.DisableAllChannels(ctx)
if err := im.app.OnChanCloseConfirm(ctx, portID, channelID); err != nil {
return nil
}
return im.app.OnChanCloseConfirm(ctx, portID, channelID)

if err := im.keeper.RefundFeesOnChannelClosure(ctx, portID, channelID); err != nil {
return err
}

im.keeper.DeleteFeeEnabled(ctx, portID, channelID)

return nil
}

// OnRecvPacket implements the IBCModule interface.
Expand Down
58 changes: 39 additions & 19 deletions modules/apps/29-fee/keeper/escrow.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,37 +109,57 @@ func (k Keeper) distributeFee(ctx sdk.Context, receiver sdk.AccAddress, fee sdk.
}
}

func (k Keeper) RefundFeesOnChannel(ctx sdk.Context, portID, channelID string) error {
// RefundFeesOnChannelClosure will refund all fees associated with the given port and channel identifiers.
// If the escrow account runs out of balance then fee logic will be disabled for all channels as this
// implies a severe bug.
func (k Keeper) RefundFeesOnChannelClosure(ctx sdk.Context, portID, channelID string) error {
identifiedPacketFees := k.GetIdentifiedPacketFeesForChannel(ctx, portID, channelID)

var refundErr error
// cache context before trying to distribute fees
// if the escrow account has insufficient balance then we want to avoid partially distributing fees
cacheCtx, writeFn := ctx.CacheContext()

for _, identifiedPacketFee := range identifiedPacketFees {
for _, packetFee := range identifiedPacketFee.PacketFees {

if !k.EscrowAccountHasBalance(cacheCtx, packetFee.Fee) {
// if the escrow account does not have sufficient funds then there must exist a severe bug
// the fee module should be locked until manual intervention fixes the issue
// a locked fee module will simply skip fee logic, all channels will temporarily function as
// fee disabled channels
// NOTE: we use the uncached context to lock the fee module so that the state changes from
// locking the fee module are persisted
lockFeeModule(ctx)

// return a nil error so state changes are committed but distribution stops
return nil
}

k.IteratePacketFeesInEscrow(ctx, portID, channelID, func(packetFees types.PacketFees) (stop bool) {
for _, identifiedFee := range packetFees.PacketFees {
refundAccAddr, err := sdk.AccAddressFromBech32(identifiedFee.RefundAddress)
refundAccAddr, err := sdk.AccAddressFromBech32(packetFee.RefundAddress)
if err != nil {
refundErr = err
return true
return err
}

// refund all fees to refund address
// Use SendCoins rather than the module account send functions since refund address may be a user account or module address.
// if any `SendCoins` call returns an error, we return error and stop iteration
if err = k.bankKeeper.SendCoinsFromModuleToAccount(ctx, types.ModuleName, refundAccAddr, identifiedFee.Fee.RecvFee); err != nil {
refundErr = err
return true
if err = k.bankKeeper.SendCoinsFromModuleToAccount(cacheCtx, types.ModuleName, refundAccAddr, packetFee.Fee.RecvFee); err != nil {
return err
}
if err = k.bankKeeper.SendCoinsFromModuleToAccount(ctx, types.ModuleName, refundAccAddr, identifiedFee.Fee.AckFee); err != nil {
refundErr = err
return true
if err = k.bankKeeper.SendCoinsFromModuleToAccount(cacheCtx, types.ModuleName, refundAccAddr, packetFee.Fee.AckFee); err != nil {
return err
}
if err = k.bankKeeper.SendCoinsFromModuleToAccount(ctx, types.ModuleName, refundAccAddr, identifiedFee.Fee.TimeoutFee); err != nil {
refundErr = err
return true
if err = k.bankKeeper.SendCoinsFromModuleToAccount(cacheCtx, types.ModuleName, refundAccAddr, packetFee.Fee.TimeoutFee); err != nil {
return err
}

}

return false
})
k.DeleteFeesInEscrow(cacheCtx, identifiedPacketFee.PacketId)
}

return refundErr
// write the cache
writeFn()

return nil
}
6 changes: 3 additions & 3 deletions modules/apps/29-fee/keeper/escrow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ func (suite *KeeperTestSuite) TestDistributeTimeoutFee() {
suite.Require().True(hasBalance)
}

func (suite *KeeperTestSuite) TestRefundFeesOnChannel() {
func (suite *KeeperTestSuite) TestRefundFeesOnChannelClosure() {
suite.coordinator.Setup(suite.path)

// setup
Expand Down Expand Up @@ -296,7 +296,7 @@ func (suite *KeeperTestSuite) TestRefundFeesOnChannel() {
suite.Require().NoError(err)

// check that refunding all fees on channel-0 refunds all fees except for fee on channel-1
err = suite.chainA.GetSimApp().IBCFeeKeeper.RefundFeesOnChannel(suite.chainA.GetContext(), suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID)
err = suite.chainA.GetSimApp().IBCFeeKeeper.RefundFeesOnChannelClosure(suite.chainA.GetContext(), suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID)
suite.Require().NoError(err, "refund fees returned unexpected error")

// add fee sent to channel-1 to after balance to recover original balance
Expand All @@ -312,6 +312,6 @@ func (suite *KeeperTestSuite) TestRefundFeesOnChannel() {

suite.chainA.GetSimApp().BankKeeper.SendCoinsFromModuleToAccount(suite.chainA.GetContext(), types.ModuleName, refundAcc, fee.TimeoutFee)

err = suite.chainA.GetSimApp().IBCFeeKeeper.RefundFeesOnChannel(suite.chainA.GetContext(), suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID)
err = suite.chainA.GetSimApp().IBCFeeKeeper.RefundFeesOnChannelClosure(suite.chainA.GetContext(), suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID)
suite.Require().Error(err, "refund fees returned no error with insufficient balance on module account")
}
72 changes: 37 additions & 35 deletions modules/apps/29-fee/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,35 @@ func (k Keeper) GetNextSequenceSend(ctx sdk.Context, portID, channelID string) (
return k.channelKeeper.GetNextSequenceSend(ctx, portID, channelID)
}

// GetFeeAccount returns the ICS29 Fee ModuleAccount address
// GetFeeModuleAddress returns the ICS29 Fee ModuleAccount address
func (k Keeper) GetFeeModuleAddress() sdk.AccAddress {
return k.authKeeper.GetModuleAddress(types.ModuleName)
}

// EscrowAccountHasBalance
func (k Keeper) EscrowAccountHasBalance(ctx sdk.Context, fee types.Fee) bool {
for _, coin := range fee.Total() {
if !k.bankKeeper.HasBalance(ctx, k.GetFeeModuleAddress(), coin) {
return false
}
}

return true
}

// lockFeeModule sets a flag to determine if fee handling logic should run for the given channel
// identified by channel and port identifiers.
func (k Keeper) lockFeeModule(ctx sdk.Context) {
store := ctx.KVStore(k.storeKey)
store.Set(types.KeyLocked(), []byte{1})
}

// IsLocked indicates if the fee module is locked
func (k Keeper) IsLocked(ctx sdk.Context) bool {
store := ctx.KVStore(k.storeKey)
return store.Has(types.KeyLocked())
}

// SetFeeEnabled sets a flag to determine if fee handling logic should run for the given channel
// identified by channel and port identifiers.
func (k Keeper) SetFeeEnabled(ctx sdk.Context, portID, channelID string) {
Expand Down Expand Up @@ -120,21 +144,6 @@ func (k Keeper) GetAllFeeEnabledChannels(ctx sdk.Context) []types.FeeEnabledChan
return enabledChArr
}

// DisableAllChannels will disable the fee module for all channels.
// Only called if the module enters into an invalid state
// e.g. ModuleAccount has insufficient balance to refund users.
// In this case, chain developers should investigate the issue, fix it,
// and then re-enable the fee module in a coordinated upgrade.
func (k Keeper) DisableAllChannels(ctx sdk.Context) {
store := ctx.KVStore(k.storeKey)
iterator := sdk.KVStorePrefixIterator(store, []byte(types.FeeEnabledKeyPrefix))

defer iterator.Close()
for ; iterator.Valid(); iterator.Next() {
store.Delete(iterator.Key())
}
}

// SetCounterpartyAddress maps the destination chain relayer address to the source relayer address
// The receiving chain must store the mapping from: address -> counterpartyAddress for the given channel
func (k Keeper) SetCounterpartyAddress(ctx sdk.Context, address, counterpartyAddress, channelID string) {
Expand Down Expand Up @@ -284,34 +293,27 @@ func (k Keeper) DeleteFeesInEscrow(ctx sdk.Context, packetID channeltypes.Packet
store.Delete(key)
}

// IteratePacketFeesInEscrow iterates over all the fees on the given channel currently escrowed and calls the provided callback
// if the callback returns true, then iteration is stopped.
func (k Keeper) IteratePacketFeesInEscrow(ctx sdk.Context, portID, channelID string, cb func(packetFees types.PacketFees) (stop bool)) {
// GetIdentifiedPacketFeesForChannel returns all the currently escrowed fees on a given channel.
func (k Keeper) GetIdentifiedPacketFeesForChannel(ctx sdk.Context, portID, channelID string) []types.IdentifiedPacketFees {
var identifiedPacketFees []types.IdentifiedPacketFees

store := ctx.KVStore(k.storeKey)
iterator := sdk.KVStorePrefixIterator(store, types.KeyFeesInEscrowChannelPrefix(portID, channelID))

defer iterator.Close()
for ; iterator.Valid(); iterator.Next() {
packetFees := k.MustUnmarshalFees(iterator.Value())
if cb(packetFees) {
break
packetId, err := types.ParseKeyFeesInEscrow(string(iterator.Key()))
if err != nil {
panic(err)
}
}
}

// IterateChannelFeesInEscrow iterates over all the fees on the given channel currently escrowed and calls the provided callback
// if the callback returns true, then iteration is stopped.
func (k Keeper) IterateChannelFeesInEscrow(ctx sdk.Context, portID, channelID string, cb func(identifiedFee types.IdentifiedPacketFee) (stop bool)) {
store := ctx.KVStore(k.storeKey)
iterator := sdk.KVStorePrefixIterator(store, types.KeyFeeInEscrowChannelPrefix(portID, channelID))
packetFees := k.MustUnmarshalFees(iterator.Value())

defer iterator.Close()
for ; iterator.Valid(); iterator.Next() {
identifiedFee := k.MustUnmarshalFee(iterator.Value())
if cb(identifiedFee) {
break
}
identifiedFee := types.NewIdentifiedPacketFees(packetId, packetFees.PacketFees)
identifiedPacketFees = append(identifiedPacketFees, identifiedFee)
}

return identifiedPacketFees
}

// Deletes the fee associated with the given packetId
Expand Down
29 changes: 7 additions & 22 deletions modules/apps/29-fee/keeper/keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,28 +83,13 @@ func (suite *KeeperTestSuite) TestFeeInEscrow() {
suite.chainA.GetSimApp().IBCFeeKeeper.DeleteFeeInEscrow(suite.chainA.GetContext(), packetId)

// iterate over remaining fees
arr := []int64{}
expectedArr := []int64{1, 2, 4, 5}
suite.chainA.GetSimApp().IBCFeeKeeper.IterateChannelFeesInEscrow(suite.chainA.GetContext(), suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, func(identifiedFee types.IdentifiedPacketFee) (stop bool) {
arr = append(arr, int64(identifiedFee.PacketId.Sequence))
return false
})
suite.Require().Equal(expectedArr, arr, "did not retrieve expected fees during iteration")
}

func (suite *KeeperTestSuite) TestDisableAllChannels() {
suite.chainA.GetSimApp().IBCFeeKeeper.SetFeeEnabled(suite.chainA.GetContext(), "port1", "channel1")
suite.chainA.GetSimApp().IBCFeeKeeper.SetFeeEnabled(suite.chainA.GetContext(), "port2", "channel2")
suite.chainA.GetSimApp().IBCFeeKeeper.SetFeeEnabled(suite.chainA.GetContext(), "port3", "channel3")

suite.chainA.GetSimApp().IBCFeeKeeper.DisableAllChannels(suite.chainA.GetContext())

suite.Require().False(suite.chainA.GetSimApp().IBCFeeKeeper.IsFeeEnabled(suite.chainA.GetContext(), "port1", "channel1"),
"fee is still enabled on channel-1 after DisableAllChannels call")
suite.Require().False(suite.chainA.GetSimApp().IBCFeeKeeper.IsFeeEnabled(suite.chainA.GetContext(), "port2", "channel2"),
"fee is still enabled on channel-2 after DisableAllChannels call")
suite.Require().False(suite.chainA.GetSimApp().IBCFeeKeeper.IsFeeEnabled(suite.chainA.GetContext(), "port3", "channel3"),
"fee is still enabled on channel-3 after DisableAllChannels call")
//arr := []int64{}
//expectedArr := []int64{1, 2, 4, 5}
//suite.chainA.GetSimApp().IBCFeeKeeper.IterateChannelFeesInEscrow(suite.chainA.GetContext(), suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, func(identifiedFee types.IdentifiedPacketFee) (stop bool) {
// arr = append(arr, int64(identifiedFee.PacketId.Sequence))
// return false
//})
//suite.Require().Equal(expectedArr, arr, "did not retrieve expected fees during iteration")
}

func (suite *KeeperTestSuite) TestGetAllIdentifiedPacketFees() {
Expand Down
8 changes: 8 additions & 0 deletions modules/apps/29-fee/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ func (k Keeper) RegisterCounterpartyAddress(goCtx context.Context, msg *types.Ms
func (k Keeper) PayPacketFee(goCtx context.Context, msg *types.MsgPayPacketFee) (*types.MsgPayPacketFeeResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

if k.IsLocked(ctx) {
return nil, types.ErrFeeModuleLocked
}

// get the next sequence
sequence, found := k.GetNextSequenceSend(ctx, msg.SourcePortId, msg.SourceChannelId)
if !found {
Expand All @@ -57,6 +61,10 @@ func (k Keeper) PayPacketFee(goCtx context.Context, msg *types.MsgPayPacketFee)
func (k Keeper) PayPacketFeeAsync(goCtx context.Context, msg *types.MsgPayPacketFeeAsync) (*types.MsgPayPacketFeeAsyncResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

if k.IsLocked(ctx) {
return nil, types.ErrFeeModuleLocked
}

if err := k.EscrowPacketFee(ctx, msg.PacketId, msg.PacketFee); err != nil {
return nil, err
}
Expand Down
1 change: 1 addition & 0 deletions modules/apps/29-fee/types/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ var (
ErrForwardRelayerAddressNotFound = sdkerrors.Register(ModuleName, 8, "forward relayer address not found")
ErrFeeNotEnabled = sdkerrors.Register(ModuleName, 9, "fee module is not enabled for this channel. If this error occurs after channel setup, fee module may not be enabled")
ErrRelayerNotFoundForAsyncAck = sdkerrors.Register(ModuleName, 10, "relayer address must be stored for async WriteAcknowledgement")
ErrFeeModuleLocked = sdkerrors.Register(ModuleName, 11, "the fee module is currently locked")
)
6 changes: 6 additions & 0 deletions modules/apps/29-fee/types/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ const (
AttributeKeyTimeoutFee = "timeout_fee"
)

// KeyLocked returns the key used to lock and unlock the fee module. This key is used
// in the presence of a severe bug.
func KeyLocked() []byte {
return []byte("locked")
}

// KeyFeeEnabled returns the key that stores a flag to determine if fee logic should
// be enabled for the given port and channel identifiers.
func KeyFeeEnabled(portID, channelID string) []byte {
Expand Down