Skip to content

Commit

Permalink
Do not change input inplace
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 21, 2021
1 parent 62c6326 commit 695c23e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
9 changes: 5 additions & 4 deletions pymc3/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,12 @@ def log1mexp_numpy(x):
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
"""
x = np.asarray(x)
mask = x < 0.6931471805599453
x[mask] = np.log(-np.expm1(-x[mask]))
out = np.empty_like(x)
mask = x < 0.6931471805599453 # log(2)
out[mask] = np.log(-np.expm1(-x[mask]))
mask = ~mask
x[mask] = np.log1p(-np.exp(-x[mask]))
return x
out[mask] = np.log1p(-np.exp(-x[mask]))
return out


def flatten_list(tensors):
Expand Down
3 changes: 3 additions & 0 deletions pymc3/tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def test_log1pexp():

def test_log1mexp():
vals = np.array([-1, 0, 1e-20, 1e-4, 10, 100, 1e20])
vals_ = vals.copy()
# import mpmath
# mpmath.mp.dps = 1000
# [float(mpmath.log(1 - mpmath.exp(-x))) for x in vals]
Expand All @@ -151,6 +152,8 @@ def test_log1mexp():
npt.assert_allclose(actual, expected)
actual_ = log1mexp_numpy(vals)
npt.assert_allclose(actual_, expected)
# Check that input was not changed in place
npt.assert_allclose(vals, vals_)


def test_log1mexp_numpy_no_warning():
Expand Down

0 comments on commit 695c23e

Please sign in to comment.