diff --git a/app/app.go b/app/app.go index 2545a2ef6d..85cd58d8bd 100644 --- a/app/app.go +++ b/app/app.go @@ -12,6 +12,7 @@ import ( tmproto "github.com/cometbft/cometbft/proto/tendermint/types" dbm "github.com/cosmos/cosmos-db" "github.com/cosmos/gogoproto/proto" + ibccallbacks "github.com/cosmos/ibc-go/modules/apps/callbacks" "github.com/cosmos/ibc-go/modules/capability" capabilitykeeper "github.com/cosmos/ibc-go/modules/capability/keeper" capabilitytypes "github.com/cosmos/ibc-go/modules/capability/types" @@ -650,10 +651,10 @@ func NewWasmApp( wasmOpts..., ) - // Create Transfer Stack - var transferStack porttypes.IBCModule - transferStack = transfer.NewIBCModule(app.TransferKeeper) - transferStack = ibcfee.NewIBCMiddleware(transferStack, app.IBCFeeKeeper) + // Create fee enabled wasm ibc Stack + var wasmStack porttypes.IBCModule + wasmStackIBCHandler := wasm.NewIBCHandler(app.WasmKeeper, app.IBCKeeper.ChannelKeeper, app.IBCFeeKeeper) + wasmStack = ibcfee.NewIBCMiddleware(wasmStackIBCHandler, app.IBCFeeKeeper) // Create Interchain Accounts Stack // SendPacket, since it is originating from the application to core IBC: @@ -663,7 +664,13 @@ func NewWasmApp( // see https://medium.com/the-interchain-foundation/ibc-go-v6-changes-to-interchain-accounts-and-how-it-impacts-your-chain-806c185300d7 var noAuthzModule porttypes.IBCModule icaControllerStack = icacontroller.NewIBCMiddleware(noAuthzModule, app.ICAControllerKeeper) + // app.ICAAuthModule = icaControllerStack.(ibcmock.IBCModule) + icaControllerStack = icacontroller.NewIBCMiddleware(icaControllerStack, app.ICAControllerKeeper) + icaControllerStack = ibccallbacks.NewIBCMiddleware(icaControllerStack, app.IBCFeeKeeper, wasmStackIBCHandler, wasm.DefaultMaxIBCCallbackGas) + icaICS4Wrapper := icaControllerStack.(porttypes.ICS4Wrapper) icaControllerStack = ibcfee.NewIBCMiddleware(icaControllerStack, app.IBCFeeKeeper) + // Since the callbacks middleware itself is an ics4wrapper, it needs to be passed to the ica controller keeper + app.ICAControllerKeeper.WithICS4Wrapper(icaICS4Wrapper) // RecvPacket, message that originates from core IBC and goes down to app, the flow is: // channel.RecvPacket -> fee.OnRecvPacket -> icaHost.OnRecvPacket @@ -671,10 +678,14 @@ func NewWasmApp( icaHostStack = icahost.NewIBCModule(app.ICAHostKeeper) icaHostStack = ibcfee.NewIBCMiddleware(icaHostStack, app.IBCFeeKeeper) - // Create fee enabled wasm ibc Stack - var wasmStack porttypes.IBCModule - wasmStack = wasm.NewIBCHandler(app.WasmKeeper, app.IBCKeeper.ChannelKeeper, app.IBCFeeKeeper) - wasmStack = ibcfee.NewIBCMiddleware(wasmStack, app.IBCFeeKeeper) + // Create Transfer Stack + var transferStack porttypes.IBCModule + transferStack = transfer.NewIBCModule(app.TransferKeeper) + transferStack = ibccallbacks.NewIBCMiddleware(transferStack, app.IBCFeeKeeper, wasmStackIBCHandler, wasm.DefaultMaxIBCCallbackGas) + transferICS4Wrapper := transferStack.(porttypes.ICS4Wrapper) + transferStack = ibcfee.NewIBCMiddleware(transferStack, app.IBCFeeKeeper) + // Since the callbacks middleware itself is an ics4wrapper, it needs to be passed to the ica controller keeper + app.TransferKeeper.WithICS4Wrapper(transferICS4Wrapper) // Create static IBC router, add app routes, then set and seal it ibcRouter := porttypes.NewRouter(). diff --git a/go.mod b/go.mod index 55b7dc6672..90d55fd641 100644 --- a/go.mod +++ b/go.mod @@ -45,6 +45,7 @@ require ( cosmossdk.io/x/upgrade v0.1.3 github.com/cometbft/cometbft v0.38.9 github.com/cosmos/cosmos-db v1.0.2 + github.com/cosmos/ibc-go/modules/apps/callbacks v0.2.1-0.20231113120333-342c00b0f8bd github.com/cosmos/ibc-go/modules/capability v1.0.0 github.com/cosmos/ibc-go/v8 v8.3.2 github.com/distribution/reference v0.5.0 diff --git a/go.sum b/go.sum index 03f5de0cee..32ee3b3169 100644 --- a/go.sum +++ b/go.sum @@ -364,6 +364,8 @@ github.com/cosmos/gogoproto v1.5.0 h1:SDVwzEqZDDBoslaeZg+dGE55hdzHfgUA40pEanMh52 github.com/cosmos/gogoproto v1.5.0/go.mod h1:iUM31aofn3ymidYG6bUR5ZFrk+Om8p5s754eMUcyp8I= github.com/cosmos/iavl v1.2.0 h1:kVxTmjTh4k0Dh1VNL046v6BXqKziqMDzxo93oh3kOfM= github.com/cosmos/iavl v1.2.0/go.mod h1:HidWWLVAtODJqFD6Hbne2Y0q3SdxByJepHUOeoH4LiI= +github.com/cosmos/ibc-go/modules/apps/callbacks v0.2.1-0.20231113120333-342c00b0f8bd h1:Lx+/5dZ/nN6qPXP2Ofog6u1fmlkCFA1ElcOconnofEM= +github.com/cosmos/ibc-go/modules/apps/callbacks v0.2.1-0.20231113120333-342c00b0f8bd/go.mod h1:JWfpWVKJKiKtd53/KbRoKfxWl8FsT2GPcNezTOk0o5Q= github.com/cosmos/ibc-go/modules/capability v1.0.0 h1:r/l++byFtn7jHYa09zlAdSeevo8ci1mVZNO9+V0xsLE= github.com/cosmos/ibc-go/modules/capability v1.0.0/go.mod h1:D81ZxzjZAe0ZO5ambnvn1qedsFQ8lOwtqicG6liLBco= github.com/cosmos/ibc-go/v8 v8.3.2 h1:8X1oHHKt2Bh9hcExWS89rntLaCKZp2EjFTUSxKlPhGI= diff --git a/tests/e2e/README.md b/tests/e2e/README.md index dae38fe2df..b72bc4a1dc 100644 --- a/tests/e2e/README.md +++ b/tests/e2e/README.md @@ -1,3 +1,3 @@ # End To End Testing - e2e -Scenario tests that run against on or multiple chain instances. +Scenario tests that run against one or multiple chain instances. diff --git a/tests/e2e/ibc_callbacks_test.go b/tests/e2e/ibc_callbacks_test.go new file mode 100644 index 0000000000..03d438c08d --- /dev/null +++ b/tests/e2e/ibc_callbacks_test.go @@ -0,0 +1,225 @@ +package e2e_test + +import ( + "encoding/json" + "fmt" + "testing" + "time" + + wasmvmtypes "github.com/CosmWasm/wasmvm/v2/types" + ibcfee "github.com/cosmos/ibc-go/v8/modules/apps/29-fee/types" + ibctransfertypes "github.com/cosmos/ibc-go/v8/modules/apps/transfer/types" + channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types" + ibctesting "github.com/cosmos/ibc-go/v8/testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + sdkmath "cosmossdk.io/math" + + sdk "github.com/cosmos/cosmos-sdk/types" + + "github.com/CosmWasm/wasmd/app" + "github.com/CosmWasm/wasmd/tests/e2e" + wasmibctesting "github.com/CosmWasm/wasmd/x/wasm/ibctesting" + "github.com/CosmWasm/wasmd/x/wasm/types" +) + +func TestIBCCallbacks(t *testing.T) { + // scenario: + // given two chains + // with an ics-20 channel established + // and an ibc-callbacks contract deployed on chain A and B each + // when the contract on A sends an IBCMsg::Transfer to the contract on B + // then the contract on B should receive a destination chain callback + // and the contract on A should receive a source chain callback with the result (ack or timeout) + marshaler := app.MakeEncodingConfig(t).Codec + coord := wasmibctesting.NewCoordinator(t, 2) + chainA := coord.GetChain(wasmibctesting.GetChainID(1)) + chainB := coord.GetChain(wasmibctesting.GetChainID(2)) + + actorChainA := sdk.AccAddress(chainA.SenderPrivKey.PubKey().Address()) + oneToken := sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, sdkmath.NewInt(1))) + + path := wasmibctesting.NewPath(chainA, chainB) + path.EndpointA.ChannelConfig = &ibctesting.ChannelConfig{ + PortID: ibctransfertypes.PortID, + Version: string(marshaler.MustMarshalJSON(&ibcfee.Metadata{FeeVersion: ibcfee.Version, AppVersion: ibctransfertypes.Version})), + Order: channeltypes.UNORDERED, + } + path.EndpointB.ChannelConfig = &ibctesting.ChannelConfig{ + PortID: ibctransfertypes.PortID, + Version: string(marshaler.MustMarshalJSON(&ibcfee.Metadata{FeeVersion: ibcfee.Version, AppVersion: ibctransfertypes.Version})), + Order: channeltypes.UNORDERED, + } + // with an ics-20 transfer channel setup between both chains + coord.Setup(path) + + // with an ibc-callbacks contract deployed on chain A + codeIDonA := chainA.StoreCodeFile("./testdata/ibc_callbacks.wasm").CodeID + + // and on chain B + codeIDonB := chainB.StoreCodeFile("./testdata/ibc_callbacks.wasm").CodeID + + type TransferExecMsg struct { + ToAddress string `json:"to_address"` + ChannelID string `json:"channel_id"` + TimeoutSeconds uint32 `json:"timeout_seconds"` + } + // ExecuteMsg is the ibc-callbacks contract's execute msg + type ExecuteMsg struct { + Transfer *TransferExecMsg `json:"transfer"` + } + type QueryMsg struct { + CallbackStats struct{} `json:"callback_stats"` + } + type QueryResp struct { + IBCAckCallbacks []wasmvmtypes.IBCPacketAckMsg `json:"ibc_ack_callbacks"` + IBCTimeoutCallbacks []wasmvmtypes.IBCPacketTimeoutMsg `json:"ibc_timeout_callbacks"` + IBCDestinationCallbacks []wasmvmtypes.IBCDestinationCallbackMsg `json:"ibc_destination_callbacks"` + } + + specs := map[string]struct { + contractMsg ExecuteMsg + // expAck is true if the packet is relayed, false if it times out + expAck bool + }{ + "success": { + contractMsg: ExecuteMsg{ + Transfer: &TransferExecMsg{ + ChannelID: path.EndpointA.ChannelID, + TimeoutSeconds: 100, + }, + }, + expAck: true, + }, + "timeout": { + contractMsg: ExecuteMsg{ + Transfer: &TransferExecMsg{ + ChannelID: path.EndpointA.ChannelID, + TimeoutSeconds: 1, + }, + }, + expAck: false, + }, + } + + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + contractAddrA := chainA.InstantiateContract(codeIDonA, []byte(`{}`)) + require.NotEmpty(t, contractAddrA) + contractAddrB := chainB.InstantiateContract(codeIDonB, []byte(`{}`)) + require.NotEmpty(t, contractAddrB) + + if spec.contractMsg.Transfer != nil && spec.contractMsg.Transfer.ToAddress == "" { + spec.contractMsg.Transfer.ToAddress = contractAddrB.String() + } + contractMsgBz, err := json.Marshal(spec.contractMsg) + require.NoError(t, err) + + // when the contract on chain A sends an IBCMsg::Transfer to the contract on chain B + execMsg := types.MsgExecuteContract{ + Sender: actorChainA.String(), + Contract: contractAddrA.String(), + Msg: contractMsgBz, + Funds: oneToken, + } + _, err = chainA.SendMsgs(&execMsg) + require.NoError(t, err) + + if spec.expAck { + // and the packet is relayed + require.NoError(t, coord.RelayAndAckPendingPackets(path)) + + // then the contract on chain B should receive a receive callback + var response QueryResp + chainB.SmartQuery(contractAddrB.String(), QueryMsg{CallbackStats: struct{}{}}, &response) + assert.Empty(t, response.IBCAckCallbacks) + assert.Empty(t, response.IBCTimeoutCallbacks) + assert.Len(t, response.IBCDestinationCallbacks, 1) + + // and the receive callback should contain the ack + assert.Equal(t, []byte("{\"result\":\"AQ==\"}"), response.IBCDestinationCallbacks[0].Ack.Data) + + // and the contract on chain A should receive a callback with the ack + chainA.SmartQuery(contractAddrA.String(), QueryMsg{CallbackStats: struct{}{}}, &response) + assert.Len(t, response.IBCAckCallbacks, 1) + assert.Empty(t, response.IBCTimeoutCallbacks) + assert.Empty(t, response.IBCDestinationCallbacks) + + // and the ack result should be the ics20 success ack + assert.Equal(t, []byte(`{"result":"AQ=="}`), response.IBCAckCallbacks[0].Acknowledgement.Data) + } else { + // and the packet times out + require.NoError(t, coord.TimeoutPendingPackets(path)) + + // then the contract on chain B should not receive anything + var response QueryResp + chainB.SmartQuery(contractAddrB.String(), QueryMsg{CallbackStats: struct{}{}}, &response) + assert.Empty(t, response.IBCAckCallbacks) + assert.Empty(t, response.IBCTimeoutCallbacks) + assert.Empty(t, response.IBCDestinationCallbacks) + + // and the contract on chain A should receive a callback with the timeout result + chainA.SmartQuery(contractAddrA.String(), QueryMsg{CallbackStats: struct{}{}}, &response) + assert.Empty(t, response.IBCAckCallbacks) + assert.Len(t, response.IBCTimeoutCallbacks, 1) + assert.Empty(t, response.IBCDestinationCallbacks) + } + }) + } +} + +func TestIBCCallbacksWithoutEntrypoints(t *testing.T) { + // scenario: + // given two chains + // with an ics-20 channel established + // and a reflect contract deployed on chain A and B each + // when the contract on A sends an IBCMsg::Transfer to the contract on B + // then the VM should try to call the callback on B and fail gracefully + // and should try to call the callback on A and fail gracefully + marshaler := app.MakeEncodingConfig(t).Codec + coord := wasmibctesting.NewCoordinator(t, 2) + chainA := coord.GetChain(wasmibctesting.GetChainID(1)) + chainB := coord.GetChain(wasmibctesting.GetChainID(2)) + + oneToken := sdk.NewCoin(sdk.DefaultBondDenom, sdkmath.NewInt(1)) + + path := wasmibctesting.NewPath(chainA, chainB) + path.EndpointA.ChannelConfig = &ibctesting.ChannelConfig{ + PortID: ibctransfertypes.PortID, + Version: string(marshaler.MustMarshalJSON(&ibcfee.Metadata{FeeVersion: ibcfee.Version, AppVersion: ibctransfertypes.Version})), + Order: channeltypes.UNORDERED, + } + path.EndpointB.ChannelConfig = &ibctesting.ChannelConfig{ + PortID: ibctransfertypes.PortID, + Version: string(marshaler.MustMarshalJSON(&ibcfee.Metadata{FeeVersion: ibcfee.Version, AppVersion: ibctransfertypes.Version})), + Order: channeltypes.UNORDERED, + } + // with an ics-20 transfer channel setup between both chains + coord.Setup(path) + + // with a reflect contract deployed on chain A and B + contractAddrA := e2e.InstantiateReflectContract(t, chainA) + chainA.Fund(contractAddrA, oneToken.Amount) + contractAddrB := e2e.InstantiateReflectContract(t, chainA) + + // when the contract on A sends an IBCMsg::Transfer to the contract on B + memo := fmt.Sprintf(`{"src_callback":{"address":"%v"},"dest_callback":{"address":"%v"}}`, contractAddrA.String(), contractAddrB.String()) + e2e.MustExecViaReflectContract(t, chainA, contractAddrA, wasmvmtypes.CosmosMsg{ + IBC: &wasmvmtypes.IBCMsg{ + Transfer: &wasmvmtypes.TransferMsg{ + ToAddress: contractAddrB.String(), + ChannelID: path.EndpointA.ChannelID, + Amount: wasmvmtypes.NewCoin(oneToken.Amount.Uint64(), oneToken.Denom), + Timeout: wasmvmtypes.IBCTimeout{ + Timestamp: uint64(chainA.LastHeader.GetTime().Add(time.Second * 100).UnixNano()), + }, + Memo: memo, + }, + }, + }) + + // and the packet is relayed without problems + require.NoError(t, coord.RelayAndAckPendingPackets(path)) + assert.Empty(t, chainA.PendingSendPackets) +} diff --git a/tests/e2e/testdata/ibc_callbacks.wasm b/tests/e2e/testdata/ibc_callbacks.wasm new file mode 100644 index 0000000000..63519505aa Binary files /dev/null and b/tests/e2e/testdata/ibc_callbacks.wasm differ diff --git a/x/wasm/ibc.go b/x/wasm/ibc.go index f58e0dd644..b0311f5227 100644 --- a/x/wasm/ibc.go +++ b/x/wasm/ibc.go @@ -5,6 +5,7 @@ import ( wasmvmtypes "github.com/CosmWasm/wasmvm/v2/types" capabilitytypes "github.com/cosmos/ibc-go/modules/capability/types" + clienttypes "github.com/cosmos/ibc-go/v8/modules/core/02-client/types" channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types" porttypes "github.com/cosmos/ibc-go/v8/modules/core/05-port/types" host "github.com/cosmos/ibc-go/v8/modules/core/24-host" @@ -18,6 +19,12 @@ import ( "github.com/CosmWasm/wasmd/x/wasm/types" ) +// DefaultMaxIBCCallbackGas is the default value of maximum gas that an IBC callback can use. +// If the callback uses more gas, it will be out of gas and the contract state changes will be reverted, +// but the transaction will be committed. +// Pass this to the callbacks middleware or choose a custom value. +const DefaultMaxIBCCallbackGas = uint64(1_000_000) + var _ porttypes.IBCModule = IBCHandler{} // internal interface that is implemented by ibc middleware @@ -326,22 +333,146 @@ func (i IBCHandler) OnTimeoutPacket(ctx sdk.Context, packet channeltypes.Packet, return nil } -func newIBCPacket(packet channeltypes.Packet) wasmvmtypes.IBCPacket { +// IBCSendPacketCallback implements the IBC Callbacks ContractKeeper interface +// see https://github.com/cosmos/ibc-go/blob/main/docs/architecture/adr-008-app-caller-cbs.md#contractkeeper +func (i IBCHandler) IBCSendPacketCallback( + cachedCtx sdk.Context, + sourcePort string, + sourceChannel string, + timeoutHeight clienttypes.Height, + timeoutTimestamp uint64, + packetData []byte, + contractAddress, + packetSenderAddress string, +) error { + _, err := validateSender(contractAddress, packetSenderAddress) + if err != nil { + return err + } + + // no-op, since we are not interested in this callback + return nil +} + +// IBCOnAcknowledgementPacketCallback implements the IBC Callbacks ContractKeeper interface +// see https://github.com/cosmos/ibc-go/blob/main/docs/architecture/adr-008-app-caller-cbs.md#contractkeeper +func (i IBCHandler) IBCOnAcknowledgementPacketCallback( + cachedCtx sdk.Context, + packet channeltypes.Packet, + acknowledgement []byte, + relayer sdk.AccAddress, + contractAddress, + packetSenderAddress string, +) error { + contractAddr, err := validateSender(contractAddress, packetSenderAddress) + if err != nil { + return err + } + + msg := wasmvmtypes.IBCSourceCallbackMsg{ + Acknowledgement: &wasmvmtypes.IBCAckCallbackMsg{ + Acknowledgement: wasmvmtypes.IBCAcknowledgement{Data: acknowledgement}, + OriginalPacket: newIBCPacket(packet), + Relayer: relayer.String(), + }, + } + err = i.keeper.IBCSourceCallback(cachedCtx, contractAddr, msg) + if err != nil { + return errorsmod.Wrap(err, "on source chain callback ack") + } + + return nil +} + +// IBCOnTimeoutPacketCallback implements the IBC Callbacks ContractKeeper interface +// see https://github.com/cosmos/ibc-go/blob/main/docs/architecture/adr-008-app-caller-cbs.md#contractkeeper +func (i IBCHandler) IBCOnTimeoutPacketCallback( + cachedCtx sdk.Context, + packet channeltypes.Packet, + relayer sdk.AccAddress, + contractAddress, + packetSenderAddress string, +) error { + contractAddr, err := validateSender(contractAddress, packetSenderAddress) + if err != nil { + return err + } + + msg := wasmvmtypes.IBCSourceCallbackMsg{ + Timeout: &wasmvmtypes.IBCTimeoutCallbackMsg{ + Packet: newIBCPacket(packet), + Relayer: relayer.String(), + }, + } + err = i.keeper.IBCSourceCallback(cachedCtx, contractAddr, msg) + if err != nil { + return errorsmod.Wrap(err, "on source chain callback timeout") + } + return nil +} + +// IBCReceivePacketCallback implements the IBC Callbacks ContractKeeper interface +// see https://github.com/cosmos/ibc-go/blob/main/docs/architecture/adr-008-app-caller-cbs.md#contractkeeper +func (i IBCHandler) IBCReceivePacketCallback( + cachedCtx sdk.Context, + packet ibcexported.PacketI, + ack ibcexported.Acknowledgement, + contractAddress string, +) error { + // sender validation makes no sense here, as the receiver is never the sender + contractAddr, err := sdk.AccAddressFromBech32(contractAddress) + if err != nil { + return err + } + + msg := wasmvmtypes.IBCDestinationCallbackMsg{ + Ack: wasmvmtypes.IBCAcknowledgement{Data: ack.Acknowledgement()}, + Packet: newIBCPacket(packet), + } + + err = i.keeper.IBCDestinationCallback(cachedCtx, contractAddr, msg) + if err != nil { + return errorsmod.Wrap(err, "on destination chain callback") + } + + return nil +} + +func validateSender(contractAddr, senderAddr string) (sdk.AccAddress, error) { + contractAddress, err := sdk.AccAddressFromBech32(contractAddr) + if err != nil { + return nil, errorsmod.Wrapf(err, "contract address") + } + senderAddress, err := sdk.AccAddressFromBech32(senderAddr) + if err != nil { + return nil, errorsmod.Wrapf(err, "packet sender address") + } + + // We only allow the contract that sent the message to receive source chain callbacks for it. + if !contractAddress.Equals(senderAddress) { + return nil, errorsmod.Wrapf(types.ErrExecuteFailed, "contract address %s does not match packet sender %s", contractAddr, senderAddress) + } + + return contractAddress, nil +} + +func newIBCPacket(packet ibcexported.PacketI) wasmvmtypes.IBCPacket { timeout := wasmvmtypes.IBCTimeout{ - Timestamp: packet.TimeoutTimestamp, + Timestamp: packet.GetTimeoutTimestamp(), } - if !packet.TimeoutHeight.IsZero() { + timeoutHeight := packet.GetTimeoutHeight() + if !timeoutHeight.IsZero() { timeout.Block = &wasmvmtypes.IBCTimeoutBlock{ - Height: packet.TimeoutHeight.RevisionHeight, - Revision: packet.TimeoutHeight.RevisionNumber, + Height: timeoutHeight.GetRevisionHeight(), + Revision: timeoutHeight.GetRevisionNumber(), } } return wasmvmtypes.IBCPacket{ - Data: packet.Data, - Src: wasmvmtypes.IBCEndpoint{ChannelID: packet.SourceChannel, PortID: packet.SourcePort}, - Dest: wasmvmtypes.IBCEndpoint{ChannelID: packet.DestinationChannel, PortID: packet.DestinationPort}, - Sequence: packet.Sequence, + Data: packet.GetData(), + Src: wasmvmtypes.IBCEndpoint{ChannelID: packet.GetSourceChannel(), PortID: packet.GetSourcePort()}, + Dest: wasmvmtypes.IBCEndpoint{ChannelID: packet.GetDestChannel(), PortID: packet.GetDestPort()}, + Sequence: packet.GetSequence(), Timeout: timeout, } } diff --git a/x/wasm/keeper/relay.go b/x/wasm/keeper/relay.go index 65cb51dbf4..a12293197b 100644 --- a/x/wasm/keeper/relay.go +++ b/x/wasm/keeper/relay.go @@ -270,6 +270,74 @@ func (k Keeper) OnTimeoutPacket( return k.handleIBCBasicContractResponse(ctx, contractAddr, contractInfo.IBCPortID, res.Ok) } +// IBCSourceCallback calls the contract to let it know the packet triggered by its +// IBC-callbacks-enabled message either timed out or was acknowledged. +func (k Keeper) IBCSourceCallback( + ctx sdk.Context, + contractAddr sdk.AccAddress, + msg wasmvmtypes.IBCSourceCallbackMsg, +) error { + defer telemetry.MeasureSince(time.Now(), "wasm", "contract", "ibc-source-chain-callback") + + contractInfo, codeInfo, prefixStore, err := k.contractInstance(ctx, contractAddr) + if err != nil { + return err + } + + env := types.NewEnv(ctx, contractAddr) + querier := k.newQueryHandler(ctx, contractAddr) + + gasLeft := k.runtimeGasForContract(ctx) + res, gasUsed, execErr := k.wasmVM.IBCSourceCallback(codeInfo.CodeHash, env, msg, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gasLeft, costJSONDeserialization) + k.consumeRuntimeGas(ctx, gasUsed) + if execErr != nil { + return errorsmod.Wrap(types.ErrExecuteFailed, execErr.Error()) + } + if res == nil { + // If this gets executed, that's a bug in wasmvm + return errorsmod.Wrap(types.ErrVMError, "internal wasmvm error") + } + if res.Err != "" { + return types.MarkErrorDeterministic(errorsmod.Wrap(types.ErrExecuteFailed, res.Err)) + } + + return k.handleIBCBasicContractResponse(ctx, contractAddr, contractInfo.IBCPortID, res.Ok) +} + +// IBCDestinationCallback calls the contract to let it know that it received a packet of an +// IBC-callbacks-enabled message that was acknowledged. +func (k Keeper) IBCDestinationCallback( + ctx sdk.Context, + contractAddr sdk.AccAddress, + msg wasmvmtypes.IBCDestinationCallbackMsg, +) error { + defer telemetry.MeasureSince(time.Now(), "wasm", "contract", "ibc-destination-chain-callback") + + contractInfo, codeInfo, prefixStore, err := k.contractInstance(ctx, contractAddr) + if err != nil { + return err + } + + env := types.NewEnv(ctx, contractAddr) + querier := k.newQueryHandler(ctx, contractAddr) + + gasLeft := k.runtimeGasForContract(ctx) + res, gasUsed, execErr := k.wasmVM.IBCDestinationCallback(codeInfo.CodeHash, env, msg, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gasLeft, costJSONDeserialization) + k.consumeRuntimeGas(ctx, gasUsed) + if execErr != nil { + return errorsmod.Wrap(types.ErrExecuteFailed, execErr.Error()) + } + if res == nil { + // If this gets executed, that's a bug in wasmvm + return errorsmod.Wrap(types.ErrVMError, "internal wasmvm error") + } + if res.Err != "" { + return types.MarkErrorDeterministic(errorsmod.Wrap(types.ErrExecuteFailed, res.Err)) + } + + return k.handleIBCBasicContractResponse(ctx, contractAddr, contractInfo.IBCPortID, res.Ok) +} + func (k Keeper) handleIBCBasicContractResponse(ctx sdk.Context, addr sdk.AccAddress, id string, res *wasmvmtypes.IBCBasicResponse) error { _, err := k.handleContractResponse(ctx, addr, id, res.Messages, res.Attributes, nil, res.Events) return err diff --git a/x/wasm/keeper/wasmtesting/mock_engine.go b/x/wasm/keeper/wasmtesting/mock_engine.go index f35386d7fd..14a3976d1a 100644 --- a/x/wasm/keeper/wasmtesting/mock_engine.go +++ b/x/wasm/keeper/wasmtesting/mock_engine.go @@ -22,27 +22,29 @@ var _ types.WasmEngine = &MockWasmEngine{} // MockWasmEngine implements types.WasmEngine for testing purpose. One or multiple messages can be stubbed. // Without a stub function a panic is thrown. type MockWasmEngine struct { - StoreCodeFn func(codeID wasmvm.WasmCode, gasLimit uint64) (wasmvm.Checksum, uint64, error) - StoreCodeUncheckedFn func(codeID wasmvm.WasmCode) (wasmvm.Checksum, error) - AnalyzeCodeFn func(codeID wasmvm.Checksum) (*wasmvmtypes.AnalysisReport, error) - InstantiateFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, info wasmvmtypes.MessageInfo, initMsg []byte, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.ContractResult, uint64, error) - ExecuteFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, info wasmvmtypes.MessageInfo, executeMsg []byte, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.ContractResult, uint64, error) - QueryFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, queryMsg []byte, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.QueryResult, uint64, error) - MigrateFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, migrateMsg []byte, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.ContractResult, uint64, error) - SudoFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, sudoMsg []byte, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.ContractResult, uint64, error) - ReplyFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, reply wasmvmtypes.Reply, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.ContractResult, uint64, error) - GetCodeFn func(codeID wasmvm.Checksum) (wasmvm.WasmCode, error) - CleanupFn func() - IBCChannelOpenFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, msg wasmvmtypes.IBCChannelOpenMsg, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.IBCChannelOpenResult, uint64, error) - IBCChannelConnectFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, msg wasmvmtypes.IBCChannelConnectMsg, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.IBCBasicResult, uint64, error) - IBCChannelCloseFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, msg wasmvmtypes.IBCChannelCloseMsg, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.IBCBasicResult, uint64, error) - IBCPacketReceiveFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, msg wasmvmtypes.IBCPacketReceiveMsg, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.IBCReceiveResult, uint64, error) - IBCPacketAckFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, msg wasmvmtypes.IBCPacketAckMsg, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.IBCBasicResult, uint64, error) - IBCPacketTimeoutFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, msg wasmvmtypes.IBCPacketTimeoutMsg, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.IBCBasicResult, uint64, error) - PinFn func(checksum wasmvm.Checksum) error - UnpinFn func(checksum wasmvm.Checksum) error - GetMetricsFn func() (*wasmvmtypes.Metrics, error) - GetPinMetricsFn func() (*wasmvmtypes.PinnedMetrics, error) + StoreCodeFn func(codeID wasmvm.WasmCode, gasLimit uint64) (wasmvm.Checksum, uint64, error) + StoreCodeUncheckedFn func(codeID wasmvm.WasmCode) (wasmvm.Checksum, error) + AnalyzeCodeFn func(codeID wasmvm.Checksum) (*wasmvmtypes.AnalysisReport, error) + InstantiateFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, info wasmvmtypes.MessageInfo, initMsg []byte, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.ContractResult, uint64, error) + ExecuteFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, info wasmvmtypes.MessageInfo, executeMsg []byte, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.ContractResult, uint64, error) + QueryFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, queryMsg []byte, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.QueryResult, uint64, error) + MigrateFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, migrateMsg []byte, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.ContractResult, uint64, error) + SudoFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, sudoMsg []byte, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.ContractResult, uint64, error) + ReplyFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, reply wasmvmtypes.Reply, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.ContractResult, uint64, error) + GetCodeFn func(codeID wasmvm.Checksum) (wasmvm.WasmCode, error) + CleanupFn func() + IBCChannelOpenFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, msg wasmvmtypes.IBCChannelOpenMsg, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.IBCChannelOpenResult, uint64, error) + IBCChannelConnectFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, msg wasmvmtypes.IBCChannelConnectMsg, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.IBCBasicResult, uint64, error) + IBCChannelCloseFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, msg wasmvmtypes.IBCChannelCloseMsg, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.IBCBasicResult, uint64, error) + IBCPacketReceiveFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, msg wasmvmtypes.IBCPacketReceiveMsg, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.IBCReceiveResult, uint64, error) + IBCPacketAckFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, msg wasmvmtypes.IBCPacketAckMsg, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.IBCBasicResult, uint64, error) + IBCPacketTimeoutFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, msg wasmvmtypes.IBCPacketTimeoutMsg, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.IBCBasicResult, uint64, error) + IBCSourceCallbackFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, msg wasmvmtypes.IBCSourceCallbackMsg, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.IBCBasicResult, uint64, error) + IBCDestinationCallbackFn func(codeID wasmvm.Checksum, env wasmvmtypes.Env, msg wasmvmtypes.IBCDestinationCallbackMsg, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.IBCBasicResult, uint64, error) + PinFn func(checksum wasmvm.Checksum) error + UnpinFn func(checksum wasmvm.Checksum) error + GetMetricsFn func() (*wasmvmtypes.Metrics, error) + GetPinMetricsFn func() (*wasmvmtypes.PinnedMetrics, error) } func (m *MockWasmEngine) IBCChannelOpen(codeID wasmvm.Checksum, env wasmvmtypes.Env, msg wasmvmtypes.IBCChannelOpenMsg, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.IBCChannelOpenResult, uint64, error) { @@ -87,6 +89,20 @@ func (m *MockWasmEngine) IBCPacketTimeout(codeID wasmvm.Checksum, env wasmvmtype return m.IBCPacketTimeoutFn(codeID, env, msg, store, goapi, querier, gasMeter, gasLimit, deserCost) } +func (m MockWasmEngine) IBCSourceCallback(codeID wasmvm.Checksum, env wasmvmtypes.Env, msg wasmvmtypes.IBCSourceCallbackMsg, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.IBCBasicResult, uint64, error) { + if m.IBCSourceCallbackFn == nil { + panic("not expected to be called") + } + return m.IBCSourceCallbackFn(codeID, env, msg, store, goapi, querier, gasMeter, gasLimit, deserCost) +} + +func (m MockWasmEngine) IBCDestinationCallback(codeID wasmvm.Checksum, env wasmvmtypes.Env, msg wasmvmtypes.IBCDestinationCallbackMsg, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64, deserCost wasmvmtypes.UFraction) (*wasmvmtypes.IBCBasicResult, uint64, error) { + if m.IBCDestinationCallbackFn == nil { + panic("not expected to be called") + } + return m.IBCDestinationCallbackFn(codeID, env, msg, store, goapi, querier, gasMeter, gasLimit, deserCost) +} + func (m *MockWasmEngine) StoreCode(codeID wasmvm.WasmCode, gasLimit uint64) (wasmvm.Checksum, uint64, error) { if m.StoreCodeFn == nil { panic("not supposed to be called!") diff --git a/x/wasm/types/exported_keepers.go b/x/wasm/types/exported_keepers.go index cbe0d734f1..0ece2c0207 100644 --- a/x/wasm/types/exported_keepers.go +++ b/x/wasm/types/exported_keepers.go @@ -116,6 +116,16 @@ type IBCContractKeeper interface { contractAddr sdk.AccAddress, msg wasmvmtypes.IBCPacketTimeoutMsg, ) error + IBCSourceCallback( + ctx sdk.Context, + contractAddr sdk.AccAddress, + msg wasmvmtypes.IBCSourceCallbackMsg, + ) error + IBCDestinationCallback( + ctx sdk.Context, + contractAddr sdk.AccAddress, + msg wasmvmtypes.IBCDestinationCallbackMsg, + ) error // ClaimCapability allows the transfer module to claim a capability // that IBC module passes to it ClaimCapability(ctx sdk.Context, cap *capabilitytypes.Capability, name string) error diff --git a/x/wasm/types/wasmer_engine.go b/x/wasm/types/wasmer_engine.go index fb53494353..31114e67cb 100644 --- a/x/wasm/types/wasmer_engine.go +++ b/x/wasm/types/wasmer_engine.go @@ -233,6 +233,36 @@ type WasmEngine interface { deserCost wasmvmtypes.UFraction, ) (*wasmvmtypes.IBCBasicResult, uint64, error) + // IBCSourceCallback is available on IBC-callbacks-enabled contracts and is called when an + // IBC-callbacks-enabled IBC message previously sent by this contract is either acknowledged or + // times out. + IBCSourceCallback( + checksum wasmvm.Checksum, + env wasmvmtypes.Env, + msg wasmvmtypes.IBCSourceCallbackMsg, + store wasmvm.KVStore, + goapi wasmvm.GoAPI, + querier wasmvm.Querier, + gasMeter wasmvm.GasMeter, + gasLimit uint64, + deserCost wasmvmtypes.UFraction, + ) (*wasmvmtypes.IBCBasicResult, uint64, error) + + // IBCSourceCallback is available on IBC-callbacks-enabled contracts and is called when an + // IBC-callbacks-enabled IBC message previously sent by this contract is either acknowledged or + // times out. + IBCDestinationCallback( + checksum wasmvm.Checksum, + env wasmvmtypes.Env, + msg wasmvmtypes.IBCDestinationCallbackMsg, + store wasmvm.KVStore, + goapi wasmvm.GoAPI, + querier wasmvm.Querier, + gasMeter wasmvm.GasMeter, + gasLimit uint64, + deserCost wasmvmtypes.UFraction, + ) (*wasmvmtypes.IBCBasicResult, uint64, error) + // Pin pins a code to an in-memory cache, such that is // always loaded quickly when executed. // Pin is idempotent.