Skip to content

Commit

Permalink
from_cmdstanpy - fix bug introduced by refactor (PR 1558) (#1564)
Browse files Browse the repository at this point in the history
* squeezing scalar vars

* added checks on variable shape

* lint fix

* remove checks on variable shape

* adding back shape checks

* pylint - diable no-member, not-an-iterable

* black

* Update .pylintrc indentation

* add commas

Co-authored-by: Oriol (ZBook) <oriol.abril.pla@gmail.com>
  • Loading branch information
mitzimorris and OriolAbril authored Feb 16, 2021
1 parent 8a6ca10 commit 1a7f83f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ disable=missing-docstring,
unsubscriptable-object,
cyclic-import,
ungrouped-imports,

not-an-iterable,
no-member,
#TODO: Remove this once todos are done
fixme

Expand Down
13 changes: 10 additions & 3 deletions arviz/data/io_cmdstanpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,10 +501,17 @@ def _unpack_fit(fit, items, save_warmup):
else:
raise ValueError("fit data, unknown variable: {}".format(item))
if save_warmup:
sample_warmup[item] = draws[:num_warmup, :, col_idxs]
sample[item] = draws[num_warmup:, :, col_idxs]
if len(col_idxs) == 1:
sample_warmup[item] = np.squeeze(draws[:num_warmup, :, col_idxs], axis=2)
sample[item] = np.squeeze(draws[num_warmup:, :, col_idxs], axis=2)
else:
sample_warmup[item] = draws[:num_warmup, :, col_idxs]
sample[item] = draws[num_warmup:, :, col_idxs]
else:
sample[item] = draws[:, :, col_idxs]
if len(col_idxs) == 1:
sample[item] = np.squeeze(draws[:, :, col_idxs], axis=2)
else:
sample[item] = draws[:, :, col_idxs]

return sample, sample_warmup

Expand Down
7 changes: 6 additions & 1 deletion arviz/tests/external_tests/test_data_cmdstanpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def test_sampler_stats(self, data, eight_schools_params):
test_dict = {"sample_stats": ["lp", "diverging"]}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails
assert len(inference_data.sample_stats.lp.shape) == 2 # pylint: disable=no-member

def test_inference_data(self, data, eight_schools_params):
inference_data1 = self.get_inference_data(data, eight_schools_params)
Expand Down Expand Up @@ -354,6 +355,8 @@ def test_inference_data(self, data, eight_schools_params):
test_dict = {"posterior": ["theta"], "prior": ["theta"]}
fails = check_multiple_attrs(test_dict, inference_data4)
assert not fails
assert len(inference_data4.posterior.theta.shape) == 3 # pylint: disable=no-member
assert len(inference_data4.posterior.mu.shape) == 2 # pylint: disable=no-member

def test_inference_data_warmup(self, data, eight_schools_params):
inference_data_true_is_true = self.get_inference_data_warmup_true_is_true(
Expand Down Expand Up @@ -429,4 +432,6 @@ def test_inference_data_warmup(self, data, eight_schools_params):
assert "warmup_posterior" not in inference_data_false_is_false
assert "warmup_predictions" not in inference_data_false_is_false
assert "warmup_log_likelihood" not in inference_data_false_is_false
assert "warmup_prior" not in inference_data_false_is_false
assert (
"warmup_prior" not in inference_data_false_is_false
) # pylint: disable=redefined-outer-name

0 comments on commit 1a7f83f

Please sign in to comment.