7
7
from typing import Optional , Union
8
8
9
9
from cuda import cuda
10
+ from cuda .core .experimental ._device import Device
10
11
from cuda .core .experimental ._kernel_arg_handler import ParamHolder
11
12
from cuda .core .experimental ._module import Kernel
12
13
from cuda .core .experimental ._stream import Stream
@@ -38,10 +39,14 @@ class LaunchConfig:
38
39
----------
39
40
grid : Union[tuple, int]
40
41
Collection of threads that will execute a kernel function.
42
+ cluster : Union[tuple, int]
43
+ Group of blocks (Thread Block Cluster) that will execute on the same
44
+ GPU Processing Cluster (GPC). Blocks within a cluster have access to
45
+ distributed shared memory and can be explicitly synchronized.
41
46
block : Union[tuple, int]
42
47
Group of threads (Thread Block) that will execute on the same
43
- multiprocessor. Threads within a thread blocks have access to
44
- shared memory and can be explicitly synchronized.
48
+ streaming multiprocessor (SM) . Threads within a thread blocks have
49
+ access to shared memory and can be explicitly synchronized.
45
50
stream : :obj:`Stream`
46
51
The stream establishing the stream ordering semantic of a
47
52
launch.
@@ -53,13 +58,22 @@ class LaunchConfig:
53
58
54
59
# TODO: expand LaunchConfig to include other attributes
55
60
grid : Union [tuple , int ] = None
61
+ cluster : Union [tuple , int ] = None
56
62
block : Union [tuple , int ] = None
57
63
stream : Stream = None
58
64
shmem_size : Optional [int ] = None
59
65
60
66
def __post_init__ (self ):
67
+ _lazy_init ()
61
68
self .grid = self ._cast_to_3_tuple (self .grid )
62
69
self .block = self ._cast_to_3_tuple (self .block )
70
+ # thread block clusters are supported starting H100
71
+ if self .cluster is not None :
72
+ if not _use_ex :
73
+ raise CUDAError ("thread block clusters require cuda.bindings & driver 11.8+" )
74
+ if Device ().compute_capability < (9 , 0 ):
75
+ raise CUDAError ("thread block clusters are not supported below Hopper" )
76
+ self .cluster = self ._cast_to_3_tuple (self .cluster )
63
77
# we handle "stream=None" in the launch API
64
78
if self .stream is not None and not isinstance (self .stream , Stream ):
65
79
try :
@@ -69,8 +83,6 @@ def __post_init__(self):
69
83
if self .shmem_size is None :
70
84
self .shmem_size = 0
71
85
72
- _lazy_init ()
73
-
74
86
def _cast_to_3_tuple (self , cfg ):
75
87
if isinstance (cfg , int ):
76
88
if cfg < 1 :
@@ -134,6 +146,12 @@ def launch(kernel, config, *kernel_args):
134
146
drv_cfg .hStream = config .stream .handle
135
147
drv_cfg .sharedMemBytes = config .shmem_size
136
148
drv_cfg .numAttrs = 0 # TODO
149
+ if config .cluster :
150
+ drv_cfg .numAttrs += 1
151
+ attr = cuda .CUlaunchAttribute ()
152
+ attr .id = cuda .CUlaunchAttributeID .CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
153
+ attr .value .clusterDim .x , attr .value .clusterDim .y , attr .value .clusterDim .z = config .cluster
154
+ drv_cfg .attrs .append (attr )
137
155
handle_return (cuda .cuLaunchKernelEx (drv_cfg , int (kernel ._handle ), args_ptr , 0 ))
138
156
else :
139
157
# TODO: check if config has any unsupported attrs
0 commit comments