Skip to content

Shared variable issues when using NumPyro JAX sampler #4142

Closed
@mschmidt87

Description

@mschmidt87

If you have questions about a specific use case, or you are not sure whether this is a bug or not, please post it to our discourse channel: https://discourse.pymc.io

Description of your problem

I am trying to the new JAX-based sampler in the pymc3jax branch, presented in this notebook: https://gist.github.com/twiecki/f0a28dd06620aa86142931c1f10b5434
I can run the notebook as it is just fine, but if I register the data of the model using the pm.Data constructor, I am getting an. MissingInputError.

Essentially, I am replacing Cell 6 in the notebook with this code:

with pm.Model() as hierarchical_model:
    county_idx = pm.Data('county_idx', data.county_code.values.astype('int32'))
    floor = pm.Data('floor', data.floor.values)
    log_radon = pm.Data('log_radon', data.log_radon)
    
    # Hyperpriors for group nodes
    mu_a = pm.Normal('mu_a', mu=0., sigma=100.)
    sigma_a = pm.HalfNormal('sigma_a', 5.)
    mu_b = pm.Normal('mu_b', mu=0., sigma=100.)
    sigma_b = pm.HalfNormal('sigma_b', 5.)

    # Intercept for each county, distributed around group mean mu_a
    # Above we just set mu and sd to a fixed value while here we
    # plug in a common group distribution for all a and b (which are
    # vectors of length n_counties).
    a = pm.Normal('a', mu=mu_a, sigma=sigma_a, shape=n_counties)
    # Intercept for each county, distributed around group mean mu_a
    b = pm.Normal('b', mu=mu_b, sigma=sigma_b, shape=n_counties)

    # Model error
    eps = pm.HalfCauchy('eps', 5.)

    radon_est = a[county_idx] + b[county_idx]*floor

    # Data likelihood
    radon_like = pm.Normal('radon_like', mu=radon_est,
                           sigma=eps, observed=log_radon)

So, I am registering the two input variables and the output variables as pm.Data objects and replaced their calls in the code below.

I can then run the standard samples without problems but the JAX sampler (Cell 10) fails.

Please provide the full traceback.

---------------------------------------------------------------------------
MissingInputError                         Traceback (most recent call last)
<timed exec> in <module>

/path/to/pymc3/pymc3/sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, model, progress_bar)
    114     seed = jax.random.PRNGKey(random_seed)
    115 
--> 116     fgraph = theano.gof.FunctionGraph(model.free_RVs, [model.logpt])
    117     fns = theano.sandbox.jaxify.jax_funcify(fgraph)
    118     logp_fn_jax = fns[0]

/path/to/Theano-PyMC/theano/gof/fg.py in __init__(self, inputs, outputs, features, clone, update_mapping)
    174 
    175         for output in outputs:
--> 176             self.__import_r__(output, reason="init")
    177         for i, output in enumerate(outputs):
    178             output.clients.append(("output", i))

/path/to/Theano-PyMC/theano/gof/fg.py in __import_r__(self, variable, reason)
    347         # Imports the owners of the variables
    348         if variable.owner and variable.owner not in self.apply_nodes:
--> 349             self.__import__(variable.owner, reason=reason)
    350         elif (
    351             variable.owner is None

/path/to/Theano-PyMC/theano/gof/fg.py in __import__(self, apply_node, check, reason)
    399                             % (node.inputs.index(r), str(node))
    400                         )
--> 401                         raise MissingInputError(error_msg, variable=r)
    402 
    403         for node in new_nodes:

MissingInputError: Input 1 of the graph (indices start from 0), used to compute AdvancedSubtensor1(b, county_idx), was not provided and not given a value. Use the Theano flag exception_verbosity='high', for more information on this error.
 
Backtrace when that variable is created:

  File "/path/to/conda/env//lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2866, in run_cell
    result = self._run_cell(
  File "/path/to/conda/env//lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2895, in _run_cell
    return runner(coro)
  File "/path/to/conda/env//lib/python3.8/site-packages/IPython/core/async_helpers.py", line 68, in _pseudo_sync_runner
    coro.send(None)
  File "/path/to/conda/env//lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3071, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/path/to/conda/env//lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3263, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "/path/to/conda/env//lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3343, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-17-2577b767217d>", line 2, in <module>
    county_idx = pm.Data('county_idx', data.county_code.values.astype('int32'))
  File "/path/to/pymc3/pymc3/data.py", line 516, in __new__
    shared_object = theano.shared(pm.model.pandas_to_array(value), name)

Versions and main components

  • PyMC3 Version: checkout of pymc3jax branch
  • Theano Version: checkout of Theano-Pymc master branch
  • Python Version: 3.8
  • Operating system: Mac OS
  • How did you install PyMC3: manual installation of the branch

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions