forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ForeachBinaryOpList.cu
479 lines (445 loc) · 19.3 KB
/
ForeachBinaryOpList.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Dispatch.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/native/ForeachUtils.h>
#include <ATen/native/cuda/ForeachFunctors.cuh>
#include <ATen/native/cuda/ForeachMinMaxFunctors.cuh>
#include <functional>
#include <type_traits>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_foreach_add_native.h>
#include <ATen/ops/_foreach_clamp_max_native.h>
#include <ATen/ops/_foreach_clamp_min_native.h>
#include <ATen/ops/_foreach_copy_native.h>
#include <ATen/ops/_foreach_div_native.h>
#include <ATen/ops/_foreach_mul_native.h>
#include <ATen/ops/_foreach_pow_native.h>
#include <ATen/ops/_foreach_sub_native.h>
#include <ATen/ops/empty_like_native.h>
#endif
namespace at::native {
template <typename T, template <class> class Op>
std::vector<Tensor> foreach_tensor_list_op(
TensorList tensors1,
TensorList tensors2,
const Scalar& alpha = 1) {
std::vector<std::vector<at::Tensor>> tensor_lists;
std::vector<at::Tensor> vec_res;
vec_res.reserve(tensors1.size());
for (const auto& t : tensors1) {
vec_res.emplace_back(at::native::empty_like(t));
}
tensor_lists.emplace_back(tensors1.vec());
tensor_lists.emplace_back(tensors2.vec());
tensor_lists.emplace_back(std::move(vec_res));
using opmath_t = at::opmath_type<T>;
multi_tensor_apply<3>(
tensor_lists,
BinaryOpListAlphaFunctor<
T,
/* depth */ 3,
/* r_args_depth */ 2,
/* res_arg_index */ 2>(),
Op<opmath_t>(),
alpha.to<opmath_t>());
return tensor_lists[2];
}
template <typename T, template <class> class Op>
void foreach_tensor_list_op_(
TensorList tensors1,
TensorList tensors2,
const Scalar& alpha = 1) {
std::vector<std::vector<at::Tensor>> tensor_lists;
tensor_lists.emplace_back(tensors1.vec());
tensor_lists.emplace_back(tensors2.vec());
using opmath_t = at::opmath_type<T>;
multi_tensor_apply<2>(
tensor_lists,
BinaryOpListAlphaFunctor<
T,
/* depth */ 2,
/* r_args_depth */ 2,
/* res_arg_index */ 0>(),
Op<opmath_t>(),
alpha.to<opmath_t>());
increment_version(tensors1);
}
template <template <class> class Op>
std::vector<Tensor> all_types_complex_bool_half_bfloat16(
TensorList tensors1,
TensorList tensors2,
const Scalar& alpha = 1) {
return AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kBool,
kBFloat16,
kHalf,
tensors1[0].scalar_type(),
"foreach_binary_op_list_cuda",
[&]() {
return foreach_tensor_list_op<scalar_t, Op>(tensors1, tensors2, alpha);
});
}
template <template <class> class Op>
void all_types_complex_bool_half_bfloat16_(
TensorList tensors1,
TensorList tensors2,
const Scalar& alpha = 1) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kBool,
kBFloat16,
kHalf,
tensors1[0].scalar_type(),
"foreach_binary_op_list_cuda_",
[&]() {
foreach_tensor_list_op_<scalar_t, Op>(tensors1, tensors2, alpha);
});
}
template <template <class> class Op>
std::vector<Tensor> all_types_half_bfloat16(
TensorList tensors1,
TensorList tensors2,
const Scalar& alpha = 1) {
return AT_DISPATCH_ALL_TYPES_AND2(
kBFloat16,
kHalf,
tensors1[0].scalar_type(),
"foreach_binary_op_list_cuda",
[&]() {
return foreach_tensor_list_op<scalar_t, Op>(tensors1, tensors2, alpha);
});
}
template <template <class> class Op>
void all_types_complex_half_bfloat16_(
TensorList tensors1,
TensorList tensors2,
const Scalar& alpha = 1) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
kBFloat16,
kHalf,
tensors1[0].scalar_type(),
"foreach_binary_op_list_cuda_",
[&]() {
foreach_tensor_list_op_<scalar_t, Op>(tensors1, tensors2, alpha);
});
}
template <template <class> class Op>
void all_types_half_bfloat16_(
TensorList tensors1,
TensorList tensors2,
const Scalar& alpha = 1) {
AT_DISPATCH_ALL_TYPES_AND2(
kBFloat16,
kHalf,
tensors1[0].scalar_type(),
"foreach_binary_op_list_cuda_",
[&]() {
foreach_tensor_list_op_<scalar_t, Op>(tensors1, tensors2, alpha);
});
}
template <template <class> class Op>
std::vector<Tensor> all_types_complex_half_bfloat16(
TensorList tensors1,
TensorList tensors2,
const Scalar& alpha = 1) {
return AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
kBFloat16,
kHalf,
tensors1[0].scalar_type(),
"foreach_binary_op_list_cuda",
[&]() {
return foreach_tensor_list_op<scalar_t, Op>(tensors1, tensors2, alpha);
});
}
#define FOREACH_BINARY_OP_LIST(FUNCTION, NAME, OP, DIVISION_OP) \
void foreach_tensor_##NAME##_list_kernel_cuda_( \
TensorList tensors1, TensorList tensors2) { \
check_foreach_api_restrictions(tensors1, tensors2); \
if (!can_use_fast_route(tensors1, tensors2, DIVISION_OP)) { \
return at::native::foreach_tensor_##NAME##_list_kernel_slow_( \
tensors1, tensors2); \
} \
\
FUNCTION##_<OP>(tensors1, tensors2); \
} \
\
std::vector<Tensor> foreach_tensor_##NAME##_list_kernel_cuda( \
TensorList tensors1, TensorList tensors2) { \
check_foreach_api_restrictions(tensors1, tensors2); \
if (!can_use_fast_route(tensors1, tensors2, DIVISION_OP)) { \
return at::native::foreach_tensor_##NAME##_list_kernel_slow( \
tensors1, tensors2); \
} \
\
return FUNCTION<OP>(tensors1, tensors2); \
}
#define FOREACH_BINARY_OP_LIST_ALPHA(FUNCTION, NAME, OP) \
void foreach_tensor_##NAME##_list_kernel_cuda_( \
TensorList tensors1, TensorList tensors2, const Scalar& alpha) { \
check_foreach_api_restrictions(tensors1, tensors2); \
if (!can_use_fast_route({tensors1, tensors2}, alpha)) { \
return at::native::foreach_tensor_##NAME##_list_kernel_slow_( \
tensors1, tensors2, alpha); \
} \
\
FUNCTION##_<OP>(tensors1, tensors2, alpha); \
} \
\
std::vector<Tensor> foreach_tensor_##NAME##_list_kernel_cuda( \
TensorList tensors1, TensorList tensors2, const Scalar& alpha) { \
check_foreach_api_restrictions(tensors1, tensors2); \
if (!can_use_fast_route({tensors1, tensors2}, alpha)) { \
return at::native::foreach_tensor_##NAME##_list_kernel_slow( \
tensors1, tensors2, alpha); \
} \
\
return FUNCTION<OP>(tensors1, tensors2, alpha); \
}
FOREACH_BINARY_OP_LIST_ALPHA(
all_types_complex_bool_half_bfloat16,
add,
std::plus);
FOREACH_BINARY_OP_LIST_ALPHA(
all_types_complex_bool_half_bfloat16,
sub,
std::minus);
FOREACH_BINARY_OP_LIST(
all_types_complex_bool_half_bfloat16,
mul,
std::multiplies,
/*division_op*/ false);
FOREACH_BINARY_OP_LIST(
all_types_complex_bool_half_bfloat16,
div,
std::divides,
/*division_op*/ true);
// NOTE(crcrpar): `all_types_half_bfloat16` does not cover bool, so temporarily
// set `division_op` to true.
FOREACH_BINARY_OP_LIST(
all_types_half_bfloat16,
clamp_max,
minimum,
/*division_op*/ true);
FOREACH_BINARY_OP_LIST(
all_types_half_bfloat16,
clamp_min,
maximum,
/*division_op*/ true);
// NOTE(crcrpar): [Why is foreach_pow's division_op=true?]
// To push integer inputs to slow path. This is because with integer type inputs
// the fast path behaves differently from the slow one. Need to investigate
// later.
FOREACH_BINARY_OP_LIST(
all_types_complex_half_bfloat16,
pow,
power_functor,
/*division_op*/ true);
template <typename dst_t, typename src_t = dst_t>
struct Copy {
__device__ __forceinline__ dst_t operator()(const src_t& x) {
return static_cast<dst_t>(x);
}
};
template <typename dst_t>
struct Copy<dst_t, c10::complex<double>> {
__device__ __forceinline__ dst_t operator()(const c10::complex<double>& x) {
if constexpr (!(std::is_same_v<dst_t, c10::complex<double>> ||
std::is_same_v<dst_t, c10::complex<float>>)) {
return static_cast<dst_t>(x.real());
} else {
return static_cast<dst_t>(x);
}
}
};
template <typename dst_t>
struct Copy<dst_t, c10::complex<float>> {
__device__ __forceinline__ dst_t operator()(const c10::complex<float>& x) {
if constexpr (!(std::is_same_v<dst_t, c10::complex<double>> ||
std::is_same_v<dst_t, c10::complex<float>>)) {
return static_cast<dst_t>(x.real());
} else {
return static_cast<dst_t>(x);
}
}
};
#define AT_DISPATCH_SOURCE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Byte, \
src_t, \
__VA_ARGS__) AT_PRIVATE_CASE_TYPE_USING_HINT(at::ScalarType::Char, src_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Long, src_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Short, src_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Int, src_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Double, src_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Float, src_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::ComplexDouble, \
src_t, \
__VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::ComplexFloat, \
src_t, \
__VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Half, \
src_t, \
__VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::BFloat16, \
src_t, \
__VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Bool, \
src_t, \
__VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType:: \
Float8_e4m3fn, \
src_t, \
__VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType:: \
Float8_e4m3fnuz, \
src_t, \
__VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType:: \
Float8_e5m2, \
src_t, \
__VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType:: \
Float8_e5m2fnuz, \
src_t, \
__VA_ARGS__))
namespace {
template <
typename T,
typename src_t,
int depth,
int r_args_depth,
int res_arg_index>
struct CopyFunctor {
static_assert(depth == 2 && r_args_depth == 1 && res_arg_index == 1);
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth>& tl,
Op op) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
auto n = tl.numel_for_tensor[tensor_loc];
src_t* src_ptr = (src_t*)tl.addresses[0][tensor_loc];
src_ptr += chunk_idx * chunk_size;
T* self_ptr = (T*)tl.addresses[1][tensor_loc];
self_ptr += chunk_idx * chunk_size;
const bool all_aligned{is_aligned(src_ptr) && is_aligned(self_ptr)};
n -= chunk_idx * chunk_size;
src_t src_args[kILP];
T r_args[kILP];
// to make things simple, we put aligned case in a different code path
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
for (int64_t i_start = threadIdx.x;
i_start * kILP < n && i_start * kILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(src_args, src_ptr, 0, i_start);
#pragma unroll
for (int ii = 0; ii < kILP; ii++) {
r_args[ii] = static_cast<T>(op(src_args[ii]));
}
// store
load_store(self_ptr, r_args, i_start, 0);
}
} else {
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * kILP) {
#pragma unroll
for (int ii = 0; ii < kILP; ii++) {
const auto i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
src_args[ii] = src_ptr[i];
}
}
#pragma unroll
for (int ii = 0; ii < kILP; ii++) {
r_args[ii] = static_cast<T>(op(src_args[ii]));
}
store_args(self_ptr, r_args, i_start, chunk_size, n);
}
}
}
};
} // anonymous namespace
void foreach_tensor_copy_list_kernel_cuda_(
TensorList self,
TensorList src,
const bool non_blocking) {
check_foreach_api_restrictions(self, src);
if (!(_check_tensors_share_device_and_dtype(
{self, src}, /* skip_dtype_check */ true) &&
std::all_of(
src.cbegin(),
src.cend(),
[&](const auto& t) -> bool {
return t.dtype() == src[0].dtype();
}) &&
_check_tensors_share_sizes_and_strides({self, src}))) {
return at::native::foreach_tensor_copy_list_kernel_slow_(
self, src, non_blocking);
}
std::vector<std::vector<at::Tensor>> tensor_lists{src.vec(), self.vec()};
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND7(
ScalarType::Half,
ScalarType::BFloat16,
ScalarType::Bool,
ScalarType::Float8_e4m3fn,
ScalarType::Float8_e4m3fnuz,
ScalarType::Float8_e5m2,
ScalarType::Float8_e5m2fnuz,
self[0].scalar_type(),
"foreach_tensor_copy",
[&]() {
using opmath_t = at::opmath_type<scalar_t>;
AT_DISPATCH_SOURCE_TYPES(src[0].scalar_type(), "foreach_tensor_copy", [&] {
if constexpr (std::is_same_v<scalar_t, src_t>) {
multi_tensor_apply<2>(
tensor_lists,
UnaryOpFunctor<
scalar_t,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1>(),
Copy<opmath_t, opmath_t>());
} else {
// Ref:
// https://github.com/pytorch/pytorch/blob/656134c38f4737d13c3f43fc5c59470bc23c1d2f/aten/src/ATen/native/Copy.cpp#L299-L301
if (!self[0].is_complex() && src[0].is_complex()) {
TORCH_WARN_ONCE(
"Casting complex values to real discards the imaginary part");
}
multi_tensor_apply<2>(
tensor_lists,
CopyFunctor<
scalar_t,
src_t,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1>(),
Copy<scalar_t, src_t>());
}
});
});
increment_version(self);
}
#undef AT_DISPATCH_SOURCE_TYPES
} // namespace at::native