-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error importing jax after certain tensorflow import #349
Comments
Tagging #120 which is a similar issue. |
I have four ideas for how to avoid this:
|
tensorflow-copybara
pushed a commit
to tensorflow/tensorflow
that referenced
this issue
Feb 25, 2019
The XLA Python extension is packaged separately as "jaxlib", but XLA itself is part of TensorFlow. Some of the same basic protocol buffers are used by both (e.g., xla_data.proto), leading to a conflict if a proto is imported twice into the same Python interpreter via different routes (e.g., jax-ml/jax#349), since a single global C++ protocol buffer registry exists for the entire interpreter. The simplest solution, short of a significant refactoring of the TensorFlow->XLA's dependency structure, seems to be to change xla_client.py not to depend on any XLA protocol buffers. A few other possible alternatives are discussed in jax-ml/jax#349. Fortunately, we don't use protocol buffers in any essential ways in the XLA client, mostly for objects such as convolution dimension numbers. Instead, create Python objects that play the same role and that duck type as protocol buffers well enough to keep the SWIG bindings happy. Remove a couple an unused function OpMetadataToProto. Change Computation.GetProto() to Computation.GetSerializedProto(). In passing, remove duplicated comment between xla_data.i and local_computation_builder.i. PiperOrigin-RevId: 235560841
hawkinsp
added a commit
to hawkinsp/jax
that referenced
this issue
Feb 26, 2019
Updates XLA to tensorflow/tensorflow@00afc7b. The new XLA release removes the use of protocol buffers from the XLA client. Fixes jax-ml#349. Add backward compatibility shims to jaxlib to allow older jax releases to still work on an up to date jaxlib. The new XLA release also incorporates a fix that avoids a host-device copy for every iteration of a `lax.fori_loop()` on GPU. Fixes jax-ml#402. Add a new jaxlib.__version__ field, change jax/jaxlib compatibility logic to check for it.
Merged
This is now fixed, but to access the fix, you either need to rebuild jaxlib from source or to wait until we push new binary wheels to PyPI (probably later this week). |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Running
causes
Full traceback below. I'm using the most recent protobuf and Tensorflow 1.12.0. This workaround makes the problem go away, but doesn't seem like a permanent solution.
Edit: Doing the imports in the opposite order causes a similar error during the tf import.
The text was updated successfully, but these errors were encountered: