Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed LearningRateLogger to LearningRateMonitor #3251

Merged
merged 8 commits into from
Sep 3, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ Lightning has a few built-in callbacks.

----------------

.. automodule:: pytorch_lightning.callbacks.lr_logger
.. automodule:: pytorch_lightning.callbacks.lr_monitor
:noindex:
:exclude-members:
_extract_lr,
Expand Down
13 changes: 7 additions & 6 deletions pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
from pytorch_lightning.callbacks.lr_logger import LearningRateLogger
from pytorch_lightning.callbacks.lr_monitor import LearningRateLogger, LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.progress import ProgressBarBase, ProgressBar
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase

__all__ = [
'Callback',
'EarlyStopping',
'ModelCheckpoint',
'GPUStatsMonitor',
'GradientAccumulationScheduler',
'LearningRateLogger',
'ProgressBarBase',
'LearningRateMonitor',
'ModelCheckpoint',
'ProgressBar',
'GPUStatsMonitor'
'ProgressBarBase',
]
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

r"""

Learning Rate Logger
====================
Learning Rate Monitor
=====================

Log learning rate for lr schedulers during training
Monitor and logs learning rate for lr schedulers during training.

"""

Expand All @@ -28,9 +28,9 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class LearningRateLogger(Callback):
class LearningRateMonitor(Callback):
r"""
Automatically logs learning rate for learning rate schedulers during training.
Automatically monitor and logs learning rate for learning rate schedulers during training.

Args:
logging_interval: set to `epoch` or `step` to log `lr` of all optimizers
Expand All @@ -40,9 +40,9 @@ class LearningRateLogger(Callback):
Example::

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import LearningRateLogger
>>> lr_logger = LearningRateLogger(logging_interval='step')
>>> trainer = Trainer(callbacks=[lr_logger])
>>> from pytorch_lightning.callbacks import LearningRateMonitor
>>> lr_monitor = LearningRateMonitor(logging_interval='step')
>>> trainer = Trainer(callbacks=[lr_monitor])

Logging names are automatically determined based on optimizer class name.
In case of multiple optimizers of same type, they will be named `Adam`,
Expand All @@ -57,6 +57,7 @@ def configure_optimizer(self):
lr_scheduler = {'scheduler': torch.optim.lr_schedulers.LambdaLR(optimizer, ...)
'name': 'my_logging_name'}
return [optimizer], [lr_scheduler]

"""
def __init__(self, logging_interval: Optional[str] = None):
if logging_interval not in (None, 'step', 'epoch'):
Expand All @@ -69,18 +70,19 @@ def __init__(self, logging_interval: Optional[str] = None):
self.lr_sch_names = []

def on_train_start(self, trainer, pl_module):
""" Called before training, determines unique names for all lr
schedulers in the case of multiple of the same type or in
the case of multiple parameter groups
"""
Called before training, determines unique names for all lr
schedulers in the case of multiple of the same type or in
the case of multiple parameter groups
"""
if not trainer.logger:
raise MisconfigurationException(
'Cannot use LearningRateLogger callback with Trainer that has no logger.'
'Cannot use LearningRateMonitor callback with Trainer that has no logger.'
)

if not trainer.lr_schedulers:
rank_zero_warn(
'You are using LearningRateLogger callback with models that'
'You are using LearningRateMonitor callback with models that'
' have no learning rate schedulers. Please see documentation'
' for `configure_optimizers` method.', RuntimeWarning
)
Expand Down Expand Up @@ -135,6 +137,7 @@ def _find_names(self, lr_schedulers):
else:
opt_name = 'lr-' + sch.optimizer.__class__.__name__
i, name = 1, opt_name

# Multiple schduler of the same type
while True:
if name not in names:
Expand All @@ -154,3 +157,10 @@ def _find_names(self, lr_schedulers):
self.lr_sch_names.append(name)

return names


class LearningRateLogger(LearningRateMonitor):
def __init__(self, *args, **kwargs):
rank_zero_warn("`LearningRateLogger` is now `LearningRateMonitor`"
" and this will be removed in v0.10.0", DeprecationWarning)
super().__init__(*args, **kwargs)
Original file line number Diff line number Diff line change
@@ -1,109 +1,127 @@
import pytest

import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateLogger
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
import tests.base.develop_utils as tutils


def test_lr_logger_single_lr(tmpdir):
def test_lr_monitor_single_lr(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler. """
tutils.reset_seed()

