-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initialize JAX configuration as per the v0.4.25 release #92
Conversation
Some tests fail raise segmentation fault: tests/test_jax_oop.py::test_decorators_vmap FAILED [100%]
=================================== FAILURES ===================================
__________________________________ test_data ___________________________________
:-1: running the test CRASHED with signal 11
_______________________________ test_mutability ________________________________
:-1: running the test CRASHED with signal 11
_______________________ test_decorators_jit_compilation ________________________
:-1: running the test CRASHED with signal 11
_____________________________ test_decorators_vmap _____________________________
:-1: running the test CRASHED with signal 11
=============================== warnings summary ===============================
tests/test_ad_physics.py: 3 warnings
tests/test_eom.py: 12 warnings
tests/test_forward_dynamics.py: 12 warnings
tests/test_jax_oop.py: 4 warnings
/opt/hostedtoolcache/Python/3.12.2/x64/lib/python3.12/site-packages/py/_process/forkedfunc.py:45: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
pid = os.fork()
tests/test_ad_physics.py: 3 warnings
tests/test_eom.py: 12 warnings
tests/test_forward_dynamics.py: 12 warnings
tests/test_jax_oop.py: 4 warnings
/opt/hostedtoolcache/Python/3.12.2/x64/lib/python3.12/site-packages/py/_process/forkedfunc.py:45: DeprecationWarning: This process (pid=4701) is multi-threaded, use of fork() may lead to deadlocks in the child.
pid = os.fork()
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============ 4 failed, 27 passed, 62 warnings in 2637.68s (0:43:57) ============
Error: Process completed with exit code 1. It may not be related to the changes of this PR C.C. @traversaro |
The last working job is: https://github.com/ami-iit/jaxsim/actions/runs/7985799317 . |
All the jobs will fail if re-run as Is is possible that |
Indeed, see #93 . I guess that pinning jax to 0.4.24 everything should work fine? |
Do you know if we can get a stacktrace of the segfault? |
Ok, reproduced locally:
with:
|
Initially I thought the |
Actually I think I was barking at the wrong tree. If we remove the |
@flferretti I opened a proposal for a workaround in #94 . |
Remove --forked from pytest as workaround for jax 0.4.25
Thanks a lot @traversaro, I checked #94 and it seems to be the reason why the error is been raised |
To be honest, I am not sure if this is a regression of jax or simply one should not use forked with jax, but for the time being it should unblock the situation. |
By the way, this seems related: jax-ml/jax#10242 . |
Mmh I see CI failing with:
And this is the reason why using |
I will revert #94 then |
The problem is that with |
If I understood, that was happening with We can likely remove this pinning as soon as the transition to functional will be finalized. I expect that we don't need |
Exactly, that was our idea as well, see #93 (comment) . |
This PR will update the JAX configuration initialization as per JAX v0.4.25, in which the access to
jax.config
has been deprecated, leading toImportError: cannot import name 'config' from 'jax.config'
.📚 Documentation preview 📚: https://jaxsim--92.org.readthedocs.build//92/