Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PEFT Integration for Text Encoder to handle multiple alphas/ranks, disable/enable adapters and support for multiple adapters #5147

Merged
merged 63 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
ba24f2a
more fixes
younesbelkada Sep 14, 2023
c17634c
up
younesbelkada Sep 15, 2023
2a6e535
up
younesbelkada Sep 15, 2023
01f6d1d
style
younesbelkada Sep 15, 2023
5a150b2
add in setup
younesbelkada Sep 15, 2023
961e776
oops
younesbelkada Sep 15, 2023
cdbe739
more changes
younesbelkada Sep 15, 2023
691368b
v1 rzfactor CI
younesbelkada Sep 18, 2023
7918851
Apply suggestions from code review
younesbelkada Sep 18, 2023
14db139
few todos
younesbelkada Sep 18, 2023
c06c40b
Merge branch 'main' into peftpart-1
younesbelkada Sep 18, 2023
d56a14d
protect torch import
younesbelkada Sep 18, 2023
ec87c19
style
younesbelkada Sep 18, 2023
40a6028
fix fuse text encoder
younesbelkada Sep 18, 2023
0c62ef3
Merge remote-tracking branch 'upstream/main' into peftpart-1
younesbelkada Sep 18, 2023
c4295c9
Update src/diffusers/loaders.py
younesbelkada Sep 19, 2023
4162ddf
replace with `recurse_replace_peft_layers`
younesbelkada Sep 19, 2023
1d13f40
keep old modules for BC
younesbelkada Sep 19, 2023
78a860d
adjustments on `adjust_lora_scale_text_encoder`
younesbelkada Sep 19, 2023
78a01d5
Merge branch 'main' into peftpart-1
younesbelkada Sep 19, 2023
ecbc714
Merge remote-tracking branch 'upstream/main' into peftpart-1
younesbelkada Sep 19, 2023
9d650c9
Merge branch 'peftpart-1' of https://github.com/younesbelkada/diffuse…
younesbelkada Sep 19, 2023
6f1adcd
nit
younesbelkada Sep 19, 2023
f890906
move tests
younesbelkada Sep 19, 2023
f8e87f6
add conversion utils
younesbelkada Sep 19, 2023
3ba2d4e
Merge remote-tracking branch 'upstream/main' into peftpart-1
younesbelkada Sep 19, 2023
dc83fa0
remove unneeded methods
younesbelkada Sep 19, 2023
b83fcba
use class method instead
younesbelkada Sep 19, 2023
74e33a9
oops
younesbelkada Sep 19, 2023
9cb8563
use `base_version`
younesbelkada Sep 19, 2023
c90f85d
fix examples
younesbelkada Sep 19, 2023
40a4894
fix CI
younesbelkada Sep 19, 2023
ea05959
fix weird error with python 3.8
younesbelkada Sep 19, 2023
27e3da6
fix
younesbelkada Sep 19, 2023
3d7c567
better fix
younesbelkada Sep 19, 2023
d01a292
style
younesbelkada Sep 19, 2023
e836b14
Apply suggestions from code review
younesbelkada Sep 20, 2023
cb48405
Apply suggestions from code review
younesbelkada Sep 20, 2023
325462d
add comment
younesbelkada Sep 20, 2023
b412adc
Apply suggestions from code review
younesbelkada Sep 20, 2023
b72ef23
conv2d support for recurse remove
younesbelkada Sep 20, 2023
e072655
added docstrings
younesbelkada Sep 20, 2023
bd46ae9
more docstring
younesbelkada Sep 20, 2023
724b52b
add deprecate
younesbelkada Sep 20, 2023
5e6f343
revert
younesbelkada Sep 20, 2023
71650d4
try to fix merge conflicts
younesbelkada Sep 20, 2023
920333f
Merge remote-tracking branch 'upstream/main' into peftpart-1
younesbelkada Sep 20, 2023
0985d17
peft integration features for text encoder
pacman100 Sep 22, 2023
ece3b02
Merge branch 'main' into peftpart-1
pacman100 Sep 22, 2023
01a15cc
fix bug
pacman100 Sep 22, 2023
080db75
Merge branch 'main' into smangrul/peft-integration
pacman100 Sep 22, 2023
ffbac30
fix code quality
pacman100 Sep 22, 2023
916c31a
Apply suggestions from code review
pacman100 Sep 25, 2023
5de0f1b
fix bugs
pacman100 Sep 25, 2023
c32872e
Merge branch 'smangrul/peft-integration' of https://github.com/huggin…
pacman100 Sep 25, 2023
0acb58c
Apply suggestions from code review
pacman100 Sep 25, 2023
1ca4c62
address comments
pacman100 Sep 26, 2023
7c37788
fix code quality
pacman100 Sep 26, 2023
2fcf174
address comments
pacman100 Sep 26, 2023
a1f0128
address comments
pacman100 Sep 26, 2023
7b2ccff
Merge branch 'main' into smangrul/peft-integration
patrickvonplaten Sep 27, 2023
fd9bcfe
Apply suggestions from code review
patrickvonplaten Sep 27, 2023
9916ac6
find and replace
patrickvonplaten Sep 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 105 additions & 12 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,23 @@
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
deprecate,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_omegaconf_available,
is_peft_available,
is_transformers_available,
logging,
recurse_remove_peft_layers,
scale_lora_layers,
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .utils.import_utils import BACKENDS_MAPPING


