Skip to content

Commit

Permalink
Refactor test to use InferenceData
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelosthege committed May 14, 2021
1 parent da43d64 commit aebc7e2
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions pymc3/tests/test_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,22 +159,40 @@ def test_shared_data_as_rv_input(self):
"""
with pm.Model() as m:
x = pm.Data("x", [1.0, 2.0, 3.0])
_ = pm.Normal("y", mu=x, size=3)
trace = pm.sample(
chains=1, return_inferencedata=False, compute_convergence_checks=False
y = pm.Normal("y", mu=x, size=(2, 3))
assert y.eval().shape == (2, 3)
idata = pm.sample(
chains=1,
tune=500,
draws=550,
return_inferencedata=True,
compute_convergence_checks=False,
)
samples = idata.posterior["y"]
assert samples.shape == (1, 550, 2, 3)

np.testing.assert_allclose(np.array([1.0, 2.0, 3.0]), x.get_value(), atol=1e-1)
np.testing.assert_allclose(np.array([1.0, 2.0, 3.0]), trace["y"].mean(0), atol=1e-1)
np.testing.assert_allclose(
np.array([1.0, 2.0, 3.0]), samples.mean(("chain", "draw", "y_dim_0")), atol=1e-1
)

with m:
pm.set_data({"x": np.array([2.0, 4.0, 6.0])})
trace = pm.sample(
chains=1, return_inferencedata=False, compute_convergence_checks=False
assert y.eval().shape == (2, 3)
idata = pm.sample(
chains=1,
tune=500,
draws=620,
return_inferencedata=True,
compute_convergence_checks=False,
)
samples = idata.posterior["y"]
assert samples.shape == (1, 620, 2, 3)

np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), x.get_value(), atol=1e-1)
np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), trace["y"].mean(0), atol=1e-1)
np.testing.assert_allclose(
np.array([2.0, 4.0, 6.0]), samples.mean(("chain", "draw", "y_dim_0")), atol=1e-1
)

def test_shared_scalar_as_rv_input(self):
# See https://github.com/pymc-devs/pymc3/issues/3139
Expand Down

0 comments on commit aebc7e2

Please sign in to comment.