Skip to content

Commit

Permalink
Replace usage of elementwise cuda forward kernel in Compare_all_op (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesLim-sy authored Jul 5, 2021
1 parent 4d16724 commit ea1a0d4
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 56 deletions.
22 changes: 3 additions & 19 deletions paddle/fluid/operators/controlflow/compare_all_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,13 @@ class CompareReduceOpKernel
auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out");
bool shape_same = true;

Tensor tmp;
framework::DDim x_dims = x->dims();
framework::DDim y_dims = y->dims();

// judge the two inputs shape is same, if not same, just return false
if (x_dims.size() != y_dims.size()) {
shape_same = false;
} else {
for (auto i = 0; i < x_dims.size(); i++) {
if (x_dims[i] != y_dims[i]) {
shape_same = false;
break;
}
}
}

bool* z_data = z->mutable_data<bool>(context.GetPlace());
if (!shape_same) {

if (x->dims() != y->dims()) {
z_data[0] = false;
} else {
tmp.mutable_data<bool>(x_dims, context.GetPlace());
tmp.mutable_data<bool>(x->dims(), context.GetPlace());
if (x->numel() == 1 && y->numel() == 1) {
bool* z_data = tmp.mutable_data<bool>(context.GetPlace());
z_data[0] = Functor()(x->data<T>()[0], y->data<T>()[0]);
Expand Down
83 changes: 47 additions & 36 deletions paddle/fluid/operators/controlflow/compare_all_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@ limitations under the License. */

#include <thrust/fill.h>
#include "paddle/fluid/operators/controlflow/compare_all_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"

namespace ops = paddle::operators;
namespace plat = paddle::platform;

namespace paddle {
namespace operators {

template <typename T>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}

HOSTDEVICE inline T operator()(const T& x) const { return x; }
};

Expand All @@ -33,6 +37,24 @@ struct BitwiseAdd {
return a & b;
}
};

template <typename T, typename Enable = void>
struct CudaEqualReduceFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T args[]) const {
return (args[0] == args[1]);
}
};

template <typename T>
struct CudaEqualReduceFunctor<
T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T args[]) const {
return fabs(static_cast<double>(args[0] - args[1])) < 1e-8;
}
};

template <typename DeviceContext, typename Functor>
class CompareReduceOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
Expand All @@ -44,32 +66,22 @@ class CompareReduceOpKernel
auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out");
bool shape_same = true;

bool* z_data = z->mutable_data<bool>(context.GetPlace());
Tensor tmp;
framework::DDim x_dims = x->dims();
framework::DDim y_dims = y->dims();

if (x_dims.size() != y_dims.size()) {
shape_same = false;
} else {
for (auto i = 0; i < x_dims.size(); i++) {
if (x_dims[i] != y_dims[i]) {
shape_same = false;
break;
}
}
}

bool* z_data = z->mutable_data<bool>(context.GetPlace());
if (!shape_same) {
if (x->dims() != y->dims()) {
thrust::device_ptr<bool> z_dev_ptr(z_data);
thrust::fill(z_dev_ptr, z_dev_ptr + 1, false);
return;
} else {
tmp.mutable_data<bool>(x_dims, context.GetPlace());
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, 0,
Functor(), &tmp);
tmp.mutable_data<bool>(x->dims(), context.GetPlace());
const auto& cuda_ctx =
context.template device_context<platform::CUDADeviceContext>();
std::vector<const framework::Tensor*> ins = {x, y};
std::vector<framework::Tensor*> outs = {&tmp};
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, T, bool>(
cuda_ctx, ins, &outs, Functor());

// Reduce by 'bitwise and' operator
std::vector<int> reduce_dims;
reduce_dims.resize(tmp.dims().size());
Expand All @@ -85,18 +97,17 @@ class CompareReduceOpKernel
} // namespace operators
} // namespace paddle

#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \
REGISTER_OP_CUDA_KERNEL( \
op_type, paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<bool>>, \
paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<int>>, \
paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<int64_t>>, \
paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<float>>, \
paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<double>>);

REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all,
paddle::operators::EqualReduceFunctor);
#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \
REGISTER_OP_CUDA_KERNEL( \
op_type, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, ops::functor<bool>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, ops::functor<int>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
ops::functor<int64_t>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
ops::functor<float>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
ops::functor<double>>);

REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all, CudaEqualReduceFunctor)
#undef REGISTER_COMPARE_REDUCE_CUDA_KERNEL
1 change: 0 additions & 1 deletion paddle/fluid/operators/controlflow/compare_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ struct CudaNotEqualFunctor<
template <typename Functor, typename InverseFunctor>
class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
public:
using InT = typename Functor::ELEMENT_TYPE;
using OutT = bool;
Expand Down

0 comments on commit ea1a0d4

Please sign in to comment.