Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add to_dict method to InferenceData object #1223

Merged
merged 18 commits into from
Aug 13, 2020
Merged
55 changes: 55 additions & 0 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,61 @@ def to_netcdf(self, filename, compress=True, groups=None):
empty_netcdf_file.close()
return filename

def to_dict(self, groups=None):
"""Convert InferenceData to a dictionary following xarray naming
conventions.

Parameters
----------
groups : list, optional
Write only these groups to netcdf file.

Returns
-------
dict
A dictionary containing all groups of InferenceData object.
When `data=False` return just the schema.
"""
ret = {}
ret.setdefault("coords", dict())
ret.setdefault("dims", dict())
ret.setdefault("pred_dims", dict())
percygautam marked this conversation as resolved.
Show resolved Hide resolved
attrs = None
if self._groups_all: # check's whether a group is present or not.
if groups is None:
groups = self._groups_all
else:
groups = [group for group in self._groups_all if group in groups]
percygautam marked this conversation as resolved.
Show resolved Hide resolved

for group in groups:
xr_data = getattr(self, group)
ds = xr_data.data_vars
data = {}
for key, value in ds.items():
data[key] = value.values
dims = []
for k_, v_ in value.coords.items():
percygautam marked this conversation as resolved.
Show resolved Hide resolved
if k_ not in ("chain", "draw") and not k_.startswith(key + "_dim_"):
dims.append(k_)
ret["coords"][k_] = v_.values

if group in ("predictions", "predictions_constant_data",):
dims_key = "pred_dims"
else:
dims_key = "dims"
if len(dims) > 0:
ret[dims_key][key] = dims
ret[group] = data
if attrs is None:
attrs = xr_data.attrs
elif attrs != xr_data.attrs:
warnings.warn(
"The attributes are not same for all groups. Considering only the first group `attrs`"
)

ret["attrs"] = attrs
return ret

def __add__(self, other):
"""Concatenate two InferenceData objects."""
return concat(self, other, copy=True, inplace=False)
Expand Down
63 changes: 42 additions & 21 deletions arviz/data/io_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def __init__(
constant_data=None,
predictions_constant_data=None,
coords=None,
dims=None
dims=None,
pred_dims=None,
attrs=None,
):
self.posterior = posterior
self.posterior_predictive = posterior_predictive
Expand All @@ -41,6 +43,10 @@ def __init__(
self.predictions_constant_data = predictions_constant_data
self.coords = coords
self.dims = dims
self.pred_dims = dims if pred_dims is None else pred_dims
self.attrs = {} if attrs is None else attrs
self.attrs.pop("created_at", None)
self.attrs.pop("arviz_version", None)

@requires("posterior")
def posterior_to_xarray(self):
Expand All @@ -56,7 +62,9 @@ def posterior_to_xarray(self):
UserWarning,
)

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
)

@requires("sample_stats")
def sample_stats_to_xarray(self):
Expand All @@ -73,7 +81,9 @@ def sample_stats_to_xarray(self):
PendingDeprecationWarning,
)

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
)

@requires("log_likelihood")
def log_likelihood_to_xarray(self):
Expand All @@ -82,7 +92,9 @@ def log_likelihood_to_xarray(self):
if not isinstance(data, dict):
raise TypeError("DictConverter.log_likelihood is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
)

@requires("posterior_predictive")
def posterior_predictive_to_xarray(self):
Expand All @@ -91,7 +103,9 @@ def posterior_predictive_to_xarray(self):
if not isinstance(data, dict):
raise TypeError("DictConverter.posterior_predictive is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
)

@requires("predictions")
def predictions_to_xarray(self):
Expand All @@ -100,7 +114,9 @@ def predictions_to_xarray(self):
if not isinstance(data, dict):
raise TypeError("DictConverter.predictions is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=None, coords=self.coords, dims=self.pred_dims, attrs=self.attrs
)

@requires("prior")
def prior_to_xarray(self):
Expand All @@ -109,7 +125,9 @@ def prior_to_xarray(self):
if not isinstance(data, dict):
raise TypeError("DictConverter.prior is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
)

@requires("sample_stats_prior")
def sample_stats_prior_to_xarray(self):
Expand All @@ -118,7 +136,9 @@ def sample_stats_prior_to_xarray(self):
if not isinstance(data, dict):
raise TypeError("DictConverter.sample_stats_prior is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
)

@requires("prior_predictive")
def prior_predictive_to_xarray(self):
Expand All @@ -127,17 +147,17 @@ def prior_predictive_to_xarray(self):
if not isinstance(data, dict):
raise TypeError("DictConverter.prior_predictive is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
)

def data_to_xarray(self, dct, group):
def data_to_xarray(self, dct, group, dims=None):
"""Convert data to xarray."""
data = dct
if not isinstance(data, dict):
raise TypeError("DictConverter.{} is not a dictionary".format(group))
if self.dims is None:
dims = {}
else:
dims = self.dims
if dims is None:
dims = {} if self.dims is None else self.dims
new_data = dict()
for key, vals in data.items():
vals = utils.one_de(vals)
Expand All @@ -146,12 +166,12 @@ def data_to_xarray(self, dct, group):
vals.shape, key, dims=val_dims, coords=self.coords
)
new_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=new_data, attrs=make_attrs(library=None))
return xr.Dataset(data_vars=new_data, attrs=make_attrs(attrs=self.attrs, library=None))

@requires("observed_data")
def observed_data_to_xarray(self):
"""Convert observed_data to xarray."""
return self.data_to_xarray(self.observed_data, group="observed_data")
return self.data_to_xarray(self.observed_data, group="observed_data", dims=self.dims)

@requires("constant_data")
def constant_data_to_xarray(self):
Expand All @@ -162,7 +182,7 @@ def constant_data_to_xarray(self):
def predictions_constant_data_to_xarray(self):
"""Convert predictions_constant_data to xarray."""
return self.data_to_xarray(
self.predictions_constant_data, group="predictions_constant_data"
self.predictions_constant_data, group="predictions_constant_data", dims=self.pred_dims
)

def to_inference_data(self):
Expand Down Expand Up @@ -203,13 +223,12 @@ def from_dict(
constant_data=None,
predictions_constant_data=None,
coords=None,
dims=None
dims=None,
pred_dims=None,
attrs=None,
percygautam marked this conversation as resolved.
Show resolved Hide resolved
):
"""Convert Dictionary data into an InferenceData object.

For a usage example read the
:doc:`Cookbook section on from_dict </notebooks/InferenceDataCookbook>`

percygautam marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
posterior : dict
Expand Down Expand Up @@ -247,4 +266,6 @@ def from_dict(
predictions_constant_data=predictions_constant_data,
coords=coords,
dims=dims,
pred_dims=pred_dims,
attrs=attrs,
).to_inference_data()