Skip to content

Commit

Permalink
[PyTorch] Use pinned memory for zero_cuda_out (pytorch#134712)
Browse files Browse the repository at this point in the history
Summary: This diff creates a pinned tensor for copying from device for the zero_out op.

Differential Revision: D61759262

Pull Request resolved: pytorch#134712
Approved by: https://github.com/zyan0
  • Loading branch information
banitag1 authored and tolleybot committed Sep 14, 2024
1 parent 75f8ad5 commit fd99b4a
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion aten/src/ATen/native/cuda/Nonzero.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/EmptyTensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/EmptyTensor.h>
Expand Down Expand Up @@ -70,7 +71,16 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out){
auto temp_storage = allocator.allocate(temp_storage_bytes);
cub::DeviceReduce::Sum(temp_storage.get(), temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream);
int num_nonzeros_h;
at::cuda::memcpy_and_sync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int), cudaMemcpyDeviceToHost, stream);
auto pinned_num_nonzeros_h = at::detail::empty_cpu(
{1}, /* size */
c10::CppTypeToScalarType<int>(), /* dtype */
std::nullopt, /* layout */
std::nullopt, /* device */
true, /* pin_memory */
std::nullopt /* memory format */
);
at::cuda::memcpy_and_sync((void *)pinned_num_nonzeros_h.const_data_ptr<int>(), num_nonzeros.get(), sizeof(int), cudaMemcpyDeviceToHost, stream);
num_nonzeros_h = (int)*(pinned_num_nonzeros_h.const_data_ptr<int>());
//expected output size is num_nonzeros x ndim
//we are producing output with size {num_nonzeros, ndim} and strides {1, num_nonzeros} (that is, transposed ndim x num_nonzeros output)
//we are able to directly use passed output with this size and strides, and we can also (per contract)
Expand Down

0 comments on commit fd99b4a

Please sign in to comment.