Skip to content

Commit

Permalink
Added kl_divergence for multivariate normals
Browse files Browse the repository at this point in the history
  • Loading branch information
lumip committed Sep 29, 2023
1 parent bcf38d7 commit 304cab7
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 1 deletion.
64 changes: 63 additions & 1 deletion numpyro/distributions/kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,18 @@

from multipledispatch import dispatch

from jax import lax
from jax import lax, vmap
import jax.numpy as jnp
from jax.scipy.special import betaln, digamma, gammaln
from jax.scipy.linalg import solve_triangular

from numpyro.distributions.continuous import (
Beta,
Dirichlet,
Gamma,
Kumaraswamy,
Normal,
MultivariateNormal,
Weibull,
)
from numpyro.distributions.discrete import CategoricalProbs
Expand Down Expand Up @@ -134,6 +136,66 @@ def kl_divergence(p, q):
return 0.5 * (var_ratio + t1 - 1 - jnp.log(var_ratio))


@dispatch(MultivariateNormal, MultivariateNormal)
def kl_divergence(p: MultivariateNormal, q: MultivariateNormal):
# cf https://statproofbook.github.io/P/mvn-kl.html

if p.event_shape != q.event_shape:
raise ValueError(
"Distributions must be have the same event shape, but are"
f" {p.event_shape} and {q.event_shape} for p and q, respectively."
)

if p.batch_shape != q.batch_shape:
raise ValueError(
"Distributions must be have the same batch shape, but are"
f" {p.batch_shape} and {q.batch_shape} for p and q, respectively."
)

assert len(p.event_shape) == 1, "event_shape must be one-dimensional"
D = p.event_shape[0]

assert p.mean.shape == p.batch_shape + p.event_shape
assert q.mean.shape == p.mean.shape

def _single_mvn_kl(p_mean, p_scale_tril, q_mean, q_scale_tril):
assert jnp.ndim(p_mean) == 1
assert jnp.ndim(q_mean) == 1
assert jnp.ndim(p_scale_tril) == 2
assert jnp.ndim(q_scale_tril) == 2

p_half_log_det = jnp.log(
jnp.diagonal(p_scale_tril)
).sum(-1)
q_half_log_det = jnp.log(
jnp.diagonal(q_scale_tril)
).sum(-1)
log_det_ratio = 2 * (p_half_log_det - q_half_log_det)

Lq_inv = solve_triangular(q_scale_tril, jnp.eye(D), lower=True)

tr = jnp.sum(jnp.diagonal(
Lq_inv.T @ (Lq_inv @ p_scale_tril) @ p_scale_tril.T
))

z = jnp.matmul(Lq_inv, (p_mean - q_mean))
t1 = jnp.dot(z, z)

return .5 * (tr + t1 - D - log_det_ratio)

p_mean_flat = jnp.reshape(p.mean, (-1, D))
p_scale_tril_flat = jnp.reshape(p.scale_tril, (-1, D, D))

q_mean_flat = jnp.reshape(q.mean, (-1, D))
q_scale_tril_flat = jnp.reshape(q.scale_tril, (-1, D, D))

kl_flat = vmap(_single_mvn_kl)(p_mean_flat, p_scale_tril_flat, q_mean_flat, q_scale_tril_flat)
assert jnp.ndim(kl_flat) == 1

kl = jnp.reshape(kl_flat, p.batch_shape)
return kl


@dispatch(Beta, Beta)
def kl_divergence(p, q):
# From https://en.wikipedia.org/wiki/Beta_distribution#Quantities_of_information_(entropy)
Expand Down
49 changes: 49 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2849,6 +2849,55 @@ def test_kl_expanded_normal(batch_shape, event_shape):
assert_allclose(actual, expected)


@pytest.mark.parametrize("batch_shape", [(), (1,), (2, 3)], ids=str)
def test_kl_multivariate_normal_consistency_with_independent_normals(batch_shape):
event_shape = (5, )
shape = batch_shape + event_shape

def make_dists():
mus = np.random.normal(size=shape)
scales = np.exp(np.random.normal(size=shape))
scales = np.ones(shape)

def diagonalize(v, ignore_axes: int):
if ignore_axes == 0:
return jnp.diag(v)
return vmap(diagonalize, in_axes=(0, None))(v, ignore_axes - 1)
scale_tril = diagonalize(scales, len(batch_shape))
return (
dist.Normal(mus, scales).to_event(len(event_shape)),
dist.MultivariateNormal(mus, scale_tril=scale_tril)
)

p_uni, p_mvn = make_dists()
q_uni, q_mvn = make_dists()

actual = kl_divergence(
p_mvn, q_mvn
)
expected = kl_divergence(
p_uni, q_uni
)
assert_allclose(actual, expected, atol=1e-6)


def test_kl_multivariate_normal_nondiagonal_covariance():
p_mvn = dist.MultivariateNormal(np.zeros(2), covariance_matrix=np.eye(2))
q_mvn = dist.MultivariateNormal(
np.ones(2),
covariance_matrix=np.array([
[2, .8],
[.8, .5]
])
)

actual = kl_divergence(
p_mvn, q_mvn
)
expected = 3.21138
assert_allclose(actual, expected, atol=2e-5)


@pytest.mark.parametrize("shape", [(), (4,), (2, 3)], ids=str)
@pytest.mark.parametrize(
"p_dist, q_dist",
Expand Down

0 comments on commit 304cab7

Please sign in to comment.