Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

using tf.data for fit method in DeepEnsemble model #890

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 109 additions & 12 deletions tests/unit/models/keras/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def test_deep_ensemble_prepare_data_call(


def test_deep_ensemble_deep_copyable() -> None:
example_data = _get_example_data([10, 3], [10, 3])
example_data = _get_example_data([160, 3], [160, 3])
model, _, _ = trieste_deep_ensemble_model(example_data, 2, False, False)
model_copy = copy.deepcopy(model)

Expand All @@ -570,7 +570,7 @@ def test_deep_ensemble_deep_copyable() -> None:
npt.assert_allclose(variance_f, variance_f_copy)

# check that updating the original doesn't break or change the deepcopy
new_example_data = _get_example_data([20, 3], [20, 3])
new_example_data = _get_example_data([320, 3], [320, 3])
model.update(new_example_data)
model.optimize(new_example_data)

Expand All @@ -582,7 +582,7 @@ def test_deep_ensemble_deep_copyable() -> None:
npt.assert_array_compare(operator.__ne__, variance_f_updated, variance_f)

# check that we can also update the copy
newer_example_data = _get_example_data([30, 3], [30, 3])
newer_example_data = _get_example_data([640, 3], [640, 3])
model_copy.update(newer_example_data)
model_copy.optimize(newer_example_data)

Expand All @@ -595,7 +595,7 @@ def test_deep_ensemble_deep_copyable() -> None:


def test_deep_ensemble_tf_saved_model() -> None:
example_data = _get_example_data([10, 3], [10, 3])
example_data = _get_example_data([16, 3], [16, 3])
model, _, _ = trieste_deep_ensemble_model(example_data, 2, False, False)

with tempfile.TemporaryDirectory() as path:
Expand Down Expand Up @@ -638,9 +638,9 @@ def _sample(query_points: TensorType, num_samples: int) -> TensorType:


def test_deep_ensemble_deep_copies_optimizer_state() -> None:
example_data = _get_example_data([10, 3], [10, 3])
example_data = _get_example_data([100, 3], [100, 3])
model, _, _ = trieste_deep_ensemble_model(example_data, 2, False, False)
new_example_data = _get_example_data([20, 3], [20, 3])
new_example_data = _get_example_data([120, 3], [120, 3])
model.update(new_example_data)
assert not keras_optimizer_weights(model.model.optimizer)
model.optimize(new_example_data)
Expand Down Expand Up @@ -681,11 +681,11 @@ def test_deep_ensemble_deep_copies_optimizer_state() -> None:
],
)
def test_deep_ensemble_deep_copies_different_callback_types(callbacks: list[Callback]) -> None:
example_data = _get_example_data([10, 3], [10, 3])
example_data = _get_example_data([160, 3], [160, 3])
model, _, _ = trieste_deep_ensemble_model(example_data, 2, False, False)
model.optimizer.fit_args["callbacks"] = callbacks

new_example_data = _get_example_data([20, 3], [20, 3])
new_example_data = _get_example_data([320, 3], [320, 3])
model.update(new_example_data)
model.optimize(new_example_data)

Expand All @@ -697,7 +697,7 @@ def test_deep_ensemble_deep_copies_different_callback_types(callbacks: list[Call


def test_deep_ensemble_deep_copies_optimizer_callback_models() -> None:
example_data = _get_example_data([10, 3], [10, 3])
example_data = _get_example_data([16, 3], [16, 3])
keras_ensemble = trieste_keras_ensemble_model(example_data, _ENSEMBLE_SIZE, False)
model = DeepEnsemble(keras_ensemble)
new_example_data = _get_example_data([20, 3], [20, 3])
Expand All @@ -716,7 +716,7 @@ def test_deep_ensemble_deep_copies_optimizer_callback_models() -> None:


def test_deep_ensemble_deep_copies_optimizer_without_callbacks() -> None:
example_data = _get_example_data([10, 3], [10, 3])
example_data = _get_example_data([16, 3], [16, 3])
keras_ensemble = trieste_keras_ensemble_model(example_data, _ENSEMBLE_SIZE, False)
model = DeepEnsemble(keras_ensemble)
del model.optimizer.fit_args["callbacks"]
Expand All @@ -727,7 +727,7 @@ def test_deep_ensemble_deep_copies_optimizer_without_callbacks() -> None:


def test_deep_ensemble_deep_copies_optimization_history() -> None:
example_data = _get_example_data([10, 3], [10, 3])
example_data = _get_example_data([16, 3], [16, 3])
keras_ensemble = trieste_keras_ensemble_model(example_data, _ENSEMBLE_SIZE, False)
model = DeepEnsemble(keras_ensemble)
model.optimize(example_data)
Expand All @@ -752,7 +752,7 @@ def test_deep_ensemble_log(
mocked_summary_histogram: unittest.mock.MagicMock,
use_dataset: bool,
) -> None:
example_data = _get_example_data([10, 3], [10, 3])
example_data = _get_example_data([16, 3], [16, 3])
keras_ensemble = trieste_keras_ensemble_model(example_data, _ENSEMBLE_SIZE, False)
model = DeepEnsemble(keras_ensemble)
model.optimize(example_data)
Expand All @@ -777,3 +777,100 @@ def test_deep_ensemble_log(

assert mocked_summary_scalar.call_count == num_scalars
assert mocked_summary_histogram.call_count == num_histogram


def test_deep_ensemble_prepare_tf_data_returns_dataset() -> None:
example_data = _get_example_data([100, 1])
model, _, _ = trieste_deep_ensemble_model(example_data, _ENSEMBLE_SIZE, False, False)

x, y = model.prepare_dataset(example_data)
dataset = model.prepare_tf_data(x, y, batch_size=10, num_points=100)

assert isinstance(dataset, tf.data.Dataset)


def test_deep_ensemble_prepare_tf_data_batch_size() -> None:
example_data = _get_example_data([100, 1])
model, _, _ = trieste_deep_ensemble_model(example_data, _ENSEMBLE_SIZE, False, False)

x, y = model.prepare_dataset(example_data)
batch_size = 10
dataset = model.prepare_tf_data(x, y, batch_size=batch_size, num_points=100)

for batch_x, batch_y in dataset:
for key in batch_x:
assert batch_x[key].shape[0] == batch_size
for key in batch_y:
assert batch_y[key].shape[0] == batch_size


@random_seed
def test_deep_ensemble_prepare_tf_data_shuffling() -> None:
example_data = _get_example_data([100, 1])
model, _, _ = trieste_deep_ensemble_model(example_data, _ENSEMBLE_SIZE, False, False)

x, y = model.prepare_dataset(example_data)
dataset = model.prepare_tf_data(x, y, batch_size=10, num_points=100)

# Get first batch from two iterations
first_iter = next(iter(dataset))
second_iter = next(iter(dataset))

# Check that the batches are different (shuffled)
for key in first_iter[0]:
assert not tf.reduce_all(first_iter[0][key] == second_iter[0][key])


@pytest.mark.parametrize("input_dim", [1, 3, 5])
def test_deep_ensemble_prepare_tf_data_shapes_and_types(num_outputs: int, input_dim: int) -> None:
example_data = _get_example_data([100, input_dim], [100, num_outputs])
model, _, _ = trieste_deep_ensemble_model(example_data, _ENSEMBLE_SIZE, False, False)

x, y = model.prepare_dataset(example_data)
dataset = model.prepare_tf_data(x, y, batch_size=10, num_points=100)

# Check shapes and types of the dataset elements
for batch_x, batch_y in dataset:
# Check input shapes and types
for key in batch_x:
assert batch_x[key].shape[-1] == input_dim
assert batch_x[key].dtype == example_data.query_points.dtype

# Check output shapes and types
for key in batch_y:
assert batch_y[key].shape[-1] == num_outputs
assert batch_y[key].dtype == example_data.observations.dtype


def test_deep_ensemble_prepare_tf_data_validation_split() -> None:
example_data = _get_example_data([100, 1])
model, _, _ = trieste_deep_ensemble_model(example_data, _ENSEMBLE_SIZE, False, False)

x, y = model.prepare_dataset(example_data)
validation_split = 0.2
train_dataset, val_dataset = model.prepare_tf_data(
x, y, batch_size=10, num_points=100, validation_split=validation_split
)

# Check that we get two datasets
assert isinstance(train_dataset, tf.data.Dataset)
assert isinstance(val_dataset, tf.data.Dataset)

# Check approximate sizes (might be off by 1 due to rounding)
train_size = sum(1 for _ in train_dataset)
val_size = sum(1 for _ in val_dataset)
assert abs(train_size * 10 - 80) <= 1 # 80% of 100 = 80 samples
assert abs(val_size * 10 - 20) <= 1 # 20% of 100 = 20 samples


def test_deep_ensemble_prepare_tf_data_invalid_validation_split() -> None:
example_data = _get_example_data([100, 1])
model, _, _ = trieste_deep_ensemble_model(example_data, _ENSEMBLE_SIZE, False, False)

x, y = model.prepare_dataset(example_data)

with pytest.raises(ValueError):
model.prepare_tf_data(x, y, batch_size=10, num_points=100, validation_split=1.0)

with pytest.raises(ValueError):
model.prepare_tf_data(x, y, batch_size=10, num_points=100, validation_split=-0.1)
91 changes: 83 additions & 8 deletions trieste/models/keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

from __future__ import annotations

import copy
import re
from typing import Any, Dict, Mapping, Optional
from typing import Any, Dict, Mapping, Optional, Union

import dill
import tensorflow as tf
Expand Down Expand Up @@ -377,6 +378,62 @@ def update_encoded(self, dataset: Dataset) -> None:
"""
return

def prepare_tf_data(
self,
x: dict[str, TensorType],
y: dict[str, TensorType],
batch_size: int,
num_points: int,
validation_split: float = 0.0,
) -> Union[tf.data.Dataset, tuple[tf.data.Dataset, tf.data.Dataset]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be nicer to always return a tuple?

Suggested change
) -> Union[tf.data.Dataset, tuple[tf.data.Dataset, tf.data.Dataset]]:
) -> tuple[tf.data.Dataset, Optional[tf.data.Dataset]]]:

"""
Prepare data for optimization as a `tf.data.Dataset`. This method allows user a more control
over the data pipeline, e.g. shuffling, batching, prefetching, repeating,etc.

:param x: Dictionary of input tensors
:param y: Dictionary of output tensors
:param batch_size: Batch size for the dataset
:param num_points: Number of data points
:param validation_split: Float between 0 and 1, fraction of data to use for validation
:return: If validation_split is 0, returns a single dataset for training.
If validation_split > 0, returns a tuple of (training_dataset, validation_dataset)
"""
if not 0.0 <= validation_split < 1.0:
raise ValueError("validation_split must be between 0 and 1")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ValueError("validation_split must be between 0 and 1")
raise ValueError(f"validation_split must be between 0 and 1: got {validation_split}")


dataset = tf.data.Dataset.from_tensor_slices((x, y))

if validation_split > 0:
# Calculate split sizes
val_size = int(num_points * validation_split)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
val_size = int(num_points * validation_split)
val_size = round(num_points * validation_split)

train_size = num_points - val_size

# Shuffle before splitting to ensure randomness
dataset = dataset.shuffle(num_points, reshuffle_each_iteration=True)

# Split into train and validation
train_dataset = dataset.take(train_size)
val_dataset = dataset.skip(train_size)

# Prepare training dataset
train_dataset = (
train_dataset.prefetch(tf.data.AUTOTUNE)
.shuffle(train_size, reshuffle_each_iteration=True)
.batch(batch_size, drop_remainder=True)
)

# Prepare validation dataset
val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)

return train_dataset, val_dataset
else:
# Original behavior when no validation split is requested
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: is this really the same as the original behaviour?

return (
dataset.prefetch(tf.data.AUTOTUNE)
.shuffle(train_size, reshuffle_each_iteration=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I think?)

Suggested change
.shuffle(train_size, reshuffle_each_iteration=True)
.shuffle(num_points, reshuffle_each_iteration=True)

.batch(batch_size, drop_remainder=True)
)

def optimize_encoded(self, dataset: Dataset) -> tf_keras.callbacks.History:
"""
Optimize the underlying Keras ensemble model with the specified ``dataset``.
Expand All @@ -393,20 +450,38 @@ def optimize_encoded(self, dataset: Dataset) -> tf_keras.callbacks.History:

:param dataset: The data with which to optimize the model.
"""
fit_args = dict(self.optimizer.fit_args)
fit_args_copy = copy.deepcopy(dict(self.optimizer.fit_args))

# Tell optimizer how many epochs have been used before: the optimizer will "continue"
# optimization across multiple BO iterations rather than start fresh at each iteration.
# This allows us to monitor training across iterations.

if "epochs" in fit_args:
fit_args["epochs"] = fit_args["epochs"] + self._absolute_epochs
if "epochs" in fit_args_copy:
fit_args_copy["epochs"] = fit_args_copy["epochs"] + self._absolute_epochs

x, y = self.prepare_dataset(dataset)

validation_split = fit_args_copy.pop("validation_split", 0.0)
tf_data = self.prepare_tf_data(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(if you change the return type above as suggested)

Suggested change
tf_data = self.prepare_tf_data(
train_dataset, val_dataset = self.prepare_tf_data(

x,
y,
batch_size=fit_args_copy["batch_size"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"batch_size" isn't guaranteed to exist for a user-supplied fit_args

Suggested change
batch_size=fit_args_copy["batch_size"],
batch_size=fit_args_copy.get("batch_size"),

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, well spotted, I was remembering BatchOptimizer...

num_points=dataset.observations.shape[0],
validation_split=validation_split,
)
fit_args_copy["batch_size"] = None # batching is done in prepare_tf_data

if validation_split > 0:
train_dataset, val_dataset = tf_data
fit_args_copy["validation_data"] = val_dataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we maybe raise an exception if "train_dataset, val_dataset" is already present in the fit_args?

history = self.model.fit(
train_dataset, **fit_args_copy, initial_epoch=self._absolute_epochs
)
else:
history = self.model.fit(tf_data, **fit_args_copy, initial_epoch=self._absolute_epochs)
Comment on lines +476 to +480
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
history = self.model.fit(
train_dataset, **fit_args_copy, initial_epoch=self._absolute_epochs
)
else:
history = self.model.fit(tf_data, **fit_args_copy, initial_epoch=self._absolute_epochs)
history = self.model.fit(tf_data, **fit_args_copy, initial_epoch=self._absolute_epochs)


history = self.model.fit(
x=x,
y=y,
**fit_args,
tf_data,
**fit_args_copy,
initial_epoch=self._absolute_epochs,
)
if self._continuous_optimisation:
Expand Down
Loading