From 0c958c5a1f8ca8926eeaf1a2035725c15417830e Mon Sep 17 00:00:00 2001 From: "Xinyao(Alvin) Sun" Date: Mon, 24 May 2021 02:21:45 -0600 Subject: [PATCH] Fix dataloaders are not reset when tuning the model (#7566) Co-authored-by: Carlos Mocholi --- CHANGELOG.md | 2 + pytorch_lightning/tuner/batch_size_scaling.py | 10 +++- tests/tuner/test_scale_batch_size.py | 47 +++++++++++-------- 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 35089c9f993a1..1239f349e8f5f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -120,6 +120,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566)) + - Fixed parsing of multiple training dataloaders ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433)) diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 120a95a5084b1..d114c36a60104 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -160,7 +160,10 @@ def _run_power_scaling( else: raise # some other error not memory related - if not changed: + if changed: + # Force the train dataloader to reset as the batch size has changed + trainer.reset_train_dataloader(model) + else: break return new_size @@ -192,7 +195,10 @@ def _run_binsearch_scaling( else: new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') - if not changed: + if changed: + # Force the train dataloader to reset as the batch size has changed + trainer.reset_train_dataloader(model) + else: break except RuntimeError as exception: diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index 7d4e05000d5da..f9e132662b220 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -24,14 +24,14 @@ from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate -from tests.helpers import BoringDataModule, BoringModel +from tests.helpers import BoringDataModule, BoringModel, RandomDataset from tests.helpers.datamodules import MNISTDataModule from tests.helpers.runif import RunIf class BatchSizeDataModule(BoringDataModule): - def __init__(self, batch_size=None): + def __init__(self, batch_size): super().__init__() if batch_size is not None: self.batch_size = batch_size @@ -42,21 +42,23 @@ def train_dataloader(self): class BatchSizeModel(BoringModel): - def __init__(self, batch_size=None): + def __init__(self, batch_size): super().__init__() if batch_size is not None: self.batch_size = batch_size + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=getattr(self, "batch_size", 1)) -@pytest.mark.parametrize( - "model,datamodule", [ - (BatchSizeModel(2), None), - (BatchSizeModel(2), BatchSizeDataModule(2)), - (BatchSizeModel(2), BatchSizeDataModule(None)), - (BatchSizeModel(None), BatchSizeDataModule(2)), - ] -) -def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule): + +@pytest.mark.parametrize(["model_bs", "dm_bs"], [ + (2, -1), + (2, 2), + (2, None), + (None, 2), + (16, 16), +]) +def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_bs): """ Test the tuner method `Tuner.scale_batch_size` with a datamodule. """ trainer = Trainer( default_root_dir=tmpdir, @@ -65,14 +67,21 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamod max_epochs=1, ) tuner = Tuner(trainer) - new_batch_size = tuner.scale_batch_size( - model=model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule - ) + + model = BatchSizeModel(model_bs) + datamodule = BatchSizeDataModule(dm_bs) if dm_bs != -1 else None + + new_batch_size = tuner.scale_batch_size(model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule) assert new_batch_size == 16 - if hasattr(model, "batch_size"): - assert model.batch_size == 16 - if datamodule is not None and hasattr(datamodule, "batch_size"): - assert datamodule.batch_size == 16 + + if model_bs is not None: + assert model.batch_size == new_batch_size + if dm_bs == -1: + # datamodule batch size takes precedence + assert trainer.train_dataloader.loaders.batch_size == new_batch_size + if dm_bs not in (-1, None): + assert datamodule.batch_size == new_batch_size + assert trainer.train_dataloader.loaders.batch_size == new_batch_size def test_model_reset_correctly(tmpdir):