diff --git a/arviz/data/io_cmdstanpy.py b/arviz/data/io_cmdstanpy.py index 095de475f1..7d74cd1c06 100644 --- a/arviz/data/io_cmdstanpy.py +++ b/arviz/data/io_cmdstanpy.py @@ -125,6 +125,13 @@ def posterior_to_xarray(self): def sample_stats_to_xarray(self): """Extract sample_stats from fit.""" dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64} + rename_dict = { + "divergent": "diverging", + "n_leapfrog": "n_steps", + "treedepth": "tree_depth", + "stepsize": "step_size", + "accept_stat": "acceptance_rate", + } columns = self.posterior.column_names valid_cols = [col for col in columns if col.endswith("__")] @@ -138,7 +145,7 @@ def sample_stats_to_xarray(self): for s_param in list(data.keys()): s_param_, *_ = s_param.split(".") name = re.sub("__$", "", s_param_) - name = "diverging" if name == "divergent" else name + name = rename_dict.get(name, name) data[name] = data.pop(s_param).astype(dtypes.get(s_param, float)) if data_warmup: data_warmup[name] = data_warmup.pop(s_param).astype(dtypes.get(s_param, float))