|
16 | 16 | from torch import optim
|
17 | 17 |
|
18 | 18 | from lightning.pytorch import Trainer
|
19 |
| -from lightning.pytorch.callbacks import LearningRateMonitor |
| 19 | +from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor |
20 | 20 | from lightning.pytorch.callbacks.callback import Callback
|
21 | 21 | from lightning.pytorch.callbacks.finetuning import BackboneFinetuning
|
22 | 22 | from lightning.pytorch.demos.boring_classes import BoringModel
|
@@ -626,3 +626,40 @@ def configure_optimizers(self):
|
626 | 626 | assert list(lr_monitor.last_momentum_values) == ["lr-Adam/pg1-momentum", "lr-Adam/pg2-momentum"]
|
627 | 627 | assert all(val == momentum for val in lr_monitor.last_momentum_values.values())
|
628 | 628 | assert all(all(val == lr for val in lr_monitor.lrs[lr_key]) for lr_key in lr_monitor.lrs)
|
| 629 | + |
| 630 | + |
| 631 | +def test_lr_monitor_update_callback_metrics(tmpdir): |
| 632 | + """Test that the `LearningRateMonitor` callback updates trainer.callback_metrics.""" |
| 633 | + |
| 634 | + class TestModel(BoringModel): |
| 635 | + def configure_optimizers(self): |
| 636 | + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) |
| 637 | + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5) |
| 638 | + return [optimizer], [lr_scheduler] |
| 639 | + |
| 640 | + monitor_key = "lr-SGD" |
| 641 | + stop_threshold = 0.02 |
| 642 | + expected_stop_epoch = 3 |
| 643 | + |
| 644 | + lr_monitor = LearningRateMonitor() |
| 645 | + lr_es = EarlyStopping( |
| 646 | + monitor=monitor_key, mode="min", stopping_threshold=stop_threshold, check_on_train_epoch_end=True |
| 647 | + ) |
| 648 | + trainer = Trainer( |
| 649 | + default_root_dir=tmpdir, |
| 650 | + callbacks=[lr_monitor, lr_es], |
| 651 | + max_epochs=5, |
| 652 | + limit_val_batches=0, |
| 653 | + limit_train_batches=2, |
| 654 | + logger=CSVLogger(tmpdir), |
| 655 | + ) |
| 656 | + model = TestModel() |
| 657 | + trainer.fit(model) |
| 658 | + |
| 659 | + assert monitor_key in trainer.callback_metrics |
| 660 | + assert lr_monitor.lrs[monitor_key] == [0.1, 0.05, 0.025, 0.0125] |
| 661 | + assert min(lr_monitor.lrs[monitor_key][:expected_stop_epoch]) > stop_threshold |
| 662 | + assert len(lr_monitor.lrs[monitor_key][expected_stop_epoch:]) == 1 |
| 663 | + assert min(lr_monitor.lrs[monitor_key][expected_stop_epoch:]) < stop_threshold |
| 664 | + assert trainer.current_epoch - 1 == expected_stop_epoch |
| 665 | + assert lr_es.stopped_epoch == expected_stop_epoch |
0 commit comments