From 956bf19bfc32cba87c752d5c47e43b37103e8fbd Mon Sep 17 00:00:00 2001 From: percygautam Date: Mon, 20 Jan 2020 02:41:41 +0530 Subject: [PATCH 01/17] Added api link feature to gallery examples --- doc/sphinxext/gallery_generator.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/doc/sphinxext/gallery_generator.py b/doc/sphinxext/gallery_generator.py index 391cee8879..a089f46030 100644 --- a/doc/sphinxext/gallery_generator.py +++ b/doc/sphinxext/gallery_generator.py @@ -40,6 +40,7 @@ def execfile(filename, globals=None, locals=None): .. image:: {img_file} **Python source code:** :download:`[download source: {fname}]<{fname}>` +**API documentation:** `{api_name} <../../generated/arviz.{api_name}>`_ .. literalinclude:: {fname} :lines: {end_line}- @@ -54,6 +55,7 @@ def execfile(filename, globals=None, locals=None): :source-position: none **Python source code:** :download:`[download source: {fname}]<{fname}>` +**API documentation:** `{api_name} <../../generated/arviz.{api_name}>`_ .. literalinclude:: {fname} :lines: {end_line}- @@ -236,6 +238,12 @@ def thumbfilename(self): pngfile = self.modulename + "_thumb.png" return pngfile + @property + def apiname(self): + name = self.modulename.split("_") + name = name[1::] + return "_".join(name) + @property def sphinxtag(self): return self.modulename @@ -320,6 +328,7 @@ def contents_entry(self): ".. raw:: html\n\n" "
\n" " \n" + " \n" " \n" "

{3}

\n" @@ -380,6 +389,7 @@ def main(app): fname=ex.pyfilename, absfname=op.join(target_dir, ex.pyfilename), img_file=ex.pngfilename, + api_name=ex.apiname, ) with open(op.join(target_dir, ex.rstfilename), "w") as f: f.write(output) From 0dfb396b31b0e782a8da52e0ccc29e005a5cebcc Mon Sep 17 00:00:00 2001 From: percygautam Date: Mon, 20 Jan 2020 02:50:53 +0530 Subject: [PATCH 02/17] Added api link feature to gallery examples --- doc/sphinxext/gallery_generator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/doc/sphinxext/gallery_generator.py b/doc/sphinxext/gallery_generator.py index a089f46030..458835d9f2 100644 --- a/doc/sphinxext/gallery_generator.py +++ b/doc/sphinxext/gallery_generator.py @@ -328,7 +328,6 @@ def contents_entry(self): ".. raw:: html\n\n" "
\n" " \n" - " \n" " \n" "

{3}

