Skip to content

Commit

Permalink
Add rmm::prefetch() and DeviceBuffer.prefetch() (#1573)
Browse files Browse the repository at this point in the history
This adds two `rmm::prefetch()` functions in C++.
1. `rmm::prefetch(void *ptr, size_t bytes, device, stream)`
2. `rmm::prefetch<T>(cuda::std::span<T> data, device, stream)`

Item 2 enables prefetching the containers that RMM provides (`device_uvector`, `device_scalar`) that support conversion to `cuda::std::span`. In order to enable that, `device_scalar::size()` is added.

Note that `device_buffer`s must be prefetched using item 1 because you can't create a `span<void>`.

In Python, this adds `DeviceBuffer.prefetch()` because that's really the only RMM Python data type to prefetch. There is *one* Cython use of `device_uvector` in cuDF `join` that we might need to add prefetch support for later.

`prefetch` is a no-op on non-managed memory. Rather than querying the type of memory, it just catches `cudaErrorInvalidValue` from `cudaMemPrefetchAsync`.

Authors:
  - Mark Harris (https://github.com/harrism)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)
  - Rong Ou (https://github.com/rongou)
  - Jake Hemstad (https://github.com/jrhemstad)
  - Michael Schellenberger Costa (https://github.com/miscco)
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #1573
  • Loading branch information
harrism authored and rongou committed Jun 17, 2024
1 parent e61bf52 commit 366615f
Show file tree
Hide file tree
Showing 10 changed files with 320 additions and 13 deletions.
8 changes: 8 additions & 0 deletions include/rmm/cuda_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

namespace rmm {

struct cuda_device_id;
inline cuda_device_id get_current_cuda_device();

/**
* @addtogroup cuda_device_management
* @{
Expand All @@ -34,6 +37,11 @@ namespace rmm {
struct cuda_device_id {
using value_type = int; ///< Integer type used for device identifier

/**
* @brief Construct a `cuda_device_id` from the current device
*/
cuda_device_id() noexcept : id_{get_current_cuda_device().value()} {}

/**
* @brief Construct a `cuda_device_id` from the specified integer value.
*
Expand Down
6 changes: 6 additions & 0 deletions include/rmm/device_scalar.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class device_scalar {
static_assert(std::is_trivially_copyable<T>::value, "Scalar type must be trivially copyable");

using value_type = typename device_uvector<T>::value_type; ///< T, the type of the scalar element
using size_type = typename device_uvector<T>::size_type; ///< The type used for the size
using reference = typename device_uvector<T>::reference; ///< value_type&
using const_reference = typename device_uvector<T>::const_reference; ///< const value_type&
using pointer =
Expand Down Expand Up @@ -254,6 +255,11 @@ class device_scalar {
return static_cast<const_pointer>(_storage.data());
}

/**
* @briefreturn{The size of the scalar: always 1}
*/
[[nodiscard]] constexpr size_type size() const noexcept { return 1; }

/**
* @briefreturn{Stream associated with the device memory allocation}
*/
Expand Down
77 changes: 77 additions & 0 deletions include/rmm/prefetch.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <rmm/cuda_device.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/error.hpp>

#include <cuda/std/span>

namespace rmm {

/**
* @addtogroup utilities
* @{
* @file
*/

/**
* @brief Prefetch memory to the specified device on the specified stream.
*
* This function is a no-op if the pointer is not to CUDA managed memory.
*
* @throw rmm::cuda_error if the prefetch fails.
*
* @param ptr The pointer to the memory to prefetch
* @param size The number of bytes to prefetch
* @param device The device to prefetch to
* @param stream The stream to use for the prefetch
*/
void prefetch(void const* ptr,
std::size_t size,
rmm::cuda_device_id device,
rmm::cuda_stream_view stream)
{
auto result = cudaMemPrefetchAsync(ptr, size, device.value(), stream.value());
// InvalidValue error is raised when non-managed memory is passed to cudaMemPrefetchAsync
// We should treat this as a no-op
if (result != cudaErrorInvalidValue && result != cudaSuccess) { RMM_CUDA_TRY(result); }
}

/**
* @brief Prefetch a span of memory to the specified device on the specified stream.
*
* This function is a no-op if the buffer is not backed by CUDA managed memory.
*
* @throw rmm::cuda_error if the prefetch fails.
*
* @param data The span to prefetch
* @param device The device to prefetch to
* @param stream The stream to use for the prefetch
*/
template <typename T>
void prefetch(cuda::std::span<T const> data,
rmm::cuda_device_id device,
rmm::cuda_stream_view stream)
{
prefetch(data.data(), data.size_bytes(), device, stream);
}

/** @} */ // end of group

} // namespace rmm
1 change: 1 addition & 0 deletions python/rmm/docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@
intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"numba": ("https://numba.readthedocs.io/en/stable", None),
"cuda-python": ("https://nvidia.github.io/cuda-python/", None),
}

# Config numpydoc
Expand Down
53 changes: 44 additions & 9 deletions python/rmm/docs/guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ There are two ways to use RMM in Python code:
RMM provides a `MemoryResource` abstraction to control _how_ device
memory is allocated in both the above uses.

### DeviceBuffers
### `DeviceBuffer` Objects

A DeviceBuffer represents an **untyped, uninitialized device memory
allocation**. DeviceBuffers can be created by providing the
A `DeviceBuffer` represents an **untyped, uninitialized device memory
allocation**. `DeviceBuffer`s can be created by providing the
size of the allocation in bytes:

```python
Expand All @@ -48,7 +48,7 @@ can be accessed via the `.size` and `.ptr` attributes respectively:
140202544726016
```

DeviceBuffers can also be created by copying data from host memory:
`DeviceBuffer`s can also be created by copying data from host memory:

```python
>>> import rmm
Expand All @@ -59,15 +59,50 @@ DeviceBuffers can also be created by copying data from host memory:
24
```

Conversely, the data underlying a DeviceBuffer can be copied to the
host:
Conversely, the data underlying a `DeviceBuffer` can be copied to the host:

```python
>>> np.frombuffer(buf.tobytes())
array([1., 2., 3.])
```

### MemoryResource objects
#### Prefetching a `DeviceBuffer`

[CUDA Unified Memory](
https://developer.nvidia.com/blog/unified-memory-cuda-beginners/
), also known as managed memory, can be allocated using an
`rmm.mr.ManagedMemoryResource` explicitly, or by calling `rmm.reinitialize`
with `managed_memory=True`.

A `DeviceBuffer` backed by managed memory or other
migratable memory (such as
[HMM/ATS](https://developer.nvidia.com/blog/simplifying-gpu-application-development-with-heterogeneous-memory-management/)
memory) may be prefetched to a specified device, for example to reduce or eliminate page faults.

```python
>>> import rmm
>>> rmm.reinitialize(managed_memory=True)
>>> buf = rmm.DeviceBuffer(size=100)
>>> buf.prefetch()
```

The above example prefetches the `DeviceBuffer` memory to the current CUDA device
on the stream that the `DeviceBuffer` last used (e.g. at construction). The
destination device ID and stream are optional parameters.

```python
>>> import rmm
>>> rmm.reinitialize(managed_memory=True)
>>> from rmm._cuda.stream import Stream
>>> stream = Stream()
>>> buf = rmm.DeviceBuffer(size=100, stream=stream)
>>> buf.prefetch(device=3, stream=stream) # prefetch to device on stream.
```

`DeviceBuffer.prefetch()` is a no-op if the `DeviceBuffer` is not backed
by migratable memory.

### `MemoryResource` objects

`MemoryResource` objects are used to configure how device memory allocations are made by
RMM.
Expand Down Expand Up @@ -122,13 +157,13 @@ Similarly, to use a pool of managed memory:
>>> rmm.mr.set_current_device_resource(pool)
```

Other MemoryResources include:
Other `MemoryResource`s include:

* `FixedSizeMemoryResource` for allocating fixed blocks of memory
* `BinningMemoryResource` for allocating blocks within specified "bin" sizes from different memory
resources

MemoryResources are highly configurable and can be composed together in different ways.
`MemoryResource`s are highly configurable and can be composed together in different ways.
See `help(rmm.mr)` for more information.

## Using RMM with third-party libraries
Expand Down
15 changes: 15 additions & 0 deletions python/rmm/rmm/_lib/device_buffer.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,21 @@ from rmm._lib.memory_resource cimport (
)


cdef extern from "rmm/mr/device/per_device_resource.hpp" namespace "rmm" nogil:
cdef cppclass cuda_device_id:
ctypedef int value_type
cuda_device_id()
cuda_device_id(value_type id)
value_type value()

cdef cuda_device_id get_current_cuda_device()

cdef extern from "rmm/prefetch.hpp" namespace "rmm" nogil:
cdef void prefetch(const void* ptr,
size_t bytes,
cuda_device_id device,
cuda_stream_view stream) except +

cdef extern from "rmm/device_buffer.hpp" namespace "rmm" nogil:
cdef cppclass device_buffer:
device_buffer()
Expand Down
25 changes: 25 additions & 0 deletions python/rmm/rmm/_lib/device_buffer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,31 @@ cdef class DeviceBuffer:
}
return intf

def prefetch(self, device=None, stream=None):
"""Prefetch buffer data to the specified device on the specified stream.
Assumes the storage for this DeviceBuffer is CUDA managed memory
(unified memory). If it is not, this function is a no-op.
Parameters
----------
device : optional
The CUDA device to which to prefetch the memory for this buffer.
Defaults to the current CUDA device. To prefetch to the CPU, pass
`~cuda.cudart.cudaCpuDeviceId` as the device.
stream : optional
CUDA stream to use for prefetching. Defaults to self.stream
"""
cdef cuda_device_id dev = (get_current_cuda_device()
if device is None
else cuda_device_id(device))
cdef Stream strm = self.stream if stream is None else stream
with nogil:
prefetch(self.c_obj.get()[0].data(),
self.c_obj.get()[0].size(),
dev,
strm.view())

def copy(self):
"""Returns a copy of DeviceBuffer.
Expand Down
27 changes: 27 additions & 0 deletions python/rmm/rmm/tests/test_rmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import warnings
from itertools import product

import cuda.cudart as cudart
import numpy as np
import pytest
from numba import cuda
Expand Down Expand Up @@ -284,6 +285,32 @@ def test_rmm_device_buffer_pickle_roundtrip(hb):
assert hb3 == hb


def assert_prefetched(buffer, device_id):
err, dev = cudart.cudaMemRangeGetAttribute(
4,
cudart.cudaMemRangeAttribute.cudaMemRangeAttributeLastPrefetchLocation,
buffer.ptr,
buffer.size,
)
assert err == cudart.cudaError_t.cudaSuccess
assert dev == device_id


@pytest.mark.parametrize(
"managed, pool", list(product([False, True], [False, True]))
)
def test_rmm_device_buffer_prefetch(pool, managed):
rmm.reinitialize(pool_allocator=pool, managed_memory=managed)
db = rmm.DeviceBuffer.to_device(np.zeros(256, dtype="u1"))
if managed:
assert_prefetched(db, cudart.cudaInvalidDeviceId)
db.prefetch() # just test that it doesn't throw
if managed:
err, device = cudart.cudaGetDevice()
assert err == cudart.cudaError_t.cudaSuccess
assert_prefetched(db, device)


@pytest.mark.parametrize("stream", [cuda.default_stream(), cuda.stream()])
def test_rmm_pool_numba_stream(stream):
rmm.reinitialize(pool_allocator=True)
Expand Down
11 changes: 7 additions & 4 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# =============================================================================
# Copyright (c) 2018-2023, NVIDIA CORPORATION.
# Copyright (c) 2018-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
Expand Down Expand Up @@ -174,12 +174,15 @@ ConfigureTest(DEVICE_BUFFER_TEST device_buffer_tests.cu)
# device scalar tests
ConfigureTest(DEVICE_SCALAR_TEST device_scalar_tests.cpp)

# logger tests
ConfigureTest(LOGGER_TEST logger_tests.cpp)

# uvector tests
ConfigureTest(DEVICE_UVECTOR_TEST device_uvector_tests.cpp GPUS 1 PERCENT 60)

# prefetch tests
ConfigureTest(PREFETCH_TEST prefetch_tests.cpp)

# logger tests
ConfigureTest(LOGGER_TEST logger_tests.cpp)

# arena MR tests
ConfigureTest(ARENA_MR_TEST mr/device/arena_mr_tests.cpp GPUS 1 PERCENT 60)

Expand Down
Loading

0 comments on commit 366615f

Please sign in to comment.