diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py index 3eb364853390..98e549cc9c32 100644 --- a/python/tvm/tir/build.py +++ b/python/tvm/tir/build.py @@ -17,14 +17,14 @@ # pylint: disable=invalid-name """The build utils in python.""" -from typing import Union, Optional, Dict, Tuple +from typing import Dict, Optional, Tuple, Union import tvm from tvm import ir -from tvm.runtime import ndarray -from tvm.tir import PrimFunc from tvm.ir.module import IRModule +from tvm.runtime import ndarray from tvm.target import Target +from tvm.tir import PrimFunc def split_host_device_mods(mod: IRModule) -> Tuple[IRModule, Dict[Target, IRModule]]: @@ -100,10 +100,12 @@ def main(self_handle: T.handle, args: T.handle, num_args: T.int32, result: T.han - Device kernel functions: use `calling_conv: 2` (kDeviceKernelLaunch) """ - host_mod = tvm.tir.transform.Filter(lambda f: "cpu" in str(f.attrs.get("target", "cpu")))(mod) - device_mod = tvm.tir.transform.Filter(lambda f: "cpu" not in str(f.attrs.get("target", "cpu")))( - mod - ) + def is_host_func(f): + target = f.attrs.get("target", tvm.target.Target("llvm")) + return str(target.kind) in ["llvm", "c"] + + host_mod = tvm.tir.transform.Filter(is_host_func)(mod) + device_mod = tvm.tir.transform.Filter(lambda f: not is_host_func(f))(mod) # TODO(syfeng): Here we use str as key since target hash is not correct target_str2target = {} device_func_dict = {} diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index 28dfb6b9d4cb..a304cb1e41c7 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -820,10 +820,12 @@ def main( for tx in T.thread_binding(length, "threadIdx.x"): C[bx, tx] = Module.add(A[bx, tx], B[bx, tx]) # Call from device - # If we set host to llvm, it will raise an error of - # "the tir.ret should be transformed to return zero before the llvm code generation." - # Need to revisit this. - target = tvm.target.Target("cuda", host="c") + # 1. If we set host to llvm, it will raise an error of + # "the tir.ret should be transformed to return zero before the llvm code generation." + # Need to revisit this. + # 2. We set a dummy mcpu value for testing purpose, + # in order to avoid checking a function is host or device based on the "cpu" substring. + target = tvm.target.Target({"kind": "cuda", "mcpu": "dummy_mcpu"}, host="c") lib = tvm.compile(Module, target=target) cuda_code = lib.mod.imported_modules[0].get_source() assert 'extern "C" __device__ int add(int a, int b) {\n return (a + b);\n}' in cuda_code