diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index b473d68b68ba9..ecdfa7abcfd42 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -585,26 +585,16 @@ void BroadcastKernel(const KPDevice &ctx, Functor func) { std::vector dims_size; dims_size.reserve(ins.size()); - bool no_broadcast_flag = true; for (auto *in : ins) { - no_broadcast_flag &= ins[0]->dims() == in->dims(); dims_size.emplace_back(in->dims().size()); } - if (ins.size() > 0 && outs->size() > 0) { - no_broadcast_flag &= outs->at(0)->dims() == ins[0]->dims(); - } - - if (no_broadcast_flag) { - phi::funcs::ElementwiseKernel(ctx, ins, outs, func); - } else { - axis = axis == -1 - ? *std::max_element(dims_size.begin(), dims_size.end()) - - *std::min_element(dims_size.begin(), dims_size.end()) - : axis; - BroadcastKernelForDifferentVecSize( - ctx, ins, outs, axis, func); - } + axis = axis == -1 + ? *std::max_element(dims_size.begin(), dims_size.end()) - + *std::min_element(dims_size.begin(), dims_size.end()) + : axis; + BroadcastKernelForDifferentVecSize( + ctx, ins, outs, axis, func); } template diff --git a/paddle/phi/kernels/gpu/gelu_grad_kernel.cu b/paddle/phi/kernels/gpu/gelu_grad_kernel.cu index 1e21f8d4267bc..1f33d5c901f29 100644 --- a/paddle/phi/kernels/gpu/gelu_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/gelu_grad_kernel.cu @@ -81,11 +81,13 @@ void GeluGradKernel(const Context& dev_ctx, } } #endif - phi::funcs::BroadcastKernel( - dev_ctx, ins, &outs, 0, GeluWithApproximateGradFunctor()); + using Functor = GeluWithApproximateGradFunctor; + phi::funcs::ElementwiseKernel( + dev_ctx, ins, &outs, Functor()); } else { - phi::funcs::BroadcastKernel( - dev_ctx, ins, &outs, 0, GeluWithoutApproximateGradFunctor()); + using Functor = GeluWithoutApproximateGradFunctor; + phi::funcs::ElementwiseKernel( + dev_ctx, ins, &outs, Functor()); } } diff --git a/paddle/phi/kernels/gpu/gelu_kernel.cu b/paddle/phi/kernels/gpu/gelu_kernel.cu index ce6dda2d6cc65..00dc58df0d826 100644 --- a/paddle/phi/kernels/gpu/gelu_kernel.cu +++ b/paddle/phi/kernels/gpu/gelu_kernel.cu @@ -71,11 +71,13 @@ void GeluKernel(const Context& dev_ctx, } } #endif - phi::funcs::BroadcastKernel( - dev_ctx, ins, &outs, 0, GeluWithApproximateFunctor()); + using Functor = GeluWithApproximateFunctor; + phi::funcs::ElementwiseKernel( + dev_ctx, ins, &outs, Functor()); } else { - phi::funcs::BroadcastKernel( - dev_ctx, ins, &outs, 0, GeluWithoutApproximateFunctor()); + using Functor = GeluWithoutApproximateFunctor; + phi::funcs::ElementwiseKernel( + dev_ctx, ins, &outs, Functor()); } } diff --git a/paddle/phi/kernels/gpu/reduce_grad.h b/paddle/phi/kernels/gpu/reduce_grad.h index e1f7419fb7a01..ed6cc0c3c2022 100644 --- a/paddle/phi/kernels/gpu/reduce_grad.h +++ b/paddle/phi/kernels/gpu/reduce_grad.h @@ -43,22 +43,19 @@ void ReduceGrad(const GPUContext& dev_ctx, })); } -template class TransformOp> +template void ReduceGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& out_grad, const std::vector& dims, bool keep_dim, bool reduce_all, - DenseTensor* x_grad) { + DenseTensor* x_grad, + Functor functor) { auto* in_x = &x; auto* d_out = &out_grad; auto* d_x = x_grad; - auto pt_out_dtype = x.dtype(); - // get reduce_dim and reduce_num for reduce_mean_grad int dim_size = in_x->dims().size(); std::vector reduce_dims = @@ -79,14 +76,10 @@ void ReduceGradKernel(const Context& dev_ctx, auto pt_d_out = new_d_out; auto pt_d_x = *d_x; - using MPType = typename kps::details::MPTypeTrait::Type; - - phi::ReduceGrad>( - dev_ctx, - &pt_d_out, - &pt_d_x, - pt_out_dtype, - TransformOp(reduce_num)); + std::vector inputs = {&pt_d_out}; + std::vector outputs = {&pt_d_x}; + funcs::BroadcastKernel( + dev_ctx, inputs, &outputs, 0, functor); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu index b81a5e50dfb3e..50564a339ddc0 100644 --- a/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu @@ -29,8 +29,23 @@ void ReduceMeanGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { - ReduceGradKernel( - dev_ctx, x, out_grad, dims, keep_dim, reduce_all, x_grad); + int dim_size = x.dims().size(); + std::vector reduce_dims = + funcs::details::GetReduceDim(dims, dim_size, reduce_all); + int reduce_num = 1; + for (auto i : reduce_dims) { + reduce_num *= (x.dims())[i]; + } + using MPType = typename kps::details::MPTypeTrait::Type; + ReduceGradKernel>( + dev_ctx, + x, + out_grad, + dims, + keep_dim, + reduce_all, + x_grad, + kps::DivideFunctor(reduce_num)); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu index 1ad6b8fefe7e4..8b111641cfa40 100644 --- a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu @@ -29,8 +29,40 @@ void ReduceSumGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { - ReduceGradKernel( - dev_ctx, x, out_grad, dims, keep_dim, reduce_all, x_grad); + using MPType = typename kps::details::MPTypeTrait::Type; + auto out_dtype = x.dtype(); + auto* in_x = &x; + auto* d_out = &out_grad; + auto* d_x = x_grad; + + // get reduce_dim and reduce_num for reduce_mean_grad + int dim_size = in_x->dims().size(); + std::vector reduce_dims = + funcs::details::GetReduceDim(dims, dim_size, reduce_all); + + auto update_dims = vectorize(d_x->dims()); + int reduce_num = 1; + for (auto i : reduce_dims) { + reduce_num *= (in_x->dims())[i]; + update_dims[i] = 1; + } + // make new tensor + DenseTensor new_d_out(d_out->dtype()); + new_d_out.ShareDataWith(*d_out); + new_d_out.Resize(phi::make_ddim(update_dims)); + + dev_ctx.Alloc(d_x, x.dtype()); + auto pt_out_dtype = x.dtype(); + auto pt_d_out = new_d_out; + auto pt_d_x = *d_x; + std::vector inputs = {&pt_d_out}; + std::vector outputs = {&pt_d_x}; + phi::ReduceGrad>( + dev_ctx, + &pt_d_out, + &pt_d_x, + pt_out_dtype, + kps::IdentityFunctor()); } } // namespace phi @@ -48,4 +80,3 @@ PD_REGISTER_KERNEL(sum_grad, int64_t, phi::dtype::complex, phi::dtype::complex) {} - diff --git a/paddle/phi/kernels/gpu/where_kernel.cu b/paddle/phi/kernels/gpu/where_kernel.cu index a0be388065f4b..441be02b99efa 100644 --- a/paddle/phi/kernels/gpu/where_kernel.cu +++ b/paddle/phi/kernels/gpu/where_kernel.cu @@ -40,8 +40,7 @@ void WhereKernel(const Context& ctx, ctx.template Alloc(out); CondFunctor func; - funcs::BroadcastKernel( - ctx, ins, &outs, -1, func); + funcs::ElementwiseKernel, 1>(ctx, ins, &outs, func); } } // namespace phi diff --git a/paddle/phi/kernels/kps/bitwise_kernel.cu b/paddle/phi/kernels/kps/bitwise_kernel.cu index 44859785f2fb8..285b18927af80 100644 --- a/paddle/phi/kernels/kps/bitwise_kernel.cu +++ b/paddle/phi/kernels/kps/bitwise_kernel.cu @@ -51,9 +51,9 @@ void BitwiseNotKernel(const Context& dev_ctx, dev_ctx.template Alloc(out); std::vector ins = {&x}; std::vector outs = {out}; - funcs::BitwiseNotFunctor func; - funcs::BroadcastKernel( - dev_ctx, ins, &outs, -1, func); + funcs::BitwiseNotFunctor unary_func; + funcs::ElementwiseKernel>( + dev_ctx, ins, &outs, unary_func); } } // namespace phi