Neuro-symbolic reinforcement learning (NS-RL) has emerged as a promising paradigm for explainable decision-making, characterized by the interpretability of symbolic policies. NS-RL entails structured state representations for tasks with visual observations, but previous methods are unable to refine the structured states with rewards due to a lack of efficiency. Accessibility also remains to be an issue, as extensive domain knowledge is required to interpret symbolic policies. In this paper, we present a framework for learning structured states and symbolic policies jointly, whose key idea is to distill vision foundation models into a scalable perception module and refines it during policy learning. Moreover, we design a pipeline to generate language explanations for policies and decisions using large language models. In experiments on nine Atari tasks, we verify the efficacy of our approach, and we also present explanations for policies and decisions.
Here is the segmentation videos before and after policy learing on Freeway:
Prerequisites:
- Python==3.9.17
# core dependencies
pip install -r requirements/requirements.txt
pip install -r requirements/requirements-atari.txt
pip install torch==2.0.0+cu117 torchvision==0.15.1+cu117 torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cu117
pip install git+https://github.com/DLR-RM/stable-baselines3
git clone https://github.com/metadriverse/metadrive.git
cd metadrive
pip install -e .
pip install -e .[cuda]
conda install -c nvidia cuda-python
cd ..
cd cleanrl
cd sam_track
bash script/install.sh
bash script/download_ckpt.sh
cd FastSAM
pip install -r requirements.txt
pip install git+https://github.com/openai/CLIP.git
To generate dataset, use
cd ..
python demo.py --video-name PongNoFrameskip-v4
To train cnn, use
cd ..
python train_cnn.py --wandb-project-name nsrl-eql --env-id PongNoFrameskip-v4 --run-name benchmark-pretrain-Pong-seed1 --seed 1
Or you can use a build-in dataset directly
To train policy, use
python train_policy_atari.py --wandb-project-name nsrl-eql --env-id PongNoFrameskip-v4 --run-name benchmark-ng-reg-weight-1e-3-Pong-seed1 --ng True --reg_weight 1e-3 --seed 1 --load_cnn True
To train metadrive, use
python train_policy_metadrive.py --wandb-project-name nsrl-eql --run-name benchmark-INSIGHT-MetaDriveEnv-seed1 --env-id MetaDriveEnv --cnn_loss_weight 2 --distillation_loss_weight 1 --load_cnn True --seed 1 --learning-rate 5e-5 --clip-coef 0.2 --ent-coef 0.01 --ego_state True --num-envs 8 --num-steps 125 --update-epochs 4 --num-minibatches 10 --max-grad-norm 0.5 --anneal-lr False --kl-penalty-coef 0.2 --reg_weight 1e-4 --use_eql_actor True
If you find our code implementation helpful for your own research or work, please cite our paper.
@article{luo2024insight,
title={INSIGHT: End-to-End Neuro-Symbolic Visual Reinforcement Learning with Language Explanations},
author={Luo, Lirui and Zhang, Guoxi and Xu, Hongming and Yang, Yaodong and Fang, Cong and Li, Qing},
journal={ICML},
year={2024}
}
For any queries, please raise an issue or contact Qing Li.
This project is open sourced under MIT License.