Skip to content

Commit

Permalink
UserWarning if doing predictive sampling with models containing Poten…
Browse files Browse the repository at this point in the history
…tials (#4419)

* Raise warning when sampling with Potentials

* Add warning to fast_sample_ppc and unittests

* Add release note

* Avoid sampling in unittests
  • Loading branch information
ricardoV94 authored Jan 16, 2021
1 parent 2a3d9a3 commit 4e2c099
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 0 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ It also brings some dreadfully awaited fixes, so be sure to go through the chang
- Fixed `MatrixNormal` random method to work with parameters as random variables. (see [#4368](https://github.com/pymc-devs/pymc3/pull/4368))
- Update the `logcdf` method of several continuous distributions to return -inf for invalid parameters and values, and raise an informative error when multiple values cannot be evaluated in a single call. (see [4393](https://github.com/pymc-devs/pymc3/pull/4393))
- Improve numerical stability in `logp` and `logcdf` methods of `ExGaussian` (see [#4407](https://github.com/pymc-devs/pymc3/pull/4407))
- Issue UserWarning when doing prior or posterior predictive sampling with models containing Potential factors (see [#4419](https://github.com/pymc-devs/pymc3/pull/4419))

## PyMC3 3.10.0 (7 December 2020)

Expand Down
8 changes: 8 additions & 0 deletions pymc3/distributions/posterior_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,14 @@ def fast_sample_posterior_predictive(

model = modelcontext(model)
assert model is not None

if model.potentials:
warnings.warn(
"The effect of Potentials on other parameters is ignored during posterior predictive sampling. "
"This is likely to lead to invalid or biased predictive samples.",
UserWarning,
)

with model:

if keep_size and samples is not None:
Expand Down
23 changes: 23 additions & 0 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1692,6 +1692,13 @@ def sample_posterior_predictive(

model = modelcontext(model)

if model.potentials:
warnings.warn(
"The effect of Potentials on other parameters is ignored during posterior predictive sampling. "
"This is likely to lead to invalid or biased predictive samples.",
UserWarning,
)

if var_names is not None:
vars_ = [model[x] for x in var_names]
else:
Expand Down Expand Up @@ -1791,6 +1798,15 @@ def sample_posterior_predictive_w(
if models is None:
models = [modelcontext(models)] * len(traces)

for model in models:
if model.potentials:
warnings.warn(
"The effect of Potentials on other parameters is ignored during posterior predictive sampling. "
"This is likely to lead to invalid or biased predictive samples.",
UserWarning,
)
break

if weights is None:
weights = [1] * len(traces)

Expand Down Expand Up @@ -1903,6 +1919,13 @@ def sample_prior_predictive(
"""
model = modelcontext(model)

if model.potentials:
warnings.warn(
"The effect of Potentials on other parameters is ignored during prior predictive sampling. "
"This is likely to lead to invalid or biased predictive samples.",
UserWarning,
)

if var_names is None:
prior_pred_vars = model.observed_RVs
prior_vars = (
Expand Down
36 changes: 36 additions & 0 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,21 @@ def test_variable_type(self):
assert ppc["a"].dtype.kind == "f"
assert ppc["b"].dtype.kind == "i"

def test_potentials_warning(self):
warning_msg = "The effect of Potentials on other parameters is ignored during"
with pm.Model() as m:
a = pm.Normal("a", 0, 1)
p = pm.Potential("p", a + 1)
obs = pm.Normal("obs", a, 1, observed=5)

trace = az.from_dict({"a": np.random.rand(10)})
with m:
with pytest.warns(UserWarning, match=warning_msg):
pm.sample_posterior_predictive(trace, samples=5)

with pytest.warns(UserWarning, match=warning_msg):
pm.fast_sample_posterior_predictive(trace, samples=5)


class TestSamplePPCW(SeededTest):
def test_sample_posterior_predictive_w(self):
Expand Down Expand Up @@ -773,6 +788,17 @@ def test_sample_posterior_predictive_w(self):
):
pm.sample_posterior_predictive_w([trace_0, trace_2], 100, [model_0, model_2])

def test_potentials_warning(self):
warning_msg = "The effect of Potentials on other parameters is ignored during"
with pm.Model() as m:
a = pm.Normal("a", 0, 1)
p = pm.Potential("p", a + 1)
obs = pm.Normal("obs", a, 1, observed=5)

trace = az.from_dict({"a": np.random.rand(10)})
with pytest.warns(UserWarning, match=warning_msg):
pm.sample_posterior_predictive_w(samples=5, traces=[trace, trace], models=[m, m])


@pytest.mark.parametrize(
"method",
Expand Down Expand Up @@ -1012,6 +1038,16 @@ def test_bounded_dist(self):
prior_trace = pm.sample_prior_predictive(5)
assert prior_trace["x"].shape == (5, 3, 1)

def test_potentials_warning(self):
warning_msg = "The effect of Potentials on other parameters is ignored during"
with pm.Model() as m:
a = pm.Normal("a", 0, 1)
p = pm.Potential("p", a + 1)

with m:
with pytest.warns(UserWarning, match=warning_msg):
pm.sample_prior_predictive(samples=5)


class TestSamplePosteriorPredictive:
def test_point_list_arg_bug_fspp(self, point_list_arg_bug_fixture):
Expand Down

0 comments on commit 4e2c099

Please sign in to comment.