|
41 | 41 | from pymc.blocking import DictToArrayBijection, RaveledVars
|
42 | 42 | from pymc.distributions import Normal, transforms
|
43 | 43 | from pymc.distributions.logprob import _joint_logp
|
| 44 | +from pymc.distributions.transforms import log |
44 | 45 | from pymc.exceptions import ImputationWarning, ShapeError, ShapeWarning
|
45 | 46 | from pymc.model import Point, ValueGradFunction, modelcontext
|
46 | 47 | from pymc.tests.helpers import SeededTest
|
47 | 48 | from pymc.tests.models import simple_model
|
| 49 | +from pymc.util import _FutureWarningValidatingScratchpad |
48 | 50 |
|
49 | 51 |
|
50 | 52 | class NewModel(pm.Model):
|
@@ -1406,3 +1408,59 @@ def test_deterministic(self):
|
1406 | 1408 | assert np.all(
|
1407 | 1409 | np.isclose(model.compile_logp(sum=False)({}), st.norm().logpdf(data_values))
|
1408 | 1410 | )
|
| 1411 | + |
| 1412 | + |
| 1413 | +def test_tag_future_warning_model(): |
| 1414 | + # Test no unexpected warnings |
| 1415 | + with warnings.catch_warnings(): |
| 1416 | + warnings.simplefilter("error") |
| 1417 | + |
| 1418 | + model = pm.Model() |
| 1419 | + |
| 1420 | + x = at.random.normal() |
| 1421 | + x.tag.something_else = "5" |
| 1422 | + x.tag.test_value = 0 |
| 1423 | + assert not isinstance(x.tag, _FutureWarningValidatingScratchpad) |
| 1424 | + |
| 1425 | + # Test that model changes the tag type, but copies exsiting contents |
| 1426 | + x = model.register_rv(x, name="x", transform=log) |
| 1427 | + assert isinstance(x.tag, _FutureWarningValidatingScratchpad) |
| 1428 | + assert x.tag.something_else == "5" |
| 1429 | + assert x.tag.test_value == 0 |
| 1430 | + |
| 1431 | + # Test expected warnings |
| 1432 | + with pytest.warns(FutureWarning, match="model.rvs_to_values"): |
| 1433 | + x_value = x.tag.value_var |
| 1434 | + |
| 1435 | + assert isinstance(x_value.tag, _FutureWarningValidatingScratchpad) |
| 1436 | + with pytest.warns(FutureWarning, match="model.rvs_to_transforms"): |
| 1437 | + transform = x_value.tag.transform |
| 1438 | + assert transform is log |
| 1439 | + |
| 1440 | + with pytest.raises(AttributeError): |
| 1441 | + x.tag.observations |
| 1442 | + |
| 1443 | + with pytest.warns(FutureWarning, match="model.rvs_to_total_sizes"): |
| 1444 | + total_size = x.tag.total_size |
| 1445 | + assert total_size is None |
| 1446 | + |
| 1447 | + # Cloning a node will keep the same tag type and contents |
| 1448 | + y = x.owner.clone().default_output() |
| 1449 | + assert y is not x |
| 1450 | + assert y.tag is not x.tag |
| 1451 | + assert isinstance(y.tag, _FutureWarningValidatingScratchpad) |
| 1452 | + y = model.register_rv(y, name="y", data=5) |
| 1453 | + assert isinstance(y.tag, _FutureWarningValidatingScratchpad) |
| 1454 | + |
| 1455 | + # Test expected warnings |
| 1456 | + with pytest.warns(FutureWarning, match="model.rvs_to_values"): |
| 1457 | + y_value = y.tag.value_var |
| 1458 | + with pytest.warns(FutureWarning, match="model.rvs_to_values"): |
| 1459 | + y_obs = y.tag.observations |
| 1460 | + assert y_value is y_obs |
| 1461 | + assert y_value.eval() == 5 |
| 1462 | + |
| 1463 | + assert isinstance(y_value.tag, _FutureWarningValidatingScratchpad) |
| 1464 | + with pytest.warns(FutureWarning, match="model.rvs_to_total_sizes"): |
| 1465 | + total_size = y.tag.total_size |
| 1466 | + assert total_size is None |
0 commit comments