diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index 5880f11b30fa0..e7dfc4c50563b 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -454,21 +454,47 @@ # TODO(GhostScreaming): Support aliquant condition. # Specialized Code, for example, reshape needs to calculate local_shape RESHAPE_CALCULATE_LOCAL_SHAPE_TEMPLATE = """ + + // The dist_input_x is a dist tensor, the dims() func return the global dims. + auto x_shape = dist_input_x->dims(); + auto x_numel = dist_input_x->numel(); + bool visit_negative = false; std::vector local_shape; for (size_t i = 0; i < shape.GetData().size(); i++) { auto& out_dist_attr = PADDLE_GET_CONST(phi::distributed::TensorDistAttr, spmd_info.second[0]); if (out_dist_attr.dims_mapping()[i] >= 0) { + int64_t shape_i = shape.GetData()[i]; + if (shape_i == 0) { + shape_i = x_shape[i]; + } else if (shape_i == -1) { + PADDLE_ENFORCE(not visit_negative, + phi::errors::InvalidArgument( + "Reshape can only have one -1 in the shape.")); + visit_negative = true; + int64_t non_negative_product = 1; + for (size_t j = 0; j < shape.GetData().size(); j++) { + if (i == j) { + continue; + } + int64_t tmp_j = shape.GetData()[j]; + if (tmp_j == 0) { + tmp_j = x_shape[j]; + } + non_negative_product *= tmp_j; + } + PADDLE_ENFORCE(x_numel % non_negative_product == 0, + phi::errors::InvalidArgument("Cannot infer real shape for -1.")); + shape_i = x_numel / non_negative_product; + } int64_t dim = out_dist_attr.dims_mapping()[i]; int64_t mesh_dim = out_dist_attr.process_mesh().shape()[dim]; // TODO: Support aliquant condition. - PADDLE_ENFORCE_EQ(shape.GetData()[i] % mesh_dim, - 0, + PADDLE_ENFORCE(shape_i % mesh_dim == 0, phi::errors::InvalidArgument( - "Reshape only support local shape dim is divisible" - "by the mesh dim, however local_shape[%d] is %d", - "and shard mesh dims is %d", - i, shape.GetData()[i], mesh_dim)); - local_shape.push_back(shape.GetData()[i] / mesh_dim); + "Reshape only support local shape dim is divisible " + "by the mesh dim, however local_shape[%lld] is %lld " + "and shard mesh dims is %lld.", i, shape_i, mesh_dim)); + local_shape.push_back(shape_i / mesh_dim); } else { local_shape.push_back(shape.GetData()[i]); } diff --git a/test/auto_parallel/semi_auto_parallel_for_reshape.py b/test/auto_parallel/semi_auto_parallel_for_reshape.py index 6fac20bbb8afd..ac194353655b7 100644 --- a/test/auto_parallel/semi_auto_parallel_for_reshape.py +++ b/test/auto_parallel/semi_auto_parallel_for_reshape.py @@ -47,6 +47,14 @@ def test_reshape_forward(self): self.check_placements(output, [dist.Shard(0)]) self.check_placements(input.grad, [dist.Shard(0)]) + def test_reshape_infer_shape(self): + mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + x = paddle.ones([10, 20, 30]) + x = dist.shard_tensor(x, mesh, [Shard(0)]) + y = x.reshape([-1, 0, x.shape[0]]) + assert y.shape == [30, 20, 10] + assert y._local_shape == [15, 20, 10] + def run_test_case(self): if self._backend == "cpu": paddle.set_device("cpu") @@ -55,6 +63,7 @@ def run_test_case(self): else: raise ValueError("Only support cpu or gpu backend.") self.test_reshape_forward() + self.test_reshape_infer_shape() if __name__ == '__main__':