Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functions checkForUpgradeCompatibility and syncUpgradeSequence #4352

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions modules/core/04-channel/keeper/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,13 @@ func (k Keeper) StartFlushUpgradeHandshake(
func (k Keeper) ValidateSelfUpgradeFields(ctx sdk.Context, proposedUpgrade types.UpgradeFields, currentChannel types.Channel) error {
return k.validateSelfUpgradeFields(ctx, proposedUpgrade, currentChannel)
}

// CheckForUpgradeCompatibility is a wrapper around checkForUpgradeCompatibility to allow the function to be directly called in tests.
func (k Keeper) CheckForUpgradeCompatibility(ctx sdk.Context, upgradeFields, counterpartyUpgradeFields types.UpgradeFields) error {
return k.checkForUpgradeCompatibility(ctx, upgradeFields, counterpartyUpgradeFields)
}

// SyncUpgradeSequence is a wrapper around syncUpgradeSequence to allow the function to be directly called in tests.
func (k Keeper) SyncUpgradeSequence(ctx sdk.Context, portID, channelID string, channel types.Channel, counterpartyUpgradeSequence uint64) error {
return k.syncUpgradeSequence(ctx, portID, channelID, channel, counterpartyUpgradeSequence)
}
55 changes: 32 additions & 23 deletions modules/core/04-channel/keeper/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (k Keeper) ChanUpgradeInit(

// WriteUpgradeInitChannel writes a channel which has successfully passed the UpgradeInit handshake step.
// An event is emitted for the handshake step.
func (k Keeper) WriteUpgradeInitChannel(ctx sdk.Context, portID, channelID string, upgrade types.Upgrade) {
func (k Keeper) WriteUpgradeInitChannel(ctx sdk.Context, portID, channelID string, upgrade types.Upgrade) types.Channel {
defer telemetry.IncrCounter(1, "ibc", "channel", "upgrade-init")

channel, found := k.GetChannel(ctx, portID, channelID)
Expand All @@ -60,6 +60,7 @@ func (k Keeper) WriteUpgradeInitChannel(ctx sdk.Context, portID, channelID strin
k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "previous-state", types.OPEN.String(), "new-state", types.INITUPGRADE.String())

emitChannelUpgradeInitEvent(ctx, portID, channelID, channel, upgrade)
return channel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit of a command-query separation purist :) so I do like the convention that state-mutating functions should not return a value.

Copy link
Contributor Author

@chatton chatton Aug 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We end up in the situation where functions that call this function end up with a stale channel afterwards, so we would need to remember to do an extra get operation on the channel each time we call this if we want an up-to-date channel.

This is also in line with what is happening with WriteUpgradeTryChannel for similar reasons.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can definitely come back to this later, I do agree with @crodriguezvega's preferred approach, but I think for now its fine so that we can make progress quickly on the refactor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

linking relevant issue opened for the try version #3825 (since I randomly bumped into it)

}

// ChanUpgradeTry is called by a module to accept the first step of a channel upgrade handshake initiated by
Expand Down Expand Up @@ -122,7 +123,7 @@ func (k Keeper) ChanUpgradeTry(

// NOTE: OnChanUpgradeInit will not be executed by the application

k.WriteUpgradeInitChannel(ctx, portID, channelID, upgrade)
channel = k.WriteUpgradeInitChannel(ctx, portID, channelID, upgrade)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is required as the upgrade sequence is modified here, without this the channel variable defined above becomes stale.


case types.INITUPGRADE:
// crossing hellos
Expand Down Expand Up @@ -166,6 +167,14 @@ func (k Keeper) ChanUpgradeTry(
return types.Upgrade{}, errorsmod.Wrap(err, "failed to verify counterparty channel state")
}

if err := k.syncUpgradeSequence(ctx, portID, channelID, channel, counterpartyUpgradeSequence); err != nil {
return types.Upgrade{}, err
}

if err := k.checkForUpgradeCompatibility(ctx, upgrade.Fields, counterpartyUpgradeFields); err != nil {
return types.Upgrade{}, errorsmod.Wrap(err, "failed upgrade compatibility check")
}

// verifies the proof that a particular proposed upgrade has been stored in the upgrade path of the counterparty
if err := k.connectionKeeper.VerifyChannelUpgrade(
ctx,
Expand Down Expand Up @@ -299,7 +308,7 @@ func (k Keeper) ChanUpgradeAck(
return errorsmod.Wrapf(types.ErrUpgradeNotFound, "failed to retrieve channel upgrade: port ID (%s) channel ID (%s)", portID, channelID)
}

if err := k.checkForUpgradeCompatibility(ctx, upgrade.Fields, counterpartyUpgrade); err != nil {
if err := k.checkForUpgradeCompatibility(ctx, upgrade.Fields, counterpartyUpgrade.Fields); err != nil {
return types.NewUpgradeError(channel.UpgradeSequence, err)
}

Expand Down Expand Up @@ -730,52 +739,52 @@ func (k Keeper) startFlushUpgradeHandshake(
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String())
}

// the current upgrade handshake must only continue if both channels are using the same upgrade sequence,
// otherwise an error receipt must be written so that the upgrade handshake may be attempted again with synchronized sequences
if counterpartyChannel.UpgradeSequence != channel.UpgradeSequence {
// save the previous upgrade sequence for the error message
prevUpgradeSequence := channel.UpgradeSequence
return nil
}

// syncUpgradeSequence ensures current upgrade handshake only continues if both channels are using the same upgrade sequence,
// otherwise an upgrade error is returned so that an error receipt will be written so that the upgrade handshake may be attempted again with synchronized sequences.
func (k Keeper) syncUpgradeSequence(ctx sdk.Context, portID, channelID string, channel types.Channel, counterpartyUpgradeSequence uint64) error {
// save the previous upgrade sequence for the error message
prevUpgradeSequence := channel.UpgradeSequence

if counterpartyUpgradeSequence != channel.UpgradeSequence {
// error on the higher sequence so that both chains synchronize on a fresh sequence
channel.UpgradeSequence = sdkmath.Max(counterpartyChannel.UpgradeSequence, channel.UpgradeSequence)
channel.UpgradeSequence = sdkmath.Max(counterpartyUpgradeSequence, channel.UpgradeSequence)
k.SetChannel(ctx, portID, channelID, channel)

return types.NewUpgradeError(channel.UpgradeSequence, errorsmod.Wrapf(
types.ErrIncompatibleCounterpartyUpgrade, "expected upgrade sequence (%d) to match counterparty upgrade sequence (%d)", prevUpgradeSequence, counterpartyChannel.UpgradeSequence),
types.ErrInvalidUpgradeSequence, "expected upgrade sequence (%d) to match counterparty upgrade sequence (%d)", prevUpgradeSequence, counterpartyUpgradeSequence),
)
}

if err := k.checkForUpgradeCompatibility(ctx, proposedUpgradeFields, counterpartyUpgrade); err != nil {
return types.NewUpgradeError(channel.UpgradeSequence, err)
}

return nil
}

// checkForUpgradeCompatibility checks performs stateful validation of self upgrade fields relative to counterparty upgrade.
func (k Keeper) checkForUpgradeCompatibility(ctx sdk.Context, proposedUpgradeFields types.UpgradeFields, counterpartyUpgrade types.Upgrade) error {
func (k Keeper) checkForUpgradeCompatibility(ctx sdk.Context, upgradeFields, counterpartyUpgradeFields types.UpgradeFields) error {
// assert that both sides propose the same channel ordering
if proposedUpgradeFields.Ordering != counterpartyUpgrade.Fields.Ordering {
return errorsmod.Wrapf(types.ErrIncompatibleCounterpartyUpgrade, "expected upgrade ordering (%s) to match counterparty upgrade ordering (%s)", proposedUpgradeFields.Ordering, counterpartyUpgrade.Fields.Ordering)
if upgradeFields.Ordering != counterpartyUpgradeFields.Ordering {
return errorsmod.Wrapf(types.ErrIncompatibleCounterpartyUpgrade, "expected upgrade ordering (%s) to match counterparty upgrade ordering (%s)", upgradeFields.Ordering, counterpartyUpgradeFields.Ordering)
}

proposedConnection, found := k.connectionKeeper.GetConnection(ctx, proposedUpgradeFields.ConnectionHops[0])
connection, found := k.connectionKeeper.GetConnection(ctx, upgradeFields.ConnectionHops[0])
if !found {
// NOTE: this error is expected to be unreachable as the proposed upgrade connectionID should have been
// validated in the upgrade INIT and TRY handlers
return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, proposedUpgradeFields.ConnectionHops[0])
return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, upgradeFields.ConnectionHops[0])
}

if proposedConnection.GetState() != int32(connectiontypes.OPEN) {
if connection.GetState() != int32(connectiontypes.OPEN) {
// NOTE: this error is expected to be unreachable as the proposed upgrade connectionID should have been
// validated in the upgrade INIT and TRY handlers
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "expected proposed connection to be OPEN (got %s)", connectiontypes.State(proposedConnection.GetState()).String())
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "expected proposed connection to be OPEN (got %s)", connectiontypes.State(connection.GetState()).String())
}

// connectionHops can change in a channelUpgrade, however both sides must still be each other's counterparty.
if counterpartyUpgrade.Fields.ConnectionHops[0] != proposedConnection.GetCounterparty().GetConnectionID() {
if counterpartyUpgradeFields.ConnectionHops[0] != connection.GetCounterparty().GetConnectionID() {
return errorsmod.Wrapf(
types.ErrIncompatibleCounterpartyUpgrade, "counterparty upgrade connection end is not a counterparty of self proposed connection end (%s != %s)", counterpartyUpgrade.Fields.ConnectionHops[0], proposedConnection.GetCounterparty().GetConnectionID())
types.ErrIncompatibleCounterpartyUpgrade, "counterparty upgrade connection end is not a counterparty of self proposed connection end (%s != %s)", counterpartyUpgradeFields.ConnectionHops[0], connection.GetCounterparty().GetConnectionID())
}

return nil
Expand Down
215 changes: 156 additions & 59 deletions modules/core/04-channel/keeper/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,10 @@ func (suite *KeeperTestSuite) TestChanUpgradeTry() {
func() {
counterpartyUpgrade.Fields.ConnectionHops = []string{ibctesting.InvalidID}
},
commitmenttypes.ErrInvalidProof,
types.ErrIncompatibleCounterpartyUpgrade,
},
{
"startFlushUpgradeHandshake fails due to incompatible upgrades, chainB proposes a new connection hop that does not match counterparty",
"fails due to incompatible upgrades, chainB proposes a new connection hop that does not match counterparty",
func() {
// reuse existing connection to create a new connection in a non OPEN state
connection := path.EndpointB.GetConnection()
Expand All @@ -267,16 +267,16 @@ func (suite *KeeperTestSuite) TestChanUpgradeTry() {
suite.chainB.GetSimApp().GetIBCKeeper().ConnectionKeeper.SetConnection(suite.chainB.GetContext(), proposedConnectionID, connection)
proposedUpgrade.Fields.ConnectionHops[0] = proposedConnectionID
},
types.NewUpgradeError(1, types.ErrIncompatibleCounterpartyUpgrade),
types.ErrIncompatibleCounterpartyUpgrade,
},
{
"startFlushUpgradeHandshake fails due to mismatch in upgrade sequences",
"fails due to mismatch in upgrade sequences",
func() {
channel := path.EndpointB.GetChannel()
channel.UpgradeSequence = 5
path.EndpointB.SetChannel(channel)
},
types.NewUpgradeError(6, types.ErrIncompatibleCounterpartyUpgrade), // max sequence + 1 will be returned
types.NewUpgradeError(6, types.ErrInvalidUpgradeSequence), // max sequence + 1 will be returned
},
}

Expand Down Expand Up @@ -1445,59 +1445,6 @@ func (suite *KeeperTestSuite) TestStartFlushUpgradeHandshake() {
},
connectiontypes.ErrInvalidConnectionState,
},
{
"upgrade sequence mismatch, endpointB channel upgrade sequence is ahead",
func() {
channel := path.EndpointB.GetChannel()
channel.UpgradeSequence++
path.EndpointB.SetChannel(channel)
},
types.NewUpgradeError(2, types.ErrIncompatibleCounterpartyUpgrade), // max sequence will be returned
},
{
"upgrade ordering is not the same on both sides",
func() {
upgrade.Fields.Ordering = types.ORDERED
},
types.NewUpgradeError(1, types.ErrIncompatibleCounterpartyUpgrade),
},
{
"proposed connection is not found",
func() {
upgrade.Fields.ConnectionHops[0] = ibctesting.InvalidID
},
types.NewUpgradeError(1, connectiontypes.ErrConnectionNotFound),
},
{
"proposed connection is not in OPEN state",
func() {
// reuse existing connection to create a new connection in a non OPEN state
connectionEnd := path.EndpointB.GetConnection()
connectionEnd.State = connectiontypes.UNINITIALIZED
connectionEnd.Counterparty.ConnectionId = counterpartyUpgrade.Fields.ConnectionHops[0] // both sides must be each other's counterparty

// set proposed connection in state
proposedConnectionID := "connection-100"
suite.chainB.GetSimApp().GetIBCKeeper().ConnectionKeeper.SetConnection(suite.chainB.GetContext(), proposedConnectionID, connectionEnd)
upgrade.Fields.ConnectionHops[0] = proposedConnectionID
},
types.NewUpgradeError(1, connectiontypes.ErrInvalidConnectionState),
},
{
"proposed connection ends are not each other's counterparty",
func() {
// reuse existing connection to create a new connection in a non OPEN state
connectionEnd := path.EndpointB.GetConnection()
// ensure counterparty connectionID does not match connectionID set in counterparty proposed upgrade
connectionEnd.Counterparty.ConnectionId = "connection-50"

// set proposed connection in state
proposedConnectionID := "connection-100"
suite.chainB.GetSimApp().GetIBCKeeper().ConnectionKeeper.SetConnection(suite.chainB.GetContext(), proposedConnectionID, connectionEnd)
upgrade.Fields.ConnectionHops[0] = proposedConnectionID
},
types.NewUpgradeError(1, types.ErrIncompatibleCounterpartyUpgrade),
},
}

for _, tc := range testCases {
Expand Down Expand Up @@ -1659,7 +1606,7 @@ func (suite *KeeperTestSuite) assertUpgradeError(actualError, expError error) {
suite.Require().Equal(expUpgradeError.GetErrorReceipt(), upgradeError.GetErrorReceipt())
}

suite.Require().True(errorsmod.IsOf(actualError, expError), actualError)
suite.Require().True(errorsmod.IsOf(actualError, expError), fmt.Sprintf("expected error: %s, actual error: %s", expError, actualError))
}

// TestAbortHandshake tests that when the channel handshake is aborted, the channel state
Expand Down Expand Up @@ -1780,3 +1727,153 @@ func (suite *KeeperTestSuite) TestAbortHandshake() {
})
}
}

