Skip to content
/ CLEAR Public

Official PyTorch implementation of paper "CLEAR: Conv-Like Linearization Revs Pre-Trained Diffusion Transformers Up".

License

Notifications You must be signed in to change notification settings

Huage001/CLEAR

Repository files navigation

CLEAR

arXiv

CLEAR: Conv-Like Linearization Revs Pre-Trained Diffusion Transformers Up
Songhua Liu, Zhenxiong Tan, and Xinchao Wang
Learning and Vision Lab, National University of Singapore

🔥News

[2024/12/20] We release training and inference codes of CLEAR, a simple-yet-effective strategy to linearize the complexity of pre-trained diffusion transformers, such as FLUX and SD3.

Introduction

Diffusion Transformers (DiT) have become a leading architecture in image generation. However, the quadratic complexity of attention mechanisms, which are responsible for modeling token-wise relationships, results in significant latency when generating high-resolution images. To address this issue, we aim at a linear attention mechanism in this paper that reduces the complexity of pre-trained DiTs to linear. We begin our exploration with a comprehensive summary of existing efficient attention mechanisms and identify four key factors crucial for successful linearization of pre-trained DiTs: locality, formulation consistency, high-rank attention maps, and feature integrity. Based on these insights, we introduce a convolution-like local attention strategy termed CLEAR, which limits feature interactions to a local window around each query token, and thus achieves linear complexity. Our experiments indicate that, by fine-tuning the attention layer on merely 10K self-generated samples for 10K iterations, we can effectively transfer knowledge from a pre-trained DiT to a student model with linear complexity, yielding results comparable to the teacher model. Simultaneously, it reduces attention computations by 99.5% and accelerates generation by 6.3 times for generating 8K-resolution images. Furthermore, we investigate favorable properties in the distilled attention layers, such as zero-shot generalization cross various models and plugins, and improved support for multi-GPU parallel inference.

TL;DR: For pre-trained diffusion transformers, enforcing an image token interact with only tokens within a local window can effectively reduce the complexity of the original models to a linear scale.

Installation

  • CLEAR requires torch>=2.5.0, diffusers>=0.31.0, and other packages listed in requirements.txt. You can set up a new experiment with:

    conda create -n CLEAR python=3.12
    conda activate CLEAR
    pip install -r requirements.txt
  • Clone this repo to your project directory:

    git clone https://github.com/Huage001/CLEAR.git

Supported Models

We release a series of variants for linearized FLUX-1.dev with various local window sizes.

We experimentally find that when local window size is small, e.g., 8, the model can produce repetitive patterns in many cases. To alleviate the problem, in some variants, we also include down-sampled key-value tokens besides local tokens for attention interaction.

The supported models and the download links are:

window_size down_factor link
32 NA here
16 NA here
8 NA here
16 4 here
8 4 here

You are encouraged to download the model weights you need to ckpt beforehand. For example:

mkdir ckpt
wget https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_8_down_4.safetensors

Inference

  • If you want to compare the linearized FLUX with the original model, please try inference_t2i.ipynb.

  • If you want to use CLEAR for high-resolution acceleration, please try inference_t2i_highres.ipynb. We current adopt the strategy of SDEdit. The basic idea is to generate a low-resolution result at first, based on which we gradually upscale the image.

  • Please configure down_factor and window_size in the notebooks to use different variants of CLEAR. If you do not want to include down-sampled key-value tokens, specify down_factor=1. The models will be downloaded automatically to ckpt if not downloaded.

  • Currently, a GPU card with 48G VMem is recommeded for high-resolution generation.

Training

  • Configure /path/to/t2i_1024 in multiple .sh files.

  • Download training images from here, which contains 10K 1024-resolution images generated by FLUX-1.dev itself, and unzip it to /path/to/t2i_1024:

    tar -xvf data_000000.tar -C /path/to/t2i_1024
    
  • [Optional but Recommended] Cache T5 and CLIP text embedings and VAE features beforehand:

    bash cache_prompt_embeds.sh
    bash cache_latent_codes.sh
  • Start Training:

    bash distill.sh

    By default, it uses 4 80G-VMem GPUs with train_batch_size=2 and gradient_accumulation_steps=4. Please feel free to configure them in distill.sh and deepspeed_config.yaml according to your situations.

Acknowledgement

  • FLUX for the source models.
  • flexattention for kernel implementation.
  • diffusers for the code base.
  • DeepSpeed for the training framework.
  • SDEdit for high-resolution image generation.
  • @Weihao Yu and @Xinyin Ma for valuable discussions.
  • NUS IT’s Research Computing group using grant numbers NUSREC-HPC-00001.

Citation

If you finds this repo is helpful, please consider citing:

@article{liu2024clear,
  title     = {CLEAR: Conv-Like Linearization Revs Pre-Trained Diffusion Transformers Up},
  author    = {Liu, Songhua and Tan, Zhenxiong and Wang, Xinchao},
  year      = {2024},
  eprint    = {2412.16112},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
}

About

Official PyTorch implementation of paper "CLEAR: Conv-Like Linearization Revs Pre-Trained Diffusion Transformers Up".

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published