Skip to content

Commit

Permalink
Add test for cumulative_logsumexp for geometric series summation, tes…
Browse files Browse the repository at this point in the history
…ting against closed form
  • Loading branch information
oleksandr-pavlyk committed Mar 31, 2024
1 parent 6520b8b commit 4bd02b4
Showing 1 changed file with 29 additions and 3 deletions.
32 changes: 29 additions & 3 deletions dpctl/tests/test_tensor_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,34 @@ def test_logcumsumexp_basic():
x = dpt.ones(10, dtype=dt)
r = dpt.cumulative_logsumexp(x)

x_np = dpt.asnumpy(x)
expected = np.logaddexp.accumulate(x_np, dtype=dt)
expected = 1 + np.log(np.arange(1, 11, dtype=dt))

tol = 32 * dpt.finfo(dt).resolution
tol = 4 * dpt.finfo(dt).resolution
assert np.allclose(dpt.asnumpy(r), expected, atol=tol, rtol=tol)


def geometric_series_closed_form(n, dtype=None, device=None):
"""Closed form for cumulative_logsumexp(dpt.arange(-n, 0))
:math:`r[k] == -n + k + log(1 - exp(-k-1)) - log(1-exp(-1))`
"""
x = dpt.arange(-n, 0, dtype=dtype, device=device)
y = dpt.arange(-1, -n - 1, step=-1, dtype=dtype, device=device)
y = dpt.exp(y, out=y)
y = dpt.negative(y, out=y)
y = dpt.log1p(y, out=y)
y -= y[0]
return x + y


@pytest.mark.parametrize("fpdt", rfp_types)
def test_cumulative_logsumexp_closed_form(fpdt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(fpdt, q)

n = 128
r = dpt.cumulative_logsumexp(dpt.arange(-n, 0, dtype=fpdt, device=q))
expected = geometric_series_closed_form(n, dtype=fpdt, device=q)

tol = 4 * dpt.finfo(fpdt).eps
assert dpt.allclose(r, expected, atol=tol, rtol=tol)

0 comments on commit 4bd02b4

Please sign in to comment.