From c34ae3f54047b5c63a155d278bdda9c75fbdf650 Mon Sep 17 00:00:00 2001 From: Alexandre ANDORRA Date: Fri, 3 Apr 2020 18:02:47 +0200 Subject: [PATCH] Check that concentration parameters of Dirichlet distribution are all > 0 (#3853) * Added check that a>0 in Dirichlet * Cast a as array for tests * Test a>0 only when a not an RV and convert to array when list * Added test for init of Dirichlet with negative values * Added release note * Resolved conflict in release notes * Escaped parenthesis in match regexp --- RELEASE-NOTES.md | 1 + pymc3/distributions/multivariate.py | 10 ++++++++ pymc3/tests/test_distributions.py | 38 ++++++++++++++++++++++++----- 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 10639c14db..0fefc796da 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -20,6 +20,7 @@ - `pm.sample` now takes 1000 draws and 1000 tuning samples by default, instead of 500 previously (see [#3855](https://github.com/pymc-devs/pymc3/pull/3855)). - Dropped the outdated 'nuts' initialization method for `pm.sample` (see [#3863](https://github.com/pymc-devs/pymc3/pull/3863)). - Moved argument division out of `NegativeBinomial` `random` method. Fixes [#3864](https://github.com/pymc-devs/pymc3/issues/3864) in the style of [#3509](https://github.com/pymc-devs/pymc3/pull/3509). +- The Dirichlet distribution now raises a ValueError when it's initialized with <= 0 values (see [#3853](https://github.com/pymc-devs/pymc3/pull/3853)). ## PyMC3 3.8 (November 29 2019) diff --git a/pymc3/distributions/multivariate.py b/pymc3/distributions/multivariate.py index ece8eaf807..c198865738 100755 --- a/pymc3/distributions/multivariate.py +++ b/pymc3/distributions/multivariate.py @@ -488,6 +488,16 @@ class Dirichlet(Continuous): def __init__(self, a, transform=transforms.stick_breaking, *args, **kwargs): + + if not isinstance(a, pm.model.TensorVariable): + if not isinstance(a, list) and not isinstance(a, np.ndarray): + raise TypeError( + 'The vector of concentration parameters (a) must be a python list ' + 'or numpy array.') + a = np.array(a) + if (a <= 0).any(): + raise ValueError("All concentration parameters (a) must be > 0.") + shape = np.atleast_1d(a.shape)[-1] kwargs.setdefault("shape", shape) diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 581b1a8358..9353405451 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -944,17 +944,43 @@ def test_lkj(self, x, eta, n, lp): @pytest.mark.parametrize('n', [2, 3]) def test_dirichlet(self, n): - self.pymc3_matches_scipy(Dirichlet, Simplex( - n), {'a': Vector(Rplus, n)}, dirichlet_logpdf) + self.pymc3_matches_scipy( + Dirichlet, + Simplex(n), + {'a': Vector(Rplus, n)}, + dirichlet_logpdf + ) + + @pytest.mark.parametrize('n', [3, 4]) + def test_dirichlet_init_fail(self, n): + with Model(): + with pytest.raises( + ValueError, + match=r"All concentration parameters \(a\) must be > 0." + ): + _ = Dirichlet('x', a=np.zeros(n), shape=n) + with pytest.raises( + ValueError, + match=r"All concentration parameters \(a\) must be > 0." + ): + _ = Dirichlet('x', a=np.array([-1.] * n), shape=n) def test_dirichlet_2D(self): - self.pymc3_matches_scipy(Dirichlet, MultiSimplex(2, 2), - {'a': Vector(Vector(Rplus, 2), 2)}, dirichlet_logpdf) + self.pymc3_matches_scipy( + Dirichlet, + MultiSimplex(2, 2), + {'a': Vector(Vector(Rplus, 2), 2)}, + dirichlet_logpdf + ) @pytest.mark.parametrize('n', [2, 3]) def test_multinomial(self, n): - self.pymc3_matches_scipy(Multinomial, Vector(Nat, n), {'p': Simplex(n), 'n': Nat}, - multinomial_logpdf) + self.pymc3_matches_scipy( + Multinomial, + Vector(Nat, n), + {'p': Simplex(n), 'n': Nat}, + multinomial_logpdf + ) @pytest.mark.parametrize('p,n', [ [[.25, .25, .25, .25], 1],