-
Notifications
You must be signed in to change notification settings - Fork 562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ensure PSD-safe factorization in constructor of MultivariateNormal
#2297
base: main
Are you sure you want to change the base?
Conversation
e06f8da
to
d0e35e4
Compare
… to ensure PSD-safe factorization
d0e35e4
to
570d43f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! An alternative to making everything into a LinearOperator would be upon receiving a dense tensor, rather than just storing that and passing it along, to compute the cholesky decomposition with jitted and then pass that as the scale_tril to the torch distribution. The downside of that is ofc that this would do a lot of compute upon construction of the object, potentially unnecessary. Another hack could be to mock some of the torch distribution code so that it users the psd-safe cholesky decomposition internally (though that seems very hacky and potentially problematic).
# will fail if the covariance matrix is semi-definite, whereas DenseLinearOperator ends up | ||
# calling _psd_safe_cholesky, which factorizes semi-definite matrices by adding to the diagonal. | ||
if isinstance(covariance_matrix, Tensor): | ||
self._islazy = False # to allow _unbroadcasted_scale_tril setter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems odd to have _islazy
set to True
if the covariance matrix is indeed a LinearOperator
. I guess the "lazy" nomenclature is a bit outdated anyway with the move to LinearOperator
.
|
||
event_shape = self.loc.shape[-1:] | ||
|
||
# TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean changing the torch code to validate LinearOperator inputs? That might be somewhat challenging to do if we want to use LinearOperators there explicitly. What would work is to make changes in pure torch that would make it easier to use LinearOperator objects by means of the __torch_function__
interface we define in LinearOperator.
There was a recent BoTorch issue that was caused by a positive semi-definite matrix being passed to
MultivariateNormal
as aTensor
, which causes the constructor to fail because PyTorch's constructor callscholesky
on the tensor. This commit upstreams the corresponding BoTorch PR to ensure that all covariance matrices areLinearOperator
types, thereby triggering_psd_safe_cholesky
, whenevercholesky
is called.