Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] DiscreteHMCGibbs does not properly initialize discrete latents with init_params argument #1672

Closed
amifalk opened this issue Nov 9, 2023 · 2 comments
Labels
bug Something isn't working

Comments

@amifalk
Copy link
Contributor

amifalk commented Nov 9, 2023

Minimal working example:

import jax.numpy as jnp
from jax.random import PRNGKey

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs

def model(data):
    with numpyro.plate('n_obs', 10):
        mu = numpyro.sample('mu', dist.Categorical(jnp.array([.25, .25, .25, .25])))
        numpyro.sample('dat', dist.Normal(mu, 1), obs=data)
        
mcmc = MCMC(DiscreteHMCGibbs(NUTS(model), modified=True), 
            num_warmup=1000,
            num_samples=100,
            num_chains=1)

data = dist.Normal(0, 1).sample(PRNGKey(0), (10,))
unconstrained_params = {'mu': jnp.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])}

mcmc.run(PRNGKey(1), data, init_params=unconstrained_params)

yields the following error:

TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int32. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.

I think the issue stems from the HMCGibbs init method, which DiscreteHMCGibbs inherits from. Below, you can see that the init_params arg doesn't consider that init_params can also be gibbs_sites.

    def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
        model_kwargs = {} if model_kwargs is None else model_kwargs.copy()
        if self._prototype_trace is None:
            rng_key, key_u = random.split(rng_key)
            # We use init strategy to get around ImproperUniform which does not have
            # sample method.
            self._prototype_trace = trace(
                substitute(seed(self.model, key_u), substitute_fn=init_to_sample)
            ).get_trace(*model_args, **model_kwargs)

        rng_key, key_z = random.split(rng_key)
        gibbs_sites = {
            name: site["value"]
            for name, site in self._prototype_trace.items()
            if name in self._gibbs_sites
        }
        model_kwargs["_gibbs_sites"] = gibbs_sites
        hmc_state = self.inner_kernel.init(
            key_z, num_warmup, init_params, model_args, model_kwargs
        )

        z = {**gibbs_sites, **hmc_state.z}

        return device_put(HMCGibbsState(z, hmc_state, rng_key))

It should be pretty easy to fix this by setting the values in gibbs_sites to init_params where the keys match the prototype trace. Then, you should just be able to pop those values from the init_params dict and proceed as normal. Thoughts?

@fehiepsi
Copy link
Member

fehiepsi commented Nov 9, 2023

Good catch! Yeah, your solution looks reasonable to me. Could you make a PR?

@fehiepsi fehiepsi added the bug Something isn't working label Nov 9, 2023
@amifalk
Copy link
Contributor Author

amifalk commented Nov 9, 2023

Sure thing!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants