Skip to content

Commit

Permalink
【Hackathon 5 No.21】Add LinearLR to Paddle (PaddlePaddle#57724)
Browse files Browse the repository at this point in the history
  • Loading branch information
Asthestarsfalll authored and Frida-a committed Oct 14, 2023
1 parent 1fc628f commit cb4da91
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 0 deletions.
120 changes: 120 additions & 0 deletions python/paddle/optimizer/lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
'MultiplicativeDecay',
'OneCycleLR',
'CyclicLR',
'LinearLR',
]


Expand Down Expand Up @@ -2229,6 +2230,125 @@ def get_lr(self):
return lr


class LinearLR(LRScheduler):
r"""
Set the learning rate according to linear scheduler.
The learning rate will be firstly multiplied by start_factor and linearly increase to end learning rate.
Args:
learning_rate (float): The initial learning rate. It is a python float number.
total_steps (int): Number of iterations that the learning_rate reaches end learning_rate.
start_factor (float): Start learning rate is defined by `start_factor * learning_rate` . Default: 1./3.
end_factor (float) End learning rate is defined by `end_factor * learning_rate`. Default: 1.0.
last_epoch (int, optional): The index of last epoch. Can be set to restart training.Default: -1, means initial learning rate.
verbose: (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``LinearLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
:name: code-dynamic
>>> # Example1: train on default dynamic graph mode
>>> import paddle
>>> import numpy as np
>>> # train on default dynamic graph mode
>>> linear = paddle.nn.Linear(10, 10)
>>> scheduler = paddle.optimizer.lr.LinearLR(learning_rate=0.5, total_steps=5, verbose=True)
>>> sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
>>> for epoch in range(5):
... for batch_id in range(20):
... x = paddle.uniform([10, 10])
... out = linear(x)
... loss = paddle.mean(out)
... loss.backward()
... sgd.step()
... sgd.clear_gradients()
... scheduler.step()
.. code-block:: python
:name: code-static
>>> # Example2: train on static graph mode
>>> import paddle
>>> import numpy as np
>>> paddle.enable_static()
>>> main_prog = paddle.static.Program()
>>> start_prog = paddle.static.Program()
>>> with paddle.static.program_guard(main_prog, start_prog):
... x = paddle.static.data(name='x', shape=[None, 4, 5])
... y = paddle.static.data(name='y', shape=[None, 4, 5])
... z = paddle.static.nn.fc(x, 100)
... loss = paddle.mean(z)
... scheduler = paddle.optimizer.lr.LinearLR(learning_rate=0.5,
... total_steps=5, verbose=True)
... sgd = paddle.optimizer.SGD(learning_rate=scheduler)
... sgd.minimize(loss)
...
>>> exe = paddle.static.Executor()
>>> exe.run(start_prog)
>>> for epoch in range(5):
... for batch_id in range(20):
... out = exe.run(
... main_prog,
... feed={
... 'x': np.random.randn(3, 4, 5).astype('float32'),
... 'y': np.random.randn(3, 4, 5).astype('float32')
... },
... fetch_list=loss.name)
... scheduler.step()
"""

def __init__(
self,
learning_rate,
total_steps,
start_factor=1.0 / 3,
end_factor=1.0,
last_epoch=-1,
verbose=False,
):
if start_factor > 1.0 or start_factor <= 0:
raise ValueError(
"`start_factor` must be greater than 0 and less or equal to 1, but got {}".format(
start_factor
)
)

if end_factor > 1.0 or end_factor < 0:
raise ValueError(
"`end_factor` must be greater than 0 and less than 1, but got {}".format(
end_factor
)
)

if total_steps <= 0:
raise ValueError(
f"`total_steps` must be greater than 0, but got {total_steps}"
)

self.start_factor = start_factor
self.end_factor = end_factor
self.total_steps = total_steps

super().__init__(learning_rate, last_epoch, verbose)

def get_lr(self):
if self.last_epoch == 0:
return self.base_lr * self.start_factor
elif self.last_epoch > self.total_steps:
return self.last_lr
else:
base_lr = self.total_steps * self.start_factor
cur_factor = self.end_factor - self.start_factor
factor = 1.0 + cur_factor / (
base_lr + (self.last_epoch - 1) * cur_factor
)
return self.last_lr * factor


def autoincreased_step_counter(counter_name=None, begin=1, step=1):
"""
:api_attr: Static Graph
Expand Down
60 changes: 60 additions & 0 deletions test/legacy_test/test_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,31 @@ def exp_range(x):
return base_learning_rate + base_height * scale_fn(eval(scale_mode))


linear_last_lr = None


def linear_lr(
epoch_num,
learning_rate,
total_steps,
start_factor=1.0 / 3,
end_factor=1.0,
verbose=False,
):
global linear_last_lr
if epoch_num == 0:
linear_last_lr = learning_rate * start_factor
return linear_last_lr
elif epoch_num > total_steps:
return linear_last_lr
else:
base_lr = total_steps * start_factor
cur_factor = end_factor - start_factor
factor = 1.0 + cur_factor / (base_lr + (epoch_num - 1) * cur_factor)
linear_last_lr *= factor
return linear_last_lr


class TestLRScheduler(unittest.TestCase):
def _test_static(self, python_func, paddle_api, kwarg, place):
scheduler = paddle_api(**kwarg)
Expand Down Expand Up @@ -711,6 +736,19 @@ def test_scheduler(self):
paddle.optimizer.lr.PiecewiseDecay(
boundaries=[100, 200], values=[0.5, 0.1]
)
# check minus total_steps
with self.assertRaises(ValueError):
paddle.optimizer.lr.LinearLR(learning_rate=1, total_steps=-1)
# check start_factor
with self.assertRaises(ValueError):
paddle.optimizer.lr.LinearLR(
learning_rate=1, total_steps=5, start_factor=2
)
# check end_factor
with self.assertRaises(ValueError):
paddle.optimizer.lr.LinearLR(
learning_rate=1, total_steps=5, end_factor=2
)

func_api_kwargs = [
(
Expand Down Expand Up @@ -944,6 +982,28 @@ def test_scheduler(self):
"verbose": False,
},
),
(
linear_lr,
paddle.optimizer.lr.LinearLR,
{
"learning_rate": 0.2,
"total_steps": 40,
"start_factor": 0.5,
"end_factor": 1,
"verbose": False,
},
),
(
linear_lr,
paddle.optimizer.lr.LinearLR,
{
"learning_rate": 0.2,
"total_steps": 5,
"start_factor": 0.2,
"end_factor": 0.5,
"verbose": False,
},
),
]

for python_func, paddle_api, kwarg in func_api_kwargs:
Expand Down

0 comments on commit cb4da91

Please sign in to comment.