Skip to content

Commit

Permalink
apply formatting after iter_arrow to speed up format -> map, filter f…
Browse files Browse the repository at this point in the history
…or iterable datasets (#7207)

* apply formatting after iter_arrow

* add support for formatting to map iteration

* formatted iterator for filter

* fix filtered formatting

* option to disable formatting for outputs of map

* remove format_outputs kwarg

* rename batched_examples_iterator -> inputs_iterator

* support arbitrary input formatting in filtered examples iterable iter arrow

* preserve formatting on filtered shuffle

* pass token_per_repo_id to python_feature_decoder in formatters

* implement FormattedExamplesIterator

* fix formatted examples iterable

* restore is_typed property

* pass formatting config to formatted examples iterable

* fix formatter init

* map examples iterable expects to receive rebatchedarrowexamplesiterable instance

* only apply features if they exist

* fix shuffle and shard

* remove formatting from FilteredExamplesIterable

* run pre commit

* filtered iter_arrow always allowed if available

* filtered examples iterable needs formatting when iter_arrow enabled

* only iter arrow on filter if formatting is set

* add features property to support feature inference

* fix features property

* dont re-encode featuers

* avoid re-encoding outputs of map

* map should not preserve formatting

* update comment

* update map features property

* return bool for mapped ex iterable is typed

* pass return features to mapped exampels iterable constructor

* don't iter arrow with formatted filter to avoid re formatting

* avoid re-formatting data

* rename return features -> features

* update refs to return_features

* decode features in batched map

* preserve formatting in with_format

* fix features (mapped ex iterable

* remove formatted examples iterable from with_format

* avoid reapplying features when chaining filter, map

* preserve formatting in map

* fix tests

* style

* fix tests

---------

Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
Co-authored-by: Quentin Lhoest <lhoest.q@gmail.com>
  • Loading branch information
3 people authored Jan 14, 2025
1 parent 7a1a84b commit 75e61d1
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 74 deletions.
24 changes: 16 additions & 8 deletions src/datasets/formatting/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,14 @@ def extract_batch(self, pa_table: pa.Table) -> pd.DataFrame:


class PythonFeaturesDecoder:
def __init__(self, features: Optional[Features]):
def __init__(
self, features: Optional[Features], token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None
):
self.features = features
self.token_per_repo_id = token_per_repo_id

def decode_row(self, row: dict) -> dict:
return self.features.decode_example(row) if self.features else row
return self.features.decode_example(row, token_per_repo_id=self.token_per_repo_id) if self.features else row

def decode_column(self, column: list, column_name: str) -> list:
return self.features.decode_column(column, column_name) if self.features else column
Expand Down Expand Up @@ -393,9 +396,14 @@ class Formatter(Generic[RowFormat, ColumnFormat, BatchFormat]):
numpy_arrow_extractor = NumpyArrowExtractor
pandas_arrow_extractor = PandasArrowExtractor

def __init__(self, features: Optional[Features] = None):
def __init__(
self,
features: Optional[Features] = None,
token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None,
):
self.features = features
self.python_features_decoder = PythonFeaturesDecoder(self.features)
self.token_per_repo_id = token_per_repo_id
self.python_features_decoder = PythonFeaturesDecoder(self.features, self.token_per_repo_id)
self.pandas_features_decoder = PandasFeaturesDecoder(self.features)

def __call__(self, pa_table: pa.Table, query_type: str) -> Union[RowFormat, ColumnFormat, BatchFormat]:
Expand Down Expand Up @@ -433,8 +441,8 @@ def format_batch(self, pa_table: pa.Table) -> pa.Table:


class PythonFormatter(Formatter[Mapping, list, Mapping]):
def __init__(self, features=None, lazy=False):
super().__init__(features)
def __init__(self, features=None, lazy=False, token_per_repo_id=None):
super().__init__(features, token_per_repo_id)
self.lazy = lazy

def format_row(self, pa_table: pa.Table) -> Mapping:
Expand Down Expand Up @@ -484,8 +492,8 @@ class CustomFormatter(Formatter[dict, ColumnFormat, dict]):
to return.
"""

def __init__(self, transform: Callable[[dict], dict], features=None, **kwargs):
super().__init__(features=features)
def __init__(self, transform: Callable[[dict], dict], features=None, token_per_repo_id=None, **kwargs):
super().__init__(features=features, token_per_repo_id=token_per_repo_id)
self.transform = transform

def format_row(self, pa_table: pa.Table) -> dict:
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/formatting/jax_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@


class JaxFormatter(TensorFormatter[Mapping, "jax.Array", Mapping]):
def __init__(self, features=None, device=None, **jnp_array_kwargs):
super().__init__(features=features)
def __init__(self, features=None, device=None, token_per_repo_id=None, **jnp_array_kwargs):
super().__init__(features=features, token_per_repo_id=token_per_repo_id)
import jax
from jaxlib.xla_client import Device

Expand Down
4 changes: 2 additions & 2 deletions src/datasets/formatting/np_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@


class NumpyFormatter(TensorFormatter[Mapping, np.ndarray, Mapping]):
def __init__(self, features=None, **np_array_kwargs):
super().__init__(features=features)
def __init__(self, features=None, token_per_repo_id=None, **np_array_kwargs):
super().__init__(features=features, token_per_repo_id=token_per_repo_id)
self.np_array_kwargs = np_array_kwargs

def _consolidate(self, column):
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/formatting/tf_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@


class TFFormatter(TensorFormatter[Mapping, "tf.Tensor", Mapping]):
def __init__(self, features=None, **tf_tensor_kwargs):
super().__init__(features=features)
def __init__(self, features=None, token_per_repo_id=None, **tf_tensor_kwargs):
super().__init__(features=features, token_per_repo_id=token_per_repo_id)
self.tf_tensor_kwargs = tf_tensor_kwargs
import tensorflow as tf # noqa: F401 - import tf at initialization

Expand Down
4 changes: 2 additions & 2 deletions src/datasets/formatting/torch_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@


class TorchFormatter(TensorFormatter[Mapping, "torch.Tensor", Mapping]):
def __init__(self, features=None, **torch_tensor_kwargs):
super().__init__(features=features)
def __init__(self, features=None, token_per_repo_id=None, **torch_tensor_kwargs):
super().__init__(features=features, token_per_repo_id=token_per_repo_id)
self.torch_tensor_kwargs = torch_tensor_kwargs
import torch # noqa import torch at initialization

Expand Down
Loading

0 comments on commit 75e61d1

Please sign in to comment.