From 0bec9f3063f2628145719a135785dd37400ce29b Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Thu, 15 Dec 2022 20:15:54 +0100 Subject: [PATCH] Support torch dataloader without torch formatting (#5357) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * dynamically add torch IterableDataset parent class * tests * docs * subclass torch Iterable dataset in __init__ * fix tests * Update src/datasets/iterable_dataset.py Co-authored-by: Mario Šaško * polina's comments Co-authored-by: Mario Šaško --- docs/source/use_with_pytorch.mdx | 10 +- src/datasets/filesystems/__init__.py | 18 +++ .../formatting/dataset_wrappers/__init__.py | 0 .../torch_iterable_dataset.py | 68 ---------- src/datasets/iterable_dataset.py | 124 ++++++++++++------ tests/test_iterable_dataset.py | 52 ++++---- tests/utils.py | 10 +- 7 files changed, 144 insertions(+), 138 deletions(-) delete mode 100644 src/datasets/formatting/dataset_wrappers/__init__.py delete mode 100644 src/datasets/formatting/dataset_wrappers/torch_iterable_dataset.py diff --git a/docs/source/use_with_pytorch.mdx b/docs/source/use_with_pytorch.mdx index cdc6cb644c1..5a812b3c1e8 100644 --- a/docs/source/use_with_pytorch.mdx +++ b/docs/source/use_with_pytorch.mdx @@ -150,7 +150,7 @@ Like `torch.utils.data.Dataset` objects, a [`Dataset`] can be passed directly to ### Optimize data loading There are several ways you can increase the speed your data is loaded which can save you time, especially if you are working with large datasets. -PyTorch offers parallelized data loading, retrieving batches of indices instead of individually, and streaming to progressively download datasets. +PyTorch offers parallelized data loading, retrieving batches of indices instead of individually, and streaming to iterate over the dataset without downloading it on disk. #### Use multiple Workers @@ -200,8 +200,8 @@ You must use a `BatchSampler` if you want the transform to be given full batches ### Stream data -Loading a dataset in streaming mode is useful to progressively download the data you need while iterating over the dataset. -Set the format of a streaming dataset to `torch`, and it inherits from `torch.utils.data.IterableDataset` so you can pass it to a `DataLoader`: +Loading a dataset in streaming mode allows one to iterate over the dataset without downloading it on disk. +An iterable dataset from `datasets` inherits from `torch.utils.data.IterableDataset` so you can pass it to a `DataLoader`: ```py >>> import numpy as np @@ -209,14 +209,14 @@ Set the format of a streaming dataset to `torch`, and it inherits from `torch.ut >>> from torch.utils.data import DataLoader >>> data = np.random.rand(10_000) >>> Dataset.from_dict({"data": data}).push_to_hub("/my_dataset") # Upload to the Hugging Face Hub ->>> ds = load_dataset("/my_dataset", streaming=True, split="train").with_format("torch") +>>> ds = load_dataset("/my_dataset", streaming=True, split="train") >>> dataloader = DataLoader(ds, batch_size=32) ``` If the dataset is split in several shards (i.e. if the dataset consists of multiple data files), then you can stream in parallel using `num_workers`: ```py ->>> ds = load_dataset("c4", "en", streaming=True, split="train").with_format("torch") +>>> ds = load_dataset("c4", "en", streaming=True, split="train") >>> ds.n_shards 1024 >>> dataloader = DataLoader(ds, batch_size=32, num_workers=4) diff --git a/src/datasets/filesystems/__init__.py b/src/datasets/filesystems/__init__.py index 577fbb36c40..cbc4c5a7464 100644 --- a/src/datasets/filesystems/__init__.py +++ b/src/datasets/filesystems/__init__.py @@ -1,7 +1,9 @@ import importlib +import threading from typing import List import fsspec +import fsspec.asyn from . import compression from .hffilesystem import HfFileSystem @@ -50,3 +52,19 @@ def is_remote_filesystem(fs: fsspec.AbstractFileSystem) -> bool: return True else: return False + + +def _reset_fsspec_lock() -> None: + """ + Clear reference to the loop and thread. + This is necessary otherwise HTTPFileSystem hangs in the ML training loop. + Only required for fsspec >= 0.9.0 + See https://github.com/fsspec/gcsfs/issues/379 + """ + if hasattr(fsspec.asyn, "reset_lock"): + # for future fsspec>2022.05.0 + fsspec.asyn.reset_lock() + else: + fsspec.asyn.iothread[0] = None + fsspec.asyn.loop[0] = None + fsspec.asyn.lock = threading.Lock() diff --git a/src/datasets/formatting/dataset_wrappers/__init__.py b/src/datasets/formatting/dataset_wrappers/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/src/datasets/formatting/dataset_wrappers/torch_iterable_dataset.py b/src/datasets/formatting/dataset_wrappers/torch_iterable_dataset.py deleted file mode 100644 index ea80f5c76fc..00000000000 --- a/src/datasets/formatting/dataset_wrappers/torch_iterable_dataset.py +++ /dev/null @@ -1,68 +0,0 @@ -import threading - -import fsspec.asyn -import torch - -from ...iterable_dataset import IterableDataset, _apply_feature_types_on_example -from ...utils.logging import get_logger - - -logger = get_logger(__name__) - - -def _set_fsspec_for_multiprocess() -> None: - """ - Clear reference to the loop and thread. - This is necessary otherwise HTTPFileSystem hangs in the ML training loop. - Only required for fsspec >= 0.9.0 - See https://github.com/fsspec/gcsfs/issues/379 - """ - if hasattr(fsspec.asyn, "reset_lock"): - # for future fsspec>2022.05.0 - fsspec.asyn.reset_lock() - else: - fsspec.asyn.iothread[0] = None - fsspec.asyn.loop[0] = None - fsspec.asyn.lock = threading.Lock() - - -class TorchIterableDataset(IterableDataset, torch.utils.data.IterableDataset): - def __iter__(self): - # fix for fsspec when using multprocess - _set_fsspec_for_multiprocess() - worker_info = torch.utils.data.get_worker_info() - if worker_info is None: # single-process data loading, return the full iterator - yield from IterableDataset.__iter__(self) - else: # in a worker process - # check if there aren't too many workers - if worker_info.id == 0 and self.n_shards < worker_info.num_workers: - logger.warning( - f"Too many dataloader workers: {worker_info.num_workers} (max is dataset.n_shards={self.n_shards}). " - f"Stopping dataloader workers [{self.n_shards}...{worker_info.num_workers -1}]." - ) - logger.warning( - f"To parallelize data loading, we give each process some shards (or data sources) to process. " - f"Therefore it's unnecessary to have a number of workers greater than dataset.n_shards={self.n_shards}." - f"To enable more parallelism, please split the dataset in more files than {self.n_shards}." - ) - # split workload - shards_indices = list(range(worker_info.id, self.n_shards, worker_info.num_workers)) - if shards_indices: - logger.debug( - f"dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{self.n_shards} shards." - ) - for shard_idx in shards_indices: - for key, example in self._iter_shard(shard_idx): - if self.features: - yield _apply_feature_types_on_example( - example, self.features, token_per_repo_id=self._token_per_repo_id - ) - else: - yield example - logger.debug( - f"dataloader worker#{worker_info.id}, ': Finished iterating over {len(shards_indices)}/{self.n_shards} shards." - ) - else: - logger.debug( - f"dataloader worker#{worker_info.id}, ': Stopping... Number of dataset shards < num_workers ({self.n_shards}<{worker_info.num_workers})." - ) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 65429768812..3234a3df744 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1,24 +1,31 @@ import copy import itertools +import sys from collections import Counter from copy import deepcopy from dataclasses import dataclass from itertools import cycle, islice -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Union import numpy as np import pyarrow as pa +from . import config from .arrow_dataset import DatasetInfoMixin from .features import Features from .features.features import FeatureType, _align_features, _check_if_features_can_be_aligned -from .formatting import PythonFormatter +from .filesystems import _reset_fsspec_lock +from .formatting import PythonFormatter, get_format_type_from_alias from .info import DatasetInfo from .splits import NamedSplit from .table import table_cast +from .utils.logging import get_logger from .utils.sharding import _number_of_shards_in_gen_kwargs, _shuffle_gen_kwargs, _split_gen_kwargs +logger = get_logger(__name__) + + def _infer_features_from_batch(batch: Dict[str, list], try_features: Optional[Features] = None) -> Features: pa_table = pa.Table.from_pydict(batch) if try_features is not None: @@ -719,6 +726,15 @@ class ShufflingConfig: generator: np.random.Generator +def _maybe_add_torch_iterable_dataset_parent_class(cls): + """Add torch.utils.data.IterableDataset as a parent class if 'torch' is available""" + if config.TORCH_AVAILABLE: + import torch.utils.data + + if torch.utils.data.IterableDataset not in cls.__bases__: + cls.__bases__ += (torch.utils.data.IterableDataset,) + + class IterableDataset(DatasetInfoMixin): """A Dataset backed by an iterable.""" @@ -739,6 +755,15 @@ def __init__( self._shuffling = shuffling self._epoch = 0 self._token_per_repo_id: Dict[str, Union[str, bool, None]] = token_per_repo_id or {} + _maybe_add_torch_iterable_dataset_parent_class(self.__class__) + + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, d): + self.__dict__ = d + # Re-add torch iterable dataset as a parent class, since dynamically added parent classes are not kept when pickling + _maybe_add_torch_iterable_dataset_parent_class(self.__class__) def _head(self, n=5): return _examples_to_batch([x for key, x in islice(self._iter(), n)]) @@ -772,7 +797,55 @@ def _iter_shard(self, shard_idx: int): ex_iterable = self._ex_iterable yield from ex_iterable.shard_data_sources(shard_idx) + def _iter_pytorch(self, worker_info): + # fix for fsspec when using multprocess + _reset_fsspec_lock() + if worker_info is None: # single-process data loading, return the full iterator + yield from IterableDataset.__iter__(self) + else: # in a worker process + # check if there aren't too many workers + if worker_info.id == 0 and self.n_shards < worker_info.num_workers: + logger.warning( + f"Too many dataloader workers: {worker_info.num_workers} (max is dataset.n_shards={self.n_shards}). " + f"Stopping dataloader workers [{self.n_shards}...{worker_info.num_workers -1}]." + ) + logger.warning( + f"To parallelize data loading, we give each process some shards (or data sources) to process. " + f"Therefore it's unnecessary to have a number of workers greater than dataset.n_shards={self.n_shards}." + f"To enable more parallelism, please split the dataset in more files than {self.n_shards}." + ) + # split workload + shards_indices = list(range(worker_info.id, self.n_shards, worker_info.num_workers)) + if shards_indices: + logger.debug( + f"dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{self.n_shards} shards." + ) + for shard_idx in shards_indices: + for key, example in self._iter_shard(shard_idx): + if self.features: + yield _apply_feature_types_on_example( + example, self.features, token_per_repo_id=self._token_per_repo_id + ) + else: + yield example + logger.debug( + f"dataloader worker#{worker_info.id}, ': Finished iterating over {len(shards_indices)}/{self.n_shards} shards." + ) + else: + logger.debug( + f"dataloader worker#{worker_info.id}, ': Stopping... Number of dataset shards < num_workers ({self.n_shards}<{worker_info.num_workers})." + ) + def __iter__(self): + if "torch" in sys.modules: + import torch.utils.data + + worker_info = torch.utils.data.get_worker_info() + if isinstance(self, torch.utils.data.IterableDataset) and worker_info is not None: + # We're a torch.utils.data.IterableDataset in a PyTorch worker process + yield from self._iter_pytorch(worker_info) + return + for key, example in self._iter(): if self.features: # `IterableDataset` automatically fills missing columns with None. @@ -872,11 +945,12 @@ def with_format( type (`str`, optional, default None): if set to "torch", the returned dataset will be a subclass of torch.utils.data.IterableDataset to be used in a DataLoader """ + type = get_format_type_from_alias(type) # TODO(QL): add examples formatting to get tensors when using the "torch" format # TODO(QL): add format_kwargs # TODO(QL): add format_columns and return_all_columns # TODO(QL): add pandas, numpy and tf formats - return iterable_dataset( + return IterableDataset( ex_iterable=self._ex_iterable, info=self._info.copy(), split=self._split, @@ -987,7 +1061,7 @@ def map( ) info = self.info.copy() info.features = features - return iterable_dataset( + return IterableDataset( ex_iterable=ex_iterable, info=info, split=self._split, @@ -1059,7 +1133,7 @@ def filter( batched=batched, batch_size=batch_size, ) - return iterable_dataset( + return IterableDataset( ex_iterable=ex_iterable, info=info, split=self._split, @@ -1123,7 +1197,7 @@ def shuffle( else: generator = deepcopy(generator) shuffling = ShufflingConfig(generator=generator) - return iterable_dataset( + return IterableDataset( ex_iterable=BufferShuffledExamplesIterable( self._ex_iterable, buffer_size=buffer_size, generator=generator ).shuffle_data_sources(generator), @@ -1166,7 +1240,7 @@ def skip(self, n) -> "IterableDataset": ``` """ ex_iterable = SkipExamplesIterable(self._ex_iterable, n) - return iterable_dataset( + return IterableDataset( ex_iterable=ex_iterable, info=self._info.copy(), split=self._split, @@ -1197,7 +1271,7 @@ def take(self, n) -> "IterableDataset": ``` """ ex_iterable = TakeExamplesIterable(self._ex_iterable, n) - return iterable_dataset( + return IterableDataset( ex_iterable=ex_iterable, info=self._info.copy(), split=self._split, @@ -1387,7 +1461,7 @@ def cast_column(self, column: str, feature: FeatureType) -> "IterableDataset": info.copy() except ValueError: info.task_templates = None - return iterable_dataset( + return IterableDataset( ex_iterable=self._ex_iterable, info=info, split=self._split, @@ -1437,7 +1511,7 @@ def cast( info.copy() except ValueError: info.task_templates = None - return iterable_dataset( + return IterableDataset( ex_iterable=self._ex_iterable, info=info, split=self._split, @@ -1455,7 +1529,7 @@ def _resolve_features(self): features = _infer_features_from_batch(self._head()) info = self.info.copy() info.features = features - return iterable_dataset( + return IterableDataset( ex_iterable=self._ex_iterable, info=info, split=self._split, @@ -1465,30 +1539,6 @@ def _resolve_features(self): ) -def iterable_dataset( - ex_iterable: Iterable, - info: Optional[DatasetInfo] = None, - split: Optional[NamedSplit] = None, - format_type: Optional[str] = None, - shuffling: Optional[ShufflingConfig] = None, - token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None, -): - if format_type is not None and format_type == "torch": - from .formatting.dataset_wrappers.torch_iterable_dataset import TorchIterableDataset - - cls = TorchIterableDataset - else: - cls = IterableDataset - return cls( - ex_iterable=ex_iterable, - info=info, - split=split, - format_type=format_type, - shuffling=shuffling, - token_per_repo_id=token_per_repo_id, - ) - - def _concatenate_iterable_datasets( dsets: List[IterableDataset], info: Optional[DatasetInfo] = None, @@ -1546,7 +1596,7 @@ def _concatenate_iterable_datasets( # Get all the auth tokens per repository - in case the datasets come from different private repositories token_per_repo_id = {repo_id: token for dataset in dsets for repo_id, token in dataset._token_per_repo_id.items()} # Return new daset - return iterable_dataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id) + return IterableDataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id) def _interleave_iterable_datasets( @@ -1614,4 +1664,4 @@ def _interleave_iterable_datasets( repo_id: token for dataset in datasets for repo_id, token in dataset._token_per_repo_id.items() } # Return new daset - return iterable_dataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id) + return IterableDataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index de5d0dafb3e..be0149b0caf 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -8,6 +8,7 @@ from datasets import load_dataset from datasets.combine import concatenate_datasets, interleave_datasets from datasets.features import ClassLabel, Features, Value +from datasets.formatting import get_format_type_from_alias from datasets.info import DatasetInfo from datasets.iterable_dataset import ( BufferShuffledExamplesIterable, @@ -24,7 +25,6 @@ VerticallyConcatenatedMultiSourcesExamplesIterable, _batch_to_examples, _examples_to_batch, - iterable_dataset, ) from .utils import is_rng_equal, require_torch @@ -617,13 +617,6 @@ def test_iterable_dataset(): assert list(dataset) == expected -def test_iterable_dataset_factory(): - ex_iterable = ExamplesIterable(generate_examples_fn, {}) - dataset = iterable_dataset(ex_iterable) - assert isinstance(dataset, IterableDataset) - assert dataset._ex_iterable is ex_iterable - - def test_iterable_dataset_from_generator(): data = [ {"col_1": "0", "col_2": 0, "col_3": 0.0}, @@ -653,26 +646,26 @@ def gen(shard_names): @require_torch -def test_iterable_dataset_factory_torch_integration(): - import torch +def test_iterable_dataset_torch_integration(): ex_iterable = ExamplesIterable(generate_examples_fn, {}) - dataset = iterable_dataset(ex_iterable, format_type="torch") - assert isinstance(dataset, IterableDataset) + dataset = IterableDataset(ex_iterable) + import torch.utils.data + assert isinstance(dataset, torch.utils.data.IterableDataset) - assert dataset._format_type == "torch" + assert isinstance(dataset, IterableDataset) assert dataset._ex_iterable is ex_iterable @require_torch -def test_iterable_dataset_factory_torch_picklable(): +def test_iterable_dataset_torch_picklable(): import pickle ex_iterable = ExamplesIterable(generate_examples_fn, {}) - dataset = iterable_dataset(ex_iterable, format_type="torch") + dataset = IterableDataset(ex_iterable, format_type="torch") reloaded_dataset = pickle.loads(pickle.dumps(dataset)) - import torch + import torch.utils.data assert isinstance(reloaded_dataset, IterableDataset) assert isinstance(reloaded_dataset, torch.utils.data.IterableDataset) @@ -682,10 +675,10 @@ def test_iterable_dataset_factory_torch_picklable(): @require_torch def test_iterable_dataset_with_format_torch(): + ex_iterable = ExamplesIterable(generate_examples_fn, {}) + dataset = IterableDataset(ex_iterable) from torch.utils.data import DataLoader - ex_iterable = ExamplesIterable(generate_examples_fn, {}) - dataset = iterable_dataset(ex_iterable).with_format("torch") dataloader = DataLoader(dataset) assert len(list(dataloader)) == len(list(ex_iterable)) @@ -695,7 +688,7 @@ def test_iterable_dataset_torch_dataloader_parallel(): from torch.utils.data import DataLoader ex_iterable = ExamplesIterable(generate_examples_fn, {}) - dataset = iterable_dataset(ex_iterable).with_format("torch") + dataset = IterableDataset(ex_iterable).with_format("torch") dataloader = DataLoader(dataset, num_workers=2, batch_size=None) result = list(dataloader) expected = [example for _, example in ex_iterable] @@ -709,7 +702,7 @@ def test_sharded_iterable_dataset_torch_dataloader_parallel(n_shards, num_worker from torch.utils.data import DataLoader ex_iterable = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}.txt" for i in range(n_shards)]}) - dataset = iterable_dataset(ex_iterable).with_format("torch") + dataset = IterableDataset(ex_iterable).with_format("torch") dataloader = DataLoader(dataset, batch_size=None, num_workers=num_workers) result = list(dataloader) expected = [example for _, example in ex_iterable] @@ -920,15 +913,22 @@ def test_iterable_dataset_features_cast_to_python(): assert list(dataset) == [{"timestamp": pd.Timestamp(2020, 1, 1).to_pydatetime(), "array": [1] * 5, "id": 0}] -@require_torch -@pytest.mark.parametrize("format_type", [None, "torch", "python"]) +@pytest.mark.parametrize( + "format_type", [None, "torch", "python", "tf", "tensorflow", "np", "numpy", "jax", "pd", "pandas"] +) def test_iterable_dataset_with_format(dataset: IterableDataset, format_type): formatted_dataset = dataset.with_format(format_type) - assert formatted_dataset._format_type == format_type - if format_type == "torch": - import torch + assert formatted_dataset._format_type == get_format_type_from_alias(format_type) - assert isinstance(formatted_dataset, torch.utils.data.IterableDataset) + +@require_torch +def test_iterable_dataset_is_torch_iterable_dataset(dataset: IterableDataset): + from torch.utils.data import DataLoader, _DatasetKind + + dataloader = DataLoader(dataset) + assert dataloader._dataset_kind == _DatasetKind.Iterable + out = list(dataloader) + assert len(out) == DEFAULT_N_EXAMPLES @pytest.mark.parametrize("n", [0, 2, int(1e10)]) diff --git a/tests/utils.py b/tests/utils.py index 397b2ed261d..dc0fc633371 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -19,6 +19,12 @@ from datasets import config +if config.PY_VERSION < version.parse("3.8"): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + + def parse_flag_from_env(key, default=False): try: value = os.environ[key] @@ -66,12 +72,12 @@ def parse_flag_from_env(key, default=False): ) require_torchaudio = pytest.mark.skipif( find_spec("torchaudio") is None - or version.parse(import_module("torchaudio").__version__) >= version.parse("0.12.0"), + or version.parse(importlib_metadata.version("torchaudio")) >= version.parse("0.12.0"), reason="test requires torchaudio<0.12", ) require_torchaudio_latest = pytest.mark.skipif( find_spec("torchaudio") is None - or version.parse(import_module("torchaudio").__version__) < version.parse("0.12.0"), + or version.parse(importlib_metadata.version("torchaudio")) < version.parse("0.12.0"), reason="test requires torchaudio>=0.12", )