Skip to content

Commit

Permalink
Fix race condition in device_scalar::set_value (#569)
Browse files Browse the repository at this point in the history
* Pass by ref instead of value.

* Improve docs for device_scalar asynchony.

* changelog

* Remove extraneous stream.

* Rename Dummy to Placeholder.

* Correct copy/paste docs.
  • Loading branch information
jrhemstad authored Sep 23, 2020
1 parent e3ce8c5 commit cce9081
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
- PR #545 Fix build to support using `clang` as the host compiler
- PR #534 Fix `pool_memory_resource` failure when init and max pool sizes are equal
- PR #546 Remove CUDA driver linking and correct NVTX macro.
- PR #569 Correct `device_scalar::set_value` to pass host value by reference to avoid copying from invalid value
- PR #559 Fix `align_down` to only change unaligned values.


Expand Down
52 changes: 44 additions & 8 deletions include/rmm/device_scalar.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,35 @@ class device_scalar {
* (e.g. using `cudaStreamWaitEvent()` or `cudaStreamSynchronize()`) before and after calling
* this function, otherwise there may be a race condition.
*
* Does not synchronize `stream`.
* This function does not synchronize `stream` before returning. Therefore, the object
* referenced by `host_value` should not be destroyed or modified until `stream` has been
* synchronized. Otherwise, behavior is undefined.
*
* @note: This function incurs a host to device memcpy and should be used sparingly.
* Example:
* \code{cpp}
* rmm::device_scalar<int32_t> s;
*
* int v{42};
*
* // Copies 42 to device storage on `stream`. Does _not_ synchronize
* vec.set_value(v, stream);
* ...
* cudaStreamSynchronize(stream);
* // Synchronization is required before `v` can be modified
* v = 13;
* \endcode
*
* @throws `rmm::cuda_error` if copying `host_value` to device memory fails.
* @throws `rmm::cuda_error` if synchronizing `stream` fails.
*
* @param host_value The host value which will be copied to device
* @param stream CUDA stream on which to perform the copy
*/
template <typename Dummy = void>
auto set_value(T host_value, cudaStream_t stream = 0)
-> std::enable_if_t<std::is_fundamental<T>::value, Dummy>
template <typename Placeholder = void>
auto set_value(T const &host_value, cudaStream_t stream = 0)
-> std::enable_if_t<std::is_fundamental<T>::value, Placeholder>
{
if (host_value == T{0}) {
RMM_CUDA_TRY(cudaMemsetAsync(buffer.data(), 0, sizeof(T), stream));
Expand All @@ -141,17 +159,35 @@ class device_scalar {
* (e.g. using `cudaStreamWaitEvent()` or `cudaStreamSynchronize()`) before and after calling
* this function, otherwise there may be a race condition.
*
* Does not synchronize `stream`.
* This function does not synchronize `stream` before returning. Therefore, the object
* referenced by `host_value` should not be destroyed or modified until `stream` has been
* synchronized. Otherwise, behavior is undefined.
*
* @note: This function incurs a host to device memcpy and should be used sparingly.
* Example:
* \code{cpp}
* rmm::device_scalar<int32_t> s;
*
* int v{42};
*
* // Copies 42 to device storage on `stream`. Does _not_ synchronize
* vec.set_value(v, stream);
* ...
* cudaStreamSynchronize(stream);
* // Synchronization is required before `v` can be modified
* v = 13;
* \endcode
*
* @throws `rmm::cuda_error` if copying `host_value` to device memory fails
* @throws `rmm::cuda_error` if synchronizing `stream` fails
*
* @param host_value The host value which will be copied to device
* @param stream CUDA stream on which to perform the copy
*/
template <typename Dummy = void>
auto set_value(T host_value, cudaStream_t stream = 0)
-> std::enable_if_t<not std::is_fundamental<T>::value, Dummy>
template <typename Placeholder = void>
auto set_value(T const &host_value, cudaStream_t stream = 0)
-> std::enable_if_t<not std::is_fundamental<T>::value, Placeholder>
{
_memcpy(buffer.data(), &host_value, stream);
}
Expand Down

0 comments on commit cce9081

Please sign in to comment.