Skip to content

Commit

Permalink
onediffx.lora supports diffusers>=0.19.3 (#740)
Browse files Browse the repository at this point in the history
  • Loading branch information
marigoold authored Mar 20, 2024
1 parent 8b382a6 commit 5f068ea
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 130 deletions.
7 changes: 6 additions & 1 deletion onediff_diffusers_extensions/onediffx/lora/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
from .lora import load_and_fuse_lora, unfuse_lora, set_and_fuse_adapters, delete_adapters
from .lora import (
load_and_fuse_lora,
unfuse_lora,
set_and_fuse_adapters,
delete_adapters,
)
48 changes: 25 additions & 23 deletions onediff_diffusers_extensions/onediffx/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,26 @@

import diffusers
from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import PatchedLoraProjection

if version.parse(diffusers.__version__) >= version.parse("0.21.0"):
from diffusers.models.lora import PatchedLoraProjection
else:
from diffusers.loaders import PatchedLoraProjection

from .utils import _unfuse_lora, _set_adapter, _delete_adapter

from .utils import (
_unfuse_lora,
_set_adapter,
_delete_adapter,
_maybe_map_sgm_blocks_to_diffusers,
is_peft_available,
)
from .text_encoder import load_lora_into_text_encoder
from .unet import load_lora_into_unet

from diffusers.utils.import_utils import is_peft_available

if is_peft_available():
import peft
is_onediffx_lora_available = version.parse(diffusers.__version__) >= version.parse(
"0.21.0"
)
is_onediffx_lora_available = version.parse(diffusers.__version__) >= version.parse("0.19.3")


USE_PEFT_BACKEND = False
Expand All @@ -40,23 +46,25 @@ def load_and_fuse_lora(
) -> None:
if not is_onediffx_lora_available:
raise RuntimeError(
"onediffx.lora only supports diffusers of at least version 0.21.0"
"onediffx.lora only supports diffusers of at least version 0.19.3"
)

self = pipeline

if use_cache:
state_dict, network_alphas = load_state_dict_cached(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, **kwargs,
)
else:
# for diffusers <= 0.20
if hasattr(LoraLoaderMixin, "_map_sgm_blocks_to_diffusers"):
orig_func = getattr(LoraLoaderMixin, "_map_sgm_blocks_to_diffusers")
LoraLoaderMixin._map_sgm_blocks_to_diffusers = _maybe_map_sgm_blocks_to_diffusers
state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, **kwargs,
)
if hasattr(LoraLoaderMixin, "_map_sgm_blocks_to_diffusers"):
LoraLoaderMixin._map_sgm_blocks_to_diffusers = orig_func

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
Expand All @@ -75,9 +83,7 @@ def load_and_fuse_lora(
)

