Skip to content

Commit

Permalink
Added kl_divergence for multivariate normals (pyro-ppl#1654)
Browse files Browse the repository at this point in the history
* Added kl_divergence for multivariate normals

* style fixes

making the linter tests happy

* kl for multivariate normal now deals with possible batch shapes
  • Loading branch information
lumip authored Oct 27, 2023
1 parent 9ca209d commit eaa29a0
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 3 deletions.
14 changes: 11 additions & 3 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,16 @@ def _batch_solve_triangular(A, B):
return X


def _batch_trace_from_cholesky(L):
"""Computes the trace of matrix X given it's Cholesky decomposition matrix L.
:param jnp.ndarray(..., M, M) L: An array with lower triangular structure in the last two dimensions.
:return: Trace of X, where X = L L^T
"""
return jnp.square(L).sum((-1, -2))


class MatrixNormal(Distribution):
"""
Matrix variate normal distribution as described in [1] but with a lower_triangular parametrization,
Expand Down Expand Up @@ -1358,9 +1368,7 @@ def log_prob(self, values):
diff_col_solve = _batch_solve_triangular(
A=self.scale_tril_column, B=jnp.swapaxes(diff_row_solve, -2, -1)
)
batched_trace_term = jnp.square(
diff_col_solve.reshape(diff_col_solve.shape[:-2] + (-1,))
).sum(-1)
batched_trace_term = _batch_trace_from_cholesky(diff_col_solve)

log_prob = -0.5 * batched_trace_term - log_det_term

Expand Down
49 changes: 49 additions & 0 deletions numpyro/distributions/kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@
Dirichlet,
Gamma,
Kumaraswamy,
MultivariateNormal,
Normal,
Weibull,
_batch_solve_triangular,
_batch_trace_from_cholesky,
)
from numpyro.distributions.discrete import CategoricalProbs
from numpyro.distributions.distribution import (
Expand Down Expand Up @@ -134,6 +137,52 @@ def kl_divergence(p, q):
return 0.5 * (var_ratio + t1 - 1 - jnp.log(var_ratio))


@dispatch(MultivariateNormal, MultivariateNormal)
def kl_divergence(p: MultivariateNormal, q: MultivariateNormal):
# cf https://statproofbook.github.io/P/mvn-kl.html

def _shapes_are_broadcastable(first_shape, second_shape):
try:
jnp.broadcast_shapes(first_shape, second_shape)
return True
except ValueError:
return False

if p.event_shape != q.event_shape:
raise ValueError(
"Distributions must have the same event shape, but are"
f" {p.event_shape} and {q.event_shape} for p and q, respectively."
)

try:
result_batch_shape = jnp.broadcast_shapes(p.batch_shape, q.batch_shape)
except ValueError as ve:
raise ValueError(
"Distributions must have broadcastble batch shapes, "
f"but have {p.batch_shape} and {q.batch_shape} for p and q,"
"respectively."
) from ve

assert len(p.event_shape) == 1, "event_shape must be one-dimensional"
D = p.event_shape[0]

p_half_log_det = jnp.log(jnp.diagonal(p.scale_tril, axis1=-2, axis2=-1)).sum(-1)
q_half_log_det = jnp.log(jnp.diagonal(q.scale_tril, axis1=-2, axis2=-1)).sum(-1)

log_det_ratio = 2 * (p_half_log_det - q_half_log_det)
assert _shapes_are_broadcastable(log_det_ratio.shape, result_batch_shape)

Lq_inv = _batch_solve_triangular(q.scale_tril, jnp.eye(D))

tr = _batch_trace_from_cholesky(Lq_inv @ p.scale_tril)
assert _shapes_are_broadcastable(tr.shape, result_batch_shape)

t1 = jnp.square(Lq_inv @ (p.loc - q.loc)[..., jnp.newaxis]).sum((-2, -1))
assert _shapes_are_broadcastable(t1.shape, result_batch_shape)

return 0.5 * (tr + t1 - D - log_det_ratio)


@dispatch(Beta, Beta)
def kl_divergence(p, q):
# From https://en.wikipedia.org/wiki/Beta_distribution#Quantities_of_information_(entropy)
Expand Down
62 changes: 62 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2849,6 +2849,68 @@ def test_kl_expanded_normal(batch_shape, event_shape):
assert_allclose(actual, expected)


@pytest.mark.parametrize(
"batch_shape_p, batch_shape_q",
[
((1,), (1,)),
((2, 3), (2, 3)),
((5, 1, 3), (2, 3)),
((1, 3), (5, 2, 3)),
],
ids=str,
)
@pytest.mark.parametrize("single_scale_p", [False, True], ids=str)
@pytest.mark.parametrize("single_loc_p", [False, True], ids=str)
@pytest.mark.parametrize("single_scale_q", [False, True], ids=str)
@pytest.mark.parametrize("single_loc_q", [False, True], ids=str)
def test_kl_multivariate_normal_consistency_with_independent_normals(
batch_shape_p,
batch_shape_q,
single_scale_p,
single_loc_p,
single_scale_q,
single_loc_q,
):
event_shape = (5,)

def make_dists(loc_batch_shape, scales_batch_shape):
mus = np.random.normal(size=loc_batch_shape + event_shape)
scales = np.exp(np.random.normal(size=scales_batch_shape + event_shape) * 0.1)

def diagonalize(v, ignore_axes: int):
if ignore_axes == 0:
return jnp.diag(v)
return vmap(diagonalize, in_axes=(0, None))(v, ignore_axes - 1)

scale_tril = diagonalize(scales, len(scales_batch_shape))
return (
dist.Normal(mus, scales).to_event(len(event_shape)),
dist.MultivariateNormal(mus, scale_tril=scale_tril),
)

p_uni, p_mvn = make_dists(
() if single_loc_p else batch_shape_p, () if single_scale_p else batch_shape_p
)
q_uni, q_mvn = make_dists(
() if single_loc_q else batch_shape_q, () if single_scale_q else batch_shape_q
)

actual = kl_divergence(p_mvn, q_mvn)
expected = kl_divergence(p_uni, q_uni)
assert_allclose(actual, expected, atol=1e-5)


def test_kl_multivariate_normal_nondiagonal_covariance():
p_mvn = dist.MultivariateNormal(np.zeros(2), covariance_matrix=np.eye(2))
q_mvn = dist.MultivariateNormal(
np.ones(2), covariance_matrix=np.array([[2, 0.8], [0.8, 0.5]])
)

actual = kl_divergence(p_mvn, q_mvn)
expected = 3.21138
assert_allclose(actual, expected, atol=2e-5)


@pytest.mark.parametrize("shape", [(), (4,), (2, 3)], ids=str)
@pytest.mark.parametrize(
"p_dist, q_dist",
Expand Down

0 comments on commit eaa29a0

Please sign in to comment.