From f59bcb1c781038b871154118f31658c0fff8b16a Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 1 Jun 2022 15:43:47 +0800 Subject: [PATCH] [AutoParallel & Science] Miscellaneous improvements (#43139) * adapt for 10 loss * partitioner support optimizer --- .../distributed/auto_parallel/operators/dist_default.py | 6 +++++- .../distributed/auto_parallel/operators/dist_reduce_p.py | 3 ++- python/paddle/distributed/auto_parallel/partitioner.py | 6 +++--- python/paddle/distributed/auto_parallel/utils.py | 5 +++++ 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index 78f30422e742f..e18cee6d42dca 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -363,7 +363,7 @@ def forward(ctx, *args, **kwargs): output_name) # replicate op in dist program - dist_op_desc = main_block.append_op(type='nop').desc + dist_op_desc = main_block.desc.append_op() dist_op_desc.copy_from(src_op.desc) set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) for input_name in src_op.desc.input_names(): @@ -371,6 +371,8 @@ def forward(ctx, *args, **kwargs): for output_name in src_op.desc.output_names(): dist_op_desc.set_output(output_name, kwargs[output_name]) + main_block._sync_with_cpp() + # data parallel synchronization for primtive operators from paddle.incubate.autograd import prim_enabled if prim_enabled(): @@ -426,6 +428,8 @@ def forward(ctx, *args, **kwargs): op_attr.set_input_dims_mapping(param.name, dims_mapping) ctx.set_op_dist_attr_for_program(new_op, op_attr) + startup_block._sync_with_cpp() + @staticmethod def backward(ctx, *args, **kwargs): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reduce_p.py b/python/paddle/distributed/auto_parallel/operators/dist_reduce_p.py index 755dcab4be34f..3275bddd9b4cc 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_reduce_p.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_reduce_p.py @@ -107,13 +107,14 @@ def forward(ctx, *args, **kwargs): output_name) # replicate op in dist program - dist_op_desc = main_block.append_op(type='nop').desc + dist_op_desc = main_block.desc.append_op() dist_op_desc.copy_from(src_op.desc) set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) for input_name in src_op.desc.input_names(): dist_op_desc.set_input(input_name, kwargs[input_name]) for output_name in src_op.desc.output_names(): dist_op_desc.set_output(output_name, kwargs[output_name]) + main_block._sync_with_cpp() # batch dimension synchronization var_name = src_op.output_arg_names[0] diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index ce686fd6a5683..6a767e5afcdf6 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -25,7 +25,7 @@ from .dist_attribute import OperatorDistributedAttribute from .process_group import new_process_group from .utils import set_dist_op_desc_original_id -from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_loss_op +from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_loss_op, is_optimize_op from .operators.common import BACKWARD_ONLY_DIST_OPS __varname_not_in_block__ = ["lod_tensor_blocking_queue_0"] @@ -263,14 +263,14 @@ def partition_block(self, ref_block, target_block): dist_op_backward_impl.backward( self._dist_context, **kinputs, **koutputs, **{"grad_var_to_var": grad_var_to_var}) - elif int(op.attr('op_role')) == 2: + elif is_optimize_op(op): kinputs, koutputs = dist_op_context.prepare_context(op) dist_op_impl = get_distributed_operator_impl_container( "default").get_impl(0) dist_op_impl.backward(self._dist_context, **kinputs, **koutputs) else: raise NotImplementedError( - "partitioner only support forward op and backward op, but got {}". + "partitioner only support forward and backward, optimize ops, but got {}". format(str(op))) def _is_valid_annotated_program(self, program): diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 42d90b0d4d619..7b198e288c636 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1099,6 +1099,11 @@ def is_backward_op(op): int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward) +def is_optimize_op(op): + return OP_ROLE_KEY in op.attr_names and \ + int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize) + + def is_loss_op(op): return OP_ROLE_KEY in op.attr_names and \ int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss))