Skip to content

Commit

Permalink
Update grad of while_loop message.
Browse files Browse the repository at this point in the history
The previous error message was misleading as of
jax-ml@ed8dbd2
(see jax-ml#2414 (comment)
for context).
  • Loading branch information
skye committed May 6, 2020
1 parent e51c7d7 commit 7cd644d
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/lax/lax_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts:

def _while_transpose_error(*_, **kwargs):
raise ValueError("Reverse-mode differentiation does not work for lax.while_loop. "
"Try using lax.scan, or lax.fori_loop with constant bounds.")
"Try using lax.scan.")

while_p = lax.Primitive('while')
while_p.multiple_results = True
Expand Down

0 comments on commit 7cd644d

Please sign in to comment.