-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
29 changed files
with
2,951 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
|
||
outputs/ | ||
__pycache__/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# SMART: Self-supervised Multi-task pretrAining with contRol Transformers | ||
|
||
|
||
This is the official codebase for the ICLR 2023 spotlight paper "SMART: Self-supervised Multi-task pretrAining with contRol Transformers". Pretrained models can be downloaded [here](https://link-url-here.org). Dataset can be downloaded [here](https://link-url-here.org). | ||
|
||
## Setting up | ||
|
||
- Using conda | ||
|
||
``` | ||
# dmc specific | ||
# create env | ||
conda env create --file docker/environment.yml | ||
# activate conda | ||
conda activate smart | ||
bash src/scripts/dmc_setup.sh | ||
# install this repo | ||
(smart) $ pip install -e . | ||
``` | ||
|
||
- Using docker | ||
|
||
``` | ||
# dmc specific | ||
docker pull PUBLIC_DOCKER_IMAGE | ||
# run image | ||
docker run -it -d --gpus=all --name=rl_pretrain_dmc_1 -v HOST_PATH:CONTAINER_PATH commondockerimages.azurecr.io/atari_pretrain:latest-azureml-dmc | ||
# setup the repo (run inside the container) | ||
pip install -e . | ||
``` | ||
|
||
## Preparing the dataset | ||
|
||
Download dataset to PATH_TO_DATASET, or collect data following this instruction. | ||
|
||
## Running the code | ||
|
||
**Pretraining on multiple domains and tasks** (selection of pretraining tasks can be specified in the config file as shown below): | ||
``` | ||
## pretrain with offline data collected by exploratory policies | ||
python src/dmc_multidomain_train.py \ | ||
--epochs 10 --num_steps 80000 --train_replay_id 5 --model_type naive \ | ||
--multi_config configs/train_configs/multipretrain_source_v1.json \ | ||
--output_dir ./outputs/pretrain_explore/ \ | ||
--data_dir_prefix PATH_TO_DATASET | ||
## pretrain with offline data collected by random policies | ||
python src/dmc_multidomain_train.py \ | ||
--epochs 10 --num_steps 80000 --train_replay_id 5 --model_type naive \ | ||
--multi_config configs/train_configs/multipretrain_source_v1.json --source_data_type rand \ | ||
--output_dir ./outputs/pretrain_random/ \ | ||
--data_dir_prefix PATH_TO_DATASET | ||
``` | ||
You can also download our pretrained models as reported in the paper in this [here](https://link-url-here.org). | ||
|
||
|
||
|
||
|
||
The command below **loads the pretrained model and finetunes the policy on a specific downstream task**: | ||
``` | ||
## (example) set the downstream domain and task | ||
DOMAIN=cheetah | ||
TASK=run | ||
## behavior cloning as the learning algorithm | ||
python src/dmc_train.py \ | ||
--epochs 30 --num_steps 1000000 --domain ${DOMAIN} --task ${TASK} \ | ||
--model_type naive --no_load_action \ | ||
--load_model_from ./outputs/pretrain_explore/checkpoints/last.ckpt \ | ||
--output_dir ./outputs/${DOMAIN}_${TASK}_bc/ | ||
## RTG-conditioned learning as the learning algorithm | ||
python src/dmc_train.py \ | ||
--epochs 30 --num_steps 1000000 --domain ${DOMAIN} --task ${TASK} \ | ||
--model_type reward_conditioned --rand_select --no_load_action \ | ||
--load_model_from ./outputs/pretrain_explore/checkpoints/last.ckpt \ | ||
--output_dir ./outputs/${DOMAIN}_${TASK}_rtg/ | ||
``` | ||
|
||
Note that if *--load_model_from* is not specified, the model is trained from scratch. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
{ | ||
"walker": ["walk"], | ||
"cheetah": ["run"] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
{ | ||
"walker": ["stand", "run"], | ||
"cheetah": ["run"], | ||
"cartpole": ["swingup"], | ||
"hopper": ["hop"] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
{ | ||
"walker": ["run"], | ||
"hopper": ["stand"], | ||
"finger": ["spin"], | ||
"swimmer": ["swimmer15"], | ||
"fish": ["swim"] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
ARG BASE_IMAGE=openmpi4.1.0-cuda11.3-cudnn8-ubuntu20.04:latest | ||
|
||
FROM mcr.microsoft.com/azureml/${BASE_IMAGE} | ||
|
||
ARG DEBIAN_FRONTEND=noninteractive | ||
|
||
RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends \ | ||
build-essential \ | ||
cmake \ | ||
g++-7 \ | ||
git \ | ||
gpg \ | ||
curl \ | ||
vim \ | ||
wget \ | ||
ca-certificates \ | ||
libjpeg-dev \ | ||
libpng-dev \ | ||
librdmacm1 \ | ||
libibverbs1 \ | ||
ibverbs-providers \ | ||
openssh-client \ | ||
openssh-server \ | ||
libsm6 \ | ||
libxext6 \ | ||
ffmpeg \ | ||
libfontconfig1 \ | ||
libxrender1 \ | ||
libgl1-mesa-glx &&\ | ||
apt-get clean && rm -rf /var/lib/apt/lists/* | ||
|
||
RUN gpg --version && apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv-keys A4B469963BF863CC && \ | ||
apt update -y && apt install -y libgl1-mesa-glx libosmesa6 libosmesa6-dev libglfw3-dev libgles2-mesa-dev freeglut3-dev | ||
|
||
ADD environment.yml /tmp/environment.yml | ||
RUN conda env update -n base -f /tmp/environment.yml | ||
|
||
RUN pip install PyOpenGL==3.1.4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
absl-py==0.9.0 | ||
astunparse==1.6.3 | ||
atari-py==0.2.6 | ||
cachetools==4.1.1 | ||
certifi==2020.6.20 | ||
chardet==3.0.4 | ||
cloudpickle==1.3.0 | ||
decorator==4.4.2 | ||
dm-control==1.0.3.post1 | ||
dm-acme | ||
dm-reverb-nightly | ||
dopamine-rl==3.1.2 | ||
future==0.18.2 | ||
gast==0.3.3 | ||
gin-config==0.3.0 | ||
glfw==1.11.2 | ||
google-auth==1.18.0 | ||
google-auth-oauthlib==0.4.1 | ||
google-pasta==0.2.0 | ||
grpcio==1.30.0 | ||
gym==0.25.2 | ||
h5py==2.10.0 | ||
idna==2.10 | ||
Keras-Preprocessing==1.1.2 | ||
lxml==4.5.1 | ||
Markdown==3.2.2 | ||
numpy==1.19.0 | ||
oauthlib==3.1.0 | ||
opencv-python==4.3.0.36 | ||
opt-einsum==3.2.1 | ||
Pillow==7.2.0 | ||
portpicker==1.3.1 | ||
protobuf==3.20.1 | ||
pyasn1==0.4.8 | ||
pyasn1-modules==0.2.8 | ||
pyglet==1.5.0 | ||
PyOpenGL==3.1.5 | ||
pyparsing==2.4.7 | ||
requests==2.24.0 | ||
requests-oauthlib==1.3.0 | ||
rsa==4.6 | ||
scipy==1.4.1 | ||
six==1.15.0 | ||
tabulate==0.8.7 | ||
tf-nightly | ||
tb-nightly | ||
tensorboard-plugin-wit | ||
termcolor==1.1.0 | ||
tf-estimator-nightly | ||
tfp-nightly | ||
trfl==1.1.0 | ||
urllib3==1.25.9 | ||
Werkzeug==1.0.1 | ||
wrapt==1.12.1 |
Oops, something went wrong.