Skip to content

Commit

Permalink
Remove deprecated Distribution kwargs (#7488)
Browse files Browse the repository at this point in the history
* Remove deprecated Distribution kwargs

Removing these to reduce cognitive load for an
eventual migration to a function-based
distribution.

These were deprecated in #5109, in 2021, as part
of pymc 4 being released. We're on 5.x so these
should be safe.

* Replace deprecated testval arg with initval in test

I keep the test, since it seems to cover behaviour not tested elsewhere.
  • Loading branch information
thomasaarholt authored Sep 3, 2024
1 parent 253513b commit d596afb
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 100 deletions.
44 changes: 0 additions & 44 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,6 @@ class DistributionMeta(ABCMeta):
"""

def __new__(cls, name, bases, clsdict):
# Forcefully deprecate old v3 `Distribution`s
if "random" in clsdict:

def _random(*args, **kwargs):
warnings.warn(
"The old `Distribution.random` interface is deprecated.",
FutureWarning,
stacklevel=2,
)
return clsdict["random"](*args, **kwargs)

clsdict["random"] = _random

rv_op = clsdict.setdefault("rv_op", None)
rv_type = clsdict.setdefault("rv_type", None)

Expand Down Expand Up @@ -206,13 +193,6 @@ def support_point(op, rv, *dist_params):
return new_cls


def _make_nice_attr_error(oldcode: str, newcode: str):
def fn(*args, **kwargs):
raise AttributeError(f"The `{oldcode}` method was removed. Instead use `{newcode}`.`")

return fn


class _class_or_instancemethod(classmethod):
"""Allow a method to be called both as a classmethod and an instancemethod,
giving priority to the instancemethod.
Expand Down Expand Up @@ -510,14 +490,6 @@ def __new__(
"for a standalone distribution."
)

if "testval" in kwargs:
initval = kwargs.pop("testval")
warnings.warn(
"The `testval` argument is deprecated; use `initval`.",
FutureWarning,
stacklevel=2,
)

if not isinstance(name, string_types):
raise TypeError(f"Name needs to be a string but got: {name}")

Expand Down Expand Up @@ -551,10 +523,6 @@ def __new__(
rv_out._repr_latex_ = types.MethodType(
functools.partial(str_for_dist, formatting="latex"), rv_out
)

rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")
return rv_out

@classmethod
Expand Down Expand Up @@ -582,15 +550,6 @@ def dist(
rv : TensorVariable
The created random variable tensor.
"""
if "testval" in kwargs:
kwargs.pop("testval")
warnings.warn(
"The `.dist(testval=...)` argument is deprecated and has no effect. "
"Initial values for sampling/optimization can be specified with `initval` in a modelcontext. "
"For using PyTensor's test value features, you must assign the `.tag.test_value` yourself.",
FutureWarning,
stacklevel=2,
)
if "initval" in kwargs:
raise TypeError(
"Unexpected keyword argument `initval`. "
Expand All @@ -617,9 +576,6 @@ def dist(
create_size = find_size(shape=shape, size=size, ndim_supp=ndim_supp)
rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)

rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")
_add_future_warning_tag(rv_out)
return rv_out

Expand Down
25 changes: 0 additions & 25 deletions tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,31 +84,6 @@ def test_issue_4499(self):
npt.assert_almost_equal(m.compile_logp()({"x": np.ones(10)}), 0 * 10)


@pytest.mark.parametrize(
"method,newcode",
[
("logp", r"pm.logp\(rv, x\)"),
("logcdf", r"pm.logcdf\(rv, x\)"),
("random", r"pm.draw\(rv\)"),
],
)
def test_logp_gives_migration_instructions(method, newcode):
rv = pm.Normal.dist()
f = getattr(rv, method)
with pytest.raises(AttributeError, match=rf"use `{newcode}`"):
f()

# A dim-induced resize of the rv created by the `.dist()` API,
# happening in Distribution.__new__ would make us loose the monkeypatches.
# So this triggers it to test if the monkeypatch still works.
with pm.Model(coords={"year": [2019, 2021, 2022]}):
rv = pm.Normal("n", dims="year")
f = getattr(rv, method)
with pytest.raises(AttributeError, match=rf"use `{newcode}`"):
f()
pass


def test_all_distributions_have_support_points():
import pymc.distributions as dist_module

Expand Down
4 changes: 2 additions & 2 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,8 +660,8 @@ def test_initial_point():

b_initval = np.array(0.3, dtype=pytensor.config.floatX)

with pytest.warns(FutureWarning), model:
b = pm.Uniform("b", testval=b_initval)
with model:
b = pm.Uniform("b", initval=b_initval)

b_initval_trans = model.rvs_to_transforms[b].forward(b_initval, *b.owner.inputs).eval()

Expand Down
29 changes: 0 additions & 29 deletions tests/test_initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
import cloudpickle
import numpy as np
import numpy.testing as npt
import pytensor
import pytensor.tensor as pt
import pytest
Expand All @@ -34,34 +33,6 @@ def transform_back(rv, transformed, model) -> np.ndarray:
return model.rvs_to_transforms[rv].backward(transformed, *rv.owner.inputs).eval()


class TestInitvalAssignment:
def test_dist_warnings_and_errors(self):
with pytest.warns(FutureWarning, match="argument is deprecated and has no effect"):
rv = pm.Exponential.dist(lam=1, testval=0.5)
assert not hasattr(rv.tag, "test_value")

with pytest.raises(TypeError, match="Unexpected keyword argument `initval`."):
pm.Normal.dist(1, 2, initval=None)
pass

def test_new_warnings(self):
with pm.Model() as pmodel:
with pytest.warns(FutureWarning, match="`testval` argument is deprecated"):
rv = pm.Uniform("u", 0, 1, testval=0.75)
initial_point = pmodel.initial_point(random_seed=0)
npt.assert_allclose(
initial_point["u_interval__"], transform_fwd(rv, 0.75, model=pmodel)
)
assert not hasattr(rv.tag, "test_value")
pass

def test_valid_string_strategy(self):
with pm.Model() as pmodel:
pm.Uniform("x", 0, 1, size=2, initval="unknown")
with pytest.raises(ValueError, match="Invalid string strategy: unknown"):
pmodel.initial_point(random_seed=0)


class TestInitvalEvaluation:
def test_make_initial_point_fns_per_chain_checks_kwargs(self):
with pm.Model() as pmodel:
Expand Down

0 comments on commit d596afb

Please sign in to comment.