diff --git a/CHANGELOG.md b/CHANGELOG.md index 37945300774d5..f809e66c6b7ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -546,6 +546,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed issue with non-init dataclass fields in `apply_to_collection` ([#9963](https://github.com/PyTorchLightning/pytorch-lightning/issues/9963)) +- Reset `val_dataloader` in `tuner/batch_size_scaling` for binsearch ([#9975](https://github.com/PyTorchLightning/pytorch-lightning/pull/9975)) + ## [1.4.9] - 2021-09-30 diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index d3fd0822aa39f..42f9ce084a43c 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -205,6 +205,7 @@ def _run_binsearch_scaling( if changed: # Force the train dataloader to reset as the batch size has changed trainer.reset_train_dataloader(model) + trainer.reset_val_dataloader(model) else: break diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index 5e4d1af1277c7..9dbb24d9edf30 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -274,10 +274,11 @@ def __init__(self): trainer.tuner.scale_batch_size(model, mode="ThisModeDoesNotExist") -def test_dataloader_reset_with_scale_batch_size(tmpdir): +@pytest.mark.parametrize("scale_method", ["power", "binsearch"]) +def test_dataloader_reset_with_scale_batch_size(tmpdir, scale_method): """Test that train and val dataloaders are reset at every update in scale batch size.""" model = BatchSizeModel(batch_size=16) - scale_batch_size_kwargs = {"max_trials": 5, "init_val": 4} + scale_batch_size_kwargs = {"max_trials": 5, "init_val": 4, "mode": scale_method} trainer = Trainer(max_epochs=2, auto_scale_batch_size=True) new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"]