From b36d9ba91403a5ac0bf719fb02d886f37ebdc851 Mon Sep 17 00:00:00 2001 From: Hartikainen Ari Date: Sat, 12 Jan 2019 01:46:28 +0200 Subject: [PATCH 01/11] add stable logsumexp --- .pylintrc | 3 ++- arviz/stats/stats.py | 41 ++++++++++++++++++++++++++++++++++++----- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/.pylintrc b/.pylintrc index 4f6b41dfb8..c1fdab2c4d 100644 --- a/.pylintrc +++ b/.pylintrc @@ -233,7 +233,8 @@ function-naming-style=snake_case #function-rgx= # Good variable names which should always be accepted, separated by a comma -good-names=i, +good-names=b, + i, j, k, t, diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 7837646e54..dfd0d290e4 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -1,9 +1,9 @@ """Statistical functions in ArviZ.""" import warnings +from collections.abc import Sequence import numpy as np import pandas as pd -from scipy.special import logsumexp import scipy.stats as st from scipy.optimize import minimize import xarray as xr @@ -281,6 +281,37 @@ def hpd(x, credible_interval=0.94, circular=False): return np.array([hdi_min, hdi_max]) +def _logsumexp(ary, *, b=0, b_inv=None, axis=None, keepdims=False, out=None, copy=True): + """stable logsumexp when b>0""" + shape = ary.shape + shape_len = len(shape) + if isinstance(axis, Sequence): + axis = tuple(axis_i if axis_i >= 0 else shape_len + axis_i for axis_i in axis) + else: + axis = (axis if axis >= 0 else shape_len + axis,) + shape_max = tuple(1 if i in axis else d for i, d in enumerate(shape)) + if out is None: + if not keepdims: + out_shape = tuple(d for i, d in enumerate(shape) if i not in axis) + else: + out_shape = shape_max + out = np.empty(out_shape) + ary_max = np.empty(shape_max) + ary.max(axis=axis, keepdims=True, out=ary_max) + if copy: + ary = ary.copy() + ary += ary_max + np.exp(ary, out=ary) + ary.sum(axis=axis, keepdims=keepdims, out=out) + np.log(out, out=out) + if b_inv is not None: + ary_max -= np.log(b_inv) + elif b: + ary_max += np.log(b) + out += ary_max.squeeze() if not keepdims else ary_max + return out + + def loo(data, pointwise=False, reff=None): """Pareto-smoothed importance sampling leave-one-out cross-validation. @@ -346,11 +377,11 @@ def loo(data, pointwise=False, reff=None): ) warn_mg = 1 - loo_lppd_i = -2 * logsumexp(log_weights, axis=0) + loo_lppd_i = -2 * _logsumexp(log_weights, axis=0) loo_lppd = loo_lppd_i.sum() loo_lppd_se = (len(loo_lppd_i) * np.var(loo_lppd_i)) ** 0.5 - lppd = np.sum(logsumexp(log_likelihood, axis=0, b=1.0 / log_likelihood.shape[0])) + lppd = np.sum(_logsumexp(log_likelihood, axis=0, b=1.0 / log_likelihood.shape[0])) p_loo = lppd + (0.5 * loo_lppd) if pointwise: @@ -432,7 +463,7 @@ def psislw(log_weights, reff=1.0): # truncate smoothed values to the largest raw weight 0 x[x > 0] = 0 # renormalize weights - x -= logsumexp(x) + x -= _logsumexp(x) # store tail index k kss[i] = k @@ -845,7 +876,7 @@ def waic(data, pointwise=False): new_shape = (n_samples,) + log_likelihood.shape[2:] log_likelihood = log_likelihood.values.reshape(*new_shape) - lppd_i = logsumexp(log_likelihood, axis=0, b=1.0 / log_likelihood.shape[0]) + lppd_i = _logsumexp(log_likelihood, axis=0, b=1.0 / log_likelihood.shape[0]) vars_lpd = np.var(log_likelihood, axis=0) warn_mg = 0 From b750bd51a3614889a56deca5bd9364a467a44327 Mon Sep 17 00:00:00 2001 From: Hartikainen Ari Date: Sat, 12 Jan 2019 01:51:16 +0200 Subject: [PATCH 02/11] use b_inv --- arviz/stats/stats.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index dfd0d290e4..a100e424b1 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -283,6 +283,7 @@ def hpd(x, credible_interval=0.94, circular=False): def _logsumexp(ary, *, b=0, b_inv=None, axis=None, keepdims=False, out=None, copy=True): """stable logsumexp when b>0""" + ary = np.asarray(ary) shape = ary.shape shape_len = len(shape) if isinstance(axis, Sequence): @@ -381,7 +382,7 @@ def loo(data, pointwise=False, reff=None): loo_lppd = loo_lppd_i.sum() loo_lppd_se = (len(loo_lppd_i) * np.var(loo_lppd_i)) ** 0.5 - lppd = np.sum(_logsumexp(log_likelihood, axis=0, b=1.0 / log_likelihood.shape[0])) + lppd = np.sum(_logsumexp(log_likelihood, axis=0, b_inv=log_likelihood.shape[0])) p_loo = lppd + (0.5 * loo_lppd) if pointwise: @@ -876,7 +877,7 @@ def waic(data, pointwise=False): new_shape = (n_samples,) + log_likelihood.shape[2:] log_likelihood = log_likelihood.values.reshape(*new_shape) - lppd_i = _logsumexp(log_likelihood, axis=0, b=1.0 / log_likelihood.shape[0]) + lppd_i = _logsumexp(log_likelihood, axis=0, b_inv=log_likelihood.shape[0]) vars_lpd = np.var(log_likelihood, axis=0) warn_mg = 0 From 5e62b1d37cd4d1693ba2009b7be7e5f20eab1ad8 Mon Sep 17 00:00:00 2001 From: Hartikainen Ari Date: Sat, 12 Jan 2019 02:23:09 +0200 Subject: [PATCH 03/11] handle none --- arviz/stats/stats.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index a100e424b1..86fd360bf0 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -282,14 +282,14 @@ def hpd(x, credible_interval=0.94, circular=False): def _logsumexp(ary, *, b=0, b_inv=None, axis=None, keepdims=False, out=None, copy=True): - """stable logsumexp when b>0""" + """Stable logsumexp when b >= 0.""" ary = np.asarray(ary) shape = ary.shape shape_len = len(shape) if isinstance(axis, Sequence): axis = tuple(axis_i if axis_i >= 0 else shape_len + axis_i for axis_i in axis) else: - axis = (axis if axis >= 0 else shape_len + axis,) + axis = (axis if (axis is None) or (axis >= 0) else shape_len + axis,) shape_max = tuple(1 if i in axis else d for i, d in enumerate(shape)) if out is None: if not keepdims: From 0ebc3315fb06d3f86d3216d0a70b958c7a18c25b Mon Sep 17 00:00:00 2001 From: Ari Hartikainen Date: Sat, 12 Jan 2019 07:07:50 +0200 Subject: [PATCH 04/11] Change comparison --- arviz/stats/stats.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 86fd360bf0..d67a83ee58 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -288,12 +288,14 @@ def _logsumexp(ary, *, b=0, b_inv=None, axis=None, keepdims=False, out=None, cop shape_len = len(shape) if isinstance(axis, Sequence): axis = tuple(axis_i if axis_i >= 0 else shape_len + axis_i for axis_i in axis) + agroup = axis else: - axis = (axis if (axis is None) or (axis >= 0) else shape_len + axis,) - shape_max = tuple(1 if i in axis else d for i, d in enumerate(shape)) + axis = axis if (axis is None) or (axis >= 0) else shape_len + axis + agroup = axis, + shape_max = tuple(1 if i in agroup else d for i, d in enumerate(shape)) if out is None: if not keepdims: - out_shape = tuple(d for i, d in enumerate(shape) if i not in axis) + out_shape = tuple(d for i, d in enumerate(shape) if i not in agroup) else: out_shape = shape_max out = np.empty(out_shape) From 4da0893b04271666d91679fa9cd5110a935d1097 Mon Sep 17 00:00:00 2001 From: Hartikainen Ari Date: Sat, 12 Jan 2019 09:43:34 +0200 Subject: [PATCH 05/11] fix dtype casting and axis is None --- arviz/stats/stats.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index d67a83ee58..ab3b2f7d1a 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -284,6 +284,9 @@ def hpd(x, credible_interval=0.94, circular=False): def _logsumexp(ary, *, b=0, b_inv=None, axis=None, keepdims=False, out=None, copy=True): """Stable logsumexp when b >= 0.""" ary = np.asarray(ary) + if ary.dtype.kind == "i": + ary = ary.astype(np.float64) + dtype = ary.dtype.type shape = ary.shape shape_len = len(shape) if isinstance(axis, Sequence): @@ -291,15 +294,23 @@ def _logsumexp(ary, *, b=0, b_inv=None, axis=None, keepdims=False, out=None, cop agroup = axis else: axis = axis if (axis is None) or (axis >= 0) else shape_len + axis - agroup = axis, - shape_max = tuple(1 if i in agroup else d for i, d in enumerate(shape)) + agroup = (axis,) + shape_max = ( + tuple(1 for _ in shape) + if axis is None + else tuple(1 if i in agroup else d for i, d in enumerate(shape)) + ) if out is None: if not keepdims: - out_shape = tuple(d for i, d in enumerate(shape) if i not in agroup) + out_shape = ( + tuple() + if axis is None + else tuple(d for i, d in enumerate(shape) if i not in agroup) + ) else: out_shape = shape_max - out = np.empty(out_shape) - ary_max = np.empty(shape_max) + out = np.empty(out_shape, dtype=dtype) + ary_max = np.empty(shape_max, dtype=dtype) ary.max(axis=axis, keepdims=True, out=ary_max) if copy: ary = ary.copy() @@ -312,7 +323,7 @@ def _logsumexp(ary, *, b=0, b_inv=None, axis=None, keepdims=False, out=None, cop elif b: ary_max += np.log(b) out += ary_max.squeeze() if not keepdims else ary_max - return out + return out if out.shape else dtype(out) def loo(data, pointwise=False, reff=None): From 7851f79881e8add5f6c8d3f881ad033d376d4d96 Mon Sep 17 00:00:00 2001 From: Hartikainen Ari Date: Sat, 12 Jan 2019 10:51:44 +0200 Subject: [PATCH 06/11] fix test back and fix logsumexp missing subtraction --- arviz/stats/stats.py | 2 +- arviz/tests/test_stats.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index ab3b2f7d1a..e9d6b30549 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -314,7 +314,7 @@ def _logsumexp(ary, *, b=0, b_inv=None, axis=None, keepdims=False, out=None, cop ary.max(axis=axis, keepdims=True, out=ary_max) if copy: ary = ary.copy() - ary += ary_max + ary -= ary_max np.exp(ary, out=ary) ary.sum(axis=axis, keepdims=keepdims, out=out) np.log(out, out=out) diff --git a/arviz/tests/test_stats.py b/arviz/tests/test_stats.py index 16a040b2be..7569e97db1 100644 --- a/arviz/tests/test_stats.py +++ b/arviz/tests/test_stats.py @@ -140,11 +140,6 @@ def test_waic_bad(centered_eight): waic(centered_eight) -@pytest.mark.xfail( - reason="Issue #509. " - "Numerical accuracy (logsumexp) prevents function to throw a warning." - "See https://github.com/arviz-devs/arviz/issues/509" -) def test_waic_warning(centered_eight): centered_eight = deepcopy(centered_eight) centered_eight.sample_stats["log_likelihood"][:, :250, 1] = 10 From 4e2c9a0d923ea4cab60a2f8a53668e209d3eafdd Mon Sep 17 00:00:00 2001 From: Hartikainen Ari Date: Sat, 12 Jan 2019 12:53:11 +0200 Subject: [PATCH 07/11] more tests --- arviz/tests/test_stats.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/arviz/tests/test_stats.py b/arviz/tests/test_stats.py index 7569e97db1..19369a3dbf 100644 --- a/arviz/tests/test_stats.py +++ b/arviz/tests/test_stats.py @@ -7,7 +7,7 @@ from ..data import load_arviz_data from ..stats import bfmi, compare, hpd, loo, r2_score, waic, psislw, summary -from ..stats.stats import _gpinv, _mc_error +from ..stats.stats import _gpinv, _mc_error, _logsumexp @pytest.fixture(scope="session") @@ -212,3 +212,26 @@ def test_gpinv(probs, kappa, sigma): else: probs = np.array([-0.1, 0.1, 0.1, 0.2, 0.3]) assert len(_gpinv(probs, kappa, sigma)) == len(probs) + +@pytest.mark.parametrize("ary" : np.random.randn(100,101).astype(np.float64)) +@pytest.mark.parametrize("ary" : np.random.randn(100,101).astype(np.float32)) +@pytest.mark.parametrize("ary" : np.random.randn(100,101).astype(np.int32)) +@pytest.mark.parametrize("ary" : np.random.randn(100,101).astype(np.int64)) +@pytest.mark.parametrize("axis" : None) +@pytest.mark.parametrize("axis" : 0) +@pytest.mark.parametrize("axis" : 1) +@pytest.mark.parametrize("axis" : (-2, -1)) +@pytest.mark.parametrize("b" : 0) +@pytest.mark.parametrize("b" : 1/100) +@pytest.mark.parametrize("b" : 1/101) +@pytest.mark.parametrize("b_inv" : None) +@pytest.mark.parametrize("b_inv" : 100) +@pytest.mark.parametrize("b_inv" : 101) +@pytest.mark.parametrize("keepdims" : 1/100) +@pytest.mark.parametrize("keepdims" : 1/101) +def test_logsumexp(ary, axis, b, b_inv): + assert _logsumexp(ary=ary, axis=axis, b=b, b_inv=b_inv, keepdims=keepdims, copy=True) not None + ary = ary.copy() + assert _logsumexp(ary=ary, axis=axis, b=b, b_inv=b_inv, keepdims=keepdims, copy=False) not None + out = np.empty(5) + assert _logsumexp(ary=np.random.randn(10,5), axis=0) From 35e342e8b4dd7ae07d343cb5c1846dd119024bc5 Mon Sep 17 00:00:00 2001 From: Hartikainen Ari Date: Sat, 12 Jan 2019 13:01:30 +0200 Subject: [PATCH 08/11] fix tests --- arviz/tests/test_stats.py | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/arviz/tests/test_stats.py b/arviz/tests/test_stats.py index 19369a3dbf..27e0261157 100644 --- a/arviz/tests/test_stats.py +++ b/arviz/tests/test_stats.py @@ -213,25 +213,20 @@ def test_gpinv(probs, kappa, sigma): probs = np.array([-0.1, 0.1, 0.1, 0.2, 0.3]) assert len(_gpinv(probs, kappa, sigma)) == len(probs) -@pytest.mark.parametrize("ary" : np.random.randn(100,101).astype(np.float64)) -@pytest.mark.parametrize("ary" : np.random.randn(100,101).astype(np.float32)) -@pytest.mark.parametrize("ary" : np.random.randn(100,101).astype(np.int32)) -@pytest.mark.parametrize("ary" : np.random.randn(100,101).astype(np.int64)) -@pytest.mark.parametrize("axis" : None) -@pytest.mark.parametrize("axis" : 0) -@pytest.mark.parametrize("axis" : 1) -@pytest.mark.parametrize("axis" : (-2, -1)) -@pytest.mark.parametrize("b" : 0) -@pytest.mark.parametrize("b" : 1/100) -@pytest.mark.parametrize("b" : 1/101) -@pytest.mark.parametrize("b_inv" : None) -@pytest.mark.parametrize("b_inv" : 100) -@pytest.mark.parametrize("b_inv" : 101) -@pytest.mark.parametrize("keepdims" : 1/100) -@pytest.mark.parametrize("keepdims" : 1/101) -def test_logsumexp(ary, axis, b, b_inv): - assert _logsumexp(ary=ary, axis=axis, b=b, b_inv=b_inv, keepdims=keepdims, copy=True) not None + +@pytest.mark.parametrize("ary_dtype", [np.float64, np.float32, np.int32, np.int64]) +@pytest.mark.parametrize("axis", [None, 0, 1, (-2, -1)]) +@pytest.mark.parametrize("b", [None, 0, 1 / 100, 1 / 101]) +@pytest.mark.parametrize("b_inv", [None, 100, 101]) +@pytest.mark.parametrize("keepdims", [True, False]) +def test_logsumexp(ary_dtype, axis, b, b_inv, keepdims): + ary = np.random.randn(100, 101).astype(ary_dtype) + assert ( + _logsumexp(ary=ary, axis=axis, b=b, b_inv=b_inv, keepdims=keepdims, copy=True) is not None + ) ary = ary.copy() - assert _logsumexp(ary=ary, axis=axis, b=b, b_inv=b_inv, keepdims=keepdims, copy=False) not None + assert ( + _logsumexp(ary=ary, axis=axis, b=b, b_inv=b_inv, keepdims=keepdims, copy=False) is not None + ) out = np.empty(5) - assert _logsumexp(ary=np.random.randn(10,5), axis=0) + assert _logsumexp(ary=np.random.randn(10, 5), axis=0, out=out) is not None From 4697100890d8d3fcdc97067e7087d65b966009fa Mon Sep 17 00:00:00 2001 From: Hartikainen Ari Date: Sun, 13 Jan 2019 07:45:43 +0200 Subject: [PATCH 09/11] add test against scipy --- arviz/stats/stats.py | 4 ++++ arviz/tests/test_stats.py | 21 +++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index e9d6b30549..e251cd2f6d 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -283,6 +283,7 @@ def hpd(x, credible_interval=0.94, circular=False): def _logsumexp(ary, *, b=0, b_inv=None, axis=None, keepdims=False, out=None, copy=True): """Stable logsumexp when b >= 0.""" + # check dimensions for result arrays ary = np.asarray(ary) if ary.dtype.kind == "i": ary = ary.astype(np.float64) @@ -300,6 +301,7 @@ def _logsumexp(ary, *, b=0, b_inv=None, axis=None, keepdims=False, out=None, cop if axis is None else tuple(1 if i in agroup else d for i, d in enumerate(shape)) ) + # create result arrays if out is None: if not keepdims: out_shape = ( @@ -311,6 +313,7 @@ def _logsumexp(ary, *, b=0, b_inv=None, axis=None, keepdims=False, out=None, cop out_shape = shape_max out = np.empty(out_shape, dtype=dtype) ary_max = np.empty(shape_max, dtype=dtype) + # calculations ary.max(axis=axis, keepdims=True, out=ary_max) if copy: ary = ary.copy() @@ -323,6 +326,7 @@ def _logsumexp(ary, *, b=0, b_inv=None, axis=None, keepdims=False, out=None, cop elif b: ary_max += np.log(b) out += ary_max.squeeze() if not keepdims else ary_max + # transform to scalar if possible return out if out.shape else dtype(out) diff --git a/arviz/tests/test_stats.py b/arviz/tests/test_stats.py index 27e0261157..2393031c86 100644 --- a/arviz/tests/test_stats.py +++ b/arviz/tests/test_stats.py @@ -3,8 +3,10 @@ import numpy as np from numpy.testing import assert_almost_equal, assert_array_almost_equal, assert_array_less import pytest +from scipy.special import logsumexp from scipy.stats import linregress + from ..data import load_arviz_data from ..stats import bfmi, compare, hpd, loo, r2_score, waic, psislw, summary from ..stats.stats import _gpinv, _mc_error, _logsumexp @@ -220,6 +222,12 @@ def test_gpinv(probs, kappa, sigma): @pytest.mark.parametrize("b_inv", [None, 100, 101]) @pytest.mark.parametrize("keepdims", [True, False]) def test_logsumexp(ary_dtype, axis, b, b_inv, keepdims): + """Test ArviZ implementation of logsumexp. + + Test also compares against Scipy implementation. + Case where b=0, they are equal. + Second case where b=x, and x is 1/(number of elements), they are almost equal. + """ ary = np.random.randn(100, 101).astype(ary_dtype) assert ( _logsumexp(ary=ary, axis=axis, b=b, b_inv=b_inv, keepdims=keepdims, copy=True) is not None @@ -230,3 +238,16 @@ def test_logsumexp(ary_dtype, axis, b, b_inv, keepdims): ) out = np.empty(5) assert _logsumexp(ary=np.random.randn(10, 5), axis=0, out=out) is not None + + # Scipy implementation + if b_inv is not None: + b_scipy = 1 / b_inv + elif b is None: + if b_inv is None: + b_scipy = 0 + else: + b_scipy = 1 / b_inv + scipy_results = logsumexp(ary, b=b_scipy, axis=axis, keepdims=keepdims) + arviz_results = _logsumexp(ary, b=b, b_inv=b_inv, axis=axis, keepdims=keepdims) + + assert_array_almost_equal(scipy_results, arviz_results) From de100a799a2d7cea75bc88d9fdc92f3a51e2cb25 Mon Sep 17 00:00:00 2001 From: Hartikainen Ari Date: Mon, 14 Jan 2019 01:50:45 +0200 Subject: [PATCH 10/11] fix b, b_inv behaviour --- arviz/stats/stats.py | 11 +++++++-- arviz/tests/test_stats.py | 51 ++++++++++++++++++++++++++++----------- 2 files changed, 46 insertions(+), 16 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index e251cd2f6d..6c69496fe5 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -281,8 +281,11 @@ def hpd(x, credible_interval=0.94, circular=False): return np.array([hdi_min, hdi_max]) -def _logsumexp(ary, *, b=0, b_inv=None, axis=None, keepdims=False, out=None, copy=True): - """Stable logsumexp when b >= 0.""" +def _logsumexp(ary, *, b=None, b_inv=None, axis=None, keepdims=False, out=None, copy=True): + """Stable logsumexp when b >= 0 and b is scalar. + + b_inv overwrites b unless b_inv is None. + """ # check dimensions for result arrays ary = np.asarray(ary) if ary.dtype.kind == "i": @@ -312,6 +315,10 @@ def _logsumexp(ary, *, b=0, b_inv=None, axis=None, keepdims=False, out=None, cop else: out_shape = shape_max out = np.empty(out_shape, dtype=dtype) + if b_inv == 0: + return np.full_like(out, np.inf, dtype=dtype) if out.shape else np.inf + if b_inv is None and b == 0: + return np.full_like(out, -np.inf) if out.shape else -np.inf ary_max = np.empty(shape_max, dtype=dtype) # calculations ary.max(axis=axis, keepdims=True, out=ary_max) diff --git a/arviz/tests/test_stats.py b/arviz/tests/test_stats.py index 2393031c86..291bdf94d2 100644 --- a/arviz/tests/test_stats.py +++ b/arviz/tests/test_stats.py @@ -219,35 +219,58 @@ def test_gpinv(probs, kappa, sigma): @pytest.mark.parametrize("ary_dtype", [np.float64, np.float32, np.int32, np.int64]) @pytest.mark.parametrize("axis", [None, 0, 1, (-2, -1)]) @pytest.mark.parametrize("b", [None, 0, 1 / 100, 1 / 101]) +@pytest.mark.parametrize("keepdims", [True, False]) +def test_logsumexp_b(ary_dtype, axis, b, keepdims): + """Test ArviZ implementation of logsumexp. + + Test also compares against Scipy implementation. + Case where b=None, they are equal. (N=len(ary)) + Second case where b=x, and x is 1/(number of elements), they are almost equal. + + Test tests against b parameter. + """ + np.random.seed(17) + ary = np.random.randn(100, 101).astype(ary_dtype) + assert _logsumexp(ary=ary, axis=axis, b=b, keepdims=keepdims, copy=True) is not None + ary = ary.copy() + assert _logsumexp(ary=ary, axis=axis, b=b, keepdims=keepdims, copy=False) is not None + out = np.empty(5) + assert _logsumexp(ary=np.random.randn(10, 5), axis=0, out=out) is not None + + # Scipy implementation + scipy_results = logsumexp(ary, b=b, axis=axis, keepdims=keepdims) + arviz_results = _logsumexp(ary, b=b, axis=axis, keepdims=keepdims) + + assert_array_almost_equal(scipy_results, arviz_results) + + +@pytest.mark.parametrize("ary_dtype", [np.float64, np.float32, np.int32, np.int64]) +@pytest.mark.parametrize("axis", [None, 0, 1, (-2, -1)]) @pytest.mark.parametrize("b_inv", [None, 100, 101]) @pytest.mark.parametrize("keepdims", [True, False]) -def test_logsumexp(ary_dtype, axis, b, b_inv, keepdims): +def test_logsumexp_b_inv(ary_dtype, axis, b_inv, keepdims): """Test ArviZ implementation of logsumexp. Test also compares against Scipy implementation. - Case where b=0, they are equal. + Case where b=None, they are equal. (N=len(ary)) Second case where b=x, and x is 1/(number of elements), they are almost equal. + + Test tests against b_inv parameter. """ + np.random.seed(17) ary = np.random.randn(100, 101).astype(ary_dtype) - assert ( - _logsumexp(ary=ary, axis=axis, b=b, b_inv=b_inv, keepdims=keepdims, copy=True) is not None - ) + assert _logsumexp(ary=ary, axis=axis, b_inv=b_inv, keepdims=keepdims, copy=True) is not None ary = ary.copy() - assert ( - _logsumexp(ary=ary, axis=axis, b=b, b_inv=b_inv, keepdims=keepdims, copy=False) is not None - ) + assert _logsumexp(ary=ary, axis=axis, b_inv=b_inv, keepdims=keepdims, copy=False) is not None out = np.empty(5) assert _logsumexp(ary=np.random.randn(10, 5), axis=0, out=out) is not None # Scipy implementation if b_inv is not None: b_scipy = 1 / b_inv - elif b is None: - if b_inv is None: - b_scipy = 0 - else: - b_scipy = 1 / b_inv + else: + b_scipy = None scipy_results = logsumexp(ary, b=b_scipy, axis=axis, keepdims=keepdims) - arviz_results = _logsumexp(ary, b=b, b_inv=b_inv, axis=axis, keepdims=keepdims) + arviz_results = _logsumexp(ary, b_inv=b_inv, axis=axis, keepdims=keepdims) assert_array_almost_equal(scipy_results, arviz_results) From cae2218d461dad5aa3358ff15adde275f30a3e63 Mon Sep 17 00:00:00 2001 From: Hartikainen Ari Date: Mon, 14 Jan 2019 02:14:28 +0200 Subject: [PATCH 11/11] test when b_inv=0 --- arviz/tests/test_stats.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/arviz/tests/test_stats.py b/arviz/tests/test_stats.py index 291bdf94d2..fa3abe8104 100644 --- a/arviz/tests/test_stats.py +++ b/arviz/tests/test_stats.py @@ -246,7 +246,7 @@ def test_logsumexp_b(ary_dtype, axis, b, keepdims): @pytest.mark.parametrize("ary_dtype", [np.float64, np.float32, np.int32, np.int64]) @pytest.mark.parametrize("axis", [None, 0, 1, (-2, -1)]) -@pytest.mark.parametrize("b_inv", [None, 100, 101]) +@pytest.mark.parametrize("b_inv", [None, 0, 100, 101]) @pytest.mark.parametrize("keepdims", [True, False]) def test_logsumexp_b_inv(ary_dtype, axis, b_inv, keepdims): """Test ArviZ implementation of logsumexp. @@ -265,12 +265,13 @@ def test_logsumexp_b_inv(ary_dtype, axis, b_inv, keepdims): out = np.empty(5) assert _logsumexp(ary=np.random.randn(10, 5), axis=0, out=out) is not None - # Scipy implementation - if b_inv is not None: - b_scipy = 1 / b_inv - else: - b_scipy = None - scipy_results = logsumexp(ary, b=b_scipy, axis=axis, keepdims=keepdims) - arviz_results = _logsumexp(ary, b_inv=b_inv, axis=axis, keepdims=keepdims) + if b_inv != 0: + # Scipy implementation when b_inv != 0 + if b_inv is not None: + b_scipy = 1 / b_inv + else: + b_scipy = None + scipy_results = logsumexp(ary, b=b_scipy, axis=axis, keepdims=keepdims) + arviz_results = _logsumexp(ary, b_inv=b_inv, axis=axis, keepdims=keepdims) - assert_array_almost_equal(scipy_results, arviz_results) + assert_array_almost_equal(scipy_results, arviz_results)