CLEAR: Conv-Like Linearization Revs Pre-Trained Diffusion Transformers Up
Songhua Liu, Zhenxiong Tan, and Xinchao Wang
Learning and Vision Lab, National University of Singapore
[2024/12/20] We release training and inference codes of CLEAR, a simple-yet-effective strategy to linearize the complexity of pre-trained diffusion transformers, such as FLUX and SD3.
Diffusion Transformers (DiT) have become a leading architecture in image generation. However, the quadratic complexity of attention mechanisms, which are responsible for modeling token-wise relationships, results in significant latency when generating high-resolution images. To address this issue, we aim at a linear attention mechanism in this paper that reduces the complexity of pre-trained DiTs to linear. We begin our exploration with a comprehensive summary of existing efficient attention mechanisms and identify four key factors crucial for successful linearization of pre-trained DiTs: locality, formulation consistency, high-rank attention maps, and feature integrity. Based on these insights, we introduce a convolution-like local attention strategy termed CLEAR, which limits feature interactions to a local window around each query token, and thus achieves linear complexity. Our experiments indicate that, by fine-tuning the attention layer on merely 10K self-generated samples for 10K iterations, we can effectively transfer knowledge from a pre-trained DiT to a student model with linear complexity, yielding results comparable to the teacher model. Simultaneously, it reduces attention computations by 99.5% and accelerates generation by 6.3 times for generating 8K-resolution images. Furthermore, we investigate favorable properties in the distilled attention layers, such as zero-shot generalization cross various models and plugins, and improved support for multi-GPU parallel inference.
TL;DR: For pre-trained diffusion transformers, enforcing an image token interact with only tokens within a local window can effectively reduce the complexity of the original models to a linear scale.
-
CLEAR requires
torch>=2.5.0
,diffusers>=0.31.0
, and other packages listed inrequirements.txt
. You can set up a new experiment with:conda create -n CLEAR python=3.12 conda activate CLEAR pip install -r requirements.txt
-
Clone this repo to your project directory:
git clone https://github.com/Huage001/CLEAR.git
We release a series of variants for linearized FLUX-1.dev with various local window sizes.
We experimentally find that when local window size is small, e.g., 8, the model can produce repetitive patterns in many cases. To alleviate the problem, in some variants, we also include down-sampled key-value tokens besides local tokens for attention interaction.
The supported models and the download links are:
window_size | down_factor | link |
---|---|---|
32 | NA | here |
16 | NA | here |
8 | NA | here |
16 | 4 | here |
8 | 4 | here |
You are encouraged to download the model weights you need to ckpt
beforehand. For example:
mkdir ckpt
wget https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_8_down_4.safetensors
-
If you want to compare the linearized FLUX with the original model, please try
inference_t2i.ipynb
. -
If you want to use CLEAR for high-resolution acceleration, please try
inference_t2i_highres.ipynb
. We current adopt the strategy of SDEdit. The basic idea is to generate a low-resolution result at first, based on which we gradually upscale the image. -
Please configure
down_factor
andwindow_size
in the notebooks to use different variants of CLEAR. If you do not want to include down-sampled key-value tokens, specifydown_factor=1
. The models will be downloaded automatically tockpt
if not downloaded. -
Currently, a GPU card with 48G VMem is recommeded for high-resolution generation.
-
Configure
/path/to/t2i_1024
in multiple.sh
files. -
Download training images from here, which contains 10K 1024-resolution images generated by
FLUX-1.dev
itself, and unzip it to/path/to/t2i_1024
:tar -xvf data_000000.tar -C /path/to/t2i_1024
-
[Optional but Recommended] Cache T5 and CLIP text embedings and VAE features beforehand:
bash cache_prompt_embeds.sh bash cache_latent_codes.sh
-
Start Training:
bash distill.sh
By default, it uses 4 80G-VMem GPUs with
train_batch_size=2
andgradient_accumulation_steps=4
. Please feel free to configure them indistill.sh
anddeepspeed_config.yaml
according to your situations.
- FLUX for the source models.
- flexattention for kernel implementation.
- diffusers for the code base.
- DeepSpeed for the training framework.
- SDEdit for high-resolution image generation.
- @Weihao Yu and @Xinyin Ma for valuable discussions.
- NUS IT’s Research Computing group using grant numbers NUSREC-HPC-00001.
If you finds this repo is helpful, please consider citing:
@article{liu2024clear,
title = {CLEAR: Conv-Like Linearization Revs Pre-Trained Diffusion Transformers Up},
author = {Liu, Songhua and Tan, Zhenxiong and Wang, Xinchao},
year = {2024},
eprint = {2412.16112},
archivePrefix={arXiv},
primaryClass={cs.CV}
}