Skip to content

Commit

Permalink
try fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Oct 11, 2022
1 parent c343010 commit c8a3c95
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 18 deletions.
5 changes: 3 additions & 2 deletions pymc/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,8 +1263,9 @@ def test_interval_missing_observations():

# Make sure that the observed values are newly generated samples and that
# the observed and deterministic matche
pp_trace = pm.sample_posterior_predictive(
trace, return_inferencedata=False, keep_size=False
pp_idata = pm.sample_posterior_predictive(trace)
pp_trace = pp_idata.posterior_predictive.stack(sample=["chain", "draw"]).transpose(
"sample", ...
)
assert np.all(np.var(pp_trace["theta1"], 0) > 0.0)
assert np.all(np.var(pp_trace["theta2"], 0) > 0.0)
Expand Down
18 changes: 2 additions & 16 deletions pymc/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,21 +505,6 @@ def test_partial_trace_sample():
assert "b" not in idata.posterior


def test_chain_idx():
# see https://github.com/pymc-devs/pymc/issues/4469
with pm.Model():
mu = pm.Normal("mu")
x = pm.Normal("x", mu=mu, sigma=1, observed=np.asarray(3))
# note draws-tune must be >100 AND we need an observed RV for this to properly
# trigger convergence checks, which is one particular case in which this failed
# before
idata = pm.sample(draws=150, tune=10, chain_idx=1)

ppc = pm.sample_posterior_predictive(idata)
# TODO FIXME: Assert something.
ppc = pm.sample_posterior_predictive(idata)


@pytest.mark.parametrize(
"n_points, tune, expected_length, expected_n_traces",
[
Expand Down Expand Up @@ -655,7 +640,8 @@ def test_normal_scalar(self):
assert ppc["a"].shape == (nchains, ndraws)

# test default case
ppc = pm.sample_posterior_predictive(trace, var_names=["a"])
idata_ppc = pm.sample_posterior_predictive(trace, var_names=["a"])
ppc = idata_ppc.posterior_predictive
assert "a" in ppc
assert ppc["a"].shape == (nchains, ndraws)
# mu's standard deviation may have changed thanks to a's observed
Expand Down

0 comments on commit c8a3c95

Please sign in to comment.