Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix zero division error for empty dataloaders #12885

Merged
merged 11 commits into from
May 3, 2022
12 changes: 9 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Make positional arguments required for classes passed into the `add_argparse_args` function. ([#12504](https://github.com/PyTorchLightning/pytorch-lightning/pull/12504))


- Raise an error if there are insufficient training batches when using a float value of `limit_train_batches` ([#12885](https://github.com/PyTorchLightning/pytorch-lightning/pull/12885))


-

### Deprecated
Expand Down Expand Up @@ -179,6 +182,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enable mixed precision in `DDPFullyShardedStrategy` when `precision=16` ([#12965](https://github.com/PyTorchLightning/pytorch-lightning/pull/12965))


- Fixed `TQDMProgressBar` reset and update to show correct time estimation ([#12889](https://github.com/PyTorchLightning/pytorch-lightning/pull/12889))


- Fixed an issue causing zero-division error for empty dataloaders ([#12885](https://github.com/PyTorchLightning/pytorch-lightning/pull/12885))


## [1.6.2] - 2022-04-27

### Fixed
Expand All @@ -189,9 +198,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed support for `ModelCheckpoint` monitors with dots ([#12783](https://github.com/PyTorchLightning/pytorch-lightning/pull/12783))


- Fixed `TQDMProgressBar` reset and update to show correct time estimation ([#12889](https://github.com/PyTorchLightning/pytorch-lightning/pull/12889))


## [1.6.1] - 2022-04-13

### Changed
Expand Down
30 changes: 19 additions & 11 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,37 +381,45 @@ def _reset_eval_dataloader(
loader_num_batches = []

# determine number of batches
# datasets could be none, 1 or 2+
module = model or self.trainer.lightning_module or self.datamodule
if len(dataloaders) != 0:
for i, dataloader in enumerate(dataloaders):
orig_num_batches = num_batches = (
len(dataloader) if has_len_all_ranks(dataloader, self.trainer.strategy, module) else float("inf")
)

if orig_num_batches == 0:
loader_num_batches.append(orig_num_batches)
continue

self._worker_check(dataloader, f"{mode.dataloader_prefix}_dataloader {i}")

# percent or num_steps
limit_eval_batches = getattr(self.trainer, f"limit_{mode.dataloader_prefix}_batches")

# limit num batches either as a percent or num steps
if isinstance(limit_eval_batches, int):
num_batches = min(num_batches, int(limit_eval_batches))
elif num_batches != float("inf"):
num_batches = int(num_batches * limit_eval_batches)
num_batches = min(orig_num_batches, limit_eval_batches)
elif isinstance(limit_eval_batches, float) and orig_num_batches != float("inf"):
num_batches = int(orig_num_batches * limit_eval_batches)
elif limit_eval_batches != 1.0:
raise MisconfigurationException(
f"When using an IterableDataset for `limit_{mode}_batches`,"
f" `Trainer(limit_{mode.dataloader_prefix}_batches)` must be `1.0` or an int. An int k"
f" specifies `num_{mode.dataloader_prefix}_batches` to use."
f"When using an `IterableDataset`, `Trainer(limit_{mode.dataloader_prefix}_batches)` must be"
f" `1.0` or an int. An int specifies `num_{mode.dataloader_prefix}_batches` to use."
)

if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
min_pct = 1.0 / len(dataloader)
if (
num_batches == 0
and limit_eval_batches > 0.0
and isinstance(limit_eval_batches, float)
and orig_num_batches != float("inf")
):
min_percentage = 1.0 / orig_num_batches
raise MisconfigurationException(
f"You requested to check {limit_eval_batches} of the `{mode.dataloader_prefix}_dataloader` but"
f" {limit_eval_batches} * {orig_num_batches} < 1. Please increase the"
f" `limit_{mode.dataloader_prefix}_batches` flag. Try at least"
f" `limit_{mode.dataloader_prefix}_batches={min_pct}`"
f" `limit_{mode.dataloader_prefix}_batches` argument. Try at least"
f" `limit_{mode.dataloader_prefix}_batches={min_percentage}`"
)

loader_num_batches.append(num_batches)
Expand Down
31 changes: 23 additions & 8 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,21 +1814,25 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
self.train_dataloader = CombinedLoader(loaders, self._data_connector.multiple_trainloader_mode)

module = model or self.lightning_module or self.datamodule
self.num_training_batches = (
orig_train_batches = self.num_training_batches = (
len(self.train_dataloader)
if has_len_all_ranks(self.train_dataloader, self.strategy, module)
else float("inf")
)
if orig_train_batches == 0:
return

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
self._last_train_dl_reload_epoch = self.current_epoch
carmocca marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(self.limit_train_batches, int):
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
self.num_training_batches = min(orig_train_batches, self.limit_train_batches)
elif self.num_training_batches != float("inf"):
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
self.num_training_batches = int(orig_train_batches * self.limit_train_batches)
elif self.limit_train_batches != 1.0:
raise MisconfigurationException(
"When using an IterableDataset for `limit_train_batches`,"
" `Trainer(limit_train_batches)` must be `1.0` or an int. An int k specifies"
" `num_training_batches` to use."
"When using an `IterableDataset`, `Trainer(limit_train_batches)` must be `1.0` or an int."
"An int specifies `num_training_batches` to use."
)

if isinstance(self.val_check_interval, int):
Expand Down Expand Up @@ -1862,8 +1866,19 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
category=PossibleUserWarning,
)

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
self._last_train_dl_reload_epoch = self.current_epoch
if (
self.num_training_batches == 0
and self.limit_train_batches > 0.0
and isinstance(self.limit_train_batches, float)
and orig_train_batches != float("inf")
):
min_percentage = 1.0 / orig_train_batches
raise MisconfigurationException(
f"You requested to check {self.limit_train_batches} of the `train_dataloader` but"
f" {self.limit_train_batches} * {orig_train_batches} < 1. Please increase the"
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
f" `limit_train_batches` argument. Try at least"
f" `limit_train_batches={min_percentage}`"
)

def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets the validation dataloader and determines the number of batches.
Expand Down
14 changes: 14 additions & 0 deletions tests/trainer/connectors/test_data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,3 +539,17 @@ def val_dataloader(self):
trainer._data_connector.attach_data(model)
trainer.reset_val_dataloader(model)
assert trainer.val_dataloaders[0].sampler.shuffle == shuffle


def test_error_raised_with_insufficient_float_limit_train_dataloader():
batch_size = 16
dl = DataLoader(RandomDataset(32, batch_size * 9), batch_size=batch_size)
trainer = Trainer(limit_train_batches=0.1)
model = BoringModel()

trainer._data_connector.attach_data(model=model, train_dataloaders=dl)
with pytest.raises(
MisconfigurationException,
match="Please increase the `limit_train_batches` argument. Try at least",
):
trainer.reset_train_dataloader(model)
21 changes: 13 additions & 8 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,13 +519,18 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
def test_warning_on_zero_len_dataloader(tmpdir):
"""Test that a warning is raised if a zero-length dataloader is defined."""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=1,
)
dataloader = DataLoader(RandomDataset(32, 0))
with pytest.warns(UserWarning, match="returned 0 length"):
trainer.fit(model, dataloader)
trainer = Trainer()
train_dataloader = DataLoader(RandomDataset(32, 0))
val_dataloader = DataLoader(RandomDataset(32, 0))
trainer._data_connector.attach_data(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

with pytest.warns(UserWarning, match="Total length of `CombinedLoader` across ranks is zero"):
trainer.reset_train_dataloader(model)
assert trainer.num_training_batches == 0

with pytest.warns(UserWarning, match="Total length of `DataLoader` across ranks is zero"):
trainer.reset_val_dataloader(model)
assert trainer.num_val_batches == [0]


@RunIf(skip_windows=True)
Expand Down Expand Up @@ -962,7 +967,7 @@ def test_inf_dataloader_raise_error_with_partial_batch_limits(tmpdir, stage, dat
trainer = Trainer(**trainer_kwargs)
trainer_fn = "fit" if stage == RunningStage.TRAINING else stage.value

with pytest.raises(MisconfigurationException, match=r"using an IterableDataset .* must be `1.0` or an int"):
with pytest.raises(MisconfigurationException, match=r"IterableDataset`.*limit_.*_batches\)`.*`1.0` or an int"):
getattr(trainer, trainer_fn)(model)


Expand Down
52 changes: 0 additions & 52 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,58 +1416,6 @@ def test_predict_return_predictions_cpu(return_predictions, precision, tmpdir):
assert preds[0].dtype == (torch.float64 if precision == 64 else torch.float32)


@pytest.mark.parametrize(
["limit_train_batches", "global_step", "num_training_batches", "current_epoch", "should_train"],
[(0.2, 0, 0, 0, False), (0.5, 10, 2, 5, True)],
)
def test_disabled_training_for_insufficient_limit_train_batches(
carmocca marked this conversation as resolved.
Show resolved Hide resolved
tmpdir, limit_train_batches, global_step, num_training_batches, current_epoch, should_train
):
"""Verify when `limit_train_batches` is float & between [0.0, 1.0] and.

`int(self.num_training_batches * self.limit_train_batches) == 0`, the training loop is disabled.
"""

class CurrentModel(BoringModel):

training_step_invoked = False
training_epoch_end_invoked = False

def training_step(self, *args, **kwargs):
self.training_step_invoked = True
return super().training_step(*args, **kwargs)

def training_epoch_end(self, *args, **kwargs):
self.training_epoch_end_invoked = True
return super().training_epoch_end(*args, **kwargs)

dataset_len = 100
batch_size = 25

train = RandomDataset(32, length=dataset_len)
train_loader = DataLoader(train, batch_size=batch_size)

model = CurrentModel()

trainer = Trainer(default_root_dir=tmpdir, max_epochs=5, limit_train_batches=limit_train_batches)
trainer.fit(model, train_loader)

params_string = f"""`limit_train_batches={limit_train_batches}`, `dataset_len={dataset_len}`
& `batch_size={batch_size}` as
`num_training_batches={num_training_batches}`"""
if should_train:
error_string = f"should run with {params_string}"
else:
error_string = f"should not run with {params_string}"

assert trainer.state.finished, f"Training failed with {trainer.state}"
assert trainer.global_step == global_step
assert trainer.num_training_batches == num_training_batches
assert trainer.current_epoch == current_epoch
assert model.training_step_invoked == should_train, f"`training_step` {error_string}"
assert model.training_epoch_end_invoked == should_train, f"`training_epoch_end` {error_string}"


@pytest.mark.parametrize(["max_steps", "max_epochs", "global_step"], [(10, 5, 10), (20, None, 20)])
def test_repeated_fit_calls_with_max_epochs_and_steps(tmpdir, max_steps, max_epochs, global_step):
"""Ensure that the training loop is bound by `max_steps` and `max_epochs` for repeated calls of `trainer.fit`,
Expand Down