From c54227e3ef24942e5067e9924b5992975c36b7d3 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 7 Oct 2021 19:37:56 +0530 Subject: [PATCH 1/2] reset val --- .../logger_connector/logger_connector.py | 2 +- pytorch_lightning/tuner/batch_size_scaling.py | 1 + tests/tuner/test_scale_batch_size.py | 15 +++++++++++++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 85323e92dc7e5..8b0a59a45b84e 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -181,7 +181,7 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]: # log results of evaluation if ( - self.trainer.state.fn != TrainerFn.FITTING + self.trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING) and self.trainer.evaluating and self.trainer.is_global_zero and self.trainer.verbose_evaluate diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 0ecc983994afd..8316793ea01b0 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -166,6 +166,7 @@ def _run_power_scaling( if changed: # Force the train dataloader to reset as the batch size has changed trainer.reset_train_dataloader(model) + trainer.reset_val_dataloader(model) else: break return new_size diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index 5786bae594265..32b6f1db41ac9 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -48,6 +48,9 @@ def __init__(self, batch_size): def train_dataloader(self): return DataLoader(RandomDataset(32, 64), batch_size=getattr(self, "batch_size", 1)) + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=getattr(self, "batch_size", 1)) + @pytest.mark.parametrize(["model_bs", "dm_bs"], [(2, -1), (2, 2), (2, None), (None, 2), (16, 16)]) def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_bs): @@ -266,3 +269,15 @@ def __init__(self): trainer.tune(model) with pytest.raises(ValueError, match="could either be `power` or `binsearch`"): trainer.tuner.scale_batch_size(model, mode="ThisModeDoesNotExist") + + +def test_dataloader_reset_with_scale_batch_size(tmpdir): + """Test that train and val dataloaders are reset at every update in scale batch size.""" + model = BatchSizeModel(batch_size=16) + scale_batch_size_kwargs = {"max_trials": 5, "init_val": 4} + + trainer = Trainer(max_epochs=2, auto_scale_batch_size=True) + new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"] + + assert trainer.train_dataloader.loaders.batch_size == new_batch_size + assert trainer.val_dataloaders[0].batch_size == new_batch_size From 801e22155c41c7912cf1fe8e20fb342ea401ec0e Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 7 Oct 2021 19:44:11 +0530 Subject: [PATCH 2/2] chlog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 01076b46ec073..252b1d7cb983d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -465,6 +465,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed missing arguments when saving hyperparameters from the parent class but not from the child class ([#9800](https://github.com/PyTorchLightning/pytorch-lightning/pull/9800)) +- Reset `val_dataloader` in `tuner/batch_size_scaling` ([#9857](https://github.com/PyTorchLightning/pytorch-lightning/pull/9857)) + + ## [1.4.9] - 2021-09-30 - Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704))