diff --git a/modules/core/04-channel/keeper/upgrade.go b/modules/core/04-channel/keeper/upgrade.go index 51c2747aadc..de7905bc2ca 100644 --- a/modules/core/04-channel/keeper/upgrade.go +++ b/modules/core/04-channel/keeper/upgrade.go @@ -596,6 +596,94 @@ func (k Keeper) startFlushUpgradeHandshake( return nil } +// ChanUpgradeOpen is called by a module to complete the channel upgrade handshake and move the channel back to an OPEN state. +// This method should only be called after both channels have flushed any in-flight packets. +func (k Keeper) ChanUpgradeOpen( + ctx sdk.Context, + portID, + channelID string, + counterpartyChannelState types.State, + proofChannel []byte, + proofHeight clienttypes.Height, +) error { + if k.hasInflightPackets(ctx, portID, channelID) { + return errorsmod.Wrapf(types.ErrPendingInflightPackets, "port ID (%s) channel ID (%s)", portID, channelID) + } + + channel, found := k.GetChannel(ctx, portID, channelID) + if !found { + return errorsmod.Wrapf(types.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID) + } + + if !collections.Contains(channel.State, []types.State{types.TRYUPGRADE, types.ACKUPGRADE}) { + return errorsmod.Wrapf(types.ErrInvalidChannelState, "expected one of [%s, %s], got %s", types.TRYUPGRADE, types.ACKUPGRADE, channel.State) + } + + if channel.FlushStatus != types.FLUSHCOMPLETE { + return errorsmod.Wrapf(types.ErrInvalidFlushStatus, "expected %s, got %s", types.FLUSHCOMPLETE, channel.FlushStatus) + } + + connection, err := k.GetConnection(ctx, channel.ConnectionHops[0]) + if err != nil { + return errorsmod.Wrap(err, "failed to retrieve connection using the channel connection hops") + } + + if connection.GetState() != int32(connectiontypes.OPEN) { + return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String()) + } + + var counterpartyChannel types.Channel + counterpartyHops := []string{connection.GetCounterparty().GetConnectionID()} + switch counterpartyChannelState { + case types.OPEN: + upgrade, found := k.GetUpgrade(ctx, portID, channelID) + if !found { + return errorsmod.Wrapf(types.ErrUpgradeNotFound, "failed to retrieve channel upgrade: port ID (%s) channel ID (%s)", portID, channelID) + } + counterpartyChannel = types.Channel{ + State: types.OPEN, + Ordering: channel.Ordering, + ConnectionHops: upgrade.Fields.ConnectionHops, + Counterparty: types.NewCounterparty(portID, channelID), + Version: upgrade.Fields.GetVersion(), + UpgradeSequence: channel.UpgradeSequence, + FlushStatus: types.NOTINFLUSH, + } + + case types.TRYUPGRADE: + // If the counterparty is in TRYUPGRADE, then we must have gone through the ACKUPGRADE step. + if channel.State != types.ACKUPGRADE { + return errorsmod.Wrapf(types.ErrInvalidChannelState, "expected %s, got %s", types.ACKUPGRADE, channel.State) + } + counterpartyChannel = types.Channel{ + State: types.TRYUPGRADE, + Ordering: channel.Ordering, + ConnectionHops: counterpartyHops, + Counterparty: types.NewCounterparty(portID, channelID), + Version: channel.Version, + UpgradeSequence: channel.UpgradeSequence, + FlushStatus: types.FLUSHCOMPLETE, + } + + case types.ACKUPGRADE: + counterpartyChannel = types.Channel{ + State: types.ACKUPGRADE, + Ordering: channel.Ordering, + ConnectionHops: counterpartyHops, + Counterparty: types.NewCounterparty(portID, channelID), + Version: channel.Version, + UpgradeSequence: channel.UpgradeSequence, + FlushStatus: types.FLUSHCOMPLETE, + } + default: + panic(fmt.Sprintf("counterparty channel state should be in one of [%s, %s, %s]; got %s", types.TRYUPGRADE, types.ACKUPGRADE, types.OPEN, counterpartyChannelState)) + } + + return k.connectionKeeper.VerifyChannelState( + ctx, connection, proofHeight, proofChannel, portID, channelID, counterpartyChannel, + ) +} + // WriteUpgradeOpenChannel writes the agreed upon upgrade fields to the channel, sets the channel flush status to NOTINFLUSH and sets the channel state back to OPEN. This can be called in one of two cases: // - In the UpgradeAck step of the handshake if both sides have already flushed all in-flight packets. // - In the UpgradeOpen step of the handshake. diff --git a/modules/core/04-channel/types/errors.go b/modules/core/04-channel/types/errors.go index 50f3e9f8a0b..e1efa4e369e 100644 --- a/modules/core/04-channel/types/errors.go +++ b/modules/core/04-channel/types/errors.go @@ -50,4 +50,5 @@ var ( ErrUpgradeRestoreFailed = errorsmod.Register(SubModuleName, 34, "restore failed") ErrUpgradeTimeout = errorsmod.Register(SubModuleName, 35, "upgrade timed-out") ErrInvalidUpgradeTimeout = errorsmod.Register(SubModuleName, 36, "upgrade timeout is invalid") + ErrPendingInflightPackets = errorsmod.Register(SubModuleName, 37, "pending inflight packets exist") )