Skip to content

Commit

Permalink
Removed error_if host callbacks to test on TPU
Browse files Browse the repository at this point in the history
  • Loading branch information
neel04 committed Dec 29, 2023
1 parent 8d3175a commit e0346e1
Showing 1 changed file with 2 additions and 13 deletions.
15 changes: 2 additions & 13 deletions equinox/internal/_loop/checkpointed.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
from jaxtyping import Array, ArrayLike, Bool

from ..._ad import filter_closure_convert, filter_custom_vjp
from ..._errors import error_if
from ..._filters import combine, filter, is_array, is_inexact_array, partition
from ..._module import Static
from ..._tree import tree_at, tree_equal
Expand Down Expand Up @@ -306,12 +305,7 @@ def _stumm_walther_i(step, save_state):
i = jnp.where(pred, o, i)
s = jnp.where(pred, o > 0, s)
out = save_residual, index, (i, o, p, s)
msg = (
"Internal run-time error when checkpointing "
"`equinox.internal.checkpointed_while_loop`. "
"Please raise an issue at https://github.com/patrick-kidger/equinox"
)
out = error_if(out, pred & (o == -1), msg)
# out = error_if(out, pred & (o == -1), msg)
out = nonbatchable(out)
return out

Expand Down Expand Up @@ -795,12 +789,7 @@ def _body_fun(carry):
(step_val, step_grad_val, step_next_checkpoint, index, residual_steps)
)

msg = (
"Internal run-time error when backpropagating through "
"`equinox.internal.checkpointed_while_loop`. "
"Please raise an issue at https://github.com/patrick-kidger/equinox"
)
step_val = error_if(step_val, step_val >= step_grad_val, msg)
# step_val = error_if(step_val, step_val >= step_grad_val, msg)

#
# First either propagate our primal state forward, or make a U-turn if the
Expand Down

0 comments on commit e0346e1

Please sign in to comment.