Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can’t import both jax and tensorflow (causes kernel restart) #120

Closed
zeevikal opened this issue Dec 16, 2018 · 9 comments · Fixed by #265
Closed

Can’t import both jax and tensorflow (causes kernel restart) #120

zeevikal opened this issue Dec 16, 2018 · 9 comments · Fixed by #265
Labels
bug Something isn't working

Comments

@zeevikal
Copy link

hi,
after importing those libs:

import pandas as pd
import jax.numpy as np
import tensorflow as tf

I tried to run this cell and the results:

image

@mattjj
Copy link
Collaborator

mattjj commented Dec 16, 2018

Currently we can’t import both jax and tensorflow. I’m not sure the exact reason but it has something to do with how we build XLA. Definitely a bug we need to fix, but for now there’s no workaround that I know of.

@mattjj
Copy link
Collaborator

mattjj commented Dec 16, 2018

By the way, unless you have additional info to suggest otherwise, I think this is not an OOM. I’ll revise the issue title to reflect my best understanding (that this is about importing both jax and tf). Let me know if we need to change it back.

Thanks for bringing this up!

@mattjj mattjj changed the title OOM - running jax in jupyter notebook Can’t import both jax and tensorflow (causes kernel restart) Dec 16, 2018
@zeevikal
Copy link
Author

thanks for quick response @mattjj .
I tried to run the example in jax repo with adding tf lib and got the following error:
image

I think the bug is related to C++ code.
maybe a program where it calls free() on a pointer that wasn't obtained with malloc(). or maybe some other memory-related bug in the code that triggers this behavior.

@mattjj mattjj added the bug Something isn't working label Dec 16, 2018
@mattjj
Copy link
Collaborator

mattjj commented Dec 16, 2018

I think the issue stems from the build process, rather than a bug in the code itself. We build XLA as part of tensorflow, but that means our XLA "jaxlib" .so file and tensorflow's .so files may have some symbol redundancies, or something.

@hawkinsp any thoughts to jot down here?

@hawkinsp
Copy link
Collaborator

Just to follow up: I finally had time to look at this. It turns out that the jaxlib (i.e. XLA) Python extension depends on a few small pieces of TensorFlow (mostly some C++ support libraries). If both the jaxlib and tensorflow Python extensions export the same symbols, a crash occurs. The fix is to restrict the set of symbols we export from the XLA python extension; there's no good reason for the XLA python extension to also export bits of TF.

I'll push a fix shortly.

@mattjj
Copy link
Collaborator

mattjj commented Jan 18, 2019

FYI @rsepassi

tensorflow-copybara pushed a commit to tensorflow/tensorflow that referenced this issue Jan 18, 2019
[XLA] Hide all symbols except PyInit_* from the pywrap_xla module.

The XLA Python extension includes parts of TensorFlow core. If the TensorFlow symbols are exported by the XLA extension as well as TensorFlow itself, a segfault occurs if both tensorflow and jax are imported into the same Python interpreter.

JAX issue jax-ml/jax#120 will be fixed when this change is integrated into the JAX repository.

PiperOrigin-RevId: 229962366
hawkinsp added a commit to hawkinsp/jax that referenced this issue Jan 18, 2019
Includes tensorflow/tensorflow@7fce32e, which fixes jax-ml#120 once jaxlib is rebuilt.
@hawkinsp
Copy link
Collaborator

PR #265 checked in a fix to jaxlib that should fix this problem.

Note that for the moment you'll either need to rebuild jaxlib from source, or wait until we build and push new Jaxlib wheels.

Hope that helps!

@rsepassi
Copy link
Contributor

rsepassi commented Jan 18, 2019 via email

@hawkinsp
Copy link
Collaborator

I've built and pushed updated jaxlib wheels (0.1.6) to pypi. If you update your jaxlib version (pip install --upgrade jaxlib, if you aren't using a CUDA build), you should be able to import both tensorflow and jax with no problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants