Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: module 'jax.random' has no attribute 'KeyArray' while fine tuning. #221

Closed
samyakai opened this issue Apr 20, 2022 · 15 comments

Comments

@samyakai
Copy link

I am following your rep to fine tune GPT-J on TPU. When I run "python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/" with my bucket name and the config file I have created, I get an error as "AttributeError: module 'jax.random' has no attribute 'KeyArray'". These are some of the specs:

OS: Ubuntu 20.04
jax version = 0.2.12
TPU : V3-8
Zone : us-central1-b

The error is caused by line 7 in the device_train.py where optax is being imported:- "import optax".

This is the error stack:

WARNING: Logging before InitGoogle() is written to STDERR
I0420 11:47:44.856002 10240 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process.
Traceback (most recent call last):
File "device_train.py", line 7, in
import optax
File "/usr/local/lib/python3.8/dist-packages/optax/init.py", line 17, in
from optax._src.alias import adabelief
File "/usr/local/lib/python3.8/dist-packages/optax/_src/alias.py", line 21, in
from optax._src import base
File "/usr/local/lib/python3.8/dist-packages/optax/_src/base.py", line 18, in
import chex
File "/usr/local/lib/python3.8/dist-packages/chex/init.py", line 17, in
from chex._src.asserts import assert_axis_dimension
File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts.py", line 26, in
from chex._src import asserts_internal as _ai
File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts_internal.py", line 32, in
from chex._src import pytypes
File "/usr/local/lib/python3.8/dist-packages/chex/_src/pytypes.py", line 36, in
PRNGKey = jax.random.KeyArray
AttributeError: module 'jax.random' has no attribute 'KeyArray'

Any help is appreciated!

@mosmos6
Copy link

mosmos6 commented Apr 20, 2022

I've just encountered exactly the same error and I was about to open an issue about this.

@KD1903
Copy link

KD1903 commented Apr 20, 2022

WARNING: Logging before InitGoogle() is written to STDERR
I0420 11:47:44.856002 10240 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process.
Traceback (most recent call last):
File "device_train.py", line 7, in
import optax
File "/usr/local/lib/python3.8/dist-packages/optax/init.py", line 17, in
from optax._src.alias import adabelief
File "/usr/local/lib/python3.8/dist-packages/optax/_src/alias.py", line 21, in
from optax._src import base
File "/usr/local/lib/python3.8/dist-packages/optax/_src/base.py", line 18, in
import chex
File "/usr/local/lib/python3.8/dist-packages/chex/init.py", line 17, in
from chex._src.asserts import assert_axis_dimension
File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts.py", line 26, in
from chex._src import asserts_internal as _ai
File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts_internal.py", line 32, in
from chex._src import pytypes
File "/usr/local/lib/python3.8/dist-packages/chex/_src/pytypes.py", line 36, in
PRNGKey = jax.random.KeyArray
AttributeError: module 'jax.random' has no attribute 'KeyArray'

I am facing the same error. Kindly solve this!

@jagruti-samyak
Copy link

The same error is facing while "import optax".

WARNING: Logging before InitGoogle() is written to STDERR
I0420 11:47:44.856002 10240 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process.
Traceback (most recent call last):
File "device_train.py", line 7, in
import optax
File "/usr/local/lib/python3.8/dist-packages/optax/init.py", line 17, in
from optax._src.alias import adabelief
File "/usr/local/lib/python3.8/dist-packages/optax/_src/alias.py", line 21, in
from optax._src import base
File "/usr/local/lib/python3.8/dist-packages/optax/_src/base.py", line 18, in
import chex
File "/usr/local/lib/python3.8/dist-packages/chex/init.py", line 17, in
from chex._src.asserts import assert_axis_dimension
File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts.py", line 26, in
from chex._src import asserts_internal as _ai
File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts_internal.py", line 32, in
from chex._src import pytypes
File "/usr/local/lib/python3.8/dist-packages/chex/_src/pytypes.py", line 36, in
PRNGKey = jax.random.KeyArray
AttributeError: module 'jax.random' has no attribute 'KeyArray'

@abdelatifsd
Copy link

Exact same issue here!

@hxiaoyang
Copy link

same issue!

@vfbd
Copy link
Contributor

vfbd commented Apr 20, 2022

Chex 0.1.3 doesn't support JAX 0.2.12. You need to downgrade to Chex 0.1.2:

pip3 install chex==0.1.2

@mosmos6
Copy link

