Skip to content

Commit

Permalink
Limit random_state to the only test that checks the content of a draw
Browse files Browse the repository at this point in the history
  • Loading branch information
mattiadg committed Oct 19, 2022
1 parent 29dbf85 commit bfc7452
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions pymc/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,27 +625,23 @@ def test_normal_scalar(self):
chains=nchains,
)

random_state = self.get_random_state()
with model:
# test list input
ppc0 = pm.sample_posterior_predictive(
10 * [model.initial_point()], return_inferencedata=False, random_seed=random_state
10 * [model.initial_point()], return_inferencedata=False
)
assert "a" in ppc0
assert len(ppc0["a"][0]) == 10
# test empty ppc
ppc = pm.sample_posterior_predictive(
trace, var_names=[], return_inferencedata=False, random_seed=random_state
)
ppc = pm.sample_posterior_predictive(trace, var_names=[], return_inferencedata=False)
assert len(ppc) == 0

# test keep_size parameter
ppc = pm.sample_posterior_predictive(
trace, return_inferencedata=False, random_seed=random_state
)
ppc = pm.sample_posterior_predictive(trace, return_inferencedata=False)
assert ppc["a"].shape == (nchains, ndraws)

# test default case
random_state = self.get_random_state()
idata_ppc = pm.sample_posterior_predictive(
trace, var_names=["a"], random_seed=random_state
)
Expand Down

0 comments on commit bfc7452

Please sign in to comment.