Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support torch dataloader without torch formatting #5357

Merged
merged 7 commits into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/source/use_with_pytorch.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ Like `torch.utils.data.Dataset` objects, a [`Dataset`] can be passed directly to
### Optimize data loading

There are several ways you can increase the speed your data is loaded which can save you time, especially if you are working with large datasets.
PyTorch offers parallelized data loading, retrieving batches of indices instead of individually, and streaming to progressively download datasets.
PyTorch offers parallelized data loading, retrieving batches of indices instead of individually, and streaming to iterate over the dataset without downloading it on disk.

#### Use multiple Workers

Expand Down Expand Up @@ -200,23 +200,23 @@ You must use a `BatchSampler` if you want the transform to be given full batches

### Stream data

Loading a dataset in streaming mode is useful to progressively download the data you need while iterating over the dataset.
Set the format of a streaming dataset to `torch`, and it inherits from `torch.utils.data.IterableDataset` so you can pass it to a `DataLoader`:
Loading a dataset in streaming mode allows one to iterate over the dataset without downloading it on disk.
An iterable dataset from `datasets` inherits from `torch.utils.data.IterableDataset` so you can pass it to a `DataLoader`:

```py
>>> import numpy as np
>>> from datasets import Dataset, load_dataset
>>> from torch.utils.data import DataLoader
>>> data = np.random.rand(10_000)
>>> Dataset.from_dict({"data": data}).push_to_hub("<username>/my_dataset") # Upload to the Hugging Face Hub
>>> ds = load_dataset("<username>/my_dataset", streaming=True, split="train").with_format("torch")
>>> ds = load_dataset("<username>/my_dataset", streaming=True, split="train")
>>> dataloader = DataLoader(ds, batch_size=32)
```

If the dataset is split in several shards (i.e. if the dataset consists of multiple data files), then you can stream in parallel using `num_workers`:

```py
>>> ds = load_dataset("c4", "en", streaming=True, split="train").with_format("torch")
>>> ds = load_dataset("c4", "en", streaming=True, split="train")
>>> ds.n_shards
1024
>>> dataloader = DataLoader(ds, batch_size=32, num_workers=4)
Expand Down
18 changes: 18 additions & 0 deletions src/datasets/filesystems/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import importlib
import threading
from typing import List

import fsspec
import fsspec.asyn

from . import compression
from .hffilesystem import HfFileSystem
Expand Down Expand Up @@ -48,3 +50,19 @@ def is_remote_filesystem(fs: fsspec.AbstractFileSystem) -> bool:
return True
else:
return False


def _reset_fsspec_lock() -> None:
"""
Clear reference to the loop and thread.
This is necessary otherwise HTTPFileSystem hangs in the ML training loop.
Only required for fsspec >= 0.9.0
See https://github.com/fsspec/gcsfs/issues/379
"""
if hasattr(fsspec.asyn, "reset_lock"):
# for future fsspec>2022.05.0
fsspec.asyn.reset_lock()
else:
fsspec.asyn.iothread[0] = None
fsspec.asyn.loop[0] = None
fsspec.asyn.lock = threading.Lock()
Empty file.
68 changes: 0 additions & 68 deletions src/datasets/formatting/dataset_wrappers/torch_iterable_dataset.py

This file was deleted.

124 changes: 87 additions & 37 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,31 @@
import copy
import itertools
import sys
from collections import Counter
from copy import deepcopy
from dataclasses import dataclass
from itertools import cycle, islice
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Union

import numpy as np
import pyarrow as pa

from . import config
from .arrow_dataset import DatasetInfoMixin
from .features import Features
from .features.features import FeatureType, _align_features, _check_if_features_can_be_aligned
from .formatting import PythonFormatter
from .filesystems import _reset_fsspec_lock
from .formatting import PythonFormatter, get_format_type_from_alias
from .info import DatasetInfo
from .splits import NamedSplit
from .table import table_cast
from .utils.logging import get_logger
from .utils.sharding import _number_of_shards_in_gen_kwargs, _shuffle_gen_kwargs, _split_gen_kwargs


logger = get_logger(__name__)


def _infer_features_from_batch(batch: Dict[str, list], try_features: Optional[Features] = None) -> Features:
pa_table = pa.Table.from_pydict(batch)
if try_features is not None:
Expand Down Expand Up @@ -719,6 +726,15 @@ class ShufflingConfig:
generator: np.random.Generator


def _maybe_add_torch_iterable_dataset_parent_class(cls):
"""Add torch.utils.data.IterableDataset as a parent class if 'torch' is available"""
if config.TORCH_AVAILABLE:
import torch.utils.data

if torch.utils.data.IterableDataset not in cls.__bases__:
cls.__bases__ += (torch.utils.data.IterableDataset,)


class IterableDataset(DatasetInfoMixin):
"""A Dataset backed by an iterable."""

Expand All @@ -739,6 +755,15 @@ def __init__(
self._shuffling = shuffling
self._epoch = 0
self._token_per_repo_id: Dict[str, Union[str, bool, None]] = token_per_repo_id or {}
_maybe_add_torch_iterable_dataset_parent_class(self.__class__)

def __getstate__(self):
return self.__dict__

def __setstate__(self, d):
self.__dict__ = d
# Re-add torch iterable dataset as a parent class, since dynamically added parent classes are not kept when pickling
_maybe_add_torch_iterable_dataset_parent_class(self.__class__)

def _head(self, n=5):
return _examples_to_batch([x for key, x in islice(self._iter(), n)])
Expand Down Expand Up @@ -772,7 +797,55 @@ def _iter_shard(self, shard_idx: int):
ex_iterable = self._ex_iterable
yield from ex_iterable.shard_data_sources(shard_idx)

def _iter_pytorch(self, worker_info):
# fix for fsspec when using multprocess
_reset_fsspec_lock()
if worker_info is None: # single-process data loading, return the full iterator
yield from IterableDataset.__iter__(self)
else: # in a worker process
# check if there aren't too many workers
if worker_info.id == 0 and self.n_shards < worker_info.num_workers:
logger.warning(
f"Too many dataloader workers: {worker_info.num_workers} (max is dataset.n_shards={self.n_shards}). "
f"Stopping dataloader workers [{self.n_shards}...{worker_info.num_workers -1}]."
)
logger.warning(
f"To parallelize data loading, we give each process some shards (or data sources) to process. "
f"Therefore it's unnecessary to have a number of workers greater than dataset.n_shards={self.n_shards}."
f"To enable more parallelism, please split the dataset in more files than {self.n_shards}."
)
# split workload
shards_indices = list(range(worker_info.id, self.n_shards, worker_info.num_workers))
if shards_indices:
logger.debug(
f"dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{self.n_shards} shards."
)
for shard_idx in shards_indices:
for key, example in self._iter_shard(shard_idx):
if self.features:
yield _apply_feature_types_on_example(
example, self.features, token_per_repo_id=self._token_per_repo_id
)
else:
yield example
logger.debug(
f"dataloader worker#{worker_info.id}, ': Finished iterating over {len(shards_indices)}/{self.n_shards} shards."
)
else:
logger.debug(
f"dataloader worker#{worker_info.id}, ': Stopping... Number of dataset shards < num_workers ({self.n_shards}<{worker_info.num_workers})."
)

def __iter__(self):
if "torch" in sys.modules:
import torch.utils.data

worker_info = torch.utils.data.get_worker_info()
if isinstance(self, torch.utils.data.IterableDataset) and worker_info is not None:
# We're a torch.utils.data.IterableDataset in a PyTorch worker process
yield from self._iter_pytorch(worker_info)
return

for key, example in self._iter():
if self.features:
# `IterableDataset` automatically fills missing columns with None.
Expand Down Expand Up @@ -872,11 +945,12 @@ def with_format(
type (`str`, optional, default None): if set to "torch", the returned dataset
will be a subclass of torch.utils.data.IterableDataset to be used in a DataLoader
"""
type = get_format_type_from_alias(type)
# TODO(QL): add examples formatting to get tensors when using the "torch" format
# TODO(QL): add format_kwargs
# TODO(QL): add format_columns and return_all_columns
# TODO(QL): add pandas, numpy and tf formats
return iterable_dataset(
return IterableDataset(
ex_iterable=self._ex_iterable,
info=self._info.copy(),
split=self._split,
Expand Down Expand Up @@ -987,7 +1061,7 @@ def map(
)
info = self.info.copy()
info.features = features
return iterable_dataset(
return IterableDataset(
ex_iterable=ex_iterable,
info=info,
split=self._split,
Expand Down Expand Up @@ -1059,7 +1133,7 @@ def filter(
batched=batched,
batch_size=batch_size,
)
return iterable_dataset(
return IterableDataset(
ex_iterable=ex_iterable,
info=info,
split=self._split,
Expand Down Expand Up @@ -1123,7 +1197,7 @@ def shuffle(
else:
generator = deepcopy(generator)
shuffling = ShufflingConfig(generator=generator)
return iterable_dataset(
return IterableDataset(
ex_iterable=BufferShuffledExamplesIterable(
self._ex_iterable, buffer_size=buffer_size, generator=generator
).shuffle_data_sources(generator),
Expand Down Expand Up @@ -1166,7 +1240,7 @@ def skip(self, n) -> "IterableDataset":
```
"""
ex_iterable = SkipExamplesIterable(self._ex_iterable, n)
return iterable_dataset(
return IterableDataset(
ex_iterable=ex_iterable,
info=self._info.copy(),
split=self._split,
Expand Down Expand Up @@ -1197,7 +1271,7 @@ def take(self, n) -> "IterableDataset":
```
"""
ex_iterable = TakeExamplesIterable(self._ex_iterable, n)
return iterable_dataset(
return IterableDataset(
ex_iterable=ex_iterable,
info=self._info.copy(),
split=self._split,
Expand Down Expand Up @@ -1387,7 +1461,7 @@ def cast_column(self, column: str, feature: FeatureType) -> "IterableDataset":
info.copy()
except ValueError:
info.task_templates = None
return iterable_dataset(
return IterableDataset(
ex_iterable=self._ex_iterable,
info=info,
split=self._split,
Expand Down Expand Up @@ -1437,7 +1511,7 @@ def cast(
info.copy()
except ValueError:
info.task_templates = None
return iterable_dataset(
return IterableDataset(
ex_iterable=self._ex_iterable,
info=info,
split=self._split,
Expand All @@ -1455,7 +1529,7 @@ def _resolve_features(self):
features = _infer_features_from_batch(self._head())
info = self.info.copy()
info.features = features
return iterable_dataset(
return IterableDataset(
ex_iterable=self._ex_iterable,
info=info,
split=self._split,
Expand All @@ -1465,30 +1539,6 @@ def _resolve_features(self):
)


def iterable_dataset(
ex_iterable: Iterable,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
format_type: Optional[str] = None,
shuffling: Optional[ShufflingConfig] = None,
token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None,
):
if format_type is not None and format_type == "torch":
from .formatting.dataset_wrappers.torch_iterable_dataset import TorchIterableDataset

cls = TorchIterableDataset
else:
cls = IterableDataset
return cls(
ex_iterable=ex_iterable,
info=info,
split=split,
format_type=format_type,
shuffling=shuffling,
token_per_repo_id=token_per_repo_id,
)


def _concatenate_iterable_datasets(
dsets: List[IterableDataset],
info: Optional[DatasetInfo] = None,
Expand Down Expand Up @@ -1546,7 +1596,7 @@ def _concatenate_iterable_datasets(
# Get all the auth tokens per repository - in case the datasets come from different private repositories
token_per_repo_id = {repo_id: token for dataset in dsets for repo_id, token in dataset._token_per_repo_id.items()}
# Return new daset
return iterable_dataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id)
return IterableDataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id)


def _interleave_iterable_datasets(
Expand Down Expand Up @@ -1614,4 +1664,4 @@ def _interleave_iterable_datasets(
repo_id: token for dataset in datasets for repo_id, token in dataset._token_per_repo_id.items()
}
# Return new daset
return iterable_dataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id)
return IterableDataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id)
Loading