Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a NumPyroNUTS Op #4646

Merged
merged 1 commit into from
May 12, 2021
Merged

Conversation

brandonwillard
Copy link
Contributor

@brandonwillard brandonwillard commented Apr 15, 2021

This PR creates a NumPyroNUTS Op for better integration with Aesara's JAX backend. Closes #4142.

The current implementation doesn't un-transform the transformed variables.

Also, someone who knows the NumPyro functions needs to confirm that the output dimensions are correct.

Really, this is just a prototype that finally demonstrates the point I was trying to make in #4142. Someone who has more time/interest should take over this PR and add the finishing touches, because I don't know when I'll get around to doing it myself.

@brandonwillard brandonwillard force-pushed the create-numpyro-op branch 3 times, most recently from 07fc65a to d24fbf7 Compare April 17, 2021 01:43
@brandonwillard brandonwillard marked this pull request as ready for review April 17, 2021 01:44
@brandonwillard
Copy link
Contributor Author

The un-transforming is now done in the same JAX run as the sampling.

@brandonwillard brandonwillard linked an issue Apr 17, 2021 that may be closed by this pull request
@brandonwillard
Copy link
Contributor Author

brandonwillard commented Apr 18, 2021

Actually, it looks like something extra is needed for shared variables (under this current approach).

Right now, shared variables used in a log-likelihood graph only ever appear in the "inner" FunctionGraph (i.e. the FunctionGraph created inside of the NumPyroNUTS Op. This means that storage is never set up for those shared variables, and, even if it were, it wouldn't be used by the inner FunctionGraph.

@brandonwillard brandonwillard marked this pull request as draft April 18, 2021 23:59
@brandonwillard brandonwillard marked this pull request as ready for review April 19, 2021 00:20
@brandonwillard
Copy link
Contributor Author

All right, I figured out a way to use the shared variables from within the "inner" FunctionGraph (i.e. PyMC3's log-likelihood graph); however, this approach probably won't work when/if the shared variables are updated within the log-likelihood graph.

That should be an acceptable limitation, because log-likelihood graphs probably shouldn't have any sort of "state". Plus, if they did, I don't think it would work with PyMC3's samplers either—for roughly the same reasons.

@twiecki
Copy link
Member

twiecki commented May 5, 2021

I tried running this on https://github.com/pymc-devs/pymc-examples/blob/main/examples/case_studies/factor_analysis.ipynb but got:

k = 2

with pm.Model() as PPCA:

    W = pm.Normal("W", size=(d, k))

    F = pm.Normal("F", size=(k, n))

    psi = pm.HalfNormal("psi", 1.0)

    X = pm.Normal("X", mu=tt.dot(W, F), sigma=psi, observed=Y)

    # select a subset of weights and factors to plot

    W_plot = pm.Deterministic("W_plot", W[1:3, 0])

    F_plot = pm.Deterministic("F_plot", F[0, 1:3])

​

    trace = pm.sampling_jax.sample_numpyro_nuts() #N_SAMPLE, chains=4, cores=1, init="advi+adapt_diag")

​

az.plot_trace(trace, ("W_plot", "F_plot", "psi"));

Compiling...
Compilation time =  0 days 00:00:00.158433
Sampling...

---------------------------------------------------------------------------
FilteredStackTrace                        Traceback (most recent call last)
<ipython-input-10-a13f8033a7ad> in <module>
     10 
---> 11     trace = pm.sampling_jax.sample_numpyro_nuts() #N_SAMPLE, chains=4, cores=1, init="advi+adapt_diag")
     12 

~/projects/pymc/pymc3/sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, model, progress_bar, keep_untransformed)
    225 
--> 226     *mcmc_samples, leapfrogs_taken = _sample()
    227     tic3 = pd.Timestamp.now()

~/projects/aesara/aesara/compile/function/types.py in __call__(self, *args, **kwargs)
    974             outputs = (
--> 975                 self.fn()
    976                 if output_subset is None

~/projects/aesara/aesara/link/utils.py in streamline_default_f()
    186                 ):
--> 187                     thunk()
    188                     for old_s in old_storage:

~/projects/aesara/aesara/link/basic.py in thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    715         ):
--> 716             outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    717 

/var/folders/mn/0x4pxw0n61lf479ndp07r0gr0000gn/T/tmp3ap5m6jp in jax_funcified_fgraph()
      2 def jax_funcified_fgraph():
----> 3     W, F, auto_2707, auto_2708 = _sample(auto_2702, auto_2703, auto_2704)
      4     psi = exp(auto_2707)

~/projects/pymc/pymc3/sampling_jax.py in _sample(*inputs)
    148 
--> 149         pmap_numpyro.run(seed, init_params=current_state, extra_fields=("num_steps",))
    150         samples = pmap_numpyro.get_samples(group_by_chain=True)

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    503             elif self.chain_method == 'parallel':
--> 504                 states, last_state = pmap(partial_map_fn)(map_args)
    505             else:

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
    332         if init_state is None:
--> 333             init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
    334                                            model_args=args, model_kwargs=kwargs)

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    504         if rng_key.ndim == 1:
--> 505             init_state = hmc_init_fn(init_params, rng_key)
    506         else:

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/numpyro/infer/hmc.py in <lambda>(init_params, rng_key)
    487 
--> 488         hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
    489             init_params,

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/numpyro/infer/hmc.py in init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, model_args, model_kwargs, rng_key)
    210         step_size = lax.convert_element_type(step_size, jnp.result_type(float))
--> 211         trajectory_length = lax.convert_element_type(trajectory_length, jnp.result_type(float))
    212         nonlocal wa_update, max_treedepth, vv_update, wa_steps, forward_mode_ad

FilteredStackTrace: TypeError: Value None with type <class 'NoneType'> is not a valid JAX type

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
~/projects/aesara/aesara/link/utils.py in streamline_default_f()
    186                 ):
--> 187                     thunk()
    188                     for old_s in old_storage:

~/projects/aesara/aesara/link/basic.py in thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    715         ):
--> 716             outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    717 

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    138     try:
--> 139       return fun(*args, **kwargs)
    140     except Exception as e:

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/api.py in cache_miss(*args, **kwargs)
    331     flat_fun, out_tree = flatten_fun(f, in_tree)
--> 332     out_flat = xla.xla_call(
    333         flat_fun,

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1401   def bind(self, fun, *args, **params):
-> 1402     return call_bind(self, fun, *args, **params)
   1403 

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1392   with maybe_new_sublevel(top_trace):
-> 1393     outs = primitive.process(top_trace, fun, tracers, params)
   1394   return map(full_lower, apply_todos(env_trace_todo(), outs))

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1404   def process(self, trace, fun, tracers, params):
-> 1405     return trace.process_call(self, fun, tracers, params)
   1406 

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    599   def process_call(self, primitive, f, tracers, params):
--> 600     return primitive.impl(f, *tracers, **params)
    601   process_map = process_call

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    575 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
--> 576   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
    577                                *unsafe_map(arg_spec, args))

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    259     else:
--> 260       ans = call(fun, *args)
    261       cache[key] = (ans, fun.stores)

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    651   abstract_args, arg_devices = unzip2(arg_specs)
--> 652   jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit")
    653   if any(isinstance(c, core.Tracer) for c in consts):

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals, transform_name)
   1208     main.jaxpr_stack = ()  # type: ignore
-> 1209     jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1210     del fun, main

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1187     in_tracers = map(trace.new_arg, in_avals)
-> 1188     ans = fun.call_wrapped(*in_tracers)
   1189     out_tracers = map(trace.full_raise, ans)

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    165     try:
--> 166       ans = self.f(*args, **dict(self.params, **kwargs))
    167     except:

/var/folders/mn/0x4pxw0n61lf479ndp07r0gr0000gn/T/tmp3ap5m6jp in jax_funcified_fgraph()
      2 def jax_funcified_fgraph():
----> 3     W, F, auto_2707, auto_2708 = _sample(auto_2702, auto_2703, auto_2704)
      4     psi = exp(auto_2707)

~/projects/pymc/pymc3/sampling_jax.py in _sample(*inputs)
    148 
--> 149         pmap_numpyro.run(seed, init_params=current_state, extra_fields=("num_steps",))
    150         samples = pmap_numpyro.get_samples(group_by_chain=True)

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    503             elif self.chain_method == 'parallel':
--> 504                 states, last_state = pmap(partial_map_fn)(map_args)
    505             else:

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    138     try:
--> 139       return fun(*args, **kwargs)
    140     except Exception as e:

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/api.py in f_pmapped(*args, **kwargs)
   1563         closure=out_axes)
-> 1564     out = pxla.xla_pmap(
   1565         flat_fun, *args, backend=backend, axis_name=axis_name,

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1460     assert len(params['in_axes']) == len(args)
-> 1461     return call_bind(self, fun, *args, **params)
   1462 

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1392   with maybe_new_sublevel(top_trace):
-> 1393     outs = primitive.process(top_trace, fun, tracers, params)
   1394   return map(full_lower, apply_todos(env_trace_todo(), outs))

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1463   def process(self, trace, fun, tracers, params):
-> 1464     return trace.process_map(self, fun, tracers, params)
   1465 

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py in process_map(self, map_primitive, f, tracers, params)
   1087     with core.extend_axis_env(axis_name, axis_size, None):  # type: ignore
-> 1088       jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
   1089           f, self.main, reduced_in_avals)

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1187     in_tracers = map(trace.new_arg, in_avals)
-> 1188     ans = fun.call_wrapped(*in_tracers)
   1189     out_tracers = map(trace.full_raise, ans)

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    165     try:
--> 166       ans = self.f(*args, **dict(self.params, **kwargs))
    167     except:

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
    332         if init_state is None:
--> 333             init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
    334                                            model_args=args, model_kwargs=kwargs)

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    504         if rng_key.ndim == 1:
--> 505             init_state = hmc_init_fn(init_params, rng_key)
    506         else:

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/numpyro/infer/hmc.py in <lambda>(init_params, rng_key)
    487 
--> 488         hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
    489             init_params,

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/numpyro/infer/hmc.py in init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, model_args, model_kwargs, rng_key)
    210         step_size = lax.convert_element_type(step_size, jnp.result_type(float))
--> 211         trajectory_length = lax.convert_element_type(trajectory_length, jnp.result_type(float))
    212         nonlocal wa_update, max_treedepth, vv_update, wa_steps, forward_mode_ad

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/_src/lax/lax.py in convert_element_type(operand, new_dtype)
    428     operand = operand.__jax_array__()
--> 429   return _convert_element_type(operand, new_dtype, weak_type=False)
    430 

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/_src/lax/lax.py in _convert_element_type(operand, new_dtype, weak_type)
    457   else:
--> 458     return convert_element_type_p.bind(operand, new_dtype=new_dtype,
    459                                        weak_type=new_weak_type)

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in bind(self, *args, **params)
    257     top_trace = find_top_trace(args)
--> 258     tracers = map(top_trace.full_raise, args)
    259     out = top_trace.process_primitive(self, tracers, params)

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/_src/util.py in safe_map(f, *args)
     39     assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
---> 40   return list(map(f, *args))
     41 

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in full_raise(self, val)
    364     if not isinstance(val, Tracer):
--> 365       return self.pure(val)
    366     val._assert_live()

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py in new_const(self, val)
   1011   def new_const(self, val):
-> 1012     aval = raise_to_shaped(get_aval(val), weak_type=dtypes.is_weakly_typed(val))
   1013     tracer = DynamicJaxprTracer(self, aval, source_info_util.current())

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in get_aval(x)
    926   else:
--> 927     return concrete_aval(x)
    928 

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in concrete_aval(x)
    918     return concrete_aval(x.__jax_array__())
--> 919   raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX "
    920                    "type")

TypeError: Value None with type <class 'NoneType'> is not a valid JAX type

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
<ipython-input-10-a13f8033a7ad> in <module>
      9     F_plot = pm.Deterministic("F_plot", F[0, 1:3])
     10 
---> 11     trace = pm.sampling_jax.sample_numpyro_nuts() #N_SAMPLE, chains=4, cores=1, init="advi+adapt_diag")
     12 
     13 az.plot_trace(trace, ("W_plot", "F_plot", "psi"));

~/projects/pymc/pymc3/sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, model, progress_bar, keep_untransformed)
    224     print("Sampling...")
    225 
