From 695c23e445e9f25480685748c66474c5acda0c4a Mon Sep 17 00:00:00 2001 From: Ricardo Date: Thu, 21 Jan 2021 18:32:53 +0100 Subject: [PATCH] Do not change input inplace --- pymc3/math.py | 9 +++++---- pymc3/tests/test_math.py | 3 +++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pymc3/math.py b/pymc3/math.py index 7516c26f138..aff54d13b71 100644 --- a/pymc3/math.py +++ b/pymc3/math.py @@ -244,11 +244,12 @@ def log1mexp_numpy(x): https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf """ x = np.asarray(x) - mask = x < 0.6931471805599453 - x[mask] = np.log(-np.expm1(-x[mask])) + out = np.empty_like(x) + mask = x < 0.6931471805599453 # log(2) + out[mask] = np.log(-np.expm1(-x[mask])) mask = ~mask - x[mask] = np.log1p(-np.exp(-x[mask])) - return x + out[mask] = np.log1p(-np.exp(-x[mask])) + return out def flatten_list(tensors): diff --git a/pymc3/tests/test_math.py b/pymc3/tests/test_math.py index 005bf8bdf7e..b31319021fd 100644 --- a/pymc3/tests/test_math.py +++ b/pymc3/tests/test_math.py @@ -133,6 +133,7 @@ def test_log1pexp(): def test_log1mexp(): vals = np.array([-1, 0, 1e-20, 1e-4, 10, 100, 1e20]) + vals_ = vals.copy() # import mpmath # mpmath.mp.dps = 1000 # [float(mpmath.log(1 - mpmath.exp(-x))) for x in vals] @@ -151,6 +152,8 @@ def test_log1mexp(): npt.assert_allclose(actual, expected) actual_ = log1mexp_numpy(vals) npt.assert_allclose(actual_, expected) + # Check that input was not changed in place + npt.assert_allclose(vals, vals_) def test_log1mexp_numpy_no_warning():