Skip to content

Commit

Permalink
Fix restarting attribute for lr finder (#15620)
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock authored Dec 8, 2022
1 parent d0b101c commit 15184c6
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 7 deletions.
4 changes: 4 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed deprecated `pytorch_lightning.utilities.memory.get_gpu_memory_map` in favor of `pytorch_lightning.accelerators.cuda.get_nvidia_gpu_stats` ([#15617](https://github.com/Lightning-AI/lightning/pull/15617))


- Temporarily removed support for Hydra multi-run ([#15737](https://github.com/Lightning-AI/lightning/pull/15737))


Expand All @@ -87,6 +88,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed issue with unsupported torch.inference_mode() on hpu backends ([#15918](https://github.com/Lightning-AI/lightning/pull/15918))

- Fixed `fit_loop.restarting` to be `False` for lr finder ([#15620](https://github.com/Lightning-AI/lightning/pull/15620))


## [1.8.3] - 2022-11-22

Expand All @@ -104,6 +107,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed the automatic fallback from `Trainer(strategy="ddp_spawn", ...)` to `Trainer(strategy="ddp", ...)` when on an LSF cluster ([#15103](https://github.com/PyTorchLightning/pytorch-lightning/issues/15103))



## [1.8.1] - 2022-11-10

### Added
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/callbacks/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
max_lr: float = 1,
num_training_steps: int = 100,
mode: str = "exponential",
early_stop_threshold: float = 4.0,
early_stop_threshold: Optional[float] = 4.0,
update_attr: bool = False,
) -> None:
mode = mode.lower()
Expand Down
15 changes: 9 additions & 6 deletions src/pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def lr_find(
max_lr: float = 1,
num_training: int = 100,
mode: str = "exponential",
early_stop_threshold: float = 4.0,
early_stop_threshold: Optional[float] = 4.0,
update_attr: bool = False,
) -> Optional[_LRFinder]:
"""See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`"""
Expand All @@ -225,6 +225,8 @@ def lr_find(
ckpt_path = trainer.strategy.broadcast(ckpt_path)
trainer.save_checkpoint(ckpt_path)

start_steps = trainer.global_step

# Arguments we adjust during the lr finder, save for restoring
params = __lr_finder_dump_params(trainer)

Expand All @@ -245,7 +247,7 @@ def lr_find(
_try_loop_run(trainer, params)

# Prompt if we stopped early
if trainer.global_step != num_training:
if trainer.global_step != num_training + start_steps:
log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.")

# Transfer results from callback to lr finder object
Expand All @@ -270,6 +272,7 @@ def lr_find(
# Restore initial state of model
trainer._checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True

return lr_finder

Expand All @@ -289,7 +292,7 @@ def __lr_finder_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]:
}


def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: float) -> None:
def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: Optional[float]) -> None:
from pytorch_lightning.loggers.logger import DummyLogger

trainer.strategy.lr_scheduler_configs = []
Expand All @@ -300,8 +303,8 @@ def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_sto
trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)]
# No logging
trainer.logger = DummyLogger() if trainer.logger is not None else None
# Max step set to number of iterations
trainer.fit_loop.max_steps = num_training
# Max step set to number of iterations starting at current number of iterations
trainer.fit_loop.max_steps = num_training + trainer.global_step
trainer.limit_val_batches = num_training


Expand Down Expand Up @@ -340,7 +343,7 @@ class _LRCallback(Callback):
def __init__(
self,
num_training: int,
early_stop_threshold: float = 4.0,
early_stop_threshold: Optional[float] = 4.0,
progress_bar_refresh_rate: int = 0,
beta: float = 0.98,
):
Expand Down
47 changes: 47 additions & 0 deletions tests/tests_pytorch/tuner/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,53 @@ def test_if_lr_finder_callback_already_configured():
trainer.tune(model)


def test_lr_finder_callback_restarting(tmpdir):
"""Test that `LearningRateFinder` does not set restarting=True when loading checkpoint."""

num_lr_steps = 100

class MyBoringModel(BoringModel):
def __init__(self):
super().__init__()
self.learning_rate = 0.123

def on_train_batch_start(self, batch, batch_idx):
if getattr(self, "_expected_max_steps", None) is not None:
assert self.trainer.fit_loop.max_steps == self._expected_max_steps

def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=self.learning_rate)

class CustomLearningRateFinder(LearningRateFinder):
milestones = (1,)

def lr_find(self, trainer, pl_module) -> None:
pl_module._expected_max_steps = trainer.global_step + self._num_training_steps
super().lr_find(trainer, pl_module)
pl_module._expected_max_steps = None
assert not trainer.fit_loop.restarting

def on_train_epoch_start(self, trainer, pl_module):
if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
self.lr_find(trainer, pl_module)

model = MyBoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=3,
callbacks=[
CustomLearningRateFinder(early_stop_threshold=None, update_attr=True, num_training_steps=num_lr_steps)
],
limit_train_batches=10,
limit_val_batches=0,
limit_test_batches=0,
num_sanity_val_steps=0,
enable_model_summary=False,
)

trainer.fit(model)


@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
@RunIf(standalone=True)
def test_lr_finder_with_ddp(tmpdir):
Expand Down

0 comments on commit 15184c6

Please sign in to comment.