diff --git a/.pylintrc b/.pylintrc index 54e92986bb..f629f9ed8d 100644 --- a/.pylintrc +++ b/.pylintrc @@ -70,7 +70,8 @@ disable=missing-docstring, unsubscriptable-object, cyclic-import, ungrouped-imports, - + not-an-iterable, + no-member, #TODO: Remove this once todos are done fixme diff --git a/arviz/data/io_cmdstanpy.py b/arviz/data/io_cmdstanpy.py index 638fba4844..67481bb8ee 100644 --- a/arviz/data/io_cmdstanpy.py +++ b/arviz/data/io_cmdstanpy.py @@ -494,10 +494,17 @@ def _unpack_fit(fit, items, save_warmup): else: raise ValueError("fit data, unknown variable: {}".format(item)) if save_warmup: - sample_warmup[item] = draws[:num_warmup, :, col_idxs] - sample[item] = draws[num_warmup:, :, col_idxs] + if len(col_idxs) == 1: + sample_warmup[item] = np.squeeze(draws[:num_warmup, :, col_idxs], axis=2) + sample[item] = np.squeeze(draws[num_warmup:, :, col_idxs], axis=2) + else: + sample_warmup[item] = draws[:num_warmup, :, col_idxs] + sample[item] = draws[num_warmup:, :, col_idxs] else: - sample[item] = draws[:, :, col_idxs] + if len(col_idxs) == 1: + sample[item] = np.squeeze(draws[:, :, col_idxs], axis=2) + else: + sample[item] = draws[:, :, col_idxs] return sample, sample_warmup diff --git a/arviz/tests/external_tests/test_data_cmdstanpy.py b/arviz/tests/external_tests/test_data_cmdstanpy.py index bf3e12d9e5..6d64ff480b 100644 --- a/arviz/tests/external_tests/test_data_cmdstanpy.py +++ b/arviz/tests/external_tests/test_data_cmdstanpy.py @@ -308,6 +308,7 @@ def test_sampler_stats(self, data, eight_schools_params): test_dict = {"sample_stats": ["lp", "diverging"]} fails = check_multiple_attrs(test_dict, inference_data) assert not fails + assert len(inference_data.sample_stats.lp.shape) == 2 # pylint: disable=no-member def test_inference_data(self, data, eight_schools_params): inference_data1 = self.get_inference_data(data, eight_schools_params) @@ -354,6 +355,8 @@ def test_inference_data(self, data, eight_schools_params): test_dict = {"posterior": ["theta"], "prior": ["theta"]} fails = check_multiple_attrs(test_dict, inference_data4) assert not fails + assert len(inference_data4.posterior.theta.shape) == 3 # pylint: disable=no-member + assert len(inference_data4.posterior.mu.shape) == 2 # pylint: disable=no-member def test_inference_data_warmup(self, data, eight_schools_params): inference_data_true_is_true = self.get_inference_data_warmup_true_is_true( @@ -429,4 +432,6 @@ def test_inference_data_warmup(self, data, eight_schools_params): assert "warmup_posterior" not in inference_data_false_is_false assert "warmup_predictions" not in inference_data_false_is_false assert "warmup_log_likelihood" not in inference_data_false_is_false - assert "warmup_prior" not in inference_data_false_is_false + assert ( + "warmup_prior" not in inference_data_false_is_false + ) # pylint: disable=redefined-outer-name