Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support theano.shared in jax_funcify #73

Closed
twiecki opened this issue Oct 4, 2020 · 4 comments
Closed

Support theano.shared in jax_funcify #73

twiecki opened this issue Oct 4, 2020 · 4 comments
Labels
enhancement New feature or request JAX Involves JAX transpilation question Further information is requested

Comments

@twiecki
Copy link
Contributor

twiecki commented Oct 4, 2020

Currently using a theano.shared leads to a MissingInputError , see pymc-devs/pymc#4142.

@twiecki twiecki added bug Something isn't working JAX Involves JAX transpilation important labels Oct 4, 2020
@brandonwillard
Copy link
Member

@junpenglao, is there anything similar to shared variables in jax? We can always convert these into standard arrays, but then changing the shared value won't have an effect on the JAXed function.

@brandonwillard
Copy link
Member

See my comment in the parent issue for an explanation of the real problem.

For the sake of completion. I've added a direct test for shared variables in #76.

We can change the purpose of this issue to center around true shared variable support (i.e. the ability to shared_var.set_value(...) and affect changes in a JAX-jitted function)—if that's even possible—or we can simply add a warning and close this after merging #76 with said warning. @twiecki, what do you think?

@brandonwillard brandonwillard removed bug Something isn't working important labels Oct 4, 2020
@brandonwillard
Copy link
Member

brandonwillard commented Oct 4, 2020

Just to be clear, shared variables work as expected under JAX compilation in Theano (i.e. using theano.function), but the functions returned by jax_funcify are oblivious to shared variables and their functionality. As a matter of fact, jax_funcify should never even see a shared variable.

With this in mind, we shouldn't add any sort of warnings; we should simply let jax_funcify err if it is somehow faced with a shared variable. Likewise, the functionality we're considering here would be a nice feature of jax_funcify, but we can't consider it a requirement, because in that sense it's already fulfilled within the framework that defines shared variables (i.e. Theano and its theano.functions).

In other words, adding shared variable support to sample_numpyro_nuts may require a new framework of its own. FYI: if we had samplers written in Theano using Ops with JAX conversions, this functionality would already be supported, since everything would be compiled via theano.function!

@brandonwillard brandonwillard added enhancement New feature or request question Further information is requested labels Oct 4, 2020
@brandonwillard brandonwillard changed the title Support theano.shared on JAX Support theano.shared in jax_funcify Oct 4, 2020
@brandonwillard
Copy link
Member

brandonwillard commented Oct 8, 2020

I've added comments to pymc-devs/pymc#4142 that should clarify the situation and explain what a viable solution would look like. In the meantime, I'm going to close this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request JAX Involves JAX transpilation question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants