CRoFT: Robust Fine-Tuning with Concurrent Optimization for OOD Generalization and Open-Set OOD Detection
The official implementation of CRoFT: Robust Fine-Tuning with Concurrent Optimization for OOD Generalization and Open-Set OOD Detection (ICML2024 CRoFT: Robust Fine-Tuning with Concurrent Optimization for OOD Generalization and Open-Set OOD Detection (openreview.net)).
Recent vision-language pre-trained models (VL-PTMs) have shown remarkable success in open-vocabulary tasks. However, downstream use cases often involve further fine-tuning of VL-PTMs, which may distort their general knowledge and impair their ability to handle distribution shifts. In real-world scenarios, machine learning systems inevitably encounter both covariate shifts (e.g., changes in image styles) and semantic shifts (e.g., test-time unseen classes). This highlights the importance of enhancing out-of-distribution (OOD) generalization on covariate shifts and simultaneously detecting semantic-shifted unseen classes. Thus a critical but underexplored question arises: How to improve VL-PTMs' generalization ability to closed-set OOD data, while effectively detecting open-set unseen classes during fine-tuning? In this paper, we propose a novel objective function of OOD detection that also serves to improve OOD generalization. We show that minimizing the gradient magnitude of energy scores on training data leads to domain-consistent Hessians of classification loss, a strong indicator for OOD generalization revealed by theoretical analysis. Based on this finding, we have developed a unified fine-tuning framework that allows for concurrent optimization of both tasks. Extensive experiments have demonstrated the superiority of our method.
Overview of the CRoFT framework
This code is built on top of the awesome Dassl and [CoOp](KaiyangZhou/CoOp: Prompt Learning for Vision-Language Models (IJCV'22, CVPR'22) (github.com))., run pip install -r requirements.txt
under CRoFT/CoOp/
to install the required packages.
git clone https://github.com/LinLLLL/CRoFT
cd CRoFT/CoOp
conda create -n CRoFT python=3.9
conda activate CRoFT
pip install -r requirements.txt
# Install the according versions of torch and torchvision
conda install pytorch torchvision cudatoolkit
Follow DATASET.md to install ImageNet, ImageNetV2, ImageNet-Sketch, ImageNet-A, ImageNet-R, and other 10 datasets referring to CoOp.
For the OOD datasets, such as PACS and VLCS, are publicly available but need to be downloaded manually. Please refer this [instruction](OoD-Bench/data/README.md at main · m-Just/OoD-Bench (github.com)) for OOD datasets download. Please make sure that the directory structure of each dataset is arranged as follows:
PACS
PACS
├── images
├── art_painting
├── cartoon
├── photo
└── sketch
├── test_on_art_painting.json
├── test_on_cartoon.json
├── test_on_photo.json
└── test_on_sketch.json
VLCS
VLCS
├── images
├── CALTECH
├── LABELME
├── PASCAL
└── SUN
├── test_on_caltech.json
├── test_on_labelme.json
├── test_on_pascal.json
└── test_on_sun.json
We provide the running scripts in CoOp/scripts
. We take CRoFT as an example, other methods can be similarly evaluated. Make sure you change the path on DATA
in shell files under CoOp/scripts/CRoFT
and run the commands under CoOp/scripts/CRoFT
.
python run_setup1.py
python test_setup1.py
bash test_energy.sh
python run_setup2.py
python test_setup2_energy.py
# run the commands under CoOp/
python collect_result_set1_oodg.py
# run the commands under CoOp/
python collect_result_set1_oodd.py
# run the commands under CoOp/
python collect_result_set2_oodg.py
We probide two OOD detection methods in SETUP-II, i.e., inferring energy score and KNN distance. Make sure you have completed the evluation process of python test_setup2_energy.py
before you run python test_setup2_knn.py
.
# run the commands under CoOp/
# inferring energy score
python collect_result_set2_oodd.py
# run the commands under CoOp/scripts/CRoFT
# inferring KNN distance:
python test_setup2_knn.py
The evaluation results are then saved to the folders output
and eval_open_ood
or displayed directly on your screen.
This repo benefits from CLIP, CoOp [CoCoOp](KaiyangZhou/CoOp: Prompt Learning for Vision-Language Models (IJCV'22, CVPR'22) (github.com)), [Tip-Adapter-F](gaopengcuhk/Tip-Adapter (github.com)), [DPLCLIP](shogi880/DPLCLIP (github.com)), CLIP-Adapter and the OOD generalization benchmark [OoD-Bench](ynysjtu/ood_bench (github.com)). Thanks for their wonderful works.
If you use this code in your research, please kindly cite this paper:
@article{zhu2024croft,
title={CRoFT: Robust Fine-Tuning with Concurrent Optimization for OOD Generalization and Open-Set OOD Detection},
author={Zhu, Lin and Yang, Yifeng and Gu, Qinying and Wang, Xinbing and Zhou, Chenghu and Ye, Nanyang},
journal={arXiv preprint arXiv:2405.16417},
year={2024}
}
If you have any question about this project, please feel free to contact zhulin_sjtu@sjtu.edu.cn.