Skip to content

Commit

Permalink
support warmup groups in from_pymc3
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed May 4, 2020
1 parent 7ab50b0 commit 379c803
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 15 deletions.
70 changes: 55 additions & 15 deletions arviz/data/io_pymc3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .. import utils
from .inference_data import InferenceData, concat
from .base import requires, dict_to_dataset, generate_dims_coords, make_attrs, CoordSpec, DimSpec
from ..rcparams import rcParams

if TYPE_CHECKING:
import pymc3 as pm
Expand Down Expand Up @@ -60,7 +61,8 @@ def __init__(
predictions=None,
coords: Optional[Coords] = None,
dims: Optional[Dims] = None,
model=None
model=None,
save_warmup=None,
):
import pymc3
import theano
Expand All @@ -70,6 +72,7 @@ def __init__(
self.pymc3 = pymc3
self.theano = theano

self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
self.trace = trace

# this permits us to get the model from command-line argument or from with model:
Expand All @@ -86,9 +89,11 @@ def __init__(
0
].model
self.nchains = trace.nchains if hasattr(trace, "nchains") else 1
self.ndraws = len(trace)
self.ndraws = trace.report.n_draws
self.attrs = {"t_sampling": trace.report.t_sampling}
else:
self.nchains = self.ndraws = 0
self.attrs = None

if self.model is None:
warnings.warn(
Expand Down Expand Up @@ -151,7 +156,7 @@ def log_likelihood_vals_point(self, point, var, log_like_fun):

@requires("trace")
@requires("model")
def _extract_log_likelihood(self):
def _extract_log_likelihood(self, trace):
"""Compute log likelihood of each observation."""
# If we have predictions, then we have a thinned trace which does not
# support extracting a log likelihood.
Expand All @@ -165,18 +170,18 @@ def _extract_log_likelihood(self):
]
try:
log_likelihood_dict = self.pymc3.sampling._DefaultTrace( # pylint: disable=protected-access
len(self.trace.chains)
len(trace.chains)
)
except AttributeError:
raise AttributeError(
"Installed version of ArviZ requires PyMC3>=3.8. Please upgrade with "
"`pip install pymc3>=3.8` or `conda install -c conda-forge pymc3>=3.8`."
)
for var, log_like_fun in cached:
for chain in self.trace.chains:
for chain in trace.chains:
log_like_chain = [
self.log_likelihood_vals_point(point, var, log_like_fun)
for point in self.trace.points([chain])
for point in trace.points([chain])
]
log_likelihood_dict.insert(var.name, np.stack(log_like_chain), chain)
return log_likelihood_dict.trace_dict
Expand All @@ -188,21 +193,45 @@ def posterior_to_xarray(self):
self.trace.varnames, include_transformed=False
)
data = {}
data_warmup = {}
for var_name in var_names:
data[var_name] = np.array(self.trace.get_values(var_name, combine=False, squeeze=False))
return dict_to_dataset(data, library=self.pymc3, coords=self.coords, dims=self.dims)
if self.save_warmup:
data_warmup[var_name] = np.array(
self.trace[: -self.ndraws].get_values(var_name, combine=False, squeeze=False)
)
data[var_name] = np.array(
self.trace[-self.ndraws :].get_values(var_name, combine=False, squeeze=False)
)
return (
dict_to_dataset(data, library=self.pymc3, coords=self.coords, dims=self.dims),
dict_to_dataset(data_warmup, library=self.pymc3, coords=self.coords, dims=self.dims),
)

@requires("trace")
def sample_stats_to_xarray(self):
"""Extract sample_stats from PyMC3 trace."""
data = {}
rename_key = {"model_logp": "lp"}
data = {}
data_warmup = {}
for stat in self.trace.stat_names:
name = rename_key.get(stat, stat)
data[name] = np.array(self.trace.get_sampler_stats(stat, combine=False))

return dict_to_dataset(data, library=self.pymc3, dims=None, coords=self.coords)
if name == "tune":
continue
if self.save_warmup:
data_warmup[name] = np.array(
self.trace[: -self.ndraws].get_sampler_stats(stat, combine=False)
)
data[name] = np.array(self.trace[-self.ndraws :].get_sampler_stats(stat, combine=False))

return (
dict_to_dataset(
data, library=self.pymc3, dims=None, coords=self.coords, attrs=self.attrs
),
dict_to_dataset(
data_warmup, library=self.pymc3, dims=None, coords=self.coords, attrs=self.attrs
),
)

@requires("trace")
@requires("model")
Expand All @@ -211,14 +240,20 @@ def log_likelihood_to_xarray(self):
if self.predictions or not self.log_likelihood:
return None
try:
data = self._extract_log_likelihood()
data = self._extract_log_likelihood(self.trace[-self.ndraws :])
except TypeError:
warnings.warn(
"""Could not compute log_likelihood, it will be omitted.
Check your model object or set log_likelihood=False"""
)
return None
return dict_to_dataset(data, library=self.pymc3, dims=self.dims, coords=self.coords)
data_warmup = {}
if self.save_warmup:
data_warmup = self._extract_log_likelihood(self.trace[: -self.ndraws])
return (
dict_to_dataset(data, library=self.pymc3, dims=self.dims, coords=self.coords),
dict_to_dataset(data_warmup, library=self.pymc3, dims=self.dims, coords=self.coords),
)

def translate_posterior_predictive_dict_to_xarray(self, dct) -> xr.Dataset:
"""Take Dict of variables to numpy ndarrays (samples) and translate into dataset."""
Expand Down Expand Up @@ -375,7 +410,7 @@ def to_inference_data(self):
id_dict["predictions_constant_data"] = self.constant_data_to_xarray()
else:
id_dict["constant_data"] = self.constant_data_to_xarray()
return InferenceData(**id_dict)
return InferenceData(save_warmup=self.save_warmup, **id_dict)


def from_pymc3(
Expand All @@ -386,7 +421,8 @@ def from_pymc3(
log_likelihood: Union[bool, Iterable[str]] = True,
coords: Optional[CoordSpec] = None,
dims: Optional[DimSpec] = None,
model: Optional[Model] = None
model: Optional[Model] = None,
save_warmup: Optional[bool] = None,
) -> InferenceData:
"""Convert pymc3 data into an InferenceData object.
Expand All @@ -413,6 +449,9 @@ def from_pymc3(
model : pymc3.Model, optional
Model used to generate ``trace``. It is not necessary to pass ``model`` if in
``with`` context.
save_warmup : bool, optional
Save warmup iterations InferenceData object. If not defined, use default
defined by the rcParams.
Returns
-------
Expand All @@ -426,6 +465,7 @@ def from_pymc3(
coords=coords,
dims=dims,
model=model,
save_warmup=save_warmup,
).to_inference_data()


Expand Down
1 change: 1 addition & 0 deletions scripts/container.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ if [[ $* == *--clear-cache* ]]; then
echo "Removing cached files and models"
find -type d -name __pycache__ -exec rm -rf {} +
rm -f arviz/tests/saved_models/*.pkl
rm -f arviz/tests/saved_models/*.pkl.gzip
fi

# Build container for use of testing or notebook
Expand Down

0 comments on commit 379c803

Please sign in to comment.