diff --git a/CMakeLists.txt b/CMakeLists.txt index 4eb2468e4e2f..3667ed6ba974 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -832,7 +832,7 @@ endif() if (USE_CUDA AND USE_NVSHMEM) - include_directories(SYSTEM ${USE_NVSHMEM}/include) + target_include_directories(tvm_runtime_objs PUBLIC ${NVSHMEM_INCLUDE_DIR}) find_library(NVSHMEM_HOST nvshmem_host ${NVSHMEM_LIB_DIR}) find_library(NVSHMEM_DEVICE nvshmem_device ${NVSHMEM_LIB_DIR}) target_link_libraries(tvm PRIVATE ${NVSHMEM_HOST} ${NVSHMEM_DEVICE}) diff --git a/include/tvm/runtime/disco/builtin.h b/include/tvm/runtime/disco/builtin.h index 93bf0084db87..d594de7247c8 100644 --- a/include/tvm/runtime/disco/builtin.h +++ b/include/tvm/runtime/disco/builtin.h @@ -62,7 +62,7 @@ inline std::string ReduceKind2String(ReduceKind kind) { * \param device The default device used to initialize the RelaxVM * \return The RelaxVM as a runtime Module */ -TVM_DLL Module LoadVMModule(std::string path, Device device); +TVM_DLL Module LoadVMModule(std::string path, Optional device); /*! * \brief Create an uninitialized empty NDArray * \param shape The shape of the NDArray @@ -70,7 +70,7 @@ TVM_DLL Module LoadVMModule(std::string path, Device device); * \param device The device the NDArray is created on. If None, use the thread local default device * \return The NDArray created */ -TVM_DLL NDArray DiscoEmptyNDArray(ffi::Shape shape, DataType dtype, Device device); +TVM_DLL NDArray DiscoEmptyNDArray(ffi::Shape shape, DataType dtype, Optional device); /*! * \brief Perform an allreduce operation using the underlying communication library * \param send The array send to perform allreduce on diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 45e2793fbb6f..c79305a739cd 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -21,6 +21,7 @@ import os import subprocess import warnings +from typing import Tuple import tvm.ffi from tvm.target import Target @@ -29,7 +30,7 @@ from . import utils -def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None): +def compile_cuda(code, target_format=None, arch=None, options=None, path_target=None): """Compile cuda code with NVCC from env. Parameters @@ -54,6 +55,15 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target cubin : bytearray The bytearray of the cubin """ + # Check for NVSHMEM dependency + nvshmem_include_path, nvshmem_lib_path = None, None + use_nvshmem = ( + tvm.get_global_func("runtime.nvshmem.cumodule_init", allow_missing=True) is not None + ) + if use_nvshmem: + target_format = "cubin" + nvshmem_include_path, nvshmem_lib_path = find_nvshmem_paths() + if arch is None: # If None, then it will use `tvm.target.Target.current().arch`. # Target arch could be a str like "sm_xx", or a list, such as @@ -68,6 +78,8 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target temp = utils.tempdir() file_name = "tvm_kernels" + if target_format is None and not use_nvshmem: + target_format = "ptx" if target_format not in ["cubin", "ptx", "fatbin"]: raise ValueError("target_format must be in cubin, ptx, fatbin") temp_code = temp.relpath(f"{file_name}.cu") @@ -89,6 +101,9 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target out_file.write(code) file_target = path_target if path_target else temp_target + if use_nvshmem: + file_prefix = file_target.split(".")[0] + file_target = f"{file_prefix}.o" # in the first stage, compile to object file cmd = ["nvcc"] cmd += [f"--{target_format}", "-O3"] if kernels_output_dir is not None: @@ -107,7 +122,12 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target raise ValueError("options must be str or list of str") cmd += ["-o", file_target] - cmd += [temp_code] + if not use_nvshmem: + cmd += [temp_code] + else: + cmd += ["-c", temp_code] + cmd += ["-rdc=true"] + cmd += ["-I", nvshmem_include_path] # NOTE: ccbin option can be used to tell nvcc where to find the c++ compiler # just in case it is not in the path. On Windows it is not in the path by default. @@ -127,6 +147,32 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target msg += py_str(out) raise RuntimeError(msg) + # start second stage of compilation + if use_nvshmem: + cmd = ["nvlink"] + cmd += [f"-arch=sm_{compute_version}"] + cmd += [ + "-L", + nvshmem_lib_path, + ] + cmd += ["-L", os.path.join(find_cuda_path(), "lib64")] + cmd += ["-l", "nvshmem_device"] + cmd += ["-l", "cudadevrt"] + cmd += ["-o", f"{file_prefix}.cubin"] + cmd += [file_target] + + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + + (out, _) = proc.communicate() + + if proc.returncode != 0: + msg = code + msg += "\nCompilation error:\n" + msg += py_str(out) + raise RuntimeError(msg) + + file_target = f"{file_prefix}.cubin" + with open(file_target, "rb") as f: data = bytearray(f.read()) if not data: @@ -198,6 +244,70 @@ def get_cuda_version(cuda_path=None): raise RuntimeError("Cannot read cuda version file") +def find_nvshmem_paths() -> Tuple[str, str]: + """ + Searches for the NVSHMEM include and library directories. + Returns: + A tuple containing the path to the include directory and the library directory. + (include_path, lib_path) + """ + candidate_roots = [] + + # 1. NVSHMEM_HOME env variable + if "NVSHMEM_HOME" in os.environ: + candidate_roots.append(os.environ["NVSHMEM_HOME"]) + + # 2. CUDA Toolkit + try: + cuda_home = find_cuda_path() + candidate_roots.append(cuda_home) + except RuntimeError: + pass + + # 3. Other common system installation paths + candidate_roots.extend(["/usr/local", "/usr"]) + + seen = set() + unique_candidates = [] + for path in candidate_roots: + if path and path not in seen: + seen.add(path) + unique_candidates.append(path) + + for root in unique_candidates: + include_path = os.path.join(root, "include") + lib_paths_to_check = [ + os.path.join(root, "lib64"), + os.path.join(root, "lib"), + ] + + if os.path.isfile(os.path.join(include_path, "nvshmem.h")): + for lib_path in lib_paths_to_check: + if os.path.isfile(os.path.join(lib_path, "libnvshmem.a")): + return include_path, lib_path + + error_message = [ + "Error: Could not find NVSHMEM installation.", + "Searched in the following locations:", + ] + error_message.extend([f" - {path}" for path in unique_candidates]) + error_message.extend( + [ + "", + "Please ensure NVSHMEM is installed and try one of the following:", + ( + " 1. Set the 'NVSHMEM_HOME' environment variable " + "to your NVSHMEM installation directory." + ), + ( + " 2. Ensure your CUDA Toolkit installation includes NVSHMEM and " + "'nvcc' is on your PATH." + ), + ] + ) + raise RuntimeError("\n".join(error_message)) + + @tvm.ffi.register_func def tvm_callback_cuda_compile(code, target): # pylint: disable=unused-argument """use nvcc to generate fatbin code for better optimization""" diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index c551eac428b7..bd0d3d8ed869 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -150,8 +150,6 @@ def empty( The created NDArray. """ - if device is None: - device = Device(device_type=0, device_id=0) func = self._get_cached_method("runtime.disco.empty") return func(ShapeTuple(shape), dtype, device, worker0_only, in_group) @@ -237,6 +235,12 @@ def _sync_worker(self, worker_id: int) -> None: """ return _ffi_api.SessionSyncWorker(self, worker_id) # type: ignore # pylint: disable=no-member + def _sync_all(self) -> None: + """Synchronize the controller with all workers in the current session, and it will + wait until all workers finish executing all the existing instructions.""" + for i in range(self.num_workers): + self._sync_worker(i) + def sync_worker_0(self) -> None: """Synchronize the controller with worker-0, and it will wait until the worker-0 finishes executing all the existing instructions.""" @@ -302,8 +306,6 @@ def load_vm_module( module : DModule The loaded VM module. """ - if device is None: - device = Device(device_type=0, device_id=0) func = self._get_cached_method("runtime.disco.load_vm_module") return DModule(func(path, device), self) diff --git a/src/runtime/contrib/nvshmem/init.cc b/src/runtime/contrib/nvshmem/init.cc index 7b4a617a2501..6dea2281f714 100644 --- a/src/runtime/contrib/nvshmem/init.cc +++ b/src/runtime/contrib/nvshmem/init.cc @@ -106,6 +106,18 @@ void InitNVSHMEMWrapper(String args) { InitNVSHMEM(uid_64, num_workers, worker_id_start); } +void NVSHMEMXCumoduleInit(void* cuModule) { + CUmodule mod = static_cast(cuModule); + auto status = nvshmemx_init_status(); + // The NVSHMEM library must have completed device initialization prior to + // nvshmemx_cumodule_init. If not, we skip the cumodule initialization. + if (status == NVSHMEM_STATUS_IS_INITIALIZED || status == NVSHMEM_STATUS_LIMITED_MPG || + status == NVSHMEM_STATUS_FULL_MPG) { + int result = nvshmemx_cumodule_init(mod); + ICHECK_EQ(result, 0) << "nvshmemx_cumodule_init failed with error code: " << result; + } +} + TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_uid").set_body_typed(InitNVSHMEMUID); TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(InitNVSHMEM); @@ -113,5 +125,7 @@ TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(Ini TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_wrapper") .set_body_typed(InitNVSHMEMWrapper); +TVM_FFI_REGISTER_GLOBAL("runtime.nvshmem.cumodule_init").set_body_typed(NVSHMEMXCumoduleInit); + } // namespace runtime } // namespace tvm diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index acb2dc6cdf11..a29d303acf7f 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -108,6 +108,10 @@ class CUDAModuleNode : public runtime::ModuleNode { // must recheck under the lock scope if (module_[device_id] == nullptr) { CUDA_DRIVER_CALL(cuModuleLoadData(&(module_[device_id]), data_.c_str())); + static auto nvshmem_init_hook = ffi::Function::GetGlobal("runtime.nvshmem.cumodule_init"); + if (nvshmem_init_hook.has_value()) { + (*nvshmem_init_hook)(static_cast(module_[device_id])); + } } CUfunction func; CUresult result = cuModuleGetFunction(&func, module_[device_id], func_name.c_str()); @@ -124,6 +128,10 @@ class CUDAModuleNode : public runtime::ModuleNode { // must recheck under the lock scope if (module_[device_id] == nullptr) { CUDA_DRIVER_CALL(cuModuleLoadData(&(module_[device_id]), data_.c_str())); + static auto nvshmem_init_hook = ffi::Function::GetGlobal("runtime.nvshmem.cumodule_init"); + if (nvshmem_init_hook.has_value()) { + (*nvshmem_init_hook)(static_cast(module_[device_id])); + } } CUdeviceptr global; size_t nbytes; diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc index 9cd76e673a45..9cd4c5eda4af 100644 --- a/src/runtime/disco/builtin.cc +++ b/src/runtime/disco/builtin.cc @@ -46,26 +46,27 @@ class DSOLibraryCache { std::mutex mutex_; }; -Module LoadVMModule(std::string path, Device device) { +Module LoadVMModule(std::string path, Optional device) { static DSOLibraryCache cache; Module dso_mod = cache.Open(path); - device = UseDefaultDeviceIfNone(device); + Device dev = UseDefaultDeviceIfNone(device); ffi::Function vm_load_executable = dso_mod.GetFunction("vm_load_executable"); - CHECK(vm_load_executable != nullptr) - << "ValueError: File `" << path - << "` is not built by RelaxVM, because `vm_load_executable` does not exist"; + if (vm_load_executable == nullptr) { + // not built by RelaxVM, return the dso_mod directly + return dso_mod; + } auto mod = vm_load_executable().cast(); ffi::Function vm_initialization = mod.GetFunction("vm_initialization"); CHECK(vm_initialization != nullptr) << "ValueError: File `" << path << "` is not built by RelaxVM, because `vm_initialization` does not exist"; - vm_initialization(static_cast(device.device_type), static_cast(device.device_id), + vm_initialization(static_cast(dev.device_type), static_cast(dev.device_id), static_cast(AllocatorType::kPooled), static_cast(kDLCPU), 0, static_cast(AllocatorType::kPooled)); return mod; } -NDArray DiscoEmptyNDArray(ffi::Shape shape, DataType dtype, Device device) { +NDArray DiscoEmptyNDArray(ffi::Shape shape, DataType dtype, Optional device) { return NDArray::Empty(shape, dtype, UseDefaultDeviceIfNone(device)); } @@ -123,7 +124,7 @@ void SyncWorker() { TVM_FFI_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule); TVM_FFI_REGISTER_GLOBAL("runtime.disco.empty") - .set_body_typed([](ffi::Shape shape, DataType dtype, Device device, bool worker0_only, + .set_body_typed([](ffi::Shape shape, DataType dtype, Optional device, bool worker0_only, bool in_group) -> Optional { int worker_id = WorkerId(); int group_size = diff --git a/src/runtime/disco/utils.h b/src/runtime/disco/utils.h index fa58c73aa787..f0a10b6093d4 100644 --- a/src/runtime/disco/utils.h +++ b/src/runtime/disco/utils.h @@ -27,11 +27,8 @@ namespace tvm { namespace runtime { -inline Device UseDefaultDeviceIfNone(Device device) { - if (device.device_type == 0 && device.device_id == 0) { - return DiscoWorker::ThreadLocal()->default_device; - } - return device; +inline Device UseDefaultDeviceIfNone(Optional device) { + return device.value_or(DiscoWorker::ThreadLocal()->default_device); } /*! diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index be49fd39bd18..d4e1b785b866 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -297,19 +297,10 @@ std::string CodeGenCUDA::Finish() { decl_stream << "#define TVM_ENABLE_L2_PREFETCH 0\n"; decl_stream << "#endif\n"; - decl_stream << "\n#ifdef _WIN32\n"; - decl_stream << " using uint = unsigned int;\n"; - decl_stream << " using uchar = unsigned char;\n"; - decl_stream << " using ushort = unsigned short;\n"; - decl_stream << " using int64_t = long long;\n"; - decl_stream << " using uint64_t = unsigned long long;\n"; - decl_stream << "#else\n"; - decl_stream << " #define uint unsigned int\n"; - decl_stream << " #define uchar unsigned char\n"; - decl_stream << " #define ushort unsigned short\n"; - decl_stream << " #define int64_t long long\n"; - decl_stream << " #define uint64_t unsigned long long\n"; - decl_stream << "#endif\n"; + decl_stream << "#include \n"; + decl_stream << "using uint = unsigned int;\n"; + decl_stream << "using uchar = unsigned char;\n"; + decl_stream << "using ushort = unsigned short;\n"; return CodeGenC::Finish(); } diff --git a/tests/python/disco/test_nvshmem.py b/tests/python/disco/test_nvshmem.py index 1c4ffc9c4d08..d9976e05e50b 100644 --- a/tests/python/disco/test_nvshmem.py +++ b/tests/python/disco/test_nvshmem.py @@ -26,6 +26,8 @@ from multiprocessing import Process from typing import Any, Callable, List +from tvm.script import tir as T + import tvm import tvm.testing @@ -134,6 +136,55 @@ def test_nvshmem_empty(session_kind: di.Session, num_workers: int): sess.sync_worker_0() +def test_nvshmem_compile(): + if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is None: + return + + num_workers = 4 + sess = di.ProcessSession(num_workers=num_workers) + + f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid") + uid = f_init_nvshmem_uid() + init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem") + init_dfunc(uid, num_workers, 0) + sess.sync_worker_0() + + @T.prim_func + def main(A: T.Buffer((8, 16), "float32"), B: T.Buffer((16, 8), "float32")): + for i in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for j in T.thread_binding(T.int64(16), thread="threadIdx.x"): + with T.block("T_transpose"): + v0 = T.axis.spatial(T.int64(8), i) + v1 = T.axis.spatial(T.int64(16), j) + T.reads(A[v0, v1]) + T.writes(B[v1, v0]) + B[v1, v0] = A[v0, v1] + + with tempfile.TemporaryDirectory() as tmpdir: + path = tmpdir + "/test.so" + A_np = np.arange(8 * 16).astype("float32").reshape([8, 16]) + B_np = np.zeros((16, 8), dtype="float32") + A_array = sess.empty(A_np.shape, "float32") + B_array = sess.empty(B_np.shape, "float32") + A_array.debug_copy_from(0, A_np) + + target = tvm.target.Target("cuda") + tvm.compile(main, target=target).export_library(path) + mod = sess.load_vm_module(path) + mod["main"](A_array, B_array) + + B_res = B_array.debug_get_from_remote(0).numpy() + np.testing.assert_equal(B_res, A_np.T) + + # sync all workers to make sure the temporary files are cleaned up after all workers + # finish the execution + sess._sync_all() + + finalize_dfunc = sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem") + finalize_dfunc() + sess.sync_worker_0() + + if __name__ == "__main__": # After the first call to `nvshmem_init`, a subsequent call to `nvshmem_init` # or `nvshmem_init_thread` in the same program results in undefined behavior. @@ -145,3 +196,8 @@ def test_nvshmem_empty(session_kind: di.Session, num_workers: int): p = Process(target=test_func, args=[session_kind, num_workers]) p.start() p.join() + + # testing compilation flow + p = Process(target=test_nvshmem_compile) + p.start() + p.join() diff --git a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py index 13487b42f00f..f620610f3977 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py @@ -254,20 +254,10 @@ def test_inject_async_copy_barrier(): #else #define TVM_ENABLE_L2_PREFETCH 0 #endif - -#ifdef _WIN32 - using uint = unsigned int; - using uchar = unsigned char; - using ushort = unsigned short; - using int64_t = long long; - using uint64_t = unsigned long long; -#else - #define uint unsigned int - #define uchar unsigned char - #define ushort unsigned short - #define int64_t long long - #define uint64_t unsigned long long -#endif +#include +using uint = unsigned int; +using uchar = unsigned char; +using ushort = unsigned short; extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C); extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) { __shared__ float A_shared[64];