Skip to content

Commit

Permalink
[hybrid][npu] fix npu clear float status in pipeline (PaddlePaddle#35165
Browse files Browse the repository at this point in the history
) (PaddlePaddle#35295)

Co-authored-by: WangXi <wangxi16@baidu.com>
  • Loading branch information
sljlp and wangxicoding authored Aug 31, 2021
1 parent e64105f commit 167685e
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions python/paddle/fluid/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4654,15 +4654,22 @@ def _add_op_device_attr_for_op(self, op, idx, block):
op.type == 'elementwise_div'):
device = f"{self._device}:all"
op._set_attr(self._op_device_key, device)
elif self._is_weight_decay_op(op) and op.type == 'scale':
# set AdamW decay_coeff to device:all
op._set_attr(self._op_device_key, f"{self._device}:all")
elif op.type == "alloc_float_status" or op.type == "clear_float_status":
op._set_attr(self._op_device_key, f"{self._device}:all")
# NOTE(wangxi): NPU should only clear the float status
# once at each batch step
op._set_attr(self._op_role_key, self._op_role.LRSched)

float_status_name = op.output_arg_names[0]
float_status_var = block.var(float_status_name)
# FIXME(wangxi): pipeline lr schedule will exec on sub_scope(0)
# while update will exec on sub_scope(last_micro_step), should
# set persistable to use global scope
float_status_var.persistable = True
else:
other_known_ops = [
'update_loss_scaling', 'reduce_any', 'concat', 'sum',
'check_finite_and_unscale', 'alloc_float_status', 'memcpy'
'check_finite_and_unscale', 'memcpy'
]
assert op.type in other_known_ops, "For other ops without " \
"op_device set, they must be one of {}, but it " \
Expand Down

0 comments on commit 167685e

Please sign in to comment.