diff --git a/CHANGELOG.md b/CHANGELOG.md index 209a8a4671028..8dfd49b605941 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -178,12 +178,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed default value of the `max_steps` Trainer argument from `None` to -1 ([#9460](https://github.com/PyTorchLightning/pytorch-lightning/pull/9460)) - LightningModule now raises an error when calling `log(on_step=False, on_epoch=False)` ([#10227](https://github.com/PyTorchLightning/pytorch-lightning/pull/10227)) - Quantization aware training observers are now disabled by default during validating/testing/predicting stages ([#8540](https://github.com/PyTorchLightning/pytorch-lightning/pull/8540)) +- Raised `MisconfigurationException` when total length of `dataloader` across ranks is zero, and give warning when total length is non-zero, but only local rank length is zero. ([#9827](https://github.com/PyTorchLightning/pytorch-lightning/pull/9827)) - Changed the model size calculation using `ByteCounter` ([#10123](https://github.com/PyTorchLightning/pytorch-lightning/pull/10123)) - - - Enabled `on_load_checkpoint` for `LightningDataModule` for all `trainer_fn` ([#10238](https://github.com/PyTorchLightning/pytorch-lightning/pull/10238)) - - - Allow separate config files for parameters with class type when LightningCLI is in subclass_mode=False ([#10286](https://github.com/PyTorchLightning/pytorch-lightning/pull/10286)) @@ -220,8 +217,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `ClusterEnvironment.creates_children()` in favor of `ClusterEnvironment.creates_processes_externally` (property) ([#10106](https://github.com/PyTorchLightning/pytorch-lightning/pull/10106)) - Deprecated `PrecisionPlugin.master_params()` in favor of `PrecisionPlugin.main_params()` ([#10105](https://github.com/PyTorchLightning/pytorch-lightning/pull/10105)) - Deprecated `lr_sch_names` from `LearningRateMonitor` ([#10066](https://github.com/PyTorchLightning/pytorch-lightning/pull/10066)) - - - Deprecated `ProgressBar` callback in favor of `TQDMProgressBar` ([#10134](https://github.com/PyTorchLightning/pytorch-lightning/pull/10134)) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 96c7c333f795f..d804be4a7cab2 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -314,9 +314,13 @@ def __init__(self) -> None: prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data. + allow_zero_length_dataloader_with_multiple_devices: + If True, dataloader with zero length within local rank is allowed. + Default value is False. """ super().__init__() self.prepare_data_per_node: bool = True + self.allow_zero_length_dataloader_with_multiple_devices: bool = False def prepare_data(self) -> None: """Use this to download and prepare data. diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 071eead5613b4..e149aef9a7997 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -37,7 +37,7 @@ CaptureMapDataset, FastForwardSampler, ) -from pytorch_lightning.utilities.data import has_iterable_dataset, has_len +from pytorch_lightning.utilities.data import has_iterable_dataset, has_len_all_ranks from pytorch_lightning.utilities.enums import DistributedType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training @@ -346,7 +346,12 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches self.train_dataloader = CombinedLoader(self.train_dataloader, self._data_connector.multiple_trainloader_mode) - self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float("inf") + module = model or self.lightning_module or self.datamodule + self.num_training_batches = ( + len(self.train_dataloader) + if has_len_all_ranks(self.train_dataloader, self.training_type_plugin, module) + else float("inf") + ) if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches)) @@ -371,7 +376,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - "If you want to disable validation set `limit_val_batches` to 0.0 instead." ) else: - if not has_len(self.train_dataloader): + if not has_len_all_ranks(self.train_dataloader, self.training_type_plugin, module): if self.val_check_interval == 1.0: self.val_check_batch = float("inf") else: @@ -452,9 +457,14 @@ def _reset_eval_dataloader( # determine number of batches # datasets could be none, 1 or 2+ + module = model or self.lightning_module or self.datamodule if len(dataloaders) != 0: for i, dataloader in enumerate(dataloaders): - num_batches = len(dataloader) if has_len(dataloader) else float("inf") + num_batches = ( + len(dataloader) + if has_len_all_ranks(dataloader, self.training_type_plugin, module) + else float("inf") + ) self._worker_check(dataloader, f"{mode.dataloader_prefix}_dataloader {i}") # percent or num_steps diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 721c5e293bae8..faf2ee4f5bb9c 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -16,11 +16,13 @@ import uuid from typing import Optional, Tuple +from torch.utils.data import DataLoader + import pytorch_lightning as pl from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem -from pytorch_lightning.utilities.data import has_len +from pytorch_lightning.utilities.data import has_len_all_ranks from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr @@ -257,7 +259,7 @@ def _adjust_batch_size( if desc: log.info(f"Batch size {batch_size} {desc}, trying batch size {new_size}") - if not _is_valid_batch_size(new_size, trainer.train_dataloader): + if not _is_valid_batch_size(new_size, trainer.train_dataloader, trainer): new_size = min(new_size, len(trainer.train_dataloader.dataset)) changed = new_size != batch_size @@ -265,5 +267,6 @@ def _adjust_batch_size( return new_size, changed -def _is_valid_batch_size(current_size, dataloader): - return not has_len(dataloader) or current_size <= len(dataloader) +def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer"): + module = trainer.lightning_module or trainer.datamodule + return not has_len_all_ranks(dataloader, trainer.training_type_plugin, module) or batch_size <= len(dataloader) diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index aa9e9b401f40a..a75afa775848b 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -17,7 +17,9 @@ import torch from torch.utils.data import DataLoader, IterableDataset +import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.warnings import WarningCache BType = Union[torch.Tensor, str, Mapping[Any, "BType"], Iterable["BType"]] @@ -93,6 +95,55 @@ def has_len(dataloader: Union[DataLoader, Iterable]) -> bool: return has_len +def has_len_all_ranks( + dataloader: DataLoader, + training_type: "pl.TrainingTypePlugin", + model: Union["pl.LightningModule", "pl.LightningDataModule"], +) -> bool: + """Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or + infinite dataloader. + + Raises: + ValueError: + If the length of Dataloader is 0, as it requires at least one batch + """ + try: + total_length = training_type.reduce(torch.tensor(len(dataloader)).to(model.device), reduce_op="sum") + local_length = len(dataloader) + + if total_length == 0: + raise MisconfigurationException( + "Total length of `Dataloader` across ranks is zero. Please make sure that it returns at least 1 batch." + ) + if total_length > 0 and local_length == 0: + if model.allow_zero_length_dataloader_with_multiple_devices: + rank_zero_warn( + "Total length of `Dataloader` across ranks is zero, but local rank has zero length." + " Please be cautious of uneven batch length." + ) + has_len = False + else: + raise MisconfigurationException( + "`Dataloader` within local rank has zero length. Please make sure that it returns at least 1 batch." + ) + else: + has_len = True + + except TypeError: + has_len = False + except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used + has_len = False + + if has_len and has_iterable_dataset(dataloader): + rank_zero_warn( + "Your `IterableDataset` has `__len__` defined." + " In combination with multi-process data loading (when num_workers > 1)," + " `__len__` could be inaccurate if each worker is not configured independently" + " to avoid having duplicate data." + ) + return has_len + + def get_len(dataloader: DataLoader) -> Union[int, float]: """Return the length of the given DataLoader. diff --git a/tests/accelerators/test_tpu.py b/tests/accelerators/test_tpu.py index 8400adb3a11da..78e4c505bb99a 100644 --- a/tests/accelerators/test_tpu.py +++ b/tests/accelerators/test_tpu.py @@ -308,7 +308,6 @@ def test_xla_checkpoint_plugin_being_default(): def test_mp_device_dataloader_attribute(_): dataset = RandomDataset(32, 64) dataloader = TPUSpawnPlugin().process_dataloader(DataLoader(dataset)) - assert dataloader.dataset == dataset diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index ea31dbaf7d0a1..533eceb8018db 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -26,7 +26,7 @@ import tests.helpers.pipelines as tpipes from pytorch_lightning import Callback, seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.utilities.data import has_iterable_dataset, has_len +from pytorch_lightning.utilities.data import has_iterable_dataset, has_len_all_ranks from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset, RandomIterableDatasetWithLen @@ -265,7 +265,7 @@ def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, num_batches = 128 / batch_size for dl in (train_dl, val_dl, test_dl): - if has_len(dl): + if has_len_all_ranks(dl, trainer.training_type_plugin, model): assert len(dl) == num_batches else: assert sum(1 for _ in dl) == num_batches @@ -855,10 +855,10 @@ def __len__(self): return len(original_dataset) # with __len__ defined + trainer = Trainer(default_root_dir=tmpdir, max_steps=3) dataloader = DataLoader(IterableWithLen(), batch_size=16) - assert has_len(dataloader) + assert has_len_all_ranks(dataloader, trainer.training_type_plugin, model) assert has_iterable_dataset(dataloader) - trainer = Trainer(default_root_dir=tmpdir, max_steps=3) with pytest.warns(UserWarning, match="Your `IterableDataset` has `__len__` defined."): trainer.validate(model, val_dataloaders=[dataloader]) with pytest.warns(UserWarning, match="Your `IterableDataset` has `__len__` defined."): @@ -869,10 +869,10 @@ def __len__(self): trainer.predict(model, dataloaders=[dataloader]) # without __len__ defined + trainer = Trainer(default_root_dir=tmpdir, max_steps=3) dataloader = DataLoader(IterableWithoutLen(), batch_size=16) - assert not has_len(dataloader) + assert not has_len_all_ranks(dataloader, trainer.training_type_plugin, model) assert has_iterable_dataset(dataloader) - trainer = Trainer(default_root_dir=tmpdir, max_steps=3) trainer.validate(model, val_dataloaders=dataloader) trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader]) trainer.test(model, test_dataloaders=dataloader) diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py index 66a4c2790f720..acbe645515f55 100644 --- a/tests/utilities/test_data.py +++ b/tests/utilities/test_data.py @@ -2,8 +2,17 @@ import torch from torch.utils.data.dataloader import DataLoader -from pytorch_lightning.utilities.data import extract_batch_size, get_len, has_iterable_dataset, has_len, warning_cache -from tests.helpers.boring_model import RandomDataset, RandomIterableDataset +from pytorch_lightning import Trainer +from pytorch_lightning.utilities.data import ( + extract_batch_size, + get_len, + has_iterable_dataset, + has_len, + has_len_all_ranks, + warning_cache, +) +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset def test_extract_batch_size(): @@ -73,3 +82,13 @@ def test_get_len(): assert isinstance(value, float) assert value == float("inf") + + +def test_has_len_all_rank(): + trainer = Trainer(fast_dev_run=True) + model = BoringModel() + + with pytest.raises(MisconfigurationException, match="Total length of `Dataloader` across ranks is zero."): + assert not has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.training_type_plugin, model) + + assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.training_type_plugin, model)