Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error importing jax after certain tensorflow import #349

Closed
j-towns opened this issue Feb 11, 2019 · 3 comments · Fixed by #452
Closed

Error importing jax after certain tensorflow import #349

j-towns opened this issue Feb 11, 2019 · 3 comments · Fixed by #452

Comments

@j-towns
Copy link
Contributor

j-towns commented Feb 11, 2019

Running

from tensorflow.contrib.framework.python.ops import add_arg_scope
import jax

causes

TypeError: Couldn't build proto file into descriptor pool!
Invalid proto descriptor for file "tensorflow/compiler/xla/xla_data.proto":
  tensorflow/compiler/xla/xla_data.proto: A file with this name is already in the pool.

Full traceback below. I'm using the most recent protobuf and Tensorflow 1.12.0. This workaround makes the problem go away, but doesn't seem like a permanent solution.

Edit: Doing the imports in the opposite order causes a similar error during the tf import.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-1-0480bfae9932> in <module>
      1 from tensorflow.contrib.framework.python.ops import add_arg_scope
----> 2 import jax

~/dev/jax/jax/__init__.py in <module>
     15 import os
     16 os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1')
---> 17 from jax.api import *
     18 import jax.numpy as np  # side-effecting import sets up operator overloads

~/dev/jax/jax/api.py in <module>
     43 from .util import (unzip2, unzip3, curry, partial, safe_map, safe_zip,
     44                    WrapHashably, prod)
---> 45 from .lib.xla_bridge import canonicalize_dtype
     46 from .abstract_arrays import ShapedArray
     47 from .interpreters import partial_eval as pe

~/dev/jax/jax/lib/xla_bridge.py in <module>
     31 import numpy as onp  # 'onp' rather than 'np' to distinguish from autograd.numpy
     32
---> 33 from jaxlib import xla_data_pb2
     34 from jaxlib import xla_client
     35

~/miniconda3/envs/python36/lib/python3.6/site-packages/jaxlib/xla_data_pb2.py in <module>
     21   syntax='proto3',
     22   serialized_options=_b('\370\001\001'),
---> 23   serialized_pb=_b('\n&tensorflow/compiler/xla/xla_data.proto\x12\x03xla\"\xb7\x01\n\rPaddingConfig\x12=\n\ndimensions\x18\x01 \x03(\x0b\x32).xla.PaddingConfig.PaddingConfigDimension\x1ag\n\x16PaddingConfigDimension\x12\x18\n\x10\x65\x64ge_padding_low\x18\x01 \x01
(\x03\x12\x19\n\x11\x65\x64ge_padding_high\x18\x02 \x01(\x03\x12\x18\n\x10interior_padding\x18\x03 \x01(\x03\"\x1f\n\tTileProto\x12\x12\n\ndimensions\x18\x01 \x03(\x03\"\xca\x01\n\x0bLayoutProto\x12\x1b\n\x06\x66ormat\x18\x04 \x01(\x0e\x32\x0b.xla.Format\x12\x16\n\x0emino
r_to_major\x18\x01 \x03(\x03\x12\x1b\n\x13max_sparse_elements\x18\x05 \x01(\x03\x12\x1d\n\x05tiles\x18\x06 \x03(\x0b\x32\x0e.xla.TileProto\x12\x1c\n\x14\x65lement_size_in_bits\x18\x07 \x01(\x03J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04R\x11padded_dimensionsR\rpadding_valu
e\"\xbd\x01\n\nShapeProto\x12(\n\x0c\x65lement_type\x18\x02 \x01(\x0e\x32\x12.xla.PrimitiveType\x12\x12\n\ndimensions\x18\x03 \x03(\x03\x12%\n\x0ctuple_shapes\x18\x04 \x03(\x0b\x32\x0f.xla.ShapeProto\x12 \n\x06layout\x18\x05 \x01(\x0b\x32\x10.xla.LayoutProto\x12\x1c\n\x14
is_dynamic_dimension\x18\x06 \x03(\x08J\x04\x08\x01\x10\x02R\x04rank\"r\n\x11ProgramShapeProto\x12#\n\nparameters\x18\x01 \x03(\x0b\x32\x0f.xla.ShapeProto\x12\x1f\n\x06result\x18\x02 \x01(\x0b\x32\x0f.xla.ShapeProto\x12\x17\n\x0fparameter_names\x18\x03 \x03(\t\"D\n\x10\x43omputationStats\x12\x12\n\nflop_count\x18\x01 \x01(\x01\x12\x1c\n\x14transcendental_count\x18\x02 \x01(\x01\"X\n\nOpMetadata\x12\x0f\n\x07op_type\x18\x01 \x01(\t\x12\x0f\n\x07op_name\x18\x02 \x01(\t\x12\x13\n\x0bsource_file\x18\x03 \x01(\t\x12\x13\n\x0bsource_line\x18\x04 \x01(\x05\"\xc8\x01\n\x10\x45xecutionProfile\x12\x1d\n\x15\x63ompilation_cache_hit\x18\x01 \x01(\x08\x12\x17\n\x0f\x63ompile_time_ms\x18\x02 \x01(\x03\x12\x1b\n\x13\x63ompute_cycle_count\x18\x03 \x01(\x03\x12\x17\n\x0f\x63ompute_time_ns\x18\x04 \x01(\x03\x12$\n\x1c\x63ompute_and_transfer_time_ns\x18\x05 \x01(\x03\x12 \n\x18\x65xecutable_size_in_bytes\x18\x06 \x01(\x03\"!\n\x0f\x45xecutionHandle\x12\x0e\n\x06handle\x18\x01 \x01(\x03\"\"\n\x10GlobalDataHandle\x12\x0e\n\x06handle\x18\x01 \x01(\x03\"4\n\x0c\x44\x65viceHandle\x12\x0e\n\x06handle\x18\x01 \x01(\x03\x12\x14\n\x0c\x64\x65vice_count\x18\x02 \x01(\x03\"\xb4\x01\n\rChannelHandle\x12\x0e\n\x06handle\x18\x01 \x01(\x03\x12,\n\x04type\x18\x02 \x01(\x0e\x32\x1e.xla.ChannelHandle.ChannelType\"e\n\x0b\x43hannelType\x12\x18\n\x14\x43HANNEL_TYPE_INVALID\x10\x00\x12\x14\n\x10\x44\x45VICE_TO_DEVICE\x10\x01\x12\x12\n\x0e\x44\x45VICE_TO_HOST\x10\x02\x12\x12\n\x0eHOST_TO_DEVICE\x10\x03\"\xc5\x01\n\x15\x44\x65viceAssignmentProto\x12\x15\n\rreplica_count\x18\x01 \x01(\x05\x12\x19\n\x11\x63omputation_count\x18\x02 \x01(\x05\x12I\n\x13\x63omputation_devices\x18\x03 \x03(\x0b\x32,.xla.DeviceAssignmentProto.ComputationDevice\x1a/\n\x11\x43omputationDevice\x12\x1a\n\x12replica_device_ids\x18\x01 \x03(\x05\"\xc4\x02\n\x0cLiteralProto\x12\x1e\n\x05shape\x18\x01 \x01(\x0b\x32\x0f.xla.ShapeProto\x12\r\n\x05preds\x18\x02 \x03(\x08\x12\x0b\n\x03s8s\x18\x0f \x01(\x0c\x12\x0b\n\x03u8s\x18\x03 \x01(\x0c\x12\x0c\n\x04s32s\x18\x04 \x03(\x05\x12\x0c\n\x04s64s\x18\x05 \x03(\x03\x12\x0c\n\x04u32s\x18\x06 \x03(\r\x12\x0c\n\x04u64s\x18\x07 \x03(\x04\x12\x0c\n\x04\x66\x33\x32s\x18\x08 \x03(\x02\x12\x0c\n\x04\x66\x36\x34s\x18\t \x03(\x01\x12\x0c\n\x04\x63\x36\x34s\x18\x0c \x03(\x02\x12\r\n\x05\x63\x31\x32\x38s\x18\x12 \x03(\x01\x12)\n\x0etuple_literals\x18\n \x03(\x0b\x32\x11.xla.LiteralProto\x12\x0c\n\x04\x66\x31\x36s\x18\x0b \x01(\x0c\x12\r\n\x05\x62\x66\x31\x36s\x18\r \x01(\x0c\x12\x0c\n\x04u16s\x18\x10 \x01(\x0c\x12\x0c\n\x04s16s\x18\x11 \x01(\x0c\x12\x16\n\x0esparse_indices\x18\x0e \x03(\x03\"\xa3\x01\n\x0fWindowDimension\x12\x0c\n\x04size\x18\x01 \x01(\x03\x12\x0e\n\x06stride\x18\x02 \x01(\x03\x12\x13\n\x0bpadding_low\x18\x03 \x01(\x03\x12\x14\n\x0cpadding_high\x18\x04 \x01(\x03\x12\x17\n\x0fwindow_dilation\x18\x05 \x01(\x03\x12\x15\n\rbase_dilation\x18\x06 \x01(\x03\x12\x17\n\x0fwindow_reversal\x18\x07 \x01(\x08\"2\n\x06Window\x12(\n\ndimensions\x18\x01 \x03(\x0b\x32\x14.xla.WindowDimension\"~\n\x16GatherDimensionNumbers\x12\x13\n\x0boffset_dims\x18\x01 \x03(\x03\x12\x1c\n\x14\x63ollapsed_slice_dims\x18\x02 \x03(\x03\x12\x17\n\x0fstart_index_map\x18\x03 \x03(\x03\x12\x18\n\x10index_vector_dim\x18\x04 \x01(\x03\"\x93\x01\n\x17ScatterDimensionNumbers\x12\x1a\n\x12update_window_dims\x18\x01 \x03(\x03\x12\x1c\n\x14inserted_window_dims\x18\x02 \x03(\x03\x12$\n\x1cscatter_dims_to_operand_dims\x18\x03 \x03(\x03\x12\x18\n\x10index_vector_dim\x18\x04 \x01(\x03\"\xd8\x02\n\x1b\x43onvolutionDimensionNumbers\x12\x1d\n\x15input_batch_dimension\x18\x07 \x01(\x03\x12\x1f\n\x17input_feature_dimension\x18\x08 \x01(\x03\x12 \n\x18input_spatial_dimensions\x18\x0b \x03(\x03\x12&\n\x1ekernel_input_feature_dimension\x18\x03 \x01(\x03\x12\'\n\x1fkernel_output_feature_dimension\x18\x04 \x01(\x03\x12!\n\x19kernel_spatial_dimensions\x18\x06 \x03(\x03\x12\x1e\n\x16output_batch_dimension\x18\t \x01(\x03\x12 \n\x18output_feature_dimension\x18\n \x01(\x03\x12!\n\x19output_spatial_dimensions\x18\x0c \x03(\x03\"\x99\x01\n\x13\x44otDimensionNumbers\x12\"\n\x1alhs_contracting_dimensions\x18\x01 \x03(\x03\x12\"\n\x1arhs_contracting_dimensions\x18\x02 \x03(\x03\x12\x1c\n\x14lhs_batch_dimensions\x18\x03 \x03(\x03\x12\x1c\n\x14rhs_batch_dimensions\x18\x04 \x03(\x03\"\xdf\x01\n\x16TriangularSolveOptions\x12\x11\n\tleft_side\x18\x01 \x01(\x08\x12\r\n\x05lower\x18\x02 \x01(\x08\x12\x15\n\runit_diagonal\x18\x03 \x01(\x08\x12:\n\x0btranspose_a\x18\x04 \x01(\x0e\x32%.xla.TriangularSolveOptions.Transpose\"P\n\tTranspose\x12\x15\n\x11TRANSPOSE_INVALID\x10\x00\x12\x10\n\x0cNO_TRANSPOSE\x10\x01\x12\r\n\tTRANSPOSE\x10\x02\x12\x0b\n\x07\x41\x44JOINT\x10\x03\"\xff\x01\n\nOpSharding\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.xla.OpSharding.Type\x12#\n\ntile_shape\x18\x02 \x01(\x0b\x32\x0f.xla.ShapeProto\x12\"\n\x1atile_assignment_dimensions\x18\x03 \x03(\x03\x12\x1f\n\x17tile_assignment_devices\x18\x04 \x03(\x03\x12(\n\x0ftuple_shardings\x18\x05 \x03(\x0b\x32\x0f.xla.OpSharding\"9\n\x04Type\x12\x0e\n\nREPLICATED\x10\x00\x12\x0b\n\x07MAXIMAL\x10\x01\x12\t\n\x05TUPLE\x10\x02\x12\t\n\x05OTHER\x10\x03\"#\n\x0cReplicaGroup\x12\x13\n\x0breplica_ids\x18\x01 \x03(\x03\".\n\x0cSourceTarget\x12\x0e\n\x06source\x18\x01 \x01(\x03\x12\x0e\n\x06target\x18\x02 \x01(\x03\"}\n\x0fPrecisionConfig\x12\x39\n\x11operand_precision\x18\x01 \x03(\x0e\x32\x1e.xla.PrecisionConfig.Precision\"/\n\tPrecision\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\x08\n\x04HIGH\x10\x01\x12\x0b\n\x07HIGHEST\x10\x02*\xd5\x01\n\rPrimitiveType\x12\x1a\n\x16PRIMITIVE_TYPE_INVALID\x10\x00\x12\x08\n\x04PRED\x10\x01\x12\x06\n\x02S8\x10\x02\x12\x07\n\x03S16\x10\x03\x12\x07\n\x03S32\x10\x04\x12\x07\n\x03S64\x10\x05\x12\x06\n\x02U8\x10\x06\x12\x07\n\x03U16\x10\x07\x12\x07\n\x03U32\x10\x08\x12\x07\n\x03U64\x10\t\x12\x07\n\x03\x46\x31\x36\x10\n\x12\x07\n\x03\x46\x33\x32\x10\x0b\x12\x08\n\x04\x42\x46\x31\x36\x10\x10\x12\x07\n\x03\x46\x36\x34\x10\x0c\x12\x07\n\x03\x43\x36\x34\x10\x0f\x12\x08\n\x04\x43\x31\x32\x38\x10\x12\x12\t\n\x05TUPLE\x10\r\x12\n\n\x06OPAQUE\x10\x0e\x12\t\n\x05TOKEN\x10\x11*3\n\x06\x46ormat\x12\x12\n\x0eINVALID_FORMAT\x10\x00\x12\t\n\x05\x44\x45NSE\x10\x01\x12\n\n\x06SPARSE\x10\x02*1\n\x07\x46\x66tType\x12\x07\n\x03\x46\x46T\x10\x00\x12\x08\n\x04IFFT\x10\x01\x12\x08\n\x04RFFT\x10\x02\x12\t\n\x05IRFFT\x10\x03*F\n\x12RandomDistribution\x12\x0f\n\x0bRNG_INVALID\x10\x00\x12\x0f\n\x0bRNG_UNIFORM\x10\x01\x12\x0e\n\nRNG_NORMAL\x10\x02\x42\x03\xf8\x01\x01\x62\x06proto3')
     24 )
     25

~/miniconda3/envs/python36/lib/python3.6/site-packages/google/protobuf/descriptor.py in __new__(cls, name, package, options, serialized_options, serialized_pb, dependencies, public_dependencies, syntax, pool)
    876         # TODO(amauryfa): use the pool passed as argument. This will work only
    877         # for C++-implemented DescriptorPools.
--> 878         return _message.default_pool.AddSerializedFile(serialized_pb)
    879       else:
    880         return super(FileDescriptor, cls).__new__(cls)

TypeError: Couldn't build proto file into descriptor pool!
Invalid proto descriptor for file "tensorflow/compiler/xla/xla_data.proto":
  tensorflow/compiler/xla/xla_data.proto: A file with this name is already in the pool.
@j-towns
Copy link
Contributor Author

j-towns commented Feb 11, 2019

Tagging #120 which is a similar issue.

@hawkinsp
Copy link
Collaborator

I have four ideas for how to avoid this:

  1. have a single copy of XLA and its protos, have JAX and TF both depend on it. Seems like a fine plan, but probably not feasible without lots of work on TF's dependency structure; parts of XLA depend on parts of TF.
  2. have JAX depend on TF's Python libraries and get XLA from TF. Adding a hard dependency on TF seems unlikely to be popular, but perhaps we could have an optional dependency (i.e., if TF is installed, you can use a copy of XLA from TF, if not, you can use jaxlib).
  3. do more work to rename the copy of xla_data.proto inside jaxlib so it doesn't clash with TF's copy. I'm not terribly happy with this option, but it could work.
  4. change the XLA client Python code (Jaxlib) not to use protocol buffers. This actually seems pretty feasible — only a small number of protocol buffer messages are used, and none in any particularly essential ways.

tensorflow-copybara pushed a commit to tensorflow/tensorflow that referenced this issue Feb 25, 2019
The XLA Python extension is packaged separately as "jaxlib", but XLA itself is part of TensorFlow. Some of the same basic protocol buffers are used by both (e.g., xla_data.proto), leading to a conflict if a proto is imported twice into the same Python interpreter via different routes (e.g., jax-ml/jax#349), since a single global C++ protocol buffer registry exists for the entire interpreter.

The simplest solution, short of a significant refactoring of the TensorFlow->XLA's dependency structure, seems to be to change xla_client.py not to depend on any XLA protocol buffers. A few other possible alternatives are discussed in jax-ml/jax#349.

Fortunately, we don't use protocol buffers in any essential ways in the XLA client, mostly for objects such as convolution dimension numbers. Instead, create Python objects that play the same role and that duck type as protocol buffers well enough to keep the SWIG bindings happy.

Remove a couple an unused function OpMetadataToProto.
Change Computation.GetProto() to Computation.GetSerializedProto().

In passing, remove duplicated comment between xla_data.i and local_computation_builder.i.

PiperOrigin-RevId: 235560841
hawkinsp added a commit to hawkinsp/jax that referenced this issue Feb 26, 2019
Updates XLA to tensorflow/tensorflow@00afc7b.

The new XLA release removes the use of protocol buffers from the XLA client. Fixes jax-ml#349.
Add backward compatibility shims to jaxlib to allow older jax releases to still work on an up to date jaxlib.

The new XLA release also incorporates a fix that avoids a host-device copy for every iteration of a `lax.fori_loop()` on GPU. Fixes jax-ml#402.

Add a new jaxlib.__version__ field, change jax/jaxlib compatibility logic to check for it.
@hawkinsp hawkinsp mentioned this issue Feb 26, 2019
@hawkinsp
Copy link
Collaborator

This is now fixed, but to access the fix, you either need to rebuild jaxlib from source or to wait until we push new binary wheels to PyPI (probably later this week).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants