-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
using tf.data for fit method in DeepEnsemble model #890
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -14,8 +14,9 @@ | |||||||||||||
|
||||||||||||||
from __future__ import annotations | ||||||||||||||
|
||||||||||||||
import copy | ||||||||||||||
import re | ||||||||||||||
from typing import Any, Dict, Mapping, Optional | ||||||||||||||
from typing import Any, Dict, Mapping, Optional, Union | ||||||||||||||
|
||||||||||||||
import dill | ||||||||||||||
import tensorflow as tf | ||||||||||||||
|
@@ -377,6 +378,62 @@ def update_encoded(self, dataset: Dataset) -> None: | |||||||||||||
""" | ||||||||||||||
return | ||||||||||||||
|
||||||||||||||
def prepare_tf_data( | ||||||||||||||
self, | ||||||||||||||
x: dict[str, TensorType], | ||||||||||||||
y: dict[str, TensorType], | ||||||||||||||
batch_size: int, | ||||||||||||||
num_points: int, | ||||||||||||||
validation_split: float = 0.0, | ||||||||||||||
) -> Union[tf.data.Dataset, tuple[tf.data.Dataset, tf.data.Dataset]]: | ||||||||||||||
""" | ||||||||||||||
Prepare data for optimization as a `tf.data.Dataset`. This method allows user a more control | ||||||||||||||
over the data pipeline, e.g. shuffling, batching, prefetching, repeating,etc. | ||||||||||||||
|
||||||||||||||
:param x: Dictionary of input tensors | ||||||||||||||
:param y: Dictionary of output tensors | ||||||||||||||
:param batch_size: Batch size for the dataset | ||||||||||||||
:param num_points: Number of data points | ||||||||||||||
:param validation_split: Float between 0 and 1, fraction of data to use for validation | ||||||||||||||
:return: If validation_split is 0, returns a single dataset for training. | ||||||||||||||
If validation_split > 0, returns a tuple of (training_dataset, validation_dataset) | ||||||||||||||
""" | ||||||||||||||
if not 0.0 <= validation_split < 1.0: | ||||||||||||||
raise ValueError("validation_split must be between 0 and 1") | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
||||||||||||||
dataset = tf.data.Dataset.from_tensor_slices((x, y)) | ||||||||||||||
|
||||||||||||||
if validation_split > 0: | ||||||||||||||
# Calculate split sizes | ||||||||||||||
val_size = int(num_points * validation_split) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
train_size = num_points - val_size | ||||||||||||||
|
||||||||||||||
# Shuffle before splitting to ensure randomness | ||||||||||||||
dataset = dataset.shuffle(num_points, reshuffle_each_iteration=True) | ||||||||||||||
|
||||||||||||||
# Split into train and validation | ||||||||||||||
train_dataset = dataset.take(train_size) | ||||||||||||||
val_dataset = dataset.skip(train_size) | ||||||||||||||
|
||||||||||||||
# Prepare training dataset | ||||||||||||||
train_dataset = ( | ||||||||||||||
train_dataset.prefetch(tf.data.AUTOTUNE) | ||||||||||||||
.shuffle(train_size, reshuffle_each_iteration=True) | ||||||||||||||
.batch(batch_size, drop_remainder=True) | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
# Prepare validation dataset | ||||||||||||||
val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE) | ||||||||||||||
|
||||||||||||||
return train_dataset, val_dataset | ||||||||||||||
else: | ||||||||||||||
# Original behavior when no validation split is requested | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Q: is this really the same as the original behaviour? |
||||||||||||||
return ( | ||||||||||||||
dataset.prefetch(tf.data.AUTOTUNE) | ||||||||||||||
.shuffle(train_size, reshuffle_each_iteration=True) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (I think?)
Suggested change
|
||||||||||||||
.batch(batch_size, drop_remainder=True) | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
def optimize_encoded(self, dataset: Dataset) -> tf_keras.callbacks.History: | ||||||||||||||
""" | ||||||||||||||
Optimize the underlying Keras ensemble model with the specified ``dataset``. | ||||||||||||||
|
@@ -393,20 +450,38 @@ def optimize_encoded(self, dataset: Dataset) -> tf_keras.callbacks.History: | |||||||||||||
|
||||||||||||||
:param dataset: The data with which to optimize the model. | ||||||||||||||
""" | ||||||||||||||
fit_args = dict(self.optimizer.fit_args) | ||||||||||||||
fit_args_copy = copy.deepcopy(dict(self.optimizer.fit_args)) | ||||||||||||||
|
||||||||||||||
# Tell optimizer how many epochs have been used before: the optimizer will "continue" | ||||||||||||||
# optimization across multiple BO iterations rather than start fresh at each iteration. | ||||||||||||||
# This allows us to monitor training across iterations. | ||||||||||||||
|
||||||||||||||
if "epochs" in fit_args: | ||||||||||||||
fit_args["epochs"] = fit_args["epochs"] + self._absolute_epochs | ||||||||||||||
if "epochs" in fit_args_copy: | ||||||||||||||
fit_args_copy["epochs"] = fit_args_copy["epochs"] + self._absolute_epochs | ||||||||||||||
|
||||||||||||||
x, y = self.prepare_dataset(dataset) | ||||||||||||||
|
||||||||||||||
validation_split = fit_args_copy.pop("validation_split", 0.0) | ||||||||||||||
tf_data = self.prepare_tf_data( | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (if you change the return type above as suggested)
Suggested change
|
||||||||||||||
x, | ||||||||||||||
y, | ||||||||||||||
batch_size=fit_args_copy["batch_size"], | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "batch_size" isn't guaranteed to exist for a user-supplied fit_args
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah, well spotted, I was remembering BatchOptimizer... |
||||||||||||||
num_points=dataset.observations.shape[0], | ||||||||||||||
validation_split=validation_split, | ||||||||||||||
) | ||||||||||||||
fit_args_copy["batch_size"] = None # batching is done in prepare_tf_data | ||||||||||||||
|
||||||||||||||
if validation_split > 0: | ||||||||||||||
train_dataset, val_dataset = tf_data | ||||||||||||||
fit_args_copy["validation_data"] = val_dataset | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we maybe raise an exception if "train_dataset, val_dataset" is already present in the fit_args? |
||||||||||||||
history = self.model.fit( | ||||||||||||||
train_dataset, **fit_args_copy, initial_epoch=self._absolute_epochs | ||||||||||||||
) | ||||||||||||||
else: | ||||||||||||||
history = self.model.fit(tf_data, **fit_args_copy, initial_epoch=self._absolute_epochs) | ||||||||||||||
Comment on lines
+476
to
+480
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
||||||||||||||
history = self.model.fit( | ||||||||||||||
x=x, | ||||||||||||||
y=y, | ||||||||||||||
**fit_args, | ||||||||||||||
tf_data, | ||||||||||||||
**fit_args_copy, | ||||||||||||||
initial_epoch=self._absolute_epochs, | ||||||||||||||
) | ||||||||||||||
if self._continuous_optimisation: | ||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be nicer to always return a tuple?