if is_transformers_available():
from transformers import CLIPTextModel, CLIPTextModelWithProjection
from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel

if is_accelerate_available():
from accelerate import init_empty_weights
Expand Down Expand Up @@ -1100,7 +1105,9 @@ class LoraLoaderMixin:
num_fused_loras = 0
use_peft_backend = USE_PEFT_BACKEND

def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
`self.text_encoder`.
Expand All @@ -1120,6 +1127,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
Expand All @@ -1143,6 +1153,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
text_encoder=self.text_encoder,
lora_scale=self.lora_scale,
low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep the private variable at the end.

_pipeline=self,
)

Expand Down Expand Up @@ -1500,6 +1511,7 @@ def load_lora_into_text_encoder(
prefix=None,
lora_scale=1.0,
low_cpu_mem_usage=None,
adapter_name=None,
_pipeline=None,
):
"""
Expand All @@ -1523,6 +1535,9 @@ def load_lora_into_text_encoder(
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT

Expand Down Expand Up @@ -1584,19 +1599,22 @@ def load_lora_into_text_encoder(
if cls.use_peft_backend:
from peft import LoraConfig

lora_rank = list(rank.values())[0]
# By definition, the scale should be alpha divided by rank.
# https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/tuners/lora/layer.py#L71
alpha = lora_scale * lora_rank
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict)

target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
if patch_mlp:
target_modules += ["fc1", "fc2"]
lora_config = LoraConfig(**lora_config_kwargs)

# TODO: support multi alpha / rank: https://github.com/huggingface/peft/pull/873
lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)

text_encoder.load_adapter(adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config)
# inject LoRA layers and load the state dict
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)

is_model_cpu_offload = False
is_sequential_cpu_offload = False
Expand Down Expand Up @@ -2178,6 +2196,81 @@ def unfuse_text_encoder_lora(text_encoder):

self.num_fused_loras -= 1

def set_adapter_for_text_encoder(
self,
adapter_names: Union[List[str], str],
text_encoder: Optional[PreTrainedModel] = None,
text_encoder_weights: List[float] = None,
):
"""
Sets the adapter layers for the text encoder.

Args:
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
attribute.
text_encoder_weights (`List[float]`, *optional*):
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
"""
if not self.use_peft_backend:
raise ValueError("PEFT backend is required for this method.")

def process_weights(adapter_names, weights):
if weights is None:
weights = [1.0] * len(adapter_names)
elif isinstance(weights, float):
weights = [weights]

if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
)
return weights

adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
text_encoder = text_encoder or getattr(self, "text_encoder", None)
if text_encoder is None:
raise ValueError(
"The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
)
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)

def disable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
"""
Disables the LoRA layers for the text encoder.

