Skip to content

Commit

Permalink
Merge branch 'master' of github.com:keras-team/keras
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Jun 25, 2024
2 parents c03e7b0 + f5d3087 commit 208e70d
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 365 deletions.
119 changes: 55 additions & 64 deletions keras/src/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np

from keras.src import tree
from keras.src.api_export import keras_export
from keras.src.utils import io_utils
from keras.src.utils.module_utils import tensorflow as tf
Expand Down Expand Up @@ -137,16 +138,7 @@ def _convert_dataset_to_list(
data_size_warning_flag,
start_time,
):
if dataset_type_spec in [tuple, list]:
# The try-except here is for NumPy 1.24 compatibility, see:
# https://numpy.org/neps/nep-0034-infer-dtype-is-object.html
try:
arr = np.array(sample)
except ValueError:
arr = np.array(sample, dtype=object)
dataset_as_list.append(arr)
else:
dataset_as_list.append(sample)
dataset_as_list.append(sample)

return dataset_as_list

Expand All @@ -169,23 +161,23 @@ def _get_data_iterator_from_dataset(dataset, dataset_type_spec):
"Please provide a non-empty list of arrays."
)

if _get_type_spec(dataset[0]) is np.ndarray:
expected_shape = dataset[0].shape
for i, element in enumerate(dataset):
if np.array(element).shape[0] != expected_shape[0]:
raise ValueError(
"Received a list of NumPy arrays with different "
f"lengths. Mismatch found at index {i}, "
f"Expected shape={expected_shape} "
f"Received shape={np.array(element).shape}."
"Please provide a list of NumPy arrays with "
"the same length."
)
else:
raise ValueError(
"Expected a list of `numpy.ndarray` objects,"
f"Received: {type(dataset[0])}"
)
expected_shape = None
for i, element in enumerate(dataset):
if not isinstance(element, np.ndarray):
raise ValueError(
"Expected a list of `numpy.ndarray` objects,"
f"Received: {type(element)} at index {i}."
)
if expected_shape is None:
expected_shape = element.shape
elif element.shape[0] != expected_shape[0]:
raise ValueError(
"Received a list of NumPy arrays with different lengths."
f"Mismatch found at index {i}, "
f"Expected shape={expected_shape} "
f"Received shape={np.array(element).shape}."
"Please provide a list of NumPy arrays of the same length."
)

return iter(zip(*dataset))
elif dataset_type_spec == tuple:
Expand All @@ -195,23 +187,23 @@ def _get_data_iterator_from_dataset(dataset, dataset_type_spec):
"Please provide a non-empty tuple of arrays."
)

if _get_type_spec(dataset[0]) is np.ndarray:
expected_shape = dataset[0].shape
for i, element in enumerate(dataset):
if np.array(element).shape[0] != expected_shape[0]:
raise ValueError(
"Received a tuple of NumPy arrays with different "
f"lengths. Mismatch found at index {i}, "
f"Expected shape={expected_shape} "
f"Received shape={np.array(element).shape}."
"Please provide a tuple of NumPy arrays with "
"the same length."
)
else:
raise ValueError(
"Expected a tuple of `numpy.ndarray` objects, "
f"Received: {type(dataset[0])}"
)
expected_shape = None
for i, element in enumerate(dataset):
if not isinstance(element, np.ndarray):
raise ValueError(
"Expected a tuple of `numpy.ndarray` objects,"
f"Received: {type(element)} at index {i}."
)
if expected_shape is None:
expected_shape = element.shape
elif element.shape[0] != expected_shape[0]:
raise ValueError(
"Received a tuple of NumPy arrays with different lengths."
f"Mismatch found at index {i}, "
f"Expected shape={expected_shape} "
f"Received shape={np.array(element).shape}."
"Please provide a tuple of NumPy arrays of the same length."
)

return iter(zip(*dataset))
elif dataset_type_spec == tf.data.Dataset:
Expand Down Expand Up @@ -436,23 +428,24 @@ def _restore_dataset_from_list(
dataset_as_list, dataset_type_spec, original_dataset
):
"""Restore the dataset from the list of arrays."""
if dataset_type_spec in [tuple, list]:
return tuple(np.array(sample) for sample in zip(*dataset_as_list))
elif dataset_type_spec == tf.data.Dataset:
if isinstance(original_dataset.element_spec, dict):
restored_dataset = {}
for d in dataset_as_list:
for k, v in d.items():
if k not in restored_dataset:
restored_dataset[k] = [v]
else:
restored_dataset[k].append(v)
return restored_dataset
else:
return tuple(np.array(sample) for sample in zip(*dataset_as_list))
if dataset_type_spec in [tuple, list, tf.data.Dataset] or is_torch_dataset(
original_dataset
):
# Save structure by taking the first element.
element_spec = dataset_as_list[0]
# Flatten each element.
dataset_as_list = [tree.flatten(sample) for sample in dataset_as_list]
# Combine respective elements at all indices.
dataset_as_list = [np.array(sample) for sample in zip(*dataset_as_list)]
# Recreate the original structure of elements.
dataset_as_list = tree.pack_sequence_as(element_spec, dataset_as_list)
# Turn lists to tuples as tf.data will fail on lists.
return tree.traverse(
lambda x: tuple(x) if isinstance(x, list) else x,
dataset_as_list,
top_down=False,
)

elif is_torch_dataset(original_dataset):
return tuple(np.array(sample) for sample in zip(*dataset_as_list))
return dataset_as_list


Expand All @@ -477,14 +470,12 @@ def _get_type_spec(dataset):
return list
elif isinstance(dataset, np.ndarray):
return np.ndarray
elif isinstance(dataset, dict):
return dict
elif isinstance(dataset, tf.data.Dataset):
return tf.data.Dataset
elif is_torch_dataset(dataset):
from torch.utils.data import Dataset as torchDataset
from torch.utils.data import Dataset as TorchDataset

return torchDataset
return TorchDataset
else:
return None

Expand Down
Loading

0 comments on commit 208e70d

Please sign in to comment.