Skip to content

Commit

Permalink
Fix dataloaders are not reset when tuning the model (#7566)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
  • Loading branch information
Lucklyric and carmocca authored May 24, 2021
1 parent 299f2c4 commit 0c958c5
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 21 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566))


- Fixed parsing of multiple training dataloaders ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433))

Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ def _run_power_scaling(
else:
raise # some other error not memory related

if not changed:
if changed:
# Force the train dataloader to reset as the batch size has changed
trainer.reset_train_dataloader(model)
else:
break
return new_size

Expand Down Expand Up @@ -192,7 +195,10 @@ def _run_binsearch_scaling(
else:
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')

if not changed:
if changed:
# Force the train dataloader to reset as the batch size has changed
trainer.reset_train_dataloader(model)
else:
break

except RuntimeError as exception:
Expand Down
47 changes: 28 additions & 19 deletions tests/tuner/test_scale_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
from tests.helpers.datamodules import MNISTDataModule
from tests.helpers.runif import RunIf


class BatchSizeDataModule(BoringDataModule):

def __init__(self, batch_size=None):
def __init__(self, batch_size):
super().__init__()
if batch_size is not None:
self.batch_size = batch_size
Expand All @@ -42,21 +42,23 @@ def train_dataloader(self):

class BatchSizeModel(BoringModel):

def __init__(self, batch_size=None):
def __init__(self, batch_size):
super().__init__()
if batch_size is not None:
self.batch_size = batch_size

def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=getattr(self, "batch_size", 1))

@pytest.mark.parametrize(
"model,datamodule", [
(BatchSizeModel(2), None),
(BatchSizeModel(2), BatchSizeDataModule(2)),
(BatchSizeModel(2), BatchSizeDataModule(None)),
(BatchSizeModel(None), BatchSizeDataModule(2)),
]
)
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule):

@pytest.mark.parametrize(["model_bs", "dm_bs"], [
(2, -1),
(2, 2),
(2, None),
(None, 2),
(16, 16),
])
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_bs):
""" Test the tuner method `Tuner.scale_batch_size` with a datamodule. """
trainer = Trainer(
default_root_dir=tmpdir,
Expand All @@ -65,14 +67,21 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamod
max_epochs=1,
)
tuner = Tuner(trainer)
new_batch_size = tuner.scale_batch_size(
model=model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule
)

model = BatchSizeModel(model_bs)
datamodule = BatchSizeDataModule(dm_bs) if dm_bs != -1 else None

new_batch_size = tuner.scale_batch_size(model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule)
assert new_batch_size == 16
if hasattr(model, "batch_size"):
assert model.batch_size == 16
if datamodule is not None and hasattr(datamodule, "batch_size"):
assert datamodule.batch_size == 16

if model_bs is not None:
assert model.batch_size == new_batch_size
if dm_bs == -1:
# datamodule batch size takes precedence
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
if dm_bs not in (-1, None):
assert datamodule.batch_size == new_batch_size
assert trainer.train_dataloader.loaders.batch_size == new_batch_size


def test_model_reset_correctly(tmpdir):
Expand Down

0 comments on commit 0c958c5

Please sign in to comment.