https://arxiv.org/abs/2203.12119
This repository contains the official PyTorch implementation for Visual Prompt Tuning.
See env_setup.sh
-
src/configs
: handles config parameters for the experiments.- 👉
src/config/config.py
: main config setups for experiments and explanation for each of them.
- 👉
-
src/data
: loading and setup input datasets. Thesrc/data/vtab_datasets
are borrowed from -
src/engine
: main training and eval actions here. -
src/models
: handles backbone archs and heads for different fine-tuning protocols-
👉
src/models/vit_prompt
: a folder contains the same backbones invit_backbones
folder, specified for VPT. This folder should contain the same file names as those invit_backbones
-
👉
src/models/vit_models.py
: main model for transformer-based models ❗️Note❗️: Current version only support ViT, Swin and ViT with mae, moco-v3 -
src/models/build_model.py
: main action here to utilize the config and build the model to train / eval.
-
-
src/solver
: optimization, losses and learning rate schedules. -
src/utils
: helper functions for io, loggings, training, visualizations. -
👉
train.py
: call this one for training and eval a model with a specified transfer type. -
👉
tune_fgvc.py
: call this one for tuning learning rate and weight decay for a model with a specified transfer type. We used this script for FGVC tasks. -
👉
tune_vtab.py
: call this one for tuning vtab tasks: use 800/200 split to find the best lr and wd, and use the best lr/wd for the final runs -
launch.py
: contains functions used to launch the job.
- 🔥VPT related:
- MODEL.PROMPT.NUM_TOKENS: prompt length
- MODEL.PROMPT.DEEP: deep or shallow prompt
- Fine-tuning method specification:
- MODEL.TRANSFER_TYPE
- Vision backbones:
- DATA.FEATURE: specify which representation to use
- MODEL.TYPE: the general backbone type, e.g., "vit" or "swin"
- MODEL.MODEL_ROOT: folder with pre-trained model checkpoints
- Optimization related:
- SOLVER.BASE_LR: learning rate for the experiment
- SOLVER.WEIGHT_DECAY: weight decay value for the experiment
- DATA.BATCH_SIZE
- Datasets related:
- DATA.NAME
- DATA.DATAPATH: where you put the datasets
- DATA.NUMBER_CLASSES
- Others:
- RUN_N_TIMES: ensure only run once in case for duplicated submision, not used during vtab runs
- OUTPUT_DIR: output dir of the final model and logs
- MODEL.SAVE_CKPT: if set to
True
, will save model ckpts and final output of both val and test set
See Table 8 in the Appendix for dataset details.
-
Fine-Grained Visual Classification tasks (FGVC): The datasets can be downloaded following the official links. We split the training data if the public validation set is not available. The splitted dataset can be found here: Dropbox, Google Drive.
-
Visual Task Adaptation Benchmark (VTAB): see
VTAB_SETUP.md
for detailed instructions and tips.
Download and place the pre-trained Transformer-based backbones to MODEL.MODEL_ROOT
(ConvNeXt-Base and ResNet50 would be automatically downloaded via the links in the code). Note that you also need to rename the downloaded ViT-B/16 ckpt from ViT-B_16.npz
to imagenet21k_ViT-B_16.npz
.
See Table 9 in the Appendix for more details about pre-trained backbones.
Pre-trained Backbone | Pre-trained Objective | Link | md5sum |
---|---|---|---|
ViT-B/16 | Supervised | link | d9715d |
ViT-B/16 | MoCo v3 | link | 8f39ce |
ViT-B/16 | MAE | link | 8cad7c |
Swin-B | Supervised | link | bf9cc1 |
ConvNeXt-Base | Supervised | link | - |
ResNet-50 | Supervised | link | - |
See demo.ipynb
for how to use this repo.
The hyperparameter values used (prompt length for VPT / reduction rate for Adapters, base learning rate, weight decay values) in Table 1-2, Fig. 3-4, Table 4-5 can be found here: Dropbox / Google Drive.
If you find our work helpful in your research, please cite it as:
@inproceedings{jia2022vpt,
title={Visual Prompt Tuning},
author={Jia, Menglin and Tang, Luming and Chen, Bor-Chun and Cardie, Claire and Belongie, Serge and Hariharan, Bharath and Lim, Ser-Nam},
booktitle={European Conference on Computer Vision (ECCV)},
year={2022}
}
The majority of VPT is licensed under the CC-BY-NC 4.0 license (see LICENSE for details). Portions of the project are available under separate license terms: GitHub - google-research/task_adaptation and huggingface/transformers are licensed under the Apache 2.0 license; Swin-Transformer, ConvNeXt and ViT-pytorch are licensed under the MIT license; and MoCo-v3 and MAE are licensed under the Attribution-NonCommercial 4.0 International license.