From 5040226e7d176efff36612978b6990535d1ac5b4 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Wed, 3 Jan 2024 02:02:29 +0000 Subject: [PATCH 1/5] custom op spmd rule register --- cmake/inference_lib.cmake | 8 + paddle/extension.h | 1 - paddle/phi/api/ext/op_meta_info.h | 13 + paddle/phi/api/ext/spmd_infer.h | 140 ++++ paddle/phi/api/lib/op_meta_info.cc | 15 + paddle/phi/infermeta/spmd_rules/rules.cc | 605 ++++++++++++++++++ paddle/phi/infermeta/spmd_rules/rules.h | 590 ----------------- test/cpp/auto_parallel/CMakeLists.txt | 24 +- .../auto_parallel/custom_op_spmd_rule_test.cc | 79 +++ 9 files changed, 878 insertions(+), 597 deletions(-) create mode 100644 paddle/phi/api/ext/spmd_infer.h create mode 100644 paddle/phi/infermeta/spmd_rules/rules.cc create mode 100644 test/cpp/auto_parallel/custom_op_spmd_rule_test.cc diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 517ac24cccc72..5d20dd2a90650 100755 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -328,10 +328,18 @@ copy( inference_lib_dist SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/visit_type.h DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/) + copy( inference_lib_dist SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/hostdevice.h DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/) + +copy( + inference_lib_dist + SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/distributed/auto_parallel/*.h + DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/distributed/auto_parallel/ +) + copy( inference_lib_dist SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/init_phi.h diff --git a/paddle/extension.h b/paddle/extension.h index c60ee269fb173..c534a13a22d97 100644 --- a/paddle/extension.h +++ b/paddle/extension.h @@ -22,5 +22,4 @@ limitations under the License. */ #endif // For initialization of DeviceContextPool and MemoryMethod #include "paddle/fluid/platform/init_phi.h" - static paddle::InitPhi g_init_phi; diff --git a/paddle/phi/api/ext/op_meta_info.h b/paddle/phi/api/ext/op_meta_info.h index e5273958504fd..9362b531c9320 100644 --- a/paddle/phi/api/ext/op_meta_info.h +++ b/paddle/phi/api/ext/op_meta_info.h @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/common/exception.h" #include "paddle/phi/api/include/dll_decl.h" #include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" #include "paddle/utils/any.h" #include "paddle/utils/none.h" #include "paddle/utils/optional.h" @@ -995,6 +996,12 @@ struct TrtGetOutputDimsFuncImpl { #endif ////////////////////// Op Meta Info ////////////////////// +class CustomSpmdInferTensorArg; +class CustomSpmdInferAttrArg; + +using InferSpmdFunc = phi::distributed::SpmdInfo (*)( + const std::vector& inputs, + const std::vector& attrs); class PADDLE_API OpMetaInfo { public: @@ -1023,6 +1030,9 @@ class PADDLE_API OpMetaInfo { // format: PD_INFER_DTYPE(...) OpMetaInfo& SetInferDtypeFn(InferDtypeFunc&& func); + // format: PD_INFER_SPMD_RULE(...) + OpMetaInfo& SetInferSpmdFn(InferSpmdFunc&& func); + #ifdef PADDLE_WITH_TENSORRT // format: PD_TRT_INFER_SHAPE(...) OpMetaInfo& SetTrtInferShapeFn(TrtGetOutputDimsFunc&& func); @@ -1045,6 +1055,7 @@ class PADDLE_API OpMetaInfo { KernelFunc kernel_fn_{nullptr}; InferShapeFunc infer_shape_fn_{nullptr}; InferDtypeFunc infer_dtype_fn_{nullptr}; + InferSpmdFunc infer_spmd_fn_{nullptr}; #ifdef PADDLE_WITH_TENSORRT TrtGetOutputDimsFunc trt_infer_shape_fn_{nullptr}; std::vector trt_supports_format_config_; @@ -1068,6 +1079,7 @@ class OpMetaInfoHelper { static const KernelFunc& GetKernelFn(const paddle::OpMetaInfo& info); static const InferShapeFunc& GetInferShapeFn(const paddle::OpMetaInfo& info); static const InferDtypeFunc& GetInferDtypeFn(const paddle::OpMetaInfo& info); + static const InferSpmdFunc& GetInferSpmdFn(const paddle::OpMetaInfo& info); #ifdef PADDLE_WITH_TENSORRT static const TrtGetOutputDimsFunc& GetTrtInferShapeFn( @@ -1108,6 +1120,7 @@ class PADDLE_API OpMetaInfoBuilder { OpMetaInfoBuilder& SetKernelFn(KernelFunc func); OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc func); OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc func); + OpMetaInfoBuilder& SetInferSpmdFn(InferSpmdFunc func); #ifdef PADDLE_WITH_TENSORRT OpMetaInfoBuilder& SetTrtInferShapeFn(TrtGetOutputDimsFunc func); diff --git a/paddle/phi/api/ext/spmd_infer.h b/paddle/phi/api/ext/spmd_infer.h new file mode 100644 index 0000000000000..df4d177054a9a --- /dev/null +++ b/paddle/phi/api/ext/spmd_infer.h @@ -0,0 +1,140 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace paddle { + +using CustomSpmdInferTensorArg = + paddle::variant>; + +using CustomSpmdInferAttrArg = paddle::any; +template +struct SpmdInferHelperTypeEnd {}; + +#define PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(attr_type) \ + template \ + struct SpmdInferHelper { \ + template \ + static phi::distributed::SpmdInfo InferSpmd( \ + const std::vector& inputs, \ + const std::vector& attrs, \ + const PreviousArgs&... pargs) { \ + try { \ + attr_type arg = paddle::any_cast(attrs[attr_idx]); \ + return SpmdInferHelper::template InferSpmd( \ + inputs, attrs, pargs..., arg); \ + } catch (paddle::bad_any_cast&) { \ + PD_THROW( \ + "Attribute cast error in custom operator SpmdInferFunc " \ + "function. " \ + "Expected " #attr_type \ + " value. SpmdInferFunc's attribute list must be exactly " \ + "same " \ + "as " \ + "Forward " \ + "KernelFn's attribute list except std::vector " \ + "attribute."); \ + } \ + } \ + } + +template +struct SpmdInferImpl; + +template +struct SpmdInferImpl { + static phi::distributed::SpmdInfo InferSpmd( + const std::vector& inputs, + const std::vector& attrs) { + return SpmdInferHelper>:: + template InferSpmd<0, 0>(inputs, attrs); + } + + private: + template + struct SpmdInferHelper; + + // Handle args for general tensor input case + template + struct SpmdInferHelper { + template + static phi::distributed::SpmdInfo InferSpmd( + const std::vector& inputs, + const std::vector& attrs, + PreviousArgs&... pargs) { + auto& arg = + PADDLE_GET_CONST(phi::distributed::DistMetaTensor, inputs[in_idx]); + return SpmdInferHelper::template InferSpmd( + inputs, attrs, pargs..., arg); + } + }; + + // Handle args for vector of Tensor input case + template + struct SpmdInferHelper&, + Tail...> { + template + static phi::distributed::SpmdInfo InferSpmd( + const std::vector& inputs, + const std::vector& attrs, + PreviousArgs&... pargs) { + auto& arg = PADDLE_GET_CONST( + std::vector, inputs[in_idx]); + return SpmdInferHelper::template InferSpmd( + inputs, attrs, pargs..., arg); + } + }; + + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(bool); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(int); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(float); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(int64_t); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::string&); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::vector&); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::vector&); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::vector&); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::vector&); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const bool&); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const int&); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const float&); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const int64_t&); + + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(std::string); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(std::vector); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(std::vector); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(std::vector); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::vector); + + // end: base template + template + struct SpmdInferHelper> { + template + static phi::distributed::SpmdInfo InferSpmd( + const std::vector& inputs, + const std::vector& attrs, + PreviousArgs&... pargs) { + return impl_fn(pargs...); + } + }; +}; + +#define PD_INFER_SPMD_RULE(...) \ + ::paddle::SpmdInferImpl::InferSpmd + +} // namespace paddle diff --git a/paddle/phi/api/lib/op_meta_info.cc b/paddle/phi/api/lib/op_meta_info.cc index 857c2930da45f..3cef3187193f7 100644 --- a/paddle/phi/api/lib/op_meta_info.cc +++ b/paddle/phi/api/lib/op_meta_info.cc @@ -358,6 +358,11 @@ OpMetaInfo& OpMetaInfo::SetInferDtypeFn(InferDtypeFunc&& func) { return *this; } +OpMetaInfo& OpMetaInfo::SetInferSpmdFn(InferSpmdFunc&& func) { + infer_spmd_fn_ = std::forward(func); + return *this; +} + #ifdef PADDLE_WITH_TENSORRT OpMetaInfo& OpMetaInfo::SetTrtInferShapeFn(TrtGetOutputDimsFunc&& func) { trt_infer_shape_fn_ = std::forward(func); @@ -407,6 +412,11 @@ const InferDtypeFunc& OpMetaInfoHelper::GetInferDtypeFn( return info.infer_dtype_fn_; } +const InferSpmdFunc& OpMetaInfoHelper::GetInferSpmdFn( + const paddle::OpMetaInfo& info) { + return info.infer_spmd_fn_; +} + #ifdef PADDLE_WITH_TENSORRT const TrtGetOutputDimsFunc& OpMetaInfoHelper::GetTrtInferShapeFn( const paddle::OpMetaInfo& info) { @@ -559,6 +569,11 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferDtypeFn(InferDtypeFunc func) { return *this; } +OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferSpmdFn(InferSpmdFunc func) { + info_ptr_->SetInferSpmdFn(std::forward(func)); + return *this; +} + #ifdef PADDLE_WITH_TENSORRT OpMetaInfoBuilder& OpMetaInfoBuilder::SetTrtInferShapeFn( TrtGetOutputDimsFunc func) { diff --git a/paddle/phi/infermeta/spmd_rules/rules.cc b/paddle/phi/infermeta/spmd_rules/rules.cc new file mode 100644 index 0000000000000..cef950dfd2d81 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/rules.cc @@ -0,0 +1,605 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/rules.h" + +/** + * Design Notes: + * + * 1. SPMD info is the special meta info of DistTensor, so we put Spmd infer + * functions in `infermeta` directory. + * + * 2. Since the infer functions of Spmd forward and backward are closely related + * and need to be registered together, we manage them together in one file. + * + * 3. SPMD rules are much smaller than infermeta function, and we manage files + * in operator units. + * + * 4. The previous registration used some compile-time regular matching methods, + * which was less flexible, and the registration of SPMD rules here is declare + * directly in the header file + */ + +namespace phi { +namespace distributed { + +// matmul rule +PD_REGISTER_SPMD_RULE(matmul, + PD_INFER_SPMD(phi::distributed::MatmulInferSpmd), + PD_INFER_SPMD(phi::distributed::MatmulInferSpmdReverse)); +PD_REGISTER_SPMD_RULE(matmul_v2, // static mode + PD_INFER_SPMD(phi::distributed::MatmulInferSpmd), + PD_INFER_SPMD(phi::distributed::MatmulInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE( + elementwise_unary, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE( + elementwise_binary, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); + +// default data parallel rule +PD_REGISTER_SPMD_RULE( + default_data_parallel, + PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmd), + PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + default_, + PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmd), + PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmdReverse)); + +// fused rope +PD_REGISTER_SPMD_RULE( + fused_rotary_position_embedding, + PD_INFER_SPMD(phi::distributed::FusedRopeInferSpmd), + PD_INFER_SPMD(phi::distributed::FusedRopeInferSpmdReverse)); + +// replicated rule /* for unittest */ +PD_REGISTER_SPMD_RULE( + replicated, + PD_INFER_SPMD(phi::distributed::ReplicatedInferSpmd), + PD_INFER_SPMD(phi::distributed::ReplicatedInferSpmdReverse)); + +// unsqueeze rule +PD_REGISTER_SPMD_RULE( + unsqueeze, + PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmd), + PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + unsqueeze2, + PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmd), + PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmdReverse)); + +// elementwise unary rule +PD_REGISTER_SPMD_RULE( + assign, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + hardswish, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + mish, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + relu6, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + swish, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + acos, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + acosh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + asin, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + asinh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + atan, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + atanh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + bernoulli, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + bitwise_not, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + ceil, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + celu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + clip, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + conj, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + cos, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + cosh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + digamma, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + elu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + erf, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + erfinv, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + exp, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + expm1, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + fill, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + floor, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + gelu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + hardshrink, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + hardsigmoid, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + hardtanh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + label_smooth, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + leaky_relu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + lgamma, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + log, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + log10, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + log1p, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + log2, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + logical_not, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + logit, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + logsigmoid, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + poisson, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + pow, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + reciprocal, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + relu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + round, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + rsqrt, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + scale, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + selu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + sigmoid, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + sign, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + silu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + sin, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + sinh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + softplus, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + softshrink, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + softsign, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + sqrt, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + square, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + stanh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + tan, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + tanh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + tanh_shrink, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + thresholded_relu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + trunc, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + dropout, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); + +// elementwise binary rule +PD_REGISTER_SPMD_RULE( + add, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + elementwise_add, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + divide, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + elementwise_div, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + elementwise_pow, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + floor_divide, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + fmin, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + heaviside, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + maximum, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + minimum, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + multiply, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + elementwise_mul, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + remainder, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + subtract, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + bitwise_and, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + bitwise_or, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + bitwise_xor, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + fmax, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + logical_and, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + logical_or, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + logical_xor, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE( + not_equal, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); + +// TODO(pkuzyc): add multiary elementwise rule + +// reduction rule +PD_REGISTER_SPMD_RULE( + all, + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + amax, + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + amin, + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + any, + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + frobenius_norm, + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + max, + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE( + reduce_max, + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE( + min, + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + prod, + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + sum, + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + reduce_sum, // static + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); + +// layer_norm +PD_REGISTER_SPMD_RULE( + layer_norm, + PD_INFER_SPMD(phi::distributed::LayerNormInferSpmd), + PD_INFER_SPMD(phi::distributed::LayerNormInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE( + flash_attention, + PD_INFER_SPMD(phi::distributed::FlashAttInferSpmdStatic), + PD_INFER_SPMD(phi::distributed::FlashAttInferSpmdReverse)); + +// reshape rule +PD_REGISTER_SPMD_RULE(reshape, + PD_INFER_SPMD(phi::distributed::ReshapeInferSpmd), + PD_INFER_SPMD(phi::distributed::ReshapeInferSpmdReverse)); +PD_REGISTER_SPMD_RULE(reshape2, + PD_INFER_SPMD(phi::distributed::ReshapeInferSpmd), + PD_INFER_SPMD(phi::distributed::ReshapeInferSpmdReverse)); + +// squeeze rule +PD_REGISTER_SPMD_RULE(squeeze, + PD_INFER_SPMD(phi::distributed::SqueezeInferSpmd), + PD_INFER_SPMD(phi::distributed::SqueezeInferSpmdReverse)); +// flatten rule +PD_REGISTER_SPMD_RULE(flatten, + PD_INFER_SPMD(phi::distributed::FlattenInferSpmd), + PD_INFER_SPMD(phi::distributed::FlattenInferSpmdReverse)); + +// embedding rule +PD_REGISTER_SPMD_RULE( + embedding, + PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmd), + PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + lookup_table_v2, + PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmd), + PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmdReverse)); + +// split rule +PD_REGISTER_SPMD_RULE(split, + PD_INFER_SPMD(phi::distributed::SplitInferSpmd), + PD_INFER_SPMD(phi::distributed::SplitInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + split_with_num, + PD_INFER_SPMD(phi::distributed::SplitWithNumInferSpmd), + PD_INFER_SPMD(phi::distributed::SplitWithNumInferSpmdReverse)); + +// slice rule +PD_REGISTER_SPMD_RULE(slice, + PD_INFER_SPMD(phi::distributed::SliceInferSpmd), + PD_INFER_SPMD(phi::distributed::SliceInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE(concat, + PD_INFER_SPMD(phi::distributed::ConcatInferSpmd), + PD_INFER_SPMD(phi::distributed::ConcatInferSpmdReverse)); + +// transpose rule +PD_REGISTER_SPMD_RULE( + transpose, + PD_INFER_SPMD(phi::distributed::TransposeInferSpmd), + PD_INFER_SPMD(phi::distributed::TransposeInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + transpose2, + PD_INFER_SPMD(phi::distributed::TransposeInferSpmd), + PD_INFER_SPMD(phi::distributed::TransposeInferSpmdReverse)); + +// softmax rule +PD_REGISTER_SPMD_RULE(softmax, + PD_INFER_SPMD(phi::distributed::SoftmaxInferSpmd), + PD_INFER_SPMD(phi::distributed::SoftmaxInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE(log_softmax, + PD_INFER_SPMD(phi::distributed::SoftmaxInferSpmd), + PD_INFER_SPMD(phi::distributed::SoftmaxInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE(where, + PD_INFER_SPMD(phi::distributed::WhereInferSpmd), + PD_INFER_SPMD(phi::distributed::WhereInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE(triu, + PD_INFER_SPMD(phi::distributed::TriuInferSpmd), + PD_INFER_SPMD(phi::distributed::TriuInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE( + tril_triu, + PD_INFER_SPMD(phi::distributed::TrilTriuInferSpmd), + PD_INFER_SPMD(phi::distributed::TrilTriuInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE(tile, + PD_INFER_SPMD(phi::distributed::TileInferSpmd), + PD_INFER_SPMD(phi::distributed::TileInferSpmdReverse)); + +// cross_entropy_with_softmax +PD_REGISTER_SPMD_RULE( + cross_entropy_with_softmax, + PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmd), + PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE( + softmax_with_cross_entropy, + PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmd), + PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmdReverse)); + +// fused_linear_param_grad_add got no reverse infer spmd rule +PD_REGISTER_SPMD_RULE( + fused_linear_param_grad_add, + PD_INFER_SPMD(phi::distributed::FusedLinearParamGradAddInferSpmd), + PD_INFER_SPMD( + phi::distributed::FusedLinearParamGradAddInferSpmdFakeReverse)); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index 1015f61802bc4..37eab9f57ba73 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -44,593 +44,3 @@ limitations under the License. */ #include "paddle/phi/infermeta/spmd_rules/triu.h" #include "paddle/phi/infermeta/spmd_rules/unsqueeze.h" #include "paddle/phi/infermeta/spmd_rules/where.h" - -/** - * Design Notes: - * - * 1. SPMD info is the special meta info of DistTensor, so we put Spmd infer - * functions in `infermeta` directory. - * - * 2. Since the infer functions of Spmd forward and backward are closely related - * and need to be registered together, we manage them together in one file. - * - * 3. SPMD rules are much smaller than infermeta function, and we manage files - * in operator units. - * - * 4. The previous registration used some compile-time regular matching methods, - * which was less flexible, and the registration of SPMD rules here is declare - * directly in the header file - */ - -namespace phi { -namespace distributed { - -// matmul rule -PD_REGISTER_SPMD_RULE(matmul, - PD_INFER_SPMD(phi::distributed::MatmulInferSpmd), - PD_INFER_SPMD(phi::distributed::MatmulInferSpmdReverse)); -PD_REGISTER_SPMD_RULE(matmul_v2, // static mode - PD_INFER_SPMD(phi::distributed::MatmulInferSpmd), - PD_INFER_SPMD(phi::distributed::MatmulInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE( - elementwise_unary, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE( - elementwise_binary, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); - -// default data parallel rule -PD_REGISTER_SPMD_RULE( - default_data_parallel, - PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmd), - PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - default_, - PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmd), - PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmdReverse)); - -// fused rope -PD_REGISTER_SPMD_RULE( - fused_rotary_position_embedding, - PD_INFER_SPMD(phi::distributed::FusedRopeInferSpmd), - PD_INFER_SPMD(phi::distributed::FusedRopeInferSpmdReverse)); - -// replicated rule /* for unittest */ -PD_REGISTER_SPMD_RULE( - replicated, - PD_INFER_SPMD(phi::distributed::ReplicatedInferSpmd), - PD_INFER_SPMD(phi::distributed::ReplicatedInferSpmdReverse)); - -// unsqueeze rule -PD_REGISTER_SPMD_RULE( - unsqueeze, - PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmd), - PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - unsqueeze2, - PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmd), - PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmdReverse)); - -// elementwise unary rule -PD_REGISTER_SPMD_RULE( - assign, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - hardswish, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - mish, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - relu6, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - swish, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - acos, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - acosh, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - asin, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - asinh, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - atan, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - atanh, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - bernoulli, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - bitwise_not, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - ceil, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - celu, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - clip, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - conj, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - cos, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - cosh, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - digamma, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - elu, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - erf, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - erfinv, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - exp, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - expm1, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - fill, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - floor, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - gelu, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - hardshrink, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - hardsigmoid, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - hardtanh, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - label_smooth, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - leaky_relu, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - lgamma, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - log, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - log10, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - log1p, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - log2, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - logical_not, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - logit, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - logsigmoid, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - poisson, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - pow, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - reciprocal, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - relu, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - round, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - rsqrt, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - scale, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - selu, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - sigmoid, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - sign, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - silu, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - sin, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - sinh, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - softplus, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - softshrink, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - softsign, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - sqrt, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - square, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - stanh, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - tan, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - tanh, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - tanh_shrink, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - thresholded_relu, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - trunc, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - dropout, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); - -// elementwise binary rule -PD_REGISTER_SPMD_RULE( - add, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - elementwise_add, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - divide, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - elementwise_div, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - elementwise_pow, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - floor_divide, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - fmin, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - heaviside, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - maximum, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - minimum, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - multiply, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - elementwise_mul, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - remainder, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - subtract, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - bitwise_and, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - bitwise_or, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - bitwise_xor, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - fmax, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - logical_and, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - logical_or, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - logical_xor, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE( - not_equal, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); - -// TODO(pkuzyc): add multiary elementwise rule - -// reduction rule -PD_REGISTER_SPMD_RULE( - all, - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - amax, - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - amin, - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - any, - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - frobenius_norm, - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - max, - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE( - reduce_max, - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE( - min, - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - prod, - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - sum, - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - reduce_sum, // static - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); - -// layer_norm -PD_REGISTER_SPMD_RULE( - layer_norm, - PD_INFER_SPMD(phi::distributed::LayerNormInferSpmd), - PD_INFER_SPMD(phi::distributed::LayerNormInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE( - flash_attention, - PD_INFER_SPMD(phi::distributed::FlashAttInferSpmdStatic), - PD_INFER_SPMD(phi::distributed::FlashAttInferSpmdReverse)); - -// reshape rule -PD_REGISTER_SPMD_RULE(reshape, - PD_INFER_SPMD(phi::distributed::ReshapeInferSpmd), - PD_INFER_SPMD(phi::distributed::ReshapeInferSpmdReverse)); -PD_REGISTER_SPMD_RULE(reshape2, - PD_INFER_SPMD(phi::distributed::ReshapeInferSpmd), - PD_INFER_SPMD(phi::distributed::ReshapeInferSpmdReverse)); - -// squeeze rule -PD_REGISTER_SPMD_RULE(squeeze, - PD_INFER_SPMD(phi::distributed::SqueezeInferSpmd), - PD_INFER_SPMD(phi::distributed::SqueezeInferSpmdReverse)); -// flatten rule -PD_REGISTER_SPMD_RULE(flatten, - PD_INFER_SPMD(phi::distributed::FlattenInferSpmd), - PD_INFER_SPMD(phi::distributed::FlattenInferSpmdReverse)); - -// embedding rule -PD_REGISTER_SPMD_RULE( - embedding, - PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmd), - PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - lookup_table_v2, - PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmd), - PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmdReverse)); - -// split rule -PD_REGISTER_SPMD_RULE(split, - PD_INFER_SPMD(phi::distributed::SplitInferSpmd), - PD_INFER_SPMD(phi::distributed::SplitInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - split_with_num, - PD_INFER_SPMD(phi::distributed::SplitWithNumInferSpmd), - PD_INFER_SPMD(phi::distributed::SplitWithNumInferSpmdReverse)); - -// slice rule -PD_REGISTER_SPMD_RULE(slice, - PD_INFER_SPMD(phi::distributed::SliceInferSpmd), - PD_INFER_SPMD(phi::distributed::SliceInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE(concat, - PD_INFER_SPMD(phi::distributed::ConcatInferSpmd), - PD_INFER_SPMD(phi::distributed::ConcatInferSpmdReverse)); - -// transpose rule -PD_REGISTER_SPMD_RULE( - transpose, - PD_INFER_SPMD(phi::distributed::TransposeInferSpmd), - PD_INFER_SPMD(phi::distributed::TransposeInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - transpose2, - PD_INFER_SPMD(phi::distributed::TransposeInferSpmd), - PD_INFER_SPMD(phi::distributed::TransposeInferSpmdReverse)); - -// softmax rule -PD_REGISTER_SPMD_RULE(softmax, - PD_INFER_SPMD(phi::distributed::SoftmaxInferSpmd), - PD_INFER_SPMD(phi::distributed::SoftmaxInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE(log_softmax, - PD_INFER_SPMD(phi::distributed::SoftmaxInferSpmd), - PD_INFER_SPMD(phi::distributed::SoftmaxInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE(where, - PD_INFER_SPMD(phi::distributed::WhereInferSpmd), - PD_INFER_SPMD(phi::distributed::WhereInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE(triu, - PD_INFER_SPMD(phi::distributed::TriuInferSpmd), - PD_INFER_SPMD(phi::distributed::TriuInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE( - tril_triu, - PD_INFER_SPMD(phi::distributed::TrilTriuInferSpmd), - PD_INFER_SPMD(phi::distributed::TrilTriuInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE(tile, - PD_INFER_SPMD(phi::distributed::TileInferSpmd), - PD_INFER_SPMD(phi::distributed::TileInferSpmdReverse)); - -// cross_entropy_with_softmax -PD_REGISTER_SPMD_RULE( - cross_entropy_with_softmax, - PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmd), - PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE( - softmax_with_cross_entropy, - PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmd), - PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmdReverse)); - -// fused_linear_param_grad_add got no reverse infer spmd rule -PD_REGISTER_SPMD_RULE( - fused_linear_param_grad_add, - PD_INFER_SPMD(phi::distributed::FusedLinearParamGradAddInferSpmd), - PD_INFER_SPMD( - phi::distributed::FusedLinearParamGradAddInferSpmdFakeReverse)); - -} // namespace distributed -} // namespace phi diff --git a/test/cpp/auto_parallel/CMakeLists.txt b/test/cpp/auto_parallel/CMakeLists.txt index 311958d2e1031..39a7cd28f1c6d 100644 --- a/test/cpp/auto_parallel/CMakeLists.txt +++ b/test/cpp/auto_parallel/CMakeLists.txt @@ -15,20 +15,32 @@ if(WITH_DISTRIBUTE) SRCS dist_tensor_test.cc DEPS phi common) - paddle_test(spmd_rule_test SRCS spmd_rule_test.cc DEPS spmd_rule_test_util) + paddle_test(spmd_rule_test SRCS spmd_rule_test.cc DEPS spmd_rule_test_util + spmd_rules) paddle_test(softmax_grad_spmd_rule_test SRCS softmax_grad_spmd_rule_test.cc - DEPS spmd_rule_test_util) + DEPS spmd_rule_test_util spmd_rules) paddle_test(tile_spmd_rule_test SRCS tile_spmd_rule_test.cc DEPS - spmd_rule_test_util) + spmd_rule_test_util spmd_rules) paddle_test( fused_linear_param_grad_add_spmd_rule_test SRCS - fused_linear_param_grad_add_spmd_rule_test.cc DEPS spmd_rule_test_util) + fused_linear_param_grad_add_spmd_rule_test.cc DEPS spmd_rule_test_util + spmd_rules) - paddle_test(cross_entropy_softmax_spmd_rule_test SRCS - cross_entropy_softmax_spmd_rule_test.cc DEPS spmd_rule_test_util) + paddle_test( + cross_entropy_softmax_spmd_rule_test SRCS + cross_entropy_softmax_spmd_rule_test.cc DEPS spmd_rule_test_util spmd_rules) + + paddle_test( + custom_op_spmd_rule_test + SRCS + custom_op_spmd_rule_test.cc + DEPS + spmd_rule_test_util + spmd_rules + phi) endif() diff --git a/test/cpp/auto_parallel/custom_op_spmd_rule_test.cc b/test/cpp/auto_parallel/custom_op_spmd_rule_test.cc new file mode 100644 index 0000000000000..66fe9217f897c --- /dev/null +++ b/test/cpp/auto_parallel/custom_op_spmd_rule_test.cc @@ -0,0 +1,79 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/api/ext/spmd_infer.h" +#include "test/cpp/auto_parallel/spmd_rule_test_util.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { +TEST(CustomOp, Ctor) { + // test with concat rule + std::vector mesh_shape = {2, 2}; + std::vector process_ids = {0, 1, 2, 3}; + std::vector dim_names = {"x", "y"}; + ProcessMesh process_mesh(mesh_shape, process_ids, dim_names); + + std::vector> shapes = { + {16, 16, 16}, {4, 16, 16}, {2, 16, 16}}; + std::vector> dim_mappings = { + {-1, 0, 1}, {-1, 1, 0}, {-1, -1, 0}}; + std::vector> partial_status = {{}, {}, {1}}; + + auto build_inputs = [&] { + std::vector inputs; + for (int i = 0; i < 3; i++) { + auto t_dist_attr = TensorDistAttr(); + t_dist_attr.set_process_mesh(process_mesh); + t_dist_attr.set_dims_mapping(dim_mappings[i]); + t_dist_attr.set_dynamic_dims({false, false, false}); + auto input = phi::distributed::DistMetaTensor( + common::make_ddim(shapes[i]), t_dist_attr); + inputs.push_back(input); + } + return inputs; + }; + + // test 1, inputs are aligned according to cost, and partial status is cleared + auto inputs = build_inputs(); + + auto forward_spmd_func = + PD_INFER_SPMD_RULE(phi::distributed::ConcatInferSpmd); + int axis = 0; + std::vector infer_inputs = {inputs}; + std::vector attrs = {axis}; + + auto infered_dist_attrs = forward_spmd_func(infer_inputs, attrs); + // list of tensor => sigle tensor + EXPECT_EQ(infered_dist_attrs.first.size(), static_cast(1)); + EXPECT_EQ(infered_dist_attrs.second.size(), static_cast(1)); + EXPECT_TRUE( + paddle::holds_alternative>( + infered_dist_attrs.first[0])); + EXPECT_TRUE(paddle::holds_alternative( + infered_dist_attrs.second[0])); + auto& inputs_infer1 = + PADDLE_GET_CONST(std::vector, + infered_dist_attrs.first[0]); + + for (auto e : inputs_infer1) { + check_dim_mapping(e, {-1, 1, 0}); + check_partial_dims(e, {}); + } + check_dim_mapping(infered_dist_attrs.second[0], {-1, 1, 0}); + check_partial_dims(infered_dist_attrs.second[0], {}); +} +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle From dd7f52fd1a325437efb8318027d01a35c2283139 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Wed, 3 Jan 2024 02:05:46 +0000 Subject: [PATCH 2/5] custom op spmd rule register --- paddle/extension.h | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/extension.h b/paddle/extension.h index c534a13a22d97..c60ee269fb173 100644 --- a/paddle/extension.h +++ b/paddle/extension.h @@ -22,4 +22,5 @@ limitations under the License. */ #endif // For initialization of DeviceContextPool and MemoryMethod #include "paddle/fluid/platform/init_phi.h" + static paddle::InitPhi g_init_phi; From d385c689536922cc215f8505dc463f59f3fe7119 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Wed, 3 Jan 2024 02:07:46 +0000 Subject: [PATCH 3/5] custom op spmd rule register --- cmake/inference_lib.cmake | 8 -------- 1 file changed, 8 deletions(-) diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 5d20dd2a90650..517ac24cccc72 100755 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -328,18 +328,10 @@ copy( inference_lib_dist SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/visit_type.h DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/) - copy( inference_lib_dist SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/hostdevice.h DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/) - -copy( - inference_lib_dist - SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/distributed/auto_parallel/*.h - DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/distributed/auto_parallel/ -) - copy( inference_lib_dist SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/init_phi.h From 549868a46471489fdcdb400cf4b0f36078229e38 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Wed, 3 Jan 2024 03:25:35 +0000 Subject: [PATCH 4/5] custom op spmd rule register --- cmake/inference_lib.cmake | 7 +++++++ paddle/phi/api/ext/op_meta_info.h | 7 +++++-- paddle/phi/core/distributed/type_defs.h | 1 + test/cpp/auto_parallel/custom_op_spmd_rule_test.cc | 10 ++++++++++ 4 files changed, 23 insertions(+), 2 deletions(-) diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 517ac24cccc72..e907540e92757 100755 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -328,6 +328,13 @@ copy( inference_lib_dist SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/visit_type.h DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/) + +copy( + inference_lib_dist + SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/distributed/type_defs.h + DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/distributed/ +) + copy( inference_lib_dist SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/hostdevice.h diff --git a/paddle/phi/api/ext/op_meta_info.h b/paddle/phi/api/ext/op_meta_info.h index 9362b531c9320..2b73e28b44858 100644 --- a/paddle/phi/api/ext/op_meta_info.h +++ b/paddle/phi/api/ext/op_meta_info.h @@ -996,8 +996,11 @@ struct TrtGetOutputDimsFuncImpl { #endif ////////////////////// Op Meta Info ////////////////////// -class CustomSpmdInferTensorArg; -class CustomSpmdInferAttrArg; + +using CustomSpmdInferTensorArg = + paddle::variant>; +using CustomSpmdInferAttrArg = paddle::any; using InferSpmdFunc = phi::distributed::SpmdInfo (*)( const std::vector& inputs, diff --git a/paddle/phi/core/distributed/type_defs.h b/paddle/phi/core/distributed/type_defs.h index 1b7035c1a4528..a629fccbf9fbb 100644 --- a/paddle/phi/core/distributed/type_defs.h +++ b/paddle/phi/core/distributed/type_defs.h @@ -23,6 +23,7 @@ namespace phi { namespace distributed { class TensorDistAttr; +class DistMetaTensor; using ArgDistAttr = paddle::variant>; diff --git a/test/cpp/auto_parallel/custom_op_spmd_rule_test.cc b/test/cpp/auto_parallel/custom_op_spmd_rule_test.cc index 66fe9217f897c..6e51634e1df49 100644 --- a/test/cpp/auto_parallel/custom_op_spmd_rule_test.cc +++ b/test/cpp/auto_parallel/custom_op_spmd_rule_test.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/phi/api/ext/spmd_infer.h" #include "test/cpp/auto_parallel/spmd_rule_test_util.h" @@ -74,6 +75,15 @@ TEST(CustomOp, Ctor) { check_dim_mapping(infered_dist_attrs.second[0], {-1, 1, 0}); check_partial_dims(infered_dist_attrs.second[0], {}); } + +TEST(CustomOp, Register) { + OpMetaInfoBuilder builder("test_custom_op_smpd", 0); + auto iter = OpMetaInfoMap::Instance().GetMap().find("test_custom_op_smpd"); + EXPECT_TRUE(iter != OpMetaInfoMap::Instance().GetMap().end()); + EXPECT_TRUE(OpMetaInfoHelper::GetInferSpmdFn(iter->second[0]) == nullptr); + builder.SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::ConcatInferSpmd)); + EXPECT_TRUE(OpMetaInfoHelper::GetInferSpmdFn(iter->second[0]) != nullptr); +} } // namespace auto_parallel } // namespace distributed } // namespace paddle From 2c8934150e51a43746dcd5384071654a4e193993 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Wed, 3 Jan 2024 12:55:45 +0000 Subject: [PATCH 5/5] polish --- tools/gpups_test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/gpups_test.sh b/tools/gpups_test.sh index 91cc6627dd7e2..a482de9074eac 100644 --- a/tools/gpups_test.sh +++ b/tools/gpups_test.sh @@ -124,7 +124,7 @@ set +e ctest --output-on-failure -R "($parallel_list)" --timeout 120 -j4 | tee -a $tmpfile; test ${PIPESTATUS[0]} -eq 0; EXIT_CODE_1=$? -ctest --output-on-failure -R "($serial_list)" --timeout 120 -j1 | tee -a $tmpfile; test ${PIPESTATUS[0]} -eq 0; +ctest --output-on-failure -R "($serial_list)" --timeout 180 -j1 | tee -a $tmpfile; test ${PIPESTATUS[0]} -eq 0; EXIT_CODE_2=$? set -e