From 0b9bb6c059ad49420a3c6a8ec9751a96aac5e551 Mon Sep 17 00:00:00 2001 From: Ari Hartikainen Date: Tue, 15 Jan 2019 00:25:00 +0200 Subject: [PATCH] add stable logsumexp (#522) * add stable logsumexp * use b_inv * handle none * Change comparison * fix dtype casting and axis is None * fix test back and fix logsumexp missing subtraction * more tests * fix tests * add test against scipy * fix b, b_inv behaviour * test when b_inv=0 --- .pylintrc | 3 +- arviz/stats/stats.py | 66 +++++++++++++++++++++++++++++++++--- arviz/tests/test_stats.py | 70 +++++++++++++++++++++++++++++++++++---- 3 files changed, 127 insertions(+), 12 deletions(-) diff --git a/.pylintrc b/.pylintrc index 4f6b41dfb8..c1fdab2c4d 100644 --- a/.pylintrc +++ b/.pylintrc @@ -233,7 +233,8 @@ function-naming-style=snake_case #function-rgx= # Good variable names which should always be accepted, separated by a comma -good-names=i, +good-names=b, + i, j, k, t, diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 7837646e54..6c69496fe5 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -1,9 +1,9 @@ """Statistical functions in ArviZ.""" import warnings +from collections.abc import Sequence import numpy as np import pandas as pd -from scipy.special import logsumexp import scipy.stats as st from scipy.optimize import minimize import xarray as xr @@ -281,6 +281,62 @@ def hpd(x, credible_interval=0.94, circular=False): return np.array([hdi_min, hdi_max]) +def _logsumexp(ary, *, b=None, b_inv=None, axis=None, keepdims=False, out=None, copy=True): + """Stable logsumexp when b >= 0 and b is scalar. + + b_inv overwrites b unless b_inv is None. + """ + # check dimensions for result arrays + ary = np.asarray(ary) + if ary.dtype.kind == "i": + ary = ary.astype(np.float64) + dtype = ary.dtype.type + shape = ary.shape + shape_len = len(shape) + if isinstance(axis, Sequence): + axis = tuple(axis_i if axis_i >= 0 else shape_len + axis_i for axis_i in axis) + agroup = axis + else: + axis = axis if (axis is None) or (axis >= 0) else shape_len + axis + agroup = (axis,) + shape_max = ( + tuple(1 for _ in shape) + if axis is None + else tuple(1 if i in agroup else d for i, d in enumerate(shape)) + ) + # create result arrays + if out is None: + if not keepdims: + out_shape = ( + tuple() + if axis is None + else tuple(d for i, d in enumerate(shape) if i not in agroup) + ) + else: + out_shape = shape_max + out = np.empty(out_shape, dtype=dtype) + if b_inv == 0: + return np.full_like(out, np.inf, dtype=dtype) if out.shape else np.inf + if b_inv is None and b == 0: + return np.full_like(out, -np.inf) if out.shape else -np.inf + ary_max = np.empty(shape_max, dtype=dtype) + # calculations + ary.max(axis=axis, keepdims=True, out=ary_max) + if copy: + ary = ary.copy() + ary -= ary_max + np.exp(ary, out=ary) + ary.sum(axis=axis, keepdims=keepdims, out=out) + np.log(out, out=out) + if b_inv is not None: + ary_max -= np.log(b_inv) + elif b: + ary_max += np.log(b) + out += ary_max.squeeze() if not keepdims else ary_max + # transform to scalar if possible + return out if out.shape else dtype(out) + + def loo(data, pointwise=False, reff=None): """Pareto-smoothed importance sampling leave-one-out cross-validation. @@ -346,11 +402,11 @@ def loo(data, pointwise=False, reff=None): ) warn_mg = 1 - loo_lppd_i = -2 * logsumexp(log_weights, axis=0) + loo_lppd_i = -2 * _logsumexp(log_weights, axis=0) loo_lppd = loo_lppd_i.sum() loo_lppd_se = (len(loo_lppd_i) * np.var(loo_lppd_i)) ** 0.5 - lppd = np.sum(logsumexp(log_likelihood, axis=0, b=1.0 / log_likelihood.shape[0])) + lppd = np.sum(_logsumexp(log_likelihood, axis=0, b_inv=log_likelihood.shape[0])) p_loo = lppd + (0.5 * loo_lppd) if pointwise: @@ -432,7 +488,7 @@ def psislw(log_weights, reff=1.0): # truncate smoothed values to the largest raw weight 0 x[x > 0] = 0 # renormalize weights - x -= logsumexp(x) + x -= _logsumexp(x) # store tail index k kss[i] = k @@ -845,7 +901,7 @@ def waic(data, pointwise=False): new_shape = (n_samples,) + log_likelihood.shape[2:] log_likelihood = log_likelihood.values.reshape(*new_shape) - lppd_i = logsumexp(log_likelihood, axis=0, b=1.0 / log_likelihood.shape[0]) + lppd_i = _logsumexp(log_likelihood, axis=0, b_inv=log_likelihood.shape[0]) vars_lpd = np.var(log_likelihood, axis=0) warn_mg = 0 diff --git a/arviz/tests/test_stats.py b/arviz/tests/test_stats.py index 16a040b2be..fa3abe8104 100644 --- a/arviz/tests/test_stats.py +++ b/arviz/tests/test_stats.py @@ -3,11 +3,13 @@ import numpy as np from numpy.testing import assert_almost_equal, assert_array_almost_equal, assert_array_less import pytest +from scipy.special import logsumexp from scipy.stats import linregress + from ..data import load_arviz_data from ..stats import bfmi, compare, hpd, loo, r2_score, waic, psislw, summary -from ..stats.stats import _gpinv, _mc_error +from ..stats.stats import _gpinv, _mc_error, _logsumexp @pytest.fixture(scope="session") @@ -140,11 +142,6 @@ def test_waic_bad(centered_eight): waic(centered_eight) -@pytest.mark.xfail( - reason="Issue #509. " - "Numerical accuracy (logsumexp) prevents function to throw a warning." - "See https://github.com/arviz-devs/arviz/issues/509" -) def test_waic_warning(centered_eight): centered_eight = deepcopy(centered_eight) centered_eight.sample_stats["log_likelihood"][:, :250, 1] = 10 @@ -217,3 +214,64 @@ def test_gpinv(probs, kappa, sigma): else: probs = np.array([-0.1, 0.1, 0.1, 0.2, 0.3]) assert len(_gpinv(probs, kappa, sigma)) == len(probs) + + +@pytest.mark.parametrize("ary_dtype", [np.float64, np.float32, np.int32, np.int64]) +@pytest.mark.parametrize("axis", [None, 0, 1, (-2, -1)]) +@pytest.mark.parametrize("b", [None, 0, 1 / 100, 1 / 101]) +@pytest.mark.parametrize("keepdims", [True, False]) +def test_logsumexp_b(ary_dtype, axis, b, keepdims): + """Test ArviZ implementation of logsumexp. + + Test also compares against Scipy implementation. + Case where b=None, they are equal. (N=len(ary)) + Second case where b=x, and x is 1/(number of elements), they are almost equal. + + Test tests against b parameter. + """ + np.random.seed(17) + ary = np.random.randn(100, 101).astype(ary_dtype) + assert _logsumexp(ary=ary, axis=axis, b=b, keepdims=keepdims, copy=True) is not None + ary = ary.copy() + assert _logsumexp(ary=ary, axis=axis, b=b, keepdims=keepdims, copy=False) is not None + out = np.empty(5) + assert _logsumexp(ary=np.random.randn(10, 5), axis=0, out=out) is not None + + # Scipy implementation + scipy_results = logsumexp(ary, b=b, axis=axis, keepdims=keepdims) + arviz_results = _logsumexp(ary, b=b, axis=axis, keepdims=keepdims) + + assert_array_almost_equal(scipy_results, arviz_results) + + +@pytest.mark.parametrize("ary_dtype", [np.float64, np.float32, np.int32, np.int64]) +@pytest.mark.parametrize("axis", [None, 0, 1, (-2, -1)]) +@pytest.mark.parametrize("b_inv", [None, 0, 100, 101]) +@pytest.mark.parametrize("keepdims", [True, False]) +def test_logsumexp_b_inv(ary_dtype, axis, b_inv, keepdims): + """Test ArviZ implementation of logsumexp. + + Test also compares against Scipy implementation. + Case where b=None, they are equal. (N=len(ary)) + Second case where b=x, and x is 1/(number of elements), they are almost equal. + + Test tests against b_inv parameter. + """ + np.random.seed(17) + ary = np.random.randn(100, 101).astype(ary_dtype) + assert _logsumexp(ary=ary, axis=axis, b_inv=b_inv, keepdims=keepdims, copy=True) is not None + ary = ary.copy() + assert _logsumexp(ary=ary, axis=axis, b_inv=b_inv, keepdims=keepdims, copy=False) is not None + out = np.empty(5) + assert _logsumexp(ary=np.random.randn(10, 5), axis=0, out=out) is not None + + if b_inv != 0: + # Scipy implementation when b_inv != 0 + if b_inv is not None: + b_scipy = 1 / b_inv + else: + b_scipy = None + scipy_results = logsumexp(ary, b=b_scipy, axis=axis, keepdims=keepdims) + arviz_results = _logsumexp(ary, b_inv=b_inv, axis=axis, keepdims=keepdims) + + assert_array_almost_equal(scipy_results, arviz_results)