Skip to content

Commit

Permalink
Modify compare logical inplace (PaddlePaddle#56888)
Browse files Browse the repository at this point in the history
* fix error

* fix compare

* fix

* fix

* remove fluid

* fix inpalce test

* fix and sep inpalce impl
  • Loading branch information
GGBond8488 authored Sep 13, 2023
1 parent 9e35a25 commit 5d0b968
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 113 deletions.
70 changes: 42 additions & 28 deletions paddle/phi/kernels/cpu/compare_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,34 @@ inline void CompareKernelImpl(const Context& ctx,
const DenseTensor& y,
int axis,
DenseTensor* out) {
if (!out->IsSharedWith(x)) {
ctx.template Alloc<bool>(out);
if (x.dims().size() >= y.dims().size()) {
funcs::ElementwiseCompute<Functor, T, bool>(
ctx, x, y, Functor(), out, axis);
} else {
funcs::ElementwiseCompute<InverseFunctor, T, bool>(
ctx, x, y, InverseFunctor(), out, axis);
}
ctx.template Alloc<bool>(out);
if (x.dims().size() >= y.dims().size()) {
funcs::ElementwiseCompute<Functor, T, bool>(
ctx, x, y, Functor(), out, axis);
} else {
if (x.dims().size() >= y.dims().size()) {
funcs::ElementwiseCompute<Functor, T, T>(ctx, x, y, Functor(), out, axis);
} else {
funcs::ElementwiseCompute<InverseFunctor, T, T>(
ctx, x, y, InverseFunctor(), out, axis);
}
funcs::ElementwiseCompute<InverseFunctor, T, bool>(
ctx, x, y, InverseFunctor(), out, axis);
}
}

template <typename T,
typename Context,
typename Functor,
typename InverseFunctor>
inline void InplaceCompareKernelImpl(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
auto x_origin = x;
out->set_type(phi::DataType::BOOL);
ctx.template Alloc<bool>(out);
if (x_origin.dims().size() >= y.dims().size()) {
funcs::ElementwiseCompute<Functor, T, bool>(
ctx, x_origin, y, Functor(), out, axis);
} else {
funcs::ElementwiseCompute<InverseFunctor, T, bool>(
ctx, x_origin, y, InverseFunctor(), out, axis);
}
}

Expand Down Expand Up @@ -92,19 +104,21 @@ PD_REGISTER_KERNEL(equal_all,
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}

#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \
CPU, \
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) {}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \
CPU, \
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}
PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan)
Expand Down
67 changes: 46 additions & 21 deletions paddle/phi/kernels/cpu/logical_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,40 @@

namespace phi {

#define DEFINE_LOGICAL_BINARY_KERNEL(type) \
template <typename T, typename Context> \
void Logical##type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
funcs::Logical##type##Functor<T> binary_func; \
if (out->IsSharedWith(x)) { \
funcs::ElementwiseCompute<funcs::Logical##type##Functor<T>, T, T>( \
dev_ctx, x, y, binary_func, out); \
} else { \
funcs::ElementwiseCompute<funcs::Logical##type##Functor<T>, T, bool>( \
dev_ctx, x, y, binary_func, out); \
} \
template <typename T, typename Context, typename Functor>
void LogicalKernelImpl(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
Functor binary_func;
funcs::ElementwiseCompute<Functor, T, bool>(dev_ctx, x, y, binary_func, out);
}

template <typename T, typename Context, typename Functor>
void InplaceLogicalKernelImpl(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
Functor binary_func;
auto x_origin = x;
out->set_type(phi::DataType::BOOL);
funcs::ElementwiseCompute<Functor, T, bool>(
dev_ctx, x_origin, y, binary_func, out);
}

#define DEFINE_LOGICAL_BINARY_KERNEL(type) \
template <typename T, typename Context> \
void Logical##type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
if (out->IsSharedWith(x)) { \
InplaceLogicalKernelImpl<T, Context, funcs::Logical##type##Functor<T>>( \
dev_ctx, x, y, out); \
} else { \
LogicalKernelImpl<T, Context, funcs::Logical##type##Functor<T>>( \
dev_ctx, x, y, out); \
} \
}

DEFINE_LOGICAL_BINARY_KERNEL(And)
Expand All @@ -52,15 +72,18 @@ void LogicalNotKernel(const Context& dev_ctx,
funcs::LogicalNotFunctor<T> unary_func;

phi::Transform<Context> trans;
if (!out->IsSharedWith(x)) {
if (out->IsSharedWith(x)) {
auto x_origin = x;
out->set_type(phi::DataType::BOOL);
auto* out_ptr = dev_ctx.template Alloc<bool>(out);
trans(dev_ctx, x.data<T>(), x.data<T>() + x.numel(), out_ptr, unary_func);
} else {
trans(dev_ctx,
x.data<T>(),
x.data<T>() + x.numel(),
reinterpret_cast<T*>(out->data()),
x_origin.data<T>(),
x_origin.data<T>() + x_origin.numel(),
out_ptr,
unary_func);
} else {
auto* out_ptr = dev_ctx.template Alloc<bool>(out);
trans(dev_ctx, x.data<T>(), x.data<T>() + x.numel(), out_ptr, unary_func);
}
}

Expand All @@ -79,7 +102,9 @@ void LogicalNotKernel(const Context& dev_ctx,
int8_t, \
phi::dtype::complex<float>, \
phi::dtype::complex<double>, \
int16_t) {}
int16_t) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}

REGISTER_LOGICAL_CPU_KERNEL(logical_and, And)
REGISTER_LOGICAL_CPU_KERNEL(logical_or, Or)
Expand Down
31 changes: 23 additions & 8 deletions paddle/phi/kernels/impl/compare_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,35 @@ inline void CompareKernelImpl(const Context& ctx,
int axis,
DenseTensor* out);

template <typename T,
typename Context,
typename Functor,
typename InverseFunctor>
inline void InplaceCompareKernelImpl(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);

template <typename T, typename Context, typename Functor>
inline void CompareAllKernelImpl(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);

#define DEFINE_COMPARE_KERNEL(name, functor, inverse_functor) \
template <typename T, typename Context> \
void name##Kernel(const Context& ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
CompareKernelImpl<T, Context, functor<T>, inverse_functor<T>>( \
ctx, x, y, -1, out); \
#define DEFINE_COMPARE_KERNEL(name, functor, inverse_functor) \
template <typename T, typename Context> \
void name##Kernel(const Context& ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
if (out->IsSharedWith(x)) { \
InplaceCompareKernelImpl<T, Context, functor<T>, inverse_functor<T>>( \
ctx, x, y, -1, out); \
} else { \
CompareKernelImpl<T, Context, functor<T>, inverse_functor<T>>( \
ctx, x, y, -1, out); \
} \
}

DEFINE_COMPARE_KERNEL(LessThan,
Expand Down
54 changes: 34 additions & 20 deletions paddle/phi/kernels/kps/compare_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,27 @@ inline void CompareKernelImpl(const Context& ctx,
const DenseTensor& y,
int axis,
DenseTensor* out) {
if (!out->IsSharedWith(x)) {
ctx.template Alloc<bool>(out);
}
ctx.template Alloc<bool>(out);
std::vector<const DenseTensor*> ins{&x, &y};
std::vector<DenseTensor*> outs{out};
if (!out->IsSharedWith(x)) {
funcs::BroadcastKernel<bool>(ctx, ins, &outs, Functor(), axis);
} else {
funcs::BroadcastKernel<T>(ctx, ins, &outs, Functor(), axis);
}
funcs::BroadcastKernel<bool>(ctx, ins, &outs, Functor(), axis);
}

template <typename T,
typename Context,
typename Functor,
typename InverseFunctor>
inline void InplaceCompareKernelImpl(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
auto x_origin = x;
ctx.template Alloc<bool>(out);
out->set_type(phi::DataType::BOOL);
std::vector<const DenseTensor*> ins{&x_origin, &y};
std::vector<DenseTensor*> outs{out};
funcs::BroadcastKernel<bool>(ctx, ins, &outs, Functor(), axis);
}

#ifndef PADDLE_WITH_XPU_KP
Expand Down Expand Up @@ -134,18 +145,21 @@ PD_REGISTER_KERNEL(equal_all,
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}

#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \
KPS, \
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16) {}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \
KPS, \
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}

PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
Expand Down
77 changes: 52 additions & 25 deletions paddle/phi/kernels/kps/logical_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,45 @@

namespace phi {

#define DEFINE_LOGICAL_BINARY_KERNEL(type) \
template <typename T, typename Context> \
void Logical##type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
if (!out->IsSharedWith(x)) { \
dev_ctx.template Alloc<bool>(out); \
} \
\
funcs::Logical##type##Functor<T> binary_func; \
std::vector<const DenseTensor*> ins = {&x, &y}; \
std::vector<DenseTensor*> outs = {out}; \
if (!out->IsSharedWith(x)) { \
funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, binary_func); \
} else { \
funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, binary_func); \
} \
template <typename T, typename Context, typename Functor>
void LogicalKernelImpl(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
dev_ctx.template Alloc<bool>(out);
Functor binary_func;
std::vector<const DenseTensor*> ins = {&x, &y};
std::vector<DenseTensor*> outs = {out};
funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, binary_func);
}

template <typename T, typename Context, typename Functor>
void InplaceLogicalKernelImpl(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
auto x_origin = x;
dev_ctx.template Alloc<bool>(out);
out->set_type(phi::DataType::BOOL);
Functor binary_func;
std::vector<const DenseTensor*> ins = {&x_origin, &y};
std::vector<DenseTensor*> outs = {out};
funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, binary_func);
}

#define DEFINE_LOGICAL_BINARY_KERNEL(type) \
template <typename T, typename Context> \
void Logical##type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
if (out->IsSharedWith(x)) { \
InplaceLogicalKernelImpl<T, Context, funcs::Logical##type##Functor<T>>( \
dev_ctx, x, y, out); \
} else { \
LogicalKernelImpl<T, Context, funcs::Logical##type##Functor<T>>( \
dev_ctx, x, y, out); \
} \
}

DEFINE_LOGICAL_BINARY_KERNEL(And)
Expand All @@ -56,14 +77,18 @@ void LogicalNotKernel(const Context& dev_ctx,
DenseTensor* out) {
if (!out->IsSharedWith(x)) {
dev_ctx.template Alloc<bool>(out);
}
funcs::LogicalNotFunctor<T> unary_func;
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
if (!out->IsSharedWith(x)) {
funcs::LogicalNotFunctor<T> unary_func;
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, unary_func);
} else {
funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, unary_func);
auto x_origin = x;
out->set_type(phi::DataType::BOOL);
dev_ctx.template Alloc<bool>(out);
funcs::LogicalNotFunctor<T> unary_func;
std::vector<const DenseTensor*> ins = {&x_origin};
std::vector<DenseTensor*> outs = {out};
funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, unary_func);
}
}

Expand Down Expand Up @@ -99,7 +124,9 @@ PD_REGISTER_KERNEL(logical_xor, KPS, ALL_LAYOUT, phi::LogicalXorKernel, int) {
int8_t, \
phi::dtype::complex<float>, \
phi::dtype::complex<double>, \
int16_t) {}
int16_t) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}

REGISTER_LOGICAL_CUDA_KERNEL(logical_and, And)
REGISTER_LOGICAL_CUDA_KERNEL(logical_or, Or)
Expand Down
Loading

0 comments on commit 5d0b968

Please sign in to comment.