Skip to content

Commit

Permalink
Casting Tensor to LinearOperator in constructor of MultivariateNormal…
Browse files Browse the repository at this point in the history
… to ensure PSD-safe factorization
  • Loading branch information
SebastianAment committed Mar 10, 2023
1 parent 442ad98 commit d0e35e4
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 42 deletions.
9 changes: 7 additions & 2 deletions gpytorch/distributions/multitask_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

import torch
from linear_operator import LinearOperator, to_linear_operator
from linear_operator.operators import BlockDiagLinearOperator, BlockInterleavedLinearOperator, CatLinearOperator
from linear_operator.operators import (BlockDiagLinearOperator,
BlockInterleavedLinearOperator,
CatLinearOperator)
from torch import Tensor

from .multivariate_normal import MultivariateNormal

Expand All @@ -24,7 +27,9 @@ class MultitaskMultivariateNormal(MultivariateNormal):
w.r.t. inter-observation covariance for each task.
"""

def __init__(self, mean, covariance_matrix, validate_args=False, interleaved=True):
def __init__(
self, mean: Tensor, covariance_matrix: LinearOperator, validate_args: bool = False, interleaved: bool = True
):
if not torch.is_tensor(mean) and not isinstance(mean, LinearOperator):
raise RuntimeError("The mean of a MultitaskMultivariateNormal must be a Tensor or LinearOperator")

Expand Down
81 changes: 41 additions & 40 deletions gpytorch/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

import torch
from linear_operator import to_dense, to_linear_operator
from linear_operator.operators import DiagLinearOperator, LinearOperator, RootLinearOperator
from linear_operator.operators import (DiagLinearOperator, LinearOperator,
RootLinearOperator)
from torch import Tensor
from torch.distributions import MultivariateNormal as TMultivariateNormal
from torch.distributions.kl import register_kl
Expand Down Expand Up @@ -42,27 +43,36 @@ class MultivariateNormal(TMultivariateNormal, Distribution):
:ivar torch.Tensor variance: The variance.
"""

def __init__(self, mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator], validate_args: bool = False):
self._islazy = isinstance(mean, LinearOperator) or isinstance(covariance_matrix, LinearOperator)
if self._islazy:
if validate_args:
ms = mean.size(-1)
cs1 = covariance_matrix.size(-1)
cs2 = covariance_matrix.size(-2)
if not (ms == cs1 and ms == cs2):
raise ValueError(f"Wrong shapes in {self._repr_sizes(mean, covariance_matrix)}")
self.loc = mean
self._covar = covariance_matrix
self.__unbroadcasted_scale_tril = None
self._validate_args = validate_args
batch_shape = torch.broadcast_shapes(self.loc.shape[:-1], covariance_matrix.shape[:-2])

event_shape = self.loc.shape[-1:]

# TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic
super(TMultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=False)
else:
super().__init__(loc=mean, covariance_matrix=covariance_matrix, validate_args=validate_args)
def __init__(
self,
mean: Union[Tensor, LinearOperator],
covariance_matrix: Union[Tensor, LinearOperator],
validate_args: bool = False,
):
self._islazy = True
# casting Tensor to DenseLinearOperator because the super constructor calls cholesky, which
# will fail if the covariance matrix is semi-definite, whereas DenseLinearOperator ends up
# calling _psd_safe_cholesky, which factorizes semi-definite matrices by adding to the diagonal.
if isinstance(covariance_matrix, Tensor):
self._islazy = False # to allow _unbroadcasted_scale_tril setter
covariance_matrix = to_linear_operator(covariance_matrix)

if validate_args:
ms = mean.size(-1)
cs1 = covariance_matrix.size(-1)
cs2 = covariance_matrix.size(-2)
if not (ms == cs1 and ms == cs2):
raise ValueError(f"Wrong shapes in {self._repr_sizes(mean, covariance_matrix)}")
self.loc = mean
self._covar = covariance_matrix
self.__unbroadcasted_scale_tril = None
self._validate_args = validate_args
batch_shape = torch.broadcast_shapes(self.loc.shape[:-1], covariance_matrix.shape[:-2])

