From 70a86f791fac2d44ae95747f2a147578652b9d65 Mon Sep 17 00:00:00 2001 From: Tsaiyue Date: Wed, 10 Jan 2024 13:17:22 +0800 Subject: [PATCH 1/2] [ppdiffusers]Kandinsky2_2 trainning support --- .../kandinsky2_2/text_to_image/README.md | 253 +++++ .../text_to_image/requirements.txt | 4 + .../train_text_to_image_decoder.py | 734 +++++++++++++++ .../train_text_to_image_decoder_lora.py | 888 ++++++++++++++++++ .../train_text_to_image_prior.py | 743 +++++++++++++++ .../train_text_to_image_prior_lora.py | 876 +++++++++++++++++ .../ppdiffusers/models/attention_processor.py | 4 +- 7 files changed, 3500 insertions(+), 2 deletions(-) create mode 100644 ppdiffusers/examples/kandinsky2_2/text_to_image/README.md create mode 100644 ppdiffusers/examples/kandinsky2_2/text_to_image/requirements.txt create mode 100644 ppdiffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py create mode 100644 ppdiffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder_lora.py create mode 100644 ppdiffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py create mode 100644 ppdiffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_prior_lora.py diff --git a/ppdiffusers/examples/kandinsky2_2/text_to_image/README.md b/ppdiffusers/examples/kandinsky2_2/text_to_image/README.md new file mode 100644 index 000000000..ab0b5cb2b --- /dev/null +++ b/ppdiffusers/examples/kandinsky2_2/text_to_image/README.md @@ -0,0 +1,253 @@ +# 微调Kandinsky2.2 text-to-image + +Kandinsky 2.2 包含一个根据文本提示生成图像嵌入表示的prior pipeline和一个根据图像嵌入表示生成输出图像的decoder pipeline。我们提供了 'train_text_to_image_prior.py '和 train_text_to_image_decoder.py '脚本,向您展示如何根据自己的数据集分别微调 Kandinsky prior模型和decoder模型。为了达到最佳效果,您应该对prior模型和decoder模型均进行微调。 + +___注意___: + +___这个脚本是试验性的。该脚本会对整个'decoder'或'prior'模型进行微调,有时模型会过度拟合,并遇到像`"catastrophic forgetting"`等问题。建议尝试不同的超参数,以便在自定义``数据集上获得最佳结果。___ + +## 1 安装依赖 + +在运行这个训练代码前,我们需要安装下面的训练依赖: + +```bash +# 安装2.5.2版本的paddlepaddle-gpu,当前我们选择了cuda11.7的版本,可以查看 https://www.paddlepaddle.org.cn/ 寻找自己适合的版本 +python -m pip install paddlepaddle-gpu==2.5.2.post117 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html + +# 安装所需的依赖, 如果提示权限不够,请在最后增加 --user 选项 +pip install -r requirements.txt +``` + +### 2 Pokemon训练示例 + +#### 2.1 微调 decoder + +```bash +export DATASET_NAME="lambdalabs/pokemon-blip-captions" + +python -u train_text_to_image_decoder.py \ + --dataset_name=$DATASET_NAME \ + --resolution=768 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --checkpoints_total_limit=3 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --seed=42 \ + --output_dir="kandi2-decoder-pokemon-model" +``` + + + +如果用户想要在自己的数据集上进行训练,那么需要根据`huggingface的 datasets 库`所需的格式准备数据集,有关数据集的介绍可以查看 [HF dataset的文档](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata). + +如果用户想要修改代码中的部分训练逻辑,那么需要修改训练代码。 + +```bash +export TRAIN_DIR="path_to_your_dataset" + +python -u train_text_to_image_decoder.py \ + --train_data_dir=$TRAIN_DIR \ + --resolution=768 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --checkpoints_total_limit=3 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --seed=42 \ + --output_dir="kandi2-decoder-pokemon-model" +``` + +训练完成后,模型将保存在命令中指定的 `output_dir` 目录中。在本例中是 `kandi22-decoder-pokemon-model`。要加载微调后的模型进行推理,后将该路径传递给 `AutoPipelineForText2Image` : + +```python +from ppdiffusers import AutoPipelineForText2Image + +pipe = AutoPipelineForText2Image.from_pretrained(output_dir) + +prompt='A robot pokemon, 4k photo' +images = pipe(prompt=prompt).images +images[0].save("robot-pokemon.png") +``` + +Checkpoints只保存 unet,因此要从checkpoints运行推理只需加载 unet: + +```python +from ppdiffusers import AutoPipelineForText2Image, UNet2DConditionModel + +model_path = "path_to_saved_model" + +unet = UNet2DConditionModel.from_pretrained(model_path + "/checkpoint-/unet") + +pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", unet=unet) + +images = pipe(prompt="A robot pokemon, 4k photo").images +images[0].save("robot-pokemon.png") +``` + +#### 2.2 微调 prior + +您可以使用`train_text_to_image_prior.py`脚本对Kandinsky prior模型进行微调: + +```bash +export DATASET_NAME="lambdalabs/pokemon-blip-captions" + +python -u train_text_to_image_prior.py \ + --dataset_name=$DATASET_NAME \ + --resolution=768 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --checkpoints_total_limit=3 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --seed=42 \ + --output_dir="kandi2-prior-pokemon-model" +``` + + + +要使用微调prior模型进行推理,首先需要将 `output_dir`传给 `DiffusionPipeline`从而创建一个prior pipeline。然后,从一个预训练或微调decoder以及刚刚创建的prior pipeline的所有模块中创建一个`KandinskyV22CombinedPipeline`: + +```python +from ppdiffusers import AutoPipelineForText2Image, DiffusionPipeline + +pipe_prior = DiffusionPipeline.from_pretrained(output_dir) +prior_components = {"prior_" + k: v for k,v in pipe_prior.components.items()} +pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", **prior_components) + +prompt='A robot pokemon, 4k photo' +images = pipe(prompt=prompt, negative_prompt=negative_prompt).images +images[0] +``` + +如果您想使用微调decoder和微调prior checkpoint,只需将上述代码中的 "kandinsky-community/kandinsky-2-2-decoder "替换为您自定义的模型库名称即可。请注意,要创建 `KandinskyV22CombinedPipeline`,您的模型 repo 必须有prior tag。如果您使用我们的训练脚本创建了模型 repo,则会自动包含prior tag。 + +#### 2.3 单机多卡训练 + +通过设置`--gpus`,我们可以指定 GPU 为 `0,1,2,3` 卡。这里我们只训练了`4000step`,因为这里的`4000 step x 4卡`近似于`单卡训练 16000 step`。 + +```bash +export DATASET_NAME="lambdalabs/pokemon-blip-captions" + +python -u -m paddle.distributed.launch --gpus "0,1,2,3" train_text_to_image_decoder.py \ + --dataset_name=$DATASET_NAME \ + --resolution=768 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --max_train_steps=4000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --checkpoints_total_limit=3 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --seed=42 \ + --output_dir="kandi2-decoder-pokemon-model" +``` + +# 使用 LoRA 和 Text-to-Image 技术进行模型训练 + +[LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) 是微软研究员引入的一项新技术,主要用于处理大模型微调的问题。目前超过数十亿以上参数的具有强能力的大模型 (例如 GPT-3) 通常在为了适应其下游任务的微调中会呈现出巨大开销。LoRA 建议冻结预训练模型的权重并在每个 Transformer 块中注入可训练层 (秩-分解矩阵)。因为不需要为大多数模型权重计算梯度,所以大大减少了需要训练参数的数量并且降低了 GPU 的内存要求。研究人员发现,通过聚焦大模型的 Transformer 注意力块,使用 LoRA 进行的微调质量与全模型微调相当,同时速度更快且需要更少的计算。 + +简而言之,LoRA允许通过向现有权重添加一对秩分解矩阵,并只训练这些新添加的权重来适应预训练的模型。这有几个优点: + +- 保持预训练的权重不变,这样模型就不容易出现灾难性遗忘 [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114); + +- 秩分解矩阵的参数比原始模型少得多,这意味着训练的 LoRA 权重很容易移植; + +- LoRA 注意力层允许通过一个 `scale` 参数来控制模型适应新训练图像的程度。 + +[cloneofsimo](https://github.com/cloneofsimo) 是第一个在 [LoRA GitHub](https://github.com/cloneofsimo/lora) 仓库中尝试使用 LoRA 训练 Stable Diffusion 的人。 + +利用 LoRA,可以在 T4、Tesla V100 等消费级 GPU 上,在自定义图文对数据集上对 Kandinsky 2.2 进行微调。 + +### 1 训练 + +#### 1.1 LoRA微调decoder + +```bash +export DATASET_NAME="lambdalabs/pokemon-blip-captions" + +python -u train_text_to_image_decoder_lora.py \ + --dataset_name=$DATASET_NAME \ + --resolution=768 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --max_train_steps=15000 \ + --learning_rate=1e-04 \ + --max_grad_norm=1 \ + --checkpoints_total_limit=3 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --seed=42 \ + --lora_rank=4 \ + --validation_prompt="cute dragon creature" \ + --output_dir="kandi22-decoder-pokemon-lora" +``` + +#### 1.2 LoRA微调prior + +```bash +export DATASET_NAME="lambdalabs/pokemon-blip-captions" + +python -u train_text_to_image_prior_lora.py \ + --dataset_name=$DATASET_NAME \ + --resolution=768 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --max_train_steps=15000 \ + --learning_rate=1e-04 \ + --max_grad_norm=1 \ + --checkpoints_total_limit=3 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --seed=42 \ + --lora_rank=4 \ + --validation_prompt="cute dragon creature" \ + --output_dir="kandi22-prior-pokemon-lora" +``` + +**___Note: 当我使用 LoRA 训练模型的时候,我们需要使用更大的学习率,因此我们这里使用 *1e-4* 而不是 *1e-5*.___** + +### 2 推理 + +#### 2.1 LoRA微调decoder推理流程 + +使用上述命令训练好Kandinsky decoder模型后,就可以在加载训练好 LoRA 权重后使用 `AutoPipelineForText2Image` 进行推理。您需要传递用于加载 LoRA 权重的 `output_dir` 目录,在本例中是 `kandi22-decoder-pokemon-lora`。 + +```python +from ppdiffusers import AutoPipelineForText2Image + +pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder") +pipe.unet.load_attn_procs(output_dir) + +prompt='A robot pokemon, 4k photo' +image = pipe(prompt=prompt).images[0] +image.save("robot_pokemon.png") +``` + +#### 2.2 LoRA微调prior推理流程 + +```python +from ppdiffusers import AutoPipelineForText2Image + +pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder") +pipe.prior_prior.load_attn_procs(output_dir) + +prompt='A robot pokemon, 4k photo' +image = pipe(prompt=prompt).images[0] +image.save("robot_pokemon.png") +``` + +# 参考资料 + +- https://github.com/huggingface/diffusers/tree/main/examples/kandinsky2_2/text_to_image + +- https://github.com/ai-forever/Kandinsky-2 + +- https://huggingface.co/kandinsky-community \ No newline at end of file diff --git a/ppdiffusers/examples/kandinsky2_2/text_to_image/requirements.txt b/ppdiffusers/examples/kandinsky2_2/text_to_image/requirements.txt new file mode 100644 index 000000000..d8666177e --- /dev/null +++ b/ppdiffusers/examples/kandinsky2_2/text_to_image/requirements.txt @@ -0,0 +1,4 @@ +paddlenlp>=2.6.0rc0 +Pillow +visualdl +ppdiffusers>=0.16.1 \ No newline at end of file diff --git a/ppdiffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/ppdiffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py new file mode 100644 index 000000000..7ba6057b4 --- /dev/null +++ b/ppdiffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py @@ -0,0 +1,734 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import math +import os +from pathlib import Path +from typing import Optional + +from PIL import Image +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from huggingface_hub import HfFolder, Repository, create_repo, whoami +from paddle.distributed.fleet.utils.hybrid_parallel_util import ( + fused_allreduce_gradients, +) +from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler +from paddle.optimizer import AdamW +from paddlenlp.trainer import set_seed +from paddlenlp.transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, + CLIPTextModelWithProjection, + CLIPTokenizer, +) +from paddlenlp.utils.log import logger +from tqdm.auto import tqdm +from datasets import load_dataset +from ppdiffusers import ( + AutoPipelineForText2Image, + PriorTransformer, + DDPMScheduler, + UNet2DConditionModel, + VQModel, + is_ppxformers_available, +) +from ppdiffusers.optimization import get_scheduler +from ppdiffusers.training_utils import ( + EMAModel, + freeze_params, + main_process_first, + unwrap_model, +) +from ppdiffusers.utils import check_min_version + +check_min_version("0.16.1") + + +def url_or_path_join(*path_list): + return ( + os.path.join(*path_list) + if os.path.isdir(os.path.join(*path_list)) + else "/".join(path_list) + ) + + +def get_report_to(args): + if args.report_to == "visualdl": + from visualdl import LogWriter + + writer = LogWriter(logdir=args.logging_dir) + elif args.report_to == "tensorboard": + from tensorboardX import SummaryWriter + + writer = SummaryWriter(logdir=args.logging_dir) + else: + raise ValueError("report_to must be in ['visualdl', 'tensorboard']") + return writer + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Simple example of finetuning Kandinsky 2.2." + ) + parser.add_argument( + "--pretrained_decoder_model_name_or_path", + type=str, + default="kandinsky-community/kandinsky-2-2-decoder", + required=False, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_prior_model_name_or_path", + type=str, + default="kandinsky-community/kandinsky-2-2-prior", + required=False, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private, dataset). It can also be a path pointing to a local copy of a dataset in your filesystem, or to a folder containing files that 🤗 Datasets can understand", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help="A folder containing the training data. Folder contents must follow the structure described in https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file must exist to provide the captions for the images. Ignored if `dataset_name` is specified.", + ) + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing an image.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help="For debugging purposes or quicker training, truncate the number of training examples to this value if set.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="kandi_2_2-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument( + "--seed", type=int, default=None, help="A seed for reproducible training." + ) + parser.add_argument( + "--resolution", + type=int, + default=512, + help="The resolution for input images, all the images in the train/validation dataset will be resized to this resolution", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", type=float, default=0.0001, help="learning rate" + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help='The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]', + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.", + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help="Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices", + ) + parser.add_argument( + "--use_ema", action="store_true", help="Whether to use EMA model." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_weight_decay", + type=float, + default=0.0, + required=False, + help="weight decay_to_use", + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer", + ) + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help="[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.", + ) + parser.add_argument( + "--report_to", + type=str, + default="visualdl", + choices=["tensorboard", "visualdl"], + help="Log writer type.", + ) + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help="Save a checkpoint of the training state every X updates.", + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help="Max number of checkpoints to store.", + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", + action="store_true", + help="Whether or not to use xformers.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help="The `project_name` argument passed to Accelerator.init_trackers for more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator", + ) + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + return args + + +def get_full_repo_name( + model_id: str, organization: Optional[str] = None, token: Optional[str] = None +): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +DATASET_NAME_MAPPING = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), +} + + +def main(): + args = parse_args() + rank = paddle.distributed.get_rank() + is_main_process = rank == 0 + num_processes = paddle.distributed.get_world_size() + if num_processes > 1: + paddle.distributed.init_parallel_env() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(args.output_dir).name, token=args.hub_token + ) + else: + repo_name = args.hub_model_id + create_repo(repo_name, exist_ok=True, token=args.hub_token) + repo = Repository( + args.output_dir, clone_from=repo_name, token=args.hub_token + ) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained( + args.pretrained_decoder_model_name_or_path, subfolder="scheduler" + ) + image_processor = CLIPImageProcessor.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="image_processor" + ) + + vae = VQModel.from_pretrained( + args.pretrained_decoder_model_name_or_path, subfolder="movq" + ) + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="image_encoder" + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_decoder_model_name_or_path, subfolder="unet" + ) + + freeze_params(vae.parameters()) + freeze_params(image_encoder.parameters()) + + if args.use_ema: + ema_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_decoder_model_name_or_path, + subfolder="unet", + ) + ema_unet = EMAModel(ema_unet.parameters()) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + if args.enable_xformers_memory_efficient_attention and is_ppxformers_available(): + try: + unet.enable_xformers_memory_efficient_attention() + except Exception as e: + logger.warn( + "Could not enable memory efficient attention. Make sure develop paddlepaddle is installed" + f" correctly and a GPU is available: {e}" + ) + + def compute_snr(timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod[timesteps].cast("float32") + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[timesteps].cast( + "float32" + ) + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + # 6. Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = ( + dataset_columns[0] if dataset_columns is not None else column_names[0] + ) + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + + def center_crop(image): + width, height = image.size + new_size = min(width, height) + left = (width - new_size) / 2 + top = (height - new_size) / 2 + right = (width + new_size) / 2 + bottom = (height + new_size) / 2 + return image.crop((left, top, right, bottom)) + + def train_transforms(img): + img = center_crop(img) + img = img.resize( + (args.resolution, args.resolution), resample=Image.BICUBIC, reducing_gap=1 + ) + img = np.array(img).astype(np.float32) / 127.5 - 1 + img = paddle.to_tensor(data=np.transpose(img, [2, 0, 1])) + return img + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + examples["pixel_values"] = [train_transforms(image) for image in images] + examples["clip_pixel_values"] = image_processor( + images, return_tensors="pd" + ).pixel_values + return examples + + with main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = ( + dataset["train"] + .shuffle(seed=args.seed) + .select(range(args.max_train_samples)) + ) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn(examples): + pixel_values = paddle.stack( + [example["pixel_values"] for example in examples] + ).cast("float32") + clip_pixel_values = paddle.stack( + [example["clip_pixel_values"] for example in examples] + ).cast("float32") + return {"pixel_values": pixel_values, "clip_pixel_values": clip_pixel_values} + + train_sampler = ( + DistributedBatchSampler( + train_dataset, batch_size=args.train_batch_size, shuffle=False + ) + if num_processes > 1 + else BatchSampler( + train_dataset, batch_size=args.train_batch_size, shuffle=False + ) + ) + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + if num_processes > 1: + unet = paddle.DataParallel(unet) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + learning_rate=args.learning_rate, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + # Initialize the optimizer + optimizer = AdamW( + learning_rate=lr_scheduler, + parameters=unet.parameters(), + beta1=args.adam_beta1, + beta2=args.adam_beta2, + weight_decay=args.adam_weight_decay, + epsilon=args.adam_epsilon, + grad_clip=nn.ClipGradByGlobalNorm(args.max_grad_norm) + if args.max_grad_norm > 0 + else None, + ) + + if is_main_process: + logger.info("----------- Configuration Arguments -----------") + for arg, value in sorted(vars(args).items()): + logger.info("%s: %s" % (arg, value)) + logger.info("------------------------------------------------") + writer = get_report_to(args) + + # Train! + total_batch_size = ( + args.train_batch_size * num_processes * args.gradient_accumulation_steps + ) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process) + progress_bar.set_description("Train Steps") + global_step = 0 + + # # Keep vae in eval model as we don't train these + vae.eval() + image_encoder.eval() + + unet.train() + + for epoch in range(args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + images = batch["pixel_values"].cast("float32") + clip_images = batch["clip_pixel_values"].cast("float32") + latents = vae.encode(images).latents + image_embeds = image_encoder(clip_images).image_embeds + + noise = paddle.randn(shape=latents.shape, dtype=latents.dtype) + + bsz = latents.shape[0] + timesteps = paddle.randint(low=0, high=noise_scheduler.config.num_train_timesteps, shape=(bsz,)) + timesteps = timesteps.astype(dtype='int64') + + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + target = noise + + added_cond_kwargs = {"image_embeds": image_embeds} + # Predict the noise residual and compute loss + model_pred = unet( + noisy_latents, timesteps, None, added_cond_kwargs=added_cond_kwargs + ).sample[:, :4] + + if args.snr_gamma is None: + loss = F.mse_loss( + model_pred.cast("float32"), + target.cast("float32"), + reduction="mean", + ) + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(timesteps) + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( + paddle.stack( + [snr, args.snr_gamma * paddle.ones_like(timesteps)], + axis=1, + ).min(1)[0] + / snr + ) + # We first calculate the original loss. Then we mean over the non-batch dimensions and + # rebalance the sample-wise losses with their respective loss weights. + # Finally, we take the mean of the rebalanced loss. + loss = F.mse_loss( + model_pred.cast("float32"), + target.cast("float32"), + reduction="none", + ) + loss = ( + loss.mean(axis=list(range(1, len(loss.shape)))) * mse_loss_weights + ) + loss = loss.mean() + + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + loss.backward() + + if (step + 1) % args.gradient_accumulation_steps == 0: + if num_processes > 1 and args.gradient_checkpointing: + fused_allreduce_gradients(unet.parameters(), None) + optimizer.step() + lr_scheduler.step() + optimizer.clear_grad() + progress_bar.update(1) + global_step += 1 + step_loss = loss.item() * args.gradient_accumulation_steps + + if args.use_ema: + ema_unet.step(unet.parameters()) + logs = { + "epoch": str(epoch).zfill(4), + "step_loss": round(step_loss, 10), + "lr": lr_scheduler.get_lr(), + } + progress_bar.set_postfix(**logs) + + if is_main_process: + for name, val in logs.items(): + if name == "epoch": + continue + writer.add_scalar(f"train/{name}", val, global_step) + + if global_step % args.checkpointing_steps == 0: + save_path = os.path.join( + args.output_dir, f"checkpoint-{global_step}" + ) + unwrap_model(unet).save_pretrained( + os.path.join(save_path, "unet") + ) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using the trained modules and save it. + if is_main_process: + writer.close() + unet = unwrap_model(unet) + if args.use_ema: + ema_unet.copy_to(unet.parameters()) + + text_encoder = CLIPTextModelWithProjection.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="text_encoder" + ) + prior = PriorTransformer.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="prior" + ) + prior_image_processor = CLIPImageProcessor.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="image_processor" + ) + prior_scheduler = DDPMScheduler( + beta_schedule="squaredcos_cap_v2", prediction_type="sample" + ) + prior_tokenizer = CLIPTokenizer.from_pretrained( + url_or_path_join(args.pretrained_prior_model_name_or_path, "tokenizer") + ) + + pipeline = AutoPipelineForText2Image.from_pretrained( + args.pretrained_decoder_model_name_or_path, + vae=vae, + unet=unet, + prior_text_encoder=text_encoder, + prior_image_encoder=image_encoder, + prior_prior=prior, + prior_image_processor=prior_image_processor, + prior_scheduler=prior_scheduler, + prior_tokenizer=prior_tokenizer, + ) + + pipeline.decoder_pipe.save_pretrained(args.output_dir) + + if args.push_to_hub: + repo.push_to_hub( + commit_message="End of training", blocking=False, auto_lfs_prune=True + ) + + +if __name__ == "__main__": + main() diff --git a/ppdiffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder_lora.py b/ppdiffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder_lora.py new file mode 100644 index 000000000..dc45815fd --- /dev/null +++ b/ppdiffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder_lora.py @@ -0,0 +1,888 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import gc +import math +import os + +from pathlib import Path +from typing import Optional + +from PIL import Image +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from huggingface_hub import HfFolder, Repository, create_repo, whoami +from paddle.distributed.fleet.utils.hybrid_parallel_util import ( + fused_allreduce_gradients, +) +from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler +from paddle.optimizer import AdamW +from paddlenlp.trainer import set_seed +from paddlenlp.transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, + CLIPTextModelWithProjection, + CLIPTokenizer, +) +from paddlenlp.utils.log import logger +from tqdm.auto import tqdm + +from datasets import load_dataset + +from ppdiffusers import ( + AutoPipelineForText2Image, + DDPMScheduler, + PriorTransformer, + UnCLIPScheduler, + UNet2DConditionModel, + VQModel, +) +from ppdiffusers.optimization import get_scheduler +from ppdiffusers.training_utils import ( + freeze_params, + main_process_first, + unwrap_model, +) +from ppdiffusers.utils import check_min_version +from ppdiffusers.loaders import AttnProcsLayers, LoraLoaderMixin +from ppdiffusers.models.attention_processor import LoRAAttnAddedKVProcessor + +check_min_version("0.16.1") + + +def url_or_path_join(*path_list): + return ( + os.path.join(*path_list) + if os.path.isdir(os.path.join(*path_list)) + else "/".join(path_list) + ) + + +def save_model_card( + repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None +): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +tags: +- kandinsky +- text-to-image +- ppdiffusers +- lora +inference: true +--- + """ + model_card = f""" +# LoRA text2image fine-tuning - {repo_id} +These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n +{img_str} +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def get_report_to(args): + if args.report_to == "visualdl": + from visualdl import LogWriter + + writer = LogWriter(logdir=args.logging_dir) + elif args.report_to == "tensorboard": + from tensorboardX import SummaryWriter + + writer = SummaryWriter(logdir=args.logging_dir) + else: + raise ValueError("report_to must be in ['visualdl', 'tensorboard']") + return writer + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Simple example of finetuning Kandinsky 2.2." + ) + parser.add_argument( + "--pretrained_decoder_model_name_or_path", + type=str, + default="kandinsky-community/kandinsky-2-2-decoder", + required=False, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_prior_model_name_or_path", + type=str, + default="kandinsky-community/kandinsky-2-2-prior", + required=False, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private, dataset). It can also be a path pointing to a local copy of a dataset in your filesystem, or to a folder containing files that 🤗 Datasets can understand", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help="A folder containing the training data. Folder contents must follow the structure described in https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file must exist to provide the captions for the images. Ignored if `dataset_name` is specified.", + ) + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing an image.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help="For debugging purposes or quicker training, truncate the number of training examples to this value if set.", + ) + + parser.add_argument( + "--output_dir", + type=str, + default="kandi_2_2-model-finetuned-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument( + "--seed", type=int, default=None, help="A seed for reproducible training." + ) + parser.add_argument( + "--resolution", + type=int, + default=512, + help="The resolution for input images, all the images in the train/validation dataset will be resized to this resolution", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", type=float, default=0.0001, help="learning rate" + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help='The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]', + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.", + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help="Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_weight_decay", + type=float, + default=0.0, + required=False, + help="weight decay_to_use", + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer", + ) + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help="[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.", + ) + parser.add_argument( + "--report_to", + type=str, + default="visualdl", + choices=["tensorboard", "visualdl"], + help="Log writer type.", + ) + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help="Save a checkpoint of the training state every X updates.", + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help="Max number of checkpoints to store.", + ) + + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is sampled during training for inference.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help="The `project_name` argument passed to Accelerator.init_trackers for more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator", + ) + parser.add_argument( + "--lora_rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + return args + + +def get_full_repo_name( + model_id: str, organization: Optional[str] = None, token: Optional[str] = None +): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +DATASET_NAME_MAPPING = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), +} + + +def main(): + args = parse_args() + rank = paddle.distributed.get_rank() + is_main_process = rank == 0 + num_processes = paddle.distributed.get_world_size() + if num_processes > 1: + paddle.distributed.init_parallel_env() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(args.output_dir).name, token=args.hub_token + ) + else: + repo_name = args.hub_model_id + create_repo(repo_name, exist_ok=True, token=args.hub_token) + repo = Repository( + args.output_dir, clone_from=repo_name, token=args.hub_token + ) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained( + args.pretrained_decoder_model_name_or_path, subfolder="scheduler" + ) + image_processor = CLIPImageProcessor.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="image_processor" + ) + + vae = VQModel.from_pretrained( + args.pretrained_decoder_model_name_or_path, subfolder="movq" + ) + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="image_encoder" + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_decoder_model_name_or_path, subfolder="unet" + ) + + freeze_params(vae.parameters()) + freeze_params(image_encoder.parameters()) + + # Set correct lora layers + unet_lora_attn_procs = {} + for name, attn_processor in unet.attn_processors.items(): + cross_attention_dim = ( + None + if name.endswith("attn1.processor") + else unet.config.cross_attention_dim + ) + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + + unet_lora_attn_procs[name] = LoRAAttnAddedKVProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=args.lora_rank, + ) + + unet.set_attn_processor(unet_lora_attn_procs) + unet_lora_layers = AttnProcsLayers(unet.attn_processors) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + def compute_snr(timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod[timesteps].cast("float32") + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[timesteps].cast( + "float32" + ) + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + # if args.debug: + # file_path = get_path_from_url_with_filelock( + # "https://paddlenlp.bj.bcebos.com/models/community/junnyu/develop/pokemon-blip-captions.tar.gz", + # PPDIFFUSERS_CACHE, + # ) + # dataset = DatasetDict.load_from_disk(file_path) + # args.dataset_name = "lambdalabs/pokemon-blip-captions" + # else: + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + # 6. Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = ( + dataset_columns[0] if dataset_columns is not None else column_names[0] + ) + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + + def center_crop(image): + width, height = image.size + new_size = min(width, height) + left = (width - new_size) / 2 + top = (height - new_size) / 2 + right = (width + new_size) / 2 + bottom = (height + new_size) / 2 + return image.crop((left, top, right, bottom)) + + def train_transforms(img): + img = center_crop(img) + img = img.resize( + (args.resolution, args.resolution), resample=Image.BICUBIC, reducing_gap=1 + ) + img = np.array(img).astype(np.float32) / 127.5 - 1 + img = paddle.to_tensor(data=np.transpose(img, [2, 0, 1])) + return img + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + examples["pixel_values"] = [train_transforms(image) for image in images] + examples["clip_pixel_values"] = image_processor( + images, return_tensors="pd" + ).pixel_values + return examples + + with main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = ( + dataset["train"] + .shuffle(seed=args.seed) + .select(range(args.max_train_samples)) + ) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn(examples): + pixel_values = paddle.stack( + [example["pixel_values"] for example in examples] + ).cast("float32") + clip_pixel_values = paddle.stack( + [example["clip_pixel_values"] for example in examples] + ).cast("float32") + return {"pixel_values": pixel_values, "clip_pixel_values": clip_pixel_values} + + train_sampler = ( + DistributedBatchSampler( + train_dataset, batch_size=args.train_batch_size, shuffle=False + ) + if num_processes > 1 + else BatchSampler( + train_dataset, batch_size=args.train_batch_size, shuffle=False + ) + ) + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + if num_processes > 1: + unet = paddle.DataParallel(unet) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + learning_rate=args.learning_rate, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + # Initialize the optimizer + optimizer = AdamW( + learning_rate=lr_scheduler, + parameters=unet_lora_layers.parameters(), + beta1=args.adam_beta1, + beta2=args.adam_beta2, + weight_decay=args.adam_weight_decay, + epsilon=args.adam_epsilon, + grad_clip=nn.ClipGradByGlobalNorm(args.max_grad_norm) + if args.max_grad_norm > 0 + else None, + ) + + if is_main_process: + logger.info("----------- Configuration Arguments -----------") + for arg, value in sorted(vars(args).items()): + logger.info("%s: %s" % (arg, value)) + logger.info("------------------------------------------------") + writer = get_report_to(args) + + # Train! + total_batch_size = ( + args.train_batch_size * num_processes * args.gradient_accumulation_steps + ) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process) + progress_bar.set_description("Train Steps") + global_step = 0 + + # # Keep vae in eval model as we don't train these + vae.eval() + image_encoder.eval() + + unet.train() + + for epoch in range(args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + images = batch["pixel_values"].cast("float32") + + clip_images = batch["clip_pixel_values"].cast("float32") + + latents = vae.encode(images).latents + image_embeds = image_encoder(clip_images).image_embeds + # Sample noise that we'll add to the latents + + noise = paddle.randn(latents.shape, dtype=latents.dtype) + + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = paddle.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,)) + timesteps = timesteps.astype("int64") + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + target = noise + + # Predict the noise residual and compute loss + added_cond_kwargs = {"image_embeds": image_embeds} + + model_pred = unet( + noisy_latents, timesteps, None, added_cond_kwargs=added_cond_kwargs + ).sample[:, :4] + + if args.snr_gamma is None: + loss = F.mse_loss( + model_pred.cast("float32"), + target.cast("float32"), + reduction="mean", + ) + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(timesteps) + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( + paddle.stack( + [snr, args.snr_gamma * paddle.ones_like(timesteps)], + axis=1, + ).min(1)[0] + / snr + ) + # We first calculate the original loss. Then we mean over the non-batch dimensions and + # rebalance the sample-wise losses with their respective loss weights. + # Finally, we take the mean of the rebalanced loss. + loss = F.mse_loss( + model_pred.cast("float32"), + target.cast("float32"), + reduction="none", + ) + loss = ( + loss.mean(axis=list(range(1, len(loss.shape)))) * mse_loss_weights + ) + loss = loss.mean() + step_loss = loss.item() + + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + loss.backward() + + if (step + 1) % args.gradient_accumulation_steps == 0: + if num_processes > 1 and args.gradient_checkpointing: + fused_allreduce_gradients(unet_lora_layers.parameters(), None) + optimizer.step() + lr_scheduler.step() + optimizer.clear_grad() + progress_bar.update(1) + global_step += 1 + step_loss = loss.item() * args.gradient_accumulation_steps + + logs = { + "epoch": str(epoch).zfill(4), + "step_loss": round(step_loss, 10), + "lr": lr_scheduler.get_lr(), + } + progress_bar.set_postfix(**logs) + + if is_main_process: + for name, val in logs.items(): + if name == "epoch": + continue + writer.add_scalar(f"train/{name}", val, global_step) + + if global_step % args.checkpointing_steps == 0: + save_path = os.path.join( + args.output_dir, f"checkpoint-{global_step}" + ) + + LoraLoaderMixin.save_lora_weights( + save_directory=save_path, + unet_lora_layers=unet_lora_layers, + ) + logger.info(f"Saved lora weights to {save_path}") + + if global_step >= args.max_train_steps: + break + + if is_main_process: + text_encoder = CLIPTextModelWithProjection.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="text_encoder" + ) + prior = PriorTransformer.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="prior" + ) + prior_image_processor = CLIPImageProcessor.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="image_processor" + ) + prior_scheduler = UnCLIPScheduler( + beta_schedule="squaredcos_cap_v2", prediction_type="sample" + ) + prior_tokenizer = CLIPTokenizer.from_pretrained( + url_or_path_join(args.pretrained_prior_model_name_or_path, "tokenizer") + ) + if ( + args.validation_prompt is not None + and epoch % args.validation_epochs == 0 + ): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + pipeline = AutoPipelineForText2Image.from_pretrained( + args.pretrained_decoder_model_name_or_path, + unet=unwrap_model(unet), + prior_text_encoder=text_encoder, + prior_image_encoder=image_encoder, + prior_prior=prior, + prior_image_processor=prior_image_processor, + prior_scheduler=prior_scheduler, + prior_tokenizer=prior_tokenizer, + ) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = ( + paddle.Generator().manual_seed(args.seed) if args.seed else None + ) + images = [ + pipeline( + args.validation_prompt, + num_inference_steps=30, + generator=generator, + ).images[0] + for _ in range(args.num_validation_images) + ] + np_images = np.stack([np.asarray(img) for img in images]) + + if args.report_to == "tensorboard": + writer.add_images( + "validation", np_images, epoch, dataformats="NHWC" + ) + else: + writer.add_image("validation", np_images, epoch, dataformats="NHWC") + + del pipeline + gc.collect() + unet.train() + + # Create the pipeline using the trained modules and save it. + if is_main_process: + LoraLoaderMixin.save_lora_weights( + save_directory=args.output_dir, + unet_lora_layers=unet_lora_layers, + ) + if args.push_to_hub: + save_model_card( + repo_name, + images=images, + base_model=args.pretrained_decoder_model_name_or_path, + prompt=args.instance_prompt, + repo_folder=args.output_dir, + ) + repo.push_to_hub( + commit_message="End of training", blocking=False, auto_lfs_prune=True + ) + + # pipeline.save_pretrained(args.output_dir) + pipeline = AutoPipelineForText2Image.from_pretrained( + args.pretrained_decoder_model_name_or_path, + prior_text_encoder=text_encoder, + prior_image_encoder=image_encoder, + prior_prior=prior, + prior_image_processor=prior_image_processor, + prior_scheduler=prior_scheduler, + prior_tokenizer=prior_tokenizer, + ) + pipeline.unet.load_attn_procs(args.output_dir) + + # run inference + if args.validation_prompt and args.num_validation_images > 0: + generator = paddle.Generator().manual_seed(args.seed) if args.seed else None + images = [ + pipeline( + args.validation_prompt, num_inference_steps=30, generator=generator + ).images[0] + for _ in range(args.num_validation_images) + ] + np_images = np.stack([np.asarray(img) for img in images]) + + if args.report_to == "tensorboard": + writer.add_images("test", np_images, epoch, dataformats="NHWC") + else: + writer.add_image("test", np_images, epoch, dataformats="NHWC") + + writer.close() + if args.push_to_hub: + repo.push_to_hub( + commit_message="End of training", blocking=False, auto_lfs_prune=True + ) + + +if __name__ == "__main__": + main() diff --git a/ppdiffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/ppdiffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py new file mode 100644 index 000000000..9e9cbced2 --- /dev/null +++ b/ppdiffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py @@ -0,0 +1,743 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import math +import os +import random +from pathlib import Path +from typing import Optional + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +# from datasets import DatasetDict, load_dataset +from huggingface_hub import HfFolder, Repository, create_repo, whoami +from paddle.distributed.fleet.utils.hybrid_parallel_util import ( + fused_allreduce_gradients, +) +from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler +from paddle.optimizer import AdamW +from paddlenlp.trainer import set_seed +from paddlenlp.transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, + CLIPTokenizer, + CLIPTextModelWithProjection, +) +from paddlenlp.utils.log import logger +from tqdm.auto import tqdm + +from ppdiffusers import ( + AutoPipelineForText2Image, + DDPMScheduler, + PriorTransformer, +) +from ppdiffusers.optimization import get_scheduler +from ppdiffusers.training_utils import ( + EMAModel, + freeze_params, + main_process_first, + unwrap_model, +) +from ppdiffusers.utils import check_min_version +from datasets import load_dataset + +check_min_version("0.16.1") + + +def url_or_path_join(*path_list): + return ( + os.path.join(*path_list) + if os.path.isdir(os.path.join(*path_list)) + else "/".join(path_list) + ) + + +def get_report_to(args): + if args.report_to == "visualdl": + from visualdl import LogWriter + + writer = LogWriter(logdir=args.logging_dir) + elif args.report_to == "tensorboard": + from tensorboardX import SummaryWriter + + writer = SummaryWriter(logdir=args.logging_dir) + else: + raise ValueError("report_to must be in ['visualdl', 'tensorboard']") + return writer + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Simple example of finetuning Kandinsky 2.2." + ) + parser.add_argument( + "--pretrained_decoder_model_name_or_path", + type=str, + default="kandinsky-community/kandinsky-2-2-decoder", + required=False, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_prior_model_name_or_path", + type=str, + default="kandinsky-community/kandinsky-2-2-prior", + required=False, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private, dataset). It can also be a path pointing to a local copy of a dataset in your filesystem, or to a folder containing files that 🤗 Datasets can understand", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help="A folder containing the training data. Folder contents must follow the structure described in https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file must exist to provide the captions for the images. Ignored if `dataset_name` is specified.", + ) + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing an image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help="For debugging purposes or quicker training, truncate the number of training examples to this value if set.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="kandi_2_2-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument( + "--seed", type=int, default=None, help="A seed for reproducible training." + ) + parser.add_argument( + "--resolution", + type=int, + default=512, + help="The resolution for input images, all the images in the train/validation dataset will be resized to this resolution", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", type=float, default=0.0001, help="learning rate" + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help='The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]', + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.", + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help="Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices", + ) + parser.add_argument( + "--use_ema", action="store_true", help="Whether to use EMA model." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_weight_decay", + type=float, + default=0.0, + required=False, + help="weight decay_to_use", + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer", + ) + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help="[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.", + ) + parser.add_argument( + "--report_to", + type=str, + default="visualdl", + choices=["tensorboard", "visualdl"], + help="Log writer type.", + ) + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help="Save a checkpoint of the training state every X updates.", + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help="Max number of checkpoints to store.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help="The `project_name` argument passed to Accelerator.init_trackers for more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator", + ) + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + return args + + +def get_full_repo_name( + model_id: str, organization: Optional[str] = None, token: Optional[str] = None +): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +DATASET_NAME_MAPPING = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), +} + + +def main(): + args = parse_args() + rank = paddle.distributed.get_rank() + is_main_process = rank == 0 + num_processes = paddle.distributed.get_world_size() + if num_processes > 1: + paddle.distributed.init_parallel_env() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(args.output_dir).name, token=args.hub_token + ) + else: + repo_name = args.hub_model_id + create_repo(repo_name, exist_ok=True, token=args.hub_token) + repo = Repository( + args.output_dir, clone_from=repo_name, token=args.hub_token + ) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + + noise_scheduler = DDPMScheduler( + beta_schedule="squaredcos_cap_v2", prediction_type="sample" + ) + image_processor = CLIPImageProcessor.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="image_processor" + ) + tokenizer = CLIPTokenizer.from_pretrained( + url_or_path_join(args.pretrained_prior_model_name_or_path, "tokenizer") + ) + + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="image_encoder" + ) + text_encoder = CLIPTextModelWithProjection.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="text_encoder" + ) + prior = PriorTransformer.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="prior" + ) + + freeze_params(image_encoder.parameters()) + freeze_params(text_encoder.parameters()) + + if args.use_ema: + ema_prior = PriorTransformer.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="prior" + ) + ema_prior = EMAModel(ema_prior.parameters()) + + def compute_snr(timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod[timesteps].cast("float32") + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[timesteps].cast( + "float32" + ) + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + # 6. Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = ( + dataset_columns[0] if dataset_columns is not None else column_names[0] + ) + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + + if args.caption_column is None: + caption_column = ( + dataset_columns[1] if dataset_columns is not None else column_names[1] + ) + else: + caption_column = args.caption_column + print("caption_column:", caption_column) + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer( + captions, + max_length=tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pd", + return_attention_mask=True, + ) + text_input_ids = inputs.input_ids + text_mask = inputs.attention_mask.cast(bool) + + return text_input_ids, text_mask + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + examples["clip_pixel_values"] = image_processor( + images, return_tensors="pd" + ).pixel_values + examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples) + return examples + + with main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = ( + dataset["train"] + .shuffle(seed=args.seed) + .select(range(args.max_train_samples)) + ) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn(examples): + clip_pixel_values = paddle.stack( + x=[example["clip_pixel_values"] for example in examples] + ).cast("float32") + text_input_ids = paddle.stack( + x=[example["text_input_ids"] for example in examples] + ) + text_mask = paddle.stack(x=[example["text_mask"] for example in examples]) + return { + "clip_pixel_values": clip_pixel_values, + "text_input_ids": text_input_ids, + "text_mask": text_mask, + } + + train_sampler = ( + DistributedBatchSampler( + train_dataset, batch_size=args.train_batch_size, shuffle=False + ) + if num_processes > 1 + else BatchSampler( + train_dataset, batch_size=args.train_batch_size, shuffle=False + ) + ) + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + if num_processes > 1: + prior = paddle.DataParallel(prior) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + learning_rate=args.learning_rate, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + # Initialize the optimizer + optimizer = AdamW( + learning_rate=lr_scheduler, + parameters=prior.parameters(), + beta1=args.adam_beta1, + beta2=args.adam_beta2, + weight_decay=args.adam_weight_decay, + epsilon=args.adam_epsilon, + grad_clip=nn.ClipGradByGlobalNorm(args.max_grad_norm) + if args.max_grad_norm > 0 + else None, + ) + + clip_mean = prior.clip_mean.clone() + clip_std = prior.clip_std.clone() + + if is_main_process: + logger.info("----------- Configuration Arguments -----------") + for arg, value in sorted(vars(args).items()): + logger.info("%s: %s" % (arg, value)) + logger.info("------------------------------------------------") + writer = get_report_to(args) + + # Train! + total_batch_size = ( + args.train_batch_size * num_processes * args.gradient_accumulation_steps + ) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process) + progress_bar.set_description("Train Steps") + global_step = 0 + + # # Keep vae in eval model as we don't train these + text_encoder.eval() + image_encoder.eval() + + prior.train() + + clip_mean = clip_mean.cast("float32") + clip_std = clip_std.cast("float32") + for epoch in range(args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + text_input_ids, text_mask, clip_images = ( + batch["text_input_ids"], + batch["text_mask"], + batch["clip_pixel_values"].cast("float32"), + ) + with paddle.no_grad(): + text_encoder_output = text_encoder(text_input_ids) + prompt_embeds = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + image_embeds = image_encoder(clip_images).image_embeds + + noise = paddle.randn(shape=image_embeds.shape, dtype=image_embeds.dtype) + + bsz = image_embeds.shape[0] + timesteps = paddle.randint(low=0, high=noise_scheduler.config.num_train_timesteps, shape=(bsz,)) + timesteps = timesteps.astype(dtype='int64') + + image_embeds = (image_embeds - clip_mean) / clip_std + noisy_latents = noise_scheduler.add_noise( + image_embeds, noise, timesteps + ) + target = image_embeds + model_pred = prior( + noisy_latents, + timestep=timesteps, + proj_embedding=prompt_embeds, + encoder_hidden_states=text_encoder_hidden_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if args.snr_gamma is None: + loss = F.mse_loss( + model_pred.cast("float32"), + target.cast("float32"), + reduction="mean", + ) + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(timesteps) + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( + paddle.stack( + [snr, args.snr_gamma * paddle.ones_like(timesteps)], + axis=1, + ).min(1)[0] + / snr + ) + # We first calculate the original loss. Then we mean over the non-batch dimensions and + # rebalance the sample-wise losses with their respective loss weights. + # Finally, we take the mean of the rebalanced loss. + loss = F.mse_loss( + model_pred.cast("float32"), + target.cast("float32"), + reduction="none", + ) + loss = ( + loss.mean(axis=list(range(1, len(loss.shape)))) * mse_loss_weights + ) + loss = loss.mean() + + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + loss.backward() + + if (step + 1) % args.gradient_accumulation_steps == 0: + if num_processes > 1 and args.gradient_checkpointing: + fused_allreduce_gradients(prior.parameters(), None) + optimizer.step() + lr_scheduler.step() + optimizer.clear_grad() + progress_bar.update(1) + global_step += 1 + step_loss = loss.item() * args.gradient_accumulation_steps + + if args.use_ema: + ema_prior.step(prior.parameters()) + logs = { + "epoch": str(epoch).zfill(4), + "step_loss": round(step_loss, 10), + "lr": lr_scheduler.get_lr(), + } + progress_bar.set_postfix(**logs) + + if is_main_process: + for name, val in logs.items(): + if name == "epoch": + continue + writer.add_scalar(f"train/{name}", val, global_step) + + if global_step % args.checkpointing_steps == 0: + save_path = os.path.join( + args.output_dir, f"checkpoint-{global_step}" + ) + unwrap_model(prior).save_pretrained( + os.path.join(save_path, "prior") + ) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using the trained modules and save it. + if is_main_process: + writer.close() + prior = unwrap_model(prior) + if args.use_ema: + ema_prior.copy_to(prior.parameters()) + + pipeline = AutoPipelineForText2Image.from_pretrained( + args.pretrained_decoder_model_name_or_path, + prior_image_encoder=image_encoder, + prior_text_encoder=text_encoder, + prior_prior=prior, + prior_scheduler=noise_scheduler, + prior_tokenizer=tokenizer, + prior_image_processor=image_processor, + ) + pipeline.prior_pipe.save_pretrained(args.output_dir) + + if args.push_to_hub: + repo.push_to_hub( + commit_message="End of training", blocking=False, auto_lfs_prune=True + ) + + +if __name__ == "__main__": + main() diff --git a/ppdiffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_prior_lora.py b/ppdiffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_prior_lora.py new file mode 100644 index 000000000..230ba81f6 --- /dev/null +++ b/ppdiffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_prior_lora.py @@ -0,0 +1,876 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import gc +import math +import os +import random +from pathlib import Path +from typing import Optional + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from huggingface_hub import HfFolder, Repository, create_repo, whoami +from paddle.distributed.fleet.utils.hybrid_parallel_util import ( + fused_allreduce_gradients, +) +from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler +from paddle.optimizer import AdamW +from paddlenlp.trainer import set_seed +from paddlenlp.transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, + CLIPTokenizer, + CLIPTextModelWithProjection, +) +from paddlenlp.utils.log import logger +from tqdm.auto import tqdm + +from ppdiffusers import ( + AutoPipelineForText2Image, + DDPMScheduler, + PriorTransformer, + UnCLIPScheduler, +) +from ppdiffusers.optimization import get_scheduler +from ppdiffusers.loaders import AttnProcsLayers, LoraLoaderMixin +from ppdiffusers.models.attention_processor import LoRAAttnProcessor +from ppdiffusers.training_utils import ( + freeze_params, + main_process_first, + unwrap_model, +) +from ppdiffusers.utils import check_min_version +from datasets import load_dataset + +check_min_version("0.16.1") + + +def url_or_path_join(*path_list): + return ( + os.path.join(*path_list) + if os.path.isdir(os.path.join(*path_list)) + else "/".join(path_list) + ) + + +def save_model_card( + repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None +): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +tags: +- kandinsky +- text-to-image +- ppdiffusers +- lora +inference: true +--- + """ + model_card = f""" +# LoRA text2image fine-tuning - {repo_id} +These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n +{img_str} +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def get_report_to(args): + if args.report_to == "visualdl": + from visualdl import LogWriter + + writer = LogWriter(logdir=args.logging_dir) + elif args.report_to == "tensorboard": + from tensorboardX import SummaryWriter + + writer = SummaryWriter(logdir=args.logging_dir) + else: + raise ValueError("report_to must be in ['visualdl', 'tensorboard']") + return writer + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Simple example of finetuning Kandinsky 2.2." + ) + parser.add_argument( + "--pretrained_decoder_model_name_or_path", + type=str, + default="kandinsky-community/kandinsky-2-2-decoder", + required=False, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_prior_model_name_or_path", + type=str, + default="kandinsky-community/kandinsky-2-2-prior", + required=False, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private, dataset). It can also be a path pointing to a local copy of a dataset in your filesystem, or to a folder containing files that 🤗 Datasets can understand", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help="A folder containing the training data. Folder contents must follow the structure described in https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file must exist to provide the captions for the images. Ignored if `dataset_name` is specified.", + ) + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing an image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help="For debugging purposes or quicker training, truncate the number of training examples to this value if set.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="kandi_2_2-model-finetuned-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument( + "--seed", type=int, default=None, help="A seed for reproducible training." + ) + parser.add_argument( + "--resolution", + type=int, + default=512, + help="The resolution for input images, all the images in the train/validation dataset will be resized to this resolution", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", type=float, default=0.0001, help="learning rate" + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help='The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]', + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.", + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help="Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_weight_decay", + type=float, + default=0.0, + required=False, + help="weight decay_to_use", + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer", + ) + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help="[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.", + ) + parser.add_argument( + "--report_to", + type=str, + default="visualdl", + choices=["tensorboard", "visualdl"], + help="Log writer type.", + ) + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help="Save a checkpoint of the training state every X updates.", + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help="Max number of checkpoints to store.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is sampled during training for inference.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help="The `project_name` argument passed to Accelerator.init_trackers for more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator", + ) + parser.add_argument( + "--lora_rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + return args + + +def get_full_repo_name( + model_id: str, organization: Optional[str] = None, token: Optional[str] = None +): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +DATASET_NAME_MAPPING = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), +} + + +def main(): + args = parse_args() + rank = paddle.distributed.get_rank() + is_main_process = rank == 0 + num_processes = paddle.distributed.get_world_size() + if num_processes > 1: + paddle.distributed.init_parallel_env() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(args.output_dir).name, token=args.hub_token + ) + else: + repo_name = args.hub_model_id + create_repo(repo_name, exist_ok=True, token=args.hub_token) + repo = Repository( + args.output_dir, clone_from=repo_name, token=args.hub_token + ) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + + noise_scheduler = DDPMScheduler( + beta_schedule="squaredcos_cap_v2", prediction_type="sample" + ) + image_processor = CLIPImageProcessor.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="image_processor" + ) + tokenizer = CLIPTokenizer.from_pretrained( + url_or_path_join(args.pretrained_prior_model_name_or_path, "tokenizer") + ) + + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="image_encoder" + ) + text_encoder = CLIPTextModelWithProjection.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="text_encoder" + ) + prior = PriorTransformer.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="prior" + ) + + freeze_params(image_encoder.parameters()) + freeze_params(text_encoder.parameters()) + + # Set correct lora layers + prior_lora_attn_procs = {} + for name, attn_processor in prior.attn_processors.items(): + prior_lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=2048, + rank=args.lora_rank, + ) + + prior.set_attn_processor(prior_lora_attn_procs) + prior_lora_layers = AttnProcsLayers(prior.attn_processors) + + def compute_snr(timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod[timesteps].cast("float32") + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[timesteps].cast( + "float32" + ) + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + # 6. Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = ( + dataset_columns[0] if dataset_columns is not None else column_names[0] + ) + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + + if args.caption_column is None: + caption_column = ( + dataset_columns[1] if dataset_columns is not None else column_names[1] + ) + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer( + captions, + max_length=tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pd", + return_attention_mask=True, + ) + text_input_ids = inputs.input_ids + text_mask = inputs.attention_mask.cast(bool) + + return text_input_ids, text_mask + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + examples["clip_pixel_values"] = image_processor( + images, return_tensors="pd" + ).pixel_values + examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples) + return examples + + with main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = ( + dataset["train"] + .shuffle(seed=args.seed) + .select(range(args.max_train_samples)) + ) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn(examples): + clip_pixel_values = paddle.stack( + x=[example["clip_pixel_values"] for example in examples] + ).cast("float32") + text_input_ids = paddle.stack( + x=[example["text_input_ids"] for example in examples] + ) + text_mask = paddle.stack(x=[example["text_mask"] for example in examples]) + return { + "clip_pixel_values": clip_pixel_values, + "text_input_ids": text_input_ids, + "text_mask": text_mask, + } + + train_sampler = ( + DistributedBatchSampler( + train_dataset, batch_size=args.train_batch_size, shuffle=False + ) + if num_processes > 1 + else BatchSampler( + train_dataset, batch_size=args.train_batch_size, shuffle=False + ) + ) + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + if num_processes > 1: + prior = paddle.DataParallel(prior) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + learning_rate=args.learning_rate, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + # Initialize the optimizer + optimizer = AdamW( + learning_rate=lr_scheduler, + parameters=prior_lora_layers.parameters(), + beta1=args.adam_beta1, + beta2=args.adam_beta2, + weight_decay=args.adam_weight_decay, + epsilon=args.adam_epsilon, + grad_clip=nn.ClipGradByGlobalNorm(args.max_grad_norm) + if args.max_grad_norm > 0 + else None, + ) + + clip_mean = prior.clip_mean.clone() + clip_std = prior.clip_std.clone() + + if is_main_process: + logger.info("----------- Configuration Arguments -----------") + for arg, value in sorted(vars(args).items()): + logger.info("%s: %s" % (arg, value)) + logger.info("------------------------------------------------") + writer = get_report_to(args) + + # Train! + total_batch_size = ( + args.train_batch_size * num_processes * args.gradient_accumulation_steps + ) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process) + progress_bar.set_description("Train Steps") + global_step = 0 + + # # Keep vae in eval model as we don't train these + text_encoder.eval() + image_encoder.eval() + + prior.train() + + clip_mean = clip_mean.cast("float32") + clip_std = clip_std.cast("float32") + for epoch in range(args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + text_input_ids, text_mask, clip_images = ( + batch["text_input_ids"], + batch["text_mask"], + batch["clip_pixel_values"].cast("float32"), + ) + with paddle.no_grad(): + text_encoder_output = text_encoder(text_input_ids) + prompt_embeds = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + image_embeds = image_encoder(clip_images).image_embeds + + noise = paddle.randn(shape=image_embeds.shape, dtype=image_embeds.dtype) + + bsz = image_embeds.shape[0] + timesteps = paddle.randint(low=0, high=noise_scheduler.config.num_train_timesteps, shape=(bsz,)) + timesteps = timesteps.astype(dtype='int64') + + image_embeds = (image_embeds - clip_mean) / clip_std + noisy_latents = noise_scheduler.add_noise( + image_embeds, noise, timesteps + ) + target = image_embeds + model_pred = prior( + noisy_latents, + timestep=timesteps, + proj_embedding=prompt_embeds, + encoder_hidden_states=text_encoder_hidden_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if args.snr_gamma is None: + loss = F.mse_loss( + model_pred.cast("float32"), + target.cast("float32"), + reduction="mean", + ) + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(timesteps) + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( + paddle.stack( + [snr, args.snr_gamma * paddle.ones_like(timesteps)], + axis=1, + ).min(1)[0] + / snr + ) + # We first calculate the original loss. Then we mean over the non-batch dimensions and + # rebalance the sample-wise losses with their respective loss weights. + # Finally, we take the mean of the rebalanced loss. + loss = F.mse_loss( + model_pred.cast("float32"), + target.cast("float32"), + reduction="none", + ) + loss = ( + loss.mean(axis=list(range(1, len(loss.shape)))) * mse_loss_weights + ) + loss = loss.mean() + + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + loss.backward() + + if (step + 1) % args.gradient_accumulation_steps == 0: + if num_processes > 1 and args.gradient_checkpointing: + fused_allreduce_gradients(prior_lora_layers.parameters(), None) + optimizer.step() + lr_scheduler.step() + optimizer.clear_grad() + progress_bar.update(1) + global_step += 1 + step_loss = loss.item() * args.gradient_accumulation_steps + + logs = { + "epoch": str(epoch).zfill(4), + "step_loss": round(step_loss, 10), + "lr": lr_scheduler.get_lr(), + } + progress_bar.set_postfix(**logs) + + if is_main_process: + for name, val in logs.items(): + if name == "epoch": + continue + writer.add_scalar(f"train/{name}", val, global_step) + + if global_step % args.checkpointing_steps == 0: + save_path = os.path.join( + args.output_dir, f"checkpoint-{global_step}" + ) + LoraLoaderMixin.save_lora_weights( + save_directory=save_path, + unet_lora_layers=prior_lora_layers, + ) + logger.info(f"Saved lora weights to {save_path}") + + if global_step >= args.max_train_steps: + break + if is_main_process: + prior_scheduler = UnCLIPScheduler( + beta_schedule="squaredcos_cap_v2", prediction_type="sample" + ) + if ( + args.validation_prompt is not None + and epoch % args.validation_epochs == 0 + ): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + pipeline = AutoPipelineForText2Image.from_pretrained( + args.pretrained_decoder_model_name_or_path, + prior_prior=unwrap_model(prior), + prior_text_encoder=text_encoder, + prior_image_encoder=image_encoder, + prior_image_processor=image_processor, + prior_scheduler=prior_scheduler, + prior_tokenizer=tokenizer, + ) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = ( + paddle.Generator().manual_seed(args.seed) if args.seed else None + ) + images = [ + pipeline( + args.validation_prompt, + num_inference_steps=30, + generator=generator, + ).images[0] + for _ in range(args.num_validation_images) + ] + np_images = np.stack([np.asarray(img) for img in images]) + + if args.report_to == "tensorboard": + writer.add_images( + "validation", np_images, epoch, dataformats="NHWC" + ) + else: + writer.add_image("validation", np_images, epoch, dataformats="NHWC") + + del pipeline + gc.collect() + prior.train() + + # Create the pipeline using the trained modules and save it. + if is_main_process: + LoraLoaderMixin.save_lora_weights( + save_directory=args.output_dir, + unet_lora_layers=prior_lora_layers, + ) + + if args.push_to_hub: + save_model_card( + repo_name, + images=images, + base_model=args.pretrained_prior_model_name_or_path, + prompt=args.instance_prompt, + repo_folder=args.output_dir, + ) + repo.push_to_hub( + commit_message="End of training", blocking=False, auto_lfs_prune=True + ) + + pipeline = AutoPipelineForText2Image.from_pretrained( + args.pretrained_decoder_model_name_or_path, + prior_prior=unwrap_model(prior), + prior_text_encoder=text_encoder, + prior_image_encoder=image_encoder, + prior_image_processor=image_processor, + prior_scheduler=prior_scheduler, + prior_tokenizer=tokenizer, + ) + + pipeline.prior_prior.load_attn_procs(args.output_dir) + # run inference + if args.validation_prompt and args.num_validation_images > 0: + generator = paddle.Generator().manual_seed(args.seed) if args.seed else None + images = [ + pipeline( + args.validation_prompt, num_inference_steps=30, generator=generator + ).images[0] + for _ in range(args.num_validation_images) + ] + np_images = np.stack([np.asarray(img) for img in images]) + + if args.report_to == "tensorboard": + writer.add_images("test", np_images, epoch, dataformats="NHWC") + else: + writer.add_image("test", np_images, epoch, dataformats="NHWC") + + writer.close() + if args.push_to_hub: + repo.push_to_hub( + commit_message="End of training", blocking=False, auto_lfs_prune=True + ) + + +if __name__ == "__main__": + main() diff --git a/ppdiffusers/ppdiffusers/models/attention_processor.py b/ppdiffusers/ppdiffusers/models/attention_processor.py index 0ed4e1e5d..8c9cd8a80 100644 --- a/ppdiffusers/ppdiffusers/models/attention_processor.py +++ b/ppdiffusers/ppdiffusers/models/attention_processor.py @@ -786,8 +786,8 @@ def __call__( value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) - key = paddle.concat([encoder_hidden_states_key_proj, key], axis=1) - value = paddle.concat([encoder_hidden_states_value_proj, value], axis=1) + key = paddle.concat([encoder_hidden_states_key_proj, key], axis=2) + value = paddle.concat([encoder_hidden_states_value_proj, value], axis=2) else: key = encoder_hidden_states_key_proj value = encoder_hidden_states_value_proj From d0051f89d0f7fdf9470de2a5c88fbed368b61111 Mon Sep 17 00:00:00 2001 From: Tsaiyue Date: Sun, 14 Jan 2024 16:29:18 +0800 Subject: [PATCH 2/2] update README & PriorTransformer --- .../kandinsky2_2/text_to_image/README.md | 28 +++++++++---------- .../ppdiffusers/models/prior_transformer.py | 4 +-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/ppdiffusers/examples/kandinsky2_2/text_to_image/README.md b/ppdiffusers/examples/kandinsky2_2/text_to_image/README.md index ab0b5cb2b..fa0c785e5 100644 --- a/ppdiffusers/examples/kandinsky2_2/text_to_image/README.md +++ b/ppdiffusers/examples/kandinsky2_2/text_to_image/README.md @@ -11,8 +11,8 @@ ___这个脚本是试验性的。该脚本会对整个'decoder'或'prior'模型 在运行这个训练代码前,我们需要安装下面的训练依赖: ```bash -# 安装2.5.2版本的paddlepaddle-gpu,当前我们选择了cuda11.7的版本,可以查看 https://www.paddlepaddle.org.cn/ 寻找自己适合的版本 -python -m pip install paddlepaddle-gpu==2.5.2.post117 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html +# 安装2.6.0版本的paddlepaddle-gpu,当前我们选择了cuda12.0的版本,可以查看 https://www.paddlepaddle.org.cn/ 寻找自己适合的版本 +python -m pip install paddlepaddle-gpu==2.6.0.post120 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html # 安装所需的依赖, 如果提示权限不够,请在最后增加 --user 选项 pip install -r requirements.txt @@ -64,12 +64,12 @@ python -u train_text_to_image_decoder.py \ --output_dir="kandi2-decoder-pokemon-model" ``` -训练完成后,模型将保存在命令中指定的 `output_dir` 目录中。在本例中是 `kandi22-decoder-pokemon-model`。要加载微调后的模型进行推理,后将该路径传递给 `AutoPipelineForText2Image` : +训练完成后,模型将保存在命令中指定的 `output_dir` 目录中。在本例中是 `kandi22-decoder-pokemon-model`。要加载微调后的模型进行推理,后将该路径传递给 `KandinskyV22CombinedPipeline` : ```python -from ppdiffusers import AutoPipelineForText2Image +from ppdiffusers import KandinskyV22CombinedPipeline -pipe = AutoPipelineForText2Image.from_pretrained(output_dir) +pipe = KandinskyV22CombinedPipeline.from_pretrained(output_dir) prompt='A robot pokemon, 4k photo' images = pipe(prompt=prompt).images @@ -79,13 +79,13 @@ images[0].save("robot-pokemon.png") Checkpoints只保存 unet,因此要从checkpoints运行推理只需加载 unet: ```python -from ppdiffusers import AutoPipelineForText2Image, UNet2DConditionModel +from ppdiffusers import KandinskyV22CombinedPipeline, UNet2DConditionModel model_path = "path_to_saved_model" unet = UNet2DConditionModel.from_pretrained(model_path + "/checkpoint-/unet") -pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", unet=unet) +pipe = KandinskyV22CombinedPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", unet=unet) images = pipe(prompt="A robot pokemon, 4k photo").images images[0].save("robot-pokemon.png") @@ -117,11 +117,11 @@ python -u train_text_to_image_prior.py \ 要使用微调prior模型进行推理,首先需要将 `output_dir`传给 `DiffusionPipeline`从而创建一个prior pipeline。然后,从一个预训练或微调decoder以及刚刚创建的prior pipeline的所有模块中创建一个`KandinskyV22CombinedPipeline`: ```python -from ppdiffusers import AutoPipelineForText2Image, DiffusionPipeline +from ppdiffusers import KandinskyV22CombinedPipeline, DiffusionPipeline pipe_prior = DiffusionPipeline.from_pretrained(output_dir) prior_components = {"prior_" + k: v for k,v in pipe_prior.components.items()} -pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", **prior_components) +pipe = KandinskyV22CombinedPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", **prior_components) prompt='A robot pokemon, 4k photo' images = pipe(prompt=prompt, negative_prompt=negative_prompt).images @@ -218,12 +218,12 @@ python -u train_text_to_image_prior_lora.py \ #### 2.1 LoRA微调decoder推理流程 -使用上述命令训练好Kandinsky decoder模型后,就可以在加载训练好 LoRA 权重后使用 `AutoPipelineForText2Image` 进行推理。您需要传递用于加载 LoRA 权重的 `output_dir` 目录,在本例中是 `kandi22-decoder-pokemon-lora`。 +使用上述命令训练好Kandinsky decoder模型后,就可以在加载训练好 LoRA 权重后使用 `KandinskyV22CombinedPipeline` 进行推理。您需要传递用于加载 LoRA 权重的 `output_dir` 目录,在本例中是 `kandi22-decoder-pokemon-lora`。 ```python -from ppdiffusers import AutoPipelineForText2Image +from ppdiffusers import KandinskyV22CombinedPipeline -pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder") +pipe = KandinskyV22CombinedPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder") pipe.unet.load_attn_procs(output_dir) prompt='A robot pokemon, 4k photo' @@ -234,9 +234,9 @@ image.save("robot_pokemon.png") #### 2.2 LoRA微调prior推理流程 ```python -from ppdiffusers import AutoPipelineForText2Image +from ppdiffusers import KandinskyV22CombinedPipeline -pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder") +pipe = KandinskyV22CombinedPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder") pipe.prior_prior.load_attn_procs(output_dir) prompt='A robot pokemon, 4k photo' diff --git a/ppdiffusers/ppdiffusers/models/prior_transformer.py b/ppdiffusers/ppdiffusers/models/prior_transformer.py index d8f0b97d6..8515f17bb 100644 --- a/ppdiffusers/ppdiffusers/models/prior_transformer.py +++ b/ppdiffusers/ppdiffusers/models/prior_transformer.py @@ -26,7 +26,7 @@ from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin - +from ..loaders import UNet2DConditionLoadersMixin @dataclass class PriorTransformerOutput(BaseOutput): @@ -41,7 +41,7 @@ class PriorTransformerOutput(BaseOutput): predicted_image_embedding: paddle.Tensor -class PriorTransformer(ModelMixin, ConfigMixin): +class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): """ A Prior Transformer model.