Skip to content

Commit 7ac088c

Browse files
Merge pull request #20699 from pearu:pearu/gammainc
PiperOrigin-RevId: 735878582
2 parents 99c9106 + 82b2591 commit 7ac088c

File tree

4 files changed

+45
-20
lines changed

4 files changed

+45
-20
lines changed

jax/_src/lax/special.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -303,15 +303,16 @@ def body_fn(vals):
303303

304304
def igamma_impl(a, x, *, dtype):
305305
is_nan = bitwise_or(_isnan(a), _isnan(x))
306-
x_is_zero = eq(x, _const(x, 0))
307306
x_is_infinity = eq(x, _const(x, float('inf')))
308-
domain_error = bitwise_or(lt(x, _const(x, 0)), le(a, _const(a, 0)))
309-
use_igammac = bitwise_and(gt(x, _const(x, 1)), gt(x, a))
307+
a_is_zero = eq(a, _const(a, 0))
308+
x_is_zero = eq(x, _const(x, 0))
309+
domain_error = _reduce(bitwise_or, [lt(x, _const(x, 0)), lt(a, _const(a, 0)), bitwise_and(a_is_zero, x_is_zero)])
310+
311+
use_igammac = bitwise_and(ge(x, _const(x, 1)), gt(x, a))
310312
ax = a * log(x) - x - lgamma(a)
311313
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
312314
ax = exp(ax)
313-
enabled = bitwise_not(
314-
_reduce(bitwise_or,[x_is_zero, domain_error, underflow, is_nan]))
315+
enabled = bitwise_not(_reduce(bitwise_or, [x_is_zero, domain_error, underflow, is_nan, x_is_infinity]))
315316

316317
output = select(
317318
use_igammac,
@@ -323,8 +324,7 @@ def igamma_impl(a, x, *, dtype):
323324
)
324325
output = select(x_is_zero, full_like(a, 0), output)
325326
output = select(x_is_infinity, full_like(a, 1), output)
326-
output = select(bitwise_or(domain_error, is_nan),
327-
full_like(a, float('nan')), output)
327+
output = select(domain_error, full_like(a, float('nan')), output)
328328
return output
329329

330330
def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode):
@@ -433,22 +433,26 @@ def body_fn(vals):
433433
raise ValueError(f"Invalid mode: {mode}")
434434

435435
def igammac_impl(a, x, *, dtype):
436-
out_of_range = bitwise_or(le(x, _const(x, 0)), le(a, _const(a, 0)))
436+
is_nan = bitwise_or(_isnan(a), _isnan(x))
437+
a_is_zero = eq(a, _const(a, 0))
438+
x_is_zero = eq(x, _const(x, 0))
439+
x_is_infinity = eq(x, _const(x, float('inf')))
440+
domain_error = _reduce(bitwise_or, [lt(x, _const(x, 0)), lt(a, _const(a, 0)), bitwise_and(a_is_zero, x_is_zero)])
437441
use_igamma = bitwise_or(lt(x, _const(x, 1)), lt(x, a))
438442
ax = a * log(x) - x - lgamma(a)
439443
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
440-
enabled = bitwise_not(bitwise_or(out_of_range, underflow))
444+
enabled = bitwise_not(_reduce(bitwise_or, [domain_error, underflow, is_nan, x_is_infinity, a_is_zero]))
441445
ax = exp(ax)
442446

443447
igamma_call = _igamma_series(ax, x, a, bitwise_and(enabled, use_igamma),
444448
dtype, IgammaMode.VALUE)
445449
igammac_cf_call = _igammac_continued_fraction(ax, x, a,
446450
bitwise_and(enabled, bitwise_not(use_igamma)), dtype, IgammaMode.VALUE)
447451

448-
result = select(use_igamma, _const(a, 1) - igamma_call, igammac_cf_call)
449-
x_is_infinity = eq(x, _const(x, float('inf')))
450-
result = select(x_is_infinity, full_like(result, 0), result)
451-
return select(out_of_range, full_like(a, 1), result)
452+
output = select(use_igamma, _const(a, 1) - igamma_call, igammac_cf_call)
453+
output = select(bitwise_or(x_is_infinity, a_is_zero), full_like(output, 0), output)
454+
output = select(domain_error, full_like(a, float('nan')), output)
455+
return output
452456

453457
def igamma_grad_a_impl(a, x, *, dtype):
454458
is_nan = bitwise_or(_isnan(a), _isnan(x))

jax/_src/scipy/stats/gamma.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) ->
198198
- :func:`jax.scipy.stats.gamma.logsf`
199199
"""
200200
x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale)
201-
return gammaincc(a, lax.div(lax.sub(x, loc), scale))
201+
y = lax.div(lax.sub(x, loc), scale)
202+
return jnp.where(lax.lt(y, _lax_const(y, 0)), 1, gammaincc(a, y))
202203

203204

204205
def logsf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:

jax/experimental/jax2tf/tests/jax2tf_limitations.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -865,15 +865,15 @@ def igammac(cls, harness: test_harnesses.Harness):
865865
def custom_assert(tst, result_jax, result_tf, *, args, tol,
866866
err_msg): # noqa: F811
867867
arg1, arg2 = args
868-
# lax.igammac returns 1. when arg1 <= 0; tf.math.igammac returns NaN
868+
# lax.igammac returns nan. when arg1 <= 0; tf.math.igammac returns 1
869869
special_cases = (arg1 <= 0.) | (arg2 <= 0)
870870
nr_special_cases = np.count_nonzero(special_cases)
871871
tst.assertAllClose(
872-
np.full((nr_special_cases,), 1., dtype=dtype),
872+
np.full((nr_special_cases,), np.nan, dtype=dtype),
873873
result_jax[special_cases],
874874
err_msg=err_msg)
875875
tst.assertAllClose(
876-
np.full((nr_special_cases,), np.nan, dtype=dtype),
876+
np.full((nr_special_cases,), 1, dtype=dtype),
877877
result_tf[special_cases],
878878
err_msg=err_msg)
879879
# non-special cases are equal
@@ -892,12 +892,12 @@ def custom_assert(tst, result_jax, result_tf, *, args, tol,
892892
custom_numeric(dtypes=[np.float64], tol=1e-9),
893893
custom_numeric(devices="gpu", tol=1e-3),
894894
custom_numeric(
895+
modes=("compiled",),
895896
custom_assert=custom_assert,
896-
devices=("cpu", "gpu"),
897+
devices=("cpu", "gpu", "tpu"),
897898
description=(
898899
"May return different results at undefined points "
899-
"(both arguments less or equal 0). JAX returns `NaN` and TF returns 0 or "
900-
"JAX returns 1 and TF returns `NaN`")),
900+
"(both arguments less or equal 0). JAX returns `NaN` and TF returns 1")),
901901
]
902902

903903
@classmethod

tests/lax_scipy_special_functions_test.py

+20
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,26 @@ def testExpiDisableJit(self):
287287
result_jit = lsp_special.expi(x)
288288
self.assertAllClose(result_jit, result_nojit)
289289

290+
def testGammaIncBoundaryValues(self):
291+
dtype = jax.numpy.zeros(0).dtype # default float dtype.
292+
nan = float('nan')
293+
inf = float('inf')
294+
args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan]).astype(dtype),
295+
np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf]).astype(dtype)]
296+
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5
297+
self._CheckAgainstNumpy(osp_special.gammainc, lsp_special.gammainc, args_maker, rtol=rtol)
298+
self._CompileAndCheck(lsp_special.gammainc, args_maker, rtol=rtol)
299+
300+
def testGammaIncCBoundaryValues(self):
301+
dtype = jax.numpy.zeros(0).dtype # default float dtype.
302+
nan = float('nan')
303+
inf = float('inf')
304+
args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan, 1]).astype(dtype),
305+
np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf, -1]).astype(dtype)]
306+
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5
307+
self._CheckAgainstNumpy(osp_special.gammaincc, lsp_special.gammaincc, args_maker, rtol=rtol)
308+
self._CompileAndCheck(lsp_special.gammaincc, args_maker, rtol=rtol)
309+
290310

291311
if __name__ == "__main__":
292312
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)