Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise exceptions from jit-compiled functions #181

Merged
merged 3 commits into from
Jun 18, 2024

Conversation

diegoferigo
Copy link
Member

@diegoferigo diegoferigo commented Jun 14, 2024

This PR:

  • Introduces a new helper that allows to raise exceptions from within jit-compiled functions.
  • For each exception, a new dummy branch triggered by the same condition of the expectin is injected in the compiled code. In this way, the host callback (that can be slow even if no-op) is not triggered unless needed.
  • The -apparently- double check on condition (both in the callback and the low-level jax.lax.cond) is necessary because JAX compiles both branches and it would raise an exception while tracing.

The caveat is that JAX raises a XlaRuntimeError to stop the execution of the jit-compiled function. The real exception raised in the callback is printed together with the corresponding stack trace earlier in the output.

Although this method is not capable of handling raised exceptions with a try statement (I don't see any way to do that, regardless), at least we can stop the execution by raising.


📚 Documentation preview 📚: https://jaxsim--181.org.readthedocs.build//181/

@diegoferigo diegoferigo self-assigned this Jun 14, 2024
@diegoferigo diegoferigo force-pushed the raise_exceptions_from_jit_compiled_functions branch from dcd6fff to 82b6d79 Compare June 14, 2024 16:32
@diegoferigo diegoferigo force-pushed the raise_exceptions_from_jit_compiled_functions branch from 82b6d79 to 9a295c5 Compare June 17, 2024 09:25
@diegoferigo diegoferigo marked this pull request as ready for review June 17, 2024 09:37
@diegoferigo diegoferigo requested a review from flferretti as a code owner June 17, 2024 09:37
@diegoferigo diegoferigo force-pushed the raise_exceptions_from_jit_compiled_functions branch 3 times, most recently from bbc54a3 to 6d10fa8 Compare June 17, 2024 11:08
It contains utility functions to raise exceptions within jit-compiled functions
@diegoferigo diegoferigo force-pushed the raise_exceptions_from_jit_compiled_functions branch from 6d10fa8 to 827e5d3 Compare June 17, 2024 11:27
Copy link
Collaborator

@flferretti flferretti left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Diego, I left some comments

tests/test_exceptions.py Show resolved Hide resolved
src/jaxsim/exceptions.py Show resolved Hide resolved
src/jaxsim/exceptions.py Outdated Show resolved Hide resolved
src/jaxsim/exceptions.py Outdated Show resolved Hide resolved
@diegoferigo
Copy link
Member Author

On a MWE similar to the new test of this PR, the following is the output:

Output of raising exceptions in a jax callback
jax.debug_callback failed
Traceback (most recent call last):
  File "/jaxsim/lib/python3.12/site-packages/jax/_src/debugging.py", line 84, in debug_callback_impl
    callback(*args)
  File "/jaxsim/lib/python3.12/site-packages/jax/_src/debugging.py", line 246, in _flat_callback
    callback(*args, **kwargs)
  File "/home/dferigo/git/jaxsim/src/jaxsim/exceptions.py", line 46, in _raise_exception
    raise exception(msg.format(*args, **kwargs)).with_traceback(back_tb)
  File "/jaxsim/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 2486, in _wrapped_callback
    out_vals = callback(*args)
               ^^^^^^^^^^^^^^^
ValueError: This is a test exception for 42 data
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[1], line 81
     77     return data
     80 # _ = jit_compiled_function(data=40)
---> 81 _ = jit_compiled_function(data=42)

    [... skipping hidden 10 frame]

File /jaxsim/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py:1178, in ExecuteReplicated.__call__(self, *args)
   1175 if (self.ordered_effects or self.has_unordered_effects
   1176     or self.has_host_callbacks):
   1177   input_bufs = self._add_tokens_to_inputs(input_bufs)
-> 1178   results = self.xla_executable.execute_sharded(
   1179       input_bufs, with_tokens=True
   1180   )
   1181   result_token_bufs = results.disassemble_prefix_into_single_device_arrays(
   1182       len(self.ordered_effects))
   1183   sharded_runtime_token = results.consume_token()

XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: ValueError: This is a test exception for 42 data

At:
  /jaxsim/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py(2486): _wrapped_callback
  /jaxsim/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py(1178): __call__
  /jaxsim/lib/python3.12/site-packages/jax/_src/profiler.py(335): wrapper
  /jaxsim/lib/python3.12/site-packages/jax/_src/pjit.py(1488): _pjit_call_impl_python
  /jaxsim/lib/python3.12/site-packages/jax/_src/pjit.py(1534): call_impl_cache_miss
  /jaxsim/lib/python3.12/site-packages/jax/_src/pjit.py(1558): _pjit_call_impl
  /jaxsim/lib/python3.12/site-packages/jax/_src/core.py(879): process_primitive
  /jaxsim/lib/python3.12/site-packages/jax/_src/core.py(391): bind_with_trace
  /jaxsim/lib/python3.12/site-packages/jax/_src/core.py(2789): bind
  /jaxsim/lib/python3.12/site-packages/jax/_src/pjit.py(182): _python_pjit_helper
  /jaxsim/lib/python3.12/site-packages/jax/_src/pjit.py(305): cache_miss
  /jaxsim/lib/python3.12/site-packages/jax/_src/traceback_util.py(179): reraise_with_filtered_traceback
  <ipython-input-1-4ea84c6da2c9>(81): <module>
  /jaxsim/lib/python3.12/site-packages/IPython/core/interactiveshell.py(3577): run_code
  /jaxsim/lib/python3.12/site-packages/IPython/core/interactiveshell.py(3517): run_ast_nodes
  /jaxsim/lib/python3.12/site-packages/IPython/core/interactiveshell.py(3334): run_cell_async
  /jaxsim/lib/python3.12/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
  /jaxsim/lib/python3.12/site-packages/IPython/core/interactiveshell.py(3130): _run_cell
  /jaxsim/lib/python3.12/site-packages/IPython/core/interactiveshell.py(3075): run_cell
  /jaxsim/lib/python3.12/site-packages/IPython/terminal/interactiveshell.py(910): interact
  /jaxsim/lib/python3.12/site-packages/IPython/terminal/interactiveshell.py(917): mainloop
  /jaxsim/lib/python3.12/site-packages/IPython/terminal/ipapp.py(317): start
  /jaxsim/lib/python3.12/site-packages/traitlets/config/application.py(1075): launch_instance
  /jaxsim/lib/python3.12/site-packages/IPython/__init__.py(130): start_ipython
  /jaxsim/bin/ipython(10): <module>

There are two piece of outputs. Before the line -------- there is what is seems the actual output of the callback (let's call it output 1) that raises the right type of exception, and after the line there is the XlaRuntimeError exception that can be caught by the code (let's call it output 2).

Originally, in the test I was trying to capture output 1, but I couldn't find any way to do that within pytest (I've tried both by redirecting the std{out/err} streams to a buffer, and using the capsys fixture). I suspect that the callback runs in a different thread o similar, making it impossible to catch its output (at least, I couldn't figure out a way.

@flferretti your suggestion in #181 (comment) makes sense, I didn't think about it because it could only catch the content of the XlaRuntimeError, that is much longer than the original exception. However, as you can notice from the output above, it contains the content of the original exception. I'll update the tests to use that, since it is good enough for testing purpose. Thanks! In any case, I wanted to provide all this information here instead of the original comment in order to have better visibility for future readers.

Co-authored-by: Filippo Luca Ferretti <filippo.ferretti@iit.it>
@diegoferigo diegoferigo merged commit 954c5c6 into main Jun 18, 2024
43 checks passed
@diegoferigo diegoferigo deleted the raise_exceptions_from_jit_compiled_functions branch June 18, 2024 14:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants