@@ -18,16 +18,25 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
18
18
# First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format).
19
19
all_keys_in_peft_format = all (k .endswith (("lora_A.weight" , "lora_B.weight" )) for k in state_dict .keys ())
20
20
21
- # Next, check that this is likely a FLUX model by spot-checking a few keys.
22
- expected_keys = [
21
+ # Check if keys use transformer prefix
22
+ transformer_prefix_keys = [
23
23
"transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight" ,
24
24
"transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight" ,
25
25
"transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight" ,
26
26
"transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight" ,
27
27
]
28
- all_expected_keys_present = all (k in state_dict for k in expected_keys )
28
+ transformer_keys_present = all (k in state_dict for k in transformer_prefix_keys )
29
+
30
+ # Check if keys use base_model.model prefix
31
+ base_model_prefix_keys = [
32
+ "base_model.model.single_transformer_blocks.0.attn.to_q.lora_A.weight" ,
33
+ "base_model.model.single_transformer_blocks.0.attn.to_q.lora_B.weight" ,
34
+ "base_model.model.transformer_blocks.0.attn.add_q_proj.lora_A.weight" ,
35
+ "base_model.model.transformer_blocks.0.attn.add_q_proj.lora_B.weight" ,
36
+ ]
37
+ base_model_keys_present = all (k in state_dict for k in base_model_prefix_keys )
29
38
30
- return all_keys_in_peft_format and all_expected_keys_present
39
+ return all_keys_in_peft_format and ( transformer_keys_present or base_model_keys_present )
31
40
32
41
33
42
def lora_model_from_flux_diffusers_state_dict (
@@ -49,8 +58,16 @@ def lora_layers_from_flux_diffusers_grouped_state_dict(
49
58
https://github.com/huggingface/diffusers/blob/55ac421f7bb12fd00ccbef727be4dc2f3f920abb/scripts/convert_flux_to_diffusers.py
50
59
"""
51
60
52
- # Remove the "transformer." prefix from all keys.
53
- grouped_state_dict = {k .replace ("transformer." , "" ): v for k , v in grouped_state_dict .items ()}
61
+ # Determine which prefix is used and remove it from all keys.
62
+ # Check if any key starts with "base_model.model." prefix
63
+ has_base_model_prefix = any (k .startswith ("base_model.model." ) for k in grouped_state_dict .keys ())
64
+
65
+ if has_base_model_prefix :
66
+ # Remove the "base_model.model." prefix from all keys.
67
+ grouped_state_dict = {k .replace ("base_model.model." , "" ): v for k , v in grouped_state_dict .items ()}
68
+ else :
69
+ # Remove the "transformer." prefix from all keys.
70
+ grouped_state_dict = {k .replace ("transformer." , "" ): v for k , v in grouped_state_dict .items ()}
54
71
55
72
# Constants for FLUX.1
56
73
num_double_layers = 19
0 commit comments