lax.while_loop
(and probably other operations) crash when passed nonhashable cond/body functions.
#13554
Labels
bug
Something isn't working
produces
In practice Equinox tends to pass around a lot of callable PyTrees, whose hash is the hash of their leaves. If those leaves are JAX arrays then the PyTree is unhashable. (Whilst still being callable.)
Probably this is low-priority, but fixing it would be easy enough, so I'm opening this issue mostly as a reminder for myself in case I ever get time to come back and sort this out.
The text was updated successfully, but these errors were encountered: