Skip to content

Commit

Permalink
refactor (#60968)
Browse files Browse the repository at this point in the history
  • Loading branch information
cocoshe authored Jan 19, 2024
1 parent f86f9dd commit 99717e9
Show file tree
Hide file tree
Showing 14 changed files with 258 additions and 394 deletions.
51 changes: 0 additions & 51 deletions paddle/phi/kernels/copysign_grad_kernel.h

This file was deleted.

61 changes: 0 additions & 61 deletions paddle/phi/kernels/copysign_kernel.h

This file was deleted.

83 changes: 0 additions & 83 deletions paddle/phi/kernels/cpu/copysign_grad_kernel.cc

This file was deleted.

52 changes: 0 additions & 52 deletions paddle/phi/kernels/cpu/copysign_kernel.cc

This file was deleted.

37 changes: 37 additions & 0 deletions paddle/phi/kernels/cpu/elementwise_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,28 @@ void MinimumGradKernel(const Context& dev_ctx,
dev_ctx, x, y, dout, dout, axis, dx, dy, MinGradDx<T>(), MinGradDy<T>());
}

template <typename T, typename Context>
void CopySignGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
funcs::ElementwiseGradPreProcess(out_grad, x_grad);
int axis = -1;
phi::funcs::
ElemwiseGradCompute<Context, T, CopySignGradDX<T>, CopySignGradDY<T>>(
dev_ctx,
x,
y,
out_grad,
out_grad,
axis,
x_grad,
y_grad,
CopySignGradDX<T>(),
CopySignGradDY<T>());
}
} // namespace phi

PD_REGISTER_KERNEL(fmax_grad,
Expand Down Expand Up @@ -107,3 +129,18 @@ PD_REGISTER_KERNEL(elementwise_pow_grad,
int,
int64_t,
phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(copysign_grad,
CPU,
ALL_LAYOUT,
phi::CopySignGradKernel,
bool,
uint8_t,
int8_t,
int16_t,
int,
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
32 changes: 32 additions & 0 deletions paddle/phi/kernels/cpu/elementwise_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,23 @@ void HeavisideKernel(const Context& dev_ctx,
dev_ctx, x, y, funcs::ElementwiseHeavisideFunctor<T>(), out);
}

template <typename T, typename Context>
void CopySignKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
auto x_dims = x.dims();
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::CopySignFunctor<T>, T>(
dev_ctx, x, y, funcs::CopySignFunctor<T>(), out);
} else {
funcs::ElementwiseCompute<funcs::InverseCopySignFunctor<T>, T>(
dev_ctx, x, y, funcs::InverseCopySignFunctor<T>(), out);
}
}

} // namespace phi

using complex64 = ::phi::dtype::complex<float>;
Expand Down Expand Up @@ -148,3 +165,18 @@ PD_REGISTER_KERNEL(heaviside,
double,
int,
int64_t) {}

PD_REGISTER_KERNEL(copysign,
CPU,
ALL_LAYOUT,
phi::CopySignKernel,
bool,
uint8_t,
int8_t,
int16_t,
int,
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
9 changes: 9 additions & 0 deletions paddle/phi/kernels/elementwise_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,13 @@ void ElementwisePowGradKernel(const Context& dev_ctx,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy);

template <typename T, typename Context>
void CopySignGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
DenseTensor* x_grad,
DenseTensor* y_grad);

} // namespace phi
6 changes: 6 additions & 0 deletions paddle/phi/kernels/elementwise_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ void HeavisideKernel(const Context& dev_ctx,
const DenseTensor& y,
DenseTensor* out);

template <typename T, typename Context>
void CopySignKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);

template <typename T, typename Context>
DenseTensor Maximum(const Context& dev_ctx,
const DenseTensor& x,
Expand Down
Loading

0 comments on commit 99717e9

Please sign in to comment.