diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 3df72eb98..2e9498e99 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -132,11 +132,16 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): ).get_trace(*model_args, **model_kwargs) rng_key, key_z = random.split(rng_key) - gibbs_sites = { - name: site["value"] - for name, site in self._prototype_trace.items() - if name in self._gibbs_sites - } + + gibbs_sites = {} + + for name, site in self._prototype_trace.items(): + if init_params and (name in init_params) and (name in self._gibbs_sites): + gibbs_sites[name] = init_params.pop(name) + + elif name in self._gibbs_sites: + gibbs_sites[name] = site["value"] + model_kwargs["_gibbs_sites"] = gibbs_sites hmc_state = self.inner_kernel.init( key_z, num_warmup, init_params, model_args, model_kwargs