@@ -717,6 +717,7 @@ def _dygraph_clip(self, params_grads):
717717 sum_square_list = []
718718 sum_square_list_fp16 = []
719719 sum_square_list_fp32 = []
720+ flag_auto_hybrid_pp = True # Determine whether to use the new dynamic graph semi-automatic parallel pp framework
720721 if len (params_grads ) > 0 and len (params_grads [0 ]) > 0 :
721722 src_mesh = params_grads [0 ][0 ].process_mesh
722723 else :
@@ -742,8 +743,10 @@ def _dygraph_clip(self, params_grads):
742743 # if the gradient mesh is not equal to src mesh
743744 # do reshard to get the result of squared_l2 from other pp stage mesh
744745 if src_mesh is not None and g .process_mesh != src_mesh :
746+ flag_auto_hybrid_pp = False
745747 pp_mesh = get_complete_pp_mesh (g .process_mesh )
746748 if set (g .process_mesh .process_ids ) < set (pp_mesh .process_ids ):
749+ flag_auto_hybrid_pp = True
747750 sum_square = dist .reshard (
748751 sum_square , pp_mesh , sum_square .placements
749752 )
@@ -798,7 +801,7 @@ def async_add_n(var_list):
798801 # then performs pp group communication reduce(sum) to get correct global_norm_var.
799802 # For complete alignment with old dygraph semi-auto parallel PP logic,
800803 # refer to NOTE: align ClipGradByGlobalNorm in auto_parallel_align_mode
801- if src_mesh is not None :
804+ if flag_auto_hybrid_pp and src_mesh is not None :
802805 g_mesh = dist .get_mesh ()
803806 if (
804807 g_mesh
@@ -882,6 +885,15 @@ def async_add_n(var_list):
882885 "Reshard a sharded tensor from a local mesh to a global mesh is not supported"
883886 )
884887 else :
888+ pp_mesh = get_complete_pp_mesh (g .process_mesh )
889+
890+ if set (g .process_mesh .process_ids ) < set (
891+ pp_mesh .process_ids
892+ ):
893+ clip_input = dist .reshard (
894+ clip_input , pp_mesh , clip_input .placements
895+ )
896+
885897 clip_input = paddle .distributed .reshard (
886898 clip_input , g .process_mesh , clip_input .placements
887899 )
0 commit comments