Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA for Conv2d layer, script to convert kohya_ss LoRA to PEFT #461

Merged
merged 3 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import argparse
import os
from typing import List, Optional

import safetensors
import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel
from transformers import CLIPTextModel

from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict


# Default kohya_ss LoRA replacement modules
# https://github.com/kohya-ss/sd-scripts/blob/c924c47f374ac1b6e33e71f82948eb1853e2243f/networks/lora.py#L661
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"


def get_modules_names(
root_module: nn.Module,
target_replace_modules_linear: Optional[List[str]] = [],
target_replace_modules_conv2d: Optional[List[str]] = [],
):
# Combine replacement modules
target_replace_modules = target_replace_modules_linear + target_replace_modules_conv2d

# Store result
modules_names = set()
# https://github.com/kohya-ss/sd-scripts/blob/c924c47f374ac1b6e33e71f82948eb1853e2243f/networks/lora.py#L720
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
if len(name) == 0:
continue
for child_name, child_module in module.named_modules():
if len(child_name) == 0:
continue
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"

if (is_linear and module.__class__.__name__ in target_replace_modules_linear) or (
is_conv2d and module.__class__.__name__ in target_replace_modules_conv2d
):
modules_names.add(f"{name}.{child_name}")

return sorted(modules_names)


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument("--sd_checkpoint", default=None, type=str, required=True, help="SD checkpoint to use")

parser.add_argument(
"--kohya_lora_path", default=None, type=str, required=True, help="Path to kohya_ss trained LoRA"
)

parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")

parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
args = parser.parse_args()

# Find text encoder modules to add LoRA to
text_encoder = CLIPTextModel.from_pretrained(args.sd_checkpoint, subfolder="text_encoder")
text_encoder_modules_names = get_modules_names(
text_encoder, target_replace_modules_linear=TEXT_ENCODER_TARGET_REPLACE_MODULE
)

# Find unet2d modules to add LoRA to
unet = UNet2DConditionModel.from_pretrained(args.sd_checkpoint, subfolder="unet")
unet_modules_names = get_modules_names(
unet,
target_replace_modules_linear=UNET_TARGET_REPLACE_MODULE,
target_replace_modules_conv2d=UNET_TARGET_REPLACE_MODULE,
)

# Open kohya_ss checkpoint
with safetensors.safe_open(args.kohya_lora_path, framework="pt", device="cpu") as f:
# Extract information about LoRA structure
metadata = f.metadata()
lora_r = lora_text_encoder_r = int(metadata["ss_network_dim"])
lora_alpha = lora_text_encoder_alpha = float(metadata["ss_network_alpha"])

# Create LoRA for text encoder
text_encoder_config = LoraConfig(
r=lora_text_encoder_r,
lora_alpha=lora_text_encoder_alpha,
target_modules=text_encoder_modules_names,
lora_dropout=0.0,
bias="none",
)
text_encoder = get_peft_model(text_encoder, text_encoder_config)
text_encoder_lora_state_dict = {x: None for x in get_peft_model_state_dict(text_encoder).keys()}

# Load text encoder values from kohya_ss LoRA
for peft_te_key in text_encoder_lora_state_dict.keys():
kohya_ss_te_key = peft_te_key.replace("base_model.model", LORA_PREFIX_TEXT_ENCODER)
kohya_ss_te_key = kohya_ss_te_key.replace("lora_A", "lora_down")
kohya_ss_te_key = kohya_ss_te_key.replace("lora_B", "lora_up")
kohya_ss_te_key = kohya_ss_te_key.replace(".", "_", kohya_ss_te_key.count(".") - 2)
text_encoder_lora_state_dict[peft_te_key] = f.get_tensor(kohya_ss_te_key).to(text_encoder.dtype)

# Load converted kohya_ss text encoder LoRA back to PEFT
set_peft_model_state_dict(text_encoder, text_encoder_lora_state_dict)

if args.half:
text_encoder.to(torch.float16)

# Save text encoder result
text_encoder.save_pretrained(
os.path.join(args.dump_path, "text_encoder"),
)

# Create LoRA for unet2d
unet_config = LoraConfig(
r=lora_r, lora_alpha=lora_alpha, target_modules=unet_modules_names, lora_dropout=0.0, bias="none"
)
unet = get_peft_model(unet, unet_config)
unet_lora_state_dict = {x: None for x in get_peft_model_state_dict(unet).keys()}

# Load unet2d values from kohya_ss LoRA
for peft_unet_key in unet_lora_state_dict.keys():
kohya_ss_unet_key = peft_unet_key.replace("base_model.model", LORA_PREFIX_UNET)
kohya_ss_unet_key = kohya_ss_unet_key.replace("lora_A", "lora_down")
kohya_ss_unet_key = kohya_ss_unet_key.replace("lora_B", "lora_up")
kohya_ss_unet_key = kohya_ss_unet_key.replace(".", "_", kohya_ss_unet_key.count(".") - 2)
unet_lora_state_dict[peft_unet_key] = f.get_tensor(kohya_ss_unet_key).to(unet.dtype)

# Load converted kohya_ss unet LoRA back to PEFT
set_peft_model_state_dict(unet, unet_lora_state_dict)

if args.half:
unet.to(torch.float16)

