From a045e7b8800e26484d441b35a8774fb65161eedf Mon Sep 17 00:00:00 2001 From: Carlo Date: Fri, 28 Jun 2024 22:35:11 -0700 Subject: [PATCH 1/2] Updated instructions --- .gitignore | 10 +++++++++- README.md | 27 +++++++++++++++------------ requirements_dreamer.txt | 18 ++++++++++++++++++ requirements_jaxrl.txt | 5 +++++ requirements_tdmpc.txt | 11 +++++++++++ setup.py | 22 +++++++++++----------- 6 files changed, 69 insertions(+), 24 deletions(-) create mode 100644 requirements_dreamer.txt create mode 100644 requirements_jaxrl.txt create mode 100644 requirements_tdmpc.txt diff --git a/.gitignore b/.gitignore index 038c68c..8ff2010 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,12 @@ __pycache__/ dist # macOS -.DS_Store \ No newline at end of file +.DS_Store + +# others +logdir +logs +models +outputs +runs +wandb \ No newline at end of file diff --git a/README.md b/README.md index cd39ed8..e81bbf2 100644 --- a/README.md +++ b/README.md @@ -19,26 +19,29 @@ Structure of the repository: * `tdmpc2`: Training code for TD-MPC2 ## Installation +Create a clean conda environment: ``` -# Install humanoid benchmark +conda create -n humanoidbench python=3.11 +conda activate humanoidbench +``` +Then, install the required packages: +``` +# Install HumanoidBench pip install -e . +# jax GPU version +pip install "jax[cuda12]==0.4.28" +# Or, jax CPU version +pip install "jax[cpu]==0.4.28" + # Install jaxrl -pip install -e jaxrl_m -pip install ml_collections flax distrax tf-keras +pip install -r requirements_jaxrl.txt # Install dreamer -pip install -e dreamerv3 -pip install ipdb wandb moviepy imageio opencv-python ruamel.yaml rich cloudpickle tensorflow tensorflow_probability dm-sonnet optax plotly msgpack zmq colored matplotlib +pip install -r requirements_dreamer.txt # Install td-mpc2 -pip install -e tdmpc2 -pip install torch torchvision torchaudio hydra-core pyquaternion tensordict torchrl pandas hydra-submitit-launcher termcolor - -# jax GPU version -pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -# Or, jax CPU version -pip install --upgrade "jax[cpu]" +pip install -r requirements_tdmpc.txt ``` diff --git a/requirements_dreamer.txt b/requirements_dreamer.txt new file mode 100644 index 0000000..f796448 --- /dev/null +++ b/requirements_dreamer.txt @@ -0,0 +1,18 @@ +--editable dreamerv3 +ipdb==0.13.13 +wandb==0.17.3 +moviepy==1.0.3 +imageio==2.34.2 +opencv-python==4.10.0.84 +ruamel.yaml==0.18.6 +rich==13.7.1 +cloudpickle==3.0.0 +tensorflow==2.16.2 +tensorflow-probability==0.24.0 +dm-sonnet==2.0.2 +optax==0.2.2 +plotly==5.22.0 +msgpack==1.0.8 +pyzmq==26.0.3 +colored==2.2.4 +matplotlib==3.9.0 diff --git a/requirements_jaxrl.txt b/requirements_jaxrl.txt new file mode 100644 index 0000000..fc64952 --- /dev/null +++ b/requirements_jaxrl.txt @@ -0,0 +1,5 @@ +--editable jaxrl_m +ml_collections==0.1.1 +flax==0.8.5 +distrax==0.1.5 +tf_keras==2.16.0 \ No newline at end of file diff --git a/requirements_tdmpc.txt b/requirements_tdmpc.txt new file mode 100644 index 0000000..6ac36e1 --- /dev/null +++ b/requirements_tdmpc.txt @@ -0,0 +1,11 @@ +--editable tdmpc2 +torch==2.3.1 +torchaudio==2.3.1 +torchrl==0.4.0 +torchvision==0.18.1 +hydra-core==1.3.2 +hydra-submitit-launcher==1.2.0 +pyquaternion==0.9.9 +tensordict==0.4.0 +pandas==2.2.2 +termcolor==2.4.0 diff --git a/setup.py b/setup.py index 1b62cf4..3b563c3 100644 --- a/setup.py +++ b/setup.py @@ -5,18 +5,18 @@ long_description = (Path(__file__).parent / "README.md").read_text() core_requirements = [ - "gymnasium", - "rich", - "tqdm", - "ipdb", - "mujoco", - "mujoco-mjx", - "dm_control", - "imageio", - "gymnax", + "gymnasium==0.29.1", + "rich==13.7.1", + "tqdm==4.66.4", + "ipdb==0.13.13", + "mujoco==3.1.6", + "mujoco-mjx==3.1.6", + "dm_control==1.0.20", + "imageio==2.34.2", + "gymnax==0.0.8", "brax==0.9.4", - "torch", - "opencv-python", + "torch==2.3.1", + "opencv-python==4.10.0.84", ] setup( From fcb8b29144d009162798bbe4a993f2e7b9c35421 Mon Sep 17 00:00:00 2001 From: Carlo Date: Fri, 28 Jun 2024 22:42:59 -0700 Subject: [PATCH 2/2] Updated SAC config --- README.md | 2 +- jaxrl_m/examples/mujoco/run_mujoco_sac.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index e81bbf2..c9309ad 100644 --- a/README.md +++ b/README.md @@ -148,7 +148,7 @@ python -m tdmpc2.train disable_wandb=False wandb_entity=[WANDB_ENTITY] exp_name= python -m embodied.agents.dreamerv3.train --configs humanoid_benchmark --run.wandb True --run.wandb_entity [WANDB_ENTITY] --method dreamer --logdir logs --task humanoid_${TASK} --seed 0 # Train SAC -python ./jaxrl_m/examples/mujoco/run_mujoco_sac.py --env_name ${TASK} --wandb_entity [WANDB_ENTITY] --max_steps 5000000 --seed 0 +python ./jaxrl_m/examples/mujoco/run_mujoco_sac.py --env_name ${TASK} --wandb_entity [WANDB_ENTITY] --seed 0 # Train PPO (not using MJX) python ./ppo/run_sb3_ppo.py --env_name ${TASK} --wandb_entity [WANDB_ENTITY] --seed 0 diff --git a/jaxrl_m/examples/mujoco/run_mujoco_sac.py b/jaxrl_m/examples/mujoco/run_mujoco_sac.py index 199fa15..5a973b0 100644 --- a/jaxrl_m/examples/mujoco/run_mujoco_sac.py +++ b/jaxrl_m/examples/mujoco/run_mujoco_sac.py @@ -31,7 +31,7 @@ flags.DEFINE_integer('save_interval', 25000, 'Eval interval.') flags.DEFINE_integer('render_interval', 250000, 'Render interval.') flags.DEFINE_integer('batch_size', 64, 'Mini batch size.') -flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.') +flags.DEFINE_integer('max_steps', int(1e7), 'Number of training steps.') flags.DEFINE_integer('start_steps', int(1e4), 'Number of initial exploration steps.') flags.DEFINE_string('wandb_entity', 'robot-learning', 'Wandb entity.')