diff --git a/ffi/python/tvm_ffi/cpp/load_inline.py b/ffi/python/tvm_ffi/cpp/load_inline.py index 111dee8d5276..3bc0fc4cbc73 100644 --- a/ffi/python/tvm_ffi/cpp/load_inline.py +++ b/ffi/python/tvm_ffi/cpp/load_inline.py @@ -326,10 +326,12 @@ def load_inline( cuda_sources: Sequence[str] | str, optional The CUDA source code. It can be a list of sources or a single source. functions: Mapping[str, str] | Sequence[str] | str, optional - The functions in cpp_sources that will be exported to the tvm ffi module. When a mapping is given, the keys - are the names of the exported functions, and the values are docstrings for the functions. When a sequence or a - single string is given, they are the functions needed to be exported, and the docstrings are set to empty - strings. A single function name can also be given as a string. + The functions in cpp_sources or cuda_source that will be exported to the tvm ffi module. When a mapping is + given, the keys are the names of the exported functions, and the values are docstrings for the functions. When + a sequence or a single string is given, they are the functions needed to be exported, and the docstrings are set + to empty strings. A single function name can also be given as a string. When cpp_sources is given, the functions + must be declared (not necessarily defined) in the cpp_sources. When cpp_sources is not given, the functions + must be defined in the cuda_sources. If not specified, no function will be exported. extra_cflags: Sequence[str], optional The extra compiler flags for C++ compilation. The default flags are: @@ -369,6 +371,7 @@ def load_inline( elif isinstance(cuda_sources, str): cuda_sources = [cuda_sources] cuda_source = "\n".join(cuda_sources) + with_cpp = len(cpp_sources) > 0 with_cuda = len(cuda_sources) > 0 extra_ldflags = extra_ldflags or [] @@ -381,8 +384,13 @@ def load_inline( functions = {functions: ""} elif isinstance(functions, Sequence): functions = {name: "" for name in functions} - cpp_source = _decorate_with_tvm_ffi(cpp_source, functions) - cuda_source = _decorate_with_tvm_ffi(cuda_source, {}) + + if with_cpp: + cpp_source = _decorate_with_tvm_ffi(cpp_source, functions) + cuda_source = _decorate_with_tvm_ffi(cuda_source, {}) + else: + cpp_source = _decorate_with_tvm_ffi(cpp_source, {}) + cuda_source = _decorate_with_tvm_ffi(cuda_source, functions) # determine the cache dir for the built module if build_directory is None: diff --git a/ffi/tests/python/test_load_inline.py b/ffi/tests/python/test_load_inline.py index 9a10476d8eff..28ca4b3709e8 100644 --- a/ffi/tests/python/test_load_inline.py +++ b/ffi/tests/python/test_load_inline.py @@ -159,9 +159,6 @@ def test_load_inline_cpp_build_dir(): def test_load_inline_cuda(): mod: Module = tvm_ffi.cpp.load_inline( name="hello", - cpp_sources=r""" - void add_one_cuda(DLTensor* x, DLTensor* y); - """, cuda_sources=r""" __global__ void AddOneKernel(float* x, float* y, int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x;