Skip to content

Commit

Permalink
Parametrize Binomial and Categorical distributions via logit_p (#5637)
Browse files Browse the repository at this point in the history
  • Loading branch information
purna135 authored Mar 21, 2022
1 parent 5b31ec7 commit 52682eb
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 6 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
- Adding support for blackjax's NUTS sampler `pymc.sampling_jax` (see [#5477](ihttps://github.com/pymc-devs/pymc/pull/5477))
- `pymc.sampling_jax` samplers support `log_likelihood`, `observed_data`, and `sample_stats` in returned InferenceData object (see [#5189](https://github.com/pymc-devs/pymc/pull/5189))
- Adding support for `pm.Deterministic` in `pymc.sampling_jax` (see [#5182](https://github.com/pymc-devs/pymc/pull/5182))
- Added an alternative parametrization, `logit_p` to `pm.Binomial` and `pm.Categorical` distributions (see [5637](https://github.com/pymc-devs/pymc/pull/5637)).
- ...


Expand Down
30 changes: 25 additions & 5 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,25 @@ class Binomial(Discrete):
Parameters
----------
n: int
n : int
Number of Bernoulli trials (n >= 0).
p: float
p : float
Probability of success in each trial (0 < p < 1).
logit_p : float
Alternative log odds for the probability of success.
"""
rv_op = binomial

@classmethod
def dist(cls, n, p, *args, **kwargs):
def dist(cls, n, p=None, logit_p=None, *args, **kwargs):
if p is not None and logit_p is not None:
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
elif p is None and logit_p is None:
raise ValueError("Incompatible parametrization. Must specify either p or logit_p.")

if logit_p is not None:
p = at.sigmoid(logit_p)

n = at.as_tensor_variable(intX(n))
p = at.as_tensor_variable(floatX(p))
return super().dist([n, p], **kwargs)
Expand Down Expand Up @@ -1245,14 +1255,24 @@ class Categorical(Discrete):
Parameters
----------
p: array of floats
p : array of floats
p > 0 and the elements of p must sum to 1. They will be automatically
rescaled otherwise.
logit_p : float
Alternative log odds for the probability of success.
"""
rv_op = categorical

@classmethod
def dist(cls, p, **kwargs):
def dist(cls, p=None, logit_p=None, **kwargs):
if p is not None and logit_p is not None:
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
elif p is None and logit_p is None:
raise ValueError("Incompatible parametrization. Must specify either p or logit_p.")

if logit_p is not None:
p = pm.math.softmax(logit_p, axis=-1)

if isinstance(p, np.ndarray) or isinstance(p, list):
if (np.asarray(p) < 0).any():
raise ValueError(f"Negative `p` parameters are not valid, got: {p}")
Expand Down
45 changes: 44 additions & 1 deletion pymc/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def random_polyagamma(*args, **kwargs):
raise RuntimeError("polyagamma package is not installed!")


from scipy.special import expit
from scipy.special import expit, softmax

import pymc as pm

Expand Down Expand Up @@ -1006,6 +1006,25 @@ class TestBinomial(BaseTestDistributionRandom):
checks_to_run = ["check_pymc_params_match_rv_op"]


class TestLogitBinomial(BaseTestDistributionRandom):
pymc_dist = pm.Binomial
pymc_dist_params = {"n": 100, "logit_p": 0.5}
expected_rv_op_params = {"n": 100, "p": expit(0.5)}
tests_to_run = ["check_pymc_params_match_rv_op"]

@pytest.mark.parametrize(
"n, p, logit_p, expected",
[
(5, None, None, "Must specify either p or logit_p."),
(5, 0.5, 0.5, "Can't specify both p and logit_p."),
],
)
def test_binomial_init_fail(self, n, p, logit_p, expected):
with pm.Model() as model:
with pytest.raises(ValueError, match=f"Incompatible parametrization. {expected}"):
pm.Binomial("x", n=n, p=p, logit_p=logit_p)


class TestNegativeBinomial(BaseTestDistributionRandom):
pymc_dist = pm.NegativeBinomial
pymc_dist_params = {"n": 100, "p": 0.33}
Expand Down Expand Up @@ -1411,6 +1430,30 @@ class TestCategorical(BaseTestDistributionRandom):
]


class TestLogitCategorical(BaseTestDistributionRandom):
pymc_dist = pm.Categorical
pymc_dist_params = {"logit_p": np.array([[0.28, 0.62, 0.10], [0.28, 0.62, 0.10]])}
expected_rv_op_params = {
"p": softmax(np.array([[0.28, 0.62, 0.10], [0.28, 0.62, 0.10]]), axis=-1)
}
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_rv_size",
]

@pytest.mark.parametrize(
"p, logit_p, expected",
[
(None, None, "Must specify either p or logit_p."),
(0.5, 0.5, "Can't specify both p and logit_p."),
],
)
def test_categorical_init_fail(self, p, logit_p, expected):
with pm.Model() as model:
with pytest.raises(ValueError, match=f"Incompatible parametrization. {expected}"):
pm.Categorical("x", p=p, logit_p=logit_p)


class TestGeometric(BaseTestDistributionRandom):
pymc_dist = pm.Geometric
pymc_dist_params = {"p": 0.9}
Expand Down

0 comments on commit 52682eb

Please sign in to comment.