Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

[FEATURE] Use RTC for reduction ops #19426

Merged
merged 29 commits into from
May 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
3ca492a
Initial rebase
ptrendx Oct 19, 2020
6d4fcc3
Merge commit '3faf6df22' into pr_rtc_reduce_ops
ptrendx Oct 26, 2020
fd0db00
Merge commit '75c62166e' into pr_rtc_reduce_ops
ptrendx Oct 26, 2020
91f715f
Fixes after merge
ptrendx Oct 27, 2020
a3694b9
Merge commit '187c75d6b' into pr_rtc_reduce_ops
ptrendx Oct 28, 2020
46060a1
Merge commit '95f9ea2c8' into pr_rtc_reduce_ops
ptrendx Oct 28, 2020
fc8b771
Merge commit '4b3be14a4' into pr_rtc_reduce_ops
ptrendx Oct 28, 2020
22b669d
Merge commit '8dc3652ab' into pr_rtc_reduce_ops
ptrendx Oct 28, 2020
a08fa72
Merge branch 'upstream' into pr_rtc_reduce_ops
ptrendx Oct 28, 2020
fe5656f
Fixes
ptrendx Oct 29, 2020
bd61456
Merge branch 'upstream' into pr_rtc_reduce_ops
ptrendx Oct 29, 2020
f44314a
Fix lint
ptrendx Oct 29, 2020
9b561e3
Fix lint for real
ptrendx Oct 29, 2020
9acc3f5
Cleaning and code reuse
ptrendx Oct 30, 2020
4d14726
Fix lint
ptrendx Oct 31, 2020
7541de0
Try to WAR the maybe-uninitialized warning
ptrendx Nov 2, 2020
ac278b1
Second try
ptrendx Nov 2, 2020
d133862
Fix Windows compilation
ptrendx Nov 2, 2020
9008e86
More fixes for Windows compilation
ptrendx Nov 2, 2020
245ec89
Breaking the strings to please Windows compiler
ptrendx Nov 5, 2020
8149bbd
Do not use the default stream in kron
ptrendx Nov 5, 2020
9194ea9
Fix argmin/argmax
ptrendx Nov 6, 2020
d6f4311
Merge branch 'upstream' into pr_rtc_reduce_ops
ptrendx Nov 30, 2020
442da8c
Merge branch 'upstream' into pr_rtc_reduce_ops
ptrendx Dec 18, 2020
0980621
Merge branch 'upstream' into pr_rtc_reduce_ops
ptrendx Jan 4, 2021
cff6699
Merge branch 'upstream' into pr_rtc_reduce_ops
ptrendx Apr 22, 2021
2ff74b2
Fix layernorm
ptrendx Apr 22, 2021
4a6f89a
Merge branch 'upstream' into pr_rtc_reduce_ops
ptrendx May 19, 2021
c457f55
Merge branch 'upstream' into pr_rtc_reduce_ops
ptrendx May 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/common/cuda/rtc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,16 @@ CUfunction get_function(const std::string &parameters,
std::string(fp16_support_string) + "\n" +
type_support_string + "\n" +
util_string + "\n" +
limits + "\n" +
special_functions_definitions + '\n' +
vectorization_support_string + "\n" +
function_definitions_util + "\n" +
function_definitions_binary + "\n" +
function_definitions_unary + "\n" +
backward_function_definitions + "\n" +
reducer + "\n";
grad_function_definitions + "\n" +
reducer + "\n" +
logic_reducer + "\n";
std::string code_with_header = common_header + parameters + code;
// If verbose mode, output kernel source, though not including the common header
if (dmlc::GetEnv("MXNET_RTC_VERBOSE", false)) {
Expand Down
240 changes: 159 additions & 81 deletions src/common/cuda/rtc/backward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,98 @@ backward_square(const DTypeGrad grad, const DType val) {
return 2 * val * grad;
}

template <typename DType, typename DType2>
__device__ inline DType div_rgrad(const DType val,
const DType2 val2) {
return -val / (val2 * val2);
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_clip(const DTypeGrad grad, const DType val,
const float a_min, const float a_max) {
if (val > a_max || val < a_min) {
return 0;
} else {
return grad;
}
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_reciprocal(const DTypeGrad grad, const DType val) {
return -grad / (val * val);
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_erf(const DTypeGrad grad, const DType val) {
using type = mixed_type<DTypeGrad, DType>;
const type v = val;
constexpr type my_pi = pi;
return 2.0f / op::sqrt(my_pi) * op::exp(-(v*v)) * grad;
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_erfinv(const DTypeGrad grad, const DType val) {
using type = mixed_type<DTypeGrad, DType>;
constexpr type my_pi = pi;
const type g = grad;
const type v = val;
return 0.5f * op::sqrt(my_pi) * op::exp(v * v) * g;
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_gamma(const DTypeGrad grad, const DType val) {
using type = mixed_type<DTypeGrad, DType>;
const type v = val;
if (type_util::is_same<DTypeGrad, double>::value) {
return grad * op::gamma(v) * op::special_functions::cephes::psi<double>(v);
} else {
return grad * op::gamma(v) * op::special_functions::cephes::psi<float>(v);
}
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_gammaln(const DTypeGrad grad, const DType val) {
using type = mixed_type<DTypeGrad, DType>;
const type v = val;
if (type_util::is_same<DTypeGrad, double>::value) {
return grad * op::special_functions::cephes::psi<double>(v);
} else {
return grad * op::special_functions::cephes::psi<float>(v);
}
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_digamma(const DTypeGrad grad, const DType val) {
using type = mixed_type<DTypeGrad, DType>;
const type v = val;
if (type_util::is_same<DTypeGrad, double>::value) {
return grad * op::special_functions::trigamma<double>(v);
} else {
return grad * op::special_functions::trigamma<float>(v);
}
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_gelu(const DTypeGrad grad, const DType val) {
return 0.5f * (grad + grad * op::erf(val / op::sqrt(2.0f)) +
val * backward_erf(grad, val / op::sqrt(2.0f)) / op::sqrt(2.0f));
}

} // namespace op

)code";

const char grad_function_definitions[] = R"code(
namespace op {

template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2>
rdiv_grad(const DType val,
Expand All @@ -252,12 +344,6 @@ div_grad(const DType val,
return op::reciprocal(temp);
}

template <typename DType, typename DType2>
__device__ inline DType div_rgrad(const DType val,
const DType2 val2) {
return -val / (val2 * val2);
}

template <typename DType, typename DType2>
__device__ inline DType mod_grad(const DType val,
const DType2 val2) {
Expand Down Expand Up @@ -368,80 +454,6 @@ rldexp_grad(const DType val,
return val2 * op::power(static_cast<type>(2), val) * op::log(static_cast<type>(2));
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_clip(const DTypeGrad grad, const DType val,
const float a_min, const float a_max) {
if (val > a_max || val < a_min) {
return 0;
} else {
return grad;
}
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_reciprocal(const DTypeGrad grad, const DType val) {
return -grad / (val * val);
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_erf(const DTypeGrad grad, const DType val) {
const mixed_type<DTypeGrad, DType> v = val;
constexpr mixed_type<DTypeGrad, DType> my_pi = pi;
return 2.0f / op::sqrt(my_pi) * op::exp(-(v*v)) * grad;
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_erfinv(const DTypeGrad grad, const DType val) {
constexpr mixed_type<DTypeGrad, DType> my_pi = pi;
const mixed_type<DTypeGrad, DType> g = grad;
const mixed_type<DTypeGrad, DType> v = val;
return 0.5f * op::sqrt(my_pi) * op::exp(v * v) * g;
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_gamma(const DTypeGrad grad, const DType val) {
const mixed_type<DTypeGrad, DType> v = val;
if (type_util::is_same<DTypeGrad, double>::value) {
return grad * op::gamma(v) * op::special_functions::cephes::psi<double>(v);
} else {
return grad * op::gamma(v) * op::special_functions::cephes::psi<float>(v);
}
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_gammaln(const DTypeGrad grad, const DType val) {
const mixed_type<DTypeGrad, DType> v = val;
if (type_util::is_same<DTypeGrad, double>::value) {
return grad * op::special_functions::cephes::psi<double>(v);
} else {
return grad * op::special_functions::cephes::psi<float>(v);
}
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_digamma(const DTypeGrad grad, const DType val) {
const mixed_type<DTypeGrad, DType> v = val;
if (type_util::is_same<DTypeGrad, double>::value) {
return grad * op::special_functions::trigamma<double>(v);
} else {
return grad * op::special_functions::trigamma<float>(v);
}
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_gelu(const DTypeGrad grad, const DType val) {
return 0.5f * (grad + grad * op::erf(val / op::sqrt(2.0f)) +
val * backward_erf(grad, val / op::sqrt(2.0f)) / op::sqrt(2.0f));
}

template <typename DType, typename DType2>
__device__ inline DType smooth_l1_grad(const DType val, const DType2 scalar) {
auto bsq = scalar * scalar;
Expand All @@ -467,8 +479,74 @@ __device__ inline DType prelu_grad(const DType val,
return (val > 0) ? 0 : val;
}

} // namespace op
template <typename DType, typename DType2>
__device__ inline mixed_type<DType2, DType>
gamma_implicit_grad(const DType a_in, const DType2 x_in) {
using OType = mixed_type<DType2, DType>;
const OType a = a_in;
const OType x = x_in;
if (x < 0.8f) {
OType numer = 1;
OType denom = a;
OType series1 = numer / denom;
OType series2 = numer / (denom * denom);
#pragma unroll
for (int i = 1; i <= 5; i++) {
numer *= -x / static_cast<DType>(i);
denom += 1;
series1 += numer / denom;
series2 += numer / (denom * denom);
}
OType pow_x_alpha = op::power(x, a);
OType gamma_pdf = op::power(x, a - 1) * op::exp(-x);
OType gamma_cdf = pow_x_alpha * series1;
OType gamma_cdf_alpha =
(op::log(x) - OType(special_functions::cephes::psi<float>(a))) *
gamma_cdf -
pow_x_alpha * series2;
OType result = -gamma_cdf_alpha / gamma_pdf;
return op::isnan(result) ? 0.f : result;
}
if (a > 8.0f) {
if (0.9f * a <= x && x <= 1.1f * a) {
OType numer_1 = 1 + 24 * a * (1 + 12 * a);
OType numer_2 = 1440 * (a * a) + 6 * x * (53 - 120 * x) -
65 * x * x / a + a * (107 + 3600 * x);
OType denom = 1244160 * (a * a) * (a * a);
return numer_1 * numer_2 / denom;
}
OType denom = op::sqrt(8 * a);
OType term2 = denom / (a - x);
OType term3 =
op::power(x - a - a * op::log(x / a), static_cast<OType>(-1.5));
OType term23 = (x < a) ? term2 - term3 : term2 + term3;
OType term1 = op::log(x / a) * term23 -
op::sqrt(2 / a) * (a + x) / ((a - x) * (a - x));
OType stirling = 1.f + 1.f / (12.f * a) * (1.f + 1.f / (24.f * a));
OType numer = x * term1;
return -stirling * numer / denom;
}
OType u = op::log(x / a);
OType v = op::log(a);
OType coef_uv[3][8] = {
{0.16009398, -0.094634809, 0.025146376, -0.0030648343, 1, 0.32668115,
0.10406089, 0.0014179084},
{0.53487893, 0.1298071, 0.065735949, -0.0015649758, 0.16639465,
0.020070113, -0.0035938915, -0.00058392623},
{0.040121004, -0.0065914022, -0.0026286047, -0.0013441777, 0.017050642,
-0.0021309326, 0.00085092367, -1.5247877e-07},
};
OType coef_v[8];
#pragma unroll
for (int i = 0; i < 8; i++) {
coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
}
OType p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
OType q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
return op::exp(p / q);
}

} // namespace op
)code";

} // namespace rtc
Expand Down
9 changes: 9 additions & 0 deletions src/common/cuda/rtc/forward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,10 @@ __device__ inline DType log_sigmoid(const DType val) {

template <typename DType>
__device__ inline DType softrelu(const DType val) {
// Avoid overflow of exp for large inputs.
// The threshold 20 is chosen such that softrelu(a) = a
// for a > 20 using floating precision.
if (val > 20) return val;
if (type_util::has_double_or_integral<DType>::value) {
return ::log(1 + ::exp(val));
} else {
Expand Down Expand Up @@ -936,6 +940,11 @@ __device__ inline bool_t np_logical_not(const DType val) {
return !static_cast<bool>(val);
}

template <typename DType>
__device__ inline bool_t NonZero(const DType val) {
return val != 0;
}

#undef DEFINE_UNARY_MATH_FUNC

template <typename DType>
Expand Down
Loading