event_shape = self.loc.shape[-1:]

# TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic
super(TMultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=False)

def _extended_shape(self, sample_shape: torch.Size = torch.Size()) -> torch.Size:
"""
Expand All @@ -81,16 +91,16 @@ def _extended_shape(self, sample_shape: torch.Size = torch.Size()) -> torch.Size
def _repr_sizes(mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator]) -> str:
return f"MultivariateNormal(loc: {mean.size()}, scale: {covariance_matrix.size()})"

@property
@property # not using lazy_property here, because it does not allow for setter below
def _unbroadcasted_scale_tril(self) -> Tensor:
if self.islazy and self.__unbroadcasted_scale_tril is None:
if self.__unbroadcasted_scale_tril is None:
# cache root decoposition
ust = to_dense(self.lazy_covariance_matrix.cholesky())
self.__unbroadcasted_scale_tril = ust
return self.__unbroadcasted_scale_tril

@_unbroadcasted_scale_tril.setter
def _unbroadcasted_scale_tril(self, ust: Tensor):
def _unbroadcasted_scale_tril(self, ust: Tensor) -> None:
if self.islazy:
raise NotImplementedError("Cannot set _unbroadcasted_scale_tril for lazy MVN distributions")
else:
Expand All @@ -114,10 +124,7 @@ def base_sample_shape(self) -> torch.Size:

@lazy_property
def covariance_matrix(self) -> Tensor:
if self.islazy:
return self._covar.to_dense()
else:
return super().covariance_matrix
return self._covar.to_dense()

def confidence_region(self) -> Tuple[Tensor, Tensor]:
"""
Expand Down Expand Up @@ -157,10 +164,7 @@ def get_base_samples(self, sample_shape: torch.Size = torch.Size()) -> Tensor:

@lazy_property
def lazy_covariance_matrix(self) -> LinearOperator:
if self.islazy:
return self._covar
else:
return to_linear_operator(super().covariance_matrix)
return self._covar

def log_prob(self, value: Tensor) -> Tensor:
r"""
Expand Down Expand Up @@ -304,13 +308,10 @@ def to_data_independent_dist(self) -> torch.distributions.Normal:

@property
def variance(self) -> Tensor:
if self.islazy:
# overwrite this since torch MVN uses unbroadcasted_scale_tril for this
diag = self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)
diag = diag.view(diag.shape[:-1] + self._event_shape)
variance = diag.expand(self._batch_shape + self._event_shape)
else:
variance = super().variance
# overwrite this since torch MVN uses unbroadcasted_scale_tril for this
diag = self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)
diag = diag.view(diag.shape[:-1] + self._event_shape)
variance = diag.expand(self._batch_shape + self._event_shape)

# Check to make sure that variance isn't lower than minimum allowed value (default 1e-6).
# This ensures that all variances are positive
Expand Down
14 changes: 14 additions & 0 deletions test/distributions/test_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,20 @@ def test_multivariate_normal_non_lazy(self, cuda=False):
self.assertTrue(mvn.sample(torch.Size([2])).shape == torch.Size([2, 3]))
self.assertTrue(mvn.sample(torch.Size([2, 4])).shape == torch.Size([2, 4, 3]))

# testing with semi-definite input
A = torch.randn(len(mean), 1)
covmat = A @ A.T
handles_psd = False
try:
# the regular call fails:
# mvn = TMultivariateNormal(loc=mean, covariance_matrix=covmat, validate_args=True)
mvn = MultivariateNormal(mean=mean, covariance_matrix=covmat, validate_args=True)
mvn.sample()
handles_psd = True
except ValueError:
handles_psd = False
self.assertTrue(handles_psd)

def test_multivariate_normal_non_lazy_cuda(self):
if torch.cuda.is_available():
with least_used_cuda_device():
Expand Down

0 comments on commit d0e35e4

Please sign in to comment.