-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Raise exceptions from jit-compiled functions #181
Raise exceptions from jit-compiled functions #181
Conversation
dcd6fff
to
82b6d79
Compare
82b6d79
to
9a295c5
Compare
bbc54a3
to
6d10fa8
Compare
It contains utility functions to raise exceptions within jit-compiled functions
6d10fa8
to
827e5d3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Diego, I left some comments
On a MWE similar to the new test of this PR, the following is the output: Output of raising exceptions in a jax callback
There are two piece of outputs. Before the line Originally, in the test I was trying to capture output 1, but I couldn't find any way to do that within pytest (I've tried both by redirecting the std{out/err} streams to a buffer, and using the @flferretti your suggestion in #181 (comment) makes sense, I didn't think about it because it could only catch the content of the |
Co-authored-by: Filippo Luca Ferretti <filippo.ferretti@iit.it>
This PR:
condition
(both in the callback and the low-leveljax.lax.cond
) is necessary because JAX compiles both branches and it would raise an exception while tracing.The caveat is that JAX raises a
XlaRuntimeError
to stop the execution of the jit-compiled function. The real exception raised in the callback is printed together with the corresponding stack trace earlier in the output.Although this method is not capable of handling raised exceptions with a
try
statement (I don't see any way to do that, regardless), at least we can stop the execution by raising.📚 Documentation preview 📚: https://jaxsim--181.org.readthedocs.build//181/