Skip to content

Commit

Permalink
Merge d64a5fd into 29b35ce
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril authored Feb 21, 2021
2 parents 29b35ce + d64a5fd commit 423d8a1
Show file tree
Hide file tree
Showing 73 changed files with 1,702 additions and 695 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ coverage.xml
*.log

# Sphinx documentation
docs/source/_build/
doc/build/
doc/source/savefig

# PyBuilder
target/
Expand Down
17 changes: 17 additions & 0 deletions .projections.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"arviz/plots/backends/matplotlib/*.py": {
"alternate": "arviz/plots/backends/bokeh/{}.py",
"related": "arviz/plots/{}.py",
"type": "mpl"
},
"arviz/plots/backends/bokeh/*.py": {
"alternate": "arviz/plots/backends/matplotlib/{}.py",
"related": "arviz/plots/{}.py",
"type": "bokeh"
},
"arviz/plots/*.py": {
"alternate": "arviz/plots/backends/matplotlib/{}.py",
"related": "arviz/plots/backends/bokeh/{}.py",
"type": "base"
}
}
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,26 @@

## v0.x.x Unreleased
### New features
* Added `labeller` argument to enable label customization in plots and summary ([1201](https://github.com/arviz-devs/arviz/pull/1201))
* Added `arviz.labels` module with classes and utilities ([1201](https://github.com/arviz-devs/arviz/pull/1201))

### Maintenance and fixes
* Enforced using coordinate values as default labels ([1201](https://github.com/arviz-devs/arviz/pull/1201))
* Integrate `index_origin` with all the library ([1201](https://github.com/arviz-devs/arviz/pull/1201))

### Deprecation
* Deprecated `index_origin` and `order` arguments in `az.summary` ([1201](https://github.com/arviz-devs/arviz/pull/1201))

### Documentation
* Added "Label guide" page and API section for `arviz.labels` module ([1201](https://github.com/arviz-devs/arviz/pull/1201))

## v0.11.2 (2021 Feb 21)
### New features
* Added `to_zarr` and `from_zarr` methods to InferenceData ([1518](https://github.com/arviz-devs/arviz/pull/1518))
* Added confidence interval band to auto-correlation plot ([1535](https://github.com/arviz-devs/arviz/pull/1535))

### Maintenance and fixes
* Updated CmdStanPy converter form compatibility with versions >=0.9.68 ([1558](https://github.com/arviz-devs/arviz/pull/1558) and ([1564](https://github.com/arviz-devs/arviz/pull/1564))
* Updated `from_cmdstanpy`, `from_cmdstan`, `from_numpyro` and `from_pymc3` converters to follow schema convention ([1550](https://github.com/arviz-devs/arviz/pull/1550), [1541](https://github.com/arviz-devs/arviz/pull/1541), [1525](https://github.com/arviz-devs/arviz/pull/1525) and [1555](https://github.com/arviz-devs/arviz/pull/1555))
* Fix calculation of mode as point estimate ([1552](https://github.com/arviz-devs/arviz/pull/1552))
* Remove variable name from legend in posterior predictive plot ([1559](https://github.com/arviz-devs/arviz/pull/1559))
Expand Down
108 changes: 80 additions & 28 deletions arviz/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import json # type: ignore

from .. import __version__, utils
from ..rcparams import rcParams

CoordSpec = Dict[str, List[Any]]
DimSpec = Dict[str, List[str]]
Expand Down Expand Up @@ -49,7 +50,13 @@ def wrapped(cls, *args, **kwargs):


def generate_dims_coords(
shape, var_name, dims=None, coords=None, default_dims=None, skip_event_dims=None
shape,
var_name,
dims=None,
coords=None,
default_dims=None,
index_origin=None,
skip_event_dims=None,
):
"""Generate default dimensions and coordinates for a variable.
Expand All @@ -70,6 +77,9 @@ def generate_dims_coords(
when manipulating Monte Carlo traces, the ``default_dims`` would be
``["chain" , "draw"]`` which ArviZ uses as its own names for dimensions
of MCMC traces.
index_origin : int, optional
Starting value of integer coordinate values. Defaults to the value in rcParam
``data.index_origin``.
skip_event_dims : bool, default False
Returns
Expand All @@ -79,6 +89,8 @@ def generate_dims_coords(
dict[str] -> list[str]
Default coords
"""
if index_origin is None:
index_origin = rcParams["data.index_origin"]
if default_dims is None:
default_dims = []
if dims is None:
Expand Down Expand Up @@ -127,19 +139,30 @@ def generate_dims_coords(
dims[idx] = dim_name
dim_name = dims[idx]
if dim_name not in coords:
coords[dim_name] = utils.arange(dim_len)
coords[dim_name] = np.arange(index_origin, dim_len + index_origin)
coords = {key: coord for key, coord in coords.items() if any(key == dim for dim in dims)}
return dims, coords


def numpy_to_data_array(ary, *, var_name="data", coords=None, dims=None, skip_event_dims=None):
def numpy_to_data_array(
ary,
*,
var_name="data",
coords=None,
dims=None,
default_dims=None,
index_origin=None,
skip_event_dims=None,
):
"""Convert a numpy array to an xarray.DataArray.
The first two dimensions will be (chain, draw), and any remaining
By default, the first two dimensions will be (chain, draw), and any remaining
dimensions will be "shape".
If the numpy array is 1d, this dimension is interpreted as draw
If the numpy array is 2d, it is interpreted as (chain, draw)
If the numpy array is 3 or more dimensions, the last dimensions are kept as shapes.
* If the numpy array is 1d, this dimension is interpreted as draw
* If the numpy array is 2d, it is interpreted as (chain, draw)
* If the numpy array is 3 or more dimensions, the last dimensions are kept as shapes.
To modify this behaviour, use ``default_dims``.
Parameters
----------
Expand All @@ -154,6 +177,11 @@ def numpy_to_data_array(ary, *, var_name="data", coords=None, dims=None, skip_ev
is the name of the dimension, the values are the index values.
dims : List(str)
A list of coordinate names for the variable
default_dims : list of str, optional
Passed to :py:func:`generate_dims_coords`. Defaults to ``["chain", "draw"]``, and
an empty list is accepted
index_origin : int, optional
Passed to :py:func:`generate_dims_coords`
skip_event_dims : bool
Returns
Expand All @@ -162,45 +190,59 @@ def numpy_to_data_array(ary, *, var_name="data", coords=None, dims=None, skip_ev
Will have the same data as passed, but with coordinates and dimensions
"""
# manage and transform copies
default_dims = ["chain", "draw"]
ary = utils.two_de(ary)
n_chains, n_samples, *shape = ary.shape
if n_chains > n_samples:
warnings.warn(
"More chains ({n_chains}) than draws ({n_samples}). "
"Passed array should have shape (chains, draws, *shape)".format(
n_chains=n_chains, n_samples=n_samples
),
UserWarning,
)
if default_dims is None:
default_dims = ["chain", "draw"]
if "chain" in default_dims and "draw" in default_dims:
ary = utils.two_de(ary)
n_chains, n_samples, *_ = ary.shape
if n_chains > n_samples:
warnings.warn(
"More chains ({n_chains}) than draws ({n_samples}). "
"Passed array should have shape (chains, draws, *shape)".format(
n_chains=n_chains, n_samples=n_samples
),
UserWarning,
)
else:
ary = utils.one_de(ary)

dims, coords = generate_dims_coords(
shape,
ary.shape[len(default_dims) :],
var_name,
dims=dims,
coords=coords,
default_dims=default_dims,
index_origin=index_origin,
skip_event_dims=skip_event_dims,
)

# reversed order for default dims: 'chain', 'draw'
if "draw" not in dims:
if "draw" not in dims and "draw" in default_dims:
dims = ["draw"] + dims
if "chain" not in dims:
if "chain" not in dims and "chain" in default_dims:
dims = ["chain"] + dims

if "chain" not in coords:
coords["chain"] = utils.arange(n_chains)
if "draw" not in coords:
coords["draw"] = utils.arange(n_samples)
index_origin = rcParams["data.index_origin"]
if "chain" not in coords and "chain" in default_dims:
coords["chain"] = np.arange(index_origin, n_chains + index_origin)
if "draw" not in coords and "draw" in default_dims:
coords["draw"] = np.arange(index_origin, n_samples + index_origin)

# filter coords based on the dims
coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in dims}
return xr.DataArray(ary, coords=coords, dims=dims)


def dict_to_dataset(
data, *, attrs=None, library=None, coords=None, dims=None, skip_event_dims=None
data,
*,
attrs=None,
library=None,
coords=None,
dims=None,
default_dims=None,
index_origin=None,
skip_event_dims=None,
):
"""Convert a dictionary of numpy arrays to an xarray.Dataset.
Expand All @@ -217,6 +259,10 @@ def dict_to_dataset(
dims : dict[str] -> list[str]
Dimensions of each variable. The keys are variable names, values are lists of
coordinates.
default_dims : list of str, optional
Passed to :py:func:`numpy_to_data_array`
index_origin : int, optional
Passed to :py:func:`numpy_to_data_array`
skip_event_dims : bool
If True, cut extra dims whenever present to match the shape of the data.
Necessary for PPLs which have the same name in both observed data and log
Expand All @@ -238,7 +284,13 @@ def dict_to_dataset(
data_vars = {}
for key, values in data.items():
data_vars[key] = numpy_to_data_array(
values, var_name=key, coords=coords, dims=dims.get(key), skip_event_dims=skip_event_dims
values,
var_name=key,
coords=coords,
dims=dims.get(key),
default_dims=default_dims,
index_origin=index_origin,
skip_event_dims=skip_event_dims,
)
return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))

Expand Down Expand Up @@ -312,7 +364,7 @@ def wrapped(self, *args, **kwargs):
return None if _inplace else out

description_default = """{method_name} method is extended from xarray.Dataset methods.
{description}For more info see :meth:`xarray:xarray.Dataset.{method_name}`
""".format(
description=description, method_name=func.__name__ # pylint: disable=no-member
Expand Down
Loading

0 comments on commit 423d8a1

Please sign in to comment.