Skip to content

Commit

Permalink
add categorical moment (#5176)
Browse files Browse the repository at this point in the history
Co-authored-by: Farhan Reynaldo <farhanreynaldo@gmail.com>
  • Loading branch information
farhanreynaldo and farhanreynaldo authored Nov 13, 2021
1 parent 7485ccc commit 140dab0
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
11 changes: 6 additions & 5 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,13 +1167,14 @@ class Categorical(Discrete):
def dist(cls, p, **kwargs):

p = at.as_tensor_variable(floatX(p))

# mode = at.argmax(p, axis=-1)
# if mode.ndim == 1:
# mode = at.squeeze(mode)

return super().dist([p], **kwargs)

def get_moment(rv, size, p):
mode = at.argmax(p, axis=-1)
if not rv_size_is_none(size):
mode = at.full(size, mode)
return mode

def logp(value, p):
r"""
Calculate log-probability of Categorical distribution at specified value.
Expand Down
20 changes: 20 additions & 0 deletions pymc/tests/test_distributions_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Beta,
BetaBinomial,
Binomial,
Categorical,
Cauchy,
ChiSquared,
Constant,
Expand Down Expand Up @@ -728,3 +729,22 @@ def test_logitnormal_moment(mu, sigma, size, expected):
with Model() as model:
LogitNormal("x", mu=mu, sigma=sigma, size=size)
assert_moment_is_expected(model, expected)


@pytest.mark.parametrize(
"p, size, expected",
[
(np.array([0.1, 0.3, 0.6]), None, 2),
(np.array([0.6, 0.1, 0.3]), 5, np.full(5, 0)),
(np.full((2, 3), np.array([0.6, 0.1, 0.3])), None, [0, 0]),
(
np.full((2, 3), np.array([0.1, 0.3, 0.6])),
(3, 2),
np.full((3, 2), [2, 2]),
),
],
)
def test_categorical_moment(p, size, expected):
with Model() as model:
Categorical("x", p=p, size=size)
assert_moment_is_expected(model, expected)

0 comments on commit 140dab0

Please sign in to comment.