Skip to content

Commit def4f58

Browse files
committed
adapt grad clip for moe layer
1 parent bfbb12d commit def4f58

File tree

1 file changed

+1
-12
lines changed

1 file changed

+1
-12
lines changed

python/paddle/nn/clip.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,6 @@ def _dygraph_clip(self, params_grads):
717717
sum_square_list = []
718718
sum_square_list_fp16 = []
719719
sum_square_list_fp32 = []
720-
flag_auto_hybrid_pp = True # Determine whether to use the new dynamic graph semi-automatic parallel pp framework
721720
if len(params_grads) > 0 and len(params_grads[0]) > 0:
722721
src_mesh = params_grads[0][0].process_mesh
723722
else:
@@ -743,7 +742,6 @@ def _dygraph_clip(self, params_grads):
743742
# if the gradient mesh is not equal to src mesh
744743
# do reshard to get the result of squared_l2 from other pp stage mesh
745744
if src_mesh is not None and g.process_mesh != src_mesh:
746-
flag_auto_hybrid_pp = False
747745
pp_mesh = get_complete_pp_mesh(g.process_mesh)
748746
if set(g.process_mesh.process_ids) < set(pp_mesh.process_ids):
749747
sum_square = dist.reshard(
@@ -800,7 +798,7 @@ def async_add_n(var_list):
800798
# then performs pp group communication reduce(sum) to get correct global_norm_var.
801799
# For complete alignment with old dygraph semi-auto parallel PP logic,
802800
# refer to NOTE: align ClipGradByGlobalNorm in auto_parallel_align_mode
803-
if flag_auto_hybrid_pp and src_mesh is not None:
801+
if src_mesh is not None:
804802
g_mesh = dist.get_mesh()
805803
if (
806804
g_mesh
@@ -884,15 +882,6 @@ def async_add_n(var_list):
884882
"Reshard a sharded tensor from a local mesh to a global mesh is not supported"
885883
)
886884
else:
887-
pp_mesh = get_complete_pp_mesh(g.process_mesh)
888-
889-
if set(g.process_mesh.process_ids) < set(
890-
pp_mesh.process_ids
891-
):
892-
clip_input = dist.reshard(
893-
clip_input, pp_mesh, clip_input.placements
894-
)
895-
896885
clip_input = paddle.distributed.reshard(
897886
clip_input, g.process_mesh, clip_input.placements
898887
)

0 commit comments

Comments
 (0)