From 167685e524127bf1f61a21fe750b07d5dce8b090 Mon Sep 17 00:00:00 2001 From: Roc <30228238+sljlp@users.noreply.github.com> Date: Tue, 31 Aug 2021 13:37:31 +0800 Subject: [PATCH] [hybrid][npu] fix npu clear float status in pipeline (#35165) (#35295) Co-authored-by: WangXi --- python/paddle/fluid/optimizer.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 9e87681c4bef3..378902d8dde81 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4654,15 +4654,22 @@ def _add_op_device_attr_for_op(self, op, idx, block): op.type == 'elementwise_div'): device = f"{self._device}:all" op._set_attr(self._op_device_key, device) - elif self._is_weight_decay_op(op) and op.type == 'scale': - # set AdamW decay_coeff to device:all - op._set_attr(self._op_device_key, f"{self._device}:all") elif op.type == "alloc_float_status" or op.type == "clear_float_status": op._set_attr(self._op_device_key, f"{self._device}:all") + # NOTE(wangxi): NPU should only clear the float status + # once at each batch step + op._set_attr(self._op_role_key, self._op_role.LRSched) + + float_status_name = op.output_arg_names[0] + float_status_var = block.var(float_status_name) + # FIXME(wangxi): pipeline lr schedule will exec on sub_scope(0) + # while update will exec on sub_scope(last_micro_step), should + # set persistable to use global scope + float_status_var.persistable = True else: other_known_ops = [ 'update_loss_scaling', 'reduce_any', 'concat', 'sum', - 'check_finite_and_unscale', 'alloc_float_status', 'memcpy' + 'check_finite_and_unscale', 'memcpy' ] assert op.type in other_known_ops, "For other ops without " \ "op_device set, they must be one of {}, but it " \