-
Notifications
You must be signed in to change notification settings - Fork 3
/
THCTensorMathPairwise.cu
56 lines (45 loc) · 1.26 KB
/
THCTensorMathPairwise.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include <THC/THCTensorMath.h>
#include <THC/THCGeneral.h>
#include <TH/THHalf.h>
#include <THC/THCTensorCopy.h>
#include <THC/THCApply.cuh>
#include <THC/THCNumerics.cuh>
#include <THC/THCTensorMathCompareT.cuh>
#include <THC/THCTensor.hpp>
template <typename T>
struct TensorMulConstantOp {
TensorMulConstantOp(T v) : val(v) {}
__device__ __forceinline__ void operator()(T* out, T* in) {
*out = *in * val;
}
__device__ __forceinline__ void operator()(T* v) {
*v *= val;
}
const T val;
};
template <typename T>
struct TensorFmodOp {
TensorFmodOp(T v) : val((float)v) {}
__device__ __forceinline__ void operator()(T* out, T* in) {
*out = (T) fmodf((float) *in, val);
}
__device__ __forceinline__ void operator()(T* v) {
*v = (T) fmodf((float) *v, val);
}
const float val;
};
template <>
struct TensorFmodOp<double> {
TensorFmodOp(double v) : val(v) {}
__device__ __forceinline__ void operator()(double* out, double* in) {
*out = fmod(*in, val);
}
__device__ __forceinline__ void operator()(double* v) {
*v = fmod(*v, val);
}
const double val;
};
#include <THC/generic/THCTensorMathPairwise.cu>
#include <THC/THCGenerateAllTypes.h>
#include <THC/generic/THCTensorMathPairwise.cu>
#include <THC/THCGenerateBoolType.h>