Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support JIT compilation for CUDA driver & bindings 11.x #188

Merged
merged 7 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions cuda_core/cuda/core/experimental/_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

from dataclasses import dataclass
import importlib.metadata
from typing import Optional, Union

import numpy as np
Expand All @@ -19,6 +20,7 @@
class LaunchConfig:
"""
"""
# TODO: expand LaunchConfig to include other attributes
grid: Union[tuple, int] = None
block: Union[tuple, int] = None
stream: Stream = None
Expand Down Expand Up @@ -63,28 +65,45 @@ def _cast_to_3_tuple(self, cfg):
raise ValueError


# binding availability depends on cuda-python version
py_major_minor = tuple(int(v) for v in (
importlib.metadata.version("cuda-python").split(".")[:2]))
driver_ver = handle_return(cuda.cuDriverGetVersion())
leofang marked this conversation as resolved.
Show resolved Hide resolved
use_ex = (driver_ver >= 11080) and (py_major_minor >= (11, 8))


def launch(kernel, config, *kernel_args):
if not isinstance(kernel, Kernel):
raise ValueError
config = check_or_create_options(LaunchConfig, config, "launch config")
if config.stream is None:
vzhurba01 marked this conversation as resolved.
Show resolved Hide resolved
raise CUDAError("stream cannot be None")

# TODO: can we ensure kernel_args is valid/safe to use here?
# TODO: merge with HelperKernelParams?
kernel_args = ParamHolder(kernel_args)
args_ptr = kernel_args.ptr

driver_ver = handle_return(cuda.cuDriverGetVersion())
if driver_ver >= 12000:
# Note: CUkernel can still be launched via the old cuLaunchKernel and we do not care
# about the CUfunction/CUkernel difference (which depends on whether the "old" or
# "new" module loading APIs are in use). We check both binding & driver versions here
# mainly to see if the "Ex" API is available and if so we use it, as it's more feature
# rich.
if use_ex:
drv_cfg = cuda.CUlaunchConfig()
drv_cfg.gridDimX, drv_cfg.gridDimY, drv_cfg.gridDimZ = config.grid
drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block
if config.stream is None:
raise CUDAError("stream cannot be None")
drv_cfg.hStream = config.stream._handle
drv_cfg.sharedMemBytes = config.shmem_size
drv_cfg.numAttrs = 0 # FIXME

# TODO: merge with HelperKernelParams?
kernel_args = ParamHolder(kernel_args)
args_ptr = kernel_args.ptr

drv_cfg.numAttrs = 0 # TODO
handle_return(cuda.cuLaunchKernelEx(
drv_cfg, int(kernel._handle), args_ptr, 0))
else:
raise NotImplementedError("TODO")
# TODO: check if config has any unsupported attrs
handle_return(cuda.cuLaunchKernel(
int(kernel._handle),
*config.grid,
*config.block,
config.shmem_size,
config.stream._handle,
args_ptr, 0))
36 changes: 25 additions & 11 deletions cuda_core/cuda/core/experimental/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,33 @@
#
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

import importlib.metadata

from cuda import cuda, cudart
from cuda.core.experimental._utils import handle_return


_backend = {
"new": {
"file": cuda.cuLibraryLoadFromFile,
"data": cuda.cuLibraryLoadData,
"kernel": cuda.cuLibraryGetKernel,
},
"old": {
"file": cuda.cuModuleLoad,
"data": cuda.cuModuleLoadDataEx,
"kernel": cuda.cuModuleGetFunction,
},
}

# binding availability depends on cuda-python version
py_major_ver = int(importlib.metadata.version("cuda-python").split(".")[0])
if py_major_ver >= 12:
_backend["new"] = {
"file": cuda.cuLibraryLoadFromFile,
"data": cuda.cuLibraryLoadData,
"kernel": cuda.cuLibraryGetKernel,
}
_kernel_ctypes = (cuda.CUfunction, cuda.CUkernel)
else:
_kernel_ctypes = (cuda.CUfunction,)
driver_ver = handle_return(cuda.cuDriverGetVersion())


class Kernel:

Expand All @@ -29,13 +39,15 @@ def __init__(self):

@staticmethod
def _from_obj(obj, mod):
assert isinstance(obj, (cuda.CUkernel, cuda.CUfunction))
assert isinstance(obj, _kernel_ctypes)
assert isinstance(mod, ObjectCode)
ker = Kernel.__new__(Kernel)
ker._handle = obj
ker._module = mod
return ker

# TODO: implement from_handle()
vzhurba01 marked this conversation as resolved.
Show resolved Hide resolved


class ObjectCode:

Expand All @@ -48,8 +60,8 @@ def __init__(self, module, code_type, jit_options=None, *,
raise ValueError
self._handle = None

driver_ver = handle_return(cuda.cuDriverGetVersion())
self._loader = _backend["new"] if driver_ver >= 12000 else _backend["old"]
backend = "new" if (py_major_ver >= 12 and driver_ver >= 12000) else "old"
self._loader = _backend[backend]

if isinstance(module, str):
if driver_ver < 12000 and jit_options is not None:
Expand All @@ -60,12 +72,12 @@ def __init__(self, module, code_type, jit_options=None, *,
assert isinstance(module, bytes)
if jit_options is None:
jit_options = {}
if driver_ver >= 12000:
if backend == "new":
args = (module, list(jit_options.keys()), list(jit_options.values()), len(jit_options),
# TODO: support library options
[], [], 0)
else:
args = (module, len(jit_options), jit_options.keys(), jit_options.values())
else: # "old" backend
args = (module, len(jit_options), list(jit_options.keys()), list(jit_options.values()))
leofang marked this conversation as resolved.
Show resolved Hide resolved
self._handle = handle_return(self._loader["data"](*args))

self._code_type = code_type
Expand All @@ -83,3 +95,5 @@ def get_kernel(self, name):
name = name.encode()
data = handle_return(self._loader["kernel"](self._handle, name))
return Kernel._from_obj(data, self)

# TODO: implement from_handle()
vzhurba01 marked this conversation as resolved.
Show resolved Hide resolved