@@ -121,42 +121,42 @@ static OffsetCalculator<N> make_offset_calculator(const TensorIterator& iter) {
121
121
}
122
122
123
123
template <int nt, int vt, typename func_t >
124
- static void launch_kernel (int64_t N, const func_t & f) {
124
+ static void launch_kernel (int64_t N, func_t && f) {
125
125
TORCH_INTERNAL_ASSERT (N >= 0 && N <= std::numeric_limits<int32_t >::max ());
126
126
if (N == 0 ) {
127
127
return ;
128
128
}
129
129
dim3 block (nt);
130
130
dim3 grid ((N + block.x * vt - 1 ) / (block.x * vt));
131
131
auto stream = at::cuda::getCurrentCUDAStream ();
132
- elementwise_kernel<nt, vt, func_t ><<<grid, block, 0 , stream>>> (N, f );
132
+ elementwise_kernel<nt, vt, func_t ><<<grid, block, 0 , stream>>> (N, std::move (f) );
133
133
AT_CUDA_CHECK (cudaGetLastError ());
134
134
}
135
135
136
136
template <typename traits, typename func_t , typename index_t , size_t ... INDEX>
137
137
C10_HOST_DEVICE typename traits::result_type
138
- invoke_impl (const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], int i,
138
+ invoke_impl (func_t &f, char *const C10_RESTRICT data[], const index_t strides[], int i,
139
139
std::index_sequence<INDEX...>) {
140
140
return f (*(typename traits::template arg<INDEX>::type*)(data[INDEX] + i * strides[INDEX])...);
141
141
}
142
142
143
143
template <typename func_t , typename index_t , typename traits = function_traits<func_t >>
144
144
C10_HOST_DEVICE typename traits::result_type
145
- invoke (const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], int i) {
145
+ invoke (func_t &f, char *const C10_RESTRICT data[], const index_t strides[], int i) {
146
146
using Indices = std::make_index_sequence<traits::arity>;
147
147
return invoke_impl<traits>(f, data, strides, i, Indices{});
148
148
}
149
149
150
150
template <typename traits, typename func_t , typename index_t , size_t ... I>
151
151
C10_HOST_DEVICE typename traits::result_type
152
- invoke_impl (const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i,
152
+ invoke_impl (func_t &f, char *const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i,
153
153
std::index_sequence<I...>) {
154
154
return f (c10::fetch_and_cast<typename traits::template arg<I>::type>(dtypes[I], data[I] + i * strides[I])...);
155
155
}
156
156
157
157
template <typename func_t , typename index_t , typename traits = function_traits<func_t >>
158
158
C10_HOST_DEVICE typename traits::result_type
159
- invoke (const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i) {
159
+ invoke (func_t &f, char *const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i) {
160
160
using Indices = std::make_index_sequence<traits::arity>;
161
161
return invoke_impl<traits>(f, data, strides, dtypes, i, Indices{});
162
162
}
@@ -167,7 +167,7 @@ invoke(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[]
167
167
namespace modern {
168
168
169
169
template <typename func_t , typename policy_t >
170
- __device__ inline void elementwise_kernel_helper (func_t f, policy_t policy) {
170
+ __device__ inline void elementwise_kernel_helper (func_t & f, policy_t policy) {
171
171
using traits = function_traits<func_t >;
172
172
using return_t = typename traits::result_type;
173
173
using args_t = typename traits::ArgsTuple;
@@ -218,7 +218,7 @@ __global__ void unrolled_elementwise_kernel(int N, func_t f, array_t data, inp_c
218
218
219
219
// this function assume trivial 1d and no dynamic casting
220
220
template <typename func_t , typename array_t >
221
- static inline void launch_vectorized_kernel (int64_t N, const func_t & f, array_t data) {
221
+ static inline void launch_vectorized_kernel (int64_t N, func_t & f, array_t data) {
222
222
TORCH_INTERNAL_ASSERT (N > 0 && N <= std::numeric_limits<int32_t >::max ());
223
223
using traits = function_traits<func_t >;
224
224
int64_t grid = (N + block_work_size - 1 ) / block_work_size;
@@ -256,7 +256,7 @@ static inline void launch_unrolled_kernel(int64_t N, const func_t& f, array_t da
256
256
257
257
258
258
template <typename func_t >
259
- void gpu_kernel_impl (TensorIterator& iter, const func_t & f) {
259
+ void gpu_kernel_impl (TensorIterator& iter, func_t f) {
260
260
using traits = function_traits<func_t >;
261
261
using arg0_t = typename traits::result_type;
262
262
constexpr int ntensors = traits::arity + 1 ;
@@ -300,28 +300,28 @@ void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
300
300
}
301
301
302
302
if (needs_dynamic_casting<func_t >::check (iter)) {
303
- legacy::launch_kernel<launch_size_1d, 1 >(numel, [=]GPU_LAMBDA (int idx) {
303
+ legacy::launch_kernel<launch_size_1d, 1 >(numel, [=]GPU_LAMBDA (int idx) mutable {
304
304
void * out = data[0 ] + strides[0 ] * idx;
305
305
arg0_t result = legacy::invoke (f, &data.data [1 ], &strides.data [1 ], &dtypes.data [1 ], idx);
306
306
c10::cast_and_store<arg0_t >(dtypes[0 ], out, result);
307
307
});
308
308
} else {
309
- legacy::launch_kernel<launch_size_1d, 1 >(numel, [=]GPU_LAMBDA (int idx) {
309
+ legacy::launch_kernel<launch_size_1d, 1 >(numel, [=]GPU_LAMBDA (int idx) mutable {
310
310
arg0_t * out = (arg0_t *)(data[0 ] + strides[0 ] * idx);
311
311
*out = legacy::invoke (f, &data.data [1 ], &strides.data [1 ], idx);
312
312
});
313
313
}
314
314
} else {
315
315
auto offset_calc = legacy::make_offset_calculator<traits::arity + 1 >(iter);
316
316
if (needs_dynamic_casting<func_t >::check (iter)) {
317
- legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA (int idx) {
317
+ legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA (int idx) mutable {
318
318
auto offsets = offset_calc.get (idx);
319
319
void * out = data[0 ] + offsets[0 ];
320
320
arg0_t result = legacy::invoke (f, &data.data [1 ], &offsets.data [1 ], &dtypes.data [1 ], 1 );
321
321
c10::cast_and_store<arg0_t >(dtypes[0 ], out, result);
322
322
});
323
323
} else {
324
- legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA (int idx) {
324
+ legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA (int idx) mutable {
325
325
auto offsets = offset_calc.get (idx);
326
326
arg0_t * out = (arg0_t *)(data[0 ] + offsets[0 ]);
327
327
*out = legacy::invoke (f, &data.data [1 ], &offsets.data [1 ], 1 );
0 commit comments