From 377078ab65a612a9691062a4d952d77f00847865 Mon Sep 17 00:00:00 2001 From: Yichen Zhang Date: Tue, 19 Sep 2023 17:29:00 +0800 Subject: [PATCH 1/3] adapt reshape rule spmd rule to phi --- .../spmd_rules/reshape_spmd_rule.h | 41 ----- .../auto_parallel/spmd_rules/rules.h | 4 - .../infermeta}/spmd_rules/dim_trans.cc | 19 +-- .../infermeta}/spmd_rules/dim_trans.h | 11 +- .../infermeta/spmd_rules/reshape.cc} | 159 +++++++++--------- paddle/phi/infermeta/spmd_rules/reshape.h | 32 ++++ paddle/phi/infermeta/spmd_rules/rules.h | 6 + .../spmd_rules/test_reshape_rule.py | 4 +- 8 files changed, 131 insertions(+), 145 deletions(-) delete mode 100644 paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h rename paddle/{fluid/distributed/auto_parallel => phi/infermeta}/spmd_rules/dim_trans.cc (95%) rename paddle/{fluid/distributed/auto_parallel => phi/infermeta}/spmd_rules/dim_trans.h (94%) rename paddle/{fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.cc => phi/infermeta/spmd_rules/reshape.cc} (56%) create mode 100644 paddle/phi/infermeta/spmd_rules/reshape.h diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h deleted file mode 100644 index 737455e0be6c8b..00000000000000 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include -#include - -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" - -namespace paddle { -namespace distributed { -namespace auto_parallel { - -class ReshapeSPMDRule : public SPMDRuleBase { - public: - std::pair, std::vector> - InferForward(const std::vector& input_specs, - const paddle::framework::AttributeMap& attrs) override; - - std::pair, std::vector> - InferBackward(const std::vector& input_specs, - const std::vector& output_specs, - const paddle::framework::AttributeMap& attrs) override; -}; -} // namespace auto_parallel -} // namespace distributed -} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h index 71f939ffd37850..46806ce4daab72 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h @@ -18,7 +18,6 @@ #include "paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/embedding_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h" -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.h" @@ -50,9 +49,6 @@ REGISTER_SPMD_RULE(split_with_num, SplitSPMDRule); // transpose rule REGISTER_SPMD_RULE(transpose, TransposeSPMDRule); -// reshape rule -REGISTER_SPMD_RULE(reshape, ReshapeSPMDRule); - } // namespace auto_parallel } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.cc b/paddle/phi/infermeta/spmd_rules/dim_trans.cc similarity index 95% rename from paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.cc rename to paddle/phi/infermeta/spmd_rules/dim_trans.cc index 56aab1ec6093f9..d781cc415ae4c4 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.cc +++ b/paddle/phi/infermeta/spmd_rules/dim_trans.cc @@ -12,17 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.h" +#include "paddle/phi/infermeta/spmd_rules/dim_trans.h" #include #include #include #include -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" #include "paddle/phi/core/enforce.h" -namespace paddle { +namespace phi { namespace distributed { -namespace auto_parallel { static std::vector all_dim_trans; @@ -289,10 +288,11 @@ void GetUsedInputDim(DimTrans* dim_trans, std::set* seen_dims) { } std::vector> InferFromDimTrans( - const DistTensorSpec& input_spec, const std::vector& dim_trans) { - const std::vector& input_shape = input_spec.shape(); - const std::vector& input_dims_mapping = input_spec.dims_mapping(); - const ProcessMesh& mesh = input_spec.dist_attr().process_mesh(); + const DistMetaTensor& input, const std::vector& dim_trans) { + std::vector input_shape = phi::vectorize(input.dims()); + const std::vector& input_dims_mapping = + input.dist_attr().dims_mapping(); + const ProcessMesh& mesh = input.dist_attr().process_mesh(); const std::vector& mesh_shape = mesh.shape(); std::set sharded_input_dims; @@ -354,6 +354,5 @@ std::vector> InferFromDimTrans( return {new_input_dims_mapping, out_dims_mapping}; } -} // namespace auto_parallel } // namespace distributed -} // namespace paddle +} // namespace phi diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.h b/paddle/phi/infermeta/spmd_rules/dim_trans.h similarity index 94% rename from paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.h rename to paddle/phi/infermeta/spmd_rules/dim_trans.h index f196a0266d5d40..58ce07d0095c10 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.h +++ b/paddle/phi/infermeta/spmd_rules/dim_trans.h @@ -17,11 +17,11 @@ limitations under the License. */ #include #include -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" -namespace paddle { +namespace phi { namespace distributed { -namespace auto_parallel { // This is a base class to describe how each dimension in output tensor // is transformed from input tensor's axes. The transformation includes @@ -153,8 +153,7 @@ DimTrans* make_split(DimTrans* dim, // leftmost output split axis can be sharded when its shape can be divisible // by the mesh dimension. std::vector> InferFromDimTrans( - const DistTensorSpec& input_spec, const std::vector& dim_trans); + const DistMetaTensor& input_spec, const std::vector& dim_trans); -} // namespace auto_parallel } // namespace distributed -} // namespace paddle +} // namespace phi diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.cc b/paddle/phi/infermeta/spmd_rules/reshape.cc similarity index 56% rename from paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.cc rename to paddle/phi/infermeta/spmd_rules/reshape.cc index 5e0c2c5a92c5b7..64643d808c6138 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.cc +++ b/paddle/phi/infermeta/spmd_rules/reshape.cc @@ -12,14 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h" +#include "paddle/phi/infermeta/spmd_rules/reshape.h" #include -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" #include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/dim_trans.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" -namespace paddle { +namespace phi { namespace distributed { -namespace auto_parallel { using phi::distributed::auto_parallel::str_join; @@ -135,29 +140,27 @@ std::vector MakeReshapeDimTrans( return ret; } -// -std::pair, std::vector> -paddle::distributed::auto_parallel::ReshapeSPMDRule::InferForward( - const std::vector& input_specs, - const paddle::framework::AttributeMap& attrs) { - // step0: Verify Input Args Based on Reshape Logic - int64_t ninputs = static_cast(input_specs.size()); +SpmdInfo ReshapeInferSpmd(const DistMetaTensor& x, + const std::vector& shape) { + // Step0: Verify input args based on reshape logic + auto src_shape = phi::vectorize(x.dims()); + int x_ndim = src_shape.size(); + auto x_dist_attr_src = x.dist_attr(); + std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); PADDLE_ENFORCE_EQ( - ninputs, - 1, - phi::errors::InvalidArgument("The size of InputSpec in reshape must " - "be equal to 1, but got [%d].", - ninputs)); - VerifySpecs(input_specs, "reshape"); - - // step1: build the transformation from - // original shape to target shape - std::vector src_shape = input_specs[0].shape(); - std::vector tgt_shape = - ExtractAttr>("shape", attrs); + x_ndim, + x_dims_mapping.size(), + phi::errors::InvalidArgument("The Tensor X's rank [%d] and X's " + "dims_mapping size [%d] are not matched.", + x_ndim, + x_dims_mapping.size())); + + // Step1: Build the transformation from + // the original shape to the target shape // handle the '0' values in target shape, '0' indicates // that the target shape is equal to the source shape + std::vector tgt_shape(shape); for (int64_t i = 0, n = static_cast(tgt_shape.size()); i < n; i++) { if (tgt_shape[i] == 0) { tgt_shape[i] = src_shape[i]; @@ -166,96 +169,88 @@ paddle::distributed::auto_parallel::ReshapeSPMDRule::InferForward( std::vector trans = MakeReshapeDimTrans(src_shape, tgt_shape); - // step2: infer the dims mapping of input (if reshard is + // Step2: Infer the dims mapping of input (if reshard is // needed) and output from the dimension transformation. std::vector> dims_mapping_vec = - InferFromDimTrans(input_specs[0], trans); + InferFromDimTrans(x, trans); - // step3: update the dist attributes of input - // and output with the inferred dims mapping - TensorDistAttr new_input_dist_attr(input_specs[0].dist_attr()); - new_input_dist_attr.set_dims_mapping(dims_mapping_vec[0]); - TensorDistAttr output_dist_attr(input_specs[0].dist_attr()); - output_dist_attr.set_dims_mapping(dims_mapping_vec[1]); + // Step3: Update the dist attributes of input + // and output with the inferred dims mapping. + TensorDistAttr x_dist_attr_dst(x_dist_attr_src); + x_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]); + TensorDistAttr out_dist_attr(x_dist_attr_src); + out_dist_attr.set_dims_mapping(dims_mapping_vec[1]); - VLOG(4) << "Reshape: input_shape: [" << str_join(src_shape) - << "] output_shape: [" << str_join(tgt_shape) << "]"; + VLOG(4) << "ReshapeInferSpmd: X shape: [" << str_join(src_shape) + << "] Out shape: [" << str_join(tgt_shape) << "]"; VLOG(4) << "Transformation from input to output:"; for (int64_t i = 0, n = static_cast(trans.size()); i < n; i++) { DimTrans* t = trans[i]; - VLOG(4) << "\tOutput axis " << i << ": " << t->to_string(); + VLOG(4) << "\tOut axis[" << i << "]: " << t->to_string(); } - VLOG(4) << "input_dims_mapping: [" << str_join(dims_mapping_vec[0]) - << "] output_dims_mapping: [" << str_join(dims_mapping_vec[1]) + VLOG(4) << "X dims_mapping_src: [" << str_join(x_dims_mapping) + << "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) + << "]\n Out dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n"; CleanUp(); - return {{new_input_dist_attr}, {output_dist_attr}}; + return {{x_dist_attr_dst}, {out_dist_attr}}; } -std::pair, std::vector> -paddle::distributed::auto_parallel::ReshapeSPMDRule::InferBackward( - const std::vector& input_specs, - const std::vector& output_specs, - const paddle::framework::AttributeMap& attrs) { - // step0: Verify Input Args Based on Reshape Logic - int64_t ninputs = input_specs.size(); - int64_t noutputs = output_specs.size(); +SpmdInfo ReshapeInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& out, + const std::vector& shape) { + // Step0: Verify input args based on reshape logic + auto x_shape = phi::vectorize(x.dims()); + auto out_shape = phi::vectorize(out.dims()); + int out_ndim = out_shape.size(); + auto out_dist_attr_src = out.dist_attr(); + std::vector out_dims_mapping = out_dist_attr_src.dims_mapping(); PADDLE_ENFORCE_EQ( - ninputs, - 1, - phi::errors::InvalidArgument("The size of InputSpec in reshape must " - "be equal to 1, but got [%d].", - ninputs)); - PADDLE_ENFORCE_EQ( - noutputs, - 1, - phi::errors::InvalidArgument("The size of OutputSpec in reshape must " - "be equal to 1, but got [%d].", - noutputs)); - VerifySpecs(output_specs, "reshape"); - - // step1: build the transformation from the output shape - // to original shape. Inferbackward infers the dims mapping + out_ndim, + out_dims_mapping.size(), + phi::errors::InvalidArgument("The Tensor Out's rank [%d] and Out's " + "dims_mapping size [%d] are not matched.", + out_ndim, + out_dims_mapping.size())); + + // Step1: Build the transformation from the output shape + // to original shape. This function infers the dims mapping // from output to input, we first get the transformation // from output to input so that we can infer the dims mapping // with the map from output axes to input axes. - // Shapes in Inferbackward don't contain -1 or 0, so they will - // not be modified and we can use ref here. - const std::vector& output_shape = output_specs[0].shape(); - const std::vector& input_shape = input_specs[0].shape(); + // Shapes in InferSpmdReverse don't contain -1 or 0, so they will + // not be modified and we can directly use them. + std::vector trans = MakeReshapeDimTrans(out_shape, x_shape); - std::vector trans = MakeReshapeDimTrans(output_shape, input_shape); - - // step2: infer the dims mapping of input with + // Step2: Infer the dims mapping of input with // output's dims_mapping and the transformation. std::vector> dims_mapping_vec = - InferFromDimTrans(output_specs[0], trans); + InferFromDimTrans(out, trans); - // step3: update the dist attributes of input + // Step3: Update the dist attributes of input // and output with the inferred dims mapping - TensorDistAttr new_output_dist_attr(output_specs[0].dist_attr()); - new_output_dist_attr.set_dims_mapping(dims_mapping_vec[0]); - TensorDistAttr input_dist_attr(input_specs[0].dist_attr()); - input_dist_attr.set_dims_mapping(dims_mapping_vec[1]); + TensorDistAttr out_dist_attr_dst(out_dist_attr_src); + out_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]); + TensorDistAttr x_dist_attr(x.dist_attr()); + x_dist_attr.set_dims_mapping(dims_mapping_vec[1]); - VLOG(4) << "Reshape Inferbackward: output_shape: [" << str_join(output_shape) - << "] input_shape: [" << str_join(input_shape) << "]"; + VLOG(4) << "ReshapeInferSpmdReverse: Out shape: [" << str_join(out_shape) + << "] X shape: [" << str_join(x_shape) << "]"; VLOG(4) << "Transformation from output to input:"; for (int64_t i = 0, n = trans.size(); i < n; i++) { DimTrans* t = trans[i]; - VLOG(4) << "\tInput axis " << i << ": " << t->to_string(); + VLOG(4) << "\tX axis[" << i << "]: " << t->to_string(); } - VLOG(4) << "input_dims_mapping: [" << str_join(dims_mapping_vec[1]) - << "] output_dims_mapping: [" << str_join(dims_mapping_vec[0]) - << "]\n\n"; + VLOG(4) << "Out dims_mapping_src: [" << str_join(out_dims_mapping) << "] " + << "dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) << "]"; + VLOG(4) << "X dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n"; CleanUp(); - return {{input_dist_attr}, {new_output_dist_attr}}; + return {{x_dist_attr}, {out_dist_attr_dst}}; } -} // namespace auto_parallel } // namespace distributed -} // namespace paddle +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/reshape.h b/paddle/phi/infermeta/spmd_rules/reshape.h new file mode 100644 index 00000000000000..394f31c2b8cf30 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/reshape.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { + +SpmdInfo ReshapeInferSpmd(const DistMetaTensor& x, + const std::vector& shape); + +SpmdInfo ReshapeInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& out, + const std::vector& shape); +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index cb01b8996a8c91..0d14f7da7abe93 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/phi/infermeta/spmd_rules/matmul.h" #include "paddle/phi/infermeta/spmd_rules/reduction.h" #include "paddle/phi/infermeta/spmd_rules/replicated.h" +#include "paddle/phi/infermeta/spmd_rules/reshape.h" /** * Design Notes: @@ -464,5 +465,10 @@ PD_REGISTER_SPMD_RULE( PD_INFER_SPMD(phi::distributed::LayerNormInferSpmd), PD_INFER_SPMD(phi::distributed::LayerNormInferSpmdReverse)); +// reshape rule +PD_REGISTER_SPMD_RULE(reshape, + PD_INFER_SPMD(phi::distributed::ReshapeInferSpmd), + PD_INFER_SPMD(phi::distributed::ReshapeInferSpmdReverse)); + } // namespace distributed } // namespace phi diff --git a/test/auto_parallel/spmd_rules/test_reshape_rule.py b/test/auto_parallel/spmd_rules/test_reshape_rule.py index dd7c248ca42fbd..1b9ae19b0348f3 100644 --- a/test/auto_parallel/spmd_rules/test_reshape_rule.py +++ b/test/auto_parallel/spmd_rules/test_reshape_rule.py @@ -14,17 +14,17 @@ import unittest -from paddle.distributed.auto_parallel.static.completion import get_spmd_rule from paddle.distributed.auto_parallel.static.dist_attribute import ( DistTensorSpec, TensorDistAttr, ) from paddle.distributed.fleet import auto +from paddle.framework import core class TestReshapeSPMDRule(unittest.TestCase): def setUp(self): - self.rule = get_spmd_rule("reshape") + self.rule = core.get_phi_spmd_rule("reshape") x_shape = [6, 12, 48, 24] process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]]) From 8c2732842dd947d1e35df271ebb923cc053d55a4 Mon Sep 17 00:00:00 2001 From: Yichen Zhang Date: Thu, 21 Sep 2023 10:50:25 +0800 Subject: [PATCH 2/3] fix the bug when op attribute is vector type --- .../auto_parallel/inferspmd_utils.cc | 23 +++++- paddle/phi/infermeta/spmd_rules/reshape.cc | 6 +- .../spmd_rules/test_reshape_rule.py | 79 ++++++++++--------- 3 files changed, 66 insertions(+), 42 deletions(-) diff --git a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc index 24030b5d0ffa80..485e2f09a42e94 100644 --- a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc +++ b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc @@ -86,7 +86,28 @@ std::vector InferSpmdContext::AttrAt(size_t idx) const { } catch (paddle::bad_variant_access const& e) { PADDLE_THROW(phi::errors::InvalidArgument( "Attribute cast error in InferSpmd Context, the input attr type is " - "`%s`, but the expected attribute type is `bool`.", + "`%s`, but the expected attribute type is `std::vector`.", + attrs_.at(idx).type().name())); + } +} + +template <> +std::vector InferSpmdContext::AttrAt(size_t idx) const { + try { + auto attr = attrs_.at(idx); + if (attr.type() == typeid(std::vector)) { + std::vector val = PADDLE_GET_CONST(std::vector, attr); + return std::vector(val.begin(), val.end()); + } else if (attr.type() == typeid(std::vector)) { + std::vector val = PADDLE_GET_CONST(std::vector, attr); + return std::vector(val.begin(), val.end()); + } else { + return PADDLE_GET_CONST(std::vector, attr); + } + } catch (paddle::bad_variant_access const& e) { + PADDLE_THROW(phi::errors::InvalidArgument( + "Attribute cast error in InferSpmd Context, the input attr type is " + "`%s`, but the expected attribute type is `std::vector`.", attrs_.at(idx).type().name())); } } diff --git a/paddle/phi/infermeta/spmd_rules/reshape.cc b/paddle/phi/infermeta/spmd_rules/reshape.cc index 64643d808c6138..4c95b846c87d03 100644 --- a/paddle/phi/infermeta/spmd_rules/reshape.cc +++ b/paddle/phi/infermeta/spmd_rules/reshape.cc @@ -76,9 +76,9 @@ std::vector MakeReshapeDimTrans( std::vector inferred_tgt_shape = InferTargetShape(tgt_shape, total_elem_num_src); - int64_t src_idx = 0, tgt_idx = 0; - int64_t s, t; - int64_t src_len, tgt_len; + int src_idx = 0, tgt_idx = 0; + int s, t; + int src_len, tgt_len; src_len = static_cast(src_shape.size()); tgt_len = static_cast(inferred_tgt_shape.size()); while (src_idx < src_len || tgt_idx < tgt_len) { diff --git a/test/auto_parallel/spmd_rules/test_reshape_rule.py b/test/auto_parallel/spmd_rules/test_reshape_rule.py index 1b9ae19b0348f3..7b461ccac8751f 100644 --- a/test/auto_parallel/spmd_rules/test_reshape_rule.py +++ b/test/auto_parallel/spmd_rules/test_reshape_rule.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +from collections import OrderedDict from paddle.distributed.auto_parallel.static.dist_attribute import ( DistTensorSpec, @@ -34,14 +35,14 @@ def setUp(self): x_tensor_dist_attr.process_mesh = process_mesh self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) - self.attrs = {"shape": [1, 72, 48, 4, 6]} + self.attrs = OrderedDict([('shape', [1, 72, 48, 4, 6])]) def test_reshape_infer_forward(self): # shape: [6, 12, 48, 24] --> [1, 72, 48, 4, 6] # dims_mapping: [0, -1, 1, -1] --> [0, -1, 1, -1] [-1, 0, 1, -1, -1] self.x_dist_tensor_spec.set_dims_mapping([0, -1, 1, -1]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec], self.attrs + self.x_dist_tensor_spec, self.attrs['shape'] ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -59,7 +60,7 @@ def test_reshape_infer_forward(self): # dims_mapping: [-1, 0, -1, 1] --> [-1, -1, -1, -1] [-1, -1, -1, -1, -1] self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec], self.attrs + self.x_dist_tensor_spec, self.attrs['shape'] ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -75,7 +76,7 @@ def test_reshape_infer_forward(self): # dims_mapping: [1, -1, -1, 0] --> [1, -1, -1, 0] [-1, 1, -1, 0, -1] self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec], self.attrs + self.x_dist_tensor_spec, self.attrs['shape'] ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -93,7 +94,7 @@ def test_reshape_infer_forward(self): self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec], self.attrs + self.x_dist_tensor_spec, self.attrs['shape'] ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -110,7 +111,7 @@ def test_reshape_infer_forward(self): self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec], self.attrs + self.x_dist_tensor_spec, self.attrs['shape'] ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -128,7 +129,7 @@ def test_reshape_infer_forward(self): self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0, 1]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec], self.attrs + self.x_dist_tensor_spec, self.attrs['shape'] ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -145,7 +146,7 @@ def test_reshape_infer_forward(self): self.attrs["shape"] = [1, 72, 0, 4, 6] self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec], self.attrs + self.x_dist_tensor_spec, self.attrs['shape'] ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -162,7 +163,7 @@ def test_reshape_infer_forward(self): self.attrs["shape"] = [6, 12, 48, 24] self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0, 1]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec], self.attrs + self.x_dist_tensor_spec, self.attrs['shape'] ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -179,7 +180,7 @@ def test_reshape_infer_forward(self): self.attrs["shape"] = [72, 3, 16, 24] self.x_dist_tensor_spec.set_dims_mapping([0, -1, 1, -1]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec], self.attrs + self.x_dist_tensor_spec, self.attrs['shape'] ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -196,7 +197,7 @@ def test_reshape_infer_forward(self): self.attrs["shape"] = [72, 3, 16, 24] self.x_dist_tensor_spec.set_dims_mapping([1, -1, 0, -1]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec], self.attrs + self.x_dist_tensor_spec, self.attrs['shape'] ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -214,7 +215,7 @@ def test_reshape_infer_forward(self): self.attrs["shape"] = [6, 12, 48, 24] self.x_dist_tensor_spec.set_dims_mapping([-1, 1, -1, 0, -1]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec], self.attrs + self.x_dist_tensor_spec, self.attrs['shape'] ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -232,7 +233,7 @@ def test_reshape_infer_forward(self): self.attrs["shape"] = [0, 0, -1, 192] self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec], self.attrs + self.x_dist_tensor_spec, self.attrs['shape'] ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -246,7 +247,9 @@ def test_reshape_infer_forward(self): # raise error self.attrs["shape"] = [3, 24, 6, -1, -1] with self.assertRaises(BaseException): - self.rule.infer_forward([self.x_dist_tensor_spec], self.attrs) + self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['shape'] + ) def test_reshape_infer_backward(self): process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]]) @@ -262,9 +265,9 @@ def test_reshape_infer_backward(self): ) self.output_dist_tensor_spec.set_dims_mapping([-1, 0, 1, -1, -1]) result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec], - [self.output_dist_tensor_spec], - self.attrs, + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['shape'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -283,9 +286,9 @@ def test_reshape_infer_backward(self): self.output_dist_tensor_spec.shape = [1, 72, 48, 4, 6] self.output_dist_tensor_spec.set_dims_mapping([-1, -1, -1, -1, -1]) result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec], - [self.output_dist_tensor_spec], - self.attrs, + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['shape'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -302,9 +305,9 @@ def test_reshape_infer_backward(self): self.output_dist_tensor_spec.shape = [1, 72, 48, 4, 6] self.output_dist_tensor_spec.set_dims_mapping([-1, 1, -1, 0, -1]) result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec], - [self.output_dist_tensor_spec], - self.attrs, + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['shape'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -322,9 +325,9 @@ def test_reshape_infer_backward(self): self.output_dist_tensor_spec.set_dims_mapping([1, -1, -1, -1, 0]) result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec], - [self.output_dist_tensor_spec], - self.attrs, + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['shape'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -342,9 +345,9 @@ def test_reshape_infer_backward(self): self.output_dist_tensor_spec.set_dims_mapping([-1, -1, 0, -1, 1]) result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec], - [self.output_dist_tensor_spec], - self.attrs, + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['shape'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -361,9 +364,9 @@ def test_reshape_infer_backward(self): self.output_dist_tensor_spec.shape = [6, 12, 48, 24] self.output_dist_tensor_spec.set_dims_mapping([-1, -1, 0, 1]) result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec], - [self.output_dist_tensor_spec], - self.attrs, + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['shape'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -380,9 +383,9 @@ def test_reshape_infer_backward(self): self.output_dist_tensor_spec.shape = [72, 3, 16, 24] self.output_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1]) result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec], - [self.output_dist_tensor_spec], - self.attrs, + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['shape'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -399,9 +402,9 @@ def test_reshape_infer_backward(self): self.output_dist_tensor_spec.shape = [72, 3, 16, 24] self.output_dist_tensor_spec.set_dims_mapping([1, -1, -1, -1]) result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec], - [self.output_dist_tensor_spec], - self.attrs, + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['shape'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] From e9c5e277d69d50c090b9ffa7d83ad67cf8da42f3 Mon Sep 17 00:00:00 2001 From: Yichen Zhang Date: Thu, 21 Sep 2023 11:39:49 +0800 Subject: [PATCH 3/3] add two more unit test cases --- .../spmd_rules/test_reshape_rule.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/test/auto_parallel/spmd_rules/test_reshape_rule.py b/test/auto_parallel/spmd_rules/test_reshape_rule.py index 7b461ccac8751f..a370580682d8cb 100644 --- a/test/auto_parallel/spmd_rules/test_reshape_rule.py +++ b/test/auto_parallel/spmd_rules/test_reshape_rule.py @@ -416,6 +416,44 @@ def test_reshape_infer_backward(self): infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, -1] ) + # shape: [6, 12, 48, 24] --> [1, 72, 48, 4, 6] (intput --> output) + # dims_mapping: [-1, 0, -1, -1, 1] --> [0, -1, -1, -1], [-1, 0, -1, -1, -1] (output --> input, output) + self.output_dist_tensor_spec.shape = [1, 72, 48, 4, 6] + self.output_dist_tensor_spec.set_dims_mapping([-1, 0, -1, -1, 1]) + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['shape'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, -1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, 0, -1, -1, -1] + ) + + # shape: [6, 12, 48, 24] --> [3, 24, 6, 8, 24] (intput --> output) + # dims_mapping: [-1, 1, -1, -1, 0] --> [-1, -1, -1, 0], [-1, -1, -1, -1, 0] (output --> input, output) + self.output_dist_tensor_spec.shape = [3, 24, 6, 8, 24] + self.output_dist_tensor_spec.set_dims_mapping([-1, 1, -1, -1, 0]) + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['shape'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, 0] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1, 0] + ) + if __name__ == "__main__": unittest.main()