From da17ac0f484b28a8471004b47bddfc408969ae04 Mon Sep 17 00:00:00 2001 From: takuoko Date: Fri, 1 Dec 2023 00:58:42 +0900 Subject: [PATCH] [Feature] Support OFT (#1160) * Support OFT * add test * Update README * fix code quality * fix test * Skip 1 test * fix eps rule and add more test * feat: added examples to new OFT method * fix: removed wrong arguments from model example * fix: changed name of inference file * fix: changed prompt variable * fix docs * fix: dreambooth inference revision based on feedback * fix: review from BenjaminBossan * apply safe merge * del partially * refactor oft * refactor oft * del unused line * del unused line * fix skip in windows * skip test * Add comments about bias added place * rename orig_weights to new_weights * use inverse instead of linalg.inv * delete alpha and scaling --------- Co-authored-by: Lukas Kuhn Co-authored-by: Lukas Kuhn --- README.md | 7 +- .../oft_dreambooth_inference.ipynb | 89 ++ examples/oft_dreambooth/train_dreambooth.py | 1112 +++++++++++++++++ src/peft/__init__.py | 2 + src/peft/mapping.py | 4 + src/peft/peft_model.py | 2 + src/peft/tuners/__init__.py | 1 + src/peft/tuners/oft/__init__.py | 21 + src/peft/tuners/oft/config.py | 109 ++ src/peft/tuners/oft/layer.py | 375 ++++++ src/peft/tuners/oft/model.py | 108 ++ src/peft/utils/peft_types.py | 1 + src/peft/utils/save_and_load.py | 5 +- tests/test_config.py | 4 +- tests/test_custom_models.py | 103 +- tests/test_stablediffusion.py | 23 +- tests/testing_common.py | 6 +- 17 files changed, 1959 insertions(+), 13 deletions(-) create mode 100644 examples/oft_dreambooth/oft_dreambooth_inference.ipynb create mode 100644 examples/oft_dreambooth/train_dreambooth.py create mode 100644 src/peft/tuners/oft/__init__.py create mode 100644 src/peft/tuners/oft/config.py create mode 100644 src/peft/tuners/oft/layer.py create mode 100644 src/peft/tuners/oft/model.py diff --git a/README.md b/README.md index 79259f98ee..09846dc61c 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ Supported methods: 8. LoHa: [FedPara: Low-Rank Hadamard Product for Communication-Efficient Federated Learning](https://arxiv.org/abs/2108.06098) 9. LoKr: [KronA: Parameter Efficient Tuning with Kronecker Adapter](https://arxiv.org/abs/2212.10650) based on [Navigating Text-To-Image Customization:From LyCORIS Fine-Tuning to Model Evaluation](https://arxiv.org/abs/2309.14859) implementation 10. LoftQ: [LoftQ: LoRA-Fine-Tuning-aware Quantization for Large Language Models](https://arxiv.org/abs/2310.08659) +11. OFT: [Controlling Text-to-Image Diffusion by Orthogonal Finetuning](https://arxiv.org/abs/2306.07280) ## Getting started @@ -278,9 +279,9 @@ Find models that are supported out of the box below. Note that PEFT works with a ### Text-to-Image Generation -| Model | LoRA | LoHa | LoKr | Prefix Tuning | P-Tuning | Prompt Tuning | IA3 | -| --------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | -| Stable Diffusion | ✅ | ✅ | ✅ | | | | +| Model | LoRA | LoHa | LoKr | OFT | Prefix Tuning | P-Tuning | Prompt Tuning | IA3 | +| --------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | +| Stable Diffusion | ✅ | ✅ | ✅ | ✅ | | | | ### Image Classification diff --git a/examples/oft_dreambooth/oft_dreambooth_inference.ipynb b/examples/oft_dreambooth/oft_dreambooth_inference.ipynb new file mode 100644 index 0000000000..4a28c4040e --- /dev/null +++ b/examples/oft_dreambooth/oft_dreambooth_inference.ipynb @@ -0,0 +1,89 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "acd7b15e", + "metadata": {}, + "source": [ + "# Dreambooth with OFT\n", + "This Notebook assumes that you already ran the train_dreambooth.py script to create your own adapter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "acab479f", + "metadata": {}, + "outputs": [], + "source": [ + "from diffusers import DiffusionPipeline\n", + "from diffusers.utils import check_min_version, get_logger\n", + "from peft import PeftModel\n", + "\n", + "# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\n", + "check_min_version(\"0.10.0.dev0\")\n", + "\n", + "logger = get_logger(__name__)\n", + "\n", + "BASE_MODEL_NAME = \"stabilityai/stable-diffusion-2-1-base\"\n", + "ADAPTER_MODEL_PATH = \"INSERT MODEL PATH HERE\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipe = DiffusionPipeline.from_pretrained(\n", + " BASE_MODEL_NAME,\n", + ")\n", + "pipe.to('cuda')\n", + "pipe.unet = PeftModel.from_pretrained(pipe.unet, ADAPTER_MODEL_PATH + \"/unet\", adapter_name=\"default\")\n", + "pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, ADAPTER_MODEL_PATH + \"/text_encoder\", adapter_name=\"default\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompt = \"A photo of a sks dog\"\n", + "image = pipe(\n", + " prompt,\n", + " num_inference_steps=50,\n", + " height=512,\n", + " width=512,\n", + ").images[0]\n", + "image" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + }, + "vscode": { + "interpreter": { + "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/oft_dreambooth/train_dreambooth.py b/examples/oft_dreambooth/train_dreambooth.py new file mode 100644 index 0000000000..cacce70647 --- /dev/null +++ b/examples/oft_dreambooth/train_dreambooth.py @@ -0,0 +1,1112 @@ +import argparse +import gc +import hashlib +import itertools +import logging +import math +import os +import threading +import warnings +from contextlib import nullcontext +from pathlib import Path +from typing import Optional + +import datasets +import diffusers +import numpy as np +import psutil +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available +from huggingface_hub import HfFolder, Repository, whoami +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +from peft import get_peft_model +from peft.tuners.oft.config import OFTConfig + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.10.0.dev0") + +logger = get_logger(__name__) + +UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] # , "ff.net.0.proj"] +TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run dreambooth validation every X steps. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" + ) + parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") + + # oft args + parser.add_argument("--use_oft", action="store_true", help="Whether to use OFT for parameter efficient tuning") + parser.add_argument("--oft_r", type=int, default=8, help="OFT rank, only used if use_oft is True") + parser.add_argument("--oft_alpha", type=int, default=32, help="OFT alpha, only used if use_oft is True") + parser.add_argument("--oft_dropout", type=float, default=0.0, help="OFT dropout, only used if use_oft is True") + parser.add_argument( + "--oft_use_coft", action="store_true", help="Using constrained OFT, only used if use_oft is True" + ) + parser.add_argument( + "--oft_eps", + type=float, + default=0.0, + help="The control strength of COFT. Only has an effect if `oft_use_coft` is set to True.", + ) + + parser.add_argument( + "--oft_text_encoder_r", + type=int, + default=8, + help="OFT rank for text encoder, only used if `use_oft` and `train_text_encoder` are True", + ) + parser.add_argument( + "--oft_text_encoder_alpha", + type=int, + default=32, + help="OFT alpha for text encoder, only used if `use_oft` and `train_text_encoder` are True", + ) + parser.add_argument( + "--oft_text_encoder_dropout", + type=float, + default=0.0, + help="OFT dropout for text encoder, only used if `use_oft` and `train_text_encoder` are True", + ) + parser.add_argument( + "--oft_text_encoder_use_coft", + action="store_true", + help="Using constrained OFT on the text encoder, only used if use_oft is True", + ) + parser.add_argument( + "--oft_text_encoder_eps", + type=float, + default=0.0, + help="The control strength of COFT on the text encoder. Only has an effect if `oft_text_encoder_use_coft` is set to True.", + ) + + parser.add_argument( + "--num_dataloader_workers", type=int, default=1, help="Num of workers for the training dataloader." + ) + + parser.add_argument( + "--no_tracemalloc", + default=False, + action="store_true", + help="Flag to stop memory allocation tracing during training. This could speed up training on Windows.", + ) + + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--wandb_key", + type=str, + default=None, + help=("If report to option is set to wandb, api-key for wandb used for login to wandb "), + ) + parser.add_argument( + "--wandb_project_name", + type=str, + default=None, + help=("If report to option is set to wandb, project name in wandb for log tracking "), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +# Converting Bytes to Megabytes +def b2mb(x): + return int(x / 2**20) + + +# This context manager is used to track the peak memory usage of the process +class TorchTracemalloc: + def __enter__(self): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = torch.cuda.memory_allocated() + self.process = psutil.Process() + + self.cpu_begin = self.cpu_mem_used() + self.peak_monitoring = True + peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) + peak_monitor_thread.daemon = True + peak_monitor_thread.start() + return self + + def cpu_mem_used(self): + """get resident set size memory for the current process""" + return self.process.memory_info().rss + + def peak_monitor_func(self): + self.cpu_peak = -1 + + while True: + self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak) + + # can't sleep or will not catch the peak right (this comment is here on purpose) + # time.sleep(0.001) # 1msec + + if not self.peak_monitoring: + break + + def __exit__(self, *exc): + self.peak_monitoring = False + + gc.collect() + torch.cuda.empty_cache() + self.end = torch.cuda.memory_allocated() + self.peak = torch.cuda.max_memory_allocated() + self.used = b2mb(self.end - self.begin) + self.peaked = b2mb(self.peak - self.begin) + + self.cpu_end = self.cpu_mem_used() + self.cpu_used = b2mb(self.cpu_end - self.cpu_begin) + self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin) + # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}") + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + return example + + +def collate_fn(examples, with_prior_preservation=False): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.cat(input_ids, dim=0) + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_dir=logging_dir, + ) + if args.report_to == "wandb": + import wandb + + wandb.login(key=args.wandb_key) + wandb.init(project=args.wandb_project_name) + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) # noqa: F841 + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + noise_scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + ) # DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + + if args.use_oft: + config = OFTConfig( + r=args.oft_r, + alpha=args.oft_alpha, + target_modules=UNET_TARGET_MODULES, + module_dropout=args.oft_dropout, + init_weights=True, + coft=args.oft_use_coft, + eps=args.oft_eps, + ) + unet = get_peft_model(unet, config) + unet.print_trainable_parameters() + print(unet) + + vae.requires_grad_(False) + if not args.train_text_encoder: + text_encoder.requires_grad_(False) + elif args.train_text_encoder and args.use_oft: + config = OFTConfig( + r=args.oft_text_encoder_r, + alpha=args.oft_text_encoder_alpha, + target_modules=TEXT_ENCODER_TARGET_MODULES, + module_dropout=args.oft_text_encoder_dropout, + init_weights=True, + coft=args.oft_text_encoder_use_coft, + eps=args.oft_text_encoder_eps, + ) + text_encoder = get_peft_model(text_encoder, config) + text_encoder.print_trainable_parameters() + print(text_encoder) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + # below fails when using oft so commenting it out + if args.train_text_encoder and not args.use_oft: + text_encoder.gradient_checkpointing_enable() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() + ) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.num_dataloader_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae and text_encoder to device and cast to weight_dtype + vae.to(accelerator.device, dtype=weight_dtype) + if not args.train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = resume_global_step // num_update_steps_per_epoch + resume_step = resume_global_step % num_update_steps_per_epoch + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + if args.train_text_encoder: + text_encoder.train() + with TorchTracemalloc() if not args.no_tracemalloc else nullcontext() as tracemalloc: + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + if args.report_to == "wandb": + accelerator.print(progress_bar) + continue + + with accelerator.accumulate(unet): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + if args.report_to == "wandb": + accelerator.print(progress_bar) + global_step += 1 + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if ( + args.validation_prompt is not None + and (step + num_update_steps_per_epoch * epoch) % args.validation_steps == 0 + ): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + safety_checker=None, + revision=args.revision, + ) + # set `keep_fp32_wrapper` to True because we do not want to remove + # mixed precision hooks while we are still training + pipeline.unet = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) + pipeline.text_encoder = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + if args.seed is not None: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + else: + generator = None + images = [] + for _ in range(args.num_validation_images): + image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + import wandb + + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + if global_step >= args.max_train_steps: + break + # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage + + if not args.no_tracemalloc: + accelerator.print("GPU Memory before entering the train : {}".format(b2mb(tracemalloc.begin))) + accelerator.print("GPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.used)) + accelerator.print("GPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.peaked)) + accelerator.print( + "GPU Total Peak Memory consumed during the train (max): {}".format( + tracemalloc.peaked + b2mb(tracemalloc.begin) + ) + ) + + accelerator.print("CPU Memory before entering the train : {}".format(b2mb(tracemalloc.cpu_begin))) + accelerator.print( + "CPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.cpu_used) + ) + accelerator.print( + "CPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.cpu_peaked) + ) + accelerator.print( + "CPU Total Peak Memory consumed during the train (max): {}".format( + tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin) + ) + ) + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + if args.use_oft: + unwarpped_unet = accelerator.unwrap_model(unet) + unwarpped_unet.save_pretrained( + os.path.join(args.output_dir, "unet"), state_dict=accelerator.get_state_dict(unet) + ) + if args.train_text_encoder: + unwarpped_text_encoder = accelerator.unwrap_model(text_encoder) + unwarpped_text_encoder.save_pretrained( + os.path.join(args.output_dir, "text_encoder"), + state_dict=accelerator.get_state_dict(text_encoder), + ) + else: + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + revision=args.revision, + ) + pipeline.save_pretrained(args.output_dir) + + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/src/peft/__init__.py b/src/peft/__init__.py index 4d9380e697..75ddda498c 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -68,6 +68,8 @@ PromptTuningInit, MultitaskPromptTuningConfig, MultitaskPromptTuningInit, + OFTConfig, + OFTModel, ) from .utils import ( TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, diff --git a/src/peft/mapping.py b/src/peft/mapping.py index f69e89ec3e..60503fa985 100644 --- a/src/peft/mapping.py +++ b/src/peft/mapping.py @@ -42,6 +42,8 @@ LoraConfig, LoraModel, MultitaskPromptTuningConfig, + OFTConfig, + OFTModel, PrefixTuningConfig, PromptEncoderConfig, PromptTuningConfig, @@ -73,6 +75,7 @@ "ADALORA": AdaLoraConfig, "IA3": IA3Config, "MULTITASK_PROMPT_TUNING": MultitaskPromptTuningConfig, + "OFT": OFTConfig, } PEFT_TYPE_TO_TUNER_MAPPING = { @@ -81,6 +84,7 @@ "LOKR": LoKrModel, "ADALORA": AdaLoraModel, "IA3": IA3Model, + "OFT": OFTModel, } diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 24ef48c22e..79bf8e4610 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -44,6 +44,7 @@ LoKrModel, LoraModel, MultitaskPromptEmbedding, + OFTModel, PrefixEncoder, PromptEmbedding, PromptEncoder, @@ -77,6 +78,7 @@ PeftType.ADALORA: AdaLoraModel, PeftType.ADAPTION_PROMPT: AdaptionPromptModel, PeftType.IA3: IA3Model, + PeftType.OFT: OFTModel, } diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index 666e29d997..f5f665dd99 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -27,3 +27,4 @@ from .prefix_tuning import PrefixEncoder, PrefixTuningConfig from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTuningConfig, MultitaskPromptTuningInit +from .oft import OFTConfig, OFTModel diff --git a/src/peft/tuners/oft/__init__.py b/src/peft/tuners/oft/__init__.py new file mode 100644 index 0000000000..456c46ee07 --- /dev/null +++ b/src/peft/tuners/oft/__init__.py @@ -0,0 +1,21 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import OFTConfig +from .layer import Conv2d, Linear, OFTLayer +from .model import OFTModel + + +__all__ = ["OFTConfig", "OFTModel", "Conv2d", "Linear", "OFTLayer"] diff --git a/src/peft/tuners/oft/config.py b/src/peft/tuners/oft/config.py new file mode 100644 index 0000000000..6b43255d1d --- /dev/null +++ b/src/peft/tuners/oft/config.py @@ -0,0 +1,109 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import List, Optional, Union + +from peft.tuners.lycoris_utils import LycorisConfig +from peft.utils import PeftType + + +@dataclass +class OFTConfig(LycorisConfig): + """ + This is the configuration class to store the configuration of a [`OFTModel`]. + + Args: + r (`int`): OFT rank. + module_dropout (`int`): The dropout probability for disabling OFT modules during training. + target_modules (`Union[List[str],str]`): The names of the modules to apply OFT to. + init_weights (`bool`): Whether to perform initialization of OFT weights. + layers_to_transform (`Union[List[int],int]`): + The layer indexes to transform, if this argument is specified, it will apply the OFT transformations on the + layer indexes that are specified in this list. If a single integer is passed, it will apply the OFT + transformations on the layer at this index. + layers_pattern (`str`): + The layer pattern name, used only if `layers_to_transform` is different from `None` and if the layer + pattern is not in the common layers pattern. + rank_pattern (`dict`): + The mapping from layer names or regexp expression to ranks which are different from the default rank + specified by `r`. + modules_to_save (`List[str]`): The names of modules to be set as trainable except OFT parameters. + coft (`bool`): Whether to use the constrainted variant of OFT or not. + eps (`float`): + The control strength of COFT. The freedom of rotation. Only has an effect if `coft` is set to True. + block_share (`bool`): Whether to share the OFT parameters between blocks or not. + """ + + r: int = field(default=8, metadata={"help": "OFT rank"}) + module_dropout: float = field( + default=0.0, metadata={"help": "The dropout probability for disabling OFT modules during training"} + ) + target_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={ + "help": "List of module names or regex expression of the module names to replace with OFT." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " + }, + ) + init_weights: bool = field( + default=True, + metadata={ + "help": ( + "Whether to initialize the weights of the OFT layers with their default initialization. Don't change " + "this setting, except if you know exactly what you're doing." + ), + }, + ) + layers_to_transform: Optional[Union[List[int], int]] = field( + default=None, + metadata={ + "help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index." + }, + ) + layers_pattern: Optional[str] = field( + default=None, + metadata={ + "help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern." + }, + ) + modules_to_save: Optional[List[str]] = field( + default=None, + metadata={ + "help": "List of modules apart from OFT layers to be set as trainable and saved in the final checkpoint. " + "For example, in Sequence Classification or Token Classification tasks, " + "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." + }, + ) + coft: bool = field( + default=False, + metadata={"help": "Whether to use the constrainted variant of OFT or not."}, + ) + eps: float = field( + default=6e-5, + metadata={ + "help": "The control strength of COFT. The freedom of rotation. Only has an effect if `coft` is set to True." + }, + ) + block_share: bool = field( + default=False, + metadata={"help": "Whether to share the OFT parameters between blocks or not."}, + ) + + def __post_init__(self): + self.peft_type = PeftType.OFT + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) diff --git a/src/peft/tuners/oft/layer.py b/src/peft/tuners/oft/layer.py new file mode 100644 index 0000000000..b9e0d011b3 --- /dev/null +++ b/src/peft/tuners/oft/layer.py @@ -0,0 +1,375 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from typing import Any, List, Optional, Set, Tuple + +import torch +import torch.nn as nn + +from peft.tuners.lycoris_utils import LycorisLayer + + +class OFTLayer(nn.Module, LycorisLayer): + # All names of layers that may contain adapter weights + adapter_layer_names = ("oft_r",) + # other_param_names is defined on parent class + + def __init__(self, base_layer: nn.Module): + super().__init__() + LycorisLayer.__init__(self, base_layer) + + # OFT info + self.oft_r = nn.ParameterDict({}) + self.coft = {} + self.eps = {} + self.block_share = {} + + @property + def _available_adapters(self) -> Set[str]: + return {*self.oft_r} + + def create_adapter_parameters(self, adapter_name: str, r: int, shape: Tuple[int, ...], block_share: bool): + if block_share: + self.oft_r[adapter_name] = nn.Parameter(torch.empty(1, math.ceil(shape[0] / r), math.ceil(shape[0] / r))) + else: + self.oft_r[adapter_name] = nn.Parameter(torch.empty(r, math.ceil(shape[0] / r), math.ceil(shape[0] / r))) + + def reset_adapter_parameters(self, adapter_name: str): + nn.init.zeros_(self.oft_r[adapter_name]) + + def reset_adapter_parameters_random(self, adapter_name: str): + nn.init.kaiming_uniform_(self.oft_r[adapter_name], a=math.sqrt(5)) + + def update_layer( + self, + adapter_name: str, + r: int, + module_dropout: float, + init_weights: bool, + coft: bool = False, + eps: float = 6e-5, + block_share: bool = False, + **kwargs, + ) -> None: + """Internal function to create oft adapter + + Args: + adapter_name (`str`): Name for the adapter to add. + r (`int`): Rank for the added adapter. + module_dropout (`float`): The dropout probability for disabling adapter during training. + init_weights (`bool`): Whether to initialize weights. + coft (`bool`): Whether to use the constrainted variant of OFT or not. + eps (`float`): + The control strength of COFT. The freedom of rotation. Only has an effect if `coft` is set to True. + block_share (`bool`): Whether to share the OFT parameters between blocks or not. + """ + + self.r[adapter_name] = r + self.module_dropout[adapter_name] = module_dropout + self.coft[adapter_name] = coft + self.block_share[adapter_name] = block_share + + # Determine shape of OFT weights + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + shape = tuple(base_layer.weight.shape) + elif isinstance(base_layer, nn.Conv2d): + shape = ( + base_layer.out_channels, + base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1], + ) + else: + raise TypeError(f"OFT is not implemented for base layers of type {type(base_layer).__name__}") + + self.eps[adapter_name] = eps * math.ceil(shape[0] / r) * math.ceil(shape[0] / r) + + # Create weights with provided shape + self.create_adapter_parameters(adapter_name, r, shape, block_share) + + # Initialize weights + if init_weights: + self.reset_adapter_parameters(adapter_name) + else: + self.reset_adapter_parameters_random(adapter_name) + + # Move new weights to device + weight = getattr(self.get_base_layer(), "weight", None) + if weight is not None: + # the layer is already completely initialized, this is an update + if weight.dtype.is_floating_point or weight.dtype.is_complex: + self.to(weight.device, dtype=weight.dtype) + else: + self.to(weight.device) + self.set_adapter(self.active_adapters) + + def unscale_layer(self, scale=None) -> None: + # scale is not used + pass + + def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: + if self.merged: + warnings.warn( + f"Already following adapters were merged {','.join(self.merged_adapters)}. " + f"You are now additionally merging {','.join(self.active_adapters)}." + ) + if adapter_names is None: + adapter_names = self.active_adapters + + for active_adapter in adapter_names: + if active_adapter in self._available_adapters: + base_layer = self.get_base_layer() + + orig_weights = base_layer.weight.data + if isinstance(base_layer, nn.Linear): + orig_weights = torch.transpose(orig_weights, 0, 1) + elif isinstance(base_layer, nn.Conv2d): + orig_weights = orig_weights.view( + [ + base_layer.out_channels, + base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1], + ] + ) + orig_weights = torch.transpose(orig_weights, 0, 1) + delta_weight = self.get_delta_weight(active_adapter) + if orig_weights.shape[1] != delta_weight.shape[1]: + # when in channels is not divisible by r + delta_weight = delta_weight[: orig_weights.shape[1], : orig_weights.shape[1]] + new_weights = torch.mm(orig_weights, delta_weight) + if isinstance(base_layer, nn.Linear): + new_weights = torch.transpose(new_weights, 0, 1) + elif isinstance(base_layer, nn.Conv2d): + new_weights = torch.transpose(new_weights, 0, 1) + new_weights = new_weights.view( + [ + base_layer.out_channels, + base_layer.in_channels, + base_layer.kernel_size[0], + base_layer.kernel_size[1], + ] + ) + + if safe_merge and not torch.isfinite(new_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = new_weights + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self._available_adapters: + base_layer = self.get_base_layer() + new_weights = base_layer.weight.data + if isinstance(base_layer, nn.Linear): + new_weights = torch.transpose(new_weights, 0, 1) + elif isinstance(base_layer, nn.Conv2d): + new_weights = new_weights.view( + [ + base_layer.out_channels, + base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1], + ] + ) + new_weights = torch.transpose(new_weights, 0, 1) + delta_weight = self.get_delta_weight(active_adapter) + if new_weights.shape[1] != delta_weight.shape[1]: + # when in channels is not divisible by r + delta_weight = delta_weight[: new_weights.shape[1], : new_weights.shape[1]] + delta_inv = torch.inverse(delta_weight) + orig_weights = torch.mm(new_weights, delta_inv) + + if isinstance(base_layer, nn.Linear): + orig_weights = torch.transpose(orig_weights, 0, 1) + elif isinstance(base_layer, nn.Conv2d): + orig_weights = torch.transpose(orig_weights, 0, 1) + orig_weights = orig_weights.reshape( + [ + base_layer.out_channels, + base_layer.in_channels, + base_layer.kernel_size[0], + base_layer.kernel_size[1], + ] + ) + base_layer.weight.data = orig_weights + + def get_delta_weight(self, adapter_name: str) -> torch.Tensor: + rank = self.r[adapter_name] + coft = self.coft[adapter_name] + eps = self.eps[adapter_name] + opt_r = self.oft_r[adapter_name] + + if coft: + with torch.no_grad(): + opt_r.copy_(self._project_batch(opt_r, eps=eps)) + + orth_rotate = self._cayley_batch(opt_r) + weight = self._block_diagonal(orth_rotate, rank) + + return weight + + # Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L144 + def _cayley_batch(self, data: torch.Tensor) -> torch.Tensor: + b, r, c = data.shape + # Ensure the input matrix is skew-symmetric + skew = 0.5 * (data - data.transpose(1, 2)) + I = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c) + + # Perform the Cayley parametrization + Q = torch.bmm(I - skew, torch.inverse(I + skew)) + + return Q + + # Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L155 + def _block_diagonal(self, oft_r: torch.Tensor, rank: int) -> torch.Tensor: + if oft_r.shape[0] == 1: + # block share + blocks = [oft_r[0, ...] for i in range(rank)] + else: + blocks = [oft_r[i, ...] for i in range(rank)] + + # Use torch.block_diag to create the block diagonal matrix + A = torch.block_diag(*blocks) + + return A + + # Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L52 + def _project_batch(self, oft_r, eps=1e-5): + # scaling factor for each of the smaller block matrix + eps = eps * 1 / torch.sqrt(torch.tensor(oft_r.shape[0])) + I = ( + torch.zeros((oft_r.size(1), oft_r.size(1)), device=oft_r.device, dtype=oft_r.dtype) + .unsqueeze(0) + .expand_as(oft_r) + ) + diff = oft_r - I + norm_diff = torch.norm(oft_r - I, dim=(1, 2), keepdim=True) + mask = (norm_diff <= eps).bool() + out = torch.where(mask, oft_r, I + eps * (diff / norm_diff)) + return out + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + if len(result.shape) == 4: + result = result.permute(0, 2, 3, 1) + + base_layer = self.get_base_layer() + base_bias = base_layer.bias + if base_bias is not None: + # Bias should be added after OFT forward + result = result - base_bias.data + + # Execute all the adapters + for active_adapter in self.active_adapters: + if active_adapter not in self._available_adapters: + continue + + module_dropout = self.module_dropout[active_adapter] + + # Modify current execution weights + if (not self.training) or (self.training and torch.rand(1) > module_dropout): + result = self._get_delta_activations(active_adapter, result, *args, **kwargs) + + if base_bias is not None: + result = result + base_bias.data + if len(result.shape) == 4: + result = result.permute(0, 3, 1, 2) + + result = result.to(previous_dtype) + return result + + +class Linear(OFTLayer): + """OFT implemented in Linear layer""" + + def __init__( + self, + base_layer: nn.Module, + adapter_name: str = "default", + r: int = 0, + module_dropout: float = 0.0, + init_weights: bool = True, + **kwargs, + ): + super().__init__(base_layer) + + # Create adapter and set it active + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, module_dropout, init_weights, **kwargs) + + def _get_delta_activations( + self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any + ) -> torch.Tensor: + delta_weight = self.get_delta_weight(adapter_name) + + base_layer = self.get_base_layer() + base_weight = base_layer.weight.data + delta_weight = delta_weight[: base_weight.shape[0], : base_weight.shape[0]] + + # don't add bias here, because the bias will be added after OFT forward + return torch.matmul(input, delta_weight) + + def __repr__(self) -> str: + rep = super().__repr__() + return "oft." + rep + + +class Conv2d(OFTLayer): + """OFT implemented in Conv2d layer""" + + def __init__( + self, + base_layer: nn.Module, + adapter_name: str = "default", + r: int = 0, + module_dropout: float = 0.0, + init_weights: bool = True, + **kwargs, + ): + super().__init__(base_layer) + + # Create adapter and set it active + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, module_dropout, init_weights, **kwargs) + + def _get_delta_activations( + self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any + ) -> torch.Tensor: + delta_weight = self.get_delta_weight(adapter_name) + + base_layer = self.get_base_layer() + base_weight = base_layer.weight.data + delta_weight = delta_weight[: base_weight.shape[0], : base_weight.shape[0]] + + # don't add bias here, because the bias will be added after OFT forward + return torch.matmul(input, delta_weight) + + def __repr__(self) -> str: + rep = super().__repr__() + return "oft." + rep diff --git a/src/peft/tuners/oft/model.py b/src/peft/tuners/oft/model.py new file mode 100644 index 0000000000..4b7953daa9 --- /dev/null +++ b/src/peft/tuners/oft/model.py @@ -0,0 +1,108 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Dict, Type, Union + +import torch +from torch import nn + +from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner + +from .layer import Conv2d, Linear, OFTLayer + + +class OFTModel(LycorisTuner): + """ + Creates Orthogonal Finetuning model from a pretrained model. The method is described in + https://arxiv.org/abs/2306.07280 + + Args: + model (`torch.nn.Module`): The model to which the adapter tuner layers will be attached. + config ([`OFTConfig`]): The configuration of the OFT model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + + Returns: + `torch.nn.Module`: The OFT model. + + Example: + ```py + >>> from diffusers import StableDiffusionPipeline + >>> from peft import OFTModel, OFTConfig + + >>> config_te = OFTConfig( + ... r=8, + ... target_modules=["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"], + ... module_dropout=0.0, + ... init_weights=True, + ... ) + >>> config_unet = OFTConfig( + ... r=8, + ... target_modules=[ + ... "proj_in", + ... "proj_out", + ... "to_k", + ... "to_q", + ... "to_v", + ... "to_out.0", + ... "ff.net.0.proj", + ... "ff.net.2", + ... ], + ... module_dropout=0.0, + ... init_weights=True, + ... ) + + >>> model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> model.text_encoder = OFTModel(model.text_encoder, config_te, "default") + >>> model.unet = OFTModel(model.unet, config_unet, "default") + ``` + + **Attributes**: + - **model** ([`~torch.nn.Module`]) -- The model to be adapted. + - **peft_config** ([`OFTConfig`]): The configuration of the OFT model. + """ + + prefix: str = "oft_" + layers_mapping: Dict[Type[torch.nn.Module], Type[OFTLayer]] = { + torch.nn.Conv2d: Conv2d, + torch.nn.Linear: Linear, + } + + def _create_and_replace( + self, + config: LycorisConfig, + adapter_name: str, + target: Union[OFTLayer, nn.Module], + target_name: str, + parent: nn.Module, + current_key: str, + **optional_kwargs, + ) -> None: + """ + A private method to create and replace the target module with the adapter module. + """ + + # Regexp matching - Find key which matches current target_name in patterns provided + pattern_keys = list(config.rank_pattern.keys()) + target_name_key = next(filter(lambda key: re.match(f"(.*\.)?{key}$", current_key), pattern_keys), target_name) + + kwargs = config.to_dict() + kwargs["r"] = config.rank_pattern.get(target_name_key, config.r) + + if isinstance(target, OFTLayer): + target.update_layer(adapter_name, **kwargs) + else: + new_module = self._create_new_module(config, adapter_name, target, **kwargs) + self._replace_module(parent, target_name, new_module, target) diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index 29c764a08f..93b892d9e5 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -30,6 +30,7 @@ class PeftType(str, enum.Enum): IA3 = "IA3" LOHA = "LOHA" LOKR = "LOKR" + OFT = "OFT" class TaskType(str, enum.Enum): diff --git a/src/peft/utils/save_and_load.py b/src/peft/utils/save_and_load.py index 97bde0d6fe..c5da274085 100644 --- a/src/peft/utils/save_and_load.py +++ b/src/peft/utils/save_and_load.py @@ -113,6 +113,8 @@ def get_peft_model_state_dict( to_return["prompt_embeddings"] = prompt_embeddings elif config.peft_type == PeftType.IA3: to_return = {k: state_dict[k] for k in state_dict if "ia3_" in k} + elif config.peft_type == PeftType.OFT: + to_return = {k: state_dict[k] for k in state_dict if "oft_" in k} else: raise NotImplementedError if getattr(model, "modules_to_save", None) is not None: @@ -166,7 +168,7 @@ def set_peft_model_state_dict(model, peft_model_state_dict, adapter_name="defaul else: state_dict = peft_model_state_dict - if config.peft_type in (PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.ADALORA, PeftType.IA3): + if config.peft_type in (PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.ADALORA, PeftType.IA3, PeftType.OFT): peft_model_state_dict = {} parameter_prefix = { PeftType.IA3: "ia3_", @@ -174,6 +176,7 @@ def set_peft_model_state_dict(model, peft_model_state_dict, adapter_name="defaul PeftType.ADALORA: "lora_", PeftType.LOHA: "hada_", PeftType.LOKR: "lokr_", + PeftType.OFT: "oft_", }[config.peft_type] for k, v in state_dict.items(): if parameter_prefix in k: diff --git a/tests/test_config.py b/tests/test_config.py index 34f04232a9..06e72dae8e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -30,6 +30,7 @@ LoHaConfig, LoraConfig, MultitaskPromptTuningConfig, + OFTConfig, PeftConfig, PrefixTuningConfig, PromptEncoder, @@ -51,6 +52,7 @@ PrefixTuningConfig, PromptEncoderConfig, PromptTuningConfig, + OFTConfig, ) @@ -189,7 +191,7 @@ def test_prompt_encoder_warning_num_layers(self): expected_msg = "for MLP, the argument `encoder_num_layers` is ignored. Exactly 2 MLP layers are used." assert str(record.list[0].message) == expected_msg - @parameterized.expand([LoHaConfig, LoraConfig, IA3Config]) + @parameterized.expand([LoHaConfig, LoraConfig, IA3Config, OFTConfig]) def test_save_pretrained_with_target_modules(self, config_class): # See #1041, #1045 config = config_class(target_modules=["a", "list"]) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index b298388a84..4785526b26 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -24,7 +24,7 @@ from torch import nn from transformers.pytorch_utils import Conv1D -from peft import AdaLoraConfig, IA3Config, LoHaConfig, LoKrConfig, LoraConfig, PeftModel, get_peft_model +from peft import AdaLoraConfig, IA3Config, LoHaConfig, LoKrConfig, LoraConfig, OFTConfig, PeftModel, get_peft_model from peft.tuners.tuners_utils import BaseTunerLayer from .testing_common import PeftCommonTester @@ -191,6 +191,28 @@ "decompose_factor": 4, }, ), + ######## + # OFT # + ######## + ("Vanilla MLP 1 OFT", "MLP", OFTConfig, {"target_modules": "lin0"}), + ("Vanilla MLP 2 OFT", "MLP", OFTConfig, {"target_modules": ["lin0"]}), + ("Vanilla MLP 5 OFT", "MLP", OFTConfig, {"target_modules": ["lin0"], "modules_to_save": ["lin1"]}), + ( + "Vanilla MLP 6 OFT", + "MLP", + OFTConfig, + { + "target_modules": ["lin0"], + "module_dropout": 0.1, + }, + ), + ("Vanilla MLP 7 OFT", "MLP", OFTConfig, {"target_modules": ["lin0"], "coft": True}), + ("Vanilla MLP 8 OFT", "MLP", OFTConfig, {"target_modules": ["lin0"], "block_share": True}), + ("Vanilla MLP 9 OFT", "MLP", OFTConfig, {"target_modules": ["lin0"], "coft": True, "block_share": True}), + ("Conv2d 1 OFT", "Conv2d", OFTConfig, {"target_modules": ["conv2d"]}), + ("Conv2d 3 OFT", "Conv2d", OFTConfig, {"target_modules": ["conv2d"], "coft": True}), + ("Conv2d 4 OFT", "Conv2d", OFTConfig, {"target_modules": ["conv2d"], "block_share": True}), + ("Conv2d 5 OFT", "Conv2d", OFTConfig, {"target_modules": ["conv2d"], "coft": True, "block_share": True}), ] MULTIPLE_ACTIVE_ADAPTERS_TEST_CASES = [ @@ -258,6 +280,7 @@ LoraConfig: "lora_", LoHaConfig: "hada_", LoKrConfig: "lokr_", + OFTConfig: "oft_", } @@ -833,6 +856,7 @@ def test_targeting_lora_to_embedding_layer_non_transformers(self, save_embedding LoHaConfig(target_modules=["lin0"], init_weights=False), AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), IA3Config(target_modules=["lin0"], feedforward_modules=["lin0"], init_ia3_weights=False), + OFTConfig(target_modules=["lin0"], init_weights=False), ] ) def test_adapter_name_makes_no_difference(self, config0): @@ -1852,3 +1876,80 @@ def test_requires_grad_lokr_same_targets(self): "base_model.model.lin0.lokr_w1.adapter1", "base_model.model.lin0.lokr_w2.adapter1", ) + + def test_requires_grad_oft_different_targets(self): + # test two different OFT adapters that target different modules + config0 = OFTConfig(target_modules=["lin0"]) + peft_model = get_peft_model(MLP(), config0) + + config1 = OFTConfig(target_modules=["lin1"], inference_mode=True) + peft_model.add_adapter("adapter1", config1) + + # active pter is still "default" + self.check_requires_grad( + peft_model, + "base_model.model.lin0.oft_r.default", + ) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.check_requires_grad( + peft_model, + "base_model.model.lin0.oft_r.default", + ) + + # change activate pter to pter1 + peft_model.set_adapter("adapter1") + self.check_requires_grad( + peft_model, + "base_model.model.lin1.oft_r.adapter1", + ) + + # disable all pters + with peft_model.disable_adapter(): + self.check_requires_grad(peft_model) + + # after context is exited, return to the previous state + self.check_requires_grad( + peft_model, + "base_model.model.lin1.oft_r.adapter1", + ) + + def test_requires_grad_oft_same_targets(self): + # same as previous test, except that OFT adapters target the same layer + config0 = OFTConfig(target_modules=["lin0"]) + peft_model = get_peft_model(MLP(), config0) + + config1 = OFTConfig(target_modules=["lin0"], inference_mode=True) + peft_model.add_adapter("adapter1", config1) + + # active adapter is still "default" + self.check_requires_grad( + peft_model, + "base_model.model.lin0.oft_r.default", + ) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.check_requires_grad( + peft_model, + "base_model.model.lin0.oft_r.default", + ) + + # change activate adapter to adapter1 + peft_model.set_adapter("adapter1") + self.check_requires_grad( + peft_model, + "base_model.model.lin0.oft_r.adapter1", + ) + + # disable all adapters + with peft_model.disable_adapter(): + self.check_requires_grad(peft_model) + + # after context is exited, return to the previous state + peft_model.set_adapter("adapter1") + self.check_requires_grad( + peft_model, + "base_model.model.lin0.oft_r.adapter1", + ) diff --git a/tests/test_stablediffusion.py b/tests/test_stablediffusion.py index 830614a7ab..660c17caea 100644 --- a/tests/test_stablediffusion.py +++ b/tests/test_stablediffusion.py @@ -20,7 +20,7 @@ from diffusers import StableDiffusionPipeline from parameterized import parameterized -from peft import LoHaConfig, LoraConfig, get_peft_model +from peft import LoHaConfig, LoraConfig, OFTConfig, get_peft_model from .testing_common import ClassInstantier, PeftCommonTester from .testing_utils import temp_seed @@ -60,11 +60,24 @@ "module_dropout": 0.0, }, }, + { + "text_encoder": { + "r": 8, + "target_modules": ["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"], + "module_dropout": 0.0, + }, + "unet": { + "r": 8, + "target_modules": ["proj_in", "proj_out", "to_k", "to_q", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"], + "module_dropout": 0.0, + }, + }, ) CLASSES_MAPPING = { "lora": (LoraConfig, CONFIG_TESTING_KWARGS[0]), "loha": (LoHaConfig, CONFIG_TESTING_KWARGS[1]), "lokr": (LoHaConfig, CONFIG_TESTING_KWARGS[1]), + "oft": (OFTConfig, CONFIG_TESTING_KWARGS[2]), } @@ -115,13 +128,14 @@ def prepare_inputs_for_testing(self): "model_ids": PEFT_DIFFUSERS_SD_MODELS_TO_TEST, "lora_kwargs": {"init_lora_weights": [False]}, "loha_kwargs": {"init_weights": [False]}, + "oft_kwargs": {"init_weights": [False]}, }, ) ) def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): - if config_cls == LoHaConfig: + if config_cls in [LoHaConfig, OFTConfig]: # TODO: This test is flaky with PyTorch 2.1 on Windows, we need to figure out what is going on - self.skipTest("LoHaConfig test is flaky") + self.skipTest("LoHaConfig and OFTConfig test is flaky") # Instantiate model & adapters model = self.instantiate_sd_peft(model_id, config_cls, config_kwargs) @@ -148,7 +162,7 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): "model_ids": PEFT_DIFFUSERS_SD_MODELS_TO_TEST, "lora_kwargs": {"init_lora_weights": [False]}, }, - filter_params_func=lambda tests: [x for x in tests if all(s not in x[0] for s in ["loha", "lokr"])], + filter_params_func=lambda tests: [x for x in tests if all(s not in x[0] for s in ["loha", "lokr", "oft"])], ) ) def test_add_weighted_adapter_base_unchanged(self, test_name, model_id, config_cls, config_kwargs): @@ -178,6 +192,7 @@ def test_add_weighted_adapter_base_unchanged(self, test_name, model_id, config_c "lora_kwargs": {"init_lora_weights": [False]}, "loha_kwargs": {"init_weights": [False]}, "lokr_kwargs": {"init_weights": [False]}, + "oft_kwargs": {"init_weights": [False]}, }, ) ) diff --git a/tests/testing_common.py b/tests/testing_common.py index 00809c2bc1..0c081cde2c 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -574,7 +574,7 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs): self.assertTrue(torch.allclose(logits_merged, logits_merged_from_pretrained, atol=atol, rtol=rtol)) def _test_merge_layers_multi(self, model_id, config_cls, config_kwargs): - supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3] + supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3, PeftType.OFT] if ("gpt2" in model_id.lower()) and (config_cls == IA3Config): self.skipTest("Merging GPT2 adapters not supported for IA³ (yet)") @@ -886,7 +886,7 @@ def _test_training_prompt_learning_tasks(self, model_id, config_cls, config_kwar self.assertIsNotNone(param.grad) def _test_delete_adapter(self, model_id, config_cls, config_kwargs): - supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3] + supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3, PeftType.OFT] # IA3 does not support deleting adapters yet, but it just needs to be added # AdaLora does not support multiple adapters config = config_cls( @@ -924,7 +924,7 @@ def _test_delete_adapter(self, model_id, config_cls, config_kwargs): def _test_delete_inactive_adapter(self, model_id, config_cls, config_kwargs): # same as test_delete_adapter, but this time an inactive adapter is deleted - supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3] + supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3, PeftType.OFT] # IA3 does not support deleting adapters yet, but it just needs to be added # AdaLora does not support multiple adapters config = config_cls(