Skip to content

Commit

Permalink
Add IgammaGradA (jax-ml#2504)
Browse files Browse the repository at this point in the history
  • Loading branch information
srvasude authored May 6, 2020
1 parent 25a0c3b commit e51c7d7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
22 changes: 13 additions & 9 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ def igammac(a: Array, x: Array) -> Array:
r"""Elementwise complementary regularized incomplete gamma function."""
return igammac_p.bind(_brcast(a, x), _brcast(x, a))

def igamma_grad_a(a: Array, x: Array) -> Array:
r"""Elementwise derivative of the regularized incomplete gamma function."""
return igamma_grad_a_p.bind(_brcast(a, x), _brcast(x, a))

def bessel_i0e(x: Array) -> Array:
r"""Exponentially scaled modified Bessel function of order 0:
:math:`\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)`
Expand Down Expand Up @@ -1919,25 +1923,25 @@ def betainc_grad_not_implemented(g, a, b, x):
digamma_p = standard_unop(_float, 'digamma')

igamma_p = standard_naryop([_float, _float], 'igamma')
igamma_grad_a_p = standard_naryop([_float, _float], 'igamma_grad_a')

def igamma_gradx(g, a, x):
return g * exp(-x + (a - 1.) * log(x) - lgamma(a))
return g * exp(-x + (a - _ones(a)) * log(x) - lgamma(a))

# TODO(srvasude): Igamma and Igammac gradient aren't supported with respect to
# a. We can reuse some of the reparameterization code in the JAX gamma sampler,
# but better to add an XLA op for this (which will also allow TF Igamma gradient
# code to be XLA compiled).
def gamma_grad_not_implemented(a, b, x):
raise ValueError("Igamma(c) gradient with respect to `a` not supported.")
def igamma_grada(g, a, x):
return g * igamma_grad_a(a, x)

ad.defjvp(igamma_p, gamma_grad_not_implemented, igamma_gradx)
ad.defjvp(igamma_p, igamma_grada, igamma_gradx)

igammac_p = standard_naryop([_float, _float], 'igammac')

def igammac_gradx(g, a, x):
return -igamma_gradx(g, a, x)

ad.defjvp(igammac_p, gamma_grad_not_implemented, igammac_gradx)
def igammac_grada(g, a, x):
return -igamma_grada(g, a, x)

ad.defjvp(igammac_p, igammac_grada, igammac_gradx)

bessel_i0e_p = standard_unop(_float, 'bessel_i0e')
ad.defjvp2(bessel_i0e_p, lambda g, y, x: g * (bessel_i1e(x) - sign(x) * y))
Expand Down
4 changes: 2 additions & 2 deletions tests/lax_scipy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, test_name=None):
op_record("betaln", 2, float_dtypes, jtu.rand_positive, False),
op_record("betainc", 3, float_dtypes, jtu.rand_positive, False),
op_record("digamma", 1, float_dtypes, jtu.rand_positive, False),
op_record("gammainc", 2, float_dtypes, jtu.rand_positive, False),
op_record("gammaincc", 2, float_dtypes, jtu.rand_positive, False),
op_record("gammainc", 2, float_dtypes, jtu.rand_positive, True),
op_record("gammaincc", 2, float_dtypes, jtu.rand_positive, True),
op_record("erf", 1, float_dtypes, jtu.rand_small_positive, True),
op_record("erfc", 1, float_dtypes, jtu.rand_small_positive, True),
op_record("erfinv", 1, float_dtypes, jtu.rand_small_positive, True),
Expand Down

0 comments on commit e51c7d7

Please sign in to comment.