-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can’t import both jax and tensorflow (causes kernel restart) #120
Comments
Currently we can’t import both jax and tensorflow. I’m not sure the exact reason but it has something to do with how we build XLA. Definitely a bug we need to fix, but for now there’s no workaround that I know of. |
By the way, unless you have additional info to suggest otherwise, I think this is not an OOM. I’ll revise the issue title to reflect my best understanding (that this is about importing both jax and tf). Let me know if we need to change it back. Thanks for bringing this up! |
thanks for quick response @mattjj . I think the bug is related to C++ code. |
I think the issue stems from the build process, rather than a bug in the code itself. We build XLA as part of tensorflow, but that means our XLA "jaxlib" .so file and tensorflow's .so files may have some symbol redundancies, or something. @hawkinsp any thoughts to jot down here? |
Just to follow up: I finally had time to look at this. It turns out that the jaxlib (i.e. XLA) Python extension depends on a few small pieces of TensorFlow (mostly some C++ support libraries). If both the jaxlib and tensorflow Python extensions export the same symbols, a crash occurs. The fix is to restrict the set of symbols we export from the XLA python extension; there's no good reason for the XLA python extension to also export bits of TF. I'll push a fix shortly. |
FYI @rsepassi |
[XLA] Hide all symbols except PyInit_* from the pywrap_xla module. The XLA Python extension includes parts of TensorFlow core. If the TensorFlow symbols are exported by the XLA extension as well as TensorFlow itself, a segfault occurs if both tensorflow and jax are imported into the same Python interpreter. JAX issue jax-ml/jax#120 will be fixed when this change is integrated into the JAX repository. PiperOrigin-RevId: 229962366
Includes tensorflow/tensorflow@7fce32e, which fixes jax-ml#120 once jaxlib is rebuilt.
PR #265 checked in a fix to jaxlib that should fix this problem. Note that for the moment you'll either need to rebuild jaxlib from source, or wait until we build and push new Jaxlib wheels. Hope that helps! |
awesome!
…On Fri, Jan 18, 2019 at 12:01 PM Peter Hawkins ***@***.***> wrote:
PR #265 <#265> checked in a fix to
jaxlib that should fix this problem.
Note that for the moment you'll either need to rebuild jaxlib from source,
or wait until we build and push new Jaxlib wheels.
Hope that helps!
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#120 (comment)>, or mute
the thread
<https://github.com/notifications/unsubscribe-auth/ABEGW4Ved-Oh2CyvWk4rRKk7aAnlVmwbks5vEigygaJpZM4ZVFtK>
.
|
I've built and pushed updated jaxlib wheels (0.1.6) to pypi. If you update your jaxlib version ( |
hi,
after importing those libs:
I tried to run this cell and the results:
The text was updated successfully, but these errors were encountered: