Skip to content

Commit

Permalink
Removed runtime errors from eqxi.while_loop.
Browse files Browse the repository at this point in the history
This is to fix a crash on TPUs, see #628.
  • Loading branch information
patrick-kidger committed Dec 30, 2023
1 parent 598c8aa commit f893428
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions equinox/internal/_loop/checkpointed.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,8 @@ def _stumm_walther_i(step, save_state):
"`equinox.internal.checkpointed_while_loop`. "
"Please raise an issue at https://github.com/patrick-kidger/equinox"
)
out = error_if(out, pred & (o == -1), msg)
if getattr(typing, "TESTING", False):
out = error_if(out, pred & (o == -1), msg)
out = nonbatchable(out)
return out

Expand Down Expand Up @@ -800,7 +801,8 @@ def _body_fun(carry):
"`equinox.internal.checkpointed_while_loop`. "
"Please raise an issue at https://github.com/patrick-kidger/equinox"
)
step_val = error_if(step_val, step_val >= step_grad_val, msg)
if getattr(typing, "TESTING", False):
step_val = error_if(step_val, step_val >= step_grad_val, msg)

#
# First either propagate our primal state forward, or make a U-turn if the
Expand Down

0 comments on commit f893428

Please sign in to comment.