Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vectorized elbo does not work with beta bernoulli model #414

Closed
fehiepsi opened this issue Oct 27, 2019 · 4 comments
Closed

vectorized elbo does not work with beta bernoulli model #414

fehiepsi opened this issue Oct 27, 2019 · 4 comments
Labels
bug Something isn't working jax This issue is specific to JAX

Comments

@fehiepsi
Copy link
Member

fehiepsi commented Oct 27, 2019

As observed in #413, when running SVI in beta bernoulli with num_particles=2, we get the error

NotImplementedError: Forward-mode differentiation rule for 'while' not implemented

This seems like a JAX issue but we need to isolate the bug to report upstream.

@fehiepsi fehiepsi added the bug Something isn't working label Oct 27, 2019
@TuanNguyen27
Copy link
Contributor

@fehiepsi from a quick glance the error traces to here: https://github.com/google/jax/blob/master/jax/interpreters/ad.py#L217

@fehiepsi
Copy link
Member Author

Thanks, @TuanNguyen27 ! I believe that this error comes from while_loop of gamma sampler in JAX, but it is a bit hard to make a small repro script. Anyway, this issue should not block your PR, feel free to complete it and mark failing tests as xfail.

@fehiepsi
Copy link
Member Author

fehiepsi commented Dec 3, 2019

Just want to put a context: this issue is tracked at jax-ml/jax#1789.

@fehiepsi fehiepsi added the jax This issue is specific to JAX label Dec 22, 2019
@fehiepsi
Copy link
Member Author

fehiepsi commented Jan 9, 2020

This is fixed in the latest jax release.

@fehiepsi fehiepsi closed this as completed Jan 9, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working jax This issue is specific to JAX
Projects
None yet
Development

No branches or pull requests

2 participants