Skip to content

Commit

Permalink
LearningRateLogger in multi-scheduler setting (#1944)
Browse files Browse the repository at this point in the history
* fixed undesired behaviour due to dict.fromkeys

* a test for log length consistency

* runtime-warn if no schedulers are configured

* chlog

* move

Co-authored-by: Jirka <jirka@pytorchlightning.ai>
  • Loading branch information
2 people authored and justusschock committed Jun 29, 2020
1 parent 7d71dff commit c7ac477
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 84 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed an issue with `Trainer.from_argparse_args` when passing in unknown Trainer args ([#1932](https://github.com/PyTorchLightning/pytorch-lightning/pull/1932))

- Fix bug related to logger not being reset correctly for model after tuner algorithms ([#1933](https://github.com/PyTorchLightning/pytorch-lightning/pull/1933))
- Fixed bug related to logger not being reset correctly for model after tuner algorithms ([#1933](https://github.com/PyTorchLightning/pytorch-lightning/pull/1933))

- Fixed `LearningRateLogger` in multi-scheduler setting ([#1944](https://github.com/PyTorchLightning/pytorch-lightning/pull/1944))


## [0.7.6] - 2020-05-16
Expand Down
17 changes: 10 additions & 7 deletions pytorch_lightning/callbacks/lr_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from pytorch_lightning.utilities import rank_zero_warn


class LearningRateLogger(Callback):
r"""
Expand Down Expand Up @@ -45,21 +47,22 @@ def on_train_start(self, trainer, pl_module):
schedulers in the case of multiple of the same type or in
the case of multiple parameter groups
"""
if trainer.lr_schedulers == []:
raise MisconfigurationException(
'Cannot use LearningRateLogger callback with models that have no'
' learning rate schedulers. Please see documentation for'
' `configure_optimizers` method.')

if not trainer.logger:
raise MisconfigurationException(
'Cannot use LearningRateLogger callback with Trainer that has no logger.')

if not trainer.lr_schedulers:
rank_zero_warn(
'You are using LearningRateLogger callback with models that'
' have no learning rate schedulers. Please see documentation'
' for `configure_optimizers` method.', RuntimeWarning
)

# Find names for schedulers
names = self._find_names(trainer.lr_schedulers)

# Initialize for storing values
self.lrs = dict.fromkeys(names, [])
self.lrs = {name: [] for name in names}

def on_batch_start(self, trainer, pl_module):
latest_stat = self._extract_lr(trainer, 'step')
Expand Down
79 changes: 3 additions & 76 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from pathlib import Path

import pytest

import tests.base.utils as tutils
from pytorch_lightning import Callback
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger, ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import EvalModelTemplate
from pathlib import Path


def test_trainer_callback_system(tmpdir):
Expand Down Expand Up @@ -281,77 +282,3 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected):

ckpt_version = Path(trainer.ckpt_path).parent.name
assert ckpt_version == expected


def test_lr_logger_single_lr(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler"""
tutils.reset_seed()

model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__single_scheduler

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
results = trainer.fit(model)

assert results == 1
assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'


def test_lr_logger_multi_lrs(tmpdir):
""" Test that learning rates are extracted and logged for multi lr schedulers """
tutils.reset_seed()

model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__multiple_schedulers

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
results = trainer.fit(model)

assert results == 1
assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'


def test_lr_logger_param_groups(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler"""
tutils.reset_seed()

model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__param_groups

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
results = trainer.fit(model)

assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == 2 * len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of param groups'
assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'
102 changes: 102 additions & 0 deletions tests/callbacks/test_lr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import pytest

import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateLogger
from tests.base import EvalModelTemplate


def test_lr_logger_single_lr(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler. """
tutils.reset_seed()

model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__single_scheduler

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
result = trainer.fit(model)
assert result

assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'


def test_lr_logger_no_lr(tmpdir):
tutils.reset_seed()

model = EvalModelTemplate()

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)

with pytest.warns(RuntimeWarning):
result = trainer.fit(model)
assert result


def test_lr_logger_multi_lrs(tmpdir):
""" Test that learning rates are extracted and logged for multi lr schedulers. """
tutils.reset_seed()

model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__multiple_schedulers

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=10,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
result = trainer.fit(model)
assert result

assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'
assert all(len(lr) == trainer.max_epochs for k, lr in lr_logger.lrs.items()), \
'Length of logged learning rates exceeds the number of epochs'


def test_lr_logger_param_groups(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler. """
tutils.reset_seed()

model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__param_groups

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
result = trainer.fit(model)
assert result

assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == 2 * len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of param groups'
assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'

0 comments on commit c7ac477

Please sign in to comment.