Skip to content

Commit

Permalink
Add bound to HyperGeometric logp (resolves #4366) (#4367)
Browse files Browse the repository at this point in the history
* - Add bound to HyperGeometric logp
- Pass unit tests when scipy logpmf returns nan

* - Add release-note

* - Replace tt.max and tt.min with tt.switch
  • Loading branch information
ricardoV94 authored Dec 22, 2020
1 parent 0ec65e5 commit 0402aab
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ It also brings some dreadfully awaited fixes, so be sure to go through the chang
- The notebook gallery has been moved to https://github.com/pymc-devs/pymc-examples (see [#4348](https://github.com/pymc-devs/pymc3/pull/4348)).
- `math.logsumexp` now matches `scipy.special.logsumexp` when arrays contain infinite values (see [#4360](https://github.com/pymc-devs/pymc3/pull/4360)).
- Fixed mathematical formulation in `MvStudentT` random method. (see [#4359](https://github.com/pymc-devs/pymc3/pull/4359))
- Fix issue in `logp` method of `HyperGeometric`. It now returns `-inf` for invalid parameters (see [4367](https://github.com/pymc-devs/pymc3/pull/4367))

## PyMC3 3.10.0 (7 December 2020)

Expand Down
5 changes: 4 additions & 1 deletion pymc3/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,10 @@ def logp(self, value):
- betaln(n - value + 1, bad - n + value + 1)
- betaln(tot + 1, 1)
)
return result
# value in [max(0, n - N + k), min(k, n)]
lower = tt.switch(tt.gt(n - N + k, 0), n - N + k, 0)
upper = tt.switch(tt.lt(k, n), k, n)
return bound(result, lower <= value, value <= upper)


class DiscreteUniform(Discrete):
Expand Down
6 changes: 5 additions & 1 deletion pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,11 +805,15 @@ def test_geometric(self):
)

def test_hypergeometric(self):
def modified_scipy_hypergeom_logpmf(value, N, k, n):
original_res = sp.hypergeom.logpmf(value, N, k, n)
return original_res if not np.isnan(original_res) else -np.inf

self.pymc3_matches_scipy(
HyperGeometric,
Nat,
{"N": NatSmall, "k": NatSmall, "n": NatSmall},
lambda value, N, k, n: sp.hypergeom.logpmf(value, N, k, n),
lambda value, N, k, n: modified_scipy_hypergeom_logpmf(value, N, k, n),
)

def test_negative_binomial(self):
Expand Down

0 comments on commit 0402aab

Please sign in to comment.