pip install -r requirements.txt
In this step, LoRA is trained based on SDXL for initializing the Distribution LoRA.
For example, your training script would be like this.
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export OUTPUT_DIR="checkpoints/lora-sdxl-dog"
export INSTANCE_DIR="dog"
export PROMPT="a close-up photo of a sbu dog"
export VALID_PROMPT="a sbu dog"
accelerate launch train_dreambooth_lora_sdxl.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="${PROMPT}" \
--rank=64 \
--resolution=1024 \
--train_batch_size=1 \
--learning_rate=5e-5 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=1000 \
--validation_prompt="${VALID_PROMPT}" \
--validation_epochs=50 \
--seed="0" \
--mixed_precision="fp16" \
--enable_xformers_memory_efficient_attention \
--gradient_checkpointing \
--use_8bit_adam \
--push_to_hub \
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
# for subject
export LORA_PATH="checkpoints/lora-sdxl-dog"
export INSTANCE_DIR="dog"
export PROMPT="a close-up photo of a sbu dog"
# general
export OUTPUT_DIR="distlora-sdxl-dog"
export VALID_PROMPT="a sbu dog"
accelerate launch train_dreambooth_distlora_sdxl.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--output_dir=$OUTPUT_DIR \
--lora_name_or_path=$LORA_PATH \
--instance_prompt="${PROMPT}" \
--instance_data_dir=$INSTANCE_DIR \
--resolution=1024 \
--train_batch_size=4 \
--learning_rate=5e-5 \
--similarity_lambda=0.00001 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=100 \
--validation_prompt="${VALID_PROMPT}" \
--validation_epochs=10 \
--report_to="wandb" \
--gradient_checkpointing
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export DISTLORA_PATH="..."
export OUTPUT_DIR="distlora-sdxl-dog"
export PROMPT="a close-up photo of a sbu dog"
python inference.py --pretrained_model_name_or_path=$MODEL_NAME --distlora_name_or_path=$DISTLORA_PATH --output_dir=$OUTPUT_DIR --prompt="${PROMPT}"