Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

using tf.data for fit method in DeepEnsemble model #890

Closed
wants to merge 2 commits into from

Conversation

hstojic
Copy link
Collaborator

@hstojic hstojic commented Jan 17, 2025

this simple change should improve memory handling, should be better optimized for GPUs, and generally gives more control to the user over preparing data for training

@hstojic hstojic requested review from uri-granta and avullo January 17, 2025 14:19
Copy link
Collaborator

@uri-granta uri-granta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Various comments/comments. Happy to review again (including the tests) once the tests are passing.

batch_size: int,
num_points: int,
validation_split: float = 0.0,
) -> Union[tf.data.Dataset, tuple[tf.data.Dataset, tf.data.Dataset]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be nicer to always return a tuple?

Suggested change
) -> Union[tf.data.Dataset, tuple[tf.data.Dataset, tf.data.Dataset]]:
) -> tuple[tf.data.Dataset, Optional[tf.data.Dataset]]]:

If validation_split > 0, returns a tuple of (training_dataset, validation_dataset)
"""
if not 0.0 <= validation_split < 1.0:
raise ValueError("validation_split must be between 0 and 1")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ValueError("validation_split must be between 0 and 1")
raise ValueError(f"validation_split must be between 0 and 1: got {validation_split}")


if validation_split > 0:
# Calculate split sizes
val_size = int(num_points * validation_split)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
val_size = int(num_points * validation_split)
val_size = round(num_points * validation_split)

tf_data = self.prepare_tf_data(
x,
y,
batch_size=fit_args_copy["batch_size"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"batch_size" isn't guaranteed to exist for a user-supplied fit_args

Suggested change
batch_size=fit_args_copy["batch_size"],
batch_size=fit_args_copy.get("batch_size"),

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, well spotted, I was remembering BatchOptimizer...


x, y = self.prepare_dataset(dataset)

validation_split = fit_args_copy.pop("validation_split", 0.0)
tf_data = self.prepare_tf_data(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(if you change the return type above as suggested)

Suggested change
tf_data = self.prepare_tf_data(
train_dataset, val_dataset = self.prepare_tf_data(


if validation_split > 0:
train_dataset, val_dataset = tf_data
fit_args_copy["validation_data"] = val_dataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we maybe raise an exception if "train_dataset, val_dataset" is already present in the fit_args?

Comment on lines +476 to +480
history = self.model.fit(
train_dataset, **fit_args_copy, initial_epoch=self._absolute_epochs
)
else:
history = self.model.fit(tf_data, **fit_args_copy, initial_epoch=self._absolute_epochs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
history = self.model.fit(
train_dataset, **fit_args_copy, initial_epoch=self._absolute_epochs
)
else:
history = self.model.fit(tf_data, **fit_args_copy, initial_epoch=self._absolute_epochs)
history = self.model.fit(tf_data, **fit_args_copy, initial_epoch=self._absolute_epochs)

# Original behavior when no validation split is requested
return (
dataset.prefetch(tf.data.AUTOTUNE)
.shuffle(train_size, reshuffle_each_iteration=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I think?)

Suggested change
.shuffle(train_size, reshuffle_each_iteration=True)
.shuffle(num_points, reshuffle_each_iteration=True)


return train_dataset, val_dataset
else:
# Original behavior when no validation split is requested
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: is this really the same as the original behaviour?

@uri-granta uri-granta self-requested a review January 20, 2025 10:10
@pio-neil
Copy link

I do have a small concern about the use of tf.Dataset.shuffle. When I did some testing with this before, the shuffle buffer (which is tf.Dataset's internal method of shuffling data) used around 18GB of extra memory. This was with a dataset with 30 million rows, with 15 inputs and one output, and a batch size of 1000. The shuffle buffer also has an impact on speed, but I suspect this is relatively minor compared to the model training time.

This might not be such a problem with smaller datasets. So perhaps it would be a good idea to do some benchmarking?

However, it's also not clear to me why we're introducing shuffling by default here, when AFAICT it wasn't there before? This seems like a change of behaviour. Do we expect this to improve model accuracy? It may be better to let the user of Trieste control this, rather than making it the default behaviour?

@hstojic
Copy link
Collaborator Author

hstojic commented Feb 5, 2025

I do have a small concern about the use of tf.Dataset.shuffle. When I did some testing with this before, the shuffle buffer (which is tf.Dataset's internal method of shuffling data) used around 18GB of extra memory. This was with a dataset with 30 million rows, with 15 inputs and one output, and a batch size of 1000. The shuffle buffer also has an impact on speed, but I suspect this is relatively minor compared to the model training time.

This might not be such a problem with smaller datasets. So perhaps it would be a good idea to do some benchmarking?

I have already done some testing, on much smaller data than that, haven't seen any adverse effects.
Users that handle bigger datasets now can handle data preparation themselves by subclassing the model and overwriting the new method, if there are any issues with speed/memory.

However, it's also not clear to me why we're introducing shuffling by default here, when AFAICT it wasn't there before? This seems like a change of behaviour. Do we expect this to improve model accuracy? It may be better to let the user of Trieste control this, rather than making it the default behaviour?

shuffle argument is by default True in Keras' fit method, hence why its default here, but its definitely needed - we are typically doing many epochs here and shuffling data at each epoch helps model training - so yes, should improve model accuracy. The user can always overwrite the new method, that was the purpose of extracting this data preparation bit into a method.

@pio-neil
Copy link

pio-neil commented Feb 5, 2025

However, it's also not clear to me why we're introducing shuffling by default here, when AFAICT it wasn't there before? This seems like a change of behaviour. Do we expect this to improve model accuracy? It may be better to let the user of Trieste control this, rather than making it the default behaviour?

shuffle argument is by default True in Keras' fit method, hence why its default here, but its definitely needed - we are typically doing many epochs here and shuffling data at each epoch helps model training - so yes, should improve model accuracy. The user can always overwrite the new method, that was the purpose of extracting this data preparation bit into a method.

Right, yes, I see. I was just wondering why we were doing things differently in terms of shuffling, but we aren't (since fit ignores the shuffle argument for datasets):

shuffle: Boolean, whether to shuffle the training data before each epoch. This argument is ignored when x is a keras.utils.PyDataset, tf.data.Dataset, torch.utils.data.DataLoader or Python generator function.

@hstojic hstojic changed the title using tf.data for fit method instead of using tf.data for fit method in DeepEnsemble model Feb 5, 2025
@hstojic
Copy link
Collaborator Author

hstojic commented Feb 5, 2025

after some further thinking, this is currently unnecessary complication of the code, as for smallish datasets that we mainly deal with here current code is good enough

@hstojic hstojic closed this Feb 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants