Skip to content

Commit

Permalink
fix init_params bug in hmcgibbs (#1673)
Browse files Browse the repository at this point in the history
* fix init_params bug in hmcgibbs

* improved concision

* only set init_params to gibbs if param is gibbs
  • Loading branch information
amifalk authored Nov 10, 2023
1 parent aeeeb72 commit 24c21b8
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions numpyro/infer/hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,16 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
).get_trace(*model_args, **model_kwargs)

rng_key, key_z = random.split(rng_key)
gibbs_sites = {
name: site["value"]
for name, site in self._prototype_trace.items()
if name in self._gibbs_sites
}

gibbs_sites = {}

for name, site in self._prototype_trace.items():
if init_params and (name in init_params) and (name in self._gibbs_sites):
gibbs_sites[name] = init_params.pop(name)

elif name in self._gibbs_sites:
gibbs_sites[name] = site["value"]

model_kwargs["_gibbs_sites"] = gibbs_sites
hmc_state = self.inner_kernel.init(
key_z, num_warmup, init_params, model_args, model_kwargs
Expand Down

0 comments on commit 24c21b8

Please sign in to comment.