# load lora weights into text encoder
text_encoder_state_dict = {
k: v for k, v in state_dict.items() if "text_encoder." in k
}
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
load_lora_into_text_encoder(
self,
Expand All @@ -90,9 +96,7 @@ def load_and_fuse_lora(
_pipeline=self,
)

text_encoder_2_state_dict = {
k: v for k, v in state_dict.items() if "text_encoder_2." in k
}
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
if len(text_encoder_2_state_dict) > 0 and hasattr(self, "text_encoder_2"):
load_lora_into_text_encoder(
self,
Expand Down Expand Up @@ -194,9 +198,7 @@ def load_state_dict_cached(

lora_name = str(lora) + (f"/{weight_name}" if weight_name else "")
if lora_name in CachedLoRAs:
logger.debug(
f"[OneDiffX Cached LoRA] get cached lora of name: {str(lora_name)}"
)
logger.debug(f"[OneDiffX Cached LoRA] get cached lora of name: {str(lora_name)}")
return CachedLoRAs[lora_name]

state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(lora, **kwargs,)
Expand Down
77 changes: 25 additions & 52 deletions onediff_diffusers_extensions/onediffx/lora/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
from diffusers.utils import convert_state_dict_to_diffusers
else:
from .state_dict_utils import convert_state_dict_to_diffusers
from diffusers.models.lora import text_encoder_attn_modules, text_encoder_mlp_modules

if version.parse(diffusers.__version__) >= version.parse("0.24.0"):
from diffusers.models.lora import (
text_encoder_attn_modules,
text_encoder_mlp_modules,
)
else:
from diffusers.loaders import text_encoder_attn_modules, text_encoder_mlp_modules
from diffusers.utils import is_accelerate_available

from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
Expand Down Expand Up @@ -61,9 +68,7 @@ def load_lora_into_text_encoder(
`default_{i}` where i is the total number of adapters being loaded.
"""
low_cpu_mem_usage = (
low_cpu_mem_usage
if low_cpu_mem_usage is not None
else _LOW_CPU_MEM_USAGE_DEFAULT
low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
)

if adapter_name is None:
Expand All @@ -90,13 +95,9 @@ def load_lora_into_text_encoder(
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [
k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix
]
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v
for k, v in state_dict.items()
if k in text_encoder_keys
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}

if len(text_encoder_lora_state_dict) > 0:
Expand All @@ -116,40 +117,26 @@ def load_lora_into_text_encoder(
rank_key = f"{name}.out_proj.lora_B.weight"
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]

patch_mlp = any(
".mlp." in key for key in text_encoder_lora_state_dict.keys()
)
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
rank_key_fc2 = f"{name}.fc2.lora_B.weight"

rank[rank_key_fc1] = text_encoder_lora_state_dict[
rank_key_fc1
].shape[1]
rank[rank_key_fc2] = text_encoder_lora_state_dict[
rank_key_fc2
].shape[1]
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
else:
for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
rank.update(
{rank_key: text_encoder_lora_state_dict[rank_key].shape[1]}
)
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})

patch_mlp = any(
".mlp." in key for key in text_encoder_lora_state_dict.keys()
)
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
rank[rank_key_fc1] = text_encoder_lora_state_dict[
rank_key_fc1
].shape[1]
rank[rank_key_fc2] = text_encoder_lora_state_dict[
rank_key_fc2
].shape[1]
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]

# group text encoder lora state_dict
te_lora_grouped_dict = defaultdict(dict)
Expand Down Expand Up @@ -207,23 +194,13 @@ def load_lora_into_text_encoder(
is_network_alphas_populated = len(network_alphas) > 0

for name, attn_module in text_encoder_attn_modules(text_encoder):
query_alpha = network_alphas.pop(
name + ".to_q_lora.down.weight.alpha", None
)
key_alpha = network_alphas.pop(
name + ".to_k_lora.down.weight.alpha", None
)
value_alpha = network_alphas.pop(
name + ".to_v_lora.down.weight.alpha", None
)
out_alpha = network_alphas.pop(
name + ".to_out_lora.down.weight.alpha", None
)
query_alpha = network_alphas.pop(name + ".to_q_lora.down.weight.alpha", None)
key_alpha = network_alphas.pop(name + ".to_k_lora.down.weight.alpha", None)
value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None)
out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None)

if isinstance(rank, dict):
current_rank = rank.pop(
f"{name}.out_proj.lora_linear_layer.up.weight"
)
current_rank = rank.pop(f"{name}.out_proj.lora_linear_layer.up.weight")
else:
current_rank = rank

Expand Down Expand Up @@ -273,12 +250,8 @@ def load_lora_into_text_encoder(
name + ".fc2.lora_linear_layer.down.weight.alpha", None
)

current_rank_fc1 = rank.pop(
f"{name}.fc1.lora_linear_layer.up.weight"
)
current_rank_fc2 = rank.pop(
f"{name}.fc2.lora_linear_layer.up.weight"
)
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")

fuse_lora(
mlp_module.fc1,
Expand Down
Loading

0 comments on commit 5f068ea

Please sign in to comment.