Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Not compatible with tensorflow-macos #219

Open
akbir opened this issue Dec 21, 2022 · 2 comments
Open

Not compatible with tensorflow-macos #219

akbir opened this issue Dec 21, 2022 · 2 comments

Comments

@akbir
Copy link

akbir commented Dec 21, 2022

Importing distrax causes the following stack trace error

...
    import distrax
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/distrax/__init__.py", line 18, in <module>
    from distrax._src.bijectors.bijector import Bijector
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/distrax/_src/bijectors/bijector.py", line 26, in <module>
    tfb = tfp.bijectors
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/tensorflow_probability/python/internal/lazy_loader.py", line 53, in __getattr__
    module = self._load()
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/tensorflow_probability/python/internal/lazy_loader.py", line 40, in _load
    module = importlib.import_module(self.__name__)
  File "/Users/akbir.khan/.pyenv/versions/3.9.12/lib/python3.9/importlib/__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/__init__.py", line 41, in <module>
    from tensorflow_probability.substrates.jax import bijectors
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/__init__.py", line 19, in <module>
    from tensorflow_probability.substrates.jax.bijectors.absolute_value import AbsoluteValue
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/absolute_value.py", line 17, in <module>
    from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/__init__.py", line 18, in <module>
    from tensorflow_probability.python.internal.backend.jax import bitwise
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/bitwise.py", line 19, in <module>
    from tensorflow_probability.python.internal.backend.jax import _utils as utils
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/_utils.py", line 25, in <module>
    from tensorflow.python.ops import array_ops  # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top,unused-import
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/tensorflow/__init__.py", line 37, in <module>
    from tensorflow.python.tools import module_util as _module_util
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/tensorflow/python/__init__.py", line 42, in <module>
    from tensorflow.python import data
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/tensorflow/python/data/__init__.py", line 21, in <module>
    from tensorflow.python.data import experimental
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/tensorflow/python/data/experimental/__init__.py", line 96, in <module>
    from tensorflow.python.data.experimental import service
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/tensorflow/python/data/experimental/service/__init__.py", line 419, in <module>
    from tensorflow.python.data.experimental.ops.data_service_ops import distribute
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/tensorflow/python/data/experimental/ops/data_service_ops.py", line 22, in <module>
    from tensorflow.python.data.experimental.ops import compression_ops
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/tensorflow/python/data/experimental/ops/compression_ops.py", line 16, in <module>
    from tensorflow.python.data.util import structure
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/tensorflow/python/data/util/structure.py", line 22, in <module>
    from tensorflow.python.data.util import nest
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/tensorflow/python/data/util/nest.py", line 34, in <module>
    from tensorflow.python.framework import sparse_tensor as _sparse_tensor
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/tensorflow/python/framework/sparse_tensor.py", line 24, in <module>
    from tensorflow.python.framework import constant_op
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/tensorflow/python/framework/constant_op.py", line 25, in <module>
    from tensorflow.python.eager import execute
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/tensorflow/python/eager/execute.py", line 21, in <module>
    from tensorflow.python.framework import dtypes
  File "/Users/akbir.khan/pax/venv/lib/python3.9/site-packages/tensorflow/python/framework/dtypes.py", line 34, in <module>
    _np_bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
TypeError: Unable to convert function return value to a Python type! The signature was
        () -> handle
@thomaspinder
Copy link

We're also facing this issue. Is there any planned fix/solution for this?

@chriscarmona
Copy link

I'm experiencing a similar incompatibility between tensorflow and distrax.
In a clean environment, I installed jax, distrax and tensorflow

pip install -U pip
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install distrax tensorflow

and when I import distrax I get this error:

>>> import distrax
2023-10-01 20:04:19.026058: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-01 20:04:19.026139: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-01 20:04:19.026198: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-10-01 20:04:20.099045: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ubuntu/.local/lib/python3.10/site-packages/distrax/__init__.py", line 18, in <module>
    from distrax._src.bijectors.bijector import Bijector
  File "/home/ubuntu/.local/lib/python3.10/site-packages/distrax/_src/bijectors/bijector.py", line 27, in <module>
    tfb = tfp.bijectors
  File "/home/ubuntu/.local/lib/python3.10/site-packages/tensorflow_probability/python/internal/lazy_loader.py", line 53, in __getattr__
    module = self._load()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/tensorflow_probability/python/internal/lazy_loader.py", line 40, in _load
    module = importlib.import_module(self.__name__)
  File "/usr/lib/python3.10/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/__init__.py", line 41, in <module>
    from tensorflow_probability.substrates.jax import bijectors
  File "/home/ubuntu/.local/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/bijectors/__init__.py", line 19, in <module>
    from tensorflow_probability.substrates.jax.bijectors.absolute_value import AbsoluteValue
  File "/home/ubuntu/.local/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/bijectors/absolute_value.py", line 19, in <module>
    from tensorflow_probability.substrates.jax.bijectors import bijector
  File "/home/ubuntu/.local/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py", line 26, in <module>
    from tensorflow_probability.substrates.jax.internal import batch_shape_lib
  File "/home/ubuntu/.local/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/internal/batch_shape_lib.py", line 23, in <module>
    from tensorflow_probability.substrates.jax.internal import prefer_static as ps
  File "/home/ubuntu/.local/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/internal/prefer_static.py", line 361, in <module>
    ones_like = _copy_docstring(tf.ones_like, _ones_like)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/internal/prefer_static.py", line 84, in _copy_docstring
    raise ValueError(
ValueError: Arg specs do not match: original=FullArgSpec(args=['input', 'dtype', 'name', 'layout'], varargs=None, varkw=None, defaults=(None, None, None), kwonlyargs=[], kwonlydefaults=None, annotations={}), new=FullArgSpec(args=['input', 'dtype', 'name'], varargs=None, varkw=None, defaults=(None, None), kwonlyargs=[], kwonlydefaults=None, annotations={}), fn=<function ones_like_v2 at 0x7fb6b25dd360>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants