Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jitter to Cholesky factorization in Gaussian ops #3151

Merged
merged 11 commits into from
Oct 30, 2022
10 changes: 5 additions & 5 deletions pyro/distributions/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from pyro.ops.indexing import Vindex
from pyro.ops.special import safe_log
from pyro.ops.tensor_utils import cholesky, cholesky_solve
from pyro.ops.tensor_utils import cholesky_solve, safe_cholesky

from . import constraints
from .torch import Categorical, Gamma, Independent, MultivariateNormal
Expand Down Expand Up @@ -628,9 +628,9 @@ def filter(self, value):

# Convert to a distribution
precision = logp.precision
loc = cholesky_solve(logp.info_vec.unsqueeze(-1), cholesky(precision)).squeeze(
-1
)
loc = cholesky_solve(
logp.info_vec.unsqueeze(-1), safe_cholesky(precision)
).squeeze(-1)
return MultivariateNormal(
loc, precision_matrix=precision, validate_args=self._validate_args
)
Expand Down Expand Up @@ -928,7 +928,7 @@ def filter(self, value):
gamma_dist.concentration, gamma_dist.rate, validate_args=self._validate_args
)
# Conditional of last state on unit scale
scale_tril = cholesky(logp.precision)
scale_tril = safe_cholesky(logp.precision)
loc = cholesky_solve(logp.info_vec.unsqueeze(-1), scale_tril).squeeze(-1)
mvn = MultivariateNormal(
loc, scale_tril=scale_tril, validate_args=self._validate_args
Expand Down
4 changes: 2 additions & 2 deletions pyro/distributions/transforms/cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def log_abs_det_jacobian(self, x, y):

class CholeskyTransform(Transform):
r"""
Transform via the mapping :math:`y = cholesky(x)`, where `x` is a
Transform via the mapping :math:`y = safe_cholesky(x)`, where `x` is a
positive definite matrix.
"""
bijective = True
Expand All @@ -116,7 +116,7 @@ def log_abs_det_jacobian(self, x, y):

class CorrMatrixCholeskyTransform(CholeskyTransform):
r"""
Transform via the mapping :math:`y = cholesky(x)`, where `x` is a
Transform via the mapping :math:`y = safe_cholesky(x)`, where `x` is a
correlation matrix.
"""
bijective = True
Expand Down
10 changes: 5 additions & 5 deletions pyro/ops/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.nn.functional import pad

from pyro.distributions.util import broadcast_shape
from pyro.ops.tensor_utils import cholesky, matmul, matvecmul, triangular_solve
from pyro.ops.tensor_utils import matmul, matvecmul, safe_cholesky, triangular_solve


class Gaussian:
Expand Down Expand Up @@ -154,7 +154,7 @@ def rsample(
"""
Reparameterized sampler.
"""
P_chol = cholesky(self.precision)
P_chol = safe_cholesky(self.precision)
loc = self.info_vec.unsqueeze(-1).cholesky_solve(P_chol).squeeze(-1)
shape = sample_shape + self.batch_shape + (self.dim(), 1)
if noise is None:
Expand Down Expand Up @@ -254,7 +254,7 @@ def marginalize(self, left=0, right=0) -> "Gaussian":
P_aa = self.precision[..., a, a]
P_ba = self.precision[..., b, a]
P_bb = self.precision[..., b, b]
P_b = cholesky(P_bb)
P_b = safe_cholesky(P_bb)
P_a = triangular_solve(P_ba, P_b, upper=False)
P_at = P_a.transpose(-1, -2)
precision = P_aa - matmul(P_at, P_a)
Expand All @@ -277,7 +277,7 @@ def event_logsumexp(self) -> torch.Tensor:
Integrates out all latent state (i.e. operating on event dimensions).
"""
n = self.dim()
chol_P = cholesky(self.precision)
chol_P = safe_cholesky(self.precision)
chol_P_u = triangular_solve(
self.info_vec.unsqueeze(-1), chol_P, upper=False
).squeeze(-1)
Expand Down Expand Up @@ -550,7 +550,7 @@ def gaussian_tensordot(x: Gaussian, y: Gaussian, dims: int = 0) -> Gaussian:
b = xb + yb

# Pbb + Qbb needs to be positive definite, so that we can malginalize out `b` (to have a finite integral)
L = cholesky(Pbb + Qbb)
L = safe_cholesky(Pbb + Qbb)
LinvB = triangular_solve(B, L, upper=False)
LinvBt = LinvB.transpose(-2, -1)
Linvb = triangular_solve(b.unsqueeze(-1), L, upper=False)
Expand Down
13 changes: 12 additions & 1 deletion pyro/ops/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.fft import irfft, rfft

_ROOT_TWO_INVERSE = 1.0 / math.sqrt(2.0)
CHOLESKY_JITTER = 1.0 # in units of finfo.eps


def as_complex(x):
Expand Down Expand Up @@ -393,9 +394,19 @@ def inverse_haar_transform(x):
return x


def cholesky(x):
def safe_cholesky(x):
if x.size(-1) == 1:
if CHOLESKY_JITTER:
Copy link
Member

@fehiepsi fehiepsi Oct 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you intend to clamp by (CHOLESKY_JITTER * finfo(x.dtype).eps) ** 2 here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My intention was to scale by about finfo(x.dtype).eps * x.max() so that the jitter was just barely detectable by the largest matrix entry before Cholesky factorizing. That way if we set RELATIVE_CHOLESKY_JITTER = 1/2 then jitter will only affect matrix entries less than half the size of the max. And it kindof makes sense to me that each additional bit of precision would mean we would need to add half as much jitter, thus jitter would be proportional to finfo(x.dtype).eps. Mostly the proportional think helps us keep a constant RELATIVE_CHOLESKY_JITTER across float32 and float64.

What's your intuition behind the square here, is that to keep constant error post-Cholesky-factorization?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed the x_max term in the above comment. Using square seems to be more consistent w.r.t. to cases x.size(-1) > 1 - but I like your clamp by tiny better.

Re x_max: using global max makes sense, but I feel that it might be better to use max of rows instead, e.g. considering the diagonal matrix [0.0001, 10000], the global jitter is large w.r.t. the first diagonal term.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice idea! I've switched to using a row-wise max. This required increasing CHOLESKY_RELATIVE_JITTER from 1.0 to 4.0, but this way still seems better 👍

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this really preferred? this changes the eigenvalues and eigenvectors as opposed to jitter that is proportional to the identity (which only changes eigenvalues)

Copy link
Member Author

@fritzo fritzo Oct 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@martinjankowiak that's a good point, I didn't know about eigenvector preservation. I'd be ok with either version.

One thing I like about @fehiepsi's solution is that users can on their side rotate the system before performing Gaussian ops, e.g. I'm approximately diagonalizing via QR

evals, evecs = torch.linalg.eig(transition_matrix)
Q, R = torch.linalg.qr(evecs.real)
transition_matrix = Q.T @ transition_matrix @ Q

which shrinks my diagonal perturbations

- [5.45, 1.81, 0.86, 0.76] * eps * CHOLESKY_RELATIVE_JITTER
+ [4.91, 1.67, 0.21, 0.17] * eps * CHOLESKY_RELATIVE_JITTER

x = x.clamp(min=torch.finfo(x.dtype).tiny)
return x.sqrt()

if CHOLESKY_JITTER:
# Add adaptive jitter.
x = x.clone()
x_max = x.data.reshape(*x.shape[:-2], -1).abs().max(-1, True).values
jitter = CHOLESKY_JITTER * torch.finfo(x.dtype).eps * x_max
x.data.diagonal(dim1=-1, dim2=-2).add_(jitter)

return torch.linalg.cholesky(x)


Expand Down
55 changes: 54 additions & 1 deletion tests/ops/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
AffineNormal,
Gaussian,
gaussian_tensordot,
matrix_and_gaussian_to_gaussian,
matrix_and_mvn_to_gaussian,
mvn_to_gaussian,
sequential_gaussian_filter_sample,
Expand Down Expand Up @@ -378,7 +379,7 @@ def test_gaussian_tensordot(
nc = y_dim - dot_dims
try:
torch.linalg.cholesky(x.precision[..., na:, na:] + y.precision[..., :nb, :nb])
except RuntimeError:
except Exception:
pytest.skip("Cannot marginalize the common variables of two Gaussians.")

z = gaussian_tensordot(x, y, dot_dims)
Expand Down Expand Up @@ -557,3 +558,55 @@ def test_sequential_gaussian_filter_sample_antithetic(
)
expected = torch.stack([sample, mean, 2 * mean - sample])
assert torch.allclose(sample3, expected)


@pytest.mark.filterwarnings("ignore:Singular matrix in cholesky")
@pytest.mark.parametrize("num_steps", [10, 100, 1000, 10000, 100000, 1000000])
def test_sequential_gaussian_filter_sample_stability(num_steps):
# This tests long-chain filtering at low precision.
zero = torch.zeros((), dtype=torch.float)
eye = torch.eye(4, dtype=torch.float)
noise = torch.randn(num_steps, 4, dtype=torch.float, requires_grad=True)
trans_matrix = torch.tensor(
[
[
0.8571434617042542,
-0.23285813629627228,
0.05360094830393791,
-0.017088839784264565,
],
[
0.7609677314758301,
0.6596274971961975,
-0.022656921297311783,
0.05166701227426529,
],
[
3.0979342460632324,
5.446939945220947,
-0.3425334692001343,
0.01096670888364315,
],
[
-1.8180007934570312,
-0.4965082108974457,
-0.006048532668501139,
-0.08525419235229492,
],
],
dtype=torch.float,
requires_grad=True,
)

init = Gaussian(zero, zero.expand(4), eye)
trans = matrix_and_gaussian_to_gaussian(
trans_matrix, Gaussian(zero, zero.expand(4), eye)
).expand((num_steps - 1,))

# Check numerically stabilized value.
x = sequential_gaussian_filter_sample(init, trans, (), noise)
assert torch.isfinite(x).all()

# Check gradients.
grads = torch.autograd.grad(x.sum(), [trans_matrix, noise])
assert all(torch.isfinite(g).all() for g in grads)