Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implemented predict_y and predict_noise #894

Merged
merged 4 commits into from
Feb 5, 2025

Conversation

hstojic
Copy link
Collaborator

@hstojic hstojic commented Jan 21, 2025

this is an important change, it will affect DE model quite a bit, so its important to test downstream consequences

  • predict method now outputs estimation uncertainty, which is more appropriate
  • DE now supports predict_y method which outputs combined uncertainty
  • there is a new method predict_noise that gives mean and variance of the noise (i.e. aleatoric uncertainty), which comes handy in some situations

note that trajectory sampler behaviour will be affected as well, but only in the diversify mode

tests/unit/models/keras/test_models.py Outdated Show resolved Hide resolved
trieste/models/keras/sampler.py Outdated Show resolved Hide resolved
@@ -252,7 +258,42 @@ def ensemble_distributions(self, query_points: TensorType) -> tuple[tfd.Distribu

def predict_encoded(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be changed to not drop information about the aleatoric uncertainty? Currently Secondmind's calibration sampling logic uses IndependentReparametrizationSampler, which generates samples based upon predictions made using predict, which in turn calls through to predict_encoded.

For this to be suitable for query point generation in active learning, we need to know both the epistemic and aleatoric uncertainty. Hypothetically, we could not use IndependentReparametrizationSampler, and instead write a new sampler which uses predict_ensemble_encoded rather than predict, because that does return information about the aleatoric uncertainty. However, this would be a significant change in the calibration product.

An implementation could look something like:

        ensemble_means, ensemble_vars = self.predict_ensemble_encoded(query_points)
        predicted_means = tf.math.reduce_mean(ensemble_means, axis=-3)
        epistemic_variance = tf.math.reduce_variance(ensemble_means, axis=-3)
        aleatoric_variance = tf.math.reduce_mean(ensemble_vars, axis=-3)
        aleatoric_variance_var = tf.math.reduce_variance(ensemble_vars, axis=-3)

        means = tf.concat([predicted_means, aleatoric_variance], axis=-1)
        vars = tf.concat([epistemic_variance, aleatoric_variance_var], axis=-1)
        
        return means, vars

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't let this issue block this PR from being merged though, I appreciate that's likely a difficult change to make since it requires updating all usages, and would be a breaking change for anyone using DeepEnsemble.predict. We can discuss further how to integrate DeepEnsemble for active learning in a future PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I spoke to @uri-granta about this, and its not clear that this usage fits well the interfaces - we decided to proceed as is for now and think through how to deal with this kind of use case

@ChrisMorter ChrisMorter self-requested a review February 4, 2025 10:26
Copy link
Collaborator

@uri-granta uri-granta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just one comment. Also, as discussed, we should avoid changing predict to include the aleatoric uncertainty for now, though there may be a number of different ways we could do this in the future without breaking existing usages.


return unflatten(predicted_means), unflatten(predicted_vars)
def predict_noise(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
return self.predict_noise_encoded(self.encode(query_points))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have a docstring too, as it's the external API. And also a check_shapes, probably:

@check_shapes(
    "query_points: [broadcast batch..., D]",
    "return[0]: [batch..., E...]",
    "return[1]: [batch..., E...]",
)

Also, does this method generalise to other models beyond DeepEnsemble? If so, we could preemptively define SupportsPredictNoise and EncodedSupportsPredictNoise protocols, just like [Encoded]SupportsPredictY.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

other models in trieste don't have it, so probably not needed at the moment

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added doc and checkshapes otherwise

@@ -277,29 +318,38 @@ def predict_encoded(self, query_points: TensorType) -> tuple[TensorType, TensorT
:return: The predicted mean and variance of the observations at the specified
``query_points``.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
``query_points``.
``query_points``, including noise contributions.

Copy link
Collaborator

@avullo avullo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@hstojic hstojic merged commit 462c8b8 into develop Feb 5, 2025
16 checks passed
@hstojic hstojic deleted the hstojic/de_fix_predict_and_sampler branch February 5, 2025 17:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants