diff --git a/numpyro/distributions/kl.py b/numpyro/distributions/kl.py index 2fcfecb2f..3d79dca01 100644 --- a/numpyro/distributions/kl.py +++ b/numpyro/distributions/kl.py @@ -85,12 +85,12 @@ def kl_divergence(p, q): @dispatch(Delta, Distribution) def kl_divergence(p, q): - return -q.log_prob(p.v) + return -q.log_prob(p.v) + p.log_density @dispatch(Delta, ExpandedDistribution) def kl_divergence(p, q): - return -q.log_prob(p.v) + return -q.log_prob(p.v) + p.log_density @dispatch(Independent, Independent) diff --git a/test/test_distributions.py b/test/test_distributions.py index dbcb21218..0362adf28 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -2814,6 +2814,15 @@ def test_kl_delta_normal_shape(batch_shape): assert kl_divergence(p, q).shape == batch_shape +def test_kl_delta_normal(): + v = np.random.normal() + loc = np.random.normal() + scale = np.exp(np.random.normal()) + p = dist.Delta(v, 10.0) + q = dist.Normal(loc, scale) + assert_allclose(kl_divergence(p, q), 10.0 - q.log_prob(v)) + + @pytest.mark.parametrize("batch_shape", [(), (4,), (2, 3)], ids=str) @pytest.mark.parametrize("event_shape", [(), (4,), (2, 3)], ids=str) def test_kl_independent_normal(batch_shape, event_shape):