Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions invokeai/backend/model_manager/configs/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,42 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base):

base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)

@classmethod
def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
"""Z-Image LoRAs have different key patterns than SD/SDXL LoRAs.

Z-Image LoRAs use keys like:
- diffusion_model.layers.X.attention.to_k.lora_down.weight (DoRA format)
- diffusion_model.layers.X.attention.to_k.lora_A.weight (PEFT format)
- diffusion_model.layers.X.attention.to_k.dora_scale (DoRA scale)
"""
state_dict = mod.load_state_dict()

# Check for Z-Image specific LoRA patterns
has_z_image_lora_keys = state_dict_has_any_keys_starting_with(
state_dict,
{
"diffusion_model.layers.", # Z-Image S3-DiT layer pattern
},
)

# Also check for LoRA weight suffixes (various formats)
has_lora_suffix = state_dict_has_any_keys_ending_with(
state_dict,
{
"lora_A.weight",
"lora_B.weight",
"lora_down.weight",
"lora_up.weight",
"dora_scale",
},
)

if has_z_image_lora_keys and has_lora_suffix:
return

raise NotAMatchError("model does not match Z-Image LoRA heuristics")

@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
"""Z-Image LoRAs are identified by their diffusion_model.layers structure.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,50 @@ def _get_lora_layer_values(layer_dict: dict[str, torch.Tensor], alpha: float | N


def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
"""Groups the keys in the state dict by layer."""
"""Groups the keys in the state dict by layer.

Z-Image LoRAs have keys like:
- diffusion_model.layers.17.attention.to_k.alpha
- diffusion_model.layers.17.attention.to_k.dora_scale
- diffusion_model.layers.17.attention.to_k.lora_down.weight
- diffusion_model.layers.17.attention.to_k.lora_up.weight

We need to group these by the full layer path (e.g., diffusion_model.layers.17.attention.to_k)
and extract the suffix (alpha, dora_scale, lora_down.weight, lora_up.weight).
"""
layer_dict: dict[str, dict[str, torch.Tensor]] = {}

# Known suffixes that indicate the end of a layer name
known_suffixes = [
".lora_A.weight",
".lora_B.weight",
".lora_down.weight",
".lora_up.weight",
".dora_scale",
".alpha",
]

for key in state_dict:
if not isinstance(key, str):
continue
# Split the 'lora_A.weight' or 'lora_B.weight' suffix from the layer name.
parts = key.rsplit(".", maxsplit=2)
layer_name = parts[0]
key_name = ".".join(parts[1:])

# Try to find a known suffix
layer_name = None
key_name = None
for suffix in known_suffixes:
if key.endswith(suffix):
layer_name = key[: -len(suffix)]
key_name = suffix[1:] # Remove leading dot
break

if layer_name is None:
# Fallback to original logic for unknown formats
parts = key.rsplit(".", maxsplit=2)
layer_name = parts[0]
key_name = ".".join(parts[1:])

if layer_name not in layer_dict:
layer_dict[layer_name] = {}
layer_dict[layer_name][key_name] = state_dict[key]

return layer_dict