Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hybrid check] improve pipeline stage check #34193

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions python/paddle/fluid/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4663,6 +4663,7 @@ def _check_validation(self, block):
pre_stage_id = None
decrease_flag = False
in_optimize = False
in_forward = True
for op in block.ops:
if not op._has_kernel(op.type):
assert op.type == "conditional_block" and (
Expand All @@ -4680,6 +4681,8 @@ def _check_validation(self, block):
valid_op_role_value)
if int(op_role) == int(self._op_role.Optimize):
in_optimize = True
if int(op_role) == int(self._op_role.Backward):
in_forward = False

assert op.has_attr(self._op_device_key), (
"op ({}) has no {} attribute.".format(op.type,
Expand Down Expand Up @@ -4707,14 +4710,16 @@ def _check_validation(self, block):
"but the interval of op={} and prev op is {}".format(op, interval)
# stage must be in order, such as Forward(0 1 2 3 4), Backward(4 3 2 1 0)
# if stage is unordered, such as Forward(0 1 2 3 4 3 4), will report error
if interval == -1:
decrease_flag = True
if interval == 1:
# FIXME(wangxi): recompute failed
if in_forward:
assert interval >= 0, \
"Pipeline stage must be sequential increment in Forward, prev_stage={}, " \
"please check the stage of op={}".format(pre_stage_id, op)
else:
# FIXME(wangxi): recompute check failed
pass
#assert decrease_flag is False, \
# "Pipeline stage must be in order, " \
# "please check the stage of op={}".format(op)
#assert interval <=0, \
# "Pipeline stage must be sequential decrement in Backward, prev_stage={}, " \
# "please check the stage of op={}".format(pre_stage_id, op)
pre_stage_id = stage_id

return device_list
Expand Down