Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lr finder for optimizers with states #3897

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `current_epoch` and `global_step` properties mismatch between `Trainer` and `LightningModule` ([#3785](https://github.com/PyTorchLightning/pytorch-lightning/pull/3785))

- Fixed learning rate scheduler for optimizers with internal state ([#3897](https://github.com/PyTorchLightning/pytorch-lightning/pull/3897))

## [0.9.0] - YYYY-MM-DD

### Added
Expand Down
65 changes: 40 additions & 25 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# limitations under the License.
import importlib
import os
from typing import List, Optional, Sequence, Union
from typing import List, Optional, Sequence, Union, Callable
from functools import wraps

import numpy as np
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -165,13 +167,7 @@ def lr_find(
trainer.save_checkpoint(str(save_path))

# Configure optimizer and scheduler
optimizers, _, _ = trainer.init_optimizers(model)

if len(optimizers) != 1:
raise MisconfigurationException(
f'`model.configure_optimizers()` returned {len(optimizers)}, but'
' learning rate finder only works with single optimizer')
model.configure_optimizers = lr_finder._get_new_optimizer(optimizers[0])
model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers)

# Fit, lr & loss logged in callback
trainer.fit(model,
Expand Down Expand Up @@ -261,28 +257,47 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
self.results = {}
self._total_batch_idx = 0 # for debug purpose

def _get_new_optimizer(self, optimizer: torch.optim.Optimizer):
""" Construct a new `configure_optimizers()` method, that has a optimizer
with initial lr set to lr_min and a scheduler that will either
linearly or exponentially increase the lr to lr_max in num_training steps.

Args:
optimizer: instance of `torch.optim.Optimizer`

def _exchange_scheduler(self, configure_optimizers: Callable):
""" Decorate configure_optimizers methods such that it returns the users
originally specified optimizer together with a new scheduler that
that takes care of the learning rate search.
"""
new_lrs = [self.lr_min] * len(optimizer.param_groups)
for param_group, new_lr in zip(optimizer.param_groups, new_lrs):
param_group["lr"] = new_lr
param_group["initial_lr"] = new_lr

args = (optimizer, self.lr_max, self.num_training)
scheduler = _LinearLR(*args) if self.mode == 'linear' else _ExponentialLR(*args)
@wraps(configure_optimizers)
def func():
# Decide the structure of the output from configure_optimizers
# Same logic as method `init_optimizers` in trainer/optimizers.py
optim_conf = configure_optimizers()
if isinstance(optim_conf, Optimizer):
optimizers = [optim_conf]
elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 \
and isinstance(optim_conf[0], list):
optimizers, _ = optim_conf
elif isinstance(optim_conf, dict):
optimizers = [optim_conf["optimizer"]]
elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict):
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
elif isinstance(optim_conf, (list, tuple)):
optimizers = [optim_conf]

if len(optimizers) != 1:
raise MisconfigurationException(
f'`model.configure_optimizers()` returned {len(optimizers)}, but'
' learning rate finder only works with single optimizer')

optimizer = optimizers[0]

new_lrs = [self.lr_min] * len(optimizer.param_groups)
for param_group, new_lr in zip(optimizer.param_groups, new_lrs):
param_group["lr"] = new_lr
param_group["initial_lr"] = new_lr

args = (optimizer, self.lr_max, self.num_training)
scheduler = _LinearLR(*args) if self.mode == 'linear' else _ExponentialLR(*args)

def configure_optimizers():
return [optimizer], [{'scheduler': scheduler,
'interval': 'step'}]

return configure_optimizers
return func

def plot(self, suggest: bool = False, show: bool = False):
""" Plot results from lr_find run
Expand Down
4 changes: 4 additions & 0 deletions tests/base/model_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def configure_optimizers__lbfgs(self):
optimizer = optim.LBFGS(self.parameters(), lr=self.learning_rate)
return optimizer

def configure_optimizers__adagrad(self):
optimizer = optim.Adagrad(self.parameters(), lr=self.learning_rate)
return optimizer

def configure_optimizers__multiple_optimizers(self):
"""
return whatever optimizers we want here.
Expand Down
5 changes: 4 additions & 1 deletion tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,14 @@ def test_trainer_arg_str(tmpdir, use_hparams):
'Learning rate was not altered after running learning rate finder'


def test_call_to_trainer_method(tmpdir):
@pytest.mark.parametrize('optimizer', ['Adam', 'Adagrad'])
def test_call_to_trainer_method(tmpdir, optimizer):
""" Test that directly calling the trainer method works """

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
if optimizer == 'adagrad':
model.configure_optimizers = model.configure_optimizers__adagrad

before_lr = hparams.get('learning_rate')
# logger file to get meta
Expand Down