-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Investigate if lr_scheduler from segmentation can use PyTorch's schedulers #4438
Comments
Is there any |
I had a look also when Francisco raised the ticket but couldn't see anything compatible TBH. |
It might not be implemented yet. I think we should check to see if this type of scheduler has been used in more papers since then, that could justify adding it to PyTorch. |
I'm removing the "good first issue" tag because I think there isn't such a scheduler on Core and more thorough investigation would be needed to resolve. Perhaps coordinating with Core to add it is worth it but that's not a great Bootcamp task. |
Hi guys, I'm working on this issue as reported here. However, I think I need to know some extra information about the expected behavior of the scheduler. So far, I have considered the following resources:
Let's see one by one. I'm going to fix some parameters, to make a fair comparison.
data_loader = range(0, 5)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
torch.optim.SGD([v], lr=lr),
lambda step: (1 - step / len(data_loader)) ** 1.0
)
>>> for i in data_loader:
>>> lr_scheduler.step(i)
>>> print(i, optimizer.param_groups[0]['lr'])
0 0.001
1 0.0008
2 0.0006
3 0.0004
4 0.00019999999999999996 poly = tf.keras.optimizers.schedules.PolynomialDecay(
lr,
max_decay_step,
end_learning_rate= end_learning_rate,
power=power,
cycle=False,
name=None
)
>>> for i in range(0, 5):
>>> print(i, poly(i))
0 tf.Tensor(0.001, shape=(), dtype=float32)
1 tf.Tensor(0.00077499996, shape=(), dtype=float32)
2 tf.Tensor(0.00055, shape=(), dtype=float32)
3 tf.Tensor(0.000325, shape=(), dtype=float32)
4 tf.Tensor(1e-04, shape=(), dtype=float32) scheduler = PolynomialLRDecay(
torch.optim.SGD([torch.zeros(10)], lr=lr),
max_decay_steps=max_decay_steps,
end_learning_rate= end_learning_rate,
power=power,
)
>>> for i in range(0, 5):
>>> scheduler.step()
>>> print(i, optim.param_groups[0]['lr'])
0 0.00055
1 0.000325
2 0.0001
3 0.0001
4 0.0001
>>> for i in range(0, 5):
>>> scheduler.step(i)
>>> print(i, optim.param_groups[0]['lr'])
0 0.0007750000000000001
1 0.0007750000000000001
2 0.00055
3 0.000325
4 0.0001 Open issues:
|
@federicopozzi33 These are all very good questions. Unfortunately I wasn't too familiar with the API of Schedulers so in order to answer them I had to implement it and experiment. Here is the proposed implementation: import warnings
import torch
from torch.optim.lr_scheduler import _LRScheduler
class PolynomialLR(_LRScheduler):
def __init__(self, optimizer, total_iters=5, min_lr=0.0, power=1.0, last_epoch=-1, verbose=False):
self.total_iters = total_iters
if isinstance(min_lr, list) or isinstance(min_lr, tuple):
if len(min_lr) != len(optimizer.param_groups):
raise ValueError("expected {} min_lrs, got {}".format(len(optimizer.param_groups), len(min_lr)))
self.min_lrs = list(min_lr)
else:
self.min_lrs = [min_lr] * len(optimizer.param_groups)
self.power = power
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning
)
if self.last_epoch == 0:
return [group["lr"] for group in self.optimizer.param_groups]
if self.last_epoch > self.total_iters:
return [self.min_lrs[i] for i in range(len(self.optimizer.param_groups))]
return [
self.min_lrs[i]
+ ((1.0 - self.last_epoch / self.total_iters) / (1.0 - (self.last_epoch - 1) / self.total_iters))
** self.power
* (group["lr"] - self.min_lrs[i])
for i, group in enumerate(self.optimizer.param_groups)
]
def _get_closed_form_lr(self):
return [
(
self.min_lrs[i]
+ (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) ** self.power
* (base_lr - self.min_lrs[i])
)
for i, base_lr in enumerate(self.base_lrs)
]
# Test it
lr = 0.001
total_iters = 5
power = 1.0
scheduler = PolynomialLR(
torch.optim.SGD([torch.zeros(1)], lr=lr),
total_iters=total_iters,
min_lr=0.0, # Using 0 because the Lambda doesn't support this option
power=power,
)
scheduler2 = torch.optim.lr_scheduler.LambdaLR(
torch.optim.SGD([torch.zeros(1)], lr=lr), lambda step: (1 - step / total_iters) ** power
)
for i in range(0, total_iters):
print(i, scheduler.optimizer.param_groups[0]["lr"], scheduler2.optimizer.param_groups[0]["lr"])
scheduler.step()
scheduler2.step() Here are some answers to your questions:
Though I think we can use the above implementation as-is, to be able to contribute it to PyTorch core we need tests, docs and a few more bells and whistles. I believe the PR pytorch/pytorch#60836 is a good example of what needs to be done. If you are up for it, you can start a PR and I can help you get it merged. Alternatively, I can finish it off and find you a different primitive. Let me know what you prefer. |
Hi @datumbox, thank you for your help. I have some doubts about the meaning of
I didn't find any references for some parts of the formula you used for the decayed LR. Although the values seem correct to me, I have some doubts about the part:
Could you explain me better?
Ok, I get what you mean, but I was referring to this.
Yeah, I'm putting the pieces together . I will open a PR soon. |
Correct the
The API of Schedulers is a bit weird. The changes on the
Sounds good, make sure you tag me on the PR. |
The scheduler has been implemented (see pytorch/pytorch#82769). It remains only to update the segmentation training script using the newly implemented scheduler as soon as a new version of PyTorch is released. |
That's correct. In fact once the Scheduler makes it to the nightly, we can make the change. Not sure if it made it to the one today or if it will appear tomorrow, but you can start a PR and I'll review/test/merge soon. Would that work for you? |
Back when it was initially implemented in 2019, the LR scheduler in the segmentation reference scripts couldn't be implemented with native PyTorch schedulers, so we had to resort to
LambdaLR
vision/references/segmentation/train.py
Lines 136 to 138 in 9275cc6
It might be that this is now available in PyTorch natively, and this can be simplified.
cc @datumbox
The text was updated successfully, but these errors were encountered: