Skip to content

Commit

Permalink
add workaround for data groups until next arviz release
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Mar 26, 2021
1 parent 1a604c3 commit d5726e7
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 31 deletions.
82 changes: 55 additions & 27 deletions pymc3/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from aesara.graph.basic import Constant
from aesara.tensor.sharedvar import SharedVariable
from arviz import InferenceData, concat, rcParams
from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires
from arviz.data.base import CoordSpec, DimSpec
from arviz.data.base import dict_to_dataset as _dict_to_dataset
from arviz.data.base import generate_dims_coords, make_attrs, requires

import pymc3

Expand Down Expand Up @@ -98,6 +100,37 @@ def insert(self, k: str, v, idx: int):
self.trace_dict[k][idx, :] = v


def dict_to_dataset(
data,
library=None,
coords=None,
dims=None,
attrs=None,
default_dims=None,
skip_event_dims=None,
index_origin=None,
):
"""Temporal workaround for dict_to_dataset.
Once ArviZ>0.11.2 release is available, only two changes are needed for everything to work.
1) this should be deleted, 2) dict_to_dataset should be imported as is from arviz, no underscore,
also remove unnecessary imports
"""
if default_dims is None:
return _dict_to_dataset(
data, library=library, coords=coords, dims=dims, skip_event_dims=skip_event_dims
)
else:
out_data = {}
for name, vals in data.items():
vals = np.atleast_1d(vals)
val_dims = dims.get(name)
val_dims, coords = generate_dims_coords(vals.shape, name, dims=val_dims, coords=coords)
coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims}
out_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=out_data, attrs=make_attrs(library=library))


class InferenceDataConverter: # pylint: disable=too-many-instance-attributes
"""Encapsulate InferenceData specific logic."""

Expand Down Expand Up @@ -196,14 +229,13 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
self.dims = {**model_dims, **self.dims}

self.density_dist_obs = density_dist_obs
self.observations, self.multi_observations = self.find_observations()
self.observations = self.find_observations()

def find_observations(self) -> Tuple[Optional[Dict[str, Var]], Optional[Dict[str, Var]]]:
def find_observations(self) -> Optional[Dict[str, Var]]:
"""If there are observations available, return them as a dictionary."""
if self.model is None:
return (None, None)
return None
observations = {}
multi_observations = {}
for obs in self.model.observed_RVs:
aux_obs = getattr(obs.tag, "observations", None)
if aux_obs is not None:
Expand All @@ -215,7 +247,7 @@ def find_observations(self) -> Tuple[Optional[Dict[str, Var]], Optional[Dict[str
else:
warnings.warn(f"No data for observation {obs}")

return observations, multi_observations
return observations

def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]:
"""Split MultiTrace object into posterior and warmup.
Expand Down Expand Up @@ -302,15 +334,15 @@ def posterior_to_xarray(self):
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
# index_origin=self.index_origin,
index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=pymc3,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
# index_origin=self.index_origin,
index_origin=self.index_origin,
),
)

Expand Down Expand Up @@ -344,15 +376,15 @@ def sample_stats_to_xarray(self):
dims=None,
coords=self.coords,
attrs=self.attrs,
# index_origin=self.index_origin,
index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=pymc3,
dims=None,
coords=self.coords,
attrs=self.attrs,
# index_origin=self.index_origin,
index_origin=self.index_origin,
),
)

Expand Down Expand Up @@ -385,15 +417,15 @@ def log_likelihood_to_xarray(self):
dims=self.dims,
coords=self.coords,
skip_event_dims=True,
# index_origin=self.index_origin,
index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=pymc3,
dims=self.dims,
coords=self.coords,
skip_event_dims=True,
# index_origin=self.index_origin,
index_origin=self.index_origin,
),
)

Expand All @@ -415,11 +447,7 @@ def translate_posterior_predictive_dict_to_xarray(self, dct) -> xr.Dataset:
k,
)
return dict_to_dataset(
data,
library=pymc3,
coords=self.coords,
# dims=self.dims,
# index_origin=self.index_origin
data, library=pymc3, coords=self.coords, dims=self.dims, index_origin=self.index_origin
)

@requires(["posterior_predictive"])
Expand Down Expand Up @@ -454,25 +482,25 @@ def priors_to_xarray(self):
{k: np.expand_dims(self.prior[k], 0) for k in var_names},
library=pymc3,
coords=self.coords,
# dims=self.dims,
# index_origin=self.index_origin,
dims=self.dims,
index_origin=self.index_origin,
)
)
return priors_dict

@requires(["observations", "multi_observations"])
@requires("observations")
@requires("model")
def observed_data_to_xarray(self):
"""Convert observed data to xarray."""
if self.predictions:
return None
return dict_to_dataset(
{**self.observations, **self.multi_observations},
self.observations,
library=pymc3,
coords=self.coords,
# dims=self.dims,
# default_dims=[],
# index_origin=self.index_origin,
dims=self.dims,
default_dims=[],
index_origin=self.index_origin,
)

@requires(["trace", "predictions"])
Expand Down Expand Up @@ -517,9 +545,9 @@ def is_data(name, var) -> bool:
constant_data,
library=pymc3,
coords=self.coords,
# dims=self.dims,
# default_dims=[],
# index_origin=self.index_origin,
dims=self.dims,
default_dims=[],
index_origin=self.index_origin,
)

def to_inference_data(self):
Expand Down
4 changes: 0 additions & 4 deletions pymc3/tests/test_idata_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,10 +570,6 @@ def test_multivariate_observations(self):


class TestPyMC3WarmupHandling:
@pytest.mark.skipif(
not hasattr(pm.backends.base.SamplerReport, "n_draws"),
reason="requires pymc3 3.9 or higher",
)
@pytest.mark.parametrize("save_warmup", [False, True])
@pytest.mark.parametrize("chains", [1, 2])
@pytest.mark.parametrize("tune,draws", [(0, 50), (10, 40), (30, 0)])
Expand Down

0 comments on commit d5726e7

Please sign in to comment.