Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unavailable to import checkpoints #3075

Closed
alfercorral opened this issue May 4, 2023 · 3 comments · Fixed by #3089
Closed

Unavailable to import checkpoints #3075

alfercorral opened this issue May 4, 2023 · 3 comments · Fixed by #3089
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.

Comments

@alfercorral
Copy link

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

System information

  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: All to its latest, also orbitax

Name: flax
Version: 0.6.9
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page:
Author:
Author-email: Flax team flax-dev@google.com
License:
Location: /home/fernanda/.local/lib/python3.8/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions
Required-by:

Name: jax
Version: 0.4.8
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/fernanda/.local/lib/python3.8/site-packages
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, diffrax, equinox, flax, optax, orbax, orbax-checkpoint, richmol

Name: jaxlib
Version: 0.4.7
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/fernanda/.local/lib/python3.8/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, optax, orbax, orbax-checkpoint

Name: orbax
Version: 0.1.7
Summary: Orbax
Home-page:
Author:
Author-email: Orbax Authors orbax-dev@google.com
License:
Location: /home/fernanda/.local/lib/python3.8/site-packages
Requires: absl-py, cached_property, etils, importlib_resources, jax, jaxlib, msgpack, nest_asyncio, numpy, pyyaml, tensorstore, typing_extensions

  • Python version: 3.8

Problem you have encountered:

When importing checkpoints, get the following error:
"""

ModuleNotFoundError Traceback (most recent call last)
in
11 config.update("jax_enable_x64", True)
12 from flax import serialization
---> 13 from flax.training import checkpoints
14 from jax import numpy as jnp
15 import jax

/gpfs/cfel/group/cmi/common/psi4/psi4conda/lib//python3.8/site-packages/flax/training/checkpoints.py in
37 from jax import process_index
38 from jax import sharding
---> 39 from jax.experimental.global_device_array import GlobalDeviceArray
40 from jax.experimental.multihost_utils import sync_global_devices
41 import orbax.checkpoint as orbax

ModuleNotFoundError: No module named 'jax.experimental.global_device_array'

"""

I guess it is a compatibility problem between jax and flax.

What you expected to happen:

Usual importing

@cgarciae cgarciae added the Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. label May 4, 2023
@cgarciae
Copy link
Collaborator

cgarciae commented May 4, 2023

Can you try installing Flax from main and see if the issue is fixed?

pip install git+https://github.com/google/flax.git

We might need to create a new release.

@alfercorral
Copy link
Author

@cgarciae Unluckily it does not make any difference.

@cgarciae
Copy link
Collaborator

cgarciae commented May 4, 2023

Maybe try completely uninstalling flax first. The line on the error you are getting no longer exists on main:

from flax import traverse_util
from flax.training import orbax_utils
import jax
from jax import monitoring
from jax import process_index

@cgarciae cgarciae mentioned this issue May 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants