Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce has_len_all_ranks() to check the length of dataloader across ranks #9827

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e31710f
introduce , udpate tests
ninginthecloud Oct 6, 2021
87a5a63
update CHANGELOG.md
ninginthecloud Oct 6, 2021
4ac7426
change staticmethod and hook attribute naming
ninginthecloud Oct 8, 2021
8860149
Merge branch 'master' into fix/valueerror_has_len_9785_2
ninginthecloud Oct 25, 2021
fda4c65
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2021
43057d3
fix typo
ninginthecloud Oct 26, 2021
aa78bb5
remove non-essential comment
ninginthecloud Oct 26, 2021
6b07d00
Merge branch 'master' into fix/valueerror_has_len_9785_2
ninginthecloud Oct 26, 2021
19fc124
fix merge error and comment format
ninginthecloud Oct 27, 2021
d41c114
try to fix test_tpu.py failure
ninginthecloud Oct 27, 2021
9e78e0d
Merge branch 'master' into fix/valueerror_has_len_9785_2
awaelchli Oct 28, 2021
51d8f7e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2021
927dfde
Merge branch 'master' into fix/valueerror_has_len_9785_2
tchaton Nov 1, 2021
079bf17
Merge branch 'master' into fix/valueerror_has_len_9785_2
rohitgr7 Nov 2, 2021
09107e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2021
a8f8435
update on comments
rohitgr7 Nov 2, 2021
e0b376e
chlog
rohitgr7 Nov 2, 2021
ee691fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2021
03e854d
chlog
rohitgr7 Nov 2, 2021
5af44d2
update
rohitgr7 Nov 2, 2021
51cac01
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2021
c5a11b9
try fix
rohitgr7 Nov 2, 2021
0b0a1e8
Revert back TPUSpawn changes
kaushikb11 Nov 2, 2021
44835b0
Update test
kaushikb11 Nov 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Raise an exception if using `amp_level` with native `amp_backend` ([#9755](https://github.com/PyTorchLightning/pytorch-lightning/pull/9755))

- Raise MisconfigurationException when total length of `dataloader` across ranks is zero, and give warning when total length is non-zero, but only local rank length is zero. ([#9821](https://github.com/PyTorchLightning/pytorch-lightning/pull/9821))
ninginthecloud marked this conversation as resolved.
Show resolved Hide resolved


- Update the logic to check for accumulation steps with deepspeed ([#9826](https://github.com/PyTorchLightning/pytorch-lightning/pull/9826))

Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,13 @@ def __init__(self) -> None:
prepare_data_per_node:
If True, each LOCAL_RANK=0 will call prepare data.
Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.
allow_zero_length_dataloader_with_multiple_devices:
If True, dataloader with zero length within local rank is allowed.
Default value is False.
"""
super().__init__()
self.prepare_data_per_node: bool = True
self.allow_zero_length_dataloader_with_multiple_devices: bool = False

def prepare_data(self) -> None:
"""Use this to download and prepare data.
Expand Down
34 changes: 21 additions & 13 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
set_shared_parameters,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.data import has_len_all_ranks
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -82,35 +82,40 @@ def world_size(self) -> int:
def root_device(self) -> torch.device:
return xm.xla_device()

@staticmethod
def _validate_dataloader(dataloaders: Union[List[DataLoader], DataLoader]) -> None:
def _validate_dataloader(
ninginthecloud marked this conversation as resolved.
Show resolved Hide resolved
self,
dataloaders: Union[List[DataLoader], DataLoader],
module: Union["pl.LightningModule", "pl.LightningDataModule"],
) -> None:
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]

for dataloader in dataloaders:
if not has_len(dataloader):
if not has_len_all_ranks(dataloader, self, module):
raise MisconfigurationException(
"TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`."
" HINT: You can mock the length on your dataset to bypass this MisconfigurationException."
)

@staticmethod
def _validate_patched_dataloaders(model: Module) -> None:
def _validate_patched_dataloaders(
self,
model: Module,
) -> None:
"""Validate and fail fast if the dataloaders were passed directly to fit."""
if hasattr(model, "train_dataloader") and isinstance(model.train_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.train_dataloader.dataloader)
self._validate_dataloader(model.train_dataloader.dataloader, model)

if hasattr(model, "val_dataloader") and isinstance(model.val_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.val_dataloader.dataloader)
self._validate_dataloader(model.val_dataloader.dataloader, model)

if hasattr(model, "test_dataloader") and isinstance(model.test_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.test_dataloader.dataloader)
self._validate_dataloader(model.test_dataloader.dataloader, model)

if hasattr(model, "predict_dataloader") and isinstance(model.predict_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.predict_dataloader.dataloader)
self._validate_dataloader(model.predict_dataloader.dataloader, model)

def connect(self, model: "pl.LightningModule") -> None:
TPUSpawnPlugin._validate_patched_dataloaders(model)
self._validate_patched_dataloaders(model)
self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model))
return super().connect(model)

Expand All @@ -135,8 +140,11 @@ def is_distributed(self) -> bool:
# HOST_WORLD_SIZE is None outside the xmp.spawn process
return os.getenv(xenv.HOST_WORLD_SIZE, None) and self.world_size != 1

def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader:
TPUSpawnPlugin._validate_dataloader(dataloader)
def process_dataloader(
self,
dataloader: DataLoader,
) -> MpDeviceLoader:
self._validate_dataloader(dataloader, self.lightning_module)
return MpDeviceLoader(dataloader, self.root_device)

def configure_ddp(self) -> None:
Expand Down
18 changes: 14 additions & 4 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
CaptureMapDataset,
FastForwardSampler,
)
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len_all_ranks
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -344,7 +344,12 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
self.train_dataloader = CombinedLoader(self.train_dataloader, self.data_connector.multiple_trainloader_mode)

self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float("inf")
module = model or self.lightning_module or self.datamodule
self.num_training_batches = (
len(self.train_dataloader)
if has_len_all_ranks(self.train_dataloader, self.training_type_plugin, module)
else float("inf")
)

if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
Expand All @@ -369,7 +374,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
"If you want to disable validation set `limit_val_batches` to 0.0 instead."
)
else:
if not has_len(self.train_dataloader):
if not has_len_all_ranks(self.train_dataloader, self.training_type_plugin, module):
if self.val_check_interval == 1.0:
self.val_check_batch = float("inf")
else:
Expand Down Expand Up @@ -448,9 +453,14 @@ def _reset_eval_dataloader(

# determine number of batches
# datasets could be none, 1 or 2+
module = model or self.lightning_module or self.datamodule
if len(dataloaders) != 0:
for i, dataloader in enumerate(dataloaders):
num_batches = len(dataloader) if has_len(dataloader) else float("inf")
num_batches = (
len(dataloader)
if has_len_all_ranks(dataloader, self.training_type_plugin, module)
else float("inf")
)
self._worker_check(dataloader, f"{mode.dataloader_prefix}_dataloader {i}")

# percent or num_steps
Expand Down
11 changes: 7 additions & 4 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
import uuid
from typing import Optional, Tuple

from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.data import has_len_all_ranks
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error
from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr
Expand Down Expand Up @@ -255,13 +257,14 @@ def _adjust_batch_size(
if desc:
log.info(f"Batch size {batch_size} {desc}, trying batch size {new_size}")

if not _is_valid_batch_size(new_size, trainer.train_dataloader):
if not _is_valid_batch_size(new_size, trainer.train_dataloader, trainer):
new_size = min(new_size, len(trainer.train_dataloader.dataset))

changed = new_size != batch_size
lightning_setattr(model, batch_arg_name, new_size)
return new_size, changed


def _is_valid_batch_size(current_size, dataloader):
return not has_len(dataloader) or current_size <= len(dataloader)
def _is_valid_batch_size(current_size: int, dataloader: DataLoader, trainer: "pl.Trainer"):
ninginthecloud marked this conversation as resolved.
Show resolved Hide resolved
module = trainer.lightning_module or trainer.datamodule
return not has_len_all_ranks(dataloader, trainer.training_type_plugin, module) or current_size <= len(dataloader)
50 changes: 50 additions & 0 deletions pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import torch
from torch.utils.data import DataLoader, IterableDataset

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

BType = Union[torch.Tensor, str, Mapping[Any, "BType"], Iterable["BType"]]

Expand Down Expand Up @@ -75,6 +77,54 @@ def has_len(dataloader: DataLoader) -> bool:
return has_len


def has_len_all_ranks(
dataloader: DataLoader,
training_type: "pl.TrainingTypePlugin",
model: Union["pl.LightningModule", "pl.LightningDataModule"],
) -> bool:
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
infinite dataloader.

Raises:
ValueError:
If the length of Dataloader is 0, as it requires at least one batch
"""
try:
total_length = training_type.reduce(torch.tensor(len(dataloader)).to(model.device), reduce_op="sum")
local_length = len(dataloader)

if total_length == 0:
raise MisconfigurationException(
" Total length of `Dataloader` across ranks is zero. Please make sure that it returns at least 1 batch"
ninginthecloud marked this conversation as resolved.
Show resolved Hide resolved
)
if total_length > 0 and local_length == 0:
if model.allow_zero_length_dataloader_with_ddp:
rank_zero_warn(
"Total length of `Dataloader` across ranks is zero, but local rank has zero length. "
"Please be cautious of uneven batch length. "
)
ninginthecloud marked this conversation as resolved.
Show resolved Hide resolved
has_len = False
else:
raise MisconfigurationException(
"`Dataloader` within local rank has zero length. Please make sure that it returns at least 1 batch"
ninginthecloud marked this conversation as resolved.
Show resolved Hide resolved
)
else:
has_len = True
except TypeError:
has_len = False
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
has_len = False
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

if has_len and has_iterable_dataset(dataloader):
rank_zero_warn(
ninginthecloud marked this conversation as resolved.
Show resolved Hide resolved
"Your `IterableDataset` has `__len__` defined."
" In combination with multi-process data loading (when num_workers > 1),"
" `__len__` could be inaccurate if each worker is not configured independently"
" to avoid having duplicate data."
)
return has_len


def get_len(dataloader: DataLoader) -> Union[int, float]:
"""Return the length of the given DataLoader.

Expand Down
12 changes: 6 additions & 6 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import tests.helpers.pipelines as tpipes
from pytorch_lightning import Callback, seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len_all_ranks
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset, RandomIterableDatasetWithLen
Expand Down Expand Up @@ -265,7 +265,7 @@ def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches,

num_batches = 128 / batch_size
for dl in (train_dl, val_dl, test_dl):
if has_len(dl):
if has_len_all_ranks(dl, trainer.training_type_plugin, model):
assert len(dl) == num_batches
else:
assert sum(1 for _ in dl) == num_batches
Expand Down Expand Up @@ -846,10 +846,10 @@ def __len__(self):
return len(original_dataset)

# with __len__ defined
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
dataloader = DataLoader(IterableWithLen(), batch_size=16)
assert has_len(dataloader)
assert has_len_all_ranks(dataloader, trainer.training_type_plugin, model)
assert has_iterable_dataset(dataloader)
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
with pytest.warns(UserWarning, match="Your `IterableDataset` has `__len__` defined."):
trainer.validate(model, val_dataloaders=[dataloader])
with pytest.warns(UserWarning, match="Your `IterableDataset` has `__len__` defined."):
Expand All @@ -860,10 +860,10 @@ def __len__(self):
trainer.predict(model, dataloaders=[dataloader])

# without __len__ defined
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
dataloader = DataLoader(IterableWithoutLen(), batch_size=16)
assert not has_len(dataloader)
assert not has_len_all_ranks(dataloader, trainer.training_type_plugin, model)
assert has_iterable_dataset(dataloader)
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
trainer.validate(model, val_dataloaders=dataloader)
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
trainer.test(model, test_dataloaders=dataloader)
Expand Down
22 changes: 20 additions & 2 deletions tests/utilities/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@
import torch
from torch.utils.data.dataloader import DataLoader

from pytorch_lightning.utilities.data import extract_batch_size, get_len, has_iterable_dataset, has_len
from tests.helpers.boring_model import RandomDataset, RandomIterableDataset
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.data import (
extract_batch_size,
get_len,
has_iterable_dataset,
has_len,
has_len_all_ranks,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset


def test_extract_batch_size():
Expand Down Expand Up @@ -53,3 +61,13 @@ def test_get_len():

assert isinstance(value, float)
assert value == float("inf")


def test_has_len_all_rank():
ninginthecloud marked this conversation as resolved.
Show resolved Hide resolved
trainer = Trainer(fast_dev_run=True)
model = BoringModel()

with pytest.raises(MisconfigurationException, match="Total length of `Dataloader` across ranks is zero."):
assert not has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.training_type_plugin, model)

assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.training_type_plugin, model)