Skip to content
/ REPA Public

Official Pytorch Implementation of Representation Alignment for Generation: Training Diffusion Transformers Is Easier Than You Think

License

Notifications You must be signed in to change notification settings

sihyun-yu/REPA

Repository files navigation

Representation Alignment for Generation:
Training Diffusion Transformers Is Easier Than You Think

arXiv  PWC

Sihyun Yu1·Sangkyung Kwak1·Huiwon Jang1·Jongheon Jeong2
Jonathan Huang3·Jinwoo Shin1*·Saining Xie4*
1 KAIST   2Korea University   3Scaled Foundations   4New York University  
*Equal Advising  

Summary: We propose REPresentation Alignment (REPA), a method that aligns noisy input states in diffusion models with representations from pretrained visual encoders. This significantly improves training efficiency and generation quality. REPA speeds up SiT training by 17.5x and achieves state-of-the-art FID=1.42.

1. Environment setup

conda create -n repa python=3.9 -y
conda activate repa
pip install -r requirements.txt

2. Dataset

Dataset download

Currently, we provide experiments for ImageNet. You can place the data that you want and can specifiy it via --data-dir arguments in training scripts. Please refer to our preprocessing guide.

3. Training

accelerate launch train.py \
  --report-to="wandb" \
  --allow-tf32 \
  --mixed-precision="fp16" \
  --seed=0 \
  --path-type="linear" \
  --prediction="v" \
  --weighting="uniform" \
  --model="SiT-XL/2" \
  --enc-type="dinov2-vit-b" \
  --proj-coeff=0.5 \
  --encoder-depth=8 \
  --output-dir="exps" \
  --exp-name="linear-dinov2-b-enc8" \
  --data-dir=[YOUR_DATA_PATH]

Then this script will automatically create the folder in exps to save logs and checkpoints. You can adjust the following options:

  • --models: [SiT-B/2, SiT-L/2, SiT-XL/2]
  • --enc-type: [dinov2-vit-b, dinov2-vit-l, dinov2-vit-g, dinov1-vit-b, mocov3-vit-b, , mocov3-vit-l, clip-vit-L, jepa-vit-h, mae-vit-l]
  • --proj-coeff: Any values larger than 0
  • --encoder-depth: Any values between 1 to the depth of the model
  • --output-dir: Any directory that you want to save checkpoints and logs
  • --exp-name: Any string name (the folder will be created under output-dir)

For DINOv2 models, it will be automatically downloaded from torch.hub. For CLIP models, it will be also automatically downloaded from the CLIP repository. For other pretrained visual encoders, please download the model weights from the below links and place into the following directories with these names:

  • dinov1: Download the ViT-B/16 model from the DINO repository and place it as ./ckpts/dinov1_vitb.pth
  • mocov3: Download the ViT-B/16 or ViT-L/16 model from the RCG repository and place them as ./ckpts/mocov3_vitb.pth or ./ckpts/mocov3_vitl.pth
  • jepa: Download the ViT-H/14 model (ImageNet-1K) from the I-JEPA repository and place it as ./ckpts/ijepa_vith.pth
  • mae: Download the ViT-L model from MAE repository and place it as ./ckpts/mae_vitl.pth

[12/17/2024]: We also support training on 512x512 resolution. Please use the following script:

accelerate launch train.py \
  --report-to="wandb" \
  --allow-tf32 \
  --mixed-precision="fp16" \
  --seed=0 \
  --path-type="linear" \
  --prediction="v" \
  --weighting="uniform" \
  --model="SiT-XL/2" \
  --enc-type="dinov2-vit-b" \
  --proj-coeff=0.5 \
  --encoder-depth=8 \
  --output-dir="exps" \
  --exp-name="linear-dinov2-b-enc8-in512" \
  --resolution=512 \
  --data-dir=[YOUR_DATA_PATH]

You also need a new data preprocessing that resizes each image to 512x512 resolution and encodes each image as 64x64 resolution latent vectors (using stable-diffusion VAE). This script is also provided in our preprocessing guide.

4. Evaluation

You can generate images (and the .npz file can be used for ADM evaluation suite) through the following script:

torchrun --nnodes=1 --nproc_per_node=8 generate.py \
  --model SiT-XL/2 \
  --num-fid-samples 50000 \
  --ckpt YOUR_CHECKPOINT_PATH \
  --path-type=linear \
  --encoder-depth=8 \
  --projector-embed-dims=768 \
  --per-proc-batch-size=64 \
  --mode=sde \
  --num-steps=250 \
  --cfg-scale=1.8 \
  --guidance-high=0.7

We also provide the SiT-XL/2 checkpoint (trained for 4M iterations) used in the final evaluation. It will be automatically downloaded if you do not specify --ckpt.

Note

It's possible that this code may not accurately replicate the results outlined in the paper due to potential human errors during the preparation and cleaning of the code for release. If you encounter any difficulties in reproducing our findings, please don't hesitate to inform us. Additionally, we'll make an effort to carry out sanity-check experiments in the near future.

Acknowledgement

This code is mainly built upon DiT, SiT, edm2, and RCG repositories.
We also appreciate Kyungmin Lee for providing the initial version of the implementation.

BibTeX

@article{yu2024repa,
  title={Representation Alignment for Generation: Training Diffusion Transformers Is Easier Than You Think},
  author={Sihyun Yu and Sangkyung Kwak and Huiwon Jang and Jongheon Jeong and Jonathan Huang and Jinwoo Shin and Saining Xie},
  year={2024},
  journal={arXiv preprint arXiv:2410.06940},
}

About

Official Pytorch Implementation of Representation Alignment for Generation: Training Diffusion Transformers Is Easier Than You Think

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages