Skip to content

Commit

Permalink
Fix zero division error for empty dataloaders (#12885)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
rohitgr7 and carmocca authored May 3, 2022
1 parent 5641836 commit 9bfbd9e
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 82 deletions.
12 changes: 9 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Make positional arguments required for classes passed into the `add_argparse_args` function. ([#12504](https://github.com/PyTorchLightning/pytorch-lightning/pull/12504))


- Raise an error if there are insufficient training batches when using a float value of `limit_train_batches` ([#12885](https://github.com/PyTorchLightning/pytorch-lightning/pull/12885))


-

### Deprecated
Expand Down Expand Up @@ -179,6 +182,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enable mixed precision in `DDPFullyShardedStrategy` when `precision=16` ([#12965](https://github.com/PyTorchLightning/pytorch-lightning/pull/12965))


- Fixed `TQDMProgressBar` reset and update to show correct time estimation ([#12889](https://github.com/PyTorchLightning/pytorch-lightning/pull/12889))


- Fixed an issue causing zero-division error for empty dataloaders ([#12885](https://github.com/PyTorchLightning/pytorch-lightning/pull/12885))


## [1.6.2] - 2022-04-27

### Fixed
Expand All @@ -189,9 +198,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed support for `ModelCheckpoint` monitors with dots ([#12783](https://github.com/PyTorchLightning/pytorch-lightning/pull/12783))


- Fixed `TQDMProgressBar` reset and update to show correct time estimation ([#12889](https://github.com/PyTorchLightning/pytorch-lightning/pull/12889))


## [1.6.1] - 2022-04-13

### Changed
Expand Down
30 changes: 19 additions & 11 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,37 +381,45 @@ def _reset_eval_dataloader(
loader_num_batches = []

# determine number of batches
# datasets could be none, 1 or 2+
module = model or self.trainer.lightning_module or self.datamodule
if len(dataloaders) != 0:
for i, dataloader in enumerate(dataloaders):
orig_num_batches = num_batches = (
len(dataloader) if has_len_all_ranks(dataloader, self.trainer.strategy, module) else float("inf")
)

if orig_num_batches == 0:
loader_num_batches.append(orig_num_batches)
continue

self._worker_check(dataloader, f"{mode.dataloader_prefix}_dataloader {i}")

# percent or num_steps
limit_eval_batches = getattr(self.trainer, f"limit_{mode.dataloader_prefix}_batches")

# limit num batches either as a percent or num steps
if isinstance(limit_eval_batches, int):
num_batches = min(num_batches, int(limit_eval_batches))
elif num_batches != float("inf"):
num_batches = int(num_batches * limit_eval_batches)
num_batches = min(orig_num_batches, limit_eval_batches)
elif isinstance(limit_eval_batches, float) and orig_num_batches != float("inf"):
num_batches = int(orig_num_batches * limit_eval_batches)
elif limit_eval_batches != 1.0:
raise MisconfigurationException(
f"When using an IterableDataset for `limit_{mode}_batches`,"
f" `Trainer(limit_{mode.dataloader_prefix}_batches)` must be `1.0` or an int. An int k"
f" specifies `num_{mode.dataloader_prefix}_batches` to use."
f"When using an `IterableDataset`, `Trainer(limit_{mode.dataloader_prefix}_batches)` must be"
f" `1.0` or an int. An int specifies `num_{mode.dataloader_prefix}_batches` to use."
)

if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
min_pct = 1.0 / len(dataloader)
if (
num_batches == 0
and limit_eval_batches > 0.0
and isinstance(limit_eval_batches, float)
and orig_num_batches != float("inf")
):
min_percentage = 1.0 / orig_num_batches
raise MisconfigurationException(
f"You requested to check {limit_eval_batches} of the `{mode.dataloader_prefix}_dataloader` but"
f" {limit_eval_batches} * {orig_num_batches} < 1. Please increase the"
f" `limit_{mode.dataloader_prefix}_batches` flag. Try at least"
f" `limit_{mode.dataloader_prefix}_batches={min_pct}`"
f" `limit_{mode.dataloader_prefix}_batches` argument. Try at least"
f" `limit_{mode.dataloader_prefix}_batches={min_percentage}`"
)

loader_num_batches.append(num_batches)
Expand Down
31 changes: 23 additions & 8 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,21 +1814,25 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
self.train_dataloader = CombinedLoader(loaders, self._data_connector.multiple_trainloader_mode)

module = model or self.lightning_module or self.datamodule
self.num_training_batches = (
orig_train_batches = self.num_training_batches = (
len(self.train_dataloader)
if has_len_all_ranks(self.train_dataloader, self.strategy, module)
else float("inf")
)
if orig_train_batches == 0:
return

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
self._last_train_dl_reload_epoch = self.current_epoch

if isinstance(self.limit_train_batches, int):
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
self.num_training_batches = min(orig_train_batches, self.limit_train_batches)
elif self.num_training_batches != float("inf"):
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
self.num_training_batches = int(orig_train_batches * self.limit_train_batches)
elif self.limit_train_batches != 1.0:
raise MisconfigurationException(
"When using an IterableDataset for `limit_train_batches`,"
" `Trainer(limit_train_batches)` must be `1.0` or an int. An int k specifies"
" `num_training_batches` to use."
"When using an `IterableDataset`, `Trainer(limit_train_batches)` must be `1.0` or an int."
"An int specifies `num_training_batches` to use."
)

if isinstance(self.val_check_interval, int):
Expand Down Expand Up @@ -1862,8 +1866,19 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
category=PossibleUserWarning,
)

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
self._last_train_dl_reload_epoch = self.current_epoch
if (
self.num_training_batches == 0
and self.limit_train_batches > 0.0
and isinstance(self.limit_train_batches, float)
and orig_train_batches != float("inf")
):
min_percentage = 1.0 / orig_train_batches
raise MisconfigurationException(
f"You requested to check {self.limit_train_batches} of the `train_dataloader` but"
f" {self.limit_train_batches} * {orig_train_batches} < 1. Please increase the"
f" `limit_train_batches` argument. Try at least"
f" `limit_train_batches={min_percentage}`"
)

def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets the validation dataloader and determines the number of batches.
Expand Down
14 changes: 14 additions & 0 deletions tests/trainer/connectors/test_data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,3 +539,17 @@ def val_dataloader(self):
trainer._data_connector.attach_data(model)
trainer.reset_val_dataloader(model)
assert trainer.val_dataloaders[0].sampler.shuffle == shuffle


def test_error_raised_with_insufficient_float_limit_train_dataloader():
batch_size = 16
dl = DataLoader(RandomDataset(32, batch_size * 9), batch_size=batch_size)
trainer = Trainer(limit_train_batches=0.1)
model = BoringModel()

trainer._data_connector.attach_data(model=model, train_dataloaders=dl)
with pytest.raises(
MisconfigurationException,
match="Please increase the `limit_train_batches` argument. Try at least",
):
trainer.reset_train_dataloader(model)
21 changes: 13 additions & 8 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,13 +519,18 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
def test_warning_on_zero_len_dataloader(tmpdir):
"""Test that a warning is raised if a zero-length dataloader is defined."""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=1,
)
dataloader = DataLoader(RandomDataset(32, 0))
with pytest.warns(UserWarning, match="returned 0 length"):
trainer.fit(model, dataloader)
trainer = Trainer()
train_dataloader = DataLoader(RandomDataset(32, 0))
val_dataloader = DataLoader(RandomDataset(32, 0))
trainer._data_connector.attach_data(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

with pytest.warns(UserWarning, match="Total length of `CombinedLoader` across ranks is zero"):
trainer.reset_train_dataloader(model)
assert trainer.num_training_batches == 0

with pytest.warns(UserWarning, match="Total length of `DataLoader` across ranks is zero"):
trainer.reset_val_dataloader(model)
assert trainer.num_val_batches == [0]


@RunIf(skip_windows=True)
Expand Down Expand Up @@ -962,7 +967,7 @@ def test_inf_dataloader_raise_error_with_partial_batch_limits(tmpdir, stage, dat
trainer = Trainer(**trainer_kwargs)
trainer_fn = "fit" if stage == RunningStage.TRAINING else stage.value

with pytest.raises(MisconfigurationException, match=r"using an IterableDataset .* must be `1.0` or an int"):
with pytest.raises(MisconfigurationException, match=r"IterableDataset`.*limit_.*_batches\)`.*`1.0` or an int"):
getattr(trainer, trainer_fn)(model)


Expand Down
52 changes: 0 additions & 52 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,58 +1416,6 @@ def test_predict_return_predictions_cpu(return_predictions, precision, tmpdir):
assert preds[0].dtype == (torch.float64 if precision == 64 else torch.float32)


@pytest.mark.parametrize(
["limit_train_batches", "global_step", "num_training_batches", "current_epoch", "should_train"],
[(0.2, 0, 0, 0, False), (0.5, 10, 2, 5, True)],
)
def test_disabled_training_for_insufficient_limit_train_batches(
tmpdir, limit_train_batches, global_step, num_training_batches, current_epoch, should_train
):
"""Verify when `limit_train_batches` is float & between [0.0, 1.0] and.
`int(self.num_training_batches * self.limit_train_batches) == 0`, the training loop is disabled.
"""

class CurrentModel(BoringModel):

training_step_invoked = False
training_epoch_end_invoked = False

def training_step(self, *args, **kwargs):
self.training_step_invoked = True
return super().training_step(*args, **kwargs)

def training_epoch_end(self, *args, **kwargs):
self.training_epoch_end_invoked = True
return super().training_epoch_end(*args, **kwargs)

dataset_len = 100
batch_size = 25

train = RandomDataset(32, length=dataset_len)
train_loader = DataLoader(train, batch_size=batch_size)

model = CurrentModel()

trainer = Trainer(default_root_dir=tmpdir, max_epochs=5, limit_train_batches=limit_train_batches)
trainer.fit(model, train_loader)

params_string = f"""`limit_train_batches={limit_train_batches}`, `dataset_len={dataset_len}`
& `batch_size={batch_size}` as
`num_training_batches={num_training_batches}`"""
if should_train:
error_string = f"should run with {params_string}"
else:
error_string = f"should not run with {params_string}"

assert trainer.state.finished, f"Training failed with {trainer.state}"
assert trainer.global_step == global_step
assert trainer.num_training_batches == num_training_batches
assert trainer.current_epoch == current_epoch
assert model.training_step_invoked == should_train, f"`training_step` {error_string}"
assert model.training_epoch_end_invoked == should_train, f"`training_epoch_end` {error_string}"


@pytest.mark.parametrize(["max_steps", "max_epochs", "global_step"], [(10, 5, 10), (20, None, 20)])
def test_repeated_fit_calls_with_max_epochs_and_steps(tmpdir, max_steps, max_epochs, global_step):
"""Ensure that the training loop is bound by `max_steps` and `max_epochs` for repeated calls of `trainer.fit`,
Expand Down

0 comments on commit 9bfbd9e

Please sign in to comment.