--> 226     *mcmc_samples, leapfrogs_taken = _sample()
    227     tic3 = pd.Timestamp.now()
    228 

~/projects/aesara/aesara/compile/function/types.py in __call__(self, *args, **kwargs)
    973         try:
    974             outputs = (
--> 975                 self.fn()
    976                 if output_subset is None
    977                 else self.fn(output_subset=output_subset)

~/projects/aesara/aesara/link/utils.py in streamline_default_f()
    189                         old_s[0] = None
    190             except Exception:
--> 191                 raise_with_op(fgraph, node, thunk)
    192 
    193         f = streamline_default_f

~/projects/aesara/aesara/link/utils.py in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    519         # Some exception need extra parameter in inputs. So forget the
    520         # extra long error message in that case.
--> 521     raise exc_value.with_traceback(exc_trace)
    522 
    523 

~/projects/aesara/aesara/link/utils.py in streamline_default_f()
    185                     thunks, order, post_thunk_old_storage
    186                 ):
--> 187                     thunk()
    188                     for old_s in old_storage:
    189                         old_s[0] = None

~/projects/aesara/aesara/link/basic.py in thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    714             thunk_outputs=thunk_outputs,
    715         ):
--> 716             outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    717 
    718             for o_node, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    137   def reraise_with_filtered_traceback(*args, **kwargs):
    138     try:
--> 139       return fun(*args, **kwargs)
    140     except Exception as e:
    141       if not is_under_reraiser(e):

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/api.py in cache_miss(*args, **kwargs)
    330       _check_arg(arg)
    331     flat_fun, out_tree = flatten_fun(f, in_tree)
--> 332     out_flat = xla.xla_call(
    333         flat_fun,
    334         *args_flat,

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1400 
   1401   def bind(self, fun, *args, **params):
-> 1402     return call_bind(self, fun, *args, **params)
   1403 
   1404   def process(self, trace, fun, tracers, params):

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1391   tracers = map(top_trace.full_raise, args)
   1392   with maybe_new_sublevel(top_trace):
-> 1393     outs = primitive.process(top_trace, fun, tracers, params)
   1394   return map(full_lower, apply_todos(env_trace_todo(), outs))
   1395 

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1403 
   1404   def process(self, trace, fun, tracers, params):
-> 1405     return trace.process_call(self, fun, tracers, params)
   1406 
   1407   def post_process(self, trace, out_tracers, params):

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    598 
    599   def process_call(self, primitive, f, tracers, params):
--> 600     return primitive.impl(f, *tracers, **params)
    601   process_map = process_call
    602 

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    574 
    575 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
--> 576   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
    577                                *unsafe_map(arg_spec, args))
    578   try:

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    258       fun.populate_stores(stores)
    259     else:
--> 260       ans = call(fun, *args)
    261       cache[key] = (ans, fun.stores)
    262 

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    650 
    651   abstract_args, arg_devices = unzip2(arg_specs)
--> 652   jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit")
    653   if any(isinstance(c, core.Tracer) for c in consts):
    654     raise core.UnexpectedTracerError("Encountered an unexpected tracer.")

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals, transform_name)
   1207     main.source_info = fun_sourceinfo(fun.f, transform_name)  # type: ignore
   1208     main.jaxpr_stack = ()  # type: ignore
-> 1209     jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1210     del fun, main
   1211   return jaxpr, out_avals, consts

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1186     trace = DynamicJaxprTrace(main, core.cur_sublevel())
   1187     in_tracers = map(trace.new_arg, in_avals)
-> 1188     ans = fun.call_wrapped(*in_tracers)
   1189     out_tracers = map(trace.full_raise, ans)
   1190     jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    164 
    165     try:
--> 166       ans = self.f(*args, **dict(self.params, **kwargs))
    167     except:
    168       # Some transformations yield from inside context managers, so we have to

/var/folders/mn/0x4pxw0n61lf479ndp07r0gr0000gn/T/tmp3ap5m6jp in jax_funcified_fgraph()
      1 
      2 def jax_funcified_fgraph():
----> 3     W, F, auto_2707, auto_2708 = _sample(auto_2702, auto_2703, auto_2704)
      4     psi = exp(auto_2707)
      5     return W, F, psi, auto_2708

~/projects/pymc/pymc3/sampling_jax.py in _sample(*inputs)
    147         )
    148 
--> 149         pmap_numpyro.run(seed, init_params=current_state, extra_fields=("num_steps",))
    150         samples = pmap_numpyro.get_samples(group_by_chain=True)
    151         leapfrogs_taken = pmap_numpyro.get_extra_fields(group_by_chain=True)["num_steps"]

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    502                 states, last_state = _laxmap(partial_map_fn, map_args)
    503             elif self.chain_method == 'parallel':
--> 504                 states, last_state = pmap(partial_map_fn)(map_args)
    505             else:
    506                 assert self.chain_method == 'vectorized'

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    137   def reraise_with_filtered_traceback(*args, **kwargs):
    138     try:
--> 139       return fun(*args, **kwargs)
    140     except Exception as e:
    141       if not is_under_reraiser(e):

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/api.py in f_pmapped(*args, **kwargs)
   1562         lambda: tuple(flatten_axes("pmap out_axes", out_tree(), out_axes)),
   1563         closure=out_axes)
-> 1564     out = pxla.xla_pmap(
   1565         flat_fun, *args, backend=backend, axis_name=axis_name,
   1566         axis_size=local_axis_size, global_axis_size=axis_size,

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1459   def bind(self, fun, *args, **params):
   1460     assert len(params['in_axes']) == len(args)
-> 1461     return call_bind(self, fun, *args, **params)
   1462 
   1463   def process(self, trace, fun, tracers, params):

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1391   tracers = map(top_trace.full_raise, args)
   1392   with maybe_new_sublevel(top_trace):
-> 1393     outs = primitive.process(top_trace, fun, tracers, params)
   1394   return map(full_lower, apply_todos(env_trace_todo(), outs))
   1395 

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1462 
   1463   def process(self, trace, fun, tracers, params):
-> 1464     return trace.process_map(self, fun, tracers, params)
   1465 
   1466   def post_process(self, trace, out_tracers, params):

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py in process_map(self, map_primitive, f, tracers, params)
   1086                         for a, in_axis in zip(in_avals, params['in_axes'])]
   1087     with core.extend_axis_env(axis_name, axis_size, None):  # type: ignore
-> 1088       jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
   1089           f, self.main, reduced_in_avals)
   1090       out_axes = params['out_axes_thunk']()

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1186     trace = DynamicJaxprTrace(main, core.cur_sublevel())
   1187     in_tracers = map(trace.new_arg, in_avals)
-> 1188     ans = fun.call_wrapped(*in_tracers)
   1189     out_tracers = map(trace.full_raise, ans)
   1190     jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    164 
    165     try:
--> 166       ans = self.f(*args, **dict(self.params, **kwargs))
    167     except:
    168       # Some transformations yield from inside context managers, so we have to

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
    331         rng_key, init_state, init_params = init
    332         if init_state is None:
--> 333             init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
    334                                            model_args=args, model_kwargs=kwargs)
    335         sample_fn, postprocess_fn = self._get_cached_fns()

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    503         )
    504         if rng_key.ndim == 1:
--> 505             init_state = hmc_init_fn(init_params, rng_key)
    506         else:
    507             # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/numpyro/infer/hmc.py in <lambda>(init_params, rng_key)
    486                              ' `potential_fn`.')
    487 
--> 488         hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
    489             init_params,
    490             num_warmup=num_warmup,

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/numpyro/infer/hmc.py in init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, model_args, model_kwargs, rng_key)
    209         """
    210         step_size = lax.convert_element_type(step_size, jnp.result_type(float))
--> 211         trajectory_length = lax.convert_element_type(trajectory_length, jnp.result_type(float))
    212         nonlocal wa_update, max_treedepth, vv_update, wa_steps, forward_mode_ad
    213         forward_mode_ad = forward_mode_differentiation

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/_src/lax/lax.py in convert_element_type(operand, new_dtype)
    427   if hasattr(operand, '__jax_array__'):
    428     operand = operand.__jax_array__()
--> 429   return _convert_element_type(operand, new_dtype, weak_type=False)
    430 
    431 def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None,

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/_src/lax/lax.py in _convert_element_type(operand, new_dtype, weak_type)
    456     return operand
    457   else:
--> 458     return convert_element_type_p.bind(operand, new_dtype=new_dtype,
    459                                        weak_type=new_weak_type)
    460 

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in bind(self, *args, **params)
    256             all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
    257     top_trace = find_top_trace(args)
--> 258     tracers = map(top_trace.full_raise, args)
    259     out = top_trace.process_primitive(self, tracers, params)
    260     return map(full_lower, out) if self.multiple_results else full_lower(out)

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/_src/util.py in safe_map(f, *args)
     38   for arg in args[1:]:
     39     assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
---> 40   return list(map(f, *args))
     41 
     42 def unzip2(xys):

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in full_raise(self, val)
    363   def full_raise(self, val) -> 'Tracer':
    364     if not isinstance(val, Tracer):
--> 365       return self.pure(val)
    366     val._assert_live()
    367     level = self.level

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py in new_const(self, val)
   1010 
   1011   def new_const(self, val):
-> 1012     aval = raise_to_shaped(get_aval(val), weak_type=dtypes.is_weakly_typed(val))
   1013     tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
   1014     self.frame.tracers.append(tracer)

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in get_aval(x)
    925     return x.aval
    926   else:
--> 927     return concrete_aval(x)
    928 
    929 

~/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/jax/core.py in concrete_aval(x)
    917   if hasattr(x, '__jax_array__'):
    918     return concrete_aval(x.__jax_array__())
--> 919   raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX "
    920                    "type")
    921 

TypeError: Value None with type <class 'NoneType'> is not a valid JAX type
Apply node that caused the error: NumPyroNUTS(TensorConstant{[[[ 0.963 .. -1.028]]]}, TensorConstant{[[[-0.128 .. -0.927]]]}, TensorConstant{(4,) of -0..3503390389})
Toposort index: 0
Inputs types: [TensorType(float64, 3D), TensorType(float64, 3D), TensorType(float64, vector)]
Inputs shapes: []
Inputs strides: []
Inputs values: []
Outputs clients: [['output'], ['output'], [Elemwise{Exp}[(0, 0)](NumPyroNUTS.2)], ['output']]

Backtrace when the node is created (use Aesara flag traceback__limit=N to make it longer):
  File "/Users/twiecki/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3165, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/Users/twiecki/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3357, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "/Users/twiecki/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3437, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-10-a13f8033a7ad>", line 11, in <module>
    trace = pm.sampling_jax.sample_numpyro_nuts() #N_SAMPLE, chains=4, cores=1, init="advi+adapt_diag")
  File "/Users/twiecki/projects/pymc/pymc3/sampling_jax.py", line 181, in sample_numpyro_nuts
    numpyro_samples = NumPyroNUTS(
  File "/Users/twiecki/projects/aesara/aesara/graph/op.py", line 270, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/Users/twiecki/projects/pymc/pymc3/sampling_jax.py", line 65, in make_node
    outputs = [
  File "/Users/twiecki/projects/pymc/pymc3/sampling_jax.py", line 66, in <listcomp>
    TensorType(v.dtype, self.samples_bcast + list(v.broadcastable))() for v in inputs

HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

@brandonwillard
Copy link
Contributor Author

brandonwillard commented May 5, 2021

@twiecki, are you running the code on this branch? It looks like you're using Theano-PyMC (e.g. using the alias tt, but the traceback says aesara?).

I just ran this locally and it worked, although the Deterministics don't end up in the trace—but that's something we can add later:

import numpy as np

import arviz as az
import matplotlib

import pymc3 as pm

import aesara.tensor as at

from pymc3.sampling_jax import sample_numpyro_nuts


n = 250
k_true = 5
d = 9
err_sd = 2
N_SAMPLE = 350
M = np.random.binomial(1, 0.25, size=(k_true, n))
Q = np.hstack(
    [np.random.exponential(2 * k_true - k, size=(d, 1)) for k in range(k_true)]
) * np.random.binomial(1, 0.75, size=(d, k_true))
Y = np.round(1000 * np.dot(Q, M) + np.random.normal(size=(d, n)) * err_sd) / 1000

k = 2


with pm.Model() as PPCA:
    W = pm.Normal("W", size=(d, k))
    F = pm.Normal("F", size=(k, n))
    psi = pm.HalfNormal("psi", 1.0)
    X = pm.Normal("X", mu=at.dot(W, F), sigma=psi, observed=Y)
    W_plot = pm.Deterministic("W_plot", W[1:3, 0])
    F_plot = pm.Deterministic("F_plot", F[0, 1:3])

    trace = sample_numpyro_nuts()

trace.posterior['W_plot'] = trace.posterior.W[:, :, 1:3, 0]
trace.posterior['F_plot'] = trace.posterior.F[:, :, 0, 1:3]

az.plot_trace(trace, ("W_plot", "F_plot", "psi"));

numpyro-trace-example

@twiecki
Copy link
Member

twiecki commented May 5, 2021

This is aesara master and this branch, odd.

@twiecki
Copy link
Member

twiecki commented May 5, 2021

Which JAX and numpyro version are you running?

image

@brandonwillard
Copy link
Contributor Author

brandonwillard commented May 5, 2021

Looks like I haven't rebased my local version of this branch since 94213ca, so your issue might've been caused by something introduced at or after that commit if you've rebased it locally.

I'll try those exact versions of jax and numpyro next.

@brandonwillard
Copy link
Contributor Author

I just got a pip error when trying to install those exact versions together:

>>> pip install jax==0.2.12 numpyro==0.6.0
...
The conflict is caused by:
    The user requested jax==0.2.12
    numpyro 0.6.0 depends on jax==0.2.10

@brandonwillard
Copy link
Contributor Author

brandonwillard commented May 5, 2021

It appears to work with numpyro==0.6.0 and jax==0.2.10, though.

@twiecki
Copy link
Member

twiecki commented May 5, 2021

I get the same error with jax==0.2.10.

@brandonwillard
Copy link
Contributor Author

Oh, wait, if you're on the master branch of Aesara, that could be it.

@twiecki
Copy link
Member

twiecki commented May 5, 2021

I guess I should use the version that's pinned.

@brandonwillard
Copy link
Contributor Author

I just tried it with Aesara master and that worked, too. Have you tried the code I posted above? Also, are you running this on Windows (e.g. in case it's somehow related to #4652)?

@twiecki
Copy link
Member

twiecki commented May 5, 2021

Yes, ran the code you posted above.

This is on OSX.

@brandonwillard
Copy link
Contributor Author

Is your local branch in sync with this remote?

@twiecki
Copy link
Member

twiecki commented May 5, 2021

Yep, I checked it out today.

>>git log
commit 23f7a7b74b2426f65eee23a7007427a9bb383541 (HEAD -> create-numpyro-op, origin/pr/4646)
Author: Brandon T. Willard <brandonwillard@users.noreply.github.com>
Date:   Thu Apr 15 15:12:12 2021 -0500

    Create a NumPyro sampler Op for better JAX backend integration

commit 45cb4ebf36500e502481bdced6980dd9e630acca
Author: Ricardo <ricardo.vieira1994@gmail.com>
Date:   Fri Apr 16 13:00:40 2021 +0200

    Add auto_deterministics list to Model

    Ensures that when missing variables are present in the model, the automatic deterministic (x_observed + x_missing) only appears in predictive sampling and not normal sampling.
    Fixes `x` missing from prior_predictive when missing values were present (only `x_missing` was present)

@brandonwillard
Copy link
Contributor Author

Since it's complaining about a None, and it appears to be happening when JAX traces _sample, you can try removing the keyword argument with None here.

@twiecki
Copy link
Member

twiecki commented May 5, 2021

Did not help.

@twiecki
Copy link
Member

twiecki commented May 5, 2021

Running this same code on v4 gives:

---------------------------------------------------------------------------
MissingInputError                         Traceback (most recent call last)
<ipython-input-1-c7429fc90d9e> in <module>
     33     #F_plot = pm.Deterministic("F_plot", F[0, 1:3])
     34 
---> 35     trace = sample_numpyro_nuts()
     36 
     37 trace.posterior['W_plot'] = trace.posterior.W[:, :, 1:3, 0]

~/projects/pymc/pymc3/sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, model, progress_bar, keep_untransformed)
    134     seed = jax.random.PRNGKey(random_seed)
    135 
--> 136     fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
    137     fns = jax_funcify(fgraph)
    138     logp_fn_jax = fns[0]

~/projects/aesara/aesara/graph/fg.py in __init__(self, inputs, outputs, features, clone, update_mapping, memo, copy_inputs, copy_orphans)
    166 
    167         for output in outputs:
--> 168             self.import_var(output, reason="init")
    169         for i, output in enumerate(outputs):
    170             self.clients[output].append(("output", i))

~/projects/aesara/aesara/graph/fg.py in import_var(self, var, reason, import_missing)
    335         # Imports the owners of the variables
    336         if var.owner and var.owner not in self.apply_nodes:
--> 337             self.import_node(var.owner, reason=reason, import_missing=import_missing)
    338         elif (
    339             var.owner is None

~/projects/aesara/aesara/graph/fg.py in import_node(self, apply_node, check, reason, import_missing)
    400                                 "for more information on this error."
    401                             )
--> 402                             raise MissingInputError(error_msg, variable=var)
    403 
    404         for node in new_nodes:

MissingInputError: Input 0 (psi_log__) of the graph (indices start from 0), used to compute Elemwise{exp,no_inplace}(psi_log__), was not provided and not given a value. Use the Aesara flag exception_verbosity='high', for more information on this error.
 
Backtrace when that variable is created:

  File "/Users/twiecki/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 68, in _pseudo_sync_runner
    coro.send(None)
  File "/Users/twiecki/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3165, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/Users/twiecki/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3357, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "/Users/twiecki/miniconda3/envs/pymc3v4/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3437, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-1-c7429fc90d9e>", line 30, in <module>
    psi = pm.HalfNormal("psi", 1.0)
  File "/Users/twiecki/projects/pymc/pymc3/distributions/distribution.py", line 303, in __new__
    rv_registered = model.register_rv(
  File "/Users/twiecki/projects/pymc/pymc3/model.py", line 1114, in register_rv
    self.create_value_var(rv_var, transform)
  File "/Users/twiecki/projects/pymc/pymc3/model.py", line 1258, in create_value_var
    value_var = rv_var.type()

@brandonwillard
Copy link
Contributor Author

Does the test in this branch pass locally?

@brandonwillard brandonwillard force-pushed the create-numpyro-op branch 2 times, most recently from 77736e1 to d950485 Compare May 5, 2021 20:30
@brandonwillard
Copy link
Contributor Author

brandonwillard commented May 5, 2021

I added a MacOS test for that model and it appears to have passed in CI.

@twiecki
Copy link
Member

twiecki commented May 6, 2021

Alright, updating from this PR it works now! Did you change anything other than adding the OSX test? It's hard to see what's changing here if you rewrite commits.

pymc3/sampling_jax.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member

ricardoV94 commented May 6, 2021

Failing test seems to be the one reported here: #4661

@brandonwillard
Copy link
Contributor Author

It's hard to see what's changing here if you rewrite commits.

From the GitHub UI, if you click on the force-pushed links, you can see exactly what changed (e.g. here's the last one).

Locally, a git diff against the remote provides exactly the same thing.

@brandonwillard
Copy link
Contributor Author

Failing test seems to be the one reported here: #4661

Yes, it's a flaky test; I believe I've tried to deal with that one before.

twiecki
twiecki previously approved these changes May 6, 2021
@twiecki
Copy link
Member

twiecki commented May 12, 2021

@brandonwillard Can we merge this?

@brandonwillard
Copy link
Contributor Author

brandonwillard commented May 12, 2021

@twiecki why is "Rebase and merge" disabled in this repo?

twiecki
twiecki previously approved these changes May 12, 2021
@twiecki
Copy link
Member

twiecki commented May 12, 2021

@brandonwillard Because of conflicts, which I thought I had resolved.

@brandonwillard
Copy link
Contributor Author

@brandonwillard Because of conflicts, which I thought I had resolved.

It wasn't saying there was a conflict on this page, but I rebased the branch itself and that seems to have cleared it up.

@twiecki twiecki merged commit 04e1271 into pymc-devs:v4 May 12, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Shared variable issues when using NumPyro JAX sampler
3 participants