Skip to content

Commit

Permalink
Monkeypatch instructions for migrating away from the old rv.logp API
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelosthege authored and ricardoV94 committed Sep 23, 2021
1 parent 55d455a commit 8e54fc9
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
13 changes: 13 additions & 0 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ def get_moment(op, rv, size, *rv_inputs):
return new_cls


def _make_nice_attr_error(oldcode: str, newcode: str):
def fn(*args, **kwargs):
raise AttributeError(f"The `{oldcode}` method was removed. Instead use `{newcode}`.`")

return fn


class Distribution(metaclass=DistributionMeta):
"""Statistical distribution"""

Expand Down Expand Up @@ -243,6 +250,9 @@ def __new__(
functools.partial(str_for_dist, formatting="latex"), rv_out
)

rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()")
return rv_out

@classmethod
Expand Down Expand Up @@ -333,6 +343,9 @@ def dist(
rv_out.update = (rng, new_rng)
rng.default_update = new_rng

rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()")
return rv_out


Expand Down
25 changes: 25 additions & 0 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3262,3 +3262,28 @@ def test_distinct_rvs():
pp_samples_2 = pm.sample_prior_predictive(samples=2)

assert np.array_equal(pp_samples["y"], pp_samples_2["y"])


@pytest.mark.parametrize(
"method,newcode",
[
("logp", r"pm.logp\(rv, x\)"),
("logcdf", r"pm.logcdf\(rv, x\)"),
("random", r"rv.eval\(\)"),
],
)
def test_logp_gives_migration_instructions(method, newcode):
rv = pm.Normal.dist()
f = getattr(rv, method)
with pytest.raises(AttributeError, match=rf"use `{newcode}`"):
f()

# A dim-induced resize of the rv created by the `.dist()` API,
# happening in Distribution.__new__ would make us loose the monkeypatches.
# So this triggers it to test if the monkeypatch still works.
with pm.Model(coords={"year": [2019, 2021, 2022]}):
rv = pm.Normal("n", dims="year")
f = getattr(rv, method)
with pytest.raises(AttributeError, match=rf"use `{newcode}`"):
f()
pass

0 comments on commit 8e54fc9

Please sign in to comment.