Skip to content

Commit 2706062

Browse files
authored
feat: experimental support of green ctx (#1163)
<!-- .github/pull_request_template.md --> ## 📌 Description Use cuda-python bindings to create green context for spliting SM resources. Co-authored-by: Yi Pan <conlesspan@outlook.com> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes cc @Conless
1 parent ba2470c commit 2706062

File tree

5 files changed

+209
-1
lines changed

5 files changed

+209
-1
lines changed

docs/api/green_ctx.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
.. _apigreenctx:
2+
3+
flashinfer.green_ctx
4+
====================
5+
6+
.. currentmodule:: flashinfer.green_ctx
7+
8+
.. autofunction:: split_device_green_ctx

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,4 @@ FlashInfer is a library and kernel generator for Large Language Models that prov
4040
api/rope
4141
api/activation
4242
api/quantization
43+
api/green_ctx

flashinfer/green_ctx.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
"""
2+
Copyright (c) 2025 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import List, Tuple
18+
19+
import cuda.bindings.driver as driver
20+
import cuda.bindings.runtime as runtime
21+
import cuda.cudart as cudart
22+
import cuda.nvrtc as nvrtc
23+
import torch
24+
from cuda.bindings.driver import CUdevice, CUdevResource
25+
26+
27+
def _cudaGetErrorEnum(error):
28+
if isinstance(error, driver.CUresult):
29+
err, name = driver.cuGetErrorName(error)
30+
return name if err == driver.CUresult.CUDA_SUCCESS else "<unknown>"
31+
elif isinstance(error, runtime.cudaError_t):
32+
return cudart.cudaGetErrorName(error)[1]
33+
elif isinstance(error, nvrtc.nvrtcResult):
34+
return nvrtc.nvrtcGetErrorString(error)[1]
35+
else:
36+
raise RuntimeError(f"Unknown error type: {error}")
37+
38+
39+
def checkCudaErrors(result):
40+
if result[0].value:
41+
raise RuntimeError(
42+
f"CUDA error code={result[0].value}({_cudaGetErrorEnum(result[0])})"
43+
)
44+
if len(result) == 1:
45+
return None
46+
elif len(result) == 2:
47+
return result[1]
48+
else:
49+
return result[1:]
50+
51+
52+
def get_cudevice(dev: torch.device) -> CUdevice:
53+
try:
54+
cu_dev = checkCudaErrors(driver.cuDeviceGet(dev.index))
55+
except RuntimeError as e:
56+
runtime.cudaInitDevice(dev.index, 0, 0)
57+
cu_dev = checkCudaErrors(driver.cuDeviceGet(dev.index))
58+
return cu_dev
59+
60+
61+
def get_device_resource(cu_dev: CUdevice) -> CUdevResource:
62+
return checkCudaErrors(
63+
driver.cuDeviceGetDevResource(
64+
cu_dev, driver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_SM
65+
)
66+
)
67+
68+
69+
def split_resource(
70+
resource: CUdevResource,
71+
num_groups: int,
72+
min_count: int,
73+
) -> Tuple[CUdevResource, CUdevResource]:
74+
results, _, remaining = checkCudaErrors(
75+
driver.cuDevSmResourceSplitByCount(
76+
num_groups,
77+
resource,
78+
0, # useFlags
79+
min_count,
80+
)
81+
)
82+
return results, remaining
83+
84+
85+
def create_green_ctx_streams(
86+
cu_dev: CUdevResource, resources: List[CUdevResource]
87+
) -> List[torch.Stream]:
88+
streams = []
89+
for split in resources:
90+
desc = checkCudaErrors(driver.cuDevResourceGenerateDesc([split], 1))
91+
green_ctx = checkCudaErrors(
92+
driver.cuGreenCtxCreate(
93+
desc, cu_dev, driver.CUgreenCtxCreate_flags.CU_GREEN_CTX_DEFAULT_STREAM
94+
)
95+
)
96+
stream = checkCudaErrors(
97+
driver.cuGreenCtxStreamCreate(
98+
green_ctx,
99+
driver.CUstream_flags.CU_STREAM_NON_BLOCKING,
100+
0, # priority
101+
)
102+
)
103+
streams.append(torch.cuda.get_stream_from_external(stream))
104+
105+
return streams
106+
107+
108+
def split_device_green_ctx(
109+
dev: torch.device, num_groups: int, min_count: int
110+
) -> Tuple[List[torch.Stream], List[CUdevResource]]:
111+
r"""
112+
Split the device into multiple `green contexts <https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html>`_,
113+
return the corresponding streams and `CUdevResource` for each group and the remaining SMs.
114+
Green contexts allow concurrent execution of multiple kernels on different SM partitions.
115+
116+
Args:
117+
dev: The device to split.
118+
num_groups: The number of groups to split the device into.
119+
min_count: Minimum number of SMs required for each group, it will be adjusted to meet the
120+
alignment and granularity requirements.
121+
122+
Returns:
123+
streams: The list of torch.Streams objects corresponding to the green contexts.
124+
resources: The list of CUdevResource objects corresponding to the green contexts.
125+
126+
Example:
127+
>>> from flashinfer.green_ctx import split_device_green_ctx
128+
>>> import torch
129+
>>> dev = torch.device("cuda:0")
130+
>>> streams, resources = split_device_green_ctx(dev, 2, 16)
131+
>>> print([r.sm.smCount for r in resources])
132+
[16, 16, 100]
133+
>>> with torch.cuda.stream(streams[0]):
134+
... x = torch.randn(8192, 8192, device=dev, dtype=torch.bfloat16)
135+
... y = torch.randn(8192, 8192, device=dev, dtype=torch.bfloat16)
136+
... z = x @ y
137+
... print(z.shape)
138+
...
139+
torch.Size([8192, 8192])
140+
141+
Note:
142+
The length of the returned streams and resources is ``num_groups + 1``,
143+
where the last one is the remaining SMs.
144+
145+
Raises:
146+
RuntimeError: when requested SM allocation exceeds device capacity:
147+
``num_groups * round_up(min_count, 8) > num_sm``
148+
"""
149+
cu_dev = get_cudevice(dev)
150+
resource = get_device_resource(cu_dev)
151+
results, remaining = split_resource(resource, num_groups, min_count)
152+
resources = results + [remaining]
153+
streams = create_green_ctx_streams(cu_dev, resources)
154+
return streams, resources

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def generate_build_meta(aot_build_meta: dict) -> None:
5151

5252
ext_modules = []
5353
cmdclass = {}
54-
install_requires = ["numpy", "torch", "ninja", "requests"]
54+
install_requires = ["numpy", "torch", "ninja", "requests", "cuda-python"]
5555
generate_build_meta({})
5656

5757
if enable_aot:

tests/test_green_ctx.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import pytest
2+
import torch
3+
4+
import flashinfer.green_ctx as green_ctx
5+
6+
7+
@pytest.mark.parametrize("device", ["cuda:0"])
8+
@pytest.mark.parametrize("num_groups", [1, 2, 3])
9+
@pytest.mark.parametrize("min_count", [16, 32])
10+
def test_green_ctx_creation(
11+
device: str,
12+
num_groups: int,
13+
min_count: int,
14+
):
15+
streams, resources = green_ctx.split_device_green_ctx(
16+
torch.device(device), num_groups, min_count
17+
)
18+
19+
assert len(resources) == num_groups + 1
20+
for resource in resources[:-1]:
21+
sm_count = resource.sm.smCount
22+
assert sm_count >= min_count
23+
24+
25+
@pytest.mark.parametrize("device", ["cuda:0"])
26+
@pytest.mark.parametrize("num_groups", [1, 2, 3])
27+
@pytest.mark.parametrize("min_count", [16, 32])
28+
def test_green_ctx_kernel_execution(
29+
device: str,
30+
num_groups: int,
31+
min_count: int,
32+
):
33+
streams, resources = green_ctx.split_device_green_ctx(
34+
torch.device(device), num_groups, min_count
35+
)
36+
num_partitions = num_groups + 1
37+
assert len(streams) == num_partitions
38+
assert len(resources) == num_partitions
39+
40+
for stream in streams:
41+
with torch.cuda.stream(stream):
42+
x = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16)
43+
y = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16)
44+
z = x @ y
45+
print(z.shape)

0 commit comments

Comments
 (0)