mosmos6 commented Apr 21, 2022

@vfbd It worked for me to infer the model. but apparently not for finetuning.

@samyakai
Copy link
Author

samyakai commented Apr 21, 2022

@mosmos6 @vfbd Now it is giving me this error: "AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'".
@mosmos6 How did it work for you? Are you training on TPU v3-8?

@mosmos6
Copy link

mosmos6 commented Apr 21, 2022

@samyakai Now I noticed you encountered this error on fine tune. I did on inference but the same error. Sorry for the confusion. I modified my previous comment. The issue hasn't been resolved for finetuning.

@samyakai samyakai changed the title AttributeError: module 'jax.random' has no attribute 'KeyArray' AttributeError: module 'jax.random' has no attribute 'KeyArray' while fine tuning. Apr 21, 2022
@samyakai
Copy link
Author

As suggested by @vfbd if I downgrade chex to 0.1.2 I encounter " "AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'".". To overcome this https://github.com/google/brax/issues/187 suggests upgrading to latest version. if I do that I again encounter the error "AttributeError: module 'jax.random' has no attribute 'KeyArray'" .

@jagruti-samyak
Copy link

I am following your rep to fine tune GPT-J on TPU. When I run "python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/" with my bucket name and the config file I have created, I get an error as "AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice''". These are some of the specs:

OS: Ubuntu 20.04
jax version = 0.2.12
chex version == 0.1.2

TPU : V3-8
Zone : us-central1-b

The error is caused by line 7 in the device_train.py where optax is being imported:- "import optax".

This is the error stack:

WARNING: Logging before InitGoogle() is written to STDERR
I0421 10:06:19.047791 8679 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which pr ocess is using the TPU. Not attempting to load libtpu.so in this process.
Traceback (most recent call last):
File "device_train.py", line 7, in
import optax
File "/usr/local/lib/python3.8/dist-packages/optax/init.py", line 17, in < module>
from optax import experimental
File "/usr/local/lib/python3.8/dist-packages/optax/experimental/init.py", line 20, in
from optax._src.experimental.complex_valued import split_real_and_imaginary
File "/usr/local/lib/python3.8/dist-packages/optax/_src/experimental/complex_v alued.py", line 32, in
import chex
File "/usr/local/lib/python3.8/dist-packages/chex/init.py", line 17, in
from chex._src.asserts import assert_axis_dimension
File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts.py", line 26, i n
from chex._src import asserts_internal as _ai
File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts_internal.py", l ine 32, in
from chex._src import pytypes
File "/usr/local/lib/python3.8/dist-packages/chex/_src/pytypes.py", line 40, i n
CpuDevice = jax.lib.xla_extension.CpuDevice
AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'

Please help us to resolve it asap..
Thank you

@vfbd
Copy link
Contributor

vfbd commented Apr 22, 2022

Which version of jaxlib (not jax) do you have? Maybe try again with jaxlib==0.1.68

@samyakai
Copy link
Author

samyakai commented Apr 22, 2022

@vfbd These are the library versions which solve the error.
jax==0.2.16
jaxlib==0.1.68
optax==0.1.2
chex==0.1.2

@kufton
Copy link

kufton commented Jan 21, 2023

I just paid for colab pro to play around with this and found the same issues described here. I added !pip install for the lib versions mentioned and then I got this error:

`---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
in
4 from jax.experimental import maps
5 import numpy as np
----> 6 import optax
7 import transformers
8

6 frames
/usr/local/lib/python3.8/dist-packages/jax/_src/api.py in
42 from . import dtypes
43 from ..core import eval_jaxpr
---> 44 from ..api_util import (flatten_fun, apply_flat_fun, flatten_fun_nokwargs,
45 flatten_fun_nokwargs2, argnums_partial,
46 argnums_partial_except, flatten_axes, donation_vector,

ImportError: cannot import name '_ensure_str_tuple' from 'jax.api_util' (/usr/local/lib/python3.8/dist-packages/jax/api_util.py)


NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
"Open Examples" button below.
---------------------------------------------------------------------------`

I appreciate that this is like, alpha, so while I'll go play with GTP3, thank you for your work.

@mystiverv
Copy link


from jax_md import rigid_body
File "C:\Users....\env\Lib\site-packages\jax_md\rigid_body.py", line 76, in
KeyArray = random.KeyArray

module 'jax.random' has no attribute 'KeyArray'

I get this error when trying to run import jax_md

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants