Skip to content

Commit

Permalink
[Relax] Support callback as argument (#16542)
Browse files Browse the repository at this point in the history
Prior to this commit, calls from Relax to external PackedFuncs could
only be done through the TVM global registry.  While Relax functions
accepting a callback could be written as `callback_arg:
R.Callable(arg_struct_info, ret_struct_info)`, attempting to compile
these functions would raise an error during the `CodeGenVM` step of
`relax.build`.  In addition, the global registry is only queried when
initializing the `relax.VirtualMachine`, and so later changes requires
restarting the VM.

This commit updates both the `CodeGenVM` lowering pass and the relax
VM to support callbacks.  The is primarily intended for use with the
`LazyTransformParams` pass, to improve flexibility by avoiding use of
the global registry.
  • Loading branch information
Lunderberg authored Feb 13, 2024
1 parent bb2adbf commit 3c5ee30
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 23 deletions.
28 changes: 25 additions & 3 deletions include/tvm/runtime/relax_vm/bytecode.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ enum class Opcode {
Ret = 2U,
Goto = 3U,
If = 4U,
CallFromRegister = 5U,
};

/*! \brief A single virtual machine instruction.
Expand Down Expand Up @@ -183,10 +184,15 @@ struct Instruction {
/*! \brief The instruction opcode. */
Opcode op;
union {
struct /* Call */ {
struct /* Call, CallFromRegister */ {
/*! \brief The destination register. */
RegName dst;
/*! \brief The index into the packed function table. */
/*! \brief The index of the function.
*
* For `OpCode::Call`, this is an index into the table of static
* functions. For `OpCode::CallFromRegister`, this is an index
* of a register.
*/
Index func_idx;
/*! \brief The number of arguments to the packed function. */
Index num_args;
Expand All @@ -208,27 +214,43 @@ struct Instruction {
Index false_offset;
};
};

/*!
* \brief Construct a Call instruction.
* \param func_idx The index of the function to call.
* \param func_idx The index of the function to call within the
* static function table
* \param num_args The number of arguments.
* \param args The input arguments.
* \param dst The destination register.
* \return The call instruction.
*/
static Instruction Call(Index func_idx, Index num_args, Arg* args, RegName dst);

/*!
* \brief Construct a Call instruction.
* \param func_idx The index of the function to call within the
* current stack frame's registers.
* \param num_args The number of arguments.
* \param args The input arguments.
* \param dst The destination register.
* \return The call instruction.
*/
static Instruction CallFromRegister(Index func_idx, Index num_args, Arg* args, RegName dst);

/*!
* \brief Construct a return instruction.
* \param result The register containing the return value.
* \return The return instruction.
*/
static Instruction Ret(RegName result);

/*!
* \brief Construct a goto instruction.
* \param pc_offset The register containing the jump offset.
* \return The goto instruction.
*/
static Instruction Goto(RegName pc_offset);

/*!
* \brief Construct an If instruction.
* \param cond The register containing the cond value.
Expand Down
20 changes: 16 additions & 4 deletions src/relax/backend/vm/exec_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,20 @@ void ExecBuilderNode::EndFunction(const std::string& func_name) {

void ExecBuilderNode::EmitCall(vm::Instruction::Arg func, std::vector<vm::Instruction::Arg> args,
vm::RegName dst) {
ICHECK(func.kind() == vm::Instruction::ArgKind::kFuncIdx);
Opcode op_code;
if (func.kind() == vm::Instruction::ArgKind::kFuncIdx) {
op_code = Opcode::Call;
} else if (func.kind() == vm::Instruction::ArgKind::kRegister) {
op_code = Opcode::CallFromRegister;
} else {
LOG(FATAL) << "VM instruction for a function must be either "
<< "kFuncIdx (static function ) "
<< "or kRegister (function passed as parameter), "
<< "but instead found " << func.kind();
}
// store instruction
exec_->instr_offset.push_back(exec_->instr_data.size());
exec_->instr_data.push_back(static_cast<ExecWord>(Opcode::Call));
exec_->instr_data.push_back(static_cast<ExecWord>(op_code));
exec_->instr_data.push_back(dst);
exec_->instr_data.push_back(func.value());
exec_->instr_data.push_back(args.size());
Expand Down Expand Up @@ -228,7 +238,8 @@ void ExecBuilderNode::CheckExecutable() {
for (size_t idx = start_instr; idx < end_instr; ++idx) {
Instruction instr = exec_->GetInstruction(idx);
switch (instr.op) {
case Opcode::Call: {
case Opcode::Call:
case Opcode::CallFromRegister: {
check_func_defined(Instruction::Arg::FuncIdx(instr.func_idx));
for (int i = 0; i < instr.num_args; ++i) {
check_reg_defined(instr.args[i]);
Expand Down Expand Up @@ -280,7 +291,8 @@ void ExecBuilderNode::Formalize() {
for (size_t idx = start_instr; idx < end_instr; ++idx) {
Instruction instr = this->exec_->GetInstruction(idx);
switch (instr.op) {
case Opcode::Call: {
case Opcode::Call:
case Opcode::CallFromRegister: {
// rewrite args
for (int i = 0; i < instr.num_args; ++i) {
if (instr.args[i].kind() == Instruction::ArgKind::kRegister &&
Expand Down
12 changes: 8 additions & 4 deletions src/runtime/library_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,15 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>&
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
TVMValue ret_value;
int ret_type_code = kTVMNullptr;
int ret = (*faddr)(const_cast<TVMValue*>(args.values), const_cast<int*>(args.type_codes),
args.num_args, &ret_value, &ret_type_code, nullptr);
// NOTE: important to keep the original error message.
auto arg_values = const_cast<TVMValue*>(args.values);
auto arg_type_codes = const_cast<int*>(args.type_codes);
int ret =
(*faddr)(arg_values, arg_type_codes, args.num_args, &ret_value, &ret_type_code, nullptr);
// NOTE: It is important to keep the original error message.
// Using the `TVMThrowLastError()` function will also preserve the
// full stack trace for debugging in pdb.
if (ret != 0) {
LOG(FATAL) << TVMGetLastError();
TVMThrowLastError();
}
if (ret_type_code != kTVMNullptr) {
*rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code);
Expand Down
11 changes: 11 additions & 0 deletions src/runtime/relax_vm/bytecode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ Instruction Instruction::Call(Index func_idx, Index num_args, Instruction::Arg*
return instr;
}

Instruction Instruction::CallFromRegister(Index func_idx, Index num_args, Instruction::Arg* args,
RegName dst) {
Instruction instr;
instr.op = Opcode::CallFromRegister;
instr.dst = dst;
instr.func_idx = func_idx;
instr.num_args = num_args;
instr.args = args;
return instr;
}

Instruction Instruction::Ret(RegName result) {
Instruction instr;
instr.op = Opcode::Ret;
Expand Down
8 changes: 8 additions & 0 deletions src/runtime/relax_vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ Instruction Executable::GetInstruction(Index i) const {
ExecWord* args = const_cast<ExecWord*>(&instr_data[offset + 4]);
return Instruction::Call(func_idx, num_args, reinterpret_cast<Instruction::Arg*>(args), dst);
}
case Opcode::CallFromRegister: {
RegName dst = instr_data[offset + 1];
Index func_idx = instr_data[offset + 2];
Index num_args = instr_data[offset + 3];
ExecWord* args = const_cast<ExecWord*>(&instr_data[offset + 4]);
return Instruction::CallFromRegister(func_idx, num_args,
reinterpret_cast<Instruction::Arg*>(args), dst);
}
case Opcode::Ret: {
RegName result = instr_data[offset + 1];
return Instruction::Ret(result);
Expand Down
39 changes: 27 additions & 12 deletions src/runtime/relax_vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,10 @@ class VirtualMachineImpl : public VirtualMachine {
/*!
* \brief Run call instruction.
* \param curr_frame The current frame.
* \param callable The callable object, either PackedFunc or closure
* \param inst The call instruction.
*/
virtual void RunInstrCall(VMFrame* curr_frame, Instruction inst);
virtual void RunInstrCall(VMFrame* curr_frame, const ObjectRef& callable, Instruction inst);

/*! \brief Run VM dispatch loop. */
void RunLoop();
Expand Down Expand Up @@ -506,14 +507,18 @@ void VirtualMachineImpl::SetInput(std::string func_name, bool with_param_module,
//------------------------------------------
void VirtualMachineImpl::InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, TVMArgs args,
TVMRetValue* rv) {
ICHECK(closure_or_packedfunc.defined())
<< "InvokeClosurePacked requires the callable object to be defined";

// run packed call if it is a packed func.
if (auto* packed = closure_or_packedfunc.as<PackedFunc::ContainerType>()) {
packed->CallPacked(args, rv);
return;
}
// run closure call.
auto* clo = closure_or_packedfunc.as<VMClosureObj>();
ICHECK(clo != nullptr) << "Function expects a closure or PackedFunc ";
ICHECK(clo != nullptr) << "Function expects a closure or PackedFunc, "
<< "but received " << closure_or_packedfunc->GetTypeKey();

std::vector<TVMValue> values(args.size() + 1);
std::vector<int> tcodes(args.size() + 1);
Expand Down Expand Up @@ -595,6 +600,8 @@ Optional<VMClosure> VirtualMachineImpl::GetClosureInternal(const String& func_na
auto impl = PackedFunc([gf_idx](TVMArgs args, TVMRetValue* rv) {
// Per convention, ctx ptr is a VirtualMachine*
VirtualMachine* ctx_ptr = static_cast<VirtualMachine*>(args[0].operator void*());
ICHECK(ctx_ptr) << "Context pointer for relax VM closure should be a VirtualMachine*, "
<< "but was NULL";

std::vector<RegType> inputs(args.size() - 1);
for (size_t i = 0; i < inputs.size(); ++i) {
Expand Down Expand Up @@ -644,7 +651,7 @@ RegType VirtualMachineImpl::InvokeBytecode(Index gf_idx, const std::vector<RegTy
auto guard = PushFrame(this->pc_, gfunc);
// Get new frame and set the caller info.
VMFrame* curr_frame = frames_.back().get();
if (curr_instr.op == Opcode::Call) {
if (curr_instr.op == Opcode::Call || curr_instr.op == Opcode::CallFromRegister) {
curr_frame->caller_return_register = curr_instr.dst;
}

Expand Down Expand Up @@ -688,8 +695,12 @@ void VirtualMachineImpl::InitFuncPool() {
}
}

void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) {
DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << GetFuncName(instr.func_idx);
void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, const ObjectRef& callable,
Instruction instr) {
ICHECK(callable.defined()) << "RunInstrCall requires the callable object to be defined";
auto func_name = instr.op == Opcode::Call ? GetFuncName(instr.func_idx) : "<dynamic>";

DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << func_name;
int args_begin_offset = instrument_ != nullptr ? 4 : 0;
// Use the call arg stack from the current frame to increase reuse
// and avoid re-allocation
Expand Down Expand Up @@ -735,11 +746,11 @@ void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) {
ICHECK_LT(static_cast<size_t>(instr.func_idx), this->func_pool_.size());

if (instrument_ == nullptr) {
this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret);
this->InvokeClosurePacked(callable, args, &ret);
} else {
// insert light-weight instrument callback
setter(0, func_pool_[instr.func_idx]);
setter(1, GetFuncName(instr.func_idx));
setter(0, callable);
setter(1, func_name);
setter(2, true);
setter(3, nullptr);
TVMRetValue rv;
Expand All @@ -758,7 +769,7 @@ void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) {
ret_kind = rv;
}
if (ret_kind != static_cast<int>(VMInstrumentReturnKind::kSkipRun)) {
this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret);
this->InvokeClosurePacked(callable, args, &ret);
setter(2, false);
setter(3, ret);
instrument_.CallPacked(TVMArgs(values.data(), tcodes.data(), values.size()), &rv);
Expand All @@ -782,7 +793,11 @@ void VirtualMachineImpl::RunLoop() {
Instruction instr = exec_->GetInstruction(pc_);
switch (instr.op) {
case Opcode::Call: {
this->RunInstrCall(curr_frame, instr);
this->RunInstrCall(curr_frame, func_pool_[instr.func_idx], instr);
break;
}
case Opcode::CallFromRegister: {
this->RunInstrCall(curr_frame, ReadRegister(curr_frame, instr.func_idx), instr);
break;
}
case Opcode::Ret: {
Expand Down Expand Up @@ -1000,7 +1015,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl {
}

protected:
void RunInstrCall(VMFrame* curr_frame, Instruction inst) override {
void RunInstrCall(VMFrame* curr_frame, const ObjectRef& callable, Instruction inst) override {
bool profiling = false;
if (prof_ && prof_->IsRunning()) {
auto f_name = GetFuncName(inst.func_idx);
Expand Down Expand Up @@ -1036,7 +1051,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl {
}
}

VirtualMachineImpl::RunInstrCall(curr_frame, inst);
VirtualMachineImpl::RunInstrCall(curr_frame, callable, inst);

if (profiling) {
prof_->StopCall();
Expand Down
Loading

0 comments on commit 3c5ee30

Please sign in to comment.