Knowledge Distillation as Efficient Pretraining: Faster Convergence, Higher Data-efficiency, and Better Transferability
This repository contains the code and models necessary to replicate the results of our paper:
@inproceedings{he2022knowledge,
title={Knowledge Distillation as Efficient Pre-training: Faster Convergence, Higher Data-efficiency, and Better Transferability
},
author={He, Ruifei and Sun, Shuyang, and Yang, Jihan, and Bai, Song and Qi, Xiaojuan},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2022}
}
Large-scale pre-training has been proven to be crucial for various computer vision tasks. However, with the increase of pre-training data amount, model architecture amount, and the private/inaccessible data, it is not very efficient or possible to pre-train all the model architectures on large-scale datasets. In this work, we investigate an alternative strategy for pre-training, namely Knowledge Distillation as Efficient Pre-training (KDEP), aiming to efficiently transfer the learned feature representation from existing pre-trained models to new student models for future downstream tasks. We observe that existing Knowledge Distillation (KD) methods are unsuitable towards pre-training since they normally distill the logits that are going to be discarded when transferred to downstream tasks. To resolve this problem, we propose a feature-based KD method with non-parametric feature dimension aligning. Notably, our method performs comparably with supervised pre-training counterparts in 3 downstream tasks and 9 downstream datasets requiring 10x less data and 5x less pre-training time.
-
Clone our repo:
git clone https://github.com/CVMI-Lab/KDEP.git
-
Install dependencies:
conda create -n KDEP python=3.7 conda activate KDEP pip install -r requirements.txt
- ImageNet-1K (Download)
- Caltech256 (Download)
- Cifar100 (Automatically downloaded when you run the code)
- DTD (Download)
- CUB-200 (Download)
- Cityscapes (Download)
- VOC (segmentation and detection, Download)
- ADE20K (Download)
- COCO (Download)
For image classification datasets (except for Caltech256), the folder structure should follow ImageNet:
data root
├─ train/
├── n01440764
│ ├── n01440764_10026.JPEG
│ ├── n01440764_10027.JPEG
│ ├── ......
├── ......
├─ val/
├── n01440764
│ ├── ILSVRC2012_val_00000293.JPEG
│ ├── ILSVRC2012_val_00002138.JPEG
│ ├── ......
├── ......
For semantic segmentation datasets, please refer to PyTorch Semantic Segmentation.
For object detection datasets, please refer to Detectron2.
-
Download teacher models (Download), and put them under
pretrained-models/
. -
You can use a provided python file
scripts/make-imgnet-subset.py
to create the 10% of ImageNet-1K data. -
Update the path of the dataset for KDEP (10% or 100% of ImageNet-1K) in
src/utils/constants.py
. -
Prepare the SVD weights for teacher models. You can download the weights we provide (Download) or generate using our provided script
scripts/gen_svd_weights.sh
.sh scripts/gen_svd_weights.sh imgnet_128k ex_gen_svd 0
-
Scripts of pre-training with KDEP are in
scripts/
. For example, you can use teacher-student pair of Microsoft ResNet50 -> ResNet18 withscripts/KDEP_MS-R50_R18.sh
by:sh scripts/KDEP_MS-R50_R18.sh imgnet_128k exp_name 90 30 5e-4 0,1,2,3 ### imgnet_128k or imgnet_full to select 10% or 100% ImageNet-1K data ### 90 is #epoch, 30 is step-lr ### 5e-4 is weight decay ### 0,1,2,3 is GPU id
You can run KDEP with different data amount and training schedules by changing the data name (imgnet_128k or imgnet_full), #epoch and step-lr, and weight decay.
Note that we do not generate the svd weights for 100% ImageNet-1K data, but directly use the svd weights generated from 10% data.
-
We use four image classification tasks: CIFAR100, DTD, Caltech256, CUB-200.
-
Scripts (
scripts/TL_img-cls_R18.sh
andscripts/TL_img-cls_mnv2.sh
) are provided for running all four tasks twice for a distilled student (R18/mnv2).sh scripts/TL_img-cls_R18.sh exp_name # note the exp_name here should be identical to that of the distilled student
-
We use three semantic segmentation tasks: Cityscapes, VOC2012, ADE20K.
-
Transform the checkpoint into segmentation code format by
src/transform_ckpt_custom2seg.py
cd src python3 transform_ckpt_custom2seg.py exp_name # note the exp_name here should be identical to that of the distilled student
Move the transformed checkpoint to
semseg/initmodel/
. -
Scripts (
semseg/tool/TL_seg_R18.sh
andsemseg/tool/TL_seg_mnv2.sh
) are provided for running all three tasks twice for a distilled student (R18/mnv2).cd semseg sh tool/TL_seg_R18.sh ckpt_name # note the ckpt_name should be what you put into the semseg/initmodel/ in step1.
-
We use two object detection tasks: COCO and VOC.
-
Transform the checkpoint into Detectron2 format by
src/transform_ckpt_custom2det.py
cd src python3 transform_ckpt_custom2det.py exp_name R18 # note the exp_name here should be identical to that of the distilled student # R18 could be changed to mnv2
Move the transformed checkpoint to
detectron2/ckpts/
. -
Install Detectron2, and export dataset path
python3 -m pip install -e detectron2 export DETECTRON2_DATASETS='path/to/datasets'
-
Scripts (
detectron2/tool/TL_det_R18.sh
anddetectron2/tool/TL_det_mnv2.sh
) are provided for running all two tasks twice for a distilled student (R18/mnv2).cd detectron2/tool sh TL_det_R18.sh ckpt_name # note the ckpt_name should be what you put into the semseg/initmodel/ in step1.
We provide some distilled models of KDEP here.
- (Download) ResNet18, KDEP(SVD+PTS) from MS-R50 teacher on 100% ImageNet-1K data for 90 epochs.
- (Download) MobileNet-V2, KDEP(SVD+PTS) from MS-R50 teacher on 100% ImageNet-1K data for 90 epochs.
Our code is mainly based on robust-models-transfer, we also thank the open source code from PyTorch Semantic Segmentation and Detectron2.