diff --git a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py index c793639c5ba01..719f733305cac 100644 --- a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py +++ b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py @@ -232,6 +232,7 @@ def _append_gradient_merge_backward_op( attrs={ 'axis': -1, OP_ROLE_KEY: OpRole.Backward, + "op_namescope": "/auto_parallel/gradient_merge", }, ) @@ -375,8 +376,16 @@ def true_apply_gradient(): ) paddle.static.nn.cond(cond_var, true_fn=true_apply_gradient, false_fn=None) + cond_dist_attr = dist_context.get_tensor_dist_attr_for_program(cond_var) cond_op = main_program.global_block().ops[-1] cond_op._set_attr(OP_ROLE_KEY, OpRole.Optimize) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + cond_op, + process_mesh=cond_dist_attr.process_mesh, + ref_mapping=cond_dist_attr.dims_mapping, + ctx=dist_context, + chunk_id=cond_dist_attr.chunk_id, + ) def parse_program(