Skip to content

Commit

Permalink
a
Browse files Browse the repository at this point in the history
  • Loading branch information
frans committed Dec 6, 2023
1 parent 42bb6f2 commit 70bc8f3
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ def _replay_wrapper(replay_trace, trace, i, length):
def get_ith_value(site):
value_shape = jnp.shape(site["value"])
site_len = value_shape[0] if value_shape else 0

if site["name"] not in trace or site_len != length or site["type"] not in ("sample", "deterministic"):
if (
site["name"] not in trace
or site_len != length
or site["type"] not in ("sample", "deterministic")
):
return site

site["value"] = site["value"][i]
Expand Down

0 comments on commit 70bc8f3

Please sign in to comment.