Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cluster to LaunchConfig to support thread block clusters on Hopper #261

Merged
merged 13 commits into from
Dec 13, 2024
30 changes: 25 additions & 5 deletions cuda_core/cuda/core/experimental/_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional, Union

from cuda import cuda
from cuda.core.experimental._device import Device
from cuda.core.experimental._kernel_arg_handler import ParamHolder
from cuda.core.experimental._module import Kernel
from cuda.core.experimental._stream import Stream
Expand Down Expand Up @@ -38,10 +39,14 @@ class LaunchConfig:
----------
grid : Union[tuple, int]
Collection of threads that will execute a kernel function.
cluster : Union[tuple, int]
Group of blocks (Thread Block Cluster) that will execute on the same
GPU Processing Cluster (GPC). Blocks within a cluster have access to
distributed shared memory and can be explicitly synchronized.
block : Union[tuple, int]
Group of threads (Thread Block) that will execute on the same
multiprocessor. Threads within a thread blocks have access to
shared memory and can be explicitly synchronized.
streaming multiprocessor (SM). Threads within a thread blocks have
access to shared memory and can be explicitly synchronized.
stream : :obj:`Stream`
The stream establishing the stream ordering semantic of a
launch.
Expand All @@ -53,13 +58,22 @@ class LaunchConfig:

# TODO: expand LaunchConfig to include other attributes
grid: Union[tuple, int] = None
cluster: Union[tuple, int] = None
block: Union[tuple, int] = None
stream: Stream = None
shmem_size: Optional[int] = None

def __post_init__(self):
_lazy_init()
self.grid = self._cast_to_3_tuple(self.grid)
self.block = self._cast_to_3_tuple(self.block)
# thread block clusters are supported starting H100
if self.cluster is not None:
if not _use_ex:
raise CUDAError("thread block clusters require cuda.bindings & driver 11.8+")
if Device().compute_capability < (9, 0):
raise CUDAError("thread block clusters are not supported on devices with compute capability < 9.0")
self.cluster = self._cast_to_3_tuple(self.cluster)
# we handle "stream=None" in the launch API
if self.stream is not None and not isinstance(self.stream, Stream):
try:
Expand All @@ -69,8 +83,6 @@ def __post_init__(self):
if self.shmem_size is None:
self.shmem_size = 0

_lazy_init()

def _cast_to_3_tuple(self, cfg):
if isinstance(cfg, int):
if cfg < 1:
Expand Down Expand Up @@ -133,7 +145,15 @@ def launch(kernel, config, *kernel_args):
drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block
drv_cfg.hStream = config.stream.handle
drv_cfg.sharedMemBytes = config.shmem_size
drv_cfg.numAttrs = 0 # TODO
attrs = [] # TODO: support more attributes
if config.cluster:
attr = cuda.CUlaunchAttribute()
attr.id = cuda.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
dim = attr.value.clusterDim
dim.x, dim.y, dim.z = config.cluster
attrs.append(attr)
drv_cfg.numAttrs = len(attrs)
drv_cfg.attrs = attrs
handle_return(cuda.cuLaunchKernelEx(drv_cfg, int(kernel._handle), args_ptr, 0))
else:
# TODO: check if config has any unsupported attrs
Expand Down
14 changes: 14 additions & 0 deletions cuda_core/docs/source/release/0.1.1-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,20 @@ Released on Dec XX, 2024
- Add a `cuda.core.experimental.system` module for querying system- or process- wide information.
- Support TCC devices with a default synchronous memory resource to avoid the use of memory pools

## New features

- Add `LaunchConfig.cluster` to support thread block clusters on Hopper GPUs.
vzhurba01 marked this conversation as resolved.
Show resolved Hide resolved

## Enchancements

- Ensure "ltoir" is a valid code type to `ObjectCode`.
- Improve test coverage.
- Enforce code formatting.

## Bug fixes

- Eliminate potential class destruction issues.
- Fix circular import during handling a foreign CUDA stream.

## Limitations

Expand Down
65 changes: 65 additions & 0 deletions cuda_core/examples/thread_block_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

import os
import sys

from cuda.core.experimental import Device, LaunchConfig, Program, launch

# prepare include
cuda_path = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME"))
if cuda_path is None:
print("this demo requires a valid CUDA_PATH environment variable set", file=sys.stderr)
sys.exit(0)
cuda_include_path = os.path.join(cuda_path, "include")

# print cluster info using a kernel
code = r"""
#include <cooperative_groups.h>

namespace cg = cooperative_groups;

extern "C"
__global__ void check_cluster_info() {
auto g = cg::this_grid();
auto b = cg::this_thread_block();
if (g.cluster_rank() == 0 && g.block_rank() == 0 && g.thread_rank() == 0) {
printf("grid dim: (%u, %u, %u)\n", g.dim_blocks().x, g.dim_blocks().y, g.dim_blocks().z);
printf("cluster dim: (%u, %u, %u)\n", g.dim_clusters().x, g.dim_clusters().y, g.dim_clusters().z);
printf("block dim: (%u, %u, %u)\n", b.dim_threads().x, b.dim_threads().y, b.dim_threads().z);
}
}
"""

dev = Device()
arch = dev.compute_capability
if arch < (9, 0):
print(
"this demo requires compute capability >= 9.0 (since thread block cluster is a hardware feature)",
file=sys.stderr,
)
sys.exit(0)
arch = "".join(f"{i}" for i in arch)

# prepare program & compile kernel
dev.set_current()
prog = Program(code, code_type="c++")
mod = prog.compile(
target_type="cubin",
# TODO: update this after NVIDIA/cuda-python#237 is merged
options=(f"-arch=sm_{arch}", "-std=c++17", f"-I{cuda_include_path}"),
)
ker = mod.get_kernel("check_cluster_info")

# prepare launch config
grid = 4
cluster = 2
block = 32
config = LaunchConfig(grid=grid, cluster=cluster, block=block, stream=dev.default_stream)

# launch kernel on the default stream
launch(ker, config)
dev.sync()

print("done!")
3 changes: 3 additions & 0 deletions cuda_core/tests/example_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def run_example(samples_path, filename, env=None):
break
else:
raise
except SystemExit:
# for samples that early return due to any missing requirements
pytest.skip(f"skip {filename}")
leofang marked this conversation as resolved.
Show resolved Hide resolved
except Exception as e:
msg = "\n"
msg += f"Got error ({filename}):\n"
Expand Down
Loading