From 1e4bf10b8fbd2db5120e301386fc32357e49290c Mon Sep 17 00:00:00 2001 From: Pierre Marcenac Date: Fri, 24 May 2024 00:59:31 -0700 Subject: [PATCH] Make tfds.data_source pickable. PiperOrigin-RevId: 636824581 --- .../core/data_sources/array_record.py | 18 +------ tensorflow_datasets/core/data_sources/base.py | 32 +++++++++--- .../core/data_sources/base_test.py | 49 +++++++++++++++---- .../core/data_sources/parquet.py | 3 +- tensorflow_datasets/core/dataset_builder.py | 4 +- tensorflow_datasets/testing/mocking.py | 39 ++++++++++++--- tensorflow_datasets/testing/mocking_test.py | 9 ++++ 7 files changed, 110 insertions(+), 44 deletions(-) diff --git a/tensorflow_datasets/core/data_sources/array_record.py b/tensorflow_datasets/core/data_sources/array_record.py index c8f3ca8fabe..88cf3cb1c82 100644 --- a/tensorflow_datasets/core/data_sources/array_record.py +++ b/tensorflow_datasets/core/data_sources/array_record.py @@ -20,13 +20,8 @@ """ import dataclasses -from typing import Any, Optional -from tensorflow_datasets.core import dataset_info as dataset_info_lib -from tensorflow_datasets.core import decode -from tensorflow_datasets.core import splits as splits_lib from tensorflow_datasets.core.data_sources import base -from tensorflow_datasets.core.utils import type_utils from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_data_source @@ -42,18 +37,9 @@ class ArrayRecordDataSource(base.BaseDataSource): source. """ - dataset_info: dataset_info_lib.DatasetInfo - split: splits_lib.Split = None - decoders: Optional[type_utils.TreeDict[decode.partial_decode.DecoderArg]] = ( - None - ) - # In order to lazy load array_record, we don't load - # `array_record_data_source.ArrayRecordDataSource` here. - data_source: Any = dataclasses.field(init=False) - length: int = dataclasses.field(init=False) - def __post_init__(self): - file_instructions = base.file_instructions(self.dataset_info, self.split) + dataset_info = self.dataset_builder.info + file_instructions = base.file_instructions(dataset_info, self.split) self.data_source = array_record_data_source.ArrayRecordDataSource( file_instructions ) diff --git a/tensorflow_datasets/core/data_sources/base.py b/tensorflow_datasets/core/data_sources/base.py index c70f736b92c..a6109d24b92 100644 --- a/tensorflow_datasets/core/data_sources/base.py +++ b/tensorflow_datasets/core/data_sources/base.py @@ -17,12 +17,14 @@ from collections.abc import MappingView, Sequence import dataclasses +import functools import typing from typing import Any, Generic, Iterable, Protocol, SupportsIndex, TypeVar from tensorflow_datasets.core import dataset_info as dataset_info_lib from tensorflow_datasets.core import decode from tensorflow_datasets.core import splits as splits_lib +from tensorflow_datasets.core.features import top_level_feature from tensorflow_datasets.core.utils import shard_utils from tensorflow_datasets.core.utils import type_utils from tensorflow_datasets.core.utils.lazy_imports_utils import tree @@ -54,6 +56,14 @@ def file_instructions( return split_dict[split].file_instructions +class _DatasetBuilder(Protocol): + """Protocol for the DatasetBuilder to avoid cyclic imports.""" + + @property + def info(self) -> dataset_info_lib.DatasetInfo: + ... + + @dataclasses.dataclass class BaseDataSource(MappingView, Sequence): """Base DataSource to override all dunder methods with the deserialization. @@ -64,22 +74,28 @@ class BaseDataSource(MappingView, Sequence): deserialization/decoding. Attributes: - dataset_info: The DatasetInfo of the + dataset_builder: The dataset builder. split: The split to load in the data source. decoders: Optional decoders for decoding. data_source: The underlying data source to initialize in the __post_init__. """ - dataset_info: dataset_info_lib.DatasetInfo + dataset_builder: _DatasetBuilder split: splits_lib.Split | None = None decoders: type_utils.TreeDict[decode.partial_decode.DecoderArg] | None = None data_source: DataSource[Any] = dataclasses.field(init=False) + @functools.cached_property + def _features(self) -> top_level_feature.TopLevelFeature: + """Caches features because we log the use of dataset_builder.info.""" + features = self.dataset_builder.info.features + if not features: + raise ValueError('No feature defined in the dataset builder.') + return features + def __getitem__(self, key: SupportsIndex) -> Any: record = self.data_source[key.__index__()] - return self.dataset_info.features.deserialize_example_np( - record, decoders=self.decoders - ) + return self._features.deserialize_example_np(record, decoders=self.decoders) def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]: """Retrieves items by batch. @@ -98,7 +114,6 @@ def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]: if not keys: return [] records = self.data_source.__getitems__(keys) - features = self.dataset_info.features if len(keys) != len(records): raise IndexError( f'Requested {len(keys)} records but got' @@ -106,7 +121,7 @@ def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]: f'{keys=}, {records=}' ) return [ - features.deserialize_example_np(record, decoders=self.decoders) + self._features.deserialize_example_np(record, decoders=self.decoders) for record in records ] @@ -114,8 +129,9 @@ def __repr__(self) -> str: decoders_repr = ( tree.map_structure(type, self.decoders) if self.decoders else None ) + name = self.dataset_builder.info.name return ( - f'{self.__class__.__name__}(name={self.dataset_info.name}, ' + f'{self.__class__.__name__}(name={name}, ' f'split={self.split!r}, ' f'decoders={decoders_repr})' ) diff --git a/tensorflow_datasets/core/data_sources/base_test.py b/tensorflow_datasets/core/data_sources/base_test.py index 6891a2cff7b..a6ad8151a21 100644 --- a/tensorflow_datasets/core/data_sources/base_test.py +++ b/tensorflow_datasets/core/data_sources/base_test.py @@ -15,13 +15,15 @@ """Tests for all data sources.""" +import pickle from unittest import mock +import cloudpickle from etils import epath import pytest import tensorflow_datasets as tfds from tensorflow_datasets import testing -from tensorflow_datasets.core import dataset_builder +from tensorflow_datasets.core import dataset_builder as dataset_builder_lib from tensorflow_datasets.core import dataset_info as dataset_info_lib from tensorflow_datasets.core import decode from tensorflow_datasets.core import file_adapters @@ -77,7 +79,7 @@ def mocked_parquet_dataset(): ) def test_read_write( tmp_path: epath.Path, - builder_cls: dataset_builder.DatasetBuilder, + builder_cls: dataset_builder_lib.DatasetBuilder, file_format: file_adapters.FileFormat, ): builder = builder_cls(data_dir=tmp_path, file_format=file_format) @@ -106,7 +108,9 @@ def test_read_write( ] -def create_dataset_info(file_format: file_adapters.FileFormat): +def create_dataset_builder( + file_format: file_adapters.FileFormat, +) -> dataset_builder_lib.DatasetBuilder: with mock.patch.object(splits_lib, 'SplitInfo') as split_mock: split_mock.return_value.name = 'train' split_mock.return_value.file_instructions = _FILE_INSTRUCTIONS @@ -114,7 +118,11 @@ def create_dataset_info(file_format: file_adapters.FileFormat): dataset_info.file_format = file_format dataset_info.splits = {'train': split_mock()} dataset_info.name = 'dataset_name' - return dataset_info + + dataset_builder = mock.create_autospec(dataset_builder_lib.DatasetBuilder) + dataset_builder.info = dataset_info + + return dataset_builder @pytest.mark.parametrize( @@ -122,12 +130,14 @@ def create_dataset_info(file_format: file_adapters.FileFormat): _DATA_SOURCE_CLS, ) def test_missing_split_raises_error(data_source_cls): - dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD) + dataset_builder = create_dataset_builder( + file_adapters.FileFormat.ARRAY_RECORD + ) with pytest.raises( ValueError, match="Unknown split 'doesnotexist'.", ): - data_source_cls(dataset_info, split='doesnotexist') + data_source_cls(dataset_builder, split='doesnotexist') @pytest.mark.usefixtures(*_FIXTURES) @@ -136,8 +146,10 @@ def test_missing_split_raises_error(data_source_cls): _DATA_SOURCE_CLS, ) def test_repr_returns_meaningful_string_without_decoders(data_source_cls): - dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD) - source = data_source_cls(dataset_info, split='train') + dataset_builder = create_dataset_builder( + file_adapters.FileFormat.ARRAY_RECORD + ) + source = data_source_cls(dataset_builder, split='train') name = data_source_cls.__name__ assert ( repr(source) == f"{name}(name=dataset_name, split='train', decoders=None)" @@ -150,9 +162,11 @@ def test_repr_returns_meaningful_string_without_decoders(data_source_cls): _DATA_SOURCE_CLS, ) def test_repr_returns_meaningful_string_with_decoders(data_source_cls): - dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD) + dataset_builder = create_dataset_builder( + file_adapters.FileFormat.ARRAY_RECORD + ) source = data_source_cls( - dataset_info, + dataset_builder, split='train', decoders={'my_feature': decode.SkipDecoding()}, ) @@ -181,3 +195,18 @@ def test_data_source_is_sliceable(): file_instructions = mock_array_record_data_source.call_args_list[1].args[0] assert file_instructions[0].skip == 0 assert file_instructions[0].take == 30000 + + +# PyGrain requires that data sources are picklable. +@pytest.mark.parametrize( + 'file_format', + file_adapters.FileFormat.with_random_access(), +) +@pytest.mark.parametrize('pickle_module', [pickle, cloudpickle]) +def test_data_source_is_picklable_after_use(file_format, pickle_module): + with tfds.testing.tmp_dir() as data_dir: + builder = tfds.testing.DummyDataset(data_dir=data_dir) + builder.download_and_prepare(file_format=file_format) + data_source = builder.as_data_source(split='train') + assert data_source[0] == {'id': 0} + assert pickle_module.loads(pickle_module.dumps(data_source))[0] == {'id': 0} diff --git a/tensorflow_datasets/core/data_sources/parquet.py b/tensorflow_datasets/core/data_sources/parquet.py index 7fe8b19b85e..048bf18994e 100644 --- a/tensorflow_datasets/core/data_sources/parquet.py +++ b/tensorflow_datasets/core/data_sources/parquet.py @@ -57,7 +57,8 @@ class ParquetDataSource(base.BaseDataSource): """ParquetDataSource to read from a ParquetDataset.""" def __post_init__(self): - file_instructions = base.file_instructions(self.dataset_info, self.split) + dataset_info = self.dataset_builder.info + file_instructions = base.file_instructions(dataset_info, self.split) filenames = [ file_instruction.filename for file_instruction in file_instructions ] diff --git a/tensorflow_datasets/core/dataset_builder.py b/tensorflow_datasets/core/dataset_builder.py index b834cb970c8..39d713871e9 100644 --- a/tensorflow_datasets/core/dataset_builder.py +++ b/tensorflow_datasets/core/dataset_builder.py @@ -774,13 +774,13 @@ def build_single_data_source( file_format = self.info.file_format if file_format == file_adapters.FileFormat.ARRAY_RECORD: return array_record.ArrayRecordDataSource( - self.info, + self, split=split, decoders=decoders, ) elif file_format == file_adapters.FileFormat.PARQUET: return parquet.ParquetDataSource( - self.info, + self, split=split, decoders=decoders, ) diff --git a/tensorflow_datasets/testing/mocking.py b/tensorflow_datasets/testing/mocking.py index 9378204504c..863f36287a5 100644 --- a/tensorflow_datasets/testing/mocking.py +++ b/tensorflow_datasets/testing/mocking.py @@ -120,9 +120,7 @@ def _getitems( _getitem(self, record_key, generator, serialized=serialized) for record_key in record_keys ] - if serialized: - return np.array(items) - return items + return np.asarray(items) def _deserialize_example_np(serialized_example, *, decoders=None): @@ -173,6 +171,7 @@ def mock_data( as_data_source_fn: Optional[Callable[..., Sequence[Any]]] = None, data_dir: Optional[str] = None, mock_array_record_data_source: Optional[PickableDataSourceMock] = None, + use_in_multiprocessing: bool = False, ) -> Iterator[None]: """Mock tfds to generate random data. @@ -262,6 +261,10 @@ def as_dataset(self, *args, **kwargs): mock_array_record_data_source: Overwrite a mock for the underlying ArrayRecord data source if it is used. Note: If used the same mock will be used for all data sources loaded within this context. + use_in_multiprocessing: If True, the mock will use a multiprocessing-safe + approach to generate the data. It's notably useful for PyGrain. The goal + is to migrate the codebase to this mode by default. Find a more detailed + explanation of this parameter in a comment in the code below. Yields: None @@ -361,9 +364,31 @@ def mock_as_data_source(self, split, decoders=None, **kwargs): if split is None: split = {s: s for s in self.info.splits} - generator_cls, features, _, _ = _get_fake_data_components( - decoders, self.info.features - ) + features = self.info.features + if use_in_multiprocessing: + # In multiprocessing, we generate serialized data. The data is then + # re-deserialized by the feature as it would normally happen in TFDS. In + # this approach, we don't need to monkey-patch workers to propagate the + # information that deserialize_example_np should be a no-op. Indeed, doing + # so is difficult as PyGrain uses the `spawn` multiprocessing mode. Users + # of tfds.testing.mock_data in the codebase started relying on the + # function not serializing (for example, they don't have TensorFlow in + # their dependency), so we cannot have use_in_multiprocessing by default. + # ┌─────────────┐ + # │ Main process│ + # └─┬──────┬────┘ + # ┌───────▼─┐ ┌─▼───────┐ + # │ worker1 │ │ worker2 │ ... + # └───────┬─┘ └─┬───────┘ + # serialized data by the generator + # ┌───────▼─┐ ┌─▼───────┐ + # │ tfds 1 │ │ tfds 2 │ ... + # └───────┬─┘ └─┬───────┘ + # deserialized data + generator_cls = SerializedRandomFakeGenerator + else: + # We generate already deserialized data with the generator. + generator_cls, _, _, _ = _get_fake_data_components(decoders, features) generator = generator_cls(features, num_examples) if actual_policy == MockPolicy.USE_CODE: @@ -399,7 +424,7 @@ def mock_as_data_source(self, split, decoders=None, **kwargs): def build_single_data_source(split): single_data_source = array_record.ArrayRecordDataSource( - dataset_info=self.info, split=split, decoders=decoders + dataset_builder=self, split=split, decoders=decoders ) return single_data_source diff --git a/tensorflow_datasets/testing/mocking_test.py b/tensorflow_datasets/testing/mocking_test.py index 3280e166512..d707e810cbf 100644 --- a/tensorflow_datasets/testing/mocking_test.py +++ b/tensorflow_datasets/testing/mocking_test.py @@ -392,3 +392,12 @@ def test_as_data_source_fn(): assert imagenet[0] == 'foo' assert imagenet[1] == 'bar' assert imagenet[2] == 'baz' + + +# PyGrain requires that data sources are picklable. +def test_mocked_data_source_is_pickable(): + with tfds.testing.mock_data(num_examples=2): + data_source = tfds.data_source('imagenet2012', split='train') + pickled_and_unpickled_data_source = pickle.loads(pickle.dumps(data_source)) + assert len(pickled_and_unpickled_data_source) == 2 + assert isinstance(pickled_and_unpickled_data_source[0]['image'], np.ndarray)