Skip to content

Commit

Permalink
Add ChanUpgradeCancel core handler (#3846)
Browse files Browse the repository at this point in the history
  • Loading branch information
chatton authored Jun 16, 2023
1 parent 1b0edff commit c0943cc
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 0 deletions.
46 changes: 46 additions & 0 deletions modules/core/04-channel/keeper/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,49 @@ func (k Keeper) WriteUpgradeAckChannel(
emitChannelUpgradeAckEvent(ctx, portID, channelID, channel, proposedUpgrade)
}

// ChanUpgradeCancel is called by a module to cancel a channel upgrade that is in progress.
func (k Keeper) ChanUpgradeCancel(ctx sdk.Context, portID, channelID string, errorReceipt types.ErrorReceipt, errorReceiptProof []byte, proofHeight clienttypes.Height) error {
channel, found := k.GetChannel(ctx, portID, channelID)
if !found {
return errorsmod.Wrapf(types.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
}

// the channel state must be in INITUPGRADE or TRYUPGRADE
if !collections.Contains(channel.State, []types.State{types.INITUPGRADE, types.TRYUPGRADE}) {
return errorsmod.Wrapf(types.ErrInvalidChannelState, "expected one of [%s, %s], got %s", types.INITUPGRADE, types.TRYUPGRADE, channel.State)
}

// get underlying connection for proof verification
connection, err := k.GetConnection(ctx, channel.ConnectionHops[0])
if err != nil {
return errorsmod.Wrap(err, "failed to retrieve connection using the channel connection hops")
}

if connection.GetState() != int32(connectiontypes.OPEN) {
return errorsmod.Wrapf(
connectiontypes.ErrInvalidConnectionState,
"connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String(),
)
}

if err := k.connectionKeeper.VerifyChannelUpgradeError(ctx, portID, channelID, connection, errorReceipt, errorReceiptProof, proofHeight); err != nil {
return errorsmod.Wrap(err, "failed to verify counterparty error receipt")
}

// If counterparty sequence is less than the current sequence, abort the transaction since this error receipt is from a previous upgrade.
// Otherwise, set our upgrade sequence to the counterparty's error sequence + 1 so that both sides start with a fresh sequence.
currentSequence := channel.UpgradeSequence
counterpartySequence := errorReceipt.Sequence
if counterpartySequence < currentSequence {
return errorsmod.Wrap(types.ErrInvalidUpgradeSequence, "error sequence must be less than current sequence")
}

channel.UpgradeSequence = errorReceipt.Sequence + 1
k.SetChannel(ctx, portID, channelID, channel)

return nil
}

// WriteUpgradeCancelChannel writes a channel which has canceled the upgrade process.Auxiliary upgrade state is
// also deleted.
func (k Keeper) WriteUpgradeCancelChannel(ctx sdk.Context, portID, channelID string) {
Expand All @@ -245,8 +288,11 @@ func (k Keeper) WriteUpgradeCancelChannel(ctx sdk.Context, portID, channelID str
panic(fmt.Sprintf("could not find existing channel when updating channel state, channelID: %s, portID: %s", channelID, portID))
}

previousState := channel.State

k.restoreChannel(ctx, portID, channelID, channel)

k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "previous-state", previousState, "new-state", types.OPEN.String())
emitChannelUpgradeCancelEvent(ctx, portID, channelID, channel, upgrade)
}

Expand Down
114 changes: 114 additions & 0 deletions modules/core/04-channel/keeper/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"

errorsmod "cosmossdk.io/errors"
sdk "github.com/cosmos/cosmos-sdk/types"

clienttypes "github.com/cosmos/ibc-go/v7/modules/core/02-client/types"
connectiontypes "github.com/cosmos/ibc-go/v7/modules/core/03-connection/types"
Expand Down Expand Up @@ -876,3 +877,116 @@ func (suite *KeeperTestSuite) TestAbortHandshake() {
})
}
}

func (suite *KeeperTestSuite) TestChanUpgradeCancel() {
var (
path *ibctesting.Path
errorReceipt types.ErrorReceipt
errorReceiptProof []byte
proofHeight clienttypes.Height
)
tests := []struct {
name string
malleate func()
expError error
}{
{
name: "success",
malleate: func() {},
expError: nil,
},
{
name: "invalid channel state",
malleate: func() {
channel := path.EndpointA.GetChannel()
channel.State = types.INIT
path.EndpointA.SetChannel(channel)
},
expError: types.ErrInvalidChannelState,
},
{
name: "channel not found",
malleate: func() {
path.EndpointA.Chain.DeleteKey(host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID))
},
expError: types.ErrChannelNotFound,
},
{
name: "connection not found",
malleate: func() {
channel := path.EndpointA.GetChannel()
channel.ConnectionHops = []string{"connection-100"}
path.EndpointA.SetChannel(channel)
},
expError: connectiontypes.ErrConnectionNotFound,
},
{
name: "counter partyupgrade sequence less than current sequence",
malleate: func() {
var ok bool
errorReceipt, ok = suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)
suite.Require().True(ok)

// the channel sequence will be 1
errorReceipt.Sequence = 0

suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, errorReceipt)

suite.coordinator.CommitBlock(suite.chainB)
suite.Require().NoError(path.EndpointA.UpdateClient())

upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)
errorReceiptProof, proofHeight = suite.chainB.QueryProof(upgradeErrorReceiptKey)
},
expError: types.ErrInvalidUpgradeSequence,
},
}

for _, tc := range tests {
tc := tc
suite.Run(tc.name, func() {
suite.SetupTest()

path = ibctesting.NewPath(suite.chainA, suite.chainB)
suite.coordinator.Setup(path)

path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion
path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion

suite.Require().NoError(path.EndpointA.ChanUpgradeInit())

suite.Require().NoError(path.EndpointB.UpdateClient())

// cause the upgrade to fail on chain b so an error receipt is written.
suite.chainB.GetSimApp().IBCMockModule.IBCApp.OnChanUpgradeTry = func(
ctx sdk.Context, portID, channelID string, order types.Order, connectionHops []string, counterpartyVersion string,
) (string, error) {
return "", fmt.Errorf("mock app callback failed")
}

suite.Require().NoError(path.EndpointB.ChanUpgradeTry())

suite.Require().NoError(path.EndpointA.UpdateClient())

upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)
errorReceiptProof, proofHeight = suite.chainB.QueryProof(upgradeErrorReceiptKey)

var ok bool
errorReceipt, ok = suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)
suite.Require().True(ok)

tc.malleate()

err := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.ChanUpgradeCancel(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, errorReceipt, errorReceiptProof, proofHeight)

expPass := tc.expError == nil
if expPass {
suite.Require().NoError(err)
channel := path.EndpointA.GetChannel()
suite.Require().Equal(errorReceipt.Sequence+1, channel.UpgradeSequence, "upgrade sequence should be incremented")
} else {
suite.Require().ErrorIs(err, tc.expError)
}
})
}
}

0 comments on commit c0943cc

Please sign in to comment.