diff --git a/.gitignore b/.gitignore index 2c8249ac36..d0111fb0bd 100644 --- a/.gitignore +++ b/.gitignore @@ -55,7 +55,8 @@ coverage.xml *.log # Sphinx documentation -docs/source/_build/ +doc/build/ +doc/source/savefig # PyBuilder target/ diff --git a/.projections.json b/.projections.json new file mode 100644 index 0000000000..fc9f1f1578 --- /dev/null +++ b/.projections.json @@ -0,0 +1,17 @@ +{ + "arviz/plots/backends/matplotlib/*.py": { + "alternate": "arviz/plots/backends/bokeh/{}.py", + "related": "arviz/plots/{}.py", + "type": "mpl" + }, + "arviz/plots/backends/bokeh/*.py": { + "alternate": "arviz/plots/backends/matplotlib/{}.py", + "related": "arviz/plots/{}.py", + "type": "bokeh" + }, + "arviz/plots/*.py": { + "alternate": "arviz/plots/backends/matplotlib/{}.py", + "related": "arviz/plots/backends/bokeh/{}.py", + "type": "base" + } +} diff --git a/CHANGELOG.md b/CHANGELOG.md index 045bb4490f..62b636b9e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,12 +2,18 @@ ## v0.x.x Unreleased ### New features +* Added `labeller` argument to enable label customization in plots and summary ([1201](https://github.com/arviz-devs/arviz/pull/1201)) +* Added `arviz.labels` module with classes and utilities ([1201](https://github.com/arviz-devs/arviz/pull/1201)) ### Maintenance and fixes +* Enforced using coordinate values as default labels ([1201](https://github.com/arviz-devs/arviz/pull/1201)) +* Integrate `index_origin` with all the library ([1201](https://github.com/arviz-devs/arviz/pull/1201)) ### Deprecation +* Deprecated `index_origin` and `order` arguments in `az.summary` ([1201](https://github.com/arviz-devs/arviz/pull/1201)) ### Documentation +* Added "Label guide" page and API section for `arviz.labels` module ([1201](https://github.com/arviz-devs/arviz/pull/1201)) ## v0.11.2 (2021 Feb 21) ### New features @@ -15,6 +21,7 @@ * Added confidence interval band to auto-correlation plot ([1535](https://github.com/arviz-devs/arviz/pull/1535)) ### Maintenance and fixes +* Updated CmdStanPy converter form compatibility with versions >=0.9.68 ([1558](https://github.com/arviz-devs/arviz/pull/1558) and ([1564](https://github.com/arviz-devs/arviz/pull/1564)) * Updated `from_cmdstanpy`, `from_cmdstan`, `from_numpyro` and `from_pymc3` converters to follow schema convention ([1550](https://github.com/arviz-devs/arviz/pull/1550), [1541](https://github.com/arviz-devs/arviz/pull/1541), [1525](https://github.com/arviz-devs/arviz/pull/1525) and [1555](https://github.com/arviz-devs/arviz/pull/1555)) * Fix calculation of mode as point estimate ([1552](https://github.com/arviz-devs/arviz/pull/1552)) * Remove variable name from legend in posterior predictive plot ([1559](https://github.com/arviz-devs/arviz/pull/1559)) diff --git a/arviz/data/base.py b/arviz/data/base.py index f44176ad60..bd01291859 100644 --- a/arviz/data/base.py +++ b/arviz/data/base.py @@ -17,6 +17,7 @@ import json # type: ignore from .. import __version__, utils +from ..rcparams import rcParams CoordSpec = Dict[str, List[Any]] DimSpec = Dict[str, List[str]] @@ -49,7 +50,13 @@ def wrapped(cls, *args, **kwargs): def generate_dims_coords( - shape, var_name, dims=None, coords=None, default_dims=None, skip_event_dims=None + shape, + var_name, + dims=None, + coords=None, + default_dims=None, + index_origin=None, + skip_event_dims=None, ): """Generate default dimensions and coordinates for a variable. @@ -70,6 +77,9 @@ def generate_dims_coords( when manipulating Monte Carlo traces, the ``default_dims`` would be ``["chain" , "draw"]`` which ArviZ uses as its own names for dimensions of MCMC traces. + index_origin : int, optional + Starting value of integer coordinate values. Defaults to the value in rcParam + ``data.index_origin``. skip_event_dims : bool, default False Returns @@ -79,6 +89,8 @@ def generate_dims_coords( dict[str] -> list[str] Default coords """ + if index_origin is None: + index_origin = rcParams["data.index_origin"] if default_dims is None: default_dims = [] if dims is None: @@ -127,19 +139,30 @@ def generate_dims_coords( dims[idx] = dim_name dim_name = dims[idx] if dim_name not in coords: - coords[dim_name] = utils.arange(dim_len) + coords[dim_name] = np.arange(index_origin, dim_len + index_origin) coords = {key: coord for key, coord in coords.items() if any(key == dim for dim in dims)} return dims, coords -def numpy_to_data_array(ary, *, var_name="data", coords=None, dims=None, skip_event_dims=None): +def numpy_to_data_array( + ary, + *, + var_name="data", + coords=None, + dims=None, + default_dims=None, + index_origin=None, + skip_event_dims=None, +): """Convert a numpy array to an xarray.DataArray. - The first two dimensions will be (chain, draw), and any remaining + By default, the first two dimensions will be (chain, draw), and any remaining dimensions will be "shape". - If the numpy array is 1d, this dimension is interpreted as draw - If the numpy array is 2d, it is interpreted as (chain, draw) - If the numpy array is 3 or more dimensions, the last dimensions are kept as shapes. + * If the numpy array is 1d, this dimension is interpreted as draw + * If the numpy array is 2d, it is interpreted as (chain, draw) + * If the numpy array is 3 or more dimensions, the last dimensions are kept as shapes. + + To modify this behaviour, use ``default_dims``. Parameters ---------- @@ -154,6 +177,11 @@ def numpy_to_data_array(ary, *, var_name="data", coords=None, dims=None, skip_ev is the name of the dimension, the values are the index values. dims : List(str) A list of coordinate names for the variable + default_dims : list of str, optional + Passed to :py:func:`generate_dims_coords`. Defaults to ``["chain", "draw"]``, and + an empty list is accepted + index_origin : int, optional + Passed to :py:func:`generate_dims_coords` skip_event_dims : bool Returns @@ -162,37 +190,43 @@ def numpy_to_data_array(ary, *, var_name="data", coords=None, dims=None, skip_ev Will have the same data as passed, but with coordinates and dimensions """ # manage and transform copies - default_dims = ["chain", "draw"] - ary = utils.two_de(ary) - n_chains, n_samples, *shape = ary.shape - if n_chains > n_samples: - warnings.warn( - "More chains ({n_chains}) than draws ({n_samples}). " - "Passed array should have shape (chains, draws, *shape)".format( - n_chains=n_chains, n_samples=n_samples - ), - UserWarning, - ) + if default_dims is None: + default_dims = ["chain", "draw"] + if "chain" in default_dims and "draw" in default_dims: + ary = utils.two_de(ary) + n_chains, n_samples, *_ = ary.shape + if n_chains > n_samples: + warnings.warn( + "More chains ({n_chains}) than draws ({n_samples}). " + "Passed array should have shape (chains, draws, *shape)".format( + n_chains=n_chains, n_samples=n_samples + ), + UserWarning, + ) + else: + ary = utils.one_de(ary) dims, coords = generate_dims_coords( - shape, + ary.shape[len(default_dims) :], var_name, dims=dims, coords=coords, default_dims=default_dims, + index_origin=index_origin, skip_event_dims=skip_event_dims, ) # reversed order for default dims: 'chain', 'draw' - if "draw" not in dims: + if "draw" not in dims and "draw" in default_dims: dims = ["draw"] + dims - if "chain" not in dims: + if "chain" not in dims and "chain" in default_dims: dims = ["chain"] + dims - if "chain" not in coords: - coords["chain"] = utils.arange(n_chains) - if "draw" not in coords: - coords["draw"] = utils.arange(n_samples) + index_origin = rcParams["data.index_origin"] + if "chain" not in coords and "chain" in default_dims: + coords["chain"] = np.arange(index_origin, n_chains + index_origin) + if "draw" not in coords and "draw" in default_dims: + coords["draw"] = np.arange(index_origin, n_samples + index_origin) # filter coords based on the dims coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in dims} @@ -200,7 +234,15 @@ def numpy_to_data_array(ary, *, var_name="data", coords=None, dims=None, skip_ev def dict_to_dataset( - data, *, attrs=None, library=None, coords=None, dims=None, skip_event_dims=None + data, + *, + attrs=None, + library=None, + coords=None, + dims=None, + default_dims=None, + index_origin=None, + skip_event_dims=None, ): """Convert a dictionary of numpy arrays to an xarray.Dataset. @@ -217,6 +259,10 @@ def dict_to_dataset( dims : dict[str] -> list[str] Dimensions of each variable. The keys are variable names, values are lists of coordinates. + default_dims : list of str, optional + Passed to :py:func:`numpy_to_data_array` + index_origin : int, optional + Passed to :py:func:`numpy_to_data_array` skip_event_dims : bool If True, cut extra dims whenever present to match the shape of the data. Necessary for PPLs which have the same name in both observed data and log @@ -238,7 +284,13 @@ def dict_to_dataset( data_vars = {} for key, values in data.items(): data_vars[key] = numpy_to_data_array( - values, var_name=key, coords=coords, dims=dims.get(key), skip_event_dims=skip_event_dims + values, + var_name=key, + coords=coords, + dims=dims.get(key), + default_dims=default_dims, + index_origin=index_origin, + skip_event_dims=skip_event_dims, ) return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library)) @@ -312,7 +364,7 @@ def wrapped(self, *args, **kwargs): return None if _inplace else out description_default = """{method_name} method is extended from xarray.Dataset methods. - + {description}For more info see :meth:`xarray:xarray.Dataset.{method_name}` """.format( description=description, method_name=func.__name__ # pylint: disable=no-member diff --git a/arviz/data/io_cmdstan.py b/arviz/data/io_cmdstan.py index 70143d5478..8248db81b2 100644 --- a/arviz/data/io_cmdstan.py +++ b/arviz/data/io_cmdstan.py @@ -7,11 +7,10 @@ from typing import Dict, List, Optional, Union import numpy as np -import xarray as xr from .. import utils from ..rcparams import rcParams -from .base import CoordSpec, DimSpec, dict_to_dataset, generate_dims_coords, requires +from .base import CoordSpec, DimSpec, dict_to_dataset, requires from .inference_data import InferenceData _log = logging.getLogger(__name__) @@ -51,6 +50,7 @@ def __init__( predictions_constant_data=None, predictions_constant_data_var=None, log_likelihood=None, + index_origin=None, coords=None, dims=None, disable_glob=False, @@ -80,6 +80,7 @@ def __init__( self.attrs_prior = None self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup + self.index_origin = index_origin if dtypes is None: self.dtypes = {} @@ -200,8 +201,20 @@ def posterior_to_xarray(self): data = _unpack_ndarrays(self.posterior[0], valid_cols, self.dtypes) data_warmup = _unpack_ndarrays(self.posterior[1], valid_cols, self.dtypes) return ( - dict_to_dataset(data, coords=self.coords, dims=self.dims, attrs=self.attrs), - dict_to_dataset(data_warmup, coords=self.coords, dims=self.dims, attrs=self.attrs), + dict_to_dataset( + data, + coords=self.coords, + dims=self.dims, + attrs=self.attrs, + index_origin=self.index_origin, + ), + dict_to_dataset( + data_warmup, + coords=self.coords, + dims=self.dims, + attrs=self.attrs, + index_origin=self.index_origin, + ), ) @requires("posterior") @@ -231,12 +244,14 @@ def sample_stats_to_xarray(self): coords=self.coords, dims=self.dims, attrs={item: key for key, item in rename_dict.items()}, + index_origin=self.index_origin, ), dict_to_dataset( data_warmup, coords=self.coords, dims=self.dims, attrs={item: key for key, item in rename_dict.items()}, + index_origin=self.index_origin, ), ) @@ -284,8 +299,20 @@ def posterior_predictive_to_xarray(self): attrs = None return ( - dict_to_dataset(data, coords=self.coords, dims=self.dims, attrs=attrs), - dict_to_dataset(data_warmup, coords=self.coords, dims=self.dims, attrs=attrs), + dict_to_dataset( + data, + coords=self.coords, + dims=self.dims, + attrs=attrs, + index_origin=self.index_origin, + ), + dict_to_dataset( + data_warmup, + coords=self.coords, + dims=self.dims, + attrs=attrs, + index_origin=self.index_origin, + ), ) @requires("posterior") @@ -330,8 +357,20 @@ def predictions_to_xarray(self): attrs = None return ( - dict_to_dataset(data, coords=self.coords, dims=self.dims, attrs=attrs), - dict_to_dataset(data_warmup, coords=self.coords, dims=self.dims, attrs=attrs), + dict_to_dataset( + data, + coords=self.coords, + dims=self.dims, + attrs=attrs, + index_origin=self.index_origin, + ), + dict_to_dataset( + data_warmup, + coords=self.coords, + dims=self.dims, + attrs=attrs, + index_origin=self.index_origin, + ), ) @requires("posterior") @@ -377,10 +416,20 @@ def log_likelihood_to_xarray(self): attrs = None return ( dict_to_dataset( - data, coords=self.coords, dims=self.dims, attrs=attrs, skip_event_dims=True + data, + coords=self.coords, + dims=self.dims, + attrs=attrs, + index_origin=self.index_origin, + skip_event_dims=True, ), dict_to_dataset( - data_warmup, coords=self.coords, dims=self.dims, attrs=attrs, skip_event_dims=True + data_warmup, + coords=self.coords, + dims=self.dims, + attrs=attrs, + index_origin=self.index_origin, + skip_event_dims=True, ), ) @@ -410,9 +459,19 @@ def prior_to_xarray(self): data = _unpack_ndarrays(self.prior[0], valid_cols, self.dtypes) data_warmup = _unpack_ndarrays(self.prior[1], valid_cols, self.dtypes) return ( - dict_to_dataset(data, coords=self.coords, dims=self.dims, attrs=self.attrs_prior), dict_to_dataset( - data_warmup, coords=self.coords, dims=self.dims, attrs=self.attrs_prior + data, + coords=self.coords, + dims=self.dims, + attrs=self.attrs_prior, + index_origin=self.index_origin, + ), + dict_to_dataset( + data_warmup, + coords=self.coords, + dims=self.dims, + attrs=self.attrs_prior, + index_origin=self.index_origin, ), ) @@ -443,12 +502,14 @@ def sample_stats_prior_to_xarray(self): coords=self.coords, dims=self.dims, attrs={item: key for key, item in rename_dict.items()}, + index_origin=self.index_origin, ), dict_to_dataset( data_warmup, coords=self.coords, dims=self.dims, attrs={item: key for key, item in rename_dict.items()}, + index_origin=self.index_origin, ), ) @@ -491,8 +552,20 @@ def prior_predictive_to_xarray(self): data_warmup = _unpack_ndarrays(self.prior[1], columns, self.dtypes) attrs = None return ( - dict_to_dataset(data, coords=self.coords, dims=self.dims, attrs=attrs), - dict_to_dataset(data_warmup, coords=self.coords, dims=self.dims, attrs=attrs), + dict_to_dataset( + data, + coords=self.coords, + dims=self.dims, + attrs=attrs, + index_origin=self.index_origin, + ), + dict_to_dataset( + data_warmup, + coords=self.coords, + dims=self.dims, + attrs=attrs, + index_origin=self.index_origin, + ), ) @requires("observed_data") @@ -506,13 +579,14 @@ def observed_data_to_xarray(self): for key, vals in observed_data_raw.items(): if variables is not None and key not in variables: continue - vals = utils.one_de(vals) - val_dims = self.dims.get(key) - val_dims, coords = generate_dims_coords( - vals.shape, key, dims=val_dims, coords=self.coords - ) - observed_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords) - return xr.Dataset(data_vars=observed_data) + observed_data[key] = utils.one_de(vals) + return dict_to_dataset( + observed_data, + coords=self.coords, + dims=self.dims, + default_dims=[], + index_origin=self.index_origin, + ) @requires("constant_data") def constant_data_to_xarray(self): @@ -525,13 +599,14 @@ def constant_data_to_xarray(self): for key, vals in constant_data_raw.items(): if variables is not None and key not in variables: continue - vals = utils.one_de(vals) - val_dims = self.dims.get(key) - val_dims, coords = generate_dims_coords( - vals.shape, key, dims=val_dims, coords=self.coords - ) - constant_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords) - return xr.Dataset(data_vars=constant_data) + constant_data[key] = utils.one_de(vals) + return dict_to_dataset( + constant_data, + coords=self.coords, + dims=self.dims, + default_dims=[], + index_origin=self.index_origin, + ) @requires("predictions_constant_data") def predictions_constant_data_to_xarray(self): @@ -545,12 +620,14 @@ def predictions_constant_data_to_xarray(self): if variables is not None and key not in variables: continue vals = utils.one_de(vals) - val_dims = self.dims.get(key) - val_dims, coords = generate_dims_coords( - vals.shape, key, dims=val_dims, coords=self.coords - ) - predictions_constant_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords) - return xr.Dataset(data_vars=predictions_constant_data) + predictions_constant_data[key] = utils.one_de(vals) + return dict_to_dataset( + predictions_constant_data, + coords=self.coords, + dims=self.dims, + default_dims=[], + index_origin=self.index_origin, + ) def to_inference_data(self): """Convert all available data to an InferenceData object. @@ -821,6 +898,7 @@ def from_cmdstan( predictions_constant_data: Optional[str] = None, predictions_constant_data_var: Optional[Union[str, List[str]]] = None, log_likelihood: Optional[Union[str, List[str]]] = None, + index_origin: Optional[int] = None, coords: Optional[CoordSpec] = None, dims: Optional[DimSpec] = None, disable_glob: Optional[bool] = False, @@ -862,6 +940,9 @@ def from_cmdstan( If not defined, all data variables are imported. log_likelihood : str or list of str, optional Pointwise log_likelihood for the data. + index_origin : int, optional + Starting value of integer coordinate values. Defaults to the value in rcParam + ``data.index_origin``. coords : dict of {str: array_like}, optional A dictionary containing the values that are used as index. The key is the name of the dimension, the values are the index values. @@ -893,6 +974,7 @@ def from_cmdstan( predictions_constant_data=predictions_constant_data, predictions_constant_data_var=predictions_constant_data_var, log_likelihood=log_likelihood, + index_origin=index_origin, coords=coords, dims=dims, disable_glob=disable_glob, diff --git a/arviz/data/io_cmdstanpy.py b/arviz/data/io_cmdstanpy.py index 3ec97e8fe3..176346aaa3 100644 --- a/arviz/data/io_cmdstanpy.py +++ b/arviz/data/io_cmdstanpy.py @@ -5,11 +5,9 @@ from copy import deepcopy import numpy as np -import xarray as xr -from .. import utils from ..rcparams import rcParams -from .base import dict_to_dataset, generate_dims_coords, make_attrs, requires +from .base import dict_to_dataset, make_attrs, requires from .inference_data import InferenceData _log = logging.getLogger(__name__) @@ -32,6 +30,7 @@ def __init__( constant_data=None, predictions_constant_data=None, log_likelihood=None, + index_origin=None, coords=None, dims=None, save_warmup=None, @@ -45,6 +44,7 @@ def __init__( self.constant_data = constant_data self.predictions_constant_data = predictions_constant_data self.log_likelihood = log_likelihood + self.index_origin = index_origin self.coords = coords self.dims = dims @@ -133,9 +133,19 @@ def stats_to_xarray(self, fit): if data_warmup: data_warmup[name] = data_warmup.pop(item).astype(dtypes.get(item, float)) return ( - dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims), dict_to_dataset( - data_warmup, library=self.cmdstanpy, coords=self.coords, dims=self.dims + data, + library=self.cmdstanpy, + coords=self.coords, + dims=self.dims, + index_origin=self.index_origin, + ), + dict_to_dataset( + data_warmup, + library=self.cmdstanpy, + coords=self.coords, + dims=self.dims, + index_origin=self.index_origin, ), ) @@ -171,9 +181,19 @@ def predictive_to_xarray(self, names, fit): ) return ( - dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims), dict_to_dataset( - data_warmup, library=self.cmdstanpy, coords=self.coords, dims=self.dims + data, + library=self.cmdstanpy, + coords=self.coords, + dims=self.dims, + index_origin=self.index_origin, + ), + dict_to_dataset( + data_warmup, + library=self.cmdstanpy, + coords=self.coords, + dims=self.dims, + index_origin=self.index_origin, ), ) @@ -200,9 +220,19 @@ def predictions_to_xarray(self): ) return ( - dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims), dict_to_dataset( - data_warmup, library=self.cmdstanpy, coords=self.coords, dims=self.dims + data, + library=self.cmdstanpy, + coords=self.coords, + dims=self.dims, + index_origin=self.index_origin, + ), + dict_to_dataset( + data_warmup, + library=self.cmdstanpy, + coords=self.coords, + dims=self.dims, + index_origin=self.index_origin, ), ) @@ -233,6 +263,7 @@ def log_likelihood_to_xarray(self): library=self.cmdstanpy, coords=self.coords, dims=self.dims, + index_origin=self.index_origin, skip_event_dims=True, ), dict_to_dataset( @@ -240,6 +271,7 @@ def log_likelihood_to_xarray(self): library=self.cmdstanpy, coords=self.coords, dims=self.dims, + index_origin=self.index_origin, skip_event_dims=True, ), ) @@ -275,51 +307,57 @@ def prior_to_xarray(self): ) return ( - dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims), dict_to_dataset( - data_warmup, library=self.cmdstanpy, coords=self.coords, dims=self.dims + data, + library=self.cmdstanpy, + coords=self.coords, + dims=self.dims, + index_origin=self.index_origin, + ), + dict_to_dataset( + data_warmup, + library=self.cmdstanpy, + coords=self.coords, + dims=self.dims, + index_origin=self.index_origin, ), ) @requires("observed_data") def observed_data_to_xarray(self): """Convert observed data to xarray.""" - observed_data = {} - for key, vals in self.observed_data.items(): - vals = utils.one_de(vals) - val_dims = self.dims.get(key) if self.dims is not None else None - val_dims, coords = generate_dims_coords( - vals.shape, key, dims=val_dims, coords=self.coords - ) - observed_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords) - return xr.Dataset(data_vars=observed_data, attrs=make_attrs(library=self.cmdstanpy)) + return dict_to_dataset( + self.observed_data, + library=self.cmdstanpy, + coords=self.coords, + dims=self.dims, + default_dims=[], + index_origin=self.index_origin, + ) @requires("constant_data") def constant_data_to_xarray(self): """Convert constant data to xarray.""" - constant_data = {} - for key, vals in self.constant_data.items(): - vals = utils.one_de(vals) - val_dims = self.dims.get(key) if self.dims is not None else None - val_dims, coords = generate_dims_coords( - vals.shape, key, dims=val_dims, coords=self.coords - ) - constant_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords) - return xr.Dataset(data_vars=constant_data, attrs=make_attrs(library=self.cmdstanpy)) + return dict_to_dataset( + self.constant_data, + library=self.cmdstanpy, + coords=self.coords, + dims=self.dims, + default_dims=[], + index_origin=self.index_origin, + ) @requires("predictions_constant_data") def predictions_constant_data_to_xarray(self): """Convert constant data to xarray.""" - predictions_constant_data = {} - for key, vals in self.predictions_constant_data.items(): - vals = utils.one_de(vals) - val_dims = self.dims.get(key) if self.dims is not None else None - val_dims, coords = generate_dims_coords( - vals.shape, key, dims=val_dims, coords=self.coords - ) - predictions_constant_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords) - return xr.Dataset( - data_vars=predictions_constant_data, attrs=make_attrs(library=self.cmdstanpy) + return dict_to_dataset( + self.predictions_constant_data, + library=self.cmdstanpy, + coords=self.coords, + dims=self.dims, + attrs=make_attrs(library=self.cmdstanpy), + default_dims=[], + index_origin=self.index_origin, ) def to_inference_data(self): @@ -614,6 +652,7 @@ def from_cmdstanpy( constant_data=None, predictions_constant_data=None, log_likelihood=None, + index_origin=None, coords=None, dims=None, save_warmup=None, @@ -643,6 +682,9 @@ def from_cmdstanpy( Constant data for predictions used in the sampling. log_likelihood : str, list of str Pointwise log_likelihood for the data. + index_origin : int, optional + Starting value of integer coordinate values. Defaults to the value in rcParam + ``data.index_origin``. coords : dict of str or dict of iterable A dictionary containing the values that are used as index. The key is the name of the dimension, the values are the index values. @@ -666,6 +708,7 @@ def from_cmdstanpy( constant_data=constant_data, predictions_constant_data=predictions_constant_data, log_likelihood=log_likelihood, + index_origin=index_origin, coords=coords, dims=dims, save_warmup=save_warmup, diff --git a/arviz/data/io_dict.py b/arviz/data/io_dict.py index 9ef820d787..9b87f25e0a 100644 --- a/arviz/data/io_dict.py +++ b/arviz/data/io_dict.py @@ -1,11 +1,9 @@ """Dictionary specific conversion code.""" import warnings +from typing import Optional -import xarray as xr - -from .. import utils from ..rcparams import rcParams -from .base import dict_to_dataset, generate_dims_coords, make_attrs, requires +from .base import dict_to_dataset, requires from .inference_data import WARMUP_TAG, InferenceData @@ -33,6 +31,7 @@ def __init__( warmup_log_likelihood=None, warmup_sample_stats=None, save_warmup=None, + index_origin=None, coords=None, dims=None, pred_dims=None, @@ -63,6 +62,8 @@ def __init__( if coords is None else {**coords, **pred_coords} ) + self.index_origin = index_origin + self.coords = coords self.dims = dims self.pred_dims = dims if pred_dims is None else pred_dims self.attrs = {} if attrs is None else attrs @@ -92,10 +93,20 @@ def posterior_to_xarray(self): return ( dict_to_dataset( - data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs + data, + library=None, + coords=self.coords, + dims=self.dims, + attrs=self.attrs, + index_origin=self.index_origin, ), dict_to_dataset( - data_warmup, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs + data_warmup, + library=None, + coords=self.coords, + dims=self.dims, + attrs=self.attrs, + index_origin=self.index_origin, ), ) @@ -119,10 +130,20 @@ def sample_stats_to_xarray(self): return ( dict_to_dataset( - data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs + data, + library=None, + coords=self.coords, + dims=self.dims, + attrs=self.attrs, + index_origin=self.index_origin, ), dict_to_dataset( - data_warmup, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs + data_warmup, + library=None, + coords=self.coords, + dims=self.dims, + attrs=self.attrs, + index_origin=self.index_origin, ), ) @@ -143,6 +164,7 @@ def log_likelihood_to_xarray(self): coords=self.coords, dims=self.dims, attrs=self.attrs, + index_origin=self.index_origin, skip_event_dims=True, ), dict_to_dataset( @@ -151,6 +173,7 @@ def log_likelihood_to_xarray(self): coords=self.coords, dims=self.dims, attrs=self.attrs, + index_origin=self.index_origin, skip_event_dims=True, ), ) @@ -167,10 +190,20 @@ def posterior_predictive_to_xarray(self): return ( dict_to_dataset( - data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs + data, + library=None, + coords=self.coords, + dims=self.dims, + attrs=self.attrs, + index_origin=self.index_origin, ), dict_to_dataset( - data_warmup, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs + data_warmup, + library=None, + coords=self.coords, + dims=self.dims, + attrs=self.attrs, + index_origin=self.index_origin, ), ) @@ -186,10 +219,20 @@ def predictions_to_xarray(self): return ( dict_to_dataset( - data, library=None, coords=self.coords, dims=self.pred_dims, attrs=self.attrs + data, + library=None, + coords=self.coords, + dims=self.pred_dims, + attrs=self.attrs, + index_origin=self.index_origin, ), dict_to_dataset( - data_warmup, library=None, coords=self.coords, dims=self.pred_dims, attrs=self.attrs + data_warmup, + library=None, + coords=self.coords, + dims=self.pred_dims, + attrs=self.attrs, + index_origin=self.index_origin, ), ) @@ -201,7 +244,12 @@ def prior_to_xarray(self): raise TypeError("DictConverter.prior is not a dictionary") return dict_to_dataset( - data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs + data, + library=None, + coords=self.coords, + dims=self.dims, + attrs=self.attrs, + index_origin=self.index_origin, ) @requires("sample_stats_prior") @@ -212,7 +260,12 @@ def sample_stats_prior_to_xarray(self): raise TypeError("DictConverter.sample_stats_prior is not a dictionary") return dict_to_dataset( - data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs + data, + library=None, + coords=self.coords, + dims=self.dims, + attrs=self.attrs, + index_origin=self.index_origin, ) @requires("prior_predictive") @@ -223,25 +276,29 @@ def prior_predictive_to_xarray(self): raise TypeError("DictConverter.prior_predictive is not a dictionary") return dict_to_dataset( - data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs + data, + library=None, + coords=self.coords, + dims=self.dims, + attrs=self.attrs, + index_origin=self.index_origin, ) - def data_to_xarray(self, dct, group, dims=None): + def data_to_xarray(self, data, group, dims=None): """Convert data to xarray.""" - data = dct if not isinstance(data, dict): raise TypeError("DictConverter.{} is not a dictionary".format(group)) if dims is None: dims = {} if self.dims is None else self.dims - new_data = dict() - for key, vals in data.items(): - vals = utils.one_de(vals) - val_dims = dims.get(key) - val_dims, coords = generate_dims_coords( - vals.shape, key, dims=val_dims, coords=self.coords - ) - new_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords) - return xr.Dataset(data_vars=new_data, attrs=make_attrs(attrs=self.attrs, library=None)) + return dict_to_dataset( + data, + library=None, + coords=self.coords, + dims=self.dims, + default_dims=[], + attrs=self.attrs, + index_origin=self.index_origin, + ) @requires("observed_data") def observed_data_to_xarray(self): @@ -304,6 +361,7 @@ def from_dict( warmup_log_likelihood=None, warmup_sample_stats=None, save_warmup=None, + index_origin: Optional[int] = None, coords=None, dims=None, pred_dims=None, @@ -336,6 +394,7 @@ def from_dict( save_warmup : bool Save warmup iterations InferenceData object. If not defined, use default defined by the rcParams. + index_origin : int, optional coords : dict[str, iterable] A dictionary containing the values that are used as index. The key is the name of the dimension, the values are the index values. @@ -370,6 +429,7 @@ def from_dict( warmup_log_likelihood=warmup_log_likelihood, warmup_sample_stats=warmup_sample_stats, save_warmup=save_warmup, + index_origin=index_origin, coords=coords, dims=dims, pred_dims=pred_dims, diff --git a/arviz/data/io_emcee.py b/arviz/data/io_emcee.py index 50d66a5912..4cebb18727 100644 --- a/arviz/data/io_emcee.py +++ b/arviz/data/io_emcee.py @@ -85,6 +85,7 @@ def _verify_names(sampler, var_names, arg_names, slices): return var_names, arg_names, slices +# pylint: disable=too-many-instance-attributes class EmceeConverter: """Encapsulate emcee specific logic.""" @@ -97,6 +98,7 @@ def __init__( arg_groups=None, blob_names=None, blob_groups=None, + index_origin=None, coords=None, dims=None, ): @@ -108,6 +110,7 @@ def __init__( self.arg_groups = arg_groups self.blob_names = blob_names self.blob_groups = blob_groups + self.index_origin = index_origin self.coords = coords self.dims = dims import emcee @@ -124,7 +127,13 @@ def posterior_to_xarray(self): if hasattr(self.sampler, "get_chain") else self.sampler.chain[(..., idx)] ) - return dict_to_dataset(data, library=self.emcee, coords=self.coords, dims=self.dims) + return dict_to_dataset( + data, + library=self.emcee, + coords=self.coords, + dims=self.dims, + index_origin=self.index_origin, + ) def args_to_xarray(self): """Convert emcee args to observed and constant_data xarray Datasets.""" @@ -157,7 +166,11 @@ def args_to_xarray(self): ) arg_dims = dims.get(arg_name) arg_dims, coords = generate_dims_coords( - arg_array.shape, arg_name, dims=arg_dims, coords=self.coords + arg_array.shape, + arg_name, + dims=arg_dims, + coords=self.coords, + index_origin=self.index_origin, ) # filter coords based on the dims coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in arg_dims} @@ -227,7 +240,11 @@ def blobs_to_dict(self): ) for key, values in blob_dict.items(): blob_dict[key] = dict_to_dataset( - values, library=self.emcee, coords=self.coords, dims=self.dims + values, + library=self.emcee, + coords=self.coords, + dims=self.dims, + index_origin=self.index_origin, ) return blob_dict @@ -248,6 +265,7 @@ def from_emcee( arg_groups=None, blob_names=None, blob_groups=None, + index_origin=None, coords=None, dims=None, ): @@ -461,6 +479,7 @@ def from_emcee( arg_groups=arg_groups, blob_names=blob_names, blob_groups=blob_groups, + index_origin=index_origin, coords=coords, dims=dims, ).to_inference_data() diff --git a/arviz/data/io_numpyro.py b/arviz/data/io_numpyro.py index a99704b020..622905633a 100644 --- a/arviz/data/io_numpyro.py +++ b/arviz/data/io_numpyro.py @@ -3,10 +3,9 @@ from typing import Callable, Optional import numpy as np -import xarray as xr from .. import utils -from .base import dict_to_dataset, generate_dims_coords, make_attrs, requires +from .base import dict_to_dataset, requires from .inference_data import InferenceData _log = logging.getLogger(__name__) @@ -30,6 +29,7 @@ def __init__( predictions=None, constant_data=None, predictions_constant_data=None, + index_origin=None, coords=None, dims=None, pred_dims=None, @@ -51,6 +51,7 @@ def __init__( Dictionary containing constant data variables mapped to their values. predictions_constant_data: dict Constant data used for out-of-sample predictions. + index_origin : int, optinal coords : dict[str] -> list[str] Map of dimensions to coordinates dims : dict[str] -> list[str] @@ -69,6 +70,7 @@ def __init__( self.predictions = predictions self.constant_data = constant_data self.predictions_constant_data = predictions_constant_data + self.index_origin = index_origin self.coords = coords self.dims = dims self.pred_dims = pred_dims @@ -127,7 +129,13 @@ def arbitrary_element(dct): def posterior_to_xarray(self): """Convert the posterior to an xarray dataset.""" data = self._samples - return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=self.dims) + return dict_to_dataset( + data, + library=self.numpyro, + coords=self.coords, + dims=self.dims, + index_origin=self.index_origin, + ) @requires("posterior") def sample_stats_to_xarray(self): @@ -147,7 +155,13 @@ def sample_stats_to_xarray(self): data[name] = value if stat == "num_steps": data["tree_depth"] = np.log2(value).astype(int) + 1 - return dict_to_dataset(data, library=self.numpyro, dims=None, coords=self.coords) + return dict_to_dataset( + data, + library=self.numpyro, + dims=None, + coords=self.coords, + index_origin=self.index_origin, + ) @requires("posterior") @requires("model") @@ -163,7 +177,12 @@ def log_likelihood_to_xarray(self): shape = (self.nchains, self.ndraws) + log_like.shape[1:] data[obs_name] = np.reshape(log_like.copy(), shape) return dict_to_dataset( - data, library=self.numpyro, dims=self.dims, coords=self.coords, skip_event_dims=True + data, + library=self.numpyro, + dims=self.dims, + coords=self.coords, + index_origin=self.index_origin, + skip_event_dims=True, ) def translate_posterior_predictive_dict_to_xarray(self, dct, dims): @@ -181,7 +200,13 @@ def translate_posterior_predictive_dict_to_xarray(self, dct, dims): "posterior predictive shape not compatible with number of chains and draws. " "This can mean that some draws or even whole chains are not represented." ) - return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=dims) + return dict_to_dataset( + data, + library=self.numpyro, + coords=self.coords, + dims=dims, + index_origin=self.index_origin, + ) @requires("posterior_predictive") def posterior_predictive_to_xarray(self): @@ -217,6 +242,7 @@ def priors_to_xarray(self): library=self.numpyro, coords=self.coords, dims=self.dims, + index_origin=self.index_origin, ) ) return priors_dict @@ -225,47 +251,38 @@ def priors_to_xarray(self): @requires("model") def observed_data_to_xarray(self): """Convert observed data to xarray.""" - if self.dims is None: - dims = {} - else: - dims = self.dims - observed_data = {} - for name, vals in self.observations.items(): - vals = utils.one_de(vals) - val_dims = dims.get(name) - val_dims, coords = generate_dims_coords( - vals.shape, name, dims=val_dims, coords=self.coords - ) - # filter coords based on the dims - coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims} - observed_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords) - return xr.Dataset(data_vars=observed_data, attrs=make_attrs(library=self.numpyro)) - - def convert_constant_data_to_xarray(self, dct, dims): - """Convert constant_data or predictions_constant_data to xarray.""" - if dims is None: - dims = {} - constant_data = {} - for name, vals in dct.items(): - vals = utils.one_de(vals) - val_dims = dims.get(name) - val_dims, coords = generate_dims_coords( - vals.shape, name, dims=val_dims, coords=self.coords - ) - # filter coords based on the dims - coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims} - constant_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords) - return xr.Dataset(data_vars=constant_data, attrs=make_attrs(library=self.numpyro)) + return dict_to_dataset( + self.observations, + library=self.numpyro, + dims=self.dims, + coords=self.coords, + default_dims=[], + index_origin=self.index_origin, + ) @requires("constant_data") def constant_data_to_xarray(self): """Convert constant_data to xarray.""" - return self.convert_constant_data_to_xarray(self.constant_data, self.dims) + return dict_to_dataset( + self.constant_data, + library=self.numpyro, + dims=self.dims, + coords=self.coords, + default_dims=[], + index_origin=self.index_origin, + ) @requires("predictions_constant_data") def predictions_constant_data_to_xarray(self): """Convert predictions_constant_data to xarray.""" - return self.convert_constant_data_to_xarray(self.predictions_constant_data, self.pred_dims) + return dict_to_dataset( + self.predictions_constant_data, + library=self.numpyro, + dims=self.pred_dims, + coords=self.coords, + default_dims=[], + index_origin=self.index_origin, + ) def to_inference_data(self): """Convert all available data to an InferenceData object. @@ -297,6 +314,7 @@ def from_numpyro( predictions=None, constant_data=None, predictions_constant_data=None, + index_origin=None, coords=None, dims=None, pred_dims=None, @@ -321,6 +339,7 @@ def from_numpyro( Dictionary containing constant data variables mapped to their values. predictions_constant_data: dict Constant data used for out-of-sample predictions. + index_origin : int, optional coords : dict[str] -> list[str] Map of dimensions to coordinates dims : dict[str] -> list[str] @@ -337,6 +356,7 @@ def from_numpyro( predictions=predictions, constant_data=constant_data, predictions_constant_data=predictions_constant_data, + index_origin=index_origin, coords=coords, dims=dims, pred_dims=pred_dims, diff --git a/arviz/labels.py b/arviz/labels.py new file mode 100644 index 0000000000..1a1af0491e --- /dev/null +++ b/arviz/labels.py @@ -0,0 +1,210 @@ +# pylint: disable=unused-argument +"""Utilities to generate labels from xarray objects.""" +from typing import Union + +__all__ = [ + "mix_labellers", + "BaseLabeller", + "DimCoordLabeller", + "DimIdxLabeller", + "MapLabeller", + "NoRepeatLabeller", + "NoModelLabeller", +] + + +def mix_labellers(labellers, class_name="MixtureLabeller"): + """Combine Labeller classes dynamically. + + The Labeller class aims to split plot labeling in ArviZ into atomic tasks to maximize + extensibility, and the few classes provided are designed with small deviations + from the base class, in many cases only one method is modified by the child class. + It is to be expected then to want to use multiple classes "at once". + + This functions helps combine classes dynamically. + + Parameters + ---------- + labellers : iterable of types + Iterable of Labeller types to combine + class_name : str, optional + The name of the generated class + + Returns + ------- + type + Mixture class object. *It is not initialized* + + Examples + -------- + Combine the :class:`~arviz.labels.DimCoordLabeller` with the + :class:`~arviz.labels.MapLabeller` to generate labels in the style of the + ``DimCoordLabeller`` but using the mappings defined by ``MapLabeller``. + Note that this works even though both modify the same methods because + ``MapLabeller`` implements the mapping and then calls `super().method`. + + .. ipython:: + + In [1]: from arviz.labels import mix_labellers, DimCoordLabeller, MapLabeller + ...: l1 = DimCoordLabeller() + ...: sel = {"dim1": "a", "dim2": "top"} + ...: print(f"Output of DimCoordLabeller alone > {l1.sel_to_str(sel, sel)}") + ...: l2 = MapLabeller(dim_map={"dim1": "$d_1$", "dim2": r"$d_2$"}) + ...: print(f"Output of MapLabeller alone > {l2.sel_to_str(sel, sel)}") + ...: l3 = mix_labellers( + ...: (MapLabeller, DimCoordLabeller) + ...: )(dim_map={"dim1": "$d_1$", "dim2": r"$d_2$"}) + ...: print(f"Output of mixture labeller > {l3.sel_to_str(sel, sel)}") + + We can see how the mappings are taken into account as well as the dim+coord style. However, + he order in the ``labellers`` arg iterator is important! See for yourself: + + .. ipython:: python + + l4 = mix_labellers( + (DimCoordLabeller, MapLabeller) + )(dim_map={"dim1": "$d_1$", "dim2": r"$d_2$"}) + print(f"Output of inverted mixture labeller > {l4.sel_to_str(sel, sel)}") + + """ + return type(class_name, labellers, {}) + + +class BaseLabeller: + """WIP.""" + + def dim_coord_to_str(self, dim, coord_val, coord_idx): + """WIP.""" + return f"{coord_val}" + + def sel_to_str(self, sel: dict, isel: dict): + """WIP.""" + if not sel: + return "" + return ", ".join( + [ + self.dim_coord_to_str(dim, v, i) + for (dim, v), (_, i) in zip(sel.items(), isel.items()) + ] + ) + + def var_name_to_str(self, var_name: Union[str, None]): + """WIP.""" + return var_name + + def var_pp_to_str(self, var_name, pp_var_name): + """WIP.""" + var_name_str = self.var_name_to_str(var_name) + pp_var_name_str = self.var_name_to_str(pp_var_name) + return f"{var_name_str} / {pp_var_name_str}" + + def model_name_to_str(self, model_name): + """WIP.""" + return model_name + + def make_label_vert(self, var_name: Union[str, None], sel: dict, isel: dict): + """WIP.""" + var_name_str = self.var_name_to_str(var_name) + sel_str = self.sel_to_str(sel, isel) + if not sel_str: + return var_name_str + if var_name_str is None: + return sel_str + return f"{var_name_str}\n{sel_str}" + + def make_label_flat(self, var_name: str, sel: dict, isel: dict): + """WIP.""" + var_name_str = self.var_name_to_str(var_name) + sel_str = self.sel_to_str(sel, isel) + if not sel_str: + return var_name_str + if var_name is None: + return sel_str + return f"{var_name_str}[{sel_str}]" + + def make_pp_label(self, var_name, pp_var_name, sel, isel): + """WIP.""" + names = self.var_pp_to_str(var_name, pp_var_name) + return self.make_label_vert(names, sel, isel) + + def make_model_label(self, model_name, label): + """WIP.""" + model_name_str = self.model_name_to_str(model_name) + if model_name_str is None: + return label + return f"{model_name}: {label}" + + +class DimCoordLabeller(BaseLabeller): + """WIP.""" + + def dim_coord_to_str(self, dim, coord_val, coord_idx): + """WIP.""" + return f"{dim}: {coord_val}" + + +class IdxLabeller(BaseLabeller): + """WIP.""" + + def dim_coord_to_str(self, dim, coord_val, coord_idx): + """WIP.""" + return f"{coord_idx}" + + +class DimIdxLabeller(BaseLabeller): + """WIP.""" + + def dim_coord_to_str(self, dim, coord_val, coord_idx): + """WIP.""" + return f"{dim}#{coord_idx}" + + +class MapLabeller(BaseLabeller): + """WIP.""" + + def __init__(self, var_name_map=None, dim_map=None, coord_map=None, model_name_map=None): + """WIP.""" + self.var_name_map = {} if var_name_map is None else var_name_map + self.dim_map = {} if dim_map is None else dim_map + self.coord_map = {} if coord_map is None else coord_map + self.model_name_map = {} if model_name_map is None else model_name_map + + def dim_coord_to_str(self, dim, coord_val, coord_idx): + """WIP.""" + dim_str = self.dim_map.get(dim, dim) + coord_str = self.coord_map.get(dim, {}).get(coord_val, coord_val) + return super().dim_coord_to_str(dim_str, coord_str, coord_idx) + + def var_name_to_str(self, var_name): + """WIP.""" + var_name_str = self.var_name_map.get(var_name, var_name) + return super().var_name_to_str(var_name_str) + + def model_name_to_str(self, model_name): + """WIP.""" + model_name_str = self.var_name_map.get(model_name, model_name) + return super().model_name_to_str(model_name_str) + + +class NoRepeatLabeller(BaseLabeller): + """WIP.""" + + def __init__(self): + """WIP.""" + self.current_var = None + + def var_name_to_str(self, var_name): + """WIP.""" + current_var = getattr(self, "current_var", None) + if var_name == current_var: + return "" + self.current_var = var_name + return var_name + + +class NoModelLabeller(BaseLabeller): + """WIP.""" + + def make_model_label(self, model_name, label): + """WIP.""" + return label diff --git a/arviz/plots/autocorrplot.py b/arviz/plots/autocorrplot.py index ab6c46b2e9..f700a20eb3 100644 --- a/arviz/plots/autocorrplot.py +++ b/arviz/plots/autocorrplot.py @@ -1,8 +1,10 @@ """Autocorrelation plot of data.""" from ..data import convert_to_dataset +from ..labels import BaseLabeller +from ..sel_utils import xarray_var_iter from ..rcparams import rcParams from ..utils import _var_names -from .plot_utils import default_grid, filter_plotters_list, get_plotting_function, xarray_var_iter +from .plot_utils import default_grid, filter_plotters_list, get_plotting_function def plot_autocorr( @@ -14,6 +16,7 @@ def plot_autocorr( grid=None, figsize=None, textsize=None, + labeller=None, ax=None, backend=None, backend_config=None, @@ -52,6 +55,9 @@ def plot_autocorr( textsize: float Text size scaling factor for labels, titles and lines. If None it will be autoscaled based on figsize. + labeller : labeller instance, optional + Class providing the method `make_label_vert` to generate the labels in the plot titles. + Read the :ref:`label_guide` for more details and usage examples. ax: numpy array-like of matplotlib axes or bokeh figures, optional A 2D array of locations into which to plot the densities. If not supplied, Arviz will create its own array of plot areas (and return it). @@ -110,6 +116,9 @@ def plot_autocorr( if max_lag is None: max_lag = min(100, data["draw"].shape[0]) + if labeller is None: + labeller = BaseLabeller() + plotters = filter_plotters_list( list(xarray_var_iter(data, var_names, combined)), "plot_autocorr" ) @@ -124,6 +133,7 @@ def plot_autocorr( cols=cols, combined=combined, textsize=textsize, + labeller=labeller, backend_kwargs=backend_kwargs, show=show, ) diff --git a/arviz/plots/backends/bokeh/autocorrplot.py b/arviz/plots/backends/bokeh/autocorrplot.py index 89b07cc9ba..97bd8d20d6 100644 --- a/arviz/plots/backends/bokeh/autocorrplot.py +++ b/arviz/plots/backends/bokeh/autocorrplot.py @@ -5,7 +5,7 @@ from ....stats import autocorr -from ...plot_utils import _scale_fig_size, make_label +from ...plot_utils import _scale_fig_size from .. import show_layout from . import backend_kwarg_defaults, create_axes_grid @@ -19,6 +19,7 @@ def plot_autocorr( cols, combined, textsize, + labeller, backend_config, backend_kwargs, show, @@ -27,7 +28,7 @@ def plot_autocorr( if backend_config is None: backend_config = {} - len_y = plotters[0][2].size + len_y = plotters[0][-1].size backend_config.setdefault("bounds_x_range", (0, len_y)) backend_config = { @@ -69,7 +70,7 @@ def plot_autocorr( start=-1, end=1, bounds=backend_config["bounds_y_range"], min_interval=0.1 ) - for (var_name, selection, x), ax in zip( + for (var_name, selection, isel, x), ax in zip( plotters, (item for item in axes.flatten() if item is not None) ): x_prime = x @@ -89,7 +90,7 @@ def plot_autocorr( ) title = Title() - title.text = make_label(var_name, selection) + title.text = labeller.make_label_vert(var_name, selection, isel) ax.title = title ax.x_range = data_range_x ax.y_range = data_range_y diff --git a/arviz/plots/backends/bokeh/bpvplot.py b/arviz/plots/backends/bokeh/bpvplot.py index d2c834ee78..94e9bbb057 100644 --- a/arviz/plots/backends/bokeh/bpvplot.py +++ b/arviz/plots/backends/bokeh/bpvplot.py @@ -36,6 +36,7 @@ def plot_bpv( color, figsize, textsize, + labeller, plot_ref_kwargs, backend_kwargs, show, @@ -83,8 +84,8 @@ def plot_bpv( ) for i, ax_i in enumerate((item for item in axes.flatten() if item is not None)): - var_name, _, obs_vals = obs_plotters[i] - pp_var_name, _, pp_vals = pp_plotters[i] + var_name, sel, isel, obs_vals = obs_plotters[i] + pp_var_name, _, _, pp_vals = pp_plotters[i] obs_vals = obs_vals.flatten() pp_vals = pp_vals.reshape(total_pp_samples, -1) @@ -175,12 +176,8 @@ def plot_bpv( obs_vals.mean(), 0, fill_color=color, line_color="black", size=markersize ) - if var_name != pp_var_name: - xlabel = "{} / {}".format(var_name, pp_var_name) - else: - xlabel = var_name _title = Title() - _title.text = xlabel + _title.text = labeller.make_pp_label(var_name, pp_var_name, sel, isel) ax_i.title = _title size = str(int(ax_labelsize)) ax_i.title.text_font_size = f"{size}pt" diff --git a/arviz/plots/backends/bokeh/compareplot.py b/arviz/plots/backends/bokeh/compareplot.py index 4b8dd726f1..45e09b575b 100644 --- a/arviz/plots/backends/bokeh/compareplot.py +++ b/arviz/plots/backends/bokeh/compareplot.py @@ -44,9 +44,6 @@ def plot_compare( yticks_pos = list(yticks_pos) if plot_ic_diff: - yticks_labels[0] = comp_df.index[0] - yticks_labels[2::2] = comp_df.index[1:] - ax.yaxis.ticker = yticks_pos ax.yaxis.major_label_overrides = { dtype(key): value @@ -77,7 +74,6 @@ def plot_compare( ax.multi_line(err_xs, err_ys, line_color=plot_kwargs.get("color_dse", "grey")) else: - yticks_labels = comp_df.index ax.yaxis.ticker = yticks_pos[::2] ax.yaxis.major_label_overrides = { key: value for key, value in zip(yticks_pos[::2], yticks_labels) diff --git a/arviz/plots/backends/bokeh/densityplot.py b/arviz/plots/backends/bokeh/densityplot.py index eb8dcec100..b14e06a914 100644 --- a/arviz/plots/backends/bokeh/densityplot.py +++ b/arviz/plots/backends/bokeh/densityplot.py @@ -8,7 +8,7 @@ from ....stats import hdi from ....stats.density_utils import get_bins, histogram, kde -from ...plot_utils import _scale_fig_size, calculate_point_estimate, make_label, vectorized_to_hex +from ...plot_utils import _scale_fig_size, calculate_point_estimate, vectorized_to_hex from .. import show_layout from . import backend_kwarg_defaults, create_axes_grid @@ -25,6 +25,7 @@ def plot_density( rows, cols, textsize, + labeller, hdi_prob, point_estimate, hdi_markers, @@ -78,8 +79,8 @@ def plot_density( legend_items = defaultdict(list) for m_idx, plotters in enumerate(to_plot): - for var_name, selection, values in plotters: - label = make_label(var_name, selection) + for var_name, selection, isel, values in plotters: + label = labeller.make_label_vert(var_name, selection, isel) if data_labels: data_label = data_labels[m_idx] diff --git a/arviz/plots/backends/bokeh/distcomparisonplot.py b/arviz/plots/backends/bokeh/distcomparisonplot.py index 3e3ed9855e..38d5bcdc58 100644 --- a/arviz/plots/backends/bokeh/distcomparisonplot.py +++ b/arviz/plots/backends/bokeh/distcomparisonplot.py @@ -10,6 +10,7 @@ def plot_dist_comparison( legend, groups, textsize, + labeller, prior_kwargs, posterior_kwargs, observed_kwargs, diff --git a/arviz/plots/backends/bokeh/essplot.py b/arviz/plots/backends/bokeh/essplot.py index 1a09f3c3e7..94206e63ee 100644 --- a/arviz/plots/backends/bokeh/essplot.py +++ b/arviz/plots/backends/bokeh/essplot.py @@ -5,9 +5,9 @@ from bokeh.models.annotations import Legend, Title from scipy.stats import rankdata -from ...plot_utils import _scale_fig_size, make_label from .. import show_layout from . import backend_kwarg_defaults, create_axes_grid +from ...plot_utils import _scale_fig_size def plot_ess( @@ -31,6 +31,7 @@ def plot_ess( n_samples, relative, min_ess, + labeller, ylabel, rug, rug_kind, @@ -61,7 +62,7 @@ def plot_ess( else: ax = np.atleast_2d(ax) - for (var_name, selection, x), ax_ in zip( + for (var_name, selection, isel, x), ax_ in zip( plotters, (item for item in ax.flatten() if item is not None) ): bulk_points = ax_.circle(np.asarray(xdata), np.asarray(x), size=6) @@ -153,7 +154,7 @@ def plot_ess( ax_.legend.click_policy = "hide" title = Title() - title.text = make_label(var_name, selection) + title.text = labeller.make_label_vert(var_name, selection, isel) ax_.title = title ax_.xaxis.axis_label = "Total number of draws" if kind == "evolution" else "Quantile" diff --git a/arviz/plots/backends/bokeh/forestplot.py b/arviz/plots/backends/bokeh/forestplot.py index 662b449587..b58a353582 100644 --- a/arviz/plots/backends/bokeh/forestplot.py +++ b/arviz/plots/backends/bokeh/forestplot.py @@ -10,12 +10,13 @@ from bokeh.models.annotations import Title from bokeh.models.tickers import FixedTicker +from ....sel_utils import xarray_var_iter from ....rcparams import rcParams from ....stats import hdi from ....stats.density_utils import get_bins, histogram, kde from ....stats.diagnostics import _ess, _rhat from ....utils import conditional_jit -from ...plot_utils import _scale_fig_size, make_label, xarray_var_iter +from ...plot_utils import _scale_fig_size from .. import show_layout from . import backend_kwarg_defaults @@ -49,6 +50,8 @@ def plot_forest( ridgeplot_truncate, ridgeplot_quantiles, textsize, + legend, + labeller, ess, r_hat, backend_config, @@ -57,7 +60,12 @@ def plot_forest( ): """Bokeh forest plot.""" plot_handler = PlotHandler( - datasets, var_names=var_names, model_names=model_names, combined=combined, colors=colors + datasets, + var_names=var_names, + model_names=model_names, + combined=combined, + colors=colors, + labeller=labeller, ) if figsize is None: @@ -195,7 +203,7 @@ class PlotHandler: # pylint: disable=inconsistent-return-statements - def __init__(self, datasets, var_names, model_names, combined, colors): + def __init__(self, datasets, var_names, model_names, combined, colors, labeller): self.data = datasets if model_names is None: @@ -233,6 +241,7 @@ def __init__(self, datasets, var_names, model_names, combined, colors): colors = [colors for _ in self.data] self.colors = list(reversed(colors)) # y-values are upside down + self.labeller = labeller self.plotters = self.make_plotters() @@ -247,6 +256,7 @@ def make_plotters(self): model_names=self.model_names, combined=self.combined, colors=self.colors, + labeller=self.labeller, ) y = plotters[var_name].y_max() return plotters @@ -540,13 +550,14 @@ def y_max(self): class VarHandler: """Handle individual variable logic.""" - def __init__(self, var_name, data, y_start, model_names, combined, colors): + def __init__(self, var_name, data, y_start, model_names, combined, colors, labeller): self.var_name = var_name self.data = data self.y_start = y_start self.model_names = model_names self.combined = combined self.colors = colors + self.labeller = labeller self.model_color = dict(zip(self.model_names, self.colors)) max_chains = max(datum.chain.max().values for datum in data) self.chain_offset = len(data) * 0.45 / max(1, max_chains) @@ -573,7 +584,7 @@ def iterator(self): reverse_selections=True, ) datum_list = list(datum_iter) - for _, selection, values in datum_list: + for _, selection, isel, values in datum_list: selection_list.append(selection) if not selection: var_name = self.var_name @@ -581,7 +592,7 @@ def iterator(self): var_name = self.var_name + ":" else: var_name = "" - label = make_label(var_name, selection, position="beside") + label = self.labeller.make_label_flat(var_name, selection, isel) if label not in label_dict: label_dict[label] = OrderedDict() if name not in label_dict[label]: diff --git a/arviz/plots/backends/bokeh/jointplot.py b/arviz/plots/backends/bokeh/jointplot.py index 616e325b98..e35fd2452b 100644 --- a/arviz/plots/backends/bokeh/jointplot.py +++ b/arviz/plots/backends/bokeh/jointplot.py @@ -4,9 +4,10 @@ from ...distplot import plot_dist from ...kdeplot import plot_kde -from ...plot_utils import _scale_fig_size, make_label from .. import show_layout from . import backend_kwarg_defaults +from ...plot_utils import _scale_fig_size +from ....sel_utils import make_label def plot_joint( @@ -80,8 +81,8 @@ def plot_joint( axjoin.yaxis.axis_label = y_var_name # Flatten data - x = plotters[0][2].flatten() - y = plotters[1][2].flatten() + x = plotters[0][-1].flatten() + y = plotters[1][-1].flatten() if kind == "scatter": axjoin.circle(x, y, **joint_kwargs) diff --git a/arviz/plots/backends/bokeh/loopitplot.py b/arviz/plots/backends/bokeh/loopitplot.py index 030f5153e2..e0bc81d0a2 100644 --- a/arviz/plots/backends/bokeh/loopitplot.py +++ b/arviz/plots/backends/bokeh/loopitplot.py @@ -34,6 +34,7 @@ def plot_loo_pit( y, color, textsize, + labeller, hdi_prob, plot_kwargs, backend_kwargs, @@ -63,15 +64,21 @@ def plot_loo_pit( plot_kwargs.setdefault("color", to_hex(color)) plot_kwargs.setdefault("linewidth", linewidth * 1.4) if isinstance(y, str): - label = ("{} LOO-PIT ECDF" if ecdf else "{} LOO-PIT").format(y) + label = "LOO-PIT ECDF" if ecdf else "LOO-PIT" + xlabel = y elif isinstance(y, DataArray) and y.name is not None: - label = ("{} LOO-PIT ECDF" if ecdf else "{} LOO-PIT").format(y.name) + label = "LOO-PIT ECDF" if ecdf else "LOO-PIT" + xlabel = y.name elif isinstance(y_hat, str): - label = ("{} LOO-PIT ECDF" if ecdf else "{} LOO-PIT").format(y_hat) + label = "LOO-PIT ECDF" if ecdf else "LOO-PIT" + xlabel = y_hat elif isinstance(y_hat, DataArray) and y_hat.name is not None: - label = ("{} LOO-PIT ECDF" if ecdf else "{} LOO-PIT").format(y_hat.name) + label = "LOO-PIT ECDF" if ecdf else "LOO-PIT" + xlabel = y_hat.name else: label = "LOO-PIT ECDF" if ecdf else "LOO-PIT" + xlabel = "" + xlabel = labeller.var_name_to_str(xlabel) plot_kwargs.setdefault("legend_label", label) @@ -204,6 +211,7 @@ def plot_loo_pit( ) # Sets xlim(0, 1) + ax.xaxis.axis_label = xlabel ax.line(0, 0) ax.line(1, 0) show_layout(ax, show) diff --git a/arviz/plots/backends/bokeh/mcseplot.py b/arviz/plots/backends/bokeh/mcseplot.py index 0caccc8178..51142f9524 100644 --- a/arviz/plots/backends/bokeh/mcseplot.py +++ b/arviz/plots/backends/bokeh/mcseplot.py @@ -5,7 +5,7 @@ from scipy.stats import rankdata from ....stats.stats_utils import quantile as _quantile -from ...plot_utils import _scale_fig_size, make_label +from ...plot_utils import _scale_fig_size from .. import show_layout from . import backend_kwarg_defaults, create_axes_grid @@ -26,6 +26,7 @@ def plot_mcse( mean_mcse, sd_mcse, textsize, + labeller, text_kwargs, # pylint: disable=unused-argument rug_kwargs, extra_kwargs, @@ -61,7 +62,7 @@ def plot_mcse( else: ax = np.atleast_2d(ax) - for (var_name, selection, x), ax_ in zip( + for (var_name, selection, isel, x), ax_ in zip( plotters, (item for item in ax.flatten() if item is not None) ): if errorbar or rug: @@ -164,7 +165,7 @@ def plot_mcse( ax_.add_glyph(cds_rug, glyph) title = Title() - title.text = make_label(var_name, selection) + title.text = labeller.make_label_vert(var_name, selection, isel) ax_.title = title ax_.xaxis.axis_label = "Quantile" diff --git a/arviz/plots/backends/bokeh/posteriorplot.py b/arviz/plots/backends/bokeh/posteriorplot.py index f77ec05bba..324e3d6e61 100644 --- a/arviz/plots/backends/bokeh/posteriorplot.py +++ b/arviz/plots/backends/bokeh/posteriorplot.py @@ -12,7 +12,6 @@ _scale_fig_size, calculate_point_estimate, format_sig_figs, - make_label, round_num, ) from .. import show_layout @@ -38,6 +37,7 @@ def plot_posterior( textsize, ref_val, rope, + labeller, kwargs, backend_kwargs, show, @@ -66,7 +66,7 @@ def plot_posterior( else: ax = np.atleast_2d(ax) idx = 0 - for (var_name, selection, x), ax_ in zip( + for (var_name, selection, isel, x), ax_ in zip( plotters, (item for item in ax.flatten() if item is not None) ): _plot_posterior_op( @@ -92,7 +92,7 @@ def plot_posterior( ) idx += 1 _title = Title() - _title.text = make_label(var_name, selection) + _title.text = labeller.make_label_vert(var_name, selection, isel) ax_.title = _title show_layout(ax, show) diff --git a/arviz/plots/backends/bokeh/ppcplot.py b/arviz/plots/backends/bokeh/ppcplot.py index fe229dae96..11fe625992 100644 --- a/arviz/plots/backends/bokeh/ppcplot.py +++ b/arviz/plots/backends/bokeh/ppcplot.py @@ -29,6 +29,7 @@ def plot_ppc( jitter, total_pp_samples, legend, # pylint: disable=unused-argument + labeller, group, # pylint: disable=unused-argument animation_kwargs, # pylint: disable=unused-argument num_pp_samples, @@ -82,8 +83,8 @@ def plot_ppc( raise ValueError("jitter must be >=0.") for i, ax_i in enumerate((item for item in axes.flatten() if item is not None)): - var_name, _, obs_vals = obs_plotters[i] - pp_var_name, _, pp_vals = pp_plotters[i] + var_name, sel, isel, obs_vals = obs_plotters[i] + pp_var_name, _, _, pp_vals = pp_plotters[i] dtype = predictive_dataset[pp_var_name].dtype.kind # flatten non-specified dimensions @@ -271,11 +272,7 @@ def plot_ppc( ax_i.yaxis.minor_tick_line_color = None ax_i.yaxis.major_label_text_font_size = "0pt" - if var_name != pp_var_name: - xlabel = "{} / {}".format(var_name, pp_var_name) - else: - xlabel = var_name - ax_i.xaxis.axis_label = xlabel + ax_i.xaxis.axis_label = labeller.make_pp_label(var_name, pp_var_name, sel, isel) show_layout(axes, show) diff --git a/arviz/plots/backends/bokeh/rankplot.py b/arviz/plots/backends/bokeh/rankplot.py index bb7fb1dd8d..1ccccfe72f 100644 --- a/arviz/plots/backends/bokeh/rankplot.py +++ b/arviz/plots/backends/bokeh/rankplot.py @@ -6,7 +6,7 @@ from bokeh.models.tickers import FixedTicker from ....stats.density_utils import histogram -from ...plot_utils import _scale_fig_size, make_label, compute_ranks +from ...plot_utils import _scale_fig_size, compute_ranks from .. import show_layout from . import backend_kwarg_defaults, create_axes_grid @@ -23,6 +23,7 @@ def plot_rank( colors, ref_line, labels, + labeller, ref_line_kwargs, bar_kwargs, vlines_kwargs, @@ -71,7 +72,7 @@ def plot_rank( else: axes = np.atleast_2d(axes) - for ax, (var_name, selection, var_data) in zip( + for ax, (var_name, selection, isel, var_data) in zip( (item for item in axes.flatten() if item is not None), plotters ): ranks = compute_ranks(var_data) @@ -138,7 +139,7 @@ def plot_rank( ax.yaxis.major_label_text_font_size = "0pt" _title = Title() - _title.text = make_label(var_name, selection) + _title.text = labeller.make_label_vert(var_name, selection, isel) ax.title = _title show_layout(axes, show) diff --git a/arviz/plots/backends/bokeh/traceplot.py b/arviz/plots/backends/bokeh/traceplot.py index 9efe9ca0b6..ed2c32f2ba 100644 --- a/arviz/plots/backends/bokeh/traceplot.py +++ b/arviz/plots/backends/bokeh/traceplot.py @@ -10,10 +10,11 @@ from bokeh.models.annotations import Title from ...distplot import plot_dist -from ...plot_utils import _scale_fig_size, make_label, xarray_var_iter +from ...plot_utils import _scale_fig_size from ...rankplot import plot_rank from .. import show_layout from . import backend_kwarg_defaults, dealiase_sel_kwargs +from ....sel_utils import xarray_var_iter def plot_trace( @@ -31,6 +32,7 @@ def plot_trace( combined, chain_prop, legend, + labeller, plot_kwargs, fill_kwargs, rug_kwargs, @@ -167,7 +169,7 @@ def plot_trace( cds_var_groups = {} draw_name = "draw" - for var_name, selection, value in list( + for var_name, selection, isel, value in list( xarray_var_iter(data, var_names=var_names, combined=True) ): if selection: @@ -204,7 +206,7 @@ def plot_trace( cds_data = {chain_idx: ColumnDataSource(cds) for chain_idx, cds in cds_data.items()} - for idx, (var_name, selection, value) in enumerate(plotters): + for idx, (var_name, selection, isel, value) in enumerate(plotters): value = np.atleast_2d(value) if len(value.shape) == 2: @@ -269,7 +271,7 @@ def plot_trace( for col in (0, 1): _title = Title() - _title.text = make_label(var_name, selection) + _title.text = labeller.make_label_vert(var_name, selection, isel) axes[idx, col].title = _title axes[idx, col].y_range = DataRange1d( bounds=backend_config["bounds_y_range"], min_interval=0.1 diff --git a/arviz/plots/backends/bokeh/violinplot.py b/arviz/plots/backends/bokeh/violinplot.py index 6445b89a48..393f1f4ede 100644 --- a/arviz/plots/backends/bokeh/violinplot.py +++ b/arviz/plots/backends/bokeh/violinplot.py @@ -4,7 +4,7 @@ from ....stats import hdi from ....stats.density_utils import get_bins, histogram, kde -from ...plot_utils import _scale_fig_size, make_label +from ...plot_utils import _scale_fig_size from .. import show_layout from . import backend_kwarg_defaults, create_axes_grid @@ -23,6 +23,7 @@ def plot_violin( rug_kwargs, bw, textsize, + labeller, circular, hdi_prob, quartiles, @@ -58,7 +59,7 @@ def plot_violin( else: ax = np.atleast_2d(ax) - for (var_name, selection, x), ax_ in zip( + for (var_name, selection, isel, x), ax_ in zip( plotters, (item for item in ax.flatten() if item is not None) ): val = x.flatten() @@ -90,7 +91,7 @@ def plot_violin( _title = Title() _title.align = "center" - _title.text = make_label(var_name, selection) + _title.text = labeller.make_label_vert(var_name, selection, isel) ax_.title = _title ax_.xaxis.major_tick_line_color = None ax_.xaxis.minor_tick_line_color = None diff --git a/arviz/plots/backends/matplotlib/autocorrplot.py b/arviz/plots/backends/matplotlib/autocorrplot.py index 3ae90310d1..d80b5d38fa 100644 --- a/arviz/plots/backends/matplotlib/autocorrplot.py +++ b/arviz/plots/backends/matplotlib/autocorrplot.py @@ -3,7 +3,7 @@ import numpy as np from ....stats import autocorr -from ...plot_utils import _scale_fig_size, make_label +from ...plot_utils import _scale_fig_size from . import backend_kwarg_defaults, backend_show, create_axes_grid @@ -16,6 +16,7 @@ def plot_autocorr( cols, combined, textsize, + labeller, backend_kwargs, show, ): @@ -45,7 +46,7 @@ def plot_autocorr( backend_kwargs=backend_kwargs, ) - for (var_name, selection, x), ax in zip(plotters, np.ravel(axes)): + for (var_name, selection, isel, x), ax in zip(plotters, np.ravel(axes)): x_prime = x if combined: x_prime = x.flatten() @@ -55,7 +56,9 @@ def plot_autocorr( ax.fill_between([0, max_lag], -c_i, c_i, color="0.75") ax.vlines(x=np.arange(0, max_lag), ymin=0, ymax=y[0:max_lag], lw=linewidth) - ax.set_title(make_label(var_name, selection), fontsize=titlesize, wrap=True) + ax.set_title( + labeller.make_label_vert(var_name, selection, isel), fontsize=titlesize, wrap=True + ) ax.tick_params(labelsize=xt_labelsize) if np.asarray(axes).size > 0: diff --git a/arviz/plots/backends/matplotlib/bpvplot.py b/arviz/plots/backends/matplotlib/bpvplot.py index fb9dbbb204..5dea9f7040 100644 --- a/arviz/plots/backends/matplotlib/bpvplot.py +++ b/arviz/plots/backends/matplotlib/bpvplot.py @@ -9,7 +9,6 @@ from ...plot_utils import ( _scale_fig_size, is_valid_quantile, - make_label, sample_reference_distribution, ) from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser @@ -34,6 +33,7 @@ def plot_bpv( color, figsize, textsize, + labeller, plot_ref_kwargs, backend_kwargs, show, @@ -82,8 +82,8 @@ def plot_bpv( ) for i, ax_i in enumerate(np.ravel(axes)[:length_plotters]): - var_name, selection, obs_vals = obs_plotters[i] - pp_var_name, _, pp_vals = pp_plotters[i] + var_name, selection, isel, obs_vals = obs_plotters[i] + pp_var_name, _, _, pp_vals = pp_plotters[i] obs_vals = obs_vals.flatten() pp_vals = pp_vals.reshape(total_pp_samples, -1) @@ -167,11 +167,9 @@ def plot_bpv( obs_vals.mean(), 0, "o", color=color, markeredgecolor="k", markersize=markersize ) - if var_name != pp_var_name: - xlabel = "{} / {}".format(var_name, pp_var_name) - else: - xlabel = var_name - ax_i.set_title(make_label(xlabel, selection), fontsize=ax_labelsize) + ax_i.set_title( + labeller.make_pp_label(var_name, pp_var_name, selection, isel), fontsize=ax_labelsize + ) if backend_show(show): plt.show() diff --git a/arviz/plots/backends/matplotlib/compareplot.py b/arviz/plots/backends/matplotlib/compareplot.py index bdcd9e414e..7c92d36c21 100644 --- a/arviz/plots/backends/matplotlib/compareplot.py +++ b/arviz/plots/backends/matplotlib/compareplot.py @@ -42,8 +42,6 @@ def plot_compare( _, ax = create_axes_grid(1, backend_kwargs=backend_kwargs) if plot_ic_diff: - yticks_labels[0] = comp_df.index[0] - yticks_labels[2::2] = comp_df.index[1:] ax.set_yticks(yticks_pos) ax.errorbar( x=comp_df[information_criterion].iloc[1:], @@ -56,7 +54,6 @@ def plot_compare( ) else: - yticks_labels = comp_df.index ax.set_yticks(yticks_pos[::2]) if plot_standard_error: diff --git a/arviz/plots/backends/matplotlib/densityplot.py b/arviz/plots/backends/matplotlib/densityplot.py index 19084778de..dd2239ebe5 100644 --- a/arviz/plots/backends/matplotlib/densityplot.py +++ b/arviz/plots/backends/matplotlib/densityplot.py @@ -6,7 +6,7 @@ from ....stats import hdi from ....stats.density_utils import get_bins, kde -from ...plot_utils import _scale_fig_size, calculate_point_estimate, make_label +from ...plot_utils import _scale_fig_size, calculate_point_estimate from . import backend_kwarg_defaults, backend_show, create_axes_grid @@ -22,6 +22,7 @@ def plot_density( rows, cols, textsize, + labeller, hdi_prob, point_estimate, hdi_markers, @@ -68,8 +69,8 @@ def plot_density( axis_map = {label: ax_ for label, ax_ in zip(all_labels, np.ravel(ax))} for m_idx, plotters in enumerate(to_plot): - for var_name, selection, values in plotters: - label = make_label(var_name, selection) + for var_name, selection, isel, values in plotters: + label = labeller.make_label_vert(var_name, selection, isel) _d_helper( values.flatten(), label, diff --git a/arviz/plots/backends/matplotlib/distcomparisonplot.py b/arviz/plots/backends/matplotlib/distcomparisonplot.py index 99a35b9535..0684c31bbe 100644 --- a/arviz/plots/backends/matplotlib/distcomparisonplot.py +++ b/arviz/plots/backends/matplotlib/distcomparisonplot.py @@ -3,7 +3,7 @@ import numpy as np from ...distplot import plot_dist -from ...plot_utils import _scale_fig_size, make_label +from ...plot_utils import _scale_fig_size from . import backend_kwarg_defaults, backend_show @@ -16,6 +16,7 @@ def plot_dist_comparison( legend, groups, textsize, + labeller, prior_kwargs, posterior_kwargs, observed_kwargs, @@ -93,12 +94,12 @@ def plot_dist_comparison( else observed_kwargs ) for idx2, ( - var, - selection, + var_name, + sel, + isel, data, ) in enumerate(plotter): - label = make_label(var, selection) - label = f"{group} {label}" + label = f"{group}" plot_dist( data, label=label if legend else None, @@ -111,6 +112,8 @@ def plot_dist_comparison( ax=axes[idx2, -1], **kwargs, ) + if idx == 0: + axes[idx2, -1].set_xlabel(labeller.make_label_vert(var_name, sel, isel)) if backend_show(show): plt.show() diff --git a/arviz/plots/backends/matplotlib/essplot.py b/arviz/plots/backends/matplotlib/essplot.py index 5e0078177a..02e49fe2ae 100644 --- a/arviz/plots/backends/matplotlib/essplot.py +++ b/arviz/plots/backends/matplotlib/essplot.py @@ -3,7 +3,7 @@ import numpy as np from scipy.stats import rankdata -from ...plot_utils import _scale_fig_size, make_label +from ...plot_utils import _scale_fig_size from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser @@ -28,6 +28,7 @@ def plot_ess( n_samples, relative, min_ess, + labeller, ylabel, rug, rug_kind, @@ -95,7 +96,7 @@ def plot_ess( backend_kwargs=backend_kwargs, ) - for (var_name, selection, x), ax_ in zip(plotters, np.ravel(ax)): + for (var_name, selection, isel, x), ax_ in zip(plotters, np.ravel(ax)): ax_.plot(xdata, x, **kwargs) if kind == "evolution": ess_tail = ess_tail_dataset[var_name].sel(**selection) @@ -143,7 +144,9 @@ def plot_ess( ax_.axhline(400 / n_samples if relative else min_ess, **hline_kwargs) - ax_.set_title(make_label(var_name, selection), fontsize=titlesize, wrap=True) + ax_.set_title( + labeller.make_label_vert(var_name, selection, isel), fontsize=titlesize, wrap=True + ) ax_.tick_params(labelsize=xt_labelsize) ax_.set_xlabel( "Total number of draws" if kind == "evolution" else "Quantile", fontsize=ax_labelsize diff --git a/arviz/plots/backends/matplotlib/forestplot.py b/arviz/plots/backends/matplotlib/forestplot.py index 2c5be5743b..ba41ebbc43 100644 --- a/arviz/plots/backends/matplotlib/forestplot.py +++ b/arviz/plots/backends/matplotlib/forestplot.py @@ -5,12 +5,14 @@ import matplotlib.pyplot as plt import numpy as np from matplotlib.colors import to_rgba +from matplotlib.lines import Line2D from ....stats import hdi from ....stats.density_utils import get_bins, histogram, kde from ....stats.diagnostics import _ess, _rhat +from ....sel_utils import xarray_var_iter from ....utils import conditional_jit -from ...plot_utils import _scale_fig_size, make_label, xarray_var_iter +from ...plot_utils import _scale_fig_size from . import backend_kwarg_defaults, backend_show @@ -43,6 +45,8 @@ def plot_forest( ridgeplot_truncate, ridgeplot_quantiles, textsize, + legend, + labeller, ess, r_hat, backend_kwargs, @@ -51,7 +55,12 @@ def plot_forest( ): """Matplotlib forest plot.""" plot_handler = PlotHandler( - datasets, var_names=var_names, model_names=model_names, combined=combined, colors=colors + datasets, + var_names=var_names, + model_names=model_names, + combined=combined, + colors=colors, + labeller=labeller, ) if figsize is None: @@ -152,6 +161,8 @@ def plot_forest( if kind == "ridgeplot": # space at the top y_max += ridgeplot_overlap axes[0].set_ylim(-all_plotters[0].group_offset, y_max) + if legend: + plot_handler.legend(ax=axes[0]) if backend_show(show): plt.show() @@ -164,14 +175,14 @@ class PlotHandler: # pylint: disable=inconsistent-return-statements - def __init__(self, datasets, var_names, model_names, combined, colors): + def __init__(self, datasets, var_names, model_names, combined, colors, labeller): self.data = datasets if model_names is None: if len(self.data) > 1: model_names = ["Model {}".format(idx) for idx, _ in enumerate(self.data)] else: - model_names = [""] + model_names = [None] elif len(model_names) != len(self.data): raise ValueError("The number of model names does not match the number of models") @@ -192,11 +203,13 @@ def __init__(self, datasets, var_names, model_names, combined, colors): self.combined = combined if colors == "cycle": + # TODO: Use matplotlib prop cycle instead colors = ["C{}".format(idx) for idx, _ in enumerate(self.data)] elif isinstance(colors, str): colors = [colors for _ in self.data] self.colors = list(reversed(colors)) # y-values are upside down + self.labeller = labeller self.plotters = self.make_plotters() @@ -211,6 +224,7 @@ def make_plotters(self): model_names=self.model_names, combined=self.combined, colors=self.colors, + labeller=self.labeller, ) y = plotters[var_name].y_max() return plotters @@ -230,6 +244,11 @@ def label_idxs(): return label_idxs() + def legend(self, ax): + """Add legend with colorcoded model info.""" + handles = [Line2D([], [], color=c) for c in self.colors] + ax.legend(handles=handles, labels=self.model_names) + def display_multiple_ropes(self, rope, ax, y, linewidth, var_name, selection): """Display ROPE when more than one interval is provided.""" for sel in rope.get(var_name, []): @@ -479,16 +498,18 @@ def y_max(self): return max(p.y_max() for p in self.plotters.values()) +# pylint: disable=too-many-instance-attributes class VarHandler: """Handle individual variable logic.""" - def __init__(self, var_name, data, y_start, model_names, combined, colors): + def __init__(self, var_name, data, y_start, model_names, combined, colors, labeller): self.var_name = var_name self.data = data self.y_start = y_start self.model_names = model_names self.combined = combined self.colors = colors + self.labeller = labeller self.model_color = dict(zip(self.model_names, self.colors)) max_chains = max(datum.chain.max().values for datum in data) self.chain_offset = len(data) * 0.45 / max(1, max_chains) @@ -515,15 +536,13 @@ def iterator(self): reverse_selections=True, ) datum_list = list(datum_iter) - for _, selection, values in datum_list: + for _, selection, isel, values in datum_list: selection_list.append(selection) - if not selection: + if not selection or not len(selection_list) % len(datum_list): var_name = self.var_name - elif not len(selection_list) % len(datum_list): - var_name = self.var_name + ":" else: var_name = "" - label = make_label(var_name, selection, position="beside") + label = self.labeller.make_label_flat(var_name, selection, isel) if label not in label_dict: label_dict[label] = OrderedDict() if name not in label_dict[label]: @@ -533,10 +552,7 @@ def iterator(self): y = self.y_start for idx, (label, model_data) in enumerate(label_dict.items()): for model_name, value_list in model_data.items(): - if model_name: - row_label = "{}: {}".format(model_name, label) - else: - row_label = label + row_label = self.labeller.make_model_label(model_name, label) for values in value_list: yield y, row_label, label, selection_list[idx], values, self.model_color[ model_name diff --git a/arviz/plots/backends/matplotlib/jointplot.py b/arviz/plots/backends/matplotlib/jointplot.py index e14570808b..de4bf746a3 100644 --- a/arviz/plots/backends/matplotlib/jointplot.py +++ b/arviz/plots/backends/matplotlib/jointplot.py @@ -4,8 +4,9 @@ from ...distplot import plot_dist from ...kdeplot import plot_kde -from ...plot_utils import _scale_fig_size, make_label +from ...plot_utils import _scale_fig_size from . import backend_kwarg_defaults, backend_show, matplotlib_kwarg_dealiaser +from ....sel_utils import make_label def plot_joint( @@ -77,8 +78,8 @@ def plot_joint( axjoin.tick_params(labelsize=xt_labelsize) # Flatten data - x = plotters[0][2].flatten() - y = plotters[1][2].flatten() + x = plotters[0][-1].flatten() + y = plotters[1][-1].flatten() if kind == "scatter": axjoin.scatter(x, y, **joint_kwargs) diff --git a/arviz/plots/backends/matplotlib/loopitplot.py b/arviz/plots/backends/matplotlib/loopitplot.py index af5f563289..b0da43245d 100644 --- a/arviz/plots/backends/matplotlib/loopitplot.py +++ b/arviz/plots/backends/matplotlib/loopitplot.py @@ -29,6 +29,7 @@ def plot_loo_pit( plot_unif_kwargs, loo_pit_kde, legend, + labeller, y_hat, y, color, @@ -58,15 +59,21 @@ def plot_loo_pit( plot_kwargs["color"] = to_hex(color) plot_kwargs.setdefault("linewidth", linewidth * 1.4) if isinstance(y, str): - label = ("{} LOO-PIT ECDF" if ecdf else "{} LOO-PIT").format(y) + label = "LOO-PIT ECDF" if ecdf else "LOO-PIT" + xlabel = y elif isinstance(y, DataArray) and y.name is not None: - label = ("{} LOO-PIT ECDF" if ecdf else "{} LOO-PIT").format(y.name) + label = "LOO-PIT ECDF" if ecdf else "LOO-PIT" + xlabel = y.name elif isinstance(y_hat, str): - label = ("{} LOO-PIT ECDF" if ecdf else "{} LOO-PIT").format(y_hat) + label = "LOO-PIT ECDF" if ecdf else "LOO-PIT" + xlabel = y_hat elif isinstance(y_hat, DataArray) and y_hat.name is not None: - label = ("{} LOO-PIT ECDF" if ecdf else "{} LOO-PIT").format(y_hat.name) + label = "LOO-PIT ECDF" if ecdf else "LOO-PIT" + xlabel = y_hat.name else: label = "LOO-PIT ECDF" if ecdf else "LOO-PIT" + xlabel = "" + xlabel = labeller.var_name_to_str(y) plot_kwargs.setdefault("label", label) plot_kwargs.setdefault("zorder", 5) @@ -126,6 +133,7 @@ def plot_loo_pit( ax.plot(x_vals, loo_pit_kde, **plot_kwargs) ax.set_xlim(0, 1) ax.set_ylim(0, None) + ax.set_xlabel(xlabel) ax.tick_params(labelsize=xt_labelsize) if legend: if not (use_hdi or (ecdf and ecdf_fill)): diff --git a/arviz/plots/backends/matplotlib/mcseplot.py b/arviz/plots/backends/matplotlib/mcseplot.py index 39252db5e9..6410270884 100644 --- a/arviz/plots/backends/matplotlib/mcseplot.py +++ b/arviz/plots/backends/matplotlib/mcseplot.py @@ -4,7 +4,7 @@ from scipy.stats import rankdata from ....stats.stats_utils import quantile as _quantile -from ...plot_utils import _scale_fig_size, make_label +from ...plot_utils import _scale_fig_size from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser @@ -24,6 +24,7 @@ def plot_mcse( mean_mcse, sd_mcse, textsize, + labeller, text_kwargs, rug_kwargs, extra_kwargs, @@ -78,7 +79,7 @@ def plot_mcse( backend_kwargs=backend_kwargs, ) - for (var_name, selection, x), ax_ in zip(plotters, np.ravel(ax)): + for (var_name, selection, isel, x), ax_ in zip(plotters, np.ravel(ax)): if errorbar or rug: values = data[var_name].sel(**selection).values.flatten() if errorbar: @@ -132,7 +133,9 @@ def plot_mcse( ax_.plot(rug_x, rug_y, **rug_kwargs) ax_.axhline(y_min, color="k", linewidth=_linewidth, alpha=0.7) - ax_.set_title(make_label(var_name, selection), fontsize=titlesize, wrap=True) + ax_.set_title( + labeller.make_label_vert(var_name, selection, isel), fontsize=titlesize, wrap=True + ) ax_.tick_params(labelsize=xt_labelsize) ax_.set_xlabel("Quantile", fontsize=ax_labelsize, wrap=True) ax_.set_ylabel( diff --git a/arviz/plots/backends/matplotlib/posteriorplot.py b/arviz/plots/backends/matplotlib/posteriorplot.py index 6fe519cdbc..2b0c286880 100644 --- a/arviz/plots/backends/matplotlib/posteriorplot.py +++ b/arviz/plots/backends/matplotlib/posteriorplot.py @@ -11,7 +11,6 @@ _scale_fig_size, calculate_point_estimate, format_sig_figs, - make_label, round_num, ) from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser @@ -36,6 +35,7 @@ def plot_posterior( textsize, ref_val, rope, + labeller, kwargs, backend_kwargs, show, @@ -69,7 +69,7 @@ def plot_posterior( backend_kwargs=backend_kwargs, ) idx = 0 - for (var_name, selection, x), ax_ in zip(plotters, np.ravel(ax)): + for (var_name, selection, isel, x), ax_ in zip(plotters, np.ravel(ax)): _plot_posterior_op( idx, x.flatten(), @@ -92,7 +92,9 @@ def plot_posterior( **kwargs, ) idx += 1 - ax_.set_title(make_label(var_name, selection), fontsize=titlesize, wrap=True) + ax_.set_title( + labeller.make_label_vert(var_name, selection, isel), fontsize=titlesize, wrap=True + ) if backend_show(show): plt.show() diff --git a/arviz/plots/backends/matplotlib/ppcplot.py b/arviz/plots/backends/matplotlib/ppcplot.py index 320dc58469..e72a059e68 100644 --- a/arviz/plots/backends/matplotlib/ppcplot.py +++ b/arviz/plots/backends/matplotlib/ppcplot.py @@ -8,7 +8,7 @@ from ....stats.density_utils import get_bins, histogram, kde from ...kdeplot import plot_kde -from ...plot_utils import _scale_fig_size, make_label +from ...plot_utils import _scale_fig_size from . import backend_kwarg_defaults, backend_show, create_axes_grid _log = logging.getLogger(__name__) @@ -34,6 +34,7 @@ def plot_ppc( jitter, total_pp_samples, legend, + labeller, group, animation_kwargs, num_pp_samples, @@ -111,8 +112,8 @@ def plot_ppc( raise ValueError("All axes must be on the same figure for animation to work") for i, ax_i in enumerate(np.ravel(axes)[:length_plotters]): - var_name, selection, obs_vals = obs_plotters[i] - pp_var_name, _, pp_vals = pp_plotters[i] + var_name, selection, isel, obs_vals = obs_plotters[i] + pp_var_name, _, _, pp_vals = pp_plotters[i] dtype = predictive_dataset[pp_var_name].dtype.kind # flatten non-specified dimensions @@ -342,11 +343,9 @@ def plot_ppc( ax_i.set_yticks([]) - if var_name != pp_var_name: - xlabel = "{} / {}".format(var_name, pp_var_name) - else: - xlabel = var_name - ax_i.set_xlabel(make_label(xlabel, selection), fontsize=ax_labelsize) + ax_i.set_xlabel( + labeller.make_pp_label(var_name, pp_var_name, selection, isel), fontsize=ax_labelsize + ) if legend: if i == 0: diff --git a/arviz/plots/backends/matplotlib/rankplot.py b/arviz/plots/backends/matplotlib/rankplot.py index 00642d1e14..35e0ac7a79 100644 --- a/arviz/plots/backends/matplotlib/rankplot.py +++ b/arviz/plots/backends/matplotlib/rankplot.py @@ -3,7 +3,7 @@ import numpy as np from ....stats.density_utils import histogram -from ...plot_utils import _scale_fig_size, make_label, compute_ranks +from ...plot_utils import _scale_fig_size, compute_ranks from . import backend_kwarg_defaults, backend_show, create_axes_grid @@ -19,6 +19,7 @@ def plot_rank( colors, ref_line, labels, + labeller, ref_line_kwargs, bar_kwargs, vlines_kwargs, @@ -64,7 +65,7 @@ def plot_rank( backend_kwargs=backend_kwargs, ) - for ax, (var_name, selection, var_data) in zip(np.ravel(axes), plotters): + for ax, (var_name, selection, isel, var_data) in zip(np.ravel(axes), plotters): ranks = compute_ranks(var_data) bin_ary = np.histogram_bin_edges(ranks, bins=bins, range=(0, ranks.size)) all_counts = np.empty((len(ranks), len(bin_ary) - 1)) @@ -107,7 +108,7 @@ def plot_rank( ax.set_xlabel("Rank (all chains)", fontsize=ax_labelsize) ax.set_yticks(y_ticks) ax.set_yticklabels(np.arange(len(y_ticks))) - ax.set_title(make_label(var_name, selection), fontsize=titlesize) + ax.set_title(labeller.make_label_vert(var_name, selection, isel), fontsize=titlesize) else: ax.set_yticks([]) diff --git a/arviz/plots/backends/matplotlib/traceplot.py b/arviz/plots/backends/matplotlib/traceplot.py index 3f25209f43..1ce7d510c4 100644 --- a/arviz/plots/backends/matplotlib/traceplot.py +++ b/arviz/plots/backends/matplotlib/traceplot.py @@ -10,7 +10,7 @@ from ....stats.density_utils import get_bins from ...distplot import plot_dist -from ...plot_utils import _scale_fig_size, format_coords_as_labels, make_label +from ...plot_utils import _scale_fig_size, format_coords_as_labels from ...rankplot import plot_rank from . import backend_kwarg_defaults, backend_show, dealiase_sel_kwargs, matplotlib_kwarg_dealiaser @@ -30,6 +30,7 @@ def plot_trace( combined, chain_prop, legend, + labeller, plot_kwargs, fill_kwargs, rug_kwargs, @@ -221,7 +222,7 @@ def plot_trace( spec = gridspec.GridSpec(ncols=2, nrows=len(plotters), figure=fig) # pylint: disable=too-many-nested-blocks - for idx, (var_name, selection, value) in enumerate(plotters): + for idx, (var_name, selection, isel, value) in enumerate(plotters): for idy in range(2): value = np.atleast_2d(value) @@ -325,7 +326,10 @@ def plot_trace( if circular: y = 0.13 if selection else 0.12 ax.set_title( - make_label(var_name, selection), fontsize=titlesize, wrap=True, y=textsize * y + labeller.make_label_vert(var_name, selection, isel), + fontsize=titlesize, + wrap=True, + y=textsize * y, ) ax.tick_params(labelsize=xt_labelsize) diff --git a/arviz/plots/backends/matplotlib/violinplot.py b/arviz/plots/backends/matplotlib/violinplot.py index 2ded74caa9..3e8474f1a7 100644 --- a/arviz/plots/backends/matplotlib/violinplot.py +++ b/arviz/plots/backends/matplotlib/violinplot.py @@ -4,7 +4,7 @@ from ....stats import hdi from ....stats.density_utils import get_bins, histogram, kde -from ...plot_utils import _scale_fig_size, make_label +from ...plot_utils import _scale_fig_size from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser @@ -22,6 +22,7 @@ def plot_violin( rug_kwargs, bw, textsize, + labeller, circular, hdi_prob, quartiles, @@ -64,7 +65,7 @@ def plot_violin( ax = np.atleast_1d(ax) current_col = 0 - for (var_name, selection, x), ax_ in zip(plotters, ax.flatten()): + for (var_name, selection, isel, x), ax_ in zip(plotters, ax.flatten()): val = x.flatten() if val[0].dtype.kind == "i": dens = cat_hist(val, rug, shade, ax_, **shade_kwargs) @@ -83,7 +84,7 @@ def plot_violin( ax_.plot([0, 0], hdi_probs, lw=linewidth, color="k", solid_capstyle="round") ax_.plot(0, per[-1], "wo", ms=linewidth * 1.5) - ax_.set_title(make_label(var_name, selection), fontsize=ax_labelsize) + ax_.set_title(labeller.make_label_vert(var_name, selection, isel), fontsize=ax_labelsize) ax_.set_xticks([]) ax_.tick_params(labelsize=xt_labelsize) ax_.grid(None, axis="x") diff --git a/arviz/plots/bpvplot.py b/arviz/plots/bpvplot.py index 6c8b4afdaa..f358bd8ea1 100644 --- a/arviz/plots/bpvplot.py +++ b/arviz/plots/bpvplot.py @@ -1,9 +1,11 @@ """Bayesian p-value Posterior/Prior predictive plot.""" import numpy as np +from ..labels import BaseLabeller from ..rcparams import rcParams from ..utils import _var_names -from .plot_utils import default_grid, filter_plotters_list, get_plotting_function, xarray_var_iter +from .plot_utils import default_grid, filter_plotters_list, get_plotting_function +from ..sel_utils import xarray_var_iter def plot_bpv( @@ -20,6 +22,7 @@ def plot_bpv( grid=None, figsize=None, textsize=None, + labeller=None, data_pairs=None, var_names=None, filter_vars=None, @@ -93,6 +96,9 @@ def plot_bpv( For example, `data_pairs = {'y' : 'y_hat'}` If None, it will assume that the observed data and the posterior/prior predictive data have the same variable name. + labeller : labeller instance, optional + Class providing the method `make_pp_label` to generate the labels in the plot titles. + Read the :ref:`label_guide` for more details and usage examples. var_names : list of variable names Variables to be plotted, if `None` all variable are plotted. Prefix the variables by `~` when you want to exclude them from the plot. @@ -186,6 +192,9 @@ def plot_bpv( if data_pairs is None: data_pairs = {} + if labeller is None: + labeller = BaseLabeller() + if backend is None: backend = rcParams["plot.backend"] backend = backend.lower() @@ -260,6 +269,7 @@ def plot_bpv( color=color, figsize=figsize, textsize=textsize, + labeller=labeller, plot_ref_kwargs=plot_ref_kwargs, backend_kwargs=backend_kwargs, show=show, diff --git a/arviz/plots/compareplot.py b/arviz/plots/compareplot.py index 65527ff6e5..4893e69c4c 100644 --- a/arviz/plots/compareplot.py +++ b/arviz/plots/compareplot.py @@ -1,6 +1,7 @@ """Summary plot for model comparison.""" import numpy as np +from ..labels import BaseLabeller from ..rcparams import rcParams from .plot_utils import get_plotting_function @@ -13,6 +14,7 @@ def plot_compare( order_by_rank=True, figsize=None, textsize=None, + labeller=None, plot_kwargs=None, ax=None, backend=None, @@ -50,6 +52,9 @@ def plot_compare( textsize: float Text size scaling factor for labels, titles and lines. If None it will be autoscaled based on figsize. + labeller : labeller instance, optional + Class providing the method `model_name_to_str` to generate the labels in the plot. + Read the :ref:`label_guide` for more details and usage examples. plot_kwargs : dict, optional Optional arguments for plot elements. Currently accepts 'color_ic', 'marker_ic', 'color_insample_dev', 'marker_insample_dev', 'color_dse', @@ -92,10 +97,19 @@ def plot_compare( if plot_kwargs is None: plot_kwargs = {} + if labeller is None: + labeller = BaseLabeller() + yticks_pos, step = np.linspace(0, -1, (comp_df.shape[0] * 2) - 1, retstep=True) yticks_pos[1::2] = yticks_pos[1::2] + step / 2 + labels = [labeller.model_name_to_str(model_name) for model_name in comp_df.index] - yticks_labels = [""] * len(yticks_pos) + if plot_ic_diff: + yticks_labels = [""] * len(yticks_pos) + yticks_labels[0] = labels[0] + yticks_labels[2::2] = labels[1:] + else: + yticks_labels = labels _information_criterion = ["loo", "waic"] column_index = [c.lower() for c in comp_df.columns] diff --git a/arviz/plots/densityplot.py b/arviz/plots/densityplot.py index d8afe14c75..cd890b86ee 100644 --- a/arviz/plots/densityplot.py +++ b/arviz/plots/densityplot.py @@ -2,9 +2,13 @@ import warnings from ..data import convert_to_dataset +from ..labels import BaseLabeller +from ..sel_utils import ( + xarray_var_iter, +) from ..rcparams import rcParams from ..utils import _var_names -from .plot_utils import default_grid, get_plotting_function, make_label, xarray_var_iter +from .plot_utils import default_grid, get_plotting_function # pylint:disable-msg=too-many-function-args @@ -25,6 +29,7 @@ def plot_density( grid=None, figsize=None, textsize=None, + labeller=None, ax=None, backend=None, backend_kwargs=None, @@ -91,6 +96,9 @@ def plot_density( textsize: Optional[float] Text size scaling factor for labels, titles and lines. If None it will be autoscaled based on figsize. + labeller : labeller instance, optional + Class providing the method `make_label_vert` to generate the labels in the plot titles. + Read the :ref:`label_guide` for more details and usage examples. ax: numpy array-like of matplotlib axes or bokeh figures, optional A 2D array of locations into which to plot the densities. If not supplied, Arviz will create its own array of plot areas (and return it). @@ -168,6 +176,9 @@ def plot_density( if transform is not None: datasets = [transform(dataset) for dataset in datasets] + if labeller is None: + labeller = BaseLabeller() + var_names = _var_names(var_names, datasets) n_data = len(datasets) @@ -193,8 +204,8 @@ def plot_density( length_plotters = [] for plotters in to_plot: length_plotters.append(len(plotters)) - for var_name, selection, _ in plotters: - label = make_label(var_name, selection) + for var_name, selection, isel, _ in plotters: + label = labeller.make_label_vert(var_name, selection, isel) if label not in all_labels: all_labels.append(label) length_plotters = len(all_labels) @@ -211,8 +222,8 @@ def plot_density( to_plot = [ [ (var_name, selection, values) - for var_name, selection, values in plotters - if make_label(var_name, selection) in all_labels + for var_name, selection, isel, values in plotters + if labeller.make_label_vert(var_name, selection, isel) in all_labels ] for plotters in to_plot ] @@ -237,6 +248,7 @@ def plot_density( rows=rows, cols=cols, textsize=textsize, + labeller=labeller, hdi_prob=hdi_prob, point_estimate=point_estimate, hdi_markers=hdi_markers, diff --git a/arviz/plots/distcomparisonplot.py b/arviz/plots/distcomparisonplot.py index 2b91257ae5..6c129d75ab 100644 --- a/arviz/plots/distcomparisonplot.py +++ b/arviz/plots/distcomparisonplot.py @@ -1,7 +1,9 @@ """Density Comparison plot.""" +from ..labels import BaseLabeller from ..rcparams import rcParams from ..utils import _var_names, get_coords -from .plot_utils import get_plotting_function, xarray_var_iter +from .plot_utils import get_plotting_function +from ..sel_utils import xarray_var_iter def plot_dist_comparison( @@ -13,6 +15,7 @@ def plot_dist_comparison( coords=None, transform=None, legend=True, + labeller=None, ax=None, prior_kwargs=None, posterior_kwargs=None, @@ -51,6 +54,9 @@ def plot_dist_comparison( Function to transform data (defaults to None i.e. the identity function) legend : bool Add legend to figure. By default True. + labeller : labeller instance, optional + Class providing the method `make_pp_label` to generate the labels in the plot. + Read the :ref:`label_guide` for more details and usage examples. ax: axes, optional Matplotlib axes: The ax argument should have shape (nvars, 3), where the last column is for the combined before/after plots and columns 0 and 1 are @@ -94,6 +100,9 @@ def plot_dist_comparison( if coords is None: coords = {} + if labeller is None: + labeller = BaseLabeller() + datasets = [] groups = [] for group in all_groups: @@ -137,6 +146,7 @@ def plot_dist_comparison( legend=legend, groups=groups, textsize=textsize, + labeller=labeller, prior_kwargs=prior_kwargs, posterior_kwargs=posterior_kwargs, observed_kwargs=observed_kwargs, diff --git a/arviz/plots/essplot.py b/arviz/plots/essplot.py index 891870b100..23713b3d19 100644 --- a/arviz/plots/essplot.py +++ b/arviz/plots/essplot.py @@ -3,10 +3,12 @@ import xarray as xr from ..data import convert_to_dataset +from ..labels import BaseLabeller from ..rcparams import rcParams +from ..sel_utils import xarray_var_iter from ..stats import ess from ..utils import _var_names, get_coords -from .plot_utils import default_grid, filter_plotters_list, get_plotting_function, xarray_var_iter +from .plot_utils import default_grid, filter_plotters_list, get_plotting_function def plot_ess( @@ -24,6 +26,7 @@ def plot_ess( n_points=20, extra_methods=False, min_ess=400, + labeller=None, ax=None, extra_kwargs=None, text_kwargs=None, @@ -74,6 +77,9 @@ def plot_ess( Plot mean and sd ESS as horizontal lines. Not taken into account in evolution kind min_ess: int Minimum number of ESS desired. + labeller : labeller instance, optional + Class providing the method `make_label_vert` to generate the labels in the plot titles. + Read the :ref:`label_guide` for more details and usage examples. ax: numpy array-like of matplotlib axes or bokeh figures, optional A 2D array of locations into which to plot the densities. If not supplied, Arviz will create its own array of plot areas (and return it). @@ -173,6 +179,8 @@ def plot_ess( coords = {} if "chain" in coords or "draw" in coords: raise ValueError("chain and draw are invalid coordinates for this kind of plot") + if labeller is None: + labeller = BaseLabeller() extra_methods = False if kind == "evolution" else extra_methods data = get_coords(convert_to_dataset(idata, group="posterior"), coords) @@ -273,6 +281,7 @@ def plot_ess( n_samples=n_samples, relative=relative, min_ess=min_ess, + labeller=labeller, ylabel=ylabel, rug=rug, rug_kind=rug_kind, diff --git a/arviz/plots/forestplot.py b/arviz/plots/forestplot.py index 9b440d6b92..7babc0cebc 100644 --- a/arviz/plots/forestplot.py +++ b/arviz/plots/forestplot.py @@ -1,5 +1,6 @@ """Forest plot.""" from ..data import convert_to_dataset +from ..labels import BaseLabeller, NoModelLabeller from ..rcparams import rcParams from ..utils import _var_names, get_coords from .plot_utils import get_plotting_function @@ -23,6 +24,8 @@ def plot_forest( textsize=None, linewidth=None, markersize=None, + legend=True, + labeller=None, ridgeplot_alpha=None, ridgeplot_overlap=2, ridgeplot_kind="auto", @@ -88,6 +91,12 @@ def plot_forest( Line width throughout. If None it will be autoscaled based on figsize. markersize: int Markersize throughout. If None it will be autoscaled based on figsize. + legend : bool, optional + Show a legend with the color encoded model information. + Defaults to true if there are multiple models + labeller : labeller instance, optional + Class providing the method `make_model_label` to generate the labels in the plot. + Read the :ref:`label_guide` for more details and usage examples. ridgeplot_alpha: float Transparency for ridgeplot fill. If 0, border is colored by model, otherwise a black outline is used. @@ -182,10 +191,15 @@ def plot_forest( """ if not isinstance(data, (list, tuple)): data = [data] + if len(data) == 1: + legend = False if coords is None: coords = {} + if labeller is None: + labeller = NoModelLabeller() if legend else BaseLabeller() + datasets = [convert_to_dataset(datum) for datum in reversed(data)] if transform is not None: datasets = [transform(dataset) for dataset in datasets] @@ -233,6 +247,8 @@ def plot_forest( ridgeplot_truncate=ridgeplot_truncate, ridgeplot_quantiles=ridgeplot_quantiles, textsize=textsize, + legend=legend, + labeller=labeller, ess=ess, r_hat=r_hat, backend_kwargs=backend_kwargs, diff --git a/arviz/plots/jointplot.py b/arviz/plots/jointplot.py index a5a53dc75b..31503a1fac 100644 --- a/arviz/plots/jointplot.py +++ b/arviz/plots/jointplot.py @@ -2,9 +2,10 @@ import warnings from ..data import convert_to_dataset +from ..sel_utils import xarray_var_iter from ..rcparams import rcParams from ..utils import _var_names, get_coords -from .plot_utils import get_plotting_function, xarray_var_iter +from .plot_utils import get_plotting_function def plot_joint( diff --git a/arviz/plots/khatplot.py b/arviz/plots/khatplot.py index eb473ed621..50e02c2fd0 100644 --- a/arviz/plots/khatplot.py +++ b/arviz/plots/khatplot.py @@ -143,7 +143,7 @@ def plot_khat( References ---------- - ..[1] Vehtari, A., Simpson, D., Gelman, A., Yao, Y., Gabry, J., + .. [1] Vehtari, A., Simpson, D., Gelman, A., Yao, Y., Gabry, J., 2019. Pareto Smoothed Importance Sampling. arXiv:1507.02646 [stat]. """ diff --git a/arviz/plots/loopitplot.py b/arviz/plots/loopitplot.py index 14bc55a52a..89f7a9167c 100644 --- a/arviz/plots/loopitplot.py +++ b/arviz/plots/loopitplot.py @@ -2,6 +2,7 @@ import numpy as np import scipy.stats as stats +from ..labels import BaseLabeller from ..rcparams import rcParams from ..stats import loo_pit as _loo_pit from ..stats.density_utils import kde @@ -20,6 +21,7 @@ def plot_loo_pit( hdi_prob=None, figsize=None, textsize=None, + labeller=None, color="C0", legend=True, ax=None, @@ -66,6 +68,9 @@ def plot_loo_pit( If None, size is (8 + numvars, 8 + numvars) textsize: int, optional Text size for labels. If None it will be autoscaled based on figsize. + labeller : labeller instance, optional + Class providing the method `make_pp_label` to generate the labels in the plot titles. + Read the :ref:`label_guide` for more details and usage examples. color : str or array_like, optional Color of the LOO-PIT estimated pdf plot. If ``plot_unif_kwargs`` has no "color" key, an slightly lighter color than this argument will be used for the uniform kde lines. @@ -127,6 +132,9 @@ def plot_loo_pit( if ecdf and use_hdi: raise ValueError("use_hdi is incompatible with ecdf plot") + if labeller is None: + labeller = BaseLabeller() + loo_pit = _loo_pit(idata=idata, y=y, y_hat=y_hat, log_weights=log_weights) loo_pit = loo_pit.flatten() if isinstance(loo_pit, np.ndarray) else loo_pit.values.flatten() @@ -184,6 +192,7 @@ def plot_loo_pit( plot_unif_kwargs=plot_unif_kwargs, loo_pit_kde=loo_pit_kde, textsize=textsize, + labeller=labeller, color=color, legend=legend, y_hat=y_hat, diff --git a/arviz/plots/mcseplot.py b/arviz/plots/mcseplot.py index 1545314120..e9b381179c 100644 --- a/arviz/plots/mcseplot.py +++ b/arviz/plots/mcseplot.py @@ -3,10 +3,12 @@ import xarray as xr from ..data import convert_to_dataset -from ..rcparams import rcParams +from ..labels import BaseLabeller +from ..sel_utils import xarray_var_iter from ..stats import mcse +from ..rcparams import rcParams from ..utils import _var_names, get_coords -from .plot_utils import default_grid, filter_plotters_list, get_plotting_function, xarray_var_iter +from .plot_utils import default_grid, filter_plotters_list, get_plotting_function def plot_mcse( @@ -22,6 +24,7 @@ def plot_mcse( rug=False, rug_kind="diverging", n_points=20, + labeller=None, ax=None, rug_kwargs=None, extra_kwargs=None, @@ -68,6 +71,9 @@ def plot_mcse( n_points: int Number of points for which to plot their quantile/local ess or number of subsets in the evolution plot. + labeller : labeller instance, optional + Class providing the method `make_label_vert` to generate the labels in the plot titles. + Read the :ref:`label_guide` for more details and usage examples. ax: numpy array-like of matplotlib axes or bokeh figures, optional A 2D array of locations into which to plot the densities. If not supplied, Arviz will create its own array of plot areas (and return it). @@ -119,6 +125,9 @@ def plot_mcse( if "chain" in coords or "draw" in coords: raise ValueError("chain and draw are invalid coordinates for this kind of plot") + if labeller is None: + labeller = BaseLabeller() + data = get_coords(convert_to_dataset(idata, group="posterior"), coords) var_names = _var_names(var_names, data, filter_vars) @@ -154,6 +163,7 @@ def plot_mcse( mean_mcse=mean_mcse, sd_mcse=sd_mcse, textsize=textsize, + labeller=labeller, text_kwargs=text_kwargs, rug_kwargs=rug_kwargs, extra_kwargs=extra_kwargs, diff --git a/arviz/plots/pairplot.py b/arviz/plots/pairplot.py index 94c50c7d37..d6cd497a67 100644 --- a/arviz/plots/pairplot.py +++ b/arviz/plots/pairplot.py @@ -5,14 +5,11 @@ import numpy as np from ..data import convert_to_dataset +from ..labels import BaseLabeller +from ..sel_utils import xarray_to_ndarray, xarray_var_iter +from .plot_utils import get_plotting_function from ..rcparams import rcParams from ..utils import _var_names, get_coords -from .plot_utils import ( - get_plotting_function, - xarray_to_ndarray, - xarray_var_iter, - make_label, -) def plot_pair( @@ -31,6 +28,7 @@ def plot_pair( fill_last=False, divergences=False, colorbar=False, + labeller=None, ax=None, divergences_kwargs=None, scatter_kwargs=None, @@ -91,6 +89,9 @@ def plot_pair( colorbar: bool If True a colorbar will be included as part of the plot (Defaults to False). Only works when kind=hexbin + labeller : labeller instance, optional + Class providing the method `make_label_vert` to generate the labels in the plot. + Read the :ref:`label_guide` for more details and usage examples. ax: axes, optional Matplotlib axes or bokeh figures. divergences_kwargs: dicts, optional @@ -191,13 +192,18 @@ def plot_pair( if coords is None: coords = {} + if labeller is None: + labeller = BaseLabeller() + # Get posterior draws and combine chains dataset = convert_to_dataset(data, group=group) var_names = _var_names(var_names, dataset, filter_vars) plotters = list( xarray_var_iter(get_coords(dataset, coords), var_names=var_names, combined=True) ) - flat_var_names = [make_label(var_name, selection) for var_name, selection, _ in plotters] + flat_var_names = [ + labeller.make_label_vert(var_name, sel, isel) for var_name, sel, isel, _ in plotters + ] divergent_data = None diverging_mask = None diff --git a/arviz/plots/parallelplot.py b/arviz/plots/parallelplot.py index dd610435ef..3016cc1ee8 100644 --- a/arviz/plots/parallelplot.py +++ b/arviz/plots/parallelplot.py @@ -3,10 +3,12 @@ from scipy.stats import rankdata from ..data import convert_to_dataset +from ..labels import BaseLabeller +from ..sel_utils import xarray_to_ndarray from ..rcparams import rcParams from ..stats.stats_utils import stats_variance_2d as svar from ..utils import _numba_var, _var_names, get_coords -from .plot_utils import get_plotting_function, xarray_to_ndarray +from .plot_utils import get_plotting_function def plot_parallel( @@ -20,6 +22,7 @@ def plot_parallel( colornd="k", colord="C1", shadend=0.025, + labeller=None, ax=None, norm_method=None, backend=None, @@ -62,6 +65,9 @@ def plot_parallel( shadend: float Alpha blending value for non-divergent points, between 0 (invisible) and 1 (opaque). Defaults to .025 + labeller : labeller instance, optional + Class providing the method `make_label_vert` to generate the labels in the plot. + Read the :ref:`label_guide` for more details and usage examples. ax: axes, optional Matplotlib axes or bokeh figures. norm_method: str @@ -104,6 +110,9 @@ def plot_parallel( if coords is None: coords = {} + if labeller is None: + labeller = BaseLabeller() + # Get diverging draws and combine chains divergent_data = convert_to_dataset(data, group="sample_stats") _, diverging_mask = xarray_to_ndarray(divergent_data, var_names=("diverging",), combined=True) @@ -113,7 +122,10 @@ def plot_parallel( posterior_data = convert_to_dataset(data, group="posterior") var_names = _var_names(var_names, posterior_data, filter_vars) var_names, _posterior = xarray_to_ndarray( - get_coords(posterior_data, coords), var_names=var_names, combined=True + get_coords(posterior_data, coords), + var_names=var_names, + combined=True, + label_fun=labeller.make_label_vert, ) if len(var_names) < 2: raise ValueError("Number of variables to be plotted must be 2 or greater.") diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index 553457a405..16ba7b7cd5 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -1,13 +1,11 @@ """Utilities for plotting.""" import importlib import warnings -from itertools import product, tee from typing import Any, Dict import matplotlib as mpl import numpy as np import packaging -import xarray as xr from matplotlib.colors import to_hex from scipy.stats import mode, rankdata from scipy.interpolate import CubicSpline @@ -143,51 +141,6 @@ def in_bounds(val): return rows, cols -def selection_to_string(selection): - """Convert dictionary of coordinates to a string for labels. - - Parameters - ---------- - selection : dict[Any] -> Any - - Returns - ------- - str - key1: value1, key2: value2, ... - """ - return ", ".join(["{}".format(v) for _, v in selection.items()]) - - -def make_label(var_name, selection, position="below"): - """Consistent labelling for plots. - - Parameters - ---------- - var_name : str - Name of the variable - - selection : dict[Any] -> Any - Coordinates of the variable - position : str - Whether to position the coordinates' label "below" (default) or "beside" - the name of the variable - - Returns - ------- - label - A text representation of the label - """ - if selection: - sel = selection_to_string(selection) - if position == "below": - sep = "\n" - elif position == "beside": - sep = " " - else: - sep = sel = "" - return "{}{}{}".format(var_name, sep, sel) - - def format_sig_figs(value, default=None): """Get a default number of significant figures. @@ -222,179 +175,6 @@ def round_num(n, round_to): return "{n:.{sig_figs}g}".format(n=n, sig_figs=sig_figs) -def purge_duplicates(list_in): - """Remove duplicates from list while preserving order. - - Parameters - ---------- - list_in: Iterable - - Returns - ------- - list - List of first occurrences in order - """ - # Algorithm taken from Stack Overflow, - # https://stackoverflow.com/questions/480214. Content by Georgy - # Skorobogatov (https://stackoverflow.com/users/7851470/georgy) and - # Markus Jarderot - # (https://stackoverflow.com/users/22364/markus-jarderot), licensed - # under CC-BY-SA 4.0. - # https://creativecommons.org/licenses/by-sa/4.0/. - - seen = set() - seen_add = seen.add - return [x for x in list_in if not (x in seen or seen_add(x))] - - -def _dims(data, var_name, skip_dims): - return [dim for dim in data[var_name].dims if dim not in skip_dims] - - -def _zip_dims(new_dims, vals): - return [{k: v for k, v in zip(new_dims, prod)} for prod in product(*vals)] - - -def xarray_sel_iter(data, var_names=None, combined=False, skip_dims=None, reverse_selections=False): - """Convert xarray data to an iterator over variable names and selections. - - Iterates over each var_name and all of its coordinates, returning the variable - names and selections that allow properly obtain the data from ``data`` as desired. - - Parameters - ---------- - data : xarray.Dataset - Posterior data in an xarray - - var_names : iterator of strings (optional) - Should be a subset of data.data_vars. Defaults to all of them. - - combined : bool - Whether to combine chains or leave them separate - - skip_dims : set - dimensions to not iterate over - - reverse_selections : bool - Whether to reverse selections before iterating. - - Returns - ------- - Iterator of (var_name: str, selection: dict(str, any)) - The string is the variable name, the dictionary are coordinate names to values,. - To get the values of the variable at these coordinates, do - ``data[var_name].sel(**selection)``. - """ - if skip_dims is None: - skip_dims = set() - - if combined: - skip_dims = skip_dims.union({"chain", "draw"}) - else: - skip_dims.add("draw") - - if var_names is None: - if isinstance(data, xr.Dataset): - var_names = list(data.data_vars) - elif isinstance(data, xr.DataArray): - var_names = [data.name] - data = {data.name: data} - - for var_name in var_names: - if var_name in data: - new_dims = _dims(data, var_name, skip_dims) - vals = [purge_duplicates(data[var_name][dim].values) for dim in new_dims] - dims = _zip_dims(new_dims, vals) - if reverse_selections: - dims = reversed(dims) - - for selection in dims: - yield var_name, selection - - -def xarray_var_iter(data, var_names=None, combined=False, skip_dims=None, reverse_selections=False): - """Convert xarray data to an iterator over vectors. - - Iterates over each var_name and all of its coordinates, returning the 1d - data. - - Parameters - ---------- - data : xarray.Dataset - Posterior data in an xarray - - var_names : iterator of strings (optional) - Should be a subset of data.data_vars. Defaults to all of them. - - combined : bool - Whether to combine chains or leave them separate - - skip_dims : set - dimensions to not iterate over - - reverse_selections : bool - Whether to reverse selections before iterating. - - Returns - ------- - Iterator of (str, dict(str, any), np.array) - The string is the variable name, the dictionary are coordinate names to values, - and the array are the values of the variable at those coordinates. - """ - data_to_sel = data - if var_names is None and isinstance(data, xr.DataArray): - data_to_sel = {data.name: data} - - for var_name, selection in xarray_sel_iter( - data, - var_names=var_names, - combined=combined, - skip_dims=skip_dims, - reverse_selections=reverse_selections, - ): - yield var_name, selection, data_to_sel[var_name].sel(**selection).values - - -def xarray_to_ndarray(data, *, var_names=None, combined=True): - """Take xarray data and unpacks into variables and data into list and numpy array respectively. - - Assumes that chain and draw are in coordinates - - Parameters - ---------- - data: xarray.DataSet - Data in an xarray from an InferenceData object. Examples include posterior or sample_stats - - var_names: iter - Should be a subset of data.data_vars not including chain and draws. Defaults to all of them - - combined: bool - Whether to combine chain into one array - - Returns - ------- - var_names: list - List of variable names - data: np.array - Data values - """ - data_to_sel = data - if var_names is None and isinstance(data, xr.DataArray): - data_to_sel = {data.name: data} - - iterator1, iterator2 = tee(xarray_sel_iter(data, var_names=var_names, combined=combined)) - vars_and_sel = list(iterator1) - unpacked_var_names = [make_label(var_name, selection) for var_name, selection in vars_and_sel] - - # Merge chains and variables, check dtype to be compatible with divergences data - data0 = data_to_sel[vars_and_sel[0][0]].sel(**vars_and_sel[0][1]) - unpacked_data = np.empty((len(unpacked_var_names), data0.size), dtype=data0.dtype) - for idx, (var_name, selection) in enumerate(iterator2): - unpacked_data[idx] = data_to_sel[var_name].sel(**selection).values.flatten() - - return unpacked_var_names, unpacked_data - - def color_from_dim(dataarray, dim_name): """Return colors and color mapping of a DataArray using coord values as color code. diff --git a/arviz/plots/posteriorplot.py b/arviz/plots/posteriorplot.py index ec6585dbf5..47f2f4be97 100644 --- a/arviz/plots/posteriorplot.py +++ b/arviz/plots/posteriorplot.py @@ -1,8 +1,10 @@ """Plot posterior densities.""" from ..data import convert_to_dataset -from ..rcparams import rcParams +from ..labels import BaseLabeller +from ..sel_utils import xarray_var_iter from ..utils import _var_names, get_coords -from .plot_utils import default_grid, filter_plotters_list, get_plotting_function, xarray_var_iter +from ..rcparams import rcParams +from .plot_utils import default_grid, filter_plotters_list, get_plotting_function def plot_posterior( @@ -26,6 +28,7 @@ def plot_posterior( bw="default", circular=False, bins=None, + labeller=None, ax=None, backend=None, backend_kwargs=None, @@ -99,6 +102,9 @@ def plot_posterior( Controls the number of bins, accepts the same keywords `matplotlib.hist()` does. Only works if `kind == hist`. If None (default) it will use `auto` for continuous variables and `range(xmin, xmax + 1)` for discrete variables. + labeller : labeller instance, optional + Class providing the method `make_label_vert` to generate the labels in the plot titles. + Read the :ref:`label_guide` for more details and usage examples. ax: numpy array-like of matplotlib axes or bokeh figures, optional A 2D array of locations into which to plot the densities. If not supplied, Arviz will create its own array of plot areas (and return it). @@ -209,6 +215,9 @@ def plot_posterior( if coords is None: coords = {} + if labeller is None: + labeller = BaseLabeller() + if hdi_prob is None: hdi_prob = rcParams["stats.hdi_prob"] elif hdi_prob not in (None, "hide"): @@ -246,6 +255,7 @@ def plot_posterior( textsize=textsize, ref_val=ref_val, rope=rope, + labeller=labeller, kwargs=kwargs, backend_kwargs=backend_kwargs, show=show, diff --git a/arviz/plots/ppcplot.py b/arviz/plots/ppcplot.py index 5daf7ba9c0..a0378e8ca2 100644 --- a/arviz/plots/ppcplot.py +++ b/arviz/plots/ppcplot.py @@ -4,9 +4,11 @@ import numpy as np +from ..labels import BaseLabeller +from ..sel_utils import xarray_var_iter from ..rcparams import rcParams from ..utils import _var_names -from .plot_utils import default_grid, filter_plotters_list, get_plotting_function, xarray_var_iter +from .plot_utils import default_grid, filter_plotters_list, get_plotting_function _log = logging.getLogger(__name__) @@ -33,6 +35,7 @@ def plot_ppc( animated=False, animation_kwargs=None, legend=True, + labeller=None, ax=None, backend=None, backend_kwargs=None, @@ -123,6 +126,9 @@ def plot_ppc( Keywords passed to `animation.FuncAnimation`. Ignored with matploblib backend. legend : bool Add legend to figure. By default True. + labeller : labeller instance, optional + Class providing the method `make_pp_label` to generate the labels in the plot titles. + Read the :ref:`label_guide` for more details and usage examples. ax: numpy array-like of matplotlib axes or bokeh figures, optional A 2D array of locations into which to plot the densities. If not supplied, Arviz will create its own array of plot areas (and return it). @@ -235,6 +241,9 @@ def plot_ppc( if coords is None: coords = {} + if labeller is None: + labeller = BaseLabeller() + if random_seed is not None: np.random.seed(random_seed) @@ -306,6 +315,7 @@ def plot_ppc( observed=observed, total_pp_samples=total_pp_samples, legend=legend, + labeller=labeller, group=group, animation_kwargs=animation_kwargs, num_pp_samples=num_pp_samples, diff --git a/arviz/plots/rankplot.py b/arviz/plots/rankplot.py index 86695bf3f4..7d4b2707a1 100644 --- a/arviz/plots/rankplot.py +++ b/arviz/plots/rankplot.py @@ -4,10 +4,12 @@ import matplotlib.pyplot as plt from ..data import convert_to_dataset +from ..labels import BaseLabeller +from ..sel_utils import xarray_var_iter from ..rcparams import rcParams from ..stats.density_utils import _sturges_formula from ..utils import _var_names -from .plot_utils import default_grid, filter_plotters_list, get_plotting_function, xarray_var_iter +from .plot_utils import default_grid, filter_plotters_list, get_plotting_function def plot_rank( @@ -21,6 +23,7 @@ def plot_rank( colors="cycle", ref_line=True, labels=True, + labeller=None, grid=None, figsize=None, ax=None, @@ -78,6 +81,9 @@ def plot_rank( Whether to include a dashed line showing where a uniform distribution would lie labels: bool whether to plot or not the x and y labels, defaults to True + labeller : labeller instance, optional + Class providing the method `make_label_vert` to generate the labels in the plot titles. + Read the :ref:`label_guide` for more details and usage examples. grid : tuple Number of rows and columns. Defaults to None, the rows and columns are automatically inferred. @@ -163,6 +169,9 @@ def plot_rank( if bins is None: bins = _sturges_formula(posterior_data, mult=2) + if labeller is None: + labeller = BaseLabeller() + rows, cols = default_grid(length_plotters, grid=grid) chains = len(posterior_data.chain) @@ -188,6 +197,7 @@ def plot_rank( colors=colors, ref_line=ref_line, labels=labels, + labeller=labeller, ref_line_kwargs=ref_line_kwargs, bar_kwargs=bar_kwargs, vlines_kwargs=vlines_kwargs, diff --git a/arviz/plots/separationplot.py b/arviz/plots/separationplot.py index c3911f952b..f7aab8fce2 100644 --- a/arviz/plots/separationplot.py +++ b/arviz/plots/separationplot.py @@ -77,10 +77,9 @@ def plot_separation( References ---------- - * Greenhill, B. *et al.*, The Separation Plot: A New Visual Method - for Evaluating the Fit of Binary Models, *American Journal of - Political Science, (2011) see - https://doi.org/10.1111/j.1540-5907.2011.00525.x + .. [1] Greenhill, B. *et al.*, The Separation Plot: A New Visual Method + for Evaluating the Fit of Binary Models, *American Journal of + Political Science*, (2011) see https://doi.org/10.1111/j.1540-5907.2011.00525.x Examples -------- diff --git a/arviz/plots/traceplot.py b/arviz/plots/traceplot.py index 79d6dd3f48..2c82e81247 100644 --- a/arviz/plots/traceplot.py +++ b/arviz/plots/traceplot.py @@ -3,9 +3,11 @@ from typing import Any, Callable, List, Mapping, Optional, Tuple, Union from ..data import CoordSpec, InferenceData, convert_to_dataset +from ..labels import BaseLabeller from ..rcparams import rcParams +from ..sel_utils import xarray_var_iter from ..utils import _var_names, get_coords -from .plot_utils import KwargSpec, get_plotting_function, xarray_var_iter +from .plot_utils import KwargSpec, get_plotting_function def plot_trace( @@ -32,6 +34,7 @@ def plot_trace( hist_kwargs: Optional[KwargSpec] = None, trace_kwargs: Optional[KwargSpec] = None, rank_kwargs: Optional[KwargSpec] = None, + labeller=None, axes=None, backend: Optional[str] = None, backend_config: Optional[KwargSpec] = None, @@ -92,6 +95,9 @@ def plot_trace( Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables. trace_kwargs: dict, optional Extra keyword arguments passed to `plt.plot` + labeller : labeller instance, optional + Class providing the method `make_label_vert` to generate the labels in the plot titles. + Read the :ref:`label_guide` for more details and usage examples. rank_kwargs : dict, optional Extra keyword arguments passed to `arviz.plot_rank` axes: axes, optional @@ -169,6 +175,9 @@ def plot_trace( if coords is None: coords = {} + if labeller is None: + labeller = BaseLabeller() + if divergences: divergence_data = get_coords( divergence_data, {k: v for k, v in coords.items() if k in ("chain", "draw")} @@ -226,6 +235,7 @@ def plot_trace( combined=combined, chain_prop=chain_prop, legend=legend, + labeller=labeller, # Generated kwargs divergence_data=divergence_data, # skip_dims=skip_dims, diff --git a/arviz/plots/violinplot.py b/arviz/plots/violinplot.py index c44144bc1b..c6d03efaa4 100644 --- a/arviz/plots/violinplot.py +++ b/arviz/plots/violinplot.py @@ -1,8 +1,10 @@ """Plot posterior traces as violin plot.""" from ..data import convert_to_dataset -from ..rcparams import rcParams +from ..labels import BaseLabeller +from ..sel_utils import xarray_var_iter from ..utils import _var_names -from .plot_utils import default_grid, filter_plotters_list, get_plotting_function, xarray_var_iter +from ..rcparams import rcParams +from .plot_utils import default_grid, filter_plotters_list, get_plotting_function def plot_violin( @@ -21,6 +23,7 @@ def plot_violin( grid=None, figsize=None, textsize=None, + labeller=None, ax=None, shade_kwargs=None, rug_kwargs=None, @@ -77,6 +80,9 @@ def plot_violin( textsize: int Text size of the point_estimates, axis ticks, and highest density interval. If None it will be autoscaled based on figsize. + labeller : labeller instance, optional + Class providing the method `make_label_vert` to generate the labels in the plot titles. + Read the :ref:`label_guide` for more details and usage examples. sharex: bool Defaults to True, violinplots share a common x-axis scale. sharey: bool @@ -120,6 +126,9 @@ def plot_violin( >>> az.plot_violin(data, var_names="tau", transform=np.log) """ + if labeller is None: + labeller = BaseLabeller() + data = convert_to_dataset(data, group="posterior") if transform is not None: data = transform(data) @@ -151,6 +160,7 @@ def plot_violin( rug_kwargs=rug_kwargs, bw=bw, textsize=textsize, + labeller=labeller, circular=circular, hdi_prob=hdi_prob, quartiles=quartiles, diff --git a/arviz/sel_utils.py b/arviz/sel_utils.py new file mode 100644 index 0000000000..36216e237f --- /dev/null +++ b/arviz/sel_utils.py @@ -0,0 +1,234 @@ +"""Utilities for selecting and iterating on xarray objects.""" +from itertools import product, tee + +import numpy as np +import xarray as xr + +from .labels import BaseLabeller + +__all__ = ["xarray_sel_iter", "xarray_var_iter", "xarray_to_ndarray"] + + +def selection_to_string(selection): + """Convert dictionary of coordinates to a string for labels. + + Parameters + ---------- + selection : dict[Any] -> Any + + Returns + ------- + str + key1: value1, key2: value2, ... + """ + return ", ".join(["{}".format(v) for _, v in selection.items()]) + + +def make_label(var_name, selection, position="below"): + """Consistent labelling for plots. + + Parameters + ---------- + var_name : str + Name of the variable + + selection : dict[Any] -> Any + Coordinates of the variable + position : str + Whether to position the coordinates' label "below" (default) or "beside" + the name of the variable + + Returns + ------- + label + A text representation of the label + """ + if selection: + sel = selection_to_string(selection) + if position == "below": + base = "{}\n{}" + elif position == "beside": + base = "{}[{}]" + else: + sel = "" + base = "{}{}" + return base.format(var_name, sel) + + +def purge_duplicates(list_in): + """Remove duplicates from list while preserving order. + + Parameters + ---------- + list_in: Iterable + + Returns + ------- + list + List of first occurrences in order + """ + # Algorithm taken from Stack Overflow, + # https://stackoverflow.com/questions/480214. Content by Georgy + # Skorobogatov (https://stackoverflow.com/users/7851470/georgy) and + # Markus Jarderot + # (https://stackoverflow.com/users/22364/markus-jarderot), licensed + # under CC-BY-SA 4.0. + # https://creativecommons.org/licenses/by-sa/4.0/. + + seen = set() + seen_add = seen.add + return [x for x in list_in if not (x in seen or seen_add(x))] + + +def _dims(data, var_name, skip_dims): + return [dim for dim in data[var_name].dims if dim not in skip_dims] + + +def _zip_dims(new_dims, vals): + return [{k: v for k, v in zip(new_dims, prod)} for prod in product(*vals)] + + +def xarray_sel_iter(data, var_names=None, combined=False, skip_dims=None, reverse_selections=False): + """Convert xarray data to an iterator over variable names and selections. + + Iterates over each var_name and all of its coordinates, returning the variable + names and selections that allow properly obtain the data from ``data`` as desired. + + Parameters + ---------- + data : xarray.Dataset + Posterior data in an xarray + + var_names : iterator of strings (optional) + Should be a subset of data.data_vars. Defaults to all of them. + + combined : bool + Whether to combine chains or leave them separate + + skip_dims : set + dimensions to not iterate over + + reverse_selections : bool + Whether to reverse selections before iterating. + + Returns + ------- + Iterator of (var_name: str, selection: dict(str, any)) + The string is the variable name, the dictionary are coordinate names to values,. + To get the values of the variable at these coordinates, do + ``data[var_name].sel(**selection)``. + """ + if skip_dims is None: + skip_dims = set() + + if combined: + skip_dims = skip_dims.union({"chain", "draw"}) + else: + skip_dims.add("draw") + + if var_names is None: + if isinstance(data, xr.Dataset): + var_names = list(data.data_vars) + elif isinstance(data, xr.DataArray): + var_names = [data.name] + data = {data.name: data} + + for var_name in var_names: + if var_name in data: + new_dims = _dims(data, var_name, skip_dims) + vals = [purge_duplicates(data[var_name][dim].values) for dim in new_dims] + dims = _zip_dims(new_dims, vals) + idims = _zip_dims(new_dims, [range(len(v)) for v in vals]) + if reverse_selections: + dims = reversed(dims) + idims = reversed(idims) + + for selection, iselection in zip(dims, idims): + yield var_name, selection, iselection + + +def xarray_var_iter(data, var_names=None, combined=False, skip_dims=None, reverse_selections=False): + """Convert xarray data to an iterator over vectors. + + Iterates over each var_name and all of its coordinates, returning the 1d + data. + + Parameters + ---------- + data : xarray.Dataset + Posterior data in an xarray + + var_names : iterator of strings (optional) + Should be a subset of data.data_vars. Defaults to all of them. + + combined : bool + Whether to combine chains or leave them separate + + skip_dims : set + dimensions to not iterate over + + reverse_selections : bool + Whether to reverse selections before iterating. + + Returns + ------- + Iterator of (str, dict(str, any), np.array) + The string is the variable name, the dictionary are coordinate names to values, + and the array are the values of the variable at those coordinates. + """ + data_to_sel = data + if var_names is None and isinstance(data, xr.DataArray): + data_to_sel = {data.name: data} + + for var_name, selection, iselection in xarray_sel_iter( + data, + var_names=var_names, + combined=combined, + skip_dims=skip_dims, + reverse_selections=reverse_selections, + ): + yield var_name, selection, iselection, data_to_sel[var_name].sel(**selection).values + + +def xarray_to_ndarray(data, *, var_names=None, combined=True, label_fun=None): + """Take xarray data and unpacks into variables and data into list and numpy array respectively. + + Assumes that chain and draw are in coordinates + + Parameters + ---------- + data: xarray.DataSet + Data in an xarray from an InferenceData object. Examples include posterior or sample_stats + + var_names: iter + Should be a subset of data.data_vars not including chain and draws. Defaults to all of them + + combined: bool + Whether to combine chain into one array + + Returns + ------- + var_names: list + List of variable names + data: np.array + Data values + """ + if label_fun is None: + label_fun = BaseLabeller().make_label_vert + data_to_sel = data + if var_names is None and isinstance(data, xr.DataArray): + data_to_sel = {data.name: data} + + iterator1, iterator2 = tee(xarray_sel_iter(data, var_names=var_names, combined=combined)) + vars_and_sel = list(iterator1) + unpacked_var_names = [ + label_fun(var_name, selection, isel) for var_name, selection, isel in vars_and_sel + ] + + # Merge chains and variables, check dtype to be compatible with divergences data + data0 = data_to_sel[vars_and_sel[0][0]].sel(**vars_and_sel[0][1]) + unpacked_data = np.empty((len(unpacked_var_names), data0.size), dtype=data0.dtype) + for idx, (var_name, selection, _) in enumerate(iterator2): + unpacked_data[idx] = data_to_sel[var_name].sel(**selection).values.flatten() + + return unpacked_var_names, unpacked_data diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 978c73be60..8ccdd606a6 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -1,9 +1,8 @@ # pylint: disable=too-many-lines """Statistical functions in ArviZ.""" import warnings -from collections import OrderedDict from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -12,7 +11,7 @@ from scipy.optimize import minimize from arviz import _log -from ..data import CoordSpec, DimSpec, InferenceData, convert_to_dataset, convert_to_inference_data +from ..data import InferenceData, convert_to_dataset, convert_to_inference_data from ..rcparams import rcParams from ..utils import Numba, _numba_var, _var_names, get_coords from .density_utils import get_bins as _get_bins @@ -25,6 +24,8 @@ from .stats_utils import make_ufunc as _make_ufunc from .stats_utils import stats_variance_2d as svar from .stats_utils import wrap_xarray_ufunc as _wrap_xarray_ufunc +from ..sel_utils import xarray_var_iter +from ..labels import BaseLabeller if TYPE_CHECKING: from typing_extensions import Literal @@ -974,11 +975,11 @@ def summary( stat_funcs=None, extend=True, hdi_prob=None, - order: "Literal['C', 'F']" = "C", - index_origin=None, skipna=False, - coords: Optional[CoordSpec] = None, - dims: Optional[DimSpec] = None, + labeller=None, + coords=None, + index_origin=None, + order=None, ) -> Union[pd.DataFrame, xr.Dataset]: """Create a data frame with summary statistics. @@ -996,6 +997,8 @@ def summary( interpret var_names as substrings of the real variables names. If "regex", interpret var_names as regular expressions on the real variables names. A la `pandas.filter`. + coords: Dict[str, List[Any]], optional + Coordinate subset for which to calculate the summary. group: str Select a group for summary. Defaults to "posterior", "prior" or first group in that order, depending what groups exists. @@ -1023,18 +1026,19 @@ def summary( hdi_prob: float, optional Highest density interval to compute. Defaults to 0.94. This is only meaningful when ``stat_funcs`` is None. - order: {"C", "F"} - If fmt is "wide", use either C or F unpacking order. Defaults to C. - index_origin: int - If fmt is "wide, select n-based indexing for multivariate parameters. - Defaults to rcParam data.index.origin, which is 0. skipna: bool If true ignores nan values when computing the summary statistics, it does not affect the behaviour of the functions passed to ``stat_funcs``. Defaults to false. - coords: Dict[str, List[Any]], optional - Coordinates specification to be used if the ``fmt`` is ``'xarray'``. - dims: Dict[str, List[str]], optional - Dimensions specification for the variables to be used if the ``fmt`` is ``'xarray'``. + labeller : labeller instance, optional + Class providing the method `make_label_flat` to generate the labels in the plot titles. + For more details on ``labeller`` usage see :ref:`label_guide` + credible_interval: float, optional + deprecated: Please see hdi_prob + order + deprecated: order is now ignored. + index_origin + deprecated: index_origin is now ignored, modify the coordinate values to change the + value used in summary. Returns ------- @@ -1096,13 +1100,18 @@ def summary( """ _log.cache = [] - extra_args = {} # type: Dict[str, Any] - if coords is not None: - extra_args["coords"] = coords - if dims is not None: - extra_args["dims"] = dims - if index_origin is None: + if coords is None: + coords = {} + + if index_origin is not None: + warnings.warn( + "index_origin has been deprecated. summary now shows coordinate values, " + "to change the label shown, modify the coordinate values before calling sumary", + DeprecationWarning, + ) index_origin = rcParams["data.index_origin"] + if labeller is None: + labeller = BaseLabeller() if hdi_prob is None: hdi_prob = rcParams["stats.hdi_prob"] else: @@ -1125,22 +1134,24 @@ def summary( raise TypeError(f"InferenceData does not contain group: {group}") dataset = data[group] else: - dataset = convert_to_dataset(data, group="posterior", **extra_args) + dataset = convert_to_dataset(data, group="posterior") var_names = _var_names(var_names, dataset, filter_vars) dataset = dataset if var_names is None else dataset[var_names] + dataset = get_coords(dataset, coords) fmt_group = ("wide", "long", "xarray") if not isinstance(fmt, str) or (fmt.lower() not in fmt_group): raise TypeError(f"Invalid format: '{fmt}'. Formatting options are: {fmt_group}") - unpack_order_group = ("C", "F") - if not isinstance(order, str) or (order.upper() not in unpack_order_group): - raise TypeError(f"Invalid order: '{order}'. Unpacking options are: {unpack_order_group}") - kind_group = ("all", "stats", "diagnostics") if not isinstance(kind, str) or kind not in kind_group: raise TypeError(f"Invalid kind: '{kind}'. Kind options are: {kind_group}") + if order is not None: + warnings.warn( + "order has been deprecated. summary now shows coordinate values.", DeprecationWarning + ) + alpha = 1 - hdi_prob extra_metrics = [] @@ -1267,28 +1278,18 @@ def summary( joined = ( xr.concat(metrics, dim="metric").assign_coords(metric=metric_names).reset_coords(drop=True) ) + n_metrics = len(metric_names) + n_vars = np.sum([joined[var].size // n_metrics for var in joined.data_vars]) if fmt.lower() == "wide": - dfs = [] - for var_name, values in joined.data_vars.items(): - if len(values.shape[1:]): - index_metric = list(values.metric.values) - data_dict = OrderedDict() - for idx in np.ndindex(values.shape[1:] if order == "C" else values.shape[1:][::-1]): - if order == "F": - idx = tuple(idx[::-1]) - ser = pd.Series(values[(Ellipsis, *idx)].values, index=index_metric) - key_index = ",".join(map(str, (i + index_origin for i in idx))) - key = f"{var_name}[{key_index}]" - data_dict[key] = ser - df = pd.DataFrame.from_dict(data_dict, orient="index") - df = df.loc[list(data_dict.keys())] - else: - df = values.to_dataframe() - df.index = list(df.index) - df = df.T - dfs.append(df) - summary_df = pd.concat(dfs, sort=False) + summary_df = pd.DataFrame(np.full((n_vars, n_metrics), np.nan), columns=metric_names) + indexs = [] + for i, (var_name, sel, isel, values) in enumerate( + xarray_var_iter(joined, skip_dims={"metric"}) + ): + summary_df.iloc[i] = values + indexs.append(labeller.make_label_flat(var_name, sel, isel)) + summary_df.index = indexs elif fmt.lower() == "long": df = joined.to_dataframe().reset_index().set_index("metric") df.index = list(df.index) diff --git a/arviz/tests/base_tests/test_data_zarr.py b/arviz/tests/base_tests/test_data_zarr.py index 7a77a2587f..d42cb15ffa 100644 --- a/arviz/tests/base_tests/test_data_zarr.py +++ b/arviz/tests/base_tests/test_data_zarr.py @@ -1,5 +1,4 @@ # pylint: disable=redefined-outer-name -import importlib import os import shutil from collections.abc import MutableMapping @@ -16,14 +15,10 @@ draws, eight_schools_params, running_on_ci, + importorskip, ) -pytestmark = pytest.mark.skipif( # pylint: disable=invalid-name - importlib.util.find_spec("zarr") is None and not running_on_ci(), - reason="test requires zarr which is not installed", -) - -import zarr # pylint: disable=wrong-import-position, wrong-import-order +zarr = importorskip("zarr") # pylint: disable=invalid-name class TestDataZarr: diff --git a/arviz/tests/base_tests/test_diagnostics.py b/arviz/tests/base_tests/test_diagnostics.py index b6d5675fc9..6c9bd57985 100644 --- a/arviz/tests/base_tests/test_diagnostics.py +++ b/arviz/tests/base_tests/test_diagnostics.py @@ -8,9 +8,9 @@ from numpy.testing import assert_almost_equal from ...data import from_cmdstan, load_arviz_data -from ...plots.plot_utils import xarray_var_iter from ...rcparams import rc_context, rcParams from ...stats import bfmi, ess, mcse, rhat +from ...sel_utils import xarray_var_iter from ...stats.diagnostics import ( _ess, _ess_quantile, @@ -148,7 +148,7 @@ def test_deterministic(self): "mcse_quantile30": lambda x: mcse(x, method="quantile", prob=0.3), } results = {} - for key, coord_dict, vals in xarray_var_iter(posterior.posterior, combined=True): + for key, coord_dict, _, vals in xarray_var_iter(posterior.posterior, combined=True): if coord_dict: key = key + ".{}".format(list(coord_dict.values())[0] + 1) results[key] = {func_name: func(vals) for func_name, func in funcs.items()} diff --git a/arviz/tests/base_tests/test_plot_utils.py b/arviz/tests/base_tests/test_plot_utils.py index e616f7d1fe..4edc978bb3 100644 --- a/arviz/tests/base_tests/test_plot_utils.py +++ b/arviz/tests/base_tests/test_plot_utils.py @@ -14,10 +14,10 @@ make_2d, set_bokeh_circular_ticks_labels, vectorized_to_hex, - xarray_to_ndarray, - xarray_var_iter, compute_ranks, ) +from ...sel_utils import xarray_to_ndarray, xarray_sel_iter + from ...rcparams import rc_context from ...stats.density_utils import get_bins from ...utils import get_coords @@ -91,7 +91,7 @@ def test_dataset_to_numpy_combined(sample_dataset): assert (data[var_names.index("tau")] == tau.reshape(1, 6)).all() -def test_xarray_var_iter_ordering(): +def test_xarray_sel_iter_ordering(): """Assert that coordinate names stay the provided order""" coords = list("dcba") data = from_dict( # pylint: disable=no-member @@ -100,21 +100,21 @@ def test_xarray_var_iter_ordering(): dims={"x": ["in_order"]}, ).posterior - coord_names = [sel["in_order"] for _, sel, _ in xarray_var_iter(data)] + coord_names = [sel["in_order"] for _, sel, _ in xarray_sel_iter(data)] assert coord_names == coords -def test_xarray_var_iter_ordering_combined(sample_dataset): # pylint: disable=invalid-name +def test_xarray_sel_iter_ordering_combined(sample_dataset): # pylint: disable=invalid-name """Assert that varname order stays consistent when chains are combined""" _, _, data = sample_dataset - var_names = [var for (var, _, _) in xarray_var_iter(data, var_names=None, combined=True)] + var_names = [var for (var, _, _) in xarray_sel_iter(data, var_names=None, combined=True)] assert set(var_names) == {"mu", "tau"} -def test_xarray_var_iter_ordering_uncombined(sample_dataset): # pylint: disable=invalid-name +def test_xarray_sel_iter_ordering_uncombined(sample_dataset): # pylint: disable=invalid-name """Assert that varname order stays consistent when chains are not combined""" _, _, data = sample_dataset - var_names = [(var, selection) for (var, selection, _) in xarray_var_iter(data, var_names=None)] + var_names = [(var, selection) for (var, selection, _) in xarray_sel_iter(data, var_names=None)] assert len(var_names) == 4 for var_name in var_names: @@ -126,13 +126,13 @@ def test_xarray_var_iter_ordering_uncombined(sample_dataset): # pylint: disable ] -def test_xarray_var_data_array(sample_dataset): # pylint: disable=invalid-name +def test_xarray_sel_data_array(sample_dataset): # pylint: disable=invalid-name """Assert that varname order stays consistent when chains are combined Touches code that is hard to reach. """ _, _, data = sample_dataset - var_names = [var for (var, _, _) in xarray_var_iter(data.mu, var_names=None, combined=True)] + var_names = [var for (var, _, _) in xarray_sel_iter(data.mu, var_names=None, combined=True)] assert set(var_names) == {"mu"} diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index ee1ae3deed..69c4073b96 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -1312,42 +1312,6 @@ def test_plot_loo_pit_incompatible_args(models): plot_loo_pit(idata=models.model_1, y="y", ecdf=True, use_hdi=True) -@pytest.mark.parametrize( - "args", - [ - {"y": "str"}, - {"y": "DataArray", "y_hat": "str"}, - {"y": "ndarray", "y_hat": "str"}, - {"y": "ndarray", "y_hat": "DataArray"}, - {"y": "ndarray", "y_hat": "ndarray"}, - ], -) -def test_plot_loo_pit_label(models, args): - assert_name = args["y"] != "ndarray" or args.get("y_hat") != "ndarray" - - if args["y"] == "str": - y = "y" - elif args["y"] == "DataArray": - y = models.model_1.observed_data.y - elif args["y"] == "ndarray": - y = models.model_1.observed_data.y.values - - if args.get("y_hat") == "str": - y_hat = "y" - elif args.get("y_hat") == "DataArray": - y_hat = models.model_1.posterior_predictive.y.stack(sample=("chain", "draw")) - elif args.get("y_hat") == "ndarray": - y_hat = models.model_1.posterior_predictive.y.stack(sample=("chain", "draw")).values - else: - y_hat = None - - ax = plot_loo_pit(idata=models.model_1, y=y, y_hat=y_hat) - if assert_name: - assert "y" in ax.get_legend_handles_labels()[1][0] - else: - assert "y" not in ax.get_legend_handles_labels()[1][0] - - @pytest.mark.parametrize( "kwargs", [ diff --git a/arviz/tests/base_tests/test_stats.py b/arviz/tests/base_tests/test_stats.py index 4416ed7a70..ce7b4ad413 100644 --- a/arviz/tests/base_tests/test_stats.py +++ b/arviz/tests/base_tests/test_stats.py @@ -261,40 +261,24 @@ def test_summary_fmt(centered_eight, fmt): assert summary(centered_eight, fmt=fmt) is not None -@pytest.mark.parametrize("order", ["C", "F"]) -def test_summary_unpack_order(order): - data = from_dict({"a": np.random.randn(4, 100, 4, 5, 3)}) - az_summary = summary(data, order=order, fmt="wide") +def test_summary_labels(): + coords1 = list("abcd") + coords2 = np.arange(1, 6) + data = from_dict( + {"a": np.random.randn(4, 100, 4, 5)}, + coords={"dim1": coords1, "dim2": coords2}, + dims={"a": ["dim1", "dim2"]}, + ) + az_summary = summary(data, fmt="wide") assert az_summary is not None - if order != "F": - first_index = 4 - second_index = 5 - third_index = 3 - else: - first_index = 3 - second_index = 5 - third_index = 4 column_order = [] - for idx1 in range(first_index): - for idx2 in range(second_index): - for idx3 in range(third_index): - if order != "F": - column_order.append("a[{},{},{}]".format(idx1, idx2, idx3)) - else: - column_order.append("a[{},{},{}]".format(idx3, idx2, idx1)) + for coord1 in coords1: + for coord2 in coords2: + column_order.append("a[{}, {}]".format(coord1, coord2)) for col1, col2 in zip(list(az_summary.index), column_order): assert col1 == col2 -@pytest.mark.parametrize("origin", [0, 1, 2, 3]) -def test_summary_index_origin(origin): - data = from_dict({"a": np.random.randn(2, 50, 10)}) - az_summary = summary(data, index_origin=origin, fmt="wide") - assert az_summary is not None - for i, col in enumerate(list(az_summary.index)): - assert col == "a[{}]".format(i + origin) - - @pytest.mark.parametrize( "stat_funcs", [[np.var], {"var": np.var, "var2": lambda x: np.var(x) ** 2}] ) @@ -306,12 +290,12 @@ def test_summary_stat_func(centered_eight, stat_funcs): def test_summary_nan(centered_eight): centered_eight = deepcopy(centered_eight) - centered_eight.posterior.theta[:, :, 0] = np.nan + centered_eight.posterior["theta"].loc[{"school": "Deerfield"}] = np.nan summary_xarray = summary(centered_eight) assert summary_xarray is not None - assert summary_xarray.loc["theta[0]"].isnull().all() + assert summary_xarray.loc["theta[Deerfield]"].isnull().all() assert ( - summary_xarray.loc[[ix for ix in summary_xarray.index if ix != "theta[0]"]] + summary_xarray.loc[[ix for ix in summary_xarray.index if ix != "theta[Deerfield]"]] .notnull() .all() .all() @@ -320,9 +304,9 @@ def test_summary_nan(centered_eight): def test_summary_skip_nan(centered_eight): centered_eight = deepcopy(centered_eight) - centered_eight.posterior.theta[:, :10, 1] = np.nan + centered_eight.posterior["theta"].loc[{"draw": slice(10), "school": "Deerfield"}] = np.nan summary_xarray = summary(centered_eight) - theta_1 = summary_xarray.loc["theta[1]"].isnull() + theta_1 = summary_xarray.loc["theta[Deerfield]"].isnull() assert summary_xarray is not None assert ~theta_1[:4].all() assert theta_1[4:].all() @@ -334,16 +318,14 @@ def test_summary_bad_fmt(centered_eight, fmt): summary(centered_eight, fmt=fmt) -@pytest.mark.parametrize("kind", [1, "bad_kind"]) -def test_summary_bad_kind(centered_eight, kind): - with pytest.raises(TypeError, match="Invalid kind"): - summary(centered_eight, kind=kind) +def test_summary_order_deprecation(centered_eight): + with pytest.warns(DeprecationWarning, match="order"): + summary(centered_eight, order="C") -@pytest.mark.parametrize("order", [1, "bad_order"]) -def test_summary_bad_unpack_order(centered_eight, order): - with pytest.raises(TypeError, match="Invalid order"): - summary(centered_eight, order=order) +def test_summary_index_origin_deprecation(centered_eight): + with pytest.warns(DeprecationWarning, match="index_origin"): + summary(centered_eight, index_origin=1) @pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"]) diff --git a/arviz/utils.py b/arviz/utils.py index a2230e2cd5..af7d857d50 100644 --- a/arviz/utils.py +++ b/arviz/utils.py @@ -406,7 +406,7 @@ def _cov_1d(x): return np.dot(x.T, x.conj()) / ddof -@conditional_jit(cache=True) +# @conditional_jit(cache=True) def _cov(data): if data.ndim == 1: return _cov_1d(data) diff --git a/doc/source/api/index.rst b/doc/source/api/index.rst index 0b4228875c..3e3010e689 100644 --- a/doc/source/api/index.rst +++ b/doc/source/api/index.rst @@ -15,6 +15,7 @@ API Reference stats_utils data inference_data + plot_utils utils rcparams wrappers diff --git a/doc/source/api/plot_utils.md b/doc/source/api/plot_utils.md new file mode 100644 index 0000000000..df91cd75b3 --- /dev/null +++ b/doc/source/api/plot_utils.md @@ -0,0 +1,44 @@ +```{eval-rst} +.. currentmodule:: arviz.labels +``` +# Plot utils + +(labeller_api)= +## Labellers +See also the {ref}`label_guide` + +```{eval-rst} +.. autosummary:: + :toctree: generated/ + + BaseLabeller + DimCoordLabeller + IdxLabeller + DimIdxLabeller + MapLabeller + NoModelLabeller + NoRepeatLabeller +``` + +## Labeling utils + +```{eval-rst} +.. autosummary:: + :toctree: generated/ + + mix_labellers +``` + +## Xarray utils +Low level functions to iterate over xarray objects. + +```{eval-rst} +.. currentmodule:: arviz.sel_utils + +.. autosummary:: + :toctree: generated/ + + xarray_sel_iter + xarray_var_iter + xarray_to_ndarray +``` diff --git a/doc/source/conf.py b/doc/source/conf.py index a633995989..19e5581714 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -82,7 +82,8 @@ # # MyST related params -jupyter_execute_notebooks = "off" +jupyter_execute_notebooks = "auto" +execution_excludepatterns = ['*.ipynb'] myst_heading_anchors = 3 panels_add_bootstrap_css = False diff --git a/doc/source/user_guide/data_structures.md b/doc/source/user_guide/data_structures.md index 52ed5c9b60..8502f615ef 100644 --- a/doc/source/user_guide/data_structures.md +++ b/doc/source/user_guide/data_structures.md @@ -1,6 +1,8 @@ # Data structures ```{toctree} +:maxdepth: 2 +label_guide ../schema/schema ``` diff --git a/doc/source/user_guide/label_guide.rst b/doc/source/user_guide/label_guide.rst new file mode 100644 index 0000000000..935376efde --- /dev/null +++ b/doc/source/user_guide/label_guide.rst @@ -0,0 +1,265 @@ +.. _label_guide: + +=========== +Label guide +=========== + +Basic labeling +-------------- + +All ArviZ plotting functions and some stats functions take an optional ``labeller`` argument. +By default, labels show the variable name and the coordinate value +(for multidimensinal variables only). +The first example below uses this default labeling. + +.. ipython:: + + In [1]: import arviz as az + ...: schools = az.load_arviz_data("centered_eight") + ...: az.summary(schools) + +Thanks to being powered by xarray, ArviZ supports label based indexing. +We can therefore use the labels we have seen in the summary to plot only a subset of the variables, +the one we are interested in. +Provided we know that the coordinate values shown for theta correspond to the `school` dimension, +we can plot only ``tau`` to better inspect it's 1.03 :func:`~arviz.rhat` and +``theta`` for ``Choate`` and ``St. Paul's``, the ones with higher means: + +.. ipython:: python + + @savefig label_guide_plot_trace.png + az.plot_trace(schools, var_names=["tau", "theta"], coords={"school": ["Choate", "St. Paul's"]}, compact=False); + +So far so good, we can identify some issues for low ``tau`` values which is great start. +But say we want to make a report on Deerfield, Hotchkiss and Lawrenceville schools to +see the probability of ``theta > 5`` and we have to present it somewhere with math notation. +Our default labels show ``theta``, not $\theta$ (generated from ``$\theta$`` using $\LaTeX$). + +Fear not, we can use the labeller argument to customize the labels. +The ``arviz.labels`` module contains some classes that cover some common customization classes. + +In this case, we can use :class:`~arviz.labels.MapLabeller` and +tell it to rename the variable name ``theta`` to ``$\theta$``, like so: + +.. ipython:: + + In [1]: import arviz.labels as azl + ...: labeller = azl.MapLabeller(var_name_map={"theta": r"$\theta$"}) + ...: coords = {"school": ["Deerfield", "Hotchkiss", "Lawrenceville"]} + + @savefig label_guide_plot_posterior.png + In [1]: az.plot_posterior(schools, var_names="theta", coords=coords, labeller=labeller, ref_val=5); + +You can see the labellers available in ArviZ at :ref:`their API reference page `. +Their names aim to be descriptive and they all have examples in their docstring. +For further customization continue reading this guide. + +Sorting labels +-------------- + +Labels in ArviZ can generally be sorted in two ways, +using the arguments passed to ArviZ plotting functions or +sorting the underlying xarray Dataset. +The first one is more convenient for single time ordering +whereas the second is better if you want plots consistenly sorted that way and +is also more flexible, using ArviZ args is more limited. + +Both alternatives have an important limitation though. +Multidimension variables are always together. +We can sort ``theta, mu, tau`` in any order, and within ``theta`` we can sort the schools in any order, +but it's not possible to show half the schools, then ``mu`` and ``tau`` and then the rest of the schools. + +Sorting variable names +...................... + +.. ipython:: + + In [1]: var_order = ["theta", "mu", "tau"] + +.. tabbed:: ArviZ args + + We can pass a list with the variable names sorted to modify the order in which they appear + when calling ArviZ functions + + .. ipython:: + + In [1]: az.summary(schools, var_names=var_order) + +.. tabbed:: xarray + + In xarray, subsetting the Datset with a sorted list of variable names will order the Dataset. + + .. ipython:: + + In [1]: schools.posterior = schools.posterior[var_order] + ...: az.summary(schools) + +Sorting coordinate values +......................... + +We may also want to sort the schools by their mean. +To do so we first have to get the means of each school: + +.. ipython:: + + In [1]: school_means = schools.posterior["theta"].mean(("chain", "draw")) + ...: school_means + +We can then use this DataArray result to sort the coordinate values for ``theta``. +Again we have two alternatives: + +.. tabbed:: ArviZ args + + Here the first step is to sort the coordinate values so we can pass them as `coords` argument and + choose the order of the rows. + If we want to manually sort the schools, `sorted_schools` can be defined straight away as a list + + .. ipython:: + + In [1]: sorted_schools = schools.posterior["school"].sortby(school_means) + ...: az.summary(schools, var_names="theta", coords={"school": sorted_schools}) + +.. tabbed:: xarray + + We can use the :meth:`~xarray.Dataset.sortby` method to order our coordinate values straight at the source + + .. ipython:: + + In [1]: schools.posterior = schools.posterior.sortby(school_means) + ...: az.summary(schools, var_names="theta") + +Sorting dimensions +.................. + +In some cases, our multidimensinal variables may not have only a length ``n`` dimension +(in addition to the ``chain`` and ``draw`` ones) +but could also have multiple dimensions. +Let's imagine we have performed a set of fixed experiments on several days to multiple subjects, +three data dimensions overall. + +We will create a fake inference data with data mimicking this situation to show how to sort dimensions. +To keep things short and not clutter the guide too much with unnecessary output lines, +we will stick to a posterior of a single variable and the dimension sizes will be ``2, 3, 4``. + +.. ipython:: + + In [1]: from numpy.random import default_rng + ...: import pandas as pd + ...: rng = default_rng() + ...: samples = rng.normal(size=(4, 500, 2, 3, 4)) + ...: coords = { + ...: "subject": ["ecoli", "pseudomonas", "clostridium"], + ...: "date": ["1-3-2020", "2-4-2020", "1-5-2020", "1-6-2020"], + ...: "experiment": [1, 2] + ...: } + ...: experiments = az.from_dict( + ...: posterior={"b": samples}, dims={"b": ["experiment", "subject", "date"]}, coords=coords + ...: ) + ...: experiments.posterior + +Given how we have constructed our dataset, the default order is ``experiment, subject, date`` + +.. dropdown:: Click to see the default summary + + .. ipython:: + + In [1]: az.summary(experiments) + +Hovever, we actually want to have the dimensions in this order: ``subject, date, experiment``. +And in this case, we need to modify the underlying xarray object in order to get the desired result: + +.. ipython:: python + + dim_order = ("chain", "draw", "subject", "date", "experiment") + experiments = experiments.posterior.transpose(*dim_order) + az.summary(experiments) + +Note however that we don't need to overwrite or store the modified xarray object. +Doing ``az.summary(experiments.posterior.transpose(*dim_order))`` would work just the same +if we only want to use this order once. + +Labeling with indexes +--------------------- + +As you may have seen, there are labellers with ``Idx`` in their name: +:class:`~arviz.labels.IdxLabeller` and :class:`~arviz.labels.DimIdxLabeller`, +which show the positional index of the values instead of their corresponding coordinate value. + +We have seen before that we can use the ``coords`` argument or +the :meth:`~arviz.InferenceData.sel` method to select data based on the coordinate values. +Similarly, we can use the :meth:`~arviz.InferenceData.isel` method to select data based on positional indexes. + +.. ipython:: python + + az.summary(schools, labeller=azl.IdxLabeller()) + +After seeing this summary, we use ``isel`` to generate the summary of a subset only. + +.. ipython:: python + + az.summary(schools.isel(school=[2, 5, 7]), labeller=azl.IdxLabeller()) + +.. warning:: + + Positional indexing is NOT label based indexing with numbers! + +The positional indexes shown will correspond to the ordinal position *in the subsetted object*. +If you are not subsetting the object, you can use these indexes with ``isel`` without problem. +However, if you are subsetting the data (either directly or with the ``coords`` argument) +and want to use the positional indexes shown, you need to use them on the corresponding subset. + +An example. If you use a dict named ``coords`` when calling a plotting function, +for ``isel`` to work it has to be called on +``original_idata.sel(**coords).isel()`` and +not on ``original_idata.isel()`` + +Labeller mixtures +----------------- + +In some cases, none of the available labellers will do the right job. +One case where this is bound to happen is with ``plot_forest``. +When setting ``legend=True`` it does not really make sense to add the model name to the tick labels. +``plot_forest`` knows that, and if no ``labeller`` is passed, it uses either +:class:`~arviz.labels.BaseLabeller` or :class:`~arviz.labels.NoModelLabeller` depending on the value of ``legend``. +If we do want to use the ``labeller`` argument however, we have to make sure to enforce this default ourselves: + +.. ipython:: python + + schools2 = az.load_arviz_data("non_centered_eight") + + @savefig default_plot_forest.png + az.plot_forest( + (schools, schools2), + model_names=("centered", "non_centered"), + coords={"school": ["Deerfield", "Lawrenceville", "Mt. Hermon"]}, + figsize=(10,7), + labeller=azl.DimCoordLabeller(), + legend=True + ); + +There is a lot of repeated information now. +The variable names, dims and coords are shown for both models and +the models are labeled both in the legend and in the labels of the y axis. +For cases like this, ArviZ provides a convenience function :func:`~arviz.labels.mix_labellers` +that combines labeller classes for some extra customization. +Labeller classes aim to split labeling into atomic tasks and have a method per task to maximize extensibility. +Thus, many new labellers can be created with this mixer function alone without needing to write a new class from scratch. + +.. ipython:: python + + MixtureLabeller = azl.mix_labellers((azl.DimCoordLabeller, azl.NoModelLabeller)) + + @savefig mixture_plot_forest.png + az.plot_forest( + (schools, schools2), + model_names=("centered", "non_centered"), + coords={"school": ["Deerfield", "Lawrenceville", "Mt. Hermon"]}, + figsize=(10,7), + labeller=MixtureLabeller(), + legend=True + ); + +Custom labellers +---------------- + +Section in construction...