From ef3b4c2bec0124de50a98652ed506c3084a4c900 Mon Sep 17 00:00:00 2001 From: madhu charan <62477860+madhucharan@users.noreply.github.com> Date: Fri, 26 Feb 2021 22:38:06 +0530 Subject: [PATCH 1/4] update io_pystan.py to follow schema convention --- arviz/data/io_pystan.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/arviz/data/io_pystan.py b/arviz/data/io_pystan.py index 8690e9aa32..14b4b482b4 100644 --- a/arviz/data/io_pystan.py +++ b/arviz/data/io_pystan.py @@ -627,6 +627,15 @@ def get_sample_stats(fit, warmup=False): """Extract sample stats from PyStan fit.""" dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64} + rename_dict = { + "divergent": "diverging", + "n_leapfrog": "n_steps", + "treedepth": "tree_depth", + "stepsize": "step_size", + "accept_stat": "acceptance_rate", + "lp":"lp", + } + ndraws_warmup = fit.sim["warmup2"] if max(ndraws_warmup) == 0: warmup = False @@ -653,7 +662,7 @@ def get_sample_stats(fit, warmup=False): dtype = dtypes.get(key) values = values.astype(dtype) name = re.sub("__$", "", key) - name = "diverging" if name == "divergent" else name + name = rename_dict.get(name, name) data[name] = values data_warmup = OrderedDict() @@ -663,7 +672,7 @@ def get_sample_stats(fit, warmup=False): dtype = dtypes.get(key) values = values.astype(dtype) name = re.sub("__$", "", key) - name = "diverging" if name == "divergent" else name + name = rename_dict.get(name, name) data_warmup[name] = values return data, data_warmup @@ -773,6 +782,15 @@ def get_sample_stats_stan3(fit, variables=None, ignore=None): """Extract sample stats from PyStan3 fit.""" dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64} + rename_dict = { + "divergent": "diverging", + "n_leapfrog": "n_steps", + "treedepth": "tree_depth", + "stepsize": "step_size", + "accept_stat": "acceptance_rate", + "lp":"lp", + } + if isinstance(variables, str): variables = [variables] if isinstance(ignore, str): @@ -789,7 +807,7 @@ def get_sample_stats_stan3(fit, variables=None, ignore=None): dtype = dtypes.get(key) values = values.astype(dtype) name = re.sub("__$", "", key) - name = "diverging" if name == "divergent" else name + name = rename_dict.get(name, name) data[name] = values return data From f6fc67cd954446b27914df8b24aea48c3df9442f Mon Sep 17 00:00:00 2001 From: madhu charan <62477860+madhucharan@users.noreply.github.com> Date: Fri, 26 Feb 2021 22:55:56 +0530 Subject: [PATCH 2/4] refactor io_pystan.py --- arviz/data/io_pystan.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/arviz/data/io_pystan.py b/arviz/data/io_pystan.py index 14b4b482b4..50c5aebfef 100644 --- a/arviz/data/io_pystan.py +++ b/arviz/data/io_pystan.py @@ -633,7 +633,7 @@ def get_sample_stats(fit, warmup=False): "treedepth": "tree_depth", "stepsize": "step_size", "accept_stat": "acceptance_rate", - "lp":"lp", + "lp": "lp", } ndraws_warmup = fit.sim["warmup2"] @@ -659,8 +659,7 @@ def get_sample_stats(fit, warmup=False): data = OrderedDict() for key, values in extraction.items(): values = np.stack(values, axis=0) - dtype = dtypes.get(key) - values = values.astype(dtype) + values = values.astype(dtypes.get(key)) name = re.sub("__$", "", key) name = rename_dict.get(name, name) data[name] = values @@ -669,8 +668,7 @@ def get_sample_stats(fit, warmup=False): if warmup: for key, values in extraction_warmup.items(): values = np.stack(values, axis=0) - dtype = dtypes.get(key) - values = values.astype(dtype) + values = values.astype(dtypes.get(key)) name = re.sub("__$", "", key) name = rename_dict.get(name, name) data_warmup[name] = values @@ -788,7 +786,7 @@ def get_sample_stats_stan3(fit, variables=None, ignore=None): "treedepth": "tree_depth", "stepsize": "step_size", "accept_stat": "acceptance_rate", - "lp":"lp", + "lp": "lp", } if isinstance(variables, str): @@ -804,8 +802,7 @@ def get_sample_stats_stan3(fit, variables=None, ignore=None): values = fit._draws[fit._parameter_indexes(key)] # pylint: disable=protected-access values = values.reshape(new_shape, order="F") values = np.moveaxis(values, [-2, -1], [1, 0]) - dtype = dtypes.get(key) - values = values.astype(dtype) + values = values.astype(dtypes.get(key)) name = re.sub("__$", "", key) name = rename_dict.get(name, name) data[name] = values From 6f17a0479c5548651d6d9055acbdcbda51afdcdd Mon Sep 17 00:00:00 2001 From: madhu charan <62477860+madhucharan@users.noreply.github.com> Date: Fri, 26 Feb 2021 23:03:22 +0530 Subject: [PATCH 3/4] update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 03a2311339..3db5da31ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ * Integrate `index_origin` with all the library ([1201](https://github.com/arviz-devs/arviz/pull/1201)) * Fix pareto k threshold typo in reloo function ([1580](https://github.com/arviz-devs/arviz/pull/1580)) * Preserve shape from Stan code in `from_cmdstanpy` ([1579](https://github.com/arviz-devs/arviz/pull/1579)) +* Updated `from_pystan` converters to follow schema convention ([1585](https://github.com/arviz-devs/arviz/pull/1585) ### Deprecation * Deprecated `index_origin` and `order` arguments in `az.summary` ([1201](https://github.com/arviz-devs/arviz/pull/1201)) From a8dc57dbd95fb717450a38c59a46a262cdfca831 Mon Sep 17 00:00:00 2001 From: madhu charan <62477860+madhucharan@users.noreply.github.com> Date: Sun, 28 Feb 2021 18:24:24 +0530 Subject: [PATCH 4/4] remove lp from rename dict and refactor the function --- arviz/data/io_pystan.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/arviz/data/io_pystan.py b/arviz/data/io_pystan.py index 50c5aebfef..90ce5321bc 100644 --- a/arviz/data/io_pystan.py +++ b/arviz/data/io_pystan.py @@ -633,7 +633,6 @@ def get_sample_stats(fit, warmup=False): "treedepth": "tree_depth", "stepsize": "step_size", "accept_stat": "acceptance_rate", - "lp": "lp", } ndraws_warmup = fit.sim["warmup2"] @@ -659,7 +658,8 @@ def get_sample_stats(fit, warmup=False): data = OrderedDict() for key, values in extraction.items(): values = np.stack(values, axis=0) - values = values.astype(dtypes.get(key)) + dtype = dtypes.get(key) + values = values.astype(dtype) name = re.sub("__$", "", key) name = rename_dict.get(name, name) data[name] = values @@ -786,7 +786,6 @@ def get_sample_stats_stan3(fit, variables=None, ignore=None): "treedepth": "tree_depth", "stepsize": "step_size", "accept_stat": "acceptance_rate", - "lp": "lp", } if isinstance(variables, str): @@ -802,7 +801,8 @@ def get_sample_stats_stan3(fit, variables=None, ignore=None): values = fit._draws[fit._parameter_indexes(key)] # pylint: disable=protected-access values = values.reshape(new_shape, order="F") values = np.moveaxis(values, [-2, -1], [1, 0]) - values = values.astype(dtypes.get(key)) + dtype = dtypes.get(key) + values = values.astype(dtype) name = re.sub("__$", "", key) name = rename_dict.get(name, name) data[name] = values