Skip to content

Commit

Permalink
propagate py/driver ver check to launch
Browse files Browse the repository at this point in the history
  • Loading branch information
leofang committed Oct 26, 2024
1 parent b64f337 commit 7587684
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
9 changes: 6 additions & 3 deletions cuda_core/cuda/core/experimental/_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,11 @@ def launch(kernel, config, *kernel_args):
kernel_args = ParamHolder(kernel_args)
args_ptr = kernel_args.ptr

driver_ver = handle_return(cuda.cuDriverGetVersion())
if driver_ver >= 12000:
# Note: CUkernel can still be launched via the old cuLaunchKernel. We check ._backend
# here not because of the CUfunction/CUkernel difference (which depends on whether the
# "old" or "new" module loading APIs are in use), but only as a proxy to check if
# both binding & driver versions support the "Ex" API, which is more feature rich.
if kernel._backend == "new":
drv_cfg = cuda.CUlaunchConfig()
drv_cfg.gridDimX, drv_cfg.gridDimY, drv_cfg.gridDimZ = config.grid
drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block
Expand All @@ -86,7 +89,7 @@ def launch(kernel, config, *kernel_args):
drv_cfg.numAttrs = 0 # TODO
handle_return(cuda.cuLaunchKernelEx(
drv_cfg, int(kernel._handle), args_ptr, 0))
else:
else: # "old" backend
# TODO: check if config has any unsupported attrs
handle_return(cuda.cuLaunchKernel(
int(kernel._handle),
Expand Down
10 changes: 6 additions & 4 deletions cuda_core/cuda/core/experimental/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,27 @@

class Kernel:

__slots__ = ("_handle", "_module",)
__slots__ = ("_handle", "_module", "_backend")

def __init__(self):
raise NotImplementedError("directly constructing a Kernel instance is not supported")

@staticmethod
def _from_obj(obj, mod):
def _from_obj(obj, mod, backend):
assert isinstance(obj, _kernel_ctypes)
assert isinstance(mod, ObjectCode)
ker = Kernel.__new__(Kernel)
ker._handle = obj
ker._module = mod
ker._backend = backend
return ker

# TODO: implement from_handle()


class ObjectCode:

__slots__ = ("_handle", "_code_type", "_module", "_loader", "_sym_map")
__slots__ = ("_handle", "_code_type", "_module", "_loader", "_loader_backend", "_sym_map")
_supported_code_type = ("cubin", "ptx", "fatbin")

def __init__(self, module, code_type, jit_options=None, *,
Expand All @@ -62,6 +63,7 @@ def __init__(self, module, code_type, jit_options=None, *,

backend = "new" if (py_major_ver >= 12 and driver_ver >= 12000) else "old"
self._loader = _backend[backend]
self._loader_backend = backend

if isinstance(module, str):
if driver_ver < 12000 and jit_options is not None:
Expand Down Expand Up @@ -94,6 +96,6 @@ def get_kernel(self, name):
except KeyError:
name = name.encode()
data = handle_return(self._loader["kernel"](self._handle, name))
return Kernel._from_obj(data, self)
return Kernel._from_obj(data, self, self._loader_backend)

# TODO: implement from_handle()

0 comments on commit 7587684

Please sign in to comment.