diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e30bb038e696..a21c00c22051c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -470,6 +470,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed missing arguments when saving hyperparameters from the parent class but not from the child class ([#9800](https://github.com/PyTorchLightning/pytorch-lightning/pull/9800)) +- Reset `val_dataloader` in `tuner/batch_size_scaling` ([#9857](https://github.com/PyTorchLightning/pytorch-lightning/pull/9857)) + + ## [1.4.9] - 2021-09-30 - Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704)) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 85323e92dc7e5..8b0a59a45b84e 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -181,7 +181,7 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]: # log results of evaluation if ( - self.trainer.state.fn != TrainerFn.FITTING + self.trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING) and self.trainer.evaluating and self.trainer.is_global_zero and self.trainer.verbose_evaluate diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 0ecc983994afd..8316793ea01b0 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -166,6 +166,7 @@ def _run_power_scaling( if changed: # Force the train dataloader to reset as the batch size has changed trainer.reset_train_dataloader(model) + trainer.reset_val_dataloader(model) else: break return new_size diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index 5786bae594265..32b6f1db41ac9 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -48,6 +48,9 @@ def __init__(self, batch_size): def train_dataloader(self): return DataLoader(RandomDataset(32, 64), batch_size=getattr(self, "batch_size", 1)) + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=getattr(self, "batch_size", 1)) + @pytest.mark.parametrize(["model_bs", "dm_bs"], [(2, -1), (2, 2), (2, None), (None, 2), (16, 16)]) def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_bs): @@ -266,3 +269,15 @@ def __init__(self): trainer.tune(model) with pytest.raises(ValueError, match="could either be `power` or `binsearch`"): trainer.tuner.scale_batch_size(model, mode="ThisModeDoesNotExist") + + +def test_dataloader_reset_with_scale_batch_size(tmpdir): + """Test that train and val dataloaders are reset at every update in scale batch size.""" + model = BatchSizeModel(batch_size=16) + scale_batch_size_kwargs = {"max_trials": 5, "init_val": 4} + + trainer = Trainer(max_epochs=2, auto_scale_batch_size=True) + new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"] + + assert trainer.train_dataloader.loaders.batch_size == new_batch_size + assert trainer.val_dataloaders[0].batch_size == new_batch_size