Skip to content

Commit

Permalink
Revamping HPD (#1117)
Browse files Browse the repository at this point in the history
* revamped hpd function

* corrected hpd errors

* add docstring

* linting change

* minor nits

* allow input to be ndarray/dataset

* add tests for unimodal case

* test group and other minor nits for unimodal

* corrected failed tests

* added sel argument and its tests

* minor nits

* pydocstyle change

* minor changes

* hpd multimodal

* changes to hpd multimodal

* minor nits

* nits

* final nits

* add to changelog

* changes

* final changes
  • Loading branch information
percygautam authored Apr 2, 2020
1 parent 5c85b77 commit aca28da
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 69 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
* Add `num_chains` and `pred_dims` arguments to io_pyro #1090
* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds (#1079)
* Allow xarray.Dataarray input for plots.(#1120)
* Revamped the `hpd` function to make it work with mutidimensional arrays, InferenceData and xarray objects (#1117)
* Skip test for optional/extra dependencies when not installed (#1113)
### Maintenance and fixes
* Fixed behaviour of `credible_interval=None` in `plot_posterior` (#1115)
* Fixed hist kind of `plot_dist` with multidimensional input (#1115)
* Fixed `TypeError` in `transform` argument of `plot_density` and `plot_forest` when `InferenceData is a list or tuple (#1121)`
* Fixed `TypeError` in `transform` argument of `plot_density` and `plot_forest` when `InferenceData` is a list or tuple (#1121)
### Deprecation

### Documentation
Expand Down
218 changes: 150 additions & 68 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from scipy.optimize import minimize
import xarray as xr

from ..plots.plot_utils import _fast_kde, get_bins
from ..plots.plot_utils import _fast_kde, get_bins, get_coords
from ..data import convert_to_inference_data, convert_to_dataset, InferenceData, CoordSpec, DimSpec
from .diagnostics import _multichain_statistics, _mc_error, ess
from .stats_utils import (
Expand Down Expand Up @@ -305,16 +305,29 @@ def _ic_matrix(ics, ic_i):
return rows, cols, ic_i_val


def hpd(ary, credible_interval=None, circular=False, multimodal=False, skipna=False):
def hpd(
ary,
credible_interval=None,
circular=False,
multimodal=False,
skipna=False,
group="posterior",
var_names=None,
coords=None,
max_modes=10,
**kwargs
):
"""
Calculate highest posterior density (HPD) of array for given credible_interval.
The HPD is the minimum width Bayesian credible interval (BCI).
Parameters
----------
ary : Numpy array
An array containing posterior samples
ary : obj
object containing posterior samples.
Any object that can be converted to an az.InferenceData object.
Refer to documentation of az.convert_to_dataset for details.
credible_interval : float, optional
Credible interval to compute. Defaults to 0.94.
circular : bool, optional
Expand All @@ -326,10 +339,22 @@ def hpd(ary, credible_interval=None, circular=False, multimodal=False, skipna=Fa
modes are well separated.
skipna : bool
If true ignores nan values when computing the hpd interval. Defaults to false.
group : str, optional
Specifies which InferenceData group should be used to calculate hpd.
Defaults to 'posterior'
var_names : list, optional
Names of variables to include in the hpd report
coords: mapping, optional
Specifies the subset over to calculate hpd.
max_modes: int, optional
Specifies the maximume number of modes for multimodal case.
kwargs : dict, optional
Additional keywords passed to `wrap_xarray_ufunc`.
See the docstring of :obj:`wrap_xarray_ufunc method </.stats_utils.wrap_xarray_ufunc>`.
Returns
-------
np.ndarray
np.ndarray or xarray.Dataset, depending upon input
lower(s) and upper(s) values of the interval(s).
Examples
Expand All @@ -342,91 +367,148 @@ def hpd(ary, credible_interval=None, circular=False, multimodal=False, skipna=Fa
...: import numpy as np
...: data = np.random.normal(size=2000)
...: az.hpd(data, credible_interval=.68)
Calculate the hpd of a dataset:
.. ipython::
In [1]: import arviz as az
...: data = az.load_arviz_data('centered_eight')
...: az.hpd(data)
We can also calculate the hpd of some of the variables of dataset:
.. ipython::
In [1]: az.hpd(data, var_names=["mu", "theta"])
If we want to calculate the hpd over specified dimension of dataset,
we can pass `input_core_dims` by kwargs:
.. ipython::
In [1]: az.hpd(data, input_core_dims = [["chain"]])
We can also calculate the hpd over a particular selection over all groups:
.. ipython::
In [1]: az.hpd(data, coords={"chain":[0, 1, 3]}, input_core_dims = [["draw"]])
"""
if credible_interval is None:
credible_interval = rcParams["stats.credible_interval"]
else:
if not 1 >= credible_interval > 0:
raise ValueError("The value of credible_interval should be in the interval (0, 1]")

if ary.ndim > 1:
hpd_array = np.array(
[
hpd(
row,
credible_interval=credible_interval,
circular=circular,
multimodal=multimodal,
)
for row in ary.T
]
)
return hpd_array
func_kwargs = {
"credible_interval": credible_interval,
"skipna": skipna,
"out_shape": (max_modes, 2,) if multimodal else (2,),
}
kwargs.setdefault("output_core_dims", [["hpd", "mode"] if multimodal else ["hpd"]])
if not multimodal:
func_kwargs["circular"] = circular
else:
func_kwargs["max_modes"] = max_modes

if multimodal:
if skipna:
ary = ary[~np.isnan(ary)]
func = _hpd_multimodal if multimodal else _hpd

if ary.dtype.kind == "f":
density, lower, upper = _fast_kde(ary)
range_x = upper - lower
dx = range_x / len(density)
bins = np.linspace(lower, upper, len(density))
else:
bins = get_bins(ary)
_, density, _ = histogram(ary, bins=bins)
dx = np.diff(bins)[0]
isarray = isinstance(ary, np.ndarray)
if isarray and ary.ndim <= 1:
func_kwargs.pop("out_shape")
hpd_data = func(ary, **func_kwargs) # pylint: disable=unexpected-keyword-arg
return hpd_data[~np.isnan(hpd_data).all(axis=1), :] if multimodal else hpd_data

density *= dx
if isarray and ary.ndim == 2:
kwargs.setdefault("input_core_dims", [["chain"]])

idx = np.argsort(-density)
intervals = bins[idx][density[idx].cumsum() <= credible_interval]
intervals.sort()
ary = convert_to_dataset(ary, group=group)
if coords is not None:
ary = get_coords(ary, coords)
var_names = _var_names(var_names, ary)
ary = ary[var_names] if var_names else ary

hpd_data = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs)
hpd_data = hpd_data.dropna("mode", how="all") if multimodal else hpd_data
return hpd_data.x.values if isarray else hpd_data


def _hpd(ary, credible_interval, circular, skipna):
"""Compute hpd over the flattened array."""
ary = ary.flatten()
if skipna:
nans = np.isnan(ary)
if not nans.all():
ary = ary[~nans]
n = len(ary)

intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1)
if circular:
mean = st.circmean(ary, high=np.pi, low=-np.pi)
ary = ary - mean
ary = np.arctan2(np.sin(ary), np.cos(ary))

hpd_intervals = []
for interval in intervals_splitted:
if interval.size == 0:
hpd_intervals.append((bins[0], bins[0]))
else:
hpd_intervals.append((interval[0], interval[-1]))
ary = np.sort(ary)
interval_idx_inc = int(np.floor(credible_interval * n))
n_intervals = n - interval_idx_inc
interval_width = ary[interval_idx_inc:] - ary[:n_intervals]

hpd_intervals = np.array(hpd_intervals)
if len(interval_width) == 0:
raise ValueError("Too few elements for interval calculation. ")

else:
if skipna:
nans = np.isnan(ary)
if not nans.all():
ary = ary[~nans]
n = len(ary)
min_idx = np.argmin(interval_width)
hdi_min = ary[min_idx]
hdi_max = ary[min_idx + interval_idx_inc]

if circular:
mean = st.circmean(ary, high=np.pi, low=-np.pi)
ary = ary - mean
ary = np.arctan2(np.sin(ary), np.cos(ary))
if circular:
hdi_min = hdi_min + mean
hdi_max = hdi_max + mean
hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min))
hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max))

ary = np.sort(ary)
interval_idx_inc = int(np.floor(credible_interval * n))
n_intervals = n - interval_idx_inc
interval_width = ary[interval_idx_inc:] - ary[:n_intervals]
hpd_intervals = np.array([hdi_min, hdi_max])

if len(interval_width) == 0:
raise ValueError("Too few elements for interval calculation. ")
return hpd_intervals

min_idx = np.argmin(interval_width)
hdi_min = ary[min_idx]
hdi_max = ary[min_idx + interval_idx_inc]

if circular:
hdi_min = hdi_min + mean
hdi_max = hdi_max + mean
hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min))
hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max))
def _hpd_multimodal(ary, credible_interval, skipna, max_modes):
"""Compute hpd if the distribution is multimodal."""
ary = ary.flatten()
if skipna:
ary = ary[~np.isnan(ary)]

hpd_intervals = np.array([hdi_min, hdi_max])
if ary.dtype.kind == "f":
density, lower, upper = _fast_kde(ary)
range_x = upper - lower
dx = range_x / len(density)
bins = np.linspace(lower, upper, len(density))
else:
bins = get_bins(ary)
_, density, _ = histogram(ary, bins=bins)
dx = np.diff(bins)[0]

return hpd_intervals
density *= dx

idx = np.argsort(-density)
intervals = bins[idx][density[idx].cumsum() <= credible_interval]
intervals.sort()

intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1)

hpd_intervals = np.full((max_modes, 2,), np.nan,)
for i, interval in enumerate(intervals_splitted):
if i == max_modes:
warnings.warn(
"found more modes than {0}, returning only the first {0} modes", max_modes
)
break
if interval.size == 0:
hpd_intervals[i] = np.asarray([bins[0], bins[0]])
else:
hpd_intervals[i] = np.asarray([interval[0], interval[-1]])

return np.array(hpd_intervals)


def loo(data, pointwise=False, reff=None, scale=None):
Expand Down
46 changes: 46 additions & 0 deletions arviz/tests/base_tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,52 @@ def test_hpd():
assert_array_almost_equal(interval, [-1.88, 1.88], 2)


def test_hpd_2darray():
normal_sample = np.random.randn(12000, 5)
result = hpd(normal_sample)
assert result.shape == (5, 2,)


def test_hpd_multidimension():
normal_sample = np.random.randn(12000, 10, 3)
result = hpd(normal_sample)
assert result.shape == (3, 2,)


def test_hpd_idata(centered_eight):
data = centered_eight.posterior
result = hpd(data)
assert isinstance(result, Dataset)
assert result.dims == {"school": 8, "hpd": 2}

result = hpd(data, input_core_dims=[["chain"]])
assert isinstance(result, Dataset)
assert result.dims == {"draw": 500, "hpd": 2, "school": 8}


def test_hpd_idata_varnames(centered_eight):
data = centered_eight.posterior
result = hpd(data, var_names=["mu", "theta"])
assert isinstance(result, Dataset)
assert result.dims == {"hpd": 2, "school": 8}
assert list(result.data_vars.keys()) == ["mu", "theta"]


def test_hpd_idata_group(centered_eight):
result_posterior = hpd(centered_eight, group="posterior", var_names="mu")
result_prior = hpd(centered_eight, group="prior", var_names="mu")
assert result_prior.dims == {"hpd": 2}
range_posterior = result_posterior.mu.values[1] - result_posterior.mu.values[0]
range_prior = result_prior.mu.values[1] - result_prior.mu.values[0]
assert range_posterior < range_prior


def test_hpd_coords(centered_eight):
data = centered_eight.posterior
result = hpd(data, coords={"chain": [0, 1, 3]}, input_core_dims=[["draw"]])
assert_array_equal(result.coords["chain"], [0, 1, 3])


def test_hpd_multimodal():
normal_sample = np.concatenate(
(np.random.normal(-4, 1, 2500000), np.random.normal(2, 0.5, 2500000))
Expand Down

0 comments on commit aca28da

Please sign in to comment.