Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fixing tvmgpu issue & not restoring tvmop checks #18818

Merged
merged 16 commits into from
Aug 22, 2020
2 changes: 2 additions & 0 deletions contrib/tvmop/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def get_cuda_arch(arch):
# we create libtvmop.o first, which gives us chance to link tvm_runtime together with the libtvmop
# to allow mxnet find external helper functions in libtvm_runtime
func_binary.save(arguments.target_path + "/libtvmop.o")
if len(func_binary.imported_modules):
func_binary.imported_modules[0].save(arguments.target_path + "/libtvmop.cubin")
ld_path = arguments.target_path if arguments.ld_path is None else arguments.ld_path
create_shared(arguments.target_path + "/libtvmop.so",
arguments.target_path + "/libtvmop.o",
Expand Down
9 changes: 9 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,15 @@ int MXGetVersion(int *out) {
int MXLoadTVMOp(const char *libpath) {
API_BEGIN();
tvm::runtime::TVMOpModule::Get()->Load(libpath);
tvm::runtime::TVMOpModule *global_module = tvm::runtime::TVMOpModule::Get();
global_module->Load(libpath);
#if MXNET_USE_CUDA
std::string libpathstr(libpath);
std::string cubinpath = libpathstr.substr(0, libpathstr.size() - 11) + "libtvmop.cubin";
tvm::runtime::TVMOpModule cubin_module;
cubin_module.Load(cubinpath);
global_module->Import(cubin_module);
#endif
API_END();
}

Expand Down
6 changes: 6 additions & 0 deletions src/operator/tvmop/op_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ void TVMOpModule::Load(const std::string &filepath) {
*module_ptr_ = module;
}

void TVMOpModule::Import(const TVMOpModule& module) {
CHECK(module_ptr_ != nullptr) << "module_ptr_ is not initialized.";
std::lock_guard<std::mutex> lock(mutex_);
module_ptr_->Import(*(module.module_ptr_));
}

PackedFunc GetFunction(const std::shared_ptr<Module> &module,
const std::string &op_name,
const std::vector<mxnet::TBlob> &args) {
Expand Down
2 changes: 2 additions & 0 deletions src/operator/tvmop/op_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class TVMOpModule {
// Load TVM operators binary
void Load(const std::string& filepath);

void Import(const TVMOpModule& module);

void Call(const std::string& func_name,
const mxnet::OpContext& ctx,
const std::vector<mxnet::TBlob>& args) const;
Expand Down