Skip to content

Commit

Permalink
Update io pystan (#1585)
Browse files Browse the repository at this point in the history
* update io_pystan.py to follow schema convention

* refactor io_pystan.py

* update CHANGELOG.md

* remove lp from rename dict and refactor the function
  • Loading branch information
madhucharan authored Mar 4, 2021
1 parent 4443da3 commit 5a00626
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* Integrate `index_origin` with all the library ([1201](https://github.com/arviz-devs/arviz/pull/1201))
* Fix pareto k threshold typo in reloo function ([1580](https://github.com/arviz-devs/arviz/pull/1580))
* Preserve shape from Stan code in `from_cmdstanpy` ([1579](https://github.com/arviz-devs/arviz/pull/1579))
* Updated `from_pystan` converters to follow schema convention ([1585](https://github.com/arviz-devs/arviz/pull/1585)
* Used generator instead of list wherever possible ([1588](https://github.com/arviz-devs/arviz/pull/1588))
* Correctly use chain index when constructing PyMC3 `DefaultTrace` in `from_pymc3` ([1590](https://github.com/arviz-devs/arviz/pull/1590))

Expand Down
25 changes: 20 additions & 5 deletions arviz/data/io_pystan.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,14 @@ def get_sample_stats(fit, warmup=False):
"""Extract sample stats from PyStan fit."""
dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64}

rename_dict = {
"divergent": "diverging",
"n_leapfrog": "n_steps",
"treedepth": "tree_depth",
"stepsize": "step_size",
"accept_stat": "acceptance_rate",
}

ndraws_warmup = fit.sim["warmup2"]
if max(ndraws_warmup) == 0:
warmup = False
Expand All @@ -653,17 +661,16 @@ def get_sample_stats(fit, warmup=False):
dtype = dtypes.get(key)
values = values.astype(dtype)
name = re.sub("__$", "", key)
name = "diverging" if name == "divergent" else name
name = rename_dict.get(name, name)
data[name] = values

data_warmup = OrderedDict()
if warmup:
for key, values in extraction_warmup.items():
values = np.stack(values, axis=0)
dtype = dtypes.get(key)
values = values.astype(dtype)
values = values.astype(dtypes.get(key))
name = re.sub("__$", "", key)
name = "diverging" if name == "divergent" else name
name = rename_dict.get(name, name)
data_warmup[name] = values

return data, data_warmup
Expand Down Expand Up @@ -773,6 +780,14 @@ def get_sample_stats_stan3(fit, variables=None, ignore=None):
"""Extract sample stats from PyStan3 fit."""
dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64}

rename_dict = {
"divergent": "diverging",
"n_leapfrog": "n_steps",
"treedepth": "tree_depth",
"stepsize": "step_size",
"accept_stat": "acceptance_rate",
}

if isinstance(variables, str):
variables = [variables]
if isinstance(ignore, str):
Expand All @@ -789,7 +804,7 @@ def get_sample_stats_stan3(fit, variables=None, ignore=None):
dtype = dtypes.get(key)
values = values.astype(dtype)
name = re.sub("__$", "", key)
name = "diverging" if name == "divergent" else name
name = rename_dict.get(name, name)
data[name] = values

return data
Expand Down

0 comments on commit 5a00626

Please sign in to comment.