Skip to content

Commit

Permalink
Make tfds.data_source pickable.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636824581
  • Loading branch information
marcenacp authored and The TensorFlow Datasets Authors committed May 28, 2024
1 parent 6bbba45 commit 1e4bf10
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 44 deletions.
18 changes: 2 additions & 16 deletions tensorflow_datasets/core/data_sources/array_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,8 @@
"""

import dataclasses
from typing import Any, Optional

from tensorflow_datasets.core import dataset_info as dataset_info_lib
from tensorflow_datasets.core import decode
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core.data_sources import base
from tensorflow_datasets.core.utils import type_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_data_source


Expand All @@ -42,18 +37,9 @@ class ArrayRecordDataSource(base.BaseDataSource):
source.
"""

dataset_info: dataset_info_lib.DatasetInfo
split: splits_lib.Split = None
decoders: Optional[type_utils.TreeDict[decode.partial_decode.DecoderArg]] = (
None
)
# In order to lazy load array_record, we don't load
# `array_record_data_source.ArrayRecordDataSource` here.
data_source: Any = dataclasses.field(init=False)
length: int = dataclasses.field(init=False)

def __post_init__(self):
file_instructions = base.file_instructions(self.dataset_info, self.split)
dataset_info = self.dataset_builder.info
file_instructions = base.file_instructions(dataset_info, self.split)
self.data_source = array_record_data_source.ArrayRecordDataSource(
file_instructions
)
32 changes: 24 additions & 8 deletions tensorflow_datasets/core/data_sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

from collections.abc import MappingView, Sequence
import dataclasses
import functools
import typing
from typing import Any, Generic, Iterable, Protocol, SupportsIndex, TypeVar

