Skip to content

Commit

Permalink
Check that concentration parameters of Dirichlet distribution are all…
Browse files Browse the repository at this point in the history
… > 0 (#3853)

* Added check that a>0 in Dirichlet

* Cast a as array for tests

* Test a>0 only when a not an RV and convert to array when list

* Added test for init of Dirichlet with negative values

* Added release note

* Resolved conflict in release notes

* Escaped parenthesis in match regexp
  • Loading branch information
AlexAndorra authored Apr 3, 2020
1 parent 0456f39 commit c34ae3f
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 6 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
- `pm.sample` now takes 1000 draws and 1000 tuning samples by default, instead of 500 previously (see [#3855](https://github.com/pymc-devs/pymc3/pull/3855)).
- Dropped the outdated 'nuts' initialization method for `pm.sample` (see [#3863](https://github.com/pymc-devs/pymc3/pull/3863)).
- Moved argument division out of `NegativeBinomial` `random` method. Fixes [#3864](https://github.com/pymc-devs/pymc3/issues/3864) in the style of [#3509](https://github.com/pymc-devs/pymc3/pull/3509).
- The Dirichlet distribution now raises a ValueError when it's initialized with <= 0 values (see [#3853](https://github.com/pymc-devs/pymc3/pull/3853)).

## PyMC3 3.8 (November 29 2019)

Expand Down
10 changes: 10 additions & 0 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,16 @@ class Dirichlet(Continuous):

def __init__(self, a, transform=transforms.stick_breaking,
*args, **kwargs):

if not isinstance(a, pm.model.TensorVariable):
if not isinstance(a, list) and not isinstance(a, np.ndarray):
raise TypeError(
'The vector of concentration parameters (a) must be a python list '
'or numpy array.')
a = np.array(a)
if (a <= 0).any():
raise ValueError("All concentration parameters (a) must be > 0.")

shape = np.atleast_1d(a.shape)[-1]

kwargs.setdefault("shape", shape)
Expand Down
38 changes: 32 additions & 6 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,17 +944,43 @@ def test_lkj(self, x, eta, n, lp):

@pytest.mark.parametrize('n', [2, 3])
def test_dirichlet(self, n):
self.pymc3_matches_scipy(Dirichlet, Simplex(
n), {'a': Vector(Rplus, n)}, dirichlet_logpdf)
self.pymc3_matches_scipy(
Dirichlet,
Simplex(n),
{'a': Vector(Rplus, n)},
dirichlet_logpdf
)

@pytest.mark.parametrize('n', [3, 4])
def test_dirichlet_init_fail(self, n):
with Model():
with pytest.raises(
ValueError,
match=r"All concentration parameters \(a\) must be > 0."
):
_ = Dirichlet('x', a=np.zeros(n), shape=n)
with pytest.raises(
ValueError,
match=r"All concentration parameters \(a\) must be > 0."
):
_ = Dirichlet('x', a=np.array([-1.] * n), shape=n)

def test_dirichlet_2D(self):
self.pymc3_matches_scipy(Dirichlet, MultiSimplex(2, 2),
{'a': Vector(Vector(Rplus, 2), 2)}, dirichlet_logpdf)
self.pymc3_matches_scipy(
Dirichlet,
MultiSimplex(2, 2),
{'a': Vector(Vector(Rplus, 2), 2)},
dirichlet_logpdf
)

@pytest.mark.parametrize('n', [2, 3])
def test_multinomial(self, n):
self.pymc3_matches_scipy(Multinomial, Vector(Nat, n), {'p': Simplex(n), 'n': Nat},
multinomial_logpdf)
self.pymc3_matches_scipy(
Multinomial,
Vector(Nat, n),
{'p': Simplex(n), 'n': Nat},
multinomial_logpdf
)

@pytest.mark.parametrize('p,n', [
[[.25, .25, .25, .25], 1],
Expand Down

0 comments on commit c34ae3f

Please sign in to comment.