-
Notifications
You must be signed in to change notification settings - Fork 3
/
THCTensorMathReduce.cuh
348 lines (296 loc) · 9.42 KB
/
THCTensorMathReduce.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
#ifndef THC_TENSORMATH_REDUCE_CUH
#define THC_TENSORMATH_REDUCE_CUH
#include <THC/THCTensorMath.h>
#include <THC/THCGeneral.h>
#include <THC/THCNumerics.cuh>
#include <THC/THCReduce.cuh>
#include <THC/THCReduceAll.cuh>
#include <THC/THCTensorCopy.hpp>
#include <THC/THCThrustAllocator.cuh>
#include <thrust/functional.h>
#include <thrust/device_ptr.h>
#include <thrust/transform_reduce.h>
#include <thrust/inner_product.h>
#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
#include <thrust/system/cuda/execution_policy.h>
#endif
/*
Reductions that (only) operate on accumulate types.
*/
template <typename T, typename U>
struct WelfordData {
T mean_;
T m_2_n_;
int count_; // do we need int64_t?
__host__ __device__ WelfordData() {
}
// stripping initialization from default constructor to avoid dynamic
// initialization warning thrown from using this data structure in CUDA kernel
// as static shared memory.
__host__ __device__ void reset() {
mean_ = T(0);
m_2_n_ = T(0);
count_ = 0;
}
__host__ __device__ WelfordData(const U data_) {
mean_ = static_cast<T>(data_);
m_2_n_ = static_cast<T>(0);
count_ = 1;
}
__host__ __device__ WelfordData(const WelfordData &t) :
mean_(t.mean_),
m_2_n_(t.m_2_n_),
count_(t.count_)
{
}
__host__ __device__ WelfordData(const volatile WelfordData &t) :
mean_(t.mean_),
m_2_n_(t.m_2_n_),
count_(t.count_)
{
}
__host__ __device__ volatile WelfordData& operator = (const volatile WelfordData &t) volatile {
mean_ = t.mean_;
m_2_n_ = t.m_2_n_;
count_ = t.count_;
return *this;
}
__host__ __device__ WelfordData& operator = (const WelfordData &t) {
mean_ = t.mean_;
m_2_n_ = t.m_2_n_;
count_ = t.count_;
return *this;
}
};
template <typename T>
struct ModifyWelford {
inline __device__ T operator()(const T &a) const {
return a;
}
};
template <typename T, typename U>
struct ReduceWelford {
inline __device__ WelfordData<T, U> operator()(const WelfordData<T, U> &a, const WelfordData<T, U> &b) const {
WelfordData<T, U> c;
c.count_ = THCNumerics<int>::add(a.count_, b.count_);
T factor = THCNumerics<T>::div(1.0, max(1, c.count_));
c.mean_ = THCNumerics<T>::mul(THCNumerics<T>::add(THCNumerics<T>::mul(a.mean_, a.count_), THCNumerics<T>::mul(b.mean_, b.count_)), factor);
c.m_2_n_ = THCNumerics<T>::add(a.m_2_n_, THCNumerics<T>::add(b.m_2_n_, THCNumerics<T>::mul(factor, THCNumerics<T>::mul(a.count_, THCNumerics<T>::mul(b.count_, THCNumerics<T>::pow(THCNumerics<T>::sub(a.mean_, b.mean_), 2) )))));
return c;
}
};
template <typename T, typename U>
struct VarianceWelford {
VarianceWelford(const int _unbiased, const bool _apply_sqrt): unbiased{_unbiased}, apply_sqrt(_apply_sqrt) {}
inline __device__ T operator()(const WelfordData<T, U> &a) const {
T res = THCNumerics<T>::div(a.m_2_n_, unbiased ? a.count_ : a.count_-1);
if (apply_sqrt) {
return THCNumerics<T>::sqrt(res);
}
return res;
}
const int unbiased;
const bool apply_sqrt;
};
template <typename T>
struct ReduceAdd {
inline __device__ T operator()(const T a, const T b) const {
return THCNumerics<T>::add(a, b);
}
};
template <typename T>
struct ReduceMultiply {
inline __device__ T operator()(const T a, const T b) const {
return THCNumerics<T>::mul(a, b);
}
};
template <typename T>
struct ReduceDivide {
ReduceDivide(const T _divisor): divisor{_divisor} {}
inline __device__ T operator()(const T x) const {
return THCNumerics<T>::div(x, divisor);
}
const T divisor;
};
template <typename T>
struct ReducePow {
ReducePow(const T _exponent): exponent{_exponent} {}
inline __device__ T operator()(const T x) const {
return THCNumerics<T>::pow(x, exponent);
}
const T exponent;
};
template <typename T>
struct SquareFunctor {
SquareFunctor(const T _mean): mean{_mean} {}
inline __device__ T operator()(const T x) const {
return THCNumerics<T>::mul(
THCNumerics<T>::sub(x, mean),
THCNumerics<T>::sub(x, mean)
);
}
const T mean;
};
struct LogicalAll {
inline __device__ unsigned char operator()(const unsigned char x,
const unsigned char y) const {
return (x && y);
}
};
struct LogicalAny {
inline __device__ unsigned char operator()(const unsigned char x,
const unsigned char y) const {
return (x || y);
}
};
template<typename T>
inline __device__ T THCMax(const T a, const T b) {
return THCNumerics<T>::gt(a, b) ? a : b;
}
template<typename T, typename AccT>
__global__ void THCTensor_kernel_renorm(T *data,
const AccT value,
const ptrdiff_t size,
const AccT maxnorm) {
__shared__ AccT buffer[32];
int64_t tx = threadIdx.x;
int64_t bx = blockIdx.x;
int64_t step = blockDim.x;
T *row = data + size * bx;
buffer[tx] = scalar_cast<AccT>(0);
AccT norm;
if (THCNumerics<AccT>::eq(value, scalar_cast<AccT, float>(INFINITY))) {
// get norm of axis
for (ptrdiff_t i = tx; i < size; i += step) {
const AccT val = scalar_cast<AccT>(row[i]);
buffer[tx] = THCMax<AccT>(buffer[tx], static_cast<AccT>(std::abs(val)));
}
// add (reduce)
for (unsigned int stride = blockDim.x >> 1; stride > 0; stride >>= 1) {
__syncthreads();
if (tx < stride)
buffer[tx] = THCMax<AccT>(buffer[tx], buffer[tx+stride]);
}
// clip norms
__syncthreads();
norm = buffer[0];
} else {
// get norm of axis
for (ptrdiff_t i = tx; i < size; i += step) {
const AccT val = scalar_cast<AccT>(row[i]);
buffer[tx] = THCNumerics<AccT>::add(
buffer[tx],
THCNumerics<AccT>::pow(static_cast<AccT>(std::abs(val)), value)
);
}
// add (reduce)
for (unsigned int stride = blockDim.x >> 1; stride > 0; stride >>= 1) {
__syncthreads();
if (tx < stride)
buffer[tx] = THCNumerics<AccT>::add(buffer[tx], buffer[tx+stride]);
}
// clip norms
__syncthreads();
norm = THCNumerics<AccT>::pow(buffer[0], static_cast<AccT>(1) / value);
}
if (THCNumerics<AccT>::gt(norm, maxnorm)) {
norm = THCNumerics<AccT>::div(
maxnorm,
THCNumerics<AccT>::add(norm, scalar_cast<AccT>(1e-7))
);
// renormalize
for (ptrdiff_t i = tx; i < size; i += step) {
const AccT val = scalar_cast<AccT>(row[i]);
row[i] = scalar_cast<T>(THCNumerics<AccT>::mul(val, norm));
}
}
}
template <typename T>
struct TensorNonZeroOp {
TensorNonZeroOp() {}
__host__ __device__ T operator()(const T lhs) const {
const T zero = scalar_cast<T>(0);
if (THCNumerics<T>::eq(lhs, zero)) return zero;
return scalar_cast<T>(1);
}
};
/*
Fuses conversions and a TensorDistOp. Needed for Thrust.
*/
template <typename T, typename AccT>
struct ThrustTensorDistOp {
ThrustTensorDistOp(AccT _exponent) : exponent{_exponent} {}
__host__ __device__ AccT operator()(T _x, T _y) const {
const AccT x = scalar_cast<AccT>(_x);
const AccT y = scalar_cast<AccT>(_y);
if (THCNumerics<AccT>::eq(exponent, scalar_cast<AccT, float>(0))) {
const AccT zero = scalar_cast<AccT>(0);
if (THCNumerics<AccT>::eq(THCNumerics<AccT>::sub(x, y), zero))return zero;
return scalar_cast<AccT>(1);
}
if (THCNumerics<AccT>::eq(exponent, scalar_cast<AccT, float>(1))) {
return static_cast<AccT>(std::abs(THCNumerics<AccT>::sub(x, y)));
} else if (THCNumerics<AccT>::eq(exponent, scalar_cast<AccT, float>(2))) {
return THCNumerics<AccT>::pow(
THCNumerics<AccT>::sub(x, y), exponent);
} else {
return THCNumerics<AccT>::pow(
static_cast<AccT>(std::abs(THCNumerics<AccT>::sub(x, y))),
exponent);
}
}
const AccT exponent;
};
#include <thrust/functional.h>
// Given the sum of values and the sum of squares, compute the variance or standard deviation.
template<typename T, bool flag, bool apply_sqrt>
__forceinline__ __device__ T THCTensor_computeVar(
T sum,
T sum2,
const unsigned row_size) {
T rs2 = scalar_cast<T>(row_size);
T rs2m = scalar_cast<T>(row_size - 1);
T zero = scalar_cast<T>(0);
if (flag) {
sum = THCNumerics<T>::div(sum, rs2);
sum2 = THCNumerics<T>::div(sum2, rs2);
sum2 = THCNumerics<T>::sub(sum2, THCNumerics<T>::mul(sum, sum));
sum2 = (THCNumerics<T>::lt(sum2, zero) ? zero : sum2);
} else {
sum = THCNumerics<T>::div(sum, rs2);
sum2 = THCNumerics<T>::div(sum2, rs2m);
sum2 = THCNumerics<T>::sub(sum2,
THCNumerics<T>::mul(
THCNumerics<T>::div(rs2 ,rs2m),
THCNumerics<T>::mul(sum, sum)));
sum2 = (THCNumerics<T>::lt(sum2, zero) ? zero : sum2);
}
if (apply_sqrt)
return THCNumerics<T>::sqrt(sum2);
return sum2;
}
template <typename T>
struct AddOp {
__device__ __forceinline__ T operator()(T const &lhs, T const &rhs) {
return THCNumerics<T>::add(lhs, rhs);
}
};
template <typename T>
struct MulOp {
__device__ __forceinline__ T operator()(T const &lhs, T const &rhs) {
return THCNumerics<T>::mul(lhs, rhs);
}
};
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(T const &lhs, T const &rhs) {
return THCNumerics<T>::gt(lhs, rhs) ? lhs : rhs;
}
};
template <typename T>
struct MinOp {
__device__ __forceinline__ T operator()(T const &lhs, T const &rhs) {
return THCNumerics<T>::lt(lhs, rhs) ? lhs : rhs;
}
};
#endif // THC_TENSORMATH_REDUCE_CUH