diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index 9d6f4da1a..cc3a251f8 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -19,8 +19,11 @@ def _replay_wrapper(replay_trace, trace, i, length): def get_ith_value(site): value_shape = jnp.shape(site["value"]) site_len = value_shape[0] if value_shape else 0 - - if site["name"] not in trace or site_len != length or site["type"] not in ("sample", "deterministic"): + if ( + site["name"] not in trace + or site_len != length + or site["type"] not in ("sample", "deterministic") + ): return site site["value"] = site["value"][i]