This repo contains the code for our paper "Less is More: Task-aware Layer-wise Distillation for Language Model Compression" (ICML2023). We propose TED, a task-aware layerwise distillation approach for language model compression.
[Aug 27, 2023] We have released the code for GLUE experiments (including KD and LWD baselines) upon requests. All teacher and student checkpoints for GLUE and SQuAD 2.0 have been released. Please submit an issue if you need examples for other tasks and models.
- Pull and run docker
pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel
- Install requirements
cd task-aware-distillation pip install -r requirements.txt
Given a teacher model and a student model, we first fine-tune them on the target task if they have not yet been fine-tuned. You can run
./scripts/finetune_glue.sh
We will then use these fine-tuned models to initialize the teacher and student in Stage I. Alternatively, you may directly use the fine-tuned models we released on Huggingface model hub:
Task | DeBERTaV3-base | DeBERTaV3-xs |
---|---|---|
MNLI | 90.6 [ckpt] | 88.3 [ckpt] |
QQP | 89.9 [ckpt] | 88.9 [ckpt] |
QNLI | 94.1 [ckpt] | 92.6 [ckpt] |
SST-2 | 96.4 [ckpt] | 93.6 [ckpt] |
RTE | 86.3 [ckpt] | 83.4 [ckpt] |
CoLA | 68.4 [ckpt] | 66.0 [ckpt] |
MRPC | 93.5 [ckpt] | 93.0 [ckpt] |
STS-B | 91.5 [ckpt] | 90.4 [ckpt] |
Given properly initialized teacher and student models, we add a set of layerwise task-aware filters to each of them. We train the filters on the target task to extract the task-relevant knowledge, while keeping the model frozen. You can run
./scripts/learn_filters_glue.sh
Stage I has several important hyperparameters:
--filter_interval
: The layer interval between consecutive filters. Default to be one. If it is not necessary to learn a filter for every layer (e.g., a 12-layer teacher only needs 6 filters to match with a 6-layer student), you may specify a largerfilter_interval
to save the training cost (e.g., 12/6 = 2).--filter_output_dim
: The output dimension of each filter. Default to be the hidden dimension of the model. If the teacher and the student models have different hidden dimensions, you may specify a singlefilter_output_dim
for both of them for dimension matching.--filter_nonlinear
: A boolean value of whether to add nonlinearities in filters. Default to be false.
We will then use the models equipped with learned filters to initialize the teacher and student in Stage II. Alternatively, you may directly use the models with learned filters we released:
Task | DeBERTaV3-base (Stage I) | DeBERTaV3-xs (Stage I) |
---|---|---|
MNLI | [ckpt] | [ckpt] |
QQP | [ckpt] | [ckpt] |
QNLI | [ckpt] | [ckpt] |
SST-2 | [ckpt] | [ckpt] |
RTE | [ckpt] | [ckpt] |
CoLA | [ckpt] | [ckpt] |
MRPC | [ckpt] | [ckpt] |
STS-B | [ckpt] | [ckpt] |
Finally, we conduct task-aware distillation (TED). We distill the student model and its filters to match its filtered output representations to those of the teacher model, while keeping the teacher model and its filters frozen. You can run
./scripts/ted_glue.sh
In Stage II, there are two new hyperparameters:
--kl_alpha
: The weighting of the KL-divergence loss between the teacher's and the student's final-layer output predictions.--mse_alpha
: The weighting of the MSE loss between the output representations of the teacher's and the student's filters. The student is optimized based onpred_loss + kl_alpha * kl_loss + mse_alpha * layer_averaged_mse_loss
.
The final task-aware distilled student checkpoints have been released at:
Task | DeBERTaV3-xs (Stage II) |
---|---|
MNLI | 88.8 [ckpt] |
QQP | 89.5 [ckpt] |
QNLI | 93.1 [ckpt] |
SST-2 | 94.4 [ckpt] |
RTE | 83.0 [ckpt] |
CoLA | 68.3 [ckpt] |
MRPC | 93.9 [ckpt] |
STS-B | 91.1 [ckpt] |
Our codebase also supports vanilla knowledge distillation (KD) and layerwise distillation (LWD) methods. You can run KD with
./scripts/kd_glue.sh
by adding --filter_disabled
in the argument list and setting --mse_alpha = 0
. You can run LWD with
./scripts/lwd_glue.sh
by adding --filter_disabled
in the argument list.
We first properly intialize the teacher and the student models on the target task by running:
./scripts/finetune_qa.sh
Alternatively, you may directly use the fine-tuned models we released:
Task | DeBERTaV3-base | DeBERTaV3-xs |
---|---|---|
SQuAD 2.0 | 88.4 [ckpt] | 84.5 [ckpt] |
For Stage I, we learn filters by running:
./scripts/learn_filters_qa.sh
Alternatively, you can directly use the models with learned filters we released:
Task | DeBERTaV3-base (Stage I) | DeBERTaV3-xs (Stage I) |
---|---|---|
SQuAD 2.0 | [ckpt] | [ckpt] |
For Stage II, we conduct task-aware distillation by running:
./scripts/ted_qa.sh
The final task-aware distilled student checkpoints have been released at:
Task | DeBERTaV3-xs (Stage II) |
---|---|
SQuAD 2.0 | 86.4 [ckpt] |
Our codebase also supports vanilla knowledge distillation (KD) and layerwise distillation (LWD) methods. You can run:
./scripts/kd_qa.sh
./scripts/lwd_qa.sh
@article{liang2022less,
title={Less is More: Task-aware Layer-wise Distillation for Language Model Compression},
author={Liang, Chen and Zuo, Simiao and Zhang, Qingru and He, Pengcheng and Chen, Weizhu and Zhao, Tuo},
journal={arXiv preprint arXiv:2210.01351},
year={2022}
}
Please submit a GitHub issue if you need examples for other tasks and models, or other issues related to this package. For questions related to this paper, please contact Chen Liang (cliang73@gatech.edu).