From 4213ad4ad7a4d0af125b35f78ca6d7d96071d375 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 29 May 2024 01:12:30 +0530 Subject: [PATCH] feat: Support bin IP Adapter files --- invokeai/backend/ip_adapter/ip_adapter.py | 2 ++ invokeai/backend/model_manager/probe.py | 11 ++++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index c33cb3f4ab4..a649f336fdd 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -221,6 +221,8 @@ def load_ip_adapter_tensors(ip_adapter_ckpt_path: pathlib.Path, device: str) -> state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model[key] else: raise RuntimeError(f"Encountered unexpected IP Adapter state dict key: '{key}'.") + elif ip_adapter_ckpt_path.suffix == ".bin": + state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu") else: ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path / "ip_adapter.bin" state_dict = torch.load(ip_adapter_diffusers_checkpoint_path, map_location="cpu") diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index a19a7727642..dec11687e2e 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -231,7 +231,7 @@ def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[C return ModelType.LoRA elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}): return ModelType.ControlNet - elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}): + elif any(key.startswith(v) for v in {"image_proj", "ip_adapter"}): return ModelType.IPAdapter elif key in {"emb_params", "string_to_param"}: return ModelType.TextualInversion @@ -542,9 +542,14 @@ class IPAdapterCheckpointProbe(CheckpointProbeBase): def get_base_type(self) -> BaseModelType: checkpoint = self.checkpoint for key in checkpoint.keys(): - if not key.startswith(("image_proj.", "ip_adapter.")): + if not key.startswith(("image_proj", "ip_adapter")): continue - cross_attention_dim = checkpoint["ip_adapter.1.to_k_ip.weight"].shape[-1] + + if key in ["image_proj", "ip_adapter"]: + cross_attention_dim = checkpoint["ip_adapter"]["1.to_k_ip.weight"].shape[-1] + else: + cross_attention_dim = checkpoint["ip_adapter.1.to_k_ip.weight"].shape[-1] + if cross_attention_dim == 768: return BaseModelType.StableDiffusion1 elif cross_attention_dim == 1024: