From 5d253f18de9b8844b60b4d441e9d69f3af19e73b Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Tue, 3 Dec 2024 04:37:11 +0000 Subject: [PATCH 1/7] add support for cluster --- cuda_core/cuda/core/experimental/_launcher.py | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_launcher.py b/cuda_core/cuda/core/experimental/_launcher.py index 55af5e30..157705cb 100644 --- a/cuda_core/cuda/core/experimental/_launcher.py +++ b/cuda_core/cuda/core/experimental/_launcher.py @@ -7,6 +7,7 @@ from typing import Optional, Union from cuda import cuda +from cuda.core.experimental._device import Device from cuda.core.experimental._kernel_arg_handler import ParamHolder from cuda.core.experimental._module import Kernel from cuda.core.experimental._stream import Stream @@ -38,10 +39,14 @@ class LaunchConfig: ---------- grid : Union[tuple, int] Collection of threads that will execute a kernel function. + cluster : Union[tuple, int] + Group of blocks (Thread Block Cluster) that will execute on the same + GPU Processing Cluster (GPC). Blocks within a cluster have access to + distributed shared memory and can be explicitly synchronized. block : Union[tuple, int] Group of threads (Thread Block) that will execute on the same - multiprocessor. Threads within a thread blocks have access to - shared memory and can be explicitly synchronized. + streaming multiprocessor (SM). Threads within a thread blocks have + access to shared memory and can be explicitly synchronized. stream : :obj:`Stream` The stream establishing the stream ordering semantic of a launch. @@ -53,13 +58,22 @@ class LaunchConfig: # TODO: expand LaunchConfig to include other attributes grid: Union[tuple, int] = None + cluster: Union[tuple, int] = None block: Union[tuple, int] = None stream: Stream = None shmem_size: Optional[int] = None def __post_init__(self): + _lazy_init() self.grid = self._cast_to_3_tuple(self.grid) self.block = self._cast_to_3_tuple(self.block) + # thread block clusters are supported starting H100 + if self.cluster is not None: + if not _use_ex: + raise CUDAError("thread block clusters require cuda.bindings & driver 11.8+") + if Device().compute_capability < (9, 0): + raise CUDAError("thread block clusters are not supported below Hopper") + self.cluster = self._cast_to_3_tuple(self.cluster) # we handle "stream=None" in the launch API if self.stream is not None and not isinstance(self.stream, Stream): try: @@ -69,8 +83,6 @@ def __post_init__(self): if self.shmem_size is None: self.shmem_size = 0 - _lazy_init() - def _cast_to_3_tuple(self, cfg): if isinstance(cfg, int): if cfg < 1: @@ -134,6 +146,12 @@ def launch(kernel, config, *kernel_args): drv_cfg.hStream = config.stream.handle drv_cfg.sharedMemBytes = config.shmem_size drv_cfg.numAttrs = 0 # TODO + if config.cluster: + drv_cfg.numAttrs += 1 + attr = cuda.CUlaunchAttribute() + attr.id = cuda.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION + attr.value.clusterDim.x, attr.value.clusterDim.y, attr.value.clusterDim.z = config.cluster + drv_cfg.attrs.append(attr) handle_return(cuda.cuLaunchKernelEx(drv_cfg, int(kernel._handle), args_ptr, 0)) else: # TODO: check if config has any unsupported attrs From 4abe5206ac75224a1e03a1198a10f1daa5eec2f0 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Mon, 2 Dec 2024 21:16:59 -0800 Subject: [PATCH 2/7] add a code sample; apply a WAR to a potential bug --- cuda_core/cuda/core/experimental/_launcher.py | 2 +- cuda_core/examples/thread_block_cluster.py | 61 +++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 cuda_core/examples/thread_block_cluster.py diff --git a/cuda_core/cuda/core/experimental/_launcher.py b/cuda_core/cuda/core/experimental/_launcher.py index 157705cb..c7cb36ef 100644 --- a/cuda_core/cuda/core/experimental/_launcher.py +++ b/cuda_core/cuda/core/experimental/_launcher.py @@ -151,7 +151,7 @@ def launch(kernel, config, *kernel_args): attr = cuda.CUlaunchAttribute() attr.id = cuda.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION attr.value.clusterDim.x, attr.value.clusterDim.y, attr.value.clusterDim.z = config.cluster - drv_cfg.attrs.append(attr) + drv_cfg.attrs = [attr] # TODO: WHAT!! handle_return(cuda.cuLaunchKernelEx(drv_cfg, int(kernel._handle), args_ptr, 0)) else: # TODO: check if config has any unsupported attrs diff --git a/cuda_core/examples/thread_block_cluster.py b/cuda_core/examples/thread_block_cluster.py new file mode 100644 index 00000000..e933c54d --- /dev/null +++ b/cuda_core/examples/thread_block_cluster.py @@ -0,0 +1,61 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +import os +import sys + +from cuda.core.experimental import Device, LaunchConfig, Program, launch + + +# prepare include +cuda_path = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME")) +if cuda_path is None: + print("this demo requires a valid CUDA_PATH environment variable set") + sys.exit(0) +cuda_include_path = os.path.join(cuda_path, "include") + +# print cluster info +code = r""" +#include + +namespace cg = cooperative_groups; + +extern "C" +__global__ void check_cluster_info() { + auto g = cg::this_grid(); + auto b = cg::this_thread_block(); + if (g.cluster_rank() == 0 && g.block_rank() == 0 && g.thread_rank() == 0) { + printf("grid dim: (%u, %u, %u)\n", g.dim_blocks().x, g.dim_blocks().y, g.dim_blocks().z); + printf("cluster dim: (%u, %u, %u)\n", g.dim_clusters().x, g.dim_clusters().y, g.dim_clusters().z); + printf("block dim: (%u, %u, %u)\n", b.dim_threads().x, b.dim_threads().y, b.dim_threads().z); + } +} +""" + +dev = Device() +dev.set_current() +arch = "".join(f"{i}" for i in dev.compute_capability) + +# prepare program +prog = Program(code, code_type="c++") +mod = prog.compile( + target_type="cubin", + # TODO: update this after NVIDIA/cuda-python#237 is merged + options=(f"-arch=sm_{arch}", "-std=c++17", f"-I{cuda_include_path}"), +) + +# run in single precision +ker = mod.get_kernel("check_cluster_info") + +# prepare launch +grid = 4 +cluster = 2 +block = 32 +config = LaunchConfig(grid=grid, cluster=cluster, block=block, stream=dev.default_stream) + +# launch kernel on the default stream +launch(ker, config) +dev.sync() + +print("done!") From 895c9fad0a9b0021a79a44ad34bcb8e9238e93df Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Wed, 4 Dec 2024 03:08:08 +0000 Subject: [PATCH 3/7] more robust treatments --- cuda_core/cuda/core/experimental/_launcher.py | 10 +- cuda_core/examples/thread_block_cluster.py | 13 +- cuda_core/tests/example_tests/utils.py | 115 +++++++++--------- 3 files changed, 73 insertions(+), 65 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_launcher.py b/cuda_core/cuda/core/experimental/_launcher.py index c7cb36ef..de4e7649 100644 --- a/cuda_core/cuda/core/experimental/_launcher.py +++ b/cuda_core/cuda/core/experimental/_launcher.py @@ -145,13 +145,15 @@ def launch(kernel, config, *kernel_args): drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block drv_cfg.hStream = config.stream.handle drv_cfg.sharedMemBytes = config.shmem_size - drv_cfg.numAttrs = 0 # TODO + attrs = [] # TODO: support more attributes if config.cluster: - drv_cfg.numAttrs += 1 attr = cuda.CUlaunchAttribute() attr.id = cuda.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION - attr.value.clusterDim.x, attr.value.clusterDim.y, attr.value.clusterDim.z = config.cluster - drv_cfg.attrs = [attr] # TODO: WHAT!! + dim = attr.value.clusterDim + dim.x, dim.y, dim.z = config.cluster + attrs.append(attr) + drv_cfg.numAttrs = len(attrs) + drv_cfg.attrs = attrs handle_return(cuda.cuLaunchKernelEx(drv_cfg, int(kernel._handle), args_ptr, 0)) else: # TODO: check if config has any unsupported attrs diff --git a/cuda_core/examples/thread_block_cluster.py b/cuda_core/examples/thread_block_cluster.py index e933c54d..74181469 100644 --- a/cuda_core/examples/thread_block_cluster.py +++ b/cuda_core/examples/thread_block_cluster.py @@ -7,15 +7,14 @@ from cuda.core.experimental import Device, LaunchConfig, Program, launch - # prepare include cuda_path = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME")) if cuda_path is None: - print("this demo requires a valid CUDA_PATH environment variable set") + print("this demo requires a valid CUDA_PATH environment variable set", file=sys.stderr) sys.exit(0) cuda_include_path = os.path.join(cuda_path, "include") -# print cluster info +# print cluster info using a kernel code = r""" #include @@ -35,7 +34,11 @@ dev = Device() dev.set_current() -arch = "".join(f"{i}" for i in dev.compute_capability) +arch = dev.compute_capability +if arch < (9, 0): + print("this demo requires a Hopper GPU (since thread block cluster is a hardware feature)", file=sys.stderr) + sys.exit(0) +arch = "".join(f"{i}" for i in arch) # prepare program prog = Program(code, code_type="c++") @@ -48,7 +51,7 @@ # run in single precision ker = mod.get_kernel("check_cluster_info") -# prepare launch +# prepare launch config grid = 4 cluster = 2 block = 32 diff --git a/cuda_core/tests/example_tests/utils.py b/cuda_core/tests/example_tests/utils.py index f6ac3e15..731adedb 100644 --- a/cuda_core/tests/example_tests/utils.py +++ b/cuda_core/tests/example_tests/utils.py @@ -1,56 +1,59 @@ -# Copyright 2024 NVIDIA Corporation. All rights reserved. -# -# Please refer to the NVIDIA end user license agreement (EULA) associated -# with this source code for terms and conditions that govern your use of -# this software. Any use, reproduction, disclosure, or distribution of -# this software and related documentation outside the terms of the EULA -# is strictly prohibited. - -import gc -import os -import sys - -import cupy as cp -import pytest - - -class SampleTestError(Exception): - pass - - -def parse_python_script(filepath): - if not filepath.endswith(".py"): - raise ValueError(f"{filepath} not supported") - with open(filepath, encoding="utf-8") as f: - script = f.read() - return script - - -def run_example(samples_path, filename, env=None): - fullpath = os.path.join(samples_path, filename) - script = parse_python_script(fullpath) - try: - old_argv = sys.argv - sys.argv = [fullpath] - old_sys_path = sys.path.copy() - sys.path.append(samples_path) - exec(script, env if env else {}) - except ImportError as e: - # for samples requiring any of optional dependencies - for m in ("cupy",): - if f"No module named '{m}'" in str(e): - pytest.skip(f"{m} not installed, skipping related tests") - break - else: - raise - except Exception as e: - msg = "\n" - msg += f"Got error ({filename}):\n" - msg += str(e) - raise SampleTestError(msg) from e - finally: - sys.path = old_sys_path - sys.argv = old_argv - # further reduce the memory watermark - gc.collect() - cp.get_default_memory_pool().free_all_blocks() +# Copyright 2024 NVIDIA Corporation. All rights reserved. +# +# Please refer to the NVIDIA end user license agreement (EULA) associated +# with this source code for terms and conditions that govern your use of +# this software. Any use, reproduction, disclosure, or distribution of +# this software and related documentation outside the terms of the EULA +# is strictly prohibited. + +import gc +import os +import sys + +import cupy as cp +import pytest + + +class SampleTestError(Exception): + pass + + +def parse_python_script(filepath): + if not filepath.endswith(".py"): + raise ValueError(f"{filepath} not supported") + with open(filepath, encoding="utf-8") as f: + script = f.read() + return script + + +def run_example(samples_path, filename, env=None): + fullpath = os.path.join(samples_path, filename) + script = parse_python_script(fullpath) + try: + old_argv = sys.argv + sys.argv = [fullpath] + old_sys_path = sys.path.copy() + sys.path.append(samples_path) + exec(script, env if env else {}) + except ImportError as e: + # for samples requiring any of optional dependencies + for m in ("cupy",): + if f"No module named '{m}'" in str(e): + pytest.skip(f"{m} not installed, skipping related tests") + break + else: + raise + except SystemExit: + # for samples that early return due to any missing requirements + pytest.skip(f"skip {filename}") + except Exception as e: + msg = "\n" + msg += f"Got error ({filename}):\n" + msg += str(e) + raise SampleTestError(msg) from e + finally: + sys.path = old_sys_path + sys.argv = old_argv + # further reduce the memory watermark + gc.collect() + cp.get_default_memory_pool().free_all_blocks() From a003b984ef5d7a278bbf4163de6a6c5cb95a44ae Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Wed, 4 Dec 2024 03:18:42 +0000 Subject: [PATCH 4/7] add release note entries --- cuda_core/docs/source/release/0.1.1-notes.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/cuda_core/docs/source/release/0.1.1-notes.md b/cuda_core/docs/source/release/0.1.1-notes.md index 473352a4..db5bbcda 100644 --- a/cuda_core/docs/source/release/0.1.1-notes.md +++ b/cuda_core/docs/source/release/0.1.1-notes.md @@ -3,9 +3,24 @@ Released on Dec XX, 2024 ## Hightlights + - Add `StridedMemoryView` and `@args_viewable_as_strided_memory` that provide a concrete implementation of DLPack & CUDA Array Interface supports. +## New features + +- Add `LaunchConfig.cluster` to support thread block clusters on Hopper GPUs. + +## Enchancements + +- Ensure "ltoir" is a valid code type to `ObjectCode`. +- Improve test coverage. +- Enforce code formatting. + +## Bug fixes + +- Eliminate potential class destruction issues. +- Fix circular import during handling a foreign CUDA stream. ## Limitations From 20692de60649aa9e6d513736453a4c187f34880c Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Tue, 3 Dec 2024 20:03:02 -0800 Subject: [PATCH 5/7] fix invalid context during test teardown --- cuda_core/tests/conftest.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cuda_core/tests/conftest.py b/cuda_core/tests/conftest.py index 59e5883f..b67eeec2 100644 --- a/cuda_core/tests/conftest.py +++ b/cuda_core/tests/conftest.py @@ -30,6 +30,10 @@ def init_cuda(): def _device_unset_current(): + ctx = handle_return(driver.cuCtxGetCurrent()) + if int(ctx) == 0: + # no active context, do nothing + return handle_return(driver.cuCtxPopCurrent()) with _device._tls_lock: del _device._tls.devices From 5ac409b6c7b07cb4c400320a056acb8170ef4101 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Tue, 3 Dec 2024 20:03:22 -0800 Subject: [PATCH 6/7] improve comments in the code sample --- cuda_core/examples/thread_block_cluster.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cuda_core/examples/thread_block_cluster.py b/cuda_core/examples/thread_block_cluster.py index 74181469..ee76b31d 100644 --- a/cuda_core/examples/thread_block_cluster.py +++ b/cuda_core/examples/thread_block_cluster.py @@ -33,22 +33,20 @@ """ dev = Device() -dev.set_current() arch = dev.compute_capability if arch < (9, 0): print("this demo requires a Hopper GPU (since thread block cluster is a hardware feature)", file=sys.stderr) sys.exit(0) arch = "".join(f"{i}" for i in arch) -# prepare program +# prepare program & compile kernel +dev.set_current() prog = Program(code, code_type="c++") mod = prog.compile( target_type="cubin", # TODO: update this after NVIDIA/cuda-python#237 is merged options=(f"-arch=sm_{arch}", "-std=c++17", f"-I{cuda_include_path}"), ) - -# run in single precision ker = mod.get_kernel("check_cluster_info") # prepare launch config From b8004e9bff153e6025373c60c8bc96487fc13035 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Fri, 13 Dec 2024 10:59:33 -0800 Subject: [PATCH 7/7] switch from chip chen to compute capability in comments --- cuda_core/cuda/core/experimental/_launcher.py | 2 +- cuda_core/examples/thread_block_cluster.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_launcher.py b/cuda_core/cuda/core/experimental/_launcher.py index de4e7649..fa37b817 100644 --- a/cuda_core/cuda/core/experimental/_launcher.py +++ b/cuda_core/cuda/core/experimental/_launcher.py @@ -72,7 +72,7 @@ def __post_init__(self): if not _use_ex: raise CUDAError("thread block clusters require cuda.bindings & driver 11.8+") if Device().compute_capability < (9, 0): - raise CUDAError("thread block clusters are not supported below Hopper") + raise CUDAError("thread block clusters are not supported on devices with compute capability < 9.0") self.cluster = self._cast_to_3_tuple(self.cluster) # we handle "stream=None" in the launch API if self.stream is not None and not isinstance(self.stream, Stream): diff --git a/cuda_core/examples/thread_block_cluster.py b/cuda_core/examples/thread_block_cluster.py index ee76b31d..fa70738d 100644 --- a/cuda_core/examples/thread_block_cluster.py +++ b/cuda_core/examples/thread_block_cluster.py @@ -35,7 +35,10 @@ dev = Device() arch = dev.compute_capability if arch < (9, 0): - print("this demo requires a Hopper GPU (since thread block cluster is a hardware feature)", file=sys.stderr) + print( + "this demo requires compute capability >= 9.0 (since thread block cluster is a hardware feature)", + file=sys.stderr, + ) sys.exit(0) arch = "".join(f"{i}" for i in arch)