Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revamping HPD #1117

Merged
merged 24 commits into from
Apr 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
* Add `num_chains` and `pred_dims` arguments to io_pyro #1090
* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds (#1079)
* Allow xarray.Dataarray input for plots.(#1120)
* Revamped the `hpd` function to make it work with mutidimensional arrays, InferenceData and xarray objects (#1117)
* Skip test for optional/extra dependencies when not installed (#1113)
### Maintenance and fixes
* Fixed behaviour of `credible_interval=None` in `plot_posterior` (#1115)
* Fixed hist kind of `plot_dist` with multidimensional input (#1115)
* Fixed `TypeError` in `transform` argument of `plot_density` and `plot_forest` when `InferenceData is a list or tuple (#1121)`
* Fixed `TypeError` in `transform` argument of `plot_density` and `plot_forest` when `InferenceData` is a list or tuple (#1121)
### Deprecation

### Documentation
Expand Down
218 changes: 150 additions & 68 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from scipy.optimize import minimize
import xarray as xr

from ..plots.plot_utils import _fast_kde, get_bins
from ..plots.plot_utils import _fast_kde, get_bins, get_coords
from ..data import convert_to_inference_data, convert_to_dataset, InferenceData, CoordSpec, DimSpec
from .diagnostics import _multichain_statistics, _mc_error, ess
from .stats_utils import (
Expand Down Expand Up @@ -305,16 +305,29 @@ def _ic_matrix(ics, ic_i):
return rows, cols, ic_i_val


def hpd(ary, credible_interval=None, circular=False, multimodal=False, skipna=False):
def hpd(
ary,
credible_interval=None,
circular=False,
multimodal=False,
skipna=False,
group="posterior",
var_names=None,
coords=None,
max_modes=10,
**kwargs
OriolAbril marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Calculate highest posterior density (HPD) of array for given credible_interval.

The HPD is the minimum width Bayesian credible interval (BCI).

Parameters
----------
ary : Numpy array
An array containing posterior samples
ary : obj
object containing posterior samples.
Any object that can be converted to an az.InferenceData object.
Refer to documentation of az.convert_to_dataset for details.
credible_interval : float, optional
Credible interval to compute. Defaults to 0.94.
circular : bool, optional
Expand All @@ -326,10 +339,22 @@ def hpd(ary, credible_interval=None, circular=False, multimodal=False, skipna=Fa
modes are well separated.
skipna : bool
If true ignores nan values when computing the hpd interval. Defaults to false.
group : str, optional
Specifies which InferenceData group should be used to calculate hpd.
Defaults to 'posterior'
var_names : list, optional
Names of variables to include in the hpd report
coords: mapping, optional
Specifies the subset over to calculate hpd.
max_modes: int, optional
Specifies the maximume number of modes for multimodal case.
kwargs : dict, optional
Additional keywords passed to `wrap_xarray_ufunc`.
See the docstring of :obj:`wrap_xarray_ufunc method </.stats_utils.wrap_xarray_ufunc>`.

Returns
-------
np.ndarray
np.ndarray or xarray.Dataset, depending upon input
lower(s) and upper(s) values of the interval(s).

Examples
Expand All @@ -342,91 +367,148 @@ def hpd(ary, credible_interval=None, circular=False, multimodal=False, skipna=Fa
...: import numpy as np
...: data = np.random.normal(size=2000)
...: az.hpd(data, credible_interval=.68)

Calculate the hpd of a dataset:

.. ipython::

In [1]: import arviz as az
...: data = az.load_arviz_data('centered_eight')
...: az.hpd(data)

We can also calculate the hpd of some of the variables of dataset:

.. ipython::

In [1]: az.hpd(data, var_names=["mu", "theta"])

If we want to calculate the hpd over specified dimension of dataset,
we can pass `input_core_dims` by kwargs:

.. ipython::

In [1]: az.hpd(data, input_core_dims = [["chain"]])

We can also calculate the hpd over a particular selection over all groups:

.. ipython::

In [1]: az.hpd(data, coords={"chain":[0, 1, 3]}, input_core_dims = [["draw"]])

"""
if credible_interval is None:
credible_interval = rcParams["stats.credible_interval"]
else:
if not 1 >= credible_interval > 0:
raise ValueError("The value of credible_interval should be in the interval (0, 1]")

if ary.ndim > 1:
hpd_array = np.array(
[
hpd(
row,
credible_interval=credible_interval,
circular=circular,
multimodal=multimodal,
)
for row in ary.T
]
)
return hpd_array
func_kwargs = {
"credible_interval": credible_interval,
"skipna": skipna,
percygautam marked this conversation as resolved.
Show resolved Hide resolved
"out_shape": (max_modes, 2,) if multimodal else (2,),
}
kwargs.setdefault("output_core_dims", [["hpd", "mode"] if multimodal else ["hpd"]])
if not multimodal:
func_kwargs["circular"] = circular
else:
func_kwargs["max_modes"] = max_modes

if multimodal:
if skipna:
ary = ary[~np.isnan(ary)]
func = _hpd_multimodal if multimodal else _hpd

if ary.dtype.kind == "f":
density, lower, upper = _fast_kde(ary)
range_x = upper - lower
dx = range_x / len(density)
bins = np.linspace(lower, upper, len(density))
else:
bins = get_bins(ary)
_, density, _ = histogram(ary, bins=bins)
dx = np.diff(bins)[0]
isarray = isinstance(ary, np.ndarray)
if isarray and ary.ndim <= 1:
func_kwargs.pop("out_shape")
hpd_data = func(ary, **func_kwargs) # pylint: disable=unexpected-keyword-arg
return hpd_data[~np.isnan(hpd_data).all(axis=1), :] if multimodal else hpd_data

density *= dx
if isarray and ary.ndim == 2:
kwargs.setdefault("input_core_dims", [["chain"]])

idx = np.argsort(-density)
intervals = bins[idx][density[idx].cumsum() <= credible_interval]
intervals.sort()
ary = convert_to_dataset(ary, group=group)
if coords is not None:
ary = get_coords(ary, coords)
var_names = _var_names(var_names, ary)
ary = ary[var_names] if var_names else ary
OriolAbril marked this conversation as resolved.
Show resolved Hide resolved

hpd_data = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs)
hpd_data = hpd_data.dropna("mode", how="all") if multimodal else hpd_data
return hpd_data.x.values if isarray else hpd_data


def _hpd(ary, credible_interval, circular, skipna):
"""Compute hpd over the flattened array."""
ary = ary.flatten()
if skipna:
nans = np.isnan(ary)
if not nans.all():
ary = ary[~nans]
n = len(ary)

intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1)
if circular:
mean = st.circmean(ary, high=np.pi, low=-np.pi)
ary = ary - mean
ary = np.arctan2(np.sin(ary), np.cos(ary))

hpd_intervals = []
for interval in intervals_splitted:
if interval.size == 0:
hpd_intervals.append((bins[0], bins[0]))
else:
hpd_intervals.append((interval[0], interval[-1]))
ary = np.sort(ary)
interval_idx_inc = int(np.floor(credible_interval * n))
n_intervals = n - interval_idx_inc
interval_width = ary[interval_idx_inc:] - ary[:n_intervals]

hpd_intervals = np.array(hpd_intervals)
if len(interval_width) == 0:
raise ValueError("Too few elements for interval calculation. ")

else:
if skipna:
nans = np.isnan(ary)
if not nans.all():
ary = ary[~nans]
n = len(ary)
min_idx = np.argmin(interval_width)
hdi_min = ary[min_idx]
hdi_max = ary[min_idx + interval_idx_inc]

if circular:
mean = st.circmean(ary, high=np.pi, low=-np.pi)
ary = ary - mean
ary = np.arctan2(np.sin(ary), np.cos(ary))
if circular:
hdi_min = hdi_min + mean
hdi_max = hdi_max + mean
hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min))
hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max))

ary = np.sort(ary)
interval_idx_inc = int(np.floor(credible_interval * n))
n_intervals = n - interval_idx_inc
interval_width = ary[interval_idx_inc:] - ary[:n_intervals]
hpd_intervals = np.array([hdi_min, hdi_max])

if len(interval_width) == 0:
raise ValueError("Too few elements for interval calculation. ")
return hpd_intervals

min_idx = np.argmin(interval_width)
hdi_min = ary[min_idx]
hdi_max = ary[min_idx + interval_idx_inc]

if circular:
hdi_min = hdi_min + mean
hdi_max = hdi_max + mean
hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min))
hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max))
def _hpd_multimodal(ary, credible_interval, skipna, max_modes):
"""Compute hpd if the distribution is multimodal."""
ary = ary.flatten()
if skipna:
ary = ary[~np.isnan(ary)]

hpd_intervals = np.array([hdi_min, hdi_max])
if ary.dtype.kind == "f":
density, lower, upper = _fast_kde(ary)
range_x = upper - lower
dx = range_x / len(density)
bins = np.linspace(lower, upper, len(density))
else:
bins = get_bins(ary)
_, density, _ = histogram(ary, bins=bins)
dx = np.diff(bins)[0]

return hpd_intervals
density *= dx

idx = np.argsort(-density)
intervals = bins[idx][density[idx].cumsum() <= credible_interval]
intervals.sort()

intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1)

hpd_intervals = np.full((max_modes, 2,), np.nan,)
for i, interval in enumerate(intervals_splitted):
if i == max_modes:
warnings.warn(
"found more modes than {0}, returning only the first {0} modes", max_modes
)
break
if interval.size == 0:
hpd_intervals[i] = np.asarray([bins[0], bins[0]])
else:
hpd_intervals[i] = np.asarray([interval[0], interval[-1]])

return np.array(hpd_intervals)


def loo(data, pointwise=False, reff=None, scale=None):
Expand Down
46 changes: 46 additions & 0 deletions arviz/tests/base_tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,52 @@ def test_hpd():
assert_array_almost_equal(interval, [-1.88, 1.88], 2)


def test_hpd_2darray():
normal_sample = np.random.randn(12000, 5)
result = hpd(normal_sample)
assert result.shape == (5, 2,)


def test_hpd_multidimension():
percygautam marked this conversation as resolved.
Show resolved Hide resolved
normal_sample = np.random.randn(12000, 10, 3)
result = hpd(normal_sample)
assert result.shape == (3, 2,)


def test_hpd_idata(centered_eight):
data = centered_eight.posterior
result = hpd(data)
assert isinstance(result, Dataset)
assert result.dims == {"school": 8, "hpd": 2}

result = hpd(data, input_core_dims=[["chain"]])
assert isinstance(result, Dataset)
assert result.dims == {"draw": 500, "hpd": 2, "school": 8}


def test_hpd_idata_varnames(centered_eight):
data = centered_eight.posterior
result = hpd(data, var_names=["mu", "theta"])
assert isinstance(result, Dataset)
assert result.dims == {"hpd": 2, "school": 8}
assert list(result.data_vars.keys()) == ["mu", "theta"]

percygautam marked this conversation as resolved.
Show resolved Hide resolved

def test_hpd_idata_group(centered_eight):
result_posterior = hpd(centered_eight, group="posterior", var_names="mu")
result_prior = hpd(centered_eight, group="prior", var_names="mu")
assert result_prior.dims == {"hpd": 2}
range_posterior = result_posterior.mu.values[1] - result_posterior.mu.values[0]
range_prior = result_prior.mu.values[1] - result_prior.mu.values[0]
assert range_posterior < range_prior


def test_hpd_coords(centered_eight):
data = centered_eight.posterior
result = hpd(data, coords={"chain": [0, 1, 3]}, input_core_dims=[["draw"]])
assert_array_equal(result.coords["chain"], [0, 1, 3])


def test_hpd_multimodal():
normal_sample = np.concatenate(
(np.random.normal(-4, 1, 2500000), np.random.normal(2, 0.5, 2500000))
Expand Down