diff --git a/equinox/internal/_loop/checkpointed.py b/equinox/internal/_loop/checkpointed.py index 51e3c4ef..d3a54455 100644 --- a/equinox/internal/_loop/checkpointed.py +++ b/equinox/internal/_loop/checkpointed.py @@ -311,7 +311,8 @@ def _stumm_walther_i(step, save_state): "`equinox.internal.checkpointed_while_loop`. " "Please raise an issue at https://github.com/patrick-kidger/equinox" ) - out = error_if(out, pred & (o == -1), msg) + if getattr(typing, "TESTING", False): + out = error_if(out, pred & (o == -1), msg) out = nonbatchable(out) return out @@ -800,7 +801,8 @@ def _body_fun(carry): "`equinox.internal.checkpointed_while_loop`. " "Please raise an issue at https://github.com/patrick-kidger/equinox" ) - step_val = error_if(step_val, step_val >= step_grad_val, msg) + if getattr(typing, "TESTING", False): + step_val = error_if(step_val, step_val >= step_grad_val, msg) # # First either propagate our primal state forward, or make a U-turn if the