Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grad and vmap do not composable with gamma sampler #1789

Closed
fehiepsi opened this issue Nov 30, 2019 · 2 comments · Fixed by #1790
Closed

grad and vmap do not composable with gamma sampler #1789

fehiepsi opened this issue Nov 30, 2019 · 2 comments · Fixed by #1790

Comments

@fehiepsi
Copy link
Contributor

Here is a repro code, which triggers error NotImplementedError: Forward-mode differentiation rule for 'while' not implemented when trying to take gradient of g.

import jax
import jax.numpy as np
jax.config.update('jax_platform_name', 'cpu')

def f(a):
    return jax.random.gamma(jax.random.PRNGKey(0), a)

def g(x):
    return np.sum(jax.vmap(f)(x))

print(g(np.ones(3)))
print(jax.grad(g)(np.ones(3)))
@shoyer
Copy link
Collaborator

shoyer commented Nov 30, 2019

This is another manifestation of #1249.

@fehiepsi
Copy link
Contributor Author

Oh... Thanks, @shoyer! That seems like the problem which I get. I'll try to see if I can go back to several months ago to not use custom_transform for gamma sampler.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants