From a98e9979363ca44142c8ebd1f3896721ac4b59ef Mon Sep 17 00:00:00 2001 From: Yichen Zhang <32740647+pkuzyc@users.noreply.github.com> Date: Wed, 20 Sep 2023 10:08:17 +0800 Subject: [PATCH] [Semi-Auto] Move and adapt elementwise rule to phi (#57197) * adapt general spmd rule * polish details * add new rules * bugfix --------- Co-authored-by: Chen Weihang Co-authored-by: liangjianzhong --- .../spmd_rules/elementwise_spmd_rule.cc | 213 ----------- .../spmd_rules/elementwise_spmd_rule.h | 41 -- .../auto_parallel/spmd_rules/rules.h | 89 ----- .../phi/infermeta/spmd_rules/elementwise.cc | 313 +++++++++++++++ paddle/phi/infermeta/spmd_rules/elementwise.h | 38 ++ paddle/phi/infermeta/spmd_rules/rules.h | 357 +++++++++++++++++- paddle/phi/infermeta/spmd_rules/utils.cc | 26 ++ paddle/phi/infermeta/spmd_rules/utils.h | 12 + .../spmd_rules/test_elementwise_rule.py | 171 ++++----- 9 files changed, 829 insertions(+), 431 deletions(-) delete mode 100644 paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.cc delete mode 100644 paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h create mode 100644 paddle/phi/infermeta/spmd_rules/elementwise.cc create mode 100644 paddle/phi/infermeta/spmd_rules/elementwise.h diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.cc deleted file mode 100644 index 7904627cf7fb7..0000000000000 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.cc +++ /dev/null @@ -1,213 +0,0 @@ -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h" -#include "paddle/phi/core/distributed/auto_parallel/utils.h" - -namespace paddle { -namespace distributed { -namespace auto_parallel { -using phi::distributed::auto_parallel::str_join; - -std::pair, std::vector> -ElementwiseSPMDRule::InferForward( - const std::vector& input_specs, - const paddle::framework::AttributeMap& attrs) { - // step0: Verify Input Args Based on Elementwise Logic - int64_t ninputs = input_specs.size(); - PADDLE_ENFORCE_GT( - ninputs, - 0, - phi::errors::InvalidArgument("The size of InputSpec in elementwise must " - "be greater than 0, but got [%d].", - ninputs)); - VerifySpecs(input_specs, "elementwise"); - - // step1: Build Einsum Notation - std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; - std::vector input_axes_vec; - int64_t max_ndim = 0; - for (int64_t i = 0; i < ninputs; ++i) { - int64_t ndim = input_specs[i].shape().size(); - if (ndim > max_ndim) { - max_ndim = ndim; - } - } - - // get einsum notation for each input, deal with broadcast - std::vector broadcast_axis_count(max_ndim, 0); - for (int64_t i = 0; i < ninputs; ++i) { - std::vector shape = input_specs[i].shape(); - int64_t ndim = shape.size(); - int64_t start_dim = max_ndim - ndim; - std::string axes_notation = GetBroadcastAxes(ndim, max_ndim, alphabet); - if (ninputs > 1) { - for (int64_t idim = 0; idim < max_ndim; idim++) { - // deal with the broadcast axes, record the - // input number at each broadcast axis - if (idim < start_dim) { - broadcast_axis_count[idim] += 1; - } else if (shape[idim - start_dim] == 1) { - broadcast_axis_count[idim] += 1; - // mark the broadcast axis to a special "1" - axes_notation[idim - start_dim] = '1'; - } - } - } - input_axes_vec.emplace_back(axes_notation); - } - - // get einsum notation for output - std::string output_axes = GetBroadcastAxes(max_ndim, max_ndim, alphabet); - for (int64_t idim = 0; idim < max_ndim; idim++) { - // if all inputs broadcast at this dimension, - // mark this axis in output as broadcast - if (broadcast_axis_count[idim] == ninputs) { - output_axes[idim] = '1'; - } - } - - // step2: Sharding Propogation - // step2.1: merge input shardings - std::vector>> axes_sharding_info; - axes_sharding_info = GetAxesDimsMappingPair(input_axes_vec, input_specs); - std::unordered_map axis_to_dim_map = - ShardingMergeForTensors(axes_sharding_info); - - // step2.2: infer output dimsmapping from merged input dimsmapping - std::vector output_dims_mapping = - GetDimsMappingForAxes(output_axes, axis_to_dim_map); - - // initialize output dist_attr's process_mesh, batch_dim and dynamic dims with - // input dist_attr. - TensorDistAttr output_dist_attr = - CopyTensorDistAttrForOutput(input_specs[0].dist_attr()); - output_dist_attr.set_dims_mapping(output_dims_mapping); - - std::vector new_input_dist_attrs; - std::vector output_dist_attrs; - - // step2.3: update inputs' dims mapping with merged one. - for (int64_t i = 0; i < ninputs; i++) { - const DistTensorSpec& spec = input_specs[i]; - TensorDistAttr dist_attr(spec.dist_attr()); - std::vector new_dims_mapping = - GetDimsMappingForAxes(input_axes_vec[i], axis_to_dim_map); - dist_attr.set_dims_mapping(new_dims_mapping); - new_input_dist_attrs.emplace_back(dist_attr); - } - - // step3: handle partial - // handle input tensor partial (TODO) - VLOG(4) << "ElementwiseSPMDRule InferForward:"; - for (int64_t i = 0; i < ninputs; i++) { - VLOG(4) << "Input" << std::to_string(i) << " shape: [" - << str_join(input_specs[i].shape()) << "] " - << "src_dims_mapping: [" << str_join(input_specs[i].dims_mapping()) - << "] " - << "dst_dims_mapping: [" - << str_join(new_input_dist_attrs[i].dims_mapping()) << "]"; - } - VLOG(4) << "Output dims_mapping: [" + str_join(output_dims_mapping) + "]\n\n"; - - output_dist_attrs.emplace_back(output_dist_attr); - return {new_input_dist_attrs, output_dist_attrs}; -} - -std::pair, std::vector> -ElementwiseSPMDRule::InferBackward( - const std::vector& input_specs, - const std::vector& output_specs, - const paddle::framework::AttributeMap& attrs) { - // step0: Verify Input Args Based on Elementwise Logic - int64_t ninputs = input_specs.size(); - int64_t noutputs = output_specs.size(); - PADDLE_ENFORCE_GT( - ninputs, - 0, - phi::errors::InvalidArgument("The size of InputSpec in elementwise must " - "be greater than 0, but got [%d].", - ninputs)); - PADDLE_ENFORCE_EQ( - noutputs, - 1, - phi::errors::InvalidArgument("The size of OutputSpec in elementwise must " - "be equal to 1, but got [%d].", - noutputs)); - VerifySpecs(output_specs, "elementwise_backward"); - - // step1: Build Einsum Notation - std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; - std::vector input_axes_vec; - int64_t output_ndim = output_specs[0].shape().size(); - std::string output_axes = - GetBroadcastAxes(output_ndim, output_ndim, alphabet); - - // get einsum notation for each input, deal with broadcast - for (int64_t i = 0; i < ninputs; ++i) { - const std::vector& shape = input_specs[i].shape(); - int64_t ndim = shape.size(); - int64_t start_dim = output_ndim - ndim; - std::string axes_notation = GetBroadcastAxes(ndim, output_ndim, alphabet); - if (ninputs > 1) { - for (int64_t idim = 0; idim < output_ndim; idim++) { - // deal with the broadcast axes - if (idim >= start_dim && shape[idim - start_dim] == 1) { - // mark the broadcast axis to a special "1" - axes_notation[idim - start_dim] = '1'; - } - } - } - input_axes_vec.emplace_back(axes_notation); - } - - // step2: Sharding Propogation - // step2.1: get dim mapping for each output axis - std::unordered_map axis_to_dim_map = - ShardingMergeForTensors({{output_axes, output_specs[0].dims_mapping()}}); - - // step2.2: infer input dims mappings from output dims mapping - // and get the input distributed attributes to return - std::vector input_dist_attrs; - std::vector output_dist_attrs; - for (int64_t i = 0; i < ninputs; ++i) { - const DistTensorSpec& spec = input_specs[i]; - TensorDistAttr dist_attr(spec.dist_attr()); - std::vector dims_mapping = - GetDimsMappingForAxes(input_axes_vec[i], axis_to_dim_map); - dist_attr.set_dims_mapping(dims_mapping); - input_dist_attrs.emplace_back(dist_attr); - } - - output_dist_attrs.emplace_back(output_specs[0].dist_attr()); - - // step3: handle partial (TODO) - - VLOG(4) << "ElementwiseSPMDRule InferBackward:"; - VLOG(4) << "Output shape: [" << str_join(output_specs[0].shape()) - << "] dims_mapping: [" << str_join(output_specs[0].dims_mapping()) - << "]"; - for (int64_t i = 0; i < ninputs; i++) { - VLOG(4) << "Input" << std::to_string(i) << " shape: [" - << str_join(input_specs[i].shape()) << "] " - << "dims_mapping: [" << str_join(input_dist_attrs[i].dims_mapping()) - << "]"; - } - - return {input_dist_attrs, output_dist_attrs}; -} - -} // namespace auto_parallel -} // namespace distributed -} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h deleted file mode 100644 index ed01d23252b21..0000000000000 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include -#include - -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" - -namespace paddle { -namespace distributed { -namespace auto_parallel { - -class ElementwiseSPMDRule : public SPMDRuleBase { - public: - std::pair, std::vector> - InferForward(const std::vector& input_specs, - const paddle::framework::AttributeMap& attrs) override; - - std::pair, std::vector> - InferBackward(const std::vector& input_specs, - const std::vector& output_specs, - const paddle::framework::AttributeMap& attrs) override; -}; -} // namespace auto_parallel -} // namespace distributed -} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h index ba468269b8230..c876fa59a7034 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h @@ -16,7 +16,6 @@ #include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.h" -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/embedding_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h" @@ -43,94 +42,6 @@ REGISTER_SPMD_RULE(min, ReductionSPMDRule); REGISTER_SPMD_RULE(prod, ReductionSPMDRule); REGISTER_SPMD_RULE(sum, ReductionSPMDRule); -// elementwise rule -REGISTER_SPMD_RULE(add, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(assign, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(assign_out_, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(divide, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(elementwise_pow, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(exponential_, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(floor_divide, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(fmin, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(hardswish, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(heaviside, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(maximum, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(minimum, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(mish, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(multiply, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(relu6, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(remainder, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(subtract, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(swish, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(acos, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(acosh, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(asin, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(asinh, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(atan, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(atanh, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(bernoulli, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(bitwise_and, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(bitwise_not, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(bitwise_or, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(bitwise_xor, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(ceil, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(celu, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(clip, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(conj, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(cos, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(cosh, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(det, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(digamma, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(elu, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(erf, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(erfinv, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(exp, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(expm1, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(fill, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(floor, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(fmax, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(gelu, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(hardshrink, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(hardsigmoid, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(hardtanh, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(label_smooth, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(leaky_relu, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(lgamma, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(log, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(log10, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(log1p, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(log2, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(logical_and, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(logical_not, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(logical_or, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(logical_xor, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(logit, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(logsigmoid, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(poisson, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(pow, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(reciprocal, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(relu, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(round, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(rsqrt, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(scale, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(selu, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(sigmoid, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(sign, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(silu, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(sin, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(sinh, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(softplus, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(softshrink, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(softsign, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(sqrt, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(square, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(stanh, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(tan, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(tanh, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(tanh_shrink, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(thresholded_relu, ElementwiseSPMDRule); -REGISTER_SPMD_RULE(trunc, ElementwiseSPMDRule); - // layer_norm rule REGISTER_SPMD_RULE(layer_norm, LayerNormSPMDRule); diff --git a/paddle/phi/infermeta/spmd_rules/elementwise.cc b/paddle/phi/infermeta/spmd_rules/elementwise.cc new file mode 100644 index 0000000000000..411c43de8cc41 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/elementwise.cc @@ -0,0 +1,313 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/elementwise.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +using phi::distributed::auto_parallel::str_join; + +////////////////// Utils Functions ////////////////// +std::string GetInputBroadcastNotation(const std::vector& shape, + const int max_ndim, + const std::string& alphabet, + std::vector* broadcast_axis_count) { + int ndim = shape.size(); + int start_dim = max_ndim - ndim; + std::string axes_notation = GetBroadcastAxes(ndim, max_ndim, alphabet); + + for (int idim = 0; idim < max_ndim; idim++) { + // deal with the broadcast axes, record the + // input number at each broadcast axis + if (idim < start_dim) { + (*broadcast_axis_count)[idim] += 1; + } else if (shape[idim - start_dim] == 1) { + (*broadcast_axis_count)[idim] += 1; + // mark the broadcast axis to a special "1" + axes_notation[idim - start_dim] = '1'; + } + } + return axes_notation; +} + +void GetBinaryNotations(const std::vector& x_shape, + const std::vector& y_shape, + std::string* x_axes, + std::string* y_axes, + std::string* out_axes) { + int x_ndim = x_shape.size(); + int y_ndim = y_shape.size(); + int max_ndim = std::max(x_ndim, y_ndim); + int ninputs = 2; + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + std::vector input_ndims({x_ndim, y_ndim}); + + // get einsum notation for each input, deal with broadcast + std::vector broadcast_axis_count(max_ndim, 0); + *x_axes = GetInputBroadcastNotation( + x_shape, max_ndim, alphabet, &broadcast_axis_count); + *y_axes = GetInputBroadcastNotation( + y_shape, max_ndim, alphabet, &broadcast_axis_count); + + // get einsum notation for output + *out_axes = GetBroadcastAxes(max_ndim, max_ndim, alphabet); + for (int64_t idim = 0; idim < max_ndim; idim++) { + // if all inputs broadcast at this dimension, + // mark this axis in output as broadcast + if (broadcast_axis_count[idim] == ninputs) { + (*out_axes)[idim] = '1'; + } + } +} + +SpmdInfo ElementwiseUnaryInferSpmd(const DistMetaTensor& x) { + // Step0: Verify Input Args Based on Elementwise Logic + auto x_shape = phi::vectorize(x.dims()); + int x_ndim = x_shape.size(); + TensorDistAttr x_dist_attr_src = x.dist_attr(); + std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); + PADDLE_ENFORCE_EQ(x_ndim, + x_dims_mapping.size(), + phi::errors::InvalidArgument( + "ElementwiseUnary, The Tensor X's rank [%d] and X's " + "dims_mapping size [%d] are not matched.", + x_ndim, + x_dims_mapping.size())); + + // Step1: Build Einsum Notation + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + std::string x_axes = GetBroadcastAxes(x_ndim, x_ndim, alphabet); + std::string out_axes = x_axes; + + // Step2: Sharding Propogation + // Step2.1: Merge input shardings + std::pair> axes_sharding_info( + x_axes, x_dims_mapping); + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors({axes_sharding_info}); + + // step2.2: Infer output dims mapping from merged input dims mapping + std::vector out_dims_mapping = + GetDimsMappingForAxes(out_axes, axis_to_dim_map); + + // initialize output dist_attr's process_mesh, batch_dim and dynamic dims with + // input dist_attr. + TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src); + out_dist_attr.set_dims_mapping(out_dims_mapping); + + // Step3: Handle partial + // Handle input tensor partial (TODO) + VLOG(4) << "ElementwiseSPMDRule InferForward:"; + VLOG(4) << "Input0 shape: [" << str_join(x_shape) << "] " + << "src_dims_mapping: [" << str_join(x_dims_mapping) << "] "; + VLOG(4) << "Output dims_mapping: [" + str_join(out_dims_mapping) + "]\n\n"; + + return {{x_dist_attr_src}, {out_dist_attr}}; +} + +SpmdInfo ElementwiseUnaryInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& out) { + // Step0: Verify Input Args Based on Elementwise Logic + auto x_shape = phi::vectorize(x.dims()); + int x_ndim = x_shape.size(); + auto out_shape = phi::vectorize(out.dims()); + int out_ndim = out_shape.size(); + TensorDistAttr out_dist_attr = out.dist_attr(); + std::vector out_dims_mapping = out_dist_attr.dims_mapping(); + PADDLE_ENFORCE_EQ( + out_ndim, + out_dims_mapping.size(), + phi::errors::InvalidArgument( + "ElementwiseUnaryReverse, The Tensor Out's rank [%d] and X's " + "dims_mapping size [%d] are not matched.", + out_ndim, + out_dims_mapping.size())); + PADDLE_ENFORCE_EQ( + out_ndim, + x_ndim, + phi::errors::InvalidArgument( + "ElementwiseUnaryReverse, The Tensor Out's rank [%d] and X's " + "rank [%d] are not matched.", + out_ndim, + x_ndim)); + + // Step1: Build Einsum Notation + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + std::string x_axes = GetBroadcastAxes(x_ndim, x_ndim, alphabet); + std::string out_axes = x_axes; + + // Step2: Sharding Propogation + // Step2.1: Merge output shardings + std::pair> axes_sharding_info( + out_axes, out_dims_mapping); + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors({axes_sharding_info}); + + // step2.2: Infer input dims mapping from merged input dims mapping + std::vector x_dims_mapping = + GetDimsMappingForAxes(x_axes, axis_to_dim_map); + TensorDistAttr x_dist_attr(x.dist_attr()); + x_dist_attr.set_dims_mapping(x_dims_mapping); + + // Step3: Handle partial + // Handle output tensor partial (TODO) + VLOG(4) << "ElementwiseSPMDRule InferReverse:"; + VLOG(4) << "Output0 shape: [" << str_join(out_shape) << "] " + << "dims_mapping: [" << str_join(out_dims_mapping) << "] "; + VLOG(4) << "Input0 dims_mapping: [" + str_join(x_dims_mapping) + "]\n\n"; + + return {{x_dist_attr}, {out_dist_attr}}; +} + +SpmdInfo ElementwiseBinaryInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& y) { + // Step0: Verify Input Args Based on Elementwise Logic + auto x_shape = phi::vectorize(x.dims()); + int x_ndim = x_shape.size(); + auto y_shape = phi::vectorize(y.dims()); + int y_ndim = y_shape.size(); + TensorDistAttr x_dist_attr_src = x.dist_attr(); + TensorDistAttr y_dist_attr_src = y.dist_attr(); + std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); + std::vector y_dims_mapping = y_dist_attr_src.dims_mapping(); + PADDLE_ENFORCE_EQ(x_ndim, + x_dims_mapping.size(), + phi::errors::InvalidArgument( + "ElementwiseBinary, The Tensor X's rank [%d] and X's " + "dims_mapping size [%d] are not matched.", + x_ndim, + x_dims_mapping.size())); + PADDLE_ENFORCE_EQ(y_ndim, + y_dims_mapping.size(), + phi::errors::InvalidArgument( + "ElementwiseBinary, The Tensor Y's rank [%d] and Y's " + "dims_mapping size [%d] are not matched.", + y_ndim, + y_dims_mapping.size())); + + // Step1: Build Einsum Notation + std::string x_axes, y_axes, out_axes; + GetBinaryNotations(x_shape, y_shape, &x_axes, &y_axes, &out_axes); + + // Step2: Sharding Propogation + // Step2.1: Merge input shardings + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors( + {{x_axes, x_dims_mapping}, {y_axes, y_dims_mapping}}); + + // Step2.2: Infer output dimsmapping from merged input dimsmapping + std::vector out_dims_mapping = + GetDimsMappingForAxes(out_axes, axis_to_dim_map); + + // initialize output dist_attr's process_mesh, batch_dim and dynamic dims with + // input dist_attr. + TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src); + out_dist_attr.set_dims_mapping(out_dims_mapping); + + // Step2.3: Update inputs' dims mapping with merged one. + TensorDistAttr x_dist_attr_dst(x_dist_attr_src); + TensorDistAttr y_dist_attr_dst(y_dist_attr_src); + x_dist_attr_dst.set_dims_mapping( + GetDimsMappingForAxes(x_axes, axis_to_dim_map)); + y_dist_attr_dst.set_dims_mapping( + GetDimsMappingForAxes(y_axes, axis_to_dim_map)); + + // Step3: Handle partial + // Handle input tensor partial (TODO) + VLOG(4) << "ElementwiseSPMDRule InferForward:"; + VLOG(4) << "Input0 shape: [" << str_join(x_shape) << "] " + << "src_dims_mapping: [" << str_join(x_dims_mapping) << "] " + << "dst_dims_mapping: [" << str_join(x_dist_attr_dst.dims_mapping()) + << "]"; + VLOG(4) << "Input1 shape: [" << str_join(y_shape) << "] " + << "src_dims_mapping: [" << str_join(y_dims_mapping) << "] " + << "dst_dims_mapping: [" << str_join(y_dist_attr_dst.dims_mapping()) + << "]"; + VLOG(4) << "Output dims_mapping: [" + str_join(out_dims_mapping) + "]\n\n"; + + return {{x_dist_attr_dst, y_dist_attr_dst}, {out_dist_attr}}; +} + +SpmdInfo ElementwiseBinaryInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& y, + const DistMetaTensor& out) { + // Step0: Verify Input Args Based on Elementwise Logic + auto x_shape = phi::vectorize(x.dims()); + int x_ndim = x_shape.size(); + auto y_shape = phi::vectorize(y.dims()); + int y_ndim = y_shape.size(); + auto out_shape = phi::vectorize(out.dims()); + int out_ndim = out_shape.size(); + int max_ndim = std::max(x_ndim, y_ndim); + TensorDistAttr out_dist_attr = out.dist_attr(); + std::vector out_dims_mapping = out_dist_attr.dims_mapping(); + PADDLE_ENFORCE_EQ( + out_ndim, + out_dims_mapping.size(), + phi::errors::InvalidArgument( + "ElementwiseBinaryReverse, The Tensor Out's rank [%d] and Out's " + "dims_mapping size [%d] are not matched.", + out_ndim, + out_dims_mapping.size())); + PADDLE_ENFORCE_EQ( + out_ndim, + max_ndim, + phi::errors::InvalidArgument( + "ElementwiseBinaryReverse, The Tensor Out's rank [%d] and the " + "max rank of inputs [%d] are not matched.", + out_ndim, + max_ndim)); + + // Step1: Build Einsum Notation + std::string x_axes, y_axes, out_axes; + GetBinaryNotations(x_shape, y_shape, &x_axes, &y_axes, &out_axes); + + // Step2: Sharding Propogation + // Step2.1: Merge output shardings + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors({{out_axes, out_dims_mapping}}); + + // Step2.2: Infer input dims mappings from merged output dims mapping + TensorDistAttr x_dist_attr_dst = x.dist_attr(); + TensorDistAttr y_dist_attr_dst = y.dist_attr(); + std::vector x_dims_mapping = + GetDimsMappingForAxes(x_axes, axis_to_dim_map); + std::vector y_dims_mapping = + GetDimsMappingForAxes(y_axes, axis_to_dim_map); + x_dist_attr_dst.set_dims_mapping(x_dims_mapping); + y_dist_attr_dst.set_dims_mapping(y_dims_mapping); + + // Step3: Handle partial + // Handle input tensor partial (TODO) + VLOG(4) << "ElementwiseSPMDRule InferReverse:"; + VLOG(4) << "Output shape: [" << str_join(out_shape) << "] dims_mapping: [" + << str_join(out_dims_mapping) << "]"; + VLOG(4) << "Input0 shape: [" << str_join(x_shape) << "] " + << "dims_mapping: [" << str_join(x_dims_mapping) << "]"; + VLOG(4) << "Input1 shape: [" << str_join(y_shape) << "] " + << "dims_mapping: [" << str_join(y_dims_mapping) << "]\n\n"; + + return {{x_dist_attr_dst, y_dist_attr_dst}, {out_dist_attr}}; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/elementwise.h b/paddle/phi/infermeta/spmd_rules/elementwise.h new file mode 100644 index 0000000000000..319d3ccbbdac1 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/elementwise.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { + +SpmdInfo ElementwiseUnaryInferSpmd(const DistMetaTensor& x); + +SpmdInfo ElementwiseUnaryInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& out); + +SpmdInfo ElementwiseBinaryInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& y); + +SpmdInfo ElementwiseBinaryInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& y, + const DistMetaTensor& out); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index acd28a224b775..4406e17495d14 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" #include "paddle/phi/infermeta/spmd_rules/default_data_parallel.h" +#include "paddle/phi/infermeta/spmd_rules/elementwise.h" #include "paddle/phi/infermeta/spmd_rules/matmul.h" #include "paddle/phi/infermeta/spmd_rules/replicated.h" @@ -29,12 +30,12 @@ limitations under the License. */ * 2. Since the infer functions of Spmd forward and backward are closely related * and need to be registered together, we manage them together in one file. * - * 3. SPMD rules are less than infermeta function, and we manage files by - * operator. + * 3. SPMD rules are much smaller than infermeta function, and we manage files + * in operator units. * * 4. The previous registration used some compile-time regular matching methods, * which was less flexible, and the registration of SPMD rules here is declare - * directly in the header file. + * directly in the header file */ namespace phi { @@ -57,5 +58,355 @@ PD_REGISTER_SPMD_RULE( PD_INFER_SPMD(phi::distributed::ReplicatedInferSpmd), PD_INFER_SPMD(phi::distributed::ReplicatedInferSpmdReverse)); +// elementwise unary rule +PD_REGISTER_SPMD_RULE( + assign, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + hardswish, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + mish, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + relu6, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + swish, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + acos, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + acosh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + asin, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + asinh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + atan, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + atanh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + bernoulli, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + bitwise_not, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + ceil, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + celu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + clip, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + conj, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + cos, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + cosh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + digamma, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + elu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + erf, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + erfinv, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + exp, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + expm1, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + fill, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + floor, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + gelu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + hardshrink, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + hardsigmoid, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + hardtanh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + label_smooth, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + leaky_relu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + lgamma, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + log, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + log10, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + log1p, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + log2, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + logical_not, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + logit, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + logsigmoid, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + poisson, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + pow, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + reciprocal, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + relu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + round, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + rsqrt, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + scale, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + selu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + sigmoid, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + sign, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + silu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + sin, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + sinh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + softplus, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + softshrink, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + softsign, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + sqrt, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + square, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + stanh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + tan, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + tanh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + tanh_shrink, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + thresholded_relu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + trunc, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); + +// elementwise binary rule +PD_REGISTER_SPMD_RULE( + add, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + elementwise_add, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + divide, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + elementwise_div, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + elementwise_pow, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + floor_divide, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + fmin, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + heaviside, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + maximum, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + minimum, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + multiply, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + elementwise_mul, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + remainder, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + subtract, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + bitwise_and, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + bitwise_or, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + bitwise_xor, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + fmax, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + logical_and, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + logical_or, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + logical_xor, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); + +// TODO(pkuzyc): add multiary elementwise rule + } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/utils.cc b/paddle/phi/infermeta/spmd_rules/utils.cc index 03c94a970823d..31bfba2a0d433 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.cc +++ b/paddle/phi/infermeta/spmd_rules/utils.cc @@ -164,5 +164,31 @@ TensorDistAttr GetReplicatedDistAttr(const TensorDistAttr& dist_attr) { return dst_dist_attr; } +std::vector GetDimsMappingForAxes( + const std::string& axes, + const std::unordered_map& axis_to_dim_map, + const bool unsharded_miss_axis) { + std::vector dims_mapping; + for (int64_t i = 0, n = static_cast(axes.size()); i < n; i++) { + std::string axis = axes.substr(i, 1); + if (axis == "1") { + dims_mapping.emplace_back(-1); + } else { + auto iter = axis_to_dim_map.find(axis); + if (iter == axis_to_dim_map.end()) { + if (unsharded_miss_axis) { + dims_mapping.emplace_back(-1); + } else { + phi::errors::InvalidArgument( + "Tensor axis [%s] of not in axis_to_dim_map.", axis); + } + } else { + dims_mapping.emplace_back(iter->second); + } + } + } + return dims_mapping; +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/utils.h b/paddle/phi/infermeta/spmd_rules/utils.h index 807e43bf5bee7..cd16a95bceac7 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.h +++ b/paddle/phi/infermeta/spmd_rules/utils.h @@ -126,5 +126,17 @@ struct VariadicSpmdRuleArgumentParser SpmdInfo InferBackward() { return Fn(inputs, outputs); } }; } // namespace detail + +// Get dims mapping for the given axes according to sharding information of +// the annotated axes after inferring forward or backward. The parameter axis +// stores the axes of the tensor. "1" is a special axis, for the axis "1", set +// its dims mapping to -1. +// if unsharded_miss_axis, "-1" is assigend to axes that has no key in +// axis_to_dim_map. +std::vector GetDimsMappingForAxes( + const std::string& axes, + const std::unordered_map& axis_to_dim_map, + const bool unsharded_miss_axis = false); + } // namespace distributed } // namespace phi diff --git a/test/auto_parallel/spmd_rules/test_elementwise_rule.py b/test/auto_parallel/spmd_rules/test_elementwise_rule.py index 59a121c4bf0b3..d7fe2b492815c 100644 --- a/test/auto_parallel/spmd_rules/test_elementwise_rule.py +++ b/test/auto_parallel/spmd_rules/test_elementwise_rule.py @@ -14,17 +14,18 @@ import unittest -from paddle.distributed.auto_parallel.static.completion import get_spmd_rule from paddle.distributed.auto_parallel.static.dist_attribute import ( DistTensorSpec, TensorDistAttr, ) from paddle.distributed.fleet import auto +from paddle.framework import core class TestElementwiseSPMDRule(unittest.TestCase): def setUp(self): - self.rule = get_spmd_rule("add") + self.unary_rule = core.get_phi_spmd_rule("relu") + self.binary_rule = core.get_phi_spmd_rule("add") x_shape = [64, 36] y_shape = [64, 36] @@ -42,14 +43,14 @@ def setUp(self): self.out_dist_tensor_spec = DistTensorSpec(self.x_dist_tensor_spec) - self.attrs = {} + self.attrs = [] def test_single_mesh_dim(self): # [0, -1], [-1, -1] --> [0, -1], [0, -1], [0, -1] self.x_dist_tensor_spec.set_dims_mapping([0, -1]) self.y_dist_tensor_spec.set_dims_mapping([-1, -1]) - result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + result_dist_attrs = self.binary_rule.infer_forward( + self.x_dist_tensor_spec, self.y_dist_tensor_spec ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -61,8 +62,8 @@ def test_single_mesh_dim(self): # [0, -1], [-1, 0] --> [0, -1], [0, -1], [0, -1] self.x_dist_tensor_spec.set_dims_mapping([0, -1]) self.y_dist_tensor_spec.set_dims_mapping([-1, 0]) - result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + result_dist_attrs = self.binary_rule.infer_forward( + self.x_dist_tensor_spec, self.y_dist_tensor_spec ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -75,8 +76,8 @@ def test_single_mesh_dim(self): self.x_dist_tensor_spec.set_dims_mapping([-1, -1]) self.y_dist_tensor_spec.set_dims_mapping([-1, -1]) - result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + result_dist_attrs = self.binary_rule.infer_forward( + self.x_dist_tensor_spec, self.y_dist_tensor_spec ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -88,8 +89,8 @@ def test_single_mesh_dim(self): # [-1, 0]--> [-1, 0], [-1, 0] self.x_dist_tensor_spec.set_dims_mapping([-1, 0]) - result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec], self.attrs + result_dist_attrs = self.unary_rule.infer_forward( + self.x_dist_tensor_spec ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -105,8 +106,8 @@ def test_single_mesh_dim_broadcast(self): self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1]) self.y_dist_tensor_spec.set_dims_mapping([-1]) - resulted_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + resulted_dist_attrs = self.binary_rule.infer_forward( + self.x_dist_tensor_spec, self.y_dist_tensor_spec ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1] @@ -123,8 +124,8 @@ def test_single_mesh_dim_broadcast(self): self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1]) self.y_dist_tensor_spec.set_dims_mapping([-1]) - resulted_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + resulted_dist_attrs = self.binary_rule.infer_forward( + self.x_dist_tensor_spec, self.y_dist_tensor_spec ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1] @@ -137,8 +138,8 @@ def test_single_mesh_dim_broadcast(self): self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0]) self.y_dist_tensor_spec.set_dims_mapping([-1]) - resulted_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + resulted_dist_attrs = self.binary_rule.infer_forward( + self.x_dist_tensor_spec, self.y_dist_tensor_spec ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1] @@ -150,8 +151,8 @@ def test_single_mesh_dim_broadcast(self): # [-1, -1, -1], [0] --> [-1, -1, 0], [0], [-1, -1, 0] self.x_dist_tensor_spec.set_dims_mapping([-1, -1, -1]) self.y_dist_tensor_spec.set_dims_mapping([0]) - resulted_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + resulted_dist_attrs = self.binary_rule.infer_forward( + self.x_dist_tensor_spec, self.y_dist_tensor_spec ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1] @@ -166,8 +167,8 @@ def test_single_mesh_dim_broadcast(self): self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1]) self.y_dist_tensor_spec.set_dims_mapping([-1, -1]) - resulted_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + resulted_dist_attrs = self.binary_rule.infer_forward( + self.x_dist_tensor_spec, self.y_dist_tensor_spec ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1] @@ -182,8 +183,8 @@ def test_single_mesh_dim_broadcast(self): self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1, -1]) self.y_dist_tensor_spec.set_dims_mapping([-1, -1, -1]) - resulted_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + resulted_dist_attrs = self.binary_rule.infer_forward( + self.x_dist_tensor_spec, self.y_dist_tensor_spec ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1] @@ -200,8 +201,8 @@ def test_single_mesh_dim_broadcast(self): self.x_dist_tensor_spec.set_dims_mapping([-1, -1, -1, -1]) self.y_dist_tensor_spec.set_dims_mapping([0, -1, -1]) - resulted_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + resulted_dist_attrs = self.binary_rule.infer_forward( + self.x_dist_tensor_spec, self.y_dist_tensor_spec ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1] @@ -225,8 +226,8 @@ def test_multi_mesh_dim(self): self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1]) self.y_dist_tensor_spec.set_dims_mapping([-1, -1, -1]) - resulted_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + resulted_dist_attrs = self.binary_rule.infer_forward( + self.x_dist_tensor_spec, self.y_dist_tensor_spec ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1] @@ -242,8 +243,8 @@ def test_multi_mesh_dim(self): # [0, -1, -1], [-1, 1, 0] --> [0, 1, -1], [0, 1, -1], [0, 1, -1] self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1]) self.y_dist_tensor_spec.set_dims_mapping([-1, 1, 0]) - resulted_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + resulted_dist_attrs = self.binary_rule.infer_forward( + self.x_dist_tensor_spec, self.y_dist_tensor_spec ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1] @@ -263,8 +264,8 @@ def test_multi_mesh_dim_broadcast(self): self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1]) self.y_dist_tensor_spec.set_dims_mapping([1]) - resulted_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + resulted_dist_attrs = self.binary_rule.infer_forward( + self.x_dist_tensor_spec, self.y_dist_tensor_spec ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1] @@ -281,8 +282,8 @@ def test_multi_mesh_dim_broadcast(self): self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1]) self.y_dist_tensor_spec.set_dims_mapping([0]) - resulted_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + resulted_dist_attrs = self.binary_rule.infer_forward( + self.x_dist_tensor_spec, self.y_dist_tensor_spec ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1] @@ -297,8 +298,8 @@ def test_multi_mesh_dim_broadcast(self): self.x_dist_tensor_spec.set_dims_mapping([-1, -1, -1, 1]) self.y_dist_tensor_spec.set_dims_mapping([0, -1, 1]) - resulted_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + resulted_dist_attrs = self.binary_rule.infer_forward( + self.x_dist_tensor_spec, self.y_dist_tensor_spec ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1] @@ -314,10 +315,10 @@ def test_multi_mesh_dim_broadcast(self): def test_backward_single_mesh_dim(self): # [0, -1] --> [0, -1], [0, -1], [0, -1] (output --> inputs, output) self.out_dist_tensor_spec.set_dims_mapping([0, -1]) - result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - [self.out_dist_tensor_spec], - self.attrs, + result_dist_attrs = self.binary_rule.infer_backward( + self.x_dist_tensor_spec, + self.y_dist_tensor_spec, + self.out_dist_tensor_spec, ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -329,10 +330,10 @@ def test_backward_single_mesh_dim(self): # [-1, -1] --> [-1, -1], [-1, -1], [-1, -1] (output --> inputs, output) self.out_dist_tensor_spec.set_dims_mapping([-1, -1]) - result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - [self.out_dist_tensor_spec], - self.attrs, + result_dist_attrs = self.binary_rule.infer_backward( + self.x_dist_tensor_spec, + self.y_dist_tensor_spec, + self.out_dist_tensor_spec, ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -344,8 +345,8 @@ def test_backward_single_mesh_dim(self): # [-1, 0]--> [-1, 0], [-1, 0] (output --> inputs, output) self.out_dist_tensor_spec.set_dims_mapping([-1, 0]) - result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs + result_dist_attrs = self.unary_rule.infer_backward( + self.x_dist_tensor_spec, self.out_dist_tensor_spec ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -361,10 +362,10 @@ def test_backward_single_mesh_dim_broadcast(self): # [0, -1, -1] --> [0, -1, -1], [-1], [0, -1, -1] (output --> inputs, output) self.out_dist_tensor_spec.set_dims_mapping([0, -1, -1]) - resulted_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - [self.out_dist_tensor_spec], - self.attrs, + resulted_dist_attrs = self.binary_rule.infer_backward( + self.x_dist_tensor_spec, + self.y_dist_tensor_spec, + self.out_dist_tensor_spec, ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1] @@ -380,10 +381,10 @@ def test_backward_single_mesh_dim_broadcast(self): # [-1, 0, -1] --> [-1, 0, -1], [-1], [-1, 0, -1] (output --> inputs, output) self.out_dist_tensor_spec.set_dims_mapping([-1, 0, -1]) - resulted_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - [self.out_dist_tensor_spec], - self.attrs, + resulted_dist_attrs = self.binary_rule.infer_backward( + self.x_dist_tensor_spec, + self.y_dist_tensor_spec, + self.out_dist_tensor_spec, ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1] @@ -395,10 +396,10 @@ def test_backward_single_mesh_dim_broadcast(self): # [-1, -1, 0] --> [-1, -1, 0], [0], [-1, -1, 0] (output --> inputs, output) self.out_dist_tensor_spec.set_dims_mapping([-1, -1, 0]) - resulted_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - [self.out_dist_tensor_spec], - self.attrs, + resulted_dist_attrs = self.binary_rule.infer_backward( + self.x_dist_tensor_spec, + self.y_dist_tensor_spec, + self.out_dist_tensor_spec, ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1] @@ -413,10 +414,10 @@ def test_backward_single_mesh_dim_broadcast(self): # [-1, 0, -1] --> [-1, 0, -1], [-1, -1], [-1, 0, -1] (output --> inputs, output) self.out_dist_tensor_spec.set_dims_mapping([-1, 0, -1]) - resulted_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - [self.out_dist_tensor_spec], - self.attrs, + resulted_dist_attrs = self.binary_rule.infer_backward( + self.x_dist_tensor_spec, + self.y_dist_tensor_spec, + self.out_dist_tensor_spec, ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1] @@ -431,10 +432,10 @@ def test_backward_single_mesh_dim_broadcast(self): # [0, -1, -1, -1] --> [0, -1, -1, -1], [-1, -1, -1], [0, -1, -1, -1] (output --> inputs, output) self.out_dist_tensor_spec.set_dims_mapping([0, -1, -1, -1]) - resulted_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - [self.out_dist_tensor_spec], - self.attrs, + resulted_dist_attrs = self.binary_rule.infer_backward( + self.x_dist_tensor_spec, + self.y_dist_tensor_spec, + self.out_dist_tensor_spec, ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1] @@ -450,10 +451,10 @@ def test_backward_single_mesh_dim_broadcast(self): # [-1, 0, -1, -1] --> [-1, -1, -1, -1], [0, -1, -1], [-1, 0, -1, -1] (output --> inputs, output) self.out_dist_tensor_spec.set_dims_mapping([-1, 0, -1, -1]) - resulted_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - [self.out_dist_tensor_spec], - self.attrs, + resulted_dist_attrs = self.binary_rule.infer_backward( + self.x_dist_tensor_spec, + self.y_dist_tensor_spec, + self.out_dist_tensor_spec, ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1] @@ -477,10 +478,10 @@ def test_backward_multi_mesh_dim(self): # [0, 1, -1] --> [0, 1, -1], [0, 1, -1], [0, 1, -1] (output --> inputs, output) self.out_dist_tensor_spec.set_dims_mapping([0, 1, -1]) - resulted_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - [self.out_dist_tensor_spec], - self.attrs, + resulted_dist_attrs = self.binary_rule.infer_backward( + self.x_dist_tensor_spec, + self.y_dist_tensor_spec, + self.out_dist_tensor_spec, ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1] @@ -504,10 +505,10 @@ def test_backward_multi_mesh_dim_broadcast(self): # [0, -1, 1] --> [0, -1, 1], [1], [0, -1, 1] (output --> inputs, output) self.out_dist_tensor_spec.set_dims_mapping([0, -1, 1]) - resulted_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - [self.out_dist_tensor_spec], - self.attrs, + resulted_dist_attrs = self.binary_rule.infer_backward( + self.x_dist_tensor_spec, + self.y_dist_tensor_spec, + self.out_dist_tensor_spec, ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1] @@ -523,10 +524,10 @@ def test_backward_multi_mesh_dim_broadcast(self): # [0, 1, -1] --> [0, 1, -1], [-1], [0, 1, -1] (output --> inputs, output) self.out_dist_tensor_spec.set_dims_mapping([0, 1, -1]) - resulted_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - [self.out_dist_tensor_spec], - self.attrs, + resulted_dist_attrs = self.binary_rule.infer_backward( + self.x_dist_tensor_spec, + self.y_dist_tensor_spec, + self.out_dist_tensor_spec, ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1] @@ -542,10 +543,10 @@ def test_backward_multi_mesh_dim_broadcast(self): # [-1, 0, -1, 1] --> [-1, -1, -1, 1], [0, -1, 1], [-1, 0, -1, 1] (output --> inputs, output) self.out_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1]) - resulted_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - [self.out_dist_tensor_spec], - self.attrs, + resulted_dist_attrs = self.binary_rule.infer_backward( + self.x_dist_tensor_spec, + self.y_dist_tensor_spec, + self.out_dist_tensor_spec, ) infered_input_dist_attrs = resulted_dist_attrs[0] infered_output_dist_attrs = resulted_dist_attrs[1]