From f063002cccea182d43ed64328d0b5df7157e3fba Mon Sep 17 00:00:00 2001 From: hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Fri, 26 Apr 2024 13:29:20 -0700 Subject: [PATCH] Add support for infinite `PyDataset`s. (#19624) `PyDataset` now uses the `num_batches` property instead of `__len__` to support `None`, which is how one indicates the dataset is infinite. Note that infinite datasets are not shuffled. Fixes https://github.com/keras-team/keras/issues/19528 Also added exception reporting when using multithreading / multiprocessing. Previously, the program would just hang with no error reported. --- keras/src/trainers/data_adapters/__init__.py | 7 ++ .../data_adapters/py_dataset_adapter.py | 114 +++++++++++------- .../data_adapters/py_dataset_adapter_test.py | 104 ++++++++++++++-- 3 files changed, 174 insertions(+), 51 deletions(-) diff --git a/keras/src/trainers/data_adapters/__init__.py b/keras/src/trainers/data_adapters/__init__.py index 41f2a91f11a..3dc04b75498 100644 --- a/keras/src/trainers/data_adapters/__init__.py +++ b/keras/src/trainers/data_adapters/__init__.py @@ -71,6 +71,13 @@ def get_data_adapter( "sample_weights", "the sample weights", "PyDataset" ) return PyDatasetAdapter(x, class_weight=class_weight, shuffle=shuffle) + # TODO: should we warn or not? + # if x.num_batches is None and shuffle: + # warnings.warn( + # "`shuffle=True` was passed, but will be ignored since the " + # "data `x` was provided as a infinite PyDataset. The " + # "PyDataset is expected to already be shuffled." + # ) elif is_torch_dataloader(x): if y is not None: raise_unsupported_arg("y", "the targets", "torch DataLoader") diff --git a/keras/src/trainers/data_adapters/py_dataset_adapter.py b/keras/src/trainers/data_adapters/py_dataset_adapter.py index 71ab2a67736..daa56a1313f 100644 --- a/keras/src/trainers/data_adapters/py_dataset_adapter.py +++ b/keras/src/trainers/data_adapters/py_dataset_adapter.py @@ -1,3 +1,4 @@ +import itertools import multiprocessing.dummy import queue import random @@ -153,23 +154,26 @@ def __getitem__(self, index): """ raise NotImplementedError - def __len__(self): - """Number of batch in the PyDataset. + @property + def num_batches(self): + """Number of batches in the PyDataset. Returns: - The number of batches in the PyDataset. + The number of batches in the PyDataset or `None` to indicate that + the dataset is infinite. """ - raise NotImplementedError + # For backwards compatibility, support `__len__`. + if hasattr(self, "__len__"): + return len(self) + raise NotImplementedError( + "You need to implement the `num_batches` property:\n\n" + "@property\ndef num_batches(self):\n return ..." + ) def on_epoch_end(self): """Method called at the end of every epoch.""" pass - def __iter__(self): - """Create a generator that iterate over the PyDataset.""" - for i in range(len(self)): - yield self[i] - class PyDatasetAdapter(DataAdapter): """Adapter for `keras.utils.PyDataset` instances.""" @@ -234,23 +238,33 @@ def generator_fn(): else: def generator_fn(): - order = range(len(self.py_dataset)) - if self.shuffle: + num_batches = self.py_dataset.num_batches + indices = ( + range(num_batches) + if num_batches is not None + else itertools.count() + ) + if self.shuffle and num_batches is not None: # Match the shuffle convention in OrderedEnqueuer. - order = list(order) - random.shuffle(order) + indices = list(indices) + random.shuffle(indices) - for i in order: + for i in indices: yield self.py_dataset[i] return generator_fn def _get_iterator(self): + num_batches = self.py_dataset.num_batches gen_fn = self._make_multiprocessed_generator_fn() for i, batch in enumerate(gen_fn()): batch = self._standardize_batch(batch) yield batch - if i >= len(self.py_dataset) - 1 and self.enqueuer: + if ( + self.enqueuer + and num_batches is not None + and i >= num_batches - 1 + ): self.enqueuer.stop() def get_numpy_iterator(self): @@ -262,11 +276,11 @@ def get_jax_iterator(self): def get_tf_dataset(self): from keras.src.utils.module_utils import tensorflow as tf + num_batches = self.py_dataset.num_batches if self._output_signature is None: - num_samples = min( - data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC, - len(self.py_dataset), - ) + num_samples = data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC + if num_batches is not None: + num_samples = min(num_samples, num_batches) batches = [ self._standardize_batch(self.py_dataset[i]) for i in range(num_samples) @@ -277,7 +291,7 @@ def get_tf_dataset(self): self._get_iterator, output_signature=self._output_signature, ) - if self.shuffle: + if self.shuffle and num_batches is not None: ds = ds.shuffle(8) ds = ds.prefetch(tf.data.AUTOTUNE) return ds @@ -292,7 +306,7 @@ def on_epoch_end(self): @property def num_batches(self): - return len(self.py_dataset) + return self.py_dataset.num_batches @property def batch_size(self): @@ -520,31 +534,40 @@ def _wait_queue(self): def _run(self): """Submits request to the executor and queue the `Future` objects.""" - indices = list(range(len(self.py_dataset))) - if self.shuffle: - random.shuffle(indices) - self._send_py_dataset() # Share the initial py_dataset - while True: - with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: - for i in indices: + try: + num_batches = self.py_dataset.num_batches + indices = ( + range(num_batches) + if num_batches is not None + else itertools.count() + ) + if self.shuffle and num_batches is not None: + indices = list(indices) + random.shuffle(indices) + self._send_py_dataset() # Share the initial py_dataset + while True: + with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: + for i in indices: + if self.stop_signal.is_set(): + return + + self.queue.put( + executor.apply_async(get_index, (self.uid, i)), + block=True, + ) + + # Done with the current epoch, waiting for the final batches + self._wait_queue() + if self.stop_signal.is_set(): + # We're done return - self.queue.put( - executor.apply_async(get_index, (self.uid, i)), - block=True, - ) - - # Done with the current epoch, waiting for the final batches - self._wait_queue() - - if self.stop_signal.is_set(): - # We're done - return - - # Call the internal on epoch end. - self.py_dataset.on_epoch_end() - self._send_py_dataset() # Update the pool + # Call the internal on epoch end. + self.py_dataset.on_epoch_end() + self._send_py_dataset() # Update the pool + except Exception as e: + self.queue.put(e) # Report exception def get(self): """Creates a generator to extract data from the queue. @@ -558,7 +581,10 @@ def get(self): """ while self.is_running(): try: - inputs = self.queue.get(block=True, timeout=5).get() + value = self.queue.get(block=True, timeout=5) + if isinstance(value, Exception): + raise value # Propagate exception from other thread + inputs = value.get() if self.is_running(): self.queue.task_done() if inputs is not None: diff --git a/keras/src/trainers/data_adapters/py_dataset_adapter_test.py b/keras/src/trainers/data_adapters/py_dataset_adapter_test.py index b1be7002ac5..7c41971db56 100644 --- a/keras/src/trainers/data_adapters/py_dataset_adapter_test.py +++ b/keras/src/trainers/data_adapters/py_dataset_adapter_test.py @@ -15,21 +15,34 @@ class ExamplePyDataset(py_dataset_adapter.PyDataset): def __init__( - self, x_set, y_set, sample_weight=None, batch_size=32, delay=0, **kwargs + self, + x_set, + y_set, + sample_weight=None, + batch_size=32, + delay=0, + infinite=False, + **kwargs ): super().__init__(**kwargs) self.x, self.y = x_set, y_set self.batch_size = batch_size self.sample_weight = sample_weight self.delay = delay + self.infinite = infinite - def __len__(self): + @property + def num_batches(self): + if self.infinite: + return None return math.ceil(len(self.x) / self.batch_size) def __getitem__(self, idx): # Create artificial delay to test multiprocessing time.sleep(self.delay) + if self.infinite: + idx = idx % math.ceil(len(self.x) / self.batch_size) # Return x, y for batch idx. low = idx * self.batch_size # Cap upper bound at array length; the last batch may be smaller @@ -48,7 +61,8 @@ def __init__(self, inputs, batch_size=32, **kwargs): self.inputs = inputs self.batch_size = batch_size - def __len__(self): + @property + def num_batches(self): return math.ceil(len(self.inputs["x"]) / self.batch_size) def __getitem__(self, idx): @@ -98,6 +112,7 @@ class PyDatasetAdapterTest(testing.TestCase, parameterized.TestCase): "dataset_type": "torch", }, ], + infinite=[True, False], iterator_type=["np", "tf", "jax", "torch"], shuffle=[True, False], ) @@ -106,6 +121,7 @@ def test_basic_flow( self, shuffle, dataset_type, + infinite, iterator_type, workers=0, use_multiprocessing=False, @@ -127,6 +143,7 @@ def test_basic_flow( workers=workers, use_multiprocessing=use_multiprocessing, max_queue_size=max_queue_size, + infinite=infinite, ) adapter = py_dataset_adapter.PyDatasetAdapter( py_dataset, shuffle=shuffle @@ -157,10 +174,16 @@ def test_basic_flow( self.assertEqual(by.shape, (16, 2)) for i in range(by.shape[0]): sample_order.append(by[i, 0]) - if shuffle: - self.assertNotAllClose(sample_order, list(range(64))) + if infinite and len(sample_order) >= 128: + break + expected_order = list(range(64)) + if infinite: + # When the dataset is infinite, we cycle through the data twice. + expected_order = expected_order + expected_order + if shuffle and not infinite: + self.assertNotAllClose(sample_order, expected_order) else: - self.assertAllClose(sample_order, list(range(64))) + self.assertAllClose(sample_order, expected_order) # TODO: test class_weight # TODO: test sample weights @@ -240,7 +263,8 @@ def test_dict_inputs(self): def test_with_different_shapes(self, iterator_type): class TestPyDataset(py_dataset_adapter.PyDataset): - def __len__(self): + @property + def num_batches(self): return 3 def __getitem__(self, idx): @@ -284,3 +308,69 @@ def __getitem__(self, idx): else: self.assertEqual(bx.shape, (2, 6)) self.assertEqual(by.shape, (2, 2)) + + @parameterized.named_parameters( + named_product( + [ + { + "testcase_name": "multiprocessing", + "workers": 2, + "use_multiprocessing": True, + "max_queue_size": 10, + }, + { + "testcase_name": "multithreading", + "workers": 2, + "max_queue_size": 10, + }, + { + "testcase_name": "single", + }, + ], + iterator_type=["np", "tf", "jax", "torch"], + ) + ) + def test_exception_reported( + self, + iterator_type, + workers=0, + use_multiprocessing=False, + max_queue_size=0, + ): + class ExceptionPyDataset(py_dataset_adapter.PyDataset): + + @property + def num_batches(self): + return 4 + + def __getitem__(self, index): + if index < 2: + return ( + np.random.random((64, 4)).astype("float32"), + np.random.random((64, 2)).astype("float32"), + ) + raise ValueError("Excepted exception") + + adapter = py_dataset_adapter.PyDatasetAdapter( + ExceptionPyDataset(), shuffle=False + ) + + expected_exception_class = ValueError + if iterator_type == "np": + it = adapter.get_numpy_iterator() + elif iterator_type == "tf": + it = adapter.get_tf_dataset() + # tf.data wraps the exception + expected_exception_class = tf.errors.InvalidArgumentError + elif iterator_type == "jax": + it = adapter.get_jax_iterator() + elif iterator_type == "torch": + it = adapter.get_torch_dataloader() + + it = iter(it) + next(it) + next(it) + with self.assertRaisesRegex( + expected_exception_class, "Excepted exception" + ): + next(it)