Skip to content

Commit 87d24ed

Browse files
committed
allreduce when grad current_mesh != pp_mesh
1 parent def4f58 commit 87d24ed

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

python/paddle/nn/clip.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,7 @@ def _dygraph_clip(self, params_grads):
717717
sum_square_list = []
718718
sum_square_list_fp16 = []
719719
sum_square_list_fp32 = []
720+
flag_auto_hybrid_pp = True # Determine whether to use the new dynamic graph semi-automatic parallel pp framework
720721
if len(params_grads) > 0 and len(params_grads[0]) > 0:
721722
src_mesh = params_grads[0][0].process_mesh
722723
else:
@@ -742,8 +743,10 @@ def _dygraph_clip(self, params_grads):
742743
# if the gradient mesh is not equal to src mesh
743744
# do reshard to get the result of squared_l2 from other pp stage mesh
744745
if src_mesh is not None and g.process_mesh != src_mesh:
746+
flag_auto_hybrid_pp = False
745747
pp_mesh = get_complete_pp_mesh(g.process_mesh)
746748
if set(g.process_mesh.process_ids) < set(pp_mesh.process_ids):
749+
flag_auto_hybrid_pp = True
747750
sum_square = dist.reshard(
748751
sum_square, pp_mesh, sum_square.placements
749752
)
@@ -798,7 +801,7 @@ def async_add_n(var_list):
798801
# then performs pp group communication reduce(sum) to get correct global_norm_var.
799802
# For complete alignment with old dygraph semi-auto parallel PP logic,
800803
# refer to NOTE: align ClipGradByGlobalNorm in auto_parallel_align_mode
801-
if src_mesh is not None:
804+
if flag_auto_hybrid_pp and src_mesh is not None:
802805
g_mesh = dist.get_mesh()
803806
if (
804807
g_mesh
@@ -882,6 +885,15 @@ def async_add_n(var_list):
882885
"Reshard a sharded tensor from a local mesh to a global mesh is not supported"
883886
)
884887
else:
888+
pp_mesh = get_complete_pp_mesh(g.process_mesh)
889+
890+
if set(g.process_mesh.process_ids) < set(
891+
pp_mesh.process_ids
892+
):
893+
clip_input = dist.reshard(
894+
clip_input, pp_mesh, clip_input.placements
895+
)
896+
885897
clip_input = paddle.distributed.reshard(
886898
clip_input, g.process_mesh, clip_input.placements
887899
)

0 commit comments

Comments
 (0)