Skip to content

Commit

Permalink
Support torch dataloader without torch formatting (#5357)
Browse files Browse the repository at this point in the history
* dynamically add torch IterableDataset parent class

* tests

* docs

* subclass torch Iterable dataset in __init__

* fix tests

* Update src/datasets/iterable_dataset.py

Co-authored-by: Mario Šaško <mariosasko777@gmail.com>

* polina's comments

Co-authored-by: Mario Šaško <mariosasko777@gmail.com>
  • Loading branch information
lhoestq and mariosasko authored Dec 15, 2022
1 parent 53c563c commit 0bec9f3
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 138 deletions.
10 changes: 5 additions & 5 deletions docs/source/use_with_pytorch.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ Like `torch.utils.data.Dataset` objects, a [`Dataset`] can be passed directly to
### Optimize data loading

There are several ways you can increase the speed your data is loaded which can save you time, especially if you are working with large datasets.
PyTorch offers parallelized data loading, retrieving batches of indices instead of individually, and streaming to progressively download datasets.
PyTorch offers parallelized data loading, retrieving batches of indices instead of individually, and streaming to iterate over the dataset without downloading it on disk.

#### Use multiple Workers

Expand Down Expand Up @@ -200,23 +200,23 @@ You must use a `BatchSampler` if you want the transform to be given full batches

### Stream data

Loading a dataset in streaming mode is useful to progressively download the data you need while iterating over the dataset.
Set the format of a streaming dataset to `torch`, and it inherits from `torch.utils.data.IterableDataset` so you can pass it to a `DataLoader`:
Loading a dataset in streaming mode allows one to iterate over the dataset without downloading it on disk.
An iterable dataset from `datasets` inherits from `torch.utils.data.IterableDataset` so you can pass it to a `DataLoader`:

```py
>>> import numpy as np
>>> from datasets import Dataset, load_dataset
>>> from torch.utils.data import DataLoader
>>> data = np.random.rand(10_000)
>>> Dataset.from_dict({"data": data}).push_to_hub("<username>/my_dataset") # Upload to the Hugging Face Hub
>>> ds = load_dataset("<username>/my_dataset", streaming=True, split="train").with_format("torch")
>>> ds = load_dataset("<username>/my_dataset", streaming=True, split="train")
>>> dataloader = DataLoader(ds, batch_size=32)
```

If the dataset is split in several shards (i.e. if the dataset consists of multiple data files), then you can stream in parallel using `num_workers`:

```py
>>> ds = load_dataset("c4", "en", streaming=True, split="train").with_format("torch")
>>> ds = load_dataset("c4", "en", streaming=True, split="train")
>>> ds.n_shards
1024
>>> dataloader = DataLoader(ds, batch_size=32, num_workers=4)
Expand Down
18 changes: 18 additions & 0 deletions src/datasets/filesystems/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import importlib
import threading
from typing import List

import fsspec
import fsspec.asyn

from . import compression
from .hffilesystem import HfFileSystem
Expand Down Expand Up @@ -50,3 +52,19 @@ def is_remote_filesystem(fs: fsspec.AbstractFileSystem) -> bool:
return True
else:
return False


def _reset_fsspec_lock() -> None:
"""
Clear reference to the loop and thread.
This is necessary otherwise HTTPFileSystem hangs in the ML training loop.
Only required for fsspec >= 0.9.0
See https://github.com/fsspec/gcsfs/issues/379
"""
if hasattr(fsspec.asyn, "reset_lock"):
# for future fsspec>2022.05.0
fsspec.asyn.reset_lock()
else:
fsspec.asyn.iothread[0] = None
fsspec.asyn.loop[0] = None
fsspec.asyn.lock = threading.Lock()
Empty file.
68 changes: 0 additions & 68 deletions src/datasets/formatting/dataset_wrappers/torch_iterable_dataset.py

This file was deleted.

124 changes: 87 additions & 37 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,31 @@
import copy
import itertools
import sys
from collections import Counter
from copy import deepcopy
from dataclasses import dataclass
from itertools import cycle, islice
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Union

import numpy as np
import pyarrow as pa

from . import config
from .arrow_dataset import DatasetInfoMixin
from .features import Features
from .features.features import FeatureType, _align_features, _check_if_features_can_be_aligned
from .formatting import PythonFormatter
from .filesystems import _reset_fsspec_lock
from .formatting import PythonFormatter, get_format_type_from_alias
from .info import DatasetInfo
from .splits import NamedSplit
from .table import table_cast
from .utils.logging import get_logger
from .utils.sharding import _number_of_shards_in_gen_kwargs, _shuffle_gen_kwargs, _split_gen_kwargs


logger = get_logger(__name__)


def _infer_features_from_batch(batch: Dict[str, list], try_features: Optional[Features] = None) -> Features:
pa_table = pa.Table.from_pydict(batch)
if try_features is not None:
Expand Down Expand Up @@ -719,6 +726,15 @@ class ShufflingConfig:
generator: np.random.Generator


def _maybe_add_torch_iterable_dataset_parent_class(cls):
"""Add torch.utils.data.IterableDataset as a parent class if 'torch' is available"""
if config.TORCH_AVAILABLE:
import torch.utils.data

if torch.utils.data.IterableDataset not in cls.__bases__:
cls.__bases__ += (torch.utils.data.IterableDataset,)


class IterableDataset(DatasetInfoMixin):
"""A Dataset backed by an iterable."""

Expand All @@ -739,6 +755,15 @@ def __init__(
self._shuffling = shuffling
self._epoch = 0
self._token_per_repo_id: Dict[str, Union[str, bool, None]] = token_per_repo_id or {}
_maybe_add_torch_iterable_dataset_parent_class(self.__class__)

def __getstate__(self):
return self.__dict__

def __setstate__(self, d):
self.__dict__ = d
# Re-add torch iterable dataset as a parent class, since dynamically added parent classes are not kept when pickling
_maybe_add_torch_iterable_dataset_parent_class(self.__class__)

def _head(self, n=5):
return _examples_to_batch([x for key, x in islice(self._iter(), n)])
Expand Down Expand Up @@ -772,7 +797,55 @@ def _iter_shard(self, shard_idx: int):
ex_iterable = self._ex_iterable
yield from ex_iterable.shard_data_sources(shard_idx)

def _iter_pytorch(self, worker_info):
# fix for fsspec when using multprocess
_reset_fsspec_lock()
if worker_info is None: # single-process data loading, return the full iterator
yield from IterableDataset.__iter__(self)
else: # in a worker process
# check if there aren't too many workers
if worker_info.id == 0 and self.n_shards < worker_info.num_workers:
logger.warning(
f"Too many dataloader workers: {worker_info.num_workers} (max is dataset.n_shards={self.n_shards}). "
f"Stopping dataloader workers [{self.n_shards}...{worker_info.num_workers -1}]."
)
logger.warning(
f"To parallelize data loading, we give each process some shards (or data sources) to process. "
f"Therefore it's unnecessary to have a number of workers greater than dataset.n_shards={self.n_shards}."
f"To enable more parallelism, please split the dataset in more files than {self.n_shards}."
)
# split workload
shards_indices = list(range(worker_info.id, self.n_shards, worker_info.num_workers))
if shards_indices:
logger.debug(
f"dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{self.n_shards} shards."
)
for shard_idx in shards_indices:
for key, example in self._iter_shard(shard_idx):
if self.features:
yield _apply_feature_types_on_example(
example, self.features, token_per_repo_id=self._token_per_repo_id
)
else:
yield example
logger.debug(
f"dataloader worker#{worker_info.id}, ': Finished iterating over {len(shards_indices)}/{self.n_shards} shards."
)
else:
logger.debug(
f"dataloader worker#{worker_info.id}, ': Stopping... Number of dataset shards < num_workers ({self.n_shards}<{worker_info.num_workers})."
)

def __iter__(self):
if "torch" in sys.modules:
import torch.utils.data

worker_info = torch.utils.data.get_worker_info()
if isinstance(self, torch.utils.data.IterableDataset) and worker_info is not None:
# We're a torch.utils.data.IterableDataset in a PyTorch worker process
yield from self._iter_pytorch(worker_info)
return

for key, example in self._iter():
if self.features:
# `IterableDataset` automatically fills missing columns with None.
Expand Down Expand Up @@ -872,11 +945,12 @@ def with_format(
type (`str`, optional, default None): if set to "torch", the returned dataset
will be a subclass of torch.utils.data.IterableDataset to be used in a DataLoader
"""
type = get_format_type_from_alias(type)
# TODO(QL): add examples formatting to get tensors when using the "torch" format
# TODO(QL): add format_kwargs
# TODO(QL): add format_columns and return_all_columns
# TODO(QL): add pandas, numpy and tf formats
return iterable_dataset(
return IterableDataset(
ex_iterable=self._ex_iterable,
info=self._info.copy(),
split=self._split,
Expand Down Expand Up @@ -987,7 +1061,7 @@ def map(
)
info = self.info.copy()
info.features = features
return iterable_dataset(
return IterableDataset(
ex_iterable=ex_iterable,
info=info,
split=self._split,
Expand Down Expand Up @@ -1059,7 +1133,7 @@ def filter(
batched=batched,
batch_size=batch_size,
)
return iterable_dataset(
return IterableDataset(
ex_iterable=ex_iterable,
info=info,
split=self._split,
Expand Down Expand Up @@ -1123,7 +1197,7 @@ def shuffle(
else:
generator = deepcopy(generator)
shuffling = ShufflingConfig(generator=generator)
return iterable_dataset(
return IterableDataset(
ex_iterable=BufferShuffledExamplesIterable(
self._ex_iterable, buffer_size=buffer_size, generator=generator
).shuffle_data_sources(generator),
Expand Down Expand Up @@ -1166,7 +1240,7 @@ def skip(self, n) -> "IterableDataset":
```
"""
ex_iterable = SkipExamplesIterable(self._ex_iterable, n)
return iterable_dataset(
return IterableDataset(
ex_iterable=ex_iterable,
info=self._info.copy(),
split=self._split,
Expand Down Expand Up @@ -1197,7 +1271,7 @@ def take(self, n) -> "IterableDataset":
```
"""
ex_iterable = TakeExamplesIterable(self._ex_iterable, n)
return iterable_dataset(
return IterableDataset(
ex_iterable=ex_iterable,
info=self._info.copy(),
split=self._split,
Expand Down Expand Up @@ -1387,7 +1461,7 @@ def cast_column(self, column: str, feature: FeatureType) -> "IterableDataset":
info.copy()
except ValueError:
info.task_templates = None
return iterable_dataset(
return IterableDataset(
ex_iterable=self._ex_iterable,
info=info,
split=self._split,
Expand Down Expand Up @@ -1437,7 +1511,7 @@ def cast(
info.copy()
except ValueError:
info.task_templates = None
return iterable_dataset(
return IterableDataset(
ex_iterable=self._ex_iterable,
info=info,
split=self._split,
Expand All @@ -1455,7 +1529,7 @@ def _resolve_features(self):
features = _infer_features_from_batch(self._head())
info = self.info.copy()
info.features = features
return iterable_dataset(
return IterableDataset(
ex_iterable=self._ex_iterable,
info=info,
split=self._split,
Expand All @@ -1465,30 +1539,6 @@ def _resolve_features(self):
)


def iterable_dataset(
ex_iterable: Iterable,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
format_type: Optional[str] = None,
shuffling: Optional[ShufflingConfig] = None,
token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None,
):
if format_type is not None and format_type == "torch":
from .formatting.dataset_wrappers.torch_iterable_dataset import TorchIterableDataset

cls = TorchIterableDataset
else:
cls = IterableDataset
return cls(
ex_iterable=ex_iterable,
info=info,
split=split,
format_type=format_type,
shuffling=shuffling,
token_per_repo_id=token_per_repo_id,
)


def _concatenate_iterable_datasets(
dsets: List[IterableDataset],
info: Optional[DatasetInfo] = None,
Expand Down Expand Up @@ -1546,7 +1596,7 @@ def _concatenate_iterable_datasets(
# Get all the auth tokens per repository - in case the datasets come from different private repositories
token_per_repo_id = {repo_id: token for dataset in dsets for repo_id, token in dataset._token_per_repo_id.items()}
# Return new daset
return iterable_dataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id)
return IterableDataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id)


def _interleave_iterable_datasets(
Expand Down Expand Up @@ -1614,4 +1664,4 @@ def _interleave_iterable_datasets(
repo_id: token for dataset in datasets for repo_id, token in dataset._token_per_repo_id.items()
}
# Return new daset
return iterable_dataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id)
return IterableDataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id)
Loading

1 comment on commit 0bec9f3

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==6.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.009345 / 0.011353 (-0.002008) 0.005639 / 0.011008 (-0.005369) 0.099485 / 0.038508 (0.060977) 0.035694 / 0.023109 (0.012584) 0.314439 / 0.275898 (0.038541) 0.385209 / 0.323480 (0.061729) 0.008399 / 0.007986 (0.000413) 0.005463 / 0.004328 (0.001134) 0.075367 / 0.004250 (0.071116) 0.047028 / 0.037052 (0.009975) 0.316610 / 0.258489 (0.058121) 0.350743 / 0.293841 (0.056903) 0.038023 / 0.128546 (-0.090524) 0.012326 / 0.075646 (-0.063321) 0.335001 / 0.419271 (-0.084270) 0.048672 / 0.043533 (0.005139) 0.300149 / 0.255139 (0.045010) 0.329159 / 0.283200 (0.045960) 0.112407 / 0.141683 (-0.029276) 1.489946 / 1.452155 (0.037792) 1.501827 / 1.492716 (0.009111)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.242771 / 0.018006 (0.224765) 0.590216 / 0.000490 (0.589726) 0.003586 / 0.000200 (0.003387) 0.000127 / 0.000054 (0.000073)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.029688 / 0.037411 (-0.007724) 0.115821 / 0.014526 (0.101295) 0.124833 / 0.176557 (-0.051724) 0.169860 / 0.737135 (-0.567275) 0.131514 / 0.296338 (-0.164824)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.397650 / 0.215209 (0.182441) 3.995692 / 2.077655 (1.918037) 1.785917 / 1.504120 (0.281797) 1.581189 / 1.541195 (0.039994) 1.679670 / 1.468490 (0.211180) 0.692615 / 4.584777 (-3.892162) 3.911817 / 3.745712 (0.166105) 2.189286 / 5.269862 (-3.080576) 1.384511 / 4.565676 (-3.181165) 0.084399 / 0.424275 (-0.339876) 0.011903 / 0.007607 (0.004296) 0.517683 / 0.226044 (0.291638) 5.213863 / 2.268929 (2.944935) 2.221472 / 55.444624 (-53.223152) 1.900715 / 6.876477 (-4.975762) 2.076686 / 2.142072 (-0.065386) 0.850590 / 4.805227 (-3.954637) 0.169062 / 6.500664 (-6.331603) 0.062509 / 0.075469 (-0.012960)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.187837 / 1.841788 (-0.653951) 15.944581 / 8.074308 (7.870273) 14.368825 / 10.191392 (4.177433) 0.174276 / 0.680424 (-0.506148) 0.029813 / 0.534201 (-0.504388) 0.448567 / 0.579283 (-0.130716) 0.453929 / 0.434364 (0.019565) 0.525962 / 0.540337 (-0.014375) 0.627476 / 1.386936 (-0.759460)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.008017 / 0.011353 (-0.003336) 0.005556 / 0.011008 (-0.005453) 0.098451 / 0.038508 (0.059943) 0.035139 / 0.023109 (0.012030) 0.379870 / 0.275898 (0.103972) 0.408219 / 0.323480 (0.084739) 0.006392 / 0.007986 (-0.001594) 0.004248 / 0.004328 (-0.000080) 0.073471 / 0.004250 (0.069221) 0.055458 / 0.037052 (0.018406) 0.369992 / 0.258489 (0.111503) 0.432652 / 0.293841 (0.138811) 0.036966 / 0.128546 (-0.091580) 0.012422 / 0.075646 (-0.063224) 0.330333 / 0.419271 (-0.088938) 0.048834 / 0.043533 (0.005301) 0.369884 / 0.255139 (0.114745) 0.381391 / 0.283200 (0.098192) 0.112831 / 0.141683 (-0.028851) 1.444639 / 1.452155 (-0.007516) 1.577630 / 1.492716 (0.084913)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.276265 / 0.018006 (0.258259) 0.574328 / 0.000490 (0.573838) 0.001280 / 0.000200 (0.001080) 0.000087 / 0.000054 (0.000033)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.033119 / 0.037411 (-0.004292) 0.121340 / 0.014526 (0.106814) 0.133485 / 0.176557 (-0.043072) 0.172415 / 0.737135 (-0.564721) 0.136676 / 0.296338 (-0.159662)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.435048 / 0.215209 (0.219839) 4.335761 / 2.077655 (2.258107) 2.147626 / 1.504120 (0.643506) 1.958791 / 1.541195 (0.417596) 2.041364 / 1.468490 (0.572874) 0.706186 / 4.584777 (-3.878591) 3.895539 / 3.745712 (0.149827) 2.170610 / 5.269862 (-3.099252) 1.383237 / 4.565676 (-3.182440) 0.086942 / 0.424275 (-0.337333) 0.012643 / 0.007607 (0.005036) 0.545663 / 0.226044 (0.319619) 5.423307 / 2.268929 (3.154378) 2.692782 / 55.444624 (-52.751843) 2.319479 / 6.876477 (-4.556998) 2.498471 / 2.142072 (0.356398) 0.857185 / 4.805227 (-3.948042) 0.173342 / 6.500664 (-6.327322) 0.066027 / 0.075469 (-0.009442)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.291048 / 1.841788 (-0.550740) 16.129752 / 8.074308 (8.055444) 13.646840 / 10.191392 (3.455448) 0.170251 / 0.680424 (-0.510173) 0.017673 / 0.534201 (-0.516527) 0.436676 / 0.579283 (-0.142607) 0.439206 / 0.434364 (0.004842) 0.506850 / 0.540337 (-0.033487) 0.616839 / 1.386936 (-0.770097)

Please sign in to comment.