-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathActivationThresholdKernel.cu
54 lines (44 loc) · 1.24 KB
/
ActivationThresholdKernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#define TORCH_ASSERT_NO_OPERATORS
#define _USE_MATH_DEFINES
#include <ATen/native/Activation.h>
#include <cmath>
#include <thrust/tuple.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/core/TensorBase.h>
#include <c10/core/Scalar.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/cuda/ApplyGridUtils.cuh>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/native/cuda/Loops.cuh>
namespace at {
namespace native {
namespace {
template <typename scalar_t>
void threshold_kernel_impl(
TensorIteratorBase& iter,
scalar_t threshold,
scalar_t value) {
gpu_kernel_with_scalars(
iter, [=] GPU_LAMBDA(scalar_t x, scalar_t other) -> scalar_t {
return x <= threshold ? value : other;
});
}
static void threshold_kernel_cuda(
TensorIteratorBase& iter,
const Scalar& threshold,
const Scalar& value) {
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
iter.dtype(),
"threshold_cuda",
[&] {
threshold_kernel_impl<scalar_t>(
iter, threshold.to<scalar_t>(), value.to<scalar_t>());
});
}
} // namespace
REGISTER_DISPATCH(threshold_stub, &threshold_kernel_cuda);
} // namespace native
} // namespace at