Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem linking jax with existing CUDA, CUDNN installations #3503

Closed
gshartnett opened this issue Jun 22, 2020 · 3 comments
Closed

Problem linking jax with existing CUDA, CUDNN installations #3503

gshartnett opened this issue Jun 22, 2020 · 3 comments
Labels

Comments

@gshartnett
Copy link

I am having trouble getting jax to recognize my existing CUDA and CUDNN installations. I installed jax using

PYTHON_VERSION=cp38
CUDA_VERSION=cuda101
PLATFORM=linux_x86_64
BASE_URL='https://storage.googleapis.com/jax-releases'
pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.48-$PYTHON_VERSION-none-$PLATFORM.whl
pip install --upgrade jax

When I try to import jax, I receive the following error message:

>>> import jax
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/gavin/anaconda3/envs/nvp_flow/lib/python3.8/site-packages/jax/__init__.py", line 17, in <module>
    from .api import (
  File "/home/gavin/anaconda3/envs/nvp_flow/lib/python3.8/site-packages/jax/api.py", line 38, in <module>
    from . import core
  File "/home/gavin/anaconda3/envs/nvp_flow/lib/python3.8/site-packages/jax/core.py", line 31, in <module>
    from . import dtypes
  File "/home/gavin/anaconda3/envs/nvp_flow/lib/python3.8/site-packages/jax/dtypes.py", line 31, in <module>
    from .lib import xla_client
  File "/home/gavin/anaconda3/envs/nvp_flow/lib/python3.8/site-packages/jax/lib/__init__.py", line 52, in <module>
    from jaxlib import xla_client
  File "/home/gavin/anaconda3/envs/nvp_flow/lib/python3.8/site-packages/jaxlib/xla_client.py", line 39, in <module>
    from . import xla_extension as _xla
ImportError: libcudnn.so.7: cannot open shared object file: No such file or directory

I am running Ubuntu 20.04, my cuda version is 10.1, and my cuda directory is usr/lib/cuda. The path to the file libcudnn.so.7 is in /usr/lib/cuda/lib64/libcudnn.so.7. As suggested in the installation instructions here and in this issue, I have tried sym-linking the path Jax expects to find the cuda installation with the actual path, for example I tried both of the following:

sudo ln -s /usr/lib/cuda /usr/local/cuda-10.1
sudo ln -s /usr/lib/cuda/lib64 /usr/local/cuda-10.1

but neither worked. I also tried setting the path within my jupyter notebook session, for example by setting

import os
os.environ["XLA_FLAGS"]="--xla_gpu_cuda_data_dir=/usr/lib/cuda/"
os.environ["CUDA_HOME"]="/usr/lib/cuda"

but this didn't work either. Another potentially useful piece of information is the version of my CUDA driver differs from the version of my CUDA runtime. The driver version is 10.2, as confirmed by calling nvidia-smi, which returns

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.64       Driver Version: 440.64       CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce GTX 108...  Off  | 00000000:08:00.0  On |                  N/A |
| 31%   46C    P0    58W / 250W |    444MiB / 11175MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0      1055      G   /usr/lib/xorg/Xorg                            35MiB |
|    0      1595      G   /usr/lib/xorg/Xorg                           124MiB |
|    0      1792      G   /usr/bin/gnome-shell                         106MiB |
|    0      2341      G   ...AAAAAAAAAAAACAAAAAAAAAA= --shared-files   159MiB |
|    0      3870      G   /usr/bin/nvidia-settings                       3MiB |
+-----------------------------------------------------------------------------+

in contrast, nvcc -V returns

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Sun_Jul_28_19:07:16_PDT_2019
Cuda compilation tools, release 10.1, V10.1.243
@hawkinsp
Copy link
Collaborator

Can you try pointing the LD_LIBRARY_PATH environment variable to the directory containing libcudnn.so.7 ?

@mattjj mattjj added the build label Jun 23, 2020
@gshartnett
Copy link
Author

After doing this, the GPU is apparently found by Jax but there is still an error:

import os
os.environ["LD_LIBRARY_PATH"]="/usr/lib/cuda/lib64/"

import jax
import jax.numpy as np
print("jax version {}".format(jax.__version__))
from jax.lib import xla_bridge
print("jax backend {}".format(xla_bridge.get_backend().platform))

from jax import random
key = random.PRNGKey(0)
x = random.normal(key, (5,5))
print(x)

produces the error message

jax version 0.1.70
jax backend gpu
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-1-229ce8972f18> in <module>
     10 from jax import random
     11 key = random.PRNGKey(0)
---> 12 x = random.normal(key, (5,5))
     13 print(x)

~/anaconda3/envs/nvp_flow/lib/python3.8/site-packages/jax/random.py in normal(key, shape, dtype)
    551   dtype = dtypes.canonicalize_dtype(dtype)
    552   shape = abstract_arrays.canonicalize_shape(shape)
--> 553   return _normal(key, shape, dtype)
    554 
    555 @partial(jit, static_argnums=(1, 2))

~/anaconda3/envs/nvp_flow/lib/python3.8/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    164     for arg in args_flat: _check_arg(arg)
    165     flat_fun, out_tree = flatten_fun(f, in_tree)
--> 166     out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
    167                        name=flat_fun.__name__, donated_invars=donated_invars)
    168     return tree_unflatten(out_tree(), out)

~/anaconda3/envs/nvp_flow/lib/python3.8/site-packages/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params)
   1083   if top_trace is None:
   1084     with new_sublevel():
-> 1085       outs = primitive.impl(f, *args, **params)
   1086   else:
   1087     tracers = map(top_trace.full_raise, args)

~/anaconda3/envs/nvp_flow/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    532 
    533 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
--> 534   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars, *map(arg_spec, args))
    535   try:
    536     return compiled_fun(*args)

~/anaconda3/envs/nvp_flow/lib/python3.8/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    219       fun.populate_stores(stores)
    220     else:
--> 221       ans = call(fun, *args)
    222       cache[key] = (ans, fun.stores)
    223     return ans

~/anaconda3/envs/nvp_flow/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    644       device_assignment=(device.id,) if device else None)
    645   options.parameter_is_tupled_arguments = tuple_args
--> 646   compiled = backend.compile(built, compile_options=options)
    647   if nreps == 1:
    648     return partial(_execute_compiled, compiled, uses_outfeed, result_handlers)

RuntimeError: Internal: libdevice not found at ./libdevice.10.bc

I think tried the suggestions from this issue thread where the same error message occurred, and I was able to get it to work. Specifically, I ran

sudo ln -s /usr/lib/cuda /usr/local/cuda-10.1

and now Jax loads just fine!

@surak
Copy link

surak commented Oct 30, 2020

That doesn't help in data centers with multiple CUDA versions in non-standard locations. How can one actually add cuda path? The python build/build.py --enable_cuda doesn't work anymore.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants