Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosferrazza committed Jun 29, 2024
2 parents e950253 + fcb8b29 commit 6db8365
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 26 deletions.
10 changes: 9 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,12 @@ __pycache__/
dist

# macOS
.DS_Store
.DS_Store

# others
logdir
logs
models
outputs
runs
wandb
29 changes: 16 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,29 @@ Structure of the repository:
* `tdmpc2`: Training code for TD-MPC2

## Installation
Create a clean conda environment:
```
# Install humanoid benchmark
conda create -n humanoidbench python=3.11
conda activate humanoidbench
```
Then, install the required packages:
```
# Install HumanoidBench
pip install -e .
# jax GPU version
pip install "jax[cuda12]==0.4.28"
# Or, jax CPU version
pip install "jax[cpu]==0.4.28"
# Install jaxrl
pip install -e jaxrl_m
pip install ml_collections flax distrax tf-keras
pip install -r requirements_jaxrl.txt
# Install dreamer
pip install -e dreamerv3
pip install ipdb wandb moviepy imageio opencv-python ruamel.yaml rich cloudpickle tensorflow tensorflow_probability dm-sonnet optax plotly msgpack zmq colored matplotlib
pip install -r requirements_dreamer.txt
# Install td-mpc2
pip install -e tdmpc2
pip install torch torchvision torchaudio hydra-core pyquaternion tensordict torchrl pandas hydra-submitit-launcher termcolor
# jax GPU version
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Or, jax CPU version
pip install --upgrade "jax[cpu]"
pip install -r requirements_tdmpc.txt
```


Expand Down Expand Up @@ -145,7 +148,7 @@ python -m tdmpc2.train disable_wandb=False wandb_entity=[WANDB_ENTITY] exp_name=
python -m embodied.agents.dreamerv3.train --configs humanoid_benchmark --run.wandb True --run.wandb_entity [WANDB_ENTITY] --method dreamer --logdir logs --task humanoid_${TASK} --seed 0
# Train SAC
python ./jaxrl_m/examples/mujoco/run_mujoco_sac.py --env_name ${TASK} --wandb_entity [WANDB_ENTITY] --max_steps 5000000 --seed 0
python ./jaxrl_m/examples/mujoco/run_mujoco_sac.py --env_name ${TASK} --wandb_entity [WANDB_ENTITY] --seed 0
# Train PPO (not using MJX)
python ./ppo/run_sb3_ppo.py --env_name ${TASK} --wandb_entity [WANDB_ENTITY] --seed 0
Expand Down
2 changes: 1 addition & 1 deletion jaxrl_m/examples/mujoco/run_mujoco_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
flags.DEFINE_integer('save_interval', 25000, 'Eval interval.')
flags.DEFINE_integer('render_interval', 250000, 'Render interval.')
flags.DEFINE_integer('batch_size', 64, 'Mini batch size.')
flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.')
flags.DEFINE_integer('max_steps', int(1e7), 'Number of training steps.')
flags.DEFINE_integer('start_steps', int(1e4), 'Number of initial exploration steps.')
flags.DEFINE_string('wandb_entity', 'robot-learning', 'Wandb entity.')

Expand Down
18 changes: 18 additions & 0 deletions requirements_dreamer.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
--editable dreamerv3
ipdb==0.13.13
wandb==0.17.3
moviepy==1.0.3
imageio==2.34.2
opencv-python==4.10.0.84
ruamel.yaml==0.18.6
rich==13.7.1
cloudpickle==3.0.0
tensorflow==2.16.2
tensorflow-probability==0.24.0
dm-sonnet==2.0.2
optax==0.2.2
plotly==5.22.0
msgpack==1.0.8
pyzmq==26.0.3
colored==2.2.4
matplotlib==3.9.0
5 changes: 5 additions & 0 deletions requirements_jaxrl.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
--editable jaxrl_m
ml_collections==0.1.1
flax==0.8.5
distrax==0.1.5
tf_keras==2.16.0
11 changes: 11 additions & 0 deletions requirements_tdmpc.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
--editable tdmpc2
torch==2.3.1
torchaudio==2.3.1
torchrl==0.4.0
torchvision==0.18.1
hydra-core==1.3.2
hydra-submitit-launcher==1.2.0
pyquaternion==0.9.9
tensordict==0.4.0
pandas==2.2.2
termcolor==2.4.0
22 changes: 11 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
long_description = (Path(__file__).parent / "README.md").read_text()

core_requirements = [
"gymnasium",
"rich",
"tqdm",
"ipdb",
"mujoco",
"mujoco-mjx",
"dm_control",
"imageio",
"gymnax",
"gymnasium==0.29.1",
"rich==13.7.1",
"tqdm==4.66.4",
"ipdb==0.13.13",
"mujoco==3.1.6",
"mujoco-mjx==3.1.6",
"dm_control==1.0.20",
"imageio==2.34.2",
"gymnax==0.0.8",
"brax==0.9.4",
"torch",
"opencv-python",
"torch==2.3.1",
"opencv-python==4.10.0.84",
]

setup(
Expand Down

0 comments on commit 6db8365

Please sign in to comment.