From 140dab0199dfb751951ba99175295c07feb00264 Mon Sep 17 00:00:00 2001 From: Farhan Reynaldo Date: Sat, 13 Nov 2021 22:41:00 +0700 Subject: [PATCH] add categorical moment (#5176) Co-authored-by: Farhan Reynaldo --- pymc/distributions/discrete.py | 11 ++++++----- pymc/tests/test_distributions_moments.py | 20 ++++++++++++++++++++ 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index e190bf7a12e..bfe03f81160 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -1167,13 +1167,14 @@ class Categorical(Discrete): def dist(cls, p, **kwargs): p = at.as_tensor_variable(floatX(p)) - - # mode = at.argmax(p, axis=-1) - # if mode.ndim == 1: - # mode = at.squeeze(mode) - return super().dist([p], **kwargs) + def get_moment(rv, size, p): + mode = at.argmax(p, axis=-1) + if not rv_size_is_none(size): + mode = at.full(size, mode) + return mode + def logp(value, p): r""" Calculate log-probability of Categorical distribution at specified value. diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 8897080b86b..22269736274 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -8,6 +8,7 @@ Beta, BetaBinomial, Binomial, + Categorical, Cauchy, ChiSquared, Constant, @@ -728,3 +729,22 @@ def test_logitnormal_moment(mu, sigma, size, expected): with Model() as model: LogitNormal("x", mu=mu, sigma=sigma, size=size) assert_moment_is_expected(model, expected) + + +@pytest.mark.parametrize( + "p, size, expected", + [ + (np.array([0.1, 0.3, 0.6]), None, 2), + (np.array([0.6, 0.1, 0.3]), 5, np.full(5, 0)), + (np.full((2, 3), np.array([0.6, 0.1, 0.3])), None, [0, 0]), + ( + np.full((2, 3), np.array([0.1, 0.3, 0.6])), + (3, 2), + np.full((3, 2), [2, 2]), + ), + ], +) +def test_categorical_moment(p, size, expected): + with Model() as model: + Categorical("x", p=p, size=size) + assert_moment_is_expected(model, expected)