Skip to content

Commit 3a68493

Browse files
baskrahmerawaelchlipre-commit-ci[bot]
authored
Log LearningRateMonitor values to Trainer.callback_metrics for EarlyStopping (#17626)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2ce9758 commit 3a68493

File tree

3 files changed

+46
-1
lines changed

3 files changed

+46
-1
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5454
- Added non-layer param count to the model summary ([#17005](https://github.com/Lightning-AI/lightning/pull/17005))
5555

5656

57+
- Updated `LearningRateMonitor` to log monitored values to `trainer.callback_metrics` ([#17626](https://github.com/Lightning-AI/lightning/pull/17626))
58+
59+
5760
### Changed
5861

5962
- Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))

src/lightning/pytorch/callbacks/lr_monitor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from collections import defaultdict
2424
from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Type
2525

26+
import torch
2627
from torch.optim.optimizer import Optimizer
2728

2829
import lightning.pytorch as pl
@@ -193,6 +194,10 @@ def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, floa
193194
current_stat = self._get_lr_momentum_stat(opt, names)
194195
latest_stat.update(current_stat)
195196

197+
trainer.callback_metrics.update(
198+
{name: torch.tensor(value, device=trainer.strategy.root_device) for name, value in latest_stat.items()}
199+
)
200+
196201
return latest_stat
197202

198203
def _get_lr_momentum_stat(self, optimizer: Optimizer, names: List[str]) -> Dict[str, float]:

tests/tests_pytorch/callbacks/test_lr_monitor.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch import optim
1717

1818
from lightning.pytorch import Trainer
19-
from lightning.pytorch.callbacks import LearningRateMonitor
19+
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
2020
from lightning.pytorch.callbacks.callback import Callback
2121
from lightning.pytorch.callbacks.finetuning import BackboneFinetuning
2222
from lightning.pytorch.demos.boring_classes import BoringModel
@@ -626,3 +626,40 @@ def configure_optimizers(self):
626626
assert list(lr_monitor.last_momentum_values) == ["lr-Adam/pg1-momentum", "lr-Adam/pg2-momentum"]
627627
assert all(val == momentum for val in lr_monitor.last_momentum_values.values())
628628
assert all(all(val == lr for val in lr_monitor.lrs[lr_key]) for lr_key in lr_monitor.lrs)
629+
630+
631+
def test_lr_monitor_update_callback_metrics(tmpdir):
632+
"""Test that the `LearningRateMonitor` callback updates trainer.callback_metrics."""
633+
634+
class TestModel(BoringModel):
635+
def configure_optimizers(self):
636+
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
637+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
638+
return [optimizer], [lr_scheduler]
639+
640+
monitor_key = "lr-SGD"
641+
stop_threshold = 0.02
642+
expected_stop_epoch = 3
643+
644+
lr_monitor = LearningRateMonitor()
645+
lr_es = EarlyStopping(
646+
monitor=monitor_key, mode="min", stopping_threshold=stop_threshold, check_on_train_epoch_end=True
647+
)
648+
trainer = Trainer(
649+
default_root_dir=tmpdir,
650+
callbacks=[lr_monitor, lr_es],
651+
max_epochs=5,
652+
limit_val_batches=0,
653+
limit_train_batches=2,
654+
logger=CSVLogger(tmpdir),
655+
)
656+
model = TestModel()
657+
trainer.fit(model)
658+
659+
assert monitor_key in trainer.callback_metrics
660+
assert lr_monitor.lrs[monitor_key] == [0.1, 0.05, 0.025, 0.0125]
661+
assert min(lr_monitor.lrs[monitor_key][:expected_stop_epoch]) > stop_threshold
662+
assert len(lr_monitor.lrs[monitor_key][expected_stop_epoch:]) == 1
663+
assert min(lr_monitor.lrs[monitor_key][expected_stop_epoch:]) < stop_threshold
664+
assert trainer.current_epoch - 1 == expected_stop_epoch
665+
assert lr_es.stopped_epoch == expected_stop_epoch

0 commit comments

Comments
 (0)