From 7fd8ccb85aa542be4f3d68b2dce03931eeef94d6 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Fri, 25 Oct 2024 19:41:59 -0700 Subject: [PATCH] nit: cleaner treatment --- cuda_core/cuda/core/experimental/_launcher.py | 21 +++++++++++++------ cuda_core/cuda/core/experimental/_module.py | 10 ++++----- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_launcher.py b/cuda_core/cuda/core/experimental/_launcher.py index b1223365..c3af8866 100644 --- a/cuda_core/cuda/core/experimental/_launcher.py +++ b/cuda_core/cuda/core/experimental/_launcher.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE from dataclasses import dataclass +import importlib.metadata from typing import Optional, Union import numpy as np @@ -64,6 +65,13 @@ def _cast_to_3_tuple(self, cfg): raise ValueError +# binding availability depends on cuda-python version +py_major_minor = tuple(int(v) for v in ( + importlib.metadata.version("cuda-python").split(".")[:2])) +driver_ver = handle_return(cuda.cuDriverGetVersion()) +use_ex = (driver_ver >= 11080) and (py_major_minor >= (11, 8)) + + def launch(kernel, config, *kernel_args): if not isinstance(kernel, Kernel): raise ValueError @@ -76,11 +84,12 @@ def launch(kernel, config, *kernel_args): kernel_args = ParamHolder(kernel_args) args_ptr = kernel_args.ptr - # Note: CUkernel can still be launched via the old cuLaunchKernel. We check ._backend - # here not because of the CUfunction/CUkernel difference (which depends on whether the - # "old" or "new" module loading APIs are in use), but only as a proxy to check if - # both binding & driver versions support the "Ex" API, which is more feature rich. - if kernel._backend == "new": + # Note: CUkernel can still be launched via the old cuLaunchKernel and we do not care + # about the CUfunction/CUkernel difference (which depends on whether the "old" or + # "new" module loading APIs are in use). We check both binding & driver versions here + # mainly to see if the "Ex" API is available and if so we use it, as it's more feature + # rich. + if use_ex: drv_cfg = cuda.CUlaunchConfig() drv_cfg.gridDimX, drv_cfg.gridDimY, drv_cfg.gridDimZ = config.grid drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block @@ -89,7 +98,7 @@ def launch(kernel, config, *kernel_args): drv_cfg.numAttrs = 0 # TODO handle_return(cuda.cuLaunchKernelEx( drv_cfg, int(kernel._handle), args_ptr, 0)) - else: # "old" backend + else: # TODO: check if config has any unsupported attrs handle_return(cuda.cuLaunchKernel( int(kernel._handle), diff --git a/cuda_core/cuda/core/experimental/_module.py b/cuda_core/cuda/core/experimental/_module.py index e5d0808f..a51ab24f 100644 --- a/cuda_core/cuda/core/experimental/_module.py +++ b/cuda_core/cuda/core/experimental/_module.py @@ -32,19 +32,18 @@ class Kernel: - __slots__ = ("_handle", "_module", "_backend") + __slots__ = ("_handle", "_module",) def __init__(self): raise NotImplementedError("directly constructing a Kernel instance is not supported") @staticmethod - def _from_obj(obj, mod, backend): + def _from_obj(obj, mod): assert isinstance(obj, _kernel_ctypes) assert isinstance(mod, ObjectCode) ker = Kernel.__new__(Kernel) ker._handle = obj ker._module = mod - ker._backend = backend return ker # TODO: implement from_handle() @@ -52,7 +51,7 @@ def _from_obj(obj, mod, backend): class ObjectCode: - __slots__ = ("_handle", "_code_type", "_module", "_loader", "_loader_backend", "_sym_map") + __slots__ = ("_handle", "_code_type", "_module", "_loader", "_sym_map") _supported_code_type = ("cubin", "ptx", "fatbin") def __init__(self, module, code_type, jit_options=None, *, @@ -63,7 +62,6 @@ def __init__(self, module, code_type, jit_options=None, *, backend = "new" if (py_major_ver >= 12 and driver_ver >= 12000) else "old" self._loader = _backend[backend] - self._loader_backend = backend if isinstance(module, str): if driver_ver < 12000 and jit_options is not None: @@ -96,6 +94,6 @@ def get_kernel(self, name): except KeyError: name = name.encode() data = handle_return(self._loader["kernel"](self._handle, name)) - return Kernel._from_obj(data, self, self._loader_backend) + return Kernel._from_obj(data, self) # TODO: implement from_handle()