You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
A follow up on #3103. I am using cond in a function recursion because I have to iterate on sub-parts of matrices. However, in this case cond is evaluating both statements, leading to an error:
Hi - thanks for the question! lax.cond does its branching at runtime, not trace-time. At trace-time, all branches must be traced, because the value of the conditional is not known until runtime. This means that it cannot be used in a recursive program in order to break the loop.
Fortunately in this case, you're branching based on static attributes (the length of the array) so you can use normal Python conditionals instead of lax.cond:
Description
A follow up on #3103. I am using
cond
in a function recursion because I have to iterate on sub-parts of matrices. However, in this casecond
is evaluating both statements, leading to an error:This results in
IndexError: index is out of bounds for axis 0 with size 0
.Is there any way to circumvent this?
Thanks for all this great work!
What jax/jaxlib version are you using?
jax/jaxlib=0.3.24
Which accelerator(s) are you using?
CPU
Additional system info
Python 3.10.8, Linux 6.0.12
The text was updated successfully, but these errors were encountered: