Skip to content

Commit

Permalink
🔥 remove vars from sample_posterior_predictive, other random refactor…
Browse files Browse the repository at this point in the history
…ings
  • Loading branch information
MarcoGorelli committed Dec 14, 2020
1 parent dbcc49e commit ca0c28b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 32 deletions.
2 changes: 1 addition & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ This is the first release to support Python3.9 and to drop Python3.6.
- Fixed bug whereby partial traces returns after keyboard interrupt during parallel sampling had fewer draws than would've been available [#4318](https://github.com/pymc-devs/pymc3/pull/4318)
- Make `sample_shape` same across all contexts in `draw_values` (see [#4305](https://github.com/pymc-devs/pymc3/pull/4305)).
- Removed `theanof.set_theano_config` because it illegally touched Theano's privates (see [#4329](https://github.com/pymc-devs/pymc3/pull/4329)).

- In `sample_posterior_predictive` the `vars` kwarg was removed in favor of `var_names` (see [#4343](https://github.com/pymc-devs/pymc3/pull/4343)).

## PyMC3 3.10.0 (7 December 2020)

Expand Down
28 changes: 10 additions & 18 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

from arviz import InferenceData
from fastprogress.fastprogress import progress_bar
from theano.tensor import Tensor

import pymc3 as pm

Expand Down Expand Up @@ -561,12 +560,11 @@ def sample(
_log.debug("Pickling error:", exec_info=True)
parallel = False
except AttributeError as e:
if str(e).startswith("AttributeError: Can't pickle"):
_log.warning("Could not pickle model, sampling singlethreaded.")
_log.debug("Pickling error:", exec_info=True)
parallel = False
else:
if not str(e).startswith("AttributeError: Can't pickle"):
raise
_log.warning("Could not pickle model, sampling singlethreaded.")
_log.debug("Pickling error:", exec_info=True)
parallel = False
if not parallel:
if has_population_samplers:
has_demcmc = np.any(
Expand Down Expand Up @@ -1602,7 +1600,6 @@ def sample_posterior_predictive(
trace,
samples: Optional[int] = None,
model: Optional[Model] = None,
vars: Optional[Iterable[Tensor]] = None,
var_names: Optional[List[str]] = None,
size: Optional[int] = None,
keep_size: Optional[bool] = False,
Expand Down Expand Up @@ -1696,14 +1693,9 @@ def sample_posterior_predictive(
model = modelcontext(model)

if var_names is not None:
if vars is not None:
raise IncorrectArgumentsError("Should not specify both vars and var_names arguments.")
else:
vars = [model[x] for x in var_names]
elif vars is not None: # var_names is None, and vars is not.
warnings.warn("vars argument is deprecated in favor of var_names.", DeprecationWarning)
if vars is None:
vars = model.observed_RVs
vars_ = [model[x] for x in var_names]
else:
vars_ = model.observed_RVs

if random_seed is not None:
np.random.seed(random_seed)
Expand All @@ -1729,8 +1721,8 @@ def sample_posterior_predictive(
else:
param = _trace[idx % len_trace]

values = draw_values(vars, point=param, size=size)
for k, v in zip(vars, values):
values = draw_values(vars_, point=param, size=size)
for k, v in zip(vars_, values):
ppc_trace_t.insert(k.name, v, idx)
except KeyboardInterrupt:
pass
Expand Down Expand Up @@ -1809,7 +1801,7 @@ def sample_posterior_predictive_w(
raise ValueError("The number of models and weights should be the same")

length_morv = len(models[0].observed_RVs)
if not all(len(i.observed_RVs) == length_morv for i in models):
if any(len(i.observed_RVs) != length_morv for i in models):
raise ValueError("The number of observed RVs should be the same for all models")

weights = np.asarray(weights)
Expand Down
14 changes: 1 addition & 13 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,7 @@ def test_normal_scalar(self):
ppc0 = pm.sample_posterior_predictive([model.test_point], samples=10)
ppc0 = pm.fast_sample_posterior_predictive([model.test_point], samples=10)
# deprecated argument is not introduced to fast version [2019/08/20:rpg]
with pytest.warns(DeprecationWarning):
ppc = pm.sample_posterior_predictive(trace, vars=[a])
ppc = pm.sample_posterior_predictive(trace, var_names=["a"])
# test empty ppc
ppc = pm.sample_posterior_predictive(trace, var_names=[])
assert len(ppc) == 0
Expand Down Expand Up @@ -518,8 +517,6 @@ def test_exceptions(self, caplog):
# Not for fast_sample_posterior_predictive
with pytest.raises(IncorrectArgumentsError):
ppc = pm.sample_posterior_predictive(trace, size=4, keep_size=True)
with pytest.raises(IncorrectArgumentsError):
ppc = pm.sample_posterior_predictive(trace, vars=[a], var_names=["a"])
# test wrong type argument
bad_trace = {"mu": stats.norm.rvs(size=1000)}
with pytest.raises(TypeError):
Expand Down Expand Up @@ -653,16 +650,7 @@ def test_deterministic_of_observed(self):

trace = pm.sample(100, chains=nchains)
np.random.seed(0)
with pytest.warns(DeprecationWarning):
ppc = pm.sample_posterior_predictive(
model=model,
trace=trace,
samples=len(trace) * nchains,
vars=(model.deterministics + model.basic_RVs),
)

rtol = 1e-5 if theano.config.floatX == "float64" else 1e-4
npt.assert_allclose(ppc["in_1"] + ppc["in_2"], ppc["out"], rtol=rtol)

np.random.seed(0)
ppc = pm.sample_posterior_predictive(
Expand Down

0 comments on commit ca0c28b

Please sign in to comment.