This is the pytorch implementation for the paper: Token-Label Alignment for Vision Transformers.
Han Xiao*, Wenzhao Zheng*, Zheng Zhu, Jie Zhou, and Jiwen Lu
- Improve your ViTs by ~0.7% with a simple --tl-align command.
- Improvements in accuracy, generalization, and robustness without additional computation during inference.
- Efficient alignment of token labels without distillation.
Model | Image Size | Params | FLOPs | Top-1 Acc.(%) | Top-5 Acc.(%) |
---|---|---|---|---|---|
DeiT-T | 5.7M | 1.6G | 72.2 | 91.3 | |
+TL-Align | 5.7M | 1.6G | 73.2 | 91.7 | |
DeiT-S | 22M | 4.6G | 79.8 | 95.0 | |
+TL-Align | 22M | 4.6G | 80.6 | 95.0 | |
DeiT-B | 86M | 17.5G | 81.8 | 95.5 | |
+TL-Align | 86M | 17.5G | 82.3 | 95.8 | |
Swin-T | 29M | 4.5G | 81.2 | 95.5 | |
+TL-Align | 29M | 4.5G | 81.4 | 95.7 | |
Swin-S | 50M | 8.8G | 83.0 | 96.3 | |
+TL-Align | 50M | 8.8G | 83.4 | 96.5 | |
Swin-B | 88M | 15.4G | 83.5 | 96.4 | |
+TL-Align | 88M | 15.4G | 83.7 | 96.5 |
This repository is built upon the Timm library and the DeiT repository.
You need to install PyTorch 1.7.0+ and torchvision 0.8.1+ and pytorch-image-models 0.3.2:
conda install -c pytorch pytorch torchvision
pip install timm==0.3.2
Download and extract ImageNet train and val images from http://image-net.org/.
The directory structure is the standard layout for the torchvision datasets.ImageFolder
, and the training and validation data are expected to be in the train/
folder and val
folder respectively:
/path/to/imagenet/
train/
class1/
img1.jpeg
class2/
img2.jpeg
val/
class1/
img3.jpeg
class/2
img4.jpeg
To enable token-label alignment during training, you can simply add a --tl-align
in your training script. For example, for DeiT-small, run:
python3 -m torch.distributed.launch --nproc_per_node=8 --use_env main_tla.py \
--model deit_small_patch16_224 \
--batch-size 128 \
--mixup 0.0 \
--tl-align \
--data-path /path/to/imagenet \
--output_dir /path/to/output \
or
bash train_deit_small_tla.sh
This should give 80.6% top-1 accuracy after 300 epochs of training.
The evaluation of models trained by our token-label alignment is the same as timm. You can also find your validation accuracy during training.
For Deit-small, run:
python main_tla.py --eval --resume checkpoint.pth --model deit_small_patch16_224 --data-path /path/to/imagenet
If you find this project useful in your research, please cite:
@article{xiao2022token,
title={Token-Label Alignment for Vision Transformers},
author={Xiao, Han and Zheng, Wenzhao and Zhu, Zheng and Zhou, Jie and Lu, Jiwen},
journal={arXiv preprint arXiv:2210.06455},
year={2022}
}