diff --git a/paddle/phi/kernels/funcs/reduce_grad_functions.h b/paddle/phi/kernels/funcs/reduce_grad_functions.h index 27de6176657e7f..e1de295be330b5 100644 --- a/paddle/phi/kernels/funcs/reduce_grad_functions.h +++ b/paddle/phi/kernels/funcs/reduce_grad_functions.h @@ -38,10 +38,10 @@ void ReduceGradFunctor(const Context& dev_ctx, auto x_dims = input0.dims(); auto reduced_dims_v = common::vectorize(x_dims); std::vector dims_ref = dims; - Eigen::array broadcast_dim; + Eigen::array broadcast_dim; for (size_t i = 0; i < D; ++i) broadcast_dim[i] = 1; - int broad_cast_times = 1; + int64_t broad_cast_times = 1; for (size_t i = 0; i < dims_ref.size(); ++i) { if (dims_ref[i] < 0) { dims_ref[i] = x_rank + dims_ref[i]; @@ -142,7 +142,7 @@ void LaunchReduceGradKernel(const Context& dev_ctx, auto& place = *dev_ctx.eigen_device(); // *dev_ctx.eigen_device(); auto broadcast_dim = - Eigen::array({{static_cast(input0->numel())}}); + Eigen::array({{static_cast(input0->numel())}}); functor(place, &x, &x_reduce, diff --git a/paddle/phi/kernels/gpu/dist_kernel.cu b/paddle/phi/kernels/gpu/dist_kernel.cu index a55458d59a2d57..1993caec70adb3 100644 --- a/paddle/phi/kernels/gpu/dist_kernel.cu +++ b/paddle/phi/kernels/gpu/dist_kernel.cu @@ -63,6 +63,18 @@ struct PowFunctor { Ty p_order_; }; +template // Tx is high precision, Tout is low/out precision +struct PowFunctorHighPrecision { + HOSTDEVICE explicit inline PowFunctorHighPrecision(const Ty& p_order) + : p_order_(p_order) {} + HOSTDEVICE inline Tx operator()(const Tx x) const { + return static_cast(pow(static_cast(x), p_order_)); + } + Ty p_order_; +}; + template __global__ void ReduceSumWithSubtract( const T* x, const T* y, T* out, int64_t N, Functor func) { @@ -126,16 +138,17 @@ void DistKernel(const Context& dev_ctx, DenseTensor intermediate; const T* x_ptr = x.data(); const T* y_ptr = y.data(); + T* o_ptr = dev_ctx.template Alloc(out); auto stream = dev_ctx.stream(); auto xdim = x.dims(); if (xdim == y.dims()) { // same shape - auto n = x.numel(); + int64_t n = x.numel(); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n); intermediate.Resize(common::make_ddim({config.block_per_grid.x})); T* i_ptr = dev_ctx.template Alloc(&intermediate); - std::vector axis_dims = {static_cast(-1)}; std::vector reduce_axis = funcs::details::GetReduceDim(axis_dims, xdim.size(), true); @@ -166,15 +179,23 @@ void DistKernel(const Context& dev_ctx, ReduceSumWithSubtract <<>>( x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor(p_order)); - phi::funcs::ReduceKernel>( - dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); - - const DenseTensor* tmp_norm = out; - std::vector ins = {tmp_norm}; + DenseTensor out_other; + out_other.Resize(out->dims()); + dev_ctx.template Alloc(&out_other); + + phi::funcs:: + ReduceKernel>( + dev_ctx, + intermediate, + &out_other, + kps::IdentityFunctor(), + reduce_axis); + std::vector ins = {&out_other}; std::vector outs = {out}; - MT p_order_ = static_cast(static_cast(1.) / p_order); + + MT p_order_ = static_cast(1.f / p_order); phi::funcs::ElementwiseKernel( - dev_ctx, ins, &outs, PowFunctor(p_order_)); + dev_ctx, ins, &outs, PowFunctorHighPrecision(p_order_)); } } else { diff --git a/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu index fdfed25b3dda8f..5efd6a36a5399f 100644 --- a/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu @@ -42,10 +42,12 @@ struct AbsMaxAndMinGradFunctor { template struct PNormGradFunctor { + using MT = typename phi::dtype::MPTypeTrait::Type; HOSTDEVICE explicit inline PNormGradFunctor(float porder, float eps) { - this->porder = static_cast(porder - 1.); - this->eps = static_cast(eps); + this->porder = static_cast(porder - 1.); + this->eps = static_cast(eps); } + template template cast(); + auto y_mt = y->template cast(); + auto dy_mt = dy->template cast(); + + auto norm_pow = y_mt.pow(-this->porder); + auto mask_norm_nonzero = (y_mt != static_cast(0)).template cast(); + + // Set to 0 where porder < 0 and x == 0 + MT zero = static_cast(0); + auto mask_x_zero = (x_mt == zero).template cast(); + + MT is_porder_negative = + this->porder < zero ? static_cast(1) : static_cast(0); + auto invalid_mask = (mask_x_zero * is_porder_negative); + auto safe_pow = + x_mt.abs().pow(this->porder) * (static_cast(1) - invalid_mask); + dx->device(place) = - (*x).abs().pow(this->porder) * (*x).sign() * dy->broadcast(dim) * - (*y + y->constant(eps)).pow(-this->porder).broadcast(dim); + (safe_pow * x_mt.sign() * dy_mt.broadcast(dim) * + norm_pow.broadcast(dim) * + mask_norm_nonzero.broadcast(dim) // Mask out positions where norm == 0 + ) + .template cast(); } - T porder; - T eps; + + MT porder; + MT eps; }; template diff --git a/paddle/phi/kernels/gpu/p_norm_kernel.cu b/paddle/phi/kernels/gpu/p_norm_kernel.cu index 9b0515feb33544..8809b082b7a826 100644 --- a/paddle/phi/kernels/gpu/p_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/p_norm_kernel.cu @@ -124,31 +124,38 @@ void PNormKernel(const Context& dev_ctx, phi::funcs::ElementwiseKernel( dev_ctx, ins, &outs, UnsignedPowFunctor(1. / porder)); #else + DenseTensor out_temp; + out_temp.Resize(out_norm->dims()); + dev_ctx.template Alloc(&out_temp); + if (porder == 1.0) { // fast 1-norm phi::funcs::ReduceKernel>( dev_ctx, *in_x, out_norm, FabsFunctor(), reduce_axis); } else if (porder == 2.0) { // fast 2-norm - phi::funcs::ReduceKernel>( - dev_ctx, *in_x, out_norm, SquareFunctor(), reduce_axis); + phi::funcs::ReduceKernel>( + dev_ctx, *in_x, &out_temp, SquareFunctor(), reduce_axis); } else if (porder == 3.0) { // fast 3-norm - phi::funcs::ReduceKernel>( - dev_ctx, *in_x, out_norm, FabsCubicFunctor(), reduce_axis); + phi::funcs::ReduceKernel>( + dev_ctx, *in_x, &out_temp, FabsCubicFunctor(), reduce_axis); } else { // vanilla norm - phi::funcs::ReduceKernel>( - dev_ctx, *in_x, out_norm, UnsignedPowFunctor(porder), reduce_axis); + phi::funcs::ReduceKernel>( + dev_ctx, + *in_x, + &out_temp, + UnsignedPowFunctor(porder), + reduce_axis); } if (porder != 1.0) { - // save computation when porder is 1.0 - const DenseTensor* tmp_norm = out_norm; - std::vector ins = {tmp_norm}; + std::vector ins = {&out_temp}; std::vector outs = {out_norm}; + MT p_order_ = static_cast(1.f / porder); phi::funcs::ElementwiseKernel( - dev_ctx, ins, &outs, UnsignedPowFunctor(1. / porder)); + dev_ctx, ins, &outs, UnsignedPowFunctor(p_order_)); } #endif }