diff --git a/pymc3/tests/test_data_container.py b/pymc3/tests/test_data_container.py index 99d6b693c7d..4aa20eedb8b 100644 --- a/pymc3/tests/test_data_container.py +++ b/pymc3/tests/test_data_container.py @@ -159,22 +159,40 @@ def test_shared_data_as_rv_input(self): """ with pm.Model() as m: x = pm.Data("x", [1.0, 2.0, 3.0]) - _ = pm.Normal("y", mu=x, size=3) - trace = pm.sample( - chains=1, return_inferencedata=False, compute_convergence_checks=False + y = pm.Normal("y", mu=x, size=(2, 3)) + assert y.eval().shape == (2, 3) + idata = pm.sample( + chains=1, + tune=500, + draws=550, + return_inferencedata=True, + compute_convergence_checks=False, ) + samples = idata.posterior["y"] + assert samples.shape == (1, 550, 2, 3) np.testing.assert_allclose(np.array([1.0, 2.0, 3.0]), x.get_value(), atol=1e-1) - np.testing.assert_allclose(np.array([1.0, 2.0, 3.0]), trace["y"].mean(0), atol=1e-1) + np.testing.assert_allclose( + np.array([1.0, 2.0, 3.0]), samples.mean(("chain", "draw", "y_dim_0")), atol=1e-1 + ) with m: pm.set_data({"x": np.array([2.0, 4.0, 6.0])}) - trace = pm.sample( - chains=1, return_inferencedata=False, compute_convergence_checks=False + assert y.eval().shape == (2, 3) + idata = pm.sample( + chains=1, + tune=500, + draws=620, + return_inferencedata=True, + compute_convergence_checks=False, ) + samples = idata.posterior["y"] + assert samples.shape == (1, 620, 2, 3) np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), x.get_value(), atol=1e-1) - np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), trace["y"].mean(0), atol=1e-1) + np.testing.assert_allclose( + np.array([2.0, 4.0, 6.0]), samples.mean(("chain", "draw", "y_dim_0")), atol=1e-1 + ) def test_shared_scalar_as_rv_input(self): # See https://github.com/pymc-devs/pymc3/issues/3139