model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__single_scheduler

lr_logger = LearningRateLogger()
lr_monitor = LearningRateMonitor()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_val_batches=0.1,
limit_train_batches=0.5,
callbacks=[lr_logger],
callbacks=[lr_monitor],
)
result = trainer.fit(model)
assert result

assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
assert lr_monitor.lrs, 'No learning rates logged'
assert len(lr_monitor.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \
assert all([k in ['lr-Adam'] for k in lr_monitor.lrs.keys()]), \
'Names of learning rates not set correctly'


def test_lr_logger_no_lr(tmpdir):
def test_lr_monitor_no_lr_scheduler(tmpdir):
tutils.reset_seed()

model = EvalModelTemplate()

lr_logger = LearningRateLogger()
lr_monitor = LearningRateMonitor()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_val_batches=0.1,
limit_train_batches=0.5,
callbacks=[lr_logger],
callbacks=[lr_monitor],
)

with pytest.warns(RuntimeWarning):
with pytest.warns(RuntimeWarning, match='have no learning rate schedulers'):
result = trainer.fit(model)
assert result


def test_lr_monitor_no_logger(tmpdir):
tutils.reset_seed()

model = EvalModelTemplate()

lr_monitor = LearningRateMonitor()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
callbacks=[lr_monitor],
logger=False
)

with pytest.raises(MisconfigurationException, match='Trainer that has no logger'):
trainer.fit(model)


@pytest.mark.parametrize("logging_interval", ['step', 'epoch'])
def test_lr_logger_multi_lrs(tmpdir, logging_interval):
def test_lr_monitor_multi_lrs(tmpdir, logging_interval):
""" Test that learning rates are extracted and logged for multi lr schedulers. """
tutils.reset_seed()

model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__multiple_schedulers

lr_logger = LearningRateLogger(logging_interval=logging_interval)
lr_monitor = LearningRateMonitor(logging_interval=logging_interval)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_val_batches=0.1,
limit_train_batches=0.5,
callbacks=[lr_logger],
callbacks=[lr_monitor],
)
result = trainer.fit(model)
assert result

assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
assert lr_monitor.lrs, 'No learning rates logged'
assert len(lr_monitor.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_monitor.lrs.keys()]), \
'Names of learning rates not set correctly'

if logging_interval == 'step':
expected_number_logged = trainer.global_step
if logging_interval == 'epoch':
expected_number_logged = trainer.max_epochs

assert all(len(lr) == expected_number_logged for lr in lr_logger.lrs.values()), \
assert all(len(lr) == expected_number_logged for lr in lr_monitor.lrs.values()), \
'Length of logged learning rates do not match the expected number'


def test_lr_logger_param_groups(tmpdir):
def test_lr_monitor_param_groups(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler. """
tutils.reset_seed()

model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__param_groups

lr_logger = LearningRateLogger()
lr_monitor = LearningRateMonitor()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_val_batches=0.1,
limit_train_batches=0.5,
callbacks=[lr_logger],
callbacks=[lr_monitor],
)
result = trainer.fit(model)
assert result

assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == 2 * len(trainer.lr_schedulers), \
assert lr_monitor.lrs, 'No learning rates logged'
assert len(lr_monitor.lrs) == 2 * len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of param groups'
assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \
assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_monitor.lrs.keys()]), \
'Names of learning rates not set correctly'