Args:
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to disable the LoRA layers for. If `None`, it will try to get the
`text_encoder` attribute.
"""
if not self.use_peft_backend:
raise ValueError("PEFT backend is required for this method.")

text_encoder = text_encoder or getattr(self, "text_encoder", None)
if text_encoder is None:
raise ValueError("Text Encoder not found.")
set_adapter_layers(text_encoder, enabled=False)

def enable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
"""
Enables the LoRA layers for the text encoder.

Args:
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
attribute.
"""
if not self.use_peft_backend:
raise ValueError("PEFT backend is required for this method.")
text_encoder = text_encoder or getattr(self, "text_encoder", None)
if text_encoder is None:
raise ValueError("Text Encoder not found.")
set_adapter_layers(self.text_encoder, enabled=True)


class FromSingleFileMixin:
"""
Expand Down
8 changes: 2 additions & 6 deletions src/diffusers/models/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,15 @@
from torch import nn

from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
from ..utils import logging
from ..utils import logging, scale_lora_layers


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False):
if use_peft_backend:
from peft.tuners.lora import LoraLayer

for module in text_encoder.modules():
if isinstance(module, LoraLayer):
module.scaling[module.active_adapter] = lora_scale
scale_lora_layers(text_encoder, weight=lora_scale)
else:
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
Expand Down
9 changes: 8 additions & 1 deletion src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,14 @@
from .loading_utils import load_image
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import recurse_remove_peft_layers
from .peft_utils import (
get_adapter_name,
get_peft_kwargs,
recurse_remove_peft_layers,
scale_lora_layers,
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft

Expand Down
97 changes: 97 additions & 0 deletions src/diffusers/utils/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"""
PEFT utilities: Utilities related to peft library
"""
import collections

from .import_utils import is_torch_available


Expand Down Expand Up @@ -68,3 +70,98 @@ def recurse_remove_peft_layers(model):
torch.cuda.empty_cache()

return model


def scale_lora_layers(model, weight):
"""
Adjust the weightage given to the LoRA layers of the model.

Args:
model (`torch.nn.Module`):
The model to scale.
weight (`float`):
The weight to be given to the LoRA layers.
"""
from peft.tuners.tuners_utils import BaseTunerLayer

for module in model.modules():
if isinstance(module, BaseTunerLayer):
module.scale_layer(weight)


def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict):
rank_pattern = {}
alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0]
if len(set(rank_dict.values())) > 1:
# get the rank occuring the most number of times
r = collections.Counter(rank_dict.values()).most_common()[0][0]

# for modules with rank different from the most occuring rank, add it to the `rank_pattern`
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}

if network_alpha_dict is not None and len(set(network_alpha_dict.values())) > 1:
# get the alpha occuring the most number of times
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]

# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}

# layer names without the Diffusers specific
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})

lora_config_kwargs = {
"r": r,
"lora_alpha": lora_alpha,
"rank_pattern": rank_pattern,
"alpha_pattern": alpha_pattern,
"target_modules": target_modules,
}
return lora_config_kwargs


def get_adapter_name(model):
from peft.tuners.tuners_utils import BaseTunerLayer

for module in model.modules():
if isinstance(module, BaseTunerLayer):
return f"default_{len(module.r)}"
return "default_0"


def set_adapter_layers(model, enabled=True):
from peft.tuners.tuners_utils import BaseTunerLayer

for module in model.modules():
if isinstance(module, BaseTunerLayer):
# The recent version of PEFT needs to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False)
else:
module.disable_adapters = True


def set_weights_and_activate_adapters(model, adapter_names, weights):
from peft.tuners.tuners_utils import BaseTunerLayer

# iterate over each adapter, make it active and set the corresponding scaling weight
for adapter_name, weight in zip(adapter_names, weights):
for module in model.modules():
if isinstance(module, BaseTunerLayer):
# For backward compatbility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
else:
module.active_adapter = adapter_name
module.scale_layer(weight)

# set multiple active adapters
for module in model.modules():
if isinstance(module, BaseTunerLayer):
# For backward compatbility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_names)
else:
module.active_adapter = adapter_names