Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

data io - change from_cmdstanpy to work with CmdStanPy version 0.9.68 #1558

Merged
merged 12 commits into from
Feb 14, 2021
248 changes: 198 additions & 50 deletions arviz/data/io_cmdstanpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
dims=None,
save_warmup=None,
):
self.posterior = posterior
self.posterior = posterior # CmdStanPy CmdStanMCMC object
self.posterior_predictive = posterior_predictive
self.predictions = predictions
self.prior = prior
Expand All @@ -57,6 +57,9 @@ def __init__(
@requires("posterior")
def posterior_to_xarray(self):
"""Extract posterior samples from output csv."""
if hasattr(self.posterior, "stan_vars_cols"):
return self.posterior_to_xarray_v68()

columns = self.posterior.column_names

# filter posterior_predictive, predictions and log_likelihood
Expand Down Expand Up @@ -124,6 +127,8 @@ def posterior_to_xarray(self):
@requires("posterior")
def sample_stats_to_xarray(self):
"""Extract sample_stats from fit."""
if hasattr(self.posterior, "sampler_vars_cols"):
return self.sample_stats_to_xarray_v68()
dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64}

columns = self.posterior.column_names
Expand Down Expand Up @@ -159,15 +164,24 @@ def posterior_predictive_to_xarray(self):
if isinstance(posterior_predictive, str):
posterior_predictive = [posterior_predictive]
posterior_predictive = set(posterior_predictive)
valid_cols = [
col for col in columns if col.split("[")[0].split(".")[0] in posterior_predictive
]
data, data_warmup = _unpack_frame(
self.posterior,
columns,
valid_cols,
self.save_warmup,
)

if hasattr(self.posterior, "stan_vars_cols"):
items = list(posterior_predictive)
data, data_warmup = _unpack_fit(
self.posterior,
items,
self.save_warmup,
)
else:
valid_cols = [
col for col in columns if col.split("[")[0].split(".")[0] in posterior_predictive
]
data, data_warmup = _unpack_frame(
self.posterior,
columns,
valid_cols,
self.save_warmup,
)

return (
dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims),
Expand All @@ -186,13 +200,24 @@ def predictions_to_xarray(self):
if isinstance(predictions, str):
predictions = [predictions]
predictions = set(predictions)
valid_cols = [col for col in columns if col.split("[")[0].split(".")[0] in set(predictions)]
data, data_warmup = _unpack_frame(
self.posterior,
columns,
valid_cols,
self.save_warmup,
)

if hasattr(self.posterior, "stan_vars_cols"):
items = list(predictions)
data, data_warmup = _unpack_fit(
self.posterior,
items,
self.save_warmup,
)
else:
valid_cols = [
col for col in columns if col.split("[")[0].split(".")[0] in set(predictions)
]
data, data_warmup = _unpack_frame(
self.posterior,
columns,
valid_cols,
self.save_warmup,
)

return (
dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims),
Expand All @@ -211,14 +236,24 @@ def log_likelihood_to_xarray(self):
if isinstance(log_likelihood, str):
log_likelihood = [log_likelihood]
log_likelihood = set(log_likelihood)
valid_cols = [col for col in columns if col.split("[")[0].split(".")[0] in log_likelihood]
data, data_warmup = _unpack_frame(
self.posterior,
columns,
valid_cols,
self.save_warmup,
)

if hasattr(self.posterior, "stan_vars_cols"):
items = list(log_likelihood)
data, data_warmup = _unpack_fit(
self.posterior,
items,
self.save_warmup,
)
else:
valid_cols = [
col for col in columns if col.split("[")[0].split(".")[0] in log_likelihood
]
data, data_warmup = _unpack_frame(
self.posterior,
columns,
valid_cols,
self.save_warmup,
)
return (
dict_to_dataset(
data,
Expand All @@ -239,32 +274,42 @@ def log_likelihood_to_xarray(self):
@requires("prior")
def prior_to_xarray(self):
"""Convert prior samples to xarray."""
# filter prior_predictive
columns = self.prior.column_names

# filter posterior_predictive and log_likelihood
prior_predictive = self.prior_predictive
if prior_predictive is None:
prior_predictive = []
elif isinstance(prior_predictive, str):
prior_predictive = [
col for col in columns if prior_predictive == col.split("[")[0].split(".")[0]
]
if hasattr(self.posterior, "stan_vars_cols"):
items = list(self.posterior.stan_vars_cols.keys())
if self.prior_predictive is not None:
try:
items = _filter(vars, self.prior_predictive)
except ValueError:
pass
data, data_warmup = _unpack_fit(
self.posterior,
items,
self.save_warmup,
)
else:
prior_predictive = [
col for col in columns if col.split("[")[0].split(".")[0] in set(prior_predictive)
]

invalid_cols = set(prior_predictive + [col for col in columns if col.endswith("__")])

valid_cols = [col for col in columns if col not in invalid_cols]

data, data_warmup = _unpack_frame(
self.prior,
columns,
valid_cols,
self.save_warmup,
)
columns = self.prior.column_names
prior_predictive = self.prior_predictive
if prior_predictive is None:
prior_predictive = []
elif isinstance(prior_predictive, str):
prior_predictive = [
col for col in columns if prior_predictive == col.split("[")[0].split(".")[0]
]
else:
prior_predictive = [
col
for col in columns
if col.split("[")[0].split(".")[0] in set(prior_predictive)
]
invalid_cols = set(prior_predictive + [col for col in columns if col.endswith("__")])
valid_cols = [col for col in columns if col not in invalid_cols]

data, data_warmup = _unpack_frame(
self.prior,
columns,
valid_cols,
self.save_warmup,
)

return (
dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims),
Expand Down Expand Up @@ -393,6 +438,105 @@ def to_inference_data(self):
},
)

@requires("posterior")
def posterior_to_xarray_v68(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we change the logic so this is "default" and the other function is "old"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes!

"""Extract posterior samples from output csv."""
items = list(self.posterior.stan_vars_cols.keys())
if self.posterior_predictive is not None:
try:
items = _filter(items, self.posterior_predictive)
except ValueError:
pass
if self.predictions is not None:
try:
items = _filter(items, self.predictions)
except ValueError:
pass
if self.log_likelihood is not None:
try:
items = _filter(items, self.log_likelihood)
except ValueError:
pass

valid_cols = []
for item in items:
valid_cols.extend(self.posterior.stan_vars_cols[item])

data, data_warmup = _unpack_fit(
self.posterior,
items,
self.save_warmup,
)
return (
dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims),
dict_to_dataset(
data_warmup, library=self.cmdstanpy, coords=self.coords, dims=self.dims
),
)

@requires("posterior")
def sample_stats_to_xarray_v68(self):
"""Extract sample_stats from fit."""
dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64}
items = list(self.posterior.sampler_vars_cols.keys())
data, data_warmup = _unpack_fit(
self.posterior,
items,
self.save_warmup,
)
for item in items:
item = re.sub("__$", "", item)
item = "diverging" if item == "divergent" else item
data[item] = data.pop(item).astype(dtypes.get(item, float))
if data_warmup:
data_warmup[item] = data_warmup.pop(item).astype(dtypes.get(item, float))
return (
dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims),
dict_to_dataset(
data_warmup, library=self.cmdstanpy, coords=self.coords, dims=self.dims
),
)


def _filter(var_names, spec):
if isinstance(spec, str):
var_names.remove(spec)
elif isinstance(spec, list):
for item in spec:
var_names.remove(item)
elif isinstance(spec, dict):
for item in spec.keys():
var_names.remove(item)
return var_names


def _unpack_fit(fit, itemss, save_warmup):
num_warmup = 0
if save_warmup:
if not fit._save_warmup: # pylint: disable=protected-access
save_warmup = False
else:
num_warmup = fit.num_draws_warmup

draws = np.swapaxes(fit.draws(inc_warmup=save_warmup), 0, 1)
sample = {}
sample_warmup = {}

for items in itemss:
if items in fit.stan_vars_cols:
col_idxs = fit.stan_vars_cols[items]
elif items in fit.sampler_vars_cols:
col_idxs = fit.sampler_vars_cols[items]
else:
raise ValueError("fit data, unknown variable: {}".format(items))
if save_warmup:
sample_warmup[items] = draws[:num_warmup, :, col_idxs]
sample[items] = draws[num_warmup:, :, col_idxs]
else:
sample[items] = draws[:, :, col_idxs]

return sample, sample_warmup


def _unpack_frame(fit, columns, valid_cols, save_warmup):
"""Transform fit to dictionary containing ndarrays.
Expand All @@ -414,7 +558,11 @@ def _unpack_frame(fit, columns, valid_cols, save_warmup):
if hasattr(fit, "draws"):
data = fit.draws(inc_warmup=save_warmup)
if save_warmup:
num_warmup = fit._draws_warmup # pylint: disable=protected-access
num_warmup = 0
if hasattr(fit, "_draws_warmup"): # pylint: disable=protected-access
num_warmup = fit._draws_warmup # pylint: disable=protected-access
else:
num_warmup = fit.num_draws_warmup
data_warmup = data[:num_warmup]
data = data[num_warmup:]
else:
Expand Down