Skip to content

Commit

Permalink
Fix Dirichlet.logp (#4454)
Browse files Browse the repository at this point in the history
* Fix Dirichlet.logp by checking number of categories > 1 only at event dims

* Update test_distributions.py

* Removed the shape validation check to even work for last dimensional shape as 1.

Modified the `test_dirichlet` function to check for the same.

* Added a test to check Dirichlet.logp with different batch shapes.

* Tested exact Dirichlet.logp values againt scipy implementation

Given a mention in RELEASE-NOTES.md
  • Loading branch information
Sayam753 authored Feb 6, 2021
1 parent b6660f9 commit 0c21de4
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 2 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
- `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)).
- `ScalarSharedVariable` can now be used as an input to other RVs directly (see [#4445](https://github.com/pymc-devs/pymc3/pull/4445)).
- `pm.sample` and `pm.find_MAP` no longer change the `start` argument (see [#4458](https://github.com/pymc-devs/pymc3/pull/4458)).
- Fixed `Dirichlet.logp` method to work with unit batch or event shapes (see [#4454](https://github.com/pymc-devs/pymc3/pull/4454)).

## PyMC3 3.11.0 (21 January 2021)

Expand Down
1 change: 0 additions & 1 deletion pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,6 @@ def logp(self, value):
tt.sum(logpow(value, a - 1) - gammaln(a), axis=-1) + gammaln(tt.sum(a, axis=-1)),
tt.all(value >= 0),
tt.all(value <= 1),
np.logical_not(a.broadcastable),
tt.all(a > 0),
broadcast_conditions=False,
)
Expand Down
13 changes: 12 additions & 1 deletion pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1692,10 +1692,21 @@ def test_lkj(self, x, eta, n, lp):
decimals = select_by_precision(float64=6, float32=4)
assert_almost_equal(model.fastlogp(pt), lp, decimal=decimals, err_msg=str(pt))

@pytest.mark.parametrize("n", [2, 3])
@pytest.mark.parametrize("n", [1, 2, 3])
def test_dirichlet(self, n):
self.pymc3_matches_scipy(Dirichlet, Simplex(n), {"a": Vector(Rplus, n)}, dirichlet_logpdf)

@pytest.mark.parametrize("dist_shape", [1, (2, 1), (1, 2), (2, 4, 3)])
def test_dirichlet_with_batch_shapes(self, dist_shape):
a = np.ones(dist_shape)
with pm.Model() as model:
d = pm.Dirichlet("a", a=a)

pymc3_res = d.distribution.logp(d.tag.test_value).eval()
for idx in np.ndindex(a.shape[:-1]):
scipy_res = scipy.stats.dirichlet(a[idx]).logpdf(d.tag.test_value[idx])
assert_almost_equal(pymc3_res[idx], scipy_res)

def test_dirichlet_shape(self):
a = tt.as_tensor_variable(np.r_[1, 2])
with pytest.warns(DeprecationWarning):
Expand Down

0 comments on commit 0c21de4

Please sign in to comment.