Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add implementation of LyCORIS LoKr (KronA-like adapter) for SD&SDXL models #978

Merged
merged 37 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
41a71e3
Initial commit for LoKr implementation
kovalexal Sep 29, 2023
1762d10
Merged 'main' into 'lokr'
kovalexal Oct 2, 2023
0f60893
Added current implementation of LoKr
kovalexal Oct 3, 2023
e7d6e23
Fixed setting requires_grad for lokr modules
kovalexal Oct 3, 2023
bb45764
Updated initialization of LoKr adapter weights
kovalexal Oct 4, 2023
0c33d8c
Updated docstrings for LoKr params
kovalexal Oct 4, 2023
84b890b
Removed unneccessary comments
kovalexal Oct 4, 2023
fd4a754
Modified sd dreambooth script to be able to train LoRA, LoHa, LoKr ad…
kovalexal Oct 4, 2023
ddfae52
Updated conversion script to incorporate LoKr
kovalexal Oct 4, 2023
7526aa2
Added simple tests for LoKr adapter
kovalexal Oct 4, 2023
fd5daad
Merge branch 'main' into lokr
kovalexal Oct 9, 2023
8dc5e98
Modified 'merged' property
kovalexal Oct 9, 2023
c1cef38
Removed duplicated comments
kovalexal Oct 9, 2023
ad525e4
Replaced wrong keys for LoKr
kovalexal Oct 10, 2023
ba45881
Refactored LoHaModel and LoKrModel
kovalexal Oct 10, 2023
2401bf1
Refactored LoHaModel and LoKrModel again
kovalexal Oct 10, 2023
1fad986
Refactored LoHaLayer and LoKrLayer a bit
kovalexal Oct 10, 2023
39e87ce
Removed unnecessary comments
kovalexal Oct 10, 2023
4171c64
Addressed comments on _available_adapters property
kovalexal Oct 13, 2023
b24bdbf
Replaced te with text_encoder
kovalexal Oct 13, 2023
d8f2a83
Apply suggestions from code review
kovalexal Oct 13, 2023
6465597
Changed exception type raised when creating adapter for unsupported l…
kovalexal Oct 13, 2023
a63d249
Added additional tests for use_effective_conv2d/decompose_both/decomp…
kovalexal Oct 13, 2023
2b70fc0
Removed classmethod
kovalexal Oct 13, 2023
30a5a85
Merge branch 'main' into lokr
kovalexal Oct 13, 2023
3033a75
Addressed conversion script review comments
kovalexal Oct 13, 2023
299de88
Replaced factorization docstring
kovalexal Oct 13, 2023
d518728
LyCORIS -> Lycoris
kovalexal Oct 13, 2023
c9a8457
Merged 'main' into 'lokr'
kovalexal Oct 24, 2023
63aba4e
Updated README to include LoKr adapter
kovalexal Oct 25, 2023
6700baf
Addressed some code review comments
kovalexal Oct 25, 2023
fa6b522
Addressed some code review comments
kovalexal Oct 25, 2023
e76182f
Addressed some code review comments
kovalexal Oct 25, 2023
25077b2
Updated check_target_modules docstring, increased test coverage
kovalexal Oct 27, 2023
9f05024
Added delete_adapter method for LoKr and LoHa
kovalexal Oct 27, 2023
f6e7335
Fixed typo in delete_adapter
kovalexal Oct 27, 2023
69ae74c
Provide default value for
kovalexal Oct 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Supported methods:
6. $(IA)^3$: [Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning](https://arxiv.org/abs/2205.05638)
7. MultiTask Prompt Tuning: [Multitask Prompt Tuning Enables Parameter-Efficient Transfer Learning](https://arxiv.org/abs/2303.02861)
8. LoHa: [FedPara: Low-Rank Hadamard Product for Communication-Efficient Federated Learning](https://arxiv.org/abs/2108.06098)
9. LoKr: [KronA: Parameter Efficient Tuning with Kronecker Adapter](https://arxiv.org/abs/2212.10650) based on [Navigating Text-To-Image Customization:From LyCORIS Fine-Tuning to Model Evaluation](https://arxiv.org/abs/2309.14859) implementation

## Getting started

Expand Down Expand Up @@ -134,7 +135,7 @@ Try out the 🤗 Gradio Space which should run seamlessly on a T4 instance:
**NEW** ✨ Multi Adapter support and combining multiple LoRA adapters in a weighted combination
![peft lora dreambooth weighted adapter](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/peft/weighted_adapter_dreambooth_lora.png)

**NEW** ✨ Dreambooth training for Stable Diffusion using LoHa adapter [`examples/stable_diffusion/train_dreambooth_loha.py`](examples/stable_diffusion/train_dreambooth_loha.py)
**NEW** ✨ Dreambooth training for Stable Diffusion using LoHa and LoKr adapters [`examples/stable_diffusion/train_dreambooth.py`](examples/stable_diffusion/train_dreambooth.py)

### Parameter Efficient Tuning of LLMs for RLHF components such as Ranker and Policy
- Here is an example in [trl](https://github.com/lvwerra/trl) library using PEFT+INT8 for tuning policy model: [gpt2-sentiment_peft.py](https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt2-sentiment_peft.py) and corresponding [Blog](https://huggingface.co/blog/trl-peft)
Expand Down Expand Up @@ -273,9 +274,9 @@ An example is provided in `~examples/causal_language_modeling/peft_lora_clm_acce

### Text-to-Image Generation

| Model | LoRA | LoHa | Prefix Tuning | P-Tuning | Prompt Tuning | IA3 |
| --------- | ---- | ---- | ---- | ---- | ---- | ---- |
| Stable Diffusion | ✅ | ✅ | | | |
| Model | LoRA | LoHa | LoKr | Prefix Tuning | P-Tuning | Prompt Tuning | IA3 |
| --------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
| Stable Diffusion | ✅ | ✅ | ✅ | | | |


### Image Classification
Expand Down
185 changes: 176 additions & 9 deletions examples/stable_diffusion/convert_sd_adapter_to_peft.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import argparse
import json
import logging
import os
from collections import Counter
from dataclasses import dataclass
from operator import attrgetter
from typing import Dict, List, Optional, Union

import safetensors
import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel
from transformers import CLIPTextModel

from peft import LoHaConfig, LoraConfig, PeftType, get_peft_model, set_peft_model_state_dict
from peft import LoHaConfig, LoKrConfig, LoraConfig, PeftType, get_peft_model, set_peft_model_state_dict
from peft.tuners.lokr.layer import factorization


# Default kohya_ss LoRA replacement modules
Expand Down Expand Up @@ -74,7 +78,48 @@ def peft_state_dict(self) -> Dict[str, torch.Tensor]:
return state_dict


def construct_peft_loraconfig(info: Dict[str, LoRAInfo]) -> LoraConfig:
@dataclass
class LoKrInfo:
kohya_key: str
peft_key: str
alpha: Optional[float] = None
rank: Optional[int] = None
lokr_w1: Optional[torch.Tensor] = None
lokr_w1_a: Optional[torch.Tensor] = None
lokr_w1_b: Optional[torch.Tensor] = None
lokr_w2: Optional[torch.Tensor] = None
lokr_w2_a: Optional[torch.Tensor] = None
lokr_w2_b: Optional[torch.Tensor] = None
lokr_t2: Optional[torch.Tensor] = None

def peft_state_dict(self) -> Dict[str, torch.Tensor]:
if (self.lokr_w1 is None) and ((self.lokr_w1_a is None) or (self.lokr_w1_b is None)):
raise ValueError("Either lokr_w1 or both lokr_w1_a and lokr_w1_b should be provided")

if (self.lokr_w2 is None) and ((self.lokr_w2_a is None) or (self.lokr_w2_b is None)):
raise ValueError("Either lokr_w2 or both lokr_w2_a and lokr_w2_b should be provided")

state_dict = {}

if self.lokr_w1 is not None:
state_dict[f"base_model.model.{self.peft_key}.lokr_w1"] = self.lokr_w1
elif self.lokr_w1_a is not None:
state_dict[f"base_model.model.{self.peft_key}.lokr_w1_a"] = self.lokr_w1_a
state_dict[f"base_model.model.{self.peft_key}.lokr_w1_b"] = self.lokr_w1_b

if self.lokr_w2 is not None:
state_dict[f"base_model.model.{self.peft_key}.lokr_w2"] = self.lokr_w2
elif self.lokr_w2_a is not None:
state_dict[f"base_model.model.{self.peft_key}.lokr_w2_a"] = self.lokr_w2_a
state_dict[f"base_model.model.{self.peft_key}.lokr_w2_b"] = self.lokr_w2_b

if self.lokr_t2 is not None:
state_dict[f"base_model.model.{self.peft_key}.lokr_t2"] = self.lokr_t2

return state_dict


def construct_peft_loraconfig(info: Dict[str, LoRAInfo], **kwargs) -> LoraConfig:
"""Constructs LoraConfig from data extracted from adapter checkpoint

Args:
Expand Down Expand Up @@ -113,7 +158,7 @@ def construct_peft_loraconfig(info: Dict[str, LoRAInfo]) -> LoraConfig:
return config


def construct_peft_lohaconfig(info: Dict[str, LoHaInfo]) -> LoHaConfig:
def construct_peft_lohaconfig(info: Dict[str, LoHaInfo], **kwargs) -> LoHaConfig:
"""Constructs LoHaConfig from data extracted from adapter checkpoint

Args:
Expand Down Expand Up @@ -156,6 +201,77 @@ def construct_peft_lohaconfig(info: Dict[str, LoHaInfo]) -> LoHaConfig:
return config


def construct_peft_lokrconfig(info: Dict[str, LoKrInfo], decompose_factor: int = -1, **kwargs) -> LoKrConfig:
"""Constructs LoKrConfig from data extracted from adapter checkpoint

Args:
info (Dict[str, LoKrInfo]): Information extracted from adapter checkpoint

Returns:
LoKrConfig: config for constructing LoKr
"""

# Unpack all ranks and alphas
ranks = {x[0]: x[1].rank for x in info.items()}
alphas = {x[0]: x[1].alpha or x[1].rank for x in info.items()}

# Determine which modules needs to be transformed
target_modules = sorted(info.keys())

# Determine most common rank and alpha
r = int(Counter(ranks.values()).most_common(1)[0][0])
alpha = Counter(alphas.values()).most_common(1)[0][0]

# Determine which modules have different rank and alpha
rank_pattern = dict(sorted(filter(lambda x: x[1] != r, ranks.items()), key=lambda x: x[0]))
alpha_pattern = dict(sorted(filter(lambda x: x[1] != alpha, alphas.items()), key=lambda x: x[0]))

# Determine whether any of modules have effective conv2d decomposition
use_effective_conv2d = any(((val.lokr_t2 is not None) for val in info.values()))

# decompose_both should be enabled if any w1 matrix in any layer is decomposed into 2
decompose_both = any((val.lokr_w1_a is not None and val.lokr_w1_b is not None) for val in info.values())

# Determining decompose factor is a bit tricky (but it is most often -1)
# Check that decompose_factor is equal to provided
for val in info.values():
# Determine shape of first matrix
if val.lokr_w1 is not None:
w1_shape = tuple(val.lokr_w1.shape)
else:
w1_shape = (val.lokr_w1_a.shape[0], val.lokr_w1_b.shape[1])

# Determine shape of second matrix
if val.lokr_w2 is not None:
w2_shape = tuple(val.lokr_w2.shape[:2])
elif val.lokr_t2 is not None:
w2_shape = (val.lokr_w2_a.shape[1], val.lokr_w2_b.shape[1])
else:
# We may iterate over Conv2d layer, for which second item in shape is multiplied by ksize^2
w2_shape = (val.lokr_w2_a.shape[0], val.lokr_w2_b.shape[1])

# We need to check, whether decompose_factor is really -1 or not
shape = (w1_shape[0], w2_shape[0])
if factorization(shape[0] * shape[1], factor=-1) != shape:
raise ValueError("Cannot infer decompose_factor, probably it is not equal to -1")

config = LoKrConfig(
r=r,
alpha=alpha,
target_modules=target_modules,
rank_dropout=0.0,
module_dropout=0.0,
init_weights=False,
rank_pattern=rank_pattern,
alpha_pattern=alpha_pattern,
use_effective_conv2d=use_effective_conv2d,
decompose_both=decompose_both,
decompose_factor=decompose_factor,
)

return config


def combine_peft_state_dict(info: Dict[str, Union[LoRAInfo, LoHaInfo]]) -> Dict[str, torch.Tensor]:
result = {}
for key_info in info.values():
Expand All @@ -180,7 +296,7 @@ def detect_adapter_type(keys: List[str]) -> PeftType:
elif any(x in key for x in ["lokr_w1", "lokr_w2", "lokr_t1", "lokr_t2"]):
# LoKr may have the following keys:
# lokr_w1, lokr_w2, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t1, lokr_t2
raise ValueError("Currently LoKr adapters are not implemented")
return PeftType.LOKR
elif "diff" in key:
raise ValueError("Currently full diff adapters are not implemented")
else:
Expand Down Expand Up @@ -231,22 +347,40 @@ def detect_adapter_type(keys: List[str]) -> PeftType:
}
)

# Store conversion info (model_type -> peft_key -> LoRAInfo | LoHaInfo)
adapter_info: Dict[str, Dict[str, Union[LoRAInfo, LoHaInfo]]] = {
# Store conversion info (model_type -> peft_key -> LoRAInfo | LoHaInfo | LoKrInfo)
adapter_info: Dict[str, Dict[str, Union[LoRAInfo, LoHaInfo, LoKrInfo]]] = {
"text_encoder": {},
"unet": {},
}

# Store decompose_factor for LoKr
decompose_factor = -1

# Open adapter checkpoint
with safetensors.safe_open(args.adapter_path, framework="pt", device="cpu") as f:
# Extract information about adapter structure
metadata = f.metadata()

# It may be difficult to determine rank for LoKr adapters
# If checkpoint was trained with large rank it may not be utilized during weights creation at all
# So we need to get it from checkpoint metadata (along with decompose_factor)
rank, conv_rank = None, None
if metadata is not None:
rank = metadata.get("ss_network_dim", None)
rank = int(rank) if rank else None
if "ss_network_args" in metadata:
network_args = json.loads(metadata["ss_network_args"])
conv_rank = network_args.get("conv_dim", None)
conv_rank = int(conv_rank) if conv_rank else rank
decompose_factor = network_args.get("factor", -1)
decompose_factor = int(decompose_factor)

# Detect adapter type based on keys
adapter_type = detect_adapter_type(f.keys())
adapter_info_cls = {
PeftType.LORA: LoRAInfo,
PeftType.LOHA: LoHaInfo,
PeftType.LOKR: LoKrInfo,
}[adapter_type]

# Iterate through available info and unpack all the values
Expand All @@ -255,9 +389,9 @@ def detect_adapter_type(keys: List[str]) -> PeftType:

# Find which model this key belongs to
if kohya_key.startswith(PREFIX_TEXT_ENCODER):
model_type = "text_encoder"
model_type, model = "text_encoder", text_encoder
elif kohya_key.startswith(PREFIX_UNET):
model_type = "unet"
model_type, model = "unet", unet
else:
raise ValueError(f"Cannot determine model for key: {key}")

Expand All @@ -266,6 +400,9 @@ def detect_adapter_type(keys: List[str]) -> PeftType:
raise ValueError(f"Cannot find corresponding key for diffusers/transformers model: {kohya_key}")
peft_key = models_keys[kohya_key]

# Retrieve corresponding layer of model
layer = attrgetter(peft_key)(model)

# Create a corresponding adapter info
if peft_key not in adapter_info[model_type]:
adapter_info[model_type][peft_key] = adapter_info_cls(kohya_key=kohya_key, peft_key=peft_key)
Expand Down Expand Up @@ -305,18 +442,48 @@ def detect_adapter_type(keys: List[str]) -> PeftType:
elif kohya_type == "hada_t2":
adapter_info[model_type][peft_key].hada_t2 = tensor
adapter_info[model_type][peft_key].rank = tensor.shape[0]
elif kohya_type == "lokr_t2":
adapter_info[model_type][peft_key].lokr_t2 = tensor
adapter_info[model_type][peft_key].rank = tensor.shape[0]
elif kohya_type == "lokr_w1":
adapter_info[model_type][peft_key].lokr_w1 = tensor
if isinstance(layer, nn.Linear) or (
isinstance(layer, nn.Conv2d) and tuple(layer.weight.shape[2:]) == (1, 1)
):
adapter_info[model_type][peft_key].rank = rank
elif isinstance(layer, nn.Conv2d):
adapter_info[model_type][peft_key].rank = conv_rank
elif kohya_type == "lokr_w2":
adapter_info[model_type][peft_key].lokr_w2 = tensor
if isinstance(layer, nn.Linear) or (
isinstance(layer, nn.Conv2d) and tuple(layer.weight.shape[2:]) == (1, 1)
):
adapter_info[model_type][peft_key].rank = rank
elif isinstance(layer, nn.Conv2d):
adapter_info[model_type][peft_key].rank = conv_rank
elif kohya_type == "lokr_w1_a":
adapter_info[model_type][peft_key].lokr_w1_a = tensor
adapter_info[model_type][peft_key].rank = tensor.shape[1]
elif kohya_type == "lokr_w1_b":
adapter_info[model_type][peft_key].lokr_w1_b = tensor
adapter_info[model_type][peft_key].rank = tensor.shape[0]
elif kohya_type == "lokr_w2_a":
adapter_info[model_type][peft_key].lokr_w2_a = tensor
elif kohya_type == "lokr_w2_b":
adapter_info[model_type][peft_key].lokr_w2_b = tensor
else:
raise ValueError(f"Unknown weight name in key: {key} - {kohya_type}")

# Get function which will create adapter config based on extracted info
construct_config_fn = {
PeftType.LORA: construct_peft_loraconfig,
PeftType.LOHA: construct_peft_lohaconfig,
PeftType.LOKR: construct_peft_lokrconfig,
}[adapter_type]

# Process each model sequentially
for model, model_name in [(text_encoder, "text_encoder"), (unet, "unet")]:
config = construct_config_fn(adapter_info[model_name])
config = construct_config_fn(adapter_info[model_name], decompose_factor=decompose_factor)

# Output warning for LoHa with use_effective_conv2d
if (
Expand Down
Loading