\n" From 2057b29ccd4eddc79ffb304d99b656227430f37c Mon Sep 17 00:00:00 2001 From: percygautam Date: Mon, 20 Jan 2020 23:42:09 +0530 Subject: [PATCH 03/17] Removed error for getting api links --- doc/sphinxext/gallery_generator.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/doc/sphinxext/gallery_generator.py b/doc/sphinxext/gallery_generator.py index 458835d9f2..a5b9f8c017 100644 --- a/doc/sphinxext/gallery_generator.py +++ b/doc/sphinxext/gallery_generator.py @@ -240,9 +240,13 @@ def thumbfilename(self): @property def apiname(self): - name = self.modulename.split("_") - name = name[1::] - return "_".join(name) + name="" + with open(op.join(self.target_dir, self.pyfilename), "r") as file: + regex = r"az\.(plot\_[a-z_]+)\(" + matches = re.finditer(regex, file.read(), re.MULTILINE) + for matchNum, match in enumerate(matches, start=1): + name = match.group(1) + return name @property def sphinxtag(self): From 1960c82ddff9ebe6d4087549fbb321a3da37987e Mon Sep 17 00:00:00 2001 From: rpgoldman Date: Sun, 19 Jan 2020 16:47:42 -0600 Subject: [PATCH 04/17] Populate InferenceData with out-of-sample prediction results from PyMC3 predictive samples (#983) Adds from_pymc3_predictions to add predictions and constant_data_predictions groups of inference data objects. Co-authored-by: Oriol Abril --- CHANGELOG.md | 2 +- arviz/data/__init__.py | 3 +- arviz/data/base.py | 2 +- arviz/data/inference_data.pyi | 77 ++++++++++ arviz/data/io_pymc3.py | 257 ++++++++++++++++++++++++++-------- arviz/tests/test_data_pymc.py | 153 +++++++++++++++++++- 6 files changed, 430 insertions(+), 64 deletions(-) create mode 100644 arviz/data/inference_data.pyi diff --git a/CHANGELOG.md b/CHANGELOG.md index 50ba8c05bd..f396080571 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,7 @@ ## v0.x.x Unreleased ### New features - +* Add out-of-sample predictions (`predictions` and `predictions_constant_data` groups) to pymc3 translations. (#983) ### Maintenance and fixes * Fixed bug in extracting prior samples for cmdstanpy. (#979) diff --git a/arviz/data/__init__.py b/arviz/data/__init__.py index 5226a41632..96372455d3 100644 --- a/arviz/data/__init__.py +++ b/arviz/data/__init__.py @@ -7,7 +7,7 @@ from .io_cmdstan import from_cmdstan from .io_cmdstanpy import from_cmdstanpy from .io_dict import from_dict -from .io_pymc3 import from_pymc3 +from .io_pymc3 import from_pymc3, from_pymc3_predictions from .io_pystan import from_pystan from .io_emcee import from_emcee from .io_pyro import from_pyro @@ -25,6 +25,7 @@ "convert_to_dataset", "convert_to_inference_data", "from_pymc3", + "from_pymc3_predictions", "from_pystan", "from_emcee", "from_cmdstan", diff --git a/arviz/data/base.py b/arviz/data/base.py index 5adf149e2c..995037715f 100644 --- a/arviz/data/base.py +++ b/arviz/data/base.py @@ -18,7 +18,7 @@ class requires: # pylint: disable=invalid-name If the decorator is called various times on the same function with different attributes, it will return None if one of them is missing. If instead a list of attributes is passed, it will return None if all attributes in the list are - missing. Both functionalities can be combines as desired. + missing. Both functionalities can be combined as desired. """ def __init__(self, *props): diff --git a/arviz/data/inference_data.pyi b/arviz/data/inference_data.pyi new file mode 100644 index 0000000000..9fdb660dce --- /dev/null +++ b/arviz/data/inference_data.pyi @@ -0,0 +1,77 @@ +from typing import Optional, List, overload, Iterable, TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import Literal +import xarray as xr + +# pylint has some problems with stub files... +# pylint: disable=unused-argument, multiple-statements + +class InferenceData: + posterior: Optional[xr.Dataset] + observations: Optional[xr.Dataset] + constant_data: Optional[xr.Dataset] + prior: Optional[xr.Dataset] + prior_predictive: Optional[xr.Dataset] + posterior_predictive: Optional[xr.Dataset] + predictions: Optional[xr.Dataset] + predictions_constant_data: Optional[xr.Dataset] + def __init__(self, **kwargs): ... + def __repr__(self) -> str: ... + def __delattr__(self, group: str) -> None: ... + def __add__(self, other: "InferenceData"): ... + @staticmethod + def from_netcdf(filename: str) -> "InferenceData": ... + def to_netcdf( + self, + filename: str, + compress: bool = True, + groups: Optional[List[str]] = None, # pylint: disable=line-too-long + ) -> str: ... + def sel( + self, inplace: bool = False, chain_prior: bool = False, **kwargs + ) -> "InferenceData": ... + +@overload +def concat( + *args, + dim: Optional[str] = None, + copy: bool = True, + inplace: "Literal[True]", + reset_dim: bool = True, +) -> None: ... +@overload +def concat( + *args, + dim: Optional[str] = None, + copy: bool = True, + inplace: "Literal[False]", + reset_dim: bool = True, +) -> InferenceData: ... +@overload +def concat( + ids: Iterable[InferenceData], + dim: Optional[str] = None, + *, + copy: bool = True, + inplace: "Literal[False]", + reset_dim: bool = True, +) -> InferenceData: ... +@overload +def concat( + ids: Iterable[InferenceData], + dim: Optional[str] = None, + *, + copy: bool = True, + inplace: "Literal[True]", + reset_dim: bool = True, +) -> None: ... +@overload +def concat( + ids: Iterable[InferenceData], + dim: Optional[str] = None, + *, + copy: bool = True, + inplace: bool = False, + reset_dim: bool = True, +) -> Optional[InferenceData]: ... diff --git a/arviz/data/io_pymc3.py b/arviz/data/io_pymc3.py index ec7c29ad4c..0fc7259484 100644 --- a/arviz/data/io_pymc3.py +++ b/arviz/data/io_pymc3.py @@ -1,29 +1,52 @@ """PyMC3-specific conversion code.""" import logging from typing import Dict, List, Any, Optional, TYPE_CHECKING +from types import ModuleType import numpy as np import xarray as xr from .. import utils -from .inference_data import InferenceData +from .inference_data import InferenceData, concat from .base import requires, dict_to_dataset, generate_dims_coords, make_attrs if TYPE_CHECKING: import pymc3 as pm + from pymc3 import MultiTrace, Model # pylint: disable=invalid-name + import theano + from typing import Set # pylint: disable=ungrouped-imports +else: + MultiTrace = Any # pylint: disable=invalid-name + Model = Any # pylint: disable=invalid-name + +___all__ = [""] _log = logging.getLogger(__name__) Coords = Dict[str, List[Any]] Dims = Dict[str, List[str]] +# random variable object ... +Var = Any # pylint: disable=invalid-name + + +def _monkey_patch_pymc3(pm: ModuleType) -> None: # pylint: disable=invalid-name + assert pm.__name__ == "pymc3" + def fixed_eq(self, other): + """Use object identity for MultiObservedRV equality.""" + return self is other -class PyMC3Converter: + if tuple([int(x) for x in pm.__version__.split(".")]) < (3, 9): # type: ignore + pm.model.MultiObservedRV.__eq__ = fixed_eq # type: ignore + + +class PyMC3Converter: # pylint: disable=too-many-instance-attributes """Encapsulate PyMC3 specific logic.""" model = None # type: Optional[pm.Model] nchains = None # type: int ndraws = None # type: int posterior_predictive = None # Type: Optional[Dict[str, np.ndarray]] + predictions = None # Type: Optional[Dict[str, np.ndarray]] prior = None # Type: Optional[Dict[str, np.ndarray]] def __init__( @@ -32,34 +55,40 @@ def __init__( trace=None, prior=None, posterior_predictive=None, + predictions=None, coords: Optional[Coords] = None, dims: Optional[Dims] = None, model=None ): import pymc3 + import theano + + _monkey_patch_pymc3(pymc3) self.pymc3 = pymc3 + self.theano = theano self.trace = trace + # this permits us to get the model from command-line argument or from with model: + try: + self.model = self.pymc3.modelcontext(model or self.model) + except TypeError: + self.model = None + # This next line is brittle and may not work forever, but is a secret # way to access the model from the trace. if trace is not None: - self.model = self.trace._straces[0].model # pylint: disable=protected-access + if self.model is None: + self.model = self.trace._straces[0].model # pylint: disable=protected-access self.nchains = trace.nchains if hasattr(trace, "nchains") else 1 self.ndraws = len(trace) else: - self.model = None self.nchains = self.ndraws = 0 - # this permits us to get the model from command-line argument or from with model: - try: - self.model = self.pymc3.modelcontext(model or self.model) - except TypeError: - self.model = None - self.prior = prior self.posterior_predictive = posterior_predictive + self.predictions = predictions def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray: return next(iter(dct.values())) @@ -68,29 +97,38 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray: # if you have a posterior_predictive built with keep_dims, # you'll lose here, but there's nothing I can do about that. self.nchains = 1 - aelem = ( - arbitrary_element(prior) - if posterior_predictive is None - else arbitrary_element(posterior_predictive) - ) + get_from = None + if predictions is not None: + get_from = predictions + elif prior is not None: + get_from = prior + elif posterior_predictive is not None: + get_from = posterior_predictive + if get_from is None: + # pylint: disable=line-too-long + raise ValueError( + """When constructing InferenceData must have at least + one of trace, prior, posterior_predictive or predictions.""" + ) + + aelem = arbitrary_element(get_from) self.ndraws = aelem.shape[0] self.coords = coords self.dims = dims - self.observations = ( - None - if self.trace is None - else True - if any( - hasattr(obs, "observations") - for obs in self.trace._straces[ # pylint: disable=protected-access - 0 - ].model.observed_RVs - ) - else None - ) - if self.observations is not None: - self.observations = {obs.name: obs.observations for obs in self.model.observed_RVs} + self.observations = self.find_observations() + + def find_observations(self) -> Optional[Dict[str, Var]]: + """If there are observations available, return them as a dictionary.""" + has_observations = False + if self.trace is not None: + assert self.model is not None, "Cannot identify observations without PymC3 model" + if any((hasattr(obs, "observations") for obs in self.model.observed_RVs)): + has_observations = True + if has_observations: + assert self.model is not None + return {obs.name: obs.observations for obs in self.model.observed_RVs} + return None @requires("trace") @requires("model") @@ -99,7 +137,9 @@ def _extract_log_likelihood(self): Return None if there is not exactly 1 observed random variable. """ - if len(self.model.observed_RVs) != 1: + # If we have predictions, then we have a thinned trace which does not + # support extracting a log likelihood. + if len(self.model.observed_RVs) != 1 or self.predictions: return None, None else: if self.dims is not None: @@ -155,11 +195,10 @@ def sample_stats_to_xarray(self): return dict_to_dataset(data, library=self.pymc3, dims=dims, coords=self.coords) - @requires("posterior_predictive") - def posterior_predictive_to_xarray(self): - """Convert posterior_predictive samples to xarray.""" + def translate_posterior_predictive_dict_to_xarray(self, dct) -> xr.Dataset: + """Take Dict of variables to numpy ndarrays (samples) and translate into dataset.""" data = {} - for k, ary in self.posterior_predictive.items(): + for k, ary in dct.items(): shape = ary.shape if shape[0] == self.nchains and shape[1] == self.ndraws: data[k] = ary @@ -167,12 +206,24 @@ def posterior_predictive_to_xarray(self): data[k] = ary.reshape((self.nchains, self.ndraws, *shape[1:])) else: data[k] = utils.expand_dims(ary) + # pylint: disable=line-too-long _log.warning( - "posterior predictive shape not compatible with number of chains and draws. " - "This can mean that some draws or even whole chains are not represented." + "posterior predictive variable %s's shape not compatible with number of chains and draws. " + "This can mean that some draws or even whole chains are not represented.", + k, ) return dict_to_dataset(data, library=self.pymc3, coords=self.coords, dims=self.dims) + @requires(["posterior_predictive"]) + def posterior_predictive_to_xarray(self): + """Convert posterior_predictive samples to xarray.""" + return self.translate_posterior_predictive_dict_to_xarray(self.posterior_predictive) + + @requires(["predictions"]) + def predictions_to_xarray(self): + """Convert predictions (out of sample predictions) to xarray.""" + return self.translate_posterior_predictive_dict_to_xarray(self.predictions) + def priors_to_xarray(self): """Convert prior samples (and if possible prior predictive too) to xarray.""" if self.prior is None: @@ -224,21 +275,34 @@ def observed_data_to_xarray(self): observed_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords) return xr.Dataset(data_vars=observed_data, attrs=make_attrs(library=self.pymc3)) - @requires("trace") + @requires(["trace", "predictions"]) @requires("model") def constant_data_to_xarray(self): """Convert constant data to xarray.""" - model_vars = self.pymc3.util.get_default_varnames( # pylint: disable=no-member - self.trace.varnames, include_transformed=True - ) - if self.observations is not None: - model_vars.extend( - [obs.name for obs in self.observations.values() if hasattr(obs, "name")] + # For constant data, we are concerned only with deterministics and data. + # The constant data vars must be either pm.Data (TensorSharedVariable) or pm.Deterministic + constant_data_vars = {} # type: Dict[str, Var] + for var in self.model.deterministics: + ancestors = self.theano.tensor.gof.graph.ancestors(var.owner.inputs) + # no dependency on a random variable + if not any((isinstance(a, self.pymc3.model.PyMC3Variable) for a in ancestors)): + constant_data_vars[var.name] = var + + def is_data(name, var) -> bool: + assert self.model is not None + return ( + var not in self.model.deterministics + and var not in self.model.observed_RVs + and var not in self.model.free_RVs + and (self.observations is None or name not in self.observations) ) - model_vars.extend(self.observations.keys()) - constant_data_vars = { - name: var for name, var in self.model.named_vars.items() if name not in model_vars - } + + # I don't know how to find pm.Data, except that they are named variables that aren't + # observed or free RVs, nor are they deterministics, and then we eliminate observations. + for name, var in self.model.named_vars.items(): + if is_data(name, var): + constant_data_vars[name] = var + if not constant_data_vars: return None if self.dims is None: @@ -249,6 +313,9 @@ def constant_data_to_xarray(self): for name, vals in constant_data_vars.items(): if hasattr(vals, "get_value"): vals = vals.get_value() + # this might be a Deterministic, and must be evaluated + elif hasattr(self.model[name], "eval"): + vals = self.model[name].eval() vals = np.atleast_1d(vals) val_dims = dims.get(name) val_dims, coords = generate_dims_coords( @@ -256,26 +323,32 @@ def constant_data_to_xarray(self): ) # filter coords based on the dims coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims} - constant_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords) + try: + constant_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords) + except ValueError as e: # pylint: disable=invalid-name + raise ValueError("Error translating constant_data variable %s: %s" % (name, e)) return xr.Dataset(data_vars=constant_data, attrs=make_attrs(library=self.pymc3)) def to_inference_data(self): """Convert all available data to an InferenceData object. - Note that if groups can not be created (i.e., there is no `trace`, so + Note that if groups can not be created (e.g., there is no `trace`, so the `posterior` and `sample_stats` can not be extracted), then the InferenceData will not have those groups. """ - return InferenceData( - **{ - "posterior": self.posterior_to_xarray(), - "sample_stats": self.sample_stats_to_xarray(), - "posterior_predictive": self.posterior_predictive_to_xarray(), - **self.priors_to_xarray(), - "observed_data": self.observed_data_to_xarray(), - "constant_data": self.constant_data_to_xarray(), - } - ) + id_dict = { + "posterior": self.posterior_to_xarray(), + "sample_stats": self.sample_stats_to_xarray(), + "posterior_predictive": self.posterior_predictive_to_xarray(), + "predictions": self.predictions_to_xarray(), + **self.priors_to_xarray(), + "observed_data": self.observed_data_to_xarray(), + } + if self.predictions: + id_dict["predictions_constant_data"] = self.constant_data_to_xarray() + else: + id_dict["constant_data"] = self.constant_data_to_xarray() + return InferenceData(**id_dict) def from_pymc3( @@ -290,3 +363,69 @@ def from_pymc3( dims=dims, model=model, ).to_inference_data() + + +### Later I could have this return ``None`` if the ``idata_orig`` argument is supplied. But +### perhaps we should have an inplace argument? +def from_pymc3_predictions( + predictions, + posterior_trace: Optional[MultiTrace] = None, + model: Optional[Model] = None, + coords=None, + dims=None, + idata_orig: Optional[InferenceData] = None, + inplace: bool = False, +) -> InferenceData: + """Translate out-of-sample predictions into ``InferenceData``. + + Parameters + ---------- + predictions: Dict[str, np.ndarray] + The predictions are the return value of ``pymc3.sample_posterior_predictive``, + a dictionary of strings (variable names) to numpy ndarrays (draws). + posterior_trace: pm.MultiTrace + This should be a trace that has been thinned appropriately for + ``pymc3.sample_posterior_predictive``. Specifically, any variable whose shape is + a deterministic function of the shape of any predictor (explanatory, independent, etc.) + variables must be *removed* from this trace. + model: pymc3.Model + This argument is *not* optional, unlike in conventional uses of ``from_pymc3``. + The reason is that the posterior_trace argument is likely to supply an incorrect + value of model. + coords: Dict[str, array-like[Any]] + Coordinates for the variables. Map from coordinate names to coordinate values. + dims: Dict[str, array-like[str]] + Map from variable name to ordered set of coordinate names. + idata_orig: InferenceData, optional + If supplied, then modify this inference data in place, adding ``predictions`` and + (if available) ``predictions_constant_data`` groups. If this is not supplied, make a + fresh InferenceData + inplace: boolean, optional + If idata_orig is supplied and inplace is True, merge the predictions into idata_orig, + rather than returning a fresh InferenceData object. + + Returns + ------- + InferenceData: + May be modified ``idata_orig``. + """ + if inplace and not idata_orig: + raise ValueError( + ( + "Do not pass True for inplace unless passing" + "an existing InferenceData as idata_orig" + ) + ) + new_idata = PyMC3Converter( + trace=posterior_trace, predictions=predictions, model=model, coords=coords, dims=dims + ).to_inference_data() + if idata_orig is None: + return new_idata + elif inplace: + concat([idata_orig, new_idata], dim=None, inplace=True) + return idata_orig + else: + # if we are not returning in place, then merge the old groups into the new inference + # data and return that. + concat([new_idata, idata_orig], dim=None, copy=True, inplace=True) + return new_idata diff --git a/arviz/tests/test_data_pymc.py b/arviz/tests/test_data_pymc.py index c91aa40960..cfbce630f5 100644 --- a/arviz/tests/test_data_pymc.py +++ b/arviz/tests/test_data_pymc.py @@ -1,10 +1,14 @@ # pylint: disable=no-member, invalid-name, redefined-outer-name +from sys import version_info +from typing import Tuple, Dict +import pytest + + import numpy as np from numpy import ma import pymc3 as pm -import pytest -from arviz import from_pymc3 +from arviz import from_pymc3, from_pymc3_predictions, InferenceData from .helpers import ( # pylint: disable=unused-import chains, check_multiple_attrs, @@ -38,6 +42,41 @@ def get_inference_data(self, data, eight_schools_params): posterior_predictive, ) + def get_predictions_inference_data( + self, data, eight_schools_params, inplace + ) -> Tuple[InferenceData, Dict[str, np.ndarray]]: + with data.model: + prior = pm.sample_prior_predictive() + posterior_predictive = pm.sample_posterior_predictive(data.obj) + + idata = from_pymc3( + trace=data.obj, + prior=prior, + coords={"school": np.arange(eight_schools_params["J"])}, + dims={"theta": ["school"], "eta": ["school"]}, + ) + assert isinstance(idata, InferenceData) + extended = from_pymc3_predictions( + posterior_predictive, idata_orig=idata, inplace=inplace + ) + assert isinstance(extended, InferenceData) + assert (id(idata) == id(extended)) == inplace + return (extended, posterior_predictive) + + def make_predictions_inference_data( + self, data, eight_schools_params + ) -> Tuple[InferenceData, Dict[str, np.ndarray]]: + with data.model: + posterior_predictive = pm.sample_posterior_predictive(data.obj) + idata = from_pymc3_predictions( + posterior_predictive, + posterior_trace=data.obj, + coords={"school": np.arange(eight_schools_params["J"])}, + dims={"theta": ["school"], "eta": ["school"]}, + ) + assert isinstance(idata, InferenceData) + return idata, posterior_predictive + def test_from_pymc(self, data, eight_schools_params, chains, draws): inference_data, posterior_predictive = self.get_inference_data(data, eight_schools_params) test_dict = { @@ -57,6 +96,66 @@ def test_from_pymc(self, data, eight_schools_params, chains, draws): np.isclose(ivalues[chain], values[chain * draws : (chain + 1) * draws]) ) + def test_from_pymc_predictions(self, data, eight_schools_params): + "Test that we can add predictions to a previously-existing InferenceData." + test_dict = { + "posterior": ["mu", "tau", "eta", "theta"], + "sample_stats": ["diverging", "log_likelihood"], + "predictions": ["obs"], + "prior": ["mu", "tau", "eta", "theta"], + "observed_data": ["obs"], + } + + # check adding non-destructively + inference_data, posterior_predictive = self.get_predictions_inference_data( + data, eight_schools_params, False + ) + fails = check_multiple_attrs(test_dict, inference_data) + assert not fails + for key, values in posterior_predictive.items(): + ivalues = inference_data.predictions[key] + assert ivalues.shape[0] == 1 # one chain in predictions + assert np.all(np.isclose(ivalues[0], values)) + + # check adding in place + inference_data, posterior_predictive = self.get_predictions_inference_data( + data, eight_schools_params, True + ) + fails = check_multiple_attrs(test_dict, inference_data) + assert not fails + for key, values in posterior_predictive.items(): + ivalues = inference_data.predictions[key] + assert ivalues.shape[0] == 1 # one chain in predictions + assert np.all(np.isclose(ivalues[0], values)) + + def test_from_pymc_predictions_new(self, data, eight_schools_params): + # check creating new + inference_data, posterior_predictive = self.make_predictions_inference_data( + data, eight_schools_params + ) + test_dict = { + "posterior": ["mu", "tau", "eta", "theta"], + "predictions": ["obs"], + "observed_data": ["obs"], + } + fails = check_multiple_attrs(test_dict, inference_data) + assert not fails + for key, values in posterior_predictive.items(): + ivalues = inference_data.predictions[key] + # could the following better be done by simply flattening both the ivalues + # and the values? + if len(ivalues.shape) == 3: + ivalues_arr = np.reshape( + ivalues.values, (ivalues.shape[0] * ivalues.shape[1], ivalues.shape[2]) + ) + elif len(ivalues.shape) == 2: + ivalues_arr = np.reshape(ivalues.values, (ivalues.shape[0] * ivalues.shape[1])) + else: + raise ValueError("Unexpected values shape for variable %s" % key) + assert (ivalues.shape[0] == 2) and (ivalues.shape[1] == 500) + assert values.shape[0] == 1000 + assert np.all(np.isclose(ivalues_arr, values)) + def test_posterior_predictive_keep_size(self, data, chains, draws, eight_schools_params): with data.model: posterior_predictive = pm.sample_posterior_predictive(data.obj, keep_size=True) @@ -119,6 +218,9 @@ def test_multiple_observed_rv(self): assert not fails assert not hasattr(inference_data.sample_stats, "log_likelihood") + @pytest.mark.skipif( + version_info < (3, 6), reason="Requires updated PyMC3, which needs Python 3.6" + ) def test_multiple_observed_rv_without_observations(self): with pm.Model(): mu = pm.Normal("mu") @@ -154,6 +256,53 @@ def test_constant_data(self): fails = check_multiple_attrs(test_dict, inference_data) assert not fails + def test_constant_data_with_model_context(self): + with pm.Model(): + x = pm.Data("x", [1.0, 2.0, 3.0]) + y = pm.Data("y", [1.0, 2.0, 3.0]) + beta = pm.Normal("beta", 0, 1) + obs = pm.Normal("obs", x * beta, 1, observed=y) # pylint: disable=unused-variable + trace = pm.sample(100, tune=100) + + inference_data = from_pymc3(trace=trace) + test_dict = {"posterior": ["beta"], "observed_data": ["obs"], "constant_data": ["x"]} + fails = check_multiple_attrs(test_dict, inference_data) + assert not fails + + def test_predictions_constant_data(self): + with pm.Model(): + x = pm.Data("x", [1.0, 2.0, 3.0]) + y = pm.Data("y", [1.0, 2.0, 3.0]) + beta = pm.Normal("beta", 0, 1) + obs = pm.Normal("obs", x * beta, 1, observed=y) # pylint: disable=unused-variable + trace = pm.sample(100, tune=100) + + inference_data = from_pymc3(trace=trace) + test_dict = {"posterior": ["beta"], "observed_data": ["obs"], "constant_data": ["x"]} + fails = check_multiple_attrs(test_dict, inference_data) + assert not fails + + with pm.Model(): + x = pm.Data("x", [1.0, 2.0]) + y = pm.Data("y", [1.0, 2.0]) + beta = pm.Normal("beta", 0, 1) + obs = pm.Normal("obs", x * beta, 1, observed=y) # pylint: disable=unused-variable + predictive_trace = pm.sample_posterior_predictive(trace) + assert set(predictive_trace.keys()) == {"obs"} + # this should be four chains of 100 samples + # assert predictive_trace["obs"].shape == (400, 2) + # but the shape seems to vary between pymc3 versions + inference_data = from_pymc3_predictions(predictive_trace, posterior_trace=trace) + test_dict = {"posterior": ["beta"], "observed_data": ["obs"]} + fails = check_multiple_attrs(test_dict, inference_data) + assert not fails, "Posterior data not copied over as expected." + test_dict = {"predictions": ["obs"]} + fails = check_multiple_attrs(test_dict, inference_data) + assert not fails, "Predictions not instantiated as expected." + test_dict = {"predictions_constant_data": ["x"]} + fails = check_multiple_attrs(test_dict, inference_data) + assert not fails, "Predictions constant data not instantiated as expected." + def test_no_trace(self): with pm.Model(): x = pm.Data("x", [1.0, 2.0, 3.0]) From 8b2b091b10490fa291df2154968f0d3a89b9b724 Mon Sep 17 00:00:00 2001 From: Osvaldo Martin Date: Mon, 20 Jan 2020 09:49:12 -0300 Subject: [PATCH 05/17] Violinplot: fix histogram, add rug (#997) * fix histogram, add rug * add rug to bokeh * remove redundant line, make bokeh plot looks closer to matplotlib, fix scale jitter --- arviz/plots/backends/bokeh/violinplot.py | 58 ++++++++++++++----- arviz/plots/backends/matplotlib/violinplot.py | 41 +++++++++---- arviz/plots/violinplot.py | 35 +++++++++-- 3 files changed, 100 insertions(+), 34 deletions(-) diff --git a/arviz/plots/backends/bokeh/violinplot.py b/arviz/plots/backends/bokeh/violinplot.py index 3dcf8d4c05..5f97b2284f 100644 --- a/arviz/plots/backends/bokeh/violinplot.py +++ b/arviz/plots/backends/bokeh/violinplot.py @@ -17,9 +17,12 @@ def plot_violin( figsize, rows, cols, + sharex, sharey, - kwargs_shade, + shade_kwargs, shade, + rug, + rug_kwargs, bw, credible_interval, linewidth, @@ -40,6 +43,7 @@ def plot_violin( len(plotters), rows, cols, + sharex=sharex, sharey=sharey, figsize=figsize, squeeze=False, @@ -54,17 +58,30 @@ def plot_violin( ): val = x.flatten() if val[0].dtype.kind == "i": - cat_hist(val, shade, ax_, **kwargs_shade) + dens = cat_hist(val, rug, shade, ax_, **shade_kwargs) else: - _violinplot(val, shade, bw, ax_, **kwargs_shade) + dens = _violinplot(val, rug, shade, bw, ax_, **shade_kwargs) + + if rug: + rug_x = -np.abs(np.random.normal(scale=max(dens) / 3.5, size=len(val))) + ax_.scatter(rug_x, val, **rug_kwargs) per = np.percentile(val, [25, 75, 50]) hpd_intervals = hpd(val, credible_interval, multimodal=False) if quartiles: - ax_.line([0, 0], per[:2], line_width=linewidth * 3, line_color="black") - ax_.line([0, 0], hpd_intervals, line_width=linewidth, line_color="black") - ax_.circle(0, per[-1]) + ax_.line( + [0, 0], per[:2], line_width=linewidth * 3, line_color="black", line_cap="round" + ) + ax_.line([0, 0], hpd_intervals, line_width=linewidth, line_color="black", line_cap="round") + ax_.circle( + 0, + per[-1], + line_color="white", + fill_color="white", + size=linewidth * 1.5, + line_width=linewidth, + ) _title = Title() _title.text = make_label(var_name, selection) @@ -80,35 +97,44 @@ def plot_violin( return ax -def _violinplot(val, shade, bw, ax, **kwargs_shade): +def _violinplot(val, rug, shade, bw, ax, **shade_kwargs): """Auxiliary function to plot violinplots.""" density, low_b, up_b = _fast_kde(val, bw=bw) x = np.linspace(low_b, up_b, len(density)) - x = np.concatenate([x, x[::-1]]) - density = np.concatenate([-density, density[::-1]]) + if not rug: + x = np.concatenate([x, x[::-1]]) + density = np.concatenate([-density, density[::-1]]) + + ax.harea(y=x, x1=density, x2=np.zeros_like(density), fill_alpha=shade, **shade_kwargs) - ax.patch(density, x, fill_alpha=shade, line_width=0, **kwargs_shade) + return density -def cat_hist(val, shade, ax, **kwargs_shade): +def cat_hist(val, rug, shade, ax, **shade_kwargs): """Auxiliary function to plot discrete-violinplots.""" bins = get_bins(val) _, binned_d, _ = histogram(val, bins=bins) bin_edges = np.linspace(np.min(val), np.max(val), len(bins)) - centers = 0.5 * (bin_edges + np.roll(bin_edges, 1))[:-1] heights = np.diff(bin_edges) + centers = bin_edges[:-1] + heights.mean() / 2 + right = 0.5 * binned_d - lefts = -0.5 * binned_d + if rug: + left = 0 + else: + left = -right ax.hbar( y=centers, - left=lefts, - right=-lefts, + left=left, + right=right, height=heights, fill_alpha=shade, line_alpha=shade, line_color=None, - **kwargs_shade + **shade_kwargs ) + + return binned_d diff --git a/arviz/plots/backends/matplotlib/violinplot.py b/arviz/plots/backends/matplotlib/violinplot.py index 4556207fcc..9c533103c6 100644 --- a/arviz/plots/backends/matplotlib/violinplot.py +++ b/arviz/plots/backends/matplotlib/violinplot.py @@ -14,9 +14,12 @@ def plot_violin( figsize, rows, cols, + sharex, sharey, - kwargs_shade, + shade_kwargs, shade, + rug, + rug_kwargs, bw, credible_interval, linewidth, @@ -28,24 +31,31 @@ def plot_violin( ): """Matplotlib violin plot.""" if ax is None: - _, ax = _create_axes_grid( + fig, ax = _create_axes_grid( len(plotters), rows, cols, + sharex=sharex, sharey=sharey, figsize=figsize, squeeze=False, backend_kwargs=backend_kwargs, ) + fig.set_constrained_layout(False) + fig.subplots_adjust(wspace=0) ax = np.atleast_1d(ax) for (var_name, selection, x), ax_ in zip(plotters, ax.flatten()): val = x.flatten() if val[0].dtype.kind == "i": - cat_hist(val, shade, ax_, **kwargs_shade) + dens = cat_hist(val, rug, shade, ax_, **shade_kwargs) else: - _violinplot(val, shade, bw, ax_, **kwargs_shade) + dens = _violinplot(val, rug, shade, bw, ax_, **shade_kwargs) + + if rug: + rug_x = -np.abs(np.random.normal(scale=max(dens) / 3.5, size=len(val))) + ax_.plot(rug_x, val, **rug_kwargs) per = np.percentile(val, [25, 75, 50]) hpd_intervals = hpd(val, credible_interval, multimodal=False) @@ -66,25 +76,32 @@ def plot_violin( return ax -def _violinplot(val, shade, bw, ax, **kwargs_shade): +def _violinplot(val, rug, shade, bw, ax, **shade_kwargs): """Auxiliary function to plot violinplots.""" density, low_b, up_b = _fast_kde(val, bw=bw) x = np.linspace(low_b, up_b, len(density)) - x = np.concatenate([x, x[::-1]]) - density = np.concatenate([-density, density[::-1]]) + if not rug: + x = np.concatenate([x, x[::-1]]) + density = np.concatenate([-density, density[::-1]]) - ax.fill_betweenx(x, density, alpha=shade, lw=0, **kwargs_shade) + ax.fill_betweenx(x, density, alpha=shade, lw=0, **shade_kwargs) + return density -def cat_hist(val, shade, ax, **kwargs_shade): +def cat_hist(val, rug, shade, ax, **shade_kwargs): """Auxiliary function to plot discrete-violinplots.""" bins = get_bins(val) _, binned_d, _ = histogram(val, bins=bins) bin_edges = np.linspace(np.min(val), np.max(val), len(bins)) - centers = 0.5 * (bin_edges + np.roll(bin_edges, 1))[:-1] heights = np.diff(bin_edges) + centers = bin_edges[:-1] + heights.mean() / 2 + + if rug: + left = None + else: + left = -0.5 * binned_d - lefts = -0.5 * binned_d - ax.barh(centers, binned_d, height=heights, left=lefts, alpha=shade, **kwargs_shade) + ax.barh(centers, binned_d, height=heights, left=left, alpha=shade, **shade_kwargs) + return binned_d diff --git a/arviz/plots/violinplot.py b/arviz/plots/violinplot.py index c6eb7e7f29..2f91d7f932 100644 --- a/arviz/plots/violinplot.py +++ b/arviz/plots/violinplot.py @@ -14,14 +14,17 @@ def plot_violin( data, var_names=None, quartiles=True, + rug=False, credible_interval=0.94, shade=0.35, bw=4.5, + sharex=True, sharey=True, figsize=None, textsize=None, ax=None, - kwargs_shade=None, + shade_kwargs=None, + rug_kwargs=None, backend=None, backend_kwargs=None, show=None, @@ -42,6 +45,8 @@ def plot_violin( quartiles : bool, optional Flag for plotting the interquartile range, in addition to the credible_interval*100% intervals. Defaults to True + rug : bool + If True adds a jittered rugplot. Defaults to False. credible_interval : float, optional Credible intervals. Defaults to 0.94. shade : float @@ -56,12 +61,17 @@ def plot_violin( textsize: int Text size of the point_estimates, axis ticks, and HPD. If None it will be autoscaled based on figsize. + sharex : bool + Defaults to True, violinplots share a common x-axis scale. sharey : bool Defaults to True, violinplots share a common y-axis scale. ax: axes, optional Matplotlib axes or bokeh figures. - kwargs_shade : dicts, optional + shade_kwargs : dicts, optional Additional keywords passed to `fill_between`, or `barh` to control the shade. + rug_kwargs : dict + Keywords passed to the rug plot. If true only the righ half side of the violin will be + plotted. backend: str, optional Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib". backend_kwargs: bool, optional @@ -81,15 +91,17 @@ def plot_violin( list(xarray_var_iter(data, var_names=var_names, combined=True)), "plot_violin" ) - if kwargs_shade is None: - kwargs_shade = {} + if shade_kwargs is None: + shade_kwargs = {} rows, cols = default_grid(len(plotters)) (figsize, ax_labelsize, _, xt_labelsize, linewidth, _) = _scale_fig_size( figsize, textsize, rows, cols ) - ax_labelsize *= 2 + + if rug_kwargs is None: + rug_kwargs = {} violinplot_kwargs = dict( ax=ax, @@ -97,9 +109,12 @@ def plot_violin( figsize=figsize, rows=rows, cols=cols, + sharex=sharex, sharey=sharey, - kwargs_shade=kwargs_shade, + shade_kwargs=shade_kwargs, shade=shade, + rug=rug, + rug_kwargs=rug_kwargs, bw=bw, credible_interval=credible_interval, linewidth=linewidth, @@ -115,6 +130,14 @@ def plot_violin( violinplot_kwargs.pop("ax_labelsize") violinplot_kwargs.pop("xt_labelsize") + rug_kwargs.setdefault("fill_alpha", 0.1) + rug_kwargs.setdefault("line_alpha", 0.1) + + else: + rug_kwargs.setdefault("alpha", 0.1) + rug_kwargs.setdefault("marker", ".") + rug_kwargs.setdefault("linestyle", "") + # TODO: Add backend kwargs plot = get_plotting_function("plot_violin", "violinplot", backend) ax = plot(**violinplot_kwargs) From 4d3722c92e8ac2384591767ec94f1a5ea69d2a0b Mon Sep 17 00:00:00 2001 From: Osvaldo Martin Date: Mon, 20 Jan 2020 09:50:48 -0300 Subject: [PATCH 06/17] Update CHANGELOG.md --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f396080571..1b1d8c6bfd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,13 +4,15 @@ ### New features * Add out-of-sample predictions (`predictions` and `predictions_constant_data` groups) to pymc3 translations. (#983) +* Violinplot: rug-plot option (#997) ### Maintenance and fixes * Fixed bug in extracting prior samples for cmdstanpy. (#979) * Fix erroneous warning in traceplot (#989) * Correct bfmi denominator (#991) * Removed parallel from jit full (#996) -* Rename flat_inference_data_to_dict (#1003) +* Rename flat_inference_data_to_dict (#1003) +* Violinplot: fix histogram (#997) ### Deprecation From c3836bb6b7a79a19a8d5f9bf29b79ee2ffdfba3b Mon Sep 17 00:00:00 2001 From: amukh18 <45681148+amukh18@users.noreply.github.com> Date: Tue, 21 Jan 2020 02:41:25 +0530 Subject: [PATCH 07/17] Add group argument to plot_joint (#1012) --- arviz/plots/jointplot.py | 5 ++++- arviz/plots/ppcplot.py | 2 +- examples/matplotlib/mpl_plot_ppc.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/arviz/plots/jointplot.py b/arviz/plots/jointplot.py index b73f350b41..86ae989eb8 100644 --- a/arviz/plots/jointplot.py +++ b/arviz/plots/jointplot.py @@ -6,6 +6,7 @@ def plot_joint( data, + group="posterior", var_names=None, coords=None, figsize=None, @@ -29,6 +30,8 @@ def plot_joint( data : obj Any object that can be converted to an az.InferenceData object Refer to documentation of az.convert_to_dataset for details + group : str, optional + Specifies which InferenceData group should be plotted. Defaults to ‘posterior’. var_names : str or iterable of str Variables to be plotted. iter of two variables or one variable (with subset having exactly 2 dimensions) are required. @@ -131,7 +134,7 @@ def plot_joint( ("Plot type {} not recognized." "Plot type must be in {}").format(kind, valid_kinds) ) - data = convert_to_dataset(data, group="posterior") + data = convert_to_dataset(data, group=group) if coords is None: coords = {} diff --git a/arviz/plots/ppcplot.py b/arviz/plots/ppcplot.py index 23ef64c668..ef6f4a3f16 100644 --- a/arviz/plots/ppcplot.py +++ b/arviz/plots/ppcplot.py @@ -127,7 +127,7 @@ def plot_ppc( >>> import arviz as az >>> data = az.load_arviz_data('radon') - >>> az.plot_ppc(data) + >>> az.plot_ppc(data,data_pairs={"obs":"obs"}) Plot the overlay with empirical CDFs. diff --git a/examples/matplotlib/mpl_plot_ppc.py b/examples/matplotlib/mpl_plot_ppc.py index 53ff435faa..8f3ed492eb 100644 --- a/examples/matplotlib/mpl_plot_ppc.py +++ b/examples/matplotlib/mpl_plot_ppc.py @@ -9,4 +9,4 @@ az.style.use("arviz-darkgrid") data = az.load_arviz_data("non_centered_eight") -az.plot_ppc(data, alpha=0.03, figsize=(12, 6), textsize=14) +az.plot_ppc(data, data_pairs={"obs": "obs"}, alpha=0.03, figsize=(12, 6), textsize=14) From cdbb49173779c28249bf2b7f2828db3d33679f54 Mon Sep 17 00:00:00 2001 From: Shashank jain Date: Tue, 21 Jan 2020 02:43:16 +0530 Subject: [PATCH 08/17] Integrate stats.ic_scale rcParam (#993) --- arviz/plots/elpdplot.py | 3 +- arviz/stats/stats.py | 28 +++++++-------- arviz/tests/test_plots_bokeh.py | 44 ++++++++--------------- examples/bokeh/bokeh_plot_dist.py | 4 +-- examples/bokeh/bokeh_plot_loo_pit_ecdf.py | 2 +- examples/bokeh/bokeh_plot_pair.py | 2 +- 6 files changed, 31 insertions(+), 52 deletions(-) diff --git a/arviz/plots/elpdplot.py b/arviz/plots/elpdplot.py index 49a0653d9e..0a3239ecf2 100644 --- a/arviz/plots/elpdplot.py +++ b/arviz/plots/elpdplot.py @@ -21,7 +21,7 @@ def plot_elpd( threshold=None, ax=None, ic=None, - scale="deviance", + scale=None, plot_kwargs=None, backend=None, backend_kwargs=None, @@ -109,6 +109,7 @@ def plot_elpd( """ valid_ics = ["waic", "loo"] ic = rcParams["stats.information_criterion"] if ic is None else ic.lower() + scale = rcParams["stats.ic_scale"] if scale is None else scale.lower() if ic not in valid_ics: raise ValueError( ("Information Criteria type {} not recognized." "IC must be in {}").format( diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 68b98e4901..6281843ed4 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -42,13 +42,7 @@ def compare( - dataset_dict, - ic=None, - method="BB-pseudo-BMA", - b_samples=1000, - alpha=1, - seed=None, - scale="deviance", + dataset_dict, ic=None, method="BB-pseudo-BMA", b_samples=1000, alpha=1, seed=None, scale=None ): r"""Compare models based on WAIC or LOO cross-validation. @@ -136,7 +130,7 @@ def compare( """ names = list(dataset_dict.keys()) - scale = scale.lower() + scale = rcParams["stats.ic_scale"] if scale is None else scale.lower() if scale == "log": scale_value = 1 ascending = False @@ -421,7 +415,7 @@ def hpd(ary, credible_interval=0.94, circular=False, multimodal=False): return hpd_intervals -def loo(data, pointwise=False, reff=None, scale="deviance"): +def loo(data, pointwise=False, reff=None, scale=None): """Pareto-smoothed importance sampling leave-one-out cross-validation. Calculates leave-one-out (LOO) cross-validation for out of sample predictive model fit, @@ -493,12 +487,13 @@ def loo(data, pointwise=False, reff=None, scale="deviance"): shape = log_likelihood.shape n_samples = shape[-1] n_data_points = np.product(shape[:-1]) + scale = rcParams["stats.ic_scale"] if scale is None else scale.lower() - if scale.lower() == "deviance": + if scale == "deviance": scale_value = -2 - elif scale.lower() == "log": + elif scale == "log": scale_value = 1 - elif scale.lower() == "negative_log": + elif scale == "negative_log": scale_value = -1 else: raise TypeError('Valid scale values are "deviance", "log", "negative_log"') @@ -1101,7 +1096,7 @@ def summary( return summary_df -def waic(data, pointwise=False, scale="deviance"): +def waic(data, pointwise=False, scale=None): """Calculate the widely available information criterion. Also calculates the WAIC's standard error and the effective number of @@ -1165,12 +1160,13 @@ def waic(data, pointwise=False, scale="deviance"): if "log_likelihood" not in inference_data.sample_stats: raise TypeError("Data must include log_likelihood in sample_stats") log_likelihood = inference_data.sample_stats.log_likelihood + scale = rcParams["stats.ic_scale"] if scale is None else scale.lower() - if scale.lower() == "deviance": + if scale == "deviance": scale_value = -2 - elif scale.lower() == "log": + elif scale == "log": scale_value = 1 - elif scale.lower() == "negative_log": + elif scale == "negative_log": scale_value = -1 else: raise TypeError('Valid scale values are "deviance", "log", "negative_log"') diff --git a/arviz/tests/test_plots_bokeh.py b/arviz/tests/test_plots_bokeh.py index e484911b12..4fe3f4cf24 100644 --- a/arviz/tests/test_plots_bokeh.py +++ b/arviz/tests/test_plots_bokeh.py @@ -189,7 +189,7 @@ def test_plot_kde_1d(continuous_model): "kwargs", [ {"contour": True, "fill_last": False}, - {"contour": True, "contourf_kwargs": {"cmap": "plasma"},}, + {"contour": True, "contourf_kwargs": {"cmap": "plasma"}}, {"contour": False}, {"contour": False, "pcolormesh_kwargs": {"cmap": "plasma"}}, ], @@ -275,9 +275,7 @@ def test_plot_compare_no_ic(models): assert "['waic', 'loo']" in str(err.value) -@pytest.mark.parametrize( - "kwargs", [{}, {"ic": "loo"}, {"xlabels": True, "scale": "log"},], -) +@pytest.mark.parametrize("kwargs", [{}, {"ic": "loo"}, {"xlabels": True, "scale": "log"}]) @pytest.mark.parametrize("add_model", [False, True]) @pytest.mark.parametrize("use_elpddata", [False, True]) def test_plot_elpd(models, add_model, use_elpddata, kwargs): @@ -300,9 +298,7 @@ def test_plot_elpd(models, add_model, use_elpddata, kwargs): assert axes.shape[0] == len(model_dict) - 1 -@pytest.mark.parametrize( - "kwargs", [{}, {"ic": "loo"}, {"xlabels": True, "scale": "log"},], -) +@pytest.mark.parametrize("kwargs", [{}, {"ic": "loo"}, {"xlabels": True, "scale": "log"}]) @pytest.mark.parametrize("add_model", [False, True]) @pytest.mark.parametrize("use_elpddata", [False, True]) def test_plot_elpd_multidim(multidim_models, add_model, use_elpddata, kwargs): @@ -443,9 +439,7 @@ def test_plot_forest(models, model_fits, args_expected): def test_plot_forest_rope_exception(): with pytest.raises(ValueError) as err: - plot_forest( - {"x": [1]}, rope="not_correct_format", backend="bokeh", show=False, - ) + plot_forest({"x": [1]}, rope="not_correct_format", backend="bokeh", show=False) assert "Argument `rope` must be None, a dictionary like" in str(err.value) @@ -508,19 +502,15 @@ def test_plot_joint_discrete(discrete_model): def test_plot_joint_bad(models): with pytest.raises(ValueError): plot_joint( - models.model_1, var_names=("mu", "tau"), kind="bad_kind", backend="bokeh", show=False, + models.model_1, var_names=("mu", "tau"), kind="bad_kind", backend="bokeh", show=False ) with pytest.raises(Exception): - plot_joint( - models.model_1, var_names=("mu", "tau", "eta"), backend="bokeh", show=False, - ) + plot_joint(models.model_1, var_names=("mu", "tau", "eta"), backend="bokeh", show=False) with pytest.raises(ValueError): _, axes = list(range(5)) - plot_joint( - models.model_1, var_names=("mu", "tau"), ax=axes, backend="bokeh", show=False, - ) + plot_joint(models.model_1, var_names=("mu", "tau"), ax=axes, backend="bokeh", show=False) @pytest.mark.parametrize( @@ -614,7 +604,7 @@ def test_plot_loo_pit_incompatible_args(models): """Test error when both ecdf and use_hpd are True.""" with pytest.raises(ValueError, match="incompatible"): plot_loo_pit( - idata=models.model_1, y="y", ecdf=True, use_hpd=True, backend="bokeh", show=False, + idata=models.model_1, y="y", ecdf=True, use_hpd=True, backend="bokeh", show=False ) @@ -694,8 +684,8 @@ def test_plot_mcse_no_divergences(models): @pytest.mark.parametrize( "kwargs", [ - {"var_names": "theta", "divergences": True, "coords": {"theta_dim_0": [0, 1]},}, - {"divergences": True, "var_names": ["theta", "mu"],}, + {"var_names": "theta", "divergences": True, "coords": {"theta_dim_0": [0, 1]}}, + {"divergences": True, "var_names": ["theta", "mu"]}, {"kind": "kde", "var_names": ["theta"]}, {"kind": "hexbin", "var_names": ["theta"]}, {"kind": "hexbin", "var_names": ["theta"]}, @@ -760,7 +750,7 @@ def test_plot_parallel_exception(models, var_names): """Ensure that correct exception is raised when one variable is passed.""" with pytest.raises(ValueError): assert plot_parallel( - models.model_1, var_names=var_names, norm_method="foo", backend="bokeh", show=False, + models.model_1, var_names=var_names, norm_method="foo", backend="bokeh", show=False ) @@ -860,9 +850,7 @@ def test_plot_ppc_bad(models, kind): with pytest.raises(TypeError): plot_ppc(models.model_1, kind="bad_val", backend="bokeh", show=False) with pytest.raises(TypeError): - plot_ppc( - models.model_1, num_pp_samples="bad_val", backend="bokeh", show=False, - ) + plot_ppc(models.model_1, num_pp_samples="bad_val", backend="bokeh", show=False) @pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"]) @@ -913,13 +901,9 @@ def test_plot_posterior_bad(models): with pytest.raises(ValueError): plot_posterior(models.model_1, backend="bokeh", show=False, rope="bad_value") with pytest.raises(ValueError): - plot_posterior( - models.model_1, ref_val="bad_value", backend="bokeh", show=False, - ) + plot_posterior(models.model_1, ref_val="bad_value", backend="bokeh", show=False) with pytest.raises(ValueError): - plot_posterior( - models.model_1, point_estimate="bad_value", backend="bokeh", show=False, - ) + plot_posterior(models.model_1, point_estimate="bad_value", backend="bokeh", show=False) @pytest.mark.parametrize("point_estimate", ("mode", "mean", "median")) diff --git a/examples/bokeh/bokeh_plot_dist.py b/examples/bokeh/bokeh_plot_dist.py index 8725bd0690..52484f9275 100644 --- a/examples/bokeh/bokeh_plot_dist.py +++ b/examples/bokeh/bokeh_plot_dist.py @@ -16,9 +16,7 @@ ax_poisson = bkp.figure(**figure_kwargs) ax_normal = bkp.figure(**figure_kwargs) -az.plot_dist( - a, color="black", label="Poisson", ax=ax_poisson, backend="bokeh", show=False, -) +az.plot_dist(a, color="black", label="Poisson", ax=ax_poisson, backend="bokeh", show=False) az.plot_dist(b, color="red", label="Gaussian", ax=ax_normal, backend="bokeh", show=False) ax = row(ax_poisson, ax_normal) diff --git a/examples/bokeh/bokeh_plot_loo_pit_ecdf.py b/examples/bokeh/bokeh_plot_loo_pit_ecdf.py index 1bfa8dd635..9f9429e6b4 100644 --- a/examples/bokeh/bokeh_plot_loo_pit_ecdf.py +++ b/examples/bokeh/bokeh_plot_loo_pit_ecdf.py @@ -11,5 +11,5 @@ log_weights = az.psislw(-log_like)[0] ax = az.plot_loo_pit( - idata, y="y_like", log_weights=log_weights, ecdf=True, color="orange", backend="bokeh", + idata, y="y_like", log_weights=log_weights, ecdf=True, color="orange", backend="bokeh" ) diff --git a/examples/bokeh/bokeh_plot_pair.py b/examples/bokeh/bokeh_plot_pair.py index 16878bd0b8..0663e02819 100644 --- a/examples/bokeh/bokeh_plot_pair.py +++ b/examples/bokeh/bokeh_plot_pair.py @@ -10,5 +10,5 @@ coords = {"school": ["Choate", "Deerfield"]} ax = az.plot_pair( - centered, var_names=["theta", "mu", "tau"], coords=coords, divergences=True, backend="bokeh", + centered, var_names=["theta", "mu", "tau"], coords=coords, divergences=True, backend="bokeh" ) From f38487bae12982e5e4cfda0ff0e5c77ee32fbfe8 Mon Sep 17 00:00:00 2001 From: Oriol Abril Date: Mon, 20 Jan 2020 23:04:37 +0100 Subject: [PATCH 09/17] Update changelog with some untracked changes (#1015) --- CHANGELOG.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b1d8c6bfd..92f7723a79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,9 +5,11 @@ ### New features * Add out-of-sample predictions (`predictions` and `predictions_constant_data` groups) to pymc3 translations. (#983) * Violinplot: rug-plot option (#997) +* Integrated rcParams `plot.point_estimate` (#994) and `stats.ic_scale` (#993) +* Added `group` argument to `plot_ppc` (#1008), `plot_pair` (#1009) and `plot_joint` (#1012) ### Maintenance and fixes -* Fixed bug in extracting prior samples for cmdstanpy. (#979) +* Fixed bug in extracting prior samples for cmdstanpy. (#979) * Fix erroneous warning in traceplot (#989) * Correct bfmi denominator (#991) * Removed parallel from jit full (#996) @@ -17,7 +19,7 @@ ### Deprecation ### Documentation -* Clarify the usage of "plot_joint" (#1001) +* Clarify the usage of "plot_joint" (#1001) ## v0.6.1 (2019 Dec 28) @@ -134,7 +136,7 @@ * And exception to plot compare ([#461](https://github.com/arviz-devs/arviz/pull/461)) * Add Docker Testing to travisCI ([#473](https://github.com/arviz-devs/arviz/pull/473)) * fix jointplot warning ([#478](https://github.com/arviz-devs/arviz/pull/478)) -* Fix tensorflow import bug ([#489](https://github.com/arviz-devs/arviz/pull/489)) +* Fix tensorflow import bug ([#489](https://github.com/arviz-devs/arviz/pull/489)) * Rename N_effective to S_effective ([#505](https://github.com/arviz-devs/arviz/pull/505)) From 2db39c233bcd267c5bbb6c93eeb47ef8c6bbbe0e Mon Sep 17 00:00:00 2001 From: amukh18 <45681148+amukh18@users.noreply.github.com> Date: Tue, 21 Jan 2020 23:43:53 +0530 Subject: [PATCH 10/17] Add data_pairs usage example for plot_ppc (#1007) --- arviz/plots/ppcplot.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/arviz/plots/ppcplot.py b/arviz/plots/ppcplot.py index ef6f4a3f16..8e3fbfedce 100644 --- a/arviz/plots/ppcplot.py +++ b/arviz/plots/ppcplot.py @@ -62,8 +62,8 @@ def plot_ppc( data_pairs : dict Dictionary containing relations between observed data and posterior/prior predictive data. Dictionary structure: - Key = data var_name - Value = posterior/prior predictive var_name + key = data var_name + value = posterior/prior predictive var_name For example, `data_pairs = {'y' : 'y_hat'}` If None, it will assume that the observed data and the posterior/prior predictive data have the same variable name. @@ -128,6 +128,7 @@ def plot_ppc( >>> import arviz as az >>> data = az.load_arviz_data('radon') >>> az.plot_ppc(data,data_pairs={"obs":"obs"}) + >>> #az.plot_ppc(data,data_pairs={"obs":"obs_hat"}) Plot the overlay with empirical CDFs. From 08f36a41436336fb9ddbc4d86e594d8f308b2520 Mon Sep 17 00:00:00 2001 From: percygautam Date: Wed, 22 Jan 2020 00:46:58 +0530 Subject: [PATCH 11/17] changes --- doc/sphinxext/gallery_generator.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/doc/sphinxext/gallery_generator.py b/doc/sphinxext/gallery_generator.py index a5b9f8c017..ab595cb585 100644 --- a/doc/sphinxext/gallery_generator.py +++ b/doc/sphinxext/gallery_generator.py @@ -240,13 +240,10 @@ def thumbfilename(self): @property def apiname(self): - name="" with open(op.join(self.target_dir, self.pyfilename), "r") as file: regex = r"az\.(plot\_[a-z_]+)\(" - matches = re.finditer(regex, file.read(), re.MULTILINE) - for matchNum, match in enumerate(matches, start=1): - name = match.group(1) - return name + name = re.findall(regex, file.read()) + return name[0] if len(name) > 0 else None @property def sphinxtag(self): From 23fd65d127aa35ef8b117544e5b07ee82ccaa7b5 Mon Sep 17 00:00:00 2001 From: percygautam Date: Wed, 22 Jan 2020 00:58:48 +0530 Subject: [PATCH 12/17] added the change in changlog.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 92f7723a79..a6f00072b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ ### Documentation * Clarify the usage of "plot_joint" (#1001) +* Show API link of the function in examples (#1013) ## v0.6.1 (2019 Dec 28) From dc25a7b399433be556b31aa86c43efea739b8f3d Mon Sep 17 00:00:00 2001 From: percygautam Date: Wed, 22 Jan 2020 19:28:24 +0530 Subject: [PATCH 13/17] minor fixes --- doc/sphinxext/gallery_generator.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/doc/sphinxext/gallery_generator.py b/doc/sphinxext/gallery_generator.py index ab595cb585..04e5cd0bd2 100644 --- a/doc/sphinxext/gallery_generator.py +++ b/doc/sphinxext/gallery_generator.py @@ -40,7 +40,7 @@ def execfile(filename, globals=None, locals=None): .. image:: {img_file} **Python source code:** :download:`[download source: {fname}]<{fname}>` -**API documentation:** `{api_name} <../../generated/arviz.{api_name}>`_ +**API documentation:** `{api_name} <{api_link}>`_ .. literalinclude:: {fname} :lines: {end_line}- @@ -55,7 +55,7 @@ def execfile(filename, globals=None, locals=None): :source-position: none **Python source code:** :download:`[download source: {fname}]<{fname}>` -**API documentation:** `{api_name} <../../generated/arviz.{api_name}>`_ +**API documentation:** `{api_name} <{api_link}>`_ .. literalinclude:: {fname} :lines: {end_line}- @@ -239,12 +239,16 @@ def thumbfilename(self): return pngfile @property - def apiname(self): + def apitext(self): with open(op.join(self.target_dir, self.pyfilename), "r") as file: regex = r"az\.(plot\_[a-z_]+)\(" name = re.findall(regex, file.read()) return name[0] if len(name) > 0 else None + @property + def apilink(self): + return " ../../generated/arviz."+self.apitext if self.apitext else None + @property def sphinxtag(self): return self.modulename @@ -389,7 +393,8 @@ def main(app): fname=ex.pyfilename, absfname=op.join(target_dir, ex.pyfilename), img_file=ex.pngfilename, - api_name=ex.apiname, + api_name=ex.apitext, + api_link=ex.apilink, ) with open(op.join(target_dir, ex.rstfilename), "w") as f: f.write(output) From d09f07c1dc5e2f9ae68986e45b4b3cce5937476d Mon Sep 17 00:00:00 2001 From: percygautam Date: Wed, 22 Jan 2020 19:37:23 +0530 Subject: [PATCH 14/17] linting fixes --- doc/sphinxext/gallery_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/sphinxext/gallery_generator.py b/doc/sphinxext/gallery_generator.py index 04e5cd0bd2..89f02105e2 100644 --- a/doc/sphinxext/gallery_generator.py +++ b/doc/sphinxext/gallery_generator.py @@ -247,7 +247,7 @@ def apitext(self): @property def apilink(self): - return " ../../generated/arviz."+self.apitext if self.apitext else None + return " ../../generated/arviz." + self.apitext if self.apitext else None @property def sphinxtag(self): From c013176a3d1c0f253f73f90adc4f7a6d54b311d5 Mon Sep 17 00:00:00 2001 From: percygautam Date: Wed, 22 Jan 2020 20:14:23 +0530 Subject: [PATCH 15/17] Corrected api text format --- doc/sphinxext/gallery_generator.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/doc/sphinxext/gallery_generator.py b/doc/sphinxext/gallery_generator.py index 89f02105e2..c9be99f7d7 100644 --- a/doc/sphinxext/gallery_generator.py +++ b/doc/sphinxext/gallery_generator.py @@ -40,7 +40,7 @@ def execfile(filename, globals=None, locals=None): .. image:: {img_file} **Python source code:** :download:`[download source: {fname}]<{fname}>` -**API documentation:** `{api_name} <{api_link}>`_ +**API documentation:** {api_name} .. literalinclude:: {fname} :lines: {end_line}- @@ -55,7 +55,7 @@ def execfile(filename, globals=None, locals=None): :source-position: none **Python source code:** :download:`[download source: {fname}]<{fname}>` -**API documentation:** `{api_name} <{api_link}>`_ +**API documentation:** {api_name} .. literalinclude:: {fname} :lines: {end_line}- @@ -239,15 +239,16 @@ def thumbfilename(self): return pngfile @property - def apitext(self): + def apiname(self): with open(op.join(self.target_dir, self.pyfilename), "r") as file: regex = r"az\.(plot\_[a-z_]+)\(" name = re.findall(regex, file.read()) - return name[0] if len(name) > 0 else None - - @property - def apilink(self): - return " ../../generated/arviz." + self.apitext if self.apitext else None + apitext = name[0] if len(name) > 0 else "" + return ( + "`" + apitext + " <../../generated/arviz." + apitext + ">`_" + if apitext + else "No API Documentation available" + ) @property def sphinxtag(self): @@ -393,8 +394,7 @@ def main(app): fname=ex.pyfilename, absfname=op.join(target_dir, ex.pyfilename), img_file=ex.pngfilename, - api_name=ex.apitext, - api_link=ex.apilink, + api_name=ex.apiname, ) with open(op.join(target_dir, ex.rstfilename), "w") as f: f.write(output) From 691d8af08b1fed35f7fcf496676f66879b46cc32 Mon Sep 17 00:00:00 2001 From: percygautam Date: Sat, 25 Jan 2020 00:42:11 +0530 Subject: [PATCH 16/17] final fixes --- doc/sphinxext/gallery_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/sphinxext/gallery_generator.py b/doc/sphinxext/gallery_generator.py index c9be99f7d7..ba14797c25 100644 --- a/doc/sphinxext/gallery_generator.py +++ b/doc/sphinxext/gallery_generator.py @@ -245,7 +245,7 @@ def apiname(self): name = re.findall(regex, file.read()) apitext = name[0] if len(name) > 0 else "" return ( - "`" + apitext + " <../../generated/arviz." + apitext + ">`_" + "`{apitext} <../../generated/arviz.{apitext}>`_".format(apitext=apitext) if apitext else "No API Documentation available" ) From a2b9bfd43aa82fb1fbce3a5a1c4a7b0d61ed8978 Mon Sep 17 00:00:00 2001 From: percygautam Date: Sat, 25 Jan 2020 00:44:44 +0530 Subject: [PATCH 17/17] final fixes --- doc/sphinxext/gallery_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/sphinxext/gallery_generator.py b/doc/sphinxext/gallery_generator.py index ba14797c25..0edfbe17fe 100644 --- a/doc/sphinxext/gallery_generator.py +++ b/doc/sphinxext/gallery_generator.py @@ -243,7 +243,7 @@ def apiname(self): with open(op.join(self.target_dir, self.pyfilename), "r") as file: regex = r"az\.(plot\_[a-z_]+)\(" name = re.findall(regex, file.read()) - apitext = name[0] if len(name) > 0 else "" + apitext = name[0] if name else "" return ( "`{apitext} <../../generated/arviz.{apitext}>`_".format(apitext=apitext) if apitext