Gongfan Fangβ , Kunjun Liβ , Xinyin Ma, Xinchao Wang
National University of Singapore
β : Equal Contribution
π [Arxiv]
This work presents TinyFusion, a learnable depth pruning method for diffusion transformers. We highlight the following key findings:
- π‘ Is calibration loss a reliable indicator? Our experiments show that, pruned models with low calibration loss may not guarantee good post-finetuning performance.
- π Optimizing the recoverability: TinyFusion directly optimizes the recoverability of the pruned model, which ensures better performance after fine-tuning.
- β‘ Training Efficiency: The trainable parameters in TinyFusion is only ~0.9% of the original model, making the pruning process highly efficient.
- π Masked KD for Enchanced Recovery: For recovery, we propose Masked Knowledge Distillation that excludes extreme activations in DiTs during knowledge transfer, which significantly improves performance compared to standard fine-tuning.
pip install -r requirements.txt
Download pre-trained TinyDiT-D14 with 14 layers
mkdir -p pretrained && cd pretrained
wget https://github.com/VainF/TinyFusion/releases/download/v1.0.0/TinyDiT-D14-MaskedKD-500K.pt
python sample.py --model DiT-D14/2 --ckpt pretrained/TinyDiT-D14-MaskedKD-500K.pt --seed 5464
torchrun --nnodes=1 --nproc_per_node=1 extract_features.py --model DiT-XL/2 --data-path data/imagenet/train --features-path data/imagenet_encoded
mkdir -p pretrained && cd pretrained
wget https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt
The script prune_by_learning.py allows users to prune and derive shallow versions of specific models. To learn a shallow DiT with 14 layers, we use the following script:
torchrun --nnodes=1 --nproc_per_node=8 prune_by_learning.py \
--model DiT-XL-1-2 \
--load-weight pretrained/DiT-XL-2-256x256.pt \
--data-path data/imagenet_encoded \
--epochs 1 \
--global-batch-size 128 \
--delta-w \
--lora \
--save-model outputs/pruned/DiT-D14-Learned.pt
-
--model
: Specifies the model to be pruned. DiT-XL-1-2 will learn a 14 layer model with the block size of 2, where one layer will be removed from each block. -
--data-path
: Path to the encoded ImageNet. -
--delta-w
: Allow weight udpates during decision optimization. -
--lora
: Uses LoRA (Low-Rank Adaptation) for weight updates. If not specified, full fine-tuning will be used. -
--save-model
: Path to save the pruned model.
The script supports multiple models, each designed for specific pruning strategies. Below are the pre-defined options:
DiT_XL_1_2, # XL with 1:2 pruning => D14
DiT_XL_2_4, # XL with 2:4 pruning => D14
DiT_XL_7_14, # XL with 7:14 pruning => D14
DiT_XL_1_4, # XL with 1:4 pruning => D7
DiT_D14_1_2, # D14 with 1:2 pruning => D7
DiT_D14_2_4 # D14 with 2:4 pruning => D7
You can also customize your pruning patterns using the groups
arguments. The following example crafts a hybrid pattern with 2:4 and 3:4 pruning, yielding a 16-layer model.
def DiT_XL_customized(**kwargs):
return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, groups=[ [2,4], [2,4], [3,4], [2,4], [3,4], [2, 4], [2, 4] ], **kwargs)
This script estimates the input and output similarity of each layer as the importance score. Please refer to ShortGPT for more details.
python prune_by_score.py --model DiT-XL/2 --ckpt pretrained/DiT-XL-2-256x256.pt --save-model outputs/pruned/DiT-D14-Score.pt --n-pruned 14
BK-SDM keeps the first layers in each "encoder block" and the last layer in each "decoder block". For DiT, we treat the first 14 layers as the encoder and the last 14 layers as the decoder, and group every 2 layers as a block. The oracle pruning can be performed with:
python prune_by_index.py --model DiT-XL/2 --ckpt pretrained/DiT-XL-2-256x256.pt --kept-indices "[0, 2, 4, 6, 8, 10, 12, 14, 17, 19, 21, 23, 25, 27]" --save-model outputs/pruned/DiT-D14-Oracle.pt
To prune a model with predefined indices, use the following command:
python prune_by_index.py --model DiT-XL/2 --ckpt pretrained/DiT-XL-2-256x256.pt --save-model outputs/pruned/DiT-D14-by-Score.pt --kept-indices "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]"
torchrun --nnodes=1 --nproc_per_node=8 train.py --model DiT-D14/2 --load-weight outputs/pruned/DiT-D14-Learned.pt --data-path data/imagenet_encoded --epochs 100 --prefix D14-Learned-Finetuning
Finetuning with the proposed Masked KD, which masks massive activations in the teacher's and student's hidden states. Please see the paper for more details.
# Masked KD
torchrun --nnodes=1 --nproc_per_node=8 train_masked_kd.py --model DiT-D14/2 --load-weight outputs/pruned/DiT-D14-Learned.pt --data-path data/imagenet_encoded --epochs 100 --prefix D14-Learned-RepKD --teacher DiT-XL/2 --load-teacher pretrained/DiT-XL-2-256x256.pt
torchrun --nnodes=1 --nproc_per_node=8 sample_ddp.py --model DiT-D14/2 --ckpt outputs/D14-Learned-Finetuning/checkpoints/0500000.pt
Please refer to https://github.com/openai/guided-diffusion/tree/main/evaluations for the VIRTUAL_imagenet256_labeled.npz
.
python evaluator.py data/VIRTUAL_imagenet256_labeled.npz PATH_TO_YOUR.npz
We show that, incorporating the recoverability estimation can bring performance benefits to downstream finetuning.
Masked KD removes extreme activations in DiTs for better and more stable knowledge transfer.
We also extend our method to other models like MARs and SiTs, and the results suggest TinyFusion exhibits strong generalization across diverse diffusion transformer achitectures.
This project is built on facebookresearch/DiT. We also use NVlabs/MaskLLM for Gumbel softmax, openai/guided-diffusion for evaluation.
@article{fang2024tinyfusion,
title={TinyFusion: Diffusion Transformers Learned Shallow},
author={Fang, Gongfan and Li, Kunjun and Ma, Xinyin and Wang, Xinchao},
journal={arXiv preprint arXiv:2412.01199},
year={2024}
}