Skip to content

Commit

Permalink
Skip tuner algorithms on fast dev (#3903)
Browse files Browse the repository at this point in the history
* skip on fast dev

* fix error

* changelog

* fix recursive issue

* combine tests

* pep8

* move logic to base funcs

* fix mistake

* Update pytorch_lightning/tuner/lr_finder.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* pep

Co-authored-by: William Falcon <waf2107@columbia.edu>
Co-authored-by: Nicki Skafte <nugginea@gmail.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: chaton <thomas@grid.ai>

(cherry picked from commit 4f3160b)
  • Loading branch information
SkafteNicki authored and SeanNaren committed Nov 10, 2020
1 parent 243b0e8 commit 96362af
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def scale_batch_size(trainer,
**fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader
or datamodule.
"""
if trainer.fast_dev_run:
rank_zero_warn('Skipping batch size scaler since `fast_dev_run=True`', UserWarning)
return

if not lightning_hasattr(model, batch_arg_name):
raise MisconfigurationException(
f'Field {batch_arg_name} not found in both `model` and `model.hparams`')
Expand Down
9 changes: 9 additions & 0 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem

# check if ipywidgets is installed before importing tqdm.auto
Expand All @@ -42,6 +43,10 @@
def _run_lr_finder_internally(trainer, model: LightningModule):
""" Call lr finder internally during Trainer.fit() """
lr_finder = lr_find(trainer, model)

if lr_finder is None:
return

lr = lr_finder.suggestion()

# TODO: log lr.results to self.logger
Expand Down Expand Up @@ -131,6 +136,10 @@ def lr_find(
trainer.fit(model)
"""
if trainer.fast_dev_run:
rank_zero_warn('Skipping learning rate finder since `fast_dev_run=True`', UserWarning)
return

save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt')

__lr_finder_dump_params(trainer, model)
Expand Down
21 changes: 21 additions & 0 deletions tests/trainer/flags/test_fast_dev_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest
from pytorch_lightning import Trainer
from tests.base import EvalModelTemplate


@pytest.mark.parametrize('tuner_alg', ['batch size scaler', 'learning rate finder'])
def test_skip_on_fast_dev_run_batch_scaler(tmpdir, tuner_alg):
""" Test that tuner algorithms are skipped if fast dev run is enabled """

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
auto_scale_batch_size=True if tuner_alg == 'batch size scaler' else False,
auto_lr_find=True if tuner_alg == 'learning rate finder' else False,
fast_dev_run=True
)
expected_message = f'Skipping {tuner_alg} since `fast_dev_run=True`'
with pytest.warns(UserWarning, match=expected_message):
trainer.tune(model)

0 comments on commit 96362af

Please sign in to comment.