You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: All to its latest, also orbitax
Name: flax
Version: 0.6.9
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page:
Author:
Author-email: Flax team flax-dev@google.com
License:
Location: /home/fernanda/.local/lib/python3.8/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions
Required-by:
When importing checkpoints, get the following error:
"""
ModuleNotFoundError Traceback (most recent call last)
in
11 config.update("jax_enable_x64", True)
12 from flax import serialization
---> 13 from flax.training import checkpoints
14 from jax import numpy as jnp
15 import jax
/gpfs/cfel/group/cmi/common/psi4/psi4conda/lib//python3.8/site-packages/flax/training/checkpoints.py in
37 from jax import process_index
38 from jax import sharding
---> 39 from jax.experimental.global_device_array import GlobalDeviceArray
40 from jax.experimental.multihost_utils import sync_global_devices
41 import orbax.checkpoint as orbax
ModuleNotFoundError: No module named 'jax.experimental.global_device_array'
"""
I guess it is a compatibility problem between jax and flax.
What you expected to happen:
Usual importing
The text was updated successfully, but these errors were encountered:
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
pip show flax jax jaxlib
: All to its latest, also orbitaxName: flax
Version: 0.6.9
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page:
Author:
Author-email: Flax team flax-dev@google.com
License:
Location: /home/fernanda/.local/lib/python3.8/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions
Required-by:
Name: jax
Version: 0.4.8
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/fernanda/.local/lib/python3.8/site-packages
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, diffrax, equinox, flax, optax, orbax, orbax-checkpoint, richmol
Name: jaxlib
Version: 0.4.7
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/fernanda/.local/lib/python3.8/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, optax, orbax, orbax-checkpoint
Name: orbax
Version: 0.1.7
Summary: Orbax
Home-page:
Author:
Author-email: Orbax Authors orbax-dev@google.com
License:
Location: /home/fernanda/.local/lib/python3.8/site-packages
Requires: absl-py, cached_property, etils, importlib_resources, jax, jaxlib, msgpack, nest_asyncio, numpy, pyyaml, tensorstore, typing_extensions
Problem you have encountered:
When importing checkpoints, get the following error:
"""
ModuleNotFoundError Traceback (most recent call last)
in
11 config.update("jax_enable_x64", True)
12 from flax import serialization
---> 13 from flax.training import checkpoints
14 from jax import numpy as jnp
15 import jax
/gpfs/cfel/group/cmi/common/psi4/psi4conda/lib//python3.8/site-packages/flax/training/checkpoints.py in
37 from jax import process_index
38 from jax import sharding
---> 39 from jax.experimental.global_device_array import GlobalDeviceArray
40 from jax.experimental.multihost_utils import sync_global_devices
41 import orbax.checkpoint as orbax
ModuleNotFoundError: No module named 'jax.experimental.global_device_array'
"""
I guess it is a compatibility problem between jax and flax.
What you expected to happen:
Usual importing
The text was updated successfully, but these errors were encountered: