From 24c21b86bae696edd451ec43dcd890fd8747e3c1 Mon Sep 17 00:00:00 2001 From: Ami Falk <96739930+amifalk@users.noreply.github.com> Date: Fri, 10 Nov 2023 12:02:32 -0500 Subject: [PATCH] fix init_params bug in hmcgibbs (#1673) * fix init_params bug in hmcgibbs * improved concision * only set init_params to gibbs if param is gibbs --- numpyro/infer/hmc_gibbs.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 3df72eb98..2e9498e99 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -132,11 +132,16 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): ).get_trace(*model_args, **model_kwargs) rng_key, key_z = random.split(rng_key) - gibbs_sites = { - name: site["value"] - for name, site in self._prototype_trace.items() - if name in self._gibbs_sites - } + + gibbs_sites = {} + + for name, site in self._prototype_trace.items(): + if init_params and (name in init_params) and (name in self._gibbs_sites): + gibbs_sites[name] = init_params.pop(name) + + elif name in self._gibbs_sites: + gibbs_sites[name] = site["value"] + model_kwargs["_gibbs_sites"] = gibbs_sites hmc_state = self.inner_kernel.init( key_z, num_warmup, init_params, model_args, model_kwargs