from tensorflow_datasets.core import dataset_info as dataset_info_lib
from tensorflow_datasets.core import decode
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core.features import top_level_feature
from tensorflow_datasets.core.utils import shard_utils
from tensorflow_datasets.core.utils import type_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import tree
Expand Down Expand Up @@ -54,6 +56,14 @@ def file_instructions(
return split_dict[split].file_instructions


class _DatasetBuilder(Protocol):
"""Protocol for the DatasetBuilder to avoid cyclic imports."""

@property
def info(self) -> dataset_info_lib.DatasetInfo:
...


@dataclasses.dataclass
class BaseDataSource(MappingView, Sequence):
"""Base DataSource to override all dunder methods with the deserialization.
Expand All @@ -64,22 +74,28 @@ class BaseDataSource(MappingView, Sequence):
deserialization/decoding.
Attributes:
dataset_info: The DatasetInfo of the
dataset_builder: The dataset builder.
split: The split to load in the data source.
decoders: Optional decoders for decoding.
data_source: The underlying data source to initialize in the __post_init__.
"""

dataset_info: dataset_info_lib.DatasetInfo
dataset_builder: _DatasetBuilder
split: splits_lib.Split | None = None
decoders: type_utils.TreeDict[decode.partial_decode.DecoderArg] | None = None
data_source: DataSource[Any] = dataclasses.field(init=False)

@functools.cached_property
def _features(self) -> top_level_feature.TopLevelFeature:
"""Caches features because we log the use of dataset_builder.info."""
features = self.dataset_builder.info.features
if not features:
raise ValueError('No feature defined in the dataset builder.')
return features

def __getitem__(self, key: SupportsIndex) -> Any:
record = self.data_source[key.__index__()]
return self.dataset_info.features.deserialize_example_np(
record, decoders=self.decoders
)
return self._features.deserialize_example_np(record, decoders=self.decoders)

def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]:
"""Retrieves items by batch.
Expand All @@ -98,24 +114,24 @@ def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]:
if not keys:
return []
records = self.data_source.__getitems__(keys)
features = self.dataset_info.features
if len(keys) != len(records):
raise IndexError(
f'Requested {len(keys)} records but got'
f' {len(records)} records.'
f'{keys=}, {records=}'
)
return [
features.deserialize_example_np(record, decoders=self.decoders)
self._features.deserialize_example_np(record, decoders=self.decoders)
for record in records
]

def __repr__(self) -> str:
decoders_repr = (
tree.map_structure(type, self.decoders) if self.decoders else None
)
name = self.dataset_builder.info.name
return (
f'{self.__class__.__name__}(name={self.dataset_info.name}, '
f'{self.__class__.__name__}(name={name}, '
f'split={self.split!r}, '
f'decoders={decoders_repr})'
)
Expand Down
49 changes: 39 additions & 10 deletions tensorflow_datasets/core/data_sources/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@

"""Tests for all data sources."""

import pickle
from unittest import mock

import cloudpickle
from etils import epath
import pytest
import tensorflow_datasets as tfds
from tensorflow_datasets import testing
from tensorflow_datasets.core import dataset_builder
from tensorflow_datasets.core import dataset_builder as dataset_builder_lib
from tensorflow_datasets.core import dataset_info as dataset_info_lib
from tensorflow_datasets.core import decode
from tensorflow_datasets.core import file_adapters
Expand Down Expand Up @@ -77,7 +79,7 @@ def mocked_parquet_dataset():
)
def test_read_write(
tmp_path: epath.Path,
builder_cls: dataset_builder.DatasetBuilder,
builder_cls: dataset_builder_lib.DatasetBuilder,
file_format: file_adapters.FileFormat,
):
builder = builder_cls(data_dir=tmp_path, file_format=file_format)
Expand Down Expand Up @@ -106,28 +108,36 @@ def test_read_write(
]


def create_dataset_info(file_format: file_adapters.FileFormat):
def create_dataset_builder(
file_format: file_adapters.FileFormat,
) -> dataset_builder_lib.DatasetBuilder:
with mock.patch.object(splits_lib, 'SplitInfo') as split_mock:
split_mock.return_value.name = 'train'
split_mock.return_value.file_instructions = _FILE_INSTRUCTIONS
dataset_info = mock.create_autospec(dataset_info_lib.DatasetInfo)
dataset_info.file_format = file_format
dataset_info.splits = {'train': split_mock()}
dataset_info.name = 'dataset_name'
return dataset_info

dataset_builder = mock.create_autospec(dataset_builder_lib.DatasetBuilder)
dataset_builder.info = dataset_info

return dataset_builder


@pytest.mark.parametrize(
'data_source_cls',
_DATA_SOURCE_CLS,
)
def test_missing_split_raises_error(data_source_cls):
dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD)
dataset_builder = create_dataset_builder(
file_adapters.FileFormat.ARRAY_RECORD
)
with pytest.raises(
ValueError,
match="Unknown split 'doesnotexist'.",
):
data_source_cls(dataset_info, split='doesnotexist')
data_source_cls(dataset_builder, split='doesnotexist')


@pytest.mark.usefixtures(*_FIXTURES)
Expand All @@ -136,8 +146,10 @@ def test_missing_split_raises_error(data_source_cls):
_DATA_SOURCE_CLS,
)
def test_repr_returns_meaningful_string_without_decoders(data_source_cls):
dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD)
source = data_source_cls(dataset_info, split='train')
dataset_builder = create_dataset_builder(
file_adapters.FileFormat.ARRAY_RECORD
)
source = data_source_cls(dataset_builder, split='train')
name = data_source_cls.__name__
assert (
repr(source) == f"{name}(name=dataset_name, split='train', decoders=None)"
Expand All @@ -150,9 +162,11 @@ def test_repr_returns_meaningful_string_without_decoders(data_source_cls):
_DATA_SOURCE_CLS,
)
def test_repr_returns_meaningful_string_with_decoders(data_source_cls):
dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD)
dataset_builder = create_dataset_builder(
file_adapters.FileFormat.ARRAY_RECORD
)
source = data_source_cls(
dataset_info,
dataset_builder,
split='train',
decoders={'my_feature': decode.SkipDecoding()},
)
Expand Down Expand Up @@ -181,3 +195,18 @@ def test_data_source_is_sliceable():
file_instructions = mock_array_record_data_source.call_args_list[1].args[0]
assert file_instructions[0].skip == 0
assert file_instructions[0].take == 30000


# PyGrain requires that data sources are picklable.
@pytest.mark.parametrize(
'file_format',
file_adapters.FileFormat.with_random_access(),
)
@pytest.mark.parametrize('pickle_module', [pickle, cloudpickle])
def test_data_source_is_picklable_after_use(file_format, pickle_module):
with tfds.testing.tmp_dir() as data_dir:
builder = tfds.testing.DummyDataset(data_dir=data_dir)
builder.download_and_prepare(file_format=file_format)
data_source = builder.as_data_source(split='train')
assert data_source[0] == {'id': 0}
assert pickle_module.loads(pickle_module.dumps(data_source))[0] == {'id': 0}
3 changes: 2 additions & 1 deletion tensorflow_datasets/core/data_sources/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ class ParquetDataSource(base.BaseDataSource):
"""ParquetDataSource to read from a ParquetDataset."""

def __post_init__(self):
file_instructions = base.file_instructions(self.dataset_info, self.split)
dataset_info = self.dataset_builder.info
file_instructions = base.file_instructions(dataset_info, self.split)
filenames = [
file_instruction.filename for file_instruction in file_instructions
]
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,13 +774,13 @@ def build_single_data_source(
file_format = self.info.file_format
if file_format == file_adapters.FileFormat.ARRAY_RECORD:
return array_record.ArrayRecordDataSource(
self.info,
self,
split=split,
decoders=decoders,
)
elif file_format == file_adapters.FileFormat.PARQUET:
return parquet.ParquetDataSource(
self.info,
self,
split=split,
decoders=decoders,
)
Expand Down
39 changes: 32 additions & 7 deletions tensorflow_datasets/testing/mocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,7 @@ def _getitems(
_getitem(self, record_key, generator, serialized=serialized)
for record_key in record_keys
]
if serialized:
return np.array(items)
return items
return np.asarray(items)


def _deserialize_example_np(serialized_example, *, decoders=None):
Expand Down Expand Up @@ -173,6 +171,7 @@ def mock_data(
as_data_source_fn: Optional[Callable[..., Sequence[Any]]] = None,
data_dir: Optional[str] = None,
mock_array_record_data_source: Optional[PickableDataSourceMock] = None,
use_in_multiprocessing: bool = False,
) -> Iterator[None]:
"""Mock tfds to generate random data.
Expand Down Expand Up @@ -262,6 +261,10 @@ def as_dataset(self, *args, **kwargs):
mock_array_record_data_source: Overwrite a mock for the underlying
ArrayRecord data source if it is used. Note: If used the same mock will be
used for all data sources loaded within this context.
use_in_multiprocessing: If True, the mock will use a multiprocessing-safe
approach to generate the data. It's notably useful for PyGrain. The goal
is to migrate the codebase to this mode by default. Find a more detailed
explanation of this parameter in a comment in the code below.
Yields:
None
Expand Down Expand Up @@ -361,9 +364,31 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
if split is None:
split = {s: s for s in self.info.splits}

