diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc index 9c25ca5307e0..ff9e0cec5fc5 100644 --- a/src/runtime/relax_vm/vm.cc +++ b/src/runtime/relax_vm/vm.cc @@ -282,6 +282,22 @@ class VirtualMachineImpl : public VirtualMachine { * \brief Initialize function pool. */ void InitFuncPool(); + + /*! + * \brief A RAII wrapper that pushes and pops VM frames. + */ + class FrameGuard { + public: + VirtualMachineImpl* vm; + explicit FrameGuard(VirtualMachineImpl* vm, std::unique_ptr frame) : vm(vm) { + vm->frames_.emplace_back(std::move(frame)); + } + ~FrameGuard() { + ICHECK_GT(vm->frames_.size(), 0); + vm->pc_ = vm->frames_.back()->return_pc; + vm->frames_.pop_back(); + } + }; //------------------------------------------------- // Instruction interpretations. //------------------------------------------------- @@ -289,17 +305,10 @@ class VirtualMachineImpl : public VirtualMachine { * \brief Push a call frame onto the call stack. * \param ret_pc The program counter to return to. * \param vm_func The function to be pushed to the call stack. + * \return A RAII wrapper that pops the frame when going out of scope. */ - void PushFrame(Index ret_pc, const VMFuncInfo& vm_func) { - frames_.emplace_back(std::make_unique(ret_pc, vm_func.register_file_size)); - } - /*! - * \brief Pop a frame off the call stack. - */ - void PopFrame() { - ICHECK_GT(frames_.size(), 0); - pc_ = frames_.back()->return_pc; - frames_.pop_back(); + FrameGuard PushFrame(Index ret_pc, const VMFuncInfo& vm_func) { + return FrameGuard(this, std::make_unique(ret_pc, vm_func.register_file_size)); } /*! * \brief Write to a VM register. @@ -733,7 +742,7 @@ RegType VirtualMachineImpl::InvokeBytecode(Index gf_idx, const std::vectorGetInstruction(pc_); - PushFrame(this->pc_, gfunc); + auto guard = PushFrame(this->pc_, gfunc); // Get new frame and set the caller info. VMFrame* curr_frame = frames_.back().get(); if (curr_instr.op == Opcode::Call) { @@ -883,14 +892,13 @@ void VirtualMachineImpl::RunLoop() { // the dispatch loop. return_value_ = ReadRegister(curr_frame, instr.result); RegName caller_return_register = curr_frame->caller_return_register; - PopFrame(); - if (frames_.size() == 0) { - // directly return if no frame in the call stack. + if (frames_.size() <= 1) { + // directly return if no other frame in the call stack. } else { // return from a local call. // Update the current frame to be the parent frame. - curr_frame = frames_.back().get(); - WriteRegister(curr_frame, caller_return_register, return_value_); + VMFrame* parent_frame = frames_.end()[-2].get(); + WriteRegister(parent_frame, caller_return_register, return_value_); } return; } diff --git a/tests/python/relax/test_vm_execbuilder.py b/tests/python/relax/test_vm_execbuilder.py index 5d9491dad7fc..4c15d8013bf3 100644 --- a/tests/python/relax/test_vm_execbuilder.py +++ b/tests/python/relax/test_vm_execbuilder.py @@ -15,11 +15,13 @@ # specific language governing permissions and limitations # under the License. """Lowest level testing VM. Test execbuilder and execution.""" -import tvm -import pytest import numpy as np -from tvm import relax, TVMError +import pytest + +import tvm +from tvm import TVMError, relax from tvm.relax.testing.vm import check_saved_func +from tvm.script import relax as R def test_vm_execute(): @@ -264,5 +266,32 @@ def test_vm_invoke_closure(): ) +def test_vm_stack_restore_after_failure(): + @tvm.script.ir_module + class Module: + @R.function + def main(inp: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.multiply(inp, R.const(2, "float32")) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + mod = relax.transform.LegalizeOps()(Module) + ex = relax.build(mod, "llvm") + vm = relax.VirtualMachine(ex, tvm.cpu()) + + correct_input = tvm.nd.array(np.random.normal(size=(10, 10)).astype("float32")) + incorrect_input = tvm.nd.array(np.random.normal(size=(12, 10)).astype("float32")) + + try: + vm["main"](incorrect_input) + except RuntimeError: + pass + + # VM should executes correctly after encountered incorrect shape in previous invocation + vm["main"](correct_input) + + if __name__ == "__main__": tvm.testing.main()