Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IterableDataset formatting in numpy/torch/tf/jax #5084

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,8 @@ def encode_nested_example(schema, obj, level=0):
if isinstance(schema, dict):
if level == 0 and obj is None:
raise ValueError("Got None but expected a dictionary instead")
if obj is not None and len(schema) != len(obj):
obj = {k: obj.get(k) for k in schema}
return (
{
k: encode_nested_example(sub_schema, sub_obj, level=level + 1)
Expand Down
24 changes: 17 additions & 7 deletions src/datasets/formatting/dataset_wrappers/torch_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import fsspec.asyn
import torch

from ...iterable_dataset import IterableDataset, _apply_feature_types
from ...formatting import get_formatter
from ...iterable_dataset import IterableDataset
from ...utils.logging import get_logger


Expand Down Expand Up @@ -51,14 +52,23 @@ def __iter__(self):
logger.debug(
f"dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{self.n_shards} shards."
)
format_kwargs = self._format_kwargs if self._format_kwargs is not None else {}
features = self._resolve_features().features
formatter = get_formatter(self._format_type, features=features, **format_kwargs)
for shard_idx in shards_indices:
for key, example in self._iter_shard(shard_idx):
if self.features:
yield _apply_feature_types(
example, self.features, token_per_repo_id=self._token_per_repo_id
)
else:
yield example
example = features.encode_example(example)
formatted_example = formatter.format_example(
{
column_name: example[column_name]
for column_name in (example if self._format_columns is None else self._format_columns)
}
)
if self._output_all_columns:
for column_name in example:
if column_name not in formatted_example:
formatted_example[column_name] = example[column_name]
yield formatted_example
logger.debug(
f"dataloader worker#{worker_info.id}, ': Finished iterating over {len(shards_indices)}/{self.n_shards} shards."
)
Expand Down
17 changes: 15 additions & 2 deletions src/datasets/formatting/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from ..features import Features
from ..features.features import _ArrayXDExtensionType, _is_zero_copy_only, decode_nested_example, pandas_types_mapper
from ..table import Table
from ..table import InMemoryTable, Table
from ..utils.py_utils import no_op_if_value_is_null


Expand Down Expand Up @@ -271,7 +271,7 @@ class Formatter(Generic[RowFormat, ColumnFormat, BatchFormat]):
numpy_arrow_extractor = NumpyArrowExtractor
pandas_arrow_extractor = PandasArrowExtractor

def __init__(self, features=None, decoded=True):
def __init__(self, features: Optional[Features] = None, decoded=True):
self.features = features
self.decoded = decoded
self.python_features_decoder = PythonFeaturesDecoder(self.features)
Expand All @@ -294,6 +294,11 @@ def format_column(self, pa_table: pa.Table) -> ColumnFormat:
def format_batch(self, pa_table: pa.Table) -> BatchFormat:
raise NotImplementedError

def format_example(self, encoded_example: dict) -> dict:
# TODO(QL): optimize this to convert directly to the right format without using Arrow
pa_table = InMemoryTable.from_pylist([encoded_example])
return self.format_row(pa_table)


class ArrowFormatter(Formatter[pa.Table, pa.Array, pa.Table]):
def format_row(self, pa_table: pa.Table) -> pa.Table:
Expand Down Expand Up @@ -325,6 +330,14 @@ def format_batch(self, pa_table: pa.Table) -> dict:
batch = self.python_features_decoder.decode_batch(batch)
return batch

def format_example(self, encoded_example: dict) -> dict:
if self.features is None:
return super().format_example(encoded_example)
elif self.decoded:
return self.features.decode_example(encoded_example)
else:
return encoded_example


class PandasFormatter(Formatter):
def format_row(self, pa_table: pa.Table) -> pd.DataFrame:
Expand Down
116 changes: 70 additions & 46 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .arrow_dataset import DatasetInfoMixin
from .features import Features
from .features.features import FeatureType, _align_features, _check_if_features_can_be_aligned
from .formatting import PythonFormatter
from .formatting import PythonFormatter, get_format_type_from_alias, get_formatter
from .info import DatasetInfo
from .splits import NamedSplit
from .table import table_cast
Expand Down Expand Up @@ -207,7 +207,6 @@ class VerticallyConcatenatedMultiSourcesExamplesIterable(_BaseExamplesIterable):
We use `IterableDataset._resolve_features` to obtain the features of all the datasets to concatenate.

Then for each example, `IterableDataset` and `TypedExamplesIterable` automatically fill missing columns with None.
This is done with `_apply_feature_types`.
"""

def __init__(self, ex_iterables: List[_BaseExamplesIterable]):
Expand Down Expand Up @@ -259,7 +258,6 @@ class HorizontallyConcatenatedMultiSourcesExamplesIterable(_BaseExamplesIterable
We use `IterableDataset._resolve_features` to obtain the features of all the datasets to concatenate.

Then for each example, `IterableDataset` and `TypedExamplesIterable` automatically fill missing columns with None.
This is done with `_apply_feature_types`.
"""

def __init__(self, ex_iterables: List[_BaseExamplesIterable]):
Expand Down Expand Up @@ -632,21 +630,6 @@ def n_shards(self) -> int:
return self.ex_iterable.n_shards


def _apply_feature_types(
example: dict, features: Features, token_per_repo_id: Dict[str, Union[str, bool, None]]
) -> dict:
example = dict(example)
# add missing columns
for column_name in features:
if column_name not in example:
example[column_name] = None
# we encode the example for ClassLabel feature types for example
encoded_example = features.encode_example(example)
# Decode example for Audio feature, e.g.
decoded_example = features.decode_example(encoded_example, token_per_repo_id=token_per_repo_id)
return decoded_example


class TypedExamplesIterable(_BaseExamplesIterable):
def __init__(
self,
Expand All @@ -660,9 +643,13 @@ def __init__(

def __iter__(self):
# Then for each example, `TypedExamplesIterable` automatically fills missing columns with None.
# This is done with `_apply_feature_types`.
for key, example in self.ex_iterable:
yield key, _apply_feature_types(example, self.features, token_per_repo_id=self.token_per_repo_id)
# we encode the example for ClassLabel feature types for example
# this also adds the missing columns
example = self.features.encode_example(example)
# Decode example for Audio feature, e.g.
example = self.features.decode_example(example, token_per_repo_id=self.token_per_repo_id)
yield key, example

def shuffle_data_sources(self, generator: np.random.Generator) -> "TypedExamplesIterable":
"""Shuffle the wrapped examples iterable."""
Expand Down Expand Up @@ -709,19 +696,33 @@ def __init__(
ex_iterable: _BaseExamplesIterable,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
format_type: Optional[str] = None,
format: Optional[dict] = None,
shuffling: Optional[ShufflingConfig] = None,
token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None,
):
info = info.copy() if info is not None else DatasetInfo()
DatasetInfoMixin.__init__(self, info=info, split=split)

self._ex_iterable = ex_iterable
self._format_type = format_type
self._shuffling = shuffling
self._epoch = 0
self._token_per_repo_id: Dict[str, Union[str, bool, None]] = token_per_repo_id or {}

format = format or {}
self._format_type: Optional[str] = format.get("type")
self._format_kwargs: dict = format.get("format_kwargs", {})
self._format_columns: Optional[list] = format.get("columns")
self._output_all_columns: bool = format.get("output_all_columns", False)

@property
def format(self):
return {
"type": self._format_type,
"format_kwargs": self._format_kwargs,
"columns": self._format_columns,
"output_all_columns": self._output_all_columns,
}

def _head(self, n=5):
return _examples_to_batch([x for key, x in islice(self._iter(), n)])

Expand Down Expand Up @@ -755,13 +756,22 @@ def _iter_shard(self, shard_idx: int):
yield from ex_iterable.shard_data_sources(shard_idx)

def __iter__(self):
format_kwargs = self._format_kwargs if self._format_kwargs is not None else {}
features = self._resolve_features().features
formatter = get_formatter(self._format_type, features=features, **format_kwargs)
for key, example in self._iter():
if self.features:
# `IterableDataset` automatically fills missing columns with None.
# This is done with `_apply_feature_types`.
yield _apply_feature_types(example, self.features, token_per_repo_id=self._token_per_repo_id)
else:
yield example
example = features.encode_example(example)
formatted_example = formatter.format_example(
{
column_name: example[column_name]
for column_name in (example if self._format_columns is None else self._format_columns)
}
)
if self._output_all_columns:
for column_name in example:
if column_name not in formatted_example:
formatted_example[column_name] = example[column_name]
yield formatted_example

@staticmethod
def from_generator(
Expand Down Expand Up @@ -820,25 +830,39 @@ def from_generator(
def with_format(
self,
type: Optional[str] = None,
columns: Optional[List] = None,
output_all_columns: bool = False,
**format_kwargs,
) -> "IterableDataset":
"""
Return a dataset with the specified format.
This method only supports the "torch" format for now.

Args:

type (`str`, optional, default None): if set to "torch", the returned dataset
will be a subclass of torch.utils.data.IterableDataset to be used in a DataLoader
type (`str`, *optional*, default `None`):
Either output type selected in [None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow'].
If set to "torch", the returned dataset will be a subclass of torch.utils.data.IterableDataset to be used in a DataLoader.
columns (:obj:`List[str]`, optional): columns to format in the output
None means __getitem__ returns all columns (default)
output_all_columns (:obj:`bool`, default to False): keep un-formatted columns as well in the output (as python objects)
**format_kwargs (additional keyword arguments): keywords arguments passed to the convert function like `np.array`, `torch.tensor` or `tensorflow.ragged.constant`.

"""
# TODO(QL): add examples formatting to get tensors when using the "torch" format
# TODO(QL): add format_kwargs
# TODO(QL): add format_columns and return_all_columns
# TODO(QL): add pandas, numpy and tf formats
# Check that the format_type and format_kwargs are valid and make it possible to have a Formatter
type = get_format_type_from_alias(type)
_ = get_formatter(type, features=self.features, **format_kwargs)

return iterable_dataset(
ex_iterable=self._ex_iterable,
info=self._info.copy(),
split=self._split,
format_type=type,
format={
"type": type,
"columns": columns,
"output_all_columns": output_all_columns,
"format_kwargs": format_kwargs,
},
shuffling=copy.deepcopy(self._shuffling),
token_per_repo_id=self._token_per_repo_id,
)
Expand Down Expand Up @@ -949,7 +973,7 @@ def map(
ex_iterable=ex_iterable,
info=info,
split=self._split,
format_type=self._format_type,
format=self.format,
shuffling=copy.deepcopy(self._shuffling),
token_per_repo_id=self._token_per_repo_id,
)
Expand Down Expand Up @@ -1021,7 +1045,7 @@ def filter(
ex_iterable=ex_iterable,
info=info,
split=self._split,
format_type=self._format_type,
format=self.format,
shuffling=copy.deepcopy(self._shuffling),
token_per_repo_id=self._token_per_repo_id,
)
Expand Down Expand Up @@ -1087,7 +1111,7 @@ def shuffle(
).shuffle_data_sources(generator),
info=self._info.copy(),
split=self._split,
format_type=self._format_type,
format=self.format,
shuffling=shuffling,
token_per_repo_id=self._token_per_repo_id,
)
Expand Down Expand Up @@ -1128,7 +1152,7 @@ def skip(self, n) -> "IterableDataset":
ex_iterable=ex_iterable,
info=self._info.copy(),
split=self._split,
format_type=self._format_type,
format=self.format,
shuffling=copy.deepcopy(self._shuffling),
token_per_repo_id=self._token_per_repo_id,
)
Expand Down Expand Up @@ -1159,7 +1183,7 @@ def take(self, n) -> "IterableDataset":
ex_iterable=ex_iterable,
info=self._info.copy(),
split=self._split,
format_type=self._format_type,
format=self.format,
shuffling=copy.deepcopy(self._shuffling),
token_per_repo_id=self._token_per_repo_id,
)
Expand Down Expand Up @@ -1349,7 +1373,7 @@ def cast_column(self, column: str, feature: FeatureType) -> "IterableDataset":
ex_iterable=self._ex_iterable,
info=info,
split=self._split,
format_type=self._format_type,
format=self.format,
shuffling=copy.deepcopy(self._shuffling),
token_per_repo_id=self._token_per_repo_id,
)
Expand Down Expand Up @@ -1399,7 +1423,7 @@ def cast(
ex_iterable=self._ex_iterable,
info=info,
split=self._split,
format_type=self._format_type,
format=self.format,
shuffling=copy.deepcopy(self._shuffling),
token_per_repo_id=self._token_per_repo_id,
)
Expand All @@ -1417,7 +1441,7 @@ def _resolve_features(self):
ex_iterable=self._ex_iterable,
info=info,
split=self._split,
format_type=self._format_type,
format=self.format,
shuffling=copy.deepcopy(self._shuffling),
token_per_repo_id=self._token_per_repo_id,
)
Expand All @@ -1427,11 +1451,11 @@ def iterable_dataset(
ex_iterable: Iterable,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
format_type: Optional[str] = None,
format: Optional[dict] = None,
shuffling: Optional[ShufflingConfig] = None,
token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None,
):
if format_type is not None and format_type == "torch":
if format and format.get("type") == "torch":
from .formatting.dataset_wrappers.torch_iterable_dataset import TorchIterableDataset

cls = TorchIterableDataset
Expand All @@ -1441,7 +1465,7 @@ def iterable_dataset(
ex_iterable=ex_iterable,
info=info,
split=split,
format_type=format_type,
format=format,
shuffling=shuffling,
token_per_repo_id=token_per_repo_id,
)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2412,11 +2412,11 @@ def test_format_ragged_vectors(self, in_memory):
self.assertIsInstance(dset[:2]["filename"], np.ndarray)
self.assertIsInstance(dset["filename"], np.ndarray)
self.assertIsInstance(dset[0]["vec"], np.ndarray)
self.assertIsInstance(dset[:2]["vec"], np.ndarray)
self.assertIsInstance(dset["vec"], np.ndarray)
# array is flat for ragged vectors in numpy
self.assertTupleEqual(dset[:2]["vec"].shape, (2,))
self.assertTupleEqual(dset["vec"][:2].shape, (2,))
# numpy doesn't support ragged tensors, so we should have lists
self.assertIsInstance(dset[:2]["vec"], list)
self.assertIsInstance(dset[:2]["vec"][0], np.ndarray)
self.assertIsInstance(dset["vec"], list)
self.assertIsInstance(dset["vec"][0], np.ndarray)

dset.set_format("torch")
self.assertIsNotNone(dset[0])
Expand Down