Open-source code for paper CDT: Cascading Decision Trees for Explainable Reinforcement Learning (
Data folder: the data folders (data
and data_no_norm
) should be put at the root of the repo to run the code. See issues: #4. The data folders are stored at the Google Drive.
- data: all data for experiments (not maintained in the repo, but can be collected with the given scripts below)
- mlp: data for MLP model;
- cdt: data for CDT model;
- sdt: data for SDT model;
- il: data for general Imitation Learning (IL);
- rl: data for general Reinforcement Learning (RL);
- cdt_compare_depth: data for cdt with different depths in RL;
- sdt_compare_depth: data for sdt with different depths in RL;
- src: source code
- mlp: training configurations for MLP as policy function approximator;
- cdt: the Cascading Decision Tree (CDT) class and necessary functions;
- sdt: the Soft Decision Tree (SDT) class and necessary functions;
- hdt: the heuristic agents;
- il: configurations for Imitation Learning (IL);
- rl: configurations for Reinforcement Learning (RL) and RL agents (e.g., PPO) etc;
- utils: some common functions
: collect dataset (state-action from heuristic or well-trained policy) for IL;
: collect dataset (states during training for calculating normalization statistics) for RL;
: train IL agent with different function approximators (e.g., SDT, CDT);
: train RL agent different function approximators (e.g., SDT, CDT, MLP);
: evaluate the trained IL agents before and after tree discretization, based on prediction accuracy;
: evaluate the trained RL agents before and after tree discretization, based on episodic reward;
: bash to run IL test with different models on server;
: bash to run RL test with different models on server;
: train RL agent with SDT;
: train RL agent with SDT;
: bash to run RL test with SDT of different depths on server;
: bash to run RL test with CDT of different depths on server;
- visual
: plot learning curves, etc.params.ipynb
: quantitive analysis of model parameters (SDT and CDT).stability_analysis.ipynb
: refer to the stability analysis in paper--compare the tree weights.
For fully replicating the experiments in the paper, the code needs to run in several stages.
Collect dataset: for state normalization
cd ./src python
Get statistics on dataset
cd rl jupyter notebook
and run cells in it to generate files for dataset statistics.Step 1, 2 can be skipped is not using state normalization.
Train RL agents with different policy function approximators: SDT, CDT, MLP
cd .. python --train --env='CartPole-v1' --method='sdt' --id=0 python --train --env='LunarLander-v2' --method='cdt' --id=0 python --train --env='MountainCar-v0' --method='mlp' --id=0
or simply run with:
Evaluate the trained agents (with discretization operation)
python --env='CartPole-v1' --method='sdt' python --env='LunarLander-v2' --method='cdt'
Results visualization
cd ../visual jupyter notebook
see in
Collect dataset: for (1) state normalization and (2) as imitation learning dataset
cd ./src python
Train RL agents with different policy function approximators: SDT, CDT
python --train --env='CartPole-v1' --method='sdt' --id=0 python --train --env='LunarLander-v2' --method='cdt' --id=0
or simply run with:
Evaluate the trained agents
python --env='CartPole-v1' --method='sdt' python --env='LunarLander-v2' --method='cdt'
Results visualization
cd ../visual jupyter notebook
see in
DAGGER and Q-DAGGER methods in VIPER are compared in the paper as well under the imitation learning setting. Code in ./src/viper/
. Credit gives to Hangrui (Henry) Bi
Run the comparison with different tree depths:
For SDT:
For CDT:
Compare the tree weights of different agents in IL:
cd ./visual
jupyner notebook
See in stability_analysis.ipynb
Quantitative analysis of number of model parameters:
cd ./visual
jupyter notebook
See in params.ipynb
title={Cdt: Cascading decision trees for explainable reinforcement learning},
author={Ding, Zihan and Hernandez-Leal, Pablo and Ding, Gavin Weiguang and Li, Changjian and Huang, Ruitong},
journal={arXiv preprint arXiv:2011.07553},