Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exclude_deterministic argument in Predictive does not apply for models with discrete latents #1861

Open
fehiepsi opened this issue Sep 21, 2024 · 0 comments
Labels
bug Something isn't working good first issue Good for newcomers

Comments

@fehiepsi
Copy link
Member

See the forum issue https://forum.pyro.ai/t/enumerate-support-for-batch-dimensions-of-custom-distribution/7656/4

We need to move the logic of exclude_deterministic to the infer_discrete branch

if infer_discrete:
from numpyro.contrib.funsor import config_enumerate
from numpyro.contrib.funsor.discrete import _sample_posterior
model_trace = prototype_trace
temperature = 1
pred_samples = _sample_posterior(
config_enumerate(condition(model, samples)),
first_available_dim,
temperature,
rng_key,
*model_args,
**model_kwargs,
)
else:
def _samples_wo_deterministic(msg):
return (
samples.get(msg["name"]) if msg["type"] != "deterministic" else None
)
substituted_model = (
substitute(masked_model, substitute_fn=_samples_wo_deterministic)
if exclude_deterministic
else substitute(masked_model, samples)
)

@fehiepsi fehiepsi added bug Something isn't working good first issue Good for newcomers labels Sep 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

1 participant