Skip to content

Commit

Permalink
fix some bugs for some optimizer not call super init, test=allcases
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu committed Aug 3, 2022
1 parent 80e2c56 commit 6b1ff3a
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions python/paddle/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,6 @@ def __init__(self,
else:
self._dtype = self._parameter_list[0].dtype

# learning rate can't be float16, for pure fp16 training, should extra handle the dtype for lr
self._lr_dtype = paddle.get_default_dtype(
) if self._dtype is None else self._dtype
self._lr_dtype = paddle.float32 if self._lr_dtype == paddle.float16 else self._lr_dtype

# each program should have a independent learning rate
# program -> tensor(learning_rate)
self._learning_rate_map = dict()
Expand Down Expand Up @@ -390,18 +385,21 @@ def get_opti_var_name_list(self):
return self._opti_name_list

def _create_global_learning_rate(self):
# learning rate can't be float16, for pure fp16 training, should extra handle the dtype for lr
_lr_dtype = paddle.get_default_dtype(
) if self._dtype is None else self._dtype
_lr_dtype = paddle.float32 if self._lr_dtype == paddle.float16 else self._lr_dtype
if isinstance(self._learning_rate, LRScheduler):
lr_var = self._global_learning_rate()
# only create global lr_var once
if not isinstance(lr_var, framework.Variable):
lr_name = unique_name.generate('learning_rate')
self._learning_rate._var_name = lr_name
lr_var = self.helper.create_global_variable(
name=lr_name,
shape=[1],
persistable=True,
stop_gradient=True,
dtype=self._lr_dtype)
lr_var = self.helper.create_global_variable(name=lr_name,
shape=[1],
persistable=True,
stop_gradient=True,
dtype=_lr_dtype)
main_prog = framework.default_main_program()
main_prog.lr_sheduler = self._learning_rate
main_prog.lr_var = lr_var
Expand All @@ -423,7 +421,7 @@ def _create_global_learning_rate(self):
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(self._learning_rate),
dtype=self._lr_dtype,
dtype=_lr_dtype,
persistable=True)

@framework.dygraph_only
Expand Down

0 comments on commit 6b1ff3a

Please sign in to comment.