generator_cls, features, _, _ = _get_fake_data_components(
decoders, self.info.features
)
features = self.info.features
if use_in_multiprocessing:
# In multiprocessing, we generate serialized data. The data is then
# re-deserialized by the feature as it would normally happen in TFDS. In
# this approach, we don't need to monkey-patch workers to propagate the
# information that deserialize_example_np should be a no-op. Indeed, doing
# so is difficult as PyGrain uses the `spawn` multiprocessing mode. Users
# of tfds.testing.mock_data in the codebase started relying on the
# function not serializing (for example, they don't have TensorFlow in
# their dependency), so we cannot have use_in_multiprocessing by default.
# ┌─────────────┐
# │ Main process│
# └─┬──────┬────┘
# ┌───────▼─┐ ┌─▼───────┐
# │ worker1 │ │ worker2 │ ...
# └───────┬─┘ └─┬───────┘
# serialized data by the generator
# ┌───────▼─┐ ┌─▼───────┐
# │ tfds 1 │ │ tfds 2 │ ...
# └───────┬─┘ └─┬───────┘
# deserialized data
generator_cls = SerializedRandomFakeGenerator
else:
# We generate already deserialized data with the generator.
generator_cls, _, _, _ = _get_fake_data_components(decoders, features)
generator = generator_cls(features, num_examples)

if actual_policy == MockPolicy.USE_CODE:
Expand Down Expand Up @@ -399,7 +424,7 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):

def build_single_data_source(split):
single_data_source = array_record.ArrayRecordDataSource(
dataset_info=self.info, split=split, decoders=decoders
dataset_builder=self, split=split, decoders=decoders
)
return single_data_source

Expand Down
9 changes: 9 additions & 0 deletions tensorflow_datasets/testing/mocking_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,12 @@ def test_as_data_source_fn():
assert imagenet[0] == 'foo'
assert imagenet[1] == 'bar'
assert imagenet[2] == 'baz'


# PyGrain requires that data sources are picklable.
def test_mocked_data_source_is_pickable():
with tfds.testing.mock_data(num_examples=2):
data_source = tfds.data_source('imagenet2012', split='train')
pickled_and_unpickled_data_source = pickle.loads(pickle.dumps(data_source))
assert len(pickled_and_unpickled_data_source) == 2
assert isinstance(pickled_and_unpickled_data_source[0]['image'], np.ndarray)

0 comments on commit 1e4bf10

Please sign in to comment.