Skip to content

Latest commit

 

History

History
55 lines (45 loc) · 2.03 KB

README.md

File metadata and controls

55 lines (45 loc) · 2.03 KB

SimpleSDXL

This repository contains a simple and flexible PyTorch implementation of StableDiffusion-XL based on diffusers.

Prepartion

  • You should download the checkpoints of SDXL-base, from SDXL-base and SDXL-refiner, from SDXL-refiner, including scheduler, text_encoder_1, text_encoder_2, tokenizer, tokenizer_2, unet, and vae. Then put it in the ckpt folder.
  • We recommend you to use git-lfs to download the huggingface checkpoint directly via:
yum install git-lfs
git lfs install
git clone https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0
git clone https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0

Requirements

A suitable conda environment named ldm can be created and activated with:

conda env create -f environment.yaml
conda activate ldm

Dataset Preparation

  • You need write a DataLoader suitable for your own Dataset, because we just provide a simple example to test the code.

Training

CUDA_VISIBLE_DEVICES=0 python train.py
CUDA_VISIBLE_DEVICES=0,1,2 accelerate launch --multi_gpu train.py

Inference

CUDA_VISIBLE_DEVICES=0 python inference.py --prompt "A cat is running in the rain."

TODO

  • Base Model Training & Inference Code
  • (Soon) Refiner Model Training & Inference Code
  • (Soon) Fix Bugs such as mixed-precision
  • (Soon) Fix Other Bugs

Acknowledgements

Many thanks to the code bases from diffusers and SimpleSDM.