From f3154ef4f9e7c45a06285116d7fec6b235a7af0f Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Wed, 27 May 2020 17:05:36 -0700 Subject: [PATCH] Update installation directions in README to mention expected CUDA location. (#3190) See https://github.com/google/jax/issues/989 --- README.md | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index e15b3aa39102..3066a0b36b4f 100644 --- a/README.md +++ b/README.md @@ -429,10 +429,24 @@ nvcc --version grep CUDNN_MAJOR -A 2 /usr/local/cuda/include/cudnn.h # might need different path ``` +Note that some GPU functionality expects the CUDA installation to be at +`/usr/local/cuda-X.X`, where X.X should be replaced with the CUDA version number +(e.g. `cuda-10.2`). If CUDA is installed elsewhere on your system, you can either +create a symlink: + +```bash +sudo ln -s /path/to/cuda /usr/local/cuda-X.X +``` + +Or set the following environment variable before importing JAX: + +```bash +XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda +``` + The Python version must match your Python interpreter. There are prebuilt wheels -for Python 3.6, 3.7, and 3.8; for anything else, you must build from -source. Jax requires Python 3.6 or above. Jax does not support Python 2 any -more. +for Python 3.6, 3.7, and 3.8; for anything else, you must build from source. Jax +requires Python 3.6 or above. Jax does not support Python 2 any more. To try automatic detection of the correct version for your system, you can run: