diff --git a/paddle/fluid/operators/controlflow/compare_op.cu b/paddle/fluid/operators/controlflow/compare_op.cu index a52920d9e8701..cc0c46adb119a 100644 --- a/paddle/fluid/operators/controlflow/compare_op.cu +++ b/paddle/fluid/operators/controlflow/compare_op.cu @@ -21,21 +21,21 @@ namespace plat = paddle::platform; namespace paddle { namespace operators { -#define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(Func, op) \ +#define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(func, op) \ template \ - struct Func##Functor { \ + struct func { \ using ELEMENT_TYPE = T; \ inline HOSTDEVICE bool operator()(const T* args) const { \ return args[0] op args[1]; \ } \ }; -DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessThan, <) -DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessEqual, <=) -DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterThan, >) -DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterEqual, >=) -DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaEqual, ==) -DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaNotEqual, !=) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessThanFunctor, <) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessEqualFunctor, <=) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterThanFunctor, >) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterEqualFunctor, >=) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaEqualFunctor, ==) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaNotEqualFunctor, !=) #undef DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT template @@ -67,10 +67,12 @@ class CompareOpKernel auto functor = Functor(); std::vector ins; std::vector outs; + const auto& cuda_ctx = + ctx.template device_context(); - PackTensorsIntoVector(ctx, &ins, &outs); + int axis = PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - ctx, ins, &outs, functor); + cuda_ctx, ins, &outs, axis, functor); } }; @@ -79,19 +81,16 @@ class CompareOpKernel #define REGISTER_CUDA_COMPARE_KERNEL(op_type, func) \ REGISTER_OP_CUDA_KERNEL( \ - op_type, ops::CompareOpKernel, void>, \ - ops::CompareOpKernel, void>, \ - ops::CompareOpKernel, \ - void>, \ - ops::CompareOpKernel, void>); + op_type, \ + ops::CompareOpKernel, void>, \ + ops::CompareOpKernel, void>, \ + ops::CompareOpKernel, void>, \ + ops::CompareOpKernel, void>); -REGISTER_CUDA_COMPARE_KERNEL(equal, CudaEqual) -REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqual) -REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThan) -REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqual) -REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThan) -REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqual) +REGISTER_CUDA_COMPARE_KERNEL(equal, CudaEqualFunctor) +REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqualFunctor) +REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThanFunctor) +REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqualFunctor) +REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThanFunctor) +REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqualFunctor) #undef REGISTER_CUDA_COMPARE_KERNEL diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index 8a1e529c99619..aff0cb281642e 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -44,9 +44,12 @@ class ElementwiseAddKernel void Compute(const framework::ExecutionContext& ctx) const override { std::vector ins; std::vector outs; - PackTensorsIntoVector(ctx, &ins, &outs); + const auto& cuda_ctx = + ctx.template device_context(); + + int axis = PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - ctx, ins, &outs, CudaAddFunctor()); + cuda_ctx, ins, &outs, axis, CudaAddFunctor()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cu b/paddle/fluid/operators/elementwise/elementwise_max_op.cu index 6e055e8f054fa..483b21d07fab1 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cu @@ -34,10 +34,12 @@ class ElementwiseMaxKernel void Compute(const framework::ExecutionContext& ctx) const override { std::vector ins; std::vector outs; - + const auto& cuda_ctx = + ctx.template device_context(); + int axis = PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - ctx, ins, &outs, axis, CudaMaxFunctor()); + cuda_ctx, ins, &outs, axis, CudaMaxFunctor()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.cu b/paddle/fluid/operators/elementwise/elementwise_min_op.cu index 946f1e36509a2..88faaf257af45 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cu @@ -34,10 +34,12 @@ class ElementwiseMinKernel void Compute(const framework::ExecutionContext& ctx) const override { std::vector ins; std::vector outs; - PackTensorsIntoVector(ctx, &ins, &outs); + const auto& cuda_ctx = + ctx.template device_context(); + int axis = PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - ctx, ins, &outs, CudaMinFunctor()); + cuda_ctx, ins, &outs, axis, CudaMinFunctor()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index e68c2319094ea..973f2305cc778 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -36,15 +36,25 @@ class ElementwiseMulKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + int axis = -1; auto x_var = ctx.InputVar("X"); - PADDLE_ENFORCE_EQ(x_var != nullptr, true, - platform::errors::InvalidArgument( - "Cannot get input Variable X, Variable name = %s.", - ctx.InputName("X"))); + PADDLE_ENFORCE_NOT_NULL( + x_var, platform::errors::InvalidArgument( + "Cannot get input Variable X, Variable name = %s.", + ctx.InputName("X"))); auto* y = ctx.Input("Y"); framework::Tensor x, *z; - if (x_var->IsType()) { + std::vector ins; + std::vector outs; + const auto& cuda_ctx = + ctx.template device_context(); + + if (x_var->IsType()) { + x = x_var->Get(); + z = ctx.Output("Out"); + axis = PackTensorsIntoVector(ctx, &ins, &outs); + } else if (x_var->IsType()) { PADDLE_ENFORCE_EQ(y->dims().size() == 1 && y->dims()[0] == 1, true, platform::errors::InvalidArgument( "For elementwise_op, if X is Sparse, Y must be " @@ -58,21 +68,22 @@ class ElementwiseMulKernel out_sele->mutable_value()->Resize(x_sele.value().dims()); out_sele->mutable_value()->mutable_data(ctx.GetPlace(), x.type()); z = ctx.Output("Out")->mutable_value(); - } else if (x_var->IsType()) { - x = x_var->Get(); - z = ctx.Output("Out"); + z->mutable_data(ctx.GetPlace()); + outs.emplace_back(z); + ins.emplace_back(&x); + ins.emplace_back(y); + + axis = ctx.HasAttr("axis") ? ctx.Attr("axis") : -1; + axis = axis == -1 ? std::abs(y->dims().size() - x.dims().size()) : axis; } else { PADDLE_THROW(platform::errors::InvalidArgument( "X's type[%s] is not supported by elementwise_op. X's type should be " "LoDTensor or SelectedRows.", framework::ToTypeName(x_var->Type()))); } - z->mutable_data(ctx.GetPlace()); - std::vector ins = {&x, y}; - std::vector outs = {z}; LaunchElementwiseCudaKernel( - ctx, ins, &outs, CudaMulFunctor()); + cuda_ctx, ins, &outs, axis, CudaMulFunctor()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h index 9bcefa351296e..74216d6a9d4d5 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h @@ -506,26 +506,18 @@ void LaunchBroadcastElementwiseCudaKernel( template void LaunchElementwiseCudaKernel( - const framework::ExecutionContext &ctx, + const platform::CUDADeviceContext &cuda_ctx, const std::vector &ins, - std::vector *outs, Functor func) { - std::vector dims_size; + std::vector *outs, int axis, Functor func) { bool no_broadcast_flag = true; for (auto *in : ins) { no_broadcast_flag = ins[0]->dims() == in->dims(); - dims_size.emplace_back(in->dims().size()); } - const auto &cuda_ctx = - ctx.template device_context(); + if (no_broadcast_flag) { LaunchSameDimsElementwiseCudaKernel(cuda_ctx, ins, outs, func); } else { - int axis = ctx.HasAttr("axis") ? ctx.Attr("axis") : -1; - axis = axis == -1 - ? *std::max_element(dims_size.begin(), dims_size.end()) - - *std::min_element(dims_size.begin(), dims_size.end()) - : axis; LaunchBroadcastElementwiseCudaKernel(cuda_ctx, ins, outs, axis, func); } diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 05b78bcf6ad66..d19c75eaf3de0 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -64,20 +64,24 @@ namespace operators { * To pack the input and output tnesors into vector for * LaunchElementwiseCudaKernel */ -template -void PackTensorsIntoVector(const framework::ExecutionContext &ctx, - std::vector *ins, - std::vector *outs) { +template +int PackTensorsIntoVector(const framework::ExecutionContext &ctx, + std::vector *ins, + std::vector *outs) { + int axis = -1; auto *x = ctx.Input("X"); auto *y = ctx.Input("Y"); auto *z = ctx.Output("Out"); - z->mutable_data(ctx.GetPlace()); - ins->emplace_back(x); + z->mutable_data(ctx.GetPlace()); outs->emplace_back(z); + ins->emplace_back(x); if (y != nullptr) { ins->emplace_back(y); + axis = ctx.HasAttr("axis") ? ctx.Attr("axis") : -1; + axis = axis == -1 ? std::abs(y->dims().size() - x->dims().size()) : axis; } + return axis; } /* diff --git a/paddle/fluid/operators/elementwise/elementwise_pow_op.cu b/paddle/fluid/operators/elementwise/elementwise_pow_op.cu index 2ff2d40102047..5335f274ef126 100644 --- a/paddle/fluid/operators/elementwise/elementwise_pow_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_pow_op.cu @@ -42,9 +42,12 @@ class ElementwisePowKernel void Compute(const framework::ExecutionContext& ctx) const override { std::vector ins; std::vector outs; - PackTensorsIntoVector(ctx, &ins, &outs); + const auto& cuda_ctx = + ctx.template device_context(); + + int axis = PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - ctx, ins, &outs, CudaPowFunctor()); + cuda_ctx, ins, &outs, axis, CudaPowFunctor()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu index 3555d7dbf8d11..da9610243f7c4 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu @@ -36,9 +36,12 @@ class ElementwiseSubKernel void Compute(const framework::ExecutionContext& ctx) const override { std::vector ins; std::vector outs; - PackTensorsIntoVector(ctx, &ins, &outs); + const auto& cuda_ctx = + ctx.template device_context(); + + int axis = PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - ctx, ins, &outs, CudaSubFunctor()); + cuda_ctx, ins, &outs, axis, CudaSubFunctor()); } };