Skip to content

Commit

Permalink
Merge 7db1820 into 6fa1ce8
Browse files Browse the repository at this point in the history
  • Loading branch information
utkarsh-maheshwari authored Feb 9, 2021
2 parents 6fa1ce8 + 7db1820 commit e520a79
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion arviz/data/io_cmdstanpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ def posterior_to_xarray(self):
def sample_stats_to_xarray(self):
"""Extract sample_stats from fit."""
dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64}
rename_dict = {
"divergent": "diverging",
"n_leapfrog": "n_steps",
"treedepth": "tree_depth",
"stepsize": "step_size",
"accept_stat": "acceptance_rate",
}

columns = self.posterior.column_names
valid_cols = [col for col in columns if col.endswith("__")]
Expand All @@ -138,7 +145,7 @@ def sample_stats_to_xarray(self):
for s_param in list(data.keys()):
s_param_, *_ = s_param.split(".")
name = re.sub("__$", "", s_param_)
name = "diverging" if name == "divergent" else name
name = rename_dict.get(name, name)
data[name] = data.pop(s_param).astype(dtypes.get(s_param, float))
if data_warmup:
data_warmup[name] = data_warmup.pop(s_param).astype(dtypes.get(s_param, float))
Expand Down

0 comments on commit e520a79

Please sign in to comment.