diff --git a/python/paddle/amp/grad_scaler.py b/python/paddle/amp/grad_scaler.py index 85e6f6efc6ba3..b6a38e0e28589 100644 --- a/python/paddle/amp/grad_scaler.py +++ b/python/paddle/amp/grad_scaler.py @@ -18,7 +18,7 @@ import numpy as np -from paddle import _legacy_C_ops +from paddle import _C_ops, _legacy_C_ops from paddle.fluid import core, in_dygraph_mode from paddle.fluid.data_feeder import check_type from paddle.fluid.dygraph import to_variable @@ -228,11 +228,9 @@ def minimize(self, optimizer, *args, **kwargs): optimize_ops, params_grads = (None, None) - if self._found_inf: - self._cache_founf_inf = True - else: - optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) - self._cache_founf_inf = False + optimizer._set_auxiliary_var('found_inf', self._found_inf) + optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) + self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf') if self._use_dynamic_loss_scaling: # uopdate the scale @@ -330,6 +328,9 @@ def _unscale(self, optimizer): param_grads_fp16, self._temp_found_inf_fp16, ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, self._temp_found_inf_fp16 + ) if len(param_grads_bf16): _legacy_C_ops.check_finite_and_unscale( param_grads_bf16, @@ -338,6 +339,9 @@ def _unscale(self, optimizer): param_grads_bf16, self._temp_found_inf_bf16, ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, self._temp_found_inf_bf16 + ) if len(param_grads_fp32): _legacy_C_ops.check_finite_and_unscale( param_grads_fp32, @@ -346,6 +350,9 @@ def _unscale(self, optimizer): param_grads_fp32, self._temp_found_inf_fp32, ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, self._temp_found_inf_fp32 + ) else: if len(param_grads_fp16): _legacy_C_ops.check_finite_and_unscale( @@ -354,6 +361,9 @@ def _unscale(self, optimizer): param_grads_fp16, self._temp_found_inf_fp16, ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, self._temp_found_inf_fp16 + ) if len(param_grads_bf16): _legacy_C_ops.check_finite_and_unscale( param_grads_bf16, @@ -361,6 +371,9 @@ def _unscale(self, optimizer): param_grads_bf16, self._temp_found_inf_bf16, ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, self._temp_found_inf_bf16 + ) if len(param_grads_fp32): _legacy_C_ops.check_finite_and_unscale( param_grads_fp32, @@ -368,12 +381,9 @@ def _unscale(self, optimizer): param_grads_fp32, self._temp_found_inf_fp32, ) - - self._found_inf = ( - self._temp_found_inf_fp16 - or self._temp_found_inf_bf16 - or self._temp_found_inf_fp32 - ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, self._temp_found_inf_fp32 + ) optimizer_state["state"] = OptimizerState.UNSCALED @@ -761,11 +771,9 @@ def step(self, optimizer): if optimizer_state["state"] is OptimizerState.INIT: self._unscale(optimizer) - if self._found_inf: - self._cache_founf_inf = True - else: - optimizer.step() - self._cache_founf_inf = False + optimizer._set_auxiliary_var('found_inf', self._found_inf) + optimizer.step() + self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf') optimizer_state["state"] = OptimizerState.STEPPED diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py index 144dc8b6586c3..c12843f106562 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py @@ -41,11 +41,9 @@ def minimize(self, optimizer, *args, **kwargs): optimize_ops, params_grads = (None, None) - if self._found_inf: - self._cache_founf_inf = True - else: - optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) - self._cache_founf_inf = False + optimizer._set_auxiliary_var('found_inf', self._found_inf) + optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) + self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf') if self._use_dynamic_loss_scaling: self._update() diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py index b1ab777964777..361b421bbae4b 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py @@ -19,10 +19,10 @@ import numpy as np import paddle -from paddle import _legacy_C_ops +from paddle import _C_ops, _legacy_C_ops from paddle.common_ops_import import dygraph_only +from paddle.fluid import core from paddle.fluid.dygraph import to_variable -from paddle.framework import core from paddle.nn import clip @@ -231,6 +231,9 @@ def unscale_method(self, optimizer): param_grads_fp16, temp_found_inf_fp16, ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, temp_found_inf_fp16 + ) if len(param_grads_fp32): _legacy_C_ops.check_finite_and_unscale( param_grads_fp32, @@ -238,15 +241,17 @@ def unscale_method(self, optimizer): param_grads_fp32, temp_found_inf_fp32, ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, temp_found_inf_fp32 + ) - self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0 - is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32") + self._found_inf = self._found_inf.cast("int32") paddle.distributed.all_reduce( - is_found_inf, op=paddle.distributed.ReduceOp.SUM, group=None + self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None ) - self._found_inf = is_found_inf.numpy()[0] + self._found_inf = self._found_inf.cast("bool") scaler._unscale = MethodType(unscale_method, scaler) return scaler diff --git a/python/paddle/distributed/fleet/scaler.py b/python/paddle/distributed/fleet/scaler.py index 003265a86123f..a06b73fd0c3ff 100755 --- a/python/paddle/distributed/fleet/scaler.py +++ b/python/paddle/distributed/fleet/scaler.py @@ -17,7 +17,7 @@ import numpy as np import paddle -from paddle import _legacy_C_ops +from paddle import _C_ops, _legacy_C_ops from paddle.distributed import fleet from paddle.fluid.dygraph import to_variable from paddle.framework import core @@ -73,6 +73,9 @@ def unscale_method(self, optimizer): param_grads_fp16, temp_found_inf_fp16, ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, temp_found_inf_fp16 + ) if len(param_grads_fp32): _legacy_C_ops.check_finite_and_unscale( param_grads_fp32, @@ -80,17 +83,19 @@ def unscale_method(self, optimizer): param_grads_fp32, temp_found_inf_fp32, ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, temp_found_inf_fp32 + ) - self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0 - is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32") + self._found_inf = self._found_inf.cast("int32") # TODO(shenliang03) Since dp allreduce in the optimizer is # after the gradscaler, check_finite needs to synchronize global # information. In the future, we should use check_group to speed. paddle.distributed.all_reduce( - is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None + self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None ) - self._found_inf = is_found_inf.numpy()[0] + self._found_inf = self._found_inf.cast("bool") # Only data_parallel doesn't need to modify scaler fleet_env = fleet.fleet diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 52501992ef9e9..c5aa80c749027 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -893,11 +893,18 @@ def _create_optimization_pass(self, parameters_and_grads): self._create_global_learning_rate() if in_dygraph_mode(): - for param_and_grad in parameters_and_grads: - if param_and_grad[1] is None: - continue - if param_and_grad[0].trainable is True: - self._append_optimize_op(target_block, param_and_grad) + found_inf = self._get_auxiliary_var('found_inf') + if found_inf: + if isinstance(found_inf, core.eager.Tensor): + self._set_auxiliary_var('found_inf', True) + else: + if isinstance(found_inf, core.eager.Tensor): + self._set_auxiliary_var('found_inf', False) + for param_and_grad in parameters_and_grads: + if param_and_grad[1] is None: + continue + if param_and_grad[0].trainable is True: + self._append_optimize_op(target_block, param_and_grad) else: for param_and_grad in parameters_and_grads: if param_and_grad[1] is None: diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index 070efdff2d126..9c827496e8b2e 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -360,8 +360,6 @@ def _append_optimize_op(self, block, param_and_grad): # create the adam optimize op if framework.in_dygraph_mode(): - found_inf = self._get_auxiliary_var('found_inf') - _beta1 = ( self._beta1 if not isinstance(self._beta1, Variable) @@ -382,7 +380,7 @@ def _append_optimize_op(self, block, param_and_grad): beta1_pow_acc, beta2_pow_acc, master_weight, - found_inf, + None, _beta1, _beta2, self._epsilon, @@ -693,21 +691,28 @@ def _append_optimize_multi_tensor_op( if master_weight is not None else None ) - _, _, _, _, _, _ = _C_ops.merged_adam_( - self._param_dict[key][param_group_idx], - grad_dict[key], - lr_dict[key], - self._moment1_dict[key][param_group_idx], - self._moment2_dict[key][param_group_idx], - self._beta1_pow_acc_dict[key][param_group_idx], - self._beta2_pow_acc_dict[key][param_group_idx], - master_weight, - _beta1, - _beta2, - self._epsilon, - find_master, - False, - ) + found_inf = self._get_auxiliary_var('found_inf') + if found_inf: + if isinstance(found_inf, core.eager.Tensor): + self._set_auxiliary_var('found_inf', True) + else: + if isinstance(found_inf, core.eager.Tensor): + self._set_auxiliary_var('found_inf', False) + _, _, _, _, _, _ = _C_ops.merged_adam_( + self._param_dict[key][param_group_idx], + grad_dict[key], + lr_dict[key], + self._moment1_dict[key][param_group_idx], + self._moment2_dict[key][param_group_idx], + self._beta1_pow_acc_dict[key][param_group_idx], + self._beta2_pow_acc_dict[key][param_group_idx], + master_weight, + _beta1, + _beta2, + self._epsilon, + find_master, + False, + ) else: inputs = { "Param": self._param_dict[key][param_group_idx], diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index a4d304b451e7b..5a75e6d243696 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -491,7 +491,6 @@ def _append_optimize_op(self, block, param_and_grad): else self._beta2.numpy().item(0) ) - found_inf = self._get_auxiliary_var('found_inf') _, _, _, _, _, _ = _C_ops.adamw_( param_and_grad[0], param_and_grad[1], @@ -501,7 +500,7 @@ def _append_optimize_op(self, block, param_and_grad): beta1_pow_acc, beta2_pow_acc, master_weight, - found_inf, + None, _beta1, _beta2, self._epsilon, diff --git a/python/paddle/optimizer/lamb.py b/python/paddle/optimizer/lamb.py index e531e785e319f..57904cd44a86c 100644 --- a/python/paddle/optimizer/lamb.py +++ b/python/paddle/optimizer/lamb.py @@ -293,7 +293,6 @@ def _append_optimize_op(self, block, param_and_grad): self._used_master_weights[p_name] = master_weight.name else: master_weight = None - found_inf = self._get_auxiliary_var('found_inf') if framework.in_dygraph_mode(): _C_ops.lamb_( @@ -305,7 +304,7 @@ def _append_optimize_op(self, block, param_and_grad): beta1_pow_acc, beta2_pow_acc, master_weight, - found_inf, + None, weight_decay, self._beta1, self._beta2, @@ -343,6 +342,7 @@ def _append_optimize_op(self, block, param_and_grad): inputs["MasterParam"] = master_weight outputs["MasterParamOut"] = master_weight + found_inf = self._get_auxiliary_var('found_inf') if found_inf: inputs["SkipUpdate"] = found_inf diff --git a/python/paddle/optimizer/momentum.py b/python/paddle/optimizer/momentum.py index 3b20777599fb0..bff9c1209e708 100644 --- a/python/paddle/optimizer/momentum.py +++ b/python/paddle/optimizer/momentum.py @@ -530,19 +530,30 @@ def _append_optimize_multi_tensor_op( ) if in_dygraph_mode(): - _, _, _ = _C_ops.merged_momentum_( - self._param_dict[key][param_group_idx], - grad_dict[key], - self._velocity_dict[key][param_group_idx], - lr_dict[key], - master_weight, - self._momentum, - self._use_nesterov, - self._regularization_method_dict[key][param_group_idx], - self._regularization_coeff_dict[key][param_group_idx], - find_master, - self._rescale_grad, - ) + found_inf = self._get_auxiliary_var('found_inf') + if found_inf: + if isinstance(found_inf, core.eager.Tensor): + self._set_auxiliary_var('found_inf', True) + else: + if isinstance(found_inf, core.eager.Tensor): + self._set_auxiliary_var('found_inf', False) + _, _, _ = _C_ops.merged_momentum_( + self._param_dict[key][param_group_idx], + grad_dict[key], + self._velocity_dict[key][param_group_idx], + lr_dict[key], + master_weight, + self._momentum, + self._use_nesterov, + self._regularization_method_dict[key][ + param_group_idx + ], + self._regularization_coeff_dict[key][ + param_group_idx + ], + find_master, + self._rescale_grad, + ) else: inputs = { "Param": self._param_dict[key][param_group_idx], diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 1799461254ced..cad226952be41 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -920,31 +920,38 @@ def _create_optimization_pass( self._create_accumulators(target_block, params_acc_dict) if framework._non_static_mode(): - if isinstance(parameters_and_grads, list): - for param_and_grad in parameters_and_grads: - if param_and_grad[1] is None: - continue - if param_and_grad[0].stop_gradient is False: - self._append_optimize_op( - target_block, param_and_grad - ) + found_inf = self._get_auxiliary_var('found_inf') + if found_inf: + if isinstance(found_inf, core.eager.Tensor): + self._set_auxiliary_var('found_inf', True) else: - for param_and_grad in parameters_and_grads['params']: - if param_and_grad[1] is None: - continue - if param_and_grad[0].stop_gradient is False: - param_grad_dict = dict() - param_grad_dict['params'] = param_and_grad - param_grad_dict.update( - { - k: v - for k, v in parameters_and_grads.items() - if k != 'params' - } - ) - self._append_optimize_op( - target_block, param_grad_dict - ) + if isinstance(found_inf, core.eager.Tensor): + self._set_auxiliary_var('found_inf', False) + if isinstance(parameters_and_grads, list): + for param_and_grad in parameters_and_grads: + if param_and_grad[1] is None: + continue + if param_and_grad[0].stop_gradient is False: + self._append_optimize_op( + target_block, param_and_grad + ) + else: + for param_and_grad in parameters_and_grads['params']: + if param_and_grad[1] is None: + continue + if param_and_grad[0].stop_gradient is False: + param_grad_dict = dict() + param_grad_dict['params'] = param_and_grad + param_grad_dict.update( + { + k: v + for k, v in parameters_and_grads.items() + if k != 'params' + } + ) + self._append_optimize_op( + target_block, param_grad_dict + ) else: for param_and_grad in parameters_and_grads: if param_and_grad[1] is None: