Skip to content

Commit

Permalink
Move rmm::exec_policy over to async_resource_ref
Browse files Browse the repository at this point in the history
This rewrites `rmm::exec_policy` to take a `cuda::std::async_resource_ref` instead of a plain `rmm::device_memory_resource*`. This is completely opaque to the user as the underlying `rmm::thrust_allocator` already takes a `cuda::mr::async_resource_ref`

Fixes #1448

Closes
  • Loading branch information
miscco committed Jan 30, 2024
1 parent 0234f23 commit aba5c97
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions include/rmm/exec_policy.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,6 +28,8 @@
#include <thrust/system/cuda/execution_policy.h>
#include <thrust/version.h>

#include <cuda/memory_resource>

namespace rmm {
/**
* @addtogroup thrust_integrations
Expand All @@ -47,15 +49,17 @@ using thrust_exec_policy_t =
* that uses RMM for temporary memory allocation on the specified stream.
*/
class exec_policy : public thrust_exec_policy_t {
using async_resource_ref = cuda::mr::async_resource_ref<cuda::mr::device_accessible>;

public:
/**
* @brief Construct a new execution policy object
*
* @param stream The stream on which to allocate temporary memory
* @param mr The resource to use for allocating temporary memory
*/
explicit exec_policy(cuda_stream_view stream = cuda_stream_default,
rmm::mr::device_memory_resource* mr = mr::get_current_device_resource())
explicit exec_policy(cuda_stream_view stream = cuda_stream_default,
async_resource_ref mr = rmm::mr::get_current_device_resource())
: thrust_exec_policy_t(
thrust::cuda::par(rmm::mr::thrust_allocator<char>(stream, mr)).on(stream.value()))
{
Expand All @@ -77,10 +81,11 @@ using thrust_exec_policy_nosync_t =
* are not required for correctness.
*/
class exec_policy_nosync : public thrust_exec_policy_nosync_t {
using async_resource_ref = cuda::mr::async_resource_ref<cuda::mr::device_accessible>;

public:
explicit exec_policy_nosync(
cuda_stream_view stream = cuda_stream_default,
rmm::mr::device_memory_resource* mr = mr::get_current_device_resource())
explicit exec_policy_nosync(cuda_stream_view stream = cuda_stream_default,
async_resource_ref mr = rmm::mr::get_current_device_resource())
: thrust_exec_policy_nosync_t(
thrust::cuda::par_nosync(rmm::mr::thrust_allocator<char>(stream, mr)).on(stream.value()))
{
Expand Down

0 comments on commit aba5c97

Please sign in to comment.