-
Notifications
You must be signed in to change notification settings - Fork 14
Error about NumPyro
ykawashima edited this page Nov 5, 2021
·
2 revisions
The following error occurs when you try to perform MCMC run with NumPyro. This happens at least with the combination of JAX 0.2.16, jaxlib 0.1.68+cuda110, and NumPyro 0.6.0.
Traceback (most recent call last):
File "mcmc.py", line 162, in <module>
mcmc.run(rng_key_, y1=nflux)
File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 498, in run
states_flat, last_state = partial_map_fn(map_args)
File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 333, in _single_chain_mcmc
init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 505, in init
init_state = hmc_init_fn(init_params, rng_key)
File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 488, in <lambda>
hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731
File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 211, in init_kernel
trajectory_length = lax.convert_element_type(trajectory_length, jnp.result_type(float))
File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 425, in convert_element_type
return _convert_element_type(operand, new_dtype, weak_type=False)
File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 454, in _convert_element_type
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/jax/core.py", line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/jax/core.py", line 603, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 248, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 240, in arg_spec
aval = abstractify(x)
File "/home/kawashimayi/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 186, in abstractify
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
TypeError: Argument 'None' of type '<class 'NoneType'>' is not a valid JAX type
Solution: Update numpyro to 0.7.0 (see also this website)