Skip to content

Commit

Permalink
Set the lr var's dtype to fp32 when create a fp16 lr var in optimizer…
Browse files Browse the repository at this point in the history
… if user not mean to use global fp16. (#44840)
  • Loading branch information
FeixLiu authored Aug 4, 2022
1 parent 9a17f05 commit 9e39d74
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions python/paddle/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,19 +385,23 @@ def get_opti_var_name_list(self):
return self._opti_name_list

def _create_global_learning_rate(self):
# lr var can't be float16, for pure fp16 training, should extra handle the dtype for lr
_lr_dtype = paddle.get_default_dtype(
) if self._dtype is None else self._dtype
_lr_dtype = paddle.float32 if (
paddle.get_default_dtype() != "float16"
and _lr_dtype == paddle.float16) else _lr_dtype
if isinstance(self._learning_rate, LRScheduler):
lr_var = self._global_learning_rate()
# only create global lr_var once
if not isinstance(lr_var, framework.Variable):
lr_name = unique_name.generate('learning_rate')
self._learning_rate._var_name = lr_name
lr_var = self.helper.create_global_variable(
name=lr_name,
shape=[1],
persistable=True,
stop_gradient=True,
dtype=paddle.get_default_dtype()
if self._dtype is None else self._dtype)
lr_var = self.helper.create_global_variable(name=lr_name,
shape=[1],
persistable=True,
stop_gradient=True,
dtype=_lr_dtype)
main_prog = framework.default_main_program()
main_prog.lr_sheduler = self._learning_rate
main_prog.lr_var = lr_var
Expand All @@ -419,8 +423,7 @@ def _create_global_learning_rate(self):
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(self._learning_rate),
dtype=paddle.get_default_dtype()
if self._dtype is None else self._dtype,
dtype=_lr_dtype,
persistable=True)

@framework.dygraph_only
Expand Down

0 comments on commit 9e39d74

Please sign in to comment.