Skip to content

Commit

Permalink
Channel Upgrade Ack (#4372)
Browse files Browse the repository at this point in the history
  • Loading branch information
chatton authored Aug 17, 2023
1 parent cd62620 commit e2c8326
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 79 deletions.
4 changes: 2 additions & 2 deletions modules/core/04-channel/keeper/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (
)

// StartFlushing is a wrapper around startFlushing to allow the function to be directly called in tests.
func (k Keeper) StartFlushing(ctx sdk.Context, portID, channelID string) error {
return k.startFlushing(ctx, portID, channelID)
func (k Keeper) StartFlushing(ctx sdk.Context, portID, channelID string, upgrade *types.Upgrade) error {
return k.startFlushing(ctx, portID, channelID, upgrade)
}

// ValidateSelfUpgradeFields is a wrapper around validateSelfUpgradeFields to allow the function to be directly called in tests.
Expand Down
49 changes: 21 additions & 28 deletions modules/core/04-channel/keeper/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (k Keeper) ChanUpgradeTry(
return types.Upgrade{}, errorsmod.Wrap(err, "failed to verify counterparty upgrade")
}

if err := k.startFlushing(ctx, portID, channelID); err != nil {
if err := k.startFlushing(ctx, portID, channelID, &upgrade); err != nil {
return types.Upgrade{}, err
}

Expand Down Expand Up @@ -242,7 +242,6 @@ func (k Keeper) ChanUpgradeAck(
Counterparty: types.NewCounterparty(portID, channelID),
Version: channel.Version,
UpgradeSequence: channel.UpgradeSequence,
FlushStatus: types.NOTINFLUSH, // TODO: remove flush status from channel end
}

// verify the counterparty channel state containing the upgrade sequence
Expand Down Expand Up @@ -274,26 +273,29 @@ func (k Keeper) ChanUpgradeAck(
return errorsmod.Wrapf(types.ErrUpgradeNotFound, "failed to retrieve channel upgrade: port ID (%s) channel ID (%s)", portID, channelID)
}

// optimistically accept version that TRY chain proposes and pass this to callback for confirmation
// in the crossing hello case, we do not modify version that our TRY call returned and instead enforce
// that both TRY calls returned the same version
if channel.IsOpen() {
upgrade.Fields.Version = counterpartyUpgrade.Fields.Version
}

// if upgrades are not compatible by ACK step, then we restore the channel
if err := k.checkForUpgradeCompatibility(ctx, upgrade.Fields, counterpartyUpgrade.Fields); err != nil {
return types.NewUpgradeError(channel.UpgradeSequence, err)
}

if channel.IsOpen() {
if err := k.startFlushing(ctx, portID, channelID, &upgrade); err != nil {
return err
}
}

timeout := counterpartyUpgrade.Timeout
if hasPassed, err := timeout.HasPassed(ctx); hasPassed {
return types.NewUpgradeError(channel.UpgradeSequence, errorsmod.Wrap(err, "counterparty upgrade timeout has passed"))
}

if err := k.startFlushing(ctx, portID, channelID); err != nil {
return err
}

// in the crossing hellos case, the versions returned by both on TRY must be the same
if channel.State == types.TRYUPGRADE {
if upgrade.Fields.Version != counterpartyUpgrade.Fields.Version {
return types.NewUpgradeError(channel.UpgradeSequence, errorsmod.Wrap(types.ErrIncompatibleCounterpartyUpgrade, "both channel ends must agree on the same version"))
}
}

return nil
}

Expand All @@ -308,16 +310,12 @@ func (k Keeper) WriteUpgradeAckChannel(ctx sdk.Context, portID, channelID string
panic(fmt.Sprintf("could not find existing channel when updating channel state in successful ChanUpgradeAck step, channelID: %s, portID: %s", channelID, portID))
}

previousState := channel.State
channel.State = types.ACKUPGRADE
channel.FlushStatus = types.FLUSHING

if !k.HasInflightPackets(ctx, portID, channelID) {
channel.FlushStatus = types.FLUSHCOMPLETE
channel.State = types.STATE_FLUSHCOMPLETE
} else {
k.SetCounterpartyUpgrade(ctx, portID, channelID, counterpartyUpgrade)
}

k.SetCounterpartyLastPacketSequence(ctx, portID, channelID, counterpartyUpgrade.LatestSequenceSend)
k.SetCounterpartyUpgrade(ctx, portID, channelID, counterpartyUpgrade)
k.SetChannel(ctx, portID, channelID, channel)

upgrade, found := k.GetUpgrade(ctx, portID, channelID)
Expand All @@ -329,7 +327,7 @@ func (k Keeper) WriteUpgradeAckChannel(ctx sdk.Context, portID, channelID string

k.SetUpgrade(ctx, portID, channelID, upgrade)

k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "previous-state", previousState, "new-state", types.ACKUPGRADE.String())
k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "state", channel.State.String())
emitChannelUpgradeAckEvent(ctx, portID, channelID, channel, upgrade)
}

