You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Here is a repro code, which triggers error NotImplementedError: Forward-mode differentiation rule for 'while' not implemented when trying to take gradient of g.
import jax
import jax.numpy as np
jax.config.update('jax_platform_name', 'cpu')
def f(a):
return jax.random.gamma(jax.random.PRNGKey(0), a)
def g(x):
return np.sum(jax.vmap(f)(x))
print(g(np.ones(3)))
print(jax.grad(g)(np.ones(3)))
The text was updated successfully, but these errors were encountered:
Oh... Thanks, @shoyer! That seems like the problem which I get. I'll try to see if I can go back to several months ago to not use custom_transform for gamma sampler.
Here is a repro code, which triggers error
NotImplementedError: Forward-mode differentiation rule for 'while' not implemented
when trying to take gradient ofg
.The text was updated successfully, but these errors were encountered: