diff --git a/invokeai/backend/model_manager/configs/lora.py b/invokeai/backend/model_manager/configs/lora.py index 1d0bd646d27..1d7b4b304e9 100644 --- a/invokeai/backend/model_manager/configs/lora.py +++ b/invokeai/backend/model_manager/configs/lora.py @@ -150,11 +150,16 @@ def _validate_base(cls, mod: ModelOnDisk) -> None: @classmethod def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None: - # First rule out ControlLoRA and Diffusers LoRA + # First rule out ControlLoRA flux_format = _get_flux_lora_format(mod) if flux_format in [FluxLoRAFormat.Control]: raise NotAMatchError("model looks like Control LoRA") + # If it's a recognized Flux LoRA format (Kohya, Diffusers, OneTrainer, AIToolkit, XLabs, etc.), + # it's valid and we skip the heuristic check + if flux_format is not None: + return + # Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA. # Some main models have these keys, likely due to the creator merging in a LoRA. has_key_with_lora_prefix = state_dict_has_any_keys_starting_with( diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index 5dd35878479..2b22221151c 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -41,6 +41,10 @@ is_state_dict_likely_in_flux_onetrainer_format, lora_model_from_flux_onetrainer_state_dict, ) +from invokeai.backend.patches.lora_conversions.flux_xlabs_lora_conversion_utils import ( + is_state_dict_likely_in_flux_xlabs_format, + lora_model_from_flux_xlabs_state_dict, +) from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import lora_model_from_z_image_state_dict @@ -118,6 +122,8 @@ def _load_model( model = lora_model_from_flux_control_state_dict(state_dict=state_dict) elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict=state_dict): model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict) + elif is_state_dict_likely_in_flux_xlabs_format(state_dict=state_dict): + model = lora_model_from_flux_xlabs_state_dict(state_dict=state_dict) else: raise ValueError("LoRA model is in unsupported FLUX format") else: diff --git a/invokeai/backend/model_manager/taxonomy.py b/invokeai/backend/model_manager/taxonomy.py index 4c06b430eea..bc7e11367a5 100644 --- a/invokeai/backend/model_manager/taxonomy.py +++ b/invokeai/backend/model_manager/taxonomy.py @@ -171,6 +171,7 @@ class FluxLoRAFormat(str, Enum): OneTrainer = "flux.onetrainer" Control = "flux.control" AIToolkit = "flux.aitoolkit" + XLabs = "flux.xlabs" AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, FluxVariantType] diff --git a/invokeai/backend/patches/lora_conversions/flux_xlabs_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_xlabs_lora_conversion_utils.py new file mode 100644 index 00000000000..b8abbb87635 --- /dev/null +++ b/invokeai/backend/patches/lora_conversions/flux_xlabs_lora_conversion_utils.py @@ -0,0 +1,92 @@ +import re +from typing import Any, Dict + +import torch + +from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch +from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict +from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX +from invokeai.backend.patches.model_patch_raw import ModelPatchRaw + +# A regex pattern that matches all of the transformer keys in the xlabs FLUX LoRA format. +# Example keys: +# double_blocks.0.processor.qkv_lora1.down.weight +# double_blocks.0.processor.qkv_lora1.up.weight +# double_blocks.0.processor.proj_lora1.down.weight +# double_blocks.0.processor.proj_lora1.up.weight +# double_blocks.0.processor.qkv_lora2.down.weight +# double_blocks.0.processor.proj_lora2.up.weight +FLUX_XLABS_KEY_REGEX = r"double_blocks\.(\d+)\.processor\.(qkv|proj)_lora([12])\.(down|up)\.weight" + + +def is_state_dict_likely_in_flux_xlabs_format(state_dict: dict[str | int, Any]) -> bool: + """Checks if the provided state dict is likely in the xlabs FLUX LoRA format. + + The xlabs format is characterized by keys matching the pattern: + double_blocks.{block_idx}.processor.{qkv|proj}_lora{1|2}.{down|up}.weight + + Where: + - lora1 corresponds to the image attention stream (img_attn) + - lora2 corresponds to the text attention stream (txt_attn) + """ + if not state_dict: + return False + + # Check that all keys match the xlabs pattern + for key in state_dict.keys(): + if not isinstance(key, str): + continue + if not re.match(FLUX_XLABS_KEY_REGEX, key): + return False + + # Ensure we have at least some valid keys + return any(isinstance(k, str) and re.match(FLUX_XLABS_KEY_REGEX, k) for k in state_dict.keys()) + + +def lora_model_from_flux_xlabs_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw: + """Converts an xlabs FLUX LoRA state dict to the InvokeAI ModelPatchRaw format. + + The xlabs format uses: + - lora1 for image attention stream (img_attn) + - lora2 for text attention stream (txt_attn) + - qkv for query/key/value projection + - proj for output projection + + Key mapping: + - double_blocks.X.processor.qkv_lora1 -> double_blocks.X.img_attn.qkv + - double_blocks.X.processor.proj_lora1 -> double_blocks.X.img_attn.proj + - double_blocks.X.processor.qkv_lora2 -> double_blocks.X.txt_attn.qkv + - double_blocks.X.processor.proj_lora2 -> double_blocks.X.txt_attn.proj + """ + # Group keys by layer (without the .down.weight/.up.weight suffix) + grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {} + + for key, value in state_dict.items(): + match = re.match(FLUX_XLABS_KEY_REGEX, key) + if not match: + raise ValueError(f"Key '{key}' does not match the expected pattern for xlabs FLUX LoRA weights.") + + block_idx = match.group(1) + component = match.group(2) # qkv or proj + lora_stream = match.group(3) # 1 or 2 + direction = match.group(4) # down or up + + # Map lora1 -> img_attn, lora2 -> txt_attn + attn_type = "img_attn" if lora_stream == "1" else "txt_attn" + + # Create the InvokeAI-style layer key + layer_key = f"double_blocks.{block_idx}.{attn_type}.{component}" + + if layer_key not in grouped_state_dict: + grouped_state_dict[layer_key] = {} + + # Map down/up to lora_down/lora_up + param_name = f"lora_{direction}.weight" + grouped_state_dict[layer_key][param_name] = value + + # Create LoRA layers + layers: dict[str, BaseLayerPatch] = {} + for layer_key, layer_state_dict in grouped_state_dict.items(): + layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict) + + return ModelPatchRaw(layers=layers) diff --git a/invokeai/backend/patches/lora_conversions/formats.py b/invokeai/backend/patches/lora_conversions/formats.py index 4cde7c98f67..ae2e1b14596 100644 --- a/invokeai/backend/patches/lora_conversions/formats.py +++ b/invokeai/backend/patches/lora_conversions/formats.py @@ -14,6 +14,9 @@ from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import ( is_state_dict_likely_in_flux_onetrainer_format, ) +from invokeai.backend.patches.lora_conversions.flux_xlabs_lora_conversion_utils import ( + is_state_dict_likely_in_flux_xlabs_format, +) def flux_format_from_state_dict( @@ -30,5 +33,7 @@ def flux_format_from_state_dict( return FluxLoRAFormat.Control elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict, metadata): return FluxLoRAFormat.AIToolkit + elif is_state_dict_likely_in_flux_xlabs_format(state_dict): + return FluxLoRAFormat.XLabs else: return None diff --git a/tests/backend/patches/lora_conversions/lora_state_dicts/flux_lora_xlabs_format.py b/tests/backend/patches/lora_conversions/lora_state_dicts/flux_lora_xlabs_format.py new file mode 100644 index 00000000000..9458f50ceb0 --- /dev/null +++ b/tests/backend/patches/lora_conversions/lora_state_dicts/flux_lora_xlabs_format.py @@ -0,0 +1,32 @@ +# A sample state dict in the xlabs FLUX LoRA format. +# The xlabs format uses: +# - lora1 for image attention stream (img_attn) +# - lora2 for text attention stream (txt_attn) +# - qkv for query/key/value projection +# - proj for output projection +state_dict_keys = { + "double_blocks.0.processor.proj_lora1.down.weight": [16, 3072], + "double_blocks.0.processor.proj_lora1.up.weight": [3072, 16], + "double_blocks.0.processor.proj_lora2.down.weight": [16, 3072], + "double_blocks.0.processor.proj_lora2.up.weight": [3072, 16], + "double_blocks.0.processor.qkv_lora1.down.weight": [16, 3072], + "double_blocks.0.processor.qkv_lora1.up.weight": [9216, 16], + "double_blocks.0.processor.qkv_lora2.down.weight": [16, 3072], + "double_blocks.0.processor.qkv_lora2.up.weight": [9216, 16], + "double_blocks.1.processor.proj_lora1.down.weight": [16, 3072], + "double_blocks.1.processor.proj_lora1.up.weight": [3072, 16], + "double_blocks.1.processor.proj_lora2.down.weight": [16, 3072], + "double_blocks.1.processor.proj_lora2.up.weight": [3072, 16], + "double_blocks.1.processor.qkv_lora1.down.weight": [16, 3072], + "double_blocks.1.processor.qkv_lora1.up.weight": [9216, 16], + "double_blocks.1.processor.qkv_lora2.down.weight": [16, 3072], + "double_blocks.1.processor.qkv_lora2.up.weight": [9216, 16], + "double_blocks.10.processor.proj_lora1.down.weight": [16, 3072], + "double_blocks.10.processor.proj_lora1.up.weight": [3072, 16], + "double_blocks.10.processor.proj_lora2.down.weight": [16, 3072], + "double_blocks.10.processor.proj_lora2.up.weight": [3072, 16], + "double_blocks.10.processor.qkv_lora1.down.weight": [16, 3072], + "double_blocks.10.processor.qkv_lora1.up.weight": [9216, 16], + "double_blocks.10.processor.qkv_lora2.down.weight": [16, 3072], + "double_blocks.10.processor.qkv_lora2.up.weight": [9216, 16], +} diff --git a/tests/backend/patches/lora_conversions/test_flux_xlabs_lora_conversion_utils.py b/tests/backend/patches/lora_conversions/test_flux_xlabs_lora_conversion_utils.py new file mode 100644 index 00000000000..f6031457bfb --- /dev/null +++ b/tests/backend/patches/lora_conversions/test_flux_xlabs_lora_conversion_utils.py @@ -0,0 +1,99 @@ +import accelerate +import pytest +import torch + +from invokeai.backend.flux.model import Flux +from invokeai.backend.flux.util import get_flux_transformers_params +from invokeai.backend.model_manager.taxonomy import FluxVariantType +from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX +from invokeai.backend.patches.lora_conversions.flux_xlabs_lora_conversion_utils import ( + is_state_dict_likely_in_flux_xlabs_format, + lora_model_from_flux_xlabs_state_dict, +) +from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_diffusers_format import ( + state_dict_keys as flux_diffusers_state_dict_keys, +) +from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_kohya_format import ( + state_dict_keys as flux_kohya_state_dict_keys, +) +from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_xlabs_format import ( + state_dict_keys as flux_xlabs_state_dict_keys, +) +from tests.backend.patches.lora_conversions.lora_state_dicts.utils import keys_to_mock_state_dict + + +def test_is_state_dict_likely_in_flux_xlabs_format_true(): + """Test that is_state_dict_likely_in_flux_xlabs_format() can identify a state dict in the xlabs FLUX LoRA format.""" + state_dict = keys_to_mock_state_dict(flux_xlabs_state_dict_keys) + assert is_state_dict_likely_in_flux_xlabs_format(state_dict) + + +@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys, flux_kohya_state_dict_keys]) +def test_is_state_dict_likely_in_flux_xlabs_format_false(sd_keys: dict[str, list[int]]): + """Test that is_state_dict_likely_in_flux_xlabs_format() returns False for state dicts in other formats.""" + state_dict = keys_to_mock_state_dict(sd_keys) + assert not is_state_dict_likely_in_flux_xlabs_format(state_dict) + + +def test_lora_model_from_flux_xlabs_state_dict(): + """Test that a ModelPatchRaw can be created from a state dict in the xlabs FLUX LoRA format.""" + state_dict = keys_to_mock_state_dict(flux_xlabs_state_dict_keys) + + lora_model = lora_model_from_flux_xlabs_state_dict(state_dict) + + # Verify the expected layer keys are created + expected_layer_keys = { + f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.0.img_attn.proj", + f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.0.img_attn.qkv", + f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.0.txt_attn.proj", + f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.0.txt_attn.qkv", + f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.1.img_attn.proj", + f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.1.img_attn.qkv", + f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.1.txt_attn.proj", + f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.1.txt_attn.qkv", + f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.10.img_attn.proj", + f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.10.img_attn.qkv", + f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.10.txt_attn.proj", + f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.10.txt_attn.qkv", + } + + assert set(lora_model.layers.keys()) == expected_layer_keys + + +def test_lora_model_from_flux_xlabs_state_dict_matches_model_keys(): + """Test that the converted xlabs LoRA keys match the actual FLUX model keys.""" + state_dict = keys_to_mock_state_dict(flux_xlabs_state_dict_keys) + + lora_model = lora_model_from_flux_xlabs_state_dict(state_dict) + + # Extract the layer prefixes (without the lora_transformer- prefix) + converted_key_prefixes: list[str] = [] + for k in lora_model.layers.keys(): + # Remove the transformer prefix + k = k.replace(FLUX_LORA_TRANSFORMER_PREFIX, "") + converted_key_prefixes.append(k) + + # Initialize a FLUX model on the meta device. + with accelerate.init_empty_weights(): + model = Flux(get_flux_transformers_params(FluxVariantType.Schnell)) + model_keys = set(model.state_dict().keys()) + + # Assert that the converted keys match prefixes in the actual model. + for converted_key_prefix in converted_key_prefixes: + found_match = False + for model_key in model_keys: + if model_key.startswith(converted_key_prefix): + found_match = True + break + if not found_match: + raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}") + + +def test_lora_model_from_flux_xlabs_state_dict_error(): + """Test that an error is raised if the input state_dict contains unexpected keys.""" + state_dict = { + "unexpected_key.down.weight": torch.empty(1), + } + + with pytest.raises(ValueError): + lora_model_from_flux_xlabs_state_dict(state_dict)