-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Shared variable issues when using NumPyro JAX sampler #4142
Comments
pm.Data
inside a model in combination with JAX sampler
Thanks for reporting and testing out! I think theano is marking shared variable as input - will need to work on the log_prob generation function to identify these. For now, you need to replace all the theano.shared variable and pm.Data with numpy array in your model. |
This issue isn't due to lack of support for shared variables in If we want to create a |
As implied in my comment here, one way to add shared variable support to functions like One way to do that is to write simple wrapper |
Here's a template for such an import theano.tensor as tt
from theano.gof.op import Op
from theano.gof.graph import Apply
from theano.tensor.type import TensorType
from theano.sandbox.jaxify import jax_funcify
class NumPyroNUTS(Op):
def __init__(self, draws=1000, tune=1000, chains=4):
self.draws = draws
self.tune = tune
self.chains = chains
super().__init__()
def make_node(self, input_rvs):
"""Construct a node for the NumPyro NUTS sampler.
Parameters
----------
input_rvs : List[TensorVariable]
The input variables, or `init_state`s, obtained from `model.free_RVs`, for example.
"""
inputs = [tt.as_tensor(rv) for rv in input_rvs]
# New, potentially broadcastable dimensions added by sampling
broadcastable_sample_dims = [self.chains == 1, self.draws == 1]
# These symbolic tensors/arrays that represent the posterior samples
# for each variable.
outputs = [TensorType(rv.dtype, broadcastable_sample_dims + list(rv.broadcastable)) for rv in input_rvs]
return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs):
# `inputs` is a list containing the numeric initial state values.
# Simply put, when we're in here, we have concrete numeric values for
# the `inputs` specified in `make_node`, so we can do any
# pure Python work we want. The only requirement is that we return
# numeric values that correspond to the `outputs` specified in `make_node`
# (i.e. they must have the same `dtype`s, number of dimensions, and broadcast
# pattern).
# This could compile[, cache,] and evaluate the JAX-jitted function just like
# `sample_numpyro_nuts` does and return the numeric values (as a list of
# samples, in the way `make_node` specifies them, of course).
# `outputs` is a list containing a list for each output variable.
# These lists need to be populated with the numeric sample arrays for
# each variable.
outputs = ...
@jax_funcify.register(NumPyroNUTS)
def jax_funcify_NumPyroNUTS(op):
draws = op.draws
tune = op.tune
chains = op.chains
def numpyronuts(init, draws=draws, tune=tune, chains=chains):
# Just return the JAX-jittable sampler function constructed in
# `sample_numpyro_nuts` (e.g. `_sample` or some variant thereof).
return ...
return numpyronuts |
@brandonwillard This is cool. I wonder if there are any other benefits to this approach rather than getting |
One other huge benefit: this approach works with symbolic shapes... |
This looks interesting. One question I have here is what exactly will be the difference between the content of |
Assuming that we want to create such an In other words, |
pm.Data
inside a model in combination with JAX sampler
@twiecki or @brandonwillard can this one be closed? |
Is it fixed? |
Ah yes, seems like it. Neat. |
If you have questions about a specific use case, or you are not sure whether this is a bug or not, please post it to our discourse channel: https://discourse.pymc.io
Description of your problem
I am trying to the new JAX-based sampler in the
pymc3jax
branch, presented in this notebook: https://gist.github.com/twiecki/f0a28dd06620aa86142931c1f10b5434I can run the notebook as it is just fine, but if I register the data of the model using the
pm.Data
constructor, I am getting an.MissingInputError
.Essentially, I am replacing Cell 6 in the notebook with this code:
So, I am registering the two input variables and the output variables as
pm.Data
objects and replaced their calls in the code below.I can then run the standard samples without problems but the JAX sampler (Cell 10) fails.
Please provide the full traceback.
Versions and main components
pymc3jax
branchTheano-Pymc
master branchThe text was updated successfully, but these errors were encountered: