From e2c83264c5028a90850c7e822f583427486cbea2 Mon Sep 17 00:00:00 2001 From: Cian Hatton Date: Thu, 17 Aug 2023 14:48:39 +0000 Subject: [PATCH] Channel Upgrade Ack (#4372) --- modules/core/04-channel/keeper/export_test.go | 4 +- modules/core/04-channel/keeper/upgrade.go | 49 +++++------ .../core/04-channel/keeper/upgrade_test.go | 84 ++++++++----------- 3 files changed, 58 insertions(+), 79 deletions(-) diff --git a/modules/core/04-channel/keeper/export_test.go b/modules/core/04-channel/keeper/export_test.go index cc8a606f760..76eb21dac45 100644 --- a/modules/core/04-channel/keeper/export_test.go +++ b/modules/core/04-channel/keeper/export_test.go @@ -11,8 +11,8 @@ import ( ) // StartFlushing is a wrapper around startFlushing to allow the function to be directly called in tests. -func (k Keeper) StartFlushing(ctx sdk.Context, portID, channelID string) error { - return k.startFlushing(ctx, portID, channelID) +func (k Keeper) StartFlushing(ctx sdk.Context, portID, channelID string, upgrade *types.Upgrade) error { + return k.startFlushing(ctx, portID, channelID, upgrade) } // ValidateSelfUpgradeFields is a wrapper around validateSelfUpgradeFields to allow the function to be directly called in tests. diff --git a/modules/core/04-channel/keeper/upgrade.go b/modules/core/04-channel/keeper/upgrade.go index 3c24e0826eb..dfabc5e9c47 100644 --- a/modules/core/04-channel/keeper/upgrade.go +++ b/modules/core/04-channel/keeper/upgrade.go @@ -177,7 +177,7 @@ func (k Keeper) ChanUpgradeTry( return types.Upgrade{}, errorsmod.Wrap(err, "failed to verify counterparty upgrade") } - if err := k.startFlushing(ctx, portID, channelID); err != nil { + if err := k.startFlushing(ctx, portID, channelID, &upgrade); err != nil { return types.Upgrade{}, err } @@ -242,7 +242,6 @@ func (k Keeper) ChanUpgradeAck( Counterparty: types.NewCounterparty(portID, channelID), Version: channel.Version, UpgradeSequence: channel.UpgradeSequence, - FlushStatus: types.NOTINFLUSH, // TODO: remove flush status from channel end } // verify the counterparty channel state containing the upgrade sequence @@ -274,26 +273,29 @@ func (k Keeper) ChanUpgradeAck( return errorsmod.Wrapf(types.ErrUpgradeNotFound, "failed to retrieve channel upgrade: port ID (%s) channel ID (%s)", portID, channelID) } + // optimistically accept version that TRY chain proposes and pass this to callback for confirmation + // in the crossing hello case, we do not modify version that our TRY call returned and instead enforce + // that both TRY calls returned the same version + if channel.IsOpen() { + upgrade.Fields.Version = counterpartyUpgrade.Fields.Version + } + + // if upgrades are not compatible by ACK step, then we restore the channel if err := k.checkForUpgradeCompatibility(ctx, upgrade.Fields, counterpartyUpgrade.Fields); err != nil { return types.NewUpgradeError(channel.UpgradeSequence, err) } + if channel.IsOpen() { + if err := k.startFlushing(ctx, portID, channelID, &upgrade); err != nil { + return err + } + } + timeout := counterpartyUpgrade.Timeout if hasPassed, err := timeout.HasPassed(ctx); hasPassed { return types.NewUpgradeError(channel.UpgradeSequence, errorsmod.Wrap(err, "counterparty upgrade timeout has passed")) } - if err := k.startFlushing(ctx, portID, channelID); err != nil { - return err - } - - // in the crossing hellos case, the versions returned by both on TRY must be the same - if channel.State == types.TRYUPGRADE { - if upgrade.Fields.Version != counterpartyUpgrade.Fields.Version { - return types.NewUpgradeError(channel.UpgradeSequence, errorsmod.Wrap(types.ErrIncompatibleCounterpartyUpgrade, "both channel ends must agree on the same version")) - } - } - return nil } @@ -308,16 +310,12 @@ func (k Keeper) WriteUpgradeAckChannel(ctx sdk.Context, portID, channelID string panic(fmt.Sprintf("could not find existing channel when updating channel state in successful ChanUpgradeAck step, channelID: %s, portID: %s", channelID, portID)) } - previousState := channel.State - channel.State = types.ACKUPGRADE - channel.FlushStatus = types.FLUSHING - if !k.HasInflightPackets(ctx, portID, channelID) { - channel.FlushStatus = types.FLUSHCOMPLETE + channel.State = types.STATE_FLUSHCOMPLETE + } else { + k.SetCounterpartyUpgrade(ctx, portID, channelID, counterpartyUpgrade) } - k.SetCounterpartyLastPacketSequence(ctx, portID, channelID, counterpartyUpgrade.LatestSequenceSend) - k.SetCounterpartyUpgrade(ctx, portID, channelID, counterpartyUpgrade) k.SetChannel(ctx, portID, channelID, channel) upgrade, found := k.GetUpgrade(ctx, portID, channelID) @@ -329,7 +327,7 @@ func (k Keeper) WriteUpgradeAckChannel(ctx sdk.Context, portID, channelID string k.SetUpgrade(ctx, portID, channelID, upgrade) - k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "previous-state", previousState, "new-state", types.ACKUPGRADE.String()) + k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "state", channel.State.String()) emitChannelUpgradeAckEvent(ctx, portID, channelID, channel, upgrade) } @@ -702,7 +700,7 @@ func (k Keeper) WriteUpgradeTimeoutChannel( // startFlushing will set the upgrade last packet send and continue blocking the upgrade from continuing until all // in-flight packets have been flushed. -func (k Keeper) startFlushing(ctx sdk.Context, portID, channelID string) error { +func (k Keeper) startFlushing(ctx sdk.Context, portID, channelID string, upgrade *types.Upgrade) error { channel, found := k.GetChannel(ctx, portID, channelID) if !found { return errorsmod.Wrapf(types.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID) @@ -720,11 +718,6 @@ func (k Keeper) startFlushing(ctx sdk.Context, portID, channelID string) error { channel.State = types.STATE_FLUSHING k.SetChannel(ctx, portID, channelID, channel) - upgrade, found := k.GetUpgrade(ctx, portID, channelID) - if !found { - return errorsmod.Wrapf(types.ErrUpgradeNotFound, "port ID (%s) channel ID (%s)", portID, channelID) - } - nextSequenceSend, found := k.GetNextSequenceSend(ctx, portID, channelID) if !found { return errorsmod.Wrapf(types.ErrSequenceSendNotFound, "port ID (%s) channel ID (%s)", portID, channelID) @@ -732,7 +725,7 @@ func (k Keeper) startFlushing(ctx sdk.Context, portID, channelID string) error { upgrade.LatestSequenceSend = nextSequenceSend - 1 upgrade.Timeout = getUpgradeTimeout() - k.SetUpgrade(ctx, portID, channelID, upgrade) + k.SetUpgrade(ctx, portID, channelID, *upgrade) return nil } diff --git a/modules/core/04-channel/keeper/upgrade_test.go b/modules/core/04-channel/keeper/upgrade_test.go index 3c0e740eead..87ee416c27d 100644 --- a/modules/core/04-channel/keeper/upgrade_test.go +++ b/modules/core/04-channel/keeper/upgrade_test.go @@ -409,30 +409,29 @@ func (suite *KeeperTestSuite) TestChanUpgradeAck() { malleate func() expError error }{ - // TODO: uncomment and handle failing tests - // { - // "success", - // func() {}, - // nil, - // }, - // { - // "success with later upgrade sequence", - // func() { - // channel := path.EndpointA.GetChannel() - // channel.UpgradeSequence = 10 - // path.EndpointA.SetChannel(channel) + { + "success", + func() {}, + nil, + }, + { + "success with later upgrade sequence", + func() { + channel := path.EndpointA.GetChannel() + channel.UpgradeSequence = 10 + path.EndpointA.SetChannel(channel) - // channel = path.EndpointB.GetChannel() - // channel.UpgradeSequence = 10 - // path.EndpointB.SetChannel(channel) + channel = path.EndpointB.GetChannel() + channel.UpgradeSequence = 10 + path.EndpointB.SetChannel(channel) - // suite.coordinator.CommitBlock(suite.chainA, suite.chainB) + suite.coordinator.CommitBlock(suite.chainA, suite.chainB) - // err := path.EndpointA.UpdateClient() - // suite.Require().NoError(err) - // }, - // nil, - // }, + err := path.EndpointA.UpdateClient() + suite.Require().NoError(err) + }, + nil, + }, { "channel not found", func() { @@ -516,21 +515,6 @@ func (suite *KeeperTestSuite) TestChanUpgradeAck() { }, types.NewUpgradeError(1, types.ErrIncompatibleCounterpartyUpgrade), }, - // { - // "channel end version mismatch on crossing hellos", - // func() { - // channel := path.EndpointA.GetChannel() - // channel.State = types.TRYUPGRADE - - // path.EndpointA.SetChannel(channel) - - // upgrade := path.EndpointA.GetChannelUpgrade() - // upgrade.Fields.Version = "invalid-version" - - // path.EndpointA.SetChannelUpgrade(upgrade) - // }, - // types.NewUpgradeError(1, types.ErrIncompatibleCounterpartyUpgrade), - // }, { "counterparty timeout has elapsed", func() { @@ -588,6 +572,11 @@ func (suite *KeeperTestSuite) TestChanUpgradeAck() { expPass := tc.expError == nil if expPass { suite.Require().NoError(err) + + channel := path.EndpointA.GetChannel() + // ChanUpgradeAck will set the channel state to STATE_FLUSHING + // It will be set to FLUSHING_COMPLETE in the write function. + suite.Require().Equal(types.STATE_FLUSHING, channel.State) } else { suite.assertUpgradeError(err, tc.expError) } @@ -650,10 +639,6 @@ func (suite *KeeperTestSuite) TestWriteChannelUpgradeAck() { upgrade := path.EndpointA.GetChannelUpgrade() suite.Require().Equal(mock.UpgradeVersion, upgrade.Fields.Version) - actualCounterpartyLastSequenceSend, ok := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetCounterpartyLastPacketSequence(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) - suite.Require().True(ok) - suite.Require().Equal(proposedUpgrade.LatestSequenceSend, actualCounterpartyLastSequenceSend) - events := ctx.EventManager().Events().ToABCIEvents() expEvents := ibctesting.EventsMap{ types.EventTypeChannelUpgradeAck: { @@ -674,14 +659,14 @@ func (suite *KeeperTestSuite) TestWriteChannelUpgradeAck() { ibctesting.AssertEvents(&suite.Suite, expEvents, events) - counterpartyUpgrade, ok := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetCounterpartyUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) - suite.Require().True(ok) - suite.Require().Equal(proposedUpgrade, counterpartyUpgrade) - - if tc.hasPacketCommitments { - suite.Require().Equal(types.FLUSHING, channel.FlushStatus) + if !tc.hasPacketCommitments { + suite.Require().Equal(types.STATE_FLUSHCOMPLETE, channel.State) + _, ok := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetCounterpartyUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().False(ok) } else { - suite.Require().Equal(types.FLUSHCOMPLETE, channel.FlushStatus) + counterpartyUpgrade, ok := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetCounterpartyUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().True(ok) + suite.Require().Equal(proposedUpgrade, counterpartyUpgrade) } }) } @@ -1419,17 +1404,18 @@ func (suite *KeeperTestSuite) TestStartFlush() { err = path.EndpointB.ChanUpgradeInit() suite.Require().NoError(err) + upgrade := path.EndpointB.GetChannelUpgrade() + tc.malleate() err = suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.StartFlushing( - suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, + suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, &upgrade, ) if tc.expError != nil { suite.assertUpgradeError(err, tc.expError) } else { channel := path.EndpointB.GetChannel() - upgrade := path.EndpointB.GetChannelUpgrade() nextSequenceSend, ok := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetNextSequenceSend(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) suite.Require().True(ok)