From c40310d392b52fea4dbdb9f5c3b60efc8b4ac77a Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Tue, 26 Jul 2022 19:56:54 +0800 Subject: [PATCH] retain dist op returns --- .../paddle/distributed/auto_parallel/dist_op.py | 4 +--- .../tests/unittests/auto_parallel/engine_api.py | 4 ++-- .../tests/unittests/auto_parallel_gpt_model.py | 16 ++++++++-------- .../unittests/test_auto_parallel_reshard_mppp.py | 8 +------- 4 files changed, 12 insertions(+), 20 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/dist_op.py b/python/paddle/distributed/auto_parallel/dist_op.py index d48804b71fc3e..b6a77b778885f 100644 --- a/python/paddle/distributed/auto_parallel/dist_op.py +++ b/python/paddle/distributed/auto_parallel/dist_op.py @@ -267,6 +267,4 @@ def __call__(self, *args, **kwargs): dist_op = DistributedOperator(op, self._dist_attr) dist_op.dist_attr.mark_annotated_as(self._dist_attr) default_dist_ctx.add_dist_op_for_program(dist_op) - if isinstance(output, Variable): - output = [output] - return list(output) + return output diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py index ec757c03478de..9335d7d9d2e03 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py @@ -89,11 +89,11 @@ def __init__(self, def forward(self, input): out = auto.shard_op(self.norm, dist_attr={"process_mesh": - PP_MESH_0})(input)[0] + PP_MESH_0})(input) out = self.linear0(out) out = F.gelu(out, approximate=True) out = auto.shard_op(self.linear1, dist_attr={"process_mesh": - PP_MESH_1})(out)[0] + PP_MESH_1})(out) out = self.dropout(out) out = self.linear2(out) self.out = out diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py b/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py index 4695f6a4a9425..87c746ab5d3b5 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py @@ -391,7 +391,7 @@ def forward(self, mod, dist_attr={ "process_mesh": PP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + })(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( output, dist_attr={ @@ -405,7 +405,7 @@ def forward(self, mod, dist_attr={ "process_mesh": DPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + })(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( output, dist_attr={ @@ -419,7 +419,7 @@ def forward(self, mod, dist_attr={ "process_mesh": MPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + })(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( output, dist_attr={ @@ -433,7 +433,7 @@ def forward(self, mod, dist_attr={ "process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + })(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( output, dist_attr={ @@ -456,7 +456,7 @@ def forward(self, "process_mesh": PP_MESH_LIST[mod.mesh_idx] })(output, memory, tgt_mask, - use_cache, cache)[0] + use_cache, cache) auto.shard_tensor( output, dist_attr={ @@ -471,7 +471,7 @@ def forward(self, "process_mesh": DPPP_MESH_LIST[mod.mesh_idx] })(output, memory, tgt_mask, - use_cache, cache)[0] + use_cache, cache) auto.shard_tensor( output, dist_attr={ @@ -486,7 +486,7 @@ def forward(self, "process_mesh": MPPP_MESH_LIST[mod.mesh_idx] })(output, memory, tgt_mask, - use_cache, cache)[0] + use_cache, cache) auto.shard_tensor( output, dist_attr={ @@ -500,7 +500,7 @@ def forward(self, mod, dist_attr={ "process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + })(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( output, dist_attr={ diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py index 0e647a3db5b64..dfb314796a9ff 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py @@ -255,12 +255,6 @@ def test_allgather(self): "dims_mapping": [-1, -1] }) - # y = paddle.distributed.shard_op(paddle.matmul, process_mesh, { - # x.name: [-1, -1], - # w.name: [-1, -1] - # }, **{"x": x, - # "y": w})[0] - y = paddle.distributed.shard_op(paddle.matmul, dist_attr={ "process_mesh": process_mesh, @@ -270,7 +264,7 @@ def test_allgather(self): w: { "dims_mapping": [-1, -1] } - })(x, w)[0] + })(x, w) rank_id = 0 dist_context = DistributedContext()