Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Requirements version error #3

Closed
heyitsguay opened this issue Feb 14, 2024 · 11 comments
Closed

Requirements version error #3

heyitsguay opened this issue Feb 14, 2024 · 11 comments

Comments

@heyitsguay
Copy link

heyitsguay commented Feb 14, 2024

Trying to install on an Ubuntu 22.04 system with pip 24.0 and python 3.11.

pip install -r requirements.txt yields the error: Could not find a version that satisfies the requirement tensorflow==2.11.0. Min version number that shows up for me is 2.12.0rc0.

@heyitsguay
Copy link
Author

Switching the tensorflow version to 2.12.1 (not that I know whether that still works yet) creates another version conflict:

    tensorflow 2.12.1 depends on numpy<=1.24.3 and >=1.22
    chex 0.1.82 depends on numpy>=1.25.0

@heyitsguay
Copy link
Author

heyitsguay commented Feb 14, 2024

Bumping tensorflow to 2.14.1 allows for successful environment setup. Downloading LWM-Chat-32K-Jax and attempting to run bash scripts/run_vision_chat.sh with a small test video raises the following error:

ImportError: cannot import name 'linear_util' from 'jax'

This is with flax==0.7.0, jax==0.4.24, and jaxlib==0.4.24.

Some fiddling around reveals that linear_util is a function in jax.extend but not jax. Not sure if this is related to using the different version of tensorflow, but would appreciate some advice on getting this up and running outside a TPU environment! Thank you.

@heyitsguay
Copy link
Author

heyitsguay commented Feb 14, 2024

Using Python 3.10 instead of Python 3.11, I'm able to install the requirements as stated in the repo (i.e. tensorflow 2.11.0). However, I still run into the

ImportError: cannot import name 'linear_util' from 'jax'

error when trying to run run_vision_chat.sh

@heyitsguay
Copy link
Author

Taking the naive approach of going into each flax source file that has a from jax import linear_util as lu line and replacing it with from jax.extend import linear_util as lu brings the script to a new error:

AttributeError: module 'jax.random' has no attribute 'KeyArray'

This one does have some hits on Google, such as this and this. Looks to be versioning issues, though there appears to be differences in suggested fixes.

Have you verified that LWM works outside a TPU environment? Can you share an example environment and params for run_vision_chat.sh that works on a Linux GPU or multi-GPU machine? Thank you!

@Alpslee
Copy link

Alpslee commented Feb 14, 2024

Using Python 3.10 instead of Python 3.11, I'm able to install the requirements as stated in the repo (i.e. tensorflow 2.11.0). However, I still run into the

ImportError: cannot import name 'linear_util' from 'jax'

error when trying to run run_vision_chat.sh

do you use pyenv or conda? I switched to python 3.10, still got error : Could not find a version that satisfies the requirement tensorflow==2.11.0 (from versions: 2.13.0rc0, 2.13.0rc1, 2.13.0rc2, 2.13.0, 2.13.1, 2.14.0rc0, 2.14.0rc1, 2.14.0, 2.14.1, 2.15.0rc0, 2.15.0rc1, 2.15.0)

@jayavanth
Copy link

jayavanth commented Feb 14, 2024

do you use pyenv or conda? I switched to python 3.10, still got error : Could not find a version that satisfies the requirement tensorflow==2.11.0 (from versions: 2.13.0rc0, 2.13.0rc1, 2.13.0rc2, 2.13.0, 2.13.1, 2.14.0rc0, 2.14.0rc1, 2.14.0, 2.14.1, 2.15.0rc0, 2.15.0rc1, 2.15.0)

Same issue with Python 3.10.13. Seems to work with Python 3.10.12 on Colab

@heyitsguay
Copy link
Author

do you use pyenv or conda? I switched to python 3.10, still got error : Could not find a version that satisfies the requirement tensorflow==2.11.0 (from versions: 2.13.0rc0, 2.13.0rc1, 2.13.0rc2, 2.13.0, 2.13.1, 2.14.0rc0, 2.14.0rc1, 2.14.0, 2.14.1, 2.15.0rc0, 2.15.0rc1, 2.15.0)

Am using Python 3.10.6, neither pyenv nor conda, just installing everything in a venv. I presume tensorflow is just specifying a range of compatible python versions for each previous library version.

Using tensorflow 2.14.1 doesn't seem to have affected the flax/jax errors I encountered down the line, though who knows if it would cause other problems eventually.

@anubhavashok
Copy link

anubhavashok commented Feb 14, 2024

Updating flax, jax, chex and tux to the latest versions worked for me.

pip install flax -U
pip install tux -U
pip install chex -U

When updating jax make sure to install the GPU compatible version if you're using GPU

pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

@heyitsguay
Copy link
Author

Updating flax, jax, chex and tux to the latest versions worked for me.

pip install flax -U
pip install tux -U
pip install chex -U

When updating jax make sure to install the GPU compatible version if you're using GPU

pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

This worked for me! Now run_vision_chat.sh runs, though it appears to be hanging after something completes? I get

I0214 22:05:40.405170 140030517383168 xla_bridge.py:689] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0214 22:05:40.408375 140030517383168 xla_bridge.py:689] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2024-02-14 22:05:42.728080: W external/xla/xla/service/gpu/nvptx_compiler.cc:744] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
100%|██████████| 1/1 [00:05<00:00,  5.89s/it]

With ~18846MB of VRAM allocated on each of my 4 old GPUs on the server (P40s) but no activity.

@wilson1yan
Copy link
Contributor

I think I ran into a similar hanging issue before on GPU, due to something in the transformers package stalling due to some FlaxSampleOutput or something. I didn't get a chance to look that deeply into it before since we were using almost all TPUs anyways but I'll look into it now

@wilson1yan
Copy link
Contributor

wilson1yan commented Feb 14, 2024

The stalling seems to be due to a weird bug with importing torch after decord (link). I've updated the requirements.txt to remove the torch dependency and it seems to run fine on GPU now (tested on an A100, CUDA 12.3). You will need to delete / reinstall your environment, or uninstall torch / torchvision.

I also added more detailed installation instructions which worked for me to the README (also shown below):

conda create -n lwm python=3.10
pip install -U "jax[cuda12_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -r requirements.txt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants