diff --git a/paddle/phi/kernels/funcs/dropout_impl.cu.h b/paddle/phi/kernels/funcs/dropout_impl.cu.h index a1fc2c225ecf2a..23756b3bdde960 100644 --- a/paddle/phi/kernels/funcs/dropout_impl.cu.h +++ b/paddle/phi/kernels/funcs/dropout_impl.cu.h @@ -34,6 +34,7 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/functors.h" #include "paddle/phi/kernels/primitive/compute_primitives.h" #include "paddle/phi/kernels/primitive/datamover_primitives.h" +#include "paddle/phi/kernels/scale_kernel.h" namespace phi { namespace funcs { @@ -255,17 +256,6 @@ __global__ void VectorizedGeneratorMask(const size_t n, } } -template -void ScaleByDropoutFactor(const phi::GPUContext& dev_ctx, - const phi::DenseTensor& x, - phi::DenseTensor* y, - MT factor) { - std::vector ins = {&x}; - std::vector outs = {y}; - auto functor = phi::funcs::ScaleFunctor(factor); - phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); -} - template void DropoutFwGPUKernelDriver( const phi::GPUContext& dev_ctx, @@ -389,7 +379,7 @@ void DropoutFwGPUKernelDriver( using MT = typename phi::dtype::MPTypeTrait::Type; MT factor = static_cast(1.0f - dropout_prob); // y = factor * x - ScaleByDropoutFactor(dev_ctx, x, y, factor); + phi::ScaleKernel(dev_ctx, x, factor, 0.0f, false, y); } } } @@ -425,7 +415,8 @@ void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx, if (is_test) { MT factor = static_cast(upscale_in_train ? 1.0f : 1.0f - dropout_prob); // y = factor * x - ScaleByDropoutFactor(dev_ctx, grad_y, grad_x, factor); + phi::ScaleKernel( + dev_ctx, grad_y, factor, 0.0f, false, grad_x); } else { if (upscale_in_train && dropout_prob == 1.0f) { #ifdef PADDLE_WITH_HIP diff --git a/paddle/phi/kernels/sparse/gpu/unary_kernel.cu b/paddle/phi/kernels/sparse/gpu/unary_kernel.cu index 3b6e84664f93cd..22392a9fea5d5f 100644 --- a/paddle/phi/kernels/sparse/gpu/unary_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/unary_kernel.cu @@ -17,22 +17,12 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/scale_kernel.h" #include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h" namespace phi { namespace sparse { -template -struct DivScalarFunctor { - T value_; - - explicit DivScalarFunctor(T value) : value_(value) {} - - __device__ __forceinline__ T operator()(const T x) const { - return x / value_; - } -}; - template void DivScalarCooKernel(const Context& dev_ctx, const SparseCooTensor& x, @@ -40,10 +30,8 @@ void DivScalarCooKernel(const Context& dev_ctx, SparseCooTensor* out) { EmptyLikeCooKernel(dev_ctx, x, out); - std::vector ins = {&(x.values())}; - std::vector outs = {out->mutable_values()}; - DivScalarFunctor func(static_cast(scalar)); - funcs::ElementwiseKernel>(dev_ctx, ins, &outs, func); + phi::ScaleKernel( + dev_ctx, x.values(), 1 / scalar, 0.0f, false, out->mutable_values()); } template @@ -53,10 +41,8 @@ void DivScalarCsrKernel(const Context& dev_ctx, SparseCsrTensor* out) { EmptyLikeCsrKernel(dev_ctx, x, out); - std::vector ins = {&(x.values())}; - std::vector outs = {out->mutable_values()}; - DivScalarFunctor func(static_cast(scalar)); - funcs::ElementwiseKernel>(dev_ctx, ins, &outs, func); + phi::ScaleKernel( + dev_ctx, x.values(), 1 / scalar, 0.0f, false, out->mutable_values()); } } // namespace sparse