Skip to content

Commit

Permalink
Add support for infinite PyDatasets.
Browse files Browse the repository at this point in the history
`PyDataset` now uses the `num_batches` property instead of `__len__` to support `None`, which is how one indicates the dataset is infinite. Note that infinite datasets are not shuffled.

Fixes keras-team#19528

Also added exception reporting when using multithreading / multiprocessing. Previously, the program would just hand with no error reported.
  • Loading branch information
hertschuh committed Apr 26, 2024
1 parent d7824ac commit ff2ba05
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 51 deletions.
7 changes: 7 additions & 0 deletions keras/src/trainers/data_adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ def get_data_adapter(
"sample_weights", "the sample weights", "PyDataset"
)
return PyDatasetAdapter(x, class_weight=class_weight, shuffle=shuffle)
# TODO: should we warn or not?
# if x.num_batches is None and shuffle:
# warnings.warn(
# "`shuffle=True` was passed, but will be ignored since the "
# "data `x` was provided as a infinite PyDataset. The "
# "PyDataset is expected to already be shuffled."
# )
elif is_torch_dataloader(x):
if y is not None:
raise_unsupported_arg("y", "the targets", "torch DataLoader")
Expand Down
114 changes: 70 additions & 44 deletions keras/src/trainers/data_adapters/py_dataset_adapter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import multiprocessing.dummy
import queue
import random
Expand Down Expand Up @@ -153,23 +154,26 @@ def __getitem__(self, index):
"""
raise NotImplementedError

def __len__(self):
"""Number of batch in the PyDataset.
@property
def num_batches(self):
"""Number of batches in the PyDataset.
Returns:
The number of batches in the PyDataset.
The number of batches in the PyDataset or `None` to indicate that
the dataset is infinite.
"""
raise NotImplementedError
# For backwards compatibility, support `__len__`.
if hasattr(self, "__len__"):
return len(self)
raise NotImplementedError(
"You need to implement the `num_batches` property:\n\n"
"@property\ndef num_batches(self):\n return ..."
)

def on_epoch_end(self):
"""Method called at the end of every epoch."""
pass

def __iter__(self):
"""Create a generator that iterate over the PyDataset."""
for i in range(len(self)):
yield self[i]


class PyDatasetAdapter(DataAdapter):
"""Adapter for `keras.utils.PyDataset` instances."""
Expand Down Expand Up @@ -234,23 +238,33 @@ def generator_fn():
else:

def generator_fn():
order = range(len(self.py_dataset))
if self.shuffle:
num_batches = self.py_dataset.num_batches
indices = (
range(num_batches)
if num_batches is not None
else itertools.count()
)
if self.shuffle and num_batches is not None:
# Match the shuffle convention in OrderedEnqueuer.
order = list(order)
random.shuffle(order)
indices = list(indices)
random.shuffle(indices)

for i in order:
for i in indices:
yield self.py_dataset[i]

return generator_fn

def _get_iterator(self):
num_batches = self.py_dataset.num_batches
gen_fn = self._make_multiprocessed_generator_fn()
for i, batch in enumerate(gen_fn()):
batch = self._standardize_batch(batch)
yield batch
if i >= len(self.py_dataset) - 1 and self.enqueuer:
if (
self.enqueuer
and num_batches is not None
and i >= num_batches - 1
):
self.enqueuer.stop()

def get_numpy_iterator(self):
Expand All @@ -262,11 +276,11 @@ def get_jax_iterator(self):
def get_tf_dataset(self):
from keras.src.utils.module_utils import tensorflow as tf

num_batches = self.py_dataset.num_batches
if self._output_signature is None:
num_samples = min(
data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC,
len(self.py_dataset),
)
num_samples = data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC
if num_batches is not None:
num_samples = min(num_samples, num_batches)
batches = [
self._standardize_batch(self.py_dataset[i])
for i in range(num_samples)
Expand All @@ -277,7 +291,7 @@ def get_tf_dataset(self):
self._get_iterator,
output_signature=self._output_signature,
)
if self.shuffle:
if self.shuffle and num_batches is not None:
ds = ds.shuffle(8)
ds = ds.prefetch(tf.data.AUTOTUNE)
return ds
Expand All @@ -292,7 +306,7 @@ def on_epoch_end(self):

@property
def num_batches(self):
return len(self.py_dataset)
return self.py_dataset.num_batches

@property
def batch_size(self):
Expand Down Expand Up @@ -520,31 +534,40 @@ def _wait_queue(self):

def _run(self):
"""Submits request to the executor and queue the `Future` objects."""
indices = list(range(len(self.py_dataset)))
if self.shuffle:
random.shuffle(indices)
self._send_py_dataset() # Share the initial py_dataset
while True:
with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
for i in indices:
try:
num_batches = self.py_dataset.num_batches
indices = (
range(num_batches)
if num_batches is not None
else itertools.count()
)
if self.shuffle and num_batches is not None:
indices = list(indices)
random.shuffle(indices)
self._send_py_dataset() # Share the initial py_dataset
while True:
with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
for i in indices:
if self.stop_signal.is_set():
return

self.queue.put(
executor.apply_async(get_index, (self.uid, i)),
block=True,
)

# Done with the current epoch, waiting for the final batches
self._wait_queue()

if self.stop_signal.is_set():
# We're done
return

self.queue.put(
executor.apply_async(get_index, (self.uid, i)),
block=True,
)

# Done with the current epoch, waiting for the final batches
self._wait_queue()

if self.stop_signal.is_set():
# We're done
return

# Call the internal on epoch end.
self.py_dataset.on_epoch_end()
self._send_py_dataset() # Update the pool
# Call the internal on epoch end.
self.py_dataset.on_epoch_end()
self._send_py_dataset() # Update the pool
except Exception as e:
self.queue.put(e) # Report exception

def get(self):
"""Creates a generator to extract data from the queue.
Expand All @@ -558,7 +581,10 @@ def get(self):
"""
while self.is_running():
try:
inputs = self.queue.get(block=True, timeout=5).get()
value = self.queue.get(block=True, timeout=5)
if isinstance(value, Exception):
raise value # Propagate exception from other thread
inputs = value.get()
if self.is_running():
self.queue.task_done()
if inputs is not None:
Expand Down
99 changes: 92 additions & 7 deletions keras/src/trainers/data_adapters/py_dataset_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,34 @@

class ExamplePyDataset(py_dataset_adapter.PyDataset):
def __init__(
self, x_set, y_set, sample_weight=None, batch_size=32, delay=0, **kwargs
self,
x_set,
y_set,
sample_weight=None,
batch_size=32,
delay=0,
infinite=False,
**kwargs
):
super().__init__(**kwargs)
self.x, self.y = x_set, y_set
self.batch_size = batch_size
self.sample_weight = sample_weight
self.delay = delay
self.infinite = infinite

def __len__(self):
@property
def num_batches(self):
if self.infinite:
return None
return math.ceil(len(self.x) / self.batch_size)

def __getitem__(self, idx):
# Create artificial delay to test multiprocessing
time.sleep(self.delay)

if self.infinite:
idx = idx % math.ceil(len(self.x) / self.batch_size)
# Return x, y for batch idx.
low = idx * self.batch_size
# Cap upper bound at array length; the last batch may be smaller
Expand All @@ -48,7 +61,8 @@ def __init__(self, inputs, batch_size=32, **kwargs):
self.inputs = inputs
self.batch_size = batch_size

def __len__(self):
@property
def num_batches(self):
return math.ceil(len(self.inputs["x"]) / self.batch_size)

def __getitem__(self, idx):
Expand Down Expand Up @@ -98,6 +112,7 @@ class PyDatasetAdapterTest(testing.TestCase, parameterized.TestCase):
"dataset_type": "torch",
},
],
infinite=[True, False],
iterator_type=["np", "tf", "jax", "torch"],
shuffle=[True, False],
)
Expand All @@ -106,6 +121,7 @@ def test_basic_flow(
self,
shuffle,
dataset_type,
infinite,
iterator_type,
workers=0,
use_multiprocessing=False,
Expand All @@ -127,6 +143,7 @@ def test_basic_flow(
workers=workers,
use_multiprocessing=use_multiprocessing,
max_queue_size=max_queue_size,
infinite=infinite,
)
adapter = py_dataset_adapter.PyDatasetAdapter(
py_dataset, shuffle=shuffle
Expand Down Expand Up @@ -157,10 +174,16 @@ def test_basic_flow(
self.assertEqual(by.shape, (16, 2))
for i in range(by.shape[0]):
sample_order.append(by[i, 0])
if shuffle:
self.assertNotAllClose(sample_order, list(range(64)))
if infinite and len(sample_order) >= 128:
break
expected_order = list(range(64))
if infinite:
# When the dataset is infinite, we cycle through the data twice.
expected_order = expected_order + expected_order
if shuffle and not infinite:
self.assertNotAllClose(sample_order, expected_order)
else:
self.assertAllClose(sample_order, list(range(64)))
self.assertAllClose(sample_order, expected_order)

# TODO: test class_weight
# TODO: test sample weights
Expand Down Expand Up @@ -240,7 +263,8 @@ def test_dict_inputs(self):
def test_with_different_shapes(self, iterator_type):

class TestPyDataset(py_dataset_adapter.PyDataset):
def __len__(self):
@property
def num_batches(self):
return 3

def __getitem__(self, idx):
Expand Down Expand Up @@ -284,3 +308,64 @@ def __getitem__(self, idx):
else:
self.assertEqual(bx.shape, (2, 6))
self.assertEqual(by.shape, (2, 2))

@parameterized.named_parameters(
named_product(
[
{
"testcase_name": "multiprocessing",
"workers": 2,
"use_multiprocessing": True,
"max_queue_size": 10,
},
{
"testcase_name": "multithreading",
"workers": 2,
"max_queue_size": 10,
},
{
"testcase_name": "single",
},
],
iterator_type=["np", "tf", "jax", "torch"],
)
)
def test_exception_reported(
self,
iterator_type,
workers=0,
use_multiprocessing=False,
max_queue_size=0,
):
class ExceptionPyDataset(py_dataset_adapter.PyDataset):

@property
def num_batches(self):
return 4

def __getitem__(self, index):
if index < 2:
return (np.random.random((64, 4)).astype("float32"), np.random.random((64, 2)).astype("float32"))
raise ValueError("Excepted exception")

adapter = py_dataset_adapter.PyDatasetAdapter(
ExceptionPyDataset(), shuffle=False
)

expected_exception_class = ValueError
if iterator_type == "np":
it = adapter.get_numpy_iterator()
elif iterator_type == "tf":
it = adapter.get_tf_dataset()
# tf.data wraps the exception
expected_exception_class = tf.errors.InvalidArgumentError
elif iterator_type == "jax":
it = adapter.get_jax_iterator()
elif iterator_type == "torch":
it = adapter.get_torch_dataloader()

it = iter(it)
next(it)
next(it)
with self.assertRaisesRegex(expected_exception_class, "Excepted exception"):
next(it)

0 comments on commit ff2ba05

Please sign in to comment.