-
Conda Environment:
- create an environment with
conda create -n serl python=3.10
- create an environment with
-
Recommended:
-
Assume the machines have the lastest Nvdia drivers and CUDA Versions (either 12.1 or 11.x)
-
Run
pip install --upgrade pip pip install -e .
# CUDA 12 installation # Note: wheels only available on linux. pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # CUDA 11 installation # Note: wheels only available on linux. pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
-
-
Check here for JAX installation with local CUDA and CUDNN installations,
- This way can be more complicated.
-
For running experiments from vision, please also
git clone
andpip install -e .
this library https://github.com/Leo428/efficientnet-jax. It is forked from https://github.com/rwightman/efficientnet-jax to support learning with pre-trained visual encoders (EfficientNet and MobileNets) in JAX and Flax.
This folder contains example usages of serl as in the paper.