Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure PSD-safe factorization in constructor of MultivariateNormal #2297

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

SebastianAment
Copy link
Contributor

There was a recent BoTorch issue that was caused by a positive semi-definite matrix being passed to MultivariateNormal as a Tensor, which causes the constructor to fail because PyTorch's constructor calls cholesky on the tensor. This commit upstreams the corresponding BoTorch PR to ensure that all covariance matrices are LinearOperator types, thereby triggering _psd_safe_cholesky, whenever cholesky is called.

@SebastianAment SebastianAment force-pushed the mvn-constructor branch 3 times, most recently from e06f8da to d0e35e4 Compare March 10, 2023 18:36
Copy link
Collaborator

@Balandat Balandat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! An alternative to making everything into a LinearOperator would be upon receiving a dense tensor, rather than just storing that and passing it along, to compute the cholesky decomposition with jitted and then pass that as the scale_tril to the torch distribution. The downside of that is ofc that this would do a lot of compute upon construction of the object, potentially unnecessary. Another hack could be to mock some of the torch distribution code so that it users the psd-safe cholesky decomposition internally (though that seems very hacky and potentially problematic).

# will fail if the covariance matrix is semi-definite, whereas DenseLinearOperator ends up
# calling _psd_safe_cholesky, which factorizes semi-definite matrices by adding to the diagonal.
if isinstance(covariance_matrix, Tensor):
self._islazy = False # to allow _unbroadcasted_scale_tril setter
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems odd to have _islazy set to True if the covariance matrix is indeed a LinearOperator. I guess the "lazy" nomenclature is a bit outdated anyway with the move to LinearOperator.


event_shape = self.loc.shape[-1:]

# TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean changing the torch code to validate LinearOperator inputs? That might be somewhat challenging to do if we want to use LinearOperators there explicitly. What would work is to make changes in pure torch that would make it easier to use LinearOperator objects by means of the __torch_function__ interface we define in LinearOperator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants