Skip to content
/ catk Public

Closed-Loop Supervised Fine-Tuning of Tokenized Traffic Models

License

Notifications You must be signed in to change notification settings

NVlabs/catk

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Closed-Loop Supervised Fine-Tuning of Tokenized Traffic Models

Closest Among Top-K (CAT-K) rollouts unroll the policy during fine-tuning in a way that visited states remain close to the ground-truth.
Closest Among Top-K (CAT-K) Rollouts unroll the policy during fine-tuning in a way that visited states remain close to the ground-truth (GT). At each time step, CAT-K first takes the top-K most likely action tokens according to the policy, then chooses the one leading to the state closest to the GT. As a result, CAT-K rollouts follow the mode of the GT (e.g., turning left), while random or top-K rollouts can lead to large deviations (e.g., going straight or right). Since the policy is essentially trained to minimize the distance between the rollout states and the GT states, the GT-based supervision remains effective for CAT-K rollouts, but not for random or top-K rollouts.

Closed-Loop Supervised Fine-Tuning of Tokenized Traffic Models
Zhejun Zhang, Peter Karkus, Maximilian Igl, Wenhao Ding, Yuxiao Chen, Boris Ivanovic and Marco Pavone.

Project Page
arXiv Paper

@article{zhang2024closed,
  title = {Closed-Loop Supervised Fine-Tuning of Tokenized Traffic Models},
  author = {Zhang, Zhejun and Karkus, Peter and Igl, Maximilian and Ding, Wenhao and Chen, Yuxiao and Ivanovic, Boris and Pavone, Marco},
  journal={arXiv preprint arXiv:2412.05334},
  year = {2024},
}

Installation

  • The easy way to setup the environment is to create a conda environment using the following commands
    conda create -y -n catk python=3.11.9
    conda activate catk
    conda install -y -c conda-forge ffmpeg=4.3.2
    pip install -r install/requirements.txt
    pip install torch_geometric
    pip install torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
    pip install --no-deps waymo-open-dataset-tf-2-12-0==1.6.4
    
  • Alternatively, a better way is to use the Dockerfile and build your own docker. We found the code runs faster in the docker for some reasons.
  • We use WandB for logging. You can register an account for free.
  • Be aware
    • We use 8 NVIDIA A100 (80GB) for training and validation, the training and fine-tuning take a few days, whereas the validation and testing take a few hours.
    • We cannot share pre-trained models according to the terms of the Waymo Open Motion Dataset.

Dataset preparation

  • Download the Waymo Open Motion Dataset. We use v1.2.1.
  • Use scripts/cache_womd.sh to preprocess the dataset into pickle files to accelerate data loading during the training and evaluation.
  • You should pack three datasets: training, validation and testing.

Run the code

In the scripts, we provide

The default script runs with single GPU. We use DDP for multi GPU training and validation, and the codes are also found in the bash scripts. To reproduce our final results, you should follow the following steps

  1. Use scripts/train.sh with the BC pre-training config to pre-train the SMART-tiny 7M model.
  2. Use scripts/train.sh with the CLSFT with CAT-K config to fine-tune the SMART-tiny model pre-trained in step 1.
  3. Use scripts/wosac_sub.sh to pack the submission fille for validate or test split. Upload the wosac_submission.tar.gz file located in logs folder to the WOSAC leaderboard such that you can evaluate the model fine-tuned in step 2 on the WOSAC leaderboard.
  4. Alternatively, you can do local validation with scripts/local_val.sh.

For Gaussian Mixture Model (GMM) based ego policy, the procedure is similar, just use the following configs

Performance

The submission of our CAT-K fine-tuned SMART to the WOSAC Leaderboard is found here. The submission of our reproduced SMART to the test split is found here, note that it is not published to the leaderboard.

Ablation configs

Please refer to docs/ablation_models.md for the configurations of ablation models. Specifically you will find the data augmentation methods used by SMART and Trajeglish.

Acknowledgement

Our code is based on SMART. We appreciate them for the valuable open-source code! Please don't forget to cite their amazing work as well!