Skip to content

Commit

Permalink
[VM] Initialize VM through packed function (#101)
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 authored and junrushao committed Feb 5, 2023
1 parent 2766afa commit 1ce5b18
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 22 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relax/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]])
init_args.append(device.device_id)
alloc_type = memory_cfg[device] if device in memory_cfg else default_alloc_type
init_args.append(alloc_type)
_ffi_api.VirtualMachineInit(self.module, *init_args)
self.module["vm_initialization"](*init_args)

def __getitem__(self, key: str) -> PackedFunc:
return self.module[key]
Expand Down
42 changes: 21 additions & 21 deletions src/runtime/relax_vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,27 @@ namespace relax_vm {

PackedFunc VirtualMachine::GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) {
if (name == "vm_initialization") {
// initialize the VirtualMachine, takes variable-length arguments
// first argument is a runtime::Module, followed by one or more device_type, device_id,
// and the AllocatorType associated with the device.
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.size() % 3, 0);
std::vector<Device> devices;
std::vector<AllocatorType> alloc_types;
for (int i = 0; i < args.size(); i += 3) {
Device dev;
int device_type = args[i];
dev.device_type = DLDeviceType(device_type);
dev.device_id = args[i + 1];
int type = args[i + 2];
devices.push_back(dev);
alloc_types.push_back(AllocatorType(type));
}
this->Init(devices, alloc_types);
});
}

const auto& m = exec_->global_map;
if (m.find(name) != m.end()) {
Index gf_idx = m.at(name);
Expand Down Expand Up @@ -220,27 +241,6 @@ inline RegType VirtualMachine::ReadRegister(VMFrame* frame, Index r) const {
return frame->register_file[r];
}

// initialize the VirtualMachine, takes variable-length arguments
// first argument is a runtime::Module, followed by one or more device_type, device_id,
// and the AllocatorType associated with the device.
TVM_REGISTER_GLOBAL("relax.VirtualMachineInit").set_body([](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.size() % 3, 1);
runtime::Module mod = args[0];
auto vm = static_cast<VirtualMachine*>(mod.operator->());
std::vector<Device> devices;
std::vector<AllocatorType> alloc_types;
for (int i = 0; i < args.size() / 3; ++i) {
Device dev;
int device_type = args[i * 3 + 1];
dev.device_type = DLDeviceType(device_type);
dev.device_id = args[i * 3 + 2];
int type = args[i * 3 + 3];
devices.push_back(dev);
alloc_types.push_back(AllocatorType(type));
}
vm->Init(devices, alloc_types);
});

} // namespace relax_vm
} // namespace runtime
} // namespace tvm

0 comments on commit 1ce5b18

Please sign in to comment.