Replies: 2 comments
-
Thanks for the report! This is unimplemented because there is no easy way to compute the derivative with available XLA primitives (the derivative can be expressed in terms of the Meijer G-function). I don't know of any way to compute that in JAX, but if you find a way to do so, it would be a welcome contribution to the JAX package. Until then, the |
Beta Was this translation helpful? Give feedback.
0 replies
-
Meanwhile, you can try finite difference approximated jvp rule. import jax
import finitediffx as fdx
fd_gammainc = jax.scipy.special.gammainc
fd_gammainc = fdx.define_fdjvp(fd_gammainc, offsets=fdx.Offset(4))
with jax.experimental.enable_x64():
ad_grad = jax.grad(jax.scipy.special.gammainc)(0.1, 0.2)
fd_grad = jax.grad(fd_gammainc)(0.1, 0.2)
ad_fd_diff = abs(ad_grad - fd_grad)
print(f"{ad_fd_diff:.3e}")
print(jax.hessian(fd_gammainc)(0.1, 0.2))
#1.081e-12
#0.5696510461191906 |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
In my work, I come across the hessian matrix of something involved with incomplete gamma function.
And then I found that jax.hessian() cannot calculate jax.scipy.special.gammainc().
For example, I am fine to calculate the gradient:
However, it reports error when I calculate the hessian matrix:
NotImplementedError: Differentiation rule for 'igamma_grad_a' not implemented
Is there any way for me to calculate the hessian function of
jax.scipy.special.gammainc()
?Beta Was this translation helpful? Give feedback.
All reactions