-
Notifications
You must be signed in to change notification settings - Fork 3
/
THCNumerics.cuh
349 lines (304 loc) · 19.3 KB
/
THCNumerics.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
#ifndef THC_NUMERICS_INC
#define THC_NUMERICS_INC
#include <cstdlib>
#include <limits>
#include <cuda.h>
#include <assert.h>
#include <TH/THHalf.h>
#include <ATen/ATen.h>
#include <ATen/cuda/NumericLimits.cuh>
// WARNING: THCNumerics is being deprecated. Please follow the comments
// in this file to learn about new usages.
// Comments on usage:
// - lt,le,gt,ge,eq,neg,add,mul,sub,div and other binary ops can
// be implemented using CUDA_apply_utils or binary cuda kernel
// - Check NumericLimits.cuh for specialized math functions.
// - Note how __half and at::Half can be casted. for instance:
// static_cast<at::Half>(std::sin(static_cast<at::Half>(a)));
template <typename T>
struct THCNumerics {
};
template <typename T>
static inline __host__ __device__ T powi(T a, T b) {
assert(THCNumerics<T>::ge(b, 0));
T result = 1;
while (b) {
if (b & 1) {
result *= a;
}
b /= 2;
a *= a;
}
return result;
}
// DEPRECATED: For integral types, use math functions from std and NumericLimits.cuh.
// Use binary_kernel or CUDA_apply_utils for arithmetic
template <>
struct THCNumerics<uint8_t> {
static inline __host__ __device__ uint8_t min() { return at::numeric_limits<uint8_t>::lowest(); }
static inline __host__ __device__ uint8_t max() { return at::numeric_limits<uint8_t>::max(); }
static inline __host__ __device__ uint8_t lower_bound() { return at::numeric_limits<uint8_t>::lower_bound(); }
static inline __host__ __device__ uint8_t upper_bound() { return at::numeric_limits<uint8_t>::upper_bound(); }
static inline __host__ __device__ bool lt(uint8_t a, uint8_t b) { return a < b; }
static inline __host__ __device__ bool le(uint8_t a, uint8_t b) { return a <= b; }
static inline __host__ __device__ bool gt(uint8_t a, uint8_t b) { return a > b; }
static inline __host__ __device__ bool ge(uint8_t a, uint8_t b) { return a >= b; }
static inline __host__ __device__ bool eq(uint8_t a, uint8_t b) { return a == b; }
static inline __host__ __device__ bool ne(uint8_t a, uint8_t b) { return a != b; }
static inline __host__ __device__ uint8_t add(uint8_t a, uint8_t b) { return a + b; }
static inline __host__ __device__ uint8_t mul(uint8_t a, uint8_t b) { return a * b; }
static inline __host__ __device__ uint8_t sub(uint8_t a, uint8_t b) { return a - b; }
static inline __host__ __device__ uint8_t div(uint8_t a, uint8_t b) { return a / b; }
static inline __host__ __device__ uint8_t pow(uint8_t a, uint8_t b) { return powi<uint8_t>(a, b); }
static inline __host__ __device__ bool isnan(uint8_t a) { return false; }
static inline __host__ __device__ bool isinf(uint8_t a) { return false; }
};
#ifdef _MSC_VER
// Suppress warning C4804: '/': unsafe use of type 'bool' in operation
#pragma warning( push )
#pragma warning( disable : 4804 )
#endif
template <>
struct THCNumerics<bool> {
static inline __host__ __device__ bool min() { return at::numeric_limits<bool>::lowest(); }
static inline __host__ __device__ bool max() { return at::numeric_limits<bool>::max(); }
static inline __host__ __device__ bool lower_bound() { return at::numeric_limits<bool>::lower_bound(); }
static inline __host__ __device__ bool upper_bound() { return at::numeric_limits<bool>::upper_bound(); }
static inline __host__ __device__ bool lt(bool a, bool b) { return a < b; }
static inline __host__ __device__ bool le(bool a, bool b) { return a <= b; }
static inline __host__ __device__ bool gt(bool a, bool b) { return a > b; }
static inline __host__ __device__ bool ge(bool a, bool b) { return a >= b; }
static inline __host__ __device__ bool eq(bool a, bool b) { return a == b; }
static inline __host__ __device__ bool ne(bool a, bool b) { return a != b; }
static inline __host__ __device__ bool add(bool a, bool b) { return a + b; }
static inline __host__ __device__ bool mul(bool a, bool b) { return a && b; }
static inline __host__ __device__ bool sub(bool a, bool b) { return a - b; }
static inline __host__ __device__ bool div(bool a, bool b) { return a / b; }
static inline __host__ __device__ bool isnan(bool a) { return false; }
static inline __host__ __device__ bool isinf(bool a) { return false; }
};
#ifdef _MSC_VER
#pragma warning( pop )
#endif
template <>
struct THCNumerics<int8_t> {
static inline __host__ __device__ int8_t min() { return at::numeric_limits<int8_t>::lowest(); }
static inline __host__ __device__ int8_t max() { return at::numeric_limits<int8_t>::max(); }
static inline __host__ __device__ int8_t lower_bound() { return at::numeric_limits<int8_t>::lower_bound(); }
static inline __host__ __device__ int8_t upper_bound() { return at::numeric_limits<int8_t>::upper_bound(); }
static inline __host__ __device__ bool lt(int8_t a, int8_t b) { return a < b; }
static inline __host__ __device__ bool le(int8_t a, int8_t b) { return a <= b; }
static inline __host__ __device__ bool gt(int8_t a, int8_t b) { return a > b; }
static inline __host__ __device__ bool ge(int8_t a, int8_t b) { return a >= b; }
static inline __host__ __device__ bool eq(int8_t a, int8_t b) { return a == b; }
static inline __host__ __device__ bool ne(int8_t a, int8_t b) { return a != b; }
static inline __host__ __device__ int8_t add(int8_t a, int8_t b) { return a + b; }
static inline __host__ __device__ int8_t mul(int8_t a, int8_t b) { return a * b; }
static inline __host__ __device__ int8_t sub(int8_t a, int8_t b) { return a - b; }
static inline __host__ __device__ int8_t div(int8_t a, int8_t b) { return a / b; }
static inline __host__ __device__ int8_t pow(int8_t a, int8_t b) { return powi<int8_t>(a, b); }
static inline __host__ __device__ bool isnan(int8_t a) { return false; }
static inline __host__ __device__ bool isinf(int8_t a) { return false; }
};
template <>
struct THCNumerics<int16_t> {
static inline __host__ __device__ int16_t min() { return at::numeric_limits<int16_t>::lowest(); }
static inline __host__ __device__ int16_t max() { return at::numeric_limits<int16_t>::max(); }
static inline __host__ __device__ int16_t lower_bound() { return at::numeric_limits<int16_t>::lower_bound(); }
static inline __host__ __device__ int16_t upper_bound() { return at::numeric_limits<int16_t>::upper_bound(); }
static inline __host__ __device__ bool lt(int16_t a, int16_t b) { return a < b; }
static inline __host__ __device__ bool le(int16_t a, int16_t b) { return a <= b; }
static inline __host__ __device__ bool gt(int16_t a, int16_t b) { return a > b; }
static inline __host__ __device__ bool ge(int16_t a, int16_t b) { return a >= b; }
static inline __host__ __device__ bool eq(int16_t a, int16_t b) { return a == b; }
static inline __host__ __device__ bool ne(int16_t a, int16_t b) { return a != b; }
static inline __host__ __device__ int16_t add(int16_t a, int16_t b) { return a + b; }
static inline __host__ __device__ int16_t mul(int16_t a, int16_t b) { return a * b; }
static inline __host__ __device__ int16_t sub(int16_t a, int16_t b) { return a - b; }
static inline __host__ __device__ int16_t div(int16_t a, int16_t b) { return a / b; }
static inline __host__ __device__ int16_t pow(int16_t a, int16_t b) { return powi<int16_t>(a, b); }
static inline __host__ __device__ bool isnan(int16_t a) { return false; }
static inline __host__ __device__ bool isinf(int16_t a) { return false; }
};
template <>
struct THCNumerics<int32_t> {
static inline __host__ __device__ int32_t min() { return at::numeric_limits<int32_t>::lowest(); }
static inline __host__ __device__ int32_t max() { return at::numeric_limits<int32_t>::max(); }
static inline __host__ __device__ int32_t lower_bound() { return at::numeric_limits<int32_t>::lower_bound(); }
static inline __host__ __device__ int32_t upper_bound() { return at::numeric_limits<int32_t>::upper_bound(); }
static inline __host__ __device__ bool lt(int32_t a, int32_t b) { return a < b; }
static inline __host__ __device__ bool le(int32_t a, int32_t b) { return a <= b; }
static inline __host__ __device__ bool gt(int32_t a, int32_t b) { return a > b; }
static inline __host__ __device__ bool ge(int32_t a, int32_t b) { return a >= b; }
static inline __host__ __device__ bool eq(int32_t a, int32_t b) { return a == b; }
static inline __host__ __device__ bool ne(int32_t a, int32_t b) { return a != b; }
static inline __host__ __device__ int32_t add(int32_t a, int32_t b) { return a + b; }
static inline __host__ __device__ int32_t mul(int32_t a, int32_t b) { return a * b; }
static inline __host__ __device__ int32_t sub(int32_t a, int32_t b) { return a - b; }
static inline __host__ __device__ int32_t div(int32_t a, int32_t b) { return a / b; }
static inline __host__ __device__ int32_t pow(int32_t a, int32_t b) { return powi<int32_t>(a, b); }
static inline __host__ __device__ bool isnan(int32_t a) { return false; }
static inline __host__ __device__ bool isinf(int32_t a) { return false; }
};
template <>
struct THCNumerics<int64_t> {
static inline __host__ __device__ int64_t min() { return at::numeric_limits<int64_t>::lowest(); }
static inline __host__ __device__ int64_t max() { return at::numeric_limits<int64_t>::max(); }
static inline __host__ __device__ int64_t lower_bound() { return at::numeric_limits<int64_t>::lower_bound(); }
static inline __host__ __device__ int64_t upper_bound() { return at::numeric_limits<int64_t>::upper_bound(); }
static inline __host__ __device__ bool lt(int64_t a, int64_t b) { return a < b; }
static inline __host__ __device__ bool le(int64_t a, int64_t b) { return a <= b; }
static inline __host__ __device__ bool gt(int64_t a, int64_t b) { return a > b; }
static inline __host__ __device__ bool ge(int64_t a, int64_t b) { return a >= b; }
static inline __host__ __device__ bool eq(int64_t a, int64_t b) { return a == b; }
static inline __host__ __device__ bool ne(int64_t a, int64_t b) { return a != b; }
static inline __host__ __device__ int64_t add(int64_t a, int64_t b) { return a + b; }
static inline __host__ __device__ int64_t mul(int64_t a, int64_t b) { return a * b; }
static inline __host__ __device__ int64_t sub(int64_t a, int64_t b) { return a - b; }
static inline __host__ __device__ int64_t div(int64_t a, int64_t b) { return a / b; };
static inline __host__ __device__ int64_t pow(int64_t a, int64_t b) { return powi<int64_t>(a, b); }
static inline __host__ __device__ bool isnan(int64_t a) { return false; }
static inline __host__ __device__ bool isinf(int64_t a) { return false; }
};
// DEPRECATED: use math functions from std and NumericLimits.cuh
template <>
struct THCNumerics<at::Half> {
static inline __host__ __device__ at::Half min() { return at::numeric_limits<at::Half>::lowest(); }
static inline __host__ __device__ at::Half max() { return at::numeric_limits<at::Half>::max(); }
static inline __host__ __device__ at::Half lower_bound() { return at::numeric_limits<at::Half>::lower_bound(); }
static inline __host__ __device__ at::Half upper_bound() { return at::numeric_limits<at::Half>::upper_bound(); }
static inline __host__ __device__ bool lt(at::Half a, at::Half b) { return a < b; }
static inline __host__ __device__ bool le(at::Half a, at::Half b) { return a <= b; }
static inline __host__ __device__ bool gt(at::Half a, at::Half b) { return a > b; }
static inline __host__ __device__ bool ge(at::Half a, at::Half b) { return a >= b; }
static inline __host__ __device__ bool eq(at::Half a, at::Half b) { return a == b; }
static inline __host__ __device__ bool ne(at::Half a, at::Half b) { return a != b; }
static inline __host__ __device__ at::Half sqrt(at::Half a) { return ::sqrt(a); }
static inline __host__ __device__ at::Half atan(at::Half a) { return ::atan(a); }
static inline __host__ __device__ at::Half add(at::Half a, at::Half b) { return a + b; }
static inline __host__ __device__ at::Half div(at::Half a, at::Half b) { return a / b; }
static inline __host__ __device__ at::Half mul(at::Half a, at::Half b) { return a * b; }
static inline __host__ __device__ at::Half sub(at::Half a, at::Half b) { return a - b; }
static inline __host__ __device__ at::Half pow(at::Half a, at::Half b) { return ::pow(a, b); }
static inline __host__ __device__ bool isnan(at::Half a) {
#ifdef _MSC_VER
// Windows requires this explicit conversion. The reason is unclear
// related issue with clang: https://reviews.llvm.org/D37906
return ::isnan((float) a);
#else
return ::isnan(a);
#endif
}
static inline __host__ __device__ bool isinf(at::Half a) {
#ifdef _MSC_VER
// Windows requires this explicit conversion. The reason is unclear
// related issue with clang: https://reviews.llvm.org/D37906
return ::isinf((float) a);
#else
return ::isinf(a);
#endif
}
};
// DEPRECATED: use math functions from std and cuda math API (if needed)
template <>
struct THCNumerics<float> {
static inline __host__ __device__ float min() { return at::numeric_limits<float>::lowest(); }
static inline __host__ __device__ float max() { return at::numeric_limits<float>::max(); }
static inline __host__ __device__ float lower_bound() { return at::numeric_limits<float>::lower_bound(); }
static inline __host__ __device__ float upper_bound() { return at::numeric_limits<float>::upper_bound(); }
static inline __host__ __device__ bool lt(float a, float b) { return a < b; }
static inline __host__ __device__ bool le(float a, float b) { return a <= b; }
static inline __host__ __device__ bool gt(float a, float b) { return a > b; }
static inline __host__ __device__ bool ge(float a, float b) { return a >= b; }
static inline __host__ __device__ bool eq(float a, float b) { return a == b; }
static inline __host__ __device__ bool ne(float a, float b) { return a != b; }
static inline __host__ __device__ float sqrt (float a) { return sqrtf(a); }
static inline __host__ __device__ float atan (float a) { return atanf(a); }
static inline __host__ __device__ float add (float a, float b) { return a + b; }
static inline __host__ __device__ float div (float a, float b) { return a / b; }
static inline __host__ __device__ float mul (float a, float b) { return a * b; }
static inline __host__ __device__ float sub (float a, float b) { return a - b; }
static inline __host__ __device__ float pow (float a, float b) { return powf(a, b); }
static inline __host__ __device__ bool isnan(float a) { return ::isnan(a); }
static inline __host__ __device__ bool isinf(float a) { return ::isinf(a); }
};
template <>
struct THCNumerics<at::BFloat16> {
static inline __host__ __device__ at::BFloat16 min() { return at::numeric_limits<at::BFloat16>::lowest(); }
static inline __host__ __device__ at::BFloat16 max() { return at::numeric_limits<at::BFloat16>::max(); }
static inline __host__ __device__ at::BFloat16 lower_bound() { return at::numeric_limits<at::BFloat16>::lower_bound(); }
static inline __host__ __device__ at::BFloat16 upper_bound() { return at::numeric_limits<at::BFloat16>::upper_bound(); }
static inline __host__ __device__ bool lt(at::BFloat16 a, at::BFloat16 b) { return a < b; }
static inline __host__ __device__ bool le(at::BFloat16 a, at::BFloat16 b) { return a <= b; }
static inline __host__ __device__ bool gt(at::BFloat16 a, at::BFloat16 b) { return a > b; }
static inline __host__ __device__ bool ge(at::BFloat16 a, at::BFloat16 b) { return a >= b; }
static inline __host__ __device__ bool eq(at::BFloat16 a, at::BFloat16 b) { return a == b; }
static inline __host__ __device__ bool ne(at::BFloat16 a, at::BFloat16 b) { return a != b; }
static inline __host__ __device__ at::BFloat16 sqrt (at::BFloat16 a) { return sqrtf(a); }
static inline __host__ __device__ at::BFloat16 atan (at::BFloat16 a) { return atanf(a); }
static inline __host__ __device__ at::BFloat16 add (at::BFloat16 a, at::BFloat16 b) { return a + b; }
static inline __host__ __device__ at::BFloat16 div (at::BFloat16 a, at::BFloat16 b) { return a / b; }
static inline __host__ __device__ at::BFloat16 mul (at::BFloat16 a, at::BFloat16 b) { return a * b; }
static inline __host__ __device__ at::BFloat16 sub (at::BFloat16 a, at::BFloat16 b) { return a - b; }
static inline __host__ __device__ at::BFloat16 pow (at::BFloat16 a, at::BFloat16 b) { return powf(a, b); }
static inline __host__ __device__ at::BFloat16 atan2(at::BFloat16 a, at::BFloat16 b) { return atan2f(a, b); }
static inline __host__ __device__ bool isnan(at::BFloat16 a) {
#ifdef _MSC_VER
// Windows requires this explicit conversion. The reason is unclear
// related issue with clang: https://reviews.llvm.org/D37906
return ::isnan((float) a);
#else
return ::isnan(a);
#endif
}
static inline __host__ __device__ bool isinf(at::BFloat16 a) {
#ifdef _MSC_VER
// Windows requires this explicit conversion. The reason is unclear
// related issue with clang: https://reviews.llvm.org/D37906
return ::isinf((float) a);
#else
return ::isinf(a);
#endif
}
};
// DEPRECATED: use math functions from std and cuda math API (if needed)
template <>
struct THCNumerics<double> {
static inline __host__ __device__ double min() { return at::numeric_limits<double>::lowest(); }
static inline __host__ __device__ double max() { return at::numeric_limits<double>::max(); }
static inline __host__ __device__ double lower_bound() { return at::numeric_limits<double>::lower_bound(); }
static inline __host__ __device__ double upper_bound() { return at::numeric_limits<double>::upper_bound(); }
static inline __host__ __device__ bool lt(double a, double b) { return a < b; }
static inline __host__ __device__ bool le(double a, double b) { return a <= b; }
static inline __host__ __device__ bool gt(double a, double b) { return a > b; }
static inline __host__ __device__ bool ge(double a, double b) { return a >= b; }
static inline __host__ __device__ bool eq(double a, double b) { return a == b; }
static inline __host__ __device__ bool ne(double a, double b) { return a != b; }
static inline __host__ __device__ double sqrt (double a) { return ::sqrt(a); }
static inline __host__ __device__ double atan (double a) { return ::atan(a); }
static inline __host__ __device__ double add (double a, double b) { return a + b; }
static inline __host__ __device__ double div (double a, double b) { return a / b; }
static inline __host__ __device__ double mul (double a, double b) { return a * b; }
static inline __host__ __device__ double sub (double a, double b) { return a - b; }
static inline __host__ __device__ double pow (double a, double b) { return ::pow(a, b); }
static inline __host__ __device__ bool isnan(double a) { return ::isnan(a); }
static inline __host__ __device__ bool isinf(double a) { return ::isinf(a); }
};
// WARNING: The following note is deprecated
/// `half` has some type conversion issues associated with it, since it
/// is a struct without a constructor/implicit conversion constructor.
/// We use this to convert scalar values to the given type that the
/// tensor expects.
///
/// at::Half has implicit conversions for float and __half types. Moreover
/// it has constructors for __half and float types.
template <typename In, typename Out>
struct ScalarConvert {
static __host__ __device__ Out to(const In v) { return (Out) v; }
};
// DEPRECATED: use static_cast in kernels instead of scalar_cast
template <typename T, typename U>
__host__ __device__ T scalar_cast(U u) {
return ScalarConvert<U, T>::to(u);
}
#endif // THC_NUMERICS_INC