Skip to content

Commit

Permalink
Add template argument into functor
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesLim-sy committed Jun 24, 2021
1 parent 308307e commit 93fb64b
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions paddle/fluid/operators/controlflow/compare_all_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ limitations under the License. */

#include <thrust/fill.h>
#include "paddle/fluid/operators/controlflow/compare_all_op.h"
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"

namespace ops = paddle::operators;
namespace plat = paddle::platform;
Expand All @@ -42,16 +42,16 @@ template <typename T, typename Enable = void>
struct CudaEqualReduceFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T args[]) const {
return (args[0] == args[1]);
return (args[0] == args[1]);
}
};

template <typename T>
struct CudaEqualReduceFunctor<
T, typename std::enable_if<std::is_floating_point<T>::value>::type > {
T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T args[]) const {
return fabs(static_cast<double>(args[0] - args[1])) < 1e-8;
HOSTDEVICE bool operator()(const T args[]) const {
return fabs(static_cast<double>(args[0] - args[1])) < 1e-8;
}
};

Expand All @@ -75,13 +75,12 @@ class CompareReduceOpKernel
return;
} else {
tmp.mutable_data<bool>(x->dims(), context.GetPlace());
auto functor = Functor();
const auto& cuda_ctx =
context.template device_context<platform::CUDADeviceContext>();
context.template device_context<platform::CUDADeviceContext>();
std::vector<const framework::Tensor*> ins = {x, y};
std::vector<framework::Tensor*> outs = {&tmp};
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, T, bool>(
cuda_ctx, ins, &outs, functor);
cuda_ctx, ins, &outs, Functor<T>();

// Reduce by 'bitwise and' operator
std::vector<int> reduce_dims;
Expand All @@ -98,14 +97,17 @@ class CompareReduceOpKernel
} // namespace operators
} // namespace paddle

#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \
REGISTER_OP_CUDA_KERNEL( \
op_type, \
#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \
REGISTER_OP_CUDA_KERNEL( \
op_type, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, ops::functor<bool>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, ops::functor<int>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, ops::functor<int64_t>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, ops::functor<float>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, ops::functor<double>>);
ops::CompareReduceOpKernel<plat::CUDADeviceContext, ops::functor<int>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
ops::functor<int64_t>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
ops::functor<float>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
ops::functor<double>>);

REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all, CudaEqualReduceFunctor)
#undef REGISTER_COMPARE_REDUCE_CUDA_KERNEL

0 comments on commit 93fb64b

Please sign in to comment.