Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revamping HPD #1117

Merged
merged 24 commits into from
Apr 2, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 75 additions & 65 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def _ic_matrix(ics, ic_i):
return rows, cols, ic_i_val


def hpd(ary, credible_interval=None, circular=False, multimodal=False, skipna=False):
def hpd(ary, *, credible_interval=None, circular=False, multimodal=False, skipna=False, **kwargs):
percygautam marked this conversation as resolved.
Show resolved Hide resolved
"""
Calculate highest posterior density (HPD) of array for given credible_interval.

Expand All @@ -315,6 +315,8 @@ def hpd(ary, credible_interval=None, circular=False, multimodal=False, skipna=Fa
----------
ary : Numpy array
An array containing posterior samples
group : List
An list containing the dimensions to compute hpd
credible_interval : float, optional
Credible interval to compute. Defaults to 0.94.
circular : bool, optional
Expand Down Expand Up @@ -349,84 +351,92 @@ def hpd(ary, credible_interval=None, circular=False, multimodal=False, skipna=Fa
if not 1 >= credible_interval > 0:
raise ValueError("The value of credible_interval should be in the interval (0, 1]")

if ary.ndim > 1:
hpd_array = np.array(
[
hpd(
row,
credible_interval=credible_interval,
circular=circular,
multimodal=multimodal,
)
for row in ary.T
]
)
return hpd_array

if multimodal:
if skipna:
ary = ary[~np.isnan(ary)]

if ary.dtype.kind == "f":
density, lower, upper = _fast_kde(ary)
range_x = upper - lower
dx = range_x / len(density)
bins = np.linspace(lower, upper, len(density))
else:
bins = get_bins(ary)
_, density, _ = histogram(ary, bins=bins)
dx = np.diff(bins)[0]
return _hpd_multimodal(ary, credible_interval, skipna)

if isinstance(ary, np.ndarray):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be only if the array is 1d or 2d:

isarray = isinstance(ary, np.ndarray)
if isarray and ary.ndim <= 2:

If the array has 3 or more dimensions, it should assume ArviZ dim order: (chain, draw, *shape). hpd should still return a numpy array though:

...
hpd_data = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs)
hpd_data = hpd_data.dropna("mode", how="all") if multimodal else hpd_data
return hpd_data.x.values if isarray else hpd_data

ary = convert_to_dataset(ary)
percygautam marked this conversation as resolved.
Show resolved Hide resolved

kwargs.setdefault("input_core_dims", [["chain", "draw"]])
kwargs.setdefault("output_core_dims", [["hpd"]])
func_kwargs = {
"credible_interval": credible_interval,
"circular": circular,
"skipna": skipna,
"out_shape": (2,),
}
return _wrap_xarray_ufunc(_hpd, ary, func_kwargs=func_kwargs, **kwargs)
percygautam marked this conversation as resolved.
Show resolved Hide resolved

density *= dx

idx = np.argsort(-density)
intervals = bins[idx][density[idx].cumsum() <= credible_interval]
intervals.sort()
def _hpd(ary, credible_interval, circular, skipna):
"""Compute hpd over the flattened array."""
ary = ary.flatten()
if skipna:
nans = np.isnan(ary)
if not nans.all():
ary = ary[~nans]
n = len(ary)

intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1)
if circular:
mean = st.circmean(ary, high=np.pi, low=-np.pi)
ary = ary - mean
ary = np.arctan2(np.sin(ary), np.cos(ary))

hpd_intervals = []
for interval in intervals_splitted:
if interval.size == 0:
hpd_intervals.append((bins[0], bins[0]))
else:
hpd_intervals.append((interval[0], interval[-1]))
ary = np.sort(ary)
interval_idx_inc = int(np.floor(credible_interval * n))
n_intervals = n - interval_idx_inc
interval_width = ary[interval_idx_inc:] - ary[:n_intervals]

hpd_intervals = np.array(hpd_intervals)
if len(interval_width) == 0:
raise ValueError("Too few elements for interval calculation. ")

else:
if skipna:
nans = np.isnan(ary)
if not nans.all():
ary = ary[~nans]
n = len(ary)
min_idx = np.argmin(interval_width)
hdi_min = ary[min_idx]
hdi_max = ary[min_idx + interval_idx_inc]

if circular:
mean = st.circmean(ary, high=np.pi, low=-np.pi)
ary = ary - mean
ary = np.arctan2(np.sin(ary), np.cos(ary))
if circular:
hdi_min = hdi_min + mean
hdi_max = hdi_max + mean
hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min))
hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max))

ary = np.sort(ary)
interval_idx_inc = int(np.floor(credible_interval * n))
n_intervals = n - interval_idx_inc
interval_width = ary[interval_idx_inc:] - ary[:n_intervals]
hpd_intervals = np.array([hdi_min, hdi_max])

if len(interval_width) == 0:
raise ValueError("Too few elements for interval calculation. ")
return hpd_intervals

min_idx = np.argmin(interval_width)
hdi_min = ary[min_idx]
hdi_max = ary[min_idx + interval_idx_inc]

if circular:
hdi_min = hdi_min + mean
hdi_max = hdi_max + mean
hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min))
hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max))
def _hpd_multimodal(ary, credible_interval, skipna):
"""Compute hpd if the distribution is multimodal"""
percygautam marked this conversation as resolved.
Show resolved Hide resolved

hpd_intervals = np.array([hdi_min, hdi_max])
if skipna:
ary = ary[~np.isnan(ary)]

return hpd_intervals
if ary.dtype.kind == "f":
density, lower, upper = _fast_kde(ary)
range_x = upper - lower
dx = range_x / len(density)
bins = np.linspace(lower, upper, len(density))
else:
bins = get_bins(ary)
_, density, _ = histogram(ary, bins=bins)
dx = np.diff(bins)[0]

density *= dx

idx = np.argsort(-density)
intervals = bins[idx][density[idx].cumsum() <= credible_interval]
intervals.sort()

intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1)

hpd_intervals = []
percygautam marked this conversation as resolved.
Show resolved Hide resolved
for interval in intervals_splitted:
if interval.size == 0:
hpd_intervals.append((bins[0], bins[0]))
else:
hpd_intervals.append((interval[0], interval[-1]))
percygautam marked this conversation as resolved.
Show resolved Hide resolved

return np.array(hpd_intervals)


def loo(data, pointwise=False, reff=None, scale=None):
Expand Down
6 changes: 3 additions & 3 deletions arviz/tests/base_tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def non_centered_eight():
def test_hpd():
normal_sample = np.random.randn(5000000)
interval = hpd(normal_sample)
assert_array_almost_equal(interval, [-1.88, 1.88], 2)
assert_array_almost_equal(interval.x.values, [-1.88, 1.88], 2)

percygautam marked this conversation as resolved.
Show resolved Hide resolved

def test_hpd_multimodal():
Expand All @@ -58,7 +58,7 @@ def test_hpd_multimodal():
def test_hpd_circular():
normal_sample = np.random.vonmises(np.pi, 1, 5000000)
interval = hpd(normal_sample, circular=True)
assert_array_almost_equal(interval, [0.6, -0.6], 1)
assert_array_almost_equal(interval.x.values, [0.6, -0.6], 1)


def test_hpd_bad_ci():
Expand All @@ -72,7 +72,7 @@ def test_hpd_skipna():
interval = hpd(normal_sample[10:])
normal_sample[:10] = np.nan
interval_ = hpd(normal_sample, skipna=True)
assert_array_almost_equal(interval, interval_)
assert_array_almost_equal(interval.x.values, interval_.x.values)


def test_r2_score():
Expand Down