From 6bc4907b7a06819ba6f19867baf4a22be635c129 Mon Sep 17 00:00:00 2001 From: Matthew Murray Date: Wed, 8 Jan 2025 17:01:23 -0800 Subject: [PATCH] address review --- python/rmm/rmm/_cuda/stream.py | 8 ++++---- python/rmm/rmm/pylibrmm/device_buffer.pxd | 2 +- python/rmm/rmm/pylibrmm/stream.pxd | 3 ++- python/rmm/rmm/pylibrmm/stream.pyx | 25 +++++------------------ python/rmm/rmm/tests/test_rmm.py | 8 ++++---- 5 files changed, 16 insertions(+), 30 deletions(-) diff --git a/python/rmm/rmm/_cuda/stream.py b/python/rmm/rmm/_cuda/stream.py index f508cec5b..f912b9ae6 100644 --- a/python/rmm/rmm/_cuda/stream.py +++ b/python/rmm/rmm/_cuda/stream.py @@ -22,10 +22,10 @@ ) __all__ = [ - DEFAULT_STREAM, - LEGACY_DEFAULT_STREAM, - PER_THREAD_DEFAULT_STREAM, - Stream, + "DEFAULT_STREAM", + "LEGACY_DEFAULT_STREAM", + "PER_THREAD_DEFAULT_STREAM", + "Stream", ] warnings.warn( diff --git a/python/rmm/rmm/pylibrmm/device_buffer.pxd b/python/rmm/rmm/pylibrmm/device_buffer.pxd index 80ee52a48..295c2494e 100644 --- a/python/rmm/rmm/pylibrmm/device_buffer.pxd +++ b/python/rmm/rmm/pylibrmm/device_buffer.pxd @@ -15,9 +15,9 @@ from libc.stdint cimport uintptr_t from libcpp.memory cimport unique_ptr -from rmm.pylibrmm.stream cimport Stream from rmm.librmm.device_buffer cimport device_buffer from rmm.pylibrmm.memory_resource cimport DeviceMemoryResource +from rmm.pylibrmm.stream cimport Stream cdef class DeviceBuffer: diff --git a/python/rmm/rmm/pylibrmm/stream.pxd b/python/rmm/rmm/pylibrmm/stream.pxd index 35d0d95b1..219b75864 100644 --- a/python/rmm/rmm/pylibrmm/stream.pxd +++ b/python/rmm/rmm/pylibrmm/stream.pxd @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2025, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,3 +30,4 @@ cdef class Stream: cdef void c_synchronize(self) except * nogil cdef bool c_is_default(self) except * nogil cdef void _init_with_new_cuda_stream(self) except * + cdef void _init_from_stream(self, Stream stream) except * diff --git a/python/rmm/rmm/pylibrmm/stream.pyx b/python/rmm/rmm/pylibrmm/stream.pyx index a02a80e23..327b2dada 100644 --- a/python/rmm/rmm/pylibrmm/stream.pyx +++ b/python/rmm/rmm/pylibrmm/stream.pyx @@ -34,38 +34,20 @@ cdef class Stream: ---------- obj: optional * If None (the default), a new CUDA stream is created. - * If a stream that implements the __cuda_stream__ protocol - is provided, we use it. * If a Numba or CuPy stream is provided, we make a thin wrapper around it. """ if obj is None: self._init_with_new_cuda_stream() return - elif hasattr(obj, "__cuda_stream__"): - protocol = getattr(obj, "__cuda_stream__") - if protocol[0] != 0: - raise ValueError("Only protocol version 0 is supported") - self._cuda_stream = obj - self.owner = obj + elif isinstance(obj, Stream): + self._init_from_stream(obj) else: - # TODO: Remove this branch when numba and cupy - # streams implement __cuda_stream__ try: self._init_from_numba_stream(obj) except TypeError: self._init_from_cupy_stream(obj) - @property - def __cuda_stream__(self): - """Return an instance of a __cuda_stream__ protocol.""" - return (0, self.handle) - - @property - def handle(self) -> int: - """Return the underlying cudaStream_t pointer address as Python int.""" - return int(self._cuda_stream) - @staticmethod cdef Stream _from_cudaStream_t(cudaStream_t s, object owner=None) except *: """ @@ -136,6 +118,9 @@ cdef class Stream: self._cuda_stream = stream.value() self._owner = stream + cdef void _init_from_stream(self, Stream stream) except *: + self._cuda_stream, self._owner = stream._cuda_stream, stream._owner + DEFAULT_STREAM = Stream._from_cudaStream_t(cuda_stream_default.value()) LEGACY_DEFAULT_STREAM = Stream._from_cudaStream_t(cuda_stream_legacy.value()) diff --git a/python/rmm/rmm/tests/test_rmm.py b/python/rmm/rmm/tests/test_rmm.py index 844e0e2d4..ee02d5d0e 100644 --- a/python/rmm/rmm/tests/test_rmm.py +++ b/python/rmm/rmm/tests/test_rmm.py @@ -26,10 +26,10 @@ from numba import cuda import rmm -import rmm.pylibrmm.stream from rmm.allocators.cupy import rmm_cupy_allocator from rmm.allocators.numba import RMMNumbaManager from rmm.pylibrmm.logger import level_enum +from rmm.pylibrmm.stream import Stream cuda.set_memory_manager(RMMNumbaManager) @@ -348,8 +348,8 @@ def test_rmm_device_buffer_prefetch(pool, managed): def test_rmm_pool_numba_stream(stream): rmm.reinitialize(pool_allocator=True) - stream = rmm.pylibrmm.stream.Stream(stream) - a = rmm.pylibrmm.device_buffer.DeviceBuffer(size=3, stream=stream) + stream = Stream(stream) + a = rmm.DeviceBuffer(size=3, stream=stream) assert a.size == 3 assert a.ptr != 0 @@ -695,7 +695,7 @@ def test_cuda_async_memory_resource_stream(nelems): # with a non-default stream works mr = rmm.mr.CudaAsyncMemoryResource() rmm.mr.set_current_device_resource(mr) - stream = rmm.pylibrmm.stream.Stream() + stream = Stream() expected = np.full(nelems, 5, dtype="u1") dbuf = rmm.DeviceBuffer.to_device(expected, stream=stream) result = np.asarray(dbuf.copy_to_host())