-
Notifications
You must be signed in to change notification settings - Fork 894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AttributeError: module 'jax.random' has no attribute 'KeyArray' while fine tuning. #221
Comments
I've just encountered exactly the same error and I was about to open an issue about this. |
WARNING: Logging before InitGoogle() is written to STDERR I am facing the same error. Kindly solve this! |
The same error is facing while "import optax". WARNING: Logging before InitGoogle() is written to STDERR |
Exact same issue here! |
same issue! |
Chex 0.1.3 doesn't support JAX 0.2.12. You need to downgrade to Chex 0.1.2:
|
@vfbd It worked for me to infer the model. but apparently not for finetuning. |
@samyakai Now I noticed you encountered this error on fine tune. I did on inference but the same error. Sorry for the confusion. I modified my previous comment. The issue hasn't been resolved for finetuning. |
As suggested by @vfbd if I downgrade chex to 0.1.2 I encounter " "AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'".". To overcome this https://github.com/google/brax/issues/187 suggests upgrading to latest version. if I do that I again encounter the error "AttributeError: module 'jax.random' has no attribute 'KeyArray'" . |
I am following your rep to fine tune GPT-J on TPU. When I run "python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/" with my bucket name and the config file I have created, I get an error as "AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice''". These are some of the specs: OS: Ubuntu 20.04 The error is caused by line 7 in the device_train.py where optax is being imported:- "import optax". This is the error stack: WARNING: Logging before InitGoogle() is written to STDERR Please help us to resolve it asap.. |
Which version of jaxlib (not jax) do you have? Maybe try again with jaxlib==0.1.68 |
@vfbd These are the library versions which solve the error. |
I just paid for colab pro to play around with this and found the same issues described here. I added !pip install for the lib versions mentioned and then I got this error: `--------------------------------------------------------------------------- 6 frames ImportError: cannot import name '_ensure_str_tuple' from 'jax.api_util' (/usr/local/lib/python3.8/dist-packages/jax/api_util.py) NOTE: If your import is failing due to a missing package, you can To view examples of installing some common dependencies, click the I appreciate that this is like, alpha, so while I'll go play with GTP3, thank you for your work. |
from jax_md import rigid_body module 'jax.random' has no attribute 'KeyArray'I get this error when trying to run import jax_md |
I am following your rep to fine tune GPT-J on TPU. When I run "python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/" with my bucket name and the config file I have created, I get an error as "AttributeError: module 'jax.random' has no attribute 'KeyArray'". These are some of the specs:
OS: Ubuntu 20.04
jax version = 0.2.12
TPU : V3-8
Zone : us-central1-b
The error is caused by line 7 in the device_train.py where optax is being imported:- "import optax".
This is the error stack:
WARNING: Logging before InitGoogle() is written to STDERR
I0420 11:47:44.856002 10240 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process.
Traceback (most recent call last):
File "device_train.py", line 7, in
import optax
File "/usr/local/lib/python3.8/dist-packages/optax/init.py", line 17, in
from optax._src.alias import adabelief
File "/usr/local/lib/python3.8/dist-packages/optax/_src/alias.py", line 21, in
from optax._src import base
File "/usr/local/lib/python3.8/dist-packages/optax/_src/base.py", line 18, in
import chex
File "/usr/local/lib/python3.8/dist-packages/chex/init.py", line 17, in
from chex._src.asserts import assert_axis_dimension
File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts.py", line 26, in
from chex._src import asserts_internal as _ai
File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts_internal.py", line 32, in
from chex._src import pytypes
File "/usr/local/lib/python3.8/dist-packages/chex/_src/pytypes.py", line 36, in
PRNGKey = jax.random.KeyArray
AttributeError: module 'jax.random' has no attribute 'KeyArray'
Any help is appreciated!
The text was updated successfully, but these errors were encountered: