Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ChanUpgradeOpen core handler. #3844

Merged
merged 9 commits into from
Jun 29, 2023
158 changes: 126 additions & 32 deletions modules/core/04-channel/keeper/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,132 @@ func (k Keeper) WriteUpgradeAckChannel(
emitChannelUpgradeAckEvent(ctx, portID, channelID, channel, upgrade)
}

// ChanUpgradeOpen is called by a module to complete the channel upgrade handshake and move the channel back to an OPEN state.
// This method should only be called after both channels have flushed any in-flight packets.
// This method should only be called directly by the core IBC message server.
func (k Keeper) ChanUpgradeOpen(
ctx sdk.Context,
portID,
channelID string,
counterpartyChannelState types.State,
proofChannel []byte,
proofHeight clienttypes.Height,
) error {
if k.hasInflightPackets(ctx, portID, channelID) {
return errorsmod.Wrapf(types.ErrPendingInflightPackets, "port ID (%s) channel ID (%s)", portID, channelID)
}

channel, found := k.GetChannel(ctx, portID, channelID)
if !found {
return errorsmod.Wrapf(types.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
}

if !collections.Contains(channel.State, []types.State{types.TRYUPGRADE, types.ACKUPGRADE}) {
return errorsmod.Wrapf(types.ErrInvalidChannelState, "expected one of [%s, %s], got %s", types.TRYUPGRADE, types.ACKUPGRADE, channel.State)
}

if channel.FlushStatus != types.FLUSHCOMPLETE {
return errorsmod.Wrapf(types.ErrInvalidFlushStatus, "expected %s, got %s", types.FLUSHCOMPLETE, channel.FlushStatus)
}

connection, err := k.GetConnection(ctx, channel.ConnectionHops[0])
if err != nil {
return errorsmod.Wrap(err, "failed to retrieve connection using the channel connection hops")
}

if connection.GetState() != int32(connectiontypes.OPEN) {
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String())
}

var counterpartyChannel types.Channel
switch counterpartyChannelState {
case types.OPEN:
upgrade, found := k.GetUpgrade(ctx, portID, channelID)
if !found {
return errorsmod.Wrapf(types.ErrUpgradeNotFound, "failed to retrieve channel upgrade: port ID (%s) channel ID (%s)", portID, channelID)
}

counterpartyChannel = types.Channel{
State: types.OPEN,
Ordering: upgrade.Fields.Ordering,
ConnectionHops: upgrade.Fields.ConnectionHops,
Counterparty: types.NewCounterparty(portID, channelID),
Version: upgrade.Fields.Version,
UpgradeSequence: channel.UpgradeSequence,
FlushStatus: types.NOTINFLUSH,
}

case types.TRYUPGRADE:
// If the counterparty is in TRYUPGRADE, then we must have gone through the ACKUPGRADE step.
if channel.State != types.ACKUPGRADE {
return errorsmod.Wrapf(types.ErrInvalidChannelState, "expected %s, got %s", types.ACKUPGRADE, channel.State)
}

counterpartyChannel = types.Channel{
State: types.TRYUPGRADE,
Ordering: channel.Ordering,
ConnectionHops: []string{connection.GetCounterparty().GetConnectionID()},
Counterparty: types.NewCounterparty(portID, channelID),
Version: channel.Version,
UpgradeSequence: channel.UpgradeSequence,
FlushStatus: types.FLUSHCOMPLETE,
}

case types.ACKUPGRADE:
counterpartyChannel = types.Channel{
State: types.ACKUPGRADE,
Ordering: channel.Ordering,
ConnectionHops: []string{connection.GetCounterparty().GetConnectionID()},
Counterparty: types.NewCounterparty(portID, channelID),
Version: channel.Version,
UpgradeSequence: channel.UpgradeSequence,
FlushStatus: types.FLUSHCOMPLETE,
}
Comment on lines +373 to +392
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] these two sections are identical with the exception of expected counterparty state, maybe we could handle these in a single case?

case types.TRYUPGRADE, types.ACKUPGRADE: ...

Happy to leave this out for now though, it might make some of the conditionals a little harder to reason about.

As an alternative, we could maybe add a function getExpectedCounterPartyChannel(currentState, ...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will remember to bring this up during review call, I'm thinking there's backing for keeping the switches separate to follow spec in structure.


default:
panic(fmt.Sprintf("counterparty channel state should be in one of [%s, %s, %s]; got %s", types.TRYUPGRADE, types.ACKUPGRADE, types.OPEN, counterpartyChannelState))
}

err = k.connectionKeeper.VerifyChannelState(ctx, connection, proofHeight, proofChannel, portID, channelID, counterpartyChannel)
if err != nil {
return errorsmod.Wrapf(err, "failed to verify counterparty channel, expected counterparty channel state: %s", counterpartyChannel.String())
}

return nil
}

// WriteUpgradeOpenChannel writes the agreed upon upgrade fields to the channel, sets the channel flush status to NOTINFLUSH and sets the channel state back to OPEN. This can be called in one of two cases:
// - In the UpgradeAck step of the handshake if both sides have already flushed all in-flight packets.
// - In the UpgradeOpen step of the handshake.
func (k Keeper) writeUpgradeOpenChannel(ctx sdk.Context, portID, channelID string) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will need to be exported if we are calling it from the message server layer right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yess! Will do so in #3895 to keep things relevant

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need unit tests for this write function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, guess the idea for Ack was test since there's the conditional there that leads to two different states. I do think we're lacking test cov on some checks in the write functions for the other handlers so I'd definitely be in favour of testing the rest in a follow up.

channel, found := k.GetChannel(ctx, portID, channelID)
if !found {
panic(fmt.Sprintf("could not find existing channel when updating channel state, channelID: %s, portID: %s", channelID, portID))
}

upgrade, found := k.GetUpgrade(ctx, portID, channelID)
if !found {
panic(fmt.Sprintf("could not find upgrade when updating channel state, channelID: %s, portID: %s", channelID, portID))
}

// Switch channel fields to upgrade fields and set channel state to OPEN
previousState := channel.State
channel.Ordering = upgrade.Fields.Ordering
channel.Version = upgrade.Fields.Version
channel.ConnectionHops = upgrade.Fields.ConnectionHops
channel.State = types.OPEN
channel.FlushStatus = types.NOTINFLUSH

k.SetChannel(ctx, portID, channelID, channel)

// Delete auxiliary state.
k.deleteUpgrade(ctx, portID, channelID)
k.deleteCounterpartyLastPacketSequence(ctx, portID, channelID)
Comment on lines +431 to +432
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potential candidate for another helper fn since it's used in restoreChannel but definitely not a big deal.


k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "previous-state", previousState.String(), "new-state", types.OPEN.String())
emitChannelUpgradeOpenEvent(ctx, portID, channelID, channel)
}

// ChanUpgradeCancel is called by a module to cancel a channel upgrade that is in progress.
func (k Keeper) ChanUpgradeCancel(ctx sdk.Context, portID, channelID string, errorReceipt types.ErrorReceipt, errorReceiptProof []byte, proofHeight clienttypes.Height) error {
channel, found := k.GetChannel(ctx, portID, channelID)
Expand Down Expand Up @@ -597,38 +723,6 @@ func (k Keeper) startFlushUpgradeHandshake(
return nil
}

// WriteUpgradeOpenChannel writes the agreed upon upgrade fields to the channel, sets the channel flush status to NOTINFLUSH and sets the channel state back to OPEN. This can be called in one of two cases:
// - In the UpgradeAck step of the handshake if both sides have already flushed all in-flight packets.
// - In the UpgradeOpen step of the handshake.
func (k Keeper) writeUpgradeOpenChannel(ctx sdk.Context, portID, channelID string) {
channel, found := k.GetChannel(ctx, portID, channelID)
if !found {
panic(fmt.Sprintf("could not find existing channel when updating channel state, channelID: %s, portID: %s", channelID, portID))
}

upgrade, found := k.GetUpgrade(ctx, portID, channelID)
if !found {
panic(fmt.Sprintf("could not find upgrade when updating channel state, channelID: %s, portID: %s", channelID, portID))
}

// Switch channel fields to upgrade fields and set channel state to OPEN
previousState := channel.State
channel.Ordering = upgrade.Fields.Ordering
channel.Version = upgrade.Fields.Version
channel.ConnectionHops = upgrade.Fields.ConnectionHops
channel.State = types.OPEN
channel.FlushStatus = types.NOTINFLUSH

k.SetChannel(ctx, portID, channelID, channel)

// Delete auxiliary state.
k.deleteUpgrade(ctx, portID, channelID)
k.deleteCounterpartyLastPacketSequence(ctx, portID, channelID)

k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "previous-state", previousState.String(), "new-state", types.OPEN.String())
emitChannelUpgradeOpenEvent(ctx, portID, channelID, channel)
}

