Skip to content

Codebase for ICLR 2023 paper, "SMART: Self-supervised Multi-task pretrAining with contRol Transformers"

License

Notifications You must be signed in to change notification settings

microsoft/smart

Folders and files

NameName
Last commit message
Last commit date

Latest commit

author
Yanchao Sun
Jan 26, 2024
1a93a4a · Jan 26, 2024

History

4 Commits
Feb 24, 2023
Feb 24, 2023
Jan 26, 2024
Feb 24, 2023
Feb 24, 2023
Jan 26, 2024
Feb 24, 2023
Feb 24, 2023
Feb 24, 2023
Feb 24, 2023
Jan 26, 2024
Feb 24, 2023
Feb 24, 2023
Feb 24, 2023
Feb 24, 2023
Feb 24, 2023

Repository files navigation

SMART: Self-supervised Multi-task pretrAining with contRol Transformers

This is the official codebase for the ICLR 2023 spotlight paper SMART: Self-supervised Multi-task pretrAining with contRol Transformers. If you use this code in an academic context, please use the following citation:

@inproceedings{
sun2023smart,
title={{SMART}: Self-supervised Multi-task pretrAining with contRol Transformers},
author={Yanchao Sun and Shuang Ma and Ratnesh Madaan and Rogerio Bonatti and Furong Huang and Ashish Kapoor},
booktitle={International Conference on Learning Representations},
year={2023},
url={https://openreview.net/forum?id=9piH3Hg8QEf}
}

Setting up

  • Using conda

    # dmc specific
    # create env
    conda env create --file docker/environment.yml
    
    # activate conda
    conda activate smart
    bash scripts/dmc_setup.sh
    
    # install this repo
    (smart) $ pip install -e .
    
  • Using docker

    # build image
    docker build \
          -f Dockerfile_base_azureml_dmc \
          --build-arg BASE_IMAGE=openmpi4.1.0-cuda11.3-cudnn8-ubuntu20.04:latest \
          -t smart:latest .
    
    # run image
    docker run -it -d --gpus=all --name=rl_pretrain_dmc_1 -v HOST_PATH:CONTAINER_PATH smart:latest
    
    # setup the repo (run inside the container)
    pip install -e .
    

Downloading data and pre-trained models download from Azure

  • Install azcopy

    wget https://aka.ms/downloadazcopy-v10-linux
    tar -xvf downloadazcopy-v10-linux
    sudo cp ./azcopy_linux_amd64_*/azcopy /usr/bin/
    rm -rf *azcopy*
    
  • Downloading the full dataset (1.18TiB)

    # download to data/ directory
    azcopy copy 'https://smartrelease.blob.core.windows.net/smartrelease/data/dmc_ae' 'data' --recursive
    
  • Downloading a subset of the full dataset

    # download to data/ directory
    azcopy copy 'https://smartrelease.blob.core.windows.net/smartrelease/data/dmc_ae/TYPE_DOMAIN_TASK' 'data' --recursive
    

    where

    • TYPE: randcollect, fullcollect Note: fullcollect datasets are ~10x larger than randcollect datasets)

    • DOMAIN_TASK: cartpole_balance, cartpole_swingup, cheetah_run, finger_spin, hopper_hop, hopper_stand, pendulum_swingup, walker_run, walker_stand, or walker_walk (See Table 2 in the paper)

    Example:

    # download to data/ directory (~ 9.7 GB each)
    azcopy copy 'https://smartrelease.blob.core.windows.net/smartrelease/data/dmc_ae/randcollect_walker_walk' 'data' --recursive
    azcopy copy 'https://smartrelease.blob.core.windows.net/smartrelease/data/dmc_ae/randcollect_cheetah_run' 'data' --recursive
    
  • Downloading the pretrained models

    # download to pretrained_models/ directory (236.34 MiB)
    azcopy copy 'https://smartrelease.blob.core.windows.net/smartrelease/pretrained_models' '.' --recursive
    

Running the code

Testing on small subset of full dataset

Let us run the code on the aforementioned small subset of randcollect_walker_walk and randcollect_cheetah_run.

python src/dmc_pretrain.py base=configs/pretrain.yaml \
        epochs=10 \
        data.num_steps=80000 \
        domain_and_task.source_data_type=rand \
        data.train_replay_id=1 \
        data.data_dir_prefix=data \
        model.model_type=naive \
        domain_and_task.source_envs="{'walker': ['walk'], 'cheetah': ['run']}" \
        output_dir=./outputs/pretrain_explore_subset

Pretraining on multiple domains and tasks

The set of pretraining tasks can be specified in the config file as shown below:

  • Pretrain with offline data collected by exploratory policies
python src/dmc_pretrain.py base=configs/pretrain.yaml \
        epochs=10 \
        data.num_steps=80000 \
        data.train_replay_id=5 \
        data.data_dir_prefix=data \
        model.model_type=naive \
        domain_and_task.source_data_type=full \
        domain_and_task.source_envs="{'walker': ['walk'], 'cheetah': ['run']}" \
        output_dir=./outputs/pretrain_explore
  • Pretrain with offline data collected by random policies
python src/dmc_pretrain.py base=configs/pretrain.yaml \
        epochs=10 \
        data.num_steps=80000 \
        data.train_replay_id=5 \
        data.data_dir_prefix=data \
        model.model_type=naive \
        domain_and_task.source_data_type=rand \
        domain_and_task.source_envs="{'walker': ['walk'], 'cheetah': ['run']}" \
        output_dir=./outputs/pretrain_random

Using pretrained model and finetunes the policy on a specific downstream task:

You can also download our pretrained models as reported in the paper, using the azcopy command in the previous section.

## set the downstream domain and task
DOMAIN=cheetah
TASK=run

## behavior cloning as the learning algorithm
python src/dmc_downstream.py base=configs/downstream.yaml \
        epochs=30 \
        data.num_steps=1000000 \
        domain_and_task.domain=${DOMAIN} \
        domain_and_task.task=${TASK} \
        model.model_type=naive \
        no_load_action=True \
        load_model_from=./outputs/pretrain_explore/checkpoints/last.ckpt \
        output_dir=./outputs/${DOMAIN}_${TASK}_bc/

## RTG-conditioned learning as the learning algorithm
python src/dmc_downstream.py \
        epochs=30 \
        data.num_steps=1000000 \
        domain_and_task.domain=${DOMAIN} \
        domain_and_task.task=${TASK} \
        model.model_type=reward_conditioned \
        data.rand_select=True \
        no_load_action=True \
        load_model_from=./outputs/pretrain_explore/checkpoints/last.ckpt \
        output_dir=./outputs/${DOMAIN}_${TASK}_bc/

Note that if --load_model_from is not specified, the model is trained from scratch.

About

Codebase for ICLR 2023 paper, "SMART: Self-supervised Multi-task pretrAining with contRol Transformers"

Resources

License

Code of conduct

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published