Skip to content

Commit

Permalink
set warmup groups as public attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed May 4, 2020
1 parent 26ff7d7 commit 1688e18
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 20 deletions.
2 changes: 1 addition & 1 deletion arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"predictions_constant_data",
]

WARMUP_TAG = "_warmup_"
WARMUP_TAG = "warmup_"

SUPPORTED_GROUPS_WARMUP = [
"{}posterior".format(WARMUP_TAG),
Expand Down
10 changes: 5 additions & 5 deletions arviz/data/inference_data.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ class InferenceData:
prior: Optional[xr.Dataset]
prior_predictive: Optional[xr.Dataset]
sample_stats_prior: Optional[xr.Dataset]
_warmup_posterior: Optional[xr.Dataset]
_warmup_posterior_predictive: Optional[xr.Dataset]
_warmup_predictions: Optional[xr.Dataset]
_warmup_log_likelihood: Optional[xr.Dataset]
_warmup_sample_stats: Optional[xr.Dataset]
warmup_posterior: Optional[xr.Dataset]
warmup_posterior_predictive: Optional[xr.Dataset]
warmup_predictions: Optional[xr.Dataset]
warmup_log_likelihood: Optional[xr.Dataset]
warmup_sample_stats: Optional[xr.Dataset]
def __init__(self, **kwargs): ...
def __repr__(self) -> str: ...
def __delattr__(self, group: str) -> None: ...
Expand Down
4 changes: 2 additions & 2 deletions arviz/plots/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def to_cds(
"posterior_groups_warmup"}
- posterior_groups: posterior, posterior_predictive, sample_stats
- prior_groups: prior, prior_predictive, sample_stats_prior
- posterior_groups_warmup: _warmup_posterior, _warmup_posterior_predictive,
_warmup_sample_stats
- posterior_groups_warmup: warmup_posterior, warmup_posterior_predictive,
warmup_sample_stats
ignore_groups : str or list of str, optional
Ignore specific groups from CDS.
dimension : str, or list of str, optional
Expand Down
15 changes: 7 additions & 8 deletions arviz/tests/external_tests/test_data_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,16 +417,15 @@ def test_save_warmup(self, save_warmup):
test_dict = {
"posterior": ["u1", "n1"],
"sample_stats": ["~tune", "accept"],
f"{prefix}_warmup_posterior": ["u1", "n1"],
f"{prefix}_warmup_sample_stats": ["~tune"],
"~_warmup_log_likelihood": [],
f"{prefix}warmup_posterior": ["u1", "n1"],
f"{prefix}warmup_sample_stats": ["~tune"],
"~warmup_log_likelihood": [],
"~log_likelihood": [],
}
fails = check_multiple_attrs(test_dict, idata)
assert not fails
assert idata.posterior.sizes["chain"] == 2
assert idata.posterior.sizes["draw"] == 200
assert idata.posterior.dims["chain"] == 2
assert idata.posterior.dims["draw"] == 200
if save_warmup:
# pylint: disable=protected-access
assert idata._warmup_posterior.dims["chain"] == 2
assert idata._warmup_posterior.dims["draw"] == 100
assert idata.warmup_posterior.dims["chain"] == 2
assert idata.warmup_posterior.dims["draw"] == 100
2 changes: 1 addition & 1 deletion arviz/tests/external_tests/test_data_pystan.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_inference_data(self, data, eight_schools_params):
}
if pystan_version() == 2:
test_dict.update(
{"_warmup_posterior": ["theta"], "_warmup_sample_stats": ["diverging", "lp"]}
{"warmup_posterior": ["theta"], "warmup_sample_stats": ["diverging", "lp"]}
)
fails = check_multiple_attrs(test_dict, inference_data4)
assert not fails
Expand Down
6 changes: 3 additions & 3 deletions arviz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,8 @@ def flatten_inference_data_to_dict(
{"posterior_groups", "prior_groups", "posterior_groups_warmup"}
- posterior_groups: posterior, posterior_predictive, sample_stats
- prior_groups: prior, prior_predictive, sample_stats_prior
- posterior_groups_warmup: _warmup_posterior, _warmup_posterior_predictive,
_warmup_sample_stats
- posterior_groups_warmup: warmup_posterior, warmup_posterior_predictive,
warmup_sample_stats
ignore_groups : str or list of str, optional
Ignore specific groups from CDS.
dimension : str, or list of str, optional
Expand Down Expand Up @@ -455,7 +455,7 @@ def flatten_inference_data_to_dict(
elif groups.lower() == "prior_groups":
groups = ["prior", "prior_predictive", "sample_stats_prior"]
elif groups.lower() == "posterior_groups_warmup":
groups = ["_warmup_posterior", "_warmup_posterior_predictive", "_warmup_sample_stats"]
groups = ["warmup_posterior", "warmup_posterior_predictive", "warmup_sample_stats"]
else:
raise TypeError(
(
Expand Down

0 comments on commit 1688e18

Please sign in to comment.