diff --git a/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py b/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py new file mode 100644 index 0000000000..34213a5409 --- /dev/null +++ b/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py @@ -0,0 +1,141 @@ +import argparse +import os +from typing import List, Optional + +import safetensors +import torch +import torch.nn as nn +from diffusers import UNet2DConditionModel +from transformers import CLIPTextModel + +from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict + + +# Default kohya_ss LoRA replacement modules +# https://github.com/kohya-ss/sd-scripts/blob/c924c47f374ac1b6e33e71f82948eb1853e2243f/networks/lora.py#L661 +UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] +UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] +TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] +LORA_PREFIX_UNET = "lora_unet" +LORA_PREFIX_TEXT_ENCODER = "lora_te" + + +def get_modules_names( + root_module: nn.Module, + target_replace_modules_linear: Optional[List[str]] = [], + target_replace_modules_conv2d: Optional[List[str]] = [], +): + # Combine replacement modules + target_replace_modules = target_replace_modules_linear + target_replace_modules_conv2d + + # Store result + modules_names = set() + # https://github.com/kohya-ss/sd-scripts/blob/c924c47f374ac1b6e33e71f82948eb1853e2243f/networks/lora.py#L720 + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + if len(name) == 0: + continue + for child_name, child_module in module.named_modules(): + if len(child_name) == 0: + continue + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + + if (is_linear and module.__class__.__name__ in target_replace_modules_linear) or ( + is_conv2d and module.__class__.__name__ in target_replace_modules_conv2d + ): + modules_names.add(f"{name}.{child_name}") + + return sorted(modules_names) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--sd_checkpoint", default=None, type=str, required=True, help="SD checkpoint to use") + + parser.add_argument( + "--kohya_lora_path", default=None, type=str, required=True, help="Path to kohya_ss trained LoRA" + ) + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + parser.add_argument("--half", action="store_true", help="Save weights in half precision.") + args = parser.parse_args() + + # Find text encoder modules to add LoRA to + text_encoder = CLIPTextModel.from_pretrained(args.sd_checkpoint, subfolder="text_encoder") + text_encoder_modules_names = get_modules_names( + text_encoder, target_replace_modules_linear=TEXT_ENCODER_TARGET_REPLACE_MODULE + ) + + # Find unet2d modules to add LoRA to + unet = UNet2DConditionModel.from_pretrained(args.sd_checkpoint, subfolder="unet") + unet_modules_names = get_modules_names( + unet, + target_replace_modules_linear=UNET_TARGET_REPLACE_MODULE, + target_replace_modules_conv2d=UNET_TARGET_REPLACE_MODULE, + ) + + # Open kohya_ss checkpoint + with safetensors.safe_open(args.kohya_lora_path, framework="pt", device="cpu") as f: + # Extract information about LoRA structure + metadata = f.metadata() + lora_r = lora_text_encoder_r = int(metadata["ss_network_dim"]) + lora_alpha = lora_text_encoder_alpha = float(metadata["ss_network_alpha"]) + + # Create LoRA for text encoder + text_encoder_config = LoraConfig( + r=lora_text_encoder_r, + lora_alpha=lora_text_encoder_alpha, + target_modules=text_encoder_modules_names, + lora_dropout=0.0, + bias="none", + ) + text_encoder = get_peft_model(text_encoder, text_encoder_config) + text_encoder_lora_state_dict = {x: None for x in get_peft_model_state_dict(text_encoder).keys()} + + # Load text encoder values from kohya_ss LoRA + for peft_te_key in text_encoder_lora_state_dict.keys(): + kohya_ss_te_key = peft_te_key.replace("base_model.model", LORA_PREFIX_TEXT_ENCODER) + kohya_ss_te_key = kohya_ss_te_key.replace("lora_A", "lora_down") + kohya_ss_te_key = kohya_ss_te_key.replace("lora_B", "lora_up") + kohya_ss_te_key = kohya_ss_te_key.replace(".", "_", kohya_ss_te_key.count(".") - 2) + text_encoder_lora_state_dict[peft_te_key] = f.get_tensor(kohya_ss_te_key).to(text_encoder.dtype) + + # Load converted kohya_ss text encoder LoRA back to PEFT + set_peft_model_state_dict(text_encoder, text_encoder_lora_state_dict) + + if args.half: + text_encoder.to(torch.float16) + + # Save text encoder result + text_encoder.save_pretrained( + os.path.join(args.dump_path, "text_encoder"), + ) + + # Create LoRA for unet2d + unet_config = LoraConfig( + r=lora_r, lora_alpha=lora_alpha, target_modules=unet_modules_names, lora_dropout=0.0, bias="none" + ) + unet = get_peft_model(unet, unet_config) + unet_lora_state_dict = {x: None for x in get_peft_model_state_dict(unet).keys()} + + # Load unet2d values from kohya_ss LoRA + for peft_unet_key in unet_lora_state_dict.keys(): + kohya_ss_unet_key = peft_unet_key.replace("base_model.model", LORA_PREFIX_UNET) + kohya_ss_unet_key = kohya_ss_unet_key.replace("lora_A", "lora_down") + kohya_ss_unet_key = kohya_ss_unet_key.replace("lora_B", "lora_up") + kohya_ss_unet_key = kohya_ss_unet_key.replace(".", "_", kohya_ss_unet_key.count(".") - 2) + unet_lora_state_dict[peft_unet_key] = f.get_tensor(kohya_ss_unet_key).to(unet.dtype) + + # Load converted kohya_ss unet LoRA back to PEFT + set_peft_model_state_dict(unet, unet_lora_state_dict) + + if args.half: + unet.to(torch.float16) + + # Save text encoder result + unet.save_pretrained( + os.path.join(args.dump_path, "unet"), + ) diff --git a/examples/lora_dreambooth/requirements.txt b/examples/lora_dreambooth/requirements.txt index fcaffe6f43..0a5f78cdb4 100644 --- a/examples/lora_dreambooth/requirements.txt +++ b/examples/lora_dreambooth/requirements.txt @@ -7,4 +7,5 @@ datasets diffusers Pillow torchvision -huggingface_hub \ No newline at end of file +huggingface_hub +safetensors \ No newline at end of file diff --git a/src/peft/tuners/lora.py b/src/peft/tuners/lora.py index 0fcaa6bf28..856f46d389 100644 --- a/src/peft/tuners/lora.py +++ b/src/peft/tuners/lora.py @@ -17,7 +17,7 @@ import warnings from dataclasses import asdict, dataclass, field from enum import Enum -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -262,6 +262,12 @@ def _create_new_module(self, lora_config, adapter_name, target): embedding_kwargs.pop("fan_in_fan_out", None) in_features, out_features = target.num_embeddings, target.embedding_dim new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs) + elif isinstance(target, torch.nn.Conv2d): + out_channels, in_channels = target.weight.size()[:2] + kernel_size = target.weight.size()[2:] + stride = target.stride + padding = target.padding + new_module = Conv2d(adapter_name, in_channels, out_channels, kernel_size, stride, padding, **kwargs) else: if isinstance(target, torch.nn.Linear): in_features, out_features = target.in_features, target.out_features @@ -303,7 +309,15 @@ def _find_and_replace(self, adapter_name): is_target_modules_in_base_model = True parent, target, target_name = _get_submodules(self.model, key) - if isinstance(target, LoraLayer): + if isinstance(target, LoraLayer) and isinstance(target, torch.nn.Conv2d): + target.update_layer_conv2d( + adapter_name, + lora_config.r, + lora_config.lora_alpha, + lora_config.lora_dropout, + lora_config.init_lora_weights, + ) + elif isinstance(target, LoraLayer): target.update_layer( adapter_name, lora_config.r, @@ -487,11 +501,7 @@ def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None: class LoraLayer: - def __init__( - self, - in_features: int, - out_features: int, - ): + def __init__(self, in_features: int, out_features: int, **kwargs): self.r = {} self.lora_alpha = {} self.scaling = {} @@ -506,6 +516,7 @@ def __init__( self.disable_adapters = False self.in_features = in_features self.out_features = out_features + self.kwargs = kwargs def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): self.r[adapter_name] = r @@ -525,6 +536,31 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig self.reset_lora_parameters(adapter_name) self.to(self.weight.device) + def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) + # Actual trainable parameters + if r > 0: + kernel_size = self.kwargs["kernel_size"] + stride = self.kwargs["stride"] + padding = self.kwargs["padding"] + self.lora_A.update( + nn.ModuleDict({adapter_name: nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False)}) + ) + self.lora_B.update( + nn.ModuleDict({adapter_name: nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False)}) + ) + self.scaling[adapter_name] = lora_alpha / r + if init_lora_weights: + self.reset_lora_parameters(adapter_name) + self.to(self.weight.device) + def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): self.r[adapter_name] = r self.lora_alpha[adapter_name] = lora_alpha @@ -726,6 +762,148 @@ def forward(self, x: torch.Tensor): return nn.Embedding.forward(self, x) +class Conv2d(nn.Conv2d, LoraLayer): + # Lora implemented in a conv2d layer + def __init__( + self, + adapter_name: str, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int]], + stride: Union[int, Tuple[int]] = 1, + padding: Union[int, Tuple[int]] = 0, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + **kwargs, + ): + init_lora_weights = kwargs.pop("init_lora_weights", True) + + nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, padding) + LoraLayer.__init__( + self, + in_features=in_channels, + out_features=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + + nn.Conv2d.reset_parameters(self) + self.update_layer_conv2d(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) + self.active_adapter = adapter_name + + def merge(self): + if self.active_adapter not in self.lora_A.keys(): + return + if self.merged: + warnings.warn("Already merged. Nothing to do.") + return + if self.r[self.active_adapter] > 0: + # https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117 + if self.weight.size()[2:4] == (1, 1): + # conv2d 1x1 + self.weight.data += ( + self.lora_B[self.active_adapter].weight.squeeze(3).squeeze(2) + @ self.lora_A[self.active_adapter].weight.squeeze(3).squeeze(2) + ).unsqueeze(2).unsqueeze(3) * self.scaling[self.active_adapter] + else: + # conv2d 3x3 + self.weight.data += ( + F.conv2d( + self.lora_A[self.active_adapter].weight.permute(1, 0, 2, 3), + self.lora_B[self.active_adapter].weight, + ).permute(1, 0, 2, 3) + * self.scaling[self.active_adapter] + ) + self.merged = True + + def unmerge(self): + if self.active_adapter not in self.lora_A.keys(): + return + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + if self.r[self.active_adapter] > 0: + if self.weight.size()[2:4] == (1, 1): + # conv2d 1x1 + self.weight.data -= ( + self.lora_B[self.active_adapter].weight.squeeze(3).squeeze(2) + @ self.lora_A[self.active_adapter].weight.squeeze(3).squeeze(2) + ).unsqueeze(2).unsqueeze(3) * self.scaling[self.active_adapter] + else: + # conv2d 3x3 + self.weight.data += ( + F.conv2d( + self.lora_A[self.active_adapter].weight.permute(1, 0, 2, 3), + self.lora_B[self.active_adapter].weight, + ).permute(1, 0, 2, 3) + * self.scaling[self.active_adapter] + ) + self.merged = False + + def forward(self, x: torch.Tensor): + previous_dtype = x.dtype + + if self.active_adapter not in self.lora_A.keys(): + return F.conv2d( + x, + self.weight, + bias=self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + if self.disable_adapters: + if self.r[self.active_adapter] > 0 and self.merged: + self.unmerge() + result = F.conv2d( + x, + self.weight, + bias=self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + elif self.r[self.active_adapter] > 0 and not self.merged: + result = F.conv2d( + x, + self.weight, + bias=self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + + x = x.to(self.lora_A[self.active_adapter].weight.dtype) + + result += ( + self.lora_B[self.active_adapter]( + self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x)) + ) + * self.scaling[self.active_adapter] + ) + else: + result = F.conv2d( + x, + self.weight, + bias=self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + + result = result.to(previous_dtype) + + return result + + if is_bnb_available(): class Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer):