Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

log_likelihood fails with random_flax_module (and flax_module) #1991

Open
kylejcaron opened this issue Feb 28, 2025 · 0 comments · May be fixed by #1992
Open

log_likelihood fails with random_flax_module (and flax_module) #1991

kylejcaron opened this issue Feb 28, 2025 · 0 comments · May be fixed by #1992
Labels
bug Something isn't working

Comments

@kylejcaron
Copy link
Contributor

Bug Description

A clear and concise description of the bug.

Versions

  • jax: '0.5.1',
  • numpyro: '0.17.0'
  • flax: '0.10.4'

Steps to Reproduce

import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from jax import random

import flax.linen as nn
from numpyro.contrib.module import random_flax_module

# Simulate
rng = np.random.default_rng(99)
N = 1000

X = rng.normal(0, 1, size=(N,1))
mu = 1 + X @ np.array([0.5])
y = rng.normal(mu, 0.5)

# Simple linear layer
class Linear(nn.Module):
    @nn.compact
    def __call__(self, x):
        return nn.Dense(1, use_bias=True, name='Dense')(x)

def model(X, y=None):
    sigma = numpyro.sample("sigma", dist.HalfNormal(0.1))
    priors = {"Dense.bias": dist.Normal(0,2.5), "Dense.kernel": dist.Normal(0,1)}
    mlp = random_flax_module("mlp", Linear(), prior=priors, input_shape=(X.shape[1],))
    with numpyro.plate("data", X.shape[0]):
        mu = numpyro.deterministic("mu", mlp(X).squeeze(-1))
        y = numpyro.sample("y", dist.Normal(mu, sigma), obs=y)

# Fit model
kernel = NUTS(model, target_accept_prob=0.95)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=2)
mcmc.run(random.PRNGKey(0), X=X, y=y)

# run log likelihood
numpyro.infer.util.log_likelihood(model, mcmc.get_samples(), X=X, y=y)

Here's the traceback

Click to expand Traceback
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[431], line 46
     42 mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=2)
     43 mcmc.run(random.PRNGKey(0), X=X, y=y)
---> 46 numpyro.infer.util.log_likelihood(model, mcmc.get_samples(), X=X, y=y)

File ~/Desktop/numpyro_bug_flax/.venv/lib/python3.11/site-packages/numpyro/infer/util.py:1147, in log_likelihood(model, posterior_samples, parallel, batch_ndims, *args, **kwargs)
   1145 batch_size = int(np.prod(batch_shape))
   1146 chunk_size = batch_size if parallel else 1
-> 1147 return soft_vmap(single_loglik, posterior_samples, len(batch_shape), chunk_size)

File ~/Desktop/numpyro_bug_flax/.venv/lib/python3.11/site-packages/numpyro/util.py:453, in soft_vmap(fn, xs, batch_ndims, chunk_size)
    447     xs = jax.tree.map(
    448         lambda x: jnp.reshape(x, prepend_shape + (chunk_size,) + jnp.shape(x)[1:]),
    449         xs,
    450     )
    451     fn = vmap(fn)
--> 453 ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)
    454 map_ndims = int(num_chunks > 1) + int(chunk_size > 1)
    455 ys = jax.tree.map(
    456     lambda y: jnp.reshape(
    457         y, (int(np.prod(jnp.shape(y)[:map_ndims])),) + jnp.shape(y)[map_ndims:]
    458     )[:batch_size],
    459     ys,
    460 )

    [... skipping hidden 13 frame]

File ~/Desktop/numpyro_bug_flax/.venv/lib/python3.11/site-packages/numpyro/infer/util.py:1122, in log_likelihood.<locals>.single_loglik(samples)
   1118 def single_loglik(samples):
   1119     substituted_model = (
   1120         substitute(model, samples) if isinstance(samples, dict) else model
   1121     )
-> 1122     model_trace = trace(substituted_model).get_trace(*args, **kwargs)
   1123     return {
   1124         name: site["fn"].log_prob(site["value"])
   1125         for name, site in model_trace.items()
   1126         if site["type"] == "sample" and site["is_observed"]
   1127     }

File ~/Desktop/numpyro_bug_flax/.venv/lib/python3.11/site-packages/numpyro/handlers.py:191, in trace.get_trace(self, *args, **kwargs)
    183 def get_trace(self, *args, **kwargs) -> OrderedDict[str, Message]:
    184     """
    185     Run the wrapped callable and return the recorded trace.
    186 
   (...)    189     :return: `OrderedDict` containing the execution trace.
    190     """
--> 191     self(*args, **kwargs)
    192     return self.trace

File ~/Desktop/numpyro_bug_flax/.venv/lib/python3.11/site-packages/numpyro/primitives.py:121, in Messenger.__call__(self, *args, **kwargs)
    119     return self
    120 with self:
--> 121     return self.fn(*args, **kwargs)

File ~/Desktop/numpyro_bug_flax/.venv/lib/python3.11/site-packages/numpyro/primitives.py:121, in Messenger.__call__(self, *args, **kwargs)
    119     return self
    120 with self:
--> 121     return self.fn(*args, **kwargs)

Cell In[431], line 26, in model(X, y)
     24 def model(X, y=None):
     25     sigma = numpyro.sample("sigma", dist.HalfNormal(0.1))
---> 26     mlp = random_flax_module(
     27         "mlp", 
     28         Linear(), 
     29         prior={"Dense.bias": dist.Normal(0,2.5), "Dense.kernel": dist.Normal(0,1)},
     30         input_shape=(X.shape[1],))
     32     # params = numpyro.deterministic("params", mlp.args[0])
     34     with numpyro.plate("data", X.shape[0]):

File ~/Desktop/numpyro_bug_flax/.venv/lib/python3.11/site-packages/numpyro/contrib/module.py:371, in random_flax_module(name, nn_module, prior, input_shape, apply_rng, mutable, *args, **kwargs)
    252 def random_flax_module(
    253     name,
    254     nn_module,
   (...)    260     **kwargs,
    261 ):
    262     """
    263     A primitive to place a prior over the parameters of the Flax module `nn_module`.
    264 
   (...)    369         >>> assert np.sqrt(np.mean(np.square(y_test - y_pred))) < 1
    370     """
--> 371     nn = flax_module(
    372         name,
    373         nn_module,
    374         *args,
    375         input_shape=input_shape,
    376         apply_rng=apply_rng,
    377         mutable=mutable,
    378         **kwargs,
    379     )
    380     params = nn.args[0]
    381     new_params = deepcopy(params)

File ~/Desktop/numpyro_bug_flax/.venv/lib/python3.11/site-packages/numpyro/contrib/module.py:97, in flax_module(name, nn_module, input_shape, apply_rng, mutable, *args, **kwargs)
     94         rngs[kind] = subkey
     95 rngs["params"] = rng_key
---> 97 nn_vars = flax.core.unfreeze(nn_module.init(rngs, *args, **kwargs))
     98 if "params" not in nn_vars:
     99     raise ValueError(
    100         "Your nn_module does not have any parameter. Currently, it is not"
    101         " supported in NumPyro. Please make a github issue if you need"
    102         " that feature."
    103     )

    [... skipping hidden 4 frame]

File ~/Desktop/numpyro_bug_flax/.venv/lib/python3.11/site-packages/flax/core/scope.py:1107, in init.<locals>.wrapper(rngs, *args, **kwargs)
   1104 @functools.wraps(fn)
   1105 def wrapper(rngs, *args, **kwargs) -> tuple[Any, VariableDict]:
   1106   if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs):
-> 1107     raise ValueError(
   1108       'First argument passed to an init function should be a '
   1109       '``jax.PRNGKey`` or a dictionary mapping strings to '
   1110       '``jax.PRNGKey``.'
   1111     )
   1112   if not isinstance(rngs, (dict, FrozenDict)):
   1113     rngs = {'params': rngs}

ValueError: First argument passed to an init function should be a ``jax.PRNGKey`` or a dictionary mapping strings to ``jax.PRNGKey``.

Expected Behavior

Expected the numpyro.infer.util.log_likelihood function to work on a model with random_flax_module

@kylejcaron kylejcaron added the bug Something isn't working label Feb 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant