From 3c5ee30630760b3e42202e622517a92c18f11889 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 13 Feb 2024 12:07:09 -0600 Subject: [PATCH] [Relax] Support callback as argument (#16542) Prior to this commit, calls from Relax to external PackedFuncs could only be done through the TVM global registry. While Relax functions accepting a callback could be written as `callback_arg: R.Callable(arg_struct_info, ret_struct_info)`, attempting to compile these functions would raise an error during the `CodeGenVM` step of `relax.build`. In addition, the global registry is only queried when initializing the `relax.VirtualMachine`, and so later changes requires restarting the VM. This commit updates both the `CodeGenVM` lowering pass and the relax VM to support callbacks. The is primarily intended for use with the `LazyTransformParams` pass, to improve flexibility by avoiding use of the global registry. --- include/tvm/runtime/relax_vm/bytecode.h | 28 +++- src/relax/backend/vm/exec_builder.cc | 20 ++- src/runtime/library_module.cc | 12 +- src/runtime/relax_vm/bytecode.cc | 11 ++ src/runtime/relax_vm/executable.cc | 8 ++ src/runtime/relax_vm/vm.cc | 39 ++++-- .../python/relax/test_vm_callback_function.py | 124 ++++++++++++++++++ 7 files changed, 219 insertions(+), 23 deletions(-) create mode 100644 tests/python/relax/test_vm_callback_function.py diff --git a/include/tvm/runtime/relax_vm/bytecode.h b/include/tvm/runtime/relax_vm/bytecode.h index 4526c6fffa1d..0db610ff42b2 100644 --- a/include/tvm/runtime/relax_vm/bytecode.h +++ b/include/tvm/runtime/relax_vm/bytecode.h @@ -58,6 +58,7 @@ enum class Opcode { Ret = 2U, Goto = 3U, If = 4U, + CallFromRegister = 5U, }; /*! \brief A single virtual machine instruction. @@ -183,10 +184,15 @@ struct Instruction { /*! \brief The instruction opcode. */ Opcode op; union { - struct /* Call */ { + struct /* Call, CallFromRegister */ { /*! \brief The destination register. */ RegName dst; - /*! \brief The index into the packed function table. */ + /*! \brief The index of the function. + * + * For `OpCode::Call`, this is an index into the table of static + * functions. For `OpCode::CallFromRegister`, this is an index + * of a register. + */ Index func_idx; /*! \brief The number of arguments to the packed function. */ Index num_args; @@ -208,27 +214,43 @@ struct Instruction { Index false_offset; }; }; + /*! * \brief Construct a Call instruction. - * \param func_idx The index of the function to call. + * \param func_idx The index of the function to call within the + * static function table * \param num_args The number of arguments. * \param args The input arguments. * \param dst The destination register. * \return The call instruction. */ static Instruction Call(Index func_idx, Index num_args, Arg* args, RegName dst); + + /*! + * \brief Construct a Call instruction. + * \param func_idx The index of the function to call within the + * current stack frame's registers. + * \param num_args The number of arguments. + * \param args The input arguments. + * \param dst The destination register. + * \return The call instruction. + */ + static Instruction CallFromRegister(Index func_idx, Index num_args, Arg* args, RegName dst); + /*! * \brief Construct a return instruction. * \param result The register containing the return value. * \return The return instruction. */ static Instruction Ret(RegName result); + /*! * \brief Construct a goto instruction. * \param pc_offset The register containing the jump offset. * \return The goto instruction. */ static Instruction Goto(RegName pc_offset); + /*! * \brief Construct an If instruction. * \param cond The register containing the cond value. diff --git a/src/relax/backend/vm/exec_builder.cc b/src/relax/backend/vm/exec_builder.cc index b5d932137be0..aa478122353d 100644 --- a/src/relax/backend/vm/exec_builder.cc +++ b/src/relax/backend/vm/exec_builder.cc @@ -138,10 +138,20 @@ void ExecBuilderNode::EndFunction(const std::string& func_name) { void ExecBuilderNode::EmitCall(vm::Instruction::Arg func, std::vector args, vm::RegName dst) { - ICHECK(func.kind() == vm::Instruction::ArgKind::kFuncIdx); + Opcode op_code; + if (func.kind() == vm::Instruction::ArgKind::kFuncIdx) { + op_code = Opcode::Call; + } else if (func.kind() == vm::Instruction::ArgKind::kRegister) { + op_code = Opcode::CallFromRegister; + } else { + LOG(FATAL) << "VM instruction for a function must be either " + << "kFuncIdx (static function ) " + << "or kRegister (function passed as parameter), " + << "but instead found " << func.kind(); + } // store instruction exec_->instr_offset.push_back(exec_->instr_data.size()); - exec_->instr_data.push_back(static_cast(Opcode::Call)); + exec_->instr_data.push_back(static_cast(op_code)); exec_->instr_data.push_back(dst); exec_->instr_data.push_back(func.value()); exec_->instr_data.push_back(args.size()); @@ -228,7 +238,8 @@ void ExecBuilderNode::CheckExecutable() { for (size_t idx = start_instr; idx < end_instr; ++idx) { Instruction instr = exec_->GetInstruction(idx); switch (instr.op) { - case Opcode::Call: { + case Opcode::Call: + case Opcode::CallFromRegister: { check_func_defined(Instruction::Arg::FuncIdx(instr.func_idx)); for (int i = 0; i < instr.num_args; ++i) { check_reg_defined(instr.args[i]); @@ -280,7 +291,8 @@ void ExecBuilderNode::Formalize() { for (size_t idx = start_instr; idx < end_instr; ++idx) { Instruction instr = this->exec_->GetInstruction(idx); switch (instr.op) { - case Opcode::Call: { + case Opcode::Call: + case Opcode::CallFromRegister: { // rewrite args for (int i = 0; i < instr.num_args; ++i) { if (instr.args[i].kind() == Instruction::ArgKind::kRegister && diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index bb5733ce013f..7b39bcd8da02 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -71,11 +71,15 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { TVMValue ret_value; int ret_type_code = kTVMNullptr; - int ret = (*faddr)(const_cast(args.values), const_cast(args.type_codes), - args.num_args, &ret_value, &ret_type_code, nullptr); - // NOTE: important to keep the original error message. + auto arg_values = const_cast(args.values); + auto arg_type_codes = const_cast(args.type_codes); + int ret = + (*faddr)(arg_values, arg_type_codes, args.num_args, &ret_value, &ret_type_code, nullptr); + // NOTE: It is important to keep the original error message. + // Using the `TVMThrowLastError()` function will also preserve the + // full stack trace for debugging in pdb. if (ret != 0) { - LOG(FATAL) << TVMGetLastError(); + TVMThrowLastError(); } if (ret_type_code != kTVMNullptr) { *rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code); diff --git a/src/runtime/relax_vm/bytecode.cc b/src/runtime/relax_vm/bytecode.cc index 9084207848b5..30d3bebd5f33 100644 --- a/src/runtime/relax_vm/bytecode.cc +++ b/src/runtime/relax_vm/bytecode.cc @@ -42,6 +42,17 @@ Instruction Instruction::Call(Index func_idx, Index num_args, Instruction::Arg* return instr; } +Instruction Instruction::CallFromRegister(Index func_idx, Index num_args, Instruction::Arg* args, + RegName dst) { + Instruction instr; + instr.op = Opcode::CallFromRegister; + instr.dst = dst; + instr.func_idx = func_idx; + instr.num_args = num_args; + instr.args = args; + return instr; +} + Instruction Instruction::Ret(RegName result) { Instruction instr; instr.op = Opcode::Ret; diff --git a/src/runtime/relax_vm/executable.cc b/src/runtime/relax_vm/executable.cc index f45786c3da32..9de708f49a9a 100644 --- a/src/runtime/relax_vm/executable.cc +++ b/src/runtime/relax_vm/executable.cc @@ -134,6 +134,14 @@ Instruction Executable::GetInstruction(Index i) const { ExecWord* args = const_cast(&instr_data[offset + 4]); return Instruction::Call(func_idx, num_args, reinterpret_cast(args), dst); } + case Opcode::CallFromRegister: { + RegName dst = instr_data[offset + 1]; + Index func_idx = instr_data[offset + 2]; + Index num_args = instr_data[offset + 3]; + ExecWord* args = const_cast(&instr_data[offset + 4]); + return Instruction::CallFromRegister(func_idx, num_args, + reinterpret_cast(args), dst); + } case Opcode::Ret: { RegName result = instr_data[offset + 1]; return Instruction::Ret(result); diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc index d7f943d5f40f..14a42df5f1e4 100644 --- a/src/runtime/relax_vm/vm.cc +++ b/src/runtime/relax_vm/vm.cc @@ -372,9 +372,10 @@ class VirtualMachineImpl : public VirtualMachine { /*! * \brief Run call instruction. * \param curr_frame The current frame. + * \param callable The callable object, either PackedFunc or closure * \param inst The call instruction. */ - virtual void RunInstrCall(VMFrame* curr_frame, Instruction inst); + virtual void RunInstrCall(VMFrame* curr_frame, const ObjectRef& callable, Instruction inst); /*! \brief Run VM dispatch loop. */ void RunLoop(); @@ -506,6 +507,9 @@ void VirtualMachineImpl::SetInput(std::string func_name, bool with_param_module, //------------------------------------------ void VirtualMachineImpl::InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, TVMArgs args, TVMRetValue* rv) { + ICHECK(closure_or_packedfunc.defined()) + << "InvokeClosurePacked requires the callable object to be defined"; + // run packed call if it is a packed func. if (auto* packed = closure_or_packedfunc.as()) { packed->CallPacked(args, rv); @@ -513,7 +517,8 @@ void VirtualMachineImpl::InvokeClosurePacked(const ObjectRef& closure_or_packedf } // run closure call. auto* clo = closure_or_packedfunc.as(); - ICHECK(clo != nullptr) << "Function expects a closure or PackedFunc "; + ICHECK(clo != nullptr) << "Function expects a closure or PackedFunc, " + << "but received " << closure_or_packedfunc->GetTypeKey(); std::vector values(args.size() + 1); std::vector tcodes(args.size() + 1); @@ -595,6 +600,8 @@ Optional VirtualMachineImpl::GetClosureInternal(const String& func_na auto impl = PackedFunc([gf_idx](TVMArgs args, TVMRetValue* rv) { // Per convention, ctx ptr is a VirtualMachine* VirtualMachine* ctx_ptr = static_cast(args[0].operator void*()); + ICHECK(ctx_ptr) << "Context pointer for relax VM closure should be a VirtualMachine*, " + << "but was NULL"; std::vector inputs(args.size() - 1); for (size_t i = 0; i < inputs.size(); ++i) { @@ -644,7 +651,7 @@ RegType VirtualMachineImpl::InvokeBytecode(Index gf_idx, const std::vectorpc_, gfunc); // Get new frame and set the caller info. VMFrame* curr_frame = frames_.back().get(); - if (curr_instr.op == Opcode::Call) { + if (curr_instr.op == Opcode::Call || curr_instr.op == Opcode::CallFromRegister) { curr_frame->caller_return_register = curr_instr.dst; } @@ -688,8 +695,12 @@ void VirtualMachineImpl::InitFuncPool() { } } -void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) { - DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << GetFuncName(instr.func_idx); +void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, const ObjectRef& callable, + Instruction instr) { + ICHECK(callable.defined()) << "RunInstrCall requires the callable object to be defined"; + auto func_name = instr.op == Opcode::Call ? GetFuncName(instr.func_idx) : ""; + + DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << func_name; int args_begin_offset = instrument_ != nullptr ? 4 : 0; // Use the call arg stack from the current frame to increase reuse // and avoid re-allocation @@ -735,11 +746,11 @@ void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) { ICHECK_LT(static_cast(instr.func_idx), this->func_pool_.size()); if (instrument_ == nullptr) { - this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret); + this->InvokeClosurePacked(callable, args, &ret); } else { // insert light-weight instrument callback - setter(0, func_pool_[instr.func_idx]); - setter(1, GetFuncName(instr.func_idx)); + setter(0, callable); + setter(1, func_name); setter(2, true); setter(3, nullptr); TVMRetValue rv; @@ -758,7 +769,7 @@ void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) { ret_kind = rv; } if (ret_kind != static_cast(VMInstrumentReturnKind::kSkipRun)) { - this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret); + this->InvokeClosurePacked(callable, args, &ret); setter(2, false); setter(3, ret); instrument_.CallPacked(TVMArgs(values.data(), tcodes.data(), values.size()), &rv); @@ -782,7 +793,11 @@ void VirtualMachineImpl::RunLoop() { Instruction instr = exec_->GetInstruction(pc_); switch (instr.op) { case Opcode::Call: { - this->RunInstrCall(curr_frame, instr); + this->RunInstrCall(curr_frame, func_pool_[instr.func_idx], instr); + break; + } + case Opcode::CallFromRegister: { + this->RunInstrCall(curr_frame, ReadRegister(curr_frame, instr.func_idx), instr); break; } case Opcode::Ret: { @@ -1000,7 +1015,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl { } protected: - void RunInstrCall(VMFrame* curr_frame, Instruction inst) override { + void RunInstrCall(VMFrame* curr_frame, const ObjectRef& callable, Instruction inst) override { bool profiling = false; if (prof_ && prof_->IsRunning()) { auto f_name = GetFuncName(inst.func_idx); @@ -1036,7 +1051,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl { } } - VirtualMachineImpl::RunInstrCall(curr_frame, inst); + VirtualMachineImpl::RunInstrCall(curr_frame, callable, inst); if (profiling) { prof_->StopCall(); diff --git a/tests/python/relax/test_vm_callback_function.py b/tests/python/relax/test_vm_callback_function.py new file mode 100644 index 000000000000..29a502ad7f98 --- /dev/null +++ b/tests/python/relax/test_vm_callback_function.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing + +from tvm.script import relax as R + +import numpy as np + +exec_mode = tvm.testing.parameter("bytecode", "compiled") + +pytestmark = tvm.testing.parametrize_targets("llvm") + + +def test_pass_tensor_to_function(exec_mode, target, dev): + @R.function + def relax_func( + A: R.Tensor([16], "int32"), + callback: R.Callable([R.Tensor([16], "int32")], R.Tuple([])), + ): + B = R.multiply(A, R.const(2)) + _ = callback(B) + return R.tuple() + + ex = tvm.relax.build(tvm.IRModule.from_expr(relax_func), target=target, exec_mode=exec_mode) + vm = tvm.relax.VirtualMachine(ex, dev) + + from_callback = None + + def custom_callback(arr): + nonlocal from_callback + from_callback = arr + + np_A = np.arange(16, dtype="int32") + tvm_A = tvm.nd.array(np_A) + + vm["relax_func"](tvm_A, custom_callback) + + assert from_callback is not None + np.testing.assert_array_equal(np_A * 2, from_callback.numpy()) + + +def test_generate_tensor_in_function(exec_mode, target, dev): + @R.function + def relax_func( + callback: R.Callable([], R.Tensor([16], "int32")), + ): + A = callback() + B = R.multiply(A, R.const(2)) + return B + + ex = tvm.relax.build( + tvm.IRModule.from_expr(relax_func), + target=target, + exec_mode=exec_mode, + ) + vm = tvm.relax.VirtualMachine(ex, dev) + + np_A = np.arange(16, dtype="int32") + + def custom_callback(): + return tvm.nd.array(np_A) + + output = vm["relax_func"](custom_callback) + + np.testing.assert_array_equal(np_A * 2, output.numpy()) + + +def test_catch_exception_with_full_stack_trace(exec_mode, target, dev): + @R.function + def relax_func( + callback: R.Callable([], R.Tensor([16], "int32")), + ): + A = callback() + return A + + ex = tvm.relax.build( + tvm.IRModule.from_expr(relax_func), + target=target, + exec_mode=exec_mode, + ) + vm = tvm.relax.VirtualMachine(ex, dev) + + def custom_callback(): + local_var = 42 + raise RuntimeError("Error thrown from callback") + + try: + vm["relax_func"](custom_callback) + except RuntimeError as err: + stack = err.__traceback__ + while stack.tb_next is not None: + stack = stack.tb_next + frame = stack.tb_frame + + assert frame.f_code is custom_callback.__code__, ( + "Inner-most stack frame should be from Python callback, " + "even though that crosses an FFI boundary" + ) + assert frame.f_locals.get("local_var") == 42, ( + "Python __traceback__ should include local variables, " + "even though that crosses an FFI boundary" + ) + else: + raise RuntimeError("Exception thrown in callback was not propagated to calling scope") + + +if __name__ == "__main__": + tvm.testing.main()