Expand Down Expand Up @@ -702,7 +700,7 @@ func (k Keeper) WriteUpgradeTimeoutChannel(

// startFlushing will set the upgrade last packet send and continue blocking the upgrade from continuing until all
// in-flight packets have been flushed.
func (k Keeper) startFlushing(ctx sdk.Context, portID, channelID string) error {
func (k Keeper) startFlushing(ctx sdk.Context, portID, channelID string, upgrade *types.Upgrade) error {
channel, found := k.GetChannel(ctx, portID, channelID)
if !found {
return errorsmod.Wrapf(types.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
Expand All @@ -720,19 +718,14 @@ func (k Keeper) startFlushing(ctx sdk.Context, portID, channelID string) error {
channel.State = types.STATE_FLUSHING
k.SetChannel(ctx, portID, channelID, channel)

upgrade, found := k.GetUpgrade(ctx, portID, channelID)
if !found {
return errorsmod.Wrapf(types.ErrUpgradeNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
}

nextSequenceSend, found := k.GetNextSequenceSend(ctx, portID, channelID)
if !found {
return errorsmod.Wrapf(types.ErrSequenceSendNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
}

upgrade.LatestSequenceSend = nextSequenceSend - 1
upgrade.Timeout = getUpgradeTimeout()
k.SetUpgrade(ctx, portID, channelID, upgrade)
k.SetUpgrade(ctx, portID, channelID, *upgrade)

return nil
}
Expand Down
84 changes: 35 additions & 49 deletions modules/core/04-channel/keeper/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,30 +409,29 @@ func (suite *KeeperTestSuite) TestChanUpgradeAck() {
malleate func()
expError error
}{
// TODO: uncomment and handle failing tests
// {
// "success",
// func() {},
// nil,
// },
// {
// "success with later upgrade sequence",
// func() {
// channel := path.EndpointA.GetChannel()
// channel.UpgradeSequence = 10
// path.EndpointA.SetChannel(channel)
{
"success",
func() {},
nil,
},
{
"success with later upgrade sequence",
func() {
channel := path.EndpointA.GetChannel()
channel.UpgradeSequence = 10
path.EndpointA.SetChannel(channel)

// channel = path.EndpointB.GetChannel()
// channel.UpgradeSequence = 10
// path.EndpointB.SetChannel(channel)
channel = path.EndpointB.GetChannel()
channel.UpgradeSequence = 10
path.EndpointB.SetChannel(channel)

// suite.coordinator.CommitBlock(suite.chainA, suite.chainB)
suite.coordinator.CommitBlock(suite.chainA, suite.chainB)

// err := path.EndpointA.UpdateClient()
// suite.Require().NoError(err)
// },
// nil,
// },
err := path.EndpointA.UpdateClient()
suite.Require().NoError(err)
},
nil,
},
{
"channel not found",
func() {
Expand Down Expand Up @@ -516,21 +515,6 @@ func (suite *KeeperTestSuite) TestChanUpgradeAck() {
},
types.NewUpgradeError(1, types.ErrIncompatibleCounterpartyUpgrade),
},
// {
// "channel end version mismatch on crossing hellos",
// func() {
// channel := path.EndpointA.GetChannel()
// channel.State = types.TRYUPGRADE

// path.EndpointA.SetChannel(channel)

// upgrade := path.EndpointA.GetChannelUpgrade()
// upgrade.Fields.Version = "invalid-version"

// path.EndpointA.SetChannelUpgrade(upgrade)
// },
// types.NewUpgradeError(1, types.ErrIncompatibleCounterpartyUpgrade),
// },
{
"counterparty timeout has elapsed",
func() {
Expand Down Expand Up @@ -588,6 +572,11 @@ func (suite *KeeperTestSuite) TestChanUpgradeAck() {
expPass := tc.expError == nil
if expPass {
suite.Require().NoError(err)

channel := path.EndpointA.GetChannel()
// ChanUpgradeAck will set the channel state to STATE_FLUSHING
// It will be set to FLUSHING_COMPLETE in the write function.
suite.Require().Equal(types.STATE_FLUSHING, channel.State)
} else {
suite.assertUpgradeError(err, tc.expError)
}
Expand Down Expand Up @@ -650,10 +639,6 @@ func (suite *KeeperTestSuite) TestWriteChannelUpgradeAck() {
upgrade := path.EndpointA.GetChannelUpgrade()
suite.Require().Equal(mock.UpgradeVersion, upgrade.Fields.Version)

actualCounterpartyLastSequenceSend, ok := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetCounterpartyLastPacketSequence(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
suite.Require().True(ok)
suite.Require().Equal(proposedUpgrade.LatestSequenceSend, actualCounterpartyLastSequenceSend)

events := ctx.EventManager().Events().ToABCIEvents()
expEvents := ibctesting.EventsMap{
types.EventTypeChannelUpgradeAck: {
Expand All @@ -674,14 +659,14 @@ func (suite *KeeperTestSuite) TestWriteChannelUpgradeAck() {

ibctesting.AssertEvents(&suite.Suite, expEvents, events)

counterpartyUpgrade, ok := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetCounterpartyUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
suite.Require().True(ok)
suite.Require().Equal(proposedUpgrade, counterpartyUpgrade)

if tc.hasPacketCommitments {
suite.Require().Equal(types.FLUSHING, channel.FlushStatus)
if !tc.hasPacketCommitments {
suite.Require().Equal(types.STATE_FLUSHCOMPLETE, channel.State)
_, ok := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetCounterpartyUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
suite.Require().False(ok)
} else {
suite.Require().Equal(types.FLUSHCOMPLETE, channel.FlushStatus)
counterpartyUpgrade, ok := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetCounterpartyUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
suite.Require().True(ok)
suite.Require().Equal(proposedUpgrade, counterpartyUpgrade)
}
})
}
Expand Down Expand Up @@ -1419,17 +1404,18 @@ func (suite *KeeperTestSuite) TestStartFlush() {
err = path.EndpointB.ChanUpgradeInit()
suite.Require().NoError(err)

upgrade := path.EndpointB.GetChannelUpgrade()

tc.malleate()

err = suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.StartFlushing(
suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID,
suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, &upgrade,
)

if tc.expError != nil {
suite.assertUpgradeError(err, tc.expError)
} else {
channel := path.EndpointB.GetChannel()
upgrade := path.EndpointB.GetChannelUpgrade()

nextSequenceSend, ok := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetNextSequenceSend(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)
suite.Require().True(ok)
Expand Down

0 comments on commit e2c8326

Please sign in to comment.