Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
ycsun2017 authored and madratman committed Feb 24, 2023
0 parents commit 9d87e19
Show file tree
Hide file tree
Showing 29 changed files with 2,951 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

outputs/
__pycache__/

84 changes: 84 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# SMART: Self-supervised Multi-task pretrAining with contRol Transformers


This is the official codebase for the ICLR 2023 spotlight paper "SMART: Self-supervised Multi-task pretrAining with contRol Transformers". Pretrained models can be downloaded [here](https://link-url-here.org). Dataset can be downloaded [here](https://link-url-here.org).

## Setting up

- Using conda

```
# dmc specific
# create env
conda env create --file docker/environment.yml
# activate conda
conda activate smart
bash src/scripts/dmc_setup.sh
# install this repo
(smart) $ pip install -e .
```

- Using docker

```
# dmc specific
docker pull PUBLIC_DOCKER_IMAGE
# run image
docker run -it -d --gpus=all --name=rl_pretrain_dmc_1 -v HOST_PATH:CONTAINER_PATH commondockerimages.azurecr.io/atari_pretrain:latest-azureml-dmc
# setup the repo (run inside the container)
pip install -e .
```

## Preparing the dataset

Download dataset to PATH_TO_DATASET, or collect data following this instruction.

## Running the code

**Pretraining on multiple domains and tasks** (selection of pretraining tasks can be specified in the config file as shown below):
```
## pretrain with offline data collected by exploratory policies
python src/dmc_multidomain_train.py \
--epochs 10 --num_steps 80000 --train_replay_id 5 --model_type naive \
--multi_config configs/train_configs/multipretrain_source_v1.json \
--output_dir ./outputs/pretrain_explore/ \
--data_dir_prefix PATH_TO_DATASET
## pretrain with offline data collected by random policies
python src/dmc_multidomain_train.py \
--epochs 10 --num_steps 80000 --train_replay_id 5 --model_type naive \
--multi_config configs/train_configs/multipretrain_source_v1.json --source_data_type rand \
--output_dir ./outputs/pretrain_random/ \
--data_dir_prefix PATH_TO_DATASET
```
You can also download our pretrained models as reported in the paper in this [here](https://link-url-here.org).




The command below **loads the pretrained model and finetunes the policy on a specific downstream task**:
```
## (example) set the downstream domain and task
DOMAIN=cheetah
TASK=run
## behavior cloning as the learning algorithm
python src/dmc_train.py \
--epochs 30 --num_steps 1000000 --domain ${DOMAIN} --task ${TASK} \
--model_type naive --no_load_action \
--load_model_from ./outputs/pretrain_explore/checkpoints/last.ckpt \
--output_dir ./outputs/${DOMAIN}_${TASK}_bc/
## RTG-conditioned learning as the learning algorithm
python src/dmc_train.py \
--epochs 30 --num_steps 1000000 --domain ${DOMAIN} --task ${TASK} \
--model_type reward_conditioned --rand_select --no_load_action \
--load_model_from ./outputs/pretrain_explore/checkpoints/last.ckpt \
--output_dir ./outputs/${DOMAIN}_${TASK}_rtg/
```

Note that if *--load_model_from* is not specified, the model is trained from scratch.
4 changes: 4 additions & 0 deletions configs/pretrain_configs/multipretrain_source_test.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"walker": ["walk"],
"cheetah": ["run"]
}
6 changes: 6 additions & 0 deletions configs/pretrain_configs/multipretrain_source_v1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"walker": ["stand", "run"],
"cheetah": ["run"],
"cartpole": ["swingup"],
"hopper": ["hop"]
}
7 changes: 7 additions & 0 deletions configs/pretrain_configs/multipretrain_source_v2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"walker": ["run"],
"hopper": ["stand"],
"finger": ["spin"],
"swimmer": ["swimmer15"],
"fish": ["swim"]
}
38 changes: 38 additions & 0 deletions docker/Dockerfile_base_azureml_dmc
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
ARG BASE_IMAGE=openmpi4.1.0-cuda11.3-cudnn8-ubuntu20.04:latest

FROM mcr.microsoft.com/azureml/${BASE_IMAGE}

ARG DEBIAN_FRONTEND=noninteractive

RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends \
build-essential \
cmake \
g++-7 \
git \
gpg \
curl \
vim \
wget \
ca-certificates \
libjpeg-dev \
libpng-dev \
librdmacm1 \
libibverbs1 \
ibverbs-providers \
openssh-client \
openssh-server \
libsm6 \
libxext6 \
ffmpeg \
libfontconfig1 \
libxrender1 \
libgl1-mesa-glx &&\
apt-get clean && rm -rf /var/lib/apt/lists/*

RUN gpg --version && apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv-keys A4B469963BF863CC && \
apt update -y && apt install -y libgl1-mesa-glx libosmesa6 libosmesa6-dev libglfw3-dev libgles2-mesa-dev freeglut3-dev

ADD environment.yml /tmp/environment.yml
RUN conda env update -n base -f /tmp/environment.yml

RUN pip install PyOpenGL==3.1.4
54 changes: 54 additions & 0 deletions docker/dmc_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
absl-py==0.9.0
astunparse==1.6.3
atari-py==0.2.6
cachetools==4.1.1
certifi==2020.6.20
chardet==3.0.4
cloudpickle==1.3.0
decorator==4.4.2
dm-control==1.0.3.post1
dm-acme
dm-reverb-nightly
dopamine-rl==3.1.2
future==0.18.2
gast==0.3.3
gin-config==0.3.0
glfw==1.11.2
google-auth==1.18.0
google-auth-oauthlib==0.4.1
google-pasta==0.2.0
grpcio==1.30.0
gym==0.25.2
h5py==2.10.0
idna==2.10
Keras-Preprocessing==1.1.2
lxml==4.5.1
Markdown==3.2.2
numpy==1.19.0
oauthlib==3.1.0
opencv-python==4.3.0.36
opt-einsum==3.2.1
Pillow==7.2.0
portpicker==1.3.1
protobuf==3.20.1
pyasn1==0.4.8
pyasn1-modules==0.2.8
pyglet==1.5.0
PyOpenGL==3.1.5
pyparsing==2.4.7
requests==2.24.0
requests-oauthlib==1.3.0
rsa==4.6
scipy==1.4.1
six==1.15.0
tabulate==0.8.7
tf-nightly
tb-nightly
tensorboard-plugin-wit
termcolor==1.1.0
tf-estimator-nightly
tfp-nightly
trfl==1.1.0
urllib3==1.25.9
Werkzeug==1.0.1
wrapt==1.12.1
Loading

0 comments on commit 9d87e19

Please sign in to comment.