You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int32. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.
I think the issue stems from the HMCGibbs init method, which DiscreteHMCGibbs inherits from. Below, you can see that the init_params arg doesn't consider that init_params can also be gibbs_sites.
definit(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
model_kwargs= {} ifmodel_kwargsisNoneelsemodel_kwargs.copy()
ifself._prototype_traceisNone:
rng_key, key_u=random.split(rng_key)
# We use init strategy to get around ImproperUniform which does not have# sample method.self._prototype_trace=trace(
substitute(seed(self.model, key_u), substitute_fn=init_to_sample)
).get_trace(*model_args, **model_kwargs)
rng_key, key_z=random.split(rng_key)
gibbs_sites= {
name: site["value"]
forname, siteinself._prototype_trace.items()
ifnameinself._gibbs_sites
}
model_kwargs["_gibbs_sites"] =gibbs_siteshmc_state=self.inner_kernel.init(
key_z, num_warmup, init_params, model_args, model_kwargs
)
z= {**gibbs_sites, **hmc_state.z}
returndevice_put(HMCGibbsState(z, hmc_state, rng_key))
It should be pretty easy to fix this by setting the values in gibbs_sites to init_params where the keys match the prototype trace. Then, you should just be able to pop those values from the init_params dict and proceed as normal. Thoughts?
The text was updated successfully, but these errors were encountered:
Minimal working example:
yields the following error:
TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int32. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.
I think the issue stems from the HMCGibbs init method, which DiscreteHMCGibbs inherits from. Below, you can see that the init_params arg doesn't consider that init_params can also be gibbs_sites.
It should be pretty easy to fix this by setting the values in gibbs_sites to init_params where the keys match the prototype trace. Then, you should just be able to pop those values from the init_params dict and proceed as normal. Thoughts?
The text was updated successfully, but these errors were encountered: