-
Notifications
You must be signed in to change notification settings - Fork 6.4k
[IP-Adapter] Support multiple IP-Adapters #6573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7e1bf9b
3b55a1e
345b4d6
f6bae6e
baa7b83
9024698
27fc796
67908bf
afd91e3
7d42455
1fdcd7b
45fb582
1a4c6b1
d924c47
cc2aa1b
2c86534
0049e44
4a3df90
ff96407
193d6e8
f7f2465
efa704a
9abdcf9
5e47ceb
c6670de
711387e
bce309f
fae861e
21da205
accee6b
816578f
e57103f
6cfa34b
1e68c64
2cc1561
475046e
98fa0c2
3a52ecb
e742cf4
dcdde9c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,8 +11,9 @@ | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from pathlib import Path | ||
| from typing import Dict, Union | ||
| from typing import Dict, List, Union | ||
|
|
||
| import torch | ||
| from huggingface_hub.utils import validate_hf_hub_args | ||
|
|
@@ -45,9 +46,9 @@ class IPAdapterMixin: | |
| @validate_hf_hub_args | ||
| def load_ip_adapter( | ||
| self, | ||
| pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], | ||
| subfolder: str, | ||
| weight_name: str, | ||
| pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]], | ||
| subfolder: Union[str, List[str]], | ||
| weight_name: Union[str, List[str]], | ||
| **kwargs, | ||
| ): | ||
| """ | ||
|
|
@@ -87,6 +88,26 @@ def load_ip_adapter( | |
| The subfolder location of a model file within a larger model repository on the Hub or locally. | ||
| """ | ||
|
|
||
| # handle the list inputs for multiple IP Adapters | ||
| if not isinstance(weight_name, list): | ||
| weight_name = [weight_name] | ||
|
|
||
| if not isinstance(pretrained_model_name_or_path_or_dict, list): | ||
| pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict] | ||
| if len(pretrained_model_name_or_path_or_dict) == 1: | ||
| pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name) | ||
|
Comment on lines
+97
to
+98
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it cater to the case where you have a single model_id and multiple weight names of different IP Adapters? |
||
|
|
||
| if not isinstance(subfolder, list): | ||
| subfolder = [subfolder] | ||
| if len(subfolder) == 1: | ||
| subfolder = subfolder * len(weight_name) | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if len(weight_name) != len(pretrained_model_name_or_path_or_dict): | ||
| raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.") | ||
|
|
||
| if len(weight_name) != len(subfolder): | ||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| raise ValueError("`weight_name` and `subfolder` must have the same length.") | ||
|
|
||
| # Load the main state dict first. | ||
| cache_dir = kwargs.pop("cache_dir", None) | ||
| force_download = kwargs.pop("force_download", False) | ||
|
|
@@ -100,61 +121,68 @@ def load_ip_adapter( | |
| "file_type": "attn_procs_weights", | ||
| "framework": "pytorch", | ||
| } | ||
|
|
||
| if not isinstance(pretrained_model_name_or_path_or_dict, dict): | ||
| model_file = _get_model_file( | ||
| pretrained_model_name_or_path_or_dict, | ||
| weights_name=weight_name, | ||
| cache_dir=cache_dir, | ||
| force_download=force_download, | ||
| resume_download=resume_download, | ||
| proxies=proxies, | ||
| local_files_only=local_files_only, | ||
| token=token, | ||
| revision=revision, | ||
| subfolder=subfolder, | ||
| user_agent=user_agent, | ||
| ) | ||
| if weight_name.endswith(".safetensors"): | ||
| state_dict = {"image_proj": {}, "ip_adapter": {}} | ||
| with safe_open(model_file, framework="pt", device="cpu") as f: | ||
| for key in f.keys(): | ||
| if key.startswith("image_proj."): | ||
| state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) | ||
| elif key.startswith("ip_adapter."): | ||
| state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) | ||
| else: | ||
| state_dict = torch.load(model_file, map_location="cpu") | ||
| else: | ||
| state_dict = pretrained_model_name_or_path_or_dict | ||
|
|
||
| keys = list(state_dict.keys()) | ||
| if keys != ["image_proj", "ip_adapter"]: | ||
| raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") | ||
|
|
||
| # load CLIP image encoder here if it has not been registered to the pipeline yet | ||
| if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: | ||
| state_dicts = [] | ||
| for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip( | ||
| pretrained_model_name_or_path_or_dict, weight_name, subfolder | ||
| ): | ||
| if not isinstance(pretrained_model_name_or_path_or_dict, dict): | ||
| logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") | ||
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( | ||
| model_file = _get_model_file( | ||
| pretrained_model_name_or_path_or_dict, | ||
| subfolder=Path(subfolder, "image_encoder").as_posix(), | ||
| ).to(self.device, dtype=self.dtype) | ||
| self.image_encoder = image_encoder | ||
| self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"]) | ||
| weights_name=weight_name, | ||
| cache_dir=cache_dir, | ||
| force_download=force_download, | ||
| resume_download=resume_download, | ||
| proxies=proxies, | ||
| local_files_only=local_files_only, | ||
| token=token, | ||
| revision=revision, | ||
| subfolder=subfolder, | ||
| user_agent=user_agent, | ||
| ) | ||
| if weight_name.endswith(".safetensors"): | ||
| state_dict = {"image_proj": {}, "ip_adapter": {}} | ||
| with safe_open(model_file, framework="pt", device="cpu") as f: | ||
| for key in f.keys(): | ||
| if key.startswith("image_proj."): | ||
| state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) | ||
| elif key.startswith("ip_adapter."): | ||
| state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) | ||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| else: | ||
| state_dict = torch.load(model_file, map_location="cpu") | ||
| else: | ||
| raise ValueError("`image_encoder` cannot be None when using IP Adapters.") | ||
|
|
||
| # create feature extractor if it has not been registered to the pipeline yet | ||
| if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: | ||
| self.feature_extractor = CLIPImageProcessor() | ||
| self.register_to_config(feature_extractor=["transformers", "CLIPImageProcessor"]) | ||
|
|
||
| # load ip-adapter into unet | ||
| state_dict = pretrained_model_name_or_path_or_dict | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the case when someone passes the state_dict directly. Shouldn't this statement then be guarded with a check? Something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is already done so it's part of this if else |
||
|
|
||
| keys = list(state_dict.keys()) | ||
| if keys != ["image_proj", "ip_adapter"]: | ||
| raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") | ||
|
Comment on lines
+155
to
+157
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are relying on the In this case, that could be before There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. moved up:) |
||
|
|
||
| state_dicts.append(state_dict) | ||
|
|
||
| # load CLIP image encoder here if it has not been registered to the pipeline yet | ||
| if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: | ||
| if not isinstance(pretrained_model_name_or_path_or_dict, dict): | ||
| logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") | ||
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( | ||
| pretrained_model_name_or_path_or_dict, | ||
| subfolder=Path(subfolder, "image_encoder").as_posix(), | ||
| ).to(self.device, dtype=self.dtype) | ||
| self.image_encoder = image_encoder | ||
| self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"]) | ||
| else: | ||
| raise ValueError("`image_encoder` cannot be None when using IP Adapters.") | ||
|
|
||
| # create feature extractor if it has not been registered to the pipeline yet | ||
| if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: | ||
| feature_extractor = CLIPImageProcessor() | ||
| self.register_modules(feature_extractor=feature_extractor) | ||
|
|
||
| # load ip-adapter into unet | ||
| unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet | ||
| unet._load_ip_adapter_weights(state_dict) | ||
| unet._load_ip_adapter_weights(state_dicts) | ||
|
|
||
| def set_ip_adapter_scale(self, scale): | ||
| if not isinstance(scale, list): | ||
| scale = [scale] | ||
| unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet | ||
| for attn_processor in unet.attn_processors.values(): | ||
| if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,12 @@ | |
| from huggingface_hub.utils import validate_hf_hub_args | ||
| from torch import nn | ||
|
|
||
| from ..models.embeddings import ImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection | ||
| from ..models.embeddings import ( | ||
| ImageProjection, | ||
| IPAdapterFullImageProjection, | ||
| IPAdapterPlusImageProjection, | ||
| MultiIPAdapterImageProjection, | ||
| ) | ||
| from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta | ||
| from ..utils import ( | ||
| USE_PEFT_BACKEND, | ||
|
|
@@ -763,28 +768,14 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict): | |
| image_projection.load_state_dict(updated_state_dict) | ||
| return image_projection | ||
|
|
||
| def _load_ip_adapter_weights(self, state_dict): | ||
| def _convert_ip_adapter_attn_to_diffusers(self, state_dicts): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 100 percent the right choice! |
||
| from ..models.attention_processor import ( | ||
| AttnProcessor, | ||
| AttnProcessor2_0, | ||
| IPAdapterAttnProcessor, | ||
| IPAdapterAttnProcessor2_0, | ||
| ) | ||
|
Comment on lines
772
to
777
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can move the imports to the top no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was following the same pattern in this file https://github.com/huggingface/diffusers to/blob/87a92f779c5ba9c180aec4b90c38149eb108d888/src/diffusers/loaders/unet.py#L449 I thought it was to avoid circular import or something, but I'm not really sure. I can look into maybe in a separate PR to see if we can move all the |
||
|
|
||
| if "proj.weight" in state_dict["image_proj"]: | ||
| # IP-Adapter | ||
| num_image_text_embeds = 4 | ||
| elif "proj.3.weight" in state_dict["image_proj"]: | ||
| # IP-Adapter Full Face | ||
| num_image_text_embeds = 257 # 256 CLIP tokens + 1 CLS token | ||
| else: | ||
| # IP-Adapter Plus | ||
| num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1] | ||
|
|
||
| # Set encoder_hid_proj after loading ip_adapter weights, | ||
| # because `IPAdapterPlusImageProjection` also has `attn_processors`. | ||
| self.encoder_hid_proj = None | ||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # set ip-adapter cross-attention processors & load state_dict | ||
| attn_procs = {} | ||
| key_id = 1 | ||
|
|
@@ -798,6 +789,7 @@ def _load_ip_adapter_weights(self, state_dict): | |
| elif name.startswith("down_blocks"): | ||
| block_id = int(name[len("down_blocks.")]) | ||
| hidden_size = self.config.block_out_channels[block_id] | ||
|
|
||
| if cross_attention_dim is None or "motion_modules" in name: | ||
| attn_processor_class = ( | ||
| AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor | ||
|
|
@@ -807,6 +799,18 @@ def _load_ip_adapter_weights(self, state_dict): | |
| attn_processor_class = ( | ||
| IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor | ||
| ) | ||
| num_image_text_embeds = [] | ||
| for state_dict in state_dicts: | ||
| if "proj.weight" in state_dict["image_proj"]: | ||
| # IP-Adapter | ||
| num_image_text_embeds += [4] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's give "4" a variable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is yet to be addressed? |
||
| elif "proj.3.weight" in state_dict["image_proj"]: | ||
| # IP-Adapter Full Face | ||
| num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token | ||
| else: | ||
| # IP-Adapter Plus | ||
| num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]] | ||
|
|
||
| attn_procs[name] = attn_processor_class( | ||
| hidden_size=hidden_size, | ||
| cross_attention_dim=cross_attention_dim, | ||
|
|
@@ -815,16 +819,31 @@ def _load_ip_adapter_weights(self, state_dict): | |
| ).to(dtype=self.dtype, device=self.device) | ||
|
|
||
| value_dict = {} | ||
| for k, w in attn_procs[name].state_dict().items(): | ||
| value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]}) | ||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| for i, state_dict in enumerate(state_dicts): | ||
| value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]}) | ||
| value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]}) | ||
|
|
||
| attn_procs[name].load_state_dict(value_dict) | ||
| key_id += 2 | ||
|
|
||
| return attn_procs | ||
|
|
||
| def _load_ip_adapter_weights(self, state_dicts): | ||
| if not isinstance(state_dicts, list): | ||
| state_dicts = [state_dicts] | ||
|
Comment on lines
+832
to
+833
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it to ensure BW compatibility? I don't see how There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes it is |
||
| # Set encoder_hid_proj after loading ip_adapter weights, | ||
| # because `IPAdapterPlusImageProjection` also has `attn_processors`. | ||
| self.encoder_hid_proj = None | ||
|
Comment on lines
+834
to
+836
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I don't understand the comment fully. Could you elaborate? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can't 😛 I assume it's notes from contributor of Ip adapter plus |
||
|
|
||
| attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts) | ||
| self.set_attn_processor(attn_procs) | ||
|
|
||
| # convert IP-Adapter Image Projection layers to diffusers | ||
| image_projection = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"]) | ||
| image_projection_layers = [] | ||
| for state_dict in state_dicts: | ||
| image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"]) | ||
| image_projection_layer.to(device=self.device, dtype=self.dtype) | ||
| image_projection_layers.append(image_projection_layer) | ||
|
Comment on lines
+842
to
+846
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very nice delegation. First convert to attention procs and then handle the rest of the stuff like projection with a dedicated method. |
||
|
|
||
| self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype) | ||
| self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) | ||
| self.config.encoder_hid_dim_type = "ip_image_proj" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so the user doesn't have to remember the config, they can just use
from_configinstead - it uses a DDIM but with the same config as the default scheduler