# Save text encoder result
unet.save_pretrained(
os.path.join(args.dump_path, "unet"),
)
3 changes: 2 additions & 1 deletion examples/lora_dreambooth/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ datasets
diffusers
Pillow
torchvision
huggingface_hub
huggingface_hub
safetensors
192 changes: 185 additions & 7 deletions src/peft/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import warnings
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -262,6 +262,12 @@ def _create_new_module(self, lora_config, adapter_name, target):
embedding_kwargs.pop("fan_in_fan_out", None)
in_features, out_features = target.num_embeddings, target.embedding_dim
new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs)
elif isinstance(target, torch.nn.Conv2d):
out_channels, in_channels = target.weight.size()[:2]
kernel_size = target.weight.size()[2:]
stride = target.stride
padding = target.padding
new_module = Conv2d(adapter_name, in_channels, out_channels, kernel_size, stride, padding, **kwargs)
else:
if isinstance(target, torch.nn.Linear):
in_features, out_features = target.in_features, target.out_features
Expand Down Expand Up @@ -303,7 +309,15 @@ def _find_and_replace(self, adapter_name):
is_target_modules_in_base_model = True
parent, target, target_name = _get_submodules(self.model, key)

if isinstance(target, LoraLayer):
if isinstance(target, LoraLayer) and isinstance(target, torch.nn.Conv2d):
target.update_layer_conv2d(
adapter_name,
lora_config.r,
lora_config.lora_alpha,
lora_config.lora_dropout,
lora_config.init_lora_weights,
)
elif isinstance(target, LoraLayer):
target.update_layer(
adapter_name,
lora_config.r,
Expand Down Expand Up @@ -487,11 +501,7 @@ def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:


class LoraLayer:
def __init__(
self,
in_features: int,
out_features: int,
):
def __init__(self, in_features: int, out_features: int, **kwargs):
self.r = {}
self.lora_alpha = {}
self.scaling = {}
Expand All @@ -506,6 +516,7 @@ def __init__(
self.disable_adapters = False
self.in_features = in_features
self.out_features = out_features
self.kwargs = kwargs

def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
self.r[adapter_name] = r
Expand All @@ -525,6 +536,31 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
self.reset_lora_parameters(adapter_name)
self.to(self.weight.device)

def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
lora_dropout_layer = nn.Dropout(p=lora_dropout)
else:
lora_dropout_layer = nn.Identity()

self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
# Actual trainable parameters
if r > 0:
kernel_size = self.kwargs["kernel_size"]
stride = self.kwargs["stride"]
padding = self.kwargs["padding"]
self.lora_A.update(
nn.ModuleDict({adapter_name: nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False)})
)
self.lora_B.update(
nn.ModuleDict({adapter_name: nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False)})
)
self.scaling[adapter_name] = lora_alpha / r
if init_lora_weights:
self.reset_lora_parameters(adapter_name)
self.to(self.weight.device)

def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
Expand Down Expand Up @@ -726,6 +762,148 @@ def forward(self, x: torch.Tensor):
return nn.Embedding.forward(self, x)


class Conv2d(nn.Conv2d, LoraLayer):
# Lora implemented in a conv2d layer
def __init__(
self,
adapter_name: str,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int]],
stride: Union[int, Tuple[int]] = 1,
padding: Union[int, Tuple[int]] = 0,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
**kwargs,
):
init_lora_weights = kwargs.pop("init_lora_weights", True)

nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, padding)
LoraLayer.__init__(
self,
in_features=in_channels,
out_features=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False

nn.Conv2d.reset_parameters(self)
self.update_layer_conv2d(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name

def merge(self):
if self.active_adapter not in self.lora_A.keys():
return
if self.merged:
warnings.warn("Already merged. Nothing to do.")
return
if self.r[self.active_adapter] > 0:
# https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117
if self.weight.size()[2:4] == (1, 1):
# conv2d 1x1
self.weight.data += (
self.lora_B[self.active_adapter].weight.squeeze(3).squeeze(2)
@ self.lora_A[self.active_adapter].weight.squeeze(3).squeeze(2)
).unsqueeze(2).unsqueeze(3) * self.scaling[self.active_adapter]
else:
# conv2d 3x3
self.weight.data += (
F.conv2d(
self.lora_A[self.active_adapter].weight.permute(1, 0, 2, 3),
self.lora_B[self.active_adapter].weight,
).permute(1, 0, 2, 3)
* self.scaling[self.active_adapter]
)
self.merged = True

def unmerge(self):
if self.active_adapter not in self.lora_A.keys():
return
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
if self.r[self.active_adapter] > 0:
if self.weight.size()[2:4] == (1, 1):
# conv2d 1x1
self.weight.data -= (
self.lora_B[self.active_adapter].weight.squeeze(3).squeeze(2)
@ self.lora_A[self.active_adapter].weight.squeeze(3).squeeze(2)
).unsqueeze(2).unsqueeze(3) * self.scaling[self.active_adapter]
else:
# conv2d 3x3
self.weight.data += (
F.conv2d(
self.lora_A[self.active_adapter].weight.permute(1, 0, 2, 3),
self.lora_B[self.active_adapter].weight,
).permute(1, 0, 2, 3)
* self.scaling[self.active_adapter]
)
self.merged = False

def forward(self, x: torch.Tensor):
previous_dtype = x.dtype

if self.active_adapter not in self.lora_A.keys():
return F.conv2d(
x,
self.weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
if self.disable_adapters:
if self.r[self.active_adapter] > 0 and self.merged:
self.unmerge()
result = F.conv2d(
x,
self.weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
elif self.r[self.active_adapter] > 0 and not self.merged:
result = F.conv2d(
x,
self.weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)

x = x.to(self.lora_A[self.active_adapter].weight.dtype)

result += (
self.lora_B[self.active_adapter](
self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
)
* self.scaling[self.active_adapter]
)
else:
result = F.conv2d(
x,
self.weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)

result = result.to(previous_dtype)

return result


if is_bnb_available():

class Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer):
Expand Down