From 5b7b7a06e6f81f61f36d1f976d189bd5e60da9b8 Mon Sep 17 00:00:00 2001 From: Alexander Kovalchuk Date: Tue, 19 Dec 2023 15:44:51 +0300 Subject: [PATCH] Fixed several errors in StableDiffusion adapter conversion script --- .../stable_diffusion/convert_sd_adapter_to_peft.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/stable_diffusion/convert_sd_adapter_to_peft.py b/examples/stable_diffusion/convert_sd_adapter_to_peft.py index 3150d9e748..e9ca9d14d1 100644 --- a/examples/stable_diffusion/convert_sd_adapter_to_peft.py +++ b/examples/stable_diffusion/convert_sd_adapter_to_peft.py @@ -39,7 +39,7 @@ def peft_state_dict(self) -> Dict[str, torch.Tensor]: if self.lora_A is None or self.lora_B is None: raise ValueError("At least one of lora_A or lora_B is None, they must both be provided") return { - f"base_model.model{self.peft_key}.lora_A.weight": self.lora_A, + f"base_model.model.{self.peft_key}.lora_A.weight": self.lora_A, f"base_model.model.{self.peft_key}.lora_B.weight": self.lora_B, } @@ -483,6 +483,10 @@ def detect_adapter_type(keys: List[str]) -> PeftType: # Process each model sequentially for model, model_name in [(text_encoder, "text_encoder"), (unet, "unet")]: + # Skip model if no data was provided + if len(adapter_info[model_name]) == 0: + continue + config = construct_config_fn(adapter_info[model_name], decompose_factor=decompose_factor) # Output warning for LoHa with use_effective_conv2d @@ -497,7 +501,11 @@ def detect_adapter_type(keys: List[str]) -> PeftType: ) model = get_peft_model(model, config) - set_peft_model_state_dict(model, combine_peft_state_dict(adapter_info[model_name])) + missing_keys, unexpected_keys = set_peft_model_state_dict( + model, combine_peft_state_dict(adapter_info[model_name]) + ) + if len(unexpected_keys) > 0: + raise ValueError(f"Unexpected keys {unexpected_keys} found during conversion") if args.half: model.to(torch.float16)