Description
If you have questions about a specific use case, or you are not sure whether this is a bug or not, please post it to our discourse channel: https://discourse.pymc.io
Description of your problem
I am trying to the new JAX-based sampler in the pymc3jax
branch, presented in this notebook: https://gist.github.com/twiecki/f0a28dd06620aa86142931c1f10b5434
I can run the notebook as it is just fine, but if I register the data of the model using the pm.Data
constructor, I am getting an. MissingInputError
.
Essentially, I am replacing Cell 6 in the notebook with this code:
with pm.Model() as hierarchical_model:
county_idx = pm.Data('county_idx', data.county_code.values.astype('int32'))
floor = pm.Data('floor', data.floor.values)
log_radon = pm.Data('log_radon', data.log_radon)
# Hyperpriors for group nodes
mu_a = pm.Normal('mu_a', mu=0., sigma=100.)
sigma_a = pm.HalfNormal('sigma_a', 5.)
mu_b = pm.Normal('mu_b', mu=0., sigma=100.)
sigma_b = pm.HalfNormal('sigma_b', 5.)
# Intercept for each county, distributed around group mean mu_a
# Above we just set mu and sd to a fixed value while here we
# plug in a common group distribution for all a and b (which are
# vectors of length n_counties).
a = pm.Normal('a', mu=mu_a, sigma=sigma_a, shape=n_counties)
# Intercept for each county, distributed around group mean mu_a
b = pm.Normal('b', mu=mu_b, sigma=sigma_b, shape=n_counties)
# Model error
eps = pm.HalfCauchy('eps', 5.)
radon_est = a[county_idx] + b[county_idx]*floor
# Data likelihood
radon_like = pm.Normal('radon_like', mu=radon_est,
sigma=eps, observed=log_radon)
So, I am registering the two input variables and the output variables as pm.Data
objects and replaced their calls in the code below.
I can then run the standard samples without problems but the JAX sampler (Cell 10) fails.
Please provide the full traceback.
---------------------------------------------------------------------------
MissingInputError Traceback (most recent call last)
<timed exec> in <module>
/path/to/pymc3/pymc3/sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, model, progress_bar)
114 seed = jax.random.PRNGKey(random_seed)
115
--> 116 fgraph = theano.gof.FunctionGraph(model.free_RVs, [model.logpt])
117 fns = theano.sandbox.jaxify.jax_funcify(fgraph)
118 logp_fn_jax = fns[0]
/path/to/Theano-PyMC/theano/gof/fg.py in __init__(self, inputs, outputs, features, clone, update_mapping)
174
175 for output in outputs:
--> 176 self.__import_r__(output, reason="init")
177 for i, output in enumerate(outputs):
178 output.clients.append(("output", i))
/path/to/Theano-PyMC/theano/gof/fg.py in __import_r__(self, variable, reason)
347 # Imports the owners of the variables
348 if variable.owner and variable.owner not in self.apply_nodes:
--> 349 self.__import__(variable.owner, reason=reason)
350 elif (
351 variable.owner is None
/path/to/Theano-PyMC/theano/gof/fg.py in __import__(self, apply_node, check, reason)
399 % (node.inputs.index(r), str(node))
400 )
--> 401 raise MissingInputError(error_msg, variable=r)
402
403 for node in new_nodes:
MissingInputError: Input 1 of the graph (indices start from 0), used to compute AdvancedSubtensor1(b, county_idx), was not provided and not given a value. Use the Theano flag exception_verbosity='high', for more information on this error.
Backtrace when that variable is created:
File "/path/to/conda/env//lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2866, in run_cell
result = self._run_cell(
File "/path/to/conda/env//lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2895, in _run_cell
return runner(coro)
File "/path/to/conda/env//lib/python3.8/site-packages/IPython/core/async_helpers.py", line 68, in _pseudo_sync_runner
coro.send(None)
File "/path/to/conda/env//lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3071, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/path/to/conda/env//lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3263, in run_ast_nodes
if (await self.run_code(code, result, async_=asy)):
File "/path/to/conda/env//lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3343, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-17-2577b767217d>", line 2, in <module>
county_idx = pm.Data('county_idx', data.county_code.values.astype('int32'))
File "/path/to/pymc3/pymc3/data.py", line 516, in __new__
shared_object = theano.shared(pm.model.pandas_to_array(value), name)
Versions and main components
- PyMC3 Version: checkout of
pymc3jax
branch - Theano Version: checkout of
Theano-Pymc
master branch - Python Version: 3.8
- Operating system: Mac OS
- How did you install PyMC3: manual installation of the branch