Skip to content

Latest commit

 

History

History
 
 

deit

DeiT

Training data-efficient image transformers & distillation through attention

Abstract

Recently, neural networks purely based on attention were shown to address image understanding tasks such as image classification. However, these visual transformers are pre-trained with hundreds of millions of images using an expensive infrastructure, thereby limiting their adoption. In this work, we produce a competitive convolution-free transformer by training on Imagenet only. We train them on a single computer in less than 3 days. Our reference vision transformer (86M parameters) achieves top-1 accuracy of 83.1% (single-crop evaluation) on ImageNet with no external data. More importantly, we introduce a teacher-student strategy specific to transformers. It relies on a distillation token ensuring that the student learns from the teacher through attention. We show the interest of this token-based distillation, especially when using a convnet as a teacher. This leads us to report results competitive with convnets for both Imagenet (where we obtain up to 85.2% accuracy) and when transferring to other tasks. We share our code and models.

How to use it?

Predict image

from mmpretrain import inference_model

predict = inference_model('deit-tiny_4xb256_in1k', 'demo/bird.JPEG')
print(predict['pred_class'])
print(predict['pred_score'])

Use the model

import torch
from mmpretrain import get_model

model = get_model('deit-tiny_4xb256_in1k', pretrained=True)
inputs = torch.rand(1, 3, 224, 224)
out = model(inputs)
print(type(out))
# To extract features.
feats = model.extract_feat(inputs)
print(type(feats))

Train/Test Command

Prepare your dataset according to the docs.

Train:

python tools/train.py configs/deit/deit-tiny_4xb256_in1k.py

Test:

python tools/test.py configs/deit/deit-tiny_4xb256_in1k.py https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.pth

Models and results

Image Classification on ImageNet-1k

Model Pretrain Params (M) Flops (G) Top-1 (%) Top-5 (%) Config Download
deit-tiny_4xb256_in1k From scratch 5.72 1.26 74.50 92.24 config model | log
deit-tiny-distilled_3rdparty_in1k* From scratch 5.91 1.27 74.51 91.90 config model
deit-small_4xb256_in1k From scratch 22.05 4.61 80.69 95.06 config model | log
deit-small-distilled_3rdparty_in1k* From scratch 22.44 4.63 81.17 95.40 config model
deit-base_16xb64_in1k From scratch 86.57 17.58 81.76 95.81 config model | log
deit-base_3rdparty_in1k* From scratch 86.57 17.58 81.79 95.59 config model
deit-base-distilled_3rdparty_in1k* From scratch 87.34 17.67 83.33 96.49 config model
deit-base_224px-pre_3rdparty_in1k-384px* 224px 86.86 55.54 83.04 96.31 config model
deit-base-distilled_224px-pre_3rdparty_in1k-384px* 224px 87.63 55.65 85.55 97.35 config model

Models with * are converted from the official repo. The config files of these models are only for inference. We haven't reprodcue the training results.

MMPretrain doesn't support training the distilled version DeiT.
And we provide distilled version checkpoints for inference only.

Citation

@InProceedings{pmlr-v139-touvron21a,
  title =     {Training data-efficient image transformers & distillation through attention},
  author =    {Touvron, Hugo and Cord, Matthieu and Douze, Matthijs and Massa, Francisco and Sablayrolles, Alexandre and Jegou, Herve},
  booktitle = {International Conference on Machine Learning},
  pages =     {10347--10357},
  year =      {2021},
  volume =    {139},
  month =     {July}
}