From 3ca492aa6383fbe1fc9a916b88c86f74448c63b9 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 19 Oct 2020 13:34:39 -0700 Subject: [PATCH 01/15] Initial rebase --- src/common/cuda/rtc/backward_functions-inl.h | 67 +++ src/common/cuda/rtc/forward_functions-inl.h | 5 + src/common/cuda/rtc/reducer-inl.h | 399 +++++++++++++++++ src/operator/mshadow_op.h | 2 +- src/operator/nn/group_norm-inl.h | 114 +++-- src/operator/nn/layer_norm-inl.h | 361 ++++++++++----- src/operator/nn/moments-inl.h | 9 + .../broadcast_reduce_customized-inl.cuh | 415 ------------------ .../linalg/broadcast_reduce_customized-inl.h | 7 - src/operator/numpy/linalg/np_norm-inl.h | 98 +++-- src/operator/numpy/np_broadcast_reduce_op.h | 95 +++- .../numpy/np_broadcast_reduce_op_boolean.cu | 8 +- .../numpy/np_broadcast_reduce_op_value.cu | 16 +- src/operator/numpy/np_constraint_check.h | 5 + src/operator/numpy/np_cross-inl.h | 5 + src/operator/numpy/np_kron-inl.h | 11 +- src/operator/numpy/np_tensordot_op-inl.h | 23 +- src/operator/numpy/np_where_op-inl.h | 51 ++- src/operator/numpy/random/dist_common.h | 80 ++++ src/operator/numpy/random/np_exponential_op.h | 5 + src/operator/numpy/random/np_gamma_op.h | 5 + .../numpy/random/np_location_scale_op.h | 72 +-- src/operator/numpy/random/np_normal_op.h | 74 +--- src/operator/numpy/random/np_pareto_op.h | 28 +- src/operator/numpy/random/np_rayleigh_op.h | 28 +- src/operator/numpy/random/np_weibull_op.h | 28 +- src/operator/quantization/quantize_v2-inl.h | 7 + src/operator/quantization/requantize-inl.h | 13 + src/operator/random/pdf_op.h | 19 +- src/operator/tensor/broadcast_reduce-inl.cuh | 408 ----------------- src/operator/tensor/broadcast_reduce-inl.h | 4 - .../tensor/broadcast_reduce_minmax_value.cu | 6 +- src/operator/tensor/broadcast_reduce_op.cc | 201 +++++++++ src/operator/tensor/broadcast_reduce_op.h | 50 ++- .../tensor/broadcast_reduce_op_value.cu | 3 +- .../tensor/broadcast_reduce_prod_value.cu | 6 +- .../tensor/broadcast_reduce_sum_value.cu | 9 +- .../tensor/elemwise_binary_broadcast_op.cc | 8 +- src/operator/tensor/matrix_op-inl.h | 8 + src/operator/tensor/reduce_rtc.cc | 24 +- tests/python/unittest/test_numpy_op.py | 6 +- 41 files changed, 1454 insertions(+), 1329 deletions(-) delete mode 100644 src/operator/numpy/linalg/broadcast_reduce_customized-inl.cuh delete mode 100644 src/operator/tensor/broadcast_reduce-inl.cuh create mode 100644 src/operator/tensor/broadcast_reduce_op.cc diff --git a/src/common/cuda/rtc/backward_functions-inl.h b/src/common/cuda/rtc/backward_functions-inl.h index 168dc686e7ad..17c0190ef3fe 100644 --- a/src/common/cuda/rtc/backward_functions-inl.h +++ b/src/common/cuda/rtc/backward_functions-inl.h @@ -466,6 +466,73 @@ __device__ inline DType prelu_grad(const DType val, return (val > 0) ? 0 : val; } +template +__device__ inline typename type_util::mixed_type::type +gamma_implicit_grad(const DType a_in, const DType2 x_in) { + using OType = typename type_util::mixed_type::type; + const OType a = a_in; + const OType x = x_in; + if (x < 0.8f) { + OType numer = 1; + OType denom = a; + OType series1 = numer / denom; + OType series2 = numer / (denom * denom); +#pragma unroll + for (int i = 1; i <= 5; i++) { + numer *= -x / static_cast(i); + denom += 1; + series1 += numer / denom; + series2 += numer / (denom * denom); + } + OType pow_x_alpha = op::power(x, a); + OType gamma_pdf = op::power(x, a - 1) * op::exp(-x); + OType gamma_cdf = pow_x_alpha * series1; + OType gamma_cdf_alpha = + (op::log(x) - OType(special_functions::cephes::psi(a))) * + gamma_cdf - + pow_x_alpha * series2; + OType result = -gamma_cdf_alpha / gamma_pdf; + return op::isnan(result) ? 0.f : result; + } + if (a > 8.0f) { + if (0.9f * a <= x && x <= 1.1f * a) { + OType numer_1 = 1 + 24 * a * (1 + 12 * a); + OType numer_2 = 1440 * (a * a) + 6 * x * (53 - 120 * x) - + 65 * x * x / a + a * (107 + 3600 * x); + OType denom = 1244160 * (a * a) * (a * a); + return numer_1 * numer_2 / denom; + } + OType denom = op::sqrt(8 * a); + OType term2 = denom / (a - x); + OType term3 = + op::power(x - a - a * op::log(x / a), static_cast(-1.5)); + OType term23 = (x < a) ? term2 - term3 : term2 + term3; + OType term1 = op::log(x / a) * term23 - + op::sqrt(2 / a) * (a + x) / ((a - x) * (a - x)); + OType stirling = 1.f + 1.f / (12.f * a) * (1.f + 1.f / (24.f * a)); + OType numer = x * term1; + return -stirling * numer / denom; + } + OType u = op::log(x / a); + OType v = op::log(a); + OType coef_uv[3][8] = { + {0.16009398, -0.094634809, 0.025146376, -0.0030648343, 1, 0.32668115, + 0.10406089, 0.0014179084}, + {0.53487893, 0.1298071, 0.065735949, -0.0015649758, 0.16639465, + 0.020070113, -0.0035938915, -0.00058392623}, + {0.040121004, -0.0065914022, -0.0026286047, -0.0013441777, 0.017050642, + -0.0021309326, 0.00085092367, -1.5247877e-07}, + }; + OType coef_v[8]; +#pragma unroll + for (int i = 0; i < 8; i++) { + coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]); + } + OType p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3])); + OType q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7])); + return op::exp(p / q); +} + } // namespace op )code"; diff --git a/src/common/cuda/rtc/forward_functions-inl.h b/src/common/cuda/rtc/forward_functions-inl.h index 6568bae58318..b872d5b27b5f 100644 --- a/src/common/cuda/rtc/forward_functions-inl.h +++ b/src/common/cuda/rtc/forward_functions-inl.h @@ -898,6 +898,11 @@ __device__ inline bool_t np_logical_not(const DType val) { return !static_cast(val); } +template +__device__ inline bool_t NonZero(const DType val) { + return val != 0; +} + #undef DEFINE_UNARY_MATH_FUNC template diff --git a/src/common/cuda/rtc/reducer-inl.h b/src/common/cuda/rtc/reducer-inl.h index 93b702788c46..259d0e060a57 100644 --- a/src/common/cuda/rtc/reducer-inl.h +++ b/src/common/cuda/rtc/reducer-inl.h @@ -94,6 +94,405 @@ struct sum { residual = 0; } }; + +/*! \brief maximum reducer */ +struct maximum { + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { // NOLINT(*) + if (!util::isnan(dst)) { + if (!(dst >= src)) dst = src; + } + } + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src, + volatile DType& none) { + Reduce(dst, src); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual, + volatile DType& src_val, volatile DType& src_residual) { + Reduce(dst_val, src_val); + } + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst) {} + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {} + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv) { + initv = -2*DBL_MAX; + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv, DType &none) { + SetInitValue(initv); + } +}; + +/*! \brief minimum reducer */ +struct minimum { + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { + if (!util::isnan(dst)) { + if (!(dst <= src)) dst = src; + } + } + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src, + volatile DType& none) { + Reduce(dst, src); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual, + volatile DType& src_val, volatile DType& src_residual) { + Reduce(dst_val, src_val); + } + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst) {} + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {} + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv) { + initv = 2*DBL_MAX; + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv, DType &none) { + SetInitValue(initv); + } +}; + +/*! \brief product reducer */ +struct product { + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { + dst = op::mul(dst, src); + } + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src, + volatile DType& none) { + Reduce(dst, src); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual, + volatile DType& src_val, volatile DType& src_residual) { + Reduce(dst_val, src_val); + } + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst) {} + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {} + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv) { + initv = 1; + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv, DType &none) { + SetInitValue(initv); + } +}; + +/*! \brief sum reducer that ignores NaN values in the input */ +struct nansum { + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { + if (util::isnan(src)) return; + dst = op::add(dst, src); + } + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType src, + volatile DType& residual) { + if (util::isnan(src)) return; + DType y = src - residual; + DType t = dst + y; + residual = (t - dst) - y; + dst = t; + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual, + volatile DType& src_val, volatile DType& src_residual) { + DType t1 = dst_val + src_val; + DType e = t1 - src_val; + DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual; + dst_val = t1 + t2; + dst_residual = t2 - (dst_val - t1); + } + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst) {} + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {} + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType & initv) { + initv = 0; + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv, DType &residual) { + SetInitValue(initv); + residual = 0; + } +}; + +/*! \brief product reducer that ignores NaN values in the input */ +struct nanprod { + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { + if (util::isnan(src)) return; + dst = op::mul(dst, src); + } + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType src, + volatile DType& none) { + Reduce(dst, src); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual, + volatile DType& src_val, volatile DType& src_residual) { + Reduce(dst_val, src_val); + } + /*! \brief finalize reduction */ + template + __device__ inline static void Finalize(volatile DType& dst) {} + /*! \brief finalize reduction */ + template + __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {} + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType & initv) { + initv = 1; + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv, DType &none) { + SetInitValue(initv); + } +}; + +struct nrm2 { + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile AType& sum_of_squares, volatile DType src) { + sum_of_squares = op::add(sum_of_square, src * src); + } + /*! \brief do stable reduction into dst */ + template + __device__ inline static void Reduce(volatile AType& sum_of_squares, + volatile DType src, volatile DType& scale) { + if (src != 0) { + DType abs = op::abs(src); + if (scale < abs) { + sum_of_squares = 1 + sum_of_squares * (scale / abs) * (scale / abs); + scale = abs; + } else { + sum_of_squares = sum_of_squares + (abs / scale) * (abs / scale); + } + } + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { + dst_val = op::add(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_ssq, volatile DType& dst_scale, + volatile DType& src_ssq, volatile DType& src_scale) { + if (dst_scale != 0 && dst_scale >= src_scale) { + dst_ssq = dst_ssq + src_ssq * (src_scale / dst_scale) * (src_scale / dst_scale); + } else if (src_scale != 0 && dst_scale < src_scale) { + dst_ssq = src_ssq + dst_ssq * (dst_scale / src_scale) * (dst_scale / src_scale); + dst_scale = src_scale; + } + } + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& sum_of_squares) { + sum_of_squares = op::sqrt(sum_of_squares); + } + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& sum_of_squares, volatile DType& scale) { + sum_of_squares = scale * op::sqrt(sum_of_squares); + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &sum_of_squares) { + sum_of_squares = 0; + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &sum_of_squares, DType &scale) { + SetInitValue(sum_of_squares); + scale = 0; + } +}; + +struct nrmlp { + double lp; + /* \brief power for Lp norm */ + __device__ inline static double lp_power(volatile double src, volatile double p) { + if (p != 0.0) { + if (src == 0.0) { + return src; + } else { + return op::power(src, p); + } + } else { // 0-norm, sparsity + return static_cast(src != 0); + } + } + + /*! \brief do reduction into dst */ + template + __device__ inline void Reduce(volatile AType& sum_of_powers, volatile DType src) { + if (src != 0) { + sum_of_powers += AType(lp_power(static_cast(src), lp)); + } + } + + /*! \brief do stable reduction into dst */ + template + __device__ inline void Reduce(volatile AType& sum_of_powers, volatile DType src, + volatile DType& scale) { + if (src != 0) { + DType src_abs = op::abs(src); + if (scale < src_abs) { + sum_of_powers = sum_of_powers * AType(lp_power(static_cast(scale / src_abs), lp)); + sum_of_powers = sum_of_powers + 1; + scale = src_abs; + } else { + sum_of_powers = sum_of_powers + AType(lp_power(static_cast(src_abs / scale), lp)); + } + } + } + + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { + dst_val = dst_val + src_val; + } + + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_ssq, volatile DType& dst_scale, + volatile DType& src_ssq, volatile DType& src_scale) { + if (dst_scale != 0 && dst_scale >= src_scale) { + dst_ssq = dst_ssq + src_ssq * DType(lp_power(static_cast(src_scale / dst_scale), 2)); + } else if (src_scale != 0 && dst_scale < src_scale) { + dst_ssq = src_ssq + dst_ssq * DType(lp_power(static_cast(dst_scale / src_scale), 2)); + dst_scale = src_scale; + } + } + + /*! \brief finalize reduction result */ + template + __device__ inline void Finalize(volatile DType& sum_of_powers) { + if (lp != 0.0) { + sum_of_powers = DType(lp_power(static_cast(sum_of_powers), 1.0 / lp)); + } + } + + /*! \brief finalize reduction result */ + template + __device__ inline void Finalize(volatile DType& sum_of_powers, volatile DType& scale) { + if (lp != 0.0) { + sum_of_powers = scale * DType(lp_power(static_cast(sum_of_powers), 1.0 / lp)); + } + } + + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &sum_of_powers) { + sum_of_powers = 0; + } + + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &sum_of_powers, DType &scale) { + SetInitValue(sum_of_powers); + scale = 0; + } +}; } // namespace red )code"; diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index ccc39ab1d8bf..71053367da90 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -904,7 +904,7 @@ template<> MSHADOW_XINLINE mshadow::half::half_t mod_rgrad::Map (mshadow::half::half_t a, mshadow::half::half_t b) { - return mshadow::half::half_t(-::floorf(static_cast(a/b))); + return mshadow::half::half_t(-::floorf(static_cast(a)/static_cast(b))); } struct rmod : public mxnet_op::tunable { diff --git a/src/operator/nn/group_norm-inl.h b/src/operator/nn/group_norm-inl.h index da30192231c7..1850062f3ac7 100644 --- a/src/operator/nn/group_norm-inl.h +++ b/src/operator/nn/group_norm-inl.h @@ -123,14 +123,23 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs, workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); // Calculate mean +#if !defined(__CUDACC__) MSHADOW_REAL_TYPE_SWITCH(data.type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { broadcast::Reduce( s, mean_, req[0], workspace, data_); - Tensor mean_data_tensor = mean_.FlatTo1D(s); - mean_data_tensor /= scalar(channel_size); }); }); +#else + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, mean_, req[0], workspace, + data_, "red::sum{}", NDim, "identity"); + }); +#endif // !defined(__CUDACC__) + MSHADOW_REAL_TYPE_SWITCH(data.type_flag_, DType, { + Tensor mean_data_tensor = mean_.FlatTo1D(s); + mean_data_tensor /= scalar(channel_size); + }); TBlob data_grp = data.reshape(temp_data_shape); const TBlob& mean_grp = mean.reshape(moments_shape); @@ -150,15 +159,25 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs, // Calculate std const TBlob centered_out = outputs[groupnorm::kOut].reshape(red_src_shape); +#if !defined(__CUDACC__) MSHADOW_REAL_TYPE_SWITCH(output_grp.type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { broadcast::Reduce( s, std_, req[0], workspace, centered_out); - Tensor std_data_tensor = std_.FlatTo1D(s); - std_data_tensor = F(std_data_tensor / scalar(channel_size) - + scalar(param.eps)); }); }); +#else + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, std_, req[0], + workspace, centered_out, + "red::sum{}", NDim, "square"); + }); +#endif + MSHADOW_REAL_TYPE_SWITCH(output_grp.type_flag_, DType, { + Tensor std_data_tensor = std_.FlatTo1D(s); + std_data_tensor = F(std_data_tensor / scalar(channel_size) + + scalar(param.eps)); + }); // Calculate data = data / std #if !defined(__CUDACC__) @@ -300,14 +319,6 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, BinaryBroadcastCompute(attrs, ctx, {normalized_data, std_}, {kWriteTo}, {normalized_data}); -#else - BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, - {data_, mean_}, - {kWriteTo}, {normalized_data}); - BinaryBroadcastRTCCompute {"div"}(attrs, ctx, - {normalized_data, std_}, - {kWriteTo}, {normalized_data}); -#endif // !defined(__CUDACC__) // Calculate grad_beta if (req[2] != kNullOp) { MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, { @@ -319,13 +330,8 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, }); } // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis) -#if !defined(__CUDACC__) ElemwiseBinaryOp::Compute(attrs, ctx, {normalized_data, ograd}, {kWriteTo}, {ograd_mult}); -#else - ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {normalized_data, ograd}, - {kWriteTo}, {ograd_mult}); -#endif // !defined(__CUDACC__) if (req[1] != kNullOp) { MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { @@ -335,6 +341,32 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, }); }); } +#else + BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, + {data_, mean_}, + {kWriteTo}, {normalized_data}); + BinaryBroadcastRTCCompute {"div"}(attrs, ctx, + {normalized_data, std_}, + {kWriteTo}, {normalized_data}); + // Calculate grad_beta + if (req[2] != kNullOp) { + BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, outputs[2].reshape(red_exclude_dst_shape), + req[2], workspace, ograd.reshape(red_exclude_src_shape), + "red::sum{}", NDim, "identity"); + }); + } + // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis) + ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {normalized_data, ograd}, + {kWriteTo}, {ograd_mult}); + if (req[1] != kNullOp) { + BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, outputs[1].reshape(red_exclude_dst_shape), + req[1], workspace, ograd_mult.reshape(red_exclude_src_shape), + "red::sum{}", NDim, "identity"); + }); + } +#endif // !defined(__CUDACC__) // Calculate grad_data: // ograd_mult = ograd * gamma / std @@ -350,15 +382,6 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, BinaryBroadcastCompute(attrs, ctx, {ograd_mult, std_}, {kWriteTo}, {ograd_mult}); -#else - BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, - {inputs[0], gamma}, - {kWriteTo}, - {ograd_mult.reshape(data.shape_)}); - BinaryBroadcastRTCCompute {"div"}(attrs, ctx, - {ograd_mult, std_}, - {kWriteTo}, {ograd_mult}); -#endif // !defined(__CUDACC__) MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { broadcast::Reduce( @@ -368,19 +391,11 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, Tensor red_out_tensor = red_out.FlatTo1D(s); red_out_tensor /= scalar(N); }); -#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, {ograd_mult, red_out}, {req[0]}, {output_}); ElemwiseBinaryOp::Compute(attrs, ctx, {ograd_mult, normalized_data}, {kWriteTo}, {ograd_mult}); -#else - BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, - {ograd_mult, red_out}, - {req[0]}, {output_}); - ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {ograd_mult, normalized_data}, - {kWriteTo}, {ograd_mult}); -#endif // !defined(__CUDACC__) MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { broadcast::Reduce( @@ -390,11 +405,38 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, Tensor red_out_tensor = red_out.FlatTo1D(s); red_out_tensor /= scalar(-N); }); -#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, {normalized_data, red_out}, {kAddTo}, {output_}); #else + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, + {inputs[0], gamma}, + {kWriteTo}, + {ograd_mult.reshape(data.shape_)}); + BinaryBroadcastRTCCompute {"div"}(attrs, ctx, + {ograd_mult, std_}, + {kWriteTo}, {ograd_mult}); + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape), "red::sum{}", NDim, "identity"); + }); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor red_out_tensor = red_out.FlatTo1D(s); + red_out_tensor /= scalar(N); + }); + BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, + {ograd_mult, red_out}, + {req[0]}, {output_}); + ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {ograd_mult, normalized_data}, + {kWriteTo}, {ograd_mult}); + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape), "red::sum{}", NDim, "identity"); + }); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor red_out_tensor = red_out.FlatTo1D(s); + red_out_tensor /= scalar(-N); + }); BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, {normalized_data, red_out}, {kAddTo}, {output_}); diff --git a/src/operator/nn/layer_norm-inl.h b/src/operator/nn/layer_norm-inl.h index b440e3d96952..0dcbdf511a03 100644 --- a/src/operator/nn/layer_norm-inl.h +++ b/src/operator/nn/layer_norm-inl.h @@ -38,6 +38,7 @@ #include "../operator_common.h" #include "../mxnet_op.h" #include "../tensor/broadcast_reduce_op.h" +#include "mxnet/tuple.h" namespace mxnet { namespace op { @@ -114,6 +115,7 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs, }); workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); +#if !defined(__CUDACC__) bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true); if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) { common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for float16 inputs for LayerNorm. " @@ -136,15 +138,9 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs, }); }); // Calculate data = data - mean -#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, {inputs[0], outputs[layernorm::kMean]}, {kWriteTo}, {outputs[0]}); -#else - BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, - {inputs[0], outputs[layernorm::kMean]}, - {kWriteTo}, {outputs[0]}); -#endif // !defined(__CUDACC__) // Calculate std const TBlob centered_out = outputs[0].reshape(red_src_shape); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { @@ -161,7 +157,6 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs, + scalar(param.eps)); }); }); -#if !defined(__CUDACC__) // Calculate data = data / std BinaryBroadcastCompute(attrs, ctx, {outputs[0], outputs[layernorm::kStd]}, @@ -175,6 +170,30 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs, {outputs[0], beta}, {kWriteTo}, {outputs[0]}); #else + // Calculate mean + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, mean_data, req[0], workspace, in_data, + "red::sum{}", NDim, "identity"); + }); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor mean_data_tensor = mean_data.FlatTo1D(s); + mean_data_tensor /= scalar(channel_size); + }); + // Calculate data = data - mean + BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, + {inputs[0], outputs[layernorm::kMean]}, + {kWriteTo}, {outputs[0]}); + // Calculate std + const TBlob centered_out = outputs[0].reshape(red_src_shape); + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, std_data, req[0], workspace, centered_out, + "red::sum{}", NDim, "square"); + }); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor std_data_tensor = std_data.FlatTo1D(s); + std_data_tensor = F(std_data_tensor / scalar(channel_size) + + scalar(param.eps)); + }); // Calculate data = data / std BinaryBroadcastRTCCompute {"div"}(attrs, ctx, {outputs[0], outputs[layernorm::kStd]}, @@ -187,7 +206,7 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs, BinaryBroadcastRTCCompute {"add"}(attrs, ctx, {outputs[0], beta}, {kWriteTo}, {outputs[0]}); -#endif // !defined(__CUDACC__) +#endif } template @@ -196,95 +215,56 @@ void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs); -/* -Calculate the gradient of layer normalization. -We have the following gradient for gamma, beta and x: - -\bar{x} = (x - mean) / std -w = og * r / std -grad_gamma = sum(\bar{x} og, exclude_axis) -grad_beta = sum(og, exclude_axis) -grad_x = w - mean(w, axis) - \bar{x} * mean(w * \bar{x}, axis) -*/ template -void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { +void LayerNormGradComputeGeneralImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const TBlob& ograd, + const TBlob& data, + const TBlob& gamma, + const TBlob& mean, + const TBlob& std, + const TBlob& normalized_data, + const TBlob& ograd_mult, + const TBlob& red_out, + const std::vector& req, + const std::vector& outputs, + const mshadow::Tensor& workspace, + const mxnet::TShape& red_dst_shape, + const mxnet::TShape& red_src_shape, + const mxnet::TShape& red_exclude_dst_shape, + const mxnet::TShape& red_exclude_src_shape, + const int channel_size); + +#ifndef __CUDACC__ +template <> +void LayerNormGradComputeGeneralImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const TBlob& ograd, + const TBlob& data, + const TBlob& gamma, + const TBlob& mean, + const TBlob& std, + const TBlob& normalized_data, + const TBlob& ograd_mult, + const TBlob& red_out, + const std::vector& req, + const std::vector& outputs, + const mshadow::Tensor& workspace, + const mxnet::TShape& red_dst_shape, + const mxnet::TShape& red_src_shape, + const mxnet::TShape& red_exclude_dst_shape, + const mxnet::TShape& red_exclude_src_shape, + const int channel_size) { using namespace mshadow; using namespace mshadow::expr; - CHECK_EQ(inputs.size(), 5U); - const LayerNormParam& param = nnvm::get(attrs.parsed); - int axis = param.axis; - if (axis < 0) { - axis += inputs[0].ndim(); - } - CHECK(axis >= 0 && axis < inputs[0].ndim()) << "Channel axis out of range: " << param.axis; - Stream *s = ctx.get_stream(); - // Reshape gamma to be broadcastable - mxnet::TShape new_param_shape(inputs[0].shape_.begin(), inputs[0].shape_.end()); - for (int i = 0; i < inputs[0].ndim(); i++) { - if (i != axis) { - new_param_shape[i] = 1; - } - } - const TBlob ograd = inputs[0]; - const TBlob data = inputs[1]; - const TBlob gamma = inputs[2].reshape(new_param_shape); - const TBlob mean = inputs[3]; - const TBlob std = inputs[4]; - // Prepare the necessary shapes for reduction - mxnet::TShape red_src_shape, red_dst_shape, red_exclude_src_shape, red_exclude_dst_shape; - BroadcastReduceShapeCompact(ograd.shape_, mean.shape_, &red_src_shape, &red_dst_shape); - BroadcastReduceShapeCompact(ograd.shape_, gamma.shape_, - &red_exclude_src_shape, &red_exclude_dst_shape); - int channel_size = red_src_shape.Size() / red_dst_shape.Size(); - // Initialize the workspace + Construct the temporary TBlobs - Tensor workspace; - size_t reduce_workspace_size = 0; - size_t data_size = 0; - size_t red_out_size = 0; - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - data_size = sizeof(DType) * data.Size(); - red_out_size = sizeof(DType) * mean.Size(); - // There are two types of reduction workloads: reduce over axis and reduce exclude axis - // We take the maximum of the workspace sizes required by these workloads. - // Also, we explicitly set the req_type=kAddto in case we want to use it. - reduce_workspace_size = - std::max(reduce_workspace_size, - broadcast::ReduceWorkspaceSize(s, red_dst_shape, - kAddTo, red_src_shape, - sizeof(DType))); - reduce_workspace_size = - std::max(reduce_workspace_size, - broadcast::ReduceWorkspaceSize(s, red_exclude_dst_shape, kAddTo, - red_exclude_src_shape, - sizeof(DType))); - }); - workspace = ctx.requested[0].get_space_typed( - Shape1(reduce_workspace_size + data_size * 2 + red_out_size), s); - const TBlob normalized_data = TBlob(workspace.dptr_ + reduce_workspace_size, - data.shape_, data.dev_mask(), data.type_flag_, data.dev_id()); - const TBlob ograd_mult = TBlob(workspace.dptr_ + reduce_workspace_size + data_size, - ograd.shape_, ograd.dev_mask(), ograd.type_flag_, ograd.dev_id()); - const TBlob red_out = TBlob(workspace.dptr_ + reduce_workspace_size + data_size * 2, - mean.shape_, mean.dev_mask(), mean.type_flag_, mean.dev_id()); + Stream *s = ctx.get_stream(); // Compute normalized_data = (data - mean) / std -#if !defined(__CUDACC__) - BinaryBroadcastCompute(attrs, ctx, + BinaryBroadcastCompute(attrs, ctx, {data, mean}, {kWriteTo}, {normalized_data}); - BinaryBroadcastCompute(attrs, ctx, + BinaryBroadcastCompute(attrs, ctx, {normalized_data, std}, {kWriteTo}, {normalized_data}); -#else - BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, - {data, mean}, - {kWriteTo}, {normalized_data}); - BinaryBroadcastRTCCompute {"div"}(attrs, ctx, - {normalized_data, std}, - {kWriteTo}, {normalized_data}); -#endif // !defined(__CUDACC__) // Calculate grad_beta bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true); if (req[2] != kNullOp) { @@ -303,13 +283,8 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs, }); } // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis) -#if !defined(__CUDACC__) - ElemwiseBinaryOp::Compute(attrs, ctx, {normalized_data, ograd}, + ElemwiseBinaryOp::Compute(attrs, ctx, {normalized_data, ograd}, {kWriteTo}, {ograd_mult}); -#else - ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {normalized_data, ograd}, - {kWriteTo}, {ograd_mult}); -#endif // !defined(__CUDACC__) if (req[1] != kNullOp) { MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { @@ -330,21 +305,12 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs, // grad_data = ograd_mult - mean(ograd_mult, axis) // + normalized_data * (-mean(normalized_data * ograd_mult, axis)) if (req[0] != kNullOp) { -#if !defined(__CUDACC__) - BinaryBroadcastCompute(attrs, ctx, + BinaryBroadcastCompute(attrs, ctx, {ograd, gamma}, {kWriteTo}, {ograd_mult}); - BinaryBroadcastCompute(attrs, ctx, + BinaryBroadcastCompute(attrs, ctx, {ograd_mult, std}, {kWriteTo}, {ograd_mult}); -#else - BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, - {ograd, gamma}, - {kWriteTo}, {ograd_mult}); - BinaryBroadcastRTCCompute {"div"}(attrs, ctx, - {ograd_mult, std}, - {kWriteTo}, {ograd_mult}); -#endif // !defined(__CUDACC__) MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { if (safe_acc) { @@ -357,22 +323,14 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs, ograd_mult.reshape(red_src_shape)); } }); - Tensor red_out_tensor = red_out.FlatTo1D(s); + Tensor red_out_tensor = red_out.FlatTo1D(s); red_out_tensor /= scalar(channel_size); }); -#if !defined(__CUDACC__) - BinaryBroadcastCompute(attrs, ctx, + BinaryBroadcastCompute(attrs, ctx, {ograd_mult, red_out}, {req[0]}, {outputs[0]}); - ElemwiseBinaryOp::Compute(attrs, ctx, {ograd_mult, normalized_data}, + ElemwiseBinaryOp::Compute(attrs, ctx, {ograd_mult, normalized_data}, {kWriteTo}, {ograd_mult}); -#else - BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, - {ograd_mult, red_out}, - {req[0]}, {outputs[0]}); - ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {ograd_mult, normalized_data}, - {kWriteTo}, {ograd_mult}); -#endif // !defined(__CUDACC__) MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { if (safe_acc) { @@ -385,20 +343,181 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs, ograd_mult.reshape(red_src_shape)); } }); - Tensor red_out_tensor = red_out.FlatTo1D(s); + Tensor red_out_tensor = red_out.FlatTo1D(s); red_out_tensor /= scalar(- channel_size); }); -#if !defined(__CUDACC__) - BinaryBroadcastCompute(attrs, ctx, + BinaryBroadcastCompute(attrs, ctx, {normalized_data, red_out}, {kAddTo}, {outputs[0]}); + } +} + #else + +template <> +void LayerNormGradComputeGeneralImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const TBlob& ograd, + const TBlob& data, + const TBlob& gamma, + const TBlob& mean, + const TBlob& std, + const TBlob& normalized_data, + const TBlob& ograd_mult, + const TBlob& red_out, + const std::vector& req, + const std::vector& outputs, + const mshadow::Tensor& workspace, + const mxnet::TShape& red_dst_shape, + const mxnet::TShape& red_src_shape, + const mxnet::TShape& red_exclude_dst_shape, + const mxnet::TShape& red_exclude_src_shape, + const int channel_size) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + // Compute normalized_data = (data - mean) / std + BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, + {data, mean}, + {kWriteTo}, {normalized_data}); + BinaryBroadcastRTCCompute {"div"}(attrs, ctx, + {normalized_data, std}, + {kWriteTo}, {normalized_data}); + // Calculate grad_beta + if (req[2] != kNullOp) { + BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace, + ograd.reshape(red_exclude_src_shape), "red::sum{}", NDim, "identity"); + }); + } + // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis) + ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {normalized_data, ograd}, + {kWriteTo}, {ograd_mult}); + if (req[1] != kNullOp) { + BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace, + ograd_mult.reshape(red_exclude_src_shape), "red::sum{}", NDim, + "identity"); + }); + } + // Calculate grad_data: + // ograd_mult = ograd * gamma / std + // grad_data = ograd_mult - mean(ograd_mult, axis) + // + normalized_data * (-mean(normalized_data * ograd_mult, axis)) + if (req[0] != kNullOp) { + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, + {ograd, gamma}, + {kWriteTo}, {ograd_mult}); + BinaryBroadcastRTCCompute {"div"}(attrs, ctx, + {ograd_mult, std}, + {kWriteTo}, {ograd_mult}); + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape), "red::sum{}", NDim, "identity"); + }); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor red_out_tensor = red_out.FlatTo1D(s); + red_out_tensor /= scalar(channel_size); + }); + BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, + {ograd_mult, red_out}, + {req[0]}, {outputs[0]}); + ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {ograd_mult, normalized_data}, + {kWriteTo}, {ograd_mult}); + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape), "red::sum{}", NDim, "identity"); + }); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor red_out_tensor = red_out.FlatTo1D(s); + red_out_tensor /= scalar(- channel_size); + }); BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, {normalized_data, red_out}, {kAddTo}, {outputs[0]}); -#endif // !defined(__CUDACC__) } } +#endif + +/* +Calculate the gradient of layer normalization. +We have the following gradient for gamma, beta and x: + +\bar{x} = (x - mean) / std +w = og * r / std +grad_gamma = sum(\bar{x} og, exclude_axis) +grad_beta = sum(og, exclude_axis) +grad_x = w - mean(w, axis) - \bar{x} * mean(w * \bar{x}, axis) +*/ +template +void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(inputs.size(), 5U); + const LayerNormParam& param = nnvm::get(attrs.parsed); + int axis = param.axis; + if (axis < 0) { + axis += inputs[0].ndim(); + } + CHECK(axis >= 0 && axis < inputs[0].ndim()) << "Channel axis out of range: " << param.axis; + Stream *s = ctx.get_stream(); + // Reshape gamma to be broadcastable + mxnet::TShape new_param_shape(inputs[0].shape_.begin(), inputs[0].shape_.end()); + for (int i = 0; i < inputs[0].ndim(); i++) { + if (i != axis) { + new_param_shape[i] = 1; + } + } + const TBlob ograd = inputs[0]; + const TBlob data = inputs[1]; + const TBlob gamma = inputs[2].reshape(new_param_shape); + const TBlob mean = inputs[3]; + const TBlob std = inputs[4]; + // Prepare the necessary shapes for reduction + mxnet::TShape red_src_shape, red_dst_shape, red_exclude_src_shape, red_exclude_dst_shape; + BroadcastReduceShapeCompact(ograd.shape_, mean.shape_, &red_src_shape, &red_dst_shape); + BroadcastReduceShapeCompact(ograd.shape_, gamma.shape_, + &red_exclude_src_shape, &red_exclude_dst_shape); + int channel_size = red_src_shape.Size() / red_dst_shape.Size(); + // Initialize the workspace + Construct the temporary TBlobs + Tensor workspace; + size_t reduce_workspace_size = 0; + size_t data_size = 0; + size_t red_out_size = 0; + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + data_size = sizeof(DType) * data.Size(); + red_out_size = sizeof(DType) * mean.Size(); + // There are two types of reduction workloads: reduce over axis and reduce exclude axis + // We take the maximum of the workspace sizes required by these workloads. + // Also, we explicitly set the req_type=kAddto in case we want to use it. + reduce_workspace_size = + std::max(reduce_workspace_size, + broadcast::ReduceWorkspaceSize(s, red_dst_shape, + kAddTo, red_src_shape, + sizeof(DType))); + reduce_workspace_size = + std::max(reduce_workspace_size, + broadcast::ReduceWorkspaceSize(s, red_exclude_dst_shape, kAddTo, + red_exclude_src_shape, + sizeof(DType))); + }); + workspace = ctx.requested[0].get_space_typed( + Shape1(reduce_workspace_size + data_size * 2 + red_out_size), s); + const TBlob normalized_data = TBlob(workspace.dptr_ + reduce_workspace_size, + data.shape_, data.dev_mask(), data.type_flag_, data.dev_id()); + const TBlob ograd_mult = TBlob(workspace.dptr_ + reduce_workspace_size + data_size, + ograd.shape_, ograd.dev_mask(), ograd.type_flag_, ograd.dev_id()); + const TBlob red_out = TBlob(workspace.dptr_ + reduce_workspace_size + data_size * 2, + mean.shape_, mean.dev_mask(), mean.type_flag_, mean.dev_id()); + + LayerNormGradComputeGeneralImpl(attrs, ctx, ograd, data, gamma, mean, std, normalized_data, + ograd_mult, red_out, req, outputs, workspace, red_dst_shape, + red_src_shape, red_exclude_dst_shape, red_exclude_src_shape, + channel_size); +} } // namespace op } // namespace mxnet diff --git a/src/operator/nn/moments-inl.h b/src/operator/nn/moments-inl.h index 6a9bdc54b905..f8d8130a7692 100644 --- a/src/operator/nn/moments-inl.h +++ b/src/operator/nn/moments-inl.h @@ -126,7 +126,11 @@ inline void MomentsForwardImpl(const OpContext& ctx, small = ReduceAxesShapeImpl(inputs[0].shape_, axes, true, false); } +#if !defined(__CUDACC__) ReduceAxesComputeImpl(ctx, {data}, {req[0]}, {mean}, small); +#else + ReduceAxesRTCComputeImpl(ctx, {data}, {req[0]}, {mean}, small, "red::sum{}", true); +#endif MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { Shape<6> data_shape, mean_shape; for (int i = 0; i < 6; ++i) { @@ -137,8 +141,13 @@ inline void MomentsForwardImpl(const OpContext& ctx, ctx.requested[0].get_space_typed(Shape1(data.shape_.Size()), s);; Kernel::Launch(s, data.shape_.Size(), temp_data.dptr_, data.dptr(), mean.dptr(), data_shape, mean_shape); +#if !defined(__CUDACC__) ReduceAxesComputeImpl( ctx, {TBlob(temp_data).reshape(data.shape_)}, {kWriteTo}, {var}, small); +#else + ReduceAxesRTCComputeImpl(ctx, {TBlob(temp_data).reshape(data.shape_)}, + {kWriteTo}, {var}, small, "red::sum{}", true); +#endif }); } diff --git a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.cuh b/src/operator/numpy/linalg/broadcast_reduce_customized-inl.cuh deleted file mode 100644 index d4374edc9828..000000000000 --- a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.cuh +++ /dev/null @@ -1,415 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2015-2020 by Contributors - * \file broadcast_reduce_customized-inl.cuh - * \brief Customized CUDA implementations for binary broadcast and reduce - * \author MXNet contributors -*/ -#ifndef MXNET_OPERATOR_NUMPY_LINALG_BROADCAST_REDUCE_INL_CUSTOMIZED_CUH_ -#define MXNET_OPERATOR_NUMPY_LINALG_BROADCAST_REDUCE_INL_CUSTOMIZED_CUH_ - -#include "../../tensor/broadcast_reduce-inl.cuh" - -using namespace mshadow::cuda; - -template -__launch_bounds__(nthread_reduce) -__global__ void reduce_kernel_wr(const int N, const int M, const bool addto, - const DType* __restrict big, OType *small, - const Shape big_shape0, const Shape small_shape, - const Shape big_shape, const Shape big_stride, - const int Mnext, const bool do_transpose, - Reducer* reducer) { - extern __shared__ char shTileChar[]; - AType* shTile = (AType*)(shTileChar); - const int tid = threadIdx.x + threadIdx.y*blockDim.x; - const int bx = (do_transpose) ? blockDim.y : blockDim.x; - const int by = (do_transpose) ? blockDim.x : blockDim.y; - const int tidx = (do_transpose) ? tid / by : threadIdx.x; - const int tidy = (do_transpose) ? tid % by : threadIdx.y; - // bool need_clean = !reducer; - // reducer = reducer ? reducer : new Reducer(); - for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) { - // This TB handles M range [Mstart, ...., Mend - 1] - const int Mstart = (int)((uint64_t)M*(uint64_t)m0/(uint64_t)Mnext); - const int Mend = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext); - for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) { - int idx = idx0 + tidx; - Shape coord = unravel(idx, small_shape); - int idx_big0 = ravel(coord, big_shape0); - - AType val, residual; - reducer->SetInitValue(val, residual); - if (idx < N) { - for (int k = tidy + Mstart; k < Mend; k += by*unroll) { - int idx_big[unroll]; - #pragma unroll - for (int u=0;u < unroll;u++) { - idx_big[u] = idx_big0 + unravel_dot(k + u*by, big_shape, big_stride); - } - DType tmp[unroll]; - #pragma unroll - for (int u=0;u < unroll;u++) { - if (k + u*by < Mend) { - tmp[u] = OP::Map(big[idx_big[u]]); - } - } - #pragma unroll - for (int u=0;u < unroll;u++) { - if (k + u*by < Mend) reducer->Reduce(val, AType(tmp[u]), residual); - } - } - } - - // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0 - if (by > 1) { - // Fix bx to avoid bank conflicts. Assumes warpSize number of banks - const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx; - const int it0 = tidx + tidy*fbx; - shTile[it0 * 2] = val; - shTile[it0 * 2 + 1] = residual; - __syncthreads(); - for (int t=1;t < by;t <<= 1) { - AType tmp, tmp_residual; - reducer->SetInitValue(tmp, tmp_residual); - if (tidy + t < by) { - tmp = shTile[(it0 + t*fbx) * 2]; - tmp_residual = shTile[(it0 + t*fbx) * 2 + 1]; - } - __syncthreads(); - reducer->Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual); - __syncthreads(); - } - if (idx < N && tidy == 0) { - reducer->Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]); - assign(&small[idx + m0*N], addto, OType(shTile[tidx * 2])); - } - } else { - if (idx < N) { - reducer->Finalize(val, residual); - assign(&small[idx + m0*N], addto, OType(val)); - } - } - } - } - // if (need_clean) { - // delete reducer; - // } -} - -template -__launch_bounds__(nthread_reduce) -__global__ void reduce_kernel_wr(const int N, const int M, const bool addto, - const DType* __restrict big, const DType* __restrict lhs, - const DType* __restrict rhs, DType *small, - const Shape big_shape0, const Shape lhs_shape0, - const Shape rhs_shape0, const Shape small_shape, - const Shape big_shape, const Shape lhs_shape, - const Shape rhs_shape, const Shape big_stride, - const Shape lhs_stride, const Shape rhs_stride, - const int Mnext, const bool do_transpose, - Reducer* reducer) { - extern __shared__ char shTileChar[]; - DType* shTile = (DType*)(shTileChar); - const int tid = threadIdx.x + threadIdx.y*blockDim.x; - const int bx = (do_transpose) ? blockDim.y : blockDim.x; - const int by = (do_transpose) ? blockDim.x : blockDim.y; - const int tidx = (do_transpose) ? tid / by : threadIdx.x; - const int tidy = (do_transpose) ? tid % by : threadIdx.y; - // bool need_clean = !reducer; - // reducer = reducer ? reducer : new Reducer(); - for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) { - // This TB handles M range [Mstart, ...., Mend - 1] - const int Mstart = (int)((uint64_t)M*(uint64_t)m0/(uint64_t)Mnext); - const int Mend = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext); - for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) { - int idx = idx0 + tidx; - Shape coord = unravel(idx, small_shape); - int idx_big0 = ravel(coord, big_shape0); - int idx_lhs0 = ravel(coord, lhs_shape0); - int idx_rhs0 = ravel(coord, rhs_shape0); - - DType val, residual; - reducer->SetInitValue(val, residual); - if (idx < N) { - for (int k = tidy + Mstart; k < Mend; k += by*unroll) { - int idx_big[unroll]; - int idx_lhs[unroll]; - int idx_rhs[unroll]; - #pragma unroll - for (int u=0;u < unroll;u++) { - idx_big[u] = idx_big0 + unravel_dot(k + u*by, big_shape, big_stride); - idx_lhs[u] = idx_lhs0 + unravel_dot(k + u*by, lhs_shape, lhs_stride); - idx_rhs[u] = idx_rhs0 + unravel_dot(k + u*by, rhs_shape, rhs_stride); - } - DType tmp[unroll]; - #pragma unroll - for (int u=0;u < unroll;u++) { - if (k + u*by < Mend) { - tmp[u] = OP1::Map(big[idx_big[u]], OP2::Map(lhs[idx_lhs[u]], rhs[idx_rhs[u]])); - } - } - #pragma unroll - for (int u=0;u < unroll;u++) { - if (k + u*by < Mend) reducer->Reduce(val, tmp[u], residual); - } - } - } - - // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0 - if (by > 1) { - // Fix bx to avoid bank conflicts. Assumes warpSize number of banks - const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx; - const int it0 = tidx + tidy*fbx; - shTile[it0 * 2] = val; - shTile[it0 * 2 + 1] = residual; - __syncthreads(); - for (int t=1;t < by;t <<= 1) { - DType tmp, tmp_residual; - reducer->SetInitValue(tmp, tmp_residual); - if (tidy + t < by) { - tmp = shTile[(it0 + t*fbx) * 2]; - tmp_residual = shTile[(it0 + t*fbx) * 2 + 1]; - } - __syncthreads(); - reducer->Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual); - __syncthreads(); - } - if (idx < N && tidy == 0) { - reducer->Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]); - assign(&small[idx + m0*N], addto, shTile[tidx * 2]); - } - } else { - if (idx < N) { - reducer->Finalize(val, residual); - assign(&small[idx + m0*N], addto, val); - } - } - } - } - // if (need_clean) { - // delete reducer; - // } -} - -// Simple reduction of lines when M is small -template -__launch_bounds__(kMaxThreadsPerBlock) -__global__ void reduce_lines_kernel_wr(const int N, const int M, const bool addto, - const int small_in_stride, const DType* __restrict small_in, DType *small_out, - Reducer* reducer) { - for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { - - DType val, residual; - reducer->SetInitValue(val, residual); - for (int k = 0; k < M; k++) { - reducer->Reduce(val, small_in[idx + k*small_in_stride], residual); - } - - if (idx < N) { - reducer->Finalize(val, residual); - assign(&small_out[idx], addto, val); - } - - } -} - -template -__launch_bounds__(kMaxThreadsPerBlock) -__global__ void reduce_kernel_M1_wr(const int N, const bool addto, - const DType* __restrict big, OType *small, const Shape bshape, - const Shape sshape, Reducer* reducer) { - for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { - Shape coord = unravel(idx, sshape); - int j = ravel(coord, bshape); - AType val, residual; - reducer->SetInitValue(val, residual); - reducer->Reduce(val, AType(OP::Map(big[j])), residual); - reducer->Finalize(val, residual); - assign(&small[idx], addto, OType(val)); - } -} - -template -__launch_bounds__(kMaxThreadsPerBlock) -__global__ void reduce_kernel_M1_wr(const int N, const bool addto, - const DType* __restrict big, - const DType* __restrict lhs, - const DType* __restrict rhs, - DType *small, - const Shape big_shape, - const Shape lhs_shape, - const Shape rhs_shape, - const Shape small_shape, - Reducer* reducer) { - for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { - Shape coord = unravel(idx, small_shape); - int idx_big = ravel(coord, big_shape); - int idx_lhs = ravel(coord, lhs_shape); - int idx_rhs = ravel(coord, rhs_shape); - DType val, residual; - reducer->SetInitValue(val, residual); - reducer->Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual); - reducer->Finalize(val, residual); - assign(&small[idx], addto, val); - } -} - -#define KERNEL_UNROLL_SWITCH(do_unroll, unrollAmount, unrollVar, ...) \ - if (do_unroll) { \ - const int unrollVar = unrollAmount; \ - {__VA_ARGS__} \ - } else { \ - const int unrollVar = 1; \ - {__VA_ARGS__} \ - } - -template -void ReduceImplWithReducer(cudaStream_t stream, const TBlob& small, const OpReqType req, - const TBlob& big, const Tensor& workspace, - const ReduceImplConfig& config, - Reducer* reducer = nullptr) { - bool need_clean = !reducer; - reducer = reducer ? reducer : new Reducer(); - if (config.M == 1) { - reduce_kernel_M1_wr - <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>( - config.N, req == kAddTo, big.dptr(), small.dptr(), big.shape_.get(), - small.shape_.get(), reducer); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1_wr); - } else { - OType* small_dptr = small.dptr(); - bool addto = (req == kAddTo); - if (config.Mnext > 1) { - // small_dptr[] is N*Mnext*sizeof(DType) bytes - small_dptr = reinterpret_cast(workspace.dptr_); - addto = false; - // Check that the workspace is contigiuous - CHECK_EQ(workspace.CheckContiguous(), true); - // Check that we have enough storage - CHECK_GE(workspace.size(0), config.workspace_size); - } - - const int by = (config.kernel_1.do_transpose) ? - config.kernel_1.blockDim.x : config.kernel_1.blockDim.y; - const bool do_unroll = ( config.M / (by*config.Mnext) >= unroll_reduce ); - KERNEL_UNROLL_SWITCH(do_unroll, unroll_reduce, UNROLL, { - reduce_kernel_wr - <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>( - config.N, config.M, addto, big.dptr(), small_dptr, big.shape_.get(), - small.shape_.get(), config.rshape.get(), config.rstride.get(), - config.Mnext, config.kernel_1.do_transpose, reducer); - }); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_wr); - - if (config.Mnext > 1) { - reduce_lines_kernel_wr - <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>> - (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr(), reducer); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel_wr); - } - } - if (need_clean) { - delete reducer; - } -} - -template -void ReduceImplWithReducer(cudaStream_t stream, const TBlob& small, const TBlob& lhs, const TBlob& rhs, - const OpReqType req, const TBlob& big, const Tensor& workspace, - const ReduceImplConfig& config, Reducer* reducer = nullptr) { - bool need_clean = !reducer; - reducer = reducer ? reducer : new Reducer(); - if (config.M == 1) { - reduce_kernel_M1_wr - <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>( - config.N, req == kAddTo, big.dptr(), lhs.dptr(), rhs.dptr(), - small.dptr(), big.shape_.get(), lhs.shape_.get(), - rhs.shape_.get(), small.shape_.get()); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1_wr); - } else { - DType* small_dptr = small.dptr(); - bool addto = (req == kAddTo); - if (config.Mnext > 1) { - // small_dptr[] is N*Mnext*sizeof(DType) bytes - small_dptr = reinterpret_cast(workspace.dptr_); - addto = false; - // Check that the workspace is contigiuous - CHECK_EQ(workspace.CheckContiguous(), true); - // Check that we have enough storage - CHECK_GE(workspace.size(0), config.workspace_size); - } - - const int by = (config.kernel_1.do_transpose) ? - config.kernel_1.blockDim.x : config.kernel_1.blockDim.y; - const bool do_unroll = ( config.M / (by*config.Mnext) >= unroll_reduce ); - KERNEL_UNROLL_SWITCH(do_unroll, unroll_reduce, UNROLL, { - reduce_kernel_wr - <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>( - config.N, config.M, addto, big.dptr(), lhs.dptr(), rhs.dptr(), - small_dptr, big.shape_.get(), lhs.shape_.get(), - rhs.shape_.get(), small.shape_.get(), config.rshape, config.lhs_shape, - config.rhs_shape, config.rstride, config.lhs_stride, config.rhs_stride, config.Mnext, - config.kernel_1.do_transpose, reducer); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_wr); - }); - - if (config.Mnext > 1) { - reduce_lines_kernel_wr - <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>> - (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr(), reducer); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel_wr); - } - } - if (need_clean) { - delete reducer; - } -} - -#undef KERNEL_UNROLL_SWITCH - -template -void ReduceWithReducer(Stream *s, const TBlob& small, const OpReqType req, - const Tensor& workspace, const TBlob& big, Reducer* reducer = nullptr) { - if (req == kNullOp) return; - cudaStream_t stream = Stream::GetStream(s); - bool need_clean = !reducer; - reducer = reducer ? reducer : new Reducer(); - ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr, sizeof(DType)); - if (safe_acc) { - MXNET_ACC_TYPE_SWITCH(mshadow::DataType::kFlag, DataType, AType, { - typedef typename std::conditional::type AccType; - MSHADOW_TYPE_SWITCH(small.type_flag_, OType, { - typedef typename std::conditional::type OutType; - config = ReduceImplConfig(small.shape_, big.shape_, nullptr, nullptr, sizeof(AccType)); - ReduceImplWithReducer( - stream, small, req, big, workspace, config, reducer); - }); - }); - } else { - ReduceImplWithReducer(stream, small, req, big, workspace, config, reducer); - } - if (need_clean) { - delete reducer; - } -} - -#endif // MXNET_OPERATOR_NUMPY_LINALG_BROADCAST_REDUCE_INL_CUSTOMIZED_CUH_ diff --git a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h b/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h index 0226df45f960..2941d54fb56c 100644 --- a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h +++ b/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h @@ -54,12 +54,6 @@ MSHADOW_XINLINE void seq_reduce_assign_wr(const index_t idx, const size_t M, con assign(&small[idx], addto, OType(val)); } -#ifdef __CUDACC__ -#include "broadcast_reduce_customized-inl.cuh" -#include "../../tensor/broadcast_reduce-inl.cuh" - -#else - template void seq_reduce_compute_wr(const size_t N, const size_t M, const bool addto, const DType *big, OType *small, const Shape bshape, @@ -177,7 +171,6 @@ void ReduceWithReducer(Stream *s, const TBlob& small, const OpReqType req, reducer); } -#endif } // namespace broadcast } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/linalg/np_norm-inl.h b/src/operator/numpy/linalg/np_norm-inl.h index b26e68086852..b8ab439106fe 100644 --- a/src/operator/numpy/linalg/np_norm-inl.h +++ b/src/operator/numpy/linalg/np_norm-inl.h @@ -285,18 +285,10 @@ void NumpyLpNormCompute(const nnvm::NodeAttrs& attrs, } else if (param.ord == std::numeric_limits::infinity()) { // inf norm LOG(FATAL) << "inf norm handled in front-end."; } else { +#ifndef __CUDACC__ mshadow_op::nrmlp host_reducer(param.ord); mshadow_op::nrmlp *reducer_instance = nullptr; -#ifdef __CUDACC__ - Stream *s = ctx.get_stream(); - cudaStream_t copy_stream = mshadow::Stream::GetStream(s); - cudaMalloc(reinterpret_cast(&reducer_instance), sizeof(mshadow_op::nrmlp)); - cudaMemcpyAsync(reducer_instance, &host_reducer, sizeof(mshadow_op::nrmlp), - cudaMemcpyHostToDevice, copy_stream); - cudaStreamSynchronize(copy_stream); -#else reducer_instance = &host_reducer; -#endif if (safe_acc) { ReduceAxesComputeImplWithReducer( ctx, inputs, req, outputs, small, reducer_instance); @@ -304,8 +296,10 @@ void NumpyLpNormCompute(const nnvm::NodeAttrs& attrs, ReduceAxesComputeImplWithReducer( ctx, inputs, req, outputs, small, reducer_instance); } -#ifdef __CUDACC__ - cudaFree(reducer_instance); +#else + ReduceAxesRTCComputeImpl( + ctx, inputs, req, outputs, small, "red::nrmlp{" + std::to_string(param.ord) + "}", + false, "abs"); #endif } } @@ -443,8 +437,13 @@ void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs, } if (param.flag == 1) { // Frobenius norm - ReduceAxesComputeImplWithReducer( +#if !defined(__CUDACC__) + ReduceAxesComputeImpl( ctx, inputs, req, outputs, reduced_shape); +#else + ReduceAxesRTCComputeImpl( + ctx, inputs, req, outputs, reduced_shape, "red::nrm2{}", false, "identity"); +#endif return; } @@ -456,6 +455,7 @@ void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs, MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { TBlob temp = outputs[1].reshape(sum_shape); std::vector sum_output({temp}); +#if !defined(__CUDACC__) ReduceAxesComputeImpl( ctx, inputs, req, sum_output, sum_shape); if (param.ord > 0) { @@ -465,6 +465,16 @@ void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs, ReduceAxesComputeImpl( ctx, sum_output, req, outputs, reduced_shape); } +#else + ReduceAxesRTCComputeImpl(ctx, inputs, req, sum_output, sum_shape, "red::sum{}", false, "abs"); + if (param.ord > 0) { + ReduceAxesRTCComputeImpl(ctx, sum_output, req, outputs, reduced_shape, + "red::maximum{}", false); + } else { + ReduceAxesRTCComputeImpl(ctx, sum_output, req, outputs, reduced_shape, + "red::minimum{}", false); + } +#endif // MXNET_USE_CUDA }); return; } @@ -500,6 +510,7 @@ void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs, L_trans[mat_axis[1]] = 1; } + std::vector eigen; MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, DType, { Tensor UT = outputs[1].get_with_shape(Shape3(batch_dim, row_dim, row_dim), s); @@ -523,32 +534,46 @@ void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs, Tensor svd_input = workspace.get_with_shape(Shape3(batch_dim, row_dim, col_dim), s); gesvd::op(svd_input, UT, L, V, ctx, attrs, &svd_workspace); - TBlob workspace0(reinterpret_cast(temp.dptr_), L_trans, temp.dev_mask(), temp.dev_id()); TransposeImpl(ctx.run_ctx, TBlob(L).reshape(L_shape), workspace0, reduce_axes); - std::vector eigen({ workspace0 }); - if (param.flag == 2) { // nuclear norm - ReduceAxesComputeImpl( + eigen.emplace_back(workspace0); + }); + +#if !defined(__CUDACC__) + if (param.flag == 2) { // nuclear norm + ReduceAxesComputeImpl( + ctx, eigen, req, outputs, reduced_shape); + } else if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true)) { + if (ord == 2) { + ReduceAxesComputeImpl( + ctx, eigen, req, outputs, reduced_shape); + } else if (ord == -2) { + ReduceAxesComputeImpl( ctx, eigen, req, outputs, reduced_shape); - } else if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true)) { - if (ord == 2) { - ReduceAxesComputeImpl( - ctx, eigen, req, outputs, reduced_shape); - } else if (ord == -2) { - ReduceAxesComputeImpl( - ctx, eigen, req, outputs, reduced_shape); - } - } else { - if (ord == 2) { - ReduceAxesComputeImpl( - ctx, eigen, req, outputs, reduced_shape); - } else if (ord == -2) { - ReduceAxesComputeImpl( - ctx, eigen, req, outputs, reduced_shape); - } } - }); + } else { + if (ord == 2) { + ReduceAxesComputeImpl( + ctx, eigen, req, outputs, reduced_shape); + } else if (ord == -2) { + ReduceAxesComputeImpl( + ctx, eigen, req, outputs, reduced_shape); + } + } +#else + if (param.flag == 2) { // nuclear norm + ReduceAxesRTCComputeImpl(ctx, eigen, req, outputs, reduced_shape, "red::sum{}", false); + } else { + if (ord == 2) { + ReduceAxesRTCComputeImpl(ctx, eigen, req, outputs, reduced_shape, + "red::maximum{}", false, "abs"); + } else if (ord == -2) { + ReduceAxesRTCComputeImpl(ctx, eigen, req, outputs, reduced_shape, + "red::minimum{}", false, "abs"); + } + } +#endif } template @@ -784,8 +809,13 @@ void NumpyNormComputeForward(const nnvm::NodeAttrs& attrs, std::vector flat_outputs({ outputs[0].reshape(TShape(1, 1)) }); - ReduceAxesComputeImplWithReducer( +#if !defined(__CUDACC__) + ReduceAxesComputeImpl( ctx, flat_inputs, req, flat_outputs, TShape(1, 1)); +#else + ReduceAxesRTCComputeImpl( + ctx, flat_inputs, req, flat_outputs, TShape(1, 1), "red::nrm2{}", false, "identity"); +#endif return; } diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h index 3b505b788ae9..93b027b572f1 100644 --- a/src/operator/numpy/np_broadcast_reduce_op.h +++ b/src/operator/numpy/np_broadcast_reduce_op.h @@ -298,7 +298,7 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { using namespace mshadow; if (req[0] == kNullOp) return; - const NumpyReduceAxesParam& param = nnvm::get(attrs.parsed); + const auto& param = nnvm::get(attrs.parsed); if (param.initial.has_value()) { LOG(FATAL) << "initial is not supported yet"; } @@ -839,13 +839,25 @@ void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs, // Compute weighted data TBlob wa = TBlob(temp_data_ptr, data.shape_, xpu::kDevMask); - NP_BROADCAST_REDUCE_OP_BROADCAST(mul)( - attrs, ctx, {data, weights}, {kWriteTo}, {wa}); + TBlob sum_of_wa; + if constexpr (std::is_same::value) { + BinaryBroadcastCompute( + attrs, ctx, {data, weights}, {kWriteTo}, {wa}); + + // Compute sum of weighted data + sum_of_wa = TBlob(temp_sum_ptr, small1, xpu::kDevMask); + ReduceAxesComputeWithWorkspaceImpl( + ctx, {wa}, {kWriteTo}, {sum_of_wa}, workspace, src_shape, dst_shape); + } else { +#if MXNET_USE_CUDA + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, {data, weights}, {kWriteTo}, {wa}); - // Compute sum of weighted data - TBlob sum_of_wa = TBlob(temp_sum_ptr, small1, xpu::kDevMask); - ReduceAxesComputeWithWorkspaceImpl( - ctx, {wa}, {kWriteTo}, {sum_of_wa}, workspace, src_shape, dst_shape); + // Compute sum of weighted data + sum_of_wa = TBlob(temp_sum_ptr, small1, xpu::kDevMask); + ReduceAxesRTCComputeWithWorkspaceImpl(ctx, {wa}, {kWriteTo}, {sum_of_wa}, "red::sum{}", + false, "identity", workspace, src_shape, dst_shape); +#endif + } if (!back) { const TBlob& avg = outputs[0]; const TBlob& sum_of_weights = outputs[1]; @@ -853,12 +865,22 @@ void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs, BroadcastReduceShapeCompact(weights.shape_, small2, &w_src_shape, &w_dst_shape); // Compute sum of weight TBlob scl = sum_of_weights.reshape(small2); - ReduceAxesComputeWithWorkspaceImpl( - ctx, {weights}, {kWriteTo}, {scl}, workspace, w_src_shape, w_dst_shape); - - // Compute avg and assign output - NP_BROADCAST_REDUCE_OP_BROADCAST(div)( - attrs, ctx, {sum_of_wa, scl}, req, {avg.reshape(small1)}); + if constexpr (std::is_same::value) { + ReduceAxesComputeWithWorkspaceImpl( + ctx, {weights}, {kWriteTo}, {scl}, workspace, w_src_shape, w_dst_shape); + // Compute avg and assign output + BinaryBroadcastCompute( + attrs, ctx, {sum_of_wa, scl}, req, {avg.reshape(small1)}); + } else { +#if MXNET_USE_CUDA + ReduceAxesRTCComputeWithWorkspaceImpl(ctx, {weights}, {kWriteTo}, {scl}, "red::sum{}", + false, "identity", workspace, w_src_shape, + w_dst_shape); + // Compute avg and assign output + BinaryBroadcastRTCCompute {"div"}( + attrs, ctx, {sum_of_wa, scl}, req, {avg.reshape(small1)}); +#endif + } } else { // Compute and assign the derivatives of a and weights const TBlob& igrad_a = outputs[0]; @@ -924,8 +946,14 @@ void NumpyWeightedAverageForward(const nnvm::NodeAttrs& attrs, auto ret = outputs[1].FlatTo1D(s); ret = scalar(data.shape_.Size()/small.Size()); // Compute mean - ReduceAxesComputeImpl( - ctx, inputs, req, {outputs[0]}, small); + if constexpr (std::is_same::value) { + ReduceAxesComputeImpl( + ctx, inputs, req, {outputs[0]}, small); + } else { +#if MXNET_USE_CUDA + ReduceAxesRTCComputeImpl(ctx, inputs, req, {outputs[0]}, small, "red::sum{}", true); +#endif + } } else { NumpyWeightedAverageComputeImpl( attrs, ctx, inputs, req, outputs, param.axis); @@ -1010,8 +1038,15 @@ void NumpyMomentsForward(const nnvm::NodeAttrs& attrs, char *workspace_ptr = temp_mem.dptr_ + temp_data_size; Tensor workspace(workspace_ptr, Shape1(workspace_size), s); // Compute mean - ReduceAxesComputeWithWorkspaceImpl( - ctx, inputs, {kWriteTo}, {mean}, workspace, src_shape, dst_shape); + if constexpr (std::is_same::value) { + ReduceAxesComputeWithWorkspaceImpl( + ctx, inputs, {kWriteTo}, {mean}, workspace, src_shape, dst_shape); + } else { +#if MXNET_USE_CUDA + ReduceAxesRTCComputeWithWorkspaceImpl(ctx, inputs, {kWriteTo}, {mean}, "red::sum{}", + true, "identity", workspace, src_shape, dst_shape); +#endif + } // Compute data - mean Shape<6> data_shape, mean_shape; for (int i = 0; i < 6; ++i) { @@ -1022,11 +1057,22 @@ void NumpyMomentsForward(const nnvm::NodeAttrs& attrs, data.dptr(), mean.dptr(), data_shape, mean_shape); Tensor temp_data_tensor(temp_data_ptr, Shape1(data.shape_.Size()), s); TBlob temp_data_blob = TBlob(temp_data_tensor).reshape(data.shape_); - ReduceAxesComputeWithWorkspaceImpl( - ctx, {temp_data_blob}, {req[0]}, {moment}, workspace, src_shape, dst_shape, param.ddof); - if (sqrt) { - Tensor moment_tensor = moment.FlatTo1D(s); - moment_tensor = F(moment_tensor); + if constexpr (std::is_same::value) { + ReduceAxesComputeWithWorkspaceImpl( + ctx, {temp_data_blob}, {req[0]}, {moment}, workspace, src_shape, dst_shape, param.ddof); + if (sqrt && req[0] != kNullOp) { + Tensor moment_tensor = moment.FlatTo1D(s); + moment_tensor = F(moment_tensor); + } + } else { +#if MXNET_USE_CUDA + ReduceAxesRTCComputeWithWorkspaceImpl(ctx, {temp_data_blob}, {req[0]}, {moment}, + "red::sum{}", true, "identity", workspace, + src_shape, dst_shape, param.ddof); + if (sqrt && req[0] != kNullOp) { + UnaryRTCCompute {"sqrt"}({}, ctx, {moment}, {kWriteInplace}, {moment}); + } +#endif } }); }); @@ -1065,6 +1111,7 @@ void NumpyBroadcastToBackward(const nnvm::NodeAttrs& attrs, for (int i = 0; i < igrad_shape.ndim(); ++i) { expanded_igrad_shape[i + ndim_delta] = igrad_shape[i]; } +#if !defined(__CUDACC__) if (NeedSafeAcc(inputs[0].type_flag_, outputs[0].type_flag_)) { ReduceAxesComputeImpl( ctx, inputs, req, {outputs[0].reshape(expanded_igrad_shape)}, expanded_igrad_shape); @@ -1072,6 +1119,10 @@ void NumpyBroadcastToBackward(const nnvm::NodeAttrs& attrs, ReduceAxesComputeImpl( ctx, inputs, req, {outputs[0].reshape(expanded_igrad_shape)}, expanded_igrad_shape); } +#else + ReduceAxesRTCComputeImpl(ctx, inputs, req, {outputs[0].reshape(expanded_igrad_shape)}, + expanded_igrad_shape, "red::sum{}", false); +#endif } template diff --git a/src/operator/numpy/np_broadcast_reduce_op_boolean.cu b/src/operator/numpy/np_broadcast_reduce_op_boolean.cu index d3247b743bc5..405ae4b38eb7 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_boolean.cu +++ b/src/operator/numpy/np_broadcast_reduce_op_boolean.cu @@ -29,12 +29,12 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_npi_any) -.set_attr("FCompute", NumpyReduceAxesBoolCompute); +.set_attr("FCompute", ReduceAxesRTCCompute + {"NonZero", "red::sum{}", false}); NNVM_REGISTER_OP(_npi_all) -.set_attr("FCompute", NumpyReduceAxesBoolCompute); +.set_attr("FCompute", ReduceAxesRTCCompute + {"NonZero", "red::product{}", false}); } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cu b/src/operator/numpy/np_broadcast_reduce_op_value.cu index 422097d20181..602057324af2 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cu +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cu @@ -27,25 +27,30 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_npi_sum) -.set_attr("FCompute", NumpyReduceAxesCompute); +.set_attr("FCompute", + ReduceAxesRTCCompute{"identity", "red::sum{}", false}); NNVM_REGISTER_OP(_backward_npi_sum) .set_attr("FCompute", NumpyReduceAxesBackwardUseNone); NNVM_REGISTER_OP(_npi_max) -.set_attr("FCompute", NumpyReduceAxesNoDTypeCompute); +.set_attr("FCompute", ReduceAxesRTCCompute + {"identity", "red::maximum{}", false}); NNVM_REGISTER_OP(_backward_npi_max) .set_attr("FCompute", NumpyReduceAxesNoDTypeBackward); NNVM_REGISTER_OP(_npi_min) -.set_attr("FCompute", NumpyReduceAxesNoDTypeCompute); +.set_attr("FCompute", + ReduceAxesRTCCompute{"identity", + "red::minimum{}", false}); NNVM_REGISTER_OP(_backward_npi_min) .set_attr("FCompute", NumpyReduceAxesNoDTypeBackward); NNVM_REGISTER_OP(_npi_prod) -.set_attr("FCompute", NumpyReduceAxesCompute); +.set_attr("FCompute", ReduceAxesRTCCompute{"identity", + "red::product{}", false}); NNVM_REGISTER_OP(_backward_npi_prod) .set_attr("FCompute", NumpyReduceAxesBackwardUseInOut); @@ -57,7 +62,8 @@ NNVM_REGISTER_OP(_backward_np_average) .set_attr("FCompute", NumpyWeightedAverageBackward); NNVM_REGISTER_OP(_npi_mean) -.set_attr("FCompute", NumpyReduceAxesCompute); +.set_attr("FCompute", ReduceAxesRTCCompute{"identity", + "red::sum{}", true}); NNVM_REGISTER_OP(_backward_np_mean) .set_attr("FCompute", NumpyReduceAxesBackwardUseNone); diff --git a/src/operator/numpy/np_constraint_check.h b/src/operator/numpy/np_constraint_check.h index 80beaa3a0bf5..01c54b650616 100644 --- a/src/operator/numpy/np_constraint_check.h +++ b/src/operator/numpy/np_constraint_check.h @@ -56,9 +56,14 @@ void ConstraintCheckForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, CHECK_EQ(outputs.size(), 1U); const ConstraintCheckParam& param = nnvm::get(attrs.parsed); +#if !defined(__CUDACC__) ReduceAxesComputeImpl(ctx, inputs, req, outputs, outputs[0].shape_); +#else + ReduceAxesRTCComputeImpl(ctx, inputs, req, outputs, + outputs[0].shape_, "red::product{}"); +#endif std::string msg = param.msg; bool red_output = true; GetReduceOutput(ctx.get_stream(), outputs[0], &red_output); diff --git a/src/operator/numpy/np_cross-inl.h b/src/operator/numpy/np_cross-inl.h index ab64564dc85d..3543bac302a5 100644 --- a/src/operator/numpy/np_cross-inl.h +++ b/src/operator/numpy/np_cross-inl.h @@ -691,8 +691,13 @@ struct ReduceImplWrap { Stream *s = ctx.get_stream(); // Reduce work_in to work_out. SUM_NDIM_SWITCH(work_out.ndim(), NDim, { +#if !defined(__CUDACC__) op::broadcast::Reduce( s, work_out, kWriteTo, workspace_tensor, work_in); +#else + op::broadcast::RTCReduce(ctx, work_out, kWriteTo, workspace_tensor, work_in, + "red::sum{}", NDim, "identity"); +#endif }); // Copy work_out to out_data. MXNET_ASSIGN_REQ_SWITCH(out_req, req_type, { diff --git a/src/operator/numpy/np_kron-inl.h b/src/operator/numpy/np_kron-inl.h index 0d72921691a9..affd221f574d 100644 --- a/src/operator/numpy/np_kron-inl.h +++ b/src/operator/numpy/np_kron-inl.h @@ -230,8 +230,15 @@ void KronOpBackwardImpl(const OpContext& ctx, ctx.requested[0].get_space_typed(Shape1(ograd.shape_.Size()), s); ASSIGN_DISPATCH(workspace, kWriteTo, tensor_ * ograd_); - ReduceAxesComputeImpl( - ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_); + if constexpr (std::is_same::value) { + ReduceAxesComputeImpl( + ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_); + } else { +#if MXNET_USE_CUDA + ReduceAxesRTCComputeImpl(ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, + scalar_grad_.shape_, "red::sum{}", false); +#endif + } } else { MXNET_NDIM_SWITCH(oshape.ndim(), ndim, { Shape ashape_ = oshape.get(); diff --git a/src/operator/numpy/np_tensordot_op-inl.h b/src/operator/numpy/np_tensordot_op-inl.h index bbdfab4cebac..aad1a8f3d530 100644 --- a/src/operator/numpy/np_tensordot_op-inl.h +++ b/src/operator/numpy/np_tensordot_op-inl.h @@ -424,8 +424,16 @@ void TensordotBackwardImpl(const Tuple& a_axes_summed, workspace.stream_); ASSIGN_DISPATCH(dtypespace, kWriteTo, tensor_ * out_grad_); - ReduceAxesComputeImpl( - ctx, {TBlob(dtypespace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_); + if constexpr (std::is_same::value) { + ReduceAxesComputeImpl( + ctx, {TBlob(dtypespace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_); + } else { +#if MXNET_USE_CUDA + ReduceAxesRTCComputeImpl(ctx, {TBlob(dtypespace)}, {scalar_req}, + {TBlob(scalar_grad_)}, scalar_grad_.shape_, + "red::sum{}"); +#endif + } } else { // Two tensors of at least 1 dimensions. Tuple a_axes_remained; @@ -734,8 +742,15 @@ void TensordotIntAxesBackwardImpl(const int axes, ctx.requested[0].get_space_typed(Shape1(out_grad.shape_.Size()), s); ASSIGN_DISPATCH(workspace, kWriteTo, tensor_ * out_grad_); - ReduceAxesComputeImpl( - ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_); + if constexpr (std::is_same::value) { + ReduceAxesComputeImpl( + ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_); + } else { +#if MXNET_USE_CUDA + ReduceAxesRTCComputeImpl(ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, + scalar_grad_.shape_, "red::sum{}"); +#endif + } } else { // Two tensors of at least 1 dimensions. Tuple a_axes_summed; diff --git a/src/operator/numpy/np_where_op-inl.h b/src/operator/numpy/np_where_op-inl.h index 10ec081b2a8f..3f60e93d4121 100644 --- a/src/operator/numpy/np_where_op-inl.h +++ b/src/operator/numpy/np_where_op-inl.h @@ -245,12 +245,19 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs, mxnet_op::Kernel, xpu>::Launch( s, ograd.Size(), req[0], cstride, oshape, cond.dptr(), ograd.dptr(), workspace.dptr_); - if (NeedSafeAcc(dx.type_flag_, dx.type_flag_)) { - ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, - {dx.reshape(expanded_lshape)}, expanded_lshape); + if constexpr (std::is_same::value) { + if (NeedSafeAcc(dx.type_flag_, dx.type_flag_)) { + ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, + {dx.reshape(expanded_lshape)}, expanded_lshape); + } else { + ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, + {dx.reshape(expanded_lshape)}, expanded_lshape); + } } else { - ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, - {dx.reshape(expanded_lshape)}, expanded_lshape); +#if MXNET_USE_CUDA + ReduceAxesRTCComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, + {dx.reshape(expanded_lshape)}, expanded_lshape, "red::sum{}"); +#endif // MXNET_USE_CUDA } } // process right output @@ -267,12 +274,19 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs, mxnet_op::Kernel, xpu>::Launch( s, ograd.Size(), req[1], cstride, oshape, cond.dptr(), ograd.dptr(), workspace.dptr_); - if (NeedSafeAcc(dy.type_flag_, dy.type_flag_)) { - ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[1]}, - {dy.reshape(expanded_rshape)}, expanded_rshape); + if constexpr (std::is_same::value) { + if (NeedSafeAcc(dy.type_flag_, dy.type_flag_)) { + ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[1]}, + {dy.reshape(expanded_rshape)}, expanded_rshape); + } else { + ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[1]}, + {dy.reshape(expanded_rshape)}, expanded_rshape); + } } else { - ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[1]}, - {dy.reshape(expanded_rshape)}, expanded_rshape); +#if MXNET_USE_CUDA + ReduceAxesRTCComputeImpl(ctx, {TBlob(workspace)}, {req[1]}, + {dy.reshape(expanded_rshape)}, expanded_rshape, "red::sum{}"); +#endif // MXNET_USE_CUDA } } }); @@ -383,12 +397,19 @@ inline void NumpyWhereScalarOpBackward(const nnvm::NodeAttrs& attrs, mxnet_op::Kernel, xpu>::Launch( s, ograd.Size(), req[0], cstride, oshape, cond.dptr(), ograd.dptr(), workspace.dptr_); - if (NeedSafeAcc(dx.type_flag_, dx.type_flag_)) { - ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, - {dx.reshape(expanded_lshape)}, expanded_lshape); + if constexpr (std::is_same::value) { + if (NeedSafeAcc(dx.type_flag_, dx.type_flag_)) { + ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, + {dx.reshape(expanded_lshape)}, expanded_lshape); + } else { + ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, + {dx.reshape(expanded_lshape)}, expanded_lshape); + } } else { - ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, - {dx.reshape(expanded_lshape)}, expanded_lshape); +#if MXNET_USE_CUDA + ReduceAxesRTCComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, + {dx.reshape(expanded_lshape)}, expanded_lshape, "red::sum{}"); +#endif // MXNET_USE_CUDA } } }); diff --git a/src/operator/numpy/random/dist_common.h b/src/operator/numpy/random/dist_common.h index ab8afe95f0b1..28ef098c8f80 100644 --- a/src/operator/numpy/random/dist_common.h +++ b/src/operator/numpy/random/dist_common.h @@ -277,6 +277,86 @@ inline bool TwoparamsDistOpConcatShape(const nnvm::NodeAttrs &attrs, return true; } +template +inline void CommonReparamBackwardImpl(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const mxnet::TShape& new_lshape, + const mxnet::TShape& new_rshape, + const mxnet::TShape& new_oshape) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace broadcast; + Stream *s = ctx.get_stream(); + const TBlob lgrad = outputs[0].reshape(new_lshape); + const TBlob rgrad = outputs[1].reshape(new_rshape); + const TBlob ograd = inputs[0].reshape(new_oshape); + // Mean + const TBlob lhs = inputs[2].reshape(new_lshape); + // Scale + const TBlob rhs = inputs[3].reshape(new_rshape); + const TBlob samples = inputs[4].reshape(new_oshape); + const TBlob noise = inputs[5].reshape(new_oshape); + size_t workspace_size_l = ReduceWorkspaceSize( + s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType)); + size_t workspace_size_r = ReduceWorkspaceSize( + s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType)); + size_t workspace_size = std::max(workspace_size_l, workspace_size_r); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); +#if !defined(__CUDACC__) + Reduce( + s, lgrad, req[0], workspace, ograd); + Reduce( + s, rgrad, req[1], workspace, ograd, noise, rhs); +#else + RTCReduce(ctx, lgrad, req[0], workspace, ograd, "red::sum{}", ndim, "identity"); + RTCReduce(ctx, rgrad, req[1], workspace, ograd, noise, rhs, "red::sum{}", ndim, "mul", "left"); +#endif +} + +template +inline void CommonScalarReparamBackwardImpl(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const mxnet::TShape& new_ishape, + const mxnet::TShape& new_oshape, + const bool loc_is_tensor = false) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace broadcast; + Stream *s = ctx.get_stream(); + const TBlob igrad = outputs[0].reshape(new_ishape); + // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor, + // samples, noise] + const TBlob ograd = inputs[0].reshape(new_oshape); + const TBlob itensor = inputs[2].reshape(new_ishape); + const TBlob samples = inputs[3].reshape(new_oshape); + const TBlob noise = inputs[4].reshape(new_oshape); + size_t workspace_size = + ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType)); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); +#if !defined(__CUDACC__) + if (loc_is_tensor) { + Reduce(s, igrad, req[0], + workspace, ograd); + } else { + Reduce( + s, igrad, req[0], workspace, ograd, noise, noise); + } +#else + if (loc_is_tensor) { + RTCReduce(ctx, igrad, req[0], workspace, ograd, "red::sum{}", ndim, "identity"); + } else { + RTCReduce(ctx, igrad, req[0], workspace, ograd, noise, noise, "red::sum{}", + ndim, "mul", "left"); + } +#endif +} + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/random/np_exponential_op.h b/src/operator/numpy/random/np_exponential_op.h index 374b3b428eba..ec4b57695f2e 100644 --- a/src/operator/numpy/random/np_exponential_op.h +++ b/src/operator/numpy/random/np_exponential_op.h @@ -174,8 +174,13 @@ inline void ExponentialReparamBackwardImpl(const OpContext& ctx, ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType)); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); +#if !defined(__CUDACC__) Reduce( s, igrad, req[0], workspace, ograd, noise, noise); +#else + RTCReduce(ctx, igrad, req[0], workspace, ograd, noise, noise, + "red::sum{}", ndim, "mul", "left"); +#endif } template diff --git a/src/operator/numpy/random/np_gamma_op.h b/src/operator/numpy/random/np_gamma_op.h index a0f3299f4d84..29fe493f62a6 100644 --- a/src/operator/numpy/random/np_gamma_op.h +++ b/src/operator/numpy/random/np_gamma_op.h @@ -426,8 +426,13 @@ inline void GammaReparamBackwardImpl(const OpContext& ctx, s, samples.Size(), samples.dptr(), samples.dptr(), DType(scale)); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); +#if !defined(__CUDACC__) Reduce( s, igrad, req[0], workspace, ograd, alpha, samples); +#else + RTCReduce(ctx, igrad, req[0], workspace, ograd, alpha, samples, "red::sum{}", ndim, + "mul", "gamma_implicit_grad"); +#endif Kernel, xpu>::Launch( s, igrad.Size(), igrad.dptr(), igrad.dptr(), DType(scale)); // Convert samples back, otherwise the output would be corrupted. diff --git a/src/operator/numpy/random/np_location_scale_op.h b/src/operator/numpy/random/np_location_scale_op.h index 73403f37f1f0..0179a572bf3f 100644 --- a/src/operator/numpy/random/np_location_scale_op.h +++ b/src/operator/numpy/random/np_location_scale_op.h @@ -275,72 +275,6 @@ void NumpyLocationScaleForward(const nnvm::NodeAttrs &attrs, } } -template -inline void LocationScaleReparamBackwardImpl(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - const mxnet::TShape& new_lshape, - const mxnet::TShape& new_rshape, - const mxnet::TShape& new_oshape) { - using namespace mshadow; - using namespace mshadow::expr; - using namespace broadcast; - Stream *s = ctx.get_stream(); - const TBlob lgrad = outputs[0].reshape(new_lshape); - const TBlob rgrad = outputs[1].reshape(new_rshape); - const TBlob ograd = inputs[0].reshape(new_oshape); - // Mean - const TBlob lhs = inputs[2].reshape(new_lshape); - // Scale - const TBlob rhs = inputs[3].reshape(new_rshape); - const TBlob samples = inputs[4].reshape(new_oshape); - const TBlob noise = inputs[5].reshape(new_oshape); - size_t workspace_size_l = ReduceWorkspaceSize( - s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType)); - size_t workspace_size_r = ReduceWorkspaceSize( - s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType)); - size_t workspace_size = std::max(workspace_size_l, workspace_size_r); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); - Reduce( - s, lgrad, req[0], workspace, ograd); - Reduce( - s, rgrad, req[1], workspace, ograd, noise, rhs); -} - -template -inline void ScalarLocationScaleReparamBackwardImpl(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - const mxnet::TShape& new_ishape, - const mxnet::TShape& new_oshape, - const bool loc_is_tensor) { - using namespace mshadow; - using namespace mshadow::expr; - using namespace broadcast; - Stream *s = ctx.get_stream(); - const TBlob igrad = outputs[0].reshape(new_ishape); - // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor, - // samples, noise] - const TBlob ograd = inputs[0].reshape(new_oshape); - const TBlob itensor = inputs[2].reshape(new_ishape); - const TBlob samples = inputs[3].reshape(new_oshape); - const TBlob noise = inputs[4].reshape(new_oshape); - size_t workspace_size = - ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType)); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); - if (loc_is_tensor) { - Reduce(s, igrad, req[0], - workspace, ograd); - } else { - Reduce( - s, igrad, req[0], workspace, ograd, noise, noise); - } -} - // Allow logistic and gumbel sampling to be differentiable, // using reparameterization trick described in: // Auto-encoding variational bayes. @@ -359,7 +293,7 @@ void LocationScaleReparamBackward(const nnvm::NodeAttrs& attrs, if (outputs.size() == 0U) { return; } - const NumpyLocationScaleParam ¶m = nnvm::get(attrs.parsed); + const auto ¶m = nnvm::get(attrs.parsed); // [tensor tensor] case if (inputs.size() == 6U) { mxnet::TShape new_lshape, new_rshape, new_oshape; @@ -367,7 +301,7 @@ void LocationScaleReparamBackward(const nnvm::NodeAttrs& attrs, &new_lshape, &new_rshape, &new_oshape); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { - LocationScaleReparamBackwardImpl( + CommonReparamBackwardImpl( ctx, inputs, req, outputs, new_lshape, new_rshape, new_oshape); }); }); @@ -380,7 +314,7 @@ void LocationScaleReparamBackward(const nnvm::NodeAttrs& attrs, bool loc_is_tensor = !param.loc.has_value(); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { - ScalarLocationScaleReparamBackwardImpl( + CommonScalarReparamBackwardImpl( ctx, inputs, req, outputs, new_ishape, new_oshape, loc_is_tensor); }); }); diff --git a/src/operator/numpy/random/np_normal_op.h b/src/operator/numpy/random/np_normal_op.h index e43d98de0168..06b5bfaabf05 100644 --- a/src/operator/numpy/random/np_normal_op.h +++ b/src/operator/numpy/random/np_normal_op.h @@ -161,7 +161,7 @@ void NumpyNormalForward(const nnvm::NodeAttrs &attrs, const std::vector &outputs) { using namespace mshadow; using namespace mxnet_op; - const NumpyNormalParam ¶m = nnvm::get(attrs.parsed); + const auto ¶m = nnvm::get(attrs.parsed); Stream *s = ctx.get_stream(); // Generate base random number. Random *prnd = ctx.requested[0].get_random(s); @@ -240,72 +240,6 @@ void NumpyNormalForward(const nnvm::NodeAttrs &attrs, } } -template -inline void NormalReparamBackwardImpl(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - const mxnet::TShape& new_lshape, - const mxnet::TShape& new_rshape, - const mxnet::TShape& new_oshape) { - using namespace mshadow; - using namespace mshadow::expr; - using namespace broadcast; - Stream *s = ctx.get_stream(); - const TBlob lgrad = outputs[0].reshape(new_lshape); - const TBlob rgrad = outputs[1].reshape(new_rshape); - const TBlob ograd = inputs[0].reshape(new_oshape); - // Mean - const TBlob lhs = inputs[2].reshape(new_lshape); - // Variance - const TBlob rhs = inputs[3].reshape(new_rshape); - const TBlob samples = inputs[4].reshape(new_oshape); - const TBlob noise = inputs[5].reshape(new_oshape); - size_t workspace_size_l = ReduceWorkspaceSize( - s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType)); - size_t workspace_size_r = ReduceWorkspaceSize( - s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType)); - size_t workspace_size = std::max(workspace_size_l, workspace_size_r); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); - Reduce(s, - lgrad, req[0], workspace, ograd); - Reduce( - s, rgrad, req[1], workspace, ograd, noise, rhs); -} - -template -inline void ScalarNormalReparamBackwardImpl(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - const mxnet::TShape& new_ishape, - const mxnet::TShape& new_oshape, - const bool loc_is_tensor) { - using namespace mshadow; - using namespace mshadow::expr; - using namespace broadcast; - Stream *s = ctx.get_stream(); - const TBlob igrad = outputs[0].reshape(new_ishape); - // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor, - // samples, noise] - const TBlob ograd = inputs[0].reshape(new_oshape); - const TBlob itensor = inputs[2].reshape(new_ishape); - const TBlob samples = inputs[3].reshape(new_oshape); - const TBlob noise = inputs[4].reshape(new_oshape); - size_t workspace_size = - ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType)); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); - if (loc_is_tensor) { - Reduce(s, igrad, req[0], - workspace, ograd); - } else { - Reduce( - s, igrad, req[0], workspace, ograd, noise, noise); - } -} - // Allow normal sampling to be differentiable, // using reparameterization trick described in: // Auto-encoding variational bayes. @@ -324,7 +258,7 @@ void NormalReparamBackward(const nnvm::NodeAttrs& attrs, if (outputs.size() == 0U) { return; } - const NumpyNormalParam ¶m = nnvm::get(attrs.parsed); + const auto ¶m = nnvm::get(attrs.parsed); // [tensor tensor] case if (inputs.size() == 6U) { mxnet::TShape new_lshape, new_rshape, new_oshape; @@ -332,7 +266,7 @@ void NormalReparamBackward(const nnvm::NodeAttrs& attrs, &new_lshape, &new_rshape, &new_oshape); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { - NormalReparamBackwardImpl( + CommonReparamBackwardImpl( ctx, inputs, req, outputs, new_lshape, new_rshape, new_oshape); }); }); @@ -345,7 +279,7 @@ void NormalReparamBackward(const nnvm::NodeAttrs& attrs, bool loc_is_tensor = !param.loc.has_value(); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { - ScalarNormalReparamBackwardImpl( + CommonScalarReparamBackwardImpl( ctx, inputs, req, outputs, new_ishape, new_oshape, loc_is_tensor); }); }); diff --git a/src/operator/numpy/random/np_pareto_op.h b/src/operator/numpy/random/np_pareto_op.h index 5e5d26aae4d2..16731c126324 100644 --- a/src/operator/numpy/random/np_pareto_op.h +++ b/src/operator/numpy/random/np_pareto_op.h @@ -155,32 +155,6 @@ void NumpyParetoForward(const nnvm::NodeAttrs &attrs, } } -template -inline void ScalarParetoReparamBackwardImpl(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - const mxnet::TShape& new_ishape, - const mxnet::TShape& new_oshape) { - using namespace mshadow; - using namespace mshadow::expr; - using namespace broadcast; - Stream *s = ctx.get_stream(); - const TBlob igrad = outputs[0].reshape(new_ishape); - // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor, - // samples, noise] - const TBlob ograd = inputs[0].reshape(new_oshape); - const TBlob itensor = inputs[2].reshape(new_ishape); - const TBlob samples = inputs[3].reshape(new_oshape); - const TBlob noise = inputs[4].reshape(new_oshape); - size_t workspace_size = - ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType)); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); - Reduce( - s, igrad, req[0], workspace, ograd, noise, noise); - } - template void ParetoReparamBackward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -202,7 +176,7 @@ if (inputs.size() == 5U) { &new_ishape, &new_ishape, &new_oshape); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { - ScalarParetoReparamBackwardImpl( + CommonScalarReparamBackwardImpl( ctx, inputs, reqs, outputs, new_ishape, new_oshape); }); }); diff --git a/src/operator/numpy/random/np_rayleigh_op.h b/src/operator/numpy/random/np_rayleigh_op.h index 0f940e511a32..75c4784a515e 100644 --- a/src/operator/numpy/random/np_rayleigh_op.h +++ b/src/operator/numpy/random/np_rayleigh_op.h @@ -153,32 +153,6 @@ void NumpyRayleighForward(const nnvm::NodeAttrs &attrs, } } -template -inline void ScalarRayleighReparamBackwardImpl(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - const mxnet::TShape& new_ishape, - const mxnet::TShape& new_oshape) { - using namespace mshadow; - using namespace mshadow::expr; - using namespace broadcast; - Stream *s = ctx.get_stream(); - const TBlob igrad = outputs[0].reshape(new_ishape); - // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor, - // samples, noise] - const TBlob ograd = inputs[0].reshape(new_oshape); - const TBlob itensor = inputs[2].reshape(new_ishape); - const TBlob samples = inputs[3].reshape(new_oshape); - const TBlob noise = inputs[4].reshape(new_oshape); - size_t workspace_size = - ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType)); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); - Reduce( - s, igrad, req[0], workspace, ograd, noise, noise); -} - template void RayleighReparamBackward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -200,7 +174,7 @@ void RayleighReparamBackward(const nnvm::NodeAttrs& attrs, &new_ishape, &new_ishape, &new_oshape); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { - ScalarRayleighReparamBackwardImpl( + CommonScalarReparamBackwardImpl( ctx, inputs, req, outputs, new_ishape, new_oshape); }); }); diff --git a/src/operator/numpy/random/np_weibull_op.h b/src/operator/numpy/random/np_weibull_op.h index 970dc859b97b..a7d6d5d2c405 100644 --- a/src/operator/numpy/random/np_weibull_op.h +++ b/src/operator/numpy/random/np_weibull_op.h @@ -155,32 +155,6 @@ void NumpyWeibullForward(const nnvm::NodeAttrs &attrs, } } -template -inline void ScalarWeibullReparamBackwardImpl(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - const mxnet::TShape& new_ishape, - const mxnet::TShape& new_oshape) { - using namespace mshadow; - using namespace mshadow::expr; - using namespace broadcast; - Stream *s = ctx.get_stream(); - const TBlob igrad = outputs[0].reshape(new_ishape); - // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor, - // samples, noise] - const TBlob ograd = inputs[0].reshape(new_oshape); - const TBlob itensor = inputs[2].reshape(new_ishape); - const TBlob samples = inputs[3].reshape(new_oshape); - const TBlob noise = inputs[4].reshape(new_oshape); - size_t workspace_size = - ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType)); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); - Reduce( - s, igrad, req[0], workspace, ograd, noise, noise); - } - template void WeibullReparamBackward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -202,7 +176,7 @@ if (inputs.size() == 5U) { &new_ishape, &new_ishape, &new_oshape); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { - ScalarWeibullReparamBackwardImpl( + CommonScalarReparamBackwardImpl( ctx, inputs, reqs, outputs, new_ishape, new_oshape); }); }); diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h index d8814cc6cb20..cfbdb7f8e0ab 100644 --- a/src/operator/quantization/quantize_v2-inl.h +++ b/src/operator/quantization/quantize_v2-inl.h @@ -205,10 +205,17 @@ class QuantizeV2Operator { dev_id); Tensor workspace(temp_space.dptr_ + 2 * actual_float_size, Shape1(temp_reduce_size), s); +#if !defined(__CUDACC__) broadcast::Reduce( s, in_min_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); broadcast::Reduce( s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); +#else + broadcast::RTCReduce(ctx, in_min_t.reshape(dst_shape), kWriteTo, workspace, + inputs[0].reshape(src_shape), "red::minimum{}", 2, "identity"); + broadcast::RTCReduce(ctx, in_max_t.reshape(dst_shape), kWriteTo, workspace, + inputs[0].reshape(src_shape), "red::maximum{}", 2, "identity"); +#endif if (out_type == mshadow::kUint8) { Kernel::Launch( s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), diff --git a/src/operator/quantization/requantize-inl.h b/src/operator/quantization/requantize-inl.h index 2bdc3a712961..56686708dba4 100644 --- a/src/operator/quantization/requantize-inl.h +++ b/src/operator/quantization/requantize-inl.h @@ -148,6 +148,7 @@ void RequantizeForward(const nnvm::NodeAttrs& attrs, temp_space.dptr_ + 8) + 1, Shape1(1), xpu::kDevMask, dev_id); Tensor workspace( temp_space.dptr_+2*actual_float_size+2*actual_quantized_size, Shape1(temp_reduce_size), s); +#if !defined(__CUDACC__) broadcast::Reduce( s, actual_min_quantized.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); @@ -158,6 +159,18 @@ void RequantizeForward(const nnvm::NodeAttrs& attrs, broadcast::Reduce( s, actual_max_quantized.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); +#else + broadcast::RTCReduce(ctx, actual_min_quantized.reshape(dst_shape), + kWriteTo, workspace, inputs[0].reshape(src_shape), + "red::minimum{}", 2, "identity"); + Kernel::Launch(s, 1, + actual_min_float.dptr_, actual_min_quantized.dptr(), + inputs[1].dptr(), inputs[2].dptr()); + + broadcast::RTCReduce(ctx, actual_max_quantized.reshape(dst_shape), + kWriteTo, workspace, inputs[0].reshape(src_shape), + "red::maximum{}", 2, "identity"); +#endif Kernel::Launch(s, 1, actual_max_float.dptr_, actual_max_quantized.dptr(), inputs[1].dptr(), inputs[2].dptr()); diff --git a/src/operator/random/pdf_op.h b/src/operator/random/pdf_op.h index 57bddfc2b1fe..4dd49fc50721 100644 --- a/src/operator/random/pdf_op.h +++ b/src/operator/random/pdf_op.h @@ -607,11 +607,22 @@ void PdfOpBackward(const nnvm::NodeAttrs& attrs, } Tensor red_work( tmp_space.dptr_ + pnum * outputs[0].Size() * sizeof(DType), Shape1(red_work_size), s); - broadcast::Reduce( - s, outputs[1].reshape(dst_shape), req[1], red_work, grads[1].reshape(src_shape)); - if (pnum == 2) { + if constexpr (std::is_same::value) { broadcast::Reduce( - s, outputs[2].reshape(dst_shape), req[2], red_work, grads[2].reshape(src_shape)); + s, outputs[1].reshape(dst_shape), req[1], red_work, grads[1].reshape(src_shape)); + if (pnum == 2) { + broadcast::Reduce( + s, outputs[2].reshape(dst_shape), req[2], red_work, grads[2].reshape(src_shape)); + } + } else { +#if MXNET_USE_CUDA + broadcast::RTCReduce(ctx, outputs[1].reshape(dst_shape), req[1], red_work, + grads[1].reshape(src_shape), "red::sum{}", 2, "identity"); + if (pnum == 2) { + broadcast::RTCReduce(ctx, outputs[2].reshape(dst_shape), req[2], red_work, + grads[2].reshape(src_shape), "red::sum{}", 2, "identity"); + } +#endif } }); } diff --git a/src/operator/tensor/broadcast_reduce-inl.cuh b/src/operator/tensor/broadcast_reduce-inl.cuh deleted file mode 100644 index c7a7c478cbb3..000000000000 --- a/src/operator/tensor/broadcast_reduce-inl.cuh +++ /dev/null @@ -1,408 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2015-2020 by Contributors - * \file broadcast_reduce-inl.cuh - * \brief CUDA implementations for binary broadcast and reduce - * \author Antti-Pekka Hynninen, Przemyslaw Tredak -*/ -#ifndef MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_INL_CUH_ -#define MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_INL_CUH_ - -using namespace mshadow::cuda; - -template -__launch_bounds__(nthread_reduce) -__global__ void reduce_kernel(const int N, const int M, const bool addto, - const DType* __restrict big, OType *small, - const Shape big_shape0, const Shape small_shape, - const Shape big_shape, const Shape big_stride, - const int Mnext, const bool do_transpose) { - extern __shared__ char shTileChar[]; - AType* shTile = (AType*)(shTileChar); - const int tid = threadIdx.x + threadIdx.y*blockDim.x; - const int bx = (do_transpose) ? blockDim.y : blockDim.x; - const int by = (do_transpose) ? blockDim.x : blockDim.y; - const int tidx = (do_transpose) ? tid / by : threadIdx.x; - const int tidy = (do_transpose) ? tid % by : threadIdx.y; - for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) { - // This TB handles M range [Mstart, ...., Mend - 1] - const int Mstart = (int)((uint64_t)M*(uint64_t)m0/(uint64_t)Mnext); - const int Mend = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext); - for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) { - int idx = idx0 + tidx; - Shape coord = mxnet_op::unravel(idx, small_shape); - int idx_big0 = mxnet_op::ravel(coord, big_shape0); - - AType val, residual; - Reducer::SetInitValue(val, residual); - if (idx < N) { - for (int k = tidy + Mstart; k < Mend; k += by*unroll) { - int idx_big[unroll]; - #pragma unroll - for (int u=0;u < unroll;u++) { - idx_big[u] = idx_big0 + mxnet_op::unravel_dot(k + u*by, big_shape, big_stride); - } - DType tmp[unroll]; - #pragma unroll - for (int u=0;u < unroll;u++) { - if (k + u*by < Mend) { - tmp[u] = OP::Map(big[idx_big[u]]); - } - } - #pragma unroll - for (int u=0;u < unroll;u++) { - if (k + u*by < Mend) Reducer::Reduce(val, AType(tmp[u]), residual); - } - } - } - - // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0 - if (by > 1) { - // Fix bx to avoid bank conflicts. Assumes warpSize number of banks - const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx; - const int it0 = tidx + tidy*fbx; - shTile[it0 * 2] = val; - shTile[it0 * 2 + 1] = residual; - __syncthreads(); - for (int t=1;t < by;t <<= 1) { - AType tmp, tmp_residual; - Reducer::SetInitValue(tmp, tmp_residual); - if (tidy + t < by) { - tmp = shTile[(it0 + t*fbx) * 2]; - tmp_residual = shTile[(it0 + t*fbx) * 2 + 1]; - } - __syncthreads(); - Reducer::Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual); - __syncthreads(); - } - if (idx < N && tidy == 0) { - Reducer::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]); - assign(&small[idx + m0*N], addto, OType(shTile[tidx * 2])); - } - } else { - if (idx < N) { - Reducer::Finalize(val, residual); - assign(&small[idx + m0*N], addto, OType(val)); - } - } - } - } -} - -template -__launch_bounds__(nthread_reduce) -__global__ void reduce_kernel(const int N, const int M, const bool addto, - const DType* __restrict big, const DType* __restrict lhs, - const DType* __restrict rhs, DType *small, - const Shape big_shape0, const Shape lhs_shape0, - const Shape rhs_shape0, const Shape small_shape, - const Shape big_shape, const Shape lhs_shape, - const Shape rhs_shape, const Shape big_stride, - const Shape lhs_stride, const Shape rhs_stride, - const int Mnext, const bool do_transpose) { - extern __shared__ char shTileChar[]; - DType* shTile = (DType*)(shTileChar); - const int tid = threadIdx.x + threadIdx.y*blockDim.x; - const int bx = (do_transpose) ? blockDim.y : blockDim.x; - const int by = (do_transpose) ? blockDim.x : blockDim.y; - const int tidx = (do_transpose) ? tid / by : threadIdx.x; - const int tidy = (do_transpose) ? tid % by : threadIdx.y; - for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) { - // This TB handles M range [Mstart, ...., Mend - 1] - const int Mstart = (int)((uint64_t)M*(uint64_t)m0/(uint64_t)Mnext); - const int Mend = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext); - for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) { - int idx = idx0 + tidx; - Shape coord = mxnet_op::unravel(idx, small_shape); - int idx_big0 = mxnet_op::ravel(coord, big_shape0); - int idx_lhs0 = mxnet_op::ravel(coord, lhs_shape0); - int idx_rhs0 = mxnet_op::ravel(coord, rhs_shape0); - - DType val, residual; - Reducer::SetInitValue(val, residual); - if (idx < N) { - for (int k = tidy + Mstart; k < Mend; k += by*unroll) { - int idx_big[unroll]; - int idx_lhs[unroll]; - int idx_rhs[unroll]; - #pragma unroll - for (int u=0;u < unroll;u++) { - idx_big[u] = idx_big0 + mxnet_op::unravel_dot(k + u*by, big_shape, big_stride); - idx_lhs[u] = idx_lhs0 + mxnet_op::unravel_dot(k + u*by, lhs_shape, lhs_stride); - idx_rhs[u] = idx_rhs0 + mxnet_op::unravel_dot(k + u*by, rhs_shape, rhs_stride); - } - DType tmp[unroll]; - #pragma unroll - for (int u=0;u < unroll;u++) { - if (k + u*by < Mend) { - tmp[u] = OP1::Map(big[idx_big[u]], OP2::Map(lhs[idx_lhs[u]], rhs[idx_rhs[u]])); - } - } - #pragma unroll - for (int u=0;u < unroll;u++) { - if (k + u*by < Mend) Reducer::Reduce(val, tmp[u], residual); - } - } - } - - // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0 - if (by > 1) { - // Fix bx to avoid bank conflicts. Assumes warpSize number of banks - const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx; - const int it0 = tidx + tidy*fbx; - shTile[it0 * 2] = val; - shTile[it0 * 2 + 1] = residual; - __syncthreads(); - for (int t=1;t < by;t <<= 1) { - DType tmp, tmp_residual; - Reducer::SetInitValue(tmp, tmp_residual); - if (tidy + t < by) { - tmp = shTile[(it0 + t*fbx) * 2]; - tmp_residual = shTile[(it0 + t*fbx) * 2 + 1]; - } - __syncthreads(); - Reducer::Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual); - __syncthreads(); - } - if (idx < N && tidy == 0) { - Reducer::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]); - assign(&small[idx + m0*N], addto, shTile[tidx * 2]); - } - } else { - if (idx < N) { - Reducer::Finalize(val, residual); - assign(&small[idx + m0*N], addto, val); - } - } - } - } -} - -// Simple reduction of lines when M is small -template -__launch_bounds__(kMaxThreadsPerBlock) -__global__ void reduce_lines_kernel(const int N, const int M, const bool addto, - const int small_in_stride, const DType* __restrict small_in, DType *small_out) { - for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { - - DType val, residual; - Reducer::SetInitValue(val, residual); - for (int k = 0; k < M; k++) { - Reducer::Reduce(val, small_in[idx + k*small_in_stride], residual); - } - - if (idx < N) { - Reducer::Finalize(val, residual); - assign(&small_out[idx], addto, val); - } - - } -} - -template -__launch_bounds__(kMaxThreadsPerBlock) -__global__ void reduce_kernel_M1(const int N, const bool addto, - const DType* __restrict big, OType *small, const Shape bshape, - const Shape sshape) { - for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { - Shape coord = mxnet_op::unravel(idx, sshape); - int j = mxnet_op::ravel(coord, bshape); - AType val, residual; - Reducer::SetInitValue(val, residual); - Reducer::Reduce(val, AType(OP::Map(big[j])), residual); - Reducer::Finalize(val, residual); - assign(&small[idx], addto, OType(val)); - } -} - -template -__launch_bounds__(kMaxThreadsPerBlock) -__global__ void reduce_kernel_M1(const int N, const bool addto, - const DType* __restrict big, - const DType* __restrict lhs, - const DType* __restrict rhs, - DType *small, - const Shape big_shape, - const Shape lhs_shape, - const Shape rhs_shape, - const Shape small_shape) { - for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { - Shape coord = mxnet_op::unravel(idx, small_shape); - int idx_big = mxnet_op::ravel(coord, big_shape); - int idx_lhs = mxnet_op::ravel(coord, lhs_shape); - int idx_rhs = mxnet_op::ravel(coord, rhs_shape); - DType val, residual; - Reducer::SetInitValue(val, residual); - Reducer::Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual); - Reducer::Finalize(val, residual); - assign(&small[idx], addto, val); - } -} - -#define KERNEL_UNROLL_SWITCH(do_unroll, unrollAmount, unrollVar, ...) \ - if (do_unroll) { \ - const int unrollVar = unrollAmount; \ - {__VA_ARGS__} \ - } else { \ - const int unrollVar = 1; \ - {__VA_ARGS__} \ - } - -template -void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req, - const TBlob& big, const Tensor& workspace, - const ReduceImplConfig& config) { - if (config.M == 1) { - reduce_kernel_M1 - <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>( - config.N, req == kAddTo, big.dptr(), small.dptr(), big.shape_.get(), - small.shape_.get()); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1); - } else { - OType* small_dptr = small.dptr(); - bool addto = (req == kAddTo); - if (config.Mnext > 1) { - // small_dptr[] is N*Mnext*sizeof(DType) bytes - small_dptr = reinterpret_cast(workspace.dptr_); - addto = false; - // Check that the workspace is contigiuous - CHECK_EQ(workspace.CheckContiguous(), true); - // Check that we have enough storage - CHECK_GE(workspace.size(0), config.workspace_size); - } - - const int by = (config.kernel_1.do_transpose) ? - config.kernel_1.blockDim.x : config.kernel_1.blockDim.y; - const bool do_unroll = ( config.M / (by*config.Mnext) >= unroll_reduce ); - KERNEL_UNROLL_SWITCH(do_unroll, unroll_reduce, UNROLL, { - reduce_kernel - <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>( - config.N, config.M, addto, big.dptr(), small_dptr, big.shape_.get(), - small.shape_.get(), config.rshape.get(), config.rstride.get(), - config.Mnext, config.kernel_1.do_transpose); - }); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel); - - if (config.Mnext > 1) { - reduce_lines_kernel - <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>> - (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr()); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel); - } - } -} - -template -void ReduceImpl(cudaStream_t stream, const TBlob& small, const TBlob& lhs, const TBlob& rhs, - const OpReqType req, const TBlob& big, const Tensor& workspace, - const ReduceImplConfig& config) { - if (config.M == 1) { - reduce_kernel_M1 - <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>( - config.N, req == kAddTo, big.dptr(), lhs.dptr(), rhs.dptr(), - small.dptr(), big.shape_.get(), lhs.shape_.get(), - rhs.shape_.get(), small.shape_.get()); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1); - } else { - DType* small_dptr = small.dptr(); - bool addto = (req == kAddTo); - if (config.Mnext > 1) { - // small_dptr[] is N*Mnext*sizeof(DType) bytes - small_dptr = reinterpret_cast(workspace.dptr_); - addto = false; - // Check that the workspace is contigiuous - CHECK_EQ(workspace.CheckContiguous(), true); - // Check that we have enough storage - CHECK_GE(workspace.size(0), config.workspace_size); - } - - const int by = (config.kernel_1.do_transpose) ? - config.kernel_1.blockDim.x : config.kernel_1.blockDim.y; - const bool do_unroll = ( config.M / (by*config.Mnext) >= unroll_reduce ); - KERNEL_UNROLL_SWITCH(do_unroll, unroll_reduce, UNROLL, { - reduce_kernel - <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>( - config.N, config.M, addto, big.dptr(), lhs.dptr(), rhs.dptr(), - small_dptr, big.shape_.get(), lhs.shape_.get(), - rhs.shape_.get(), small.shape_.get(), config.rshape.get(), - config.lhs_shape.get(), config.rhs_shape.get(), config.rstride.get(), - config.lhs_stride.get(), config.rhs_stride.get(), config.Mnext, - config.kernel_1.do_transpose); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel); - }); - - if (config.Mnext > 1) { - reduce_lines_kernel - <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>> - (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr()); - MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel); - } - } -} - -#undef KERNEL_UNROLL_SWITCH - -template -void Reduce(Stream *s, const TBlob& small, const OpReqType req, - const Tensor& workspace, const TBlob& big) { - if (req == kNullOp) return; - cudaStream_t stream = Stream::GetStream(s); - ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr, sizeof(DType)); - if (safe_acc) { - MXNET_ACC_TYPE_SWITCH(mshadow::DataType::kFlag, DataType, AType, { - typedef typename std::conditional::type AccType; - MSHADOW_TYPE_SWITCH(small.type_flag_, OType, { - typedef typename std::conditional::type OutType; - config = ReduceImplConfig(small.shape_, big.shape_, nullptr, nullptr, - sizeof(AccType)); - ReduceImpl( - stream, small, req, big, workspace, config); - }); - }); - } else { - ReduceImpl(stream, small, req, big, workspace, config); - } -} - -template -void ReduceBool(Stream *s, const TBlob& small, const OpReqType req, - const Tensor& workspace, const TBlob& big) { - if (req == kNullOp) return; - cudaStream_t stream = Stream::GetStream(s); - ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr, sizeof(DType)); - ReduceImpl(stream, small, req, big, workspace, config); -} - -template -void ReduceWithExtraMem(Stream* s, const TBlob& small, const OpReqType req, - const Tensor& workspace, const TBlob& big) {}; - -template -void Reduce(Stream *s, const TBlob& small, const OpReqType req, - const Tensor& workspace, const TBlob& big, - const TBlob& lhs, const TBlob& rhs) { - if (req == kNullOp) return; - cudaStream_t stream = Stream::GetStream(s); - ReduceImplConfig config(small.shape_, big.shape_, &lhs.shape_, &rhs.shape_, sizeof(DType)); - ReduceImpl(stream, small, lhs, rhs, req, big, workspace, config); -} - -#endif //MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_INL_CUH_ diff --git a/src/operator/tensor/broadcast_reduce-inl.h b/src/operator/tensor/broadcast_reduce-inl.h index ad3bd2a2bec9..987ab73536f4 100644 --- a/src/operator/tensor/broadcast_reduce-inl.h +++ b/src/operator/tensor/broadcast_reduce-inl.h @@ -652,10 +652,6 @@ inline size_t ReduceWorkspaceSize(Stream *s, const ::mxnet::TShape& small, return config.workspace_size; } -#ifdef __CUDACC__ -#include "broadcast_reduce-inl.cuh" -#endif - #endif // MXNET_USE_CUDA template diff --git a/src/operator/tensor/broadcast_reduce_minmax_value.cu b/src/operator/tensor/broadcast_reduce_minmax_value.cu index baf79feb5c60..c8cb757cd9a3 100644 --- a/src/operator/tensor/broadcast_reduce_minmax_value.cu +++ b/src/operator/tensor/broadcast_reduce_minmax_value.cu @@ -28,13 +28,15 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(max) -.set_attr("FCompute", ReduceAxesCompute); +.set_attr("FCompute", ReduceAxesRTCCompute + {"identity", "red::maximum{}", false}); NNVM_REGISTER_OP(_backward_max) .set_attr("FCompute", ReduceAxesBackwardUseInOut); NNVM_REGISTER_OP(min) -.set_attr("FCompute", ReduceAxesCompute); +.set_attr("FCompute", ReduceAxesRTCCompute + {"identity", "red::minimum{}", false}); NNVM_REGISTER_OP(_backward_min) .set_attr("FCompute", ReduceAxesBackwardUseInOut); diff --git a/src/operator/tensor/broadcast_reduce_op.cc b/src/operator/tensor/broadcast_reduce_op.cc new file mode 100644 index 000000000000..b5440406be8c --- /dev/null +++ b/src/operator/tensor/broadcast_reduce_op.cc @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "broadcast_reduce_op.h" +#include "../numpy/np_broadcast_reduce_op.h" +#include "elemwise_binary_scalar_op.h" +#include "mxnet/tuple.h" +#include + +namespace mxnet { +namespace op { + +#if MXNET_USE_CUDA + +void ReduceAxesRTCComputeWithWorkspaceImpl(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const std::string& reducer, + const bool normalize, + const std::string& OP, + const mshadow::Tensor& workspace, + const mxnet::TShape& src_shape, + const mxnet::TShape& dst_shape, + const int ddof) { + const TBlob in_data = inputs[0].reshape(src_shape); + const TBlob out_data = outputs[0].reshape(dst_shape); + BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, out_data, req[0], workspace, in_data, reducer, NDim, OP); + }); + if (normalize) { + NumpyBinaryScalarParam p{}; + p.scalar = static_cast(src_shape.Size()/dst_shape.Size() - ddof); + NodeAttrs a; + a.parsed = p; + BinaryScalarRTCCompute {"div"}(a, ctx, {out_data}, {kWriteInplace}, {out_data}); + } +} + +void ReduceAxesRTCComputeImpl(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const mxnet::TShape& small, + const std::string& reducer, + const bool normalize, + const std::string& OP, + const int ddof) { + using namespace mshadow; + + mxnet::TShape src_shape, dst_shape; + BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape); + const TBlob in_data = inputs[0].reshape(src_shape); + const TBlob out_data = outputs[0].reshape(dst_shape); + Stream* s = ctx.get_stream(); + size_t workspace_size = broadcast::ReduceWorkspaceSize( + s, out_data.shape_, req[0], in_data.shape_, + common::mshadow_type_info(inputs[0].type_flag_).size); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + ReduceAxesRTCComputeWithWorkspaceImpl(ctx, inputs, req, outputs, reducer, normalize, + OP, workspace, src_shape, dst_shape, ddof); +} + +namespace { +template +void PrepareReduce(const Param& param, + const std::vector& inputs, + const std::vector& outputs, + mxnet::TShape* shape, int* ddof); + +template <> +void PrepareReduce(const ReduceAxesParam& param, + const std::vector& inputs, + const std::vector& outputs, + mxnet::TShape* small, int* ddof) { + if (param.keepdims) { + *small = outputs[0].shape_; + } else { + *small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, param.exclude); + } + + *ddof = 0; +} + +template <> +void PrepareReduce(const NumpyReduceAxesNoDTypeParam& param, + const std::vector& inputs, + const std::vector& outputs, + mxnet::TShape* small, int* ddof) { + if (param.initial.has_value()) { + LOG(FATAL) << "initial is not supported yet"; + } + if (param.keepdims) { + *small = outputs[0].shape_; + } else { + *small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true); + } + + *ddof = 0; +} + +template <> +void PrepareReduce(const NumpyReduceAxesParam& param, + const std::vector& inputs, + const std::vector& outputs, + mxnet::TShape* small, int* ddof) { + if (param.initial.has_value()) { + LOG(FATAL) << "initial is not supported yet"; + } + if (param.keepdims) { + *small = outputs[0].shape_; + } else { + *small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true); + } + + *ddof = 0; +} + +template <> +void PrepareReduce(const NumpyReduceAxesBoolParam& param, + const std::vector& inputs, + const std::vector& outputs, + mxnet::TShape* small, int* ddof) { + if (param.keepdims) { + *small = outputs[0].shape_; + } else { + *small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true); + } + + *ddof = 0; +} + +} // namespace + +template +void ReduceAxesRTCCompute::operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (req[0] == kNullOp) return; + mxnet::TShape small; + int ddof; + const auto& param = nnvm::get(attrs.parsed); + CHECK_NE(req[0], kWriteInplace) << "Reduce does not support write in-place"; + PrepareReduce(param, inputs, outputs, &small, &ddof); + if (outputs[0].shape_.Size() == 0U) return; // zero-size tensor + if (inputs[0].shape_.Size() == 0) { + if (normalize && mxnet::common::is_float(outputs[0].type_flag_)) { + LOG(WARNING) << "WARNING: Mean of empty slice."; + NumpyBinaryScalarParam p{}; + p.scalar = std::numeric_limits::quiet_NaN(); + NodeAttrs a; + a.parsed = p; + BinaryScalarRTCCompute {"right"} (a, ctx, outputs, {kWriteTo}, outputs); + } else { + if (normalize) { + LOG(WARNING) << "WARNING: nan is outside the range of"<< + "representable values of type 'int'"; + } + if (init == 0 && req[0] == kAddTo) return; + NumpyBinaryScalarParam p{}; + p.scalar = init; + NodeAttrs a; + a.parsed = p; + BinaryScalarRTCCompute {"right"} (a, ctx, outputs, {req[0]}, outputs); + } + return; + } + + ReduceAxesRTCComputeImpl(ctx, inputs, req, outputs, small, reducer, normalize, OP, ddof); +} + +template struct ReduceAxesRTCCompute; +template struct ReduceAxesRTCCompute; +template struct ReduceAxesRTCCompute; +template struct ReduceAxesRTCCompute; +template struct ReduceAxesRTCCompute; +template struct ReduceAxesRTCCompute; + +#endif + +} // namespace op +} // namespace mxnet diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index c4e3dae35ab9..65ee0202701d 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -700,6 +700,44 @@ void ReduceAxesCompute(const nnvm::NodeAttrs& attrs, ReduceAxesComputeImpl(ctx, inputs, req, outputs, small); } +#if MXNET_USE_CUDA + +template +struct ReduceAxesRTCCompute { + std::string OP; + std::string reducer; + bool normalize; + + void operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +}; + +void ReduceAxesRTCComputeImpl(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const mxnet::TShape& small, + const std::string& reducer, + const bool normalize = false, + const std::string& OP = "identity", + const int ddof = 0); + +void ReduceAxesRTCComputeWithWorkspaceImpl(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const std::string& reducer, + const bool normalize, + const std::string& OP, + const mshadow::Tensor& workspace, + const mxnet::TShape& src_shape, + const mxnet::TShape& dst_shape, + const int ddof = 0); +#endif + template struct ReduceCsrKernel; @@ -1480,7 +1518,8 @@ void LpNormCompute(const nnvm::NodeAttrs& attrs, } else { small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, false); } - bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true); +#if !defined(__CUDACC__) + bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false); if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) { common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for LpNorm with float16 inputs. " "See https://mxnet.apache.org/api/faq/env_var " @@ -1503,6 +1542,15 @@ void LpNormCompute(const nnvm::NodeAttrs& attrs, ctx, inputs, req, outputs, small); } } +#else + const std::string &red = param.ord == 1 + ? "red::sum{}" + : "red::nrm2{}"; + const std::string &op = param.ord == 1 + ? "abs" + : "identity"; + ReduceAxesRTCComputeImpl(ctx, inputs, req, outputs, small, red, false, op); +#endif } template diff --git a/src/operator/tensor/broadcast_reduce_op_value.cu b/src/operator/tensor/broadcast_reduce_op_value.cu index 35b3c0272db8..f7c28341fed5 100644 --- a/src/operator/tensor/broadcast_reduce_op_value.cu +++ b/src/operator/tensor/broadcast_reduce_op_value.cu @@ -37,7 +37,8 @@ NNVM_REGISTER_OP(broadcast_like) .set_attr("FCompute", BroadcastCompute); NNVM_REGISTER_OP(_broadcast_backward) -.set_attr("FCompute", ReduceAxesCompute); +.set_attr("FCompute", ReduceAxesRTCCompute{"identity", + "red::sum{}", false}); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/broadcast_reduce_prod_value.cu b/src/operator/tensor/broadcast_reduce_prod_value.cu index 5731de308064..7e7a95b50677 100644 --- a/src/operator/tensor/broadcast_reduce_prod_value.cu +++ b/src/operator/tensor/broadcast_reduce_prod_value.cu @@ -28,13 +28,15 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(prod) -.set_attr("FCompute", ReduceAxesCompute); +.set_attr("FCompute", ReduceAxesRTCCompute + {"identity", "red::product{}", false}); NNVM_REGISTER_OP(_backward_prod) .set_attr("FCompute", ReduceAxesBackwardUseInOut); NNVM_REGISTER_OP(nanprod) -.set_attr("FCompute", ReduceAxesCompute); +.set_attr("FCompute", ReduceAxesRTCCompute + {"identity", "red::nanprod{}", false}); NNVM_REGISTER_OP(_backward_nanprod) .set_attr("FCompute", ReduceAxesBackwardUseInOut); diff --git a/src/operator/tensor/broadcast_reduce_sum_value.cu b/src/operator/tensor/broadcast_reduce_sum_value.cu index 2385d36f35b0..40a8ed8d17bf 100644 --- a/src/operator/tensor/broadcast_reduce_sum_value.cu +++ b/src/operator/tensor/broadcast_reduce_sum_value.cu @@ -28,19 +28,22 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(sum) -.set_attr("FCompute", ReduceAxesCompute); +.set_attr("FCompute", ReduceAxesRTCCompute{"identity", + "red::sum{}", false}); NNVM_REGISTER_OP(_backward_sum) .set_attr("FCompute", ReduceAxesBackwardUseNone); NNVM_REGISTER_OP(mean) -.set_attr("FCompute", ReduceAxesCompute); +.set_attr("FCompute", ReduceAxesRTCCompute{"identity", + "red::sum{}", true}); NNVM_REGISTER_OP(_backward_mean) .set_attr("FCompute", ReduceAxesBackwardUseNone); NNVM_REGISTER_OP(nansum) -.set_attr("FCompute", ReduceAxesCompute); +.set_attr("FCompute", ReduceAxesRTCCompute{"identity", + "red::nansum{}", false}); NNVM_REGISTER_OP(_backward_nansum) .set_attr("FCompute", ReduceAxesBackwardUseInOut); diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.cc b/src/operator/tensor/elemwise_binary_broadcast_op.cc index 2f9832a173f6..34515025e604 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.cc +++ b/src/operator/tensor/elemwise_binary_broadcast_op.cc @@ -376,10 +376,10 @@ void BinaryBroadcastRTCBackwardUseNone::operator()(const nnvm::NodeAttrs& attrs, if (out.shape_.Size() != 0) { broadcast::RTCReduce(ctx, lhs, req[0], workspace, out, - "red::sum", NDim, LOP); + "red::sum{}", NDim, LOP); broadcast::RTCReduce(ctx, rhs, req[1], workspace, out, - "red::sum", NDim, ROP); + "red::sum{}", NDim, ROP); } else { using namespace common::cuda::rtc::util; if (lhs.shape_.Size() != 0) { @@ -434,12 +434,12 @@ void BinaryBroadcastRTCBackwardUseIn::operator()(const nnvm::NodeAttrs& attrs, ctx.requested[0].get_space_typed(Shape1(workspace_size), s); if (req[0] != kNullOp) { broadcast::RTCReduce(ctx, lgrad, req[0], workspace, - ograd, lhs, rhs, "red::sum", NDim, + ograd, lhs, rhs, "red::sum{}", NDim, "mul", LOP); } if (req[1] != kNullOp) { broadcast::RTCReduce(ctx, rgrad, req[1], workspace, - ograd, lhs, rhs, "red::sum", NDim, + ograd, lhs, rhs, "red::sum{}", NDim, "mul", ROP); } }); diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 7bc623b493ef..fa61268d8fa1 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -2046,8 +2046,12 @@ void RepeatOpBackward(const nnvm::NodeAttrs& attrs, inputs[0].type_flag_, inputs[0].dev_id()); std::vector newInputs = {iblob}; +#if !defined(__CUDACC__) ReduceAxesComputeImpl( ctx, newInputs, req, newOutputs, rshapes.first); +#else + ReduceAxesRTCComputeImpl(ctx, newInputs, req, newOutputs, rshapes.first, "red::sum{}", false); +#endif } struct TileParam : public dmlc::Parameter { @@ -2238,8 +2242,12 @@ void TileOpBackward(const nnvm::NodeAttrs& attrs, inputs[0].type_flag_, inputs[0].dev_id()); std::vector newInputs = {iblob}; +#if !defined(__CUDACC__) ReduceAxesComputeImpl( ctx, newInputs, req, newOutputs, rshapes.first); +#else + ReduceAxesRTCComputeImpl(ctx, newInputs, req, newOutputs, rshapes.first, "red::sum{}", false); +#endif } struct ReverseParam : public dmlc::Parameter { diff --git a/src/operator/tensor/reduce_rtc.cc b/src/operator/tensor/reduce_rtc.cc index 9e2d6d3f2a53..f7caf90dbc52 100644 --- a/src/operator/tensor/reduce_rtc.cc +++ b/src/operator/tensor/reduce_rtc.cc @@ -107,7 +107,7 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto, } AType val, residual; - REDUCER::SetInitValue(val, residual); + REDUCER.SetInitValue(val, residual); if (idx < N) { for (index_t k = tidy + Mstart; k < Mend; k += by*UNROLL) { index_t idx_big[UNROLL]; @@ -133,7 +133,7 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto, } #pragma unroll for (int u=0;u < UNROLL;u++) { - if (k + u*by < Mend) REDUCER::Reduce(val, tmp[u], residual); + if (k + u*by < Mend) REDUCER.Reduce(val, tmp[u], residual); } } } @@ -148,17 +148,17 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto, __syncthreads(); for (int t=1;t < by;t <<= 1) { AType tmp, tmp_residual; - REDUCER::SetInitValue(tmp, tmp_residual); + REDUCER.SetInitValue(tmp, tmp_residual); if (tidy + t < by) { tmp = shTile[(it0 + t*fbx) * 2]; tmp_residual = shTile[(it0 + t*fbx) * 2 + 1]; } __syncthreads(); - REDUCER::Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual); + REDUCER.Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual); __syncthreads(); } if (idx < N && tidy == 0) { - REDUCER::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]); + REDUCER.Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]); if (addto) { small[idx + m0 * N] = OType::to(op::add(OType::from(small[idx + m0 * N]), shTile[tidx * 2])); @@ -168,7 +168,7 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto, } } else { if (idx < N) { - REDUCER::Finalize(val, residual); + REDUCER.Finalize(val, residual); if (addto) { small[idx + m0 * N] = OType::to(op::add(OType::from(small[idx + m0 * N]), val)); @@ -191,15 +191,15 @@ __global__ void reduce_lines_kernel(const index_t N, const index_t M, using OType = AccType; for (index_t idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { typename OType::type val, residual; - REDUCER::SetInitValue(val, residual); + REDUCER.SetInitValue(val, residual); for (int k = 0; k < M; k++) { - REDUCER::Reduce(val, + REDUCER.Reduce(val, OType::from(reinterpret_cast(small_in)[idx + k*small_in_stride]), residual); } if (idx < N) { - REDUCER::Finalize(val, residual); + REDUCER.Finalize(val, residual); if (req == OpReqType::kAddTo) { small_out[idx] = OType::to(op::add(OType::from(small_out[idx]), val)); } else { @@ -359,10 +359,10 @@ __global__ void reduce_kernel_M1(const int N, idx_rhs[0] = util::ravel(coord, params.rhs_shape); } typename OType::type val, residual; - REDUCER::SetInitValue(val, residual); + REDUCER.SetInitValue(val, residual); const int u = 0; - REDUCER::Reduce(val, FUNC, residual); - REDUCER::Finalize(val, residual); + REDUCER.Reduce(val, FUNC, residual); + REDUCER.Finalize(val, residual); if (req == OpReqType::kAddTo) { const auto temp = op::add(val, OType::from(small[idx])); small[idx] = OType::to(temp); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 580478ebcec3..2ecb596483fb 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2987,8 +2987,10 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): if isinstance(dtype, tuple): assert len(dtype) == 2 ldtype, rdtype = dtype - np_test_x1 = _np.random.uniform(low, high, lshape).astype(ldtype) - np_test_x2 = _np.random.uniform(low, high, rshape).astype(rdtype) + npldtype = ldtype if dtype != _np.float16 else _np.float32 + nprdtype = rdtype if dtype != _np.float16 else _np.float32 + np_test_x1 = _np.random.uniform(low, high, lshape).astype(ldtype).astype(npldtype) + np_test_x2 = _np.random.uniform(low, high, rshape).astype(rdtype).astype(nprdtype) mx_test_x1 = mx.numpy.array(np_test_x1, dtype=ldtype) mx_test_x2 = mx.numpy.array(np_test_x2, dtype=rdtype) for hybridize in [True, False]: From 91f715f40d06f65ac09e76c99af9514ecf10cace Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 27 Oct 2020 14:30:39 -0700 Subject: [PATCH 02/15] Fixes after merge --- src/operator/numpy/np_broadcast_reduce_op.cuh | 44 --- src/operator/numpy/np_broadcast_reduce_op.h | 53 ++-- src/operator/tensor/reduce_rtc.cc | 276 ++++++++++-------- 3 files changed, 191 insertions(+), 182 deletions(-) delete mode 100644 src/operator/numpy/np_broadcast_reduce_op.cuh diff --git a/src/operator/numpy/np_broadcast_reduce_op.cuh b/src/operator/numpy/np_broadcast_reduce_op.cuh deleted file mode 100644 index ec50f283cefa..000000000000 --- a/src/operator/numpy/np_broadcast_reduce_op.cuh +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2015-2020 by Contributors - * \file np_broadcast_reduce-inl.cuh - * \brief GPU implementations for numpy binary broadcast ops - * \author Zhaoqi Zhu -*/ -#ifndef MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_ -#define MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_ - -using namespace mshadow::cuda; -using namespace mshadow; -using namespace broadcast; - -template -void NumpyArgMinMaxReduce(Stream *s, const TBlob& in_data, const TBlob& out_data, - const Tensor& workspace) { - cudaStream_t stream = Stream::GetStream(s); - ReduceImplConfig config(out_data.shape_, in_data.shape_, nullptr, nullptr); - - ReduceImpl> - (stream, out_data, kWriteTo, in_data, workspace, config); -} - -#endif // MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_ diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h index d7d5698f8d8b..4562963b16e4 100644 --- a/src/operator/numpy/np_broadcast_reduce_op.h +++ b/src/operator/numpy/np_broadcast_reduce_op.h @@ -494,10 +494,6 @@ void NumpyArgMinMaxReduce(mshadow::Stream *s, const TBlob& in_data, const T in_data.shape_.get(), out_data.shape_.get(), rshape, rstride); } -#ifdef __CUDACC__ -#include "np_broadcast_reduce_op.cuh" -#endif - template void NumpyArgMinMaxCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -508,7 +504,7 @@ void NumpyArgMinMaxCompute(const nnvm::NodeAttrs& attrs, using namespace mshadow::expr; if (req[0] == kNullOp) return; // parse param - const ReduceAxisParam& param = nnvm::get(attrs.parsed); + const auto& param = nnvm::get(attrs.parsed); mshadow::Stream *s = ctx.get_stream(); TBlob out = outputs[0]; TBlob in = inputs[0]; @@ -537,34 +533,45 @@ void NumpyArgMinMaxCompute(const nnvm::NodeAttrs& attrs, small = NumpyReduceAxesShapeImpl(in.shape_, axes, true); mxnet::TShape src_shape, dst_shape; BroadcastReduceShapeCompact(in.shape_, small, &src_shape, &dst_shape); + const TBlob in_data = in.reshape(src_shape); + // request a work space + size_t workspace_size = broadcast::ReduceWorkspaceSize(s, dst_shape, req[0], src_shape); +#ifndef __CUDACC__ MSHADOW_TYPE_SWITCH_WITH_BOOL(in.type_flag_, DType, { // define OType typedef mxnet::op::mshadow_op::IndexedNum OType; - // request a work space - size_t workspace_size = sizeof(OType) * out.shape_.Size(); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); - // set up intermediate output - TBlob intermediate = out; - intermediate.dptr_ = reinterpret_cast(workspace.dptr_); - // reshape the input and intermediate output tensor - const TBlob in_data = in.reshape(src_shape); - const TBlob intermediate_out_data = intermediate.reshape(dst_shape); // switch dim BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { - size_t workspace_size = broadcast::ReduceWorkspaceSize( - s, intermediate_out_data.shape_, req[0], in_data.shape_); + constexpr size_t align_size = 1024; + const size_t aligned_first_workspace_size = ((workspace_size + align_size - 1) / align_size) + * align_size; + workspace_size = aligned_first_workspace_size + + sizeof(OType) * out.shape_.Size(); Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + // set up intermediate output + TBlob intermediate = out; + intermediate.dptr_ = reinterpret_cast(workspace.dptr_ + + aligned_first_workspace_size); + // reshape the input and intermediate output tensor + const TBlob intermediate_out_data = intermediate.reshape(dst_shape); NumpyArgMinMaxReduce(s, in_data, intermediate_out_data, workspace); + // parse the indices from the intermediate tensor back to the actual output tensor + using namespace mxnet_op; + Kernel::Launch( + s, out.shape_.Size(), outputs[0].dptr(), + static_cast(intermediate_out_data.dptr_)); }); - // parse the indices from the intermediate tensor back to the actual output tensor - using namespace mxnet_op; - Kernel::Launch( - s, out.shape_.Size(), outputs[0].dptr(), - static_cast(intermediate_out_data.dptr_)); }); +#else + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, outputs[0].reshape(dst_shape), req[0], workspace, in_data, + "red::argmax{}", NDim, "identity", true); + }); +#endif } template diff --git a/src/operator/tensor/reduce_rtc.cc b/src/operator/tensor/reduce_rtc.cc index 8b05df70f9ad..6fef0e002a46 100644 --- a/src/operator/tensor/reduce_rtc.cc +++ b/src/operator/tensor/reduce_rtc.cc @@ -49,39 +49,39 @@ struct reduce_kernel_params { const char reduce_function_code[] = R"code( #define FUNC OP(IType0::from(big[idx_big[u]])) +using AType = typename AccType::type; )code"; const char reduce_function_use_input_code[] = R"code( #define FUNC OP1(IType0::from(big[idx_big[u]]), \ OP2(IType1::from(lhs[idx_lhs[u]]), \ IType2::from(rhs[idx_rhs[u]]))) +using AType = typename AccType::type; )code"; const char reduce_function_index_code[] = R"code( -#define FUNC OP(IType0::from(big[idx_big[u]], k + u*by)) +#define FUNC AType(OP(IType0::from(big[idx_big[u]])), index) template -struct AccTypeIndex { - using type = AccTypeIndex; +struct AccIndex { index_t idx; - typename AccType::type num; + T num; - __device__ static inline type from(const T& val, const index_t i) { - return {AccType::from(val), i}; - } + __device__ inline AccIndex() {} + __device__ inline AccIndex(const T& val, const index_t idx) : num(val), idx(idx) {} - __device__ static inline index_t to(const type& val) { - return val.idx; + __device__ inline operator index_t() const volatile { + return idx; } - template - __device__ inline type& operator=(const AccTypeIndex& other) { + __device__ inline AccIndex& operator=(const AccIndex& other) { idx = other.idx; num = other.num; + return *this; } -} +}; -#define AccType AccTypeIndex +using AType = AccIndex::type>; )code"; const char reduce_kernel_code[] = R"code( @@ -98,22 +98,107 @@ struct reduce_kernel_params { index_t rhs_shape[util::MAX_DIM]; }; -__launch_bounds__(kRTCMaxThreadsPerBlock) -__global__ void reduce_kernel(const int N, const int M, const bool addto, - const InputType0* __restrict big, - const InputType1* __restrict lhs, - const InputType2* __restrict rhs, - OutputType0 *small, - const reduce_kernel_params params, - const int Mnext) { +inline __device__ AType reduce(const index_t idx, const int tidx, + const int tidy, const int N, + const index_t Mstart, const index_t Mend, + const InputType0* __restrict big, + const InputType1* __restrict lhs, + const InputType2* __restrict rhs, + const reduce_kernel_params& params) { extern __shared__ char shTileChar[]; using IType0 = AccType; using IType1 = AccType; using IType2 = AccType; using OType = AccType; - using MixedType = typename type_util::mixed_type::type; - using AType = typename AccType::type; AType* shTile = (AType*)(shTileChar); + const int bx = (do_transpose) ? blockDim.y : blockDim.x; + const int by = (do_transpose) ? blockDim.x : blockDim.y; + index_t coord[ndim]; + util::unravel(idx, params.small_shape, coord); + index_t idx_big0, idx_lhs0, idx_rhs0; + idx_big0 = util::ravel(coord, params.big_shape); + if (use_input) { + idx_lhs0 = util::ravel(coord, params.lhs_shape0); + idx_rhs0 = util::ravel(coord, params.rhs_shape0); + } + + AType val, residual; + REDUCER.SetInitValue(val, residual); + if (idx < N) { + for (index_t k = tidy + Mstart; k < Mend; k += by*UNROLL) { + index_t idx_big[UNROLL]; + index_t idx_lhs[UNROLL]; + index_t idx_rhs[UNROLL]; + #pragma unroll + for (int u=0;u < UNROLL;u++) { + idx_big[u] = idx_big0 + util::unravel_dot(k + u*by, params.rshape, + params.rstride); + if (use_input) { + idx_lhs[u] = idx_lhs0 + util::unravel_dot(k + u*by, params.lhs_shape, + params.lhs_stride); + idx_rhs[u] = idx_rhs0 + util::unravel_dot(k + u*by, params.rhs_shape, + params.rhs_stride); + } + } + AType tmp[UNROLL]; + #pragma unroll + for (int u=0;u < UNROLL;u++) { + if (k + u*by < Mend) { + const index_t index = k + u*by; + tmp[u] = FUNC; + } + } + #pragma unroll + for (int u=0;u < UNROLL;u++) { + if (k + u*by < Mend) REDUCER.Reduce(val, tmp[u], residual); + } + } + } + + // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0 + if (by > 1) { + // Fix bx to avoid bank conflicts. Assumes warpSize number of banks + const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx; + const int it0 = tidx + tidy*fbx; + shTile[it0 * 2] = val; + shTile[it0 * 2 + 1] = residual; + __syncthreads(); + for (int t=1;t < by;t <<= 1) { + AType tmp, tmp_residual; + REDUCER.SetInitValue(tmp, tmp_residual); + if (tidy + t < by) { + tmp = shTile[(it0 + t*fbx) * 2]; + tmp_residual = shTile[(it0 + t*fbx) * 2 + 1]; + } + __syncthreads(); + REDUCER.Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual); + __syncthreads(); + } + if (idx < N && tidy == 0) { + REDUCER.Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]); + return shTile[tidx * 2]; + } else { + return AType(); + } + } else { + if (idx < N) { + REDUCER.Finalize(val, residual); + return val; + } else { + return AType(); + } + } +} + +__launch_bounds__(kRTCMaxThreadsPerBlock) +__global__ void reduce_kernel_single(const int N, const int M, + const InputType0* __restrict big, + const InputType1* __restrict lhs, + const InputType2* __restrict rhs, + OutputType0 *small, + const reduce_kernel_params params, + const int Mnext) { + using OType = AccType; const int tid = threadIdx.x + threadIdx.y*blockDim.x; const int bx = (do_transpose) ? blockDim.y : blockDim.x; const int by = (do_transpose) ? blockDim.x : blockDim.y; @@ -124,117 +209,77 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto, const index_t Mstart = (index_t)((int64)M*(int64)m0/(int64)Mnext); const index_t Mend = (index_t)((int64)M*(int64)(m0 + 1)/(int64)Mnext); for (index_t idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) { - int idx = idx0 + tidx; - index_t coord[ndim]; - util::unravel(idx, params.small_shape, coord); - index_t idx_big0, idx_lhs0, idx_rhs0; - idx_big0 = util::ravel(coord, params.big_shape); - if (use_input) { - idx_lhs0 = util::ravel(coord, params.lhs_shape0); - idx_rhs0 = util::ravel(coord, params.rhs_shape0); - } - - AType val, residual; - REDUCER.SetInitValue(val, residual); - if (idx < N) { - for (index_t k = tidy + Mstart; k < Mend; k += by*UNROLL) { - index_t idx_big[UNROLL]; - index_t idx_lhs[UNROLL]; - index_t idx_rhs[UNROLL]; - #pragma unroll - for (int u=0;u < UNROLL;u++) { - idx_big[u] = idx_big0 + util::unravel_dot(k + u*by, params.rshape, - params.rstride); - if (use_input) { - idx_lhs[u] = idx_lhs0 + util::unravel_dot(k + u*by, params.lhs_shape, - params.lhs_stride); - idx_rhs[u] = idx_rhs0 + util::unravel_dot(k + u*by, params.rhs_shape, - params.rhs_stride); - } - } - AType tmp[UNROLL]; - #pragma unroll - for (int u=0;u < UNROLL;u++) { - if (k + u*by < Mend) { - tmp[u] = FUNC; - } - } - #pragma unroll - for (int u=0;u < UNROLL;u++) { - if (k + u*by < Mend) REDUCER.Reduce(val, tmp[u], residual); - } + const index_t idx = idx0 + tidx; + AType val = reduce(idx, tidx, tidy, N, Mstart, Mend, big, lhs, rhs, params); + if (idx < N && (by == 1 || tidy == 0)) { + if (req == OpReqType::kAddTo) { + small[idx + m0 * N] = OType::to(op::add(OType::from(small[idx + m0 * N]), + static_cast(val))); + } else { + small[idx + m0 * N] = OType::to(val); } } + } + } +} - // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0 - if (by > 1) { - // Fix bx to avoid bank conflicts. Assumes warpSize number of banks - const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx; - const int it0 = tidx + tidy*fbx; - shTile[it0 * 2] = val; - shTile[it0 * 2 + 1] = residual; - __syncthreads(); - for (int t=1;t < by;t <<= 1) { - AType tmp, tmp_residual; - REDUCER.SetInitValue(tmp, tmp_residual); - if (tidy + t < by) { - tmp = shTile[(it0 + t*fbx) * 2]; - tmp_residual = shTile[(it0 + t*fbx) * 2 + 1]; - } - __syncthreads(); - REDUCER.Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual); - __syncthreads(); - } - if (idx < N && tidy == 0) { - REDUCER.Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]); - if (addto) { - small[idx + m0 * N] = OType::to(op::add(OType::from(small[idx + m0 * N]), - shTile[tidx * 2])); - } else { - small[idx + m0 * N] = OType::to(shTile[tidx * 2]); - } - } - } else { - if (idx < N) { - REDUCER.Finalize(val, residual); - if (addto) { - small[idx + m0 * N] = OType::to(op::add(OType::from(small[idx + m0 * N]), - val)); - } else { - small[idx + m0 * N] = OType::to(val); - } - } +__launch_bounds__(kRTCMaxThreadsPerBlock) +__global__ void reduce_kernel_multi(const int N, const int M, + const InputType0* __restrict big, + const InputType1* __restrict lhs, + const InputType2* __restrict rhs, + AType *small, + const reduce_kernel_params params, + const int Mnext) { + const int tid = threadIdx.x + threadIdx.y*blockDim.x; + const int bx = (do_transpose) ? blockDim.y : blockDim.x; + const int by = (do_transpose) ? blockDim.x : blockDim.y; + const int tidx = (do_transpose) ? tid / by : threadIdx.x; + const int tidy = (do_transpose) ? tid % by : threadIdx.y; + for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) { + // This TB handles M range [Mstart, ...., Mend - 1] + const index_t Mstart = (index_t)((int64)M*(int64)m0/(int64)Mnext); + const index_t Mend = (index_t)((int64)M*(int64)(m0 + 1)/(int64)Mnext); + for (index_t idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) { + const index_t idx = idx0 + tidx; + AType val = reduce(idx, tidx, tidy, N, Mstart, Mend, big, lhs, rhs, params); + if (idx < N && (by == 1 || tidy == 0)) { + small[idx + m0 * N] = val; } } } } + )code"; const char reduce_lines_kernel_code[] = R"code( +using MixedType = typename type_util::mixed_type::type; +using AType = typename AccType::type; + __launch_bounds__(kRTCMaxThreadsPerBlock) __global__ void reduce_lines_kernel(const index_t N, const index_t M, const index_t small_in_stride, - const OutputType0* __restrict small_in, + const AType* __restrict small_in, OutputType0 *small_out) { using OType = AccType; for (index_t idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { - typename OType::type val, residual; + AType val, residual; REDUCER.SetInitValue(val, residual); for (int k = 0; k < M; k++) { REDUCER.Reduce(val, - OType::from(reinterpret_cast(small_in)[idx + k*small_in_stride]), + small_in[idx + k*small_in_stride], residual); } if (idx < N) { REDUCER.Finalize(val, residual); if (req == OpReqType::kAddTo) { - small_out[idx] = OType::to(op::add(OType::from(small_out[idx]), val)); + small_out[idx] = OType::to(op::add(OType::from(small_out[idx]), + static_cast(val))); } else { small_out[idx] = OType::to(val); } } - } } )code"; @@ -247,11 +292,9 @@ void RTCReduceImpl(Stream *s, const TBlob& small, const bool addto, const bool use_index = false) { using namespace common::cuda::rtc; void* small_dptr = small.dptr_; - bool first_kernel_addto = addto; if (config.Mnext > 1) { // small_dptr[] is N*Mnext*sizeof(DType) bytes small_dptr = workspace.dptr_; - first_kernel_addto = false; // Check that the workspace is contigiuous CHECK_EQ(workspace.CheckContiguous(), true); // Check that we have enough storage @@ -310,7 +353,6 @@ void RTCReduceImpl(Stream *s, const TBlob& small, const bool addto, std::vector args; args.emplace_back(&config.N); args.emplace_back(&config.M); - args.emplace_back(&first_kernel_addto); args.emplace_back(&big.dptr_); if (lhs != nullptr) { args.emplace_back(&(lhs->dptr_)); @@ -326,8 +368,9 @@ void RTCReduceImpl(Stream *s, const TBlob& small, const bool addto, const auto &function_code = (lhs == nullptr) ? (use_index ? reduce_function_index_code : reduce_function_code) : reduce_function_use_input_code; + const auto& kernel_name = (config.Mnext > 1) ? "reduce_kernel_multi" : "reduce_kernel_single"; auto reduce_kernel_func = get_function(code + function_code, - "reduce_kernel", + kernel_name, reduce_kernel_code, dev_id); launch(reduce_kernel_func, config.kernel_1.gridDim, @@ -377,9 +420,11 @@ __global__ void reduce_kernel_M1(const int N, using IType1 = AccType; using IType2 = AccType; using OType = AccType; - for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { + for (index_t index = threadIdx.x + blockIdx.x*blockDim.x; + index < N; + index += blockDim.x*gridDim.x) { index_t coord[ndim]; - util::unravel(idx, params.small_shape, coord); + util::unravel(index, params.small_shape, coord); index_t idx_big[1]; idx_big[0] = util::ravel(coord, params.big_shape); index_t idx_lhs[1], idx_rhs[1]; @@ -387,16 +432,17 @@ __global__ void reduce_kernel_M1(const int N, idx_lhs[0] = util::ravel(coord, params.lhs_shape); idx_rhs[0] = util::ravel(coord, params.rhs_shape); } - typename OType::type val, residual; + AType val, residual; REDUCER.SetInitValue(val, residual); const int u = 0; REDUCER.Reduce(val, FUNC, residual); REDUCER.Finalize(val, residual); if (req == OpReqType::kAddTo) { - const auto temp = op::add(val, OType::from(small[idx])); - small[idx] = OType::to(temp); + const auto temp = op::add(static_cast(val), + OType::from(small[index])); + small[index] = OType::to(temp); } else { - small[idx] = OType::to(val); + small[index] = OType::to(static_cast(val)); } } } From fe5656fcd0a52df249a7d62be5d308ac8d185920 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 29 Oct 2020 09:30:56 -0700 Subject: [PATCH 03/15] Fixes --- src/common/cuda/rtc.cc | 1 + src/common/cuda/rtc/reducer-inl.h | 90 ++++++++++-- src/common/cuda/rtc/special_functions-inl.h | 5 - src/common/cuda/rtc/util-inl.h | 138 ++++++++++++++++++ src/operator/mshadow_op.h | 8 +- src/operator/numpy/np_broadcast_reduce_op.cc | 85 +++++++++++ src/operator/numpy/np_broadcast_reduce_op.h | 23 +-- .../numpy/np_broadcast_reduce_op_index.cu | 4 +- src/operator/tensor/broadcast_reduce-inl.h | 3 +- tests/python/unittest/test_numpy_op.py | 4 +- 10 files changed, 322 insertions(+), 39 deletions(-) create mode 100644 src/operator/numpy/np_broadcast_reduce_op.cc diff --git a/src/common/cuda/rtc.cc b/src/common/cuda/rtc.cc index 8f3b3391f5e4..486f9afe53b1 100644 --- a/src/common/cuda/rtc.cc +++ b/src/common/cuda/rtc.cc @@ -116,6 +116,7 @@ CUfunction get_function(const std::string ¶meters, std::string(fp16_support_string) + "\n" + type_support_string + "\n" + util_string + "\n" + + limits + "\n" + special_functions_definitions + '\n' + vectorization_support_string + "\n" + function_definitions_util + "\n" + diff --git a/src/common/cuda/rtc/reducer-inl.h b/src/common/cuda/rtc/reducer-inl.h index d7750265d1cf..ac6bb22442b0 100644 --- a/src/common/cuda/rtc/reducer-inl.h +++ b/src/common/cuda/rtc/reducer-inl.h @@ -27,6 +27,7 @@ namespace common { namespace cuda { namespace rtc { + const char reducer[] = R"code( namespace red { @@ -132,7 +133,7 @@ struct maximum { */ template __device__ inline static void SetInitValue(DType &initv) { - initv = -2*DBL_MAX; + initv = limits::NegInfValue(); } /*! *\brief set the initial value during reduction @@ -180,7 +181,7 @@ struct minimum { */ template __device__ inline static void SetInitValue(DType &initv) { - initv = 2*DBL_MAX; + initv = limits::PosInfValue(); } /*! *\brief set the initial value during reduction @@ -499,7 +500,7 @@ struct argmax { /*! \brief do reduction into dst */ template __device__ inline static void Reduce(volatile AType& dst, volatile DType src) { - if (dst.num < src.num) { + if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) { dst.num = src.num; dst.idx = src.idx; } @@ -507,27 +508,27 @@ struct argmax { /*! \brief do stable reduction into dst */ template __device__ inline static void Reduce(volatile AType& dst, volatile DType src, - volatile DType& residual) { - if (dst.num < src.num) { + volatile DType&) { + if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) { dst.num = src.num; dst.idx = src.idx; } } /*! \brief combine the results of two reducers */ template - __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { - if (dst_val.num < src_val.num) { - dst_val.num = src_val.num; - dst_val.idx = src_val.idx; + __device__ inline static void Merge(volatile DType& dst, volatile DType& src) { + if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) { + dst.num = src.num; + dst.idx = src.idx; } } /*! \brief combine the results of two reducers */ template - __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual, - volatile DType& src_val, volatile DType& src_residual) { - if (dst_val.num < src_val.num) { - dst_val.num = src_val.num; - dst_val.idx = src_val.idx; + __device__ inline static void Merge(volatile DType& dst, volatile DType&, + volatile DType& src, volatile DType&) { + if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) { + dst.num = src.num; + dst.idx = src.idx; } } /*! \brief finalize reduction */ @@ -541,14 +542,71 @@ struct argmax { */ template __device__ inline static void SetInitValue(DType &initv) { - initv.num = -2 * DBL_MAX; + initv.num = limits::NegInfValue(); } /*! *\brief set the initial value during reduction */ template __device__ inline static void SetInitValue(DType &initv, DType &) { - initv.num = -2 * DBL_MAX; + initv.num = limits::NegInfValue(); + } +}; + +struct argmin { + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile AType& dst, volatile DType src) { + if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) { + dst.num = src.num; + dst.idx = src.idx; + } + } + /*! \brief do stable reduction into dst */ + template + __device__ inline static void Reduce(volatile AType& dst, volatile DType src, + volatile DType& residual) { + if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) { + dst.num = src.num; + dst.idx = src.idx; + } + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst, volatile DType& src) { + if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) { + dst.num = src.num; + dst.idx = src.idx; + } + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst, volatile DType&, + volatile DType& src, volatile DType&) { + if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) { + dst.num = src.num; + dst.idx = src.idx; + } + } + /*! \brief finalize reduction */ + template + __device__ inline static void Finalize(volatile DType& dst) {} + /*! \brief finalize reduction */ + template + __device__ inline static void Finalize(volatile DType& dst, volatile DType& residual) {} + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv) { + initv.num = limits::PosInfValue(); + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv, DType &residual) { + initv.num = limits::PosInfValue(); } }; } // namespace red diff --git a/src/common/cuda/rtc/special_functions-inl.h b/src/common/cuda/rtc/special_functions-inl.h index d64afb51e2c1..7110e4737d0f 100644 --- a/src/common/cuda/rtc/special_functions-inl.h +++ b/src/common/cuda/rtc/special_functions-inl.h @@ -50,11 +50,6 @@ namespace rtc { // Direct inquiries to 30 Frost Street, Cambridge, MA 02140 // const char special_functions_definitions[] = R"code( -constexpr double DBL_MAX = 1.7976931348623157081e+308; -constexpr float FLT_MAX = 3.4028234663852885981e+38; -#define inf ((float)1e50) -#define nan (inf - inf) - namespace op { namespace special_functions { diff --git a/src/common/cuda/rtc/util-inl.h b/src/common/cuda/rtc/util-inl.h index 372390fdc117..adae4c613e6c 100644 --- a/src/common/cuda/rtc/util-inl.h +++ b/src/common/cuda/rtc/util-inl.h @@ -379,6 +379,144 @@ __device__ inline bool isnan(volatile const float16 &val) { } // namespace util )code"; + +const char limits[] = R"code( +constexpr double DBL_MAX = 1.7976931348623157081e+308; +constexpr float FLT_MAX = 3.4028234663852885981e+38; +#define inf ((float)1e50) +#define nan (inf - inf) + +namespace limits { + +template +__device__ inline DType MinValue(void); + +template<> +__device__ inline float MinValue(void) { + return -FLT_MAX; +} +/*! \brief minimum value of double */ +template<> +__device__ inline double MinValue(void) { + return -DBL_MAX; +} +/*! \brief minimum value of uint8 */ +template<> +__device__ inline uint8 MinValue(void) { + return 0; +} +/*! \brief minimum value of int8_t */ +template<> +__device__ inline int8 MinValue(void) { + return -128; +} +/*! \brief minimum value of int32 */ +template<> +__device__ inline int32 MinValue(void) { + return -2147483648; +} +/*! \brief minimum value of int64_t */ +template<> +__device__ inline int64 MinValue(void) { + return -9223372036854775808LL; +} +/*! \brief minimum value of bool */ +template<> +__device__ inline bool MinValue(void) { + return false; +} +/*! \brief minimum value of bool_t */ +template<> +__device__ inline bool_t MinValue(void) { + return MinValue(); +} + +/*! + * \brief negative infinity of certain types + * \tparam DType data type + */ +template +__device__ inline DType NegInfValue(void) { + return MinValue(); +} +/*! \brief negative infinity value of float */ +template<> +__device__ inline float NegInfValue(void) { + return -inf; +} +/*! \brief negative infinity value of double */ +template<> +__device__ inline double NegInfValue(void) { + return -inf; +} + +/*! + * \brief maximum value of certain types + * \tparam DType data type + */ +template +__device__ inline DType MaxValue(void); +/*! \brief maximum value of float */ +template<> +__device__ inline float MaxValue(void) { + return FLT_MAX; +} +/*! \brief maximum value of double */ +template<> +__device__ inline double MaxValue(void) { + return DBL_MAX; +} +/*! \brief maximum value of uint8 */ +template<> +__device__ inline uint8 MaxValue(void) { + return 255; +} +/*! \brief maximum value of int8 */ +template<> +__device__ inline int8 MaxValue(void) { + return 127; +} +/*! \brief maximum value of int32 */ +template<> +__device__ inline int32 MaxValue(void) { + return 2147483647; +} +/*! \brief maximum value of int64 */ +template<> +__device__ inline int64 MaxValue(void) { + return 9223372036854775807LL; +} +/*! \brief maximum value of bool */ +template<> +__device__ inline bool MaxValue(void) { + return true; +} +/*! \brief maximum value of bool_t */ +template<> +__device__ inline bool_t MaxValue(void) { + return MaxValue(); +} +/*! + * \brief positive infinity of certain types + * \tparam DType data type + */ +template +__device__ inline DType PosInfValue(void) { + return MaxValue(); +} +/*! \brief positive infinity value of float */ +template<> +__device__ inline float PosInfValue(void) { + return inf; +} +/*! \brief positive infinity value of double */ +template<> +__device__ inline double PosInfValue(void) { + return inf; +} + +} // namespace limits +)code"; } // namespace rtc } // namespace cuda } // namespace common diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 4949a7e47267..ff2fd20306ae 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -1569,7 +1569,7 @@ struct argmax { /*! \brief do reduction into dst */ template MSHADOW_XINLINE static void Reduce(volatile AType& dst, volatile DType src) { // NOLINT(*) - if (dst.num < src.num) { + if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) { dst.num = src.num; dst.idx = src.idx; } @@ -1577,7 +1577,7 @@ struct argmax { /*! \brief do stable reduction into dst */ template MSHADOW_XINLINE static void Reduce(volatile AType& dst, volatile DType src, volatile DType& residual) { // NOLINT(*) - if (dst.num < src.num) { + if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) { dst.num = src.num; dst.idx = src.idx; } @@ -1585,7 +1585,7 @@ struct argmax { /*! \brief combine the results of two reducers */ template MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*) - if (dst_val.num < src_val.num) { + if (dst_val.num < src_val.num || (dst_val.num == src_val.num && dst_val.idx > src_val.idx)) { dst_val.num = src_val.num; dst_val.idx = src_val.idx; } @@ -1593,7 +1593,7 @@ struct argmax { /*! \brief combine the results of two reducers */ template MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*) - if (dst_val.num < src_val.num) { + if (dst_val.num < src_val.num || (dst_val.num == src_val.num && dst_val.idx > src_val.idx)) { dst_val.num = src_val.num; dst_val.idx = src_val.idx; } diff --git a/src/operator/numpy/np_broadcast_reduce_op.cc b/src/operator/numpy/np_broadcast_reduce_op.cc new file mode 100644 index 000000000000..4b64a1a29169 --- /dev/null +++ b/src/operator/numpy/np_broadcast_reduce_op.cc @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2020 by Contributors + * \file np_broadcast_reduce_op.cc + * \brief Function definitions of NumPy-compatible + * broadcast and reduce operators + */ + +#include "np_broadcast_reduce_op.h" + +namespace mxnet { +namespace op { +#if MXNET_USE_CUDA + +void NumpyArgMinMaxRTCCompute::operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + if (req[0] == kNullOp) return; + // parse param + const auto& param = nnvm::get(attrs.parsed); + mshadow::Stream *s = ctx.get_stream(); + TBlob out = outputs[0]; + TBlob in = inputs[0]; + // do some shape checks + if (in.shape_.ndim() != 0) { + if (param.axis.has_value()) { + // cannot do argmax in an empty dimension + int axis = param.axis.value(); + axis = CheckAxis(axis, in.shape_.ndim()); + CHECK_NE(in.shape_[axis], 0) + << "searching input tensor of shape " << inputs[0].shape_ + << " along axis = " << axis << " of zero dim-size is not allowed"; + } else { + // cannot do argmax on an empty array + CHECK_NE(in.shape_.Size(), 0U) << "attempt to search an empty sequence"; + } + } + if (in.shape_.Size() == 0U) return; // zero-size tensor + // prepare shape + dmlc::optional> axes; + if (param.axis.has_value()) { + mxnet::Tuple t({param.axis.value()}); + axes = dmlc::optional>(t); + } + TShape small; + small = NumpyReduceAxesShapeImpl(in.shape_, axes, true); + mxnet::TShape src_shape, dst_shape; + BroadcastReduceShapeCompact(in.shape_, small, &src_shape, &dst_shape); + const TBlob in_data = in.reshape(src_shape); + // request a work space + size_t workspace_size = broadcast::ReduceWorkspaceSize(s, dst_shape, req[0], src_shape); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, outputs[0].reshape(dst_shape), req[0], workspace, in_data, + reducer, NDim, "identity", true); + }); +} + +#endif // MXNET_USE_CUDA + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h index 4562963b16e4..fc050a506aa4 100644 --- a/src/operator/numpy/np_broadcast_reduce_op.h +++ b/src/operator/numpy/np_broadcast_reduce_op.h @@ -536,7 +536,6 @@ void NumpyArgMinMaxCompute(const nnvm::NodeAttrs& attrs, const TBlob in_data = in.reshape(src_shape); // request a work space size_t workspace_size = broadcast::ReduceWorkspaceSize(s, dst_shape, req[0], src_shape); -#ifndef __CUDACC__ MSHADOW_TYPE_SWITCH_WITH_BOOL(in.type_flag_, DType, { // define OType typedef mxnet::op::mshadow_op::IndexedNum OType; @@ -564,16 +563,22 @@ void NumpyArgMinMaxCompute(const nnvm::NodeAttrs& attrs, static_cast(intermediate_out_data.dptr_)); }); }); -#else - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); - BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { - broadcast::RTCReduce(ctx, outputs[0].reshape(dst_shape), req[0], workspace, in_data, - "red::argmax{}", NDim, "identity", true); - }); -#endif } +#if MXNET_USE_CUDA + +struct NumpyArgMinMaxRTCCompute { + std::string reducer; + + void operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +}; + +#endif + template inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/numpy/np_broadcast_reduce_op_index.cu b/src/operator/numpy/np_broadcast_reduce_op_index.cu index 892d04679422..eb6086c9d7fe 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_index.cu +++ b/src/operator/numpy/np_broadcast_reduce_op_index.cu @@ -28,10 +28,10 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_npi_argmax) -.set_attr("FCompute", NumpyArgMinMaxCompute); +.set_attr("FCompute", NumpyArgMinMaxRTCCompute{"red::argmax{}"}); NNVM_REGISTER_OP(_npi_argmin) -.set_attr("FCompute", NumpyArgMinMaxCompute); +.set_attr("FCompute", NumpyArgMinMaxRTCCompute{"red::argmin{}"}); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/broadcast_reduce-inl.h b/src/operator/tensor/broadcast_reduce-inl.h index efcae76458ad..47b9cb5dce25 100644 --- a/src/operator/tensor/broadcast_reduce-inl.h +++ b/src/operator/tensor/broadcast_reduce-inl.h @@ -507,7 +507,8 @@ struct ReduceImplConfig { lhs_shape(small.ndim(), 1), lhs_stride(small.ndim(), 1), rhs_shape(small.ndim(), 1), rhs_stride(small.ndim(), 1) { // The largest reduction type currently is (index_t, double) struct - constexpr size_t max_type_size = sizeof(double) + sizeof(index_t); + // aligned to 16B + constexpr size_t max_type_size = 2 * sizeof(double); constexpr int maxLoopPerTB = 64; int ndim = small.ndim(); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 091c3e4e1efd..50a2a06657ef 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -4311,7 +4311,7 @@ def test_np_argmin_argmax(): ((3, 5, 7), 2, False), ((3, 5, 7, 9, 11), -3, False), ] - dtypes = ['float16', 'float32', 'float64'] + dtypes = ['float16', 'float32', 'float64', 'bool', 'int32'] ops = ['argmin', 'argmax'] class TestArgExtreme(HybridBlock): @@ -4326,7 +4326,7 @@ def hybrid_forward(self, F, x): for op_name in ops: for shape, axis, throw_exception in workloads: for dtype in dtypes: - a = np.random.uniform(size=shape, dtype=dtype) + a = np.random.uniform(low=0, high=100, size=shape).astype(dtype) if throw_exception: # Cannot use assert_exception because sometimes the main thread # proceeds to `assert False` before the exception is thrown From f44314a68bf9d9c29267f56fd5e0272817ce6728 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 29 Oct 2020 13:54:10 -0700 Subject: [PATCH 04/15] Fix lint --- src/operator/tensor/broadcast_reduce_op.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/operator/tensor/broadcast_reduce_op.cc b/src/operator/tensor/broadcast_reduce_op.cc index 92c625d345c5..30d4d0f1f213 100644 --- a/src/operator/tensor/broadcast_reduce_op.cc +++ b/src/operator/tensor/broadcast_reduce_op.cc @@ -21,7 +21,6 @@ #include "../numpy/np_broadcast_reduce_op.h" #include "elemwise_binary_scalar_op.h" #include "mxnet/tuple.h" -#include namespace mxnet { namespace op { From 9b561e32c2c23a2ab06975a3839e048ea92ce6a9 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 29 Oct 2020 14:38:17 -0700 Subject: [PATCH 05/15] Fix lint for real --- src/operator/tensor/broadcast_reduce_op.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/tensor/broadcast_reduce_op.cc b/src/operator/tensor/broadcast_reduce_op.cc index 30d4d0f1f213..9348ad1bd2bf 100644 --- a/src/operator/tensor/broadcast_reduce_op.cc +++ b/src/operator/tensor/broadcast_reduce_op.cc @@ -18,6 +18,7 @@ */ #include "broadcast_reduce_op.h" +#include #include "../numpy/np_broadcast_reduce_op.h" #include "elemwise_binary_scalar_op.h" #include "mxnet/tuple.h" From 9acc3f55abcf92738a242f2e4ddb7377258aff22 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 30 Oct 2020 12:51:45 -0700 Subject: [PATCH 06/15] Cleaning and code reuse --- src/operator/nn/moments-inl.h | 4 +- src/operator/numpy/linalg/np_norm-inl.h | 19 +- src/operator/numpy/np_broadcast_reduce_op.h | 246 +++++++++----------- src/operator/numpy/np_kron-inl.h | 40 ++-- src/operator/numpy/np_tensordot_op-inl.h | 33 ++- src/operator/numpy/np_where_op-inl.h | 59 ++--- src/operator/random/pdf_op.h | 51 ++-- src/operator/tensor/broadcast_reduce_op.cc | 53 ++--- src/operator/tensor/broadcast_reduce_op.h | 34 ++- src/operator/tensor/matrix_op-inl.h | 6 +- 10 files changed, 256 insertions(+), 289 deletions(-) diff --git a/src/operator/nn/moments-inl.h b/src/operator/nn/moments-inl.h index a4398458eb7c..b8915dce2397 100644 --- a/src/operator/nn/moments-inl.h +++ b/src/operator/nn/moments-inl.h @@ -129,7 +129,7 @@ inline void MomentsForwardImpl(const OpContext& ctx, #if !defined(__CUDACC__) ReduceAxesComputeImpl(ctx, {data}, {req[0]}, {mean}, small); #else - ReduceAxesRTCComputeImpl(ctx, {data}, {req[0]}, {mean}, small, "red::sum{}", true); + ReduceAxesRTCComputeImpl(ctx, {data}, {req[0]}, {mean}, small, "red::sum{}", nullptr, true); #endif MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { Shape<6> data_shape, mean_shape; @@ -146,7 +146,7 @@ inline void MomentsForwardImpl(const OpContext& ctx, ctx, {TBlob(temp_data).reshape(data.shape_)}, {kWriteTo}, {var}, small); #else ReduceAxesRTCComputeImpl(ctx, {TBlob(temp_data).reshape(data.shape_)}, - {kWriteTo}, {var}, small, "red::sum{}", true); + {kWriteTo}, {var}, small, "red::sum{}", nullptr, true); #endif }); } diff --git a/src/operator/numpy/linalg/np_norm-inl.h b/src/operator/numpy/linalg/np_norm-inl.h index b8ab439106fe..29a33179ab7f 100644 --- a/src/operator/numpy/linalg/np_norm-inl.h +++ b/src/operator/numpy/linalg/np_norm-inl.h @@ -299,7 +299,7 @@ void NumpyLpNormCompute(const nnvm::NodeAttrs& attrs, #else ReduceAxesRTCComputeImpl( ctx, inputs, req, outputs, small, "red::nrmlp{" + std::to_string(param.ord) + "}", - false, "abs"); + nullptr, false, "abs"); #endif } } @@ -442,7 +442,7 @@ void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs, ctx, inputs, req, outputs, reduced_shape); #else ReduceAxesRTCComputeImpl( - ctx, inputs, req, outputs, reduced_shape, "red::nrm2{}", false, "identity"); + ctx, inputs, req, outputs, reduced_shape, "red::nrm2{}", nullptr, false, "identity"); #endif return; } @@ -466,13 +466,14 @@ void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs, ctx, sum_output, req, outputs, reduced_shape); } #else - ReduceAxesRTCComputeImpl(ctx, inputs, req, sum_output, sum_shape, "red::sum{}", false, "abs"); + ReduceAxesRTCComputeImpl(ctx, inputs, req, sum_output, sum_shape, + "red::sum{}", nullptr, false, "abs"); if (param.ord > 0) { ReduceAxesRTCComputeImpl(ctx, sum_output, req, outputs, reduced_shape, - "red::maximum{}", false); + "red::maximum{}", nullptr, false); } else { ReduceAxesRTCComputeImpl(ctx, sum_output, req, outputs, reduced_shape, - "red::minimum{}", false); + "red::minimum{}", nullptr, false); } #endif // MXNET_USE_CUDA }); @@ -563,14 +564,14 @@ void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs, } #else if (param.flag == 2) { // nuclear norm - ReduceAxesRTCComputeImpl(ctx, eigen, req, outputs, reduced_shape, "red::sum{}", false); + ReduceAxesRTCComputeImpl(ctx, eigen, req, outputs, reduced_shape, "red::sum{}", nullptr, false); } else { if (ord == 2) { ReduceAxesRTCComputeImpl(ctx, eigen, req, outputs, reduced_shape, - "red::maximum{}", false, "abs"); + "red::maximum{}", nullptr, false, "abs"); } else if (ord == -2) { ReduceAxesRTCComputeImpl(ctx, eigen, req, outputs, reduced_shape, - "red::minimum{}", false, "abs"); + "red::minimum{}", nullptr, false, "abs"); } } #endif @@ -814,7 +815,7 @@ void NumpyNormComputeForward(const nnvm::NodeAttrs& attrs, ctx, flat_inputs, req, flat_outputs, TShape(1, 1)); #else ReduceAxesRTCComputeImpl( - ctx, flat_inputs, req, flat_outputs, TShape(1, 1), "red::nrm2{}", false, "identity"); + ctx, flat_inputs, req, flat_outputs, TShape(1, 1), "red::nrm2{}", nullptr, false, "identity"); #endif return; } diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h index fc050a506aa4..df58ef2a1a27 100644 --- a/src/operator/numpy/np_broadcast_reduce_op.h +++ b/src/operator/numpy/np_broadcast_reduce_op.h @@ -675,36 +675,6 @@ struct NumpyMomentsParam : public dmlc::Parameter { } }; -template -void ReduceAxesComputeWithWorkspaceImpl(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - const mshadow::Tensor& workspace, - const mxnet::TShape& src_shape, - const mxnet::TShape& dst_shape, - const int ddof = 0) { - using namespace mshadow; - using namespace mshadow::expr; - - Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { - const TBlob in_data = inputs[0].reshape(src_shape); - const TBlob out_data = outputs[0].reshape(dst_shape); - BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { - broadcast::Reduce( - s, out_data, req[0], workspace, in_data); - if (normalize) { - auto out = out_data.FlatTo2D(s); - out /= scalar(src_shape.Size()/dst_shape.Size() - ddof); - } - }); - }); - }); -} - struct NumpyWeightedAverageParam : public dmlc::Parameter { dmlc::optional> axis; bool returned; @@ -883,13 +853,6 @@ struct avg_grad_w_1D_kernel { } }; -// Windows has issues with #ifdefs inside MSHADOW_TYPE_SWITCH -#ifndef __CUDACC__ -#define NP_BROADCAST_REDUCE_OP_BROADCAST(OP) BinaryBroadcastCompute -#else -#define NP_BROADCAST_REDUCE_OP_BROADCAST(OP) BinaryBroadcastRTCCompute {#OP} -#endif - template void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -926,6 +889,9 @@ void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs, weights = weights.reshape(new_w_shape); small2 = TShape(new_w_shape.ndim(), 1); } + TBlob wa; + TBlob sum_of_wa; + Tensor workspace; MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { // Get temp space size_t temp_data_size = data.shape_.Size() * sizeof(DType); @@ -938,56 +904,49 @@ void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs, size_t temp_mem_size = temp_data_size + temp_sum_size + workspace_size; Tensor temp_mem = ctx.requested[0].get_space_typed(Shape1(temp_mem_size), s); - DType *temp_data_ptr = reinterpret_cast(temp_mem.dptr_); - DType *temp_sum_ptr = reinterpret_cast(temp_mem.dptr_ + temp_data_size); + auto *temp_data_ptr = reinterpret_cast(temp_mem.dptr_); + auto *temp_sum_ptr = reinterpret_cast(temp_mem.dptr_ + temp_data_size); char *workspace_ptr = temp_mem.dptr_ + temp_data_size + temp_sum_size; - Tensor workspace(workspace_ptr, Shape1(workspace_size), s); + workspace = Tensor(workspace_ptr, Shape1(workspace_size), s); // Compute weighted data - TBlob wa = TBlob(temp_data_ptr, data.shape_, xpu::kDevMask); - TBlob sum_of_wa; - if constexpr (std::is_same::value) { - BinaryBroadcastCompute( - attrs, ctx, {data, weights}, {kWriteTo}, {wa}); - - // Compute sum of weighted data - sum_of_wa = TBlob(temp_sum_ptr, small1, xpu::kDevMask); - ReduceAxesComputeWithWorkspaceImpl( - ctx, {wa}, {kWriteTo}, {sum_of_wa}, workspace, src_shape, dst_shape); - } else { -#if MXNET_USE_CUDA - BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, {data, weights}, {kWriteTo}, {wa}); + wa = TBlob(temp_data_ptr, data.shape_, xpu::kDevMask); + sum_of_wa = TBlob(temp_sum_ptr, small1, xpu::kDevMask); + }); +#if !defined(__CUDACC__) + BinaryBroadcastCompute( + attrs, ctx, {data, weights}, {kWriteTo}, {wa}); - // Compute sum of weighted data - sum_of_wa = TBlob(temp_sum_ptr, small1, xpu::kDevMask); - ReduceAxesRTCComputeWithWorkspaceImpl(ctx, {wa}, {kWriteTo}, {sum_of_wa}, "red::sum{}", - false, "identity", workspace, src_shape, dst_shape); + // Compute sum of weighted data + ReduceAxesComputeImpl( + ctx, {wa}, {kWriteTo}, {sum_of_wa}, small1, &workspace); +#else + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, {data, weights}, {kWriteTo}, {wa}); + + // Compute sum of weighted data + ReduceAxesRTCComputeImpl(ctx, {wa}, {kWriteTo}, {sum_of_wa}, small1, "red::sum{}", + &workspace, false, "identity"); #endif - } - if (!back) { - const TBlob& avg = outputs[0]; - const TBlob& sum_of_weights = outputs[1]; - TShape w_src_shape, w_dst_shape; - BroadcastReduceShapeCompact(weights.shape_, small2, &w_src_shape, &w_dst_shape); - // Compute sum of weight - TBlob scl = sum_of_weights.reshape(small2); - if constexpr (std::is_same::value) { - ReduceAxesComputeWithWorkspaceImpl( - ctx, {weights}, {kWriteTo}, {scl}, workspace, w_src_shape, w_dst_shape); - // Compute avg and assign output - BinaryBroadcastCompute( - attrs, ctx, {sum_of_wa, scl}, req, {avg.reshape(small1)}); - } else { -#if MXNET_USE_CUDA - ReduceAxesRTCComputeWithWorkspaceImpl(ctx, {weights}, {kWriteTo}, {scl}, "red::sum{}", - false, "identity", workspace, w_src_shape, - w_dst_shape); - // Compute avg and assign output - BinaryBroadcastRTCCompute {"div"}( - attrs, ctx, {sum_of_wa, scl}, req, {avg.reshape(small1)}); + if (!back) { + const TBlob& avg = outputs[0]; + const TBlob& sum_of_weights = outputs[1]; + // Compute sum of weight + TBlob scl = sum_of_weights.reshape(small2); +#if !defined(__CUDACC__) + ReduceAxesComputeImpl( + ctx, {weights}, {kWriteTo}, {scl}, small2, &workspace); + // Compute avg and assign output + BinaryBroadcastCompute( + attrs, ctx, {sum_of_wa, scl}, req, {avg.reshape(small1)}); +#else + ReduceAxesRTCComputeImpl(ctx, {weights}, {kWriteTo}, {scl}, small2, "red::sum{}", + &workspace, false, "identity"); + // Compute avg and assign output + BinaryBroadcastRTCCompute {"div"}( + attrs, ctx, {sum_of_wa, scl}, req, {avg.reshape(small1)}); #endif - } - } else { + } else { + MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { // Compute and assign the derivatives of a and weights const TBlob& igrad_a = outputs[0]; const TBlob& igrad_w = outputs[1]; @@ -1026,12 +985,10 @@ void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs, } }); }) - } - }); + }); + } } -#undef NP_BROADCAST_REDUCE_OP_BROADCAST - template void NumpyWeightedAverageForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -1044,27 +1001,29 @@ void NumpyWeightedAverageForward(const nnvm::NodeAttrs& attrs, CHECK_NE(req[0], kWriteInplace) << "Average does not support write in-place"; const auto& param = nnvm::get(attrs.parsed); const TBlob& data = inputs[0]; + TShape small; MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { if (!param.weighted) { - TShape small = NumpyReduceAxesShapeImpl(data.shape_, param.axis, true); + small = NumpyReduceAxesShapeImpl(data.shape_, param.axis, true); // Compute sum of weights which equals to the product of sizes of reduced axes Stream* s = ctx.get_stream(); auto ret = outputs[1].FlatTo1D(s); ret = scalar(data.shape_.Size()/small.Size()); - // Compute mean - if constexpr (std::is_same::value) { - ReduceAxesComputeImpl( - ctx, inputs, req, {outputs[0]}, small); - } else { -#if MXNET_USE_CUDA - ReduceAxesRTCComputeImpl(ctx, inputs, req, {outputs[0]}, small, "red::sum{}", true); -#endif - } - } else { - NumpyWeightedAverageComputeImpl( - attrs, ctx, inputs, req, outputs, param.axis); } }); + if (!param.weighted) { + // Compute mean +#if !defined(__CUDACC__) + ReduceAxesComputeImpl( + ctx, inputs, req, {outputs[0]}, small); +#else + ReduceAxesRTCComputeImpl(ctx, inputs, req, {outputs[0]}, small, + "red::sum{}", nullptr, true); +#endif + } else { + NumpyWeightedAverageComputeImpl( + attrs, ctx, inputs, req, outputs, param.axis); + } } template @@ -1130,58 +1089,63 @@ void NumpyMomentsForward(const nnvm::NodeAttrs& attrs, mxnet::TShape src_shape, dst_shape; BroadcastReduceShapeCompact(data.shape_, small, &src_shape, &dst_shape); + Tensor temp_mem; + Tensor workspace; + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + // Get workspace and temp space for data - mean + size_t workspace_size = 0; + workspace_size = broadcast::ReduceWorkspaceSize( + s, dst_shape, req[0], src_shape); + size_t temp_data_size = data.shape_.Size() * sizeof(DType); + size_t temp_mem_size = temp_data_size + workspace_size; + temp_mem = ctx.requested[0].get_space_typed(Shape1(temp_mem_size), s); + char *workspace_ptr = temp_mem.dptr_ + temp_data_size; + workspace = Tensor(workspace_ptr, Shape1(workspace_size), s); + }); + // Compute mean +#if !defined(__CUDACC__) + ReduceAxesComputeImpl( + ctx, inputs, {kWriteTo}, {mean}, small, &workspace); +#else + ReduceAxesRTCComputeImpl(ctx, inputs, {kWriteTo}, {mean}, small, "red::sum{}", + &workspace, true, "identity"); +#endif + // Compute data - mean + Shape<6> data_shape, mean_shape; + for (int i = 0; i < 6; ++i) { + data_shape[i] = (i < data.shape_.ndim()) ? data.shape_[i] : 1; + mean_shape[i] = (i < small.ndim()) ? small[i] : 1; + } +#if !defined(__CUDACC__) MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { - // Get workspace and temp space for data - mean - size_t workspace_size = 0; - workspace_size = broadcast::ReduceWorkspaceSize( - s, dst_shape, req[0], src_shape); - size_t temp_data_size = data.shape_.Size() * sizeof(DType); - size_t temp_mem_size = temp_data_size + workspace_size; - Tensor temp_mem = - ctx.requested[0].get_space_typed(Shape1(temp_mem_size), s); DType *temp_data_ptr = reinterpret_cast(temp_mem.dptr_); - char *workspace_ptr = temp_mem.dptr_ + temp_data_size; - Tensor workspace(workspace_ptr, Shape1(workspace_size), s); - // Compute mean - if constexpr (std::is_same::value) { - ReduceAxesComputeWithWorkspaceImpl( - ctx, inputs, {kWriteTo}, {mean}, workspace, src_shape, dst_shape); - } else { -#if MXNET_USE_CUDA - ReduceAxesRTCComputeWithWorkspaceImpl(ctx, inputs, {kWriteTo}, {mean}, "red::sum{}", - true, "identity", workspace, src_shape, dst_shape); -#endif - } - // Compute data - mean - Shape<6> data_shape, mean_shape; - for (int i = 0; i < 6; ++i) { - data_shape[i] = (i < data.shape_.ndim()) ? data.shape_[i] : 1; - mean_shape[i] = (i < small.ndim()) ? small[i] : 1; - } Kernel::Launch(s, data_shape.Size(), temp_data_ptr, data.dptr(), mean.dptr(), data_shape, mean_shape); Tensor temp_data_tensor(temp_data_ptr, Shape1(data.shape_.Size()), s); TBlob temp_data_blob = TBlob(temp_data_tensor).reshape(data.shape_); - if constexpr (std::is_same::value) { - ReduceAxesComputeWithWorkspaceImpl( - ctx, {temp_data_blob}, {req[0]}, {moment}, workspace, src_shape, dst_shape, param.ddof); - if (sqrt && req[0] != kNullOp) { - Tensor moment_tensor = moment.FlatTo1D(s); - moment_tensor = F(moment_tensor); - } - } else { -#if MXNET_USE_CUDA - ReduceAxesRTCComputeWithWorkspaceImpl(ctx, {temp_data_blob}, {req[0]}, {moment}, - "red::sum{}", true, "identity", workspace, - src_shape, dst_shape, param.ddof); - if (sqrt && req[0] != kNullOp) { - UnaryRTCCompute {"sqrt"}({}, ctx, {moment}, {kWriteInplace}, {moment}); - } -#endif + ReduceAxesComputeImpl( + ctx, {temp_data_blob}, {req[0]}, {moment}, small, &workspace, param.ddof); + if (sqrt && req[0] != kNullOp) { + Tensor moment_tensor = moment.FlatTo1D(s); + moment_tensor = F(moment_tensor); } }); }); +#else + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + DType *temp_data_ptr = reinterpret_cast(temp_mem.dptr_); + Kernel::Launch(s, data_shape.Size(), temp_data_ptr, + data.dptr(), mean.dptr(), data_shape, mean_shape); + Tensor temp_data_tensor(temp_data_ptr, Shape1(data.shape_.Size()), s); + TBlob temp_data_blob = TBlob(temp_data_tensor).reshape(data.shape_); + ReduceAxesRTCComputeImpl(ctx, {temp_data_blob}, {req[0]}, {moment}, small, + "red::sum{}", &workspace, true, "identity", param.ddof); + if (sqrt && req[0] != kNullOp) { + UnaryRTCCompute {"sqrt"}({}, ctx, {moment}, {kWriteInplace}, {moment}); + } + }); +#endif } template @@ -1227,7 +1191,7 @@ void NumpyBroadcastToBackward(const nnvm::NodeAttrs& attrs, } #else ReduceAxesRTCComputeImpl(ctx, inputs, req, {outputs[0].reshape(expanded_igrad_shape)}, - expanded_igrad_shape, "red::sum{}", false); + expanded_igrad_shape, "red::sum{}", nullptr, false); #endif } diff --git a/src/operator/numpy/np_kron-inl.h b/src/operator/numpy/np_kron-inl.h index d37d3c94e807..49aee60b578c 100644 --- a/src/operator/numpy/np_kron-inl.h +++ b/src/operator/numpy/np_kron-inl.h @@ -188,6 +188,14 @@ void KronOpForwardImpl(const OpContext& ctx, }); } +#if !defined(__CUDACC__) +#define NP_KRON_REDUCE_AXES(safe_acc, workspace, ...) \ + ReduceAxesComputeImpl(__VA_ARGS__, &workspace) +#else +#define NP_KRON_REDUCE_AXES(safe_acc, workspace, ...) \ + ReduceAxesRTCComputeImpl(__VA_ARGS__, "red::sum{}", &workspace) +#endif + template void KronOpBackwardImpl(const OpContext& ctx, const std::vector& req, @@ -226,19 +234,23 @@ void KronOpBackwardImpl(const OpContext& ctx, const OpReqType& scalar_req = (ashape.ndim() == 0) ? req[0] : req[1]; ASSIGN_DISPATCH(tensor_grad_, tensor_req, broadcast_scalar(scalar_, tensor_grad_.shape_) * ograd_); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(ograd.shape_.Size()), s); - ASSIGN_DISPATCH(workspace, kWriteTo, tensor_ * ograd_); - - if constexpr (std::is_same::value) { - ReduceAxesComputeImpl( - ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_); - } else { -#if MXNET_USE_CUDA - ReduceAxesRTCComputeImpl(ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, - scalar_grad_.shape_, "red::sum{}", false); -#endif - } + TShape src_shape, dst_shape; + BroadcastReduceShapeCompact(ograd.shape_, scalar_grad_.shape_, &src_shape, &dst_shape); + size_t workspace_size = broadcast::ReduceWorkspaceSize(s, dst_shape, + {scalar_req}, src_shape); + constexpr size_t align_size = 1024; + const size_t aligned_first_workspace_size = ((workspace_size + align_size - 1) / align_size) + * align_size; + workspace_size = aligned_first_workspace_size + ograd.shape_.Size() * sizeof(DType); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + Tensor temp(reinterpret_cast(workspace.dptr_ + + aligned_first_workspace_size), + Shape1(ograd.shape_.Size())); + ASSIGN_DISPATCH(temp, kWriteTo, tensor_ * ograd_); + + NP_KRON_REDUCE_AXES(true, workspace, ctx, {TBlob(temp)}, {scalar_req}, + {TBlob(scalar_grad_)}, scalar_grad_.shape_); } else { MXNET_NDIM_SWITCH(oshape.ndim(), ndim, { Shape ashape_ = oshape.get(); @@ -283,6 +295,8 @@ void KronOpBackwardImpl(const OpContext& ctx, }); } +#undef NP_KRON_REDUCE_AXES + template inline void KronOpForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/numpy/np_tensordot_op-inl.h b/src/operator/numpy/np_tensordot_op-inl.h index 30cd59733a04..47e42d01d6a6 100644 --- a/src/operator/numpy/np_tensordot_op-inl.h +++ b/src/operator/numpy/np_tensordot_op-inl.h @@ -370,6 +370,14 @@ inline mxnet::TShape GetReverseShape(const mxnet::Tuple& shape) { return shape2; } +#if !defined(__CUDACC__) +#define NP_TENSORDOT_REDUCE_AXES(safe_acc, ...) \ + ReduceAxesComputeImpl(__VA_ARGS__) +#else +#define NP_TENSORDOT_REDUCE_AXES(safe_acc, ...) \ + ReduceAxesRTCComputeImpl(__VA_ARGS__, "red::sum{}") +#endif + /** * calculates tensordot derivative. */ @@ -424,16 +432,8 @@ void TensordotBackwardImpl(const Tuple& a_axes_summed, workspace.stream_); ASSIGN_DISPATCH(dtypespace, kWriteTo, tensor_ * out_grad_); - if constexpr (std::is_same::value) { - ReduceAxesComputeImpl( - ctx, {TBlob(dtypespace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_); - } else { -#if MXNET_USE_CUDA - ReduceAxesRTCComputeImpl(ctx, {TBlob(dtypespace)}, {scalar_req}, - {TBlob(scalar_grad_)}, scalar_grad_.shape_, - "red::sum{}"); -#endif - } + NP_TENSORDOT_REDUCE_AXES(true, ctx, {TBlob(dtypespace)}, {scalar_req}, + {TBlob(scalar_grad_)}, scalar_grad_.shape_); } else { // Two tensors of at least 1 dimensions. Tuple a_axes_remained; @@ -742,15 +742,8 @@ void TensordotIntAxesBackwardImpl(const int axes, ctx.requested[0].get_space_typed(Shape1(out_grad.shape_.Size()), s); ASSIGN_DISPATCH(workspace, kWriteTo, tensor_ * out_grad_); - if constexpr (std::is_same::value) { - ReduceAxesComputeImpl( - ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_); - } else { -#if MXNET_USE_CUDA - ReduceAxesRTCComputeImpl(ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, - scalar_grad_.shape_, "red::sum{}"); -#endif - } + NP_TENSORDOT_REDUCE_AXES(true, ctx, {TBlob(workspace)}, {scalar_req}, + {TBlob(scalar_grad_)}, scalar_grad_.shape_); } else { // Two tensors of at least 1 dimensions. Tuple a_axes_summed; @@ -774,6 +767,8 @@ void TensordotIntAxesBackwardImpl(const int axes, }); } +#undef NP_TENSORDOT_REDUCE_AXES + /** * backward function. */ diff --git a/src/operator/numpy/np_where_op-inl.h b/src/operator/numpy/np_where_op-inl.h index 4e899d6a8560..17020b90dbe3 100644 --- a/src/operator/numpy/np_where_op-inl.h +++ b/src/operator/numpy/np_where_op-inl.h @@ -175,6 +175,12 @@ inline void NumpyWhereOpForward(const nnvm::NodeAttrs& attrs, }); } +#if !defined(__CUDACC__) +#define NP_WHERE_REDUCE_AXES(safe_acc, ...) ReduceAxesComputeImpl(__VA_ARGS__) +#else +#define NP_WHERE_REDUCE_AXES(safe_acc, ...) ReduceAxesRTCComputeImpl(__VA_ARGS__, "red::sum{}") +#endif + template inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -245,19 +251,12 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs, mxnet_op::Kernel, xpu>::Launch( s, ograd.Size(), req[0], cstride, oshape, cond.dptr(), ograd.dptr(), workspace.dptr_); - if constexpr (std::is_same::value) { - if (NeedSafeAcc(dx.type_flag_, dx.type_flag_)) { - ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, - {dx.reshape(expanded_lshape)}, expanded_lshape); - } else { - ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, - {dx.reshape(expanded_lshape)}, expanded_lshape); - } + if (NeedSafeAcc(dx.type_flag_, dx.type_flag_)) { + NP_WHERE_REDUCE_AXES(true, ctx, {TBlob(workspace)}, {req[0]}, + {dx.reshape(expanded_lshape)}, expanded_lshape); } else { -#if MXNET_USE_CUDA - ReduceAxesRTCComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, - {dx.reshape(expanded_lshape)}, expanded_lshape, "red::sum{}"); -#endif // MXNET_USE_CUDA + NP_WHERE_REDUCE_AXES(false, ctx, {TBlob(workspace)}, {req[0]}, + {dx.reshape(expanded_lshape)}, expanded_lshape); } } // process right output @@ -274,19 +273,12 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs, mxnet_op::Kernel, xpu>::Launch( s, ograd.Size(), req[1], cstride, oshape, cond.dptr(), ograd.dptr(), workspace.dptr_); - if constexpr (std::is_same::value) { - if (NeedSafeAcc(dy.type_flag_, dy.type_flag_)) { - ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[1]}, - {dy.reshape(expanded_rshape)}, expanded_rshape); - } else { - ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[1]}, - {dy.reshape(expanded_rshape)}, expanded_rshape); - } + if (NeedSafeAcc(dy.type_flag_, dy.type_flag_)) { + NP_WHERE_REDUCE_AXES(true, ctx, {TBlob(workspace)}, {req[1]}, + {dy.reshape(expanded_rshape)}, expanded_rshape); } else { -#if MXNET_USE_CUDA - ReduceAxesRTCComputeImpl(ctx, {TBlob(workspace)}, {req[1]}, - {dy.reshape(expanded_rshape)}, expanded_rshape, "red::sum{}"); -#endif // MXNET_USE_CUDA + NP_WHERE_REDUCE_AXES(false, ctx, {TBlob(workspace)}, {req[1]}, + {dy.reshape(expanded_rshape)}, expanded_rshape); } } }); @@ -397,25 +389,20 @@ inline void NumpyWhereScalarOpBackward(const nnvm::NodeAttrs& attrs, mxnet_op::Kernel, xpu>::Launch( s, ograd.Size(), req[0], cstride, oshape, cond.dptr(), ograd.dptr(), workspace.dptr_); - if constexpr (std::is_same::value) { - if (NeedSafeAcc(dx.type_flag_, dx.type_flag_)) { - ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, - {dx.reshape(expanded_lshape)}, expanded_lshape); - } else { - ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, - {dx.reshape(expanded_lshape)}, expanded_lshape); - } + if (NeedSafeAcc(dx.type_flag_, dx.type_flag_)) { + NP_WHERE_REDUCE_AXES(true, ctx, {TBlob(workspace)}, {req[0]}, + {dx.reshape(expanded_lshape)}, expanded_lshape); } else { -#if MXNET_USE_CUDA - ReduceAxesRTCComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, - {dx.reshape(expanded_lshape)}, expanded_lshape, "red::sum{}"); -#endif // MXNET_USE_CUDA + NP_WHERE_REDUCE_AXES(false, ctx, {TBlob(workspace)}, {req[0]}, + {dx.reshape(expanded_lshape)}, expanded_lshape); } } }); }); } +#undef NP_WHERE_REDUCE_AXES + template inline void NumpyWhereScalar2OpForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/random/pdf_op.h b/src/operator/random/pdf_op.h index db17bbb58fad..fcd2d407b319 100644 --- a/src/operator/random/pdf_op.h +++ b/src/operator/random/pdf_op.h @@ -586,10 +586,11 @@ void PdfOpBackward(const nnvm::NodeAttrs& attrs, const PdfParam& param = nnvm::get(attrs.parsed); const size_t N(outputs[1].Size()); const TShape src_shape(Shape2(N, outputs[0].Size() / N)), dst_shape(Shape2(N, 1)); + const size_t red_work_size(broadcast::ReduceWorkspaceSize( + s, dst_shape, kAddTo, src_shape)); +#if !defined(__CUDACC__) // Inputs to PdfOpBackward: grad, samples, parm1, parm2, pdf. MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - const size_t red_work_size(broadcast::ReduceWorkspaceSize( - s, dst_shape, kAddTo, src_shape)); const size_t tmp_size(outputs[0].Size() * pnum * sizeof(DType) + red_work_size); Tensor tmp_space = ctx.requested[0].get_space_typed(Shape1(tmp_size), s); @@ -607,24 +608,42 @@ void PdfOpBackward(const nnvm::NodeAttrs& attrs, } Tensor red_work( tmp_space.dptr_ + pnum * outputs[0].Size() * sizeof(DType), Shape1(red_work_size), s); - if constexpr (std::is_same::value) { + broadcast::Reduce( + s, outputs[1].reshape(dst_shape), req[1], red_work, grads[1].reshape(src_shape)); + if (pnum == 2) { broadcast::Reduce( - s, outputs[1].reshape(dst_shape), req[1], red_work, grads[1].reshape(src_shape)); - if (pnum == 2) { - broadcast::Reduce( - s, outputs[2].reshape(dst_shape), req[2], red_work, grads[2].reshape(src_shape)); - } + s, outputs[2].reshape(dst_shape), req[2], red_work, grads[2].reshape(src_shape)); + } + }); +#else + // Inputs to PdfOpBackward: grad, samples, parm1, parm2, pdf. + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + const size_t tmp_size(outputs[0].Size() * pnum * sizeof(DType) + red_work_size); + Tensor tmp_space = + ctx.requested[0].get_space_typed(Shape1(tmp_size), s); + std::vector grads = {outputs[0]}; + grads.push_back(TBlob(tmp_space.dptr_, outputs[0].shape_, + outputs[1].dev_mask(), outputs[1].type_flag_, outputs[1].dev_id())); + if (pnum == 2) { + grads.push_back(TBlob(tmp_space.dptr_ + outputs[0].Size() * sizeof(DType), outputs[0].shape_, + outputs[2].dev_mask(), outputs[2].type_flag_, outputs[2].dev_id())); + } + if (param.is_log) { + PdfGradCaller, pnum, vparm>::op(inputs, req, grads, s); } else { -#if MXNET_USE_CUDA - broadcast::RTCReduce(ctx, outputs[1].reshape(dst_shape), req[1], red_work, - grads[1].reshape(src_shape), "red::sum{}", 2, "identity"); - if (pnum == 2) { - broadcast::RTCReduce(ctx, outputs[2].reshape(dst_shape), req[2], red_work, - grads[2].reshape(src_shape), "red::sum{}", 2, "identity"); - } -#endif + PdfGradCaller, pnum, vparm>::op(inputs, req, grads, s); + } + Tensor red_work( + tmp_space.dptr_ + pnum * outputs[0].Size() * sizeof(DType), Shape1(red_work_size), s); + broadcast::RTCReduce(ctx, outputs[1].reshape(dst_shape), req[1], red_work, + grads[1].reshape(src_shape), "red::sum{}", 2, "identity"); + if (pnum == 2) { + broadcast::RTCReduce(ctx, outputs[2].reshape(dst_shape), req[2], red_work, + grads[2].reshape(src_shape), "red::sum{}", 2, "identity"); } }); + +#endif } } // namespace op diff --git a/src/operator/tensor/broadcast_reduce_op.cc b/src/operator/tensor/broadcast_reduce_op.cc index 9348ad1bd2bf..483787ec7b0a 100644 --- a/src/operator/tensor/broadcast_reduce_op.cc +++ b/src/operator/tensor/broadcast_reduce_op.cc @@ -28,37 +28,13 @@ namespace op { #if MXNET_USE_CUDA -void ReduceAxesRTCComputeWithWorkspaceImpl(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - const std::string& reducer, - const bool normalize, - const std::string& OP, - const mshadow::Tensor& workspace, - const mxnet::TShape& src_shape, - const mxnet::TShape& dst_shape, - const int ddof) { - const TBlob in_data = inputs[0].reshape(src_shape); - const TBlob out_data = outputs[0].reshape(dst_shape); - BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { - broadcast::RTCReduce(ctx, out_data, req[0], workspace, in_data, reducer, NDim, OP); - }); - if (normalize) { - NumpyBinaryScalarParam p{}; - p.scalar = static_cast(src_shape.Size()/dst_shape.Size() - ddof); - NodeAttrs a; - a.parsed = p; - BinaryScalarRTCCompute {"div"}(a, ctx, {out_data}, {kWriteInplace}, {out_data}); - } -} - void ReduceAxesRTCComputeImpl(const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs, const mxnet::TShape& small, const std::string& reducer, + const mshadow::Tensor* workspace, const bool normalize, const std::string& OP, const int ddof) { @@ -67,12 +43,25 @@ void ReduceAxesRTCComputeImpl(const OpContext& ctx, mxnet::TShape src_shape, dst_shape; BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape); Stream* s = ctx.get_stream(); - size_t workspace_size = broadcast::ReduceWorkspaceSize( - s, dst_shape, req[0], src_shape); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); - ReduceAxesRTCComputeWithWorkspaceImpl(ctx, inputs, req, outputs, reducer, normalize, - OP, workspace, src_shape, dst_shape, ddof); + Tensor w; + if (workspace == nullptr) { + size_t workspace_size = broadcast::ReduceWorkspaceSize( + s, dst_shape, req[0], src_shape); + w = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + workspace = &w; + } + const TBlob in_data = inputs[0].reshape(src_shape); + const TBlob out_data = outputs[0].reshape(dst_shape); + BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, out_data, req[0], *workspace, in_data, reducer, NDim, OP); + }); + if (normalize) { + NumpyBinaryScalarParam p{}; + p.scalar = static_cast(src_shape.Size()/dst_shape.Size() - ddof); + NodeAttrs a; + a.parsed = p; + BinaryScalarRTCCompute {"div"}(a, ctx, {out_data}, {kWriteInplace}, {out_data}); + } } namespace { @@ -182,7 +171,7 @@ void ReduceAxesRTCCompute::operator()(const nnvm::NodeAttrs& attrs, return; } - ReduceAxesRTCComputeImpl(ctx, inputs, req, outputs, small, reducer, normalize, OP, ddof); + ReduceAxesRTCComputeImpl(ctx, inputs, req, outputs, small, reducer, nullptr, normalize, OP, ddof); } template struct ReduceAxesRTCCompute; diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index 462c81343d44..ce0fade99bfc 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -622,7 +622,9 @@ void ReduceAxesComputeImpl(const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs, - const mxnet::TShape& small) { + const mxnet::TShape& small, + const mshadow::Tensor* workspace = nullptr, + const int ddof = 0) { using namespace mshadow; using namespace mshadow::expr; @@ -634,15 +636,18 @@ void ReduceAxesComputeImpl(const OpContext& ctx, const TBlob in_data = inputs[0].reshape(src_shape); const TBlob out_data = outputs[0].reshape(dst_shape); BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { - size_t workspace_size = broadcast::ReduceWorkspaceSize( - s, out_data.shape_, req[0], in_data.shape_); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + Tensor w; + if (workspace == nullptr) { + size_t workspace_size = broadcast::ReduceWorkspaceSize( + s, out_data.shape_, req[0], in_data.shape_); + w = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + workspace = &w; + } broadcast::Reduce( - s, out_data, req[0], workspace, in_data); + s, out_data, req[0], *workspace, in_data); if (normalize) { auto out = out_data.FlatTo2D(s); - out /= scalar(src_shape.Size()/dst_shape.Size()); + out /= scalar(src_shape.Size()/dst_shape.Size() - ddof); } }); }); @@ -721,21 +726,12 @@ void ReduceAxesRTCComputeImpl(const OpContext& ctx, const std::vector& outputs, const mxnet::TShape& small, const std::string& reducer, + const mshadow::Tensor* workspace = nullptr, + const bool normalize = false, const std::string& OP = "identity", const int ddof = 0); -void ReduceAxesRTCComputeWithWorkspaceImpl(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - const std::string& reducer, - const bool normalize, - const std::string& OP, - const mshadow::Tensor& workspace, - const mxnet::TShape& src_shape, - const mxnet::TShape& dst_shape, - const int ddof = 0); #endif template @@ -1549,7 +1545,7 @@ void LpNormCompute(const nnvm::NodeAttrs& attrs, const std::string &op = param.ord == 1 ? "abs" : "identity"; - ReduceAxesRTCComputeImpl(ctx, inputs, req, outputs, small, red, false, op); + ReduceAxesRTCComputeImpl(ctx, inputs, req, outputs, small, red, nullptr, false, op); #endif } diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index fa61268d8fa1..1aaf0b52b242 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -2050,7 +2050,8 @@ void RepeatOpBackward(const nnvm::NodeAttrs& attrs, ReduceAxesComputeImpl( ctx, newInputs, req, newOutputs, rshapes.first); #else - ReduceAxesRTCComputeImpl(ctx, newInputs, req, newOutputs, rshapes.first, "red::sum{}", false); + ReduceAxesRTCComputeImpl(ctx, newInputs, req, newOutputs, rshapes.first, + "red::sum{}", nullptr, false); #endif } @@ -2246,7 +2247,8 @@ void TileOpBackward(const nnvm::NodeAttrs& attrs, ReduceAxesComputeImpl( ctx, newInputs, req, newOutputs, rshapes.first); #else - ReduceAxesRTCComputeImpl(ctx, newInputs, req, newOutputs, rshapes.first, "red::sum{}", false); + ReduceAxesRTCComputeImpl(ctx, newInputs, req, newOutputs, rshapes.first, + "red::sum{}", nullptr, false); #endif } From 4d147265873340eb6e5370f80bde1875038c35db Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 30 Oct 2020 17:19:26 -0700 Subject: [PATCH 07/15] Fix lint --- src/operator/numpy/np_where_op-inl.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/operator/numpy/np_where_op-inl.h b/src/operator/numpy/np_where_op-inl.h index 17020b90dbe3..43af21def5ae 100644 --- a/src/operator/numpy/np_where_op-inl.h +++ b/src/operator/numpy/np_where_op-inl.h @@ -176,7 +176,8 @@ inline void NumpyWhereOpForward(const nnvm::NodeAttrs& attrs, } #if !defined(__CUDACC__) -#define NP_WHERE_REDUCE_AXES(safe_acc, ...) ReduceAxesComputeImpl(__VA_ARGS__) +#define NP_WHERE_REDUCE_AXES(safe_acc, ...) \ + ReduceAxesComputeImpl(__VA_ARGS__) #else #define NP_WHERE_REDUCE_AXES(safe_acc, ...) ReduceAxesRTCComputeImpl(__VA_ARGS__, "red::sum{}") #endif From 7541de08cfa480345a84ad5f9758b27e833bf04d Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 2 Nov 2020 08:11:53 -0800 Subject: [PATCH 08/15] Try to WAR the maybe-uninitialized warning --- 3rdparty/mshadow/mshadow/tensor.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/3rdparty/mshadow/mshadow/tensor.h b/3rdparty/mshadow/mshadow/tensor.h index c92bf8d076d5..35d5f2e7bc47 100644 --- a/3rdparty/mshadow/mshadow/tensor.h +++ b/3rdparty/mshadow/mshadow/tensor.h @@ -432,7 +432,7 @@ struct Tensor: public TRValue, // struct memembers //-------------------------------- /*! \brief pointer to the data */ - DType *dptr_ = nullptr; + DType *dptr_; /*! \brief shape of the tensor */ Shape shape_; /*! @@ -449,13 +449,13 @@ struct Tensor: public TRValue, // functions //-------------------------------- /*! \brief default constructor */ - MSHADOW_XINLINE Tensor(void) : stream_(NULL) {} + MSHADOW_XINLINE Tensor(void) : dptr_(nullptr), stream_(nullptr) {} /*! \brief constructor from shape */ MSHADOW_XINLINE Tensor(const Shape &shape) - : shape_(shape), stream_(NULL) {} + : dptr_(nullptr), shape_(shape), stream_(nullptr) {} /*! \brief constructor from data pointer and shape, without stride */ MSHADOW_XINLINE Tensor(DType *dptr, const Shape &shape) - : dptr_(dptr), shape_(shape), stride_(shape[kSubdim]), stream_(NULL) {} + : dptr_(dptr), shape_(shape), stride_(shape[kSubdim]), stream_(nullptr) {} /*! \brief constructor from data pointer and shape, without stride */ MSHADOW_XINLINE Tensor(DType *dptr, const Shape &shape, Stream *stream) From ac278b15d382ef7b07df957cce314033dcbca1f5 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 2 Nov 2020 08:59:14 -0800 Subject: [PATCH 09/15] Second try --- 3rdparty/mshadow/mshadow/tensor.h | 8 ++++---- src/operator/numpy/np_broadcast_reduce_op.h | 21 ++++++++------------- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/3rdparty/mshadow/mshadow/tensor.h b/3rdparty/mshadow/mshadow/tensor.h index 35d5f2e7bc47..c92bf8d076d5 100644 --- a/3rdparty/mshadow/mshadow/tensor.h +++ b/3rdparty/mshadow/mshadow/tensor.h @@ -432,7 +432,7 @@ struct Tensor: public TRValue, // struct memembers //-------------------------------- /*! \brief pointer to the data */ - DType *dptr_; + DType *dptr_ = nullptr; /*! \brief shape of the tensor */ Shape shape_; /*! @@ -449,13 +449,13 @@ struct Tensor: public TRValue, // functions //-------------------------------- /*! \brief default constructor */ - MSHADOW_XINLINE Tensor(void) : dptr_(nullptr), stream_(nullptr) {} + MSHADOW_XINLINE Tensor(void) : stream_(NULL) {} /*! \brief constructor from shape */ MSHADOW_XINLINE Tensor(const Shape &shape) - : dptr_(nullptr), shape_(shape), stream_(nullptr) {} + : shape_(shape), stream_(NULL) {} /*! \brief constructor from data pointer and shape, without stride */ MSHADOW_XINLINE Tensor(DType *dptr, const Shape &shape) - : dptr_(dptr), shape_(shape), stride_(shape[kSubdim]), stream_(nullptr) {} + : dptr_(dptr), shape_(shape), stride_(shape[kSubdim]), stream_(NULL) {} /*! \brief constructor from data pointer and shape, without stride */ MSHADOW_XINLINE Tensor(DType *dptr, const Shape &shape, Stream *stream) diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h index df58ef2a1a27..9ce3967f4797 100644 --- a/src/operator/numpy/np_broadcast_reduce_op.h +++ b/src/operator/numpy/np_broadcast_reduce_op.h @@ -1089,19 +1089,14 @@ void NumpyMomentsForward(const nnvm::NodeAttrs& attrs, mxnet::TShape src_shape, dst_shape; BroadcastReduceShapeCompact(data.shape_, small, &src_shape, &dst_shape); - Tensor temp_mem; - Tensor workspace; - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { - // Get workspace and temp space for data - mean - size_t workspace_size = 0; - workspace_size = broadcast::ReduceWorkspaceSize( - s, dst_shape, req[0], src_shape); - size_t temp_data_size = data.shape_.Size() * sizeof(DType); - size_t temp_mem_size = temp_data_size + workspace_size; - temp_mem = ctx.requested[0].get_space_typed(Shape1(temp_mem_size), s); - char *workspace_ptr = temp_mem.dptr_ + temp_data_size; - workspace = Tensor(workspace_ptr, Shape1(workspace_size), s); - }); + // Get workspace and temp space for data - mean + size_t workspace_size = broadcast::ReduceWorkspaceSize(s, dst_shape, req[0], src_shape); + size_t temp_data_size = data.shape_.Size() * common::mshadow_type_info(inputs[0].type_flag_).size; + size_t temp_mem_size = temp_data_size + workspace_size; + Tensor temp_mem = + ctx.requested[0].get_space_typed(Shape1(temp_mem_size), s); + char *workspace_ptr = temp_mem.dptr_ + temp_data_size; + Tensor workspace(workspace_ptr, Shape1(workspace_size), s); // Compute mean #if !defined(__CUDACC__) ReduceAxesComputeImpl( From d13386256ebffbf9a5c5606f8e7c979e212b026a Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 2 Nov 2020 12:18:25 -0800 Subject: [PATCH 10/15] Fix Windows compilation --- src/operator/nn/moments-inl.h | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/operator/nn/moments-inl.h b/src/operator/nn/moments-inl.h index b8915dce2397..78c7e4a1cd44 100644 --- a/src/operator/nn/moments-inl.h +++ b/src/operator/nn/moments-inl.h @@ -131,6 +131,7 @@ inline void MomentsForwardImpl(const OpContext& ctx, #else ReduceAxesRTCComputeImpl(ctx, {data}, {req[0]}, {mean}, small, "red::sum{}", nullptr, true); #endif + TBlob temp; MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { Shape<6> data_shape, mean_shape; for (int i = 0; i < 6; ++i) { @@ -141,14 +142,15 @@ inline void MomentsForwardImpl(const OpContext& ctx, ctx.requested[0].get_space_typed(Shape1(data.shape_.Size()), s);; Kernel::Launch(s, data.shape_.Size(), temp_data.dptr_, data.dptr(), mean.dptr(), data_shape, mean_shape); + temp = TBlob(temp_data); + }); #if !defined(__CUDACC__) - ReduceAxesComputeImpl( - ctx, {TBlob(temp_data).reshape(data.shape_)}, {kWriteTo}, {var}, small); + ReduceAxesComputeImpl( + ctx, {temp.reshape(data.shape_)}, {kWriteTo}, {var}, small); #else - ReduceAxesRTCComputeImpl(ctx, {TBlob(temp_data).reshape(data.shape_)}, - {kWriteTo}, {var}, small, "red::sum{}", nullptr, true); + ReduceAxesRTCComputeImpl(ctx, {temp.reshape(data.shape_)}, + {kWriteTo}, {var}, small, "red::sum{}", nullptr, true); #endif - }); } template From 9008e860012cbd71a3aa26052453a5c2e63357d4 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 2 Nov 2020 14:45:50 -0800 Subject: [PATCH 11/15] More fixes for Windows compilation --- src/operator/numpy/linalg/np_norm-inl.h | 44 ++++++++++++------------- src/operator/numpy/np_cross-inl.h | 6 ++-- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/operator/numpy/linalg/np_norm-inl.h b/src/operator/numpy/linalg/np_norm-inl.h index 29a33179ab7f..60dee6ac3492 100644 --- a/src/operator/numpy/linalg/np_norm-inl.h +++ b/src/operator/numpy/linalg/np_norm-inl.h @@ -452,31 +452,29 @@ void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs, if (param.ord != 2 && param.ord != -2) { // row norm or col norm TShape sum_shape = inputs[0].shape_; sum_shape[mat_axis[!(param.ord == 1 || param.ord == -1)]] = 1; - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - TBlob temp = outputs[1].reshape(sum_shape); - std::vector sum_output({temp}); + TBlob temp = outputs[1].reshape(sum_shape); + std::vector sum_output({temp}); #if !defined(__CUDACC__) - ReduceAxesComputeImpl( - ctx, inputs, req, sum_output, sum_shape); - if (param.ord > 0) { - ReduceAxesComputeImpl( - ctx, sum_output, req, outputs, reduced_shape); - } else { - ReduceAxesComputeImpl( - ctx, sum_output, req, outputs, reduced_shape); - } + ReduceAxesComputeImpl( + ctx, inputs, req, sum_output, sum_shape); + if (param.ord > 0) { + ReduceAxesComputeImpl( + ctx, sum_output, req, outputs, reduced_shape); + } else { + ReduceAxesComputeImpl( + ctx, sum_output, req, outputs, reduced_shape); + } #else - ReduceAxesRTCComputeImpl(ctx, inputs, req, sum_output, sum_shape, - "red::sum{}", nullptr, false, "abs"); - if (param.ord > 0) { - ReduceAxesRTCComputeImpl(ctx, sum_output, req, outputs, reduced_shape, - "red::maximum{}", nullptr, false); - } else { - ReduceAxesRTCComputeImpl(ctx, sum_output, req, outputs, reduced_shape, - "red::minimum{}", nullptr, false); - } -#endif // MXNET_USE_CUDA - }); + ReduceAxesRTCComputeImpl(ctx, inputs, req, sum_output, sum_shape, + "red::sum{}", nullptr, false, "abs"); + if (param.ord > 0) { + ReduceAxesRTCComputeImpl(ctx, sum_output, req, outputs, reduced_shape, + "red::maximum{}", nullptr, false); + } else { + ReduceAxesRTCComputeImpl(ctx, sum_output, req, outputs, reduced_shape, + "red::minimum{}", nullptr, false); + } +#endif return; } diff --git a/src/operator/numpy/np_cross-inl.h b/src/operator/numpy/np_cross-inl.h index ef6ea1e6cec4..7d5dd30d813b 100644 --- a/src/operator/numpy/np_cross-inl.h +++ b/src/operator/numpy/np_cross-inl.h @@ -689,15 +689,17 @@ struct ReduceImplWrap { const Tensor workspace_tensor) { Stream *s = ctx.get_stream(); // Reduce work_in to work_out. - SUM_NDIM_SWITCH(work_out.ndim(), NDim, { #if !defined(__CUDACC__) + SUM_NDIM_SWITCH(work_out.ndim(), NDim, { op::broadcast::Reduce( s, work_out, kWriteTo, workspace_tensor, work_in); + }); #else + SUM_NDIM_SWITCH(work_out.ndim(), NDim, { op::broadcast::RTCReduce(ctx, work_out, kWriteTo, workspace_tensor, work_in, "red::sum{}", NDim, "identity"); -#endif }); +#endif // Copy work_out to out_data. MXNET_ASSIGN_REQ_SWITCH(out_req, req_type, { mxnet_op::Kernel, xpu>::Launch( From 245ec8939006a5e1b6dace9ba2ebebeab2177797 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 5 Nov 2020 09:47:07 -0800 Subject: [PATCH 12/15] Breaking the strings to please Windows compiler --- src/common/cuda/rtc.cc | 4 +- src/common/cuda/rtc/backward_functions-inl.h | 178 ++++++++-------- src/common/cuda/rtc/reducer-inl.h | 204 +++++++++---------- 3 files changed, 195 insertions(+), 191 deletions(-) diff --git a/src/common/cuda/rtc.cc b/src/common/cuda/rtc.cc index 486f9afe53b1..9f70c1b2ef64 100644 --- a/src/common/cuda/rtc.cc +++ b/src/common/cuda/rtc.cc @@ -123,7 +123,9 @@ CUfunction get_function(const std::string ¶meters, function_definitions_binary + "\n" + function_definitions_unary + "\n" + backward_function_definitions + "\n" + - reducer + "\n"; + grad_function_definitions + "\n" + + reducer + "\n" + + logic_reducer + "\n"; std::string code_with_header = common_header + parameters + code; // If verbose mode, output kernel source, though not including the common header if (dmlc::GetEnv("MXNET_RTC_VERBOSE", false)) { diff --git a/src/common/cuda/rtc/backward_functions-inl.h b/src/common/cuda/rtc/backward_functions-inl.h index 17c0190ef3fe..886f5fabdb72 100644 --- a/src/common/cuda/rtc/backward_functions-inl.h +++ b/src/common/cuda/rtc/backward_functions-inl.h @@ -231,6 +231,98 @@ backward_square(const DTypeGrad grad, const DType val) { return 2 * val * grad; } +template +__device__ inline DType div_rgrad(const DType val, + const DType2 val2) { + return -val / (val2 * val2); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_clip(const DTypeGrad grad, const DType val, + const float a_min, const float a_max) { + if (val > a_max || val < a_min) { + return 0; + } else { + return grad; + } +} + +template +__device__ inline typename type_util::mixed_type::type +backward_reciprocal(const DTypeGrad grad, const DType val) { + return -grad / (val * val); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_erf(const DTypeGrad grad, const DType val) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type v = val; + constexpr mixed_type my_pi = pi; + return 2.0f / op::sqrt(my_pi) * op::exp(-(v*v)) * grad; +} + +template +__device__ inline typename type_util::mixed_type::type +backward_erfinv(const DTypeGrad grad, const DType val) { + using mixed_type = typename type_util::mixed_type::type; + constexpr mixed_type my_pi = pi; + const mixed_type g = grad; + const mixed_type v = val; + return 0.5f * op::sqrt(my_pi) * op::exp(v * v) * g; +} + +template +__device__ inline typename type_util::mixed_type::type +backward_gamma(const DTypeGrad grad, const DType val) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type v = val; + if (type_util::is_same::value) { + return grad * op::gamma(v) * op::special_functions::cephes::psi(v); + } else { + return grad * op::gamma(v) * op::special_functions::cephes::psi(v); + } +} + +template +__device__ inline typename type_util::mixed_type::type +backward_gammaln(const DTypeGrad grad, const DType val) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type v = val; + if (type_util::is_same::value) { + return grad * op::special_functions::cephes::psi(v); + } else { + return grad * op::special_functions::cephes::psi(v); + } +} + +template +__device__ inline typename type_util::mixed_type::type +backward_digamma(const DTypeGrad grad, const DType val) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type v = val; + if (type_util::is_same::value) { + return grad * op::special_functions::trigamma(v); + } else { + return grad * op::special_functions::trigamma(v); + } +} + +template +__device__ inline typename type_util::mixed_type::type +backward_gelu(const DTypeGrad grad, const DType val) { + return 0.5f * (grad + grad * op::erf(val / op::sqrt(2.0f)) + + val * backward_erf(grad, val / op::sqrt(2.0f)) / op::sqrt(2.0f)); +} + +} // namespace op + +)code"; + +const char grad_function_definitions[] = R"code( +namespace op { + template __device__ inline typename type_util::mixed_type::type rdiv_grad(const DType val, @@ -246,12 +338,6 @@ div_grad(const DType val, return op::reciprocal(temp); } -template -__device__ inline DType div_rgrad(const DType val, - const DType2 val2) { - return -val / (val2 * val2); -} - template __device__ inline DType mod_grad(const DType val, const DType2 val2) { @@ -362,85 +448,6 @@ rldexp_grad(const DType val, return val2 * op::power(static_cast(2), val) * op::log(static_cast(2)); } -template -__device__ inline typename type_util::mixed_type::type -backward_clip(const DTypeGrad grad, const DType val, - const float a_min, const float a_max) { - if (val > a_max || val < a_min) { - return 0; - } else { - return grad; - } -} - -template -__device__ inline typename type_util::mixed_type::type -backward_reciprocal(const DTypeGrad grad, const DType val) { - return -grad / (val * val); -} - -template -__device__ inline typename type_util::mixed_type::type -backward_erf(const DTypeGrad grad, const DType val) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type v = val; - constexpr mixed_type my_pi = pi; - return 2.0f / op::sqrt(my_pi) * op::exp(-(v*v)) * grad; -} - -template -__device__ inline typename type_util::mixed_type::type -backward_erfinv(const DTypeGrad grad, const DType val) { - using mixed_type = typename type_util::mixed_type::type; - constexpr mixed_type my_pi = pi; - const mixed_type g = grad; - const mixed_type v = val; - return 0.5f * op::sqrt(my_pi) * op::exp(v * v) * g; -} - -template -__device__ inline typename type_util::mixed_type::type -backward_gamma(const DTypeGrad grad, const DType val) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type v = val; - if (type_util::is_same::value) { - return grad * op::gamma(v) * op::special_functions::cephes::psi(v); - } else { - return grad * op::gamma(v) * op::special_functions::cephes::psi(v); - } -} - -template -__device__ inline typename type_util::mixed_type::type -backward_gammaln(const DTypeGrad grad, const DType val) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type v = val; - if (type_util::is_same::value) { - return grad * op::special_functions::cephes::psi(v); - } else { - return grad * op::special_functions::cephes::psi(v); - } -} - -template -__device__ inline typename type_util::mixed_type::type -backward_digamma(const DTypeGrad grad, const DType val) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type v = val; - if (type_util::is_same::value) { - return grad * op::special_functions::trigamma(v); - } else { - return grad * op::special_functions::trigamma(v); - } -} - -template -__device__ inline typename type_util::mixed_type::type -backward_gelu(const DTypeGrad grad, const DType val) { - return 0.5f * (grad + grad * op::erf(val / op::sqrt(2.0f)) + - val * backward_erf(grad, val / op::sqrt(2.0f)) / op::sqrt(2.0f)); -} - template __device__ inline DType smooth_l1_grad(const DType val, const DType2 scalar) { auto bsq = scalar * scalar; @@ -534,7 +541,6 @@ gamma_implicit_grad(const DType a_in, const DType2 x_in) { } } // namespace op - )code"; } // namespace rtc diff --git a/src/common/cuda/rtc/reducer-inl.h b/src/common/cuda/rtc/reducer-inl.h index ac6bb22442b0..f5b70d832594 100644 --- a/src/common/cuda/rtc/reducer-inl.h +++ b/src/common/cuda/rtc/reducer-inl.h @@ -29,10 +29,8 @@ namespace rtc { const char reducer[] = R"code( - namespace red { -/*! \brief sum reducer */ struct sum { /*! \brief do reduction into dst */ template @@ -96,103 +94,6 @@ struct sum { } }; -/*! \brief maximum reducer */ -struct maximum { - /*! \brief do reduction into dst */ - template - __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { // NOLINT(*) - if (!util::isnan(dst)) { - if (!(dst >= src)) dst = src; - } - } - /*! \brief do reduction into dst */ - template - __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src, - volatile DType& none) { - Reduce(dst, src); - } - /*! \brief combine the results of two reducers */ - template - __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { - Reduce(dst_val, src_val); - } - /*! \brief combine the results of two reducers */ - template - __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual, - volatile DType& src_val, volatile DType& src_residual) { - Reduce(dst_val, src_val); - } - /*! \brief finalize reduction result */ - template - __device__ inline static void Finalize(volatile DType& dst) {} - /*! \brief finalize reduction result */ - template - __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {} - /*! - *\brief set the initial value during reduction - */ - template - __device__ inline static void SetInitValue(DType &initv) { - initv = limits::NegInfValue(); - } - /*! - *\brief set the initial value during reduction - */ - template - __device__ inline static void SetInitValue(DType &initv, DType &none) { - SetInitValue(initv); - } -}; - -/*! \brief minimum reducer */ -struct minimum { - /*! \brief do reduction into dst */ - template - __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { - if (!util::isnan(dst)) { - if (!(dst <= src)) dst = src; - } - } - /*! \brief do reduction into dst */ - template - __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src, - volatile DType& none) { - Reduce(dst, src); - } - /*! \brief combine the results of two reducers */ - template - __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { - Reduce(dst_val, src_val); - } - /*! \brief combine the results of two reducers */ - template - __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual, - volatile DType& src_val, volatile DType& src_residual) { - Reduce(dst_val, src_val); - } - /*! \brief finalize reduction result */ - template - __device__ inline static void Finalize(volatile DType& dst) {} - /*! \brief finalize reduction result */ - template - __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {} - /*! - *\brief set the initial value during reduction - */ - template - __device__ inline static void SetInitValue(DType &initv) { - initv = limits::PosInfValue(); - } - /*! - *\brief set the initial value during reduction - */ - template - __device__ inline static void SetInitValue(DType &initv, DType &none) { - SetInitValue(initv); - } -}; - -/*! \brief product reducer */ struct product { /*! \brief do reduction into dst */ template @@ -238,7 +139,6 @@ struct product { } }; -/*! \brief sum reducer that ignores NaN values in the input */ struct nansum { /*! \brief do reduction into dst */ template @@ -294,7 +194,6 @@ struct nansum { } }; -/*! \brief product reducer that ignores NaN values in the input */ struct nanprod { /*! \brief do reduction into dst */ template @@ -495,7 +394,106 @@ struct nrmlp { } }; -/*! \brief arg max reducer */ +} // namespace red +)code"; + +const char logic_reducer[] = R"code( +namespace red { + +struct maximum { + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { // NOLINT(*) + if (!util::isnan(dst)) { + if (!(dst >= src)) dst = src; + } + } + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src, + volatile DType& none) { + Reduce(dst, src); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual, + volatile DType& src_val, volatile DType& src_residual) { + Reduce(dst_val, src_val); + } + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst) {} + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {} + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv) { + initv = limits::NegInfValue(); + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv, DType &none) { + SetInitValue(initv); + } +}; + +struct minimum { + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { + if (!util::isnan(dst)) { + if (!(dst <= src)) dst = src; + } + } + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src, + volatile DType& none) { + Reduce(dst, src); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual, + volatile DType& src_val, volatile DType& src_residual) { + Reduce(dst_val, src_val); + } + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst) {} + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {} + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv) { + initv = limits::PosInfValue(); + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv, DType &none) { + SetInitValue(initv); + } +}; + struct argmax { /*! \brief do reduction into dst */ template @@ -610,9 +608,7 @@ struct argmin { } }; } // namespace red - )code"; - } // namespace rtc } // namespace cuda } // namespace common From 8149bbd4be66eadfba868444f9d2ebca5b4e2e37 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 5 Nov 2020 13:00:12 -0800 Subject: [PATCH 13/15] Do not use the default stream in kron --- src/operator/numpy/np_kron-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/numpy/np_kron-inl.h b/src/operator/numpy/np_kron-inl.h index 49aee60b578c..bf0983f10157 100644 --- a/src/operator/numpy/np_kron-inl.h +++ b/src/operator/numpy/np_kron-inl.h @@ -246,7 +246,7 @@ void KronOpBackwardImpl(const OpContext& ctx, ctx.requested[0].get_space_typed(Shape1(workspace_size), s); Tensor temp(reinterpret_cast(workspace.dptr_ + aligned_first_workspace_size), - Shape1(ograd.shape_.Size())); + Shape1(ograd.shape_.Size()), s); ASSIGN_DISPATCH(temp, kWriteTo, tensor_ * ograd_); NP_KRON_REDUCE_AXES(true, workspace, ctx, {TBlob(temp)}, {scalar_req}, From 9194ea9db2726e6090a721a4031c49c0ee0e3595 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 6 Nov 2020 15:41:27 -0800 Subject: [PATCH 14/15] Fix argmin/argmax --- src/operator/tensor/reduce_rtc.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/operator/tensor/reduce_rtc.cc b/src/operator/tensor/reduce_rtc.cc index 6fef0e002a46..ac39f6f09dc7 100644 --- a/src/operator/tensor/reduce_rtc.cc +++ b/src/operator/tensor/reduce_rtc.cc @@ -253,9 +253,6 @@ __global__ void reduce_kernel_multi(const int N, const int M, )code"; const char reduce_lines_kernel_code[] = R"code( -using MixedType = typename type_util::mixed_type::type; -using AType = typename AccType::type; - __launch_bounds__(kRTCMaxThreadsPerBlock) __global__ void reduce_lines_kernel(const index_t N, const index_t M, const index_t small_in_stride, @@ -385,7 +382,7 @@ void RTCReduceImpl(Stream *s, const TBlob& small, const bool addto, args.emplace_back(&small_dptr); args.emplace_back(&small.dptr_); - auto reduce_lines_kernel_func = get_function(code, + auto reduce_lines_kernel_func = get_function(code + function_code, "reduce_lines_kernel", reduce_lines_kernel_code, dev_id); From 2ff74b2a95a19ce9561bd22b629a93a0e64faa52 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 22 Apr 2021 13:36:56 -0700 Subject: [PATCH 15/15] Fix layernorm --- src/operator/nn/layer_norm-inl.h | 204 ------------------------------- src/operator/nn/layer_norm.cc | 116 ++++++++++++++++++ src/operator/nn/layer_norm.cu | 83 +++++++++++++ 3 files changed, 199 insertions(+), 204 deletions(-) diff --git a/src/operator/nn/layer_norm-inl.h b/src/operator/nn/layer_norm-inl.h index b21cc880b908..79d09063ee6c 100644 --- a/src/operator/nn/layer_norm-inl.h +++ b/src/operator/nn/layer_norm-inl.h @@ -240,210 +240,6 @@ void LayerNormGradComputeGeneralImpl(const nnvm::NodeAttrs& attrs, const mxnet::TShape& red_exclude_src_shape, const int channel_size); -#ifndef __CUDACC__ -template <> -void LayerNormGradComputeGeneralImpl(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const TBlob& ograd, - const TBlob& data, - const TBlob& gamma, - const TBlob& mean, - const TBlob& std, - const TBlob& normalized_data, - const TBlob& ograd_mult, - const TBlob& red_out, - const std::vector& req, - const std::vector& outputs, - const mshadow::Tensor& workspace, - const mxnet::TShape& red_dst_shape, - const mxnet::TShape& red_src_shape, - const mxnet::TShape& red_exclude_dst_shape, - const mxnet::TShape& red_exclude_src_shape, - const int channel_size) { - using namespace mshadow; - using namespace mshadow::expr; - Stream *s = ctx.get_stream(); - // Compute normalized_data = (data - mean) / std - BinaryBroadcastCompute(attrs, ctx, - {data, mean}, - {kWriteTo}, {normalized_data}); - BinaryBroadcastCompute(attrs, ctx, - {normalized_data, std}, - {kWriteTo}, {normalized_data}); - // Calculate grad_beta - bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true); - if (req[2] != kNullOp) { - MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, { - BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { - if (!safe_acc) { - broadcast::Reduce( - s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace, - ograd.reshape(red_exclude_src_shape)); - } else { - broadcast::Reduce( - s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace, - ograd.reshape(red_exclude_src_shape)); - } - }); - }); - } - // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis) - ElemwiseBinaryOp::Compute(attrs, ctx, {normalized_data, ograd}, - {kWriteTo}, {ograd_mult}); - if (req[1] != kNullOp) { - MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, { - BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { - if (!safe_acc) { - broadcast::Reduce( - s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace, - ograd_mult.reshape(red_exclude_src_shape)); - } else { - broadcast::Reduce( - s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace, - ograd_mult.reshape(red_exclude_src_shape)); - } - }); - }); - } - // Calculate grad_data: - // ograd_mult = ograd * gamma / std - // grad_data = ograd_mult - mean(ograd_mult, axis) - // + normalized_data * (-mean(normalized_data * ograd_mult, axis)) - if (req[0] != kNullOp) { - BinaryBroadcastCompute(attrs, ctx, - {ograd, gamma}, - {kWriteTo}, {ograd_mult}); - BinaryBroadcastCompute(attrs, ctx, - {ograd_mult, std}, - {kWriteTo}, {ograd_mult}); - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - if (!safe_acc) { - broadcast::Reduce( - s, red_out.reshape(red_dst_shape), kWriteTo, workspace, - ograd_mult.reshape(red_src_shape)); - } else { - broadcast::Reduce( - s, red_out.reshape(red_dst_shape), kWriteTo, workspace, - ograd_mult.reshape(red_src_shape)); - } - }); - Tensor red_out_tensor = red_out.FlatTo1D(s); - red_out_tensor /= scalar(channel_size); - }); - BinaryBroadcastCompute(attrs, ctx, - {ograd_mult, red_out}, - {req[0]}, {outputs[0]}); - ElemwiseBinaryOp::Compute(attrs, ctx, {ograd_mult, normalized_data}, - {kWriteTo}, {ograd_mult}); - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - if (!safe_acc) { - broadcast::Reduce( - s, red_out.reshape(red_dst_shape), kWriteTo, workspace, - ograd_mult.reshape(red_src_shape)); - } else { - broadcast::Reduce( - s, red_out.reshape(red_dst_shape), kWriteTo, workspace, - ograd_mult.reshape(red_src_shape)); - } - }); - Tensor red_out_tensor = red_out.FlatTo1D(s); - red_out_tensor /= scalar(- channel_size); - }); - BinaryBroadcastCompute(attrs, ctx, - {normalized_data, red_out}, - {kAddTo}, {outputs[0]}); - } -} - -#else - -template <> -void LayerNormGradComputeGeneralImpl(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const TBlob& ograd, - const TBlob& data, - const TBlob& gamma, - const TBlob& mean, - const TBlob& std, - const TBlob& normalized_data, - const TBlob& ograd_mult, - const TBlob& red_out, - const std::vector& req, - const std::vector& outputs, - const mshadow::Tensor& workspace, - const mxnet::TShape& red_dst_shape, - const mxnet::TShape& red_src_shape, - const mxnet::TShape& red_exclude_dst_shape, - const mxnet::TShape& red_exclude_src_shape, - const int channel_size) { - using namespace mshadow; - using namespace mshadow::expr; - Stream *s = ctx.get_stream(); - // Compute normalized_data = (data - mean) / std - BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, - {data, mean}, - {kWriteTo}, {normalized_data}); - BinaryBroadcastRTCCompute {"div"}(attrs, ctx, - {normalized_data, std}, - {kWriteTo}, {normalized_data}); - // Calculate grad_beta - if (req[2] != kNullOp) { - BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { - broadcast::RTCReduce(ctx, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace, - ograd.reshape(red_exclude_src_shape), "red::sum{}", NDim, "identity"); - }); - } - // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis) - ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {normalized_data, ograd}, - {kWriteTo}, {ograd_mult}); - if (req[1] != kNullOp) { - BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { - broadcast::RTCReduce(ctx, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace, - ograd_mult.reshape(red_exclude_src_shape), "red::sum{}", NDim, - "identity"); - }); - } - // Calculate grad_data: - // ograd_mult = ograd * gamma / std - // grad_data = ograd_mult - mean(ograd_mult, axis) - // + normalized_data * (-mean(normalized_data * ograd_mult, axis)) - if (req[0] != kNullOp) { - BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, - {ograd, gamma}, - {kWriteTo}, {ograd_mult}); - BinaryBroadcastRTCCompute {"div"}(attrs, ctx, - {ograd_mult, std}, - {kWriteTo}, {ograd_mult}); - BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - broadcast::RTCReduce(ctx, red_out.reshape(red_dst_shape), kWriteTo, workspace, - ograd_mult.reshape(red_src_shape), "red::sum{}", NDim, "identity"); - }); - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - Tensor red_out_tensor = red_out.FlatTo1D(s); - red_out_tensor /= scalar(channel_size); - }); - BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, - {ograd_mult, red_out}, - {req[0]}, {outputs[0]}); - ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {ograd_mult, normalized_data}, - {kWriteTo}, {ograd_mult}); - BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - broadcast::RTCReduce(ctx, red_out.reshape(red_dst_shape), kWriteTo, workspace, - ograd_mult.reshape(red_src_shape), "red::sum{}", NDim, "identity"); - }); - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - Tensor red_out_tensor = red_out.FlatTo1D(s); - red_out_tensor /= scalar(- channel_size); - }); - BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, - {normalized_data, red_out}, - {kAddTo}, {outputs[0]}); - } -} -#endif - /* Calculate the gradient of layer normalization. We have the following gradient for gamma, beta and x: diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc index 08847205f155..1a040fa6f7d0 100644 --- a/src/operator/nn/layer_norm.cc +++ b/src/operator/nn/layer_norm.cc @@ -268,6 +268,122 @@ void LayerNormCompute(const nnvm::NodeAttrs& attrs, LayerNormComputeGeneral(attrs, ctx, inputs, req, outputs); } +template <> +void LayerNormGradComputeGeneralImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const TBlob& ograd, + const TBlob& data, + const TBlob& gamma, + const TBlob& mean, + const TBlob& std, + const TBlob& normalized_data, + const TBlob& ograd_mult, + const TBlob& red_out, + const std::vector& req, + const std::vector& outputs, + const mshadow::Tensor& workspace, + const mxnet::TShape& red_dst_shape, + const mxnet::TShape& red_src_shape, + const mxnet::TShape& red_exclude_dst_shape, + const mxnet::TShape& red_exclude_src_shape, + const int channel_size) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + // Compute normalized_data = (data - mean) / std + BinaryBroadcastCompute(attrs, ctx, + {data, mean}, + {kWriteTo}, {normalized_data}); + BinaryBroadcastCompute(attrs, ctx, + {normalized_data, std}, + {kWriteTo}, {normalized_data}); + // Calculate grad_beta + bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true); + if (req[2] != kNullOp) { + MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { + if (!safe_acc) { + broadcast::Reduce( + s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace, + ograd.reshape(red_exclude_src_shape)); + } else { + broadcast::Reduce( + s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace, + ograd.reshape(red_exclude_src_shape)); + } + }); + }); + } + // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis) + ElemwiseBinaryOp::Compute(attrs, ctx, {normalized_data, ograd}, + {kWriteTo}, {ograd_mult}); + if (req[1] != kNullOp) { + MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { + if (!safe_acc) { + broadcast::Reduce( + s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace, + ograd_mult.reshape(red_exclude_src_shape)); + } else { + broadcast::Reduce( + s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace, + ograd_mult.reshape(red_exclude_src_shape)); + } + }); + }); + } + // Calculate grad_data: + // ograd_mult = ograd * gamma / std + // grad_data = ograd_mult - mean(ograd_mult, axis) + // + normalized_data * (-mean(normalized_data * ograd_mult, axis)) + if (req[0] != kNullOp) { + BinaryBroadcastCompute(attrs, ctx, + {ograd, gamma}, + {kWriteTo}, {ograd_mult}); + BinaryBroadcastCompute(attrs, ctx, + {ograd_mult, std}, + {kWriteTo}, {ograd_mult}); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + if (!safe_acc) { + broadcast::Reduce( + s, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape)); + } else { + broadcast::Reduce( + s, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape)); + } + }); + Tensor red_out_tensor = red_out.FlatTo1D(s); + red_out_tensor /= scalar(channel_size); + }); + BinaryBroadcastCompute(attrs, ctx, + {ograd_mult, red_out}, + {req[0]}, {outputs[0]}); + ElemwiseBinaryOp::Compute(attrs, ctx, {ograd_mult, normalized_data}, + {kWriteTo}, {ograd_mult}); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + if (!safe_acc) { + broadcast::Reduce( + s, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape)); + } else { + broadcast::Reduce( + s, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape)); + } + }); + Tensor red_out_tensor = red_out.FlatTo1D(s); + red_out_tensor /= scalar(- channel_size); + }); + BinaryBroadcastCompute(attrs, ctx, + {normalized_data, red_out}, + {kAddTo}, {outputs[0]}); + } +} + template<> void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, diff --git a/src/operator/nn/layer_norm.cu b/src/operator/nn/layer_norm.cu index a60df412299a..9a33e0665ff4 100644 --- a/src/operator/nn/layer_norm.cu +++ b/src/operator/nn/layer_norm.cu @@ -29,6 +29,89 @@ using namespace mshadow::cuda; namespace mxnet { namespace op { +template <> +void LayerNormGradComputeGeneralImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const TBlob& ograd, + const TBlob& data, + const TBlob& gamma, + const TBlob& mean, + const TBlob& std, + const TBlob& normalized_data, + const TBlob& ograd_mult, + const TBlob& red_out, + const std::vector& req, + const std::vector& outputs, + const mshadow::Tensor& workspace, + const mxnet::TShape& red_dst_shape, + const mxnet::TShape& red_src_shape, + const mxnet::TShape& red_exclude_dst_shape, + const mxnet::TShape& red_exclude_src_shape, + const int channel_size) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + // Compute normalized_data = (data - mean) / std + BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, + {data, mean}, + {kWriteTo}, {normalized_data}); + BinaryBroadcastRTCCompute {"div"}(attrs, ctx, + {normalized_data, std}, + {kWriteTo}, {normalized_data}); + // Calculate grad_beta + if (req[2] != kNullOp) { + BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace, + ograd.reshape(red_exclude_src_shape), "red::sum{}", NDim, "identity"); + }); + } + // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis) + ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {normalized_data, ograd}, + {kWriteTo}, {ograd_mult}); + if (req[1] != kNullOp) { + BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace, + ograd_mult.reshape(red_exclude_src_shape), "red::sum{}", NDim, + "identity"); + }); + } + // Calculate grad_data: + // ograd_mult = ograd * gamma / std + // grad_data = ograd_mult - mean(ograd_mult, axis) + // + normalized_data * (-mean(normalized_data * ograd_mult, axis)) + if (req[0] != kNullOp) { + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, + {ograd, gamma}, + {kWriteTo}, {ograd_mult}); + BinaryBroadcastRTCCompute {"div"}(attrs, ctx, + {ograd_mult, std}, + {kWriteTo}, {ograd_mult}); + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape), "red::sum{}", NDim, "identity"); + }); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor red_out_tensor = red_out.FlatTo1D(s); + red_out_tensor /= scalar(channel_size); + }); + BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, + {ograd_mult, red_out}, + {req[0]}, {outputs[0]}); + ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {ograd_mult, normalized_data}, + {kWriteTo}, {ograd_mult}); + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::RTCReduce(ctx, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape), "red::sum{}", NDim, "identity"); + }); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor red_out_tensor = red_out.FlatTo1D(s); + red_out_tensor /= scalar(- channel_size); + }); + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, + {normalized_data, red_out}, + {kAddTo}, {outputs[0]}); + } +} template __device__ __forceinline__ DType warp_shfl(DType value, int src_lane, int width = 32, unsigned int mask = 0xffffffff) {