diff --git a/x/wasm/keeper/contract.go b/x/wasm/keeper/contract.go index f2c29b61e..c5f9b5a2a 100644 --- a/x/wasm/keeper/contract.go +++ b/x/wasm/keeper/contract.go @@ -437,6 +437,12 @@ func (k Keeper) queryToContract(ctx sdk.Context, contractAddress sdk.AccAddress, wasmVM = wasmVMs[0] } + // assert max depth to prevent stack overflow + if err := wasmVM.IncreaseQueryDepth(); err != nil { + return nil, err + } + defer wasmVM.DecreaseQueryDepth() + queryResult, gasUsed, err := wasmVM.Query( codeInfo.CodeHash, env, diff --git a/x/wasm/keeper/keeper.go b/x/wasm/keeper/keeper.go index 35d47e092..f71439dfa 100644 --- a/x/wasm/keeper/keeper.go +++ b/x/wasm/keeper/keeper.go @@ -70,16 +70,17 @@ func NewKeeper( wasmConfig.WriteVMMemoryCacheSize = config.DefaultWriteVMMemoryCacheSize } - writeWasmVM, err := wasmvm.NewVM( + var writeWasmVM types.WasmerEngine + if vm, err := wasmvm.NewVM( filepath.Join(homePath, config.DBDir), supportedFeatures, types.ContractMemoryLimit, wasmConfig.ContractDebugMode, wasmConfig.WriteVMMemoryCacheSize, - ) - - if err != nil { + ); err != nil { panic(err) + } else { + writeWasmVM = types.NewWasmerEngineWithQueryDepth(vm) } // prevent zero read vm @@ -95,16 +96,16 @@ func NewKeeper( numReadVms := wasmConfig.NumReadVMs wasmReadVMPool := make([]types.WasmerEngine, numReadVms) for i := uint32(0); i < numReadVms; i++ { - wasmReadVMPool[i], err = wasmvm.NewVM( + if vm, err := wasmvm.NewVM( filepath.Join(homePath, config.DBDir), supportedFeatures, types.ContractMemoryLimit, wasmConfig.ContractDebugMode, wasmConfig.ReadVMMemoryCacheSize, - ) - - if err != nil { + ); err != nil { panic(err) + } else { + wasmReadVMPool[i] = types.NewWasmerEngineWithQueryDepth(vm) } } diff --git a/x/wasm/keeper/recursive_test.go b/x/wasm/keeper/recursive_test.go index bba646065..b6989300e 100644 --- a/x/wasm/keeper/recursive_test.go +++ b/x/wasm/keeper/recursive_test.go @@ -342,3 +342,20 @@ func TestLimitRecursiveQueryGas(t *testing.T) { }) } } + +func TestLimitRecursiveQueryDepth(t *testing.T) { + + contractAddr, _, ctx, keeper, _ := initRecurseContract(t) + msg := buildQuery(t, Recurse{ + Depth: types.ContractMaxQueryDepth - 1, // need to include first query + }) + + _, err := keeper.queryToContract(ctx, contractAddr, msg) + require.NoError(t, err) + + msg = buildQuery(t, Recurse{ + Depth: types.ContractMaxQueryDepth, + }) + _, err = keeper.queryToContract(ctx, contractAddr, msg) + require.Error(t, err) +} diff --git a/x/wasm/types/errors.go b/x/wasm/types/errors.go index f518d9124..d2083b75a 100644 --- a/x/wasm/types/errors.go +++ b/x/wasm/types/errors.go @@ -23,4 +23,5 @@ var ( ErrExceedMaxContractMsgSize = sdkerrors.Register(ModuleName, 16, "exceeds max contract msg size limit") ErrExceedMaxContractDataSize = sdkerrors.Register(ModuleName, 17, "exceeds max contract data size limit") ErrReplyFailed = sdkerrors.Register(ModuleName, 18, "reply wasm contract failed") + ErrExceedMaxQueryDepth = sdkerrors.Register(ModuleName, 19, "exceed max query depth") ) diff --git a/x/wasm/types/wasmer_engine.go b/x/wasm/types/wasmer_engine.go index 0309602c3..393462bb8 100644 --- a/x/wasm/types/wasmer_engine.go +++ b/x/wasm/types/wasmer_engine.go @@ -5,6 +5,40 @@ import ( wasmvmtypes "github.com/CosmWasm/wasmvm/types" ) +// ContractMaxQueryDepth maximum recursive query depth allowed +const ContractMaxQueryDepth = 20 + +var _ WasmerEngine = &WasmerEngineWithQueryDepth{} + +// WasmerEngineWithQueryDepth VM wrapper with depth counter to prevent +// stack overflow +type WasmerEngineWithQueryDepth struct { + *wasmvm.VM + QueryDepth uint8 +} + +// NewWasmerEngineWithQueryDepth wrap wasmer engine with query depth checker +func NewWasmerEngineWithQueryDepth(wasmVM *wasmvm.VM) *WasmerEngineWithQueryDepth { + wasmerEngine := WasmerEngineWithQueryDepth{VM: wasmVM} + return &wasmerEngine +} + +// IncreaseQueryDepth increase execution depth by 1 and check whether +// the depth exceeds max one or not +func (wasmer *WasmerEngineWithQueryDepth) IncreaseQueryDepth() error { + if wasmer.QueryDepth >= ContractMaxQueryDepth { + return ErrExceedMaxQueryDepth + } + + wasmer.QueryDepth++ + return nil +} + +// DecreaseQueryDepth decrease execution depth by 1 +func (wasmer *WasmerEngineWithQueryDepth) DecreaseQueryDepth() { + wasmer.QueryDepth-- +} + // WasmerEngine defines the WASM contract runtime engine. type WasmerEngine interface { @@ -119,4 +153,11 @@ type WasmerEngine interface { // Cleanup should be called when no longer using this to free resources on the rust-side Cleanup() + + // IncreaseQueryDepth will increase query depth by 1 and assert the current depth + // reached out the maximum + IncreaseQueryDepth() error + + // DecreaseQueryDepth will decrease query depth by 1 + DecreaseQueryDepth() }