We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A clear and concise description of the bug.
Versions
import numpy as np import numpyro import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS from jax import random import flax.linen as nn from numpyro.contrib.module import random_flax_module # Simulate rng = np.random.default_rng(99) N = 1000 X = rng.normal(0, 1, size=(N,1)) mu = 1 + X @ np.array([0.5]) y = rng.normal(mu, 0.5) # Simple linear layer class Linear(nn.Module): @nn.compact def __call__(self, x): return nn.Dense(1, use_bias=True, name='Dense')(x) def model(X, y=None): sigma = numpyro.sample("sigma", dist.HalfNormal(0.1)) priors = {"Dense.bias": dist.Normal(0,2.5), "Dense.kernel": dist.Normal(0,1)} mlp = random_flax_module("mlp", Linear(), prior=priors, input_shape=(X.shape[1],)) with numpyro.plate("data", X.shape[0]): mu = numpyro.deterministic("mu", mlp(X).squeeze(-1)) y = numpyro.sample("y", dist.Normal(mu, sigma), obs=y) # Fit model kernel = NUTS(model, target_accept_prob=0.95) mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=2) mcmc.run(random.PRNGKey(0), X=X, y=y) # run log likelihood numpyro.infer.util.log_likelihood(model, mcmc.get_samples(), X=X, y=y)
Here's the traceback
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) Cell In[431], line 46 42 mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=2) 43 mcmc.run(random.PRNGKey(0), X=X, y=y) ---> 46 numpyro.infer.util.log_likelihood(model, mcmc.get_samples(), X=X, y=y) File ~/Desktop/numpyro_bug_flax/.venv/lib/python3.11/site-packages/numpyro/infer/util.py:1147, in log_likelihood(model, posterior_samples, parallel, batch_ndims, *args, **kwargs) 1145 batch_size = int(np.prod(batch_shape)) 1146 chunk_size = batch_size if parallel else 1 -> 1147 return soft_vmap(single_loglik, posterior_samples, len(batch_shape), chunk_size) File ~/Desktop/numpyro_bug_flax/.venv/lib/python3.11/site-packages/numpyro/util.py:453, in soft_vmap(fn, xs, batch_ndims, chunk_size) 447 xs = jax.tree.map( 448 lambda x: jnp.reshape(x, prepend_shape + (chunk_size,) + jnp.shape(x)[1:]), 449 xs, 450 ) 451 fn = vmap(fn) --> 453 ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs) 454 map_ndims = int(num_chunks > 1) + int(chunk_size > 1) 455 ys = jax.tree.map( 456 lambda y: jnp.reshape( 457 y, (int(np.prod(jnp.shape(y)[:map_ndims])),) + jnp.shape(y)[map_ndims:] 458 )[:batch_size], 459 ys, 460 ) [... skipping hidden 13 frame] File ~/Desktop/numpyro_bug_flax/.venv/lib/python3.11/site-packages/numpyro/infer/util.py:1122, in log_likelihood.<locals>.single_loglik(samples) 1118 def single_loglik(samples): 1119 substituted_model = ( 1120 substitute(model, samples) if isinstance(samples, dict) else model 1121 ) -> 1122 model_trace = trace(substituted_model).get_trace(*args, **kwargs) 1123 return { 1124 name: site["fn"].log_prob(site["value"]) 1125 for name, site in model_trace.items() 1126 if site["type"] == "sample" and site["is_observed"] 1127 } File ~/Desktop/numpyro_bug_flax/.venv/lib/python3.11/site-packages/numpyro/handlers.py:191, in trace.get_trace(self, *args, **kwargs) 183 def get_trace(self, *args, **kwargs) -> OrderedDict[str, Message]: 184 """ 185 Run the wrapped callable and return the recorded trace. 186 (...) 189 :return: `OrderedDict` containing the execution trace. 190 """ --> 191 self(*args, **kwargs) 192 return self.trace File ~/Desktop/numpyro_bug_flax/.venv/lib/python3.11/site-packages/numpyro/primitives.py:121, in Messenger.__call__(self, *args, **kwargs) 119 return self 120 with self: --> 121 return self.fn(*args, **kwargs) File ~/Desktop/numpyro_bug_flax/.venv/lib/python3.11/site-packages/numpyro/primitives.py:121, in Messenger.__call__(self, *args, **kwargs) 119 return self 120 with self: --> 121 return self.fn(*args, **kwargs) Cell In[431], line 26, in model(X, y) 24 def model(X, y=None): 25 sigma = numpyro.sample("sigma", dist.HalfNormal(0.1)) ---> 26 mlp = random_flax_module( 27 "mlp", 28 Linear(), 29 prior={"Dense.bias": dist.Normal(0,2.5), "Dense.kernel": dist.Normal(0,1)}, 30 input_shape=(X.shape[1],)) 32 # params = numpyro.deterministic("params", mlp.args[0]) 34 with numpyro.plate("data", X.shape[0]): File ~/Desktop/numpyro_bug_flax/.venv/lib/python3.11/site-packages/numpyro/contrib/module.py:371, in random_flax_module(name, nn_module, prior, input_shape, apply_rng, mutable, *args, **kwargs) 252 def random_flax_module( 253 name, 254 nn_module, (...) 260 **kwargs, 261 ): 262 """ 263 A primitive to place a prior over the parameters of the Flax module `nn_module`. 264 (...) 369 >>> assert np.sqrt(np.mean(np.square(y_test - y_pred))) < 1 370 """ --> 371 nn = flax_module( 372 name, 373 nn_module, 374 *args, 375 input_shape=input_shape, 376 apply_rng=apply_rng, 377 mutable=mutable, 378 **kwargs, 379 ) 380 params = nn.args[0] 381 new_params = deepcopy(params) File ~/Desktop/numpyro_bug_flax/.venv/lib/python3.11/site-packages/numpyro/contrib/module.py:97, in flax_module(name, nn_module, input_shape, apply_rng, mutable, *args, **kwargs) 94 rngs[kind] = subkey 95 rngs["params"] = rng_key ---> 97 nn_vars = flax.core.unfreeze(nn_module.init(rngs, *args, **kwargs)) 98 if "params" not in nn_vars: 99 raise ValueError( 100 "Your nn_module does not have any parameter. Currently, it is not" 101 " supported in NumPyro. Please make a github issue if you need" 102 " that feature." 103 ) [... skipping hidden 4 frame] File ~/Desktop/numpyro_bug_flax/.venv/lib/python3.11/site-packages/flax/core/scope.py:1107, in init.<locals>.wrapper(rngs, *args, **kwargs) 1104 @functools.wraps(fn) 1105 def wrapper(rngs, *args, **kwargs) -> tuple[Any, VariableDict]: 1106 if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs): -> 1107 raise ValueError( 1108 'First argument passed to an init function should be a ' 1109 '``jax.PRNGKey`` or a dictionary mapping strings to ' 1110 '``jax.PRNGKey``.' 1111 ) 1112 if not isinstance(rngs, (dict, FrozenDict)): 1113 rngs = {'params': rngs} ValueError: First argument passed to an init function should be a ``jax.PRNGKey`` or a dictionary mapping strings to ``jax.PRNGKey``.
Expected the numpyro.infer.util.log_likelihood function to work on a model with random_flax_module
numpyro.infer.util.log_likelihood
The text was updated successfully, but these errors were encountered:
contrib/model.py
Successfully merging a pull request may close this issue.
Bug Description
A clear and concise description of the bug.
Versions
Steps to Reproduce
Here's the traceback
Click to expand Traceback
Expected Behavior
Expected the
numpyro.infer.util.log_likelihood
function to work on a model with random_flax_moduleThe text was updated successfully, but these errors were encountered: