Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace usage of elementwise cuda forward kernel in Compare_all_op #33754

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 3 additions & 19 deletions paddle/fluid/operators/controlflow/compare_all_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,13 @@ class CompareReduceOpKernel
auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out");
bool shape_same = true;

Tensor tmp;
framework::DDim x_dims = x->dims();
framework::DDim y_dims = y->dims();

// judge the two inputs shape is same, if not same, just return false
if (x_dims.size() != y_dims.size()) {
shape_same = false;
} else {
for (auto i = 0; i < x_dims.size(); i++) {
if (x_dims[i] != y_dims[i]) {
shape_same = false;
break;
}
}
}

bool* z_data = z->mutable_data<bool>(context.GetPlace());
if (!shape_same) {

if (x->dims() != y->dims()) {
z_data[0] = false;
} else {
tmp.mutable_data<bool>(x_dims, context.GetPlace());
tmp.mutable_data<bool>(x->dims(), context.GetPlace());
if (x->numel() == 1 && y->numel() == 1) {
bool* z_data = tmp.mutable_data<bool>(context.GetPlace());
z_data[0] = Functor()(x->data<T>()[0], y->data<T>()[0]);
Expand Down
83 changes: 47 additions & 36 deletions paddle/fluid/operators/controlflow/compare_all_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@ limitations under the License. */

#include <thrust/fill.h>
#include "paddle/fluid/operators/controlflow/compare_all_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"

namespace ops = paddle::operators;
namespace plat = paddle::platform;

namespace paddle {
namespace operators {

template <typename T>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}

HOSTDEVICE inline T operator()(const T& x) const { return x; }
};

Expand All @@ -33,6 +37,24 @@ struct BitwiseAdd {
return a & b;
}
};

template <typename T, typename Enable = void>
struct CudaEqualReduceFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T args[]) const {
return (args[0] == args[1]);
}
};

template <typename T>
struct CudaEqualReduceFunctor<
T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T args[]) const {
return fabs(static_cast<double>(args[0] - args[1])) < 1e-8;
}
};

template <typename DeviceContext, typename Functor>
class CompareReduceOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
Expand All @@ -44,32 +66,22 @@ class CompareReduceOpKernel
auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out");
bool shape_same = true;

bool* z_data = z->mutable_data<bool>(context.GetPlace());
Tensor tmp;
framework::DDim x_dims = x->dims();
framework::DDim y_dims = y->dims();

if (x_dims.size() != y_dims.size()) {
shape_same = false;
} else {
for (auto i = 0; i < x_dims.size(); i++) {
if (x_dims[i] != y_dims[i]) {
shape_same = false;
break;
}
}
}

bool* z_data = z->mutable_data<bool>(context.GetPlace());
if (!shape_same) {
if (x->dims() != y->dims()) {
thrust::device_ptr<bool> z_dev_ptr(z_data);
thrust::fill(z_dev_ptr, z_dev_ptr + 1, false);
return;
} else {
tmp.mutable_data<bool>(x_dims, context.GetPlace());
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, 0,
Functor(), &tmp);
tmp.mutable_data<bool>(x->dims(), context.GetPlace());
const auto& cuda_ctx =
context.template device_context<platform::CUDADeviceContext>();
std::vector<const framework::Tensor*> ins = {x, y};
std::vector<framework::Tensor*> outs = {&tmp};
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, T, bool>(
cuda_ctx, ins, &outs, Functor());

// Reduce by 'bitwise and' operator
std::vector<int> reduce_dims;
reduce_dims.resize(tmp.dims().size());
Expand All @@ -85,18 +97,17 @@ class CompareReduceOpKernel
} // namespace operators
} // namespace paddle

#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \
REGISTER_OP_CUDA_KERNEL( \
op_type, paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<bool>>, \
paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<int>>, \
paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<int64_t>>, \
paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<float>>, \
paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<double>>);

REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all,
paddle::operators::EqualReduceFunctor);
#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \
REGISTER_OP_CUDA_KERNEL( \
op_type, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, ops::functor<bool>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, ops::functor<int>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
ops::functor<int64_t>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
ops::functor<float>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
ops::functor<double>>);

REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all, CudaEqualReduceFunctor)
#undef REGISTER_COMPARE_REDUCE_CUDA_KERNEL
1 change: 0 additions & 1 deletion paddle/fluid/operators/controlflow/compare_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ struct CudaNotEqualFunctor<
template <typename Functor, typename InverseFunctor>
class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
public:
using InT = typename Functor::ELEMENT_TYPE;
using OutT = bool;
Expand Down