Closest Among Top-K (CAT-K) Rollouts unroll the policy during fine-tuning in a way that visited states remain close to the ground-truth (GT). At each time step, CAT-K first takes the top-K most likely action tokens according to the policy, then chooses the one leading to the state closest to the GT. As a result, CAT-K rollouts follow the mode of the GT (e.g., turning left), while random or top-K rollouts can lead to large deviations (e.g., going straight or right). Since the policy is essentially trained to minimize the distance between the rollout states and the GT states, the GT-based supervision remains effective for CAT-K rollouts, but not for random or top-K rollouts.
Closed-Loop Supervised Fine-Tuning of Tokenized Traffic Models
Zhejun Zhang, Peter Karkus, Maximilian Igl, Wenhao Ding, Yuxiao Chen, Boris Ivanovic and Marco Pavone.
@article{zhang2024closed,
title = {Closed-Loop Supervised Fine-Tuning of Tokenized Traffic Models},
author = {Zhang, Zhejun and Karkus, Peter and Igl, Maximilian and Ding, Wenhao and Chen, Yuxiao and Ivanovic, Boris and Pavone, Marco},
journal={arXiv preprint arXiv:2412.05334},
year = {2024},
}
- The easy way to setup the environment is to create a conda environment using the following commands
conda create -y -n catk python=3.11.9 conda activate catk conda install -y -c conda-forge ffmpeg=4.3.2 pip install -r install/requirements.txt pip install torch_geometric pip install torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-2.4.0+cu121.html pip install --no-deps waymo-open-dataset-tf-2-12-0==1.6.4
- Alternatively, a better way is to use the Dockerfile and build your own docker. We found the code runs faster in the docker for some reasons.
- We use WandB for logging. You can register an account for free.
- Be aware
- We use 8 NVIDIA A100 (80GB) for training and validation, the training and fine-tuning take a few days, whereas the validation and testing take a few hours.
- We cannot share pre-trained models according to the terms of the Waymo Open Motion Dataset.
- Download the Waymo Open Motion Dataset. We use v1.2.1.
- Use scripts/cache_womd.sh to preprocess the dataset into pickle files to accelerate data loading during the training and evaluation.
- You should pack three datasets:
training
,validation
andtesting
.
In the scripts, we provide
- scripts/train.sh for training and fine-tuning.
- scripts/local_val.sh for local validation.
- scripts/wosac_sub.sh for packing submission files.
The default script runs with single GPU. We use DDP for multi GPU training and validation, and the codes are also found in the bash scripts. To reproduce our final results, you should follow the following steps
- Use scripts/train.sh with the BC pre-training config to pre-train the SMART-tiny 7M model.
- Use scripts/train.sh with the CLSFT with CAT-K config to fine-tune the SMART-tiny model pre-trained in step 1.
- Use scripts/wosac_sub.sh to pack the submission fille for
validate
ortest
split. Upload thewosac_submission.tar.gz
file located inlogs
folder to the WOSAC leaderboard such that you can evaluate the model fine-tuned in step 2 on the WOSAC leaderboard. - Alternatively, you can do local validation with scripts/local_val.sh.
For Gaussian Mixture Model (GMM) based ego policy, the procedure is similar, just use the following configs
- BC pre-training config for GMM-based ego policy
- CLSFT with CAT-K config for GMM-based ego policy
- Local validation config for GMM-based ego policy
- There is no submission option for ego-policy.
The submission of our CAT-K fine-tuned SMART to the WOSAC Leaderboard is found here. The submission of our reproduced SMART to the test split is found here, note that it is not published to the leaderboard.
Please refer to docs/ablation_models.md for the configurations of ablation models. Specifically you will find the data augmentation methods used by SMART and Trajeglish.
Our code is based on SMART. We appreciate them for the valuable open-source code! Please don't forget to cite their amazing work as well!