Skip to content

Commit 39a101d

Browse files
zasdfgbnmfacebook-github-bot
authored andcommitted
Make GPU loops support mutable lambda (pytorch#35015)
Summary: I will need it for pytorch#34004 The `mutable` qualifier allows a lambda to capture some values, and modify its own copy. This would be useful for random kernels: we capture a `state` of RNG, initialize it when it first run, and the initialized stated will be used later: ```C++ gpu_kernel(iter, [state, initialized](scalar_t arg) mutable -> scalar_t { if (!initialized) { curand_init(..., state); initialized = true; } return some_math(curand_uniform(state), arg); } ``` The `operator()` of `mutable` lambda is not `const`, so we can not pass it as constant reference. It can not be called inside a non-`mutable` lambda either. Example usage: ```C++ auto t = at::empty({4096}, kCUDA); float thread_work_index_ = 0; auto iter = TensorIterator::nullary_op(t); gpu_kernel(iter, [thread_work_index_]GPU_LAMBDA() mutable -> float { return thread_work_index_++; }); ``` Pull Request resolved: pytorch#35015 Differential Revision: D20624698 Pulled By: ngimel fbshipit-source-id: 06e3987793451cd514181d20252510297e2d28a9
1 parent edad9c1 commit 39a101d

File tree

5 files changed

+47
-29
lines changed

5 files changed

+47
-29
lines changed

aten/src/ATen/native/cuda/CUDALoops.cuh

+13-13
Original file line numberDiff line numberDiff line change
@@ -121,42 +121,42 @@ static OffsetCalculator<N> make_offset_calculator(const TensorIterator& iter) {
121121
}
122122

123123
template<int nt, int vt, typename func_t>
124-
static void launch_kernel(int64_t N, const func_t& f) {
124+
static void launch_kernel(int64_t N, func_t &&f) {
125125
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
126126
if (N == 0) {
127127
return;
128128
}
129129
dim3 block(nt);
130130
dim3 grid((N + block.x * vt - 1) / (block.x * vt));
131131
auto stream = at::cuda::getCurrentCUDAStream();
132-
elementwise_kernel<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
132+
elementwise_kernel<nt, vt, func_t><<<grid, block, 0, stream>>>(N, std::move(f));
133133
AT_CUDA_CHECK(cudaGetLastError());
134134
}
135135

136136
template <typename traits, typename func_t, typename index_t, size_t... INDEX>
137137
C10_HOST_DEVICE typename traits::result_type
138-
invoke_impl(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], int i,
138+
invoke_impl(func_t &f, char *const C10_RESTRICT data[], const index_t strides[], int i,
139139
std::index_sequence<INDEX...>) {
140140
return f(*(typename traits::template arg<INDEX>::type*)(data[INDEX] + i * strides[INDEX])...);
141141
}
142142

143143
template <typename func_t, typename index_t, typename traits = function_traits<func_t>>
144144
C10_HOST_DEVICE typename traits::result_type
145-
invoke(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], int i) {
145+
invoke(func_t &f, char *const C10_RESTRICT data[], const index_t strides[], int i) {
146146
using Indices = std::make_index_sequence<traits::arity>;
147147
return invoke_impl<traits>(f, data, strides, i, Indices{});
148148
}
149149

150150
template <typename traits, typename func_t, typename index_t, size_t... I>
151151
C10_HOST_DEVICE typename traits::result_type
152-
invoke_impl(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i,
152+
invoke_impl(func_t &f, char *const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i,
153153
std::index_sequence<I...>) {
154154
return f(c10::fetch_and_cast<typename traits::template arg<I>::type>(dtypes[I], data[I] + i * strides[I])...);
155155
}
156156

157157
template <typename func_t, typename index_t, typename traits = function_traits<func_t>>
158158
C10_HOST_DEVICE typename traits::result_type
159-
invoke(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i) {
159+
invoke(func_t &f, char *const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i) {
160160
using Indices = std::make_index_sequence<traits::arity>;
161161
return invoke_impl<traits>(f, data, strides, dtypes, i, Indices{});
162162
}
@@ -167,7 +167,7 @@ invoke(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[]
167167
namespace modern {
168168

169169
template<typename func_t, typename policy_t>
170-
__device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
170+
__device__ inline void elementwise_kernel_helper(func_t &f, policy_t policy) {
171171
using traits = function_traits<func_t>;
172172
using return_t = typename traits::result_type;
173173
using args_t = typename traits::ArgsTuple;
@@ -218,7 +218,7 @@ __global__ void unrolled_elementwise_kernel(int N, func_t f, array_t data, inp_c
218218

219219
// this function assume trivial 1d and no dynamic casting
220220
template<typename func_t, typename array_t>
221-
static inline void launch_vectorized_kernel(int64_t N, const func_t& f, array_t data) {
221+
static inline void launch_vectorized_kernel(int64_t N, func_t& f, array_t data) {
222222
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
223223
using traits = function_traits<func_t>;
224224
int64_t grid = (N + block_work_size - 1) / block_work_size;
@@ -256,7 +256,7 @@ static inline void launch_unrolled_kernel(int64_t N, const func_t& f, array_t da
256256

257257

258258
template <typename func_t>
259-
void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
259+
void gpu_kernel_impl(TensorIterator& iter, func_t f) {
260260
using traits = function_traits<func_t>;
261261
using arg0_t = typename traits::result_type;
262262
constexpr int ntensors = traits::arity + 1;
@@ -300,28 +300,28 @@ void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
300300
}
301301

302302
if (needs_dynamic_casting<func_t>::check(iter)) {
303-
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
303+
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) mutable {
304304
void* out = data[0] + strides[0] * idx;
305305
arg0_t result = legacy::invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
306306
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
307307
});
308308
} else {
309-
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
309+
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) mutable {
310310
arg0_t* out = (arg0_t*)(data[0] + strides[0] * idx);
311311
*out = legacy::invoke(f, &data.data[1], &strides.data[1], idx);
312312
});
313313
}
314314
} else {
315315
auto offset_calc = legacy::make_offset_calculator<traits::arity + 1>(iter);
316316
if (needs_dynamic_casting<func_t>::check(iter)) {
317-
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
317+
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) mutable {
318318
auto offsets = offset_calc.get(idx);
319319
void* out = data[0] + offsets[0];
320320
arg0_t result = legacy::invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
321321
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
322322
});
323323
} else {
324-
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
324+
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) mutable {
325325
auto offsets = offset_calc.get(idx);
326326
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
327327
*out = legacy::invoke(f, &data.data[1], &offsets.data[1], 1);

aten/src/ATen/native/cuda/ROCmLoops.cuh

+13-13
Original file line numberDiff line numberDiff line change
@@ -96,42 +96,42 @@ static OffsetCalculator<N> make_offset_calculator(const TensorIterator& iter) {
9696
}
9797

9898
template<int nt, int vt, typename func_t>
99-
static void launch_kernel(int64_t N, const func_t& f) {
99+
static void launch_kernel(int64_t N, func_t&& f) {
100100
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
101101
if (N == 0) {
102102
return;
103103
}
104104
dim3 block(nt);
105105
dim3 grid((N + block.x * vt - 1) / (block.x * vt));
106106
auto stream = at::cuda::getCurrentCUDAStream();
107-
elementwise_kernel<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
107+
elementwise_kernel<nt, vt, func_t><<<grid, block, 0, stream>>>(N, std::move(f));
108108
AT_CUDA_CHECK(cudaGetLastError());
109109
}
110110

111111
template <typename traits, typename func_t, typename index_t, size_t... INDEX>
112112
C10_HOST_DEVICE typename traits::result_type
113-
invoke_impl(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], int i,
113+
invoke_impl(func_t &f, char *const C10_RESTRICT data[], const index_t strides[], int i,
114114
std::index_sequence<INDEX...>) {
115115
return f(*(typename traits::template arg<INDEX>::type*)(data[INDEX] + i * strides[INDEX])...);
116116
}
117117

118118
template <typename func_t, typename index_t, typename traits = function_traits<func_t>>
119119
C10_HOST_DEVICE typename traits::result_type
120-
invoke(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], int i) {
120+
invoke(func_t &f, char *const C10_RESTRICT data[], const index_t strides[], int i) {
121121
using Indices = std::make_index_sequence<traits::arity>;
122122
return invoke_impl<traits>(f, data, strides, i, Indices{});
123123
}
124124

125125
template <typename traits, typename func_t, typename index_t, size_t... I>
126126
C10_HOST_DEVICE typename traits::result_type
127-
invoke_impl(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i,
127+
invoke_impl(func_t &f, char *const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i,
128128
std::index_sequence<I...>) {
129129
return f(c10::fetch_and_cast<typename traits::template arg<I>::type>(dtypes[I], data[I] + i * strides[I])...);
130130
}
131131

132132
template <typename func_t, typename index_t, typename traits = function_traits<func_t>>
133133
C10_HOST_DEVICE typename traits::result_type
134-
invoke(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i) {
134+
invoke(func_t &f, char *const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i) {
135135
using Indices = std::make_index_sequence<traits::arity>;
136136
return invoke_impl<traits>(f, data, strides, dtypes, i, Indices{});
137137
}
@@ -259,7 +259,7 @@ __global__ void elementwise_kernel(int N, func_t f, array_t data) {
259259

260260
// TODO (@zasdfgbnm): this function assume trivial 1d and no dynamic casting
261261
template<typename func_t, typename array_t, std::enable_if_t<detail::has_same_arg_types<func_t>::value, int> = 0>
262-
static void launch_kernel(int64_t N, const func_t& f, array_t data) {
262+
static void launch_kernel(int64_t N, func_t f, array_t data) {
263263
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
264264
if (N == 0) {
265265
return;
@@ -271,13 +271,13 @@ static void launch_kernel(int64_t N, const func_t& f, array_t data) {
271271
}
272272

273273
template<typename func_t, typename array_t, std::enable_if_t<!detail::has_same_arg_types<func_t>::value, int> = 0>
274-
static void launch_kernel(int64_t N, const func_t& f, array_t data) {}
274+
static void launch_kernel(int64_t N, func_t f, array_t data) {}
275275

276276
} // namespace modern
277277

278278

279279
template <typename func_t>
280-
void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
280+
void gpu_kernel_impl(TensorIterator& iter, func_t f) {
281281
using traits = function_traits<func_t>;
282282
using arg0_t = typename traits::result_type;
283283
constexpr int ntensors = traits::arity + 1;
@@ -304,30 +304,30 @@ void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
304304
}
305305

306306
if (needs_dynamic_casting<func_t>::check(iter)) {
307-
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
307+
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) mutable {
308308
void* out = data[0] + strides[0] * idx;
309309
arg0_t result = legacy::invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
310310
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
311311
});
312312
} else if (iter.has_contiguous_first_dim() && modern::detail::has_same_arg_types<func_t>::value) {
313313
modern::launch_kernel(numel, f, data);
314314
} else {
315-
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
315+
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) mutable {
316316
arg0_t* out = (arg0_t*)(data[0] + strides[0] * idx);
317317
*out = legacy::invoke(f, &data.data[1], &strides.data[1], idx);
318318
});
319319
}
320320
} else {
321321
auto offset_calc = legacy::make_offset_calculator<traits::arity + 1>(iter);
322322
if (needs_dynamic_casting<func_t>::check(iter)) {
323-
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
323+
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) mutable {
324324
auto offsets = offset_calc.get(idx);
325325
void* out = data[0] + offsets[0];
326326
arg0_t result = legacy::invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
327327
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
328328
});
329329
} else {
330-
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
330+
legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) mutable {
331331
auto offsets = offset_calc.get(idx);
332332
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
333333
*out = legacy::invoke(f, &data.data[1], &offsets.data[1], 1);

aten/src/ATen/test/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ list(APPEND ATen_CUDA_TEST_SRCS
4747
${CMAKE_CURRENT_SOURCE_DIR}/cuda_optional_test.cu
4848
${CMAKE_CURRENT_SOURCE_DIR}/cuda_packedtensoraccessor_test.cu
4949
${CMAKE_CURRENT_SOURCE_DIR}/cuda_tensor_interop_test.cpp
50-
${CMAKE_CURRENT_SOURCE_DIR}/cuda_vectorized_test.cu
50+
${CMAKE_CURRENT_SOURCE_DIR}/cuda_loops_test.cu
5151
${CMAKE_CURRENT_SOURCE_DIR}/cuda_generator_test.cu)
5252
if (CAFFE2_USE_CUDNN)
5353
list(APPEND ATen_CUDA_TEST_SRCS

aten/src/ATen/test/cuda_vectorized_test.cu aten/src/ATen/test/cuda_loops_test.cu

+18
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <ATen/cuda/CUDAContext.h>
66
#include <ATen/core/Array.h>
77

8+
using namespace at;
89
using namespace at::native;
910
using namespace at::native::memory;
1011

@@ -25,6 +26,23 @@ void reset_buffers() {
2526
}
2627
}
2728

29+
Tensor thread_work_index() {
30+
auto t = at::empty({4096 * thread_work_size}, kCUDA);
31+
float thread_work_index_ = 0;
32+
auto iter = TensorIterator::nullary_op(t);
33+
gpu_kernel(iter, [thread_work_index_]GPU_LAMBDA() mutable -> float {
34+
return thread_work_index_++;
35+
});
36+
return t;
37+
}
38+
39+
TEST(TestLoops, MutableLambda) {
40+
auto t = thread_work_index();
41+
for (float i = 0; i < thread_work_size; i++) {
42+
ASSERT_EQ((t == i).to(kLong).sum().item<int64_t>(), 4096);
43+
}
44+
}
45+
2846
#ifdef __HIP_PLATFORM_HCC__
2947
TEST(TestLoops, HasSameArgTypes) {
3048
// This is a compile-time unit test. If this file compiles without error,

aten/tools/run_tests.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ fi
3939
if [[ -x ./cuda_half_test ]]; then
4040
./cuda_half_test
4141
fi
42-
if [[ -x ./cuda_vectorized_test ]]; then
43-
./cuda_vectorized_test
42+
if [[ -x ./cuda_loops_test ]]; then
43+
./cuda_loops_test
4444
fi
4545
if [[ -x ./cuda_distributions_test ]]; then
4646
./cuda_distributions_test

0 commit comments

Comments
 (0)