Skip to content

Commit

Permalink
chore: ctrl port connection id validation (#454)
Browse files Browse the repository at this point in the history
* adding pipe char | to identifier regex

* updating GeneratePortID to use pipe delimiter in favour of dash, ParseAddressFromVersion to use Split in favour of TrimPrefix

* adding CounterpartyHops method to expected channel keeper interface

* updating tests to satisy delimiter updates

* adding connection seq validation of ctrl port id and updating tests

* cleanup

* adding defensive check for ParseAddressFromVersion

* adding conn sequence parsing funcs to pkg types

* moving conn sequence validation to reusable func

* updating error msgs, adding tests for conn seq parsers

* adding expected sequence to error msgs

* updating ParseCtrlConnSequence to ParseControllerConnSequence

* fixing counterparty port error

* Update modules/apps/27-interchain-accounts/keeper/handshake.go

Co-authored-by: colin axnér <25233464+colin-axner@users.noreply.github.com>

* Update modules/apps/27-interchain-accounts/keeper/handshake.go

Co-authored-by: colin axnér <25233464+colin-axner@users.noreply.github.com>

* Update modules/apps/27-interchain-accounts/types/account.go

Co-authored-by: colin axnér <25233464+colin-axner@users.noreply.github.com>

* removing pipe from valid identifier regex

* adding error returns to parsing funcs, updating tests, error messages

* separting imports in keys.go

* updating handshake tests

* Update modules/apps/27-interchain-accounts/types/keys.go

Co-authored-by: colin axnér <25233464+colin-axner@users.noreply.github.com>

* renaming validation func, removing parenthesis in error msgs

* renaming func validateControllerPort -> validateControllerPortParams

Co-authored-by: colin axnér <25233464+colin-axner@users.noreply.github.com>
  • Loading branch information
damiannolan and colin-axner authored Oct 7, 2021
1 parent 1f87f2e commit f129376
Show file tree
Hide file tree
Showing 9 changed files with 455 additions and 59 deletions.
88 changes: 82 additions & 6 deletions modules/apps/27-interchain-accounts/keeper/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
capabilitytypes "github.com/cosmos/cosmos-sdk/x/capability/types"

"github.com/cosmos/ibc-go/v2/modules/apps/27-interchain-accounts/types"
connectiontypes "github.com/cosmos/ibc-go/v2/modules/core/03-connection/types"
channeltypes "github.com/cosmos/ibc-go/v2/modules/core/04-channel/types"
porttypes "github.com/cosmos/ibc-go/v2/modules/core/05-port/types"
host "github.com/cosmos/ibc-go/v2/modules/core/24-host"
Expand All @@ -31,10 +32,25 @@ func (k Keeper) OnChanOpenInit(
version string,
) error {
if order != channeltypes.ORDERED {
return sdkerrors.Wrapf(channeltypes.ErrInvalidChannelOrdering, "invalid channel ordering: %s, expected %s", order.String(), channeltypes.ORDERED.String())
return sdkerrors.Wrapf(channeltypes.ErrInvalidChannelOrdering, "expected %s channel, got %s", channeltypes.ORDERED, order)
}

connSequence, err := types.ParseControllerConnSequence(portID)
if err != nil {
return sdkerrors.Wrapf(err, "expected format %s, got %s", types.ControllerPortFormat, portID)
}

counterpartyConnSequence, err := types.ParseHostConnSequence(portID)
if err != nil {
return sdkerrors.Wrapf(err, "expected format %s, got %s", types.ControllerPortFormat, portID)
}

if err := k.validateControllerPortParams(ctx, channelID, portID, connSequence, counterpartyConnSequence); err != nil {
return sdkerrors.Wrapf(err, "failed to validate controller port %s", portID)
}

if counterparty.PortId != types.PortID {
return sdkerrors.Wrapf(porttypes.ErrInvalidPort, "counterparty port-id must be '%s', (%s != %s)", types.PortID, counterparty.PortId, types.PortID)
return sdkerrors.Wrapf(porttypes.ErrInvalidPort, "expected %s, got %s", types.PortID, counterparty.PortId)
}

if err := types.ValidateVersion(version); err != nil {
Expand All @@ -43,7 +59,7 @@ func (k Keeper) OnChanOpenInit(

existingChannelID, found := k.GetActiveChannel(ctx, portID)
if found {
return sdkerrors.Wrapf(porttypes.ErrInvalidPort, "existing active channel (%s) for portID (%s)", existingChannelID, portID)
return sdkerrors.Wrapf(porttypes.ErrInvalidPort, "existing active channel %s for portID %s", existingChannelID, portID)
}

// Claim channel capability passed back by IBC module
Expand All @@ -70,7 +86,25 @@ func (k Keeper) OnChanOpenTry(
counterpartyVersion string,
) error {
if order != channeltypes.ORDERED {
return sdkerrors.Wrapf(channeltypes.ErrInvalidChannelOrdering, "invalid channel ordering: %s, expected %s", order.String(), channeltypes.ORDERED.String())
return sdkerrors.Wrapf(channeltypes.ErrInvalidChannelOrdering, "expected %s channel, got %s", channeltypes.ORDERED, order)
}

if portID != types.PortID {
return sdkerrors.Wrapf(porttypes.ErrInvalidPort, "expected %s, got %s", types.PortID, portID)
}

connSequence, err := types.ParseHostConnSequence(counterparty.PortId)
if err != nil {
return sdkerrors.Wrapf(err, "expected format %s, got %s", types.ControllerPortFormat, counterparty.PortId)
}

counterpartyConnSequence, err := types.ParseControllerConnSequence(counterparty.PortId)
if err != nil {
return sdkerrors.Wrapf(err, "expected format %s, got %s", types.ControllerPortFormat, counterparty.PortId)
}

if err := k.validateControllerPortParams(ctx, channelID, portID, connSequence, counterpartyConnSequence); err != nil {
return sdkerrors.Wrapf(err, "failed to validate controller port %s", counterparty.PortId)
}

if err := types.ValidateVersion(version); err != nil {
Expand All @@ -89,7 +123,11 @@ func (k Keeper) OnChanOpenTry(

// Check to ensure that the version string contains the expected address generated from the Counterparty portID
accAddr := types.GenerateAddress(k.accountKeeper.GetModuleAddress(types.ModuleName), counterparty.PortId)
parsedAddr := types.ParseAddressFromVersion(version)
parsedAddr, err := types.ParseAddressFromVersion(version)
if err != nil {
return sdkerrors.Wrapf(err, "expected format <app-version%saccount-address>, got %s", types.Delimiter, version)
}

if parsedAddr != accAddr.String() {
return sdkerrors.Wrapf(types.ErrInvalidAccountAddress, "version contains invalid account address: expected %s, got %s", parsedAddr, accAddr)
}
Expand All @@ -116,7 +154,11 @@ func (k Keeper) OnChanOpenAck(

k.SetActiveChannel(ctx, portID, channelID)

accAddr := types.ParseAddressFromVersion(counterpartyVersion)
accAddr, err := types.ParseAddressFromVersion(counterpartyVersion)
if err != nil {
return sdkerrors.Wrapf(err, "expected format <app-version%saccount-address>, got %s", types.Delimiter, counterpartyVersion)
}

k.SetInterchainAccountAddress(ctx, portID, accAddr)

return nil
Expand All @@ -130,3 +172,37 @@ func (k Keeper) OnChanOpenConfirm(
) error {
return nil
}

// validateControllerPortParams asserts the provided connection sequence and counterparty connection sequence
// match that of the associated connection stored in state
func (k Keeper) validateControllerPortParams(ctx sdk.Context, channelID, portID string, connectionSeq, counterpartyConnectionSeq uint64) error {
channel, found := k.channelKeeper.GetChannel(ctx, portID, channelID)
if !found {
return sdkerrors.Wrapf(channeltypes.ErrChannelNotFound, "port ID %s channel ID %s", portID, channelID)
}

counterpartyHops, found := k.channelKeeper.CounterpartyHops(ctx, channel)
if !found {
return sdkerrors.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0])
}

connSeq, err := connectiontypes.ParseConnectionSequence(channel.ConnectionHops[0])
if err != nil {
return sdkerrors.Wrapf(err, "failed to parse connection sequence %s", channel.ConnectionHops[0])
}

counterpartyConnSeq, err := connectiontypes.ParseConnectionSequence(counterpartyHops[0])
if err != nil {
return sdkerrors.Wrapf(err, "failed to parse counterparty connection sequence %s", counterpartyHops[0])
}

if connSeq != connectionSeq {
return sdkerrors.Wrapf(connectiontypes.ErrInvalidConnection, "sequence mismatch, expected %d, got %d", connSeq, connectionSeq)
}

if counterpartyConnSeq != counterpartyConnectionSeq {
return sdkerrors.Wrapf(connectiontypes.ErrInvalidConnection, "counterparty sequence mismatch, expected %d, got %d", counterpartyConnSeq, counterpartyConnectionSeq)
}

return nil
}
186 changes: 161 additions & 25 deletions modules/apps/27-interchain-accounts/keeper/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,33 +23,94 @@ func (suite *KeeperTestSuite) TestOnChanOpenInit() {
}{

{
"success", func() {}, true,
"success",
func() {
path.EndpointA.SetChannel(*channel)
},
true,
},
{
"invalid order - UNORDERED", func() {
"invalid order - UNORDERED",
func() {
channel.Ordering = channeltypes.UNORDERED
}, false,
},
false,
},
{
"invalid counterparty port ID", func() {
channel.Counterparty.PortId = ibctesting.MockPort
}, false,
"invalid port ID",
func() {
path.EndpointA.ChannelConfig.PortID = "invalid-port-id"
},
false,
},
{
"invalid version", func() {
"invalid counterparty port ID",
func() {
path.EndpointA.SetChannel(*channel)
channel.Counterparty.PortId = "invalid-port-id"
},
false,
},
{
"invalid version",
func() {
path.EndpointA.SetChannel(*channel)
channel.Version = "version"
}, false,
},
false,
},
{
"channel is already active", func() {
"channel not found",
func() {
path.EndpointA.ChannelID = "invalid-channel-id"
},
false,
},
{
"connection not found",
func() {
channel.ConnectionHops = []string{"invalid-connnection-id"}
path.EndpointA.SetChannel(*channel)
},
false,
},
{
"invalid connection sequence",
func() {
portID, err := types.GeneratePortID(TestOwnerAddress, "connection-1", "connection-0")
suite.Require().NoError(err)

path.EndpointA.ChannelConfig.PortID = portID
path.EndpointA.SetChannel(*channel)
},
false,
},
{
"invalid counterparty connection sequence",
func() {
portID, err := types.GeneratePortID(TestOwnerAddress, "connection-0", "connection-1")
suite.Require().NoError(err)

path.EndpointA.ChannelConfig.PortID = portID
path.EndpointA.SetChannel(*channel)
},
false,
},
{
"channel is already active",
func() {
suite.chainA.GetSimApp().ICAKeeper.SetActiveChannel(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
}, false,
},
false,
},
{
"capability already claimed", func() {
"capability already claimed",
func() {
path.EndpointA.SetChannel(*channel)
err := suite.chainA.GetSimApp().ScopedICAKeeper.ClaimCapability(suite.chainA.GetContext(), chanCap, host.ChannelCapabilityPath(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID))
suite.Require().NoError(err)
}, false,
},
false,
},
}

Expand Down Expand Up @@ -97,7 +158,6 @@ func (suite *KeeperTestSuite) TestOnChanOpenInit() {
}
}

// ChainA is controller, ChainB is host chain
func (suite *KeeperTestSuite) TestOnChanOpenTry() {
var (
channel *channeltypes.Channel
Expand All @@ -113,33 +173,105 @@ func (suite *KeeperTestSuite) TestOnChanOpenTry() {
}{

{
"success", func() {}, true,
"success",
func() {
path.EndpointB.SetChannel(*channel)
},
true,
},
{
"invalid order - UNORDERED", func() {
"invalid order - UNORDERED",
func() {
channel.Ordering = channeltypes.UNORDERED
}, false,
},
false,
},
{
"invalid port",
func() {
path.EndpointB.ChannelConfig.PortID = "invalid-port-id"
},
false,
},
{
"invalid counterparty port",
func() {
channel.Counterparty.PortId = "invalid-port-id"
},
false,
},
{
"invalid version", func() {
"channel not found",
func() {
path.EndpointB.ChannelID = "invalid-channel-id"
},
false,
},
{
"connection not found",
func() {
channel.ConnectionHops = []string{"invalid-connnection-id"}
path.EndpointB.SetChannel(*channel)
},
false,
},
{
"invalid connection sequence",
func() {
portID, err := types.GeneratePortID(TestOwnerAddress, "connection-0", "connection-1")
suite.Require().NoError(err)

channel.Counterparty.PortId = portID
path.EndpointB.SetChannel(*channel)
},
false,
},
{
"invalid counterparty connection sequence",
func() {
portID, err := types.GeneratePortID(TestOwnerAddress, "connection-1", "connection-0")
suite.Require().NoError(err)

channel.Counterparty.PortId = portID
path.EndpointB.SetChannel(*channel)
},
false,
},
{
"invalid version",
func() {
channel.Version = "version"
}, false,
path.EndpointB.SetChannel(*channel)
},
false,
},
{
"invalid counterparty version", func() {
"invalid counterparty version",
func() {
counterpartyVersion = "version"
}, false,
path.EndpointB.SetChannel(*channel)
},
false,
},
{
"capability already claimed", func() {
"capability already claimed",
func() {
path.EndpointB.SetChannel(*channel)
err := suite.chainB.GetSimApp().ScopedICAKeeper.ClaimCapability(suite.chainB.GetContext(), chanCap, host.ChannelCapabilityPath(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID))
suite.Require().NoError(err)
}, false,
},
false,
},
{
"invalid account address", func() {
channel.Counterparty.PortId = "invalid-port-id"
}, false,
"invalid account address",
func() {
portID, err := types.GeneratePortID("invalid-owner-addr", "connection-0", "connection-0")
suite.Require().NoError(err)

channel.Counterparty.PortId = portID
path.EndpointB.SetChannel(*channel)
},
false,
},
}

Expand All @@ -155,6 +287,10 @@ func (suite *KeeperTestSuite) TestOnChanOpenTry() {
err := InitInterchainAccount(path.EndpointA, TestOwnerAddress)
suite.Require().NoError(err)

// set the channel id on host
channelSequence := path.EndpointB.Chain.App.GetIBCKeeper().ChannelKeeper.GetNextChannelSequence(path.EndpointB.Chain.GetContext())
path.EndpointB.ChannelID = channeltypes.FormatChannelIdentifier(channelSequence)

// default values
counterparty := channeltypes.NewCounterparty(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
channel = &channeltypes.Channel{
Expand Down
3 changes: 1 addition & 2 deletions modules/apps/27-interchain-accounts/keeper/keeper_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package keeper_test

import (
"fmt"
"testing"

sdk "github.com/cosmos/cosmos-sdk/types"
Expand All @@ -21,7 +20,7 @@ var (
// TestOwnerAddress defines a reusable bech32 address for testing purposes
TestOwnerAddress = "cosmos17dtl0mjt3t77kpuhg2edqzjpszulwhgzuj9ljs"
// TestPortID defines a resuable port identifier for testing purposes
TestPortID = fmt.Sprintf("%s-0-0-%s", types.VersionPrefix, TestOwnerAddress)
TestPortID, _ = types.GeneratePortID(TestOwnerAddress, "connection-0", "connection-0")
// TestVersion defines a resuable interchainaccounts version string for testing purposes
TestVersion = types.NewAppVersion(types.VersionPrefix, TestAccAddress.String())
)
Expand Down
Loading

0 comments on commit f129376

Please sign in to comment.