Skip to content

Commit bbae964

Browse files
hipsterusernameMary Hipp
authored andcommitted
Support PEFT Loras with Base_Model.model prefix (#8433)
* Support PEFT Loras with Base_Model.model prefix * update tests * ruff * fix python complaints * update kes * format keys * remove unneeded test
1 parent 2a796fe commit bbae964

File tree

4 files changed

+6909
-8
lines changed

4 files changed

+6909
-8
lines changed

invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,25 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
1818
# First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format).
1919
all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys())
2020

21-
# Next, check that this is likely a FLUX model by spot-checking a few keys.
22-
expected_keys = [
21+
# Check if keys use transformer prefix
22+
transformer_prefix_keys = [
2323
"transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight",
2424
"transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight",
2525
"transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
2626
"transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
2727
]
28-
all_expected_keys_present = all(k in state_dict for k in expected_keys)
28+
transformer_keys_present = all(k in state_dict for k in transformer_prefix_keys)
29+
30+
# Check if keys use base_model.model prefix
31+
base_model_prefix_keys = [
32+
"base_model.model.single_transformer_blocks.0.attn.to_q.lora_A.weight",
33+
"base_model.model.single_transformer_blocks.0.attn.to_q.lora_B.weight",
34+
"base_model.model.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
35+
"base_model.model.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
36+
]
37+
base_model_keys_present = all(k in state_dict for k in base_model_prefix_keys)
2938

30-
return all_keys_in_peft_format and all_expected_keys_present
39+
return all_keys_in_peft_format and (transformer_keys_present or base_model_keys_present)
3140

3241

3342
def lora_model_from_flux_diffusers_state_dict(
@@ -49,8 +58,16 @@ def lora_layers_from_flux_diffusers_grouped_state_dict(
4958
https://github.com/huggingface/diffusers/blob/55ac421f7bb12fd00ccbef727be4dc2f3f920abb/scripts/convert_flux_to_diffusers.py
5059
"""
5160

52-
# Remove the "transformer." prefix from all keys.
53-
grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
61+
# Determine which prefix is used and remove it from all keys.
62+
# Check if any key starts with "base_model.model." prefix
63+
has_base_model_prefix = any(k.startswith("base_model.model.") for k in grouped_state_dict.keys())
64+
65+
if has_base_model_prefix:
66+
# Remove the "base_model.model." prefix from all keys.
67+
grouped_state_dict = {k.replace("base_model.model.", ""): v for k, v in grouped_state_dict.items()}
68+
else:
69+
# Remove the "transformer." prefix from all keys.
70+
grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
5471

5572
# Constants for FLUX.1
5673
num_double_layers = 19

0 commit comments

Comments
 (0)