Skip to content

Commit

Permalink
Compile optimize frobenius_norm op and trace op (#57327)
Browse files Browse the repository at this point in the history
* add sqrtkerenl

* update

* add trace_kernel

* update

---------

Co-authored-by: niuliling123 <sdust_niull@163.com>
  • Loading branch information
yinwei and AnnaTrainingG authored Sep 19, 2023
1 parent 86308a3 commit 92f5c10
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
8 changes: 4 additions & 4 deletions paddle/phi/kernels/gpu/frobenius_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
#include "paddle/phi/kernels/frobenius_norm_kernel.h"

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/gpu/reduce.h"

namespace phi {

template <typename T, typename Context>
Expand All @@ -30,10 +32,8 @@ void FrobeniusNormKernel(const Context& dev_ctx,
auto out_dtype = x.dtype();
phi::Reduce<T, kps::AddFunctor, kps::SquareFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
std::vector<const DenseTensor*> ins = {out};
std::vector<DenseTensor*> outs = {out};
auto functor = funcs::CudaSqrtFunctor<T>();
funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, functor);

SqrtKernel<T, Context>(dev_ctx, *out, out);
}

} // namespace phi
Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/kernels/gpu/trace_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/diagonal.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"

namespace phi {

Expand All @@ -36,8 +37,8 @@ void TraceKernel(const Context& ctx,
auto out_dim_size = out->dims().size();
if (out_dim_size == 0) out_dim_size = 1;
reduce_dims.push_back(out_dim_size);
funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
ctx, diag, out, kps::IdentityFunctor<T>(), reduce_dims);
phi::SumKernel<T, Context>(
ctx, diag, reduce_dims, diag.dtype(), false, out);
} else {
phi::funcs::SetConstant<Context, T> functor;
functor(ctx, out, static_cast<T>(0));
Expand Down

0 comments on commit 92f5c10

Please sign in to comment.