diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index aab87b8f4dba..bf1433aa108e 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -665,6 +665,251 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): return new_state_dict +def _convert_non_diffusers_sd3_lora_to_diffusers(state_dict, prefix=None): + new_state_dict = {} + + # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; + # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + def calculate_scales(key): + lora_rank = state_dict[f"{key}.lora_down.weight"].shape[0] + alpha = state_dict.pop(key + ".alpha") + scale = alpha / lora_rank + + # calculate scale_down and scale_up + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + return scale_down, scale_up + + def weight_is_sparse(key, rank, num_splits, up_weight): + dims = [up_weight.shape[0] // num_splits] * num_splits + + is_sparse = False + requested_rank = rank + if rank % num_splits == 0: + requested_rank = rank // num_splits + is_sparse = True + i = 0 + for j in range(len(dims)): + for k in range(len(dims)): + if j == k: + continue + is_sparse = is_sparse and torch.all( + up_weight[i : i + dims[j], k * requested_rank : (k + 1) * requested_rank] == 0 + ) + i += dims[j] + if is_sparse: + logger.info(f"weight is sparse: {key}") + + return is_sparse, requested_rank + + # handle only transformer blocks for now. + layers = set() + for k in state_dict: + if "joint_blocks" in k: + idx = int(k.split("_", 4)[-1].split("_", 1)[0]) + layers.add(idx) + num_layers = max(layers) + 1 + + for i in range(num_layers): + # norms + for diffusers_key, orig_key in [ + (f"transformer_blocks.{i}.norm1.linear", f"lora_unet_joint_blocks_{i}_x_block_adaLN_modulation_1") + ]: + scale_down, scale_up = calculate_scales(orig_key) + new_state_dict[f"{diffusers_key}.lora_A.weight"] = ( + state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down + ) + new_state_dict[f"{diffusers_key}.lora_B.weight"] = state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up + + if not (i == num_layers - 1): + for diffusers_key, orig_key in [ + ( + f"transformer_blocks.{i}.norm1_context.linear", + f"lora_unet_joint_blocks_{i}_context_block_adaLN_modulation_1", + ) + ]: + scale_down, scale_up = calculate_scales(orig_key) + new_state_dict[f"{diffusers_key}.lora_A.weight"] = ( + state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down + ) + new_state_dict[f"{diffusers_key}.lora_B.weight"] = ( + state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up + ) + else: + for diffusers_key, orig_key in [ + ( + f"transformer_blocks.{i}.norm1_context.linear", + f"lora_unet_joint_blocks_{i}_context_block_adaLN_modulation_1", + ) + ]: + scale_down, scale_up = calculate_scales(orig_key) + new_state_dict[f"{diffusers_key}.lora_A.weight"] = ( + swap_scale_shift(state_dict.pop(f"{orig_key}.lora_down.weight")) * scale_down + ) + new_state_dict[f"{diffusers_key}.lora_B.weight"] = ( + swap_scale_shift(state_dict.pop(f"{orig_key}.lora_up.weight")) * scale_up + ) + + # output projections + for diffusers_key, orig_key in [ + (f"transformer_blocks.{i}.attn.to_out.0", f"lora_unet_joint_blocks_{i}_x_block_attn_proj") + ]: + scale_down, scale_up = calculate_scales(orig_key) + new_state_dict[f"{diffusers_key}.lora_A.weight"] = ( + state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down + ) + new_state_dict[f"{diffusers_key}.lora_B.weight"] = state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up + if not (i == num_layers - 1): + for diffusers_key, orig_key in [ + (f"transformer_blocks.{i}.attn.to_add_out", f"lora_unet_joint_blocks_{i}_context_block_attn_proj") + ]: + scale_down, scale_up = calculate_scales(orig_key) + new_state_dict[f"{diffusers_key}.lora_A.weight"] = ( + state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down + ) + new_state_dict[f"{diffusers_key}.lora_B.weight"] = ( + state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up + ) + + # ffs + for diffusers_key, orig_key in [ + (f"transformer_blocks.{i}.ff.net.0.proj", f"lora_unet_joint_blocks_{i}_x_block_mlp_fc1") + ]: + scale_down, scale_up = calculate_scales(orig_key) + new_state_dict[f"{diffusers_key}.lora_A.weight"] = ( + state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down + ) + new_state_dict[f"{diffusers_key}.lora_B.weight"] = state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up + + for diffusers_key, orig_key in [ + (f"transformer_blocks.{i}.ff.net.2", f"lora_unet_joint_blocks_{i}_x_block_mlp_fc2") + ]: + scale_down, scale_up = calculate_scales(orig_key) + new_state_dict[f"{diffusers_key}.lora_A.weight"] = ( + state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down + ) + new_state_dict[f"{diffusers_key}.lora_B.weight"] = state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up + + if not (i == num_layers - 1): + for diffusers_key, orig_key in [ + (f"transformer_blocks.{i}.ff_context.net.0.proj", f"lora_unet_joint_blocks_{i}_context_block_mlp_fc1") + ]: + scale_down, scale_up = calculate_scales(orig_key) + new_state_dict[f"{diffusers_key}.lora_A.weight"] = ( + state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down + ) + new_state_dict[f"{diffusers_key}.lora_B.weight"] = ( + state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up + ) + + for diffusers_key, orig_key in [ + (f"transformer_blocks.{i}.ff_context.net.2", f"lora_unet_joint_blocks_{i}_context_block_mlp_fc2") + ]: + scale_down, scale_up = calculate_scales(orig_key) + new_state_dict[f"{diffusers_key}.lora_A.weight"] = ( + state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down + ) + new_state_dict[f"{diffusers_key}.lora_B.weight"] = ( + state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up + ) + + # core transformer blocks. + # sample blocks. + scale_down, scale_up = calculate_scales(f"lora_unet_joint_blocks_{i}_x_block_attn_qkv") + is_sparse, requested_rank = weight_is_sparse( + key=f"lora_unet_joint_blocks_{i}_x_block_attn_qkv", + rank=state_dict[f"lora_unet_joint_blocks_{i}_x_block_attn_qkv.lora_down.weight"].shape[0], + num_splits=3, + up_weight=state_dict[f"lora_unet_joint_blocks_{i}_x_block_attn_qkv.lora_up.weight"], + ) + num_splits = 3 + sample_qkv_lora_down = ( + state_dict.pop(f"lora_unet_joint_blocks_{i}_x_block_attn_qkv.lora_down.weight") * scale_down + ) + sample_qkv_lora_up = state_dict.pop(f"lora_unet_joint_blocks_{i}_x_block_attn_qkv.lora_up.weight") * scale_up + dims = [sample_qkv_lora_up.shape[0] // num_splits] * num_splits # 3 = num_splits + if not is_sparse: + for attn_k in ["to_q", "to_k", "to_v"]: + new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_A.weight"] = sample_qkv_lora_down + for attn_k, v in zip(["to_q", "to_k", "to_v"], torch.split(sample_qkv_lora_up, dims, dim=0)): + new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_B.weight"] = v + else: + # down_weight is chunked to each split + new_state_dict.update( + { + f"transformer_blocks.{i}.attn.{k}.lora_A.weight": v + for k, v in zip(["to_q", "to_k", "to_v"], torch.chunk(sample_qkv_lora_down, num_splits, dim=0)) + } + ) # noqa: C416 + + # up_weight is sparse: only non-zero values are copied to each split + i = 0 + for j, attn_k in enumerate(["to_q", "to_k", "to_v"]): + new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_B.weight"] = sample_qkv_lora_up[ + i : i + dims[j], j * requested_rank : (j + 1) * requested_rank + ].contiguous() + i += dims[j] + + # context blocks. + scale_down, scale_up = calculate_scales(f"lora_unet_joint_blocks_{i}_context_block_attn_qkv") + is_sparse, requested_rank = weight_is_sparse( + key=f"lora_unet_joint_blocks_{i}_context_block_attn_qkv", + rank=state_dict[f"lora_unet_joint_blocks_{i}_context_block_attn_qkv.lora_down.weight"].shape[0], + num_splits=3, + up_weight=state_dict[f"lora_unet_joint_blocks_{i}_context_block_attn_qkv.lora_up.weight"], + ) + num_splits = 3 + sample_qkv_lora_down = ( + state_dict.pop(f"lora_unet_joint_blocks_{i}_context_block_attn_qkv.lora_down.weight") * scale_down + ) + sample_qkv_lora_up = ( + state_dict.pop(f"lora_unet_joint_blocks_{i}_context_block_attn_qkv.lora_up.weight") * scale_up + ) + dims = [sample_qkv_lora_up.shape[0] // num_splits] * num_splits # 3 = num_splits + if not is_sparse: + for attn_k in ["add_q_proj", "add_k_proj", "add_v_proj"]: + new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_A.weight"] = sample_qkv_lora_down + for attn_k, v in zip( + ["add_q_proj", "add_k_proj", "add_v_proj"], torch.split(sample_qkv_lora_up, dims, dim=0) + ): + new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_B.weight"] = v + else: + # down_weight is chunked to each split + new_state_dict.update( + { + f"transformer_blocks.{i}.attn.{k}.lora_A.weight": v + for k, v in zip( + ["add_q_proj", "add_k_proj", "add_v_proj"], + torch.chunk(sample_qkv_lora_down, num_splits, dim=0), + ) + } + ) # noqa: C416 + + # up_weight is sparse: only non-zero values are copied to each split + i = 0 + for j, attn_k in enumerate(["add_q_proj", "add_k_proj", "add_v_proj"]): + new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_B.weight"] = sample_qkv_lora_up[ + i : i + dims[j], j * requested_rank : (j + 1) * requested_rank + ].contiguous() + i += dims[j] + + if len(state_dict) > 0: + raise ValueError(f"`state_dict` should be at this point but has: {list(state_dict.keys())}.") + + prefix = prefix or "transformer" + new_state_dict = {f"{prefix}.{k}": v for k, v in new_state_dict.items()} + return new_state_dict + + def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict): converted_state_dict = {} original_state_dict_keys = list(original_state_dict.keys()) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index eb9b42c5fbb7..cf36145847d2 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -38,6 +38,7 @@ _convert_bfl_flux_control_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, + _convert_non_diffusers_sd3_lora_to_diffusers, _convert_xlabs_flux_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers, ) @@ -1239,6 +1240,27 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + is_non_diffusers = any("lora_unet" in k for k in state_dict) + if is_non_diffusers: + has_only_transformer = all(k.startswith("lora_unet") for k in state_dict) + if not has_only_transformer: + state_dict = {k: v for k, v in state_dict.items() if k.startswith("lora_unet")} + logger.warning( + "Some keys in the LoRA checkpoint are not related to transformer blocks and we will filter them out during loading. Please open a new issue with the LoRA checkpoint you are trying to load with a reproducible snippet - https://github.com/huggingface/diffusers/issues/new." + ) + + all_joint_blocks = all("joint_blocks" in k for k in state_dict) + if not all_joint_blocks: + raise ValueError( + "LoRAs containing only transformer blocks are supported at this point. Please open a new issue with the LoRA checkpoint you are trying to load with a reproducible snippet - https://github.com/huggingface/diffusers/issues/new." + ) + + has_dual_attention_layers = any("attn2" in k for k in state_dict) + if has_dual_attention_layers: + raise ValueError("LoRA state dicts with dual attention layers are not supported.") + + state_dict = _convert_non_diffusers_sd3_lora_to_diffusers(state_dict, prefix=cls.transformer_name) + return state_dict def load_lora_weights( @@ -1283,12 +1305,11 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k} + transformer_state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.")} if len(transformer_state_dict) > 0: self.load_lora_into_transformer( state_dict, @@ -1299,8 +1320,10 @@ def load_lora_weights( _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, ) + else: + logger.debug("No LoRA keys were found for the transformer.") - text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + text_encoder_state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{self.text_encoder_name}.")} if len(text_encoder_state_dict) > 0: self.load_lora_into_text_encoder( text_encoder_state_dict, @@ -1312,8 +1335,10 @@ def load_lora_weights( _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, ) + else: + logger.debug("No LoRA keys were found for the first text encoder.") - text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} + text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if k.startswith("text_encoder_2.")} if len(text_encoder_2_state_dict) > 0: self.load_lora_into_text_encoder( text_encoder_2_state_dict, @@ -1325,6 +1350,8 @@ def load_lora_weights( _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, ) + else: + logger.debug("No LoRA keys were found for the second text encoder.") @classmethod def load_lora_into_transformer(