Skip to content

Commit

Permalink
adds beta-binomial mean and test cases (#5175)
Browse files Browse the repository at this point in the history
  • Loading branch information
morganstrom authored Nov 13, 2021
1 parent 073e26b commit 7485ccc
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
6 changes: 6 additions & 0 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,12 @@ def dist(cls, alpha, beta, n, *args, **kwargs):
n = at.as_tensor_variable(intX(n))
return super().dist([n, alpha, beta], **kwargs)

def get_moment(rv, size, n, alpha, beta):
mean = at.round((n * alpha) / (alpha + beta))
if not rv_size_is_none(size):
mean = at.full(size, mean)
return mean

def logp(value, n, alpha, beta):
r"""
Calculate log-probability of BetaBinomial distribution at specified value.
Expand Down
16 changes: 16 additions & 0 deletions pymc/tests/test_distributions_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pymc.distributions import (
Bernoulli,
Beta,
BetaBinomial,
Binomial,
Cauchy,
ChiSquared,
Expand Down Expand Up @@ -212,6 +213,21 @@ def test_beta_moment(alpha, beta, size, expected):
assert_moment_is_expected(model, expected)


@pytest.mark.parametrize(
"n, alpha, beta, size, expected",
[
(10, 1, 1, None, 5),
(10, 1, 1, 5, np.full(5, 5)),
(10, 1, np.arange(1, 6), None, np.round(10 / np.arange(2, 7))),
(10, 1, np.arange(1, 6), (2, 5), np.full((2, 5), np.round(10 / np.arange(2, 7)))),
],
)
def test_beta_binomial_moment(alpha, beta, n, size, expected):
with Model() as model:
BetaBinomial("x", alpha=alpha, beta=beta, n=n, size=size)
assert_moment_is_expected(model, expected)


@pytest.mark.parametrize(
"nu, size, expected",
[
Expand Down

0 comments on commit 7485ccc

Please sign in to comment.