diff --git a/paddle/fluid/operators/controlflow/compare_all_op.cu b/paddle/fluid/operators/controlflow/compare_all_op.cu index 5c52af4d28cae..8fa52b95c85ea 100644 --- a/paddle/fluid/operators/controlflow/compare_all_op.cu +++ b/paddle/fluid/operators/controlflow/compare_all_op.cu @@ -14,8 +14,8 @@ limitations under the License. */ #include #include "paddle/fluid/operators/controlflow/compare_all_op.h" -#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" +#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" namespace ops = paddle::operators; namespace plat = paddle::platform; @@ -42,16 +42,16 @@ template struct CudaEqualReduceFunctor { using ELEM_TYPE = T; HOSTDEVICE bool operator()(const T args[]) const { - return (args[0] == args[1]); + return (args[0] == args[1]); } }; template struct CudaEqualReduceFunctor< - T, typename std::enable_if::value>::type > { + T, typename std::enable_if::value>::type> { using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T args[]) const { - return fabs(static_cast(args[0] - args[1])) < 1e-8; + HOSTDEVICE bool operator()(const T args[]) const { + return fabs(static_cast(args[0] - args[1])) < 1e-8; } }; @@ -75,13 +75,12 @@ class CompareReduceOpKernel return; } else { tmp.mutable_data(x->dims(), context.GetPlace()); - auto functor = Functor(); const auto& cuda_ctx = - context.template device_context(); + context.template device_context(); std::vector ins = {x, y}; std::vector outs = {&tmp}; LaunchSameDimsElementwiseCudaKernel( - cuda_ctx, ins, &outs, functor); + cuda_ctx, ins, &outs, Functor(); // Reduce by 'bitwise and' operator std::vector reduce_dims; @@ -98,14 +97,17 @@ class CompareReduceOpKernel } // namespace operators } // namespace paddle -#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \ - REGISTER_OP_CUDA_KERNEL( \ - op_type, \ +#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \ + REGISTER_OP_CUDA_KERNEL( \ + op_type, \ ops::CompareReduceOpKernel>, \ - ops::CompareReduceOpKernel>, \ - ops::CompareReduceOpKernel>, \ - ops::CompareReduceOpKernel>, \ - ops::CompareReduceOpKernel>); + ops::CompareReduceOpKernel>, \ + ops::CompareReduceOpKernel>, \ + ops::CompareReduceOpKernel>, \ + ops::CompareReduceOpKernel>); REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all, CudaEqualReduceFunctor) #undef REGISTER_COMPARE_REDUCE_CUDA_KERNEL