Skip to content

Commit

Permalink
[AutoParallel & Science] Miscellaneous improvements (#43139)
Browse files Browse the repository at this point in the history
* adapt for 10 loss

* partitioner support optimizer
  • Loading branch information
JZ-LIANG authored Jun 1, 2022
1 parent ff1789c commit f59bcb1
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -363,14 +363,16 @@ def forward(ctx, *args, **kwargs):
output_name)

# replicate op in dist program
dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc = main_block.desc.append_op()
dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])

main_block._sync_with_cpp()

# data parallel synchronization for primtive operators
from paddle.incubate.autograd import prim_enabled
if prim_enabled():
Expand Down Expand Up @@ -426,6 +428,8 @@ def forward(ctx, *args, **kwargs):
op_attr.set_input_dims_mapping(param.name, dims_mapping)
ctx.set_op_dist_attr_for_program(new_op, op_attr)

startup_block._sync_with_cpp()

@staticmethod
def backward(ctx, *args, **kwargs):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,14 @@ def forward(ctx, *args, **kwargs):
output_name)

# replicate op in dist program
dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc = main_block.desc.append_op()
dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])
main_block._sync_with_cpp()

# batch dimension synchronization
var_name = src_op.output_arg_names[0]
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/distributed/auto_parallel/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .dist_attribute import OperatorDistributedAttribute
from .process_group import new_process_group
from .utils import set_dist_op_desc_original_id
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_loss_op
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_loss_op, is_optimize_op
from .operators.common import BACKWARD_ONLY_DIST_OPS

__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"]
Expand Down Expand Up @@ -263,14 +263,14 @@ def partition_block(self, ref_block, target_block):
dist_op_backward_impl.backward(
self._dist_context, **kinputs, **koutputs,
**{"grad_var_to_var": grad_var_to_var})
elif int(op.attr('op_role')) == 2:
elif is_optimize_op(op):
kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_impl = get_distributed_operator_impl_container(
"default").get_impl(0)
dist_op_impl.backward(self._dist_context, **kinputs, **koutputs)
else:
raise NotImplementedError(
"partitioner only support forward op and backward op, but got {}".
"partitioner only support forward and backward, optimize ops, but got {}".
format(str(op)))

def _is_valid_annotated_program(self, program):
Expand Down
5 changes: 5 additions & 0 deletions python/paddle/distributed/auto_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,11 @@ def is_backward_op(op):
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward)


def is_optimize_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize)


def is_loss_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss))
Expand Down

1 comment on commit f59bcb1

@paddle-bot-old
Copy link

@paddle-bot-old paddle-bot-old bot commented on f59bcb1 Jun 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍 PR: #245 Commit ID: f59bcb1 contains failed CI.

🔹 Failed: PR-CI-ROCM-Compile

Unknown Failed
Unknown Failed

Please sign in to comment.