From 92f5c10b41b94ed5f8eec5f749d313d9cd105e35 Mon Sep 17 00:00:00 2001 From: yinwei <1871465933@qq.com> Date: Tue, 19 Sep 2023 20:48:08 +0800 Subject: [PATCH] Compile optimize frobenius_norm op and trace op (#57327) * add sqrtkerenl * update * add trace_kernel * update --------- Co-authored-by: niuliling123 --- paddle/phi/kernels/gpu/frobenius_norm_kernel.cu | 8 ++++---- paddle/phi/kernels/gpu/trace_kernel.cu | 5 +++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu b/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu index 9878aa6ee231d3..f2be0f073a87de 100644 --- a/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu @@ -15,8 +15,10 @@ #include "paddle/phi/kernels/frobenius_norm_kernel.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/activation_kernel.h" #include "paddle/phi/kernels/funcs/activation_functor.h" #include "paddle/phi/kernels/gpu/reduce.h" + namespace phi { template @@ -30,10 +32,8 @@ void FrobeniusNormKernel(const Context& dev_ctx, auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); - std::vector ins = {out}; - std::vector outs = {out}; - auto functor = funcs::CudaSqrtFunctor(); - funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); + + SqrtKernel(dev_ctx, *out, out); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/trace_kernel.cu b/paddle/phi/kernels/gpu/trace_kernel.cu index 304bf778094d3d..b5a38874ecfe4e 100644 --- a/paddle/phi/kernels/gpu/trace_kernel.cu +++ b/paddle/phi/kernels/gpu/trace_kernel.cu @@ -18,6 +18,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/diagonal.h" #include "paddle/phi/kernels/funcs/reduce_function.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" namespace phi { @@ -36,8 +37,8 @@ void TraceKernel(const Context& ctx, auto out_dim_size = out->dims().size(); if (out_dim_size == 0) out_dim_size = 1; reduce_dims.push_back(out_dim_size); - funcs::ReduceKernel>( - ctx, diag, out, kps::IdentityFunctor(), reduce_dims); + phi::SumKernel( + ctx, diag, reduce_dims, diag.dtype(), false, out); } else { phi::funcs::SetConstant functor; functor(ctx, out, static_cast(0));