Skip to content

nyuolab/VIPER-torch

Repository files navigation

VIPER_RL-torch

Pytorch implementation of Video Prediction Models as Rewards for Reinforcement Learning. VIPER leverages the next-frame log likelihoods of a pre-trained video prediction model as rewards for downstream reinforcement learning tasks. The method is flexible to the particular choice of video prediction model and reinforcement learning algorithm. The general method outline is shown below:

Install:

Create a conda environment with Python 3.10:

conda create -n viper python=3.10
conda activate viper

Install dependencies:

pip install -r requirements.txt

Downloading Data

Download the DeepMind Control Suite expert dataset with the following command:

python -m viper_rl_data.download dataset dmc

and the Atari dataset with:

python -m viper_rl_data.download dataset atari

This will produce datasets in <VIPER_INSTALL_PATH>/viper_rl_data/datasets/ which are used for training the video prediction model. The location of the datasets can be retrieved via the viper_rl_data.VIPER_DATASET_PATH variable.

Video Model Training

Use the following command to first train a VQ-GAN:

python scripts/train_vqgan.py -o viper_rl_data/checkpoints/dmc_vqgan -c viper_rl/configs/vqgan/dmc.yaml

To train the VideoGPT, update ae_ckpt in viper_rl/configs/dmc.yaml to point to the VQGAN checkpoint, and then run:

python scripts/train_videogpt.py -o viper_rl_data/checkpoints/dmc_videogpt_l16_s1 -c viper_rl/configs/videogpt/dmc.yaml

Policy training

python scripts/train_dreamer.py --configs=dmc_vision videogpt_prior_rb --task=dmc_walker_walk --reward_model=dmc_clen16_fskip1 --logdir=./logdir

Custom checkpoint directories can be specified with the $VIPER_CHECKPOINT_DIR environment variable. The default checkpoint path is set to viper_rl_data/checkpoints/.

Acknowledgments

This code is heavily inspired by the following works:

Releases

No releases published

Packages

No packages published