diff --git a/CHANGELOG.md b/CHANGELOG.md index fe691a95a9..77d5d54c56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ * Integrate `index_origin` with all the library ([1201](https://github.com/arviz-devs/arviz/pull/1201)) * Fix pareto k threshold typo in reloo function ([1580](https://github.com/arviz-devs/arviz/pull/1580)) * Preserve shape from Stan code in `from_cmdstanpy` ([1579](https://github.com/arviz-devs/arviz/pull/1579)) +* Updated `from_pystan` converters to follow schema convention ([1585](https://github.com/arviz-devs/arviz/pull/1585) * Used generator instead of list wherever possible ([1588](https://github.com/arviz-devs/arviz/pull/1588)) * Correctly use chain index when constructing PyMC3 `DefaultTrace` in `from_pymc3` ([1590](https://github.com/arviz-devs/arviz/pull/1590)) diff --git a/arviz/data/io_pystan.py b/arviz/data/io_pystan.py index 8690e9aa32..90ce5321bc 100644 --- a/arviz/data/io_pystan.py +++ b/arviz/data/io_pystan.py @@ -627,6 +627,14 @@ def get_sample_stats(fit, warmup=False): """Extract sample stats from PyStan fit.""" dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64} + rename_dict = { + "divergent": "diverging", + "n_leapfrog": "n_steps", + "treedepth": "tree_depth", + "stepsize": "step_size", + "accept_stat": "acceptance_rate", + } + ndraws_warmup = fit.sim["warmup2"] if max(ndraws_warmup) == 0: warmup = False @@ -653,17 +661,16 @@ def get_sample_stats(fit, warmup=False): dtype = dtypes.get(key) values = values.astype(dtype) name = re.sub("__$", "", key) - name = "diverging" if name == "divergent" else name + name = rename_dict.get(name, name) data[name] = values data_warmup = OrderedDict() if warmup: for key, values in extraction_warmup.items(): values = np.stack(values, axis=0) - dtype = dtypes.get(key) - values = values.astype(dtype) + values = values.astype(dtypes.get(key)) name = re.sub("__$", "", key) - name = "diverging" if name == "divergent" else name + name = rename_dict.get(name, name) data_warmup[name] = values return data, data_warmup @@ -773,6 +780,14 @@ def get_sample_stats_stan3(fit, variables=None, ignore=None): """Extract sample stats from PyStan3 fit.""" dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64} + rename_dict = { + "divergent": "diverging", + "n_leapfrog": "n_steps", + "treedepth": "tree_depth", + "stepsize": "step_size", + "accept_stat": "acceptance_rate", + } + if isinstance(variables, str): variables = [variables] if isinstance(ignore, str): @@ -789,7 +804,7 @@ def get_sample_stats_stan3(fit, variables=None, ignore=None): dtype = dtypes.get(key) values = values.astype(dtype) name = re.sub("__$", "", key) - name = "diverging" if name == "divergent" else name + name = rename_dict.get(name, name) data[name] = values return data