From 5c30761018f704a960840e6b04f61860de61cf4a Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Wed, 5 Jun 2019 07:41:09 -0700 Subject: [PATCH 1/2] Add failing test for single obs --- arviz/tests/test_data_pymc.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/arviz/tests/test_data_pymc.py b/arviz/tests/test_data_pymc.py index 4dc41a0315..4b18f1d54a 100644 --- a/arviz/tests/test_data_pymc.py +++ b/arviz/tests/test_data_pymc.py @@ -81,3 +81,12 @@ def test_multiple_observed_rv(self): fails = check_multiple_attrs(test_dict, inference_data) assert not fails assert not hasattr(inference_data.sample_stats, "log_likelihood") + + def test_single_observation(self): + with pm.Model(): + p = pm.Uniform("p", 0, 1) + pm.Binomial("w", p=p, n=2, observed=1) + trace = pm.sample(500, chains=2) + + inference_data = from_pymc3(trace=trace) + assert inference_data From 611cd7c8a4ca1017e8ac4c116887c5889a7d5e37 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Wed, 5 Jun 2019 08:06:47 -0700 Subject: [PATCH 2/2] Add 1d array fix --- arviz/data/io_pymc3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/data/io_pymc3.py b/arviz/data/io_pymc3.py index b352da4a3e..4b26475f0f 100644 --- a/arviz/data/io_pymc3.py +++ b/arviz/data/io_pymc3.py @@ -44,7 +44,7 @@ def log_likelihood_vals_point(point): """Compute log likelihood for each observed point.""" log_like_vals = [] for var, log_like in cached: - log_like_val = log_like(point) + log_like_val = np.atleast_1d(log_like(point)) if var.missing_values: log_like_val = log_like_val[~var.observations.mask] log_like_vals.append(log_like_val)