Skip to content

Commit

Permalink
feat: Support bin IP Adapter files
Browse files Browse the repository at this point in the history
  • Loading branch information
blessedcoolant authored and hipsterusername committed Jun 20, 2024
1 parent b03073d commit 4213ad4
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
2 changes: 2 additions & 0 deletions invokeai/backend/ip_adapter/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def load_ip_adapter_tensors(ip_adapter_ckpt_path: pathlib.Path, device: str) ->
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
else:
raise RuntimeError(f"Encountered unexpected IP Adapter state dict key: '{key}'.")
elif ip_adapter_ckpt_path.suffix == ".bin":
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
else:
ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path / "ip_adapter.bin"
state_dict = torch.load(ip_adapter_diffusers_checkpoint_path, map_location="cpu")
Expand Down
11 changes: 8 additions & 3 deletions invokeai/backend/model_manager/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[C
return ModelType.LoRA
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
return ModelType.ControlNet
elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}):
elif any(key.startswith(v) for v in {"image_proj", "ip_adapter"}):
return ModelType.IPAdapter
elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion
Expand Down Expand Up @@ -542,9 +542,14 @@ class IPAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
for key in checkpoint.keys():
if not key.startswith(("image_proj.", "ip_adapter.")):
if not key.startswith(("image_proj", "ip_adapter")):
continue
cross_attention_dim = checkpoint["ip_adapter.1.to_k_ip.weight"].shape[-1]

if key in ["image_proj", "ip_adapter"]:
cross_attention_dim = checkpoint["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
else:
cross_attention_dim = checkpoint["ip_adapter.1.to_k_ip.weight"].shape[-1]

if cross_attention_dim == 768:
return BaseModelType.StableDiffusion1
elif cross_attention_dim == 1024:
Expand Down

0 comments on commit 4213ad4

Please sign in to comment.