diff --git a/.gitattributes b/.gitattributes index 8c8fc427..00407cdc 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1,7 @@ cuda/_version.py export-subst + +* text eol=lf + +# we do not own any headers checked in, don't touch them +*.h binary +*.hpp binary diff --git a/cuda_core/cuda/core/experimental/__init__.py b/cuda_core/cuda/core/experimental/__init__.py index 9b978398..a45f4d77 100644 --- a/cuda_core/cuda/core/experimental/__init__.py +++ b/cuda_core/cuda/core/experimental/__init__.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE +from cuda.core.experimental import utils from cuda.core.experimental._device import Device from cuda.core.experimental._event import EventOptions from cuda.core.experimental._launcher import LaunchConfig, launch diff --git a/cuda_core/cuda/core/experimental/_memoryview.pyx b/cuda_core/cuda/core/experimental/_memoryview.pyx index af6b3adf..d8eba464 100644 --- a/cuda_core/cuda/core/experimental/_memoryview.pyx +++ b/cuda_core/cuda/core/experimental/_memoryview.pyx @@ -20,14 +20,67 @@ from cuda.core.experimental._utils import handle_return @cython.dataclasses.dataclass cdef class StridedMemoryView: - + """A dataclass holding metadata of a strided dense array/tensor. + + A :obj:`StridedMemoryView` instance can be created in two ways: + + 1. Using the :obj:`args_viewable_as_strided_memory` decorator (recommended) + 2. Explicit construction, see below + + This object supports both DLPack (up to v1.0) and CUDA Array Interface + (CAI) v3. When wrapping an arbitrary object it will try the DLPack protocol + first, then the CAI protocol. A :obj:`BufferError` is raised if neither is + supported. + + Since either way would take a consumer stream, for DLPack it is passed to + ``obj.__dlpack__()`` as-is (except for :obj:`None`, see below); for CAI, a + stream order will be established between the consumer stream and the + producer stream (from ``obj.__cuda_array_interface__()["stream"]``), as if + ``cudaStreamWaitEvent`` is called by this method. + + To opt-out of the stream ordering operation in either DLPack or CAI, + please pass ``stream_ptr=-1``. Note that this deviates (on purpose) + from the semantics of ``obj.__dlpack__(stream=None, ...)`` since ``cuda.core`` + does not encourage using the (legacy) default/null stream, but is + consistent with the CAI's semantics. For DLPack, ``stream=-1`` will be + internally passed to ``obj.__dlpack__()`` instead. + + Attributes + ---------- + ptr : int + Pointer to the tensor buffer (as a Python `int`). + shape: tuple + Shape of the tensor. + strides: tuple + Strides of the tensor (in **counts**, not bytes). + dtype: numpy.dtype + Data type of the tensor. + device_id: int + The device ID for where the tensor is located. It is -1 for CPU tensors + (meaning those only accessible from the host). + is_device_accessible: bool + Whether the tensor data can be accessed on the GPU. + readonly: bool + Whether the tensor data can be modified in place. + exporting_obj: Any + A reference to the original tensor object that is being viewed. + + Parameters + ---------- + obj : Any + Any objects that supports either DLPack (up to v1.0) or CUDA Array + Interface (v3). + stream_ptr: int + The pointer address (as Python `int`) to the **consumer** stream. + Stream ordering will be properly established unless ``-1`` is passed. + """ # TODO: switch to use Cython's cdef typing? ptr: int = None shape: tuple = None strides: tuple = None # in counts, not bytes dtype: numpy.dtype = None device_id: int = None # -1 for CPU - device_accessible: bool = None + is_device_accessible: bool = None readonly: bool = None exporting_obj: Any = None @@ -48,7 +101,7 @@ cdef class StridedMemoryView: + f" strides={self.strides},\n" + f" dtype={get_simple_repr(self.dtype)},\n" + f" device_id={self.device_id},\n" - + f" device_accessible={self.device_accessible},\n" + + f" is_device_accessible={self.is_device_accessible},\n" + f" readonly={self.readonly},\n" + f" exporting_obj={get_simple_repr(self.exporting_obj)})") @@ -99,28 +152,25 @@ cdef class _StridedMemoryViewProxy: cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None): cdef int dldevice, device_id, i - cdef bint device_accessible, versioned, is_readonly + cdef bint is_device_accessible, versioned, is_readonly + is_device_accessible = False dldevice, device_id = obj.__dlpack_device__() if dldevice == _kDLCPU: - device_accessible = False assert device_id == 0 + device_id = -1 if stream_ptr is None: raise BufferError("stream=None is ambiguous with view()") elif stream_ptr == -1: stream_ptr = None elif dldevice == _kDLCUDA: - device_accessible = True + assert device_id >= 0 + is_device_accessible = True # no need to check other stream values, it's a pass-through if stream_ptr is None: raise BufferError("stream=None is ambiguous with view()") - elif dldevice == _kDLCUDAHost: - device_accessible = True - assert device_id == 0 - # just do a pass-through without any checks, as pinned memory can be - # accessed on both host and device - elif dldevice == _kDLCUDAManaged: - device_accessible = True - # just do a pass-through without any checks, as managed memory can be + elif dldevice in (_kDLCUDAHost, _kDLCUDAManaged): + is_device_accessible = True + # just do a pass-through without any checks, as pinned/managed memory can be # accessed on both host and device else: raise BufferError("device not supported") @@ -171,7 +221,7 @@ cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None): buf.strides = None buf.dtype = dtype_dlpack_to_numpy(&dl_tensor.dtype) buf.device_id = device_id - buf.device_accessible = device_accessible + buf.is_device_accessible = is_device_accessible buf.readonly = is_readonly buf.exporting_obj = obj @@ -261,7 +311,7 @@ cdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None): if buf.strides is not None: # convert to counts buf.strides = tuple(s // buf.dtype.itemsize for s in buf.strides) - buf.device_accessible = True + buf.is_device_accessible = True buf.device_id = handle_return( cuda.cuPointerGetAttribute( cuda.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, @@ -284,7 +334,34 @@ cdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None): return buf -def viewable(tuple arg_indices): +def args_viewable_as_strided_memory(tuple arg_indices): + """Decorator to create proxy objects to :obj:`StridedMemoryView` for the + specified positional arguments. + + This allows array/tensor attributes to be accessed inside the function + implementation, while keeping the function body array-library-agnostic (if + desired). + + Inside the decorated function, the specified arguments become instances + of an (undocumented) proxy type, regardless of its original source. A + :obj:`StridedMemoryView` instance can be obtained by passing the (consumer) + stream pointer (as a Python `int`) to the proxies's ``view()`` method. For + example: + + .. code-block:: python + + @args_viewable_as_strided_memory((1,)) + def my_func(arg0, arg1, arg2, stream: Stream): + # arg1 can be any object supporting DLPack or CUDA Array Interface + view = arg1.view(stream.handle) + assert isinstance(view, StridedMemoryView) + ... + + Parameters + ---------- + arg_indices : tuple + The indices of the target positional arguments. + """ def wrapped_func_with_indices(func): @functools.wraps(func) def wrapped_func(*args, **kwargs): diff --git a/cuda_core/cuda/core/experimental/utils.py b/cuda_core/cuda/core/experimental/utils.py index 0717b41a..cc9a437d 100644 --- a/cuda_core/cuda/core/experimental/utils.py +++ b/cuda_core/cuda/core/experimental/utils.py @@ -2,3 +2,7 @@ # # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE +from cuda.core.experimental._memoryview import ( + StridedMemoryView, # noqa: F401 + args_viewable_as_strided_memory, # noqa: F401 +) diff --git a/cuda_core/docs/source/api.rst b/cuda_core/docs/source/api.rst index 1cb9811b..558c3ec8 100644 --- a/cuda_core/docs/source/api.rst +++ b/cuda_core/docs/source/api.rst @@ -31,3 +31,18 @@ CUDA compilation toolchain :toctree: generated/ Program + + +.. module:: cuda.core.experimental.utils + +Utility functions +----------------- + +.. autosummary:: + :toctree: generated/ + + args_viewable_as_strided_memory + + :template: dataclass.rst + + StridedMemoryView diff --git a/cuda_core/docs/source/conf.py b/cuda_core/docs/source/conf.py index ce37b3aa..4621e887 100644 --- a/cuda_core/docs/source/conf.py +++ b/cuda_core/docs/source/conf.py @@ -34,6 +34,7 @@ "sphinx.ext.autodoc", "sphinx.ext.autosummary", "sphinx.ext.napoleon", + "sphinx.ext.intersphinx", "myst_nb", "enum_tools.autoenum", "sphinx_copybutton", @@ -82,3 +83,11 @@ # skip cmdline prompts copybutton_exclude = ".linenos, .gp" + +intersphinx_mapping = { + "python": ("https://docs.python.org/3/", None), + "numpy": ("https://numpy.org/doc/stable/", None), +} + +napoleon_google_docstring = False +napoleon_numpy_docstring = True diff --git a/cuda_core/docs/source/release.md b/cuda_core/docs/source/release.md index 48e24786..55090b0b 100644 --- a/cuda_core/docs/source/release.md +++ b/cuda_core/docs/source/release.md @@ -5,5 +5,6 @@ maxdepth: 3 --- + 0.1.1 0.1.0 ``` diff --git a/cuda_core/docs/source/release/0.1.1-notes.md b/cuda_core/docs/source/release/0.1.1-notes.md new file mode 100644 index 00000000..473352a4 --- /dev/null +++ b/cuda_core/docs/source/release/0.1.1-notes.md @@ -0,0 +1,13 @@ +# `cuda.core` Release notes + +Released on Dec XX, 2024 + +## Hightlights +- Add `StridedMemoryView` and `@args_viewable_as_strided_memory` that provide a concrete + implementation of DLPack & CUDA Array Interface supports. + + +## Limitations + +- All APIs are currently *experimental* and subject to change without deprecation notice. + Please kindly share your feedbacks with us so that we can make `cuda.core` better! diff --git a/cuda_core/docs/versions.json b/cuda_core/docs/versions.json index 4163fd31..41664534 100644 --- a/cuda_core/docs/versions.json +++ b/cuda_core/docs/versions.json @@ -1,4 +1,5 @@ { "latest" : "latest", + "0.1.1" : "0.1.1", "0.1.0" : "0.1.0" } diff --git a/cuda_core/examples/strided_memory_view.py b/cuda_core/examples/strided_memory_view.py new file mode 100644 index 00000000..564d7fa0 --- /dev/null +++ b/cuda_core/examples/strided_memory_view.py @@ -0,0 +1,171 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +# ################################################################################ +# +# This demo aims to illustrate two takeaways: +# +# 1. The similarity between CPU and GPU JIT-compilation with C++ sources +# 2. How to use StridedMemoryView to interface with foreign C/C++ functions +# +# To facilitate this demo, we use cffi (https://cffi.readthedocs.io/) for the CPU +# path, which can be easily installed from pip or conda following their instructions. +# We also use NumPy/CuPy as the CPU/GPU array container. +# +# ################################################################################ + +import importlib +import string +import sys + +try: + from cffi import FFI +except ImportError: + print("cffi is not installed, the CPU example will be skipped", file=sys.stderr) + FFI = None +try: + import cupy as cp +except ImportError: + print("cupy is not installed, the GPU example will be skipped", file=sys.stderr) + cp = None +import numpy as np + +from cuda.core.experimental import Device, LaunchConfig, Program, launch +from cuda.core.experimental.utils import StridedMemoryView, args_viewable_as_strided_memory + +# ################################################################################ +# +# Usually this entire code block is in a separate file, built as a Python extension +# module that can be imported by users at run time. For illustrative purposes we +# use JIT compilation to make this demo self-contained. +# +# Here we assume an in-place operation, equivalent to the following NumPy code: +# +# >>> arr = ... +# >>> assert arr.dtype == np.int32 +# >>> assert arr.ndim == 1 +# >>> arr += np.arange(arr.size, dtype=arr.dtype) +# +# is implemented for both CPU and GPU at low-level, with the following C function +# signature: +func_name = "inplace_plus_arange_N" +func_sig = f"void {func_name}(int* data, size_t N)" + +# Here is a concrete (very naive!) implementation on CPU: +if FFI: + cpu_code = string.Template(r""" + extern "C" + $func_sig { + for (size_t i = 0; i < N; i++) { + data[i] += i; + } + } + """).substitute(func_sig=func_sig) + # This is cffi's way of JIT compiling & loading a CPU function. cffi builds an + # extension module that has the Python binding to the underlying C function. + # For more details, please refer to cffi's documentation. + cpu_prog = FFI() + cpu_prog.cdef(f"{func_sig};") + cpu_prog.set_source( + "_cpu_obj", + cpu_code, + source_extension=".cpp", + extra_compile_args=["-std=c++11"], + ) + cpu_prog.compile() + cpu_func = getattr(importlib.import_module("_cpu_obj.lib"), func_name) + +# Here is a concrete (again, very naive!) implementation on GPU: +if cp: + gpu_code = string.Template(r""" + extern "C" + __global__ $func_sig { + const size_t tid = threadIdx.x + blockIdx.x * blockDim.x; + const size_t stride_size = gridDim.x * blockDim.x; + for (size_t i = tid; i < N; i += stride_size) { + data[i] += i; + } + } + """).substitute(func_sig=func_sig) + gpu_prog = Program(gpu_code, code_type="c++") + # To know the GPU's compute capability, we need to identify which GPU to use. + dev = Device(0) + arch = "".join(f"{i}" for i in dev.compute_capability) + mod = gpu_prog.compile( + target_type="cubin", + # TODO: update this after NVIDIA/cuda-python#237 is merged + options=(f"-arch=sm_{arch}", "-std=c++11"), + ) + gpu_ker = mod.get_kernel(func_name) + +# Now we are prepared to run the code from the user's perspective! +# +# ################################################################################ + + +# Below, as a user we want to perform the said in-place operation on either CPU +# or GPU, by calling the corresponding function implemented "elsewhere" (done above). + + +# We assume the 0-th argument supports either DLPack or CUDA Array Interface (both +# of which are supported by StridedMemoryView). +@args_viewable_as_strided_memory((0,)) +def my_func(arr, work_stream): + # Create a memory view over arr (assumed to be a 1D array of int32). The stream + # ordering is taken care of, so that arr can be safely accessed on our work + # stream (ordered after a data stream on which arr is potentially prepared). + view = arr.view(work_stream.handle if work_stream else -1) + assert isinstance(view, StridedMemoryView) + assert len(view.shape) == 1 + assert view.dtype == np.int32 + + size = view.shape[0] + # DLPack also supports host arrays. We want to know if the array data is + # accessible from the GPU, and dispatch to the right routine accordingly. + if view.is_device_accessible: + block = 256 + grid = (size + block - 1) // block + config = LaunchConfig(grid=grid, block=block, stream=work_stream) + launch(gpu_ker, config, view.ptr, np.uint64(size)) + # Here we're being conservative and synchronize over our work stream, + # assuming we do not know the data stream; if we know then we could + # just order the data stream after the work stream here, e.g. + # + # data_stream.wait(work_stream) + # + # without an expensive synchronization (with respect to the host). + work_stream.sync() + else: + cpu_func(cpu_prog.cast("int*", view.ptr), size) + + +# This takes the CPU path +if FFI: + # Create input array on CPU + arr_cpu = np.zeros(1024, dtype=np.int32) + print(f"before: {arr_cpu[:10]=}") + + # Run the workload + my_func(arr_cpu, None) + + # Check the result + print(f"after: {arr_cpu[:10]=}") + assert np.allclose(arr_cpu, np.arange(1024, dtype=np.int32)) + + +# This takes the GPU path +if cp: + dev.set_current() + s = dev.create_stream() + # Create input array on GPU + arr_gpu = cp.ones(1024, dtype=cp.int32) + print(f"before: {arr_gpu[:10]=}") + + # Run the workload + my_func(arr_gpu, s) + + # Check the result + print(f"after: {arr_gpu[:10]=}") + assert cp.allclose(arr_gpu, 1 + cp.arange(1024, dtype=cp.int32)) + s.close() diff --git a/cuda_core/tests/conftest.py b/cuda_core/tests/conftest.py index bb99fb33..59e5883f 100644 --- a/cuda_core/tests/conftest.py +++ b/cuda_core/tests/conftest.py @@ -1,36 +1,56 @@ -# Copyright 2024 NVIDIA Corporation. All rights reserved. -# -# Please refer to the NVIDIA end user license agreement (EULA) associated -# with this source code for terms and conditions that govern your use of -# this software. Any use, reproduction, disclosure, or distribution of -# this software and related documentation outside the terms of the EULA -# is strictly prohibited. -try: - from cuda.bindings import driver -except ImportError: - from cuda import cuda as driver - -import pytest - -from cuda.core.experimental import Device, _device -from cuda.core.experimental._utils import handle_return - - -@pytest.fixture(scope="function") -def init_cuda(): - device = Device() - device.set_current() - yield - _device_unset_current() - - -def _device_unset_current(): - handle_return(driver.cuCtxPopCurrent()) - with _device._tls_lock: - del _device._tls.devices - - -@pytest.fixture(scope="function") -def deinit_cuda(): - yield - _device_unset_current() +# Copyright 2024 NVIDIA Corporation. All rights reserved. +# +# Please refer to the NVIDIA end user license agreement (EULA) associated +# with this source code for terms and conditions that govern your use of +# this software. Any use, reproduction, disclosure, or distribution of +# this software and related documentation outside the terms of the EULA +# is strictly prohibited. + +import glob +import os +import sys + +try: + from cuda.bindings import driver +except ImportError: + from cuda import cuda as driver + +import pytest + +from cuda.core.experimental import Device, _device +from cuda.core.experimental._utils import handle_return + + +@pytest.fixture(scope="function") +def init_cuda(): + device = Device() + device.set_current() + yield + _device_unset_current() + + +def _device_unset_current(): + handle_return(driver.cuCtxPopCurrent()) + with _device._tls_lock: + del _device._tls.devices + + +@pytest.fixture(scope="function") +def deinit_cuda(): + yield + _device_unset_current() + + +# samples relying on cffi could fail as the modules cannot be imported +sys.path.append(os.getcwd()) + + +@pytest.fixture(scope="session", autouse=True) +def clean_up_cffi_files(): + yield + files = glob.glob(os.path.join(os.getcwd(), "_cpu_obj*")) + for f in files: + try: # noqa: SIM105 + os.remove(f) + except FileNotFoundError: + pass # noqa: SIM105 diff --git a/cuda_core/tests/test_utils.py b/cuda_core/tests/test_utils.py new file mode 100644 index 00000000..0926a549 --- /dev/null +++ b/cuda_core/tests/test_utils.py @@ -0,0 +1,137 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +try: + import cupy as cp +except ImportError: + cp = None +try: + from numba import cuda as numba_cuda +except ImportError: + numba_cuda = None +import numpy as np +import pytest + +from cuda.core.experimental import Device +from cuda.core.experimental.utils import StridedMemoryView, args_viewable_as_strided_memory + + +def convert_strides_to_counts(strides, itemsize): + return tuple(s // itemsize for s in strides) + + +@pytest.mark.parametrize( + "in_arr,", + ( + np.empty(3, dtype=np.int32), + np.empty((6, 6), dtype=np.float64)[::2, ::2], + np.empty((3, 4), order="F"), + np.empty((), dtype=np.float16), + # readonly is fixed recently (numpy/numpy#26501) + pytest.param( + np.frombuffer(b""), + marks=pytest.mark.skipif( + tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+" + ), + ), + ), +) +class TestViewCPU: + def test_args_viewable_as_strided_memory_cpu(self, in_arr): + @args_viewable_as_strided_memory((0,)) + def my_func(arr): + # stream_ptr=-1 means "the consumer does not care" + view = arr.view(-1) + self._check_view(view, in_arr) + + my_func(in_arr) + + def test_strided_memory_view_cpu(self, in_arr): + # stream_ptr=-1 means "the consumer does not care" + view = StridedMemoryView(in_arr, stream_ptr=-1) + self._check_view(view, in_arr) + + def _check_view(self, view, in_arr): + assert isinstance(view, StridedMemoryView) + assert view.ptr == in_arr.ctypes.data + assert view.shape == in_arr.shape + strides_in_counts = convert_strides_to_counts(in_arr.strides, in_arr.dtype.itemsize) + if in_arr.flags.c_contiguous: + assert view.strides is None + else: + assert view.strides == strides_in_counts + assert view.dtype == in_arr.dtype + assert view.device_id == -1 + assert view.is_device_accessible is False + assert view.exporting_obj is in_arr + assert view.readonly is not in_arr.flags.writeable + + +def gpu_array_samples(): + # TODO: this function would initialize the device at test collection time + samples = [] + if cp is not None: + samples += [ + (cp.empty(3, dtype=cp.complex64), False), + (cp.empty((6, 6), dtype=cp.float64)[::2, ::2], True), + (cp.empty((3, 4), order="F"), True), + ] + # Numba's device_array is the only known array container that does not + # support DLPack (so that we get to test the CAI coverage). + if numba_cuda is not None: + samples += [ + (numba_cuda.device_array((2,), dtype=np.int8), False), + (numba_cuda.device_array((4, 2), dtype=np.float32), True), + ] + return samples + + +def gpu_array_ptr(arr): + if cp is not None and isinstance(arr, cp.ndarray): + return arr.data.ptr + if numba_cuda is not None and isinstance(arr, numba_cuda.cudadrv.devicearray.DeviceNDArray): + return arr.device_ctypes_pointer.value + raise NotImplementedError(f"{arr=}") + + +@pytest.mark.parametrize("in_arr,use_stream", (*gpu_array_samples(),)) +class TestViewGPU: + def test_args_viewable_as_strided_memory_gpu(self, in_arr, use_stream): + # TODO: use the device fixture? + dev = Device() + dev.set_current() + # This is the consumer stream + s = dev.create_stream() if use_stream else None + + @args_viewable_as_strided_memory((0,)) + def my_func(arr): + view = arr.view(s.handle if s else -1) + self._check_view(view, in_arr, dev) + + my_func(in_arr) + + def test_strided_memory_view_cpu(self, in_arr, use_stream): + # TODO: use the device fixture? + dev = Device() + dev.set_current() + # This is the consumer stream + s = dev.create_stream() if use_stream else None + + view = StridedMemoryView(in_arr, stream_ptr=s.handle if s else -1) + self._check_view(view, in_arr, dev) + + def _check_view(self, view, in_arr, dev): + assert isinstance(view, StridedMemoryView) + assert view.ptr == gpu_array_ptr(in_arr) + assert view.shape == in_arr.shape + strides_in_counts = convert_strides_to_counts(in_arr.strides, in_arr.dtype.itemsize) + if in_arr.flags["C_CONTIGUOUS"]: + assert view.strides in (None, strides_in_counts) + else: + assert view.strides == strides_in_counts + assert view.dtype == in_arr.dtype + assert view.device_id == dev.device_id + assert view.is_device_accessible is True + assert view.exporting_obj is in_arr + # can't test view.readonly with CuPy or Numba...