Skip to content

Commit

Permalink
Log LearningRateMonitor values to Trainer.callback_metrics for `E…
Browse files Browse the repository at this point in the history
…arlyStopping` (#17626)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored May 18, 2023
1 parent 2ce9758 commit 3a68493
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added non-layer param count to the model summary ([#17005](https://github.com/Lightning-AI/lightning/pull/17005))


- Updated `LearningRateMonitor` to log monitored values to `trainer.callback_metrics` ([#17626](https://github.com/Lightning-AI/lightning/pull/17626))


### Changed

- Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Type

import torch
from torch.optim.optimizer import Optimizer

import lightning.pytorch as pl
Expand Down Expand Up @@ -193,6 +194,10 @@ def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, floa
current_stat = self._get_lr_momentum_stat(opt, names)
latest_stat.update(current_stat)

trainer.callback_metrics.update(
{name: torch.tensor(value, device=trainer.strategy.root_device) for name, value in latest_stat.items()}
)

return latest_stat

def _get_lr_momentum_stat(self, optimizer: Optimizer, names: List[str]) -> Dict[str, float]:
Expand Down
39 changes: 38 additions & 1 deletion tests/tests_pytorch/callbacks/test_lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch import optim

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.callbacks.finetuning import BackboneFinetuning
from lightning.pytorch.demos.boring_classes import BoringModel
Expand Down Expand Up @@ -626,3 +626,40 @@ def configure_optimizers(self):
assert list(lr_monitor.last_momentum_values) == ["lr-Adam/pg1-momentum", "lr-Adam/pg2-momentum"]
assert all(val == momentum for val in lr_monitor.last_momentum_values.values())
assert all(all(val == lr for val in lr_monitor.lrs[lr_key]) for lr_key in lr_monitor.lrs)


def test_lr_monitor_update_callback_metrics(tmpdir):
"""Test that the `LearningRateMonitor` callback updates trainer.callback_metrics."""

class TestModel(BoringModel):
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
return [optimizer], [lr_scheduler]

monitor_key = "lr-SGD"
stop_threshold = 0.02
expected_stop_epoch = 3

lr_monitor = LearningRateMonitor()
lr_es = EarlyStopping(
monitor=monitor_key, mode="min", stopping_threshold=stop_threshold, check_on_train_epoch_end=True
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[lr_monitor, lr_es],
max_epochs=5,
limit_val_batches=0,
limit_train_batches=2,
logger=CSVLogger(tmpdir),
)
model = TestModel()
trainer.fit(model)

assert monitor_key in trainer.callback_metrics
assert lr_monitor.lrs[monitor_key] == [0.1, 0.05, 0.025, 0.0125]
assert min(lr_monitor.lrs[monitor_key][:expected_stop_epoch]) > stop_threshold
assert len(lr_monitor.lrs[monitor_key][expected_stop_epoch:]) == 1
assert min(lr_monitor.lrs[monitor_key][expected_stop_epoch:]) < stop_threshold
assert trainer.current_epoch - 1 == expected_stop_epoch
assert lr_es.stopped_epoch == expected_stop_epoch

0 comments on commit 3a68493

Please sign in to comment.