Skip to content

Commit

Permalink
Update grad of while_loop message. (#2976)
Browse files Browse the repository at this point in the history
The previous error message was misleading as of
ed8dbd2
(see #2414 (comment)
for context).
  • Loading branch information
skye authored May 6, 2020
1 parent 0534b65 commit aedf346
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions jax/lax/lax_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,9 @@ def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts:
return out_tracers

def _while_transpose_error(*_, **kwargs):
raise ValueError("Reverse-mode differentiation does not work for lax.while_loop. "
"Try using lax.scan, or lax.fori_loop with constant bounds.")
raise ValueError("Reverse-mode differentiation does not work for "
"lax.while_loop or lax.fori_loop. "
"Try using lax.scan instead.")

while_p = lax.Primitive('while')
while_p.multiple_results = True
Expand Down

0 comments on commit aedf346

Please sign in to comment.