func (suite *KeeperTestSuite) TestCheckForUpgradeCompatibility() {
var (
path *ibctesting.Path
upgradeFields types.UpgradeFields
counterpartyUpgradeFields types.UpgradeFields
)

testCases := []struct {
name string
malleate func()
expError error
}{
{
"success",
func() {},
nil,
},
{
"upgrade ordering is not the same on both sides",
func() {
upgradeFields.Ordering = types.ORDERED
},
types.ErrIncompatibleCounterpartyUpgrade,
},
{
"proposed connection is not found",
func() {
upgradeFields.ConnectionHops[0] = ibctesting.InvalidID
},
connectiontypes.ErrConnectionNotFound,
},
{
"proposed connection is not in OPEN state",
func() {
// reuse existing connection to create a new connection in a non OPEN state
connectionEnd := path.EndpointB.GetConnection()
connectionEnd.State = connectiontypes.UNINITIALIZED
connectionEnd.Counterparty.ConnectionId = counterpartyUpgradeFields.ConnectionHops[0] // both sides must be each other's counterparty

// set proposed connection in state
proposedConnectionID := "connection-100"
suite.chainB.GetSimApp().GetIBCKeeper().ConnectionKeeper.SetConnection(suite.chainB.GetContext(), proposedConnectionID, connectionEnd)
upgradeFields.ConnectionHops[0] = proposedConnectionID
},
connectiontypes.ErrInvalidConnectionState,
},
{
"proposed connection ends are not each other's counterparty",
func() {
// reuse existing connection to create a new connection in a non OPEN state
connectionEnd := path.EndpointB.GetConnection()
// ensure counterparty connectionID does not match connectionID set in counterparty proposed upgrade
connectionEnd.Counterparty.ConnectionId = "connection-50"

// set proposed connection in state
proposedConnectionID := "connection-100"
suite.chainB.GetSimApp().GetIBCKeeper().ConnectionKeeper.SetConnection(suite.chainB.GetContext(), proposedConnectionID, connectionEnd)
upgradeFields.ConnectionHops[0] = proposedConnectionID
},
types.ErrIncompatibleCounterpartyUpgrade,
},
}

for _, tc := range testCases {
tc := tc
suite.Run(tc.name, func() {
suite.SetupTest()

path = ibctesting.NewPath(suite.chainA, suite.chainB)
suite.coordinator.Setup(path)

path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion
path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion

err := path.EndpointA.ChanUpgradeInit()
suite.Require().NoError(err)

upgradeFields = path.EndpointA.GetProposedUpgrade().Fields
counterpartyUpgradeFields = path.EndpointB.GetProposedUpgrade().Fields

tc.malleate()

err = suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.CheckForUpgradeCompatibility(suite.chainB.GetContext(), upgradeFields, counterpartyUpgradeFields)
if tc.expError != nil {
suite.Require().ErrorIs(err, tc.expError)
} else {
suite.Require().NoError(err)
}
})
}
}

func (suite *KeeperTestSuite) TestSyncUpgradeSequence() {
var (
path *ibctesting.Path
counterpartyUpgradeSequence uint64
)

testCases := []struct {
name string
malleate func()
expError error
}{
{
"success",
func() {},
nil,
},
{
"upgrade sequence mismatch, endpointB channel upgrade sequence is ahead",
func() {
channel := path.EndpointB.GetChannel()
channel.UpgradeSequence = 10
path.EndpointB.SetChannel(channel)
},
types.NewUpgradeError(10, types.ErrInvalidUpgradeSequence), // max sequence will be returned
},
}

for _, tc := range testCases {
tc := tc
suite.Run(tc.name, func() {
suite.SetupTest()

path = ibctesting.NewPath(suite.chainA, suite.chainB)
suite.coordinator.Setup(path)

path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion
path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion

err := path.EndpointA.ChanUpgradeInit()
suite.Require().NoError(err)

err = path.EndpointB.ChanUpgradeInit()
suite.Require().NoError(err)

counterpartyUpgradeSequence = 1

tc.malleate()

err = suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.SyncUpgradeSequence(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, path.EndpointB.GetChannel(), counterpartyUpgradeSequence)
if tc.expError != nil {
suite.Require().ErrorIs(err, tc.expError)
} else {
suite.Require().NoError(err)
}
})
}
}