Skip to content

Commit

Permalink
Reset val_dataloader in tuner/batch_size_scaling (#9857)
Browse files Browse the repository at this point in the history
* reset val

* chlog
  • Loading branch information
rohitgr7 authored Oct 11, 2021
1 parent 8740c80 commit d71501d
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed missing arguments when saving hyperparameters from the parent class but not from the child class ([#9800](https://github.com/PyTorchLightning/pytorch-lightning/pull/9800))


- Reset `val_dataloader` in `tuner/batch_size_scaling` ([#9857](https://github.com/PyTorchLightning/pytorch-lightning/pull/9857))


## [1.4.9] - 2021-09-30

- Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]:

# log results of evaluation
if (
self.trainer.state.fn != TrainerFn.FITTING
self.trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING)
and self.trainer.evaluating
and self.trainer.is_global_zero
and self.trainer.verbose_evaluate
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def _run_power_scaling(
if changed:
# Force the train dataloader to reset as the batch size has changed
trainer.reset_train_dataloader(model)
trainer.reset_val_dataloader(model)
else:
break
return new_size
Expand Down
15 changes: 15 additions & 0 deletions tests/tuner/test_scale_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def __init__(self, batch_size):
def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=getattr(self, "batch_size", 1))

def val_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=getattr(self, "batch_size", 1))


@pytest.mark.parametrize(["model_bs", "dm_bs"], [(2, -1), (2, 2), (2, None), (None, 2), (16, 16)])
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_bs):
Expand Down Expand Up @@ -266,3 +269,15 @@ def __init__(self):
trainer.tune(model)
with pytest.raises(ValueError, match="could either be `power` or `binsearch`"):
trainer.tuner.scale_batch_size(model, mode="ThisModeDoesNotExist")


def test_dataloader_reset_with_scale_batch_size(tmpdir):
"""Test that train and val dataloaders are reset at every update in scale batch size."""
model = BatchSizeModel(batch_size=16)
scale_batch_size_kwargs = {"max_trials": 5, "init_val": 4}

trainer = Trainer(max_epochs=2, auto_scale_batch_size=True)
new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"]

assert trainer.train_dataloader.loaders.batch_size == new_batch_size
assert trainer.val_dataloaders[0].batch_size == new_batch_size

0 comments on commit d71501d

Please sign in to comment.