@@ -379,8 +379,34 @@ def test_logcumsumexp_basic():
379379 x = dpt .ones (10 , dtype = dt )
380380 r = dpt .cumulative_logsumexp (x )
381381
382- x_np = dpt .asnumpy (x )
383- expected = np .logaddexp .accumulate (x_np , dtype = dt )
382+ expected = 1 + np .log (np .arange (1 , 11 , dtype = dt ))
384383
385- tol = 32 * dpt .finfo (dt ).resolution
384+ tol = 4 * dpt .finfo (dt ).resolution
386385 assert np .allclose (dpt .asnumpy (r ), expected , atol = tol , rtol = tol )
386+
387+
388+ def geometric_series_closed_form (n , dtype = None , device = None ):
389+ """Closed form for cumulative_logsumexp(dpt.arange(-n, 0))
390+
391+ :math:`r[k] == -n + k + log(1 - exp(-k-1)) - log(1-exp(-1))`
392+ """
393+ x = dpt .arange (- n , 0 , dtype = dtype , device = device )
394+ y = dpt .arange (- 1 , - n - 1 , step = - 1 , dtype = dtype , device = device )
395+ y = dpt .exp (y , out = y )
396+ y = dpt .negative (y , out = y )
397+ y = dpt .log1p (y , out = y )
398+ y -= y [0 ]
399+ return x + y
400+
401+
402+ @pytest .mark .parametrize ("fpdt" , rfp_types )
403+ def test_cumulative_logsumexp_closed_form (fpdt ):
404+ q = get_queue_or_skip ()
405+ skip_if_dtype_not_supported (fpdt , q )
406+
407+ n = 128
408+ r = dpt .cumulative_logsumexp (dpt .arange (- n , 0 , dtype = fpdt , device = q ))
409+ expected = geometric_series_closed_form (n , dtype = fpdt , device = q )
410+
411+ tol = 4 * dpt .finfo (fpdt ).eps
412+ assert dpt .allclose (r , expected , atol = tol , rtol = tol )
0 commit comments