// validateUpgradeFields validates the proposed upgrade fields against the existing channel.
// It returns an error if the following constraints are not met:
// - there exists at least one valid proposed change to the existing channel fields
Expand Down
116 changes: 116 additions & 0 deletions modules/core/04-channel/keeper/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,122 @@ func (suite *KeeperTestSuite) TestChanUpgradeAck() {
}
}

func (suite *KeeperTestSuite) TestChanUpgradeOpen() {
var path *ibctesting.Path
testCases := []struct {
name string
malleate func()
expError error
}{
{
"success",
func() {},
nil,
},
{
"channel not found",
func() {
path.EndpointA.ChannelConfig.PortID = ibctesting.InvalidID
},
types.ErrChannelNotFound,
},

{
"channel state is not in TRYUPGRADE or ACKUPGRADE",
func() {
suite.Require().NoError(path.EndpointA.SetChannelState(types.OPEN))
},
types.ErrInvalidChannelState,
},

{
"channel has in-flight packets",
func() {
portID := path.EndpointA.ChannelConfig.PortID
channelID := path.EndpointA.ChannelID
// Set a dummy packet commitment to simulate in-flight packets
suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.SetPacketCommitment(suite.chainA.GetContext(), portID, channelID, 1, []byte("hash"))
},
types.ErrPendingInflightPackets,
},
{
"flush status is not FLUSHCOMPLETE",
func() {
channel := path.EndpointA.GetChannel()
DimitrisJim marked this conversation as resolved.
Show resolved Hide resolved
channel.FlushStatus = types.FLUSHING
path.EndpointA.SetChannel(channel)
},
types.ErrInvalidFlushStatus,
},
{
"connection not found",
func() {
channel := path.EndpointA.GetChannel()
channel.ConnectionHops = []string{"connection-100"}
path.EndpointA.SetChannel(channel)
},
connectiontypes.ErrConnectionNotFound,
},
{
"invalid connection state",
func() {
connectionEnd := path.EndpointA.GetConnection()
connectionEnd.State = connectiontypes.UNINITIALIZED
path.EndpointA.SetConnection(connectionEnd)
},
connectiontypes.ErrInvalidConnectionState,
},
}

for _, tc := range testCases {
tc := tc
suite.Run(tc.name, func() {
expPass := tc.expError == nil
DimitrisJim marked this conversation as resolved.
Show resolved Hide resolved
suite.SetupTest()

path = ibctesting.NewPath(suite.chainA, suite.chainB)
suite.coordinator.Setup(path)

path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion
path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion

err := path.EndpointB.ChanUpgradeInit()
suite.Require().NoError(err)

err = path.EndpointA.ChanUpgradeTry()
suite.Require().NoError(err)

err = path.EndpointB.ChanUpgradeAck()
suite.Require().NoError(err)

// TODO: Remove setting of FLUSHCOMPLETE once #3928 is completed
channelB := path.EndpointB.GetChannel()
channelB.FlushStatus = types.FLUSHCOMPLETE
path.EndpointB.SetChannel(channelB)

channelA := path.EndpointA.GetChannel()
channelA.FlushStatus = types.FLUSHCOMPLETE
path.EndpointA.SetChannel(channelA)

suite.coordinator.CommitBlock(suite.chainA, suite.chainB)
suite.Require().NoError(path.EndpointA.UpdateClient())

tc.malleate()

proofCounterpartyChannel, _, proofHeight := path.EndpointA.QueryChannelUpgradeProof()
err = suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.ChanUpgradeOpen(
suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID,
path.EndpointB.GetChannel().State, proofCounterpartyChannel, proofHeight,
)
if expPass {
suite.Require().NoError(err)
} else {
suite.Require().ErrorIs(err, tc.expError)
}
})
}
}

func (suite *KeeperTestSuite) TestChanUpgradeTimeout() {
var (
path *ibctesting.Path
Expand Down
1 change: 1 addition & 0 deletions modules/core/04-channel/types/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,5 @@ var (
ErrUpgradeRestoreFailed = errorsmod.Register(SubModuleName, 34, "restore failed")
ErrUpgradeTimeout = errorsmod.Register(SubModuleName, 35, "upgrade timed-out")
ErrInvalidUpgradeTimeout = errorsmod.Register(SubModuleName, 36, "upgrade timeout is invalid")
ErrPendingInflightPackets = errorsmod.Register(SubModuleName, 37, "pending inflight packets exist")
)