Skip to content

Commit

Permalink
add nvcc support (apache#7668)
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurawly authored and trevor-m committed May 11, 2021
1 parent 2f29ba8 commit 060105f
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 7 deletions.
16 changes: 10 additions & 6 deletions python/tvm/contrib/cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,16 @@ def _fcompile(outputs, objects, options=None):

def _linux_compile(output, objects, options, compile_cmd="g++", compile_shared=False):
cmd = [compile_cmd]
if compile_shared or output.endswith(".so") or output.endswith(".dylib"):
cmd += ["-shared", "-fPIC"]
if sys.platform == "darwin":
cmd += ["-undefined", "dynamic_lookup"]
elif output.endswith(".obj"):
cmd += ["-c"]
if compile_cmd != "nvcc":
if compile_shared or output.endswith(".so") or output.endswith(".dylib"):
cmd += ["-shared", "-fPIC"]
if sys.platform == "darwin":
cmd += ["-undefined", "dynamic_lookup"]
elif output.endswith(".obj"):
cmd += ["-c"]
else:
if compile_shared or output.endswith(".so") or output.endswith(".dylib"):
cmd += ["--shared"]
cmd += ["-o", output]
if isinstance(objects, str):
cmd += [objects]
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No
else:
assert module.type_key == "c"
object_format = "c"
if "cc" in kwargs:
if kwargs["cc"] == "nvcc":
object_format = "cu"
has_c_module = True
path_obj = os.path.join(workspace_dir, f"lib{index}.{object_format}")
module.save(path_obj)
Expand Down
2 changes: 2 additions & 0 deletions src/relay/backend/contrib/codegen_c/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,8 @@ class CodegenCBase {
std::string dtype;
if (runtime::TypeMatch(ttype->dtype, kDLFloat, 32)) {
dtype = "float";
} else if (runtime::TypeMatch(ttype->dtype, kDLFloat, 16)) {
dtype = "half";
} else if (runtime::TypeMatch(ttype->dtype, kDLInt, 32)) {
dtype = "int";
} else if (runtime::TypeMatch(ttype->dtype, kDLInt, 64)) {
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class CSourceModuleNode : public runtime::ModuleNode {
void SaveToFile(const std::string& file_name, const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
if (fmt == "c") {
if (fmt == "c" || fmt == "cu") {
ICHECK_NE(code_.length(), 0);
SaveBinaryToFile(file_name, code_);
} else {
Expand Down

0 comments on commit 060105f

Please sign in to comment.