This repository contains the PyTorch implementation for our paper titled StARformer: Transformer with State-Action-Reward Representations for Visual Reinforcement Learning (ECCV 2022) and StARformer: Transformer with State-Action-Reward Representations for Robot Learning (IEEE T-PAMI).
[Installation] [Usage] [Citation] [Update Notes]
We learn local State-Action-Reward representations (StAR-representations) to improve (long) sequence modeling for reinforcement learning (and imitation learning).
For details and unormalized numbers, please check the supplementary at the end of the paper or here for conveience.
Dependencies can be installed by Conda:
For example to install env used for Atari and DMC (image input):
conda env create -f atari_and_dmc/conda_env.yml
Then activate it by
conda activate starformer
-
Atari: To run on atari environment, please install Atari ROMs.
-
DMC: Install dmc2gym by
pip install git+https://github.com/denisyarats/dmc2gym.git
Make sure you have MuJoCo installed. mujoco-py
has already been installed in the conda env for you, but it's good to check whether they two are compatible.
Please follow this instruction for datasets.
See run.sh
or below:
- atari:
python run_star_atari.py --seed 123 --data_dir_prefix [data_directory] --epochs 10 --num_steps 500000 --num_buffers 50 --batch_size 64 --seq_len 30 --model_type 'star' --game 'Breakout'
[data_directory]
is where you place the Atari dataset.
- dmc:
python run_star_dmc.py --seed 123 --data_dir_prefix [data_directory] --epochs 10 --seq_len 30 --model_type 'star' --batch_size 64 --domain ball_in_cup --task catch --lr 1e-4
similarly, [data_directory]
is where you place the DMC dataset. You can collect any replay buffer you desire and modify StateActionReturnDatasetDMC
in run_star_dmc.py
to make it compatible with your buffers.
'star'
(imitation)'star_rwd'
(offline RL)'star_fusion'
(see Figure 4a in our paper)'star_stack'
(see Figure 4b in our paper)
With num_steps=500000, batch_size=64, model_type=star_rwd
, on a single NVIDIA 3090Ti (24GB)
--seq_len=10
9685MB ~25min/epoch--seq_len=20
17033MB ~50min/epoch--seq_len=30
24007MB ~66min/epoch
If you are out of memory, you can reduce batch_size
If you find our paper useful for your research, please consider cite
@InProceedings{starformer,
author="Shang, Jinghuan and Kahatapitiya, Kumara and Li, Xiang and Ryoo, Michael S.",
title="StARformer: Transformer with State-Action-Reward Representations for Visual Reinforcement Learning",
booktitle="Computer Vision -- ECCV 2022",
year="2022",
publisher="Springer Nature Switzerland",
pages="462--479",
}
@ARTICLE{starformer-robot,
author={Shang, Jinghuan and Li, Xiang and Kahatapitiya, Kumara and Lee, Yu-Cheol and Ryoo, Michael S.},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
title={StARformer: Transformer with State-Action-Reward Representations for Robot Learning},
year={2022},
pages={1-16},
doi={10.1109/TPAMI.2022.3204708}
}
- Apr 6, 2023:
- fix bug in
run_star_atari.py
- fix conda env
- provide GPU usage reference
- fix bug in
- Nov 26, 2022:
- update code for dmc envrionments
- clean conda env file
This code is based on Decision-Transformer.