Skip to content

Commit 5d253f1

Browse files
committed
add support for cluster
1 parent 6af4da3 commit 5d253f1

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

cuda_core/cuda/core/experimental/_launcher.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Optional, Union
88

99
from cuda import cuda
10+
from cuda.core.experimental._device import Device
1011
from cuda.core.experimental._kernel_arg_handler import ParamHolder
1112
from cuda.core.experimental._module import Kernel
1213
from cuda.core.experimental._stream import Stream
@@ -38,10 +39,14 @@ class LaunchConfig:
3839
----------
3940
grid : Union[tuple, int]
4041
Collection of threads that will execute a kernel function.
42+
cluster : Union[tuple, int]
43+
Group of blocks (Thread Block Cluster) that will execute on the same
44+
GPU Processing Cluster (GPC). Blocks within a cluster have access to
45+
distributed shared memory and can be explicitly synchronized.
4146
block : Union[tuple, int]
4247
Group of threads (Thread Block) that will execute on the same
43-
multiprocessor. Threads within a thread blocks have access to
44-
shared memory and can be explicitly synchronized.
48+
streaming multiprocessor (SM). Threads within a thread blocks have
49+
access to shared memory and can be explicitly synchronized.
4550
stream : :obj:`Stream`
4651
The stream establishing the stream ordering semantic of a
4752
launch.
@@ -53,13 +58,22 @@ class LaunchConfig:
5358

5459
# TODO: expand LaunchConfig to include other attributes
5560
grid: Union[tuple, int] = None
61+
cluster: Union[tuple, int] = None
5662
block: Union[tuple, int] = None
5763
stream: Stream = None
5864
shmem_size: Optional[int] = None
5965

6066
def __post_init__(self):
67+
_lazy_init()
6168
self.grid = self._cast_to_3_tuple(self.grid)
6269
self.block = self._cast_to_3_tuple(self.block)
70+
# thread block clusters are supported starting H100
71+
if self.cluster is not None:
72+
if not _use_ex:
73+
raise CUDAError("thread block clusters require cuda.bindings & driver 11.8+")
74+
if Device().compute_capability < (9, 0):
75+
raise CUDAError("thread block clusters are not supported below Hopper")
76+
self.cluster = self._cast_to_3_tuple(self.cluster)
6377
# we handle "stream=None" in the launch API
6478
if self.stream is not None and not isinstance(self.stream, Stream):
6579
try:
@@ -69,8 +83,6 @@ def __post_init__(self):
6983
if self.shmem_size is None:
7084
self.shmem_size = 0
7185

72-
_lazy_init()
73-
7486
def _cast_to_3_tuple(self, cfg):
7587
if isinstance(cfg, int):
7688
if cfg < 1:
@@ -134,6 +146,12 @@ def launch(kernel, config, *kernel_args):
134146
drv_cfg.hStream = config.stream.handle
135147
drv_cfg.sharedMemBytes = config.shmem_size
136148
drv_cfg.numAttrs = 0 # TODO
149+
if config.cluster:
150+
drv_cfg.numAttrs += 1
151+
attr = cuda.CUlaunchAttribute()
152+
attr.id = cuda.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
153+
attr.value.clusterDim.x, attr.value.clusterDim.y, attr.value.clusterDim.z = config.cluster
154+
drv_cfg.attrs.append(attr)
137155
handle_return(cuda.cuLaunchKernelEx(drv_cfg, int(kernel._handle), args_ptr, 0))
138156
else:
139157
# TODO: check if config has any unsupported attrs

0 commit comments

Comments
 (0)