Skip to content

installing or updating JAX

Hajime Kawahara edited this page Feb 21, 2022 · 5 revisions

Installation of JAX/jaxlib on GPU

See the original JAX page. https://github.com/google/jax#installation

from source, see https://jax.readthedocs.io/en/latest/developer.html#building-from-source

Some notes

JAX update for Cuda 11.5/cudnn (Ubuntu 20.04/A100)

author:Hajime Kawahara, Feb 21st (2022)

  • Background: I wanted to use jax.experimental.sparse. Then, I needed to install the latest version of JAX.

Some dependency about CUDA was broken, so I decided to use aptitude instead of apt. Cudnn can be downloaded from the NVIDIA website.

sudo aptitude install cuda-11-5
sudo dpkg -i cudnn-local-repo-ubuntu2004-8.3.1.22_1.0-1_amd64.deb
pip uninstall jax
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

Then, a new module for sparse matrix worked.

manbou ~/jax(main)>python
Python 3.8.5 (default, Sep  4 2020, 07:30:14)
[GCC 7.3.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from jax.experimental.sparse import coo

All of the unit tests in develop passed.