This is a PyTorch implementation of the paper "Locality Guidance for Improving Vision Transformers on Tiny Datasets", supporting different Transformer models (including DeiT, T2T-ViT, PiT, PVT, PVTv2, ConViT, CvT) and different classification datasets (including CIFAR-100, Oxford Flowers, Tiny ImageNet, Chaoyang).
While the Vision Transformer (VT) architecture is becoming trendy in computer vision, pure VT models perform poorly on tiny datasets. To address this issue, this paper proposes the locality guidance for improving the performance of VTs on tiny datasets. We first analyze that the local information, which is of great importance for understanding images, is hard to be learned with limited data due to the high flexibility and intrinsic globality of the self-attention mechanism in VTs. To facilitate local information, we realize the locality guidance for VTs by imitating the features of an already trained convolutional neural network (CNN), inspired by the built-in local-to-global hierarchy of CNN. Under our dual-task learning paradigm, the locality guidance provided by a lightweight CNN trained on low-resolution images is adequate to accelerate the convergence and improve the performance of VTs to a large extent. Therefore, our locality guidance approach is very simple and efficient, and can serve as a basic performance enhancement method for VTs on tiny datasets. Extensive experiments demonstrate that our method can significantly improve VTs when training from scratch on tiny datasets and is compatible with different kinds of VTs and datasets. For example, our proposed method can boost the performance of various VTs on tiny datasets (e.g., 13.07% for DeiT, 8.98% for T2T and 7.85% for PVT), and enhance even stronger baseline PVTv2 by 1.86% to 79.30%, showing the potential of VTs on tiny datasets.
The base environment we used for experiments is:
- python = 3.8.12
- pytorch = 1.8.0
- cudatoolkit = 10.1
Other dependencies can be installed by:
pip install -r requirements.txt
Step 1: download datasets from their official websites:
Step 2: move or link the datasets to data/
directory. We show the layout of data/
directory as follow:
data
└── cifar-100-python
| ├── meta
| ├── test
| └── train
└── flowers
| ├── jpg
| ├── imagelabels.mat
| └── setid.mat
└── tiny-imagenet-200
| ├── train
| ├── n01443537
| └── ...
| └── val
| ├── images
| └── val_annotations.txt
└── chaoyang
├── test
├── train
├── test.json
└── train.json
For example, you can train DeiT-Tiny from scratch using:
python run_net.py --mode train --cfg configs/deit/deit-ti_c100_base.yaml
Besides, we provide configurations for different models and different datasets at configs/
.
Step 1: train the CNN guidance model (e.g., ResNet-56). This step will only take a little time and only needs to be executed once for each dataset.
python run_net.py --mode train --cfg configs/resnet/r-56_c100.yaml
Step 2: train the target VT.
python run_net.py --mode train --cfg configs/deit/deit-ti_c100_ours.yaml
As mentioned in the supplementary materials, the locality guidance can be executed offline using the per-computed features. To run in this setting, you can use:
# Pre-compute features
python precompute_feature.py --cfg configs/resnet/r-56_c100.yaml --ckpt work_dirs/r-56_c100/model.pyth
# Train the model
python run_net.py --mode train --cfg configs/deit/deit-ti_c100_ours_offline.yaml
Just one argument needs to be added for multi-gpu or mixed precision training, for example:
# Train DeiT from scratch with 2 gpus
python run_net.py --mode train --cfg configs/deit/deit-ti_c100_base.yaml NUM_GPUS 2
# Train DeiT from scratch with 2 gpus using mixed precision
python run_net.py --mode train --cfg configs/deit/deit-ti_c100_base.yaml NUM_GPUS 2 TRAIN.MIXED_PRECISION True
python run_net.py --mode test --cfg configs/deit/deit-ti_c100_base.yaml TEST.WEIGHTS /path/to/model.pyth
Model | Top-1 Acc. (Base) | Top-1 Acc. (Ours) |
---|---|---|
DeiT-Tiny | 65.08 ( weights | log ) | 78.15 ( weights | log ) |
T2T-ViT-7 | 69.37 ( weights | log ) | 78.35 ( weights | log ) |
PiT-Tiny | 73.58 ( weights | log ) | 78.48 ( weights | log ) |
PVT-Tiny | 69.22 ( weights | log ) | 77.07 ( weights | log ) |
PVTv2-B0 | 77.44 ( weights | log ) | 79.30 ( weights | log ) |
ConViT-Tiny | 75.32 ( weights | log ) | 78.95 ( weights | log ) |
Here we provide pre-trained models and training logs (can be viewed via TensorBoard).
This repository is built upon pycls and the official implementations of DeiT, T2T-ViT, PiT, PVTv1/v2, ConViT and CvT. We would like to thank authors of these open source repositories.
@article{li2022locality,
title={Locality Guidance for Improving Vision Transformers on Tiny Datasets},
author={Li, Kehan and Yu, Runyi and Wang, Zhennan and Yuan, Li and Song, Guoli and Chen, Jie},
journal={arXiv preprint arXiv:2207.10026},
year={2022}
}