Open-source code for paper CDT: Cascading Decision Trees for Explainable Reinforcement Learning (https://arxiv.org/abs/2011.07553).
Data folder: the data folders (data
and data_no_norm
) should be put at the root of the repo to run the code. See issues: #4. The data folders are stored at the Google Drive.
- data: all data for experiments (not maintained in the repo, but can be collected with the given scripts below)
- mlp: data for MLP model;
- cdt: data for CDT model;
- sdt: data for SDT model;
- il: data for general Imitation Learning (IL);
- rl: data for general Reinforcement Learning (RL);
- cdt_compare_depth: data for cdt with different depths in RL;
- sdt_compare_depth: data for sdt with different depths in RL;
- src: source code
- mlp: training configurations for MLP as policy function approximator;
- cdt: the Cascading Decision Tree (CDT) class and necessary functions;
- sdt: the Soft Decision Tree (SDT) class and necessary functions;
- hdt: the heuristic agents;
- il: configurations for Imitation Learning (IL);
- rl: configurations for Reinforcement Learning (RL) and RL agents (e.g., PPO) etc;
- utils: some common functions
il_data_collect.py
: collect dataset (state-action from heuristic or well-trained policy) for IL;rl_data_collect.py
: collect dataset (states during training for calculating normalization statistics) for RL;il_train.py
: train IL agent with different function approximators (e.g., SDT, CDT);rl_train.py
: train RL agent different function approximators (e.g., SDT, CDT, MLP);il_eval.py
: evaluate the trained IL agents before and after tree discretization, based on prediction accuracy;rl_eval.py
: evaluate the trained RL agents before and after tree discretization, based on episodic reward;il_train.sh
: bash to run IL test with different models on server;rl_train.sh
: bash to run RL test with different models on server;rl_train_compare_sdt.py
: train RL agent with SDT;rl_train_compare_cdt.py
: train RL agent with SDT;rl_train_compare_sdt.sh
: bash to run RL test with SDT of different depths on server;rl_train_compare_cdt.sh
: bash to run RL test with CDT of different depths on server;
- visual
plot.ipynb
: plot learning curves, etc.params.ipynb
: quantitive analysis of model parameters (SDT and CDT).stability_analysis.ipynb
: refer to the stability analysis in paper--compare the tree weights.
For fully replicating the experiments in the paper, the code needs to run in several stages.
-
Collect dataset: for state normalization
cd ./src python rl_data_collect.py
-
Get statistics on dataset
cd rl jupyter notebook
open
stats.ipynb
and run cells in it to generate files for dataset statistics.Step 1, 2 can be skipped is not using state normalization.
-
Train RL agents with different policy function approximators: SDT, CDT, MLP
cd .. python rl_train.py --train --env='CartPole-v1' --method='sdt' --id=0 python rl_train.py --train --env='LunarLander-v2' --method='cdt' --id=0 python rl_train.py --train --env='MountainCar-v0' --method='mlp' --id=0
or simply run with:
./rl_train.sh
-
Evaluate the trained agents (with discretization operation)
python rl_eval.py --env='CartPole-v1' --method='sdt' python rl_eval.py --env='LunarLander-v2' --method='cdt'
-
Results visualization
cd ../visual jupyter notebook
see in
plot.ipynb
.
-
Collect dataset: for (1) state normalization and (2) as imitation learning dataset
cd ./src python il_data_collect.py
-
Train RL agents with different policy function approximators: SDT, CDT
python il_train.py --train --env='CartPole-v1' --method='sdt' --id=0 python il_train.py --train --env='LunarLander-v2' --method='cdt' --id=0
or simply run with:
./il_train.sh
-
Evaluate the trained agents
python il_eval.py --env='CartPole-v1' --method='sdt' python il_eval.py --env='LunarLander-v2' --method='cdt'
-
Results visualization
cd ../visual jupyter notebook
see in
plot.ipynb
.
DAGGER and Q-DAGGER methods in VIPER are compared in the paper as well under the imitation learning setting. Code in ./src/viper/
. Credit gives to Hangrui (Henry) Bi
.
Run the comparison with different tree depths:
For SDT:
./rl_train_compare_sdt.sh
For CDT:
./rl_train_compare_cdt.sh
Compare the tree weights of different agents in IL:
cd ./visual
jupyner notebook
See in stability_analysis.ipynb
.
Quantitative analysis of number of model parameters:
cd ./visual
jupyter notebook
See in params.ipynb
.
@article{ding2020cdt,
title={Cdt: Cascading decision trees for explainable reinforcement learning},
author={Ding, Zihan and Hernandez-Leal, Pablo and Ding, Gavin Weiguang and Li, Changjian and Huang, Ruitong},
journal={arXiv preprint arXiv:2011.07553},
year={2020}
}