diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py index cef7530264..cd25e572b2 100644 --- a/python/tvm/relax/vm.py +++ b/python/tvm/relax/vm.py @@ -126,7 +126,7 @@ def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]]) init_args.append(device.device_id) alloc_type = memory_cfg[device] if device in memory_cfg else default_alloc_type init_args.append(alloc_type) - _ffi_api.VirtualMachineInit(self.module, *init_args) + self.module["vm_initialization"](*init_args) def __getitem__(self, key: str) -> PackedFunc: return self.module[key] diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc index 65e42473a6..671820ee3b 100644 --- a/src/runtime/relax_vm/vm.cc +++ b/src/runtime/relax_vm/vm.cc @@ -29,6 +29,27 @@ namespace relax_vm { PackedFunc VirtualMachine::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + if (name == "vm_initialization") { + // initialize the VirtualMachine, takes variable-length arguments + // first argument is a runtime::Module, followed by one or more device_type, device_id, + // and the AllocatorType associated with the device. + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ICHECK_EQ(args.size() % 3, 0); + std::vector devices; + std::vector alloc_types; + for (int i = 0; i < args.size(); i += 3) { + Device dev; + int device_type = args[i]; + dev.device_type = DLDeviceType(device_type); + dev.device_id = args[i + 1]; + int type = args[i + 2]; + devices.push_back(dev); + alloc_types.push_back(AllocatorType(type)); + } + this->Init(devices, alloc_types); + }); + } + const auto& m = exec_->global_map; if (m.find(name) != m.end()) { Index gf_idx = m.at(name); @@ -220,27 +241,6 @@ inline RegType VirtualMachine::ReadRegister(VMFrame* frame, Index r) const { return frame->register_file[r]; } -// initialize the VirtualMachine, takes variable-length arguments -// first argument is a runtime::Module, followed by one or more device_type, device_id, -// and the AllocatorType associated with the device. -TVM_REGISTER_GLOBAL("relax.VirtualMachineInit").set_body([](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args.size() % 3, 1); - runtime::Module mod = args[0]; - auto vm = static_cast(mod.operator->()); - std::vector devices; - std::vector alloc_types; - for (int i = 0; i < args.size() / 3; ++i) { - Device dev; - int device_type = args[i * 3 + 1]; - dev.device_type = DLDeviceType(device_type); - dev.device_id = args[i * 3 + 2]; - int type = args[i * 3 + 3]; - devices.push_back(dev); - alloc_types.push_back(AllocatorType(type)); - } - vm->Init(devices, alloc_types); -}); - } // namespace relax_vm } // namespace runtime } // namespace tvm