Ayush Jain*, Andrew Szot*, Joseph J. Lim at USC CLVR lab
[Paper website]
The structure of the repository:
analysis
: Scripts used for analysis figures and experiments.envs
: the four subfolders in this folder contain the four environments.method
: Implementation of all method and baseline detailsrlf
: Reinforcement Learning Framework. General RL / PPO training code.scripts
: Miscalaneous scripts. Contains script for generating the train / test action set splits.main.py
: Entry point for running policy.embedder.py
: Entry point for training embedder.
Log directories:
data/trained_model/ENV-NAME_PREFIX/
: Trained models.data/vids/ENV-NAME/
: Evaluation videos.data/logs/ENV-NAME/PREFIX/
: Tensorboard summary.
- Python 3.7
- MuJoCo 2.0
All the python package requirements are in requirements.txt
. If you are using conda, you can use the following command with Python 3.7.3:
conda create -n [your_name] python=3.7
source activate [your_name]
pip install -r requirements.txt
The experiment flow for each environment is similar. The steps are always the same as follows:
-
Generate train and test action splits:
python gen_action_sets.py --env-name $ENV_NAME
-
Generate Action Datasets for the environment:
python embedder.py --env-name $PLAY_ENV_NAME --save-dataset
-
Train action embedder model:
python embedder.py --env-name $PLAY_ENV_NAME --save-emb-model-file $EMB_FILE_NAME --train-embeddings
-
Generate embedding files:
python main.py --env-name $ENV_NAME --play-env-name $PLAY_ENV_NAME --load-emb-model-file $EMB_MODEL_NAME --save-embeddings-file $EMB_FILE_NAME --prefix main
-
Train policy with saved embeddings:
python main.py --env-name $ENV_NAME --load-embeddings-file $EMB_FILE_NAME
Note:
(1) $EMB_MODEL_NAME
must be $EMB_FILE_NAME-htvae-500.m
if your model is trained for at least 500 epochs (specified by --emb-epochs
).
(2) Use --n-trajectories 64
and --emb-epochs 500
for faster data generation and embedder training.
Below are the example commands used for each environment and method approach.
$ENV_NAME
= 'CreateLevelPush-v0'
or 'CreateLevelNavigate-v0'
or 'CreateLevelObstacle-v0'
.
$PLAY_ENV_NAME
= 'StateCreateGameN1PlayNew-v0'
(state-based) or 'CreateGamePlay-v0'
(video-based).
$EMB_FILE_NAME
= 'create_st'
(state-based) or create_im
(video-based)
(1) Train policy directly with:
python main.py --env-name CreateLevelPush-v0 --prefix main
.
python main.py --env-name CreateLevelNavigate-v0 --prefix main
.
python main.py --env-name CreateLevelObstacle-v0 --prefix main
.
OR
(2) For full procedure, follow these commands:
- Generate Splits:
python gen_action_sets.py --env-name CreateLevelPush-v0
- Generate Data:
python embedder.py --env-name StateCreateGameN1PlayNew-v0 --save-dataset
- Train Action Embedder:
python embedder.py --env-name StateCreateGameN1PlayNew-v0 --save-emb-model-file create_st --train-embeddings
- Generate embedding files:
python main.py --env-name CreateLevelPush-v0 --play-env-name StateCreateGameN1PlayNew-v0 --load-emb-model-file create_st-htvae-5000.m --save-embeddings-file create_st --prefix main
- Train policy with saved embeddings:
python main.py --env-name CreateLevelPush-v0 --load-embeddings-file create_st --prefix main
There is no data generation or embedding learning to recommender system
$ENV_NAME
= 'RecoEnv-v0'
(1) Train policy directly with:
python main.py --env-name RecoEnv-v0 --prefix main
OR
(2) For full procedure, follow these commands:
- Generate Splits:
python gen_action_sets.py --env-name RecoEnv-v0
- Policy:
python main.py --env-name RecoEnv-v0 --prefix main
$ENV_NAME
= 'StackEnv-v0'
$PLAY_ENV_NAME
= 'BlockPlayImg-v0'
$EMB_FILE_NAME
= 'stack_im'
(1) Train policy directly with:
python main.py --env-name StackEnv-v0 --prefix main
OR
(2) For full procedure, follow these commands:
- Generate Splits:
python gen_action_sets.py --env-name StackEnv-v0
- Generate Data:
python embedder.py --env-name BlockPlayImg-v0 --save-dataset
- Train Action Embedder:
python embedder.py --env-name BlockPlayImg-v0 --save-emb-model-file stack_im --train-embeddings
- Generate embedding files:
python main.py --env-name StackEnv-v0 --play-env-name BlockPlayImg-v0 --load-emb-model-file stack_im-htvae-5000.m --save-embeddings-file stack_im --prefix main
- Train policy with saved embeddings:
python main.py --env-name StackEnv-v0 --load-embeddings-file stack_im --prefix main
$ENV_NAME
= 'MiniGrid-LavaCrossingS9N1-v0'
$PLAY_ENV_NAME
= 'MiniGrid-Empty-Random-80x80-v0'
$EMB_FILE_NAME
= 'gw_onehot_new'
(1) Train policy directly with:
python main.py --env-name MiniGrid-LavaCrossingS9N1-v0 --prefix main
OR
(2) For full procedure, follow these commands:
- Generate Splits:
python gen_action_sets.py --env-name MiniGrid-LavaCrossingS9N1-v0
- Generate Data:
python embedder.py --env-name MiniGrid-Empty-Random-80x80-v0 --save-dataset
- Train Action Embedder:
python embedder.py --env-name MiniGrid-Empty-Random-80x80-v0 --save-emb-model-file gw_onehot_new --train-embeddings
- Generate embedding files:
python main.py --env-name MiniGrid-LavaCrossingS9N1-v0 --play-env-name MiniGrid-Empty-Random-80x80-v0 --load-emb-model-file gw_onehot_new-htvae-5000.m --save-embeddings-file gw_onehot_new --prefix main
- Train policy with saved embeddings:
python main.py --env-name MiniGrid-LavaCrossingS9N1-v0 --load-embeddings-file gw_onehot_new --prefix main
To run the baselines for any environment, add the following to the main command:
Baselines
- Nearest-Neighbor (NN):
--nearest-neighbor --fixed-action-set --action-random-sample False --prefix NN
- Distance-based Policy Architecture (Dist):
--distance-based --prefix dist
- Non-hierarchical embeddings (VAE):
--load-embeddings-file $FILE --prefix vae
, where $FILE storing these embeddings is environment-dependent:- CREATE:
create_fc_st_vae
- Shape Stacking:
stack_vae
- Grid World:
gw_onehot_vae
- CREATE:
Ablations
- Fixed Action Space (FX):
--fixed-action-set --action-random-sample False --prefix FX
- Random-Sampling without clustering (RS):
--sample-clusters False --prefix RS
- No-entropy (NE):
--entropy-coef 0. --prefix NE
Other embedding data formats
- CREATE: Video-based embeddings:
--load-embeddings-file create_fc_im --o-dim 128 --z-dim 128 --prefix im
- Grid World: (x,y) coordinate state-based embeddings:
--load-embeddings-file gw_st --prefix st
Ground-truth embeddings
for CREATE and Grid World: --gt-embs --prefix GT
For running the three analysis scripts simply run
analysis/analysis_dist.py
.analysis/analysis_emb.py
.analysis/analysis_ratio.py
- PPO code is based on the Pytorch implementation of PPO by Ilya Kostrikov
- The Grid world environment is from https://github.com/maximecb/gym-minigrid
- The recommender systems environment is from https://github.com/criteo-research/reco-gym