diff --git a/pyro/distributions/hmm.py b/pyro/distributions/hmm.py index 2a38b19f11..2112d9979b 100644 --- a/pyro/distributions/hmm.py +++ b/pyro/distributions/hmm.py @@ -20,7 +20,7 @@ ) from pyro.ops.indexing import Vindex from pyro.ops.special import safe_log -from pyro.ops.tensor_utils import cholesky, cholesky_solve +from pyro.ops.tensor_utils import cholesky_solve, safe_cholesky from . import constraints from .torch import Categorical, Gamma, Independent, MultivariateNormal @@ -628,9 +628,9 @@ def filter(self, value): # Convert to a distribution precision = logp.precision - loc = cholesky_solve(logp.info_vec.unsqueeze(-1), cholesky(precision)).squeeze( - -1 - ) + loc = cholesky_solve( + logp.info_vec.unsqueeze(-1), safe_cholesky(precision) + ).squeeze(-1) return MultivariateNormal( loc, precision_matrix=precision, validate_args=self._validate_args ) @@ -928,7 +928,7 @@ def filter(self, value): gamma_dist.concentration, gamma_dist.rate, validate_args=self._validate_args ) # Conditional of last state on unit scale - scale_tril = cholesky(logp.precision) + scale_tril = safe_cholesky(logp.precision) loc = cholesky_solve(logp.info_vec.unsqueeze(-1), scale_tril).squeeze(-1) mvn = MultivariateNormal( loc, scale_tril=scale_tril, validate_args=self._validate_args diff --git a/pyro/distributions/transforms/cholesky.py b/pyro/distributions/transforms/cholesky.py index 3b890f5a3b..d2e4d22684 100644 --- a/pyro/distributions/transforms/cholesky.py +++ b/pyro/distributions/transforms/cholesky.py @@ -89,7 +89,7 @@ def log_abs_det_jacobian(self, x, y): class CholeskyTransform(Transform): r""" - Transform via the mapping :math:`y = cholesky(x)`, where `x` is a + Transform via the mapping :math:`y = safe_cholesky(x)`, where `x` is a positive definite matrix. """ bijective = True @@ -116,7 +116,7 @@ def log_abs_det_jacobian(self, x, y): class CorrMatrixCholeskyTransform(CholeskyTransform): r""" - Transform via the mapping :math:`y = cholesky(x)`, where `x` is a + Transform via the mapping :math:`y = safe_cholesky(x)`, where `x` is a correlation matrix. """ bijective = True diff --git a/pyro/ops/gaussian.py b/pyro/ops/gaussian.py index 4541513495..12f17e973c 100644 --- a/pyro/ops/gaussian.py +++ b/pyro/ops/gaussian.py @@ -9,7 +9,7 @@ from torch.nn.functional import pad from pyro.distributions.util import broadcast_shape -from pyro.ops.tensor_utils import cholesky, matmul, matvecmul, triangular_solve +from pyro.ops.tensor_utils import matmul, matvecmul, safe_cholesky, triangular_solve class Gaussian: @@ -154,7 +154,7 @@ def rsample( """ Reparameterized sampler. """ - P_chol = cholesky(self.precision) + P_chol = safe_cholesky(self.precision) loc = self.info_vec.unsqueeze(-1).cholesky_solve(P_chol).squeeze(-1) shape = sample_shape + self.batch_shape + (self.dim(), 1) if noise is None: @@ -254,7 +254,7 @@ def marginalize(self, left=0, right=0) -> "Gaussian": P_aa = self.precision[..., a, a] P_ba = self.precision[..., b, a] P_bb = self.precision[..., b, b] - P_b = cholesky(P_bb) + P_b = safe_cholesky(P_bb) P_a = triangular_solve(P_ba, P_b, upper=False) P_at = P_a.transpose(-1, -2) precision = P_aa - matmul(P_at, P_a) @@ -277,7 +277,7 @@ def event_logsumexp(self) -> torch.Tensor: Integrates out all latent state (i.e. operating on event dimensions). """ n = self.dim() - chol_P = cholesky(self.precision) + chol_P = safe_cholesky(self.precision) chol_P_u = triangular_solve( self.info_vec.unsqueeze(-1), chol_P, upper=False ).squeeze(-1) @@ -550,7 +550,7 @@ def gaussian_tensordot(x: Gaussian, y: Gaussian, dims: int = 0) -> Gaussian: b = xb + yb # Pbb + Qbb needs to be positive definite, so that we can malginalize out `b` (to have a finite integral) - L = cholesky(Pbb + Qbb) + L = safe_cholesky(Pbb + Qbb) LinvB = triangular_solve(B, L, upper=False) LinvBt = LinvB.transpose(-2, -1) Linvb = triangular_solve(b.unsqueeze(-1), L, upper=False) diff --git a/pyro/ops/tensor_utils.py b/pyro/ops/tensor_utils.py index 7efc847d59..f820aa72cf 100644 --- a/pyro/ops/tensor_utils.py +++ b/pyro/ops/tensor_utils.py @@ -7,6 +7,7 @@ from torch.fft import irfft, rfft _ROOT_TWO_INVERSE = 1.0 / math.sqrt(2.0) +CHOLESKY_RELATIVE_JITTER = 4.0 # in units of finfo.eps def as_complex(x): @@ -393,9 +394,19 @@ def inverse_haar_transform(x): return x -def cholesky(x): +def safe_cholesky(x): if x.size(-1) == 1: + if CHOLESKY_RELATIVE_JITTER: + x = x.clamp(min=torch.finfo(x.dtype).tiny) return x.sqrt() + + if CHOLESKY_RELATIVE_JITTER: + # Add adaptive jitter. + x = x.clone() + x_max = x.data.abs().max(-1).values + jitter = CHOLESKY_RELATIVE_JITTER * torch.finfo(x.dtype).eps * x_max + x.data.diagonal(dim1=-1, dim2=-2).add_(jitter) + return torch.linalg.cholesky(x) diff --git a/tests/ops/test_gaussian.py b/tests/ops/test_gaussian.py index f8fab903b7..0982d8eb91 100644 --- a/tests/ops/test_gaussian.py +++ b/tests/ops/test_gaussian.py @@ -15,6 +15,7 @@ AffineNormal, Gaussian, gaussian_tensordot, + matrix_and_gaussian_to_gaussian, matrix_and_mvn_to_gaussian, mvn_to_gaussian, sequential_gaussian_filter_sample, @@ -378,7 +379,7 @@ def test_gaussian_tensordot( nc = y_dim - dot_dims try: torch.linalg.cholesky(x.precision[..., na:, na:] + y.precision[..., :nb, :nb]) - except RuntimeError: + except Exception: pytest.skip("Cannot marginalize the common variables of two Gaussians.") z = gaussian_tensordot(x, y, dot_dims) @@ -557,3 +558,55 @@ def test_sequential_gaussian_filter_sample_antithetic( ) expected = torch.stack([sample, mean, 2 * mean - sample]) assert torch.allclose(sample3, expected) + + +@pytest.mark.filterwarnings("ignore:Singular matrix in cholesky") +@pytest.mark.parametrize("num_steps", [10, 100, 1000, 10000, 100000, 1000000]) +def test_sequential_gaussian_filter_sample_stability(num_steps): + # This tests long-chain filtering at low precision. + zero = torch.zeros((), dtype=torch.float) + eye = torch.eye(4, dtype=torch.float) + noise = torch.randn(num_steps, 4, dtype=torch.float, requires_grad=True) + trans_matrix = torch.tensor( + [ + [ + 0.8571434617042542, + -0.23285813629627228, + 0.05360094830393791, + -0.017088839784264565, + ], + [ + 0.7609677314758301, + 0.6596274971961975, + -0.022656921297311783, + 0.05166701227426529, + ], + [ + 3.0979342460632324, + 5.446939945220947, + -0.3425334692001343, + 0.01096670888364315, + ], + [ + -1.8180007934570312, + -0.4965082108974457, + -0.006048532668501139, + -0.08525419235229492, + ], + ], + dtype=torch.float, + requires_grad=True, + ) + + init = Gaussian(zero, zero.expand(4), eye) + trans = matrix_and_gaussian_to_gaussian( + trans_matrix, Gaussian(zero, zero.expand(4), eye) + ).expand((num_steps - 1,)) + + # Check numerically stabilized value. + x = sequential_gaussian_filter_sample(init, trans, (), noise) + assert torch.isfinite(x).all() + + # Check gradients. + grads = torch.autograd.grad(x.sum(), [trans_matrix, noise]) + assert all(torch.isfinite(g).all() for g in grads)