|
| 1 | +""" |
| 2 | +Copyright (c) 2025 by FlashInfer team. |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +""" |
| 16 | + |
| 17 | +from typing import List, Tuple |
| 18 | + |
| 19 | +import cuda.bindings.driver as driver |
| 20 | +import cuda.bindings.runtime as runtime |
| 21 | +import cuda.cudart as cudart |
| 22 | +import cuda.nvrtc as nvrtc |
| 23 | +import torch |
| 24 | +from cuda.bindings.driver import CUdevice, CUdevResource |
| 25 | + |
| 26 | + |
| 27 | +def _cudaGetErrorEnum(error): |
| 28 | + if isinstance(error, driver.CUresult): |
| 29 | + err, name = driver.cuGetErrorName(error) |
| 30 | + return name if err == driver.CUresult.CUDA_SUCCESS else "<unknown>" |
| 31 | + elif isinstance(error, runtime.cudaError_t): |
| 32 | + return cudart.cudaGetErrorName(error)[1] |
| 33 | + elif isinstance(error, nvrtc.nvrtcResult): |
| 34 | + return nvrtc.nvrtcGetErrorString(error)[1] |
| 35 | + else: |
| 36 | + raise RuntimeError(f"Unknown error type: {error}") |
| 37 | + |
| 38 | + |
| 39 | +def checkCudaErrors(result): |
| 40 | + if result[0].value: |
| 41 | + raise RuntimeError( |
| 42 | + f"CUDA error code={result[0].value}({_cudaGetErrorEnum(result[0])})" |
| 43 | + ) |
| 44 | + if len(result) == 1: |
| 45 | + return None |
| 46 | + elif len(result) == 2: |
| 47 | + return result[1] |
| 48 | + else: |
| 49 | + return result[1:] |
| 50 | + |
| 51 | + |
| 52 | +def get_cudevice(dev: torch.device) -> CUdevice: |
| 53 | + try: |
| 54 | + cu_dev = checkCudaErrors(driver.cuDeviceGet(dev.index)) |
| 55 | + except RuntimeError as e: |
| 56 | + runtime.cudaInitDevice(dev.index, 0, 0) |
| 57 | + cu_dev = checkCudaErrors(driver.cuDeviceGet(dev.index)) |
| 58 | + return cu_dev |
| 59 | + |
| 60 | + |
| 61 | +def get_device_resource(cu_dev: CUdevice) -> CUdevResource: |
| 62 | + return checkCudaErrors( |
| 63 | + driver.cuDeviceGetDevResource( |
| 64 | + cu_dev, driver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_SM |
| 65 | + ) |
| 66 | + ) |
| 67 | + |
| 68 | + |
| 69 | +def split_resource( |
| 70 | + resource: CUdevResource, |
| 71 | + num_groups: int, |
| 72 | + min_count: int, |
| 73 | +) -> Tuple[CUdevResource, CUdevResource]: |
| 74 | + results, _, remaining = checkCudaErrors( |
| 75 | + driver.cuDevSmResourceSplitByCount( |
| 76 | + num_groups, |
| 77 | + resource, |
| 78 | + 0, # useFlags |
| 79 | + min_count, |
| 80 | + ) |
| 81 | + ) |
| 82 | + return results, remaining |
| 83 | + |
| 84 | + |
| 85 | +def create_green_ctx_streams( |
| 86 | + cu_dev: CUdevResource, resources: List[CUdevResource] |
| 87 | +) -> List[torch.Stream]: |
| 88 | + streams = [] |
| 89 | + for split in resources: |
| 90 | + desc = checkCudaErrors(driver.cuDevResourceGenerateDesc([split], 1)) |
| 91 | + green_ctx = checkCudaErrors( |
| 92 | + driver.cuGreenCtxCreate( |
| 93 | + desc, cu_dev, driver.CUgreenCtxCreate_flags.CU_GREEN_CTX_DEFAULT_STREAM |
| 94 | + ) |
| 95 | + ) |
| 96 | + stream = checkCudaErrors( |
| 97 | + driver.cuGreenCtxStreamCreate( |
| 98 | + green_ctx, |
| 99 | + driver.CUstream_flags.CU_STREAM_NON_BLOCKING, |
| 100 | + 0, # priority |
| 101 | + ) |
| 102 | + ) |
| 103 | + streams.append(torch.cuda.get_stream_from_external(stream)) |
| 104 | + |
| 105 | + return streams |
| 106 | + |
| 107 | + |
| 108 | +def split_device_green_ctx( |
| 109 | + dev: torch.device, num_groups: int, min_count: int |
| 110 | +) -> Tuple[List[torch.Stream], List[CUdevResource]]: |
| 111 | + r""" |
| 112 | + Split the device into multiple `green contexts <https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html>`_, |
| 113 | + return the corresponding streams and `CUdevResource` for each group and the remaining SMs. |
| 114 | + Green contexts allow concurrent execution of multiple kernels on different SM partitions. |
| 115 | +
|
| 116 | + Args: |
| 117 | + dev: The device to split. |
| 118 | + num_groups: The number of groups to split the device into. |
| 119 | + min_count: Minimum number of SMs required for each group, it will be adjusted to meet the |
| 120 | + alignment and granularity requirements. |
| 121 | +
|
| 122 | + Returns: |
| 123 | + streams: The list of torch.Streams objects corresponding to the green contexts. |
| 124 | + resources: The list of CUdevResource objects corresponding to the green contexts. |
| 125 | +
|
| 126 | + Example: |
| 127 | + >>> from flashinfer.green_ctx import split_device_green_ctx |
| 128 | + >>> import torch |
| 129 | + >>> dev = torch.device("cuda:0") |
| 130 | + >>> streams, resources = split_device_green_ctx(dev, 2, 16) |
| 131 | + >>> print([r.sm.smCount for r in resources]) |
| 132 | + [16, 16, 100] |
| 133 | + >>> with torch.cuda.stream(streams[0]): |
| 134 | + ... x = torch.randn(8192, 8192, device=dev, dtype=torch.bfloat16) |
| 135 | + ... y = torch.randn(8192, 8192, device=dev, dtype=torch.bfloat16) |
| 136 | + ... z = x @ y |
| 137 | + ... print(z.shape) |
| 138 | + ... |
| 139 | + torch.Size([8192, 8192]) |
| 140 | +
|
| 141 | + Note: |
| 142 | + The length of the returned streams and resources is ``num_groups + 1``, |
| 143 | + where the last one is the remaining SMs. |
| 144 | +
|
| 145 | + Raises: |
| 146 | + RuntimeError: when requested SM allocation exceeds device capacity: |
| 147 | + ``num_groups * round_up(min_count, 8) > num_sm`` |
| 148 | + """ |
| 149 | + cu_dev = get_cudevice(dev) |
| 150 | + resource = get_device_resource(cu_dev) |
| 151 | + results, remaining = split_resource(resource, num_groups, min_count) |
| 152 | + resources = results + [remaining] |
| 153 | + streams = create_green_ctx_streams(cu_dev, resources) |
| 154 | + return streams, resources |
0 commit comments