Skip to content

Commit 5d7283e

Browse files
committed
Add FutureWarning when accessing RV mappings via tag
1 parent e470d13 commit 5d7283e

File tree

6 files changed

+135
-2
lines changed

6 files changed

+135
-2
lines changed

pymc/distributions/distribution.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
shape_from_dims,
5050
)
5151
from pymc.printing import str_for_dist
52-
from pymc.util import UNSET
52+
from pymc.util import UNSET, _add_future_warning_tag
5353
from pymc.vartypes import string_types
5454

5555
__all__ = [
@@ -371,6 +371,7 @@ def dist(
371371
rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
372372
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
373373
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")
374+
_add_future_warning_tag(rv_out)
374375
return rv_out
375376

376377

pymc/distributions/shape_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151
from pymc.aesaraf import PotentialShapeType
5252
from pymc.exceptions import ShapeError
53+
from pymc.util import _add_future_warning_tag
5354

5455

5556
def to_tuple(shape):
@@ -600,6 +601,7 @@ def change_dist_size(
600601
new_size = tuple(new_size) # type: ignore
601602

602603
new_dist = _change_dist_size(dist.owner.op, dist, new_size=new_size, expand=expand)
604+
_add_future_warning_tag(new_dist)
603605

604606
new_dist.name = dist.name
605607
for k, v in dist.tag.__dict__.items():

pymc/model.py

+3
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from pymc.util import (
6868
UNSET,
6969
WithMemoization,
70+
_add_future_warning_tag,
7071
get_transformed_name,
7172
get_value_vars_from_user_vars,
7273
get_var_name,
@@ -1314,6 +1315,7 @@ def register_rv(
13141315
"""
13151316
name = self.name_for(name)
13161317
rv_var.name = name
1318+
_add_future_warning_tag(rv_var)
13171319
rv_var.tag.total_size = total_size
13181320
self.rvs_to_total_sizes[rv_var] = total_size
13191321

@@ -1495,6 +1497,7 @@ def create_value_var(
14951497
if aesara.config.compute_test_value != "off":
14961498
value_var.tag.test_value = rv_var.tag.test_value
14971499

1500+
_add_future_warning_tag(value_var)
14981501
rv_var.tag.value_var = value_var
14991502

15001503
# Make the value variable a transformed value variable,

pymc/tests/distributions/test_distribution.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@
2828

2929
from pymc.distributions import DiracDelta, Flat, MvNormal, MvStudentT, logp
3030
from pymc.distributions.distribution import SymbolicRandomVariable, _moment, moment
31-
from pymc.distributions.shape_utils import to_tuple
31+
from pymc.distributions.shape_utils import change_dist_size, to_tuple
3232
from pymc.tests.distributions.util import assert_moment_is_expected
33+
from pymc.util import _FutureWarningValidatingScratchpad
3334

3435

3536
class TestBugfixes:
@@ -358,3 +359,40 @@ class TestSymbolicRV(SymbolicRandomVariable):
358359
dirac_delta_2_ = DiracDelta.dist(10)
359360
node = TestSymbolicRV([], [dirac_delta_1_, dirac_delta_2_], ndim_supp=0)().owner
360361
assert get_measurable_outputs(node.op, node) == [node.outputs[default_output_idx]]
362+
363+
364+
def test_tag_future_warning_dist():
365+
# Test no unexpected warnings
366+
with warnings.catch_warnings():
367+
warnings.simplefilter("error")
368+
369+
x = pm.Normal.dist()
370+
assert isinstance(x.tag, _FutureWarningValidatingScratchpad)
371+
372+
x.tag.banana = "banana"
373+
assert x.tag.banana == "banana"
374+
375+
# Check we didn't break test_value filtering
376+
x.tag.test_value = np.array(1)
377+
assert x.tag.test_value == 1
378+
with pytest.raises(TypeError, match="Wrong number of dimensions"):
379+
x.tag.test_value = np.array([1, 1])
380+
assert x.tag.test_value == 1
381+
382+
# No warning if deprecated attribute is not present
383+
with pytest.raises(AttributeError):
384+
x.tag.value_var
385+
386+
# Warning if present
387+
x.tag.value_var = "1"
388+
with pytest.warns(FutureWarning, match="Use model.rvs_to_values"):
389+
value_var = x.tag.value_var
390+
assert value_var == "1"
391+
392+
# Check that PyMC method that copies tag contents does not erase special tag
393+
new_x = change_dist_size(x, new_size=5)
394+
assert new_x.tag is not x.tag
395+
assert isinstance(new_x.tag, _FutureWarningValidatingScratchpad)
396+
with pytest.warns(FutureWarning, match="Use model.rvs_to_values"):
397+
value_var = new_x.tag.value_var
398+
assert value_var == "1"

pymc/tests/test_model.py

+58
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,12 @@
4141
from pymc.blocking import DictToArrayBijection, RaveledVars
4242
from pymc.distributions import Normal, transforms
4343
from pymc.distributions.logprob import _joint_logp
44+
from pymc.distributions.transforms import log
4445
from pymc.exceptions import ImputationWarning, ShapeError, ShapeWarning
4546
from pymc.model import Point, ValueGradFunction, modelcontext
4647
from pymc.tests.helpers import SeededTest
4748
from pymc.tests.models import simple_model
49+
from pymc.util import _FutureWarningValidatingScratchpad
4850

4951

5052
class NewModel(pm.Model):
@@ -1406,3 +1408,59 @@ def test_deterministic(self):
14061408
assert np.all(
14071409
np.isclose(model.compile_logp(sum=False)({}), st.norm().logpdf(data_values))
14081410
)
1411+
1412+
1413+
def test_tag_future_warning_model():
1414+
# Test no unexpected warnings
1415+
with warnings.catch_warnings():
1416+
warnings.simplefilter("error")
1417+
1418+
model = pm.Model()
1419+
1420+
x = at.random.normal()
1421+
x.tag.something_else = "5"
1422+
x.tag.test_value = 0
1423+
assert not isinstance(x.tag, _FutureWarningValidatingScratchpad)
1424+
1425+
# Test that model changes the tag type, but copies exsiting contents
1426+
x = model.register_rv(x, name="x", transform=log)
1427+
assert isinstance(x.tag, _FutureWarningValidatingScratchpad)
1428+
assert x.tag.something_else == "5"
1429+
assert x.tag.test_value == 0
1430+
1431+
# Test expected warnings
1432+
with pytest.warns(FutureWarning, match="model.rvs_to_values"):
1433+
x_value = x.tag.value_var
1434+
1435+
assert isinstance(x_value.tag, _FutureWarningValidatingScratchpad)
1436+
with pytest.warns(FutureWarning, match="model.rvs_to_transforms"):
1437+
transform = x_value.tag.transform
1438+
assert transform is log
1439+
1440+
with pytest.raises(AttributeError):
1441+
x.tag.observations
1442+
1443+
with pytest.warns(FutureWarning, match="model.rvs_to_total_sizes"):
1444+
total_size = x.tag.total_size
1445+
assert total_size is None
1446+
1447+
# Cloning a node will keep the same tag type and contents
1448+
y = x.owner.clone().default_output()
1449+
assert y is not x
1450+
assert y.tag is not x.tag
1451+
assert isinstance(y.tag, _FutureWarningValidatingScratchpad)
1452+
y = model.register_rv(y, name="y", data=5)
1453+
assert isinstance(y.tag, _FutureWarningValidatingScratchpad)
1454+
1455+
# Test expected warnings
1456+
with pytest.warns(FutureWarning, match="model.rvs_to_values"):
1457+
y_value = y.tag.value_var
1458+
with pytest.warns(FutureWarning, match="model.rvs_to_values"):
1459+
y_obs = y.tag.observations
1460+
assert y_value is y_obs
1461+
assert y_value.eval() == 5
1462+
1463+
assert isinstance(y_value.tag, _FutureWarningValidatingScratchpad)
1464+
with pytest.warns(FutureWarning, match="model.rvs_to_total_sizes"):
1465+
total_size = y.tag.total_size
1466+
assert total_size is None

pymc/util.py

+31
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import functools
16+
import warnings
1617

1718
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
1819

@@ -23,6 +24,7 @@
2324

2425
from aesara import Variable
2526
from aesara.compile import SharedVariable
27+
from aesara.graph.utils import ValidatingScratchpad
2628
from cachetools import LRUCache, cachedmethod
2729

2830

@@ -481,3 +483,32 @@ def get_value_vars_from_user_vars(
481483
)
482484

483485
return value_vars
486+
487+
488+
class _FutureWarningValidatingScratchpad(ValidatingScratchpad):
489+
def __getattribute__(self, name):
490+
for deprecated_names, alternative in (
491+
(("value_var", "observations"), "model.rvs_to_values[rv]"),
492+
(("transform",), "model.rvs_to_transforms[rv]"),
493+
(("total_size",), "model.rvs_to_total_sizes[rv]"),
494+
):
495+
if name in deprecated_names:
496+
try:
497+
super().__getattribute__(name)
498+
except AttributeError:
499+
pass
500+
else:
501+
warnings.warn(
502+
f"The tag attribute {name} is deprecated. Use {alternative} instead",
503+
FutureWarning,
504+
)
505+
return super().__getattribute__(name)
506+
507+
508+
def _add_future_warning_tag(var) -> None:
509+
old_tag = var.tag
510+
if not isinstance(old_tag, _FutureWarningValidatingScratchpad):
511+
new_tag = _FutureWarningValidatingScratchpad("test_value", var.type.filter)
512+
for k, v in old_tag.__dict__.items():
513+
new_tag.__dict__.setdefault(k, v)
514+
var.tag = new_tag

0 commit comments

Comments
 (0)