From 85cd558a3f71a4e7559bda9a31c8e416d4297a16 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Fri, 28 Aug 2020 00:58:29 +0530 Subject: [PATCH] Follow up of #2892 (#3202) * Follow up of #2892 * typo * iterabledataset --- pytorch_lightning/trainer/data_loading.py | 45 +++------------------- pytorch_lightning/utilities/data.py | 47 +++++++++++++++++++++++ tests/trainer/test_dataloaders.py | 8 ++-- 3 files changed, 56 insertions(+), 44 deletions(-) create mode 100644 pytorch_lightning/utilities/data.py diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index ea564c950d9e7..e8914d66d6dad 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -15,24 +15,18 @@ import multiprocessing import platform from abc import ABC, abstractmethod -from distutils.version import LooseVersion from typing import Union, List, Tuple, Callable, Optional -import torch import torch.distributed as torch_distrib from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler from pytorch_lightning.core import LightningModule from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.debugging import InternalDebugger -try: - from torch.utils.data import IterableDataset - ITERABLE_DATASET_EXISTS = True -except ImportError: - ITERABLE_DATASET_EXISTS = False try: from apex import amp @@ -56,35 +50,6 @@ HOROVOD_AVAILABLE = True -def _has_iterable_dataset(dataloader: DataLoader): - return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \ - and isinstance(dataloader.dataset, IterableDataset) - - -def _has_len(dataloader: DataLoader) -> bool: - """ Checks if a given Dataloader has __len__ method implemented i.e. if - it is a finite dataloader or infinite dataloader. """ - - try: - # try getting the length - if len(dataloader) == 0: - raise ValueError('`Dataloader` returned 0 length.' - ' Please make sure that your Dataloader at least returns 1 batch') - has_len = True - except TypeError: - has_len = False - except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used - has_len = False - - if has_len and _has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"): - rank_zero_warn( - 'Your `IterableDataset` has `__len__` defined.' - ' In combination with multi-processing data loading (e.g. batch size > 1),' - ' this can lead to unintended side effects since the samples will be duplicated.' - ) - return has_len - - class TrainerDataLoadingMixin(ABC): # this is just a summary on variables used in this abstract class, @@ -147,7 +112,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: # don't do anything if it's not a dataloader is_dataloader = isinstance(dataloader, DataLoader) # don't manipulate iterable datasets - is_iterable_ds = _has_iterable_dataset(dataloader) + is_iterable_ds = has_iterable_dataset(dataloader) if not is_dataloader or is_iterable_ds: return dataloader @@ -214,7 +179,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None: # automatically add samplers self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True) - self.num_training_batches = len(self.train_dataloader) if _has_len(self.train_dataloader) else float('inf') + self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf') self._worker_check(self.train_dataloader, 'train dataloader') if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: @@ -238,7 +203,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None: f'to the number of the training batches ({self.num_training_batches}). ' 'If you want to disable validation set `limit_val_batches` to 0.0 instead.') else: - if not _has_len(self.train_dataloader): + if not has_len(self.train_dataloader): if self.val_check_interval == 1.0: self.val_check_batch = float('inf') else: @@ -305,7 +270,7 @@ def _reset_eval_dataloader( # datasets could be none, 1 or 2+ if len(dataloaders) != 0: for i, dataloader in enumerate(dataloaders): - num_batches = len(dataloader) if _has_len(dataloader) else float('inf') + num_batches = len(dataloader) if has_len(dataloader) else float('inf') self._worker_check(dataloader, f'{mode} dataloader {i}') # percent or num_steps diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py new file mode 100644 index 0000000000000..54f81f20f9ab7 --- /dev/null +++ b/pytorch_lightning/utilities/data.py @@ -0,0 +1,47 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distutils.version import LooseVersion +import torch +from torch.utils.data import DataLoader, IterableDataset + +from pytorch_lightning.utilities import rank_zero_warn + + +def has_iterable_dataset(dataloader: DataLoader): + return hasattr(dataloader, 'dataset') and isinstance(dataloader.dataset, IterableDataset) + + +def has_len(dataloader: DataLoader) -> bool: + """ Checks if a given Dataloader has __len__ method implemented i.e. if + it is a finite dataloader or infinite dataloader. """ + + try: + # try getting the length + if len(dataloader) == 0: + raise ValueError('`Dataloader` returned 0 length.' + ' Please make sure that your Dataloader at least returns 1 batch') + has_len = True + except TypeError: + has_len = False + except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used + has_len = False + + if has_len and has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"): + rank_zero_warn( + 'Your `IterableDataset` has `__len__` defined.' + ' In combination with multi-processing data loading (e.g. batch size > 1),' + ' this can lead to unintended side effects since the samples will be duplicated.' + ) + return has_len diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index adb8be0bac178..85b8241e2b5ac 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -11,7 +11,7 @@ import tests.base.develop_pipelines as tpipes from pytorch_lightning import Trainer, Callback -from pytorch_lightning.trainer.data_loading import _has_iterable_dataset, _has_len +from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate @@ -624,7 +624,7 @@ def test_warning_with_few_workers(mock, tmpdir, ckpt_path): reason="IterableDataset with __len__ before 1.4 raises", ) def test_warning_with_iterable_dataset_and_len(tmpdir): - """ Tests that a warning messages is shown when an IterableDataset defines `__len__`. """ + """ Tests that a warning message is shown when an IterableDataset defines `__len__`. """ model = EvalModelTemplate() original_dataset = model.train_dataloader().dataset @@ -637,8 +637,8 @@ def __len__(self): return len(original_dataset) dataloader = DataLoader(IterableWithLen(), batch_size=16) - assert _has_len(dataloader) - assert _has_iterable_dataset(dataloader) + assert has_len(dataloader) + assert has_iterable_dataset(dataloader) trainer = Trainer( default_root_dir=tmpdir, max_steps=3,