Skip to content

Commit

Permalink
[auto parallel] fix reshape infer shape (#60632)
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu authored Jan 10, 2024
1 parent a159cd1 commit 3324c9d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 7 deletions.
40 changes: 33 additions & 7 deletions paddle/phi/api/yaml/generator/dist_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,21 +454,47 @@
# TODO(GhostScreaming): Support aliquant condition.
# Specialized Code, for example, reshape needs to calculate local_shape
RESHAPE_CALCULATE_LOCAL_SHAPE_TEMPLATE = """
// The dist_input_x is a dist tensor, the dims() func return the global dims.
auto x_shape = dist_input_x->dims();
auto x_numel = dist_input_x->numel();
bool visit_negative = false;
std::vector<int64_t> local_shape;
for (size_t i = 0; i < shape.GetData().size(); i++) {
auto& out_dist_attr = PADDLE_GET_CONST(phi::distributed::TensorDistAttr, spmd_info.second[0]);
if (out_dist_attr.dims_mapping()[i] >= 0) {
int64_t shape_i = shape.GetData()[i];
if (shape_i == 0) {
shape_i = x_shape[i];
} else if (shape_i == -1) {
PADDLE_ENFORCE(not visit_negative,
phi::errors::InvalidArgument(
"Reshape can only have one -1 in the shape."));
visit_negative = true;
int64_t non_negative_product = 1;
for (size_t j = 0; j < shape.GetData().size(); j++) {
if (i == j) {
continue;
}
int64_t tmp_j = shape.GetData()[j];
if (tmp_j == 0) {
tmp_j = x_shape[j];
}
non_negative_product *= tmp_j;
}
PADDLE_ENFORCE(x_numel % non_negative_product == 0,
phi::errors::InvalidArgument("Cannot infer real shape for -1."));
shape_i = x_numel / non_negative_product;
}
int64_t dim = out_dist_attr.dims_mapping()[i];
int64_t mesh_dim = out_dist_attr.process_mesh().shape()[dim];
// TODO: Support aliquant condition.
PADDLE_ENFORCE_EQ(shape.GetData()[i] % mesh_dim,
0,
PADDLE_ENFORCE(shape_i % mesh_dim == 0,
phi::errors::InvalidArgument(
"Reshape only support local shape dim is divisible"
"by the mesh dim, however local_shape[%d] is %d",
"and shard mesh dims is %d",
i, shape.GetData()[i], mesh_dim));
local_shape.push_back(shape.GetData()[i] / mesh_dim);
"Reshape only support local shape dim is divisible "
"by the mesh dim, however local_shape[%lld] is %lld "
"and shard mesh dims is %lld.", i, shape_i, mesh_dim));
local_shape.push_back(shape_i / mesh_dim);
} else {
local_shape.push_back(shape.GetData()[i]);
}
Expand Down
9 changes: 9 additions & 0 deletions test/auto_parallel/semi_auto_parallel_for_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ def test_reshape_forward(self):
self.check_placements(output, [dist.Shard(0)])
self.check_placements(input.grad, [dist.Shard(0)])

def test_reshape_infer_shape(self):
mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
x = paddle.ones([10, 20, 30])
x = dist.shard_tensor(x, mesh, [Shard(0)])
y = x.reshape([-1, 0, x.shape[0]])
assert y.shape == [30, 20, 10]
assert y._local_shape == [15, 20, 10]

def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
Expand All @@ -55,6 +63,7 @@ def run_test_case(self):
else:
raise ValueError("Only support cpu or gpu backend.")
self.test_reshape_forward()
self.test_reshape_infer_shape()


if __name__ == '__main__':
Expand Down

0 comments on commit 3